Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling

Published 5 Mar 2026 in cs.CL | (2603.05451v1)

Abstract: Attention, as a core layer of the ubiquitous Transformer architecture, is the bottleneck for LLMs and long-context applications. While FlashAttention-3 optimized attention for Hopper GPUs through asynchronous execution and warp specialization, it primarily targets the H100 architecture. The AI industry has rapidly transitioned to deploying Blackwell-based systems such as the B200 and GB200, which exhibit fundamentally different performance characteristics due to asymmetric hardware scaling: tensor core throughput doubles while other functional units (shared memory bandwidth, exponential units) scale more slowly or remain unchanged. We develop several techniques to address these shifting bottlenecks on Blackwell GPUs: (1) redesigned pipelines that exploit fully asynchronous MMA operations and larger tile sizes, (2) software-emulated exponential and conditional softmax rescaling that reduces non-matmul operations, and (3) leveraging tensor memory and the 2-CTA MMA mode to reduce shared memory traffic and atomic adds in the backward pass. We demonstrate that our method, FlashAttention-4, achieves up to 1.3$\times$ speedup over cuDNN 9.13 and 2.7$\times$ over Triton on B200 GPUs with BF16, reaching up to 1613 TFLOPs/s (71% utilization). Beyond algorithmic innovations, we implement FlashAttention-4 entirely in CuTe-DSL embedded in Python, achieving 20-30$\times$ faster compile times compared to traditional C++ template-based approaches while maintaining full expressivity.

Summary

  • The paper details a hardware-aware co-design that pipelines asynchronous MMA operations and emulates exponential functions to address GPU scaling bottlenecks.
  • It integrates conditional softmax rescaling and TMEM-based tile scheduling to reduce memory traffic, warp divergence, and atomic contention.
  • Empirical results demonstrate up to 1.3× to 2.7× performance gains over prior methods, underscoring the value of algorithm-hardware co-design for modern AI accelerators.

FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling

Motivation and Problem Formulation

The proliferation of Transformer-based models has generated substantial demand for efficient GPU-accelerated attention mechanisms, particularly as hardware architectures exhibit increasingly asymmetric scaling patterns. On NVIDIA Blackwell GPUs (B200, GB200), tensor core throughput has doubled relative to Hopper H100, while other resources such as shared memory bandwidth and exponential unit throughput scale more slowly or remain unchanged. This asymmetry creates new bottlenecks in attention workloads, fundamentally altering performance regimes and necessitating algorithm-hardware co-design.

FlashAttention-4 directly addresses these bottlenecks by redesigning the attention pipeline, exploiting Blackwell GPU features, and mitigating limiting factors outside of pure matrix multiplication (MMA). Key innovations include asynchronous MMA operation pipelining, software-emulated exponential computation, conditional softmax rescaling, and reductions in shared memory traffic and global atomic adds. All kernel development is realized in CuTe-DSL embedded in Python, which accelerates compilation (20–30×) and fosters greater extensibility.

Hardware-Aware Algorithm Design

Asynchronous MMA and Tile Scheduling

The introduction of tensor memory (TMEM) in Blackwell enables fully asynchronous MMA instructions. Where Hopper held MMA accumulators in registers (dramatically constraining tile sizes and register allocation), Blackwell's tensor cores write to TMEM directly, supporting larger tiles and more aggressive pipelining. FlashAttention-4 employs overlapping pipelines—while one tile executes tensor core operations, the other computes softmax—maximizing resource utilization.

Software-Emulated Exponential Functions

Softmax computation is a critical non-MMA bottleneck, limited by the MUFU's modest throughput (16 ops/clock/SM vs. 8192 ops/clock/SM for MMA). FlashAttention-4 addresses this by distributing exponential computations across both MUFU and floating-point FMA units via polynomial approximation. Range reduction (Cody-Waite) and bit manipulation yield efficient 2x2^x evaluation. Empirical tests demonstrate that a degree-3 polynomial matches MUFU hardware to within 1 BF16 ULP on 99% of inputs; BF16 quantization dominates the error profile for all polynomial degrees ≥3.

Conditional Softmax Rescaling

FlashAttention online softmax involves recursive renormalization steps for numerical stability, but rescaling is only necessary upon encountering new maxima. FlashAttention-4 introduces conditional rescaling, skipping unnecessary renormalizations and tolerating “slack” within thresholds, with a final normalization step ensuring exact outputs. This reduces vector multiplications, warp divergence, and register pressure without compromising correctness.

Backward Pass Pipeline and Memory Optimization

Pipeline Overlap and Resource Partitioning

Backward computation is dominated by five MMA operations (recomputing SS, gradients of QKQK and PVPV, and their respective partial derivatives). FlashAttention-4 exploits TMEM and aggressive tile partitioning to overlap MMA and non-MMA work, minimizing serialization and hiding softmax latency. The figure below illustrates the backward computational graph: Figure 1

Figure 1: FlashAttention-4 backward computation graph (5 MMA operations + 2 elementwise operations), showing the 1-CTA MMA mode software pipeline order across the prologue, main loop, and tail.

TMEM partitioning is carefully managed—the memory can only fit four tiles of 128×128128 \times 128, leading to sharing between SS, PP and dPdP, dSdS, dQdQ. The pipeline is ordered for overlap and resource conservation.

2-CTA Mode: Traffic and Atomic Reduction

Blackwell’s 2-CTA MMA mode enables CTA pairs to jointly execute MMAs and share TMEM. In the backward dQdQ step, DSMEM is used to exchange half of the dSdS tile across CTAs, allowing each to process (M2×2N)(\frac{M}{2} \times 2N) operands with doubled reductions, halving the atomic operations. This alleviates shared memory contention and atomic bottlenecks. Figure 2

Figure 2: In the 2-CTA backward dQ step, the CTA pair uses DSMEM to exchange half of the dS tile so each CTA forms an M2×2N\frac{M}{2} \times 2N operand and runs a CTA-pair UMMA with a doubled reduction.

Deterministic Backward Execution

To ensure reproducible gradient computation (essential for RL and debugging), FlashAttention-4 implements deterministic reduction using semaphore locks and careful CTA order scheduling. Swizzling over heads/batches and shortest-processing-time-first (SPT) scheduling is used to minimize stalls and achieve high deterministic throughput.

Scheduling and Framework Implementation

Longest-Processing-Time-First Scheduling

Work tiles for attention (especially with causal masking or variable sequence lengths) exhibit load imbalance. By sorting and processing tiles in LPT order (with batch as the outer dimension and head swizzling to minimize L2 cache thrashing), FlashAttention-4 achieves improved performance—empirical benchmarks show 4–8% FLOPS gains for MHA and up to 14% for MQA.

CuTe-DSL: Python-Based Kernel Assembly

The entire kernel suite is written in CuTe-DSL, providing CUTLASS-equivalent expressivity in Python and yielding 20–30× faster compile times than C++ templates. Primitive/flexible abstractions support rapid prototyping (e.g., block-sparse variants, FlexAttention), orthogonal features (masking, varlen, scheduling), and composable optimizations, lowering the entry barrier for GPU attention research.

Empirical Results

FlashAttention-4 achieves up to 1.3× speedup over cuDNN 9.13 and 2.7× over Triton for BF16 workloads on B200, reaching 1613 TFLOPs/s (71% theoretical peak). The pipeline and scheduling improvements result in robust performance across sequence lengths, head dimensions, and masking types. Deterministic backward passes retain up to 75% the speed of nondeterministic variants.

Implications and Future Directions

FlashAttention-4 exemplifies the necessity of algorithm-hardware co-design in response to rapidly evolving accelerators. As future GPU generations further increase MMA throughput relative to memory and non-matmul units, attention algorithms must integrate pipelined scheduling, operator emulation, and conditional computation. The Pythonic kernel framework foreshadows democratization of GPU programming and faster iteration cycles for algorithmic innovation.

Beyond Blackwell, analogous pipelining/memory optimization concepts may transfer to TPUs, IPUs, ASICs, and future neural accelerators. The trend toward hardware-aware algorithm adaptation will accelerate as device heterogeneity rises, and will influence all stages of AI model training and inference, especially in the context of long-sequence modeling and multimodal systems.

Conclusion

FlashAttention-4 provides robust, hardware-scaled attention kernels by co-designing algorithmic and kernel strategies with direct reference to hardware bottlenecks and constraints. The method leverages asynchronous computation, memory-aware scheduling, operator emulation, and conditional processing, yielding significant speedups and utilization on Blackwell GPUs. The modular, Python-based framework empowers rapid research and application deployment. These contributions will inform future attention kernel development as hardware asymmetry intensifies and deep learning workloads continue to scale (2603.05451).

Whiteboard

Explain it Like I'm 14

FlashAttention‑4: A simple explanation

What is this paper about?

This paper is about making the “attention” part of Transformer models (the kind used in ChatGPT and many AI apps) run much faster on the newest NVIDIA Blackwell GPUs (like the B200 and GB200). The authors introduce FlashAttention‑4, a set of tricks and a new software design that better fits how these new chips work, so AI models can handle longer inputs and run more efficiently.

What questions were the researchers trying to answer?

They focused on three big questions:

  • How do we speed up attention on new GPUs where some parts got much faster (matrix math) but others didn’t (memory and special functions like exp)?
  • How can we rearrange the work so the fast parts and slow parts run in parallel and don’t wait on each other?
  • Can we make it easier and quicker for developers to build and compile these fast kernels without diving deep into complex C++ code?

How did they do it? (Methods in everyday language)

Think of a GPU like a busy factory:

  • Tensor cores are super‑fast machines that do big matrix multiplications (like multiplying huge number grids).
  • Memory is like conveyor belts and shelves that bring parts to the machines.
  • The exp function (used in softmax) is like a special tool that’s slower and limited in number.

On Blackwell GPUs, the machines (tensor cores) got much faster, but the belts and shelves (memory) and the special tool (exp) didn’t speed up as much. That creates a traffic jam unless you redesign the factory’s assembly line.

Here are the main tricks they used:

  • Overlapping tasks with a new “pipeline”: They organized the work so that while one part is doing matrix math, another part is doing softmax, and another is moving data. Like two teams working on different halves of the same job so no one stands idle.
  • Using “tensor memory” (TMEM): Blackwell chips have a new on‑chip storage area right next to the tensor cores. It’s like adding small shelves beside each machine so parts don’t have to be passed around as much. This reduces traffic and lets them use larger working chunks (“tiles”), which keeps the fast machines busier.
  • Faster softmax by “emulating exp”: The softmax step needs lots of exp calculations, and the exp units are slow. The authors approximate exp using regular math units with short polynomial formulas—think of it as a quick shortcut that’s very close to the real answer. They split the work: some exps use hardware, some use the shortcut, so they don’t overload the slow tool. Tests show the shortcut is accurate enough for BF16 precision used in these models.
  • “Conditional rescaling” in softmax: Normally, softmax rescales intermediate results every time it sees a bigger number, which costs extra work. The authors only rescale when it really matters (when the increase is big enough). This avoids many unnecessary steps while keeping the final answer correct.
  • Teaming up two thread blocks (2‑CTA mode): Blackwell lets two groups of threads act like a single bigger team for a matrix multiply. Each team loads half the data, which cuts shared‑memory traffic and reduces the number of “atomic adds” (a slow way for many threads to safely update the same value). This speeds up the backward pass (the training step where gradients are computed).
  • Better scheduling: They process the longest or heaviest tasks first (a classic scheduling trick) and smartly order batches/heads to keep caches warm. This reduces waiting and improves overall throughput.
  • Deterministic mode: They provide a reproducible option for training (important for debugging and RL), by controlling the order of reductions even if it’s a bit slower.
  • Built in CuTe‑DSL (Python): Instead of writing everything in heavy C++ templates, they use a Python‑embedded DSL that still produces low‑level GPU code. This makes compiling much faster (seconds instead of nearly a minute) and easier for developers to experiment with.

What did they find, and why is it important?

They measured performance on NVIDIA B200 GPUs and found:

  • Up to 1.3× faster than cuDNN 9.13 and up to 2.7× faster than a Triton implementation (BF16).
  • Up to 1613 TFLOPs/s achieved, which is about 71% of the chip’s theoretical peak—very high utilization for real workloads.
  • Compiles 20–30× faster than prior C++ template approaches, making iteration and development much quicker.
  • Particularly strong results for long sequences, where attention is often the main bottleneck.

Why it matters:

  • Faster attention means faster training and inference for LLMs and other Transformer‑based systems.
  • Better use of the newest GPUs lowers costs and energy for the same amount of work.
  • Handling longer contexts more efficiently enables features like reading long documents or large codebases.

What’s the bigger impact?

  • Practical speedups on modern datacenter GPUs: The work targets the hardware most AI labs deploy, not just consumer cards. This makes it immediately useful for large‑scale AI systems.
  • Future‑ready design: By co‑designing the algorithm with the new hardware features (like TMEM and 2‑CTA MMA), the approach adapts to “asymmetric scaling,” where some parts of the chip get faster than others. This idea will likely matter even more in future GPU generations.
  • Easier innovation: Writing kernels in a Python DSL means researchers can prototype new attention variants faster, lowering barriers for the community.
  • Open source: The code is released for others to use and build on, speeding up progress across AI research and industry.

In short, FlashAttention‑4 reorganizes the attention “assembly line” to fit Blackwell GPUs: it keeps the fast parts busy, relieves pressure on slower parts, and brings big, practical speedups—especially for long‑context models—while making it easier for developers to keep improving.

Knowledge Gaps

Unresolved gaps, limitations, and open questions

Below is a concise, actionable list of what remains uncertain or unexplored in the paper and could guide future research.

  • Breadth of evaluation
    • Missing systematic benchmarks across precisions beyond BF16 (e.g., FP16, FP32, FP8/FP4, INT8/INT4), including accuracy and performance trade-offs for each.
    • Lack of comprehensive workload coverage (e.g., cross-attention vs self-attention, decode vs prefill, short vs very long sequences, small-batch regimes, head dimension d not divisible by 128/256).
    • No end-to-end model training or inference throughput results (tokens/s, time-to-train) demonstrating real-world gains over strong baselines (e.g., PyTorch SDPA/cuDNN, Triton) across diverse LLM sizes and tasks.
    • Missing ablations quantifying the individual contribution of each innovation (asynchronous pipeline, exp emulation, conditional rescaling, TMEM allocation, 2-CTA mode, scheduling) across shapes and masks.
  • Numerical accuracy and stability
    • No analysis of forward/backward numerical error and stability for the polynomial EXP emulation during training (only BF16 forward error over [0,1) is reported; backward and full softmax range analysis is absent).
    • No quantified impact of partial EXP emulation on training convergence, final model quality, and stability for long-context training (e.g., degradation thresholds, failure modes).
    • Conditional softmax rescaling lacks theoretical guarantees and empirical bounds on overflow/underflow risk and accuracy loss as a function of the threshold τ; no systematic guidance on τ selection per precision and workload.
    • Absence of error propagation analysis when combining polynomial EXP emulation with BF16 rounding and softmax normalization, especially for extreme logits and very long sequences.
    • No evaluation of numerical behavior for common training features (e.g., dropout in attention, ALiBi/rotary positional encodings, mixed-precision loss scaling).
  • Backward pass specifics
    • Unclear whether EXP emulation is applied in the backward pass and, if so, its accuracy, register pressure, and performance implications; if not, remaining MUFU bottlenecks in backward are unaddressed.
    • No quantified overheads or contention analysis for DSMEM exchanges in 2-CTA dQ, including latency hiding limits, bank conflicts, and scalability to many concurrent CTAs.
    • No sensitivity study of atomic reduction patterns and their interaction with different masking modes (causal, varlen) and with MQA/GQA, including correctness and performance regressions on corner cases.
  • 2-CTA mode and portability constraints
    • 2-CTA MMA requires fixed CTA pairing and cluster scheduling; the paper does not quantify scheduling fragility, preemption sensitivity, or interactions with multi-tenant/MIG/MPS environments.
    • Lack of fallback strategy and performance characterization when 2-CTA mode cannot be used (e.g., due to resource fragmentation, dynamic shape changes).
    • Portability beyond Blackwell is not demonstrated (e.g., Hopper/H200/B300/GB300), including how kernels adapt to different MMA tile shapes, TMEM sizes, and MUFU throughput (e.g., B300’s doubled MUFU potentially obviates EXP emulation).
  • TMEM and register pressure trade-offs
    • No detailed occupancy study quantifying how TMEM allocations, larger accumulator tiles, and register usage (especially with EXP emulation) affect active warps and overall throughput.
    • No guidance on TMEM allocation policies (e.g., when S/P/dP/dQ should share TMEM vs SMEM, fragmentation risks, dynamic reuse strategies across variable shapes).
    • Missing exploration of head dimensions >128 or not multiples of 128, and the resulting TMEM/SMEM/register trade-offs and pipeline adjustments.
  • Scheduling and load balancing
    • LPT scheduling decisions lack quantitative cache modeling (e.g., L2 hit-rate improvements, cache thrash thresholds) and robust tuning guidelines across architectures and batch/head counts.
    • For varlen workloads, preprocessing/sorting overheads and amortization across iterations are not measured (including dynamic batch composition changes in continuous batching).
    • No analysis of scheduling under cluster constraints required by 2-CTA mode (e.g., degraded LPT effectiveness due to cluster-level placement).
  • Determinism and reproducibility
    • Deterministic mode overhead is not quantified (absolute and relative), nor stress-tested under worst-case contention patterns.
    • Open question whether lock-free deterministic strategies (e.g., segmented reductions, hierarchical reductions via TMEM/SMEM) could reduce overheads.
    • No guarantees or tests for determinism across different driver versions, JIT/PTX compilation, or varying CTA scheduling orders.
  • Auto-tuning and adaptivity
    • The fraction of rows using EXP emulation and the τ threshold for conditional rescaling are manually tuned; no auto-tuner or runtime adaptation strategy is provided to optimize for shape/precision/mask and hardware state.
    • No generalized policy for tile size, pipeline depth, and warpgroup roles that adapts to register/TMEM pressure and workload characteristics at runtime.
    • Missing robustness analysis to runtime variability (e.g., DVFS, thermal throttling) and its impact on choosing emulation fractions and pipeline overlap.
  • Compatibility and integration
    • Integration with major frameworks is mentioned but not evaluated; missing cost of JIT (first-run latency), kernel caching, and multi-kernel precompilation strategies in production settings.
    • Interoperability with existing quantization toolchains (e.g., FP8/INT8 models) and mixed-precision training is untested; unclear if EXP emulation and conditional rescaling maintain calibration/scale consistency.
    • No results on block-sparse, sliding-window, or top-k attention variants despite claiming flexibility; performance/accuracy impacts are unknown.
  • System-level and energy efficiency
    • No measurements of energy efficiency (Joules per token), power draw, or thermal behavior versus baselines; unclear if reduced non-MMA bottlenecks translate to better energy proportionality.
    • Multi-GPU and multi-node scaling is not addressed (e.g., overlap with NCCL collectives, pipeline/tensor parallelism, impact on end-to-end throughput).
  • Methodology transparency
    • Benchmark setup details are incomplete: precise shapes, masks, batch/head configurations, varlen distributions, memory footprints, and measurement methodology (warmup, synchronization, TFLOPs computation) are not fully specified for reproducibility.
    • Lack of statistical reporting (variance/confidence intervals) and sensitivity analyses (e.g., to driver versions, ptxas revisions, CuTe-DSL codegen changes).
  • Theoretical guarantees and proofs
    • No formal analysis bounding the error introduced by conditional rescaling across an entire attention sweep, nor a proof of safety margins per precision.
    • No formal model quantifying roofline overlap limits with fully asynchronous MMAs and TMEM-mediated pipelines, to guide tile and pipeline choices analytically.
  • Future hardware evolution
    • How the design adapts to B300/GB300 (e.g., doubled MUFU throughput) is not analyzed; conditions under which EXP emulation becomes counterproductive are unspecified.
    • Unclear how to retarget the co-designed pipeline to non-NVIDIA accelerators (AMD, TPU, custom ASICs) with different memory hierarchies and MMA semantics.

Practical Applications

Immediate Applications

Below are actionable, sector-linked uses you can deploy now, derived from the paper’s findings and engineering methods. Each item includes potential tools/workflows and key dependencies.

Industry

  • Higher-throughput LLM pretraining and fine-tuning on Blackwell (B200/GB200)
    • What: Swap in FlashAttention-4 (FA-4) kernels to reduce time-to-train and GPU-hours on BF16 workloads; expect up to ~1.3× over cuDNN 9.13 and ~2.7× over Triton (GPU DSL) baselines reported in the paper.
    • Tools/workflows: PyTorch custom ops/extensions; integration in Megatron-LM, DeepSpeed, Hugging Face Transformers, TensorRT-LLM backends; CI benchmarks for model variants (MHA/GQA/MQA).
    • Dependencies/assumptions: Access to Blackwell GPUs; BF16/FP16 training; kernel drop-in compatibility with your stack; correctness/accuracy validation on your data; sufficient SMEM/TMEM configuration.
  • Lower-latency, higher-throughput long-context inference
    • What: Serve 100k+ token contexts for chat, RAG, codebase analysis, and enterprise document QA with lower cost/latency via FA-4’s pipelining and throughput gains.
    • Tools/workflows: vLLM and other serving stacks; batching with LPT varlen scheduling; prefill/decode pipeline tuning; KV-cache management.
    • Dependencies/assumptions: Blackwell deployment; ability to replace attention kernels in the serving path; stable long-context models; sequence-length-aware schedulers.
  • Energy and TCO reductions for AI cloud services
    • What: Cut energy per token/training step by improving utilization (reported up to 71% of peak tensor core FLOPs), translating to lower TCO and improved sustainability metrics.
    • Tools/workflows: Datacenter telemetry (FLOPs/s, power draw); cost dashboards; SLA tuning based on realized throughput.
    • Dependencies/assumptions: Workloads dominated by attention; ability to instrument and attribute savings; consistent long-running jobs.
  • Reproducible RL/optimization pipelines via deterministic backward mode
    • What: Use FA-4’s deterministic mode (semaphore-based reductions) to ensure reproducible gradient updates in RL and sensitive training pipelines.
    • Tools/workflows: Toggle deterministic mode for regression tests and formal evals; seed control; audit logs; config-managed runs.
    • Dependencies/assumptions: Acceptable performance overhead; cluster scheduling that tolerates determinism-induced stalls; familiarity with lock semantics.
  • Throughput gains in GQA/MQA with fewer atomics (2-CTA mode)
    • What: Deploy FA-4’s 2-CTA MMA mode to reduce shared memory traffic and halve dQ global atomic reductions in backward, improving scaling for grouped-query attention.
    • Tools/workflows: Model configs using GQA/MQA; profiler verification of shared-memory/atomic bottlenecks; kernel parameter sweeps.
    • Dependencies/assumptions: Blackwell 2-CTA capability; cluster launch configurations that co-locate CTA pairs; DSMEM availability.
  • Robust mixed prefill/decode serving with LPT scheduling
    • What: Apply longest-processing-time-first (LPT) orderings to reduce tail latency and improve GPU occupancy in mixed varlen batches.
    • Tools/workflows: Preprocessing kernel to sort batches by estimated tile runtime; cache virtual→actual batch index maps; L2-aware head swizzling.
    • Dependencies/assumptions: Access to per-batch length metadata; L2 capacity considerations; scheduler integration.
  • Faster kernel iteration and CI/CD for GPU teams
    • What: Leverage 20–30× faster JIT compile times (CuTe-DSL in Python) to shorten kernel development cycles, enabling per-PR kernel tests and rapid prototyping.
    • Tools/workflows: Python-only kernel repos; per-commit microbenchmarks; parameterized autotuning; PTX escape hatches in CuTe-DSL.
    • Dependencies/assumptions: PTXAS toolchain compatibility; Blackwell toolchain versions; developer familiarity with CuTe-DSL.

Academia

  • Rapid prototyping of attention variants and structured sparsity
    • What: Build FlexAttention or block-sparse attention on FA-4 without modifying the framework core; probe algorithm-hardware co-design ideas easily.
    • Tools/workflows: CuTe-DSL notebooks; open-source FA-4 repo; microbench harnesses; ablations on tile shapes/scheduling.
    • Dependencies/assumptions: Availability of Blackwell GPUs for peak benefits; reproducibility checks under deterministic mode.
  • Teaching and curricula on modern GPU systems for ML
    • What: Use FA-4 as a case study on TMEM, 2-CTA MMA, warp specialization, polynomial exp emulation, and scheduling; create lab assignments on asymmetric scaling.
    • Tools/workflows: Annotated kernels; profiling labs; “feeds-and-speeds” roofline projects.
    • Dependencies/assumptions: GPU access (or simulator for conceptual labs); updated course materials; institutional compute allocations.

Policy

  • Procurement guidance focused on energy efficiency and reproducibility
    • What: Encourage adoption of kernel stacks with deterministic modes and high realized utilization for public-sector/regulated deployments.
    • Tools/workflows: RFP criteria referencing realized TFLOPs/s; deterministic reproducibility checklists; auditability requirements.
    • Dependencies/assumptions: Independent benchmarking; alignment with sustainability mandates; acceptance of modest performance loss in deterministic mode.
  • Open-source-first practices to reduce lock-in risk
    • What: Favor open-source kernels like FA-4 for critical AI infrastructure to ensure transparency and portability across vendors.
    • Tools/workflows: OSS compliance scans; fork-and-verify practices; community benchmarking.
    • Dependencies/assumptions: Ongoing maintenance and community health; compatibility with agency security policies.

Daily Life

  • Faster, cheaper AI assistants and creative tools with longer context
    • What: End users benefit from lower latency and richer context windows in assistants, coding copilots, and document/video analysis.
    • Tools/workflows: Service upgrades by providers; longer-context model releases; cost pass-through to pricing tiers.
    • Dependencies/assumptions: Providers migrate to Blackwell and integrate FA-4; models stable at long context.
  • More reliable “replayable” experiments and A/Bs in AI features
    • What: Deterministic training variants improve reproducibility in feature experimentation pipelines used in consumer apps.
    • Tools/workflows: Versioned training runs; saved seeds/configs; reproducibility gates before rollout.
    • Dependencies/assumptions: Teams accept throughput trade-offs for critical evaluation stages.

Long-Term Applications

The following opportunities likely require further research, scaling, cross-stack integration, or future hardware availability.

Industry

  • Cross-architecture portability of asymmetric-scaling-aware attention
    • What: Generalize FA-4 design patterns (TMEM-aware tiling, 2-CTA decomposition, pipelined softmax) to upcoming B300/GB300 and to other vendors (e.g., AMD MI-series) with differing memory hierarchies.
    • Tools/workflows: Hardware abstraction layers; autotuners for tile sizes/modes; per-arch exp-emulation fractions.
    • Dependencies/assumptions: Equivalent features (tensor memory, async MMAs, CTA-pair modes) or viable emulations; vendor cooperation.
  • Compiler-level integration and autotuning in mainstream frameworks
    • What: Incorporate FA-4 schedules into PyTorch Inductor, TVM, XLA; auto-select tile sizes, exp-emu ratios, and scheduling (LPT/SPT) by workload characteristics.
    • Tools/workflows: Cost models using roofline predictors; ahead-of-time candidate generation and runtime selection.
    • Dependencies/assumptions: Compiler backends expose TMEM/2-CTA; stable kernel IRs; profiling hooks.
  • Approximate math micro-libraries beyond exp for non-MMA bottlenecks
    • What: Extend polynomial/FMA-based emulation to log, GELU/SiLU, and softmax variants to relieve MUFU pressure in other layers (e.g., diffusion, recsys).
    • Tools/workflows: Error-controlled approximations by dtype; mixed MUFU/FMA allocation; safe defaults per model.
    • Dependencies/assumptions: Tight accuracy budgets; careful register pressure management; end-to-end validation.
  • Datacenter schedulers co-designed with kernel-level LPT/SPT
    • What: Elevate LPT/SPT heuristics to cluster schedulers for co-scheduling batch shapes and heads across GPUs to improve cache locality and utilization.
    • Tools/workflows: Job-level metadata; cache-aware placement; multi-tenant fairness policies integrating kernel runtime predictions.
    • Dependencies/assumptions: Cross-layer telemetry; scheduler extensibility; acceptance of more complex placement logic.
  • Generalized reduction/atomics minimization patterns
    • What: Apply 2-CTA-style reductions and DSMEM exchanges to other reduction-heavy kernels (e.g., MoE gating, attention variants, gradient accumulation).
    • Tools/workflows: Library of reduction templates; CTA-pair-aware launchers; reproducibility toggles.
    • Dependencies/assumptions: Hardware support for cross-CTA memory; memory ordering guarantees; maintainable complexity.
  • Edge and consumer GPU spillover
    • What: Adapt FA-4 ideas to consumer Blackwell and future mobile/edge NPUs with smaller on-chip memories and different interconnects.
    • Tools/workflows: Tighter tiling; mixed-precision softmax strategies; memory-aware context partitioning.
    • Dependencies/assumptions: Hardware feature parity or substitutes; on-device memory budgets; battery/thermal constraints.

Academia

  • Algorithm–hardware co-design beyond attention
    • What: Apply the paper’s roofline-driven design to LayerNorm, RMSNorm, optimizers, KV-cache transforms, and MoE dispatch to alleviate non-MMA bottlenecks.
    • Tools/workflows: Unified bottleneck analyzers; prototype libraries in CuTe-DSL; empirical scaling studies.
    • Dependencies/assumptions: Access to evolving hardware; community benchmarks.
  • Automated kernel synthesis with learned schedulers
    • What: Train ML models to predict optimal pipelining/tiling/exp-emu fractions given hardware counters and problem shapes (closed-loop autotuning).
    • Tools/workflows: Data collection across shapes/dtypes; reinforcement/autotuning loops; explainable schedule selection.
    • Dependencies/assumptions: Sufficient training data; stable APIs to emit kernels; guardrails to avoid pathological schedules.
  • Standard benchmarks for asymmetric-scaling research
    • What: Establish community suites that stress MUFU/SMEM vs MMA to catalyze co-design research and fair comparisons.
    • Tools/workflows: Open benchmark kernels; reference configs; reproducibility kits with deterministic modes.
    • Dependencies/assumptions: Broad participation; funding for continuous benchmarking.

Policy

  • Reproducibility and auditability standards for AI training
    • What: Encourage “deterministic-capable” kernels for critical domains (healthcare, finance, public sector) with published performance/variance trade-offs.
    • Tools/workflows: Certification processes; determinism reporting in model cards; audit trails for training runs.
    • Dependencies/assumptions: Sector-specific regulations; acceptance of overhead in regulated contexts.
  • Energy/transparency reporting for AI infrastructure
    • What: Adopt metrics like realized TFLOPs/s, SMEM/MUFU utilization, and energy per token to guide incentives and carbon accounting.
    • Tools/workflows: Standardized telemetry; third-party verification; policy-aligned disclosures.
    • Dependencies/assumptions: Agreement on metrics; data-sharing frameworks.

Daily Life

  • Richer, persistent-context AI experiences
    • What: As long-context models become affordable to serve, personal assistants retain more history, supporting multi-document reasoning, long-term projects, and complex video/code workflows.
    • Tools/workflows: Session memory services; privacy-preserving context stores; cost-adaptive context windows.
    • Dependencies/assumptions: Provider migration to efficient attention; user consent and data governance.
  • Broader access via cost deflation
    • What: Savings from higher utilization can translate into lower subscription tiers or more generous free quotas for AI features in productivity apps and developer tools.
    • Tools/workflows: Cost-to-price pass-through policies; usage-based caps; QoS adjustments by model size/context.
    • Dependencies/assumptions: Market competition; sustained efficiency gains at scale.

Glossary

  • 2-CTA MMA mode: A Blackwell tensor core mode where a pair of CTAs cooperatively executes a single MMA, sharing tensor memory and splitting operands to cut shared-memory traffic. "leveraging tensor memory and the 2-CTA MMA mode to reduce shared memory traffic and atomic adds in the backward pass."
  • acquire--release semantics: Memory-ordering guarantees ensuring that operations before a release are visible to threads after an acquire on the same synchronization object. "required for correct acquire--release semantics"
  • atomic add: A hardware-supported, indivisible addition to a memory location used for parallel reductions, which can introduce contention and nondeterminism. "reduce shared memory traffic and atomic adds in the backward pass."
  • BF16: Brain floating point format with 16 bits (8-bit exponent, 7-bit mantissa), used for high-throughput, lower-precision compute. "on B200 GPUs with BF16, reaching up to 1613 TFLOPs/s (71\% utilization)."
  • Blackwell: NVIDIA GPU architecture succeeding Hopper, with asymmetric scaling and new features like TMEM and larger MMA tiles. "Blackwell introduces a new memory level called tensor memory (TMEM), a 256 KB on-chip memory per SM specifically designed for storing intermediate results of tensor core operations."
  • causal masking: A masking strategy that prevents attention to future tokens by zeroing out scores above the diagonal. "For causal masking, we additionally launch KV blocks in descending order, traverse query blocks in ascending order starting from the diagonal, and order the dQdQ reductions by descending query block index."
  • Cody-Waite range reduction: A classical technique to reduce argument range before polynomial approximation of transcendental functions. "We use the classical range reduction technique (Cody-Waite)"
  • cooperative thread array (CTA): A threadblock; a group of threads that run concurrently on an SM and cooperate via shared memory and synchronization. "threadblocks (i.e. cooperative thread arrays or CTAs)"
  • cuDNN: NVIDIA’s CUDA Deep Neural Network library providing optimized kernels for DL primitives. "achieves up to 1.3×\times speedup over cuDNN 9.13"
  • CuTe-DSL: A Python-embedded domain-specific language that lowers to PTX for GPU kernels, enabling fast JIT compilation and CUTLASS-level expressivity. "Beyond algorithmic innovations, we implement FlashAttention-4 entirely in CuTe-DSL embedded in Python"
  • CUTLASS: NVIDIA’s C++ template library for high-performance GEMM and tensor operations on GPUs. "The CuTe-DSL programming model is isomorphic to CUTLASS C++"
  • deterministic execution mode: An execution setting that enforces a fixed order of operations (e.g., reductions) to ensure reproducible results. "We also implement a deterministic execution mode with minimal performance overhead, enabling reproducible training for reinforcement learning applications."
  • distributed shared memory (DSMEM): Shared memory accessible across CTAs within a cluster, enabling inter-CTA data exchange. "we use distributed shared memory (DSMEM) to exchange half of the dS"
  • exponential unit: Specialized hardware (MUFU) for transcendental ops like exp; a bottleneck relative to tensor core throughput. "The exponential unit computes elementwise operations required for the softmax computation."
  • FMA (fused multiply–add): An operation that performs a multiply and add in one instruction with a single rounding, used for fast polynomial evaluation. "we implement a software emulation of 2x2^x using floating-point FMA units"
  • FP8: 8-bit floating point format used for low-precision, high-throughput compute and memory efficiency. "incorporating FP8 support."
  • GEMM: General Matrix Multiply; the core dense linear algebra operation underlying attention matmuls. "Even with improved pipelining and with two of the ten GEMM operands resident in tensor memory, shared memory bandwidth still dominates the backward pass."
  • GPC (Graphics Processing Cluster): A hardware grouping of SMs; CTAs in a cluster are co-scheduled on the same GPC. "CTAs in the same cluster are co-scheduled on the same GPC."
  • GMEM (global memory): Off-chip device memory (HBM) accessible by all SMs; highest capacity, lowest on-chip bandwidth. "Global memory (GMEM), also known as HBM, is the off-chip DRAM that is accessible to all streaming multiprocessors (SMs)."
  • GQA (Grouped Query Attention): An attention variant that shares keys/values across groups of query heads to reduce memory/computation. "and dKdK/dVdV in the case of GQA"
  • HBM: High Bandwidth Memory; off-chip DRAM with very high bandwidth used as GPU global memory. "Global memory (GMEM), also known as HBM, is the off-chip DRAM"
  • Hopper: NVIDIA GPU architecture preceding Blackwell; many FA-3 optimizations target Hopper H100. "Blackwell B200 doubles the tensor core throughput compared to Hopper H100 (2.25 PFLOPS vs. 1 PFLOPS for FP16/BF16)"
  • Horner's method: An efficient scheme for evaluating polynomials that minimizes multiplications, ideal for FMA pipelines. "The polynomial evaluation uses Horner's method with FMA instructions"
  • JIT compilation: Just-in-time compilation that compiles code at runtime, speeding iteration and reducing build times. "By embedding CuTe-DSL in Python with just-in-time (JIT) compilation, FlashAttention-4 achieves faster build times"
  • L2 cache: On-chip cache between SMs and HBM that transparently caches global memory lines. "Data from GMEM are transparently cached in an on-chip L2 cache."
  • longest-processing-time-first (LPT) scheduling: A heuristic scheduling policy that assigns longest jobs first to balance load and minimize makespan. "we use the classical idea of longest-processing-time-first (LPT) scheduling"
  • matrix multiply-accumulate (MMA): Tensor core instruction that performs matrix multiplications with accumulation into an output tile. "Each MMA tensor core instruction processes 128×N128 \times N tiles (typically N=N = 128 or 256)"
  • MUFU: Multi-Function Unit; hardware unit implementing transcendental functions like exp/log with limited throughput. "The multifunction unit (MUFU) on B200 and GB200 can perform 16 ops / clock / SM"
  • MUFU.EX2: The hardware exponential base-2 instruction executed on MUFU. "with the remaining entries computed via hardware MUFU.EX2."
  • Multi-Head Attention (MHA): Attention mechanism with multiple parallel heads to capture diverse relations. "for BF16 and head dimension 128 we obtain 4-8\% FLOPS gain for MHA"
  • Multi-Query Attention (MQA): Attention variant sharing keys/values across many query heads to save memory and compute. "for MQA or GQA"
  • occupancy (GPU): The degree to which SM resources are utilized by active warps/CTAs; higher occupancy can hide latency. "improving GPU occupancy."
  • online softmax: A blockwise softmax algorithm that maintains running statistics (max, sum) to ensure stability without full materialization. "FlashAttention online softmax."
  • polynomial approximation: Approximating functions (e.g., exp) by polynomials over a reduced range for higher throughput on FMA units. "and then the polynomial approximation~\citep{muller2018handbook}."
  • PTX: NVIDIA’s virtual GPU ISA to which DSLs/compilers lower before producing device-specific SASS. "lowers to PTX, then uses the PTX compiler (ptxas) to finally produce the assembly code (SASS)."
  • quantization: Reducing numeric precision (e.g., INT8/FP8/FP4) to accelerate compute and reduce memory bandwidth. "achieves speedups through INT8 quantization"
  • register pressure: High demand for registers that can constrain kernel scheduling and cause spills to slower memory. "This alleviates the extreme register pressure that plagued Hopper kernels"
  • register spilling: When insufficient registers force temporary values to be stored in slower memory, hurting performance. "preventing register spills is critical."
  • roofline analysis: A performance model relating compute and memory ceilings to identify bottlenecks. "We first do a roofline analysis"
  • SASS: NVIDIA’s low-level GPU assembly produced by ptxas from PTX. "to finally produce the assembly code (SASS)."
  • semaphore lock: A synchronization primitive used to serialize access (e.g., reductions) across CTAs for determinism. "serialize the global reductions using a semaphore lock."
  • shared memory (SMEM): Programmer-managed on-chip scratchpad memory within an SM with high bandwidth and banked access. "called shared memory (SMEM) on the chip."
  • shortest-processing-time-first (SPT) schedule: A scheduling policy prioritizing the shortest tasks first to reduce waiting; used for ordering deterministic reductions. "This ``shortest-processing-time-first'' (SPT) schedule ensures that no CTA is stalled on its first dQdQ write."
  • softmax rescaling: Renormalization step in online softmax when the running maximum increases; can be conditionally skipped. "We also introduce conditional softmax rescaling that skips unnecessary rescaling operations."
  • Sollya: A tool for rigorous polynomial approximation and floating-point error analysis. "calculated using the Sollya software package"
  • tensor cores: Specialized GPU units for high-throughput matrix operations (MMA) with mixed precision. "Blackwell features fifth-generation tensor cores"
  • tensor memory (TMEM): Blackwell’s on-SM memory for tensor core accumulators and intermediate results, enabling fully asynchronous MMAs. "Blackwell introduces a new memory level called tensor memory (TMEM)"
  • Tensor Memory Accelerator (TMA): Hardware path/instruction set for high-throughput tensor memory transfers used in pipelined kernels. "only the TMA load running significantly out of turn."
  • threadblock cluster: A grouping of CTAs co-scheduled to enable features like DSMEM and 2-CTA MMAs. "threadblock clusters, and grids."
  • Triton: A GPU programming system for writing high-performance kernels in Python. "2.7×\times over Triton"
  • varlen (variable sequence length): Handling batches with heterogeneous sequence lengths that induce load imbalance. "In many situations, such as with causal masking or variable sequence length (varlen), the attention kernel is naturally load-imbalanced"
  • warp: A group of 32 threads that execute in lockstep on an SM. "warps (32 threads)"
  • warp divergence: When threads in a warp take different control paths, reducing parallel efficiency. "to avoid warp divergence"
  • warp specialization: Dividing warps into roles (e.g., producers/consumers) to increase overlap and pipeline efficiency. "Hardware support for asynchrony allows for warp-specialized kernels"
  • warp-synchronous: Operations or memory semantics that are synchronized at the warp level rather than CTA-wide. "TMEM is warp-synchronous and tightly coupled with the tensor cores"
  • warpgroup: A group of four contiguous warps used as a scheduling/cooperation unit. "warpgroups (4 contiguous warps)"

Open Problems

We found no open problems mentioned in this paper.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 22 tweets with 959 likes about this paper.