PyBird-JAX: Accelerated inference in large-scale structure with model-independent emulation of one-loop galaxy power spectra
Abstract: We present $\texttt{PyBird-JAX}$, a differentiable, $\texttt{JAX}$-based implementation of $\texttt{PyBird}$, using internal neural network emulators to accelerate computationally costly operations for rapid large-scale structure (LSS) analysis. $\texttt{PyBird-JAX}$ computes one-loop EFTofLSS predictions for redshift-space galaxy power spectrum multipoles in 1.2 ms on a CPU and 0.2 ms on a GPU, achieving 3-4 orders of magnitude speed-up over $\texttt{PyBird}$. The emulators take a compact spline-based representation of the input linear power spectrum $P(k)$ as feature vectors, making the approach applicable to a wide range of cosmological models. We rigorously validate its accuracy against large-volume simulations and on BOSS data, including cosmologies not explicitly represented in the training set. Leveraging automatic differentiation, $\texttt{PyBird-JAX}$ supports Fisher forecasting, Taylor expansion of model predictions, gradient-based searches, and vectorised ensemble sampling. Interfaced with a variety of samplers and Boltzmann solvers, $\texttt{PyBird-JAX}$ provides a high-performance, end-to-end inference pipeline. Combined with a symbolic-$P(k)$ generator, a typical Stage-4 LSS MCMC converges in minutes on a GPU. Our results demonstrate that $\texttt{PyBird-JAX}$ delivers the precision and speed required for upcoming LSS surveys, opening the door to accelerated cosmological inference with minimal accuracy loss and no pretraining. In a companion paper [1], we put $\texttt{PyBird-JAX}$ to use in achieving LSS marginalised constraints free from volume projection effects through non-flat measures.
Paper Prompts
Sign up for free to create and run prompts on this paper using GPT-5.
Top Community Prompts
Collections
Sign up for free to add this paper to one or more collections.