Mixed Distillation Helps Smaller Language Model Better Reasoning
作者: Chenglin Li, Qianglong Chen, Liangyue Li, Caiyu Wang, Yicheng Li, Zulong Chen, Yin Zhang
分类: cs.CL, cs.AI
发布日期: 2023-12-17 (更新: 2024-02-25)
备注: Working in Progress, 17 pages, 16 figures
💡 一句话要点
提出混合蒸馏框架,提升小模型在推理任务上的性能,超越GPT-3.5-Turbo。
🎯 匹配领域: 支柱二:RL算法与架构 (RL & Architecture) 支柱九:具身大模型 (Embodied Foundation Models)
关键词: 知识蒸馏 大型语言模型 推理能力 思维链 思维程序 混合蒸馏 模型压缩
📋 核心要点
- 大型语言模型部署成本高昂,知识蒸馏是降低成本的有效方法,但现有方法在推理任务上效果不佳。
- 混合蒸馏框架结合了思维程序(PoT)和思维链(CoT)的优势,通过多种提示技术将推理能力迁移到小模型。
- 实验表明,混合蒸馏显著提升了小模型在推理任务上的准确性和泛化性,甚至超越了GPT-3.5-Turbo。
📝 摘要(中文)
大型语言模型(LLMs)在自然语言处理(NLP)任务中表现出色,但其部署因计算和内存需求高而面临挑战。最近的研究集中于通过从LLM中进行知识蒸馏来增强较小的模型,并取得了可喜的成果。然而,这些模型通常难以与LLM的性能相匹配,尤其是在需要推理的任务中。本文提出了混合蒸馏(MD)框架,该框架利用LLM中思维程序(PoT)和思维链(CoT)能力的优势,结合多种提示技术,并将这些能力提炼到较小的模型中。实验结果表明,MD显著提高了较小模型在各种任务中的单路径和多路径推理能力。在推理任务的准确性和通用性方面,它生成的模型超过了两个单独蒸馏模型的综合性能。值得注意的是,使用MD的LLaMA2-7B和CodeLlama-7B分别取得了(84.5%)和(85.5%)的显著提升,在SVAMP基准测试中优于GPT-3.5-Turbo (2.5%)和(3.5%)。
🔬 方法详解
问题定义:现有的大型语言模型(LLMs)虽然在各种NLP任务中表现出色,但其巨大的计算和内存需求使其难以在资源受限的环境中部署。知识蒸馏是一种将LLM的知识迁移到较小模型的方法,但现有方法在需要复杂推理的任务中,小模型的性能仍然与LLM存在较大差距。因此,如何有效地将LLM的推理能力迁移到小模型是一个关键问题。
核心思路:本文的核心思路是利用LLM中不同的推理方法,如思维链(CoT)和思维程序(PoT),并将这些方法结合起来,通过混合蒸馏的方式,让小模型学习到更全面的推理能力。通过结合多种提示技术,可以更有效地提取LLM的知识,并将其迁移到小模型中。
技术框架:混合蒸馏(MD)框架主要包含以下几个阶段:首先,使用不同的提示技术(包括CoT和PoT)生成LLM的推理过程。然后,将这些推理过程作为训练数据,用于训练小模型。在训练过程中,使用混合损失函数,同时考虑预测结果的准确性和推理过程的相似性。最后,对训练好的小模型进行评估,验证其在推理任务上的性能。
关键创新:该方法最重要的创新点在于混合蒸馏的思想,即同时利用多种推理方法(CoT和PoT)进行知识蒸馏。与传统的只使用单一推理方法相比,混合蒸馏可以使小模型学习到更全面的推理能力,从而提高其在复杂推理任务中的性能。此外,该方法还提出了一种新的混合损失函数,可以更好地指导小模型的训练。
关键设计:在提示技术方面,论文使用了多种CoT和PoT的变体,以生成更丰富的推理过程。在损失函数方面,论文结合了交叉熵损失和KL散度损失,以同时优化预测结果的准确性和推理过程的相似性。具体的参数设置和网络结构取决于所使用的小模型,例如LLaMA2-7B和CodeLlama-7B。
📊 实验亮点
实验结果表明,使用混合蒸馏(MD)框架的LLaMA2-7B和CodeLlama-7B在SVAMP基准测试中分别取得了84.5%和85.5%的准确率,超越了GPT-3.5-Turbo的性能(分别高出2.5%和3.5%)。这表明混合蒸馏可以显著提高小模型在推理任务上的性能,使其能够与更大的模型相媲美。
🎯 应用场景
该研究成果可应用于各种需要推理能力的自然语言处理任务,例如问答系统、文本摘要、机器翻译等。通过将大型语言模型的推理能力迁移到小型模型,可以在资源受限的环境中部署高性能的NLP应用,例如移动设备、嵌入式系统等。此外,该方法还可以用于提高模型的鲁棒性和泛化能力。
📄 摘要(原文)
While large language models (LLMs) have demonstrated exceptional performance in recent natural language processing (NLP) tasks, their deployment poses substantial challenges due to high computational and memory demands in real-world applications. Recent studies have focused on enhancing smaller models through knowledge distillation from LLMs, yielding promising results. However, these models often struggle to match the performance of LLMs, especially in tasks that require reasoning. In this work, we introduce Mixed Distillation (MD) framework, which capitalizes on the strengths of Program of Thought (PoT) and Chain of Thought (CoT) capabilities within LLMs, combining multiple prompting techniques and distilling these capabilities into smaller models. Our experimental results show that MD significantly enhances the single-path and multi-path reasoning ability of smaller models in various tasks. In terms of accuracy and generality of reasoning tasks, the model generated by it exceeds the comprehensive performance of two individually distilled models. Notably, LLaMA2-7B and CodeLlama-7B using MD achieved remarkable improvements of (84.5%) and (85.5%), respectively, outperforming GPT-3.5-Turbo by (2.5%) and (3.5%), on the SVAMP benchmark.