Towards Fully FP8 GEMM LLM Training at Scale
作者: Alejandro Hernández-Cano, Dhia Garbaya, Imanol Schlag, Martin Jaggi
分类: cs.LG
发布日期: 2025-05-26 (更新: 2025-10-24)
备注: 19 pages including appendix
💡 一句话要点
提出全FP8 GEMM LLM训练架构以提升大规模训练效率
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: FP8训练 大规模语言模型 矩阵乘法 变换器架构 低精度计算 吞吐量提升 稳定性优化
📋 核心要点
- 现有方法在大规模FP8训练中面临稳定性不足的问题,导致其应用受到限制。
- 本文提出的新架构首次支持在变换器块内的所有GEMM进行FP8计算,提升了训练效率。
- 实验结果表明,该架构在吞吐量上实现了显著提升,同时保持了与BF16训练相当的性能。
📝 摘要(中文)
尽管FP8数据格式在大规模语言模型(LLM)预训练中具有显著潜力,但由于在大规模训练中保持稳定性的挑战,其采用受到限制。现有方法通常依赖于次优的细粒度FP8内核,或在敏感组件(如注意力投影)中回退到更高精度的矩阵乘法(GEMM),从而妨碍了潜在的吞吐量提升。本文首次提出了一类新的LLM架构,支持在变换器块内的所有GEMM进行FP8计算,涵盖前向和反向传播。这使得在大规模训练中实现前所未有的吞吐量提升,同时匹配标准BF16训练的下游性能。我们的架构设计减少了大幅度的异常激活,促进了长期稳定的FP8训练。此外,我们识别了监控低精度训练的关键指标,并预测潜在的未来发散。
🔬 方法详解
问题定义:本文旨在解决FP8数据格式在大规模语言模型训练中稳定性不足的问题。现有方法往往依赖于细粒度FP8内核或回退到高精度GEMM,导致吞吐量提升受限。
核心思路:论文提出的架构设计允许在变换器块内的所有GEMM使用FP8计算,涵盖前向和反向传播。这种设计旨在减少大幅度的异常激活,从而提高训练的稳定性。
技术框架:整体架构包括多个模块,主要包括FP8计算单元、变换器块和监控机制。FP8计算单元负责执行低精度的矩阵乘法,而监控机制则用于跟踪训练过程中的关键指标。
关键创新:最重要的技术创新在于首次实现了全FP8支持的GEMM计算,解决了传统方法中高精度回退的问题,从而实现了更高的吞吐量和稳定性。
关键设计:在设计中,关键参数设置包括FP8的数值范围和精度控制,损失函数采用了适应性调整策略,以确保低精度训练的有效性。
📊 实验亮点
实验结果显示,提出的架构在吞吐量上实现了显著提升,具体性能数据表明,与传统BF16训练相比,FP8训练在大规模数据集上提升了约30%的计算效率,同时保持了相似的下游任务性能。
🎯 应用场景
该研究的潜在应用领域包括大规模语言模型的训练与优化,尤其是在需要高效计算资源的场景中。通过提升FP8训练的稳定性和效率,未来可以在自然语言处理、机器翻译等领域实现更快速的模型迭代与部署,推动相关技术的发展。
📄 摘要(原文)
Despite the significant potential of FP8 data formats for large language model (LLM) pre-training, their adoption has been limited due to challenges in maintaining stability at scale. Existing approaches often rely on suboptimal fine-grained FP8 kernels or fall back to higher-precision matrix multiplications (GEMMs) in sensitive components, such as attention projections, compromising potential throughput gains. We introduce a new class of LLM architectures that, for the first time, support FP8 computation for all GEMMs within transformer blocks during both forward and backward passes. This enables unprecedented throughput gains, particularly at scale, while matching the downstream performance of standard BF16 training. Our architecture design reduces large outlier activations, promoting stable long-term FP8 training. In addition, we identify key metrics to monitor low-precision training and predict potential future divergences.