CN116134454A - 用于使用知识蒸馏训练神经网络模型的方法和系统 - Google Patents
用于使用知识蒸馏训练神经网络模型的方法和系统 Download PDFInfo
- Publication number
- CN116134454A CN116134454A CN202180054067.1A CN202180054067A CN116134454A CN 116134454 A CN116134454 A CN 116134454A CN 202180054067 A CN202180054067 A CN 202180054067A CN 116134454 A CN116134454 A CN 116134454A
- Authority
- CN
- China
- Prior art keywords
- teacher
- student
- hidden
- hidden layer
- layers
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 119
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 98
- 238000012549 training Methods 0.000 title claims abstract description 94
- 238000003062 neural network model Methods 0.000 title description 94
- 230000001537 neural effect Effects 0.000 claims abstract description 24
- 230000006870 function Effects 0.000 claims description 68
- 238000012545 processing Methods 0.000 claims description 42
- 238000013507 mapping Methods 0.000 claims description 34
- 230000015654 memory Effects 0.000 claims description 22
- 230000008569 process Effects 0.000 claims description 10
- 230000004927 fusion Effects 0.000 claims description 9
- 239000011159 matrix material Substances 0.000 claims description 5
- 238000012512 characterization method Methods 0.000 claims 1
- 238000012546 transfer Methods 0.000 abstract description 5
- 238000010801 machine learning Methods 0.000 description 39
- 238000004821 distillation Methods 0.000 description 19
- 238000013528 artificial neural network Methods 0.000 description 13
- 238000004364 calculation method Methods 0.000 description 11
- 239000013598 vector Substances 0.000 description 9
- 238000010586 diagram Methods 0.000 description 6
- 238000013519 translation Methods 0.000 description 6
- 238000004891 communication Methods 0.000 description 4
- 238000013459 approach Methods 0.000 description 3
- 230000002457 bidirectional effect Effects 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000000926 separation method Methods 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000010295 mobile communication Methods 0.000 description 2
- 230000000306 recurrent effect Effects 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 238000004088 simulation Methods 0.000 description 2
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 238000010420 art technique Methods 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000001788 irregular Effects 0.000 description 1
- 230000002093 peripheral effect Effects 0.000 description 1
- 230000001172 regenerating effect Effects 0.000 description 1
- 238000012552 review Methods 0.000 description 1
- 230000006403 short-term memory Effects 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
描述了用于将神经模型的训练知识从复杂模型(教师)转移到较不复杂模型(学生)的无关组合知识蒸馏(CKD)方法。除了训练学生以生成近似于教师最终输出和训练输入的真值两者的最终输出外,该方法还通过训练学生的隐藏层来最大化知识转移,以生成输出,该输出近似于映射到针对给定训练输入的学生隐藏层中的每一层的教师隐藏层的子集的表征。
Description
相关申请的交叉申请
本申请要求2020年9月9日提交的名称为“用于使用知识蒸馏训练神经网络模型的方法和系统(METHOD AND SYSTEM FOR TRAINING A NEURAL NETWORK MODEL USINGKNOWLEDGE DISTILLATION)”、申请号为63/076,335的美国临时专利申请和2021年9月8日提交的名称为“用于使用知识蒸馏训练神经网络模型的方法和系统(METHOD AND SYSTEM FORTRAINING A NEURAL NETWORK MODEL USING KNOWLEDGE DISTILLATION)”、申请号为17/469,573的美国专利申请的在先申请优先权和权益,这些申请的内容通过引用并入本文。
技术领域
本申请涉及用于训练机器学习模型的方法和系统,尤其涉及用于使用知识蒸馏训练深度神经网络的方法和系统。
背景技术
机器学习模型针对每个接收的输入推断(即预测)特定输出。推断的(即预测的)特定输出可以以la可以属于的形式出现。例如,机器学习模型可以基于接收的图像推断(即预测)特定输出,推断的(即预测的)输出包括一组类别中的每个类别的概率分数,其中每个分数表示图像类似于属于该特定类别的对象的概率。
机器学习模型是使用学习算法进行学习的,如随机梯度下降。使用此类技术学习的机器学习模型是近似于该输入到输出过程的深度人工神经网络。用于近似机器学习模型的深度人工神经网络包括输入层、一个或多个隐藏层、以及输出层,其中所有隐藏层都具有参数,并且非线性应用于这些参数。用于近似机器学习模型的深度人工神经网络通常被称为神经网络模型。
知识蒸馏(Knowledge distillation,KD)是神经网络压缩技术,通过该技术,复杂神经网络模型的学习参数或知识被转移到较不复杂的神经网络模型,该神经网络模型能够以较少的计算资源成本和时间作出与复杂模型相当的推断(即预测)。在此,复杂神经网络模型是指具有相对高数量的计算资源(如GPU/CPU功率和计算机内存空间)的神经网络模型和/或包括相对高数量的隐藏层的那些神经网络模型。为了KD的目的,复杂神经网络模型有时被称为教师神经网络模型(teacher neural network model,T)或简称教师。教师的典型缺点是,其可能需要显著的计算资源,这些计算资源在消费电子设备(如移动通信设备或边缘计算设备)中可能不可用。此外,由于教师神经网络模型本身的复杂度,教师神经网络模型通常需要显著量的时间来推断(即预测)针对输入的特定输出,并且因此教师神经网络模型可能不适合部署到消费计算设备以在其中使用。因此,KD技术应用于提取或蒸馏教师神经网络模型的学习参数或知识,并将此类知识传授给具有更快的推断时间和降低的计算资源和内存空间成本的较不复杂的神经网络模型,这可能会在消费计算设备(如边缘设备)上花费更少的精力。较不复杂的神经网络模型通常被称为学生神经网络模型(studentneural network model,S)或简称学生。
现有技术的KD技术仅考虑针对接收的输入的特定输出的最终推断(即预测)以计算损失函数,因此现有技术的KD技术不能处理从教师的隐藏层到学生的隐藏层的知识转移。因此,可以提高KD技术的准确性,尤其是对于具有多个深度隐藏层的教师和学生神经网络模型。
患者知识蒸馏(patient knowledge distillation,PKD)关注该问题,并引入了层到层的成本函数,还被称为内部层映射。教师神经网络的隐藏层的输出还用于训练学生神经网络模型的一个或多个隐藏层,而不是仅匹配学生和教师神经网络模型的推断的(即预测的)输出。隐藏层可以指神经网络中的内部层。具体地,PKD选择由教师神经网络模型的隐藏层生成的输出的子集,并使用由教师神经网络模型的隐藏层生成的输出来训练学生神经网络的一个或多个隐藏层,如图1所示。具体地,图1示出了示意图,其中具有n=3个内部层的教师神经网络模型100(图1右侧所示的神经网络)与PKD一起用于训练具有m=2个内部层的学生神经网络110。当n>m时,跳过由虚线指示的所示实施例中的教师神经网络神经网络模型的内部层之一,使得教师神经网络模型的剩余内部层中的每个内部层直接用于训练对应的学生神经网络模型110的内部层之一。如图1所示,不仅由学生神经网络模型110和教师神经网络模型100推断(即预测)的最终输出用于计算PKD中KD损失的损失,而且教师和学生神经网络模型的内部层的输出也被匹配,从而学生神经网络模型110可以从教师神经网络模型100内部的信息流中学习。
然而,不存在明确的方法来决定跳过教师的哪些隐藏层,以及保留教师的哪些隐藏层以进行蒸馏。因此,当将n层教师神经网络模型提取成m层学生神经网络模型时,跳过的隐藏层可能会存在显著的信息损失。当n>>m时,信息损失变得更明显。
因此,期望提供知识蒸馏方法的改进,以最小化教师神经网络模型的跳过的内部(即隐藏)层的信息损失。
发明内容
本公开提供用于使用知识蒸馏(KD)训练深度神经网络模型的方法和系统,该方法和系统将教师神经网络模型的跳过的内部层的信息损失最小化。
在一些方面,本公开描述了KD方法,其将教师神经网络模型的内部(即隐藏)层映射到学生神经网络模型的对应的内部层(即隐藏层),以最小化信息损失。
在一些方面,本公开描述了KD方法,其包括教师模型和学生模型的内部层(即隐藏层)的自动映射和层选择。
在一些方面,本公开描述了KD方法,其对层映射采用组合方法,其中n层教师神经网络模型的一个或多个内部(即隐藏)层映射到m层学生神经网络模型的一个或多个层(即隐藏层),其中n>m,从而可以最小化信息损失。
在一些方面,本公开描述了KD方法,其采用组合方法将教师神经网络模型的一个或多个隐藏层映射到学生神经网络模型的内部层(即,隐藏层),与教师和学生神经网络模型的架构无关。例如,本公开的KD方法的各个方面使得能够将知识从教师神经网络模型的内部(即隐藏)层(如变压器模型(transformer model))蒸馏到学生神经网络模型的一个或多个层(如递归神经网络模型)。
在一些方面,本公开描述了KD方法,其可以在使用通用语言理解评估(generallanguage understanding evaluation,GLUE)基准评估的训练的变压器的双向编码器表示(bidirectional encoder representations for transformers,BERT)深度学习学生神经网络模型的性能方面进行改进。
在一些方面,本公开描述了KD方法,其可以在神经机器翻译模型方面进行改进。
在一些方面,本公开描述了知识蒸馏方法,其可能能够将多个教师神经网络模型的内部层(即隐藏层)映射到单个学生神经网络模型。映射内部层(即隐藏层)涉及将教师神经网络模型的一个或多个内部(隐藏)层与单个学生神经网络模型的一个内部(隐藏)层相关联。
在一些方面,本文描述的KD方法可以用于训练可以部署到边缘设备的学生神经网络模型。
在一些方面,本文描述的方法可以聚合不同的信息源,包括多语言/多域语言理解/处理/翻译模型。
在一些方面,本文描述的KD方法可以是与任务无关的,使得其可以适用于训练针对任何特定任务的模型,包括如对象分类等计算机视觉任务。
在一些方面,对于服务外部用户的服务器端模型,可以通过本文描述的KD方法组合多个模型,最终训练的教师模型可以被上传到服务器。通过能够从不同的教师神经网络模型执行知识蒸馏,该方法可以免疫任何对抗性攻击。
根据本公开的第一方面的第一实施例,提供了从具有多个教师隐藏层的教师机器学习模型向具有多个学生隐藏层的学生机器学习模型进行知识蒸馏的方法。该方法包括训练教师机器学习模型,其中该教师机器学习模型被配置为接收输入并生成教师输出,以及在多个训练输入上训练学生机器学习模型,其中学生机器学习模型也被配置为接收输入并生成对应的输出。训练学生机器学习模型包括使用教师机器学习模型处理每个训练输入,以针对训练输入生成教师输出,其中多个教师隐藏层中的每一层生成教师隐藏层输出。训练学生机器学习模型还包括将多个教师隐藏层的子集映射到多个学生隐藏层中的每一层。训练学生机器学习模型还包括计算多个教师隐藏层的子集的教师隐藏层输出的表征,该多个教师隐藏层的子集映射到多个学生隐藏层中的每一层。训练学生机器学习模型还包括训练学生以针对训练输入中的每一个生成近似于针对训练输入的教师输出的学生输出,其中多个学生隐藏层中的每一层针对训练输入中的每一个,被训练为生成学生隐藏层输出,该学生隐藏层输出近似于映射到多个学生隐藏层中每一个的多个教师隐藏层的子集的表征。
在第一方面的第一实施例的一些示例或所有示例中,该方法还包括训练学生,以针对训练输入中的每一个,生成近似于训练输入的真值的学生输出。
在第一方面的第一实施例的一些示例或所有示例中,训练学生以生成近似于针对训练输入的教师输出的学生输出还包括:计算学生输出和教师输出之间的知识蒸馏(KD)损失;计算学生输出和上述真值之间的标准损失;计算多个学生隐藏层中的每一层和映射到多个学生层中的每一层的教师隐藏层的子集之间的组合KD(combinatorial KD,CKD)损失;将总损失计算为KD损失、标准损失和CKD损失的加权平均;以及,调整学生的参数,以最小化总损失。
在一些示例中,CKD损失通过以下方式计算:
其中,是CKD损失,MSE()是均方误差函数,是学生的第i个隐藏层,fi T是通过fi T=F(HT(i))计算的教师的第i个隐藏层,其中F()是教师的第一隐藏层和第三隐藏层提供的融合函数,F()的输出被映射到学生的第二隐藏层Hs和HT分别是学生和教师的所有隐藏层的集合,HT(i)是被选择以待映射到学生的第i个隐藏层的教师隐藏层的子集,并且M()是映射函数,其采用引用学生的隐藏层的索引,并为教师返回一组索引。
在一些示例或所有示例中,融合函数F()包括级联操作,后跟线性映射层。
在一些示例或所有示例中,融合函数F()由以下定义:
其中,“;”是级联算子,mul()是矩阵乘法运算,并且W和b是可学习参数。在该示例中,我们仅考虑了两个层,即来自教师侧的层3和1,但这可以扩展到任何数量的层。
在一些示例或所有示例中,映射函数M()定义了用于组合教师的隐藏层的组合策略。
在一些示例或所有示例中,上述映射还包括,定义用于将该教师隐藏层映射到多个学生隐藏层中的每一层的组合策略。
在一些示例或所有示例中,该组合策略为重叠组合、常规组合、跳过组合和交叉组合中的任一种。
在一些示例或所有示例中,上述映射还包括,针对学生隐藏层中的每一层,将注意力权重分配给多个教师隐藏层的子集中的每一个的教师隐藏层输出。
在一些示例中,CKD可以因注意力或其它形式的组合来增强,在本文中标注为CKD*。CKD*损失通过以下方式计算:
注意力权重(∈ij)的总和为1。
在一些示例中,注意力权重(∈ij)通过以下方式计算:
在一些示例中,注意力权重(∈ij)通过以下方式计算:
根据本公开的第一方面的第二实施例,提供了从各自具有多个教师隐藏层的多个教师机器学习模型向具有多个学生隐藏层的学生机器学习模型进行知识蒸馏的方法。该方法包括:训练多个教师机器学习模型,其中该多个教师机器学习模型中的每一个被配置为接收输入并生成教师输出;以及,在多个训练输入上训练学生机器学习模型,其中该学生机器学习模型也被配置为接收输入并生成学生输出。训练学生机器学习模型包括使用多个教师机器学习模型处理每个训练输入,以针对训练输入生成多个教师输出,多个教师机器学习模型中的每一个的多个教师隐藏层中的每一层生成教师隐藏层输出。训练学生机器学习模型还包括将多个教师机器学习模型的多个教师隐藏层的子集映射到多个学生隐藏层中的每一层。训练学生机器学习模型还包括:计算多个教师隐藏层的子集的教师隐藏层输出的表征,该多个教师隐藏层的子集映射到多个学生隐藏层中的每一层。训练学生机器学习模型还包括:训练学生以针对训练输入中的每一个生成近似于针对训练输入的教师输出的学生输出,其中多个学生隐藏层中的每一层针对训练输入中的每一个,被训练为生成学生隐藏层输出,该学生隐藏输出近似于映射到多个学生隐藏层中的每一层的多个教师隐藏层的子集的表征。
根据本公开的另一方面,提供了计算设备,该计算设备包括处理器和存储器,该存储器上有形地存储有用于由处理器执行的可执行指令。响应于处理器的执行,可执行指令使计算设备执行上述和本文描述的方法。
根据本公开的另一方面,提供了非瞬时性机器可读存储介质,其上有形地存储有用于由计算设备的处理器执行的可执行指令。响应于处理器的执行,可执行指令使计算设备执行上述和本文描述的方法。
对于本领域的普通技术人员来说,在审阅以下特定实施方式的描述后,本公开的其它方面和特征将变得显而易见。
附图说明
现在将通过示例的方式参考示出本申请示例实施例的附图。
图1示出了现有技术PKD方法的示意图;
图2示出了根据本公开的用于使用KD训练神经网络模型的机器学习系统的框图;
图3A示出了还被称为交叉组合的第一组合策略;
图3B示出了还被称为常规组合的第二组合策略;
图3C示出了还被称为跳过组合的第三组合策略;
图3D示出了还被称为重叠组合的第四组合策略;
图4示出了根据本公开的不具有注意力的示例知识蒸馏方法的流程图;
图5示出了在图4的步骤406处确定最终损失值的示例方法的流程图;
图6示出了图4中方法的部分的示例伪代码;
图7示出了具有注意力的增强CKD的高级示意性架构;
图8示出了根据本公开的具有注意力的示例知识蒸馏方法的流程图;
图9示出了表格,示出了在执行通用语言理解评估(GLUE)基准时各种KD模型的模拟结果;
图10示出了根据本公开的用于从多个教师向一个教师进行知识蒸馏的示例方法的流程图;
图11示出了用于计算注意力权重的示例方法的流程图,该示例方法可以在图10所示出的方法的步骤1006处实现;并且
图12示出了可以用于实现本文公开的实施例的示例简化处理系统的框图。
具体实施方式
本公开结合附图进行,其中示出了技术方案的实施例。然而,可以使用许多不同的实施例,并且因此不应将描述解释为限于本文中阐述的实施例。相反,提供这些实施例使得本申请将是彻底和完整的。在可能的情况下,在附图和以下描述中使用相同的附图标记来指代相同的元件,并且在替代实施例中使用素数表示法来指示相同的元件、操作或步骤。所示系统和设备的功能元件的单独框或所示分离不一定需要这些功能的物理分离,因为这些元件之间的通信可以在没有任何此类物理分离的情况下通过消息传递、函数调用、共享内存空间等方式发生。因此,尽管本文为了便于解释,它们被示出为分离的,但是这些功能不需要在物理或逻辑上分离的平台中实现。不同的设备可以具有不同的设计,使得尽管一些设备在固定功能硬件中实现一些功能,但其它设备可以在可编程处理器中实现这些功能,该处理器具有从机器可读存储介质获得的代码。最后,以单数提及的元件可以是复数,反之亦然,除非上下文明确或固有地指示。
本文中阐述的实施例表示足以实践请求保护的主题的信息,并示出了实践此类主题的方法。根据附图阅读以下描述之后,本领域技术人员会理解请求保护的主题的概念,并会认识到这些概念的应用在本文中并没有特别提及。应当理解,这些概念和应用落入本公开和所附权利要求书的范围之内。
此外,应当理解,本文公开的执行指令的任何模块、组件或设备可以包括或以其它方式接入一个或多个非瞬时性计算机/处理器可读存储介质,该介质用于存储信息,如计算机/处理器可读指令、数据结构、程序模块和/或其它数据。非瞬时性计算机/处理器可读存储介质的示例的非穷举式清单包括磁带盒、磁带、磁盘存储或其它磁存储设备,光盘,如光盘只读存储器(compact disc read-only memory,CD-ROM)、数字视频盘或数字多功能盘(即digital versatile disc,DVD)、蓝光盘TM或其它光存储器,在任何方法或技术中实现的易失性和非易失性、可移动和不可移动介质,随机存取存储器(random-access memory,RAM),只读存储器(read-only memory,ROM),电可擦除可编程只读存储器(electricallyerasable programmable read-only memory,EEPROM),闪存或其它存储技术。任何此类非瞬时性计算机/处理器存储介质可以是设备的一部分,或可访问该设备或可与该设备连接。用于实现本文描述的应用或模块的计算机/处理器可读/可执行指令可以由此类非瞬时性计算机/处理器可读存储介质存储或以其它方式保存。
以下是下面描述中可能使用的首字母缩略词和相关定义的部分列表:
NMT:神经机器翻译(neural machine translation)
BERT:变压器的双向编码器表示(bidirectional encoder representation fromtransformers)
KD:知识蒸馏(knowledge distillation)
PKD:患者知识蒸馏(patient knowledge distillation)
S:学生(student)
T:老师(teacher)
RKD:常规知识蒸馏(regular knowledge distillation)
CKD:组合知识蒸馏(combinatorial knowledge distillation)
RC:常规组合(regular combination)
OC:重叠组合(overlap combination)
SC:跳过组合(skip combination)
CC:交叉组合(cross combination)
本文描述了组合知识蒸馏(CKD)方法,用于改进作为在训练学生神经网络模型期间跳过教师神经网络模型的一个或多个隐藏层的结果的信息损失。教师神经网络模型和学生神经网络模型被训练用于特定任务,如对象分类、神经机器翻译等。本描述还描述了知识蒸馏方法的示例实施例,该方法可以使得为了KD的目的,将教师神经网络模型的一个或多个隐藏层映射到单个学生神经网络模型的隐藏层。映射涉及将教师神经网络模型的一个或多个层关联到学生神经网络模型的隐藏层。
1(y=v)是指示函数,如果神经网络模型y针对训练数据样本生成的推断的(即预测的)输出等于真值v,则输出“1”,否则其输出“0”。变量(x,y)是包括在包括几个训练数据样本的训练数据集中的训练数据样本的元组,其中x是输入,并且y是真值输出。参数θS和|V|分别是神经网络模型的参数集合和输出数量。该损失可以被称为标准损失。
然而,当执行传统的神经网络模型训练时,当指示函数1(y=v)返回零时,神经网络模型不会接收到由神经网络模型推断(即预测)的针对不正确输出的任何反馈。没有负反馈来惩罚由神经网络模型推断(即预测)的不正确输出y,可能需要更长的时间来训练神经网络模型。通过将KD应用于训练神经网络模型的过程可部分解决该问题,其中呈等式(1)形式的损失函数被扩展了附加项,如等式(2)所示出的:
其中,由神经网络模型推断(即预测)的输出由于其自身相对于真值的标准损失而受到惩罚,但也由于由q(y=v|x;θT)给出的教师模型的隐藏层生成的输出而受到损失,这可以被称为KD损失。在等式(2)中,损失函数或KD损失的第一分量,即q((y=v|x;θT)),当其与由教师模型的softmax函数作出的推断(即预测)(也被称为软标签)进行比较时通常被称为软损失。其余的损失分量,如标准损失,被称为硬损失。因此,在典型KD方法中,根据等式(3),总体损失函数包括至少两个损失项,即标准损失和KD损失:
其中α是具有值0≤α≤1的超参数。超参数是神经网络模型外部的配置,并且其值无法从数据中估计。
图2示出了根据本公开的用于使用KD训练神经网络模型的机器学习系统200的框图,也可以被称为组合知识蒸馏(CKD)。机器学习系统200可以由电子设备(未示出)的一个或多个物理处理单元实现,例如由一个或多个处理单元执行计算机可读指令(其可以存储在机器人电子设备的存储器中)以执行本文描述的方法。
在所示的实施例中,包括输入张量(x)和对应的输出张量(y)的元组202可以提供给n层教师神经网络模型204,以下称为教师204,以及m层学生神经网络模型206,以下称为学生206,其中n>m。元组202是训练数据样本,其是包括多个元组202的训练数据集的一部分。通常,教师204和学生206中的每一个都被配置为基于元组202的输入张量(x)和神经网络的参数集合推断(即预测)输出张量(y)的神经网络。教师204可以是复杂神经网络模型。学生206可以不如教师204复杂(n>m或具有较少隐藏层和较少模型参数),使得学生206需要比教师204更少的计算资源成本,并且可以在相当短的时间内推断(即预测)针对特定任务的输出。
在一些实施例中,教师204可以使用监督或无监督学习算法在包括多个元组202的训练数据集上训练,以学习教师204的参数和学生206的参数。教师204可以被训练以用于分类任务,使得教师204的推断的(即预测的)输出是包括一组类别中每个类别的概率值的张量。例如,如果输入张量(x)包括包含手写数字的图像,则推断的(即预测的)输出(y)可以包括属于类别中每个类别的手写数字的概率评分,如数字“0”至“9”。在一些实施例中,教师204在训练学生206之前被训练。在一些其它实施例中,教师204和学生模型206被同时训练。
在一些实施例中,教师204是单个神经网络模型。在一些其它实施例中,教师204是集成神经网络模型,其是已经单独地训练的多个单独神经网络模型的编译,其中单个神经网络模型的推断的(即预测的)输出被组合以生成教师204的输出。在一些其他实施例中,集成神经网络模型中的神经网络模型包括推断(即预测)针对一组类别中的每个类别的输出的分类模型,以及仅针对类别的相应子集生成评分的一个或多个专业模型。
在所示的实施例中,推断的(即预测的)输出以及教师204和学生206的隐藏层的输出被提供给损失计算模块208。如图所示,损失计算模块208包括KD损失计算子模块210、CKD损失计算子模块211和标准损失计算子模块212。KD损失计算子模块210比较教师204和学生206的推断的(即预测的)输出,以计算KD损失值,如下文更详细描述的。CKD损失计算子模块211将教师204的隐藏层的子集映射到学生206的隐藏层,并确定由学生206的隐藏层生成的输出和由映射到学生206的隐藏层的教师204的隐藏层的子集生成的输出的表征之间的CKD损失。标准损失计算子模块212将学生206的推断的(即预测的)输出与元组202的真值进行比较,以计算标准损失值,如下文更详细地描述的。损失计算模块208还通过损失函数计算KD损失值、CKD损失值和标准损失值,以生成交叉熵值,该交叉熵值被反向传播到学生206,用于调整学生206的参数(即学生神经网络模型的参数)。在学生206被训练之后,其可以被部署到计算设备(未示出)上,并用于进行预测。在一些实施例中,部署学生206的计算设备是能够以更短的运行时间执行学生206的低容量、低复杂度计算设备。在一些实施例中,部署学生206的计算设备是移动通信设备,如智能手机、智能手表、笔记本电脑或平板电脑。
如上所述,在PKD中,找到教师204的可跳过隐藏层是主要挑战之一。由于教师204的隐藏层和学生206的隐藏层之间不存在一一对应关系,因此在蒸馏过程中,如PKD的现有技术跳过了教师204的隐藏层中的一些隐藏层。因此,利用CKD的机器学习系统200可能能够融合或组合教师204的隐藏层,例如通过CKD损失计算子模块211,并受益于存储在教师204的所有隐藏层中的所有或大多数学习参数。
在一些实施例中,根据本公开的CKD损失函数可以数学地表述为等式(4A):
其中,fi T是通过fi T=F(HT(i))计算的教师的第i个隐藏层,其中F()是融合函数,Hs和HT分别指示学生206和教师204的所有隐藏层的集合。参数HT(i)是选择待映射到学生206的第i个隐藏层的教师204的隐藏层的子集。函数MSE()是均方误差函数,并且是学生206的第i个隐藏层。MSE只是CKD损失函数的许多可能实现之一,并且可以使用任何其它合适的方法,如有效的矩阵范数。
在PKD中,fi T是教师204的第i个隐藏层,而相比之下,在根据本公开的CKD的一些实施例中,fi T是根据等式(4B)通过融合函数F()应用于教师204的隐藏层的选择子集的组合的结果:
在一些实施例中,教师204的所选子集HT(i)经由映射函数M()定义,该映射函数采用引用学生206的隐藏层的索引,并为教师204返回一组索引。基于M()返回的索引,教师204的对应的隐藏层被组合以用于知识蒸馏过程。作为非限制性示例,针对索引为2,函数M(2)可以返回索引{1,3}。因此,融合函数F()由教师204的第一隐藏层和第三隐藏层提供,并且F()的输出被映射到学生206的第二隐藏层。
在一些实施例中,融合函数F()包括级联操作,后跟线性映射层。在上述其中索引=2且函数M(2)可以返回索引{1,3}的示例中,融合函数F()可以呈以下形式:
映射函数M()定义了用于组合教师204的隐藏层的组合策略。图3A至图3D示出了可以通过映射函数M()实现的一些示例组合策略。特别地,图3A至图3D中的每个图示出了包括5个隐藏层的教师204和包括2个隐藏层的学生206之间的层组合策略。图3A示出了还被称为交叉组合(cross combination,CC)的第一组合策略,其中教师204的每第m隐藏层被组合用于蒸馏到学生206的m层中的对应的一个层。在所示的示例中,针对包括2个隐藏层的学生206,教师204的第一隐藏层第三隐藏层和第五隐藏层被组合用于蒸馏到学生206的第一隐藏层并且教师204的第二隐藏层和第四隐藏层被组合用于蒸馏到学生206的第二隐藏层图3B示出了还被称为常规组合(regularcombination,RC)的第二组合策略,其中教师204的近似相等数量的连续隐藏层被组合用于蒸馏到学生206的对应的隐藏层。在所示的实施例中,教师204的前三个隐藏层被组合用于蒸馏到学生206的第一隐藏层并且教师204的第四隐藏层和第五隐藏层被组合用于蒸馏到学生206的第二隐藏层应当理解,针对n层教师204(即包括n个隐藏层的教师204)和m层学生206(即包括m个隐藏层的学生206),其中n是m的倍数,教师204的偶数隐藏层可以针对学生206的每个隐藏层组合。可替代地,如果n不是m的精确倍数,则学生206的隐藏层的选择数量可以与教师204的更多组合隐藏层相关联以进行蒸馏。图3C示出了还被称为跳过组合(skip combination,SC)的第三组合策略,其中跳过教师204的一些隐藏层以进行蒸馏。在一些实施例中,可以跳过教师204的每第(m+1)个隐藏层。在一些其它实施例中,跳过教师204的一个或多个隐藏层,使得教师204的相等数量的隐藏层被组合以用于蒸馏到学生206的隐藏层之一。在一些其它实施例中,隐藏层可以以常规或非常规的间隔跳过,并且然后可以细分和组合剩余的隐藏层以进行蒸馏。可以应用确定跳过间隔的任何其它方法。在所示的示例中,针对具有2个隐藏层的学生206,跳过教师204的第三隐藏层使得教师204的第一隐藏层和第二隐藏层被组合用于蒸馏到学生206的第一隐藏层并且教师204的第四隐藏层和第五隐藏层被组合用于蒸馏到学生206的第二隐藏层图3D示出了还被称为重叠组合(overlap combination,OC)的第四组合策略,其中教师204的一个或多个隐藏层被组合成用于蒸馏到学生206的多个隐藏层的多组隐藏层。在所示的实施例中,教师204的第三隐藏层被组合用于蒸馏学生206的第一隐藏层和第二隐藏层两者。除了本文描述的四种组合策略之外,还可以应用任何其它合适的组合策略。CKD中的组合策略的类型可以在从教师204的隐藏层的不同配置中蒸馏方面提供灵活性。组合策略可以手动确定(即,不具有注意力),或由机器学习系统200自动确定(即,具有注意力)。
图4示出了根据本公开的不具有注意力的示例知识蒸馏方法400的流程图。
在步骤402处,在初始化损失值之后,如通过将它们设置为0(例如),预定义的层映射函数M()用于确定教师204的哪些隐藏层待被组合并与学生206的哪些隐藏层相关联。在此,组合层意味着组合由这些隐藏层生成的输出。
在步骤406处,学生204和教师206均被提供有相同的输入元组202,并且部分地基于基于元组202的输入和教师204的隐藏层的输出的教师204的推断的(即预测的)输出来计算最终损失值。
图5示出了在步骤406处确定最终损失值的示例方法500的流程图。
在步骤502处,例如根据等式(4),针对教师204的隐藏层中的每一层计算均方误差(MSE)损失(也被称为)。应当理解,基于MSE损失的LCKD是该损失函数的许多可能实现之一,并且可以使用任何有效的矩阵范数来代替MSE。
在步骤508处,根据等式(5),将最终损失计算为在步骤502处、504处和506处确定的损失的加权值:
其中α,β,和η是示出每个损失贡献的系数。
返回参考图4,在步骤408处,相对于损失值的梯度被计算并被用于更新学生206的对推断的(即预测的)输出y有贡献的所有参数。
在步骤410处,所有损失值再次被初始化回其相应的初始值,如零,并重复步骤404至410以进行进一步迭代,直到满足完成条件,如损失值降到可接受的阈值以下。图6示出了方法400的部分的示例伪代码。
在一些情况下,手动定义的组合策略可能不呈现针对教师204的不同隐藏层定义功能M(i)的最优组合方法。因此,本公开的一些其他实施例提供了具有注意力的增强CKD方法,其可以自动定义用于组合教师204的隐藏层的最优策略。在至少一个方面,基于注意力的KD解决了搜索待组合的教师204的隐藏层的代价高昂的问题。具体地,PKD和其它类似的方法可能需要训练几个学生206以搜索在训练期间应跳过的教师204的隐藏层,并且在教师204是深度神经网络模型的情况下,(如果有的话)寻找最佳/最优解决方案可能是耗时的。
在一些实施例中,可以在学生206的相应的隐藏层中的每个隐藏层(HS)和教师204的隐藏层(HT)之间学习注意力权重。每个注意力权重可以是有多少教师204的隐藏层对学生206的给定隐藏层的知识蒸馏过程有贡献的指示。然后,机器学习系统200可以优化注意力权重,以试图实现教师204的隐藏层和学生206的隐藏层之间的最优知识蒸馏。
图7示出了具有注意力的增强CKD的高级示意性架构,或本文可以被标注为CKD*。如图7所示出的,组合策略可能不需要经由M()手动定义用于CKD*中的知识蒸馏,而是CKD*将教师204的所有隐藏层(至)考虑在内,并将注意力权重(ε11、ε12、ε13、ε14、和ε15)分配到教师204的隐藏层(至)中的每一层。注意力权重指示由教师204的隐藏层生成的特定输出的贡献量,该输出在蒸馏期间待用于学生206的给定隐藏层。教师204的每个隐藏层的输出张量可以应用其对应的注意力权重,以计算教师204的所有隐藏层的输出的加权平均。加权平均是可以连接到学生206的隐藏层以用于在其之间进行知识蒸馏的张量。总损失根据等式(6)表示:
其中MSE损失是损失函数的许多可能实现之一,并且也可以使用任何其它合适的损失函数,如KL散度。fi *T是针对教师204的第i个隐藏层的教师204的隐藏层(HT)的基于注意力的组合表征,并且,针对学生206和教师204具有相同维度的实施例,fi *T可以根据等式7A确定:
其中,是教师204的第i个隐藏层的权重值,指示教师204的不同隐藏层在知识传输过程中对学生206的每个特定隐藏层的贡献量,以及因此相对重要性。注意力权重(∈ij)总和应为1,并可以根据等式(8A)计算:
基于点积的能量函数可以允许待通过附加映射层处理的两个隐藏层的输出之间的任何潜在维度失配,如下文更详细地描述的。
图8示出了根据本公开的具有注意力的示例知识蒸馏方法800的流程图。
在步骤802处,初始化损失值,如通过将其设置为0(例如)。与方法400不同,层映射函数M()不需要被明确地定义。相反,针对学生206的每个隐藏层,权重值被分配给教师204的隐藏层的子集。在一些实施例中,教师204的隐藏层的子集可以包括教师204的所有隐藏层。从概念上讲,权重值可以用作教师204的隐藏层到学生206的隐藏层中的每一层的隐式映射。
在步骤806处,学生206和教师204均被提供有相同元组202的输入向量(x),并且例如根据方法500,部分地基于教师和学生206的基于元组202的输入向量(x)的推断的(即预测的)输出计算最终CKD*损失值。
在步骤808处,计算相对于损失值的梯度,并且计算的梯度用于更新学生206的所有参数。
在步骤810处,所有损失值再次初始化回其相应的初始值,如零,并重复步骤804至810进行进一步迭代,直到满足完成条件,如损失值降到可接受的阈值以下。
应当理解,尽管本文描述了利用教师204的所有隐藏层的加权平均将知识转移到学生206的每个隐藏层的CKD*的实施例,但CKD*的其它实施例可以利用教师204的部分隐藏层的加权平均。另外,可以存在重叠的教师的隐藏层的子集,这些隐藏层被组合用于学生206的不同隐藏层。例如,教师204的隐藏层中的一半可以用于在没有重叠的情况下组合学生206的隐藏层中的一半。可替代地,教师204的隐藏层中的三分之二(即教师204的隐藏层中的第一1/3隐藏层和第二1/3隐藏层)可以用于组合学生的隐藏层中的一半,并且教师204的隐藏层中的另外三分之二(第二1/3隐藏层和第三1/3隐藏层)可以用于在部分重叠的情况下组合学生206的隐藏层中的另一半。
有利地,上述CKD*方法可以实现自动教师隐藏层映射选择,其可以实现最优知识转移。在一些实施例中,CKD*方法可以在BERT和神经机器翻译模型的性能方面进行改进。图9示出了表格,示出在执行本领域已知的通用语言理解评估(GLUE)基准时各种KD模型的模拟结果,包括具有105,000个数据点的问题自然语言推断(question natural languageinference,QNLI)、具有3,700个数据点的微软研究释义语料库(microsoft researchparaphrase corpus,MRPC)和具有2,500个数据点的识别文本蕴涵(recognizing textualentailment,RTE)。在图9所示出的表格中,“T”表示教师,其是基于BERT的模型,具有12个隐藏层,12个注意力头,并且隐藏尺寸为768。所有学生206都是BERT_4模型,具有4个隐藏层,12个注意力头,并且隐藏尺寸为768。在学生206中,“NKD”表示无KD,“KD”表示常规的现有技术KD。表格的最后三列分别是应用了具有无注意力重叠(即T[1,2,3,4]->S1、T[5,6,7,8]->S2、T[9,10,11,12]->S3)、部分注意力重叠(即T[1,2,3,4,5]->S1、T[5,6,7,8,9]->S2、T[9,10,11,12]->S3),和完全注意力重叠(即,教师204的所有12个隐藏层用于梳理学生206的每个隐藏层)的CKD*的学生。如图所示,在QNLI基准中,具有无注意力重叠的CKD*实现了评分为87.11,优于其余学生,在RTE基准中,具有完全注意力重叠的CKD*实现了评分为67.15,优于所有其他学生206。在MRPC基准中,具有完全注意力重叠的CKD*评分为80.72,优于所有其他学生。
在一些其他实施例中,本文描述的基于注意力的CKD*方法也可以应用于多教师/多任务KD场景。具体地,CKD*方法可以被应用来组合来自不同教师204的不同隐藏层,以将知识蒸馏到一个学生206中,而不是通过注意力来组合单个教师204的不同隐藏层。至少一些现有技术KD方法在不同的训练迭代中迭代多个教师204,并独立地考虑它们,如Clark,K.、Luong,M.T.、Khandelwal,U.、Manning,C.D.和Le,Q.V.的“瓶颈注意力模块!用于自然语言理解的再生多任务网络(Bam!born-again multi-task networks for naturallanguage understanding)(2019)”中公开的。例如,用K个不同的训练数据集 训练K个不同的教师204,其中Nq指定第q教师204的训练数据集中的训练数据样本的数量。上述用于训练具有多个教师204的学生206的现有技术方案,并且特别是第q教师Tq和学生S之间的KD损失可以在数学上表征如下:
其中LKD可以是任何损失函数,如库尔贝克·莱布勒(Kullback–Leibler,KL)散度或均方误差。为了计算训练数据集的数据样本被发送到第q教师Tq和学生S,以获得其相应的推断的(即预测的)输出。
对于根据本公开的实施例,上述多教师根据等式(10)通过CKD/CKD*方法扩展:
其中,每个教师Tp的隐藏层的权重值∈pq可以根据等式(11A)确定:
在一些实施例中,Φ(.)函数是两个输入向量(x)用于测量两个输入向量(x)之间的能量或相似性的点积,如可以根据等式11(B)计算的:
在一些其他实施例中,该Φ(.)函数是神经网络或用于测量两个输入向量(x)的能量的任何其它合适的函数。
图10示出了根据本公开的用于从多个教师204到一个学生206的知识蒸馏的示例方法1000的流程图。
在步骤1002处,初始化每个教师204的参数,包括每个教师204的隐藏层的注意力权重。
其中LKD可以是任何合适的损失函数,包括KL散度或均方误差(mean squarederror,MSE)。为了计算将训练数据集的数据样本作为输入提供给第q教师204和学生206,以获得第q教师204和学生206的推断的(即预测的)输出。
在步骤1006处,针对每个训练数据集q,计算所有K名教师204的注意力权重(即,{∈1q,∈2q,…,∈Kq}的集合,其中)。图11示出了可以在步骤1006处实现的用于计算注意力权重的示例方法1100的流程图。
返回参考图10,在步骤1008处,针对所有K名教师204的权重KD损失被总计为教师204在每个时间步骤处的总KD损失。应当理解,尽管本文描述的实施例包括损失,但在一些其它实施例中,取决于每个问题的设计,训练除的损失函数也是适用的。
在步骤1010处,计算相对于教师204的参数的总KD损失梯度,并更新学生206的参数。
在一些实施例中,方法1000将如本文描述的CKD*扩展到多个教师知识蒸馏场景。这可以允许多任务蒸馏、多语言蒸馏、多检查点蒸馏和任何其它期望从多个教师蒸馏到单个学生的应用。
参考图12,可以用于实现本文公开的实施例的示例简化处理系统1200的框图,并提供了更高级别的实现示例。教师204和学生206以及包括在机器学习系统200中的其它功能可以使用示例处理系统1200或处理系统1200的变体来实现。处理系统1200可以是终端,例如桌面终端、平板电脑、笔记本电脑、AR/VR或车载终端,或可以是服务器、云端或任何合适的处理系统。可以使用适合于实现本公开中描述的实施例的其它处理系统,这些系统可以包括与下面讨论的那些组件不同的组件。虽然图12示出了每个组件的单个实例,但是在处理系统1200中可能存在每个组件的多个实例。
处理系统1200可以包括一个或多个处理设备1202,如图形处理单元、处理器、微处理器、专用集成电路(application-specific integrated circuit,ASIC)、现场可编程门阵列(field-programmable gate array,FPGA)、专用逻辑电路、加速器、张量处理单元(tensor processing unit,TPU)、神经处理单元(neural processing unit,NPU),或它们的组合。处理系统1200还可以包括一个或多个输入/输出(input/output,I/O)接口1204,其可以实现与一个或多个适当的输入设备1214和/或输出设备1216介接。处理系统1200可以包括一个或多个网络接口1206用于与网络进行有线通信或无线通信。
处理系统1200还可以包括一个或多个存储单元1208,该一个或多个存储单元可以包括如固态驱动器、硬盘驱动器、磁盘驱动器和/或光盘驱动器等大容量存储单元。处理系统1200可以包括一个或多个存储器1210,该一个或多个存储器可以包括易失性或非易失性存储器(例如,闪存、随机存取存储器(random access memory,RAM)和/或只读存储器(read-only memory,ROM))。存储器1210的非瞬时性存储器可以存储用于由处理设备1202执行的指令,如用于执行本公开中描述的示例,例如用于机器学习系统200的CKD/CKD*指令和数据1212。存储器1210可以包括其它软件指令,如用于实现处理系统1200和其它应用/功能的操作系统。在一些示例中,一个或多个数据集和/或模块可以由外部存储器(例如,与处理系统1200进行有线通信或无线通信的外部驱动器)提供,或可以由瞬时性或非瞬时性计算机可读介质提供。非瞬时性计算机可读介质的示例包括RAM、ROM、可擦除可编程ROM(EPROM)、电可擦除可编程ROM(electrically erasable programmable ROM,EEPROM)、闪存、CD-ROM或其它便携式存储装置。
处理系统1200还可以包括总线1218,该总线提供处理系统1200的组件之间的通信,包括处理设备1202、I/O接口1204、网络接口1206、存储单元1208和/或存储器1210。总线1218可以是任何合适的总线架构,例如包括存储器总线、外围总线或视频总线。
教师204和学生206的计算可以由处理系统1200的任何合适的处理设备1202或其变体执行。此外,教师204和学生206可以是任何合适的神经网络模型,包括如递归神经网络模型、长短期记忆(long short-term memory,LSTM)神经网络模型等变体。
综述
尽管本公开可以描述具有一定顺序的步骤的方法和过程,但是可以适当地省略或改变方法和过程中的一个或多个步骤。在适当情况下,一个或多个步骤可以按所描述的顺序以外的顺序执行。
尽管本公开可以在方法方面至少部分地进行描述,但本领域普通技术人员将理解,本公开还针对用于执行所描述方法的至少一些方面和特征的各种组件,无论是通过硬件组件、软件或这两者的任何组合。相应地,本公开的技术方案可以以软件产品的形式体现。合适的软件产品可以存储在预先记录的存储设备或其它类似的非易失性或非瞬时性计算机可读介质中,例如包括DVD、CD-ROM、USB闪存盘、可移动硬盘或其它存储介质。软件产品包括有形地存储在其上的指令,这些指令使得处理设备(例如,个人计算机、服务器或网络设备)能够执行本文中公开的方法的示例。
在不脱离权利要求书的主题的前提下,本公开可以通过其它特定形式来体现。所描述的示例实施例在所有方面都应被视为仅是示意性的,而非限制性的。可以组合从一个或多个上述实施例中选择的特征,以创建未明确描述的可选实施例,在本公开的范围内可以理解适合于此类组合的特征。
还公开了所公开范围内的所有值和子范围。此外,尽管本文所公开和示出的系统、设备和过程可以包括特定数量的元件/组件,但可以修改这些系统、设备和组件,以包括更多或更少此类元件/组件。例如,尽管所公开的任何元件/组件可以引用为单数,但是可以修改本文所公开的实施例以包括多个此类元件/组件。本文描述的主题旨在覆盖和涵盖所有适当的技术变更。
Claims (22)
1.一种从具有多个教师隐藏层的教师神经模型向具有多个学生隐藏层的学生神经模型进行知识蒸馏的方法,所述方法包括:
训练所述教师神经模型,其中所述教师神经模型被配置为接收输入并生成教师输出;以及
在多个训练输入上训练所述学生神经模型,其中所述学生神经模型也被配置为接收输入并生成对应的输出,包括:
使用所述教师神经模型处理每个训练输入,以针对所述训练输入生成所述教师输出,所述多个教师隐藏层中的每一层生成教师隐藏层输出;
将所述多个教师隐藏层的子集映射到所述多个学生隐藏层中的每一层;
计算映射到所述多个学生隐藏层中的每一层的所述多个教师隐藏层的子集的所述教师隐藏层输出的表征;以及
训练所述学生以针对所述训练输入中的每一个生成近似于针对所述训练输入的所述教师输出的学生输出,其中,针对所述训练输入中的每一个,所述多个学生隐藏层中的每一层被训练为生成学生隐藏层输出,所述学生隐藏层输出近似于映射到所述多个学生隐藏层中的每一层的所述多个教师隐藏层的子集的所述表征。
2.根据权利要求1所述的方法,还包括训练所述学生,以针对所述训练输入中的每一个,生成近似于所述训练输入的真值的学生输出。
3.根据权利要求2所述的方法,其中训练所述学生以生成近似于针对所述训练输入的所述教师输出的学生输出还包括:
计算所述学生输出和所述教师输出之间的知识蒸馏KD损失;
计算所述学生输出和所述真值之间的标准损失;
计算所述多个学生隐藏层中的每一层和映射到所述多个学生层中的每一层的所述教师隐藏层的子集之间的组合KD CKD损失;
将总损失计算为所述KD损失、所述标准损失和所述CKD损失的加权平均;以及
调整所述学生的参数,以最小化所述总损失。
5.根据权利要求4所述的方法,其中所述融合函数F()包括级联操作,后跟线性映射层。
7.根据权利要求4所述的方法,其中所述映射函数M()定义用于组合所述教师的隐藏层的组合策略。
8.根据权利要求7所述的方法,其中所述组合策略为重叠组合、常规组合、跳过组合和交叉组合中的任一种。
9.根据权利要求1至8中任一项所述的方法,其中所述映射还包括,定义用于将所述教师隐藏层映射到所述多个学生隐藏层中的每一层的组合策略。
10.根据权利要求9所述的方法,其中所述组合策略为重叠组合、常规组合、跳过组合和交叉组合中的任一种。
11.根据权利要求1至10中任一项所述的方法,其中所述映射还包括,针对所述学生隐藏层中的每一层,将注意力权重分配给所述多个教师隐藏层的子集中的每一个的所述教师隐藏层输出。
20.一种从各自具有多个教师隐藏层的多个教师神经模型向具有多个学生隐藏层的学生神经模型进行知识蒸馏的方法,所述方法包括:
推断所述多个教师神经模型,其中所述多个教师神经模型中的每一个被配置为接收输入并生成教师输出;以及
在多个训练输入上训练所述学生神经模型,其中所述学生神经模型也被配置为接收输入并生成学生输出,包括:
使用所述多个教师神经模型处理每个训练输入,以针对所述训练输入生成多个教师输出,所述多个教师神经模型中的每一个的多个教师隐藏层中的每一层生成教师隐藏层输出;
将所述多个教师神经模型的所述多个教师隐藏层的子集映射到所述多个学生隐藏层中的每一层;
计算所述多个教师隐藏层的子集的教师隐藏层输出的表征,所述多个教师隐藏层的子集映射到所述多个学生隐藏层中的每一层;以及
训练所述学生以针对所述训练输入中的每一个生成近似于针对所述训练输入的所述教师输出的学生输出,其中所述多个学生隐藏层中的每一层针对所述训练输入中的每一个,被训练为生成学生隐藏层输出,所述学生隐藏层输出近似于映射到所述多个学生隐藏层中的每一层的所述多个教师隐藏层的子集的所述表征。
21.一种处理系统,包括:
处理设备;
存储器,所述存储器存储有指令,所述指令当由所述处理设备执行时,使所述处理系统执行根据权利要求1至20中任一项所述的方法。
22.一种计算机可读介质,包括指令,所述指令当由处理系统的处理设备执行时,使所述处理系统执行根据权利要求1至20中任一项所述的方法。
Applications Claiming Priority (5)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
US202063076335P | 2020-09-09 | 2020-09-09 | |
US63/076,335 | 2020-09-09 | ||
US17/469,573 | 2021-09-08 | ||
US17/469,573 US20220076136A1 (en) | 2020-09-09 | 2021-09-08 | Method and system for training a neural network model using knowledge distillation |
PCT/CN2021/117532 WO2022052997A1 (en) | 2020-09-09 | 2021-09-09 | Method and system for training neural network model using knowledge distillation |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116134454A true CN116134454A (zh) | 2023-05-16 |
Family
ID=80469799
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202180054067.1A Pending CN116134454A (zh) | 2020-09-09 | 2021-09-09 | 用于使用知识蒸馏训练神经网络模型的方法和系统 |
Country Status (3)
Country | Link |
---|---|
US (1) | US20220076136A1 (zh) |
CN (1) | CN116134454A (zh) |
WO (1) | WO2022052997A1 (zh) |
Families Citing this family (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11599794B1 (en) * | 2021-10-20 | 2023-03-07 | Moffett International Co., Limited | System and method for training sample generator with few-shot learning |
CN114663941B (zh) * | 2022-03-17 | 2024-09-06 | 深圳数联天下智能科技有限公司 | 特征检测方法、模型合并方法、设备和介质 |
WO2023203775A1 (ja) * | 2022-04-22 | 2023-10-26 | 株式会社ソシオネクスト | ニューラルネットワーク生成方法 |
CN114913400B (zh) * | 2022-05-25 | 2024-07-16 | 天津大学 | 基于知识蒸馏的海洋大数据协同表示的预警方法及装置 |
CN115082690B (zh) * | 2022-07-12 | 2023-03-28 | 北京百度网讯科技有限公司 | 目标识别方法、目标识别模型训练方法及装置 |
CN114998570B (zh) * | 2022-07-19 | 2023-03-28 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
CN115223049B (zh) * | 2022-09-20 | 2022-12-13 | 山东大学 | 面向电力场景边缘计算大模型压缩的知识蒸馏与量化方法 |
CN115271272B (zh) * | 2022-09-29 | 2022-12-27 | 华东交通大学 | 多阶特征优化与混合型知识蒸馏的点击率预测方法与系统 |
CN115471645A (zh) * | 2022-11-15 | 2022-12-13 | 南京信息工程大学 | 一种基于u型学生网络的知识蒸馏异常检测方法 |
CN116416212B (zh) * | 2023-02-03 | 2023-12-08 | 中国公路工程咨询集团有限公司 | 路面破损检测神经网络训练方法及路面破损检测神经网络 |
CN116028891B (zh) * | 2023-02-16 | 2023-07-14 | 之江实验室 | 一种基于多模型融合的工业异常检测模型训练方法和装置 |
CN116205290B (zh) * | 2023-05-06 | 2023-09-15 | 之江实验室 | 一种基于中间特征知识融合的知识蒸馏方法和装置 |
CN117315516B (zh) * | 2023-11-30 | 2024-02-27 | 华侨大学 | 基于多尺度注意力相似化蒸馏的无人机检测方法及装置 |
Family Cites Families (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
KR102492318B1 (ko) * | 2015-09-18 | 2023-01-26 | 삼성전자주식회사 | 모델 학습 방법 및 장치, 및 데이터 인식 방법 |
US11410029B2 (en) * | 2018-01-02 | 2022-08-09 | International Business Machines Corporation | Soft label generation for knowledge distillation |
CN109165738B (zh) * | 2018-09-19 | 2021-09-14 | 北京市商汤科技开发有限公司 | 神经网络模型的优化方法及装置、电子设备和存储介质 |
CN111105008A (zh) * | 2018-10-29 | 2020-05-05 | 富士通株式会社 | 模型训练方法、数据识别方法和数据识别装置 |
CN111242297A (zh) * | 2019-12-19 | 2020-06-05 | 北京迈格威科技有限公司 | 基于知识蒸馏的模型训练方法、图像处理方法及装置 |
CN111611377B (zh) * | 2020-04-22 | 2021-10-29 | 淮阴工学院 | 基于知识蒸馏的多层神经网络语言模型训练方法与装置 |
-
2021
- 2021-09-08 US US17/469,573 patent/US20220076136A1/en active Pending
- 2021-09-09 WO PCT/CN2021/117532 patent/WO2022052997A1/en active Application Filing
- 2021-09-09 CN CN202180054067.1A patent/CN116134454A/zh active Pending
Also Published As
Publication number | Publication date |
---|---|
WO2022052997A1 (en) | 2022-03-17 |
US20220076136A1 (en) | 2022-03-10 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116134454A (zh) | 用于使用知识蒸馏训练神经网络模型的方法和系统 | |
CN111291183B (zh) | 利用文本分类模型进行分类预测的方法及装置 | |
US11755885B2 (en) | Joint learning of local and global features for entity linking via neural networks | |
US10803591B2 (en) | 3D segmentation with exponential logarithmic loss for highly unbalanced object sizes | |
US8234228B2 (en) | Method for training a learning machine having a deep multi-layered network with labeled and unlabeled training data | |
EP3029606A2 (en) | Method and apparatus for image classification with joint feature adaptation and classifier learning | |
US11537930B2 (en) | Information processing device, information processing method, and program | |
CN116171446A (zh) | 通过对抗学习和知识蒸馏训练神经网络模型的方法及系统 | |
US20230222326A1 (en) | Method and system for training a neural network model using gradual knowledge distillation | |
CN109344404A (zh) | 情境感知的双重注意力自然语言推理方法 | |
Bagherzadeh et al. | A review of various semi-supervised learning models with a deep learning and memory approach | |
CN109214006A (zh) | 图像增强的层次化语义表示的自然语言推理方法 | |
Tang et al. | Modelling student behavior using granular large scale action data from a MOOC | |
CN113837370A (zh) | 用于训练基于对比学习的模型的方法和装置 | |
Triebel et al. | Driven learning for driving: How introspection improves semantic mapping | |
CN113761868B (zh) | 文本处理方法、装置、电子设备及可读存储介质 | |
CN113609337A (zh) | 图神经网络的预训练方法、训练方法、装置、设备及介质 | |
US11610393B2 (en) | Knowledge distillation for neural networks using multiple augmentation strategies | |
Bucher et al. | Semantic bottleneck for computer vision tasks | |
CN112819024A (zh) | 模型处理方法、用户数据处理方法及装置、计算机设备 | |
CN112131345A (zh) | 文本质量的识别方法、装置、设备及存储介质 | |
CN115238169A (zh) | 一种慕课可解释推荐方法、终端设备及存储介质 | |
US11948387B2 (en) | Optimized policy-based active learning for content detection | |
CN111783473B (zh) | 医疗问答中最佳答案的识别方法、装置和计算机设备 | |
US20220138425A1 (en) | Acronym definition network |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |