Output Embedding Centering for Stable LLM Pretraining

📄 arXiv: 2601.02031v1 📥 PDF

作者: Felix Stollenwerk, Anna Lokrantz, Niclas Hertzberg

分类: cs.LG, cs.AI, cs.CL

发布日期: 2026-01-05

备注: 11 pages, 5 figures


💡 一句话要点

提出输出嵌入中心化(OEC)方法,解决LLM预训练中输出Logit发散问题。

🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)

关键词: 大型语言模型 预训练 训练稳定性 输出嵌入 中心化

📋 核心要点

  1. 大型语言模型预训练易出现输出logit发散,现有z-loss方法仅缓解症状。
  2. 提出输出嵌入中心化(OEC)方法,从输出嵌入几何角度抑制logit发散。
  3. OEC的μ-centering和μ-loss变体均优于z-loss,对超参数不敏感。

📝 摘要(中文)

大型语言模型的预训练不仅成本高昂,而且容易出现某些训练不稳定现象。一种常见的、在训练后期使用较大学习率时发生的特定不稳定现象是输出logit发散。目前最广泛使用的缓解策略z-loss仅仅解决了症状,而没有解决问题的根本原因。本文从输出嵌入几何的角度分析了这种不稳定性,并确定了其原因。基于此,我们提出输出嵌入中心化(OEC)作为一种新的缓解策略,并证明它可以抑制输出logit发散。OEC可以通过两种不同的方式实现,一种是确定性操作,称为μ-centering,另一种是正则化方法,称为μ-loss。实验表明,在训练稳定性和学习率敏感性方面,这两种变体都优于z-loss。特别是,即使在z-loss失败的情况下,它们也能确保训练在较大的学习率下收敛。此外,我们发现μ-loss对正则化超参数的调整远不如z-loss敏感。

🔬 方法详解

问题定义:论文旨在解决大型语言模型预训练过程中,由于输出层logit发散导致的训练不稳定问题。现有方法,如z-loss,虽然可以缓解症状,但未能从根本上解决问题,并且对超参数敏感。

核心思路:论文的核心思路是分析输出嵌入的几何结构,发现logit发散的根本原因是输出嵌入没有以零为中心。因此,通过将输出嵌入中心化,可以有效地抑制logit发散,从而提高训练的稳定性。

技术框架:论文提出了两种实现输出嵌入中心化的方法:μ-centering和μ-loss。μ-centering是一种确定性操作,直接对输出嵌入进行中心化处理。μ-loss是一种正则化方法,通过在损失函数中添加一个正则项,促使输出嵌入趋向于中心化。整体框架是在标准的语言模型预训练流程中,加入OEC模块,可以是μ-centering或μ-loss。

关键创新:论文的关键创新在于从输出嵌入几何的角度分析了logit发散的原因,并提出了输出嵌入中心化(OEC)这一根本性的解决方案。与z-loss等仅缓解症状的方法不同,OEC直接解决了logit发散的根源。

关键设计:μ-centering通过计算输出嵌入的均值,然后从每个嵌入向量中减去该均值来实现中心化。μ-loss则是在标准交叉熵损失函数的基础上,添加一个正则化项,该正则项惩罚输出嵌入的均值偏离零的情况。正则化系数μ控制了中心化的强度。论文中详细描述了μ-centering和μ-loss的具体计算公式和实现细节。

🖼️ 关键图片

fig_0
fig_1
fig_2

📊 实验亮点

实验结果表明,OEC的两种变体(μ-centering和μ-loss)在训练稳定性和学习率敏感性方面均优于z-loss。特别是在使用较大学习率时,z-loss无法收敛,而OEC仍然可以保证训练的稳定进行。此外,μ-loss对正则化超参数的敏感度明显低于z-loss,降低了调参的难度。

🎯 应用场景

该研究成果可应用于各种大型语言模型的预训练过程,提高训练的稳定性和效率,尤其是在使用较大学习率的情况下。这有助于降低预训练的成本,并加速新模型的开发和部署。此外,该方法还可以推广到其他类型的神经网络模型中,以解决类似的训练不稳定问题。

📄 摘要(原文)

Pretraining of large language models is not only expensive but also prone to certain training instabilities. A specific instability that often occurs for large learning rates at the end of training is output logit divergence. The most widely used mitigation strategy, z-loss, merely addresses the symptoms rather than the underlying cause of the problem. In this paper, we analyze the instability from the perspective of the output embeddings' geometry and identify its cause. Based on this, we propose output embedding centering (OEC) as a new mitigation strategy, and prove that it suppresses output logit divergence. OEC can be implemented in two different ways, as a deterministic operation called μ-centering, or a regularization method called μ-loss. Our experiments show that both variants outperform z-loss in terms of training stability and learning rate sensitivity. In particular, they ensure that training converges even for large learning rates when z-loss fails. Furthermore, we find that μ-loss is significantly less sensitive to regularization hyperparameter tuning than z-loss.