CN114818902A - 基于知识蒸馏的文本分类方法及系统 - Google Patents

基于知识蒸馏的文本分类方法及系统 Download PDF

Info

Publication number
CN114818902A
CN114818902A CN202210421020.0A CN202210421020A CN114818902A CN 114818902 A CN114818902 A CN 114818902A CN 202210421020 A CN202210421020 A CN 202210421020A CN 114818902 A CN114818902 A CN 114818902A
Authority
CN
China
Prior art keywords
model
language model
training
classification
teacher
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210421020.0A
Other languages
English (en)
Inventor
张烈帅
赵振修
魏静如
解一豪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Inspur Cloud Information Technology Co Ltd
Original Assignee
Inspur Cloud Information Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Inspur Cloud Information Technology Co Ltd filed Critical Inspur Cloud Information Technology Co Ltd
Priority to CN202210421020.0A priority Critical patent/CN114818902A/zh
Publication of CN114818902A publication Critical patent/CN114818902A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/279Recognition of textual entities
    • G06F40/284Lexical analysis, e.g. tokenisation or collocates
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/042Knowledge-based neural networks; Logical representations of neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computational Linguistics (AREA)
  • Evolutionary Computation (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Machine Translation (AREA)

Abstract

本发明公开了基于知识蒸馏的文本分类方法及系统,属于自然语言处理技术领域,本发明要解决的技术问题为如何利用知识蒸馏,并借助复杂模型的精度优势得到精度相当的轻量级模型,采用的技术方案为:该方法具体如下:获取无监督语料并对无监督语料进行数据预处理;基于大规模无监督语料训练得到教师语言模型;使用针对具体分类任务的有监督训练语料对教师语言模型通过fine‑tuning进行分类任务训练,获得训练好的教师语言模型;根据具体分类任务和训练好的教师语言模型构造学生模型;根据教师语言模型的中间层输出和最终输出,构造损失函数,对学生模型进行训练,获取最终的学生模型;使用最终的学生模型进行文本分类的预测:输入新数据进行分类结构的预测。

Description

基于知识蒸馏的文本分类方法及系统
技术领域
本发明涉及自然语言处理技术领域,具体地说是一种基于知识蒸馏的文本分类方法及系统。
背景技术
在自然语言处理(NLP)领域,文本分类任务有广泛的应用,比如:垃圾过滤,新闻分类,情感分析等等。
自从BERT横空问世,使用预训练语言模型在下游任务通过fine-tuning已经成为越来越成为自然语言处理领域的范式,在自然语言任务中获取优异的效果。但这种效果带来的代价是,常用的预训练语言模型,如BERT、GPT等都是在大量的语料基础上通过复杂的网络结构训练得来,在参数存储和推理速度等方面都对硬件计算资源带来极大的要求。在资源不足的场景,特别是在万物互联的背景下,边缘端的推理服务无法满足性能的要求。
把复杂模型或者多个模型Ensemble(Teacher)学到的知识迁移到另一个轻量级模型(Student)上叫知识蒸馏。其目的是使模型变轻量的同时(方便部署),尽量不损失性能。故如何利用知识蒸馏,并借助复杂模型的精度优势得到精度相当的轻量级模型是目前亟待解决的技术问题。
发明内容
本发明的技术任务是提供一种基于知识蒸馏的文本分类方法及系统,来解决如何利用知识蒸馏,并借助复杂模型的精度优势得到精度相当的轻量级模型的问题。
本发明的技术任务是按以下方式实现的,一种基于知识蒸馏的文本分类方法,该方法具体如下:
获取无监督语料(数据1)并对无监督语料进行数据预处理;
基于大规模无监督语料训练得到教师语言模型(模型T);
使用针对具体分类任务的有监督训练语料对教师语言模型(模型T)通过fine-tuning进行分类任务训练,获得训练好的教师语言模型(模型T);
根据具体分类任务和训练好的教师语言模型(模型T)构造学生模型(模型S);
根据教师语言模型(模型T)的中间层输出和最终输出,构造损失函数,对学生模型(模型S)进行训练,获取最终的学生模型(模型S);
使用最终的学生模型(模型S)进行文本分类的预测:经过前面的训练过程,即获得了最终的模型S。模型S相对模型T,模型结构简化,参数大大减小,可以较大的提升预测效率,减小对硬件资源的依赖,对于边缘设备等可以更加方便的进行部署,输入新数据进行分类结构的预测。
作为优选,教师语言模型(模型T)设定为语言模型,在训练时直接使用无监督语料,即正常文本语言文字;
无监督语料是从任意的文章、著作、互联网博客或新闻进行搜集获取;从泛化角度考虑,收集不同领域及不同来源的语料数据;从性能角度考虑,语料数据大小为1G以上;
对无监督语料进行数据预处理具体为:
根据需要去除通用词;
自定义预处理函数去除字符;
对于BERT具有特定tokenizer方式的教师语言模型(模型T),使用对应的tokenizer函数进行处理。
作为优选,教师语言模型(模型T)采用BERT语言模型,BERT语言模型包括输入层、编码层和输出层;输入层用于词嵌入;编码层包括多层tansformer层,tansformer层用于编码;
BERT语言模型训练具体如下:
构建基于BERT的词嵌入网络向量表征信息,具体如下:
构建基于每个词的词语向量;
构建基于每个语句的段向量;
构建基于每个词的位置向量;
将词语向量、段向量和位置向量叠加,形成BERT的输入;
根据需要选择中间的tansformer层对BERT的输入进行编码;
将编码后的信息通过输出层输出,输出层包括对next sentence的预测和token(包括masked token的预测);
通过迭代,不断进行参数更新和模型评估,获取满足评估条件的教师语言模型(模型T)。
作为优选,针对具体的分类任务,具体任务数据为监督数据,监督数据包括原始文本和分类标签;
分类任务训练是针对具体任务数据对教师语言模型(模型T)进行微调,具体如下:
输入具体任务数据,构建基于BERT的分类模型,获取的模型参数为基础参数进行1个或多个epoch迭代,以获取基准的分类模型,即最终的T模型;
在训练时,为解决分类中可能的类别不平衡问题,使用focal loss函数,通过修改交叉熵函数,通过增加类别权重和样本难度权重调因子,提升模型精确度。
作为优选,学生模型(模型S)是基于教师语言模型(模型T)并选择每隔2层、3层或4层transformer抽取一层transformer的方式构造。
更优地,学生模型(模型S)是基于具体任务数据进行训练,具体如下:
构造损失函数;
在训练过程中,添加梯度扰动:通过梯度扰动,更新参数时,在原本的梯度基础上加入梯度叠加,增加模型的泛化性能,提高模型在新数据上的预测准确率;其中,使用基于L2范数的梯度叠加,公式如下:
Figure BDA0003607650430000031
Figure BDA0003607650430000032
g表示原始梯度;emb′表示经过扰动后的输出;g表示扰动后的梯度值;
其中,训练过程分为两个阶段:
①、将f和s置零,即针对网络中间层进行拟合,使S学生模型能够学习到教师语言模型的transformer结构参数;
②、适当减小m和c的值,并提高f和s的值,使S学生模型和教师语言模型在保持结构参数的情况下,学习对特定任务的预测。
更优地,构造损失函数具体如下:
(1)、针对标签的focal loss,公式如下:
Lf=-(1-pt)γlog(pt);
其中,pt表示分对的概率,γ用于调制难例,增加错误分类的重要性;
(2)、针对教师语言模型预测结果的软化softmax损失,以使模型更好的学习到数据的分布情况,公式如下:
Ls=-∑pilogsi
其中,pi和si分别为学习模型和教师模型的软化概率;
其中,软化概率分布定义如下:
Figure BDA0003607650430000041
其中,z为网络输出;T为调节因子;
(3)、针对对应学生模型与教师语言模型transformer层之间的MSE损失,公式如下:
Lm=∑MSE(trsS,trsT);
其中,trs为transformer的输出;
(4)、针对对应学生模型与教师语言模型transformer层之间的COS损失,公式如下:
Lc=∑COS(trsS,trsT);
其中COS损失定义如下:
Figure BDA0003607650430000051
即最终损失函数为损失函数加权:
L=f*Lf+s*Ls+m*Lm+c*Lc
其中,f、s、m、c分别为加权因子。
一种基于知识蒸馏的文本分类系统,该系统包括,
获取模块一,用于获取无监督语料(数据1)并对无监督语料进行数据预处理;
训练模块一,用于基于大规模无监督语料训练得到教师语言模型(模型T);
训练模块二,用于使用针对具体分类任务的有监督训练语料(数据2)对教师语言模型(模型T)通过fine-tuning进行分类任务训练,获得训练好的教师语言模型(模型T);
构造模块,用于根据具体分类任务和训练好的教师语言模型(模型T)构造学生模型(模型S);
获取模块二,用于根据教师语言模型(模型T)的中间层输出和最终输出,构造损失函数,对学生模型(模型S)进行训练,获取最终的学生模型(模型S);
预测模块,用于输入新数据,使用最终的学生模型(模型S)进行文本分类的预测。
一种电子设备,包括:存储器和至少一个处理器;
其中,所述存储器上存储有计算机程序;
所述至少一个处理器执行所述存储器存储的计算机程序,使得所述至少一个处理器执行如上述的基于知识蒸馏的文本分类方法。
一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,所述计算机程序可被处理器执行以实现如上述的基于知识蒸馏的文本分类方法。
本发明的基于知识蒸馏的文本分类方法及系统具有以下优点:
(一)本发明采用知识蒸馏,优化模型结构,减小模型大小,尽可能保留与模型T相当的准确率;
(二)本发明通过对教师模型和学生模型的构造和训练,在保留模型分类精度的同时,简化模型结构,以减少模型参数量,增加模型推理速度,使模型适应资源不足的场景,如边缘侧设备推理;
(三)本发明通过知识蒸馏,简化模型结构,减少模型参数,便于模型在边缘设备等硬件资源不充足的条件下部署使用;模型训练过程中通过损失函数的改进和训练过程的改进,有利于提升学生模型的准确率;
(四)本发明包括教师模型T和学生模型S,模型T为基础语言模型,基于大规模无监督语料1练而成;对于特定的文本分类任务,针对带标签的训练数据2进行微调;学生模型S为基于模型T结构和带标签数据2训练得到,模型简化,参数减少,适用于边缘端等资源不够充足的场景;
(五)本发明将训练过程分为两个阶段,以更好的拟合T模型结构参数,保证最终的结果准确率;
(六)本发明在训练过程中加入梯度扰动,以增强模型的泛化性能。
附图说明
下面结合附图对本发明进一步说明。
附图1为BERT的整体模型结构示意图;
附图2为构建基于BERT的词嵌入网络向量表征信息的示意图;
附图3为构造模型S的示意图;
附图4为梯度扰动的流程示意图;
附图5为基于知识蒸馏的文本分类方法的流程框图。
具体实施方式
参照说明书附图和具体实施例对本发明的基于知识蒸馏的文本分类方法及系统作以下详细地说明。
实施例1:
如附图5所示,本实施例提供了一种基于知识蒸馏的文本分类方法,该方法具体如下:
S1、获取无监督语料(数据1)并对无监督语料进行数据预处理;
S2、基于大规模无监督语料训练得到教师语言模型(模型T);
S3、使用针对具体分类任务的有监督训练语料对教师语言模型(模型T)通过fine-tuning进行分类任务训练,获得训练好的教师语言模型(模型T);
S4、根据具体分类任务和训练好的教师语言模型(模型T)构造学生模型(模型S);
S5、根据教师语言模型(模型T)的中间层输出和最终输出,构造损失函数,对学生模型(模型S)进行训练,获取最终的学生模型(模型S);
S6、使用最终的学生模型(模型S)进行文本分类的预测:经过前面的训练过程,即获得了最终的模型S。模型S相对模型T,模型结构简化,参数大大减小,可以较大的提升预测效率,减小对硬件资源的依赖,对于边缘设备等可以更加方便的进行部署,输入新数据进行分类结构的预测。
本实施例中的教师语言模型(模型T)设定为语言模型,在训练时直接使用无监督语料,即正常文本语言文字;
本实施例中的无监督语料是从任意的文章、著作、互联网博客或新闻进行搜集获取;从泛化角度考虑,收集不同领域及不同来源的语料数据;从性能角度考虑,语料数据大小为1G以上;
本实施例步骤S1中的对无监督语料进行数据预处理具体为:
S101、根据需要去除通用词;
S102、自定义预处理函数去除字符;
S103、对于BERT具有特定tokenizer方式的教师语言模型(模型T),使用对应的tokenizer函数进行处理。
如附图1所示,本实施例步骤S2中的教师语言模型(模型T)采用BERT语言模型,BERT语言模型包括输入层、编码层和输出层;输入层用于词嵌入;编码层包括多层tansformer层,tansformer层用于编码;
BERT语言模型训练具体如下:
S201、构建基于BERT的词嵌入网络向量表征信息,如附图2所示,具体如下:
S20101、构建基于每个词的词语向量;
S20102、构建基于每个语句的段向量;
S20103、构建基于每个词的位置向量;
S20104、将词语向量、段向量和位置向量叠加,形成BERT的输入;
S202、根据需要选择中间的tansformer层对BERT的输入进行编码;
S203、将编码后的信息通过输出层输出,输出层包括对next sentence的预测和token(包括masked token的预测);
S204、通过迭代,不断进行参数更新和模型评估,获取满足评估条件的教师语言模型(模型T)。
本实施例中的针对具体的分类任务,具体任务数据为监督数据,监督数据包括原始文本和分类标签;
分类任务训练是针对具体任务数据对教师语言模型(模型T)进行微调,具体如下:
(1)、输入具体任务数据,构建基于BERT的分类模型,获取的模型参数为基础参数进行1个或多个epoch迭代,以获取基准的分类模型,即最终的T模型;
(2)、在训练时,为解决分类中可能的类别不平衡问题,使用focal loss函数,通过修改交叉熵函数,通过增加类别权重和样本难度权重调因子,提升模型精确度。
本实施例中的学生模型(模型S)是基于教师语言模型(模型T)并选择每隔2层、3层或4层transformer抽取一层transformer的方式构造。以12层BERT模型为例,构造S模型时,可以选择每隔2层、3层或4层transformer抽取一层transformer的方案,进行S模型的构造。如附图3所示,为保证预测一致性,transformer层的词向量维数应保持一致。
本实施例步骤S5中的学生模型(模型S)是基于具体任务数据进行训练,具体如下:
S501、构造损失函数;
S502、如附图4所示,在训练过程中,添加梯度扰动:通过梯度扰动,更新参数时,在原本的梯度基础上加入梯度叠加,增加模型的泛化性能,提高模型在新数据上的预测准确率;其中,使用基于L2范数的梯度叠加,公式如下:
Figure BDA0003607650430000091
Figure BDA0003607650430000092
g表示原始梯度;emb′表示经过扰动后的输出;g表示扰动后的梯度值;
其中,训练过程分为两个阶段:
①、将f和s置零,即针对网络中间层进行拟合,使S学生模型能够学习到教师语言模型的transformer结构参数;
②、适当减小m和c的值,并提高f和s的值,使S学生模型和教师语言模型在保持结构参数的情况下,学习对特定任务的预测。
本实施例步骤S501中的构造损失函数具体如下:
(1)、针对标签的focal loss,公式如下:
Lf=-(1-pt)γlog(pt);
其中,pt表示分对的概率,γ用于调制难例,增加错误分类的重要性;
(2)、针对教师语言模型预测结果的软化softmax损失,以使模型更好的学习到数据的分布情况,公式如下:
Ls=-∑pilogsi
其中,pi和si分别为学习模型和教师模型的软化概率;
其中,软化概率分布定义如下:
Figure BDA0003607650430000101
其中,z为网络输出;T为调节因子;
(3)、针对对应学生模型与教师语言模型transformer层之间的MSE损失,公式如下:
Lm=∑MSE(trsS,trsT);
其中,trs为transformer的输出;
(4)、针对对应学生模型与教师语言模型transformer层之间的COS损失,公式如下:
Lc=∑COS(trsS,trsT);
其中COS损失定义如下:
Figure BDA0003607650430000102
即最终损失函数为损失函数加权:
L=f*Lf+s*Ls+m*Lm+c*Lc
其中,f、s、m、c分别为加权因子。
实施例2:
本实施例提供了一种基于知识蒸馏的文本分类系统,该系统包括,
获取模块一,用于获取无监督语料(数据1)并对无监督语料进行数据预处理;
训练模块一,用于基于大规模无监督语料训练得到教师语言模型(模型T);
训练模块二,用于使用针对具体分类任务的有监督训练语料(数据2)对教师语言模型(模型T)通过fine-tuning进行分类任务训练,获得训练好的教师语言模型(模型T);
构造模块,用于根据具体分类任务和训练好的教师语言模型(模型T)构造学生模型(模型S);
获取模块二,用于根据教师语言模型(模型T)的中间层输出和最终输出,构造损失函数,对学生模型(模型S)进行训练,获取最终的学生模型(模型S);
预测模块,用于输入新数据,使用最终的学生模型(模型S)进行文本分类的预测。
实施例3:
本实施例还提供了一种电子设备,包括:存储器和处理器;
其中,存储器存储计算机执行指令;
处理器执行所述存储器存储的计算机执行指令,使得处理器执行本发明任一实施例中的基于知识蒸馏的文本分类方法。
处理器可以是中央处理单元(,CPU),还可以是其他通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现成可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通过处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器可用于储存计算机程序和/或模块,处理器通过运行或执行存储在存储器内的计算机程序和/或模块,以及调用存储在存储器内的数据,实现电子设备的各种功能。存储器可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序等;存储数据区可存储根据终端的使用所创建的数据等。此外,存储器还可以包括高速随机存取存储器,还可以包括非易失性存储器,例如硬盘、内存、插接式硬盘,只能存储卡(SMC),安全数字(SD)卡,闪存卡、至少一个磁盘存储期间、闪存器件、或其他易失性固态存储器件。
实施例4:
本实施例还提供了一种计算机可读存储介质,其中存储有多条指令,指令由处理器加载,使处理器执行本发明任一实施例中的基于知识蒸馏的文本分类方法。具体地,可以提供配有存储介质的系统或者装置,在该存储介质上存储着实现上述实施例中任一实施例的功能的软件程序代码,且使该系统或者装置的计算机(或CPU或MPU)读出并执行存储在存储介质中的程序代码。
在这种情况下,从存储介质读取的程序代码本身可实现上述实施例中任何一项实施例的功能,因此程序代码和存储程序代码的存储介质构成了本发明的一部分。
用于提供程序代码的存储介质实施例包括软盘、硬盘、磁光盘、光盘(如CD-ROM、CD-R、CD-RW、DVD-ROM、DVD-RYM、DVD-RW、DVD+RW)、磁带、非易失性存储卡和ROM。可选择地,可以由通信网络从服务器计算机上下载程序代码。
此外,应该清楚的是,不仅可以通过执行计算机所读出的程序代码,而且可以通过基于程序代码的指令使计算机上操作的操作系统等来完成部分或者全部的实际操作,从而实现上述实施例中任意一项实施例的功能。
此外,可以理解的是,将由存储介质读出的程序代码写到插入计算机内的扩展板中所设置的存储器中或者写到与计算机相连接的扩展单元中设置的存储器中,随后基于程序代码的指令使安装在扩展板或者扩展单元上的CPU等来执行部分和全部实际操作,从而实现上述实施例中任一实施例的功能。
最后应说明的是:以上各实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述各实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分或者全部技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的范围。

Claims (10)

1.一种基于知识蒸馏的文本分类方法,其特征在于,该方法具体如下:
获取无监督语料并对无监督语料进行数据预处理;
基于大规模无监督语料训练得到教师语言模型;
使用针对具体分类任务的有监督训练语料对教师语言模型通过fine-tuning进行分类任务训练,获得训练好的教师语言模型;
根据具体分类任务和训练好的教师语言模型构造学生模型;
根据教师语言模型的中间层输出和最终输出,构造损失函数,对学生模型进行训练,获取最终的学生模型;
使用最终的学生模型进行文本分类的预测:输入新数据进行分类结构的预测。
2.根据权利要求1所述的基于知识蒸馏的文本分类方法,其特征在于,教师语言模型设定为语言模型,在训练时直接使用无监督语料,即正常文本语言文字;
无监督语料是从任意的文章、著作、互联网博客或新闻进行搜集获取;从泛化角度考虑,收集不同领域及不同来源的语料数据;从性能角度考虑,语料数据大小为1G以上;
对无监督语料进行数据预处理具体为:
根据需要去除通用词;
自定义预处理函数去除字符;
对于BERT具有特定tokenizer方式的教师语言模型,使用对应的tokenizer函数进行处理。
3.根据权利要求1所述的基于知识蒸馏的文本分类方法,其特征在于,教师语言模型采用BERT语言模型,BERT语言模型包括输入层、编码层和输出层;输入层用于词嵌入;编码层包括多层tansformer层,tansformer层用于编码;
BERT语言模型训练具体如下:
构建基于BERT的词嵌入网络向量表征信息,具体如下:
构建基于每个词的词语向量;
构建基于每个语句的段向量;
构建基于每个词的位置向量;
将词语向量、段向量和位置向量叠加,形成BERT的输入;
根据需要选择中间的tansformer层对BERT的输入进行编码;
将编码后的信息通过输出层输出,输出层包括对next sentence的预测和token;
通过迭代,不断进行参数更新和模型评估,获取满足评估条件的教师语言模型。
4.根据权利要求1所述的基于知识蒸馏的文本分类方法,其特征在于,针对具体的分类任务,具体任务数据为监督数据,监督数据包括原始文本和分类标签;
分类任务训练是针对具体任务数据对教师语言模型进行微调,具体如下:
输入具体任务数据,构建基于BERT的分类模型,获取的模型参数为基础参数进行1个或多个epoch迭代,以获取基准的分类模型,即最终的T模型;
在训练时,使用focal loss函数,通过修改交叉熵函数,通过增加类别权重和样本难度权重调因子,提升模型精确度。
5.根据权利要求1所述的基于知识蒸馏的文本分类方法,其特征在于,学生模型是基于教师语言模型并选择每隔2层、3层或4层transformer抽取一层transformer的方式构造。
6.根据权利要求1-5中任一所述的基于知识蒸馏的文本分类方法,其特征在于,学生模型是基于具体任务数据进行训练,具体如下:
构造损失函数;
在训练过程中,添加梯度扰动:通过梯度扰动,更新参数时,在原本的梯度基础上加入梯度叠加,增加模型的泛化性能,提高模型在新数据上的预测准确率;其中,使用基于L2范数的梯度叠加,公式如下:
Figure FDA0003607650420000031
Figure FDA0003607650420000032
g表示原始梯度;emb′表示经过扰动后的输出;g表示扰动后的梯度值;
其中,训练过程分为两个阶段:
①、将f和s置零,即针对网络中间层进行拟合,使S学生模型能够学习到教师语言模型的transformer结构参数;
②、适当减小m和c的值,并提高f和s的值,使S学生模型和教师语言模型在保持结构参数的情况下,学习对特定任务的预测。
7.根据权利要求6所述的基于知识蒸馏的文本分类方法,其特征在于,构造损失函数具体如下:
(1)、针对标签的focal loss,公式如下:
Lf=-(1-pt)γlog(pt);
其中,pt表示分对的概率,γ用于调制难例,增加错误分类的重要性;
(2)、针对教师语言模型预测结果的软化softmax损失,以使模型更好的学习到数据的分布情况,公式如下:
Ls=-∑pilogsi
其中,pi和si分别为学习模型和教师模型的软化概率;
其中,软化概率分布定义如下:
Figure FDA0003607650420000041
其中,z为网络输出;T为调节因子;
(3)、针对对应学生模型与教师语言模型transformer层之间的MSE损失,公式如下:
Lm=∑MSE(trsS,trsT);
其中,trs为transformer的输出;
(4)、针对对应学生模型与教师语言模型transformer层之间的COS损失,公式如下:
Lc=∑COS(trsS,trsT);
其中COS损失定义如下:
Figure FDA0003607650420000042
即最终损失函数为损失函数加权:
L=f*Lf+s*Ls+m*Lm+c*Lc
其中,f、s、m、c分别为加权因子。
8.一种基于知识蒸馏的文本分类系统,其特征在于,该系统包括,
获取模块一,用于获取无监督语料并对无监督语料进行数据预处理;
训练模块一,用于基于大规模无监督语料训练得到教师语言模型;
训练模块二,用于使用针对具体分类任务的有监督训练语料对教师语言模型通过fine-tuning进行分类任务训练,获得训练好的教师语言模型;
构造模块,用于根据具体分类任务和训练好的教师语言模型构造学生模型;
获取模块二,用于根据教师语言模型的中间层输出和最终输出,构造损失函数,对学生模型进行训练,获取最终的学生模型;
预测模块,用于输入新数据,使用最终的学生模型进行文本分类的预测。
9.一种电子设备,其特征在于,包括:存储器和至少一个处理器;
其中,所述存储器上存储有计算机程序;
所述至少一个处理器执行所述存储器存储的计算机程序,使得所述至少一个处理器执行如权利要求1至7任一项所述的基于知识蒸馏的文本分类方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有计算机程序,所述计算机程序可被处理器执行以实现如权利要求1至7中任一项所述的基于知识蒸馏的文本分类方法。
CN202210421020.0A 2022-04-21 2022-04-21 基于知识蒸馏的文本分类方法及系统 Pending CN114818902A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210421020.0A CN114818902A (zh) 2022-04-21 2022-04-21 基于知识蒸馏的文本分类方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210421020.0A CN114818902A (zh) 2022-04-21 2022-04-21 基于知识蒸馏的文本分类方法及系统

Publications (1)

Publication Number Publication Date
CN114818902A true CN114818902A (zh) 2022-07-29

Family

ID=82505399

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210421020.0A Pending CN114818902A (zh) 2022-04-21 2022-04-21 基于知识蒸馏的文本分类方法及系统

Country Status (1)

Country Link
CN (1) CN114818902A (zh)

Cited By (13)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115879446A (zh) * 2022-12-30 2023-03-31 北京百度网讯科技有限公司 文本处理方法、深度学习模型训练方法、装置以及设备
CN116187322A (zh) * 2023-03-15 2023-05-30 深圳市迪博企业风险管理技术有限公司 一种基于动量蒸馏的内控合规检测方法及系统
CN116340779A (zh) * 2023-05-30 2023-06-27 北京智源人工智能研究院 一种下一代通用基础模型的训练方法、装置和电子设备
CN116362351A (zh) * 2023-05-29 2023-06-30 深圳须弥云图空间科技有限公司 利用噪声扰动训练预训练语言模型的方法及装置
CN116595130A (zh) * 2023-07-18 2023-08-15 深圳须弥云图空间科技有限公司 基于小语言模型的多种任务下的语料扩充方法及装置
CN116629346A (zh) * 2023-07-24 2023-08-22 成都云栈科技有限公司 一种用于实验室知识传承的模型训练方法及装置
CN116663678A (zh) * 2023-06-20 2023-08-29 北京智谱华章科技有限公司 面向超大规模模型的蒸馏优化方法、装置、介质及设备
CN116861302A (zh) * 2023-09-05 2023-10-10 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法
CN117236409A (zh) * 2023-11-16 2023-12-15 中电科大数据研究院有限公司 基于大模型的小模型训练方法、装置、系统和存储介质
CN117725960A (zh) * 2024-02-18 2024-03-19 智慧眼科技股份有限公司 基于知识蒸馏的语言模型训练方法、文本分类方法及设备
CN117807235A (zh) * 2024-01-17 2024-04-02 长春大学 一种基于模型内部特征蒸馏的文本分类方法
CN117933364A (zh) * 2024-03-20 2024-04-26 烟台海颐软件股份有限公司 基于跨语言知识迁移和经验驱动的电力行业模型训练方法
CN117992598A (zh) * 2024-04-07 2024-05-07 同盾科技有限公司 基于大模型的需求响应方法、装置、介质及设备

Cited By (22)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115879446A (zh) * 2022-12-30 2023-03-31 北京百度网讯科技有限公司 文本处理方法、深度学习模型训练方法、装置以及设备
CN115879446B (zh) * 2022-12-30 2024-01-12 北京百度网讯科技有限公司 文本处理方法、深度学习模型训练方法、装置以及设备
CN116187322A (zh) * 2023-03-15 2023-05-30 深圳市迪博企业风险管理技术有限公司 一种基于动量蒸馏的内控合规检测方法及系统
CN116187322B (zh) * 2023-03-15 2023-07-25 深圳市迪博企业风险管理技术有限公司 一种基于动量蒸馏的内控合规检测方法及系统
CN116362351A (zh) * 2023-05-29 2023-06-30 深圳须弥云图空间科技有限公司 利用噪声扰动训练预训练语言模型的方法及装置
CN116362351B (zh) * 2023-05-29 2023-09-26 深圳须弥云图空间科技有限公司 利用噪声扰动训练预训练语言模型的方法及装置
CN116340779A (zh) * 2023-05-30 2023-06-27 北京智源人工智能研究院 一种下一代通用基础模型的训练方法、装置和电子设备
CN116663678A (zh) * 2023-06-20 2023-08-29 北京智谱华章科技有限公司 面向超大规模模型的蒸馏优化方法、装置、介质及设备
CN116595130A (zh) * 2023-07-18 2023-08-15 深圳须弥云图空间科技有限公司 基于小语言模型的多种任务下的语料扩充方法及装置
CN116595130B (zh) * 2023-07-18 2024-02-20 深圳须弥云图空间科技有限公司 基于小语言模型的多种任务下的语料扩充方法及装置
CN116629346B (zh) * 2023-07-24 2023-10-20 成都云栈科技有限公司 一种语言模型训练方法及装置
CN116629346A (zh) * 2023-07-24 2023-08-22 成都云栈科技有限公司 一种用于实验室知识传承的模型训练方法及装置
CN116861302A (zh) * 2023-09-05 2023-10-10 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法
CN116861302B (zh) * 2023-09-05 2024-01-23 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法
CN117236409A (zh) * 2023-11-16 2023-12-15 中电科大数据研究院有限公司 基于大模型的小模型训练方法、装置、系统和存储介质
CN117236409B (zh) * 2023-11-16 2024-02-27 中电科大数据研究院有限公司 基于大模型的小模型训练方法、装置、系统和存储介质
CN117807235A (zh) * 2024-01-17 2024-04-02 长春大学 一种基于模型内部特征蒸馏的文本分类方法
CN117807235B (zh) * 2024-01-17 2024-05-10 长春大学 一种基于模型内部特征蒸馏的文本分类方法
CN117725960A (zh) * 2024-02-18 2024-03-19 智慧眼科技股份有限公司 基于知识蒸馏的语言模型训练方法、文本分类方法及设备
CN117933364A (zh) * 2024-03-20 2024-04-26 烟台海颐软件股份有限公司 基于跨语言知识迁移和经验驱动的电力行业模型训练方法
CN117933364B (zh) * 2024-03-20 2024-06-04 烟台海颐软件股份有限公司 基于跨语言知识迁移和经验驱动的电力行业模型训练方法
CN117992598A (zh) * 2024-04-07 2024-05-07 同盾科技有限公司 基于大模型的需求响应方法、装置、介质及设备

Similar Documents

Publication Publication Date Title
CN114818902A (zh) 基于知识蒸馏的文本分类方法及系统
US11455527B2 (en) Classification of sparsely labeled text documents while preserving semantics
KR102008845B1 (ko) 비정형 데이터의 카테고리 자동분류 방법
CN110580292A (zh) 一种文本标签生成方法、装置和计算机可读存储介质
US11481552B2 (en) Generative-discriminative language modeling for controllable text generation
US20220147814A1 (en) Task specific processing of regulatory content
CN112100401B (zh) 面向科技服务的知识图谱构建方法、装置、设备及存储介质
Bokka et al. Deep Learning for Natural Language Processing: Solve your natural language processing problems with smart deep neural networks
CN113378573A (zh) 面向内容大数据的小样本关系抽取方法和装置
CN114218932A (zh) 基于故障因果图谱的航空故障文本摘要生成方法及其装置
CN111709225B (zh) 一种事件因果关系判别方法、装置和计算机可读存储介质
US20240020486A1 (en) Systems and methods for finetuning with learned hidden representations of parameter changes
Susanto et al. Semantic parsing with neural hybrid trees
CN115270797A (zh) 一种基于自训练半监督学习的文本实体抽取方法及系统
CN115496072A (zh) 一种基于对比学习的关系抽取方法
Michel et al. Identification of Decision Rules from Legislative Documents Using Machine Learning and Natural Language Processing.
CN110674642A (zh) 一种用于含噪稀疏文本的语义关系抽取方法
Zhang et al. Learned adapters are better than manually designed adapters
Luo et al. Semi-supervised teacher-student architecture for relation extraction
CN111813941A (zh) 结合rpa和ai的文本分类方法、装置、设备及介质
US20230168989A1 (en) BUSINESS LANGUAGE PROCESSING USING LoQoS AND rb-LSTM
US20230162490A1 (en) Systems and methods for vision-language distribution alignment
CN115713082A (zh) 一种命名实体识别方法、装置、设备及存储介质
CN112256841B (zh) 文本匹配和对抗文本识别方法、装置及设备
Sharma et al. Weighted Ensemble LSTM Model with Word Embedding Attention for E-Commerce Product Recommendation

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination