CN113255763B - 基于知识蒸馏的模型训练方法、装置、终端及存储介质 - Google Patents

基于知识蒸馏的模型训练方法、装置、终端及存储介质 Download PDF

Info

Publication number
CN113255763B
CN113255763B CN202110558102.5A CN202110558102A CN113255763B CN 113255763 B CN113255763 B CN 113255763B CN 202110558102 A CN202110558102 A CN 202110558102A CN 113255763 B CN113255763 B CN 113255763B
Authority
CN
China
Prior art keywords
model
representation
output
distillation
knowledge distillation
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110558102.5A
Other languages
English (en)
Other versions
CN113255763A (zh
Inventor
于凤英
王健宗
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202110558102.5A priority Critical patent/CN113255763B/zh
Publication of CN113255763A publication Critical patent/CN113255763A/zh
Application granted granted Critical
Publication of CN113255763B publication Critical patent/CN113255763B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Feedback Control In General (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于知识蒸馏的模型训练方法、装置、终端及存储介质,其中方法包括:获取预先训练好的第一模型、待训练的第二模型和训练样本集,第一模型对应的模态与第二模型对应的模态不同;将训练样本输入至第一模型和第二模型,得到第一模型嵌入层的第一表示和第一模型中间层的第一输出、以及第二模型嵌入层的第二表示和第二模型中间层的第二输出;利用第一表示和第二表示进行对比自监督学习,再根据对比自监督学习结果更新第二模型嵌入层的参数,利用第一输出和第二输出进行知识蒸馏,再根据知识蒸馏结果更新第二模型中间层的参数,得到训练好的第二模型。通过上述方式,本发明能够针对不同模态的模型进行知识蒸馏,以快速完成模型的训练。

Description

基于知识蒸馏的模型训练方法、装置、终端及存储介质
技术领域
本申请涉及人工智能技术领域,特别是涉及一种基于知识蒸馏的模型训练方法、装置、终端及存储介质。
背景技术
随着人工智能识别的发展,普遍采用模型进行数据处理、图像识别等。通常地,对于不同应用场景有定制化模型需求的时候,技术人员选择的模型训练方式大致有两种:一、使用通用数据集训练好的通用模型修改最后输出层的类别数量,然后使用自己的数据集对模型参数进行重新调整;二、自己设计结构简单的模型,使用自己的数据集从头训练模型参数。其中,前一种方法使用已训练好的模型参数继续训练,能够加快训练收敛,也能保证模型精度,但是模型较大参数众多,对于小分类任务来说“大材小用”,而且不利于部署在计算力有限的终端设备上。第二种方法可以定制结构简单参数较少的小模型,但是模型参数需要从头训练,势必会减慢收敛速度,也不能保证模型的精度,而且训练数据集较小的情况下,还容易造成模型过拟合。
目前,针对于上述问题提出了一种知识蒸馏技术。知识蒸馏是一种模型压缩方法,在教师-学生框架中,将复杂、学习能力强的教师模型学到的特征表示“知识”蒸馏出来,传递给参数量小、学习能力弱的学生模型。简单的说就是用新的小模型去学习大模型的预测结果,复杂模型或者组合模型的中“知识”通过合适的方式迁移到一个相对简单模型之中,进而方便模型推广部署。
但是,现有的知识蒸馏均是针对于相同模态的教师模型和学生模型,而无法实现不同模态模型之间的知识蒸馏。
发明内容
本申请提供一种基于知识蒸馏的模型训练方法、装置、终端及存储介质,以解决现有的知识蒸馏技术无法实现不同模态模型之间的知识蒸馏的问题。
为解决上述技术问题,本申请采用的一个技术方案是:提供一种基于知识蒸馏的模型训练方法,包括:获取预先训练好的第一模型、待训练的第二模型和训练样本集,第一模型对应第一模态,第二模型对应第二模态,第一模态与第二模态不同;将训练样本分别输入至第一模型和第二模型,得到第一模型嵌入层的第一表示和第一模型中间层的第一输出、以及第二模型嵌入层的第二表示和第二模型中间层的第二输出;利用第一表示和第二表示进行对比自监督学习,得到对比自监督学习结果;利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果;根据对比自监督学习结果更新第二模型嵌入层的参数,并根据知识蒸馏结果更新第二模型中间层的参数,得到训练好的第二模型。
作为本申请的进一步改进,根据对比自监督学习结果更新第二模型嵌入层的参数,包括:基于对比自监督学习结果计算第一表示与第二表示之间的互信息,互信息的计算公式为:
Figure BDA0003077901880000021
其中,lMI为所述互信息,P(p,g|c=1)为所述第一表示与所述第二表示对应于相同的答题式的条件概率,P(p,g|c=0)为所述第一表示与所述第二表示对应于不同的答题式的条件概率,E表示求均值,sp为softplus函数,Ptext为所述第一表示,Pgraph为所述第二表示,T(Ptext,Pgraph)是对比自监督学习的结果;
根据互信息更新第二模型嵌入层的参数。
作为本申请的进一步改进,利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果,包括:计算第一输出与第二输出之间的蒸馏损失值,蒸馏损失值的计算公式为:
Figure BDA0003077901880000031
其中,lKD为蒸馏损失值,T为预先设定的softmax函数的超参数,KL指KL散度计算,σ为softmax函数,ztext为第一输出,zgraph为第二输出;
将蒸馏损失值作为知识蒸馏结果。
作为本申请的进一步改进,获取预先训练好的第一模型、待训练的第二模型和训练样本集之后,还包括:将训练样本集转换为符合第一模态的第一样本集、以及符合第二模态的第二样本集。
作为本申请的进一步改进,将训练样本分别输入至第一模型和第二模型,包括:将第一样本集中的目标第一样本输入至第一模型;将第二样本集中与目标第一样本对应的目标第二样本输入至第二模型。
作为本申请的进一步改进,第一模型为图神经网络分类模型,第二模型为文本分类模型,文本分类模型用于对输入文本中的实体之间的关系进行推理。
为解决上述技术问题,本申请采用的另一个技术方案是:提供一种基于知识蒸馏的模型训练装置,包括:获取模块,用于获取预先训练好的第一模型、待训练的第二模型和训练样本集,第一模型对应第一模态,第二模型对应第二模态,第一模态与第二模态不同;输入模块,用于将训练样本分别输入至第一模型和第二模型,得到第一模型嵌入层的第一表示和第一模型中间层的第一输出、以及第二模型嵌入层的第二表示和第二模型中间层的第二输出;学习模块,用于利用第一表示和第二表示进行对比自监督学习,得到对比自监督学习结果;蒸馏模块,用于利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果;更新模块,用于根据对比自监督学习结果更新第二模型嵌入层的参数,并根据知识蒸馏结果更新第二模型中间层的参数,得到训练好的第二模型。
作为本申请的进一步改进,装置还包括:转换模块,用于将训练样本集转换为符合第一模态的第一样本集、以及符合第二模态的第二样本集。
为解决上述技术问题,本申请采用的再一个技术方案是:提供一种终端,该终端包括处理器、与处理器耦接的存储器,存储器中存储有程序指令,程序指令被处理器执行时,使得处理器执行如上述中任一项基于知识蒸馏的模型训练方法的步骤。
为解决上述技术问题,本申请采用的再一个技术方案是:提供一种存储介质,存储有能够实现上述基于知识蒸馏的模型训练方法的程序文件。
本申请的有益效果是:本申请的基于知识蒸馏的模型训练方法通过在基于知识蒸馏来训练跨模态的模型时,获取第一模型和第二模型各自在嵌入层的表示,以及中间层的输出,再根据两者在嵌入层的表示和中间层的输出对第二模型嵌入层和中间层的参数进行调整,以完成对第二模型的训练,其利用嵌入层的表示来调整参数,从而使得不同模态的第一模型、第二模型之间的潜在空间得以对齐,完成了跨模态模型之间的知识转移。
附图说明
图1是本发明第一实施例的基于知识蒸馏的模型训练方法的流程示意图;
图2是本发明第二实施例的基于知识蒸馏的模型训练方法的流程示意图;
图3是本发明第一实施例的基于知识蒸馏的模型训练装置的功能模块示意图;
图4是本发明第二实施例的基于知识蒸馏的模型训练装置的功能模块示意图;
图5是本发明实施例的终端的结构示意图;
图6是本发明实施例的存储介质的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请中的术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”、“第三”的特征可以明示或者隐含地包括至少一个该特征。本申请的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。本申请实施例中所有方向性指示(诸如上、下、左、右、前、后……)仅用于解释在某一特定姿态(如附图所示)下各部件之间的相对位置关系、运动情况等,如果该特定姿态发生改变时,则该方向性指示也相应地随之改变。此外,术语“包括”和“具有”以及它们任何变形,意图在于覆盖不排他的包含。例如包含了一系列步骤或单元的过程、方法、系统、产品或设备没有限定于已列出的步骤或单元,而是可选地还包括没有列出的步骤或单元,或可选地还包括对于这些过程、方法、产品或设备固有的其它步骤或单元。
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
图1是本发明第一实施例的基于知识蒸馏的模型训练方法的流程示意图。需注意的是,若有实质上相同的结果,本发明的方法并不以图1所示的流程顺序为限。如图1所示,该方法包括步骤:
步骤S101:获取预先训练好的第一模型、待训练的第二模型和训练样本集,第一模型对应第一模态,第二模型对应第二模态,第一模态与第二模态不同。
需要说明的是,模态是指文本、图像、视频、音频、传感器数据、3D等。本实施例中的第一模态和第二模态是两个不同的模态,第一模型对应第一模态是指第一模型所针对的输入数据是第一模态的数据,第二模型对应第二模态是指第二模型所针对的输入数据是第二模态的数据,例如,第一模型的输入数据是图像,第二模型的输入数据是文本。
在步骤S101中,该第一模型为预先训练好的模型,在知识蒸馏过程中作为教师模型,该第二模型为未经过训练的模型,在知识蒸馏过程中作为学生模型。需要理解的是,教师模型和学生模型需要是解决相同问题的模型,例如,当教师模型是分类模型是,学生模型也是分类模型。本实施例中,首先获取已经训练好的作为教师的第一模型和作为学生的第二模型,以及作为训练样本集,该训练样本集中包括多个训练样本。
优选地,本实施例中,第一模型为图神经网络分类模型,第二模型为文本分类模型,文本分类模型用于对输入文本中的实体之间的关系进行推理。
具体地,当第二模型训练好之后,若输入第二模型的输入文本中包括两个或两个以上的主体时,第二模型对该两个或两个以上的主体进行关系推理,得到该两个或两个以上主体相互之间的关系。例如,输入文本中包括梅琳达·盖茨和西雅图两个实体,在推理两者之间的关系时,可得到一条路径“梅琳达·盖茨-配偶-比尔·盖茨-主席-微软-总部在-西雅图”,从而推测得到梅琳达·盖茨可能居住在西雅图。
步骤S102:将训练样本分别输入至第一模型和第二模型,得到第一模型嵌入层的第一表示和第一模型中间层的第一输出、以及第二模型嵌入层的第二表示和第二模型中间层的第二输出。
需要说明的是,第一模型包括编码层、嵌入层、中间层和输出层,作为教师模型,该第一模型可以为参数复杂的大模型,第二模型的结构与第一模型类似,同样包括编码层、嵌入层、中间层和输出层。
在步骤S102中,首先从训练样本集中选取一个目标训练样本,将该目标训练样本输入至第一模型,然后从第一模型的嵌入层中获取第一表示,并从第一模型的中间层获取第一输出。然后,将该目标训练样本输入至第二模型,然后从第二模型的嵌入层中获取第二表示,并从第二模型的中间层获取第二输出。
步骤S103:利用第一表示和第二表示进行对比自监督学习,得到对比自监督学习结果。
需要说明的是,对于跨模态的问题,通常需要将不同模态数据嵌入到一个公共表示空间中,以便进行对齐。而在本实施例中,为了完成跨模态模型之间的知识转移,则需要对齐第一模型与第二模型之间的潜在空间和预测结果。
在步骤S103中,在得到第一模型嵌入层的第一表示和第二模型嵌入层的第二表示后,根据第一表示和第二表示对第二模型嵌入层的参数进行调整,具体通过利用第一表示和第二表示进行对比自监督学习,再根据对比自监督学习结果更新第二模型嵌入层的参数,以对齐第一模型和第二模型的潜在空间。需要说明的是,对比自监督学习是通过学习对使两种事物相似或不同的东西进行编码来构建表示。本实施例中,在获取到第一模型嵌入层的第一表示和第二模型嵌入层的第二表示后,利用第一表示和第二表示进行对比自监督学习,具体地,以同一实例的一个目标第一表示作为正对,并将同一实例中的其他第一表示作为负对,将与目标第一表示对应的第二表示作为“锚”数据点,而对比自监督学习的目的学习编辑器f:
score(f(Pgraph),f(Ptext))>>score(f(Pgraph),f(P′text));
其中,Ptext是正对的第一表示,Pgraph是正对的第二表示,P′text是负对的第一表示,score函数是一个度量两个特征之间相似性的指标。
步骤S104:利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果。
在步骤S104中,在得到第一模型中间层的第一输出、第二模型中间层的第二输出后,根据第一输出与第二输出对第二模型的中间层的参数进行调整,具体通过利用第一输出和第二输出进行知识蒸馏,再根据知识蒸馏结果更新第二模型中间层的参数,以对齐第一模型和第二模型的预测结果。
具体地,在得到第一模型中间层的第一输出和第二模型中间层的第二输出后,基于第一输出和第二输出进行知识蒸馏操作,以将第一模型中间层的参数映射至第二模型的中间层。
其中,利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果,具体包括:
(1)计算第一输出与第二输出之间的蒸馏损失值。
其中,蒸馏损失值的计算公式为:
Figure BDA0003077901880000081
其中,lKD为蒸馏损失值,T为预先设定的softmax函数的超参数,KL指KL散度计算,σ为softmax函数,ztext为第一输出,zgraph为第二输出。
(2)将蒸馏损失值作为知识蒸馏结果。
步骤S105:根据对比自监督学习结果更新第二模型嵌入层的参数,并根据知识蒸馏结果更新第二模型中间层的参数,得到训练好的第二模型。
在步骤S105中,在得到对比自监督学习结果和知识蒸馏结果后,根据对比自监督学习结果更新第二模型嵌入层的参数,同时根据知识蒸馏结果更新第二模型中间层的参数。
其中,根据对比自监督学习结果更新第二模型嵌入层的参数具体包括:
(1)基于对比自监督学习结果计算第一表示与第二表示之间的互信息。
其中,互信息的计算公式为:
Figure BDA0003077901880000082
其中,lMI为互信息,P(p,g|c=1)为第一表示与第二表示对应于相同的答题式的条件概率,P(p,g|c=0)为第一表示与第二表示对应于不同的答题式的条件概率,E表示求均值,sp为softplus函数,Ptext为第一表示,Pgraph为第二表示,T(Ptext,Pgraph)是对比自监督学习的结果。
(2)根据互信息更新第二模型嵌入层的参数。
本实施例中,在计算得到第一表示和第二表示之间的互信息后,基于该互信息更新第二模型嵌入层的参数,从而以最大限度地提高第一模型嵌入层的表示与第二模型嵌入层的表示之间的互信息。在得到蒸馏损失值后,根据该蒸馏损失值更新第二模型中间层的参数,目的是使得第一模型中间层的输出与第二模型中间层的输出越来越小。
需要理解的是,训练样本集中包括多个训练样本。本实施例中,首先将训练样本集中的第一个训练样本输入至第一模型和第二模型中,得到第一模型嵌入层的第一表示和中间层的第一输出、以及第二模型嵌入层的第二表示和中间层的第二输出,再基于第一表示、第二表示、第一输出、第二输出调整第二模型的嵌入层和中间层的参数,然后再从训练样本集中选出第二个训练样本并入职至第一模型和第二模型中,以对第二模型中的嵌入层和中间层的参数再次进行调整,逐个利用训练样本集中的训练样本对第二模型进行训练,直至第二模型的精度达到预设精度要求,或者是所有的训练样本被训练完,得到最终训练好的第二模型。
本发明第一实施例的基于知识蒸馏的模型训练方法通过在基于知识蒸馏来训练跨模态的模型时,获取第一模型和第二模型各自在嵌入层的表示,以及中间层的输出,再根据两者在嵌入层的表示和中间层的输出对第二模型嵌入层和中间层的参数进行调整,以完成对第二模型的训练,其利用嵌入层的表示来调整参数,从而使得不同模态的第一模型、第二模型之间的潜在空间得以对齐,完成了跨模态模型之间的知识转移。
图2是本发明第二实施例的基于知识蒸馏的模型训练方法的流程示意图。需注意的是,若有实质上相同的结果,本发明的方法并不以图2所示的流程顺序为限。如图2所示,该方法包括步骤:
步骤S201:获取预先训练好的第一模型、待训练的第二模型和训练样本集,第一模型对应第一模态,第二模型对应第二模态,第一模态与第二模态不同。
在本实施例中,图2中的步骤S201和图1中的步骤S101类似,为简约起见,在此不再赘述。
步骤S202:将训练样本集转换为符合第一模态的第一样本集、以及符合第二模态的第二样本集。
在步骤S202中,需要理解的是,本实施例的第一模型和第二模型所针对的输入数据的模态不相同,因此,在输入样本数据进行训练之前,需要将训练样本转换为符合各自模态的样本集。因此,在得到训练样本集后,对训练样本集进行转换,以得到符合第一模态的第一样本集和符合第二模态的第二样本集。例如,当第一模型为图神经网络分类模型、第二模型为文本分类模型时,则训练样本集需要转换为图片格式的第一样本集和文本格式的第二样本集。
步骤S203:将训练样本分别输入至第一模型和第二模型,得到第一模型嵌入层的第一表示和第一模型中间层的第一输出、以及第二模型嵌入层的第二表示和第二模型中间层的第二输出。
具体地,将训练样本分别输入至第一模型和第二模型的步骤具体包括:
1、将第一样本集中的目标第一样本输入至第一模型。
2、将第二样本集中与目标第一样本对应的目标第二样本输入至第二模型。
具体地,将与第一模型模态对应的第一样本集输入至第一模型,且将与第二模型模态对应的第二样本集输入至第二模型。并且,第一样本集中输入第一模型的训练样本与第二样本集中输入第二模型的训练样本相互对应。
步骤S204:利用第一表示和第二表示进行对比自监督学习,得到对比自监督学习结果。
在本实施例中,图2中的步骤S204和图1中的步骤S103类似,为简约起见,在此不再赘述。
步骤S205:利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果。
在本实施例中,图2中的步骤S205和图1中的步骤S104类似,为简约起见,在此不再赘述。
步骤S206:根据对比自监督学习结果更新第二模型嵌入层的参数,并根据知识蒸馏结果更新第二模型中间层的参数,得到训练好的第二模型
在本实施例中,图2中的步骤S206和图1中的步骤S105类似,为简约起见,在此不再赘述。
本发明第二实施例的基于知识蒸馏的模型训练方法在第一实施例的基础上,通过将训练样本转换模态,得到符合第一模型输入数据要求第一样本集和符合第二模型的输入数据要求的第二样本集,以方便输入对应模态的训练样本至模型中,完成跨模态模型的训练。
图3是本发明实施例的基于知识蒸馏的模型训练装置的功能模块示意图。如图3所示,该装置30包括获取模块31、输入模块32和学习模块33、蒸馏模块34和更新模块35。
获取模块31,用于获取预先训练好的第一模型、待训练的第二模型和训练样本集,第一模型对应第一模态,第二模型对应第二模态,第一模态与第二模态不同;
输入模块32,用于将训练样本分别输入至第一模型和第二模型,得到第一模型嵌入层的第一表示和第一模型中间层的第一输出、以及第二模型嵌入层的第二表示和第二模型中间层的第二输出;
学习模块33,用于利用第一表示和第二表示进行对比自监督学习,得到对比自监督学习结果;
蒸馏模块34,用于利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果;
更新模块35,用于根据对比自监督学习结果更新第二模型嵌入层的参数,并根据知识蒸馏结果更新第二模型中间层的参数,得到训练好的第二模型。
可选地,更新模块35执行根据对比自监督学习结果更新第二模型嵌入层的参数的操作还可以为:
基于自对比学习结果计算第一表示与第二表示之间的互信息,互信息的计算公式为:
Figure BDA0003077901880000111
Figure BDA0003077901880000121
其中,lMI为互信息,P(p,g|c=1)为第一表示与第二表示对应于相同的答题式的条件概率,P(p,g|c=0)为第一表示与第二表示对应于不同的答题式的条件概率,E表示求均值,sp为softplus函数,Ptext为第一表示,Pgraph为第二表示,T(Ptext,Pgraph)是自对比学习的结果;
根据互信息更新第二模型嵌入层的参数。
可选地,蒸馏模块34执行利用第一输出和第二输出进行知识蒸馏,得到知识蒸馏结果的操作还可以为:
计算第一输出与第二输出之间的蒸馏损失值,蒸馏损失值的计算公式为:
Figure BDA0003077901880000122
其中,lKD为蒸馏损失值,T为预先设定的softmax函数的超参数,KL指KL散度计算,σ为softmax函数,ztext为第一输出,zgraph为第二输出;
将蒸馏损失值作为知识蒸馏结果。
可选地,如图4所示,该装置30还包括转换模块36,在获取模块31执行获取预先训练好的第一模型、待训练的第二模型和训练样本集的操作之后,转换模块36用于将训练样本集转换为符合第一模态的第一样本集、以及符合第二模态的第二样本集。
可选地,输入模块32执行将训练样本分别输入至第一模型和第二模型的操作还可以为:将第一样本集中的目标第一样本输入至第一模型;将第二样本集中与目标第一样本对应的目标第二样本输入至第二模型。
可选地,第一模型为图神经网络分类模型,第二模型为文本分类模型,文本分类模型用于对输入文本中的实体之间的关系进行推理。
关于上述实施例基于知识蒸馏的模型训练装置中各模块实现技术方案的其他细节,可参见上述实施例中的基于知识蒸馏的模型训练方法中的描述,此处不再赘述。
需要说明的是,本说明书中的各个实施例均采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似的部分互相参见即可。对于装置类实施例而言,由于其与方法实施例基本相似,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
请参阅图5,图5为本发明实施例的终端的结构示意图。如图5所示,该终端50包括处理器51及和处理器51耦接的存储器52,存储器52中存储有程序指令,程序指令被处理器51执行时,使得处理器51执行上述任一实施例所述的基于知识蒸馏的模型训练方法的步骤。
其中,处理器51还可以称为CPU(Central Processing Unit,中央处理单元)。处理器51可能是一种集成电路芯片,具有信号的处理能力。处理器51还可以是通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
参阅图6,图6为本发明实施例的存储介质的结构示意图。本发明实施例的存储介质存储有能够实现上述所有方法的程序文件61,其中,该程序文件61可以以软件产品的形式存储在上述存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本申请各个实施方式所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质,或者是计算机、服务器、手机、平板等终端设备。
在本申请所提供的几个实施例中,应该理解到,所揭露的终端,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。以上仅为本申请的实施方式,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。

Claims (8)

1.一种基于知识蒸馏的模型训练方法,其特征在于,包括:
获取预先训练好的第一模型、待训练的第二模型和训练样本集,所述第一模型对应第一模态,所述第二模型对应第二模态,所述第一模态与所述第二模态不同;
将所述训练样本分别输入至所述第一模型和所述第二模型,得到所述第一模型嵌入层的第一表示和所述第一模型中间层的第一输出、以及所述第二模型嵌入层的第二表示和所述第二模型中间层的第二输出;
利用所述第一表示和所述第二表示进行对比自监督学习,得到对比自监督学习结果;
利用所述第一输出和所述第二输出进行知识蒸馏,得到知识蒸馏结果;
根据所述对比自监督学习结果更新所述第二模型嵌入层的参数,并根据所述知识蒸馏结果更新所述第二模型中间层的参数,得到训练好的第二模型;
所述根据所述对比自监督学习结果更新所述第二模型嵌入层的参数,包括:
基于所述对比自监督学习结果计算所述第一表示与所述第二表示之间的互信息,所述互信息的计算公式为:
Figure FDA0004213603060000011
其中,lMI为所述互信息,P(p,g|c=1)为所述第一表示与所述第二表示对应于相同的答题式的条件概率,P(p,g|c=0)为所述第一表示与所述第二表示对应于不同的答题式的条件概率,E表示求均值,sp为softplus函数,Ptext为所述第一表示,Pgraph为所述第二表示,T(Ptext,Pgraph)是对比自监督学习的结果;
根据所述互信息更新所述第二模型嵌入层的参数;
所述利用所述第一输出和所述第二输出进行知识蒸馏,得到知识蒸馏结果,包括:
计算所述第一输出与所述第二输出之间的蒸馏损失值,所述蒸馏损失值的计算公式为:
Figure FDA0004213603060000012
其中,lKD为所述蒸馏损失值,T为预先设定的softmax函数的超参数,KL指KL散度计算,σ为softmax函数,ztext为所述第一输出,zgraph为所述第二输出;
将所述蒸馏损失值作为所述知识蒸馏结果。
2.根据权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述获取预先训练好的第一模型、待训练的第二模型和训练样本集之后,还包括:
将所述训练样本集转换为符合所述第一模态的第一样本集、以及符合所述第二模态的第二样本集。
3.根据权利要求2所述的基于知识蒸馏的模型训练方法,其特征在于,所述将所述训练样本分别输入至所述第一模型和所述第二模型,包括:
将所述第一样本集中的目标第一样本输入至所述第一模型;
将所述第二样本集中与所述目标第一样本对应的目标第二样本输入至所述第二模型。
4.根据权利要求1所述的基于知识蒸馏的模型训练方法,其特征在于,所述第一模型为图神经网络分类模型,所述第二模型为文本分类模型,所述文本分类模型用于对输入文本中的实体之间的关系进行推理。
5.一种基于知识蒸馏的模型训练装置,其特征在于,包括:
获取模块,用于获取预先训练好的第一模型、待训练的第二模型和训练样本集,所述第一模型对应第一模态,所述第二模型对应第二模态,所述第一模态与所述第二模态不同;
输入模块,用于将所述训练样本分别输入至所述第一模型和所述第二模型,得到所述第一模型嵌入层的第一表示和所述第一模型中间层的第一输出、以及所述第二模型嵌入层的第二表示和所述第二模型中间层的第二输出;
学习模块,用于利用所述第一表示和所述第二表示进行对比自监督学习,得到对比自监督学习结果;
蒸馏模块,用于利用所述第一输出和所述第二输出进行知识蒸馏,得到知识蒸馏结果;
更新模块,用于根据所述对比自监督学习结果更新所述第二模型嵌入层的参数,并根据所述知识蒸馏结果更新所述第二模型中间层的参数,得到训练好的第二模型;
所述根据所述对比自监督学习结果更新所述第二模型嵌入层的参数,包括:
基于所述对比自监督学习结果计算所述第一表示与所述第二表示之间的互信息,所述互信息的计算公式为:
Figure FDA0004213603060000031
其中,lMI为所述互信息,P(p,g|x=1)为所述第一表示与所述第二表示对应于相同的答题式的条件概率,P(p,g|c=0)为所述第一表示与所述第二表示对应于不同的答题式的条件概率,E表示求均值,sp为softplus函数,Ptext为所述第一表示,Pgraph为所述第二表示,T(Ptext,Pgraph)是对比自监督学习的结果;
根据所述互信息更新所述第二模型嵌入层的参数;
所述利用所述第一输出和所述第二输出进行知识蒸馏,得到知识蒸馏结果,包括:
计算所述第一输出与所述第二输出之间的蒸馏损失值,所述蒸馏损失值的计算公式为:
Figure FDA0004213603060000032
其中,lKD为所述蒸馏损失值,T为预先设定的softmax函数的超参数,KL指KL散度计算,σ为softmax函数,ztext为所述第一输出,zgraph为所述第二输出;
将所述蒸馏损失值作为所述知识蒸馏结果。
6.根据权利要求5所述的基于知识蒸馏的模型训练装置,其特征在于,所述装置还包括:
转换模块,用于将所述训练样本集转换为符合第一模态的第一样本集、以及符合第二模态的第二样本集。
7.一种终端,其特征在于,所述终端包括处理器、与所述处理器耦接的存储器,所述存储器中存储有程序指令,所述程序指令被所述处理器执行时,使得所述处理器执行如权利要求1-4中任一项权利要求所述的基于知识蒸馏的模型训练方法的步骤。
8.一种存储介质,其特征在于,存储有能够实现如权利要求1-4中任一项所述的基于知识蒸馏的模型训练方法的程序文件。
CN202110558102.5A 2021-05-21 2021-05-21 基于知识蒸馏的模型训练方法、装置、终端及存储介质 Active CN113255763B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110558102.5A CN113255763B (zh) 2021-05-21 2021-05-21 基于知识蒸馏的模型训练方法、装置、终端及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110558102.5A CN113255763B (zh) 2021-05-21 2021-05-21 基于知识蒸馏的模型训练方法、装置、终端及存储介质

Publications (2)

Publication Number Publication Date
CN113255763A CN113255763A (zh) 2021-08-13
CN113255763B true CN113255763B (zh) 2023-06-09

Family

ID=77183645

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110558102.5A Active CN113255763B (zh) 2021-05-21 2021-05-21 基于知识蒸馏的模型训练方法、装置、终端及存储介质

Country Status (1)

Country Link
CN (1) CN113255763B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113487614B (zh) * 2021-09-08 2021-11-30 四川大学 胎儿超声标准切面图像识别网络模型的训练方法和装置
CN114595780B (zh) * 2022-03-15 2022-12-20 百度在线网络技术(北京)有限公司 图文处理模型训练及图文处理方法、装置、设备及介质
CN114386880B (zh) * 2022-03-22 2022-07-08 北京骑胜科技有限公司 确定订单数量的模型训练方法、订单数量确定方法及装置
CN115471799B (zh) * 2022-09-21 2024-04-30 首都师范大学 一种利用姿态估计和数据增强的车辆重识别方法及系统

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111242297A (zh) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 基于知识蒸馏的模型训练方法、图像处理方法及装置
CN111611377A (zh) * 2020-04-22 2020-09-01 淮阴工学院 基于知识蒸馏的多层神经网络语言模型训练方法与装置
CN112232397A (zh) * 2020-09-30 2021-01-15 上海眼控科技股份有限公司 图像分类模型的知识蒸馏方法、装置和计算机设备
CN112381209A (zh) * 2020-11-13 2021-02-19 平安科技(深圳)有限公司 一种模型压缩方法、系统、终端及存储介质
CN112508169A (zh) * 2020-11-13 2021-03-16 华为技术有限公司 知识蒸馏方法和系统
CN112733550A (zh) * 2020-12-31 2021-04-30 科大讯飞股份有限公司 基于知识蒸馏的语言模型训练方法、文本分类方法及装置

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111242297A (zh) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 基于知识蒸馏的模型训练方法、图像处理方法及装置
CN111611377A (zh) * 2020-04-22 2020-09-01 淮阴工学院 基于知识蒸馏的多层神经网络语言模型训练方法与装置
CN112232397A (zh) * 2020-09-30 2021-01-15 上海眼控科技股份有限公司 图像分类模型的知识蒸馏方法、装置和计算机设备
CN112381209A (zh) * 2020-11-13 2021-02-19 平安科技(深圳)有限公司 一种模型压缩方法、系统、终端及存储介质
CN112508169A (zh) * 2020-11-13 2021-03-16 华为技术有限公司 知识蒸馏方法和系统
CN112733550A (zh) * 2020-12-31 2021-04-30 科大讯飞股份有限公司 基于知识蒸馏的语言模型训练方法、文本分类方法及装置

Also Published As

Publication number Publication date
CN113255763A (zh) 2021-08-13

Similar Documents

Publication Publication Date Title
CN113255763B (zh) 基于知识蒸馏的模型训练方法、装置、终端及存储介质
US11651214B2 (en) Multimodal data learning method and device
JP2023545543A (ja) 情報生成方法、装置、コンピュータ機器、記憶媒体及びコンピュータプログラム
CN113761153A (zh) 基于图片的问答处理方法、装置、可读介质及电子设备
CN111782826A (zh) 知识图谱的信息处理方法、装置、设备及存储介质
CN114388064A (zh) 用于蛋白质表征学习的多模态信息融合方法、系统、终端及存储介质
CN116564338B (zh) 语音动画生成方法、装置、电子设备和介质
CN112669215A (zh) 一种训练文本图像生成模型、文本图像生成的方法和装置
CN111653274A (zh) 唤醒词识别的方法、装置及存储介质
CN113628059A (zh) 一种基于多层图注意力网络的关联用户识别方法及装置
CN110929532B (zh) 数据处理方法、装置、设备及存储介质
CN112668608A (zh) 一种图像识别方法、装置、电子设备及存储介质
CN116821287A (zh) 基于知识图谱和大语言模型的用户心理画像系统及方法
CN116956116A (zh) 文本的处理方法和装置、存储介质及电子设备
CN111737439A (zh) 一种问题生成方法及装置
CN115081615A (zh) 一种神经网络的训练方法、数据的处理方法以及设备
CN114328943A (zh) 基于知识图谱的问题回答方法、装置、设备及存储介质
CN115712739B (zh) 舞蹈动作生成方法、计算机设备及存储介质
CN116401364A (zh) 语言模型的训练方法、电子设备、存储介质及产品
CN112989088B (zh) 一种基于强化学习的视觉关系实例学习方法
CN115116444A (zh) 一种语音识别文本的处理方法、装置、设备及存储介质
US11145414B2 (en) Dialogue flow using semantic simplexes
CN113761149A (zh) 对话信息处理方法、装置、计算机设备及存储介质
CN113569867A (zh) 一种图像处理方法、装置、计算机设备及存储介质
CN112052680A (zh) 问题生成方法、装置、设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant