CN116595167A - 一种基于集成知识蒸馏网络的意图识别方法 - Google Patents
一种基于集成知识蒸馏网络的意图识别方法 Download PDFInfo
- Publication number
- CN116595167A CN116595167A CN202310318318.3A CN202310318318A CN116595167A CN 116595167 A CN116595167 A CN 116595167A CN 202310318318 A CN202310318318 A CN 202310318318A CN 116595167 A CN116595167 A CN 116595167A
- Authority
- CN
- China
- Prior art keywords
- model
- data
- models
- training
- integrated
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 27
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 20
- 238000012549 training Methods 0.000 claims abstract description 55
- 230000002708 enhancing effect Effects 0.000 claims abstract description 7
- 230000006870 function Effects 0.000 claims description 24
- 238000013519 translation Methods 0.000 claims description 16
- 238000011176 pooling Methods 0.000 claims description 10
- 230000010354 integration Effects 0.000 claims description 9
- 238000013459 approach Methods 0.000 claims description 7
- 230000008569 process Effects 0.000 claims description 6
- 230000002238 attenuated effect Effects 0.000 claims description 5
- 238000002922 simulated annealing Methods 0.000 claims description 5
- 238000005096 rolling process Methods 0.000 claims description 3
- 230000000694 effects Effects 0.000 description 5
- 230000002457 bidirectional effect Effects 0.000 description 4
- 102100033814 Alanine aminotransferase 2 Human genes 0.000 description 2
- 101000779415 Homo sapiens Alanine aminotransferase 2 Proteins 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 238000006243 chemical reaction Methods 0.000 description 2
- 238000003780 insertion Methods 0.000 description 2
- 230000037431 insertion Effects 0.000 description 2
- 238000004519 manufacturing process Methods 0.000 description 2
- 230000011218 segmentation Effects 0.000 description 2
- 238000009825 accumulation Methods 0.000 description 1
- 230000004075 alteration Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/35—Clustering; Classification
- G06F16/353—Clustering; Classification into predefined classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/33—Querying
- G06F16/332—Query formulation
- G06F16/3329—Natural language query formulation or dialogue systems
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/30—Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
- G06F16/33—Querying
- G06F16/3331—Query processing
- G06F16/3332—Query translation
- G06F16/3337—Translation of the query language, e.g. Chinese to English
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/20—Natural language analysis
- G06F40/279—Recognition of textual entities
- G06F40/289—Phrasal analysis, e.g. finite state techniques or chunking
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/30—Semantic analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/042—Knowledge-based neural networks; Logical representations of neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Databases & Information Systems (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Human Computer Interaction (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及计算机技术领域,尤其是一种基于集成知识蒸馏网络的意图识别方法,包括如下步骤:步骤1、获得数据,并对数据进行增强,形成增强后的数据;步骤2、结合增强后的数据,训练RoBERTa‑wwm‑ext‑zh‑Large‑RCNN模型,XLNET‑zh‑mid‑CNN模型和Multilingual T5‑RCNN模型;步骤3、集成上述是三个模型;步骤4、将上述三个模型集成的知识蒸馏到的TextCNN模型,进行意图识别的语义信息捕捉。本发明克服了意图识别系统的语义信息捕捉能力不足的问题,提升意图识别系统的性能。
Description
技术领域
本发明涉及的计算机技术领域,尤其是一种基于集成知识蒸馏网络的意图识别方法。
背景技术
随着数据快速积累、计算机运算能力的提升,人工智能发展环境也发生巨大的变化。智能人机对话系统中,能够准确地识别出用户的意图对于理解用户提出的问题以及欲求得到的帮助具有重要的意义。目前,用户意图识别主要使用规则匹配、基于特征的机器学习和基于深度学习模型,但目前的意图识别往往存在没有进行语义层面深层判断的问题,捕捉语义信息的能力不足,输入的轻微扰动很容易导致模型对意图的判断出错。本发明要解决如何提升意图识别的语义信息捕捉能力。
目前,常用的语言训练模型有RoBERTa-wwm-ext-zh-Large-RCNN,XLNET-zh-mid-CNN和Multilingual T5-RCNN等大模型,但是在生产环境中,特别是实时性要求高的环境中,大模型虽然精度高,但是用来预测的耗时非常严重,难以用于实时性有要求的场景,大模型集成进行预测的耗时更是无法接受。
发明内容
本发明针对上述问题,提出了一种基基于集成知识蒸馏网络的意图识别方法,克服了意图识别系统的语义信息捕捉能力不足的问题,提升意图识别系统的性能。
本发明提供如下技术方案:一种基于集成知识蒸馏网络的意图识别方法,包括如下步骤:
步骤1、获得数据,并对数据进行增强,形成增强后的数据;
步骤2、结合增强后的数据,训练多个语言模型,提高精度;
步骤3、集成上述是语言模型,提高模型预测的精度;
步骤4、将集成上述语言模型的知识蒸馏到的TextCNN模型,进行意图识别的语义信息捕捉。
在步骤1中,采用三种数据增强的方式对N条数据进行增强,第一种方式通过回译的方式进行数据增强,调用有道翻译api,对N条数据首先通过有道翻译翻译从中文翻译为英文,再将翻译得到的英文数据通过有道翻译翻译为中文,得到N条增强数据;第二种方式使用EDA的方式进行数据增强,得到4N条增强数据;第三种方式使用MLM(Mask LanguageModel)进行数据增强,获得N条增强数据;上述三种方式共计得到6N条增强数据。
步骤2中的语言模型为三个,分别为RoBERTa-wwm-ext-zh-Large-RCNN模型,XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型,步骤1中获得的N条数据和增强后得到的6N条增强数据,共7N条数据训练RoBERTa-wwm-ext-zh-Large-RCNN模型、XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型。
使用EDA的方式进行数据增强时,步骤101、同义词替换(SR:Synonyms Replace),将分词后的句子,在不考虑stopwords的情况下,在句子里随机抽取n个词,1<=n<=3,然后从同义词词典中随机抽取同义词,进行替换;步骤102、随机插入(RI:Randomly Insert),在分词后不考虑stopwords的情况下,随机抽取一个词,然后在该词的同义词集合里随机选择一个,插入原句字中的随机位置,该过程可以重复m次,1<=m<=3,步骤103、随机交换(RS:Randomly Swap),句子中随机选择两个词位置交换;该过程进行1次。步骤104、进行随机删除(RD:Randomly Delete),句子中的每个词,以概率p随机删除,p=0.1,若没有词被删除则再重复一次,直至出现删除词,得到4N条增强数据。
使用MLM进行数据增强时,首先将句子进行分词,随机选择某个词用MASK标记替换,将替换后的句子输入RoBERTa-zh-Large模型,输出的句子若和原句子相同,则重新选择MASK再次输出,直到得到不同的句子,即为增强后的句子,得到N条增强数据。
RoBERTa-wwm-ext-zh-Large-RCNN模型为RoBERTa-Large模型连接RCNN模型,训练时,首先数据通过RoBERTa-wwm-ext模型得到每个字符的编码,即通过RoBERTa-wwm-ext将数据每个字符进行Embedding,得到的Embedding向量再通过一个双向的LSTM模型,得到的编码再通过卷积核大小为2,3,4的CNN模型进行卷积池化操作,最后拼接在一起通过一个全连接层输出分类的结果;在训练时候,微调RoBERTa-wwm-ext模型采用的学习率lr=5e-6,在训练RCNN层的时候使用的学习率lr=1e-3,因为RCNN需要重头开始学,RoBERTa-wwm-ext参数只需要微调,对学习率采用指数衰减的模拟退火算法进行衰减,当模型接近收敛的时候,停止RoBERTa-wwm-ext参数的微调更新,只更新RCNN参数,训练到收敛为止。
XLNET-zh-mid-CNN模型训练,先将数据通过XLNET-zh-mid模型得到每个字符的编码,也就是把模型当做编码器,得到每个字符的Embedding表示,然后通过一层self-attention层计算自注意力,通过卷积核大小为2,3,4的CNN模型进行卷积和池化操作,最后拼接在一起通过一个全连接层得到输出分类的结果;在训练网络时,XLNET-zh-mid使用的学习率lr=1e-5,使用指数衰减的模拟退火算法进行衰减,模型在接近收敛时,停止更新XLNET-zh-mid的参数,只更新CNN的参数,直到模型收敛。
Multilingual T5-RCNN模型(简称mT5)训练采用bert4keras提供的默认模型加载及训练方式。在数据经过mT5模型输出后,经过一层self-attenion计算自注意力,然后再通过一个双向的LSTM层,得到的编码在通过卷积核大小为2,3,4的CNN模型进行卷积池化操作,最后拼接在一起通过一个全连接层输出分类的结果。学习率mT5使用lr=2e10-6,RCNN层使用lr=1e10-3。训练的方式是采用Seq2Seq的方式进行训练的,在任务场景中,通过转化为Seq2Seq的训练方式,比如输入:识别该内容是否是购买意图:我很喜欢这条裙子。输出:是(否)。通过转化为Seq2Seq方式训练有监督的意图识别任务,这种转化的思想和GPT2GPT3的思想是一致的,都是希望用文字把任务表达出来,然后转化为文字预测。按照上述方式训练直至模型收敛。
训练了上述3个模型,通过细致的调参将每个模型调至最好的效果。为了进一步挖掘3个模型的潜力,得到精度最高的模型,采用模型集成的方式,集合3个原本已经高精度的大型模型之力,进一步提高模型预测的精度,步骤3中,集成时将三个模型RoBERTa-wwm-ext-zh-Large-RCNN模型、XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型分别设为A,B,C,根据三个模型的预测结果分为两种情况,情况一:模型A,B,C的预测结果一致(比如都判断为某个意图),这时将模型A,B,C输出的logits的结果[a1,a2]、[b1,b2]、[c1,c2],取均值得到[(a1+b1+c1)/3,(a2+b2+c2)/3];情况二:模型A,B,C的预测结果不一致(比如A,B判断为一个意图,C判断为另一个意图),这时取模型A,B输出的logits的结果[a1,a2]、[b1,b2],取均值得到[(a1+b1)/2,(a2+b2)/2];按照上述两种情况,将训练数据依次使用三个模型进行预测并集成,得到每条训练数据的集成模型logits分布并保存下来,并且分别记录三个大模型各自的logits并保存下来,上述预测和集成都是在离线的环境下进行计算得到的。
从所述三个模型知识蒸馏到TextCNN的步骤包括:对于每条训练数据,在计算loss的时候,首先将该条数据离线计算出的集成模型的logits带入带有温度参数T的softmax函数,softmax函数如下,
其中z表示集成模型的logits,T表示温度,qi是为第i个节点的输出值,通过softmax函数获得集成模型带温度参数T的softmax值记为S_T;
分别将RoBERTa-wwm-ext-zh-Large-RCNN模型,XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型各自的logits值也带入softmax函数,分别计算出每个模型的带温度参数T的softmax值记为SA_T,SB_T,SC_T;
然后计算训练数据通过TextCNN模型之后的logits通过softmax函数得到的softmax值记为S,每条训练数据的真实标签的onthot表示记为y;
训练模型的损失函数用如下表示:
L=aH(y,S)+b1MSE(S,S_T)+b2MSE(S,SA_T)+b3MSE(S,SB_T)+b4MSE(S,SC_T),
其中H表示交叉熵损失函数,MSE(Mean Squared Error)表示均方误差损失函数,a、b1、b2、b3、b4为常数系数。
附图说明
图1为本发明具体实施方式的流程图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明具体实施方式中的技术方案进行清楚、完整地描述,显然,所描述的具体实施方式仅仅是本发明一种具体实施方式,而不是全部的具体实施方式。基于本发明中的具体实施方式,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他具体实施方式,都属于本发明保护的范围。
如图1所示,本方案的基于集成知识蒸馏网络的意图识别方法,包括如下步骤:步骤1、获得数据,并对数据进行增强,形成增强后的数据;步骤2、结合增强后的数据,训练RoBERTa-wwm-ext-zh-Large-RCNN模型,XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型,提高精度;步骤3、集成上述是三个模型,提高模型预测的精度;步骤4、将上述三个模型集成的知识蒸馏到的TextCNN模型,进行意图识别的语义信息捕捉。
在步骤1中,由于业务场景常常会遇到数据量不足的情况,使用三种数据增强的方式对数据进行增强,对于约6000千条数据,以及特定业务场景的两千条数据,共8000多条数据进行数据增强,这里记为N条数据,采用三种数据增强的方式对N条数据进行增强,第一种方式通过回译的方式进行数据增强,调用有道翻译api,对N条数据首先通过有道翻译翻译从中文翻译为英文,再将翻译得到的英文数据通过有道翻译翻译为中文,得到N条增强数据;第二种方式使用EDA的方式进行数据增强,得到4N条增强数据;第三种方式使用MLM(Mask Language Model)进行数据增强,获得N条增强数据;上述三种方式共计得到6N条增强数据。
使用EDA的方式进行数据增强时,步骤101、同义词替换(SR:Synonyms Replace),将分词后的句子,在不考虑stopwords的情况下,在句子里随机抽取n个词,1<=n<=3,然后从同义词词典中随机抽取同义词,进行替换;步骤102、随机插入(RI:Randomly Insert),在分词后不考虑stopwords的情况下,随机抽取一个词,然后在该词的同义词集合里随机选择一个,插入原句字中的随机位置,该过程可以重复m次,1<=m<=3,步骤103、随机交换(RS:Randomly Swap),句子中随机选择两个词位置交换;该过程进行1次。步骤104、进行随机删除(RD:Randomly Delete),句子中的每个词,以概率p随机删除,p=0.1,若没有词被删除则再重复一次,直至出现删除词。这样可以得到4N条增强数据。
使用MLM进行数据增强时,首先将句子进行分词,随机选择某个词用MASK标记替换,将替换后的句子输入RoBERTa-zh-Large模型,输出的句子若和原句子相同,则重新选择MASK再次输出,直到得到不同的句子即为增强后的句子。通过这种方式得到N条增强数据。
步骤1中获得的N条数据和增强后得到的6N条增强数据,共7N条数据训练RoBERTa-wwm-ext-zh-Large-RCNN模型、XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型,即训练这三个大型模型。
RoBERTa-wwm-ext-zh-Large-RCNN模型为RoBERTa-Large模型连接RCNN模型,训练时,首先数据通过RoBERTa-wwm-ext模型得到每个字符的编码,即通过RoBERTa-wwm-ext将数据每个字符进行Embedding,得到的Embedding向量再通过一个双向的LSTM模型,得到的编码再通过卷积核大小为2,3,4的CNN模型进行卷积池化操作,最后拼接在一起通过一个全连接层输出分类的结果;在训练时候,微调RoBERTa-wwm-ext模型采用的学习率lr=5e-6,在训练RCNN层的时候使用的学习率lr=1e-3,因为RCNN需要重头开始学,RoBERTa-wwm-ext参数只需要微调,对学习率采用指数衰减的模拟退火算法进行衰减,当模型接近收敛的时候,停止RoBERTa-wwm-ext参数的微调更新,只更新RCNN参数,训练到收敛为止。
XLNET-zh-mid-CNN模型训练,先将数据通过XLNET-zh-mid模型得到每个字符的编码,也就是把模型当做编码器,得到每个字符的Embedding表示,然后通过一层self-attention层计算自注意力,通过卷积核大小为2,3,4的CNN模型进行卷积和池化操作,最后拼接在一起通过一个全连接层得到输出分类的结果;在训练网络时,XLNET-zh-mid使用的学习率lr=1e-5,使用指数衰减的模拟退火算法进行衰减,模型在接近收敛时,停止更新XLNET-zh-mid的参数,只更新CNN的参数,直到模型收敛。
Multilingual T5-RCNN模型(简称mT5)训练采用bert4keras提供的默认模型加载及训练方式。在数据经过mT5模型输出后,经过一层self-attenion计算自注意力,然后再通过一个双向的LSTM层,得到的编码在通过卷积核大小为2,3,4的CNN模型进行卷积池化操作,最后拼接在一起通过一个全连接层输出分类的结果。学习率mT5使用lr=2e10-6,RCNN层使用lr=1e10-3。训练的方式是采用Seq2Seq的方式进行训练的,在任务场景中,通过转化为Seq2Seq的训练方式,比如输入:识别该内容是否是购买意图:我很喜欢这条裙子。输出:是(否)。通过转化为Seq2Seq方式训练有监督的意图识别任务,这种转化的思想和GPT2GPT3的思想是一致的,都是希望用文字把任务表达出来,然后转化为文字预测。按照上述方式训练直至模型收敛。
训练了上述3个模型,通过细致的调参将每个模型调至最好的效果。为了进一步挖掘3个模型的潜力,得到精度最高的模型,采用模型集成的方式,集合3个原本已经高精度的大型模型之力,进一步提高模型预测的精度,步骤3中,集成时将三个模型RoBERTa-wwm-ext-zh-Large-RCNN模型、XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型分别设为A,B,C,根据三个模型的预测结果分为两种情况,情况一:模型A,B,C的预测结果一致(比如都判断为某个意图),这时将模型A,B,C输出的logits的结果[a1,a2]、[b1,b2]、[c1,c2],取均值得到[(a1+b1+c1)/3,(a2+b2+c2)/3];情况二:模型A,B,C的预测结果不一致(比如A,B判断为一个意图,C判断为另一个意图),这时取模型A,B输出的logits的结果[a1,a2]、[b1,b2],取均值得到[(a1+b1)/2,(a2+b2)/2];按照上述两种情况,将训练数据依次使用三个模型进行预测并集成,得到每条训练数据的集成模型logits分布并保存下来,并且分别记录三个大模型各自的logits并保存下来,上述预测和集成都是在离线的环境下进行计算得到的。
如上分别训练三个大模型,调参到最优状态,并将三个大的模型进行集成的方式,可以使模型的预测精度达到极优的状态。但是,大模型集成进行预测的耗时巨大。因此这里采用的方式是用训练数据训练一个效果还不错的小模型,用小模型部署在实际的生产环境进行预测推断,但是小模型的效果对比大模型有一定的差距,为了尽最大可能弥补这种差距,这里使用集成三个最优大模型的知识蒸馏到小模型的方式,弥补小模型精度的不足。选择TextCNN作为小模型,TextCNN模型的预测推理速度快,而且CNN可以并行计算,提高了模型的推理速度。从所述三个模型知识蒸馏到TextCNN的步骤包括:对于每条训练数据,在计算loss的时候,首先将该条数据离线计算出的集成模型的logits带入带有温度参数T的softmax函数,softmax函数如下,
其中z表示集成模型的logits,T表示温度,qi是为第i个节点的输出值;通过softmax函数获得集成模型带温度参数T的softmax值记为S_T;
分别将RoBERTa-wwm-ext-zh-Large-RCNN模型,XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型各自的logits值也带入softmax函数,分别计算出每个模型的带温度参数T的softmax值记为SA_T,SB_T,SC_T;
然后计算训练数据通过TextCNN模型之后的logits通过softmax函数得到的softmax值记为S,每条训练数据的真实标签的onthot表示记为y;
训练模型的损失函数用如下表示:
L=aH(y,S)+b1MSE(S,S_T)+b2MSE(S,SA_T)+b3MSE(S,SB_T)+b4MSE(S,SC_T),
其中H表示交叉熵损失函数,MSE(Mean Squared Error)表示均方误差损失函数,a、b1、b2、b3、b4为常数系数。
在训练的过程中,温度参数T取值3,这样尽可能拉大每个意图概率值的差异;a取值0.1,这样让模型更多的从集成模型和各自的大模型学习知识,获得更好的学习效果。b1,b2,b3,b4的取值由以下方式获得:集成模型的参数b1取值为0.6,模型A,B,C的预测结果如果一致(比如都判断为某个意图),则选择A,B,C预测结果确信度最高的模型,即预测正确的概率值最大的那个模型,比如A最大,则b2取值0.3,b3,b4取值为0;模型A,B,C的预测结果如果不一致(比如A,B判断为某个意图,C判断为另一个意图),则选择A,B预测结果确信度最高的模型,即预测正确的概率值最大的那个模型,比如A最大,则b2取值0.3,b3,b4取值为0。按照如上方式,针对每条数据使用集成模型和确信度最高的模型的知识蒸馏到小模型。学习率取lr=1e-3,并采用warmup的方式预热学习率,让模型有一个更好的收敛。以上参数是通过调参得到的最优参数,按照上述方式和参数设置训练TextCNN模型直至收敛。通过上述方式知识蒸馏得到的TextCNN模型比普通的TextCNN模型准确率高7个百分点达到89%,比上述集成模型的准确率仅少了3个百分点。而且TextCNN模型由于可并行计算的特性推理一条数据的耗时仅为0.008ms,而大模型推理一条数据耗时80ms,但是经过蒸馏的TextCNN模型的精度已经逼近了集成过得大模型。所以知识蒸馏的小模型精度已经完全合格,速度也达到了可以适用实时性任务场景的要求了。
尽管已经示出和描述了本发明的具体实施方式,对于本领域的普通技术人员而言,可以理解在不脱离本发明的原理和精神的情况下可以对这些具体实施方式进行多种变化、修改、替换和变型,本发明的范围由所附权利要求及其等同物限定。
Claims (10)
1.一种基于集成知识蒸馏网络的意图识别方法,其特征在于包括如下步骤:
步骤1、获得数据,并对数据进行增强,形成增强后的数据;
步骤2、结合增强后的数据,训练多个语言模型;
步骤3、集成上述语言模型;
步骤4、将集成上述语言模型的知识蒸馏到的TextCNN模型,进行意图识别的语义信息捕捉。
2.根据权利要求1所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
步骤1中,采用三种数据增强的方式对获得的N条数据进行增强,第一种方式通过回译的方式进行数据增强,第二种方式使用EDA的方式进行数据增强;第三种方式使用MLM进行数据增强。
3.根据权利要求2所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
采用第一种方式通过回译的方式进行数据增强时,调用有道翻译api,对N条数据首先通过有道翻译翻译从中文翻译为英文,再将翻译得到的英文数据通过有道翻译翻译为中文,得到N条增强数据。
4.根据权利要求3所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
使用EDA的方式进行数据增强时,步骤101、同义词替换,将分词后的句子,在句子里随机抽取n个词,1<=n<=3,然后从同义词词典中随机抽取同义词,进行替换;步骤102、随机插入,随机抽取一个词,然后在该词的同义词集合里随机选择一个,插入原句字中的随机位置,该过程可以重复m次,1<=m<=3,步骤103、随机交换,句子中随机选择两个词位置交换;步骤104、进行随机删除,句子中的每个词,以概率p随机删除,p=0.1,若没有词被删除则再重复一次,直至出现删除词,得到4N条增强数据。
5.根据权利要求4所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
使用MLM进行数据增强时,首先将句子进行分词,随机选择某个词用MASK标记替换,将替换后的句子输入RoBERTa-zh-Large模型,输出的句子若和原句子相同,则重新选择MASK再次输出,直到得到不同的句子,得到N条增强数据。
6.根据权利要求5所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
步骤2中的语言模型为三个,分别为RoBERTa-wwm-ext-zh-Large-RCNN模型,XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型,步骤1中获得的N条数据和增强后得到的6N条增强数据,共7N条数据训练RoBERTa-wwm-ext-zh-Large-RCNN模型、XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型。
7.根据权利要求6所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
RoBERTa-wwm-ext-zh-Large-RCNN模型训练时,首先数据通过RoBERTa-wwm-ext模型得到每个字符的编码,得到的编码再通过卷积核大小为2,3,4的CNN模型进行卷积池化操作,最后拼接在一起通过一个全连接层输出分类的结果;在训练时候,微调RoBERTa-wwm-ext模型采用的学习率lr=5e-6,在训练RCNN层的时候使用的学习率lr=1e-3,当模型接近收敛的时候,停止RoBERTa-wwm-ext参数的微调更新,只更新RCNN参数,训练到收敛为止。
8.根据权利要求7所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
XLNET-zh-mid-CNN模型训练,先将数据通过XLNET-zh-mid模型得到每个字符的编码,得到每个字符的Embedding表示,然后通过一层self-attention层计算自注意力,通过卷积核大小为2,3,4的CNN模型进行卷积和池化操作,最后拼接在一起通过一个全连接层得到输出分类的结果;在训练网络时,XLNET-zh-mid使用的学习率lr=1e-5,使用指数衰减的模拟退火算法进行衰减,模型在接近收敛时,停止更新XLNET-zh-mid的参数,只更新CNN的参数,直到模型收敛;
MultilingualT5-RCNN模型训练采用bert4keras提供的默认模型加载及训练方式。
9.根据权利要求8所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
步骤3中,集成语言模型时,将三个模型RoBERTa-wwm-ext-zh-Large-RCNN模型、XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型分别设为A,B,C,根据三个模型的预测结果分为两种情况,情况一:模型A,B,C的预测结果一致,则将模型A,B,C输出的logits结果取均值;情况二:模型A,B,C的预测结果不一致,这时取结果相同的两个模型输出的logits结果取均值;按照上述两种情况,将训练数据依次使用三个模型进行预测并集成,得到每条训练数据的集成模型logits分布并保存下来,并且分别记录三个模型各自的logits并保存下来,此步骤的预测和集成都是在离线的环境下进行计算得到的。
10.根据权利要求9所述的基于集成知识蒸馏网络的意图识别方法,其特征在于,
步骤4中进行知识蒸馏时:对于每条训练数据,在计算loss的时候,首先将该条数据离线计算出的集成模型的logits带入带有温度参数T的softmax函数,softmax函数如下,
其中z表示集成模型的logits,T表示温度,qi是为第i个节点的输出值,通过softmax函数获得集成模型带温度参数T的softmax值记为S_T;
分别将RoBERTa-wwm-ext-zh-Large-RCNN模型,XLNET-zh-mid-CNN模型和Multilingual T5-RCNN模型各自的logits值也带入softmax函数,分别计算出每个模型的带温度参数T的softmax值记为SA_T,SB_T,SC_T;
然后计算训练数据通过TextCNN模型之后的logits通过softmax函数得到的softmax值记为S,每条训练数据的真实标签的onthot表示记为y;
训练模型的损失函数用如下表示:
L=aH(y,S)+b1MSE(S,S_T)+b2MSE(S,SA_T)+b3MSE(S,SB_T)+b4MSE(S,SC_T),
其中H表示交叉熵损失函数,MSE表示均方误差损失函数,a、b1、b2、b3、b4为常数系数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310318318.3A CN116595167A (zh) | 2023-03-29 | 2023-03-29 | 一种基于集成知识蒸馏网络的意图识别方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310318318.3A CN116595167A (zh) | 2023-03-29 | 2023-03-29 | 一种基于集成知识蒸馏网络的意图识别方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116595167A true CN116595167A (zh) | 2023-08-15 |
Family
ID=87605108
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310318318.3A Pending CN116595167A (zh) | 2023-03-29 | 2023-03-29 | 一种基于集成知识蒸馏网络的意图识别方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116595167A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117807235A (zh) * | 2024-01-17 | 2024-04-02 | 长春大学 | 一种基于模型内部特征蒸馏的文本分类方法 |
-
2023
- 2023-03-29 CN CN202310318318.3A patent/CN116595167A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117807235A (zh) * | 2024-01-17 | 2024-04-02 | 长春大学 | 一种基于模型内部特征蒸馏的文本分类方法 |
CN117807235B (zh) * | 2024-01-17 | 2024-05-10 | 长春大学 | 一种基于模型内部特征蒸馏的文本分类方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2022037256A1 (zh) | 文本语句处理方法、装置、计算机设备和存储介质 | |
CN111914091B (zh) | 一种基于强化学习的实体和关系联合抽取方法 | |
CN110134946B (zh) | 一种针对复杂数据的机器阅读理解方法 | |
CN113326731B (zh) | 一种基于动量网络指导的跨域行人重识别方法 | |
CN111460824B (zh) | 一种基于对抗迁移学习的无标注命名实体识别方法 | |
CN114757182A (zh) | 一种改进训练方式的bert短文本情感分析方法 | |
CN110348012B (zh) | 确定目标字符的方法、装置、存储介质及电子装置 | |
CN116595167A (zh) | 一种基于集成知识蒸馏网络的意图识别方法 | |
CN114417872A (zh) | 一种合同文本命名实体识别方法及系统 | |
CN113779988A (zh) | 一种通信领域过程类知识事件抽取方法 | |
CN114064856A (zh) | 一种基于XLNet-BiGRU文本纠错方法 | |
WO2022227297A1 (zh) | 一种信息分类方法及装置、信息分类模型训练方法及装置 | |
CN111309921A (zh) | 一种文本三元组抽取方法及抽取系统 | |
CN114528387A (zh) | 基于对话流自举的深度学习对话策略模型构建方法和系统 | |
CN114328939A (zh) | 基于大数据的自然语言处理模型构建方法 | |
CN114416981A (zh) | 一种长文本的分类方法、装置、设备及存储介质 | |
CN116522165B (zh) | 一种基于孪生结构的舆情文本匹配系统及方法 | |
CN115186670B (zh) | 一种基于主动学习的领域命名实体识别方法及系统 | |
CN116743605A (zh) | 一种网络服务质量预测方法及装置 | |
CN115797952A (zh) | 基于深度学习的手写英文行识别方法及系统 | |
CN115495579A (zh) | 5g通信助理文本分类的方法、装置、电子设备及存储介质 | |
CN115565177A (zh) | 文字识别模型训练、文字识别方法、装置、设备及介质 | |
CN114781356A (zh) | 一种基于输入共享的文本摘要生成方法 | |
CN114021658A (zh) | 一种命名实体识别模型的训练方法、应用方法及其系统 | |
CN115270780B (zh) | 一种术语识别方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |