CN111767711B - 基于知识蒸馏的预训练语言模型的压缩方法及平台 - Google Patents

基于知识蒸馏的预训练语言模型的压缩方法及平台 Download PDF

Info

Publication number
CN111767711B
CN111767711B CN202010910566.3A CN202010910566A CN111767711B CN 111767711 B CN111767711 B CN 111767711B CN 202010910566 A CN202010910566 A CN 202010910566A CN 111767711 B CN111767711 B CN 111767711B
Authority
CN
China
Prior art keywords
model
teacher
student
module
network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010910566.3A
Other languages
English (en)
Other versions
CN111767711A (zh
Inventor
王宏升
单海军
鲍虎军
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202010910566.3A priority Critical patent/CN111767711B/zh
Publication of CN111767711A publication Critical patent/CN111767711A/zh
Application granted granted Critical
Publication of CN111767711B publication Critical patent/CN111767711B/zh
Priority to PCT/CN2020/138019 priority patent/WO2021248868A1/zh
Priority to JP2022570419A priority patent/JP7381813B2/ja
Priority to GB2214161.8A priority patent/GB2608919A/en
Priority to US17/483,805 priority patent/US11341326B2/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/205Parsing
    • G06F40/211Syntactic parsing, e.g. based on context-free grammar [CFG] or unification grammars
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/30Semantic analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Biomedical Technology (AREA)
  • Medical Informatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Machine Translation (AREA)

Abstract

本发明公开了一种基于知识蒸馏的预训练语言模型的压缩方法及平台,该方法首先设计一种普适的特征迁移的知识蒸馏策略,在教师模型的知识蒸馏到学生模型的过程中,将学生模型每一层的特征映射逼近教师的特征,重点关注小样本在教师模型中间层特征表达能力,并利用这些特征指导学生模型;然后利用教师模型的自注意力分布具有检测词语之间语义和句法的能力构建一种基于自注意力交叉知识蒸馏方法;最后为了提升学习模型训练前期的学习质量和训练后期的泛化能力,设计了一种基于伯努利概率分布的线性迁移策略逐渐完成从教师到学生的特征映射和自注意分布的知识迁移。通过本发明,将面向多任务的预训练语言模型进行自动压缩,提高语言模型的压缩效率。

Description

基于知识蒸馏的预训练语言模型的压缩方法及平台
技术领域
本发明属于面向多任务的预训练语言模型自动压缩领域,尤其涉及一种基于知识蒸馏的预训练语言模型的压缩方法及平台。
背景技术
随着智能设备的普及,大规模语言模型在智能手机、可穿戴设备等嵌入式设备上的应用越来越常见,然而深度学习网络规模却在不断增大,计算复杂度随之增高,严重限制了其在手机等智能设备上的应用,如今的应对方法还是单向地从教师模型的知识蒸馏到学生模型的压缩方法,但是小样本在大规模语言模型压缩过程中难泛化的问题依然存在。
发明内容
本发明的目的在于针对现有技术的不足,提供一种基于知识蒸馏的预训练语言模型的压缩方法及平台。本发明基于知识蒸馏的预训练语言模型压缩,设计一个与任务无关的小模型去学习一个大模型的表达能力,压缩出某一类任务通用的架构,充分利用已压缩好的模型架构,提高模型压缩效率。具体地,通过特征映射知识蒸馏模块、自注意力交叉知识蒸馏和基于伯努利概率分布的线性迁移策略,实现了教师模型和学生模型在训练过程中渐进式地相互学习,从而提高了小样本情况下学生模型训练前期的学习质量和训练后期的泛化能力。
本发明的目的是通过以下技术方案来实现的:一种基于知识蒸馏的预训练语言模型的压缩方法,该方法对BERT模型进行压缩,包括特征映射知识蒸馏模块、自注意力交叉知识蒸馏模块和基于伯努利概率分布的线性学习模块;其中,原始的模型为教师模型,压缩后的模型为学生模型;特征映射知识蒸馏模块基于一种特征迁移的知识蒸馏策略,在教师模型的知识蒸馏到学生模型的过程中,将学生模型每一层的特征映射逼近教师模型的特征映射,学生模型关注教师模型的中间层特征,并利用这些中间层特征指导学生模型;自注意力交叉知识蒸馏模块通过交叉连接教师模型和学生模型的自注意力模块,通过在网络自注意层上进行凸组合交叉连接的方式,实现教师模型和学生模型的深度相互学习;基于伯努利概率分布的线性学习模块逐渐完成从教师模型到学生模型的特征映射和自注意分布的知识迁移。
进一步地,所述特征映射知识蒸馏模块中增加层间归一化以稳定层间训练损失;训练学生网络时,最小化特征图转换中均值和方差两个统计差异。
进一步地,所述自注意力交叉知识蒸馏模块的迁移目标函数是最小化学生模型和教师模型的注意力分布之间的相对熵。
进一步地,所述自注意力交叉知识蒸馏模块包括以下三个阶段:
第一阶段:教师网络的自注意力单元输入学生网络,并最小化迁移目标函数,具体地,将教师网络的自注意力单元当作基本真值,在网络自注意力单元位置输入学生网络,学生网络接受正确的监督信号以对后续层进行训练,避免估计误差过大并传播的现象;
第二阶段:学生网络的自注意力单元输入教师网络,并最小化迁移目标函数;由于估计误差在学生网络上逐层传播,导致在同一层位置上学生网络输入和教师网络输入存在差异;将学生网络自注意力单元输入给教师网络,实现了在相同输入前提下让学生网络模仿教师网络的输出行为;
第三阶段:在网络自注意力单元上将所述第一阶段和第二阶段的迁移目标函数进行凸组合,实现交叉迁移的蒸馏策略。
进一步地,所述基于伯努利概率分布的线性学习模块用于为驱动特征映射知识蒸馏模块和自注意力交叉知识蒸馏模块设置不同的线性迁移概率,包括以下两个步骤:
步骤一:特征映射知识蒸馏模块和自注意力交叉知识蒸馏模块均采用伯努利概率分布的迁移概率,即假设当前迁移第i个模块,先通过一个伯努利分布,采样一个随机变量X,X为0或1;当随机变量为1时代表当前模块进行迁移学习,否则不进行;
步骤二:虽然步骤一中设置一个恒定的迁移概率p,可以满足压缩模型的需要,但 是线性学习驱动的迁移概率有助于逐步迁移模型中的编码器模块,本步骤设计了一个线性 学习驱动的迁移概率
Figure 539959DEST_PATH_IMAGE001
来动态调整步骤一中的迁移概率p,即
Figure 253837DEST_PATH_IMAGE002
其中,
Figure 315465DEST_PATH_IMAGE001
表示当前迁移模块的迁移概率,第i个模块迁移对应当前训练第i步, b表示未训练时的初始迁移概率;
Figure 507412DEST_PATH_IMAGE003
为大于0的动态值,且满足在训练增加至1000步、5000步、 10000步、30000步时,相应地,
Figure 852943DEST_PATH_IMAGE001
逐渐增加至0.25、0.5、0.75、1.00。
进一步地,所述初始迁移概率b取值范围在0.1至0.3之间。
一种根据上述方法的基于知识蒸馏的预训练语言模型的压缩平台,该平台包括以下组件:
数据加载组件:用于获取面向多任务的BERT预训练语言模型及其训练样本;所述训练样本是满足监督学习任务的有标签的文本样本;
压缩组件:用于将面向多任务的大规模语言模型进行压缩,包括教师模型微调模块、教师-学生模型蒸馏模块和学生模型微调模块;其中,教师模型微调模块负责加载BERT预训练模型,将训练样本输入包含下游任务的BERT模型进行微调,输出教师模型;教师-学生模型蒸馏模块利用所述教师模型微调模块获得的教师模型,通过所述特征映射知识蒸馏模块、自注意力交叉知识蒸馏模块和基于伯努利概率分布的线性学习模块,逐渐完成从教师到学生的特征映射和自注意分布的知识蒸馏,更新学生网络的各个单元模块的权重参数;学生模型微调模块是将学生网络所有编码器单元模块重新组合成完整的编码器,并利用教师网络的特征层和输出层对下游任务场景进行微调,输出微调好的学生模型,作为最终的压缩模型;
推理组件:利用所述压缩组件输出的压缩模型在实际场景的数据集上对自然语言处理下游任务进行推理。
进一步地,所述压缩组件将所述压缩模型输出到指定的容器,供用户下载,并呈现压缩前后模型大小的对比信息;通过推理组件利用压缩模型对自然语言处理下游任务进行推理,并呈现压缩前后推理速度的对比信息。
本发明的有益效果是:本发明是基于知识蒸馏的预训练语言模型压缩,设计一个与任务无关的小模型去学习一个大模型的表达能力,压缩出某一类任务通用的架构,充分利用已压缩好的模型架构,提高模型压缩效率。
本发明推动了大规模深度学习语言模型在内存小、资源受限等端侧设备上的部署进程。以BERT为代表的大规模自然语言处理预训练模型显著提升了自然语言处理任务的效果,促进了自然语言处理领域的发展。尽管BERT等模型效果很好,但是如果一个对话机器人一秒钟只能处理一条信息很难满足实际场景需求,而且数十亿级别参数的模型,超大规模的GPU机器学习集群和超长的模型训练时间,给模型的落地带来了阻碍。本发明就是为了解决上述工业落地面临的痛点,利用本发明所述的基于知识蒸馏的预训练语言模型的压缩平台,压缩出某一类自然语言处理任务的通用的架构,充分利用已压缩好的模型架构,可以在保证现有模型的性能和精度基本不变的前提下,减少计算量、缩小模型体积、加快模型推理速率,并且可将大规模自然语言处理模型部署在内存小、资源受限等端侧设备上进行部署,推动了通用深度语言模型在工业界的落地进程。
附图说明
图1是本发明基于知识蒸馏的预训练语言模型的压缩方法及平台的整体架构图;
图2是自注意力单元交叉知识蒸馏过程示意图。
具体实施方式
如图1所示,一种基于知识蒸馏的预训练语言模型的压缩方法包括特征映射知识蒸馏模块、自注意力交叉知识蒸馏模块和基于伯努利概率分布的线性学习模块。其中,特征映射知识蒸馏模块是一种普适的特征迁移的知识蒸馏策略,在教师模型的知识蒸馏到学生模型的过程中,将学生模型每一层的特征映射逼近教师的特征,学生模型更多地关注教师模型的中间层特征,并利用这些特征指导学生模型。自注意力交叉知识蒸馏模块,即通过交叉连接教师和学生网络的自注意力模块,通过在网络自注意层上进行凸组合交叉连接的方式,实现教师模型和学生模型的深度相互学习。基于伯努利概率分布的线性学习模块逐渐完成从教师到学生的特征映射和自注意分布的知识迁移,为驱动特征映射知识蒸馏模块和自注意力交叉知识蒸馏模块设置不同的线性迁移概率。
本发明一种基于知识蒸馏的预训练语言模型的压缩方法,对BERT(BidirectionalEncoder Representations from Transformers,来自变换器的双向编码器表征量)模型进行压缩,利用已压缩的模型架构,提高压缩效率。本发明将编码器单元作为模块的基本单元;将原始的模型简称为教师模型,压缩后的模型简称为学生模型。假设模型层数的压缩比为2,即压缩一半的层数。原始教师模型为12层,压缩后为6层,那么对于学生模型来说,一共6个模块,每个模块包含一个编码器单元。对于教师模型,我们将12层分隔成6个模块,每个模块包含两个编码器单元,此时可以将教师模型和学生模型建立一对一的映射关系,之后就可以进行正式的压缩步骤了;整个过程都是在具体某个自然语言处理任务的下游任务的微调阶段实施,而不是在预训练阶段。为了加速整个训练过程,使用教师模型的部分权重来初始化学生模型的所有单元模块,即将教师模型前六层的编码器单元权重与学生模型的六层编码器单元权重共享。
本发明的压缩方法整个过程分为三个阶段,第一个阶段是微调教师模型。首先需要使用12层原始BERT模型微调出一个教师模型;第二个阶段是特征映射知识蒸馏和自注意力交叉知识蒸馏阶段,这一阶段同时考虑了教师模型和学生模型,让两个模型都参与到训练中;第三个阶段是对学生模型单独微调,目的是为了让所有学生模型的模块完整参与到训练任务中;具体过程如下:
步骤一:加载预训练BERT模型和数据集,微调教师模型;所述BERT模型可包含具体某个自然语言处理下游任务。
步骤二:如图1所示,冻结学生网络自注意力单元模块的权重参数,利用伯努利概率分布的线性学习策略完成教师模型到学生模型的特征映射知识蒸馏过程,更新学生网络其它单元模块的权重参数,包括以下子步骤:
(2.1)假设当前迁移第i个特征映射模块,先通过一个伯努利分布,采样一个随机变量X(X为0或1),当随机变量为1时代表当前模块进行迁移学习,对当前教师网络的特征映射单元进行线性迁移,否则不进行。
(2.2)考虑到线性学习驱动的迁移概率可以逐步迁移模型中的特征映射模块,本 步骤设计了一个线性学习驱动的迁移概率
Figure 222875DEST_PATH_IMAGE004
来动态调整步骤(2.1)中的迁移概率,即
Figure 21067DEST_PATH_IMAGE005
其中,
Figure 751126DEST_PATH_IMAGE006
表示当前迁移模块线性学习驱动的迁移概率,第i个模块迁移对应当 前训练第i步,
Figure 967475DEST_PATH_IMAGE007
表示初始( i为0时)的迁移概率,其取值范围在0.1至0.3之间。
Figure 491997DEST_PATH_IMAGE008
取大于0的 动态值,且满足在训练步数增加至1000步、5000步、10000步、30000步时,相应地,
Figure 43064DEST_PATH_IMAGE004
逐 渐增加至0.25、0.5、0.75、1.00。
(2.3)教师模型和学生模型的特征映射之间的均方误差被用作知识迁移目标函数,并增加层间归一化以稳定层间训练损失;训练学生网络时,最小化特征图转换中均值和方差两个统计差异。
步骤三:如图2所示,自注意力交叉知识蒸馏阶段逐渐完成从教师到学生的自注意分布的知识蒸馏,更新学生网络的各个单元模块的权重参数,将教师和学生网络的自注意力单元进行凸组合交叉连接;其中,迁移目标函数是最小化学生模型和教师模型的注意力分布之间的相对熵;包括以下子步骤:
(3.1)教师网络的自注意力单元输入学生网络,并最小化迁移目标函数,具体地,将教师网络的自注意力单元当作基本真值,在网络自注意力单元位置输入学生网络,学生网络接受正确的监督信号以对后续层进行训练,避免估计误差过大并传播的现象。当然不会将每个教师网络的自注意力单元输入学生网络,基于伯努利概率分布的线性学习的教师网络的自注意力单元迁移策略,包括以下子步骤:
(3.1.1)假设当前迁移第i个模块,先通过一个伯努利分布,采样一个随机变量X(X为0或1),当随机变量为1时代表当前模块进行迁移学习,对当前教师网络的自注意力单元进行线性迁移,否则不进行。
(3.1.2)虽然步骤(3.1.1)中设置一个恒定的迁移概率
Figure 311234DEST_PATH_IMAGE009
,可以满足压缩模型的需 要,但是线性学习驱动的迁移概率有助于逐步迁移模型中的编码器模块,本步骤设计了一 个线性学习驱动的迁移概率
Figure 116510DEST_PATH_IMAGE010
来动态调整步骤(3.2.1)中的迁移概率
Figure 343092DEST_PATH_IMAGE009
,即
Figure 115876DEST_PATH_IMAGE005
其中,
Figure 672891DEST_PATH_IMAGE006
表示当前迁移模块线性学习驱动的迁移概率,第i个模块迁移对应当 前训练第i步,
Figure 113099DEST_PATH_IMAGE007
表示初始(i为0时)的迁移概率,其取值范围在0.1至0.3之间。
Figure 979424DEST_PATH_IMAGE008
取大于0的动 态值,且满足在训练步数增加至1000步、5000步、10000步、30000步时,相应地,
Figure 239504DEST_PATH_IMAGE004
逐渐 增加至0.25、0.5、0.75、1.00。
(3.2)学生网络的自注意力单元输入教师网络,并最小化迁移目标函数。由于估计误差在学生网络上逐层传播,导致在同一层位置上学生网络输入和教师网络输入存在较大差异。将学生网络自注意力单元输入给教师网络,实现了在相同输入前提下让学生网络模仿教师网络的输出行为,同时,基于伯努利概率分布的线性学习的学生网络的自注意力单元迁移策略与步骤(3.1)中教师网络的自注意力单元迁移策略相同。
(3.3)将步骤(3.1)和(3.2)的迁移目标函数进行凸组合,实现交叉迁移的蒸馏策略。整个综合模型仍然使用下游任务的目标损失进行训练。这里需要关注一个训练细节:考虑到教师网络的权重在步骤一的原始微调阶段已经达到一个较为稳定的状态,如果此时让其参与到教师-学生网络的整合训练中,反而会导致遗忘问题。另外,步骤三的目的是让学生网络的各个单元模块尽量得到更新,如果让教师网络参与到梯度更新,可能会让学生网络的单元模块被忽略。冻结教师网络的权重也能提升整个模型训练的效率。基于以上考虑,在梯度传递的时候,所有属于教师网络的权重参数都冻结不参与梯度计算,学生网络的相关单元模块的权重参数参与梯度更新。
步骤四:单独微调学生模型。步骤三结束后,由于每步训练时,只有部分不同的学生网络的单元模块参与到训练中,学生网络所有的单元模块并没有整合到一起参与到任务训练中,因此需要添加一个单独微调学生模型的过程。将学生网络所有编码器单元模块重新组合成完整的编码器,并利用教师网络的特征层和输出层对下游任务场景进行微调,最终输出压缩模型,用于推理下游任务。
本发明一种基于知识蒸馏的预训练语言模型的压缩平台包括:
数据加载组件:用于获取登陆用户上传的待压缩的包含具体自然语言处理下游任务的BERT模型和面向多任务的预训练语言模型的训练样本,所述训练样本是满足监督学习任务的带标签的文本样本。
压缩组件:用于将面向多任务的大规模语言模型进行压缩,包括教师模型微调模块、教师-学生模型蒸馏模块和学生模型微调模块。
教师模型微调模块负责加载BERT预训练模型,并且将所述训练样本输入教师模型(包含下游任务的BERT模型)进行微调,输出教师模型;
教师-学生模型蒸馏模块利用所述教师模型微调模块获得的教师模型,基于特征映射知识蒸馏和自注意力交叉知识蒸馏和基于伯努利概率分布的线性学习,更新学生模型的各个单元模块的权重参数;
学生模型微调模块基于所述知识蒸馏所得的学生模型进行微调,是将学生网络所有编码器单元模块重新组合成完整的编码器,并利用教师网络的特征层和输出层对下游任务场景进行微调,输出最终微调好的学生模型,即登陆用户需求的包含下游任务的预训练语言模型压缩模型。将所述压缩模型输出到指定的容器,可供所述登陆用户下载,并在所述平台的输出压缩模型的页面呈现压缩前后模型大小的对比信息。
推理组件:登陆用户从所述平台获取压缩模型,用户利用所述压缩组件输出的压缩模型在实际场景的数据集上对登陆用户上传的自然语言处理下游任务的新数据进行推理。并在所述平台的压缩模型推理页面呈现压缩前后推理速度的对比信息。
本发明可根据登陆用户上传的包含具体某个自然语言处理下游任务的BERT模型进行压缩,登陆用户可以下载所述平台生成的已压缩的模型架构,并在终端上进行部署。也可以直接在所述平台上对自然语言处理下游任务进行推理。
本发明设计了自注意力交叉知识蒸馏策略,充分利用教师模型的自注意力分布具有检测词语之间语义和句法的能力,训练前期,学生网络接受教师网络自注意层的监督信号以对后续层进行训练,从而可避免估计误差过大并传播的现象。训练后期,学生网络自注意层输入给教师网络,从而在相同输入前提下让学生网络模仿教师网络的输出行为。网络自注意层上进行凸组合交叉知识蒸馏的策略促使教师模型和学生模型深度相互学习。这样的特性极大地提升了小样本情况下大规模语言压缩模型的泛化能力。此外,通过基于伯努利概率分布的线性学习驱动编码器模块迁移的策略,在训练初始阶段,可以使得更多教师模型的编码器模块参与学习,将更多教师模型的特征映射和自注意力知识参与进来,提升整个学生模型的质量,得到更小的损失函数值,从而使得整个训练过程平滑,避免了模型前期学习过程中过于震荡的现象。在训练后期,当学生模型整体性能具有比较好的表现,此时让更多学生模型的知识参与学习,使得学生模型逐渐摆脱对教师模型的依赖,使得模型整体能够平稳过度到学生模型的微调阶段,提升整个模型的泛化能力。
下面将以电影评论进行情感分类任务对本发明的技术方案做进一步的详细描述。
通过所述平台的数据加载组件获取登陆用户上传的单个句子的文本分类任务的BERT模型和情感分析数据集SST-2;
通过所述平台加载BERT预训练模型,对包含文本分类任务的BERT模型进行微调,获得教师模型;
通过所述平台的压缩组件,逐渐完成从教师到学生的特征映射和自注意分布的知识蒸馏,更新学生网络的各个单元模块的权重参数;
基于所述知识蒸馏所得的学生模型进行微调,将学生网络所有编码器单元模块重新组合成完整的编码器,并利用教师网络的特征层和输出层对下游任务场景进行微调,最终,平台输出登陆用户需求的包含文本分类任务的BERT模型的压缩模型。
将所述压缩模型输出到指定的容器,可供所述登陆用户下载,并在所述平台的输出压缩模型的页面呈现压缩前后模型大小的对比信息,压缩前模型大小为110M,压缩后为66M,压缩了40%。如下表1所示。
表1:文本分类任务BERT模型压缩前后对比信息
文本分类任务(SST-2)(包含67K个样本) 压缩前 压缩后 对比
模型大小 110M 66M 压缩40%
推理精度 91.5% 91.8% 提升0.3%
通过所述平台的推理组件,利用所述平台输出的压缩模型对登陆用户上传的SST-2测试集数据进行推理,并在所述平台的压缩模型推理页面呈现压缩后比压缩前推理速度加快1.95倍,并且推理精度从压缩前的91.5%提升为91.8%。

Claims (6)

1.一种基于知识蒸馏的预训练语言模型的压缩方法,其特征在于,该方法对BERT模型进行压缩,包括特征映射知识蒸馏模块、自注意力交叉知识蒸馏模块和基于伯努利概率分布的线性学习模块;其中,原始的BERT模型为教师模型,压缩后的BERT模型为学生模型;特征映射知识蒸馏模块基于一种特征迁移的知识蒸馏策略,在教师模型的知识蒸馏到学生模型的过程中,将学生模型每一层的特征映射逼近教师模型的特征映射,学生模型关注教师模型的中间层特征,并利用这些中间层特征指导学生模型;自注意力交叉知识蒸馏模块通过交叉连接教师模型和学生模型的自注意力单元,通过在网络自注意层上进行凸组合交叉连接的方式,实现教师模型和学生模型的深度相互学习;基于伯努利概率分布的线性学习模块逐渐完成从教师模型到学生模型的特征映射和自注意分布的知识迁移;
所述自注意力交叉知识蒸馏模块包括以下三个阶段:
第一阶段:教师网络的自注意力单元输入学生网络,并最小化迁移目标函数,具体地,将教师网络的自注意力单元当作基本真值,在网络自注意力单元位置输入学生网络,学生网络接受正确的监督信号以对后续层进行训练,避免估计误差过大并传播的现象;
第二阶段:学生网络的自注意力单元输入教师网络,并最小化迁移目标函数;由于估计误差在学生网络上逐层传播,导致在同一层位置上学生网络输入和教师网络输入存在差异;将学生网络自注意力单元输入给教师网络,实现了在相同输入前提下让学生网络模仿教师网络的输出行为;
第三阶段:在网络自注意力单元上将所述第一阶段和第二阶段的迁移目标函数进行凸组合,实现交叉迁移的蒸馏策略;
所述基于伯努利概率分布的线性学习模块用于为驱动特征映射知识蒸馏模块和自注意力交叉知识蒸馏模块设置不同的线性迁移概率,包括以下两个步骤:
步骤一:特征映射知识蒸馏模块和自注意力交叉知识蒸馏模块均采用伯努利概率分布的迁移概率,即假设当前迁移第i个模块,先通过一个伯努利分布,采样一个随机变量X,X为0或1;当随机变量为1时代表当前模块进行迁移学习,否则不进行;
步骤二:虽然步骤一中设置一个恒定的迁移概率p,可以满足压缩模型的需要,但是线性学习驱动的迁移概率有助于逐步迁移模型中的编码器模块,本步骤设计了一个线性学习驱动的迁移概率plinear来动态调整步骤一中的迁移概率p,即
plinear=min(1,k*i+b)
其中,plinear表示当前迁移模块的迁移概率,第i个模块迁移对应当前训练第i步,b表示未训练时的初始迁移概率;k为大于0的动态值,且满足在训练增加至1000步、5000步、10000步、30000步时,相应地,plinear逐渐增加至0.25、0.5、0.75、1.00。
2.根据权利要求1所述基于知识蒸馏的预训练语言模型的压缩方法,其特征在于,所述特征映射知识蒸馏模块中增加层间归一化以稳定层间训练损失;训练学生网络时,最小化特征图转换中均值和方差两个统计差异。
3.根据权利要求1所述基于知识蒸馏的预训练语言模型的压缩方法,其特征在于,所述自注意力交叉知识蒸馏模块的迁移目标函数是最小化学生模型和教师模型的注意力分布之间的相对熵。
4.根据权利要求1所述基于知识蒸馏的预训练语言模型的压缩方法,其特征在于,所述初始迁移概率b取值范围在0.1至0.3之间。
5.一种根据权利要求1所述基于知识蒸馏的预训练语言模型的压缩方法的平台,其特征在于,该平台包括以下组件:
数据加载组件:用于获取面向多任务的BERT模型及其训练样本;所述训练样本是满足监督学习任务的有标签的文本样本;
压缩组件:用于将面向多任务的大规模语言模型进行压缩,包括教师模型微调模块、教师-学生模型蒸馏模块和学生模型微调模块;其中,教师模型微调模块负责加载BERT模型,将训练样本输入包含下游任务的BERT模型进行微调,输出教师模型;教师-学生模型蒸馏模块利用所述教师模型微调模块获得的教师模型,通过所述特征映射知识蒸馏模块、自注意力交叉知识蒸馏模块和基于伯努利概率分布的线性学习模块,逐渐完成从教师到学生的特征映射和自注意分布的知识蒸馏,更新学生网络的各个单元模块的权重参数;学生模型微调模块是将学生网络所有编码器单元模块重新组合成完整的编码器,并利用教师网络的特征层和输出层对下游任务场景进行微调,输出微调好的学生模型,作为最终的压缩模型;
推理组件:利用所述压缩组件输出的压缩模型在实际场景的数据集上对自然语言处理下游任务进行推理。
6.根据权利要求5所述平台,其特征在于,所述压缩组件将所述压缩模型输出到指定的容器,供用户下载,并呈现压缩前后模型大小的对比信息;通过推理组件利用压缩模型对自然语言处理下游任务进行推理,并呈现压缩前后推理速度的对比信息。
CN202010910566.3A 2020-09-02 2020-09-02 基于知识蒸馏的预训练语言模型的压缩方法及平台 Active CN111767711B (zh)

Priority Applications (5)

Application Number Priority Date Filing Date Title
CN202010910566.3A CN111767711B (zh) 2020-09-02 2020-09-02 基于知识蒸馏的预训练语言模型的压缩方法及平台
PCT/CN2020/138019 WO2021248868A1 (zh) 2020-09-02 2020-12-21 基于知识蒸馏的预训练语言模型的压缩方法及平台
JP2022570419A JP7381813B2 (ja) 2020-09-02 2020-12-21 知識蒸留に基づく予めトレーニング言語モデルの圧縮方法及びプラットフォーム
GB2214161.8A GB2608919A (en) 2020-09-02 2020-12-21 Knowledge distillation-based compression method for pre-trained language model, and platform
US17/483,805 US11341326B2 (en) 2020-09-02 2021-09-24 Compression method and platform of pre-training language model based on knowledge distillation

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010910566.3A CN111767711B (zh) 2020-09-02 2020-09-02 基于知识蒸馏的预训练语言模型的压缩方法及平台

Publications (2)

Publication Number Publication Date
CN111767711A CN111767711A (zh) 2020-10-13
CN111767711B true CN111767711B (zh) 2020-12-08

Family

ID=72729279

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010910566.3A Active CN111767711B (zh) 2020-09-02 2020-09-02 基于知识蒸馏的预训练语言模型的压缩方法及平台

Country Status (5)

Country Link
US (1) US11341326B2 (zh)
JP (1) JP7381813B2 (zh)
CN (1) CN111767711B (zh)
GB (1) GB2608919A (zh)
WO (1) WO2021248868A1 (zh)

Families Citing this family (61)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111767711B (zh) 2020-09-02 2020-12-08 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
GB2609768A (en) * 2020-11-02 2023-02-15 Zhejiang Lab Multi-task language model-oriented meta-knowledge fine tuning method and platform
CN112418291B (zh) * 2020-11-17 2024-07-26 平安科技(深圳)有限公司 一种应用于bert模型的蒸馏方法、装置、设备及存储介质
CN112529178B (zh) * 2020-12-09 2024-04-09 中国科学院国家空间科学中心 一种适用于无预选框检测模型的知识蒸馏方法及系统
CN112464959B (zh) * 2020-12-12 2023-12-19 中南民族大学 基于注意力和多重知识迁移的植物表型检测系统及其方法
JP7381814B2 (ja) * 2020-12-15 2023-11-16 之江実験室 マルチタスク向けの予めトレーニング言語モデルの自動圧縮方法及びプラットフォーム
CN112232511B (zh) * 2020-12-15 2021-03-30 之江实验室 面向多任务的预训练语言模型自动压缩方法及平台
CN112580783B (zh) * 2020-12-16 2024-03-22 浙江工业大学 一种高维深度学习模型向低维迁移知识的跨维度知识迁移方法
CN112613273B (zh) * 2020-12-16 2022-09-23 上海交通大学 多语言bert序列标注模型的压缩方法及系统
GB2610319A (en) * 2020-12-17 2023-03-01 Zhejiang Lab Automatic compression method and platform for multilevel knowledge distillation-based pre-trained language model
CN112241455B (zh) * 2020-12-17 2021-05-04 之江实验室 基于多层级知识蒸馏预训练语言模型自动压缩方法及平台
CN112613559B (zh) * 2020-12-23 2023-04-07 电子科技大学 基于相互学习的图卷积神经网络节点分类方法、存储介质和终端
CN112365385B (zh) * 2021-01-18 2021-06-01 深圳市友杰智新科技有限公司 基于自注意力的知识蒸馏方法、装置和计算机设备
CN113159168B (zh) * 2021-04-19 2022-09-02 清华大学 基于冗余词删除的预训练模型加速推理方法和系统
US11977842B2 (en) * 2021-04-30 2024-05-07 Intuit Inc. Methods and systems for generating mobile enabled extraction models
CN113177415B (zh) * 2021-04-30 2024-06-07 科大讯飞股份有限公司 语义理解方法、装置、电子设备和存储介质
CN113222123B (zh) * 2021-06-15 2024-08-09 深圳市商汤科技有限公司 模型训练方法、装置、设备及计算机存储介质
CN113420123A (zh) * 2021-06-24 2021-09-21 中国科学院声学研究所 语言模型的训练方法、nlp任务处理方法及装置
US11763082B2 (en) 2021-07-12 2023-09-19 International Business Machines Corporation Accelerating inference of transformer-based models
CN113592007B (zh) * 2021-08-05 2022-05-31 哈尔滨理工大学 一种基于知识蒸馏的不良图片识别系统、方法、计算机及存储介质
CN113836311A (zh) * 2021-08-12 2021-12-24 中国科学技术大学 解耦的试题表征及应用方法
CN113849641B (zh) * 2021-09-26 2023-10-24 中山大学 一种跨领域层次关系的知识蒸馏方法和系统
CN113887610B (zh) * 2021-09-29 2024-02-02 内蒙古工业大学 基于交叉注意力蒸馏Transformer的花粉图像分类方法
CN113887230B (zh) * 2021-09-30 2024-06-25 北京熵简科技有限公司 一种面向金融场景的端到端自然语言处理训练系统与方法
US11450225B1 (en) * 2021-10-14 2022-09-20 Quizlet, Inc. Machine grading of short answers with explanations
EP4224379A4 (en) * 2021-12-03 2024-02-14 Contemporary Amperex Technology Co., Limited METHOD AND SYSTEM FOR RAPID ANOMALY DETECTION BASED ON CONTRASTIVE IMAGE DISTILLATION
CN114240892B (zh) * 2021-12-17 2024-07-02 华中科技大学 一种基于知识蒸馏的无监督工业图像异常检测方法及系统
CN114461871B (zh) * 2021-12-21 2023-03-28 北京达佳互联信息技术有限公司 推荐模型训练方法、对象推荐方法、装置及存储介质
CN114004315A (zh) * 2021-12-31 2022-02-01 北京泰迪熊移动科技有限公司 一种基于小样本进行增量学习的方法及装置
CN114708467B (zh) * 2022-01-27 2023-10-13 西安交通大学 基于知识蒸馏的不良场景识别方法及系统及设备
CN114863248B (zh) * 2022-03-02 2024-04-26 武汉大学 一种基于深监督自蒸馏的图像目标检测方法
CN114972839B (zh) * 2022-03-30 2024-06-25 天津大学 一种基于在线对比蒸馏网络的广义持续分类方法
CN114580571B (zh) * 2022-04-01 2023-05-23 南通大学 一种基于迁移互学习的小样本电力设备图像分类方法
CN114972904B (zh) * 2022-04-18 2024-05-31 北京理工大学 一种基于对抗三元组损失的零样本知识蒸馏方法及系统
CN114882397B (zh) * 2022-04-25 2024-07-05 国网江苏省电力有限公司电力科学研究院 一种基于交叉注意机制动态知识传播的危险车辆识别方法
CN114880347A (zh) * 2022-04-27 2022-08-09 北京理工大学 一种基于深度学习的自然语言转化为sql语句的方法
CN114819148B (zh) * 2022-05-17 2024-07-02 西安电子科技大学 基于不确定性估计知识蒸馏的语言模型压缩方法
CN114969332A (zh) * 2022-05-18 2022-08-30 北京百度网讯科技有限公司 训练文本审核模型的方法和装置
CN115064155B (zh) * 2022-06-09 2024-09-06 福州大学 一种基于知识蒸馏的端到端语音识别增量学习方法及系统
CN115309849A (zh) * 2022-06-27 2022-11-08 北京邮电大学 一种基于知识蒸馏的特征提取方法、装置及数据分类方法
CN115131627B (zh) * 2022-07-01 2024-02-20 贵州大学 一种轻量化植物病虫害目标检测模型的构建和训练方法
CN115019183B (zh) * 2022-07-28 2023-01-20 北京卫星信息工程研究所 基于知识蒸馏和图像重构的遥感影像模型迁移方法
CN115457006B (zh) * 2022-09-23 2023-08-22 华能澜沧江水电股份有限公司 基于相似一致性自蒸馏的无人机巡检缺陷分类方法及装置
CN115272981A (zh) * 2022-09-26 2022-11-01 山东大学 云边共学习输电巡检方法与系统
CN115511059B (zh) * 2022-10-12 2024-02-09 北华航天工业学院 一种基于卷积神经网络通道解耦的网络轻量化方法
CN115423540B (zh) * 2022-11-04 2023-02-03 中邮消费金融有限公司 一种基于强化学习的金融模型知识蒸馏方法及装置
CN116110022B (zh) * 2022-12-10 2023-09-05 河南工业大学 基于响应知识蒸馏的轻量化交通标志检测方法及系统
CN115797976B (zh) * 2023-01-12 2023-05-30 广州紫为云科技有限公司 一种低分辨率的实时手势识别方法
CN117152788B (zh) * 2023-05-08 2024-10-01 东莞理工学院 基于知识蒸馏与多任务自监督学习的骨架行为识别方法
CN116340779A (zh) * 2023-05-30 2023-06-27 北京智源人工智能研究院 一种下一代通用基础模型的训练方法、装置和电子设备
CN116415005B (zh) * 2023-06-12 2023-08-18 中南大学 一种面向学者学术网络构建的关系抽取方法
CN116542321B (zh) * 2023-07-06 2023-09-01 中科南京人工智能创新研究院 基于扩散模型的图像生成模型压缩和加速方法及系统
CN116776744B (zh) * 2023-08-15 2023-10-31 工业云制造(四川)创新中心有限公司 一种基于增强现实的装备制造控制方法及电子设备
CN117436014A (zh) * 2023-10-06 2024-01-23 重庆邮电大学 一种基于掩膜生成蒸馏的心音信号异常检测算法
CN117009830B (zh) * 2023-10-07 2024-02-13 之江实验室 一种基于嵌入特征正则化的知识蒸馏方法和系统
CN117612247B (zh) * 2023-11-03 2024-07-30 重庆利龙中宝智能技术有限公司 一种基于知识蒸馏的动静态手势识别方法
CN117197590B (zh) * 2023-11-06 2024-02-27 山东智洋上水信息技术有限公司 一种基于神经架构搜索与知识蒸馏的图像分类方法及装置
CN117668622B (zh) * 2024-02-01 2024-05-10 山东能源数智云科技有限公司 设备故障诊断模型的训练方法、故障诊断方法及装置
CN117892139B (zh) * 2024-03-14 2024-05-14 中国医学科学院医学信息研究所 基于层间比对的大语言模型训练和使用方法及相关装置
CN118520904B (zh) * 2024-07-25 2024-10-15 山东浪潮科学研究院有限公司 基于大语言模型的识别训练方法、识别方法
CN118585794A (zh) * 2024-08-02 2024-09-03 四川龙裕天凌电子科技有限公司 一种高频高速多层混压电路板制作方法

Family Cites Families (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10575788B2 (en) * 2016-10-18 2020-03-03 Arizona Board Of Regents On Behalf Of Arizona State University Compressive sensing of quasi-periodic signals using generative models
US11210467B1 (en) * 2017-04-13 2021-12-28 Snap Inc. Machine learned language modeling and identification
CN107247989B (zh) * 2017-06-15 2020-11-24 北京图森智途科技有限公司 一种实时的计算机视觉处理方法及装置
CN108830288A (zh) * 2018-04-25 2018-11-16 北京市商汤科技开发有限公司 图像处理方法、神经网络的训练方法、装置、设备及介质
CN110232203B (zh) * 2019-04-22 2020-03-03 山东大学 知识蒸馏优化rnn短期停电预测方法、存储介质及设备
CN110147836B (zh) * 2019-05-13 2021-07-02 腾讯科技(深圳)有限公司 模型训练方法、装置、终端及存储介质
CN110097178A (zh) * 2019-05-15 2019-08-06 电科瑞达(成都)科技有限公司 一种基于熵注意的神经网络模型压缩与加速方法
CN110880036B (zh) * 2019-11-20 2023-10-13 腾讯科技(深圳)有限公司 神经网络压缩方法、装置、计算机设备及存储介质
CN111062489B (zh) 2019-12-11 2023-10-20 北京知道创宇信息技术股份有限公司 一种基于知识蒸馏的多语言模型压缩方法、装置
US11797862B2 (en) * 2020-01-22 2023-10-24 Google Llc Extreme language model compression with optimal sub-words and shared projections
CN111461226A (zh) * 2020-04-01 2020-07-28 深圳前海微众银行股份有限公司 对抗样本生成方法、装置、终端及可读存储介质
CN115699029A (zh) * 2020-06-05 2023-02-03 华为技术有限公司 利用神经网络中的后向传递知识改进知识蒸馏
CN111767110B (zh) * 2020-07-01 2023-06-23 广州视源电子科技股份有限公司 图像处理方法、装置、系统、电子设备及存储介质
CN111767711B (zh) * 2020-09-02 2020-12-08 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台

Also Published As

Publication number Publication date
GB202214161D0 (en) 2022-11-09
CN111767711A (zh) 2020-10-13
GB2608919A (en) 2023-01-18
US20220067274A1 (en) 2022-03-03
WO2021248868A1 (zh) 2021-12-16
JP2023523644A (ja) 2023-06-06
US11341326B2 (en) 2022-05-24
JP7381813B2 (ja) 2023-11-16
GB2608919A9 (en) 2023-05-10

Similar Documents

Publication Publication Date Title
CN111767711B (zh) 基于知识蒸馏的预训练语言模型的压缩方法及平台
CN103679139B (zh) 基于粒子群优化bp网络的人脸识别方法
CN106952646A (zh) 一种基于自然语言的机器人交互方法和系统
CN107342078A (zh) 对话策略优化的冷启动系统和方法
CN105704013A (zh) 基于上下文的话题更新数据处理方法及装置
CN113627545B (zh) 一种基于同构多教师指导知识蒸馏的图像分类方法及系统
CN113392640B (zh) 一种标题确定方法、装置、设备及存储介质
CN102522010A (zh) 家庭语言学习方法、装置及系统
Gao Role of 5G network technology and artificial intelligence for research and reform of English situational teaching in higher vocational colleges
CN113343796B (zh) 一种基于知识蒸馏的雷达信号调制方式识别方法
CN114564513A (zh) 基于神经网络的海雾预测方法、装置、设备及存储介质
CN112989843B (zh) 意图识别方法、装置、计算设备及存储介质
CN109299805B (zh) 一种基于人工智能的在线教育课程请求处理方法
CN116521831A (zh) 集成自然语言理解算法和Web3.0技术的聊天机器人及方法
CN115187863A (zh) 多层次自适应知识蒸馏的轻量化高分遥感场景分类方法
CN115273828A (zh) 语音意图识别模型的训练方法、装置及电子设备
CN116362987A (zh) 基于多层级知识蒸馏的去雾模型压缩方法
CN118170890B (zh) 一种回复文本的生成方法和相关装置
CN117689041B (zh) 云端一体化的嵌入式大语言模型训练方法及语言问答方法
CN117786560B (zh) 一种基于多粒度级联森林的电梯故障分类方法及电子设备
Yang Evaluation method of English online course teaching effect based on ResNet algorithm
Liu Optimization and Comparative Modeling of Technology English Automatic Translation Algorithms for Mobile Platforms
CN202795705U (zh) 遥控器及智能控制系统
CN116681075A (zh) 文本识别方法以及相关设备
CN116975654A (zh) 对象互动方法、装置、电子设备、存储介质及程序产品

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant