CN117725960A - 基于知识蒸馏的语言模型训练方法、文本分类方法及设备 - Google Patents
基于知识蒸馏的语言模型训练方法、文本分类方法及设备 Download PDFInfo
- Publication number
- CN117725960A CN117725960A CN202410179392.6A CN202410179392A CN117725960A CN 117725960 A CN117725960 A CN 117725960A CN 202410179392 A CN202410179392 A CN 202410179392A CN 117725960 A CN117725960 A CN 117725960A
- Authority
- CN
- China
- Prior art keywords
- language model
- text
- loss function
- classification
- constructing
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 62
- 238000012549 training Methods 0.000 title claims abstract description 50
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 24
- 230000006870 function Effects 0.000 claims abstract description 55
- 239000013598 vector Substances 0.000 claims abstract description 47
- 239000011159 matrix material Substances 0.000 claims description 32
- 238000010276 construction Methods 0.000 claims description 19
- 238000004590 computer program Methods 0.000 claims description 10
- 230000008569 process Effects 0.000 claims description 9
- 238000003860 storage Methods 0.000 claims description 9
- 238000004821 distillation Methods 0.000 claims description 8
- 238000010606 normalization Methods 0.000 claims description 7
- 238000009826 distribution Methods 0.000 claims description 6
- 238000013145 classification model Methods 0.000 claims description 5
- 239000000203 mixture Substances 0.000 claims description 5
- 230000009467 reduction Effects 0.000 claims description 3
- 238000010586 diagram Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 7
- 238000012545 processing Methods 0.000 description 4
- 238000004364 calculation method Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 230000006835 compression Effects 0.000 description 2
- 238000007906 compression Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000003491 array Methods 0.000 description 1
- 235000019800 disodium phosphate Nutrition 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 230000004927 fusion Effects 0.000 description 1
- 230000010365 information processing Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
Landscapes
- Machine Translation (AREA)
Abstract
本发明公开了一种基于知识蒸馏的语言模型训练方法、文本分类方法及设备,包括:获取样本数据集,采用初始语言模型对样本数据集进行文本编码,得到句向量XE;基于句向量XE构建每个batch样本的软标签;构建损失函数,损失函数包含分类损失和差异损失;采用损失函数和样本数据集对初始语言模型进行训练,得到目标语言模型,使得该语言模型针对标准不完全的样本具有良好的识别效果,提升识别和分类的精准性。
Description
技术领域
本发明涉及数据处理领域,尤其涉及一种基于知识蒸馏的语言模型训练方法、文本分类方法及设备。
背景技术
目前业界中的医疗文本分类体系通常是单标签体系,即一个样本只有一个标签;但是在现实中一个医疗的文本,比如用户的query可能包含多个意图或类别;同时标注过程中不可避免会有标注噪音,比如错标的情况,这种错误标注会影响模型的效果;同时,真实的医疗文本分布中存在着样本不均衡的情况,即有些类别的样本多,有些类别的样本少,上述三种情形均会导致标注不完全,使得模型训练效果不佳,文本分类效果达不到要求。
发明内容
本发明实施例提供一种基于知识蒸馏的语言模型训练方法、文本分类方法及设备,以提高文本分类的精准性。
为了解决上述技术问题,本申请实施例提供一种基于知识蒸馏的语言模型训练方法,所述基于知识蒸馏的语言模型训练方法包括:
获取样本数据集,所述样本数据集中句子的输入为,其中,/>为句子长度x为对应的字特征,句子对应的标注为/>,/>,表示标签是单标签并且总标签类别有m个;
采用初始语言模型对所述样本数据集进行文本编码,得到句向量XE;
基于所述句向量XE构建每个batch样本的软标签;
构建损失函数,所述损失函数包含分类损失和差异损失;
采用所述损失函数和所述样本数据集对所述初始语言模型进行训练,得到目标语言模型。
可选地,所述初始语言模型的文本的编码器模型为采取Bert-base预训练模型,其中,词向量维度为768、隐藏层维度大小等于768、文本输入最大长度为512、由12个transformer层构成,每层的Multi-head Attention中包含12个head,样本数据集经过mean_pool操作得到句子的句向量为/>,其中/>,。
可选地,所述基于所述句向量XE构建每个batch样本的软标签包括:
根据句向量构建样本间的相似度概率矩阵A,/>;
基于相似度概率矩阵A,对样本中每个样本计算得到除本身外的预测概率的加权和,其中/>为超参数用来权衡原始预测分值和batch内融合概率的信息量,/>表示N个样本的M个标签类别的预测概率分值,当前经过一次传播得到一次传播软标签Q经过t次传播得到软标签/>。
可选地,所述根据句向量构建样本间的相似度概率矩阵A包括:
对句向量XE进行L2标准化得到标准化后的向量;
采用如下公式计算样本间的相似度概率矩阵A:
;
其中,dot()为点积操作,为行列为/>的对角矩阵,/>为常量,用于使得对角值为极小值,最后经过/>得到行之和为1的相似度概率矩阵A。
可选地,所述构建损失函数包括:
构建分类模型本身的标准的交叉熵损失函数,用来拟合学习标签信息;
构建用于减少和P分布间差异的KL散度损失函数/>,用于学习软标签信息;
最终损失函数如下:
;
其中,CE()为交叉熵损失函数;LK()为KL散度损失函数;r为调节权重比例的超参数。
为了解决上述技术问题,本申请实施例还提供一种文本分类方法,包括:
获取待分类的文本数据;
将所述待分类的文本数据输入到目标语言模型进行分类识别,得到分类结果。
为了解决上述技术问题,本申请实施例还提供一种基于知识蒸馏的语言模型训练装置,包括:
样本获取模块,用于获取样本数据集,所述样本数据集中句子的输入为,其中/>为句子长度x为对应的字特征,句子对应的标注为,/>,表示标签/>是单标签并且总标签类别有m个;
文本编码模块,用于采用初始语言模型对所述样本数据集进行文本编码,得到句向量XE;
软标签构建模块,用于基于所述句向量XE构建每个batch样本的软标签;
损失构建模块,用于构建损失函数,所述损失函数包含分类损失和差异损失;
模型训练模块,用于采用所述损失函数和所述样本数据集对所述初始语言模型进行训练,得到目标语言模型。
可选地,所述软标签构建模块包括:
矩阵构建单元,用于根据句向量构建样本间的相似度概率矩阵A,/>;
软标签生成单元,用于基于相似度概率矩阵A,对样本中每个样本计算得到除本身外的预测概率的加权和,其中/>为超参数用来权衡原始预测分值和batch内融合概率的信息量,/>表示N个样本的M个标签类别的预测概率分值,当前经过一次传播得到一次传播软标签Q经过t次传播得到软标签/>。
可选地,所述矩阵构建单元包括:
标准化子单元,用于对句向量XE进行L2标准化得到标准化后的向量;
计算子单元,用于采用如下公式计算样本间的相似度概率矩阵A:
;
其中,dot()为点积操作,为行列为/>的对角矩阵,/>为常量,用于使得对角值为极小值,最后经过/>得到行之和为1的相似度概率矩阵A。
可选地,所述损失构建模块包括:
第一构建单元,用于构建分类模型本身的标准的交叉熵损失函数,用来拟合学习标签信息;
第二构建单元,用于构建用于减少和P分布间差异的KL散度损失函数/>,用于学习软标签信息;
损失计算单元,用于构建最终损失函数如下:
;
其中,CE()为交叉熵损失函数;LK()为KL散度损失函数;r为调节权重比例的超参数。
为了解决上述技术问题,本申请实施例还提供一种文本分类装置,包括:
文本获取模块,用于获取待分类的文本数据;
文本分类模块,用于将所述待分类的文本数据输入到目标语言模型进行分类识别,得到分类结果。
为了解决上述技术问题,本申请实施例还提供一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述基于知识蒸馏的语言模型训练方法的步骤。
为了解决上述技术问题,本申请实施例还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现上述基于知识蒸馏的语言模型训练方法的步骤。
本发明实施例提供的基于知识蒸馏的语言模型训练方法、文本分类方法、装置、计算机设备及存储介质,通过获取样本数据集,采用初始语言模型对样本数据集进行文本编码,得到句向量XE;基于句向量XE构建每个batch样本的软标签;构建损失函数,损失函数包含分类损失和差异损失;采用损失函数和样本数据集对初始语言模型进行训练,得到目标语言模型,使得该语言模型针对标准不完全的样本具有良好的识别效果,提升识别和分类的精准性。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对本发明实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本申请可以应用于其中的示例性系统架构图;
图2是本申请的基于知识蒸馏的语言模型训练方法的一个实施例的流程图;
图3是本申请的文本分类方法的一个实施例的流程图;
图4是根据本申请的基于知识蒸馏的语言模型训练装置的一个实施例的结构示意图;
图5是根据本申请的文本分类装置的一个实施例的结构示意图;
图6是根据本申请的计算机设备的一个实施例的结构示意图。
具体实施方式
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同;本文中在申请的说明书中所使用的术语只是为了描述具体的实施例的目的,不是旨在于限制本申请;本申请的说明书和权利要求书及上述附图说明中的术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含。本申请的说明书和权利要求书或上述附图中的术语“第一”、“第二”等是用于区别不同对象,而不是用于描述特定顺序。
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
请参阅图1,如图1所示,系统架构100可以包括终端设备101、102、103,网络104和服务器105。网络104用以在终端设备101、102、103和服务器105之间提供通信链路的介质。网络104可以包括各种连接类型,例如有线、无线通信链路或者光纤电缆等等。
用户可以使用终端设备101、102、103通过网络104与服务器105交互,以接收或发送消息等。
终端设备101、102、103可以是具有显示屏并且支持网页浏览的各种电子设备,包括但不限于智能手机、平板电脑、电子书阅读器、MP3播放器( Moving Picture ExpertsGroup Audio Layer III,动态影像专家压缩标准音频层面3 )、MP4( Moving PictureExperts Group Audio Layer IV,动态影像专家压缩标准音频层面4 )播放器、膝上型便携计算机和台式计算机等等。
服务器105可以是提供各种服务的服务器,例如对终端设备101、102、103上显示的页面提供支持的后台服务器。
需要说明的是,本申请实施例所提供的基于知识蒸馏的语言模型训练方法由服务器执行,相应地,基于知识蒸馏的语言模型训练装置设置于服务器中。
应该理解,图1中的终端设备、网络和服务器的数目仅仅是示意性的。根据实现需要,可以具有任意数目的终端设备、网络和服务器,本申请实施例中的终端设备101、102、103具体可以对应的是实际生产中的应用系统。
请参阅图2,图2示出本发明实施例提供的一种基于知识蒸馏的语言模型训练方法,以该方法应用在图1中的服务端为例进行说明,详述如下:
S201:获取样本数据集,样本数据集中句子的输入为,其中/>为句子长度x为对应的字特征,句子对应的标注为/>,/>,表示标签/>是单标签并且总标签类别有m个。
S202:采用初始语言模型对样本数据集进行文本编码,得到句向量XE。
在本实施例一具体可选实施方式中,初始语言模型的文本的编码器模型为采取Bert-base预训练模型,其中,词向量维度为768、隐藏层维度大小等于768、文本输入最大长度为512、由12个transformer层构成,每层的Multi-head Attention中包含12个head,样本数据集经过mean_pool操作得到句子的句向量为/>,其中/>,/>。
S203:基于句向量XE构建每个batch样本的软标签。
在本实施例一具体可选实施方式中,基于句向量XE构建每个batch样本的软标签包括:
根据句向量构建样本间的相似度概率矩阵A,/>;
基于相似度概率矩阵A,对样本中每个样本计算得到除本身外的预测概率的加权和,其中/>为超参数用来权衡原始预测分值和batch内融合概率的信息量,/>表示N个样本的M个标签类别的预测概率分值,当前经过一次传播得到一次传播软标签Q经过t次传播得到软标签/>。
进一步地,根据句向量构建样本间的相似度概率矩阵A包括:
对句向量XE进行L2标准化得到标准化后的向量;
采用如下公式计算样本间的相似度概率矩阵A:
;
其中,dot()为点积操作,为行列为/>的对角矩阵,/>为常量,用于使得对角值为极小值,最后经过/>得到行之和为1的相似度概率矩阵A。
具体地,当前经过一次传播得到一次传播软标签经过/>次传播可以得到更加准确的软标签/>,该软标签/>包含了其余样本的预测信息,可以用来缓解多标签、样本不均衡以及噪音问题,因为综合多个样本预测得到/>有多个标签的信息并且不是one-hot的形式更加接近于多标签的信息形式,同时由于多个样本的综合类似于mixup做样本不均衡一样可以减少样本不均衡的影响,同时减少了错误标注的影响。实现根据同一个batch内其他样本的标注作为辅助监督信息来对当前进行进行监督训练,利用batch内所有样本形成的标签知识融合来提升分类效果。
需要说明的是,本实施例提出了一种有区别于传统Teacher蒸馏方式,构建一种基于样本的自增量方式来增强模型效果,同时可以扩展到同个模型不同checkpoint和Batch综合考虑的蒸馏方式,从多种蒸馏软目标的获取来使得模型学习更多的信息。
S204:构建损失函数,损失函数包含分类损失和差异损失。
在本实施例一具体可选实施方式中,构建损失函数包括:
构建分类模型本身的标准的交叉熵损失函数,用来拟合学习标签信息;
构建用于减少和P分布间差异的KL散度损失函数/>,用于学习软标签信息;
最终损失函数如下:
;
其中,CE()为交叉熵损失函数;LK()为KL散度损失函数;r为调节权重比例的超参数。
S205:采用损失函数和样本数据集对初始语言模型进行训练,得到目标语言模型。
本实施例中,获取样本数据集,采用初始语言模型对样本数据集进行文本编码,得到句向量XE;基于句向量XE构建每个batch样本的软标签;构建损失函数,损失函数包含分类损失和差异损失;采用损失函数和样本数据集对初始语言模型进行训练,得到目标语言模型,使得该语言模型针对标准不完全的样本具有良好的识别效果,提升识别和分类的精准性。
请参阅图3,图3示出本发明实施例提供的一种文本分类方法,以该方法应用在图1中的服务端为例进行说明,详述如下:
S206:获取待分类的文本数据。
S207:将待分类的文本数据输入到目标语言模型进行分类识别,得到分类结果。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
图4示出与上述实施例基于知识蒸馏的语言模型训练方法一一对应的基于知识蒸馏的语言模型训练装置的原理框图。如图4所示,该基于知识蒸馏的语言模型训练装置包括样本获取模块31、文本编码模块32、软标签构建模块33、损失构建模块34和模型训练模块35。各功能模块详细说明如下:
样本获取模块31,用于获取样本数据集,样本数据集中句子的输入为,其中/>为句子长度x为对应的字特征,句子对应的标注为,/>,表示标签/>是单标签并且总标签类别有m个;
文本编码模块32,用于采用初始语言模型对样本数据集进行文本编码,得到句向量XE;
软标签构建模块33,用于基于句向量XE构建每个batch样本的软标签;
损失构建模块34,用于构建损失函数,损失函数包含分类损失和差异损失;
模型训练模块35,用于采用损失函数和样本数据集对初始语言模型进行训练,得到目标语言模型。
可选地,软标签构建模块包括:
矩阵构建单元,用于根据句向量构建样本间的相似度概率矩阵A,/>;
软标签生成单元,用于基于相似度概率矩阵A,对样本中每个样本计算得到除本身外的预测概率的加权和,其中/>为超参数用来权衡原始预测分值和batch内融合概率的信息量,/>表示N个样本的M个标签类别的预测概率分值,当前经过一次传播得到一次传播软标签Q经过t次传播得到软标签/>。
可选地,矩阵构建单元包括:
标准化子单元,用于对句向量XE进行L2标准化得到标准化后的向量;
计算子单元,用于采用如下公式计算样本间的相似度概率矩阵A:
;
其中,dot()为点积操作,为行列为/>的对角矩阵,/>为常量,用于使得对角值为极小值,最后经过/>得到行之和为1的相似度概率矩阵A。
可选地,损失构建模块包括:
第一构建单元,用于构建分类模型本身的标准的交叉熵损失函数,用来拟合学习标签信息;
第二构建单元,用于构建用于减少和P分布间差异的KL散度损失函数/>,用于学习软标签信息;
损失计算单元,用于构建最终损失函数如下:
;
其中,CE()为交叉熵损失函数;LK()为KL散度损失函数;r为调节权重比例的超参数。
图5示出与上述实施例文本分类方法一一对应的文本分类装置的原理框图。如图5所示,该基于文本分类装置包括文本获取模块36和文本分类模块37。各功能模块详细说明如下:
文本获取模块36,用于获取待分类的文本数据;
文本分类模块37,用于将所述待分类的文本数据输入到目标语言模型进行分类识别,得到分类结果。
关于基于知识蒸馏的语言模型训练装置的具体限定可以参见上文中对于基于知识蒸馏的语言模型训练方法的限定,关于文本分类装置的具体限定可以参见上文中对于文本分类方法的限定在此不再赘述。上述基于知识蒸馏的语言模型训练装置、文本分类装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
为解决上述技术问题,本申请实施例还提供计算机设备。具体请参阅图6,图6为本实施例计算机设备基本结构框图。
所述计算机设备4包括通过系统总线相互通信连接存储器41、处理器42、网络接口43。需要指出的是,图中仅示出了具有组件连接存储器41、处理器42、网络接口43的计算机设备4,但是应理解的是,并不要求实施所有示出的组件,可以替代的实施更多或者更少的组件。其中,本技术领域技术人员可以理解,这里的计算机设备是一种能够按照事先设定或存储的指令,自动进行数值计算和/或信息处理的设备,其硬件包括但不限于微处理器、专用集成电路(Application Specific Integrated Circuit,ASIC)、可编程门阵列(Field-Programmable Gate Array,FPGA)、数字处理器 (Digital Signal Processor,DSP)、嵌入式设备等。
所述计算机设备可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机设备可以与用户通过键盘、鼠标、遥控器、触摸板或声控设备等方式进行人机交互。
所述存储器41至少包括一种类型的可读存储介质,所述可读存储介质包括闪存、硬盘、多媒体卡、卡型存储器(例如,SD或D界面显示存储器等)、随机访问存储器(RAM)、静态随机访问存储器(SRAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、可编程只读存储器(PROM)、磁性存储器、磁盘、光盘等。在一些实施例中,所述存储器41可以是所述计算机设备4的内部存储单元,例如该计算机设备4的硬盘或内存。在另一些实施例中,所述存储器41也可以是所述计算机设备4的外部存储设备,例如该计算机设备4上配备的插接式硬盘,智能存储卡(Smart Media Card, SMC),安全数字(Secure Digital, SD)卡,闪存卡(Flash Card)等。当然,所述存储器41还可以既包括所述计算机设备4的内部存储单元也包括其外部存储设备。本实施例中,所述存储器41通常用于存储安装于所述计算机设备4的操作系统和各类应用软件,例如基于知识蒸馏的语言模型训练程序代码、文本分类的程序代码等。此外,所述存储器41还可以用于暂时地存储已经输出或者将要输出的各类数据。
所述处理器42在一些实施例中可以是中央处理器(Central Processing Unit,CPU)、控制器、微控制器、微处理器、或其他数据处理芯片。该处理器42通常用于控制所述计算机设备4的总体操作。本实施例中,所述处理器42用于运行所述存储器41中存储的程序代码或者处理数据,例如运行文本分类的程序代码,或者,运行基于知识蒸馏的语言模型训练方法的步骤。
所述网络接口43可包括无线网络接口或有线网络接口,该网络接口43通常用于在所述计算机设备4与其他电子设备之间建立通信连接。
本申请还提供了另一种实施方式,即提供一种计算机可读存储介质,所述计算机可读存储介质存储有界面显示程序,所述界面显示程序可被至少一个处理器执行,以使所述至少一个处理器执行如上述的基于知识蒸馏的语言模型训练方法的步骤,或者,以使所述至少一个处理器执行如上述的文本分类方法的步骤。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本申请各个实施例所述的方法。
显然,以上所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例,附图中给出了本申请的较佳实施例,但并不限制本申请的专利范围。本申请可以以许多不同的形式来实现,相反地,提供这些实施例的目的是使对本申请的公开内容的理解更加透彻全面。尽管参照前述实施例对本申请进行了详细的说明,对于本领域的技术人员来而言,其依然可以对前述各具体实施方式所记载的技术方案进行修改,或者对其中部分技术特征进行等效替换。凡是利用本申请说明书及附图内容所做的等效结构,直接或间接运用在其他相关的技术领域,均同理在本申请专利保护范围之内。
Claims (10)
1.一种基于知识蒸馏的语言模型训练方法,其特征在于,包括:
获取样本数据集,所述样本数据集中句子的输入为,其中/>为句子长度x为对应的字特征,句子对应的标注为/>,/>,表示标签/>是单标签并且总标签类别有m个;
采用初始语言模型对所述样本数据集进行文本编码,得到句向量XE;
基于所述句向量XE构建每个batch样本的软标签;
构建损失函数,所述损失函数包含分类损失和差异损失;
采用所述损失函数和所述样本数据集对所述初始语言模型进行训练,得到目标语言模型。
2. 如权利要求1所述的基于知识蒸馏的语言模型训练方法,其特征在于,所述初始语言模型的文本的编码器模型为采取Bert-base预训练模型,其中,词向量维度为768、隐藏层维度大小等于768、文本输入最大长度为512、由12个transformer层构成,每层的Multi-head Attention中包含12个head,样本数据集经过mean_pool操作得到句子的句向量为/>,其中/>,/>。
3.如权利要求1所述的基于知识蒸馏的语言模型训练方法,其特征在于,所述基于所述句向量XE构建每个batch样本的软标签包括:
根据句向量构建样本间的相似度概率矩阵A,/>;
基于相似度概率矩阵A,对样本中每个样本计算得到除本身外的预测概率的加权和,其中/>为超参数用来权衡原始预测分值和batch内融合概率的信息量,/>表示N个样本的M个标签类别的预测概率分值,当前经过一次传播得到一次传播软标签Q经过t次传播得到软标签/>。
4.如权利要求3所述的基于知识蒸馏的语言模型训练方法,其特征在于,所述根据句向量构建样本间的相似度概率矩阵A包括:
对句向量XE进行L2标准化得到标准化后的向量;
采用如下公式计算样本间的相似度概率矩阵A:
;
其中,dot()为点积操作,为行列为/>的对角矩阵,/>为常量,用于使得对角值为极小值,最后经过/>得到行之和为1的相似度概率矩阵A。
5.如权利要求1所述的基于知识蒸馏的语言模型训练方法,其特征在于,所述构建损失函数包括:
构建分类模型本身的标准的交叉熵损失函数,用来拟合学习标签信息;
构建用于减少和P分布间差异的KL散度损失函数/>,用于学习软标签信息;
最终损失函数如下:
;
其中,CE()为交叉熵损失函数;LK()为KL散度损失函数;r为调节权重比例的超参数。
6.一种文本分类方法,其特征在于,包括:
获取待分类的文本数据;
将所述待分类的文本数据输入到目标语言模型进行分类识别,得到分类结果,其中,所述目标语言模型根据权利要求1至5任一项所述的基于知识蒸馏的语言模型训练方法训练得到。
7.一种基于知识蒸馏的语言模型训练装置,其特征在于,包括:
样本获取模块,用于获取样本数据集,所述样本数据集中句子的输入为,其中/>为句子长度x为对应的字特征,句子对应的标注为,/>,表示标签/>是单标签并且总标签类别有m个;
文本编码模块,用于采用初始语言模型对所述样本数据集进行文本编码,得到句向量XE;
软标签构建模块,用于基于所述句向量XE构建每个batch样本的软标签;
损失构建模块,用于构建损失函数,所述损失函数包含分类损失和差异损失;
模型训练模块,用于采用所述损失函数和所述样本数据集对所述初始语言模型进行训练,得到目标语言模型。
8.一种文本分类装置,其特征在于,包括:
文本获取模块,用于获取待分类的文本数据;
文本分类模块,用于将所述待分类的文本数据输入到目标语言模型进行分类识别,得到分类结果,其中,所述目标语言模型根据权利要求1至5任一项所述的基于知识蒸馏的语言模型训练方法训练得到。
9.一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至5任一项所述的基于知识蒸馏的语言模型训练方法,或者,所述处理器执行所述计算机程序时实现如权利要求6所述的文本分类方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至5任一项所述的基于知识蒸馏的语言模型训练方法,或者,所述计算机程序被处理器执行时实现如权利要求6所述的文本分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410179392.6A CN117725960A (zh) | 2024-02-18 | 2024-02-18 | 基于知识蒸馏的语言模型训练方法、文本分类方法及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410179392.6A CN117725960A (zh) | 2024-02-18 | 2024-02-18 | 基于知识蒸馏的语言模型训练方法、文本分类方法及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117725960A true CN117725960A (zh) | 2024-03-19 |
Family
ID=90209275
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410179392.6A Pending CN117725960A (zh) | 2024-02-18 | 2024-02-18 | 基于知识蒸馏的语言模型训练方法、文本分类方法及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117725960A (zh) |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112347763A (zh) * | 2020-12-03 | 2021-02-09 | 云知声智能科技股份有限公司 | 针对预训练语言模型bert的知识蒸馏方法、装置及系统 |
CN112613273A (zh) * | 2020-12-16 | 2021-04-06 | 上海交通大学 | 多语言bert序列标注模型的压缩方法及系统 |
CN112733550A (zh) * | 2020-12-31 | 2021-04-30 | 科大讯飞股份有限公司 | 基于知识蒸馏的语言模型训练方法、文本分类方法及装置 |
CN113673254A (zh) * | 2021-08-23 | 2021-11-19 | 东北林业大学 | 基于相似度保持的知识蒸馏的立场检测方法 |
CN114818902A (zh) * | 2022-04-21 | 2022-07-29 | 浪潮云信息技术股份公司 | 基于知识蒸馏的文本分类方法及系统 |
US20220343139A1 (en) * | 2021-04-15 | 2022-10-27 | Peyman PASSBAN | Methods and systems for training a neural network model for mixed domain and multi-domain tasks |
CN116205290A (zh) * | 2023-05-06 | 2023-06-02 | 之江实验室 | 一种基于中间特征知识融合的知识蒸馏方法和装置 |
-
2024
- 2024-02-18 CN CN202410179392.6A patent/CN117725960A/zh active Pending
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112347763A (zh) * | 2020-12-03 | 2021-02-09 | 云知声智能科技股份有限公司 | 针对预训练语言模型bert的知识蒸馏方法、装置及系统 |
CN112613273A (zh) * | 2020-12-16 | 2021-04-06 | 上海交通大学 | 多语言bert序列标注模型的压缩方法及系统 |
CN112733550A (zh) * | 2020-12-31 | 2021-04-30 | 科大讯飞股份有限公司 | 基于知识蒸馏的语言模型训练方法、文本分类方法及装置 |
US20220343139A1 (en) * | 2021-04-15 | 2022-10-27 | Peyman PASSBAN | Methods and systems for training a neural network model for mixed domain and multi-domain tasks |
CN113673254A (zh) * | 2021-08-23 | 2021-11-19 | 东北林业大学 | 基于相似度保持的知识蒸馏的立场检测方法 |
CN114818902A (zh) * | 2022-04-21 | 2022-07-29 | 浪潮云信息技术股份公司 | 基于知识蒸馏的文本分类方法及系统 |
CN116205290A (zh) * | 2023-05-06 | 2023-06-02 | 之江实验室 | 一种基于中间特征知识融合的知识蒸馏方法和装置 |
Non-Patent Citations (2)
Title |
---|
"DistilBERT, adistilledversionofBERT:smaller, faster, cheaperandlighter", ARXIV, 1 May 2020 (2020-05-01), pages 1 - 5 * |
苑婧: "融合多教师模型的知识蒸馏文本分类", 电子技术应用, vol. 49, no. 11, 6 November 2023 (2023-11-06), pages 42 - 48 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110532381B (zh) | 一种文本向量获取方法、装置、计算机设备及存储介质 | |
WO2021218028A1 (zh) | 基于人工智能的面试内容精炼方法、装置、设备及介质 | |
CN112287069B (zh) | 基于语音语义的信息检索方法、装置及计算机设备 | |
CN113158656B (zh) | 讽刺内容识别方法、装置、电子设备以及存储介质 | |
CN111694937A (zh) | 基于人工智能的面试方法、装置、计算机设备及存储介质 | |
US11947920B2 (en) | Man-machine dialogue method and system, computer device and medium | |
CN112836521A (zh) | 问答匹配方法、装置、计算机设备及存储介质 | |
CN112084752A (zh) | 基于自然语言的语句标注方法、装置、设备及存储介质 | |
CN114218945A (zh) | 实体识别方法、装置、服务器及存储介质 | |
CN115687934A (zh) | 意图识别方法、装置、计算机设备及存储介质 | |
CN115757731A (zh) | 对话问句改写方法、装置、计算机设备及存储介质 | |
CN116881446A (zh) | 一种语义分类方法、装置、设备及其存储介质 | |
CN116341646A (zh) | Bert模型的预训练方法、装置、电子设备及存储介质 | |
CN113627197B (zh) | 文本的意图识别方法、装置、设备及存储介质 | |
CN117725960A (zh) | 基于知识蒸馏的语言模型训练方法、文本分类方法及设备 | |
CN115238077A (zh) | 基于人工智能的文本分析方法、装置、设备及存储介质 | |
CN114218356A (zh) | 基于人工智能的语义识别方法、装置、设备及存储介质 | |
CN109933788B (zh) | 类型确定方法、装置、设备和介质 | |
CN112949320B (zh) | 基于条件随机场的序列标注方法、装置、设备及介质 | |
CN112732913B (zh) | 一种非均衡样本的分类方法、装置、设备及存储介质 | |
CN111680513B (zh) | 特征信息的识别方法、装置及计算机可读存储介质 | |
CN114462411B (zh) | 命名实体识别方法、装置、设备及存储介质 | |
CN112949320A (zh) | 基于条件随机场的序列标注方法、装置、设备及介质 | |
CN114860909A (zh) | 一种基于文章的回答推荐方法、装置、设备及介质 | |
CN116796730A (zh) | 基于人工智能的文本纠错方法、装置、设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |