LLM REgression with a Latent Iterative State Head
作者: Yiheng Su, Matthew Lease
分类: cs.CL, cs.LG
发布日期: 2026-04-01
💡 一句话要点
提出RELISH,一种用于LLM文本回归的轻量级迭代状态头
🎯 匹配领域: 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 文本回归 大型语言模型 交叉注意力 迭代状态 参数效率
📋 核心要点
- 现有LLM回归方法通常效率较低,要么需要解码文本,要么聚合多个输出,参数量大。
- RELISH通过迭代细化潜在状态,并使用交叉注意力机制,直接从冻结的LLM表示中预测标量值。
- 实验表明,RELISH在多个数据集和LLM上优于现有基线,同时保持了极高的参数效率。
📝 摘要(中文)
本文提出了一种名为RELISH(具有潜在迭代状态头的回归)的新型轻量级架构,专为大型语言模型的文本回归而设计。RELISH没有将数值目标解码为文本或聚合多个生成的输出,而是通过迭代地细化学习到的潜在状态,并通过对token级别表示的交叉注意力,直接从冻结的LLM表示中预测标量值,然后使用线性回归器将最终状态映射到点估计。在五个数据集、四个LLM骨干网络和两种LLM训练方案中,RELISH始终优于来自所有三个主要LLM回归系列的先前基线,包括自回归解码、回归感知推理和现有的预测头方法。尽管取得了这些收益,RELISH仍然具有很高的参数效率,在冻结的LLM骨干网络中仅需要3.4-3.7M个可训练参数(仅增加0.01-0.04%的额外开销),远低于基于LoRA的替代方案,后者随模型大小而增长(0.26-0.42%)。
🔬 方法详解
问题定义:论文旨在解决大型语言模型(LLM)在文本回归任务中的效率问题。现有的方法,如自回归解码、回归感知推理以及现有的预测头方法,要么计算成本高昂,要么需要大量的参数调整,限制了它们在资源受限场景下的应用。
核心思路:RELISH的核心思路是通过学习一个潜在的迭代状态来逐步逼近目标值。该潜在状态通过与LLM的token级别表示进行交叉注意力交互,从而提取与回归任务相关的特征,并进行迭代更新。最终,该状态被映射到一个标量值,作为回归的预测结果。
技术框架:RELISH的整体架构包括以下几个主要模块:1) 冻结的LLM骨干网络,用于提取token级别的表示;2) 一个可学习的潜在状态向量;3) 一个交叉注意力模块,用于将LLM的表示与潜在状态进行融合;4) 一个线性回归器,用于将最终的潜在状态映射到标量预测值。整个流程是,首先LLM提取文本特征,然后潜在状态通过交叉注意力迭代更新,最后线性回归器输出预测值。
关键创新:RELISH的关键创新在于其迭代的潜在状态更新机制和轻量级的设计。与直接解码或聚合输出的方法不同,RELISH通过迭代地细化潜在状态,能够更有效地利用LLM的知识,同时避免了生成文本的开销。此外,RELISH的参数量非常小,使其易于部署在资源受限的环境中。
关键设计:RELISH的关键设计包括:1) 使用交叉注意力机制来融合LLM的表示和潜在状态,使得潜在状态能够关注到与回归任务相关的token;2) 迭代更新潜在状态,逐步逼近目标值;3) 使用线性回归器作为最终的预测层,保证了模型的简单性和可解释性。损失函数通常采用均方误差(MSE)等回归任务常用的损失函数。
🖼️ 关键图片
📊 实验亮点
RELISH在五个数据集、四个LLM骨干网络和两种LLM训练方案中,始终优于现有的LLM回归基线方法。更重要的是,RELISH仅需3.4-3.7M可训练参数,相比于LoRA等方法,参数效率显著提升,仅增加了0.01-0.04%的额外开销。
🎯 应用场景
RELISH可应用于情感分析、文本可读性评估、论文评分等多种文本回归任务。其轻量级和高效的特性使其特别适用于移动设备、嵌入式系统等资源受限的场景。未来,RELISH可以扩展到其他模态的数据回归任务,例如图像或音频。
📄 摘要(原文)
We present RELISH (REgression with a Latent Iterative State Head), a novel, lightweight architecture designed for text regression with large language models. Rather than decoding numeric targets as text or aggregating multiple generated outputs, RELISH predicts scalar values directly from frozen LLM representations by iteratively refining a learned latent state through cross-attention over token-level representations, and then mapping the final state to a point estimate with a linear regressor. Across five datasets, four LLM backbones, and two LLM training regimes, RELISH consistently outperforms prior baselines from all three major LLM regression families, including autoregressive decoding, regression-aware inference, and existing predictive head methods. Despite these gains, RELISH remains highly parameter-efficient, requiring only 3.4-3.7M trainable parameters across frozen LLM backbones (only 0.01-0.04% additional overhead), far less than LoRA-based alternatives that grow with model size (0.26-0.42%).