Multi-Head Low-Rank Attention

📄 arXiv: 2603.02188v1 📥 PDF

作者: Songtao Liu, Hongwu Peng, Zhiwei Zhang, Zhengyu Chen, Yue Guo

分类: cs.LG

发布日期: 2026-03-02

备注: Accepted by ICLR 2026

🔗 代码/项目: GITHUB | HUGGINGFACE


💡 一句话要点

提出多头低秩注意力(MLRA),解决大模型长文本推理中KV缓存的张量并行瓶颈。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 多头注意力 低秩分解 长文本推理 张量并行 KV缓存 大语言模型 分布式解码

📋 核心要点

  1. 现有MLA方法在长文本推理中减少KV缓存大小,但在张量并行解码时存在单头无法分割的瓶颈。
  2. MLRA通过引入多头低秩结构,使得潜在状态可以被分割,从而支持高效的张量并行解码。
  3. 实验表明,MLRA在困惑度和下游任务性能上达到SOTA,并实现了比MLA快2.8倍的解码速度。

📝 摘要(中文)

大型语言模型中的长上下文推理受到解码阶段Key-Value(KV)缓存加载的瓶颈限制。生成过程的顺序性要求在每一步重复地将KV缓存从片外高带宽存储器(HBM)传输到片上静态随机存取存储器(SRAM)。虽然多头潜在注意力(MLA)显著减少了KV缓存的总大小,但它在通过张量并行(TP)进行分布式解码时存在分片瓶颈。由于其单个潜在头无法被分割,每个设备被迫为每个token冗余地加载完整的KV缓存,消耗过多的内存流量并降低了TP的优势,例如权重分片。在这项工作中,我们提出了多头低秩注意力(MLRA),它支持可分割的潜在状态,从而实现高效的四路TP解码。大量实验表明,MLRA实现了最先进的困惑度和下游任务性能,同时还提供了比MLA快2.8倍的解码速度。

🔬 方法详解

问题定义:论文旨在解决大型语言模型在长文本推理时,使用张量并行(TP)进行分布式解码时,由于多头潜在注意力(MLA)的单头结构无法分割,导致每个设备都需要冗余加载完整KV缓存,造成内存带宽浪费和TP效率降低的问题。现有MLA方法虽然减少了KV缓存大小,但牺牲了并行性。

核心思路:论文的核心思路是将MLA的单头结构扩展为多头低秩结构。通过引入多个低秩矩阵,将潜在状态分解为多个可独立计算和分割的部分,从而使得KV缓存可以在TP中进行有效的分片,避免冗余加载。

技术框架:MLRA沿用了Transformer的整体架构,主要改进在于注意力机制部分。在计算注意力权重时,首先将query和key投影到多个低秩子空间,然后分别在这些子空间内计算注意力权重。最后,将各个子空间的注意力结果进行融合,得到最终的注意力输出。整体流程与标准注意力机制类似,但增加了低秩分解和多头融合的步骤。

关键创新:MLRA的关键创新在于将低秩分解与多头注意力机制相结合,实现了潜在状态的可分割性。与MLA相比,MLRA通过多头结构支持张量并行,避免了单头结构的瓶颈。与标准多头注意力相比,MLRA通过低秩分解降低了计算复杂度,提高了效率。

关键设计:MLRA的关键设计包括低秩矩阵的维度选择、多头的数量、以及融合各个子空间注意力结果的方式。论文可能采用了可学习的权重来融合不同子空间的注意力结果。此外,损失函数的设计可能也考虑了低秩约束,以保证低秩分解的有效性。具体的参数设置和网络结构细节需要在论文原文中查找。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

MLRA在实验中取得了显著的性能提升。相较于MLA,MLRA实现了2.8倍的解码速度提升,同时在困惑度和下游任务性能上达到了最先进水平。这些结果表明,MLRA在提高长文本推理效率的同时,保持了甚至提升了模型的准确性。

🎯 应用场景

MLRA可应用于需要处理长文本序列的各种自然语言处理任务,例如机器翻译、文本摘要、问答系统和对话生成。尤其是在资源受限的设备上部署大型语言模型时,MLRA能够有效降低内存需求和提高推理速度,具有重要的实际应用价值。未来,MLRA可以进一步扩展到其他模态,例如图像和视频处理。

📄 摘要(原文)

Long-context inference in large language models is bottlenecked by Key--Value (KV) cache loading during the decoding stage, where the sequential nature of generation requires repeatedly transferring the KV cache from off-chip High-Bandwidth Memory (HBM) to on-chip Static Random-Access Memory (SRAM) at each step. While Multi-Head Latent Attention (MLA) significantly reduces the total KV cache size, it suffers from a sharding bottleneck during distributed decoding via Tensor Parallelism (TP). Since its single latent head cannot be partitioned, each device is forced to redundantly load the complete KV cache for every token, consuming excessive memory traffic and diminishing TP benefits like weight sharding. In this work, we propose Multi-Head Low-Rank Attention (MLRA), which enables partitionable latent states for efficient 4-way TP decoding. Extensive experiments show that MLRA achieves state-of-the-art perplexity and downstream task performance, while also delivering a 2.8$\times$ decoding speedup over MLA. Code is available at https://github.com/SongtaoLiu0823/MLRA. Pretrained weights, along with the training and evaluation data, are available at https://huggingface.co/Soughing/MLRA.