CuTile-Based Flash Attention on NVIDIA GB10
- CuTile-based Flash Attention is an optimization that adapts FlashAttention to a tile-centric model, fusing softmax and matrix multiplication without materializing the dense attention matrix.
- It employs Sawtooth Wavefront Reordering to alternate K/V tile scan directions, effectively halving L2 cache reuse distance and reducing non-compulsory cache misses by up to 67%.
- Empirical evaluations on NVIDIA GB10 demonstrate throughput improvements of up to 60% for causal and non-causal workloads, enabling efficient large language model performance.
CuTile-based Flash Attention refers to the adaptation of the FlashAttention algorithm in NVIDIA’s high-level tile-centric tensor-core programming model, CuTile, and its associated performance optimizations on modern GPU architectures, notably NVIDIA Grace Blackwell (GB10). FlashAttention fuses the softmax and matrix-multiplication operations underlying Transformer attention, streaming query (Q) tiles only once and repeatedly streaming key (K) and value (V) tiles through on-chip SRAM. Unlike traditional attention implementations, FlashAttention does not materialize the dense attention matrix, leading to significant memory and compute advantages. CuTile-based FlashAttention formalizes this logic using CuTile’s abstractions, producing high-throughput kernels suitable for LLM workloads. The key challenge is maximizing cache efficiency for the K/V working set, particularly as sequence lengths and head dimensions increase, potentially exceeding last-level cache (L2) capacity.
1. Architectural Context and Motivation
FlashAttention’s fused softmax and matmul kernel design enables streaming access patterns: Q-tiles are processed sequentially, while each CTA (Cooperative Thread Array) streams K/V tiles repeatedly from global memory into shared SRAM. CuTile, NVIDIA’s tile-centric programming model, enables description of such tensor-core kernels at a high level and facilitates mapping of the split-Q FlashAttention approach.
On NVIDIA GB10—comprising 48 streaming multiprocessors (SMs) and a unified 24 MiB L2 cache—L1 and texture cache hit rates for streaming attention are negligible (<0.02%), making L2 cache utilization the principal concern. When the K/V working set exceeds approximately 20 MiB (corresponding to sequence lengths ≈ 80,000), L2 misses escalate markedly. Under cyclic inner loop traversal (where all CTAs process K/V tiles in ascending order), the least-recently-used (LRU) reuse distance between accesses of the same line equals the entire K/V working set, causing conflict misses once the working set exceeds L2 capacity. Empirical SM-scaling studies show the L2 hit rate scales as , indicative of a synchronous wavefront fill and subsequent limited reuse as the memory footprint grows (Zhu et al., 22 Jan 2026).
2. Analytical Modeling of L2 Access Patterns
CuTile-based FlashAttention can be analytically modeled to predict L2 sector accesses as a function of sequence length (), tile size (), head dimension (), element size (), and cache sector size (). For a single batch and single head, the main parameters yield the following expressions for total L2 sector accesses:
- Without causal masking:
- With causal masking:
These closed-form equations (referred to as Eq. (2) and Eq. (3) in the source) achieve <1% mean absolute percentage error versus measured values and facilitate system-level performance projections.
3. Sawtooth Wavefront Reordering
A principal innovation for reducing non-compulsory L2 cache misses is Sawtooth Wavefront Reordering. Instead of the conventional cyclic scan—where each Q-tile iteration traverses K/V tiles in the same direction (e.g., 0 to )—the sawtooth strategy alternates the direction of the inner K/V scan between successive Q-tiles: even-numbered iterations traverse forward, odd-numbered iterations traverse backward. The result is a “zig-zag” pattern over the K/V tiles, so that adjacent CTAs on the same or neighboring SMs reuse recently cached lines before eviction.
Pseudocode (algorithmically equivalent to Algorithm 4 in (Zhu et al., 22 Jan 2026)):
1 2 3 4 5 6 7 8 9 10 11 |
i_local ← 0
for each query tile q in Q_seq do
if (i_local mod 2 == 0) then
start ← 0; end ← N_KV; step ← +1
else
start ← N_KV-1; end ← -1; step ← −1
for j ← start to end step step do
load K_j, V_j from global memory
compute Attention(q, K_j, V_j)
i_local ← i_local + 1
end for |
This reordering approximately halves the LRU reuse distance, which now becomes at most half the K/V working set for the majority of accesses. The benefit is pronounced when the working set size approaches or exceeds L2 capacity.
4. Implementation in CUDA and CuTile
For persistent-CTA grid-stride looping in CUDA, Sawtooth Wavefront Reordering is implemented by introducing a per-CTA counter (i_local) that determines scan direction via bit-wise operations in registers or shared memory. No additional synchronization or barriers are required beyond the original FlashAttention kernel.
Example code segment within the main loop:
1 2 3 4 5 6 7 8 9 10 11 |
int i_local = 0;
while (more Q tiles) {
bool forward = ((i_local & 1) == 0);
int j_start = forward ? 0 : N_KV-1;
int j_end = forward ? N_KV : -1;
int j_step = forward ? 1 : -1;
for (int j = j_start; j != j_end; j += j_step) {
// load K/V tile j, perform WMMA operations
}
i_local++;
} |
For the CuTile Python kernel, the .scan() method over K/V tiles is replaced by logic generating the appropriate alternation:
1 2 3 4 |
for t in tile_range(0, num_tiles, step=2): tile_order = [t, t+1] if (t//2) % 2 == 0 else [t+1, t] for j in tile_order: # load and compute |
Typical tile size is 64×64, with batch size 8, sequence length up to 128k, and head dimension 64.
5. Empirical Performance Evaluation
Comprehensive performance assessments were conducted for both CUDA and CuTile implementations:
| Version | L2 Misses (non-causal) | L2 Misses (causal) | Throughput |
|---|---|---|---|
| Standard CuTile | ~370 M | – | 61 TFLOPS |
| Sawtooth CuTile | ~120 M (-67%) | – | 69 TFLOPS (+13%) |
| Standard CuTile (causal) | – | – | 41 TFLOPS |
| Sawtooth CuTile (causal) | – | – | 66 TFLOPS (+60%) |
On CUDA, non-compulsory L2 misses dropped by approximately 50% for batch sizes 1, 2, 4, and 8. Throughput improved from approximately 1.3 TFLOPS to 2.4 TFLOPS (+85%).
Sawtooth reordering advantage is specifically marked when the K/V working set approaches or exceeds L2 capacity. When the working set is well below L2 capacity, or if L1 becomes effective (non-streaming scenarios), gains are diminished (Zhu et al., 22 Jan 2026).
6. Trade-offs, Limitations, and Applicability
Overheads introduced by Sawtooth Wavefront Reordering are negligible in practice. Only minor additional instructions are required for direction bit computation and index range adjustment. The approach is most effective when all adjacent CTAs follow the same parity (even/odd) pattern—a property naturally present in persistent grid-stride implementations.
CuTile’s compiler may subdivide tiles ≥128 if they do not fit into L1/Tex, modifying the scan pattern and reducing or eliminating the benefit of the sawtooth technique. In scenarios where the K/V working set is much smaller than L2 cache, or where L1 cache is utilized due to different access patterns, sawtooth ordering confers little or no benefit.
7. Conclusions and Future Directions
CuTile-based FlashAttention on NVIDIA GB10, with Sawtooth Wavefront Reordering, demonstrates that systematic restructuring of memory access patterns can substantially improve L2 cache reuse and, consequently, kernel throughput for LLM attention workloads. The primary contributions include empirical analysis of L2 performance, an analytical model for L2 accesses, and the sawtooth reordering technique that halves approximate reuse distance. The optimization is validated in both CUDA and CuTile environments, with consistent reductions in L2 misses (up to 67%) and throughput gains (up to 60%).
Prospective directions involve generalizing sawtooth reordering to other streaming kernel domains, compiler automation for detection and injection of alternating scan patterns, exploration of more complex SM scheduling policies, and integration with dynamic tile sizing and/or hardware-level directives for fine-grained cache occupancy management (Zhu et al., 22 Jan 2026).