CN117494762A - 学生模型的训练方法、素材处理方法、装置及电子设备 - Google Patents
学生模型的训练方法、素材处理方法、装置及电子设备 Download PDFInfo
- Publication number
- CN117494762A CN117494762A CN202310773161.3A CN202310773161A CN117494762A CN 117494762 A CN117494762 A CN 117494762A CN 202310773161 A CN202310773161 A CN 202310773161A CN 117494762 A CN117494762 A CN 117494762A
- Authority
- CN
- China
- Prior art keywords
- model
- target
- student model
- training
- student
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 257
- 238000000034 method Methods 0.000 title claims abstract description 111
- 239000000463 material Substances 0.000 title claims abstract description 53
- 238000003672 processing method Methods 0.000 title claims abstract description 18
- 238000004821 distillation Methods 0.000 claims abstract description 166
- 238000012545 processing Methods 0.000 claims abstract description 51
- 230000008014 freezing Effects 0.000 claims abstract description 28
- 238000007710 freezing Methods 0.000 claims abstract description 28
- 230000008569 process Effects 0.000 claims description 32
- 230000005012 migration Effects 0.000 claims description 9
- 238000013508 migration Methods 0.000 claims description 9
- 230000015654 memory Effects 0.000 claims description 6
- 238000013519 translation Methods 0.000 claims description 6
- 238000010257 thawing Methods 0.000 claims description 4
- 230000006870 function Effects 0.000 description 34
- 238000010586 diagram Methods 0.000 description 10
- 230000000694 effects Effects 0.000 description 9
- 238000013140 knowledge distillation Methods 0.000 description 9
- 238000013528 artificial neural network Methods 0.000 description 5
- 238000010606 normalization Methods 0.000 description 5
- 230000006978 adaptation Effects 0.000 description 3
- 230000015572 biosynthetic process Effects 0.000 description 3
- 230000009467 reduction Effects 0.000 description 3
- 238000003786 synthesis reaction Methods 0.000 description 3
- 108091023037 Aptamer Proteins 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 230000002457 bidirectional effect Effects 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 230000001364 causal effect Effects 0.000 description 2
- 238000012512 characterization method Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 230000007774 longterm Effects 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 238000007781 pre-processing Methods 0.000 description 2
- 239000013598 vector Substances 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 230000006403 short-term memory Effects 0.000 description 1
- 230000002123 temporal effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02P—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN THE PRODUCTION OR PROCESSING OF GOODS
- Y02P90/00—Enabling technologies with a potential contribution to greenhouse gas [GHG] emissions mitigation
- Y02P90/30—Computing systems specially adapted for manufacturing
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了一种学生模型的训练方法、素材处理方法、装置及电子设备,属于计算机领域。所述学生模型的训练方法包括:获取待训练的学生模型;通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。
Description
技术领域
本申请属于计算机领域,具体涉及一种学生模型的训练方法、素材处理方法、装置及电子设备。
背景技术
对于一项复杂任务来说,训练出来的模型通常复杂且笨重,此时模型的精度很高,但需要大量的计算资源以及庞大的数据集去支撑,无法部署到实际应用中。于是,为了创建一个能够媲美初始模型精度且便于实际应用的轻量级模型,提出来了一种被验证有效的模型压缩方法,知识蒸馏。
知识蒸馏中通过教师模型训练学生模型,具体地,保证数据集相同,将复杂、学习能力强的网络学到的特征表示“知识”蒸馏出来,传递给参数量小、学习能力弱的网络。
通过知识蒸馏的方式,学生模型获得了教师模型的特征提取能力,同时保证了轻量化易部署的特点。然而,为了保证得到的学生模型的预测准确性,通过传统知识蒸馏方式得到学生模型往往需要耗费较多的时间。
发明内容
本申请实施例提供一种学生模型的训练方法、素材处理方法、装置及电子设备,能够减少得到学生模型所耗费的时间。
第一方面,本申请实施例提供了一种学生模型的训练方法,该方法包括:
获取待训练的学生模型;
通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;
冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。
第二方面,本申请实施例提供了一种素材处理方法,该方法包括:
获取待处理的素材;
向已经训练好的学生模型输入所述待处理的素材;
通过所述已经训练好的学生模型对所述待处理的素材进行文字识别处理或者语音识别处理,并输出所述待处理的素材的识别结果;
其中,所述已经训练好的学生模型是使用第一方面所述的训练方法而得到的。
第三方面,本申请实施例提供了一种素材处理装置,该装置包括:
获取模块,用于获取待处理的素材;
输入模块,用于向已经训练好的学生模型输入所述待处理的素材;
处理模块,用于通过所述已经训练好的学生模型对所述待处理的素材进行文字识别处理或者语音识别处理;
输出模块,用于输出所述待处理的素材的识别结果。
其中,所述已经训练好的学生模型是使用第一方面所述的训练方法而得到的。
第四方面,本申请实施例提供了一种电子设备,该电子设备包括处理器和存储器,所述存储器存储程序或指令,所述程序或指令被所述处理器执行时实现第一方面所述的方法的步骤。
第五方面,本申请实施例提供了一种可读存储介质,所述可读存储介质上存储程序或指令,所述程序或指令被处理器执行时实现第一方面所述的方法的步骤。
在本申请实施例中,获取待训练的学生模型;通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。如此,通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,使得得到的第一目标学生模型获得了教师模型的预测能力,可以在一定程度上保证学生模型的预测准确性,之后在冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数的情况下,通过单独对学生模型的解码器进行蒸馏训练,而非对编码器和解码器均进行蒸馏训练,这样,相较于传统对包括编码器和解码器的整个学生模型进行蒸馏训练的方式,由于只对解码器进行蒸馏训练的时间会少于对编码器和解码器二者进行蒸馏训练的时间,进而可以减小得到学生模型所耗费的时间。
附图说明
图1是本申请实施例提供的一种学生模型的训练方法的示意图;
图2是本申请实施例提供的另一种学生模型的训练方法的流程图;
图3是本申请实施例提供的另一种学生模型的训练方法的流程图;
图4是本申请实施例提供的另一种学生模型的训练方法的流程图;
图5-1是本申请实施例提供的另一种学生模型的训练方法的流程图;
图5-2是本申请实施例提供的另一种学生模型的训练方法的流程图;
图6是本申请实施例提供的一种学生模型的训练方法的示意图;
图7是本申请实施例提供的一种素材处理方法的流程图;
图8是本申请实施例提供的一种语音处理方法的示意图;
图9是本申请实施例提供的一种素材处理装置的结构框图;
图10是本申请实施例提供的一种电子设备的结构框图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员获得的所有其他实施例,都属于本申请保护的范围。
本申请的说明书和权利要求书中的术语“第一”、“第二”等是用于区别类似的对象,而不用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便本申请的实施例能够以除了在这里图示或描述的那些以外的顺序实施,且“第一”、“第二”等所区分的对象通常为一类,并不限定对象的个数,例如第一对象可以是一个,也可以是多个。此外,说明书以及权利要求中“和/或”表示所连接对象的至少其中之一,字符“/”,一般表示前后关联对象是一种“或”的关系。
在介绍本申请实施例提供的数据处理方法之前,先对本申请实施例提供的数据处理方法中涉及到的一些名词进行阐释。
知识蒸馏:一个学生网络从一个教师网络中学习知识,以此得到教师网络的精华部分的一种方法。经过蒸馏后的模型,可以有效学习到教师模型的丰富知识,提升识别准确率。
教师网络:知识蒸馏中的概念,与之对应的是学生网络,知识蒸馏中,使用教师网络训练学生网络,以此达到提高学生网络准确率等的效果。
转换器(transformer):一种基于自注意力机制的时序模型,可以有效对编码器部分的时序信息进行编码,其对时序信息的处理能力远远好于长短期记忆网络,且速度快,在自然语言处理,计算机视觉,机器翻译,语音识别等领域中获得了广泛应用。
适应器(conformer):可包括转换器和卷积神经网络,转换器擅长捕获基于内容的全局交互,而卷积神经网络则有效地利用了局部特征,将转换器和卷积神经网络相结合,使得模型对长时全局交互信息和局部特征都有比较好的建模。
连接时序分类(Connectionist Temporal Classification,CTC):可以理解为基于神经网络的时序类分类。CTC是一种计算损失值的方法,优点是可以对没有对齐的数据进行自动对齐,主要用于没有事先对齐的序列化数据训练上,比如语音识别、文本识别等等。
转换器双向编码器(Bidirectional Encoder Representation fromTransformers,BERT):BERT是一个预训练的语言表征模型。它不采用以往传统的单向语言模型或者把两个单向语言模型进行浅层拼接的方法进行预训练,而是采用新的掩码语言模型进行训练,可以生成深度的双向语言表征,在11个自然语言处理任务中获得了新的最先进的结果。
长短期记忆网络(Long Short-Term Memory,LSTM)是一种时间循环神经网络,是为了解决一般的循环神经网络(Recurrent Neural Network,RNN)存在的长期依赖问题而专门设计出来的,所有的RNN都具有一种重复神经网络模块的链式形式。
在本申请实施例中,学生模型可以包括编码器和解码器。其中,解码器可以使用转换器,编码器可以采用适应器。当然,在一些实施例中,解码器除了可以使用转换器之外,还可以使用LSTM。自然语言模型除了可以使用BERT外,还可以使用GPT-4预训练模型。
在蒸馏训练的场景下,为了保证得到的学生模型的预测准确性,传统知识蒸馏方式往往会对学生模型的整体进行蒸馏训练,也就是说,在整个蒸馏训练过程中从始至终都对学生模型的编码器和解码器进行充分的蒸馏训练,这样一来,得到的学生模型往往具有较高的预测准确性。然而,这样会导致得到学生模型所耗费的时间过多。因而,需要一种新的学生模型的训练方法,能够减少得到学生模型所耗费的时间。
本申请实施例提供的学生模型的训练方法,获取待训练的学生模型;通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。如此,通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,使得得到的第一目标学生模型获得了教师模型的预测能力,可以在一定程度上保证学生模型的预测准确性,之后在冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数的情况下,通过单独对学生模型的解码器进行蒸馏训练,而非对编码器和解码器均进行蒸馏训练,这样,相较于传统对包括编码器和解码器的整个学生模型进行蒸馏训练的方式,由于只对解码器进行蒸馏训练的时间会少于对编码器和解码器二者进行蒸馏训练的时间,进而可以减小得到学生模型所耗费的时间。
在本申请实施例提供的学生模型的训练方法中,除了对整个学生模型进行蒸馏训练(可视为第一段蒸馏训练),以及对学生模型的解码器进行蒸馏训练(可视为第二段蒸馏训练)之外,还可以在得到第二目标学生模型之后,解冻所述第二目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第二目标学生模型进行蒸馏训练(可视为第三段蒸馏训练),得到第三目标学生模型。如此,通过第三段蒸馏训练可以保证经过第一段蒸馏训练后的编码器能够更好地适配第二段蒸馏训练后的解码器,实现编码器与解码器相互协调配合,提升预测准确性。
同时,为了进一步提升学生模型的预测准确性,本申请实施例提供的学生模型的训练方法还可引入自然语言模型(例如BERT或者GPT-4),引入的自然语言模型可以为经过领域微调迁移训练的模型,以更好地匹配所要应用的领域。如此,引入的自然语言模型和教师模型可以共同对学生模型的学习训练进行监督,从而进一步提高学生模型的预测准确率。
另外,在本申请实施例中,在对整个学生模型进行蒸馏训练(可视为第一段蒸馏训练)得到第一目标学生模型之后,还可以减少所述第一目标学生模型中解码器的层数,得到层数减少的目标解码器,后续在第二段蒸馏训练和第三段蒸馏训练中即可针对目标解码器(即层数已减少的解码器)进行蒸馏训练。如此,通过将解码器的层数变少,可以降低网络宽度,使得重打分的时间相对较短,从而更好地满足实时性需求。由于解码器的层数减少可能会影响学生模型的预测准确性,本申请实施例通过单独对学生模型的解码器部分进行蒸馏训练,使得学生模型的解码器网络有更好的信息输入,提升预测准确性。
由上可知,本申请实施例提供的学生模型的训练方法通过减少解码器的层数,可以满足实时要求,同时通过多段训练、多段蒸馏的方式,提高了学生模型的预测准确率,实现既具有较高的准确率又具有实时性的效果。经过训练之后的学生模型即可应用于各种应用场景,例如语音识别场景、图像文字识别场景或者机器翻译场景等。
需了解的是,本申请实施例提供的学生模型的训练方法和素材处理方法均可由电子设备执行。其中,电子设备上可部署教师模型和学生模型,还可部署自然语言模型。电子设备可以是终端设备,例如手机等,也可以是网络侧设备,例如目标服务器等。在本申请实施例提供的学生模型的训练方法和素材处理方法由目标服务器执行的情况下,目标服务器可以是一台服务器,也可以是一个服务器集群(例如,分布式服务器集群),服务器集群中的服务器可相互配合执行本申请实施例提供的学生模型的训练方法和素材处理方法中的各个步骤。
下面结合附图,通过具体的实施例及其应用场景对本申请实施例提供的一种学生模型的训练方法进行详细地说明。
图1是本申请实施例提供的一种学生模型的训练方法的示意图。如图1所示,本申请实施例提供的学生模型的训练方法可涉及到教师模型、学生模型以及自然语言模型。其中,教师模型和自然语言模型用于监督学生模型的学习训练。需了解的是,图1中的自然语言模型是可选的。在一些实施例中,也可以没有自然语言模型。
在本申请实施例中,根据教师模型对学生模型进行蒸馏训练,在训练过程中,可以使用教师模型监督学生模型的训练,同时,还引入自然语言模型共同对学生模型的学习训练进行监督。在训练完成后,得到的学生模型(即,后文提到第二目标学生模型或者第三目标学生模型)可对输入数据进行推理,输出结果。在本申请实施例中,通过知识蒸馏的方式,使得学生模型获得了教师模型的能力,在实际应用中,达到教师模型的大部分效果,输出有效的结果。
在本申请实施例中,蒸馏过程可以是教师模型参数冻结,只做推理操作,学生模型既做推理操作,又做反向传播训练操作。蒸馏的训练过程可以是让学生模型的输出向量分布,在交叉熵损失函数的作用下,让学生模型不断的逼近教师模型的输出,学习教师模型的正收益信息。
在本申请实施例中,学生模型可以包括编码器和解码器。其中,解码器可以使用转换器,编码器可以采用适应器。当然,在一些实施例中,解码器除了可以使用转换器之外,还可以使用LSTM。自然语言模型除了可以使用BERT外,还可以使用GPT-4预训练模型。
图2是本申请实施例提供的一种学生模型的训练方法的流程图。如图2所示,本申请实施例提供的学生模型的训练方法,包括:
步骤210:获取待训练的学生模型;
此步骤中,所述待训练的学生模型可以是初始学生模型,也可以是经过预训练的学生模型。
在本申请的一个实施例中,待训练的学生模型为经过预训练的学生模型。相应地,所述获取待训练的学生模型可包括:对初始学生模型进行预训练,得到经过预训练后收敛的模型,并将经过预训练后收敛的模型作为待训练的学生模型。
在具体的训练过程中,使用训练数据对初始学生模型进行预训练,直至模型收敛,保存学生模型,得到待训练的学生模型。训练时长可以是1万小时。训练损失函数可以为L0=0.3*C+0.7*A。其中,C为CTC损失,A为注意力损失。
在本申请实施例中,还可以对初始教师模型进行预训练,模型收敛后得到目标教师模型。在训练时,初始教师模型和初始学生模型的编码器部分和解码器部分模型都可以进行训练,其中,编码器部分可使用带有因果卷积的适应器,解码器部分使用转换器,训练时采用动态块,让教师-学生模型适应不定时长的输入。
步骤220:通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;
此步骤中,可冻结教师模型的部分参数或全部参数,加载学生模型,不冻结学生模型的参数,对学生模型进行蒸馏训练。同时,所述第一目标学生模型可以为尚未收敛的模型。
步骤230:冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。
此步骤中,通过冻结学生模型的编码器部分,实现对解码器部分的蒸馏训练。其中,所述第二目标学生模型可以为尚未收敛的模型,也可以为经过蒸馏训练后收敛的模型。
需了解的是,在本申请实施例中,可根据训练集和教师模型对学生模型进行蒸馏训练。具体过程可以是:从训练集中获取第一训练样本作为目标样本;根据所述目标样本和所述教师模型,对所述学生模型进行蒸馏损失计算,得到所述目标样本的蒸馏损失值;将所述目标样本输入所述学生模型进行交叉熵损失计算,得到所述目标样本的交叉熵损失值;根据所述目标样本的所述蒸馏损失值和所述交叉熵损失值,对所述学生模型的网络参数进行更新,更新后的所述学生模型用于下一次进行蒸馏训练;重复执行上面的整个过程,直至达到蒸馏训练结束条件,将达到所述蒸馏训练结束条件的所述学生模型作为目标模型(即第一目标学生模型或者第二目标学生模型)。
在本申请实施例中,获取待训练的学生模型;通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。如此,通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,使得得到的第一目标学生模型获得了教师模型的预测能力,可以在一定程度上保证学生模型的预测准确性,之后在冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数的情况下,通过单独对学生模型的解码器进行蒸馏训练,而非对编码器和解码器均进行蒸馏训练,这样,相较于传统对包括编码器和解码器的整个学生模型进行蒸馏训练的方式,由于只对解码器进行蒸馏训练的时间会少于对编码器和解码器二者进行蒸馏训练的时间,进而可以减小得到学生模型所耗费的时间。
图3是本申请实施例提供的一种学生模型的训练方法的流程图。如图3所示,本申请实施例提供的学生模型的训练方法,包括:
步骤310:获取待训练的学生模型;
步骤320:通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;
步骤330:冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。
步骤340:解冻所述第二目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第二目标学生模型进行蒸馏训练,得到第三目标学生模型;其中,对所述第二目标学生模型进行蒸馏训练的过程中的学习率小于,对所述第一目标学生模型中所述解码器的参数进行蒸馏训练的过程中的学习率。
其中,所述第三目标学生模型可以为所述第二目标学生模型经过蒸馏训练后收敛的模型。此步骤中,加载第二目标学生模型,解冻对应的编码器部分,降低学习率,为的是让编码器部分在参数变化不大的情况下,去自动适应解码器部分,完成学生模型训练后的编码器与解码器的适配,使得二者更加贴合。
其中,对第二目标学生模型进行蒸馏训练的过程中的学习率可以是,对所述第一目标学生模型中所述解码器的参数进行蒸馏训练的过程中的学习率的十分之一或者更低。
在本申请实施例中,通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,使得得到的第一目标学生模型获得了教师模型的预测能力,可以在一定程度上保证学生模型的预测准确性,之后在冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数的情况下,通过单独对学生模型的解码器进行蒸馏训练,而非对编码器和解码器均进行蒸馏训练,这样,相较于传统对包括编码器和解码器的整个学生模型进行蒸馏训练的方式,由于只对解码器进行蒸馏训练的时间会少于对编码器和解码器二者进行蒸馏训练的时间,进而可以减小得到学生模型所耗费的时间。同时,解冻学生模型的编码器部分,降低学习率,有利于解码器部分与编码器部分进行适配,提升学生模型的预测准确率。
图4是本申请实施例提供的一种学生模型的训练方法的流程图。如图4所示,本申请实施例提供的学生模型的训练方法,包括:
步骤410:获取待训练的学生模型;
步骤420:冻结所述目标教师模型的所有参数,不冻结所述待训练的学生模型的参数,对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;
步骤430:冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。
步骤410和步骤430的阐释可参考如图2中相应步骤的描述。
在步骤420中,冻结教师模型的所有参数,使用教师模型对整个学生模型进行蒸馏训练。
在本申请实施例中,通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,使得得到的第一目标学生模型获得了教师模型的预测能力,可以在一定程度上保证学生模型的预测准确性,之后在冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数的情况下,通过单独对学生模型的解码器进行蒸馏训练,而非对编码器和解码器均进行蒸馏训练,这样,相较于传统对包括编码器和解码器的整个学生模型进行蒸馏训练的方式,由于只对解码器进行蒸馏训练的时间会少于对编码器和解码器二者进行蒸馏训练的时间,进而可以减小得到学生模型所耗费的时间。同时,冻结教师模型的参数,对学生模型进行蒸馏训练,可以避免教师模型的参数的调整对训练结果的影响,达到尽快将教师模型的知识传递给学生模型的效果。
在本申请的一个实施例中,如图2至图4所示,本申请实施例提供的学生模型的训练方法中所述对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型可以是通过如下损失函数对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型:L1=a*L0+b*D1;
其中,L1为损失函数,a和b为系数,a+b=1,且a大于0.5且小于1,L0为所述待训练的学生模型的损失,D1为所述目标教师模型的蒸馏损失。
此损失函数是针对学生模型,L0为待训练的学生模型的损失,可以是CTC损失和注意力损失的和;D1为目标教师模型的蒸馏损失,在训练过程中由目标教师模型传递到学生模型中。在一个示例中,D1可以是目标教师模型的CTC蒸馏损失和解码器损失之和。
其中,为了保证第一目标学生模型的效果,针对待训练的学生模型的损失系数a可大于0.5,同时考虑到教师模型的蒸馏损失,a小于1,使得教师模型的蒸馏损失对学生模型的影响受到监督。
例如,在一种情况下,a=0.9,b=0.1。此时学生模型的损失函数为L1=0.9*L0+0.1*D1。
在本申请实施例中,相较于传统只对整个学生模型进行蒸馏训练的方式,通过对整个学生模型及学生模型的解码器分别进行蒸馏训练,通过多段训练、多段蒸馏的方式,可以在保证学生模型的预测准确率的同时减小得到学生模型所耗费的时间。同时,通过具体的函数计算第一目标学生模型的损失,考虑了目标教师模型传递的损失和待训练的学生模型自身的损失对第一目标学生模型的影响。
在本申请的一个实施例中,如图2至图4所示,本申请实施例提供的学生模型的训练方法中所述通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型,可以是通过如下损失函数对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型:L2=c*L1+d*D1;
其中,L2为损失函数,c和d为系数,c+d=1,且c大于0.5且小于1,L1为所述第一目标学生模型的损失,D1为所述目标教师模型的蒸馏损失。
此损失函数是针对学生模型,L1为第一学生模型的损失,可以是CTC损失和注意力损失的和;D1为目标教师模型的蒸馏损失,在训练过程中由目标教师模型传递到学生模型中。在一个示例中,D1可以是目标教师模型的CTC蒸馏损失和解码器损失之和。
其中,为了保证第二目标学生模型的效果,针对第一目标学生模型的损失系数c可大于0.5,同时考虑到目标教师模型的蒸馏损失,c小于1,使得目标教师模型的蒸馏损失对学生模型的影响受到监督。
例如,在一种情况下,c=0.8,d=0.2。此时学生模型的损失函数为L2=0.8*L1+0.2*D1。
在本申请实施例中,相较于传统只对整个学生模型进行蒸馏训练的方式,通过对整个学生模型及学生模型的解码器分别进行蒸馏训练,通过多段训练、多段蒸馏的方式,可以在保证学生模型的预测准确率的同时减小得到学生模型所耗费的时间。同时,通过具体的函数计算第二目标学生模型的损失,考虑了目标教师模型传递的损失和第一目标学生模型自身的损失对第二目标学生模型的影响。
图5-1是本申请实施例提供的一种学生模型的训练方法的流程图。如图5-1所示,本申请实施例提供的学生模型的训练方法可包括:
步骤5110:获取待训练的学生模型;
步骤5120:通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器。
步骤5110和步骤5120的阐释可参考如图2中相应步骤的描述。
需了解的是,本申请实施例在得到第一目标学生模型之后,可以减少解码器的层数。
步骤5130:获取经过领域微调迁移训练的目标自然语言模型;
此步骤中,目标自然语言模型可以是BERT模型。以语音领域为例,可以加载原始BERT预训练模型,使用外呼领域的文本数据进行领域微调迁移,待微调训练完毕后,保存领域迁移后形成的领域自适应的BERT模型。
步骤5140:冻结所述目标自然语言模型的参数,冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标自然语言模型和所述目标教师模型对所述第一目标学生模型中解码器的参数进行蒸馏训练,得到第二目标学生模型;其中,所述目标自然语言模型和所述目标教师模型用于对所述第一目标学生模型的学习训练进行监督。其中,所述第二目标学生模型可以为所述第一目标学生模型经过蒸馏训练后收敛的模型。
此步骤中,可以利用降层前学生模型解码器网络训练形成的学生模型的编码器网络、教师模型网络以及经过领域适应后的BERT模型,对降层后学生模型的解码器网络进行蒸馏训练,使得学生模型的解码器网络有更好的信息输入,防止过拟合,提升预测准确性能。
在本申请实施例中,相较于传统只对整个学生模型进行蒸馏训练的方式,通过对整个学生模型及学生模型的解码器分别进行蒸馏训练,通过多段训练、多段蒸馏的方式,可以在保证学生模型的预测准确率的同时减小得到学生模型所耗费的时间。同时,使用领域自适应的BERT模型监督学生模型解码器网络的训练,提升对学生模型的蒸馏效果。
在本申请的一个实施例中,本申请实施例提供的学生模型的训练方法中通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型,可以是通过目标自然语言模型和目标教师模型共同对第一目标学生模型中的解码器的参数进行蒸馏训练。
具体地,可通过如下损失函数对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型:L2=e*L1+f*(D1+D2);
其中,L2为损失函数,e和f为系数,e+f=1,且e大于0.5且小于1,L1为所述第一目标学生模型的损失,D1为所述目标教师模型的蒸馏损失,D2为所述目标自然语言模型的蒸馏损失。
其中,为了保证第二目标学生模型的效果,针对第一目标学生模型的损失系数e可大于0.5,同时考虑到目标教师模型和目标自然语言模型的蒸馏损失,e小于1,使得目标教师模型的蒸馏损失对学生模型的影响受到监督。
例如,在一种情况下,e=0.8,f=0.2。此时学生模型的损失函数为L2=0.8*L1+0.2*(D1+D2)。
在本申请实施例中,相较于传统只对整个学生模型进行蒸馏训练的方式,通过对整个学生模型及学生模型的解码器分别进行蒸馏训练,通过多段训练、多段蒸馏的方式,可以在保证学生模型的预测准确率的同时减小得到学生模型所耗费的时间。同时,通过具体的函数计算第二目标学生模型的损失,考虑了目标教师模型传递的损失、目标自然语言模型传递的损失和第一目标学生模型自身的损失对第二目标学生模型的影响。
图5-2是本申请实施例提供的一种学生模型的训练方法的流程图。如图5-2所示,本申请实施例提供的学生模型的训练方法可包括:
步骤5210:对初始学生模型进行预训练,得到经过预训练后收敛的模型,并将经过预训练后收敛的模型作为待训练的学生模型;
步骤5220:通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;
步骤5230:获取经过领域微调迁移训练的目标自然语言模型;
步骤5240:减少所述第一目标学生模型中所述解码器的层数,得到层数减少的目标解码器;
步骤5250:冻结所述目标自然语言模型的参数,冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标自然语言模型和所述目标教师模型对所述第一目标学生模型中层数减少的所述目标解码器的参数进行蒸馏训练,得到第二目标学生模型;其中,所述目标自然语言模型和所述目标教师模型用于对所述第一目标学生模型的学习训练进行监督;
步骤5260:解冻所述第二目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第二目标学生模型进行蒸馏训练,得到第三目标学生模型;其中,对所述第二目标学生模型进行蒸馏训练的过程中的学习率小于,对所述第一目标学生模型中所述解码器的参数进行蒸馏训练的过程中的学习率。
步骤5210、步骤5220、步骤5230、步骤5250和步骤5260的阐释可参考如图2-图5-1中相应步骤的描述。
在步骤5220中,所述第一目标学生模型可以为所述待训练的学生模型经过蒸馏训练后收敛的模型。在步骤5250中,所述第二目标学生模型可以为所述第一目标学生模型经过蒸馏训练后收敛的模型。在步骤5260中,所述第三目标学生模型可以为所述第二目标学生模型经过蒸馏训练后收敛的模型。
在步骤5240中,第一目标学生模型中解码器的层数可以是12层、10层、8层或者6层,层数减少后的目标解码器的层数可以是4层、3层或者2层。通过减少第一目标学生模型中解码器的层数,降低学生模型的资源占用和时延,提高学生模型的综合性能。在本申请实施例中,目标解码器的层数为大于1的任何数目。在解码器层数降低的同时可降低网络的宽度,例如,网络宽度由2048降低到512。
在本申请实施例中,通过降低学生模型的解码器层数,可以减少资源占用,提升学生模型的实时性能。同时,由于解码器的层数减少可能会影响学生模型的预测准确性,本申请实施例通过单独对学生模型的解码器部分进行蒸馏训练,使得学生模型的解码器网络有更好的信息输入,提升预测准确性。进而,通过对整个学生模型及学生模型的解码器分别进行蒸馏训练,通过多段训练、多段蒸馏的方式,可以在保证实时性的同时提高学生模型的预测准确率。
上述所有实施例中,所获得的学生模型可以用于语音识别、图像文字识别或者机器翻译。即,上述任一项本申请实施例中所得到的训练好的学生模型不仅可以用于语音识别中,还可以用于图像文字识别或者机器翻译等领域。
下面以语音领域为例,对本申请实施例提供的学生模型的训练过程进行进一步阐释。图6是本申请实施例提供的一种学生模型的训练方法的宏观构思的示意图。如图6所示,左侧的编码层和解码层属于教师模型,右侧的编码层和解码层属于学生模型。教师模型的编码层的输入和学生模型的编码层的输入可相同,均为预处理语音数据。在训练过程中,编码层和解码层的输入和输出均为向量。教师模型的解码层的输出可经由线性层处理之后送至归一化层,归一化层例如为softmax函数,经归一化层处理之后可输出软标签。学生模型的解码层的输出可经由线性层处理之后送至归一化层,归一化层例如为softmax函数,经归一化层处理之后可输出硬标签。在训练过程中,可基于教师模型的编码层的输出计算CTC蒸馏损失,可基于教师模型的归一化层输出的软标签计算解码器蒸馏损失。同时,可基于学生模型的编码层的输出计算CTC损失,并可基于学生模型的解码层的输出计算注意力损失。进而,可基于计算得到的CTC蒸馏损失、解码器蒸馏损失、CTC损失和注意力损失,对学生模型的网络参数进行更新。
需了解的是,本申请实施例还可引入自然语言模型。如图6所示,自然语言模型的输出可送至归一化层,归一化层例如为softmax函数,经归一化层处理之后可输出软标签,进而可基于软标签计算自然语言模型蒸馏损失。可基于计算得到的CTC蒸馏损失、解码器蒸馏损失、CTC损失、注意力损失以及自然语言模型蒸馏损失,对学生模型的网络参数进行更新。
本申请实施例中详细的学生模型的训练过程可如下:
使用训练数据对初始教师模型和初始学生模型进行预训练,直至模型收敛,保存教师模型和学生模型,得到待训练的学生模型和已经训练好的目标教师模型。训练时长可以是1万小时。训练损失函数可以为L0=0.3*C+0.7*A。其中,C为CTC损失,A为注意力损失。在预训练的过程中,教师模型和学生模型的编码器部分和解码器部分模型都进行训练,更新参数,其中,编码器部分使用带有因果卷积的适应器,解码器部分使用转换器,训练时采用动态块,让教师-学生模型适应不定时长的输入。
其中,教师模型可以是融合了外部语言模型知识的端到端语音识别框架(RNN-T模型)。自然语音模型可以是BERT模型或是生成式模型中的GPT-4(聊天机器人ChatGPT发布的语言模型)等。
在预训练完成后,冻结教师模型的参数,加载学生模型,不冻结学生模型的参数,对学生模型进行蒸馏训练。
可通过如下损失函数对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型:L1=a*L0+b*D1;其中,L1为损失函数,a和b为系数,a+b=1,且a大于0.5且小于1,L0为所述待训练的学生模型的损失,D1为所述目标教师模型的蒸馏损失。此损失函数仅针对学生模型,L0为待训练的学生模型的损失,可以是CTC损失和注意力损失的和;D1为目标教师模型的蒸馏损失,在训练过程中由目标教师模型传递到学生模型中,可以是目标教师模型的CTC蒸馏损失和解码器损失之和。
在得到第一学生目标模型后,可加载自然语言模型,所述自然语言模型可以是原始BERT预训练模型,使用外呼领域的文本数据进行领域微调迁移,待微调训练完毕后,保存领域迁移后形成的领域自适应的BERT模型;
加载领域自适应BERT模型,冻结BERT模型的参数、冻结教师模型的参数、冻结学生模型的编码器部分相关的参数,将学生模型的解码器部分,从6层网络降低到2层网络,每层网络的宽度由2048降低到512,以降低网络推理资源占用,通过如下损失函数对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型:L2=e*L1+f*(D1+D2);其中,L2为损失函数,e和f为系数,e+f=1,且e大于0.5且小于1,L1为所述第一目标学生模型的损失,D1为所述目标教师模型的蒸馏损失,D2为所述目标自然语言模型的蒸馏损失。
在得到第二学生目标模型后,解冻所述第二目标学生模型中所述编码器的参数,调小学习率至原来的1/10,通过所述目标教师模型对所述第二目标学生模型进行蒸馏训练,得到第三目标学生模型;待模型收敛后,保存训练好的学生模型。
此时,学生模型的训练完成,得到训练好的学生模型。
学生模型的应用过程可如下:当在线上应用时,输入预处理语音数据,进入学生模型的编码器部分网络,将输出结果分别进行CTC结果搜索,并输入到解码器部分,将解码器输出的后验概率和CTC搜索结果进行重打分,得到重打分结果,得到识别后的文字序列。
在本申请实施例中,解码器层数的减少和网络宽度的降低,使得学生模型重打分的时间缩短,可满足实时要求;同时,采用了多段分开蒸馏训练的方式,利用不同教师模型的信息和学生模型降层前的解码器部分对编码器网络进行训练;利用学生模型降层前的解码器部分所训练形成的编码器网络、教师模型网络以及经过领域适应后的BERT模型,对学生模型降层后的解码器网络进行蒸馏训练,使得学生模型降层后的解码器网络有更好的信息输入,防止过拟合,提升预测准确性能。进而,通过多段训练、多段蒸馏的方式,可以在保证实时性的同时提高学生模型的预测准确率。
图7是本申请实施例提供的一种素材处理方法的示意图。如图7所示,本申请实施例提供的素材处理方法,包括:
步骤710:获取待处理的素材;
步骤720:向已经训练好的学生模型输入所述待处理的素材;
步骤730:通过所述已经训练好的学生模型对所述待处理的素材进行文字识别处理或者语音识别处理,并输出所述待处理的素材的识别结果。
在本申请实施例中,所述已经训练好的学生模型是使用上述任一实施例提供的学生模型的训练方法而训练得到的。
在本申请实施例中,使用训练好的学生模型处理素材,进行文字识别处理或者语音识别处理,输出识别结果,对训练好的学生模型进行实际应用。
图8是本申请实施例提供的一种素材处理方法在实际应用中的示意图。图8所示的素材处理方法应用于语音场景中。在图8所示的语音处理过程中可涉及语音采集、语音识别、意图理解、文本生成和语音合成这几个阶段。其中,语音识别阶段可使用训练得到的学生模型,例如上文提到的第二目标学生模型或者第三目标学生模型。
图8所示的语音处理方法的具体过程可如下:在语音采集阶段,接收电话用户端实时传入的语音流信号;在语音识别阶段,当接收到电话用户端实时传入的语音流信号时,利用上文训练好的学生模型(例如图5-2所示训练方法得到的第三目标学生模型)对语音流信号进行语音识别;若在语音识别过程中有新语音输入,则重新进行语音识别,直至语音识别过程中没有新语音输入;若在语音识别过程中没有新语音输入,则对语音信号进行解码,经过快速解码后,在意图理解阶段对解码后的语音进行意图理解;在文本生成阶段,根据意图理解的结果进行判断,将判断逻辑对应到要回答的回复中,生成待合成文本;然后在语音合成阶段将待合成文本进行语音合成,进行语音应答,完成一轮回复。重复执行上面对实时传入的语音流信号进行识别至完成一轮答复的过程,直至未接收到实时传入的语音流信号。
本申请实施例可应用于智能外呼场景下,通过使用训练好的学生模型处理语音素材,进行语音识别处理,输出识别结果,可以大大降低语音识别解码延迟,在语音识别的过程中实现既满足实时要求性要求又保证准确率的效果。
图9是本申请实施例提供的一种素材处理装置的结构框图。如图9所示,本申请实施例提供的一种素材处理装置900包括:获取模块910、发送模块920、处理模块930;
所述获取模块910,用于获取待处理的素材;
所述输入模块920,用于向已经训练好的学生模型输入所述待处理的素材;
所述处理模块930,用于通过所述已经训练好的学生模型对所述待处理的素材进行文字识别处理或者语音识别处理;
输出模块940,用于输出所述待处理的素材的识别结果。
其中,所述已经训练好的学生模型是使用根据上述任一项本申请实施例所述的学生模型的训练方法而得到的。
在本申请实施例中,使用训练好的学生模型处理素材,进行文字识别处理或者语音识别处理,输出识别结果,对训练好的学生模型进行实际应用。
相应地,本申请实施例还提供了一种学生模型的训练装置,该装置包括获取模块与处理模块。
所述获取模块,用于用于获取待训练的学生模型;
所述处理模块,用于通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型为所述待训练的学生模型经过蒸馏训练后收敛的模型,所述第一目标学生模型包括编码器和解码器;
所述处理模块,还用于冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型;其中,所述第二目标学生模型为所述第一目标学生模型经过蒸馏训练后收敛的模型。
在一个实施例中,在所述得到第二目标学生模型之后,所述处理模块还用于解冻所述第二目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第二目标学生模型进行蒸馏训练,得到第三目标学生模型;其中,所述第三目标学生模型为所述第二目标学生模型经过蒸馏训练后收敛的模型;对所述第二目标学生模型进行蒸馏训练的过程中的学习率小于,对所述第一目标学生模型中所述解码器的参数进行蒸馏训练的过程中的学习率。
在一个实施例中,所述通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型中,所述处理模块还用于冻结所述目标教师模型的所有参数,不冻结所述待训练的学生模型的参数,对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型。
在一个实施例中,所述对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型,所述处理模块还用于通过如下损失函数对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型:L1=a*L0+b*D1;其中,L1为损失函数,a和b为系数,a+b=1,且a大于0.5且小于1,L0为所述待训练的学生模型的损失,D1为所述目标教师模型的蒸馏损失。
在一个实施例中,所述通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型中,所述处理模块还用于通过如下损失函数对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型:L2=c*L1+d*D1;其中,L2为损失函数,c和d为系数,c+d=1,且c大于0.5且小于1,L1为所述第一目标学生模型的损失,D1为所述目标教师模型的蒸馏损失。
在一个实施例中,所述获取模块还用于获取经过领域微调迁移训练的目标自然语言模型;所述冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型中,所述处理模块还用于冻结所述目标自然语言模型的参数,冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标自然语言模型和所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型;其中,所述目标自然语言模型和所述目标教师模型用于对所述第一目标学生模型的学习训练进行监督。
在一个实施例中,所述通过所述目标自然语言模型和所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型中,所述处理模块还用于通过如下损失函数对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型:L2=e*L1+f*(D1+D2);其中,L2为损失函数,e和f为系数,e+f=1,且e大于0.5且小于1,L1为所述第一目标学生模型的损失,D1为所述目标教师模型的蒸馏损失,D2为所述目标自然语言模型的蒸馏损失。
在一个实施例中,在得到第一目标学生模型之后,所述处理模块还用于减少所述第一目标学生模型中所述解码器的层数,得到层数减少的目标解码器;所述对所述第一目标学生模型中所述解码器的参数进行蒸馏训练中,所述处理模块还用于对所述第一目标学生模型中层数减少的所述目标解码器的参数进行蒸馏训练。
在一个实施例中,所述获取待训练的学生模型中,所述处理模块还用于对初始学生模型进行预训练,得到经过预训练后收敛的模型,并将经过预训练后收敛的模型作为待训练的学生模型。
在一个实施例中,所述学生模型用于语音识别、图像文字识别或者机器翻译。
如图10所示,本申请实施例提供了一种电子设备1000,所述电子设备可以为各种类型的计算机,终端,以及其他可能的设备等。
所述电子设备1000包括:处理器1010和存储器1020,所述存储器1020存储程序,所述程序被所述处理器1010执行时实现上文所描述的任一种方法的步骤。举例而言,所述程序被处理器1010执行时实现根据如下过程:获取待训练的学生模型;通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。在本申请实施例中,通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,使得得到的第一目标学生模型获得了教师模型的预测能力,可以在一定程度上保证学生模型的预测准确性,之后在冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数的情况下,通过单独对学生模型的解码器进行蒸馏训练,而非对编码器和解码器均进行蒸馏训练,这样,相较于传统对包括编码器和解码器的整个学生模型进行蒸馏训练的方式,由于只对解码器进行蒸馏训练的时间会少于对编码器和解码器二者进行蒸馏训练的时间,进而可以减小得到学生模型所耗费的时间。
又举例而言,所述程序被处理器1010执行时也可实现根据如下过程:获取待处理的素材;向已经训练好的学生模型输入所述待处理的素材;通过所述已经训练好的学生模型对所述待处理的素材进行文字识别处理或者语音识别处理,并输出所述待处理的素材的识别结果;其中,所述已经训练好的学生模型是使用根据上述任一项本申请实施例所述学生模型的训练方法而得到的。在本申请实施例中,使用训练好的学生模型处理素材,进行文字识别处理或者语音识别处理,输出识别结果,对训练好的学生模型进行实际应用。
本申请实施例还提供一种可读存储介质,所述可读存储介质上存储程序或指令,所述程序或指令被处理器执行时实现如上述任一项本申请实施例所述方法的步骤。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。此外,需要指出的是,本申请实施方式中的方法和装置的范围不限按示出或讨论的顺序来执行功能,还可包括根据所涉及的功能按基本同时的方式或按相反的顺序来执行功能,例如,可以按不同于所描述的次序来执行所描述的方法,并且还可以添加、省去、或组合各种步骤。另外,参照某些示例所描述的特征可在其他示例中被组合。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以计算机软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端(可以是手机,计算机,服务器,或者网络设备等)执行本申请各个实施例所述的方法。
上面结合附图对本申请的实施例进行了描述,但是本申请并不局限于上述的具体实施方式,上述的具体实施方式仅仅是示意性的,而不是限制性的,本领域的普通技术人员在本申请的启示下,在不脱离本申请宗旨和权利要求所保护的范围情况下,还可做出很多形式,均属于本申请的保护之内。
Claims (10)
1.一种学生模型的训练方法,其特征在于,包括:
获取待训练的学生模型;
通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型;其中,所述第一目标学生模型包括编码器和解码器;
冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型。
2.根据权利要求1所述的方法,其特征在于,在所述得到第二目标学生模型之后,所述方法还包括:
解冻所述第二目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第二目标学生模型进行蒸馏训练,得到第三目标学生模型;
其中,对所述第二目标学生模型进行蒸馏训练的过程中的学习率小于,对所述第一目标学生模型中所述解码器的参数进行蒸馏训练的过程中的学习率。
3.根据权利要求1所述的方法,其特征在于,所述通过已经训练好的目标教师模型对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型,包括:
冻结所述目标教师模型的所有参数,不冻结所述待训练的学生模型的参数,对所述待训练的学生模型进行蒸馏训练,得到第一目标学生模型。
4.根据权利要求1所述的方法,其特征在于,所述方法还包括:
获取经过领域微调迁移训练的目标自然语言模型;
所述冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型,包括:
冻结所述目标自然语言模型的参数,冻结所述目标教师模型的参数和所述第一目标学生模型中所述编码器的参数,通过所述目标自然语言模型和所述目标教师模型对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,得到第二目标学生模型;
其中,所述目标自然语言模型和所述目标教师模型用于对所述第一目标学生模型的学习训练进行监督。
5.根据权利要求1-4任一项所述的方法,其特征在于,在得到第一目标学生模型之后,所述方法还包括:
减少所述第一目标学生模型中所述解码器的层数,得到层数减少的目标解码器;
所述对所述第一目标学生模型中所述解码器的参数进行蒸馏训练,包括:
对所述第一目标学生模型中层数减少的所述目标解码器的参数进行蒸馏训练。
6.根据权利要求1所述的方法,其特征在于,所述学生模型用于语音识别、图像文字识别或者机器翻译。
7.一种素材处理方法,其特征在于,包括:
获取待处理的素材;
向已经训练好的学生模型输入所述待处理的素材;
通过所述已经训练好的学生模型对所述待处理的素材进行文字识别处理或者语音识别处理,并输出所述待处理的素材的识别结果;
其中,所述已经训练好的学生模型是使用根据权利要求1-6任一项所述的训练方法而得到的。
8.一种素材处理装置,其特征在于,包括:
获取模块,用于获取待处理的素材;
输入模块,用于向已经训练好的学生模型输入所述待处理的素材;
处理模块,用于通过所述已经训练好的学生模型对所述待处理的素材进行文字识别处理或者语音识别处理;
输出模块,用于输出所述待处理的素材的识别结果。
其中,所述已经训练好的学生模型是使用根据权利要求1-6任一项所述的训练方法而得到的。
9.一种电子设备,其特征在于,包括处理器和存储器,所述存储器存储程序或指令,所述程序或指令被所述处理器执行时实现如权利要求1-7任一项所述的方法的步骤。
10.一种可读存储介质,其特征在于,所述可读存储介质上存储程序或指令,所述程序或指令被处理器执行时实现如权利要求1-7任一项所述的方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310773161.3A CN117494762A (zh) | 2023-06-27 | 2023-06-27 | 学生模型的训练方法、素材处理方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310773161.3A CN117494762A (zh) | 2023-06-27 | 2023-06-27 | 学生模型的训练方法、素材处理方法、装置及电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117494762A true CN117494762A (zh) | 2024-02-02 |
Family
ID=89677044
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310773161.3A Pending CN117494762A (zh) | 2023-06-27 | 2023-06-27 | 学生模型的训练方法、素材处理方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117494762A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117708336A (zh) * | 2024-02-05 | 2024-03-15 | 南京邮电大学 | 一种基于主题增强和知识蒸馏的多策略情感分析方法 |
-
2023
- 2023-06-27 CN CN202310773161.3A patent/CN117494762A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117708336A (zh) * | 2024-02-05 | 2024-03-15 | 南京邮电大学 | 一种基于主题增强和知识蒸馏的多策略情感分析方法 |
CN117708336B (zh) * | 2024-02-05 | 2024-04-19 | 南京邮电大学 | 一种基于主题增强和知识蒸馏的多策略情感分析方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110046221B (zh) | 一种机器对话方法、装置、计算机设备及存储介质 | |
ALIAS PARTH GOYAL et al. | Z-forcing: Training stochastic recurrent networks | |
CN109977212A (zh) | 对话机器人的回复内容生成方法和终端设备 | |
WO2020155619A1 (zh) | 带情感的机器聊天方法、装置、计算机设备及存储介质 | |
CN117494762A (zh) | 学生模型的训练方法、素材处理方法、装置及电子设备 | |
CN113988086A (zh) | 对话处理方法及装置 | |
CN117350304B (zh) | 一种多轮对话上下文向量增强方法及系统 | |
CN117634459A (zh) | 目标内容生成及模型训练方法、装置、系统、设备及介质 | |
CN116863920B (zh) | 基于双流自监督网络的语音识别方法、装置、设备及介质 | |
CN117787380A (zh) | 模型获取方法、装置、介质及设备 | |
CN117112766A (zh) | 视觉对话方法、装置、电子设备和计算机可读存储介质 | |
CN115064173B (zh) | 语音识别方法、装置、电子设备及计算机可读介质 | |
CN116343760A (zh) | 基于联邦学习的语音识别方法、系统和计算机设备 | |
Huang et al. | Flow of renyi information in deep neural networks | |
CN116310643A (zh) | 视频处理模型训练方法、装置以及设备 | |
CN113159168B (zh) | 基于冗余词删除的预训练模型加速推理方法和系统 | |
US11941508B2 (en) | Dialog system with adaptive recurrent hopping and dual context encoding | |
CN111310460B (zh) | 语句的调整方法及装置 | |
CN113849641A (zh) | 一种跨领域层次关系的知识蒸馏方法和系统 | |
CN114792388A (zh) | 图像描述文字生成方法、装置及计算机可读存储介质 | |
CN114911911A (zh) | 一种多轮对话方法、装置及电子设备 | |
CN112434143A (zh) | 基于gru单元隐藏状态约束的对话方法、存储介质及系统 | |
CN115329952B (zh) | 一种模型压缩方法、装置和可读存储介质 | |
CN116797829B (zh) | 一种模型生成方法、图像分类方法、装置、设备及介质 | |
CN114638365B (zh) | 一种机器阅读理解推理方法及装置、电子设备、存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |