FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling
作者: Ted Zadouri, Markus Hoehnerbach, Jay Shah, Timmy Liu, Vijay Thakkar, Tri Dao
分类: cs.CL
发布日期: 2026-03-05
💡 一句话要点
FlashAttention-4:面向非对称硬件扩展的算法与Kernel流水线协同设计
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: FlashAttention Transformer Blackwell GPU 算法优化 Kernel优化 CuTe-DSL 高性能计算 深度学习
📋 核心要点
- 现有Attention机制在Blackwell GPU等非对称硬件上存在瓶颈,原因是tensor core吞吐量提升与其他单元不匹配。
- FlashAttention-4通过重新设计的流水线、软件模拟指数运算和利用tensor memory等技术,优化了Blackwell GPU上的Attention计算。
- 实验表明,FlashAttention-4在B200 GPU上相比cuDNN和Triton有显著加速,并实现了更高的硬件利用率和更快的编译速度。
📝 摘要(中文)
Attention机制是Transformer架构的核心,也是大型语言模型和长上下文应用中的瓶颈。FlashAttention-3通过异步执行和warp specialization优化了Hopper GPU上的attention,但主要针对H100架构。AI行业已迅速转向基于Blackwell的系统,如B200和GB200,由于非对称硬件扩展,这些系统表现出根本不同的性能特征:tensor core吞吐量翻倍,而其他功能单元(共享内存带宽、指数单元)扩展较慢或保持不变。我们开发了几种技术来解决Blackwell GPU上这些变化的瓶颈:(1)重新设计的流水线,利用完全异步的MMA操作和更大的tile size;(2)软件模拟的指数和条件softmax重新缩放,减少了非matmul操作;(3)利用tensor memory和2-CTA MMA模式,减少了反向传播中的共享内存流量和原子加法。我们的方法FlashAttention-4在B200 GPU上使用BF16时,比cuDNN 9.13快1.3倍,比Triton快2.7倍,达到高达1613 TFLOPs/s(71%利用率)。除了算法创新,我们完全在Python中嵌入的CuTe-DSL中实现了FlashAttention-4,与传统的基于C++模板的方法相比,编译时间快20-30倍,同时保持了完全的表达能力。
🔬 方法详解
问题定义:论文旨在解决在NVIDIA Blackwell架构(如B200)上,由于tensor core吞吐量大幅提升,而共享内存带宽和指数运算单元等资源增长缓慢或停滞,导致现有Attention机制性能受限的问题。现有方法,如FlashAttention-3,针对H100架构优化,无法充分利用Blackwell架构的优势。
核心思路:FlashAttention-4的核心思路是通过算法和kernel的协同设计,充分利用Blackwell架构的tensor core,同时减少对共享内存带宽和指数运算单元的依赖。具体来说,通过重新设计的流水线实现完全异步的MMA操作,并采用软件模拟的方式减少非matmul操作的开销。
技术框架:FlashAttention-4的整体框架包括以下几个关键模块:(1) 重新设计的流水线,优化数据流和计算顺序,实现完全异步的MMA操作;(2) 软件模拟的指数和条件softmax重新缩放,减少对硬件指数运算单元的依赖;(3) 利用tensor memory和2-CTA MMA模式,减少共享内存的访问和原子加法操作。
关键创新:FlashAttention-4的关键创新在于针对Blackwell架构的非对称硬件扩展特性,提出了算法和kernel的协同优化策略。与FlashAttention-3等现有方法相比,FlashAttention-4更加注重减少对共享内存带宽和指数运算单元的依赖,从而更好地利用Blackwell架构的tensor core。此外,使用CuTe-DSL实现了快速编译。
关键设计:FlashAttention-4的关键设计包括:(1) 采用更大的tile size,以充分利用tensor core的计算能力;(2) 使用软件模拟的指数运算,避免硬件指数运算单元的瓶颈;(3) 利用tensor memory和2-CTA MMA模式,减少共享内存的访问和原子加法操作;(4) 使用CuTe-DSL进行kernel开发,实现快速编译和灵活优化。
🖼️ 关键图片
📊 实验亮点
FlashAttention-4在B200 GPU上使用BF16时,相比cuDNN 9.13实现了高达1.3倍的加速,相比Triton实现了高达2.7倍的加速。在B200 GPU上达到了高达1613 TFLOPs/s的计算吞吐量,硬件利用率达到71%。此外,FlashAttention-4使用CuTe-DSL实现,编译时间比传统的C++模板方法快20-30倍。
🎯 应用场景
FlashAttention-4的潜在应用领域包括大型语言模型、长文本处理、图像识别、语音识别等。通过提高Attention机制的计算效率,FlashAttention-4可以加速这些应用的训练和推理过程,降低计算成本,并支持更大规模的模型和更长的上下文处理。该研究的实际价值在于提升AI模型的性能和效率,推动AI技术在各个领域的应用。未来,FlashAttention-4的优化思路可以推广到其他硬件平台和计算密集型任务中。
📄 摘要(原文)
Attention, as a core layer of the ubiquitous Transformer architecture, is the bottleneck for large language models and long-context applications. While FlashAttention-3 optimized attention for Hopper GPUs through asynchronous execution and warp specialization, it primarily targets the H100 architecture. The AI industry has rapidly transitioned to deploying Blackwell-based systems such as the B200 and GB200, which exhibit fundamentally different performance characteristics due to asymmetric hardware scaling: tensor core throughput doubles while other functional units (shared memory bandwidth, exponential units) scale more slowly or remain unchanged. We develop several techniques to address these shifting bottlenecks on Blackwell GPUs: (1) redesigned pipelines that exploit fully asynchronous MMA operations and larger tile sizes, (2) software-emulated exponential and conditional softmax rescaling that reduces non-matmul operations, and (3) leveraging tensor memory and the 2-CTA MMA mode to reduce shared memory traffic and atomic adds in the backward pass. We demonstrate that our method, FlashAttention-4, achieves up to 1.3$\times$ speedup over cuDNN 9.13 and 2.7$\times$ over Triton on B200 GPUs with BF16, reaching up to 1613 TFLOPs/s (71% utilization). Beyond algorithmic innovations, we implement FlashAttention-4 entirely in CuTe-DSL embedded in Python, achieving 20-30$\times$ faster compile times compared to traditional C++ template-based approaches while maintaining full expressivity.