Global Neighbor Sampling with Caching
- The paper introduces a global caching mechanism that reuses neighbor samples to accelerate computations in both Monte Carlo PDE solvers and graph neural network training.
- It employs efficient cache construction and kernel-based reweighting strategies to minimize redundant computations and reduce estimator variance.
- It demonstrates significant practical improvements, achieving up to 2×–14× speedups while maintaining accuracy in complex simulation and training tasks.
Global neighbor sampling with caching is a family of algorithms designed to accelerate sampling-based computations in large-scale problems, ranging from stochastic PDE solvers to distributed and mixed CPU–GPU training of graph neural networks (GNNs). These approaches use a global cache of samples or node data to reduce redundant computation, minimize data movement, or decrease estimator variance, in contrast to purely local or pointwise methods.
1. Fundamental Concepts and Definitions
Global neighbor sampling with caching operates on the core principle of constructing a reusable global set of samples (or node data) that efficiently supports repeated evaluation or learning queries across a domain or a graph. In the canonical setting of solving Laplace’s equation via stochastic representations on domains , the method builds a cache of spatial “centers” coupled with stored walk data. Each walk samples Brownian motion paths outward from centers, storing both first-exit and boundary-hit points, enabling flexible reuse through kernel-based reweighting to estimate the solution at nearby locations. In large-scale GNN training (e.g., SALIENT++ and GNS), the cache consists of selected high-utility node features, which reside in GPU memory and support repeated, importance-weight-corrected neighbor sampling for minibatch computation (Czekanski et al., 2024, Kaler et al., 2023, Dong et al., 2021).
2. Formal Problem Statement and Theoretical Basis
In the Monte Carlo Laplace setting, the problem is to approximate solving
with where is Brownian motion and the exit time from (Theorem 2.1 in (Czekanski et al., 2024)). The global neighbor sampling approach constructs a grid-aligned covering of centers and, for each , stores independent Walk-on-Spheres (WOS) trajectories, noted as , where is the first-exit from sphere and the eventual boundary hit. At query, for any sufficiently deep in , the estimator is
with local reweighting , where is the Poisson kernel on the sphere .
In GNN training, the analogous problem is to minimize the data movement and memory footprint in mini-batched, multi-hop neighbor sampling. Vertex-wise inclusion probabilities (VIPs) are computed for each vertex as the probability it appears in a -hop sample rooted at a particular partition. The static caching policy selects the highest-VIP nodes for replication on each worker, enabling efficient local and remote sampling (Kaler et al., 2023, Dong et al., 2021).
3. Algorithmic Structure and Pseudocode
Monte Carlo PDE (Laplace) Global Neighbor Sampling
Cache Construction (Offline):
- Cover with a -spaced grid ; select .
- For each , run WOS walks to generate and cache .
Query (Online):
- For query , find via spatial indexing.
- Aggregate cached walk outcomes with Poisson-kernel reweighting to form the estimator .
Distributed/Mixed CPU–GPU GNN Training
Cache Construction:
- Compute per-node sampling probabilities (degree-based or via propagation of initial distribution from training nodes).
- Sample a global cache of nodes to reside in GPU memory, proportional to their utility.
- For each mini-batch, perform in-GPU neighbor sampling, preferring cached nodes and applying importance correction where needed (Dong et al., 2021).
SALIENT++ / VIP-Based Distributed Caching:
- Given graph partitions and mini-batch scheme, compute the exact VIPs for each remote vertex.
- Statistically rank and cache the highest-VIP candidates up to the allowed replication factor ().
- The cache remains static per epoch (or several epochs). Queries to uncached nodes are rare and handled asynchronously, overlapping communication with computation.
Pseudocode snippets outlining these phases are provided in the original sources (Czekanski et al., 2024, Kaler et al., 2023, Dong et al., 2021).
4. Data Structures, Sampling Strategies, and Complexity
The cache for Laplace’s equation consists of tuples , indexed spatially (e.g., kd-tree, uniform grid with cell size ), supporting neighbor retrieval per query, with neighbors within radius . In GNN methods, the GPU cache includes a dense feature tensor indexed by position, adjacency lists of cached neighbors for each node, and mapping tables for gather/scatter during neighbor aggregation (Dong et al., 2021).
Neighbor selection strategies include:
- Equal-weight scheme: restricts the neighbor set to guarantee variance bounds (e.g., ).
- Inverse-variance weighting: expands the set to all with weights optimizing variance (Czekanski et al., 2024).
- VIP-based selection (GNN): statically weights and ranks cache candidates by their expected inclusion probability, minimizing communication (Kaler et al., 2023).
Offline cost for cache construction is (Laplace, with the required boundary tolerance), while online query is per evaluation, yielding end-to-end total cost for queries (Czekanski et al., 2024). For GNNs, cache size is often of , and communication reduction is proportional to the cache hit rate (Dong et al., 2021).
5. Variance Reduction and Statistical Guarantees
In the Laplace setting, the variance of the reused-walk estimator via global neighbor sampling is provably reduced: for effective neighbor centers, variance drops by a factor of compared to pointwise estimation, i.e., (Theorem 4.3, (Czekanski et al., 2024)), where bounds the range of . Lemma 4.2 quantifies the variance–neighbor radius tradeoff, while cache covering and neighbor selection determine the effective for variance stacking.
For GNNs, theoretical analysis yields that as soon as the cache size and fan-out meet (with notation as in (Dong et al., 2021)), the mean-squared-error of gradients under cached sampling matches the order of that from the full node-wise sampler. Convergence rate of stochastic gradient descent is thus preserved for sufficiently sized cache.
6. Empirical and Practical Assessment
Empirical evaluations demonstrate the variance and error reduction in Laplace-PDE applications: in domains, equal-weighted or variance-weighted estimators show – lower error over original WOS under fixed budget, and error remains stable under growth in evaluation points when reuse is employed (Czekanski et al., 2024).
In large-scale GNN settings, global neighbor sampling with caching yields significant speedups (up to vs node-wise sampling; vs layer-based sampling such as LADIES), with negligible or no loss in model accuracy. Cache hit rates above are typical with $0.3$ replication fractions, and communication overhead becomes negligible due to pipelined overlap (Kaler et al., 2023, Dong et al., 2021). Table-based results in the primary references summarize epoch times, F1 scores, and the dependency on cache size across several publicly available datasets.
| Application | Method (Cache) | Speedup | Accuracy Impact |
|---|---|---|---|
| Laplace equation | Global Neighbor (WOS) | – lower error | Stable/Less Variance |
| GNN (Products) | GNS, SALIENT++ | – | of baseline F1 |
7. Trade-offs and Tuning Considerations
Key tuning parameters for successful deployment are cache size (e.g., spatial quantization parameter , or fraction in distributed GNN training) and neighbor selection/radius (e.g., factor or use of the full kernel reweighting). Decreasing increases cache memory and offline cost, but allows more optimal near-boundary queries and higher variance reduction. Expanding the neighbor radius (or weight function) raises the effective number of walks participating in estimation and thus reduces variance, with diminishing returns if variance per walk increases. In distributed GNN contexts, replication fraction directly controls memory–bandwidth trade-off and is typically set to achieve cache hit-rate.
By appropriate parameter selection, global neighbor sampling with caching transforms methods with otherwise superlinear or communication-bound scaling into computationally efficient, linear-in-query (or epoch) algorithms with substantial variance and bandwidth reduction (Czekanski et al., 2024, Kaler et al., 2023, Dong et al., 2021).