LoLA: Low-Rank Linear Attention With Sparse Caching
作者: Luke McDermott, Robert W. Heath, Rahul Parhi
分类: cs.CL, cs.LG
发布日期: 2025-05-29 (更新: 2025-09-30)
💡 一句话要点
提出LoLA以提升线性注意力的关联记忆能力
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture)
关键词: 线性注意力 关联记忆 终身学习 自然语言处理 模型优化 内存管理 推理效率
📋 核心要点
- 现有的变换器模型在推理时,随着上下文长度的增加,计算成本急剧上升,限制了其在终身学习中的应用。
- LoLA通过将过去的键值对分配到不同的内存系统中,提升了线性注意力的关联记忆能力,且无需额外训练。
- 实验结果显示,LoLA在关键检索任务中准确率从0.6%提升至97.4%,并且在内存使用上显著优于现有模型。
📝 摘要(中文)
随着上下文长度的增加,变换器推理的每个标记成本也随之增加,这限制了其在终身上下文学习中的应用。线性注意力作为一种高效的替代方案,能够在无限上下文长度下保持恒定的内存占用,但在内存容量上存在不足。本文提出了LoLA,这是一种无需训练的线性注意力增强方法,旨在提升关联记忆的回忆能力。LoLA将过去的键值对分配到三个内存系统中:最近的对在局部滑动窗口缓存中,难以记忆的对在稀疏全局缓存中,以及通用的对在线性注意力的递归隐藏状态中。通过消融实验,我们证明了自回忆误差度量在有效管理长期关联记忆中的重要性。在关键检索任务中,LoLA将基础模型的性能从0.6%提升至97.4%的准确率,并且在4K上下文长度下使用的缓存比Llama-3.1 8B小4.6倍。LoLA在零-shot常识推理任务中也优于其他1B和8B参数的亚平方模型。
🔬 方法详解
问题定义:本文旨在解决变换器模型在推理时因上下文长度增加而导致的计算成本上升问题,特别是在终身学习场景中的应用限制。现有的线性注意力虽然在内存占用上表现良好,但在长期记忆容量方面仍显不足。
核心思路:LoLA通过将过去的键值对分配到三个不同的内存系统中,增强了线性注意力的关联记忆能力。具体而言,最近的键值对存储在局部滑动窗口缓存中,难以记忆的对存储在稀疏全局缓存中,而通用的对则保留在线性注意力的递归隐藏状态中。
技术框架:LoLA的整体架构包括三个主要模块:局部滑动窗口缓存、稀疏全局缓存和递归隐藏状态。每个模块负责不同类型的键值对存储和管理,从而提高了模型的记忆能力和推理效率。
关键创新:LoLA的主要创新在于其训练-free的设计,能够在不增加训练成本的情况下,显著提升线性注意力的记忆能力。这种设计使得模型在处理长上下文时,能够有效地管理和回忆重要信息。
关键设计:在LoLA中,缓存的大小和结构经过精心设计,以确保在保持较小内存占用的同时,最大化信息的回忆能力。具体的参数设置和损失函数的选择也经过实验验证,以确保模型在不同任务中的最佳性能。
📊 实验亮点
LoLA在关键检索任务中的表现显著提升,准确率从0.6%跃升至97.4%。此外,其在4K上下文长度下的缓存使用量比Llama-3.1 8B小4.6倍,显示出更高的内存效率。LoLA在零-shot常识推理任务中也超越了其他1B和8B参数的亚平方模型,展现了其优越的性能。
🎯 应用场景
LoLA的研究成果在多个领域具有广泛的应用潜力,尤其是在需要处理长上下文的自然语言处理任务中,如对话系统、文本生成和知识检索等。通过提升模型的记忆能力,LoLA能够更好地支持终身学习和动态知识更新,推动智能系统的进一步发展。
📄 摘要(原文)
The per-token cost of transformer inference scales with context length, preventing its application to lifelong in-context learning. Linear attention is an efficient alternative that maintains a constant memory footprint, even on infinite context lengths. While this is a potential candidate for lifelong learning, it falls short in memory capacity. In this paper, we propose LoLA, a training-free augmentation to linear attention that boosts associative recall. LoLA distributes past key-value pairs from context into three memory systems: (i) recent pairs in a local sliding window cache; (ii) difficult-to-memorize pairs in a sparse, global cache; and (iii) generic pairs in the recurrent hidden state of linear attention. We show through ablations that our self-recall error metric is crucial to efficiently manage long-term associative memories. On pass-key retrieval tasks, LoLA improves the base model's performance from 0.6% to 97.4% accuracy. This is achieved with a 4.6x smaller cache than Llama-3.1 8B on 4K context length. LoLA also outperforms other 1B and 8B parameter subquadratic models on zero-shot commonsense reasoning tasks.