Batched Conjugate Gradients (mBCG)
- Batched Conjugate Gradients (mBCG) is a matrix-oriented extension of the classic CG method that solves multiple symmetric positive-definite systems in parallel to accelerate Gaussian process inference.
- mBCG leverages batched matrix operations and GPU-friendly algorithms to reduce computational cost from O(n³) to O(n²) per iteration and effectively implements preconditioning.
- Empirical benchmarks demonstrate up to 32× speedup and seamless integration with libraries like GPyTorch, facilitating scalable and hardware-efficient GP training and inference.
The batched conjugate gradients (mBCG) algorithm is a matrix-oriented extension of the classic conjugate gradients (CG) method. It solves multiple symmetric positive-definite linear systems for a shared matrix and several right-hand sides (RHS) in parallel. mBCG is fundamental to Blackbox Matrix–Matrix (BBMM) inference, enabling scalable and hardware-efficient Gaussian process (GP) inference by leveraging modern GPU architectures and reducing the computational bottleneck of standard GP methods from to per iteration. mBCG also underpins efficient stochastic estimation of traces and log-determinants crucial for kernel hyperparameter optimization and is closely related to cooperative and block CG methods for parallel architectures (Gardner et al., 2018, Bhaya et al., 2012).
1. Classical and Batched Conjugate Gradients
The standard CG algorithm is used for solving where is symmetric positive-definite. It iteratively builds approximations in the Krylov subspace, minimizing the quadratic . CG is optimal for a single RHS; however, modern applications such as GP inference require solutions for multiple RHS. mBCG extends CG to matrix equations , where , by maintaining batched iterates and performing simultaneous updates:
- Batched matrix–matrix multiplies:
- Per-column step-sizes:
- Elementwise and diagonal matrix updates for all batched variables.
These operations allow independent systems to be solved efficiently within a single kernel on GPU hardware. Since is typically small (10–20), all batched operations are highly parallelizable (Gardner et al., 2018).
2. Algorithmic Formulation and Pseudocode
mBCG takes as input a kernel matmul routine and a matrix of RHS vectors . Pseudocode accentuates GPU-friendly, batched operations:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
X = zeros(n, t) R = B - M(X) P = R rho = colwise_dot(R, R) for i in range(p): W = M(P) sigma = colwise_dot(P, W) alpha = rho / sigma X = X + P * diag(alpha) R = R - W * diag(alpha) rho_new = colwise_dot(R, R) beta = rho_new / rho P = R + P * diag(beta) rho = rho_new |
Here, colwise_dot computes per-column inner products. All matrix–matrix multiplications and batchwise vector operations are GPU-amenable (Gardner et al., 2018).
3. Preconditioning and Convergence Acceleration
To improve convergence, mBCG applies preconditioning, most effectively implemented via a low-rank pivoted Cholesky factorization:
Preconditioned mBCG solves , where applications and can be computed in and . The preconditioned updates require additional matrix solves, which are batched across columns.
Preconditioning, especially with –10, reduces required CG iterations by more than 10 for highly correlated kernels such as deep RBF or Matérn, typically lowering iteration counts from 20 to 2–4 for competitive inference accuracy (Gardner et al., 2018).
4. Extraction of Gaussian Process Inference Quantities
A core property of mBCG is that it enables computation of all terms needed for GP training and inference in a single call:
- Solves: for posterior means.
- Stochastic trace estimates: via probe vectors and their corresponding solutions .
- Log-determinant: Approximated by stochastic Lanczos quadrature using CG coefficients (the ), yielding a tridiagonal matrix for each probe. The log-determinant is then .
Consequently, there is no need for separate Lanczos or additional iterative runs (Gardner et al., 2018).
5. Relation to Cooperative and Block CG Methods
mBCG is related to block and cooperative CG paradigms, such as cCG (Bhaya et al., 2012), in which "agents" simultaneously maintain their iterates, directions, and residuals, leveraging matrix-valued step-sizes and for coordination. In cCG, inner Gram matrices and (size ) are computed at every iteration, and updates occur by linear combinations across all agents. This yields finite termination in at most steps and nearly speedup in wall-clock time on multicore hardware for large , under exact arithmetic and full-rank assumptions. cCG minimizes per-agent flop counts to at optimal , compared to for serial CG (Bhaya et al., 2012). A plausible implication is that mBCG, as a GPU-focused batched method, realizes similar efficiencies by organizing all probe systems as "agents" over matrix operations.
6. Complexity, Implementation, and Empirical Benchmarks
The computational cost per mBCG iteration is dominated by the (potentially black-box) kernel matrix–matrix multiply , which costs if is dense. Memory requirements are for storing batched variables. The total cost is for CG iterations and probe vectors. In comparison, Cholesky decomposition costs plus multiple solves.
When kernel structure enables further acceleration (e.g., SKI or SoR), the cost per iteration is further reduced. Pivoted Cholesky preconditioning imposes negligible overhead for .
Empirically, GPyTorch’s mBCG implementation achieves:
- 20–32 speedup for exact GPs (k)
- 10–15 for SGPR (k, )
- 20 for SKI (k, k) on high-end GPUs, outperforming previous CPU and GPU implementations (Gardner et al., 2018).
7. Implementation in GPyTorch and Practical Considerations
In GPyTorch, kernel objects provide routines for both and . All batched operations reside on GPU, with minimal data transfer between CPU and device. Pivoted-Cholesky preconditioning is performed efficiently for low ranks and automatically differentiable via PyTorch’s autograd mechanism. This permits backpropagation of the GP marginal likelihood through kernel and preconditioner computations “for free.” The main mBCG loop is expressed using high-level batched matrix algebra (e.g., torch.bmm, einsum), aligning well with hardware accelerators (Gardner et al., 2018).
References:
- (Gardner et al., 2018) "GPyTorch: Blackbox Matrix-Matrix Gaussian Process Inference with GPU Acceleration"
- (Bhaya et al., 2012) "A cooperative conjugate gradient method for linear systems permitting multithread implementation of low complexity"