CN114511083A - 一种模型的训练方法、装置、存储介质及电子装置 - Google Patents
一种模型的训练方法、装置、存储介质及电子装置 Download PDFInfo
- Publication number
- CN114511083A CN114511083A CN202210407353.8A CN202210407353A CN114511083A CN 114511083 A CN114511083 A CN 114511083A CN 202210407353 A CN202210407353 A CN 202210407353A CN 114511083 A CN114511083 A CN 114511083A
- Authority
- CN
- China
- Prior art keywords
- model
- initial
- evaluation index
- determining
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 180
- 238000000034 method Methods 0.000 title claims abstract description 51
- 238000011156 evaluation Methods 0.000 claims abstract description 211
- 238000001514 detection method Methods 0.000 claims description 17
- 238000004590 computer program Methods 0.000 claims description 16
- 238000005516 engineering process Methods 0.000 abstract description 4
- 230000000694 effects Effects 0.000 abstract description 3
- 238000004821 distillation Methods 0.000 description 23
- 238000013145 classification model Methods 0.000 description 10
- 230000005540 biological transmission Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 6
- 238000005070 sampling Methods 0.000 description 6
- 238000012545 processing Methods 0.000 description 5
- 238000005457 optimization Methods 0.000 description 4
- 238000012795 verification Methods 0.000 description 3
- 238000004364 calculation method Methods 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 230000006870 function Effects 0.000 description 2
- 238000013140 knowledge distillation Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000009825 accumulation Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000009795 derivation Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000002349 favourable effect Effects 0.000 description 1
- 238000007429 general method Methods 0.000 description 1
- 230000000977 initiatory effect Effects 0.000 description 1
- 230000009191 jumping Effects 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 239000002245 particle Substances 0.000 description 1
- 238000013138 pruning Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明实施例提供了一种模型的训练方法、装置、存储介质及电子装置,其中,该方法包括:基于初始网络模型中包括的第一初始模型的第一训练精度确定第一初始模型的第一评价指标,以及基于初始网络模型中包括的第二初始模型的第二训练精度确定第二初始模型的第二评价指标;在第一评价指标以及第二评价指标中存在大于预定阈值的评价指标的情况下,将第一评价指标以及第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型;利用训练数据以及其他初始模型训练目标初始模型,得到目标网络模型。通过本发明,解决了相关技术中存在的模型训练时间长、效率低的问题,达到提高模型的训练效率的效果。
Description
技术领域
本发明实施例涉及计算机领域,具体而言,涉及一种模型的训练方法、装置、存储介质及电子装置。
背景技术
近年来,深度学习发展迅猛,被应用到计算机视觉、语音识别以及自然语言处理上。随着数据规模积累的足够大,场景越加复杂,往往会使用到更加复杂的模型结构,这不仅仅需要更大的算力需求,同时也对存储带来了一定的挑战。由此以衍生了模型优化、压缩、剪枝、蒸馏等技术方案,而模型蒸馏就这些方案中比较常用的一种。
蒸馏方案主要以一个优秀的Teacher 模型为基础,通过知识蒸馏将Teacher 模型优秀的表现传递给Student模型,从而得到一个又好又快的Student模型。然而,在相关技术中,Teacher模型是提前训练好的,且Teacher模型的精度要高于Student模型,由此导致了模型训练时间长、效率低的问题。
针对相关技术中存在的上述问题,目前尚未提出有效的解决方案。
发明内容
本发明实施例提供了一种模型的训练方法、装置、存储介质及电子装置,以至少解决相关技术中存在的模型训练时间长、效率低的问题。
根据本发明的一个实施例,提供了一种模型的训练方法,包括:基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标,以及基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标,其中,所述初始网络模型为经过训练得到的网络模型;在所述第一评价指标以及所述第二评价指标中存在大于预定阈值的评价指标的情况下,将所述第一评价指标以及所述第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型;利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型,其中,所述其他初始模型为所述第一评价指标以及所述第二评价指标中包括的最大评价指标对应的初始模型。
根据本发明的另一个实施例,提供了一种模型的训练装置,包括:第一确定模块,用于基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标,以及基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标,其中,所述初始网络模型为经过训练得到的网络模型;第二确定模块,用于在所述第一评价指标以及所述第二评价指标中存在大于预定阈值的评价指标的情况下,将所述第一评价指标以及所述第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型;训练模块,用于利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型,其中,所述其他初始模型为所述第一评价指标以及所述第二评价指标中包括的最大评价指标对应的初始模型。
根据本发明的又一个实施例,还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,其中,所述计算机程序被处理器执行时实现上述任一项中所述的方法的步骤。
根据本发明的又一个实施例,还提供了一种电子装置,包括存储器和处理器,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行上述任一项方法实施例中的步骤。
通过本发明,根据初始网络模型中包括的第一初始网络模型的第一训练精度确定第一初始模型的第一评价指标,根据初始网络模型中包括的第二初始网络模型的第二训练精度确定第二初始模型的第二评价指标。在第评价指标以及第二评价指标中存在大于预定阈值的评价指标的情况下,将第一评价指标以及第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型,利用训练数据以及第一评价指标和第二评价指标中包括的最大评价指标对应的其他初始模型训练目标初始模型,以得到目标网络模型。由于在训练时,可以同时训练第一初始模型和第二初始模型,并确定第一初始模型的第一评价指标和第二初始模型的第二评价指标,在第一评价指标和第二评价指标中存在大于预定阈值的评价指标时,利用最大评价指标对应的其他初始模型和训练数据对目标初始模型进行训练优化,实现了同时训练第一初始模型和第二初始模型,并在在第一评价指标和第二评价指标中存在大于预定阈值的评价指标时,指定用其他初始模型训练目标初始模型,无需预先训练其他初始模型。因此,可以解决相关技术中存在的模型训练时间长、效率低的问题,达到提高模型的训练效率的效果。
附图说明
图1是本发明实施例的一种模型的训练方法的移动终端的硬件结构框图;
图2是根据本发明实施例的模型的训练方法的流程图;
图3是根据本发明具体实施例的检测模型训练方法流程图;
图4是根据本发明具体实施例的分类模型训练方法流程图;
图5是根据本发明实施例的模型的训练装置的结构框图。
具体实施方式
下文中将参考附图并结合实施例来详细说明本发明的实施例。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
蒸馏方案主要以一个优秀的Teacher 模型为基础,通过知识蒸馏将Teacher 模型优秀的表现传递给Student模型,从而得到一个又好又快的Student模型。这种方案存在以下几点不足:
1)Teacher 一般需要在蒸馏前预先训练好;
2)Teacher 模型的精度对蒸馏结果的影响较大,Student模型精度上限取决于Teacher 模型的最高精度;
3)传统蒸馏Teacher模型和Studnet模型的身份是固定不变的,Teacher 模型需要更高的精度,因此Teacher 模型的结构要比Student模型更大更复杂;
4)传统蒸馏方法对分类模型和检测模型上有较大的差异,一般方法不可通用。
针对相关技术中存在的上述问题,提出以下实施例:
本申请实施例中所提供的方法实施例可以在移动终端、计算机终端或者类似的运算装置中执行。以运行在移动终端上为例,图1是本发明实施例的一种模型的训练方法的移动终端的硬件结构框图。如图1所示,移动终端可以包括一个或多个(图1中仅示出一个)处理器102(处理器102可以包括但不限于微处理器MCU或可编程逻辑器件FPGA等的处理装置)和用于存储数据的存储器104,其中,上述移动终端还可以包括用于通信功能的传输设备106以及输入输出设备108。本领域普通技术人员可以理解,图1所示的结构仅为示意,其并不对上述移动终端的结构造成限定。例如,移动终端还可包括比图1中所示更多或者更少的组件,或者具有与图1所示不同的配置。
存储器104可用于存储计算机程序,例如,应用软件的软件程序以及模块,如本发明实施例中的模型的训练方法对应的计算机程序,处理器102通过运行存储在存储器104内的计算机程序,从而执行各种功能应用以及数据处理,即实现上述的方法。存储器104可包括高速随机存储器,还可包括非易失性存储器,如一个或者多个磁性存储装置、闪存、或者其他非易失性固态存储器。在一些实例中,存储器104可进一步包括相对于处理器102远程设置的存储器,这些远程存储器可以通过网络连接至移动终端。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
传输设备106用于经由一个网络接收或者发送数据。上述的网络具体实例可包括移动终端的通信供应商提供的无线网络。在一个实例中,传输设备106包括一个网络适配器(Network Interface Controller,简称为NIC),其可通过基站与其他网络设备相连从而可与互联网进行通讯。在一个实例中,传输设备106可以为射频(Radio Frequency,简称为RF)模块,其用于通过无线方式与互联网进行通讯。
在本实施例中提供了一种运行于模型的训练方法,图2是根据本发明实施例的模型的训练方法的流程图,如图2所示,该流程包括如下步骤:
步骤S202,基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标,以及基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标,其中,所述初始网络模型为经过训练得到的网络模型;
步骤S204,在所述第一评价指标以及所述第二评价指标中存在大于预定阈值的评价指标的情况下,将所述第一评价指标以及所述第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型;
步骤S206,利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型,其中,所述其他初始模型为所述第一评价指标以及所述第二评价指标中包括的最大评价指标对应的初始模型。
在上述实施例中,初始网络模型中可以包括第一初始模型和第二初始模型,第一初始模型和第二初始模型可以是相同结构、相同精度的模型,还可以是不同结构、不同精度的模型。在初始网络模型搭建完成后,可以利用训练数据训练初始网络模型。在每训练预定次数后,可以利用验证数据集验证初始网络模型中包括的第一初始模型和第二初始模型的训练精度,即确定第一训练精度和第二训练精度。根据第一训练精度确定第一评价指标,根据第二训练精度确定第二评价指标。其中,可以根据初始网络模型的类型确定根据第一训练精度确定第一评价指标的方式,以及确定根据第二训练精度确定第二评价指标的方式。初始网络模型的类型可以包括分类模型、检测模型等。
在上述实施例中,可以每训练完一个epoch后通过模型当前状态参数推理验证集数据集,结合评价指标ρ选择指标高的模型作为Teacher模型,并在下一个epoch训练时固定Teacher模型的参数只优化更新Student模型的参数。
在上述实施例中,当第一评价指标以及第二评价指标中存在大于预定阈值的评价指标时,将最小的评价指标对应的初始模型确定为目标网络模型。其中,预定阈值可以是预先设定的阈值,当存在评价指标大于预定阈值时,可以认为大于预定阈值的初始模型的训练精度交好,可以作为老师模型,用于训练学生模型。
在上述实施例中,第一评价指标以及第二评价指标中存在大于预定阈值的评价指标包括:第一评价指标大于预定阈值,第二评价指标小于或等于预定阈值;第二评价指标大于预定阈值,第一评价指标小于或等于预定阈值;第一评价指标以及第二评价指标均大于预定阈值。
当第一评价指标大于预定阈值,第二评价指标小于或等于预定阈值时,此时第一评价指标为最大评价指标,可以将第一初始模型确定为其他初始模型,第二初始模型确定为目标初始模型,即将第一初始模型确定为老师模型,将第二初始模型确定为学生模型。
当第二评价指标大于预定阈值,第一评价指标小于或等于预定阈值时,此时,第二评价指标为最大评价指标,可以将第二初始模型确定为其他初始模型,第一初始模型确定为目标初始模型,即将第二初始模型确定为老师模型,将第一初始模型确定为学生模型。
当第一评价指标以及第二评价指标均大于预定阈值时,可以将最大评价指标对应的初始模型确定为其他初始模型,将最小评价指标对应的初始模型确定为目标初始模型。
例如,当第一评价指标大于第二评价指标时,将第一初始模型确定为其他初始模型,第二初始模型确定为目标初始模型。其中,预定阈值可以是0.6,该取值仅是一种示例性说明,预定阈值还可以取0.5、0.7、0.8等,本发明对此不作限制。
在上述实施例中,其他初始模型即为老师模型,目标初始模型即为学生模型,因此,可以利用训练数据以及老师模型对学生模型进行训练,不断更新迭代目标初始模型,在目标初始模型的损失值满足预定条件时,则退出训练,将训练后的初始网络模型确定为目标网络模型。
在上述实施例中,初始时同时训练第一初始模型和第二初始模型,此时不指定老师模型,当第一评价指标和第二评价指标中存在大于预定阈值的评价指标时,指定老师模型和学生模型。实现了同步训练Teacher老师模型和Student学生模型,不需要预先训练老师模型。
可选地,上述步骤的执行主体可以是处理器,或者其他的具备类似处理能力的设备,还可以是至少集成有数据处理设备的机器,其中,数据处理设备可以包括计算机、手机等终端,但不限于此。
通过本发明,根据初始网络模型中包括的第一初始网络模型的第一训练精度确定第一初始模型的第一评价指标,根据初始网络模型中包括的第二初始网络模型的第二训练精度确定第二初始模型的第二评价指标。在第评价指标以及第二评价指标中存在大于预定阈值的评价指标的情况下,将第一评价指标以及第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型,利用训练数据以及第一评价指标和第二评价指标中包括的最大评价指标对应的其他初始模型训练目标初始模型,以得到目标网络模型。由于在训练时,可以同时训练第一初始模型和第二初始模型,并确定第一初始模型的第一评价指标和第二初始模型的第二评价指标,在第一评价指标和第二评价指标中存在大于预定阈值的评价指标时,利用最大评价指标对应的其他初始模型和训练数据对目标初始模型进行训练优化,实现了同时训练第一初始模型和第二初始模型,并在在第一评价指标和第二评价指标中存在大于预定阈值的评价指标时,指定用其他初始模型训练目标初始模型,无需预先训练其他初始模型。因此,可以解决相关技术中存在的模型训练时间长、效率低的问题,达到提高模型的训练效率的效果。
在一个示例性实施例中,利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型包括:将训练数据输入至所述其他初始模型中,确定所述其他初始模型输出的第一特征;将所述训练数据以及所述第一特征输入至所述目标初始模型中,确定所述目标初始模型的第一损失值;基于所述第一损失值迭代更新所述目标初始模型的网络参数,得到所述目标网络模型。在本实施例中,可以将训练数据输入至目标初始模型中,确定目标初始模型输出的特征,将其他初始模型,即老师模型的输出特征输入至目标初始模型,即学生模型中,并根据目标初始模型输出的特征以及第一特征确定目标初始模型的第一损失值,根据第一损失值迭代更新目标初始模型的网络参数,得到目标网络模型。
在一个示例性实施例中,将所述训练数据以及所述第一特征输入至所述目标初始模型中,确定所述目标初始模型的第一损失值包括:确定所述第一特征中包括的每个特征层的第一子特征与所述目标初始模型输出的第二子特征之间的损失值,得到多个第二损失值,其中,所述第二子特征与所述第一子特征处于相同的特征层;确定多个所述第二损失值的第一和值;确定所述第二损失值对应的目标权重;确定所述第一和值与所述目标权重的第一乘积;基于所述目标初始模型输出的特征与所述其他初始模型输出的特征确定第三损失值;将所述第一乘积与所述第三损失值的第二和值确定为所述第一损失值。在本实施例中,在将训练输入至目标初始模型后,目标初始模型可以按照不同的采样倍率提取训练数据的特征图,不同的采样倍率可以对应一个特征层,如采样倍率为8、16、32,则可以得到采样倍率为8层的特征,采样倍率为16时的特征,采样倍率为32时的特征,将每个倍率下的特征确定为每个特征层的子特征。
在上述实施例中,可以确定第一特征中包括的每个特征层对应的第一子特征,确定目标初始模型输出的每个特征层对应的第二子特征。分别确定相同特征层对应的第一子特征和第二子特征之间的损失值,得到多个第二损失值,并确定多个第二损失值的第一和值。即。其中,表示蒸馏损失,表示目标初始模型输出的每个特征层对应的第二子特征,表示其他初始模型输出的每个特征层对应的第一子特征。
在上述实施例中,第一损失值可以表示为。其中,ω表示目标权重。 表示训练损失,即第三损失值。可以根据初始网络模型的分类确定第三损失值的计算方式。例如,当初始网络模型为分类模型时,第三损失值可以表示为;当初始网络模型为检测模型时,第三损失值可以表示为。
在上述实施例中,还可以确定第三损失值对应的训练权重,确定第三损失值与训练权重的乘积,将该乘积与第一乘积的和值确定为第一损失值。其中,训练权重可以为1,还可以是其他取值,本发明对此不做限制。当训练权重为1时,第一损失值则为第一乘积与第三损失值的和值。
在上述实施例中,目标权重和训练权重可以是预先确定的权重,还可以是根据第一评价指标和第二评价指标确定的权重。
在一个示例性实施例中,确定所述第二损失值对应的目标权重包括:确定所述目标初始模型的评价指标与第一参数的第二乘积;确定所述第二乘积与第二参数的第一差值;将所述第一差值与第三参数的比值确定为所述目标权重。在本实施例中,目标权重可以表示为,其中,ρ表示目标初始模型的评价指标,第一参数可以为5,第二参数可以为3,第三参数可以为2。需要说明的是,上述参数的取值仅是一种示例性说明,第一参数、第二参数以及第三参数还可以取其他值,本发明对此不作限制。
在一个示例性实施例中,在基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标之后,所述方法还包括:在所述第一评价指标以及所述第二评价指标均小于或等于预定阈值的情况下,将所述第一初始模型以及所述第二初始模型确定为所述目标初始模型;利用训练数据迭代更新所述目标初始模型的网络参数,得到更新后的所述初始网络模型。
在本实施例中,在得到第一评价指标和第二评价指标之后,当第一评价指标和第二评价指标均小于或等于预定阈值时,可以认为第一初始模型和第二初始模型的训练精度均不满足要求,因此,可以再次训练第一初始模型和第二初始模型,利用训练数据迭代更新目标初始模型的网络参数,得到更新后的初始网络模型。再利用验证数据集验证第一初始模型以及第二初始模型的第一训练精度以及第二训练精度,并根据第一训练精度确定第一评价指标,根据第二训练精度确定第二评价指标。确定第一评价指标与第二评价指标与预定阈值的大小关系,当存在大于预定阈值的评价指标时,则指定最大的评价指标对应的初始模型为其他初始模型,最小的评价指标对应的初始模型为目标初始模型。并利用其他初始模型和训练数据对目标初始模型进行优化训练。当第一评价指标和第二评价指标仍小于或等于预定阈值,则迭代更新第一初始模型和第二初始模型的网络参数,得到更新后的初始网络模型。不断执行上述步骤,直到第一评价指标和第二评价指标中存在大于预定阈值的评价指标时,利用其他初始模型和训练数据对目标初始模型进行优化训练。
在一个示例性实施例中,利用训练数据迭代更新所述目标初始模型的网络参数,得到更新后的所述初始网络模型包括:将所述训练数据输入至所述第一初始模型和所述第二初始模型中,基于所述第一初始模型输出的特征以及所述第二初始模型输出的特征确定所述初始网络模型的第四损失值;基于所述第四损失值迭代更新所述第一初始模型以及所述第二初始模型的网络参数,得到更新后的所述初始网络模型。在本实施例中,可以根据初始网络模型的类型确定第四损失值的计算方式。当初始网络模型的类型为分类网络模型时,第四损失值可以表示为,当初始网络模型的类型为检测网络模型时,第四损失值可以表示为。
在一个示例性实施例中,基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标包括:在所述初始网络模型为分类网络模型的情况下,将所述第一训练精度确定为所述第一评价指标;在所述初始网络模型为检测网络模型的情况下,确定所述第一训练精度与第四参数的第三和值,将所述第三和值与第五参数的比值确定为所述第一评价指标;基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标包括:在所述初始网络模型为分类网络模型的情况下,将所述第二训练精度确定为所述第二评价指标;在所述初始网络模型为检测网络模型的情况下,确定所述第二训练精度与第六参数的第四和值,将所述第四和值与第七参数的比值确定为所述第一评价指标。在本实施例中,第一评价指标和第二评价指标的确定方式均与初始网络模型的类型有关。当初始网络模型为分类网络模型的情况下,第一评价指标和第二评价指标可以表示为P,其中,P表示分类模型预测的平均精度。当初始网络模型为检测网络模型的情况下,第一评价指标和第二评价指标可以表示为。其中,P表示模型预测正样本的平均精度,I表示预测正样本框和gt框(真实样本框)的iou(交并比),即第六参数,第七参数可以是2。。
下面结合具体实施方式对模型的训练方法进行说明:
图3是根据本发明具体实施例的检测模型训练方法流程图,图4是根据本发明具体实施例的分类模型训练方法流程图,参见附图3-4可知,不同类型的模型可以采样相同的训练方法。数据进入模型,初始状态下评价指标为0,低于蒸馏的阈值(对应于上述预定阈值,可以人工设定为0.6),为训练模式。即不设定Teacher和Student身份,只优化训练损失l_train,同时训练两个模型;当训练到一定程度后通过验证集对当前模型进行评估(这里可以每个epoch做一次评估也可以多个epoch做一次评估,可以由训练者自由设定参数),当模型的评价指标高于蒸馏阈值时开始选择评价指标更高的模型为Teacher模型,在下一个epoch开始时固定Teacher的训练参数,同时优化蒸馏损失l_kd和训练损失l_train。
具体过程如下:
1)数据输入,实现数据集加载并按指定输入要求输入到模型中。
2)特征提取,通过backbone网络实现对数据抽象特征的提取。
3)模型选择 ,每训练完一个epoch后通过模型当前状态参数推理验证集数据集,结合评价指标ρ选择指标高的模型作为Teacher模型,并在下一个epoch训练时固定Teacher模型的参数只优化更新Student模型的参数。评价指标ρ由公式1和公式2生成:
4)损失计算,结合l_kd(蒸馏损失)和l_train(训练损失)组成最终优化损失l。损失的组成由公式3-7所示:
公式3表示当评价指标小于0.6时只优化训练损失,即此时只有模型训练没有模型蒸馏。当模型训练到一定程度,评价指标大于0.6时,蒸馏损失才得到优化,模型进入蒸馏状态。当两个模型同时训练时会首先判断各自当评价指标是否大于0.6,当指标都大于等于0.6时选择评价指标更高的模型作为Teacher模型,此时固定Teacher模型的参数;当指标都小于0.6时则只训练模型不做蒸馏。其中,数值0.6是超参数,可以根据训练模型的情况调节。公式4表示l_kd 的损失权重的推导过程,。公式5和公式6分别表示分类模型的损失由常规的分类损失构成,检测模型的损失由检测框的回归损失和分类损失之和构成。公式7表示蒸馏损失,和分别表示Student模型和Teacher模型的第i层特征。
在前述实施例中,将模型训练和模型蒸馏结合在一起,边训练边蒸馏,在训练过程中自动评估并择优选择Teacher模型,使得Teacher模型和Student模型身份可以随着训练动态转换。Student模型在训练过程中接受Teacher的蒸馏信息同时也在参考数据标签的信息,所以最终模型的最高精度不再受Teacher的最高精度限制。同时由于模型训练和蒸馏相结合的方案,有助于模型跳出模型的局部最优区间,从而达到更高的精度。由于Teacher模型和Student模型的身份在训练过程中是动态变化的,所以不需要强制要求Teacher模型要比Student模型的结构更大更复杂。同时适用于分类模型和检测模型的蒸馏。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到根据上述实施例的方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。
在本实施例中还提供了一种模型的训练装置,该装置用于实现上述实施例及优选实施方式,已经进行过说明的不再赘述。如以下所使用的,术语“模块”可以实现预定功能的软件和/或硬件的组合。尽管以下实施例所描述的装置较佳地以软件来实现,但是硬件,或者软件和硬件的组合的实现也是可能并被构想的。
图5是根据本发明实施例的模型的训练装置的结构框图,如图5所示,该装置包括:
第一确定模块52,用于基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标,以及基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标,其中,所述初始网络模型为经过训练得到的网络模型;
第二确定模块54,用于在所述第一评价指标以及所述第二评价指标中存在大于预定阈值的评价指标的情况下,将所述第一评价指标以及所述第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型;
训练模块56,用于利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型,其中,所述其他初始模型为所述第一评价指标以及所述第二评价指标中包括的最大评价指标对应的初始模型。
在一个示例性实施例中,训练模块56可以通过如下方式实现利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型:将训练数据输入至所述其他初始模型中,确定所述其他初始模型输出的第一特征;将所述训练数据以及所述第一特征输入至所述目标初始模型中,确定所述目标初始模型的第一损失值;基于所述第一损失值迭代更新所述目标初始模型的网络参数,得到所述目标网络模型。
在一个示例性实施例中,训练模块56可以通过如下方式实现将所述训练数据以及所述第一特征输入至所述目标初始模型中,确定所述目标初始模型的第一损失值:确定所述第一特征中包括的每个特征层的第一子特征与所述目标初始模型输出的第二子特征之间的损失值,得到多个第二损失值,其中,所述第二子特征与所述第一子特征处于相同的特征层;确定多个所述第二损失值的第一和值;确定所述第二损失值对应的目标权重;确定所述第一和值与所述目标权重的第一乘积;基于所述目标初始模型输出的特征与所述其他初始模型输出的特征确定第三损失值;将所述第一乘积与所述第三损失值的第二和值确定为所述第一损失值。
在一个示例性实施例中,训练模块56可以通过如下方式实现确定所述第二损失值对应的目标权重:确定所述目标初始模型的评价指标与第一参数的第二乘积;确定所述第二乘积与第二参数的第一差值;将所述第一差值与第三参数的比值确定为所述目标权重。
在一个示例性实施例中,所述装置还可以用于在基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标之后,在所述第一评价指标以及所述第二评价指标均小于或等于预定阈值的情况下,将所述第一初始模型以及所述第二初始模型确定为所述目标初始模型;利用训练数据迭代更新所述目标初始模型的网络参数,得到更新后的所述初始网络模型。
在一个示例性实施例中,所述装置可以通过如下方式实现利用训练数据迭代更新所述目标初始模型的网络参数,得到更新后的所述初始网络模型:将所述训练数据输入至所述第一初始模型和所述第二初始模型中,基于所述第一初始模型输出的特征以及所述第二初始模型输出的特征确定所述初始网络模型的第四损失值;基于所述第四损失值迭代更新所述第一初始模型以及所述第二初始模型的网络参数,得到更新后的所述初始网络模型。
在一个示例性实施例中,第一确定模块52可以通过如下方式实现基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标:在所述初始网络模型为分类网络模型的情况下,将所述第一训练精度确定为所述第一评价指标;在所述初始网络模型为检测网络模型的情况下,确定所述第一训练精度与第四参数的第三和值,将所述第三和值与第五参数的比值确定为所述第一评价指标;第一确定模块52可以通过如下方式实现基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标:在所述初始网络模型为分类网络模型的情况下,将所述第二训练精度确定为所述第二评价指标;在所述初始网络模型为检测网络模型的情况下,确定所述第二训练精度与第六参数的第四和值,将所述第四和值与第七参数的比值确定为所述第一评价指标。
需要说明的是,上述各个模块是可以通过软件或硬件来实现的,对于后者,可以通过以下方式实现,但不限于此:上述模块均位于同一处理器中;或者,上述各个模块以任意组合的形式分别位于不同的处理器中。
本发明的实施例还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,其中,所述计算机程序被处理器执行时实现上述任一项中所述的方法的步骤。
在一个示例性实施例中,上述计算机可读存储介质可以包括但不限于:U盘、只读存储器(Read-Only Memory,简称为ROM)、随机存取存储器(Random Access Memory,简称为RAM)、移动硬盘、磁碟或者光盘等各种可以存储计算机程序的介质。
本发明的实施例还提供了一种电子装置,包括存储器和处理器,该存储器中存储有计算机程序,该处理器被设置为运行计算机程序以执行上述任一项方法实施例中的步骤。
在一个示例性实施例中,上述电子装置还可以包括传输设备以及输入输出设备,其中,该传输设备和上述处理器连接,该输入输出设备和上述处理器连接。
本实施例中的具体示例可以参考上述实施例及示例性实施方式中所描述的示例,本实施例在此不再赘述。
显然,本领域的技术人员应该明白,上述的本发明的各模块或各步骤可以用通用的计算装置来实现,它们可以集中在单个的计算装置上,或者分布在多个计算装置所组成的网络上,它们可以用计算装置可执行的程序代码来实现,从而,可以将它们存储在存储装置中由计算装置来执行,并且在某些情况下,可以以不同于此处的顺序执行所示出或描述的步骤,或者将它们分别制作成各个集成电路模块,或者将它们中的多个模块或步骤制作成单个集成电路模块来实现。这样,本发明不限制于任何特定的硬件和软件结合。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (10)
1.一种模型的训练方法,其特征在于,包括:
基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标,以及基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标,其中,所述初始网络模型为经过训练得到的网络模型;
在所述第一评价指标以及所述第二评价指标中存在大于预定阈值的评价指标的情况下,将所述第一评价指标以及所述第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型;
利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型,其中,所述其他初始模型为所述第一评价指标以及所述第二评价指标中包括的最大评价指标对应的初始模型。
2.根据权利要求1所述的方法,其特征在于,利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型包括:
将训练数据输入至所述其他初始模型中,确定所述其他初始模型输出的第一特征;
将所述训练数据以及所述第一特征输入至所述目标初始模型中,确定所述目标初始模型的第一损失值;
基于所述第一损失值迭代更新所述目标初始模型的网络参数,得到所述目标网络模型。
3.根据权利要求2所述的方法,其特征在于,将所述训练数据以及所述第一特征输入至所述目标初始模型中,确定所述目标初始模型的第一损失值包括:
确定所述第一特征中包括的每个特征层的第一子特征与所述目标初始模型输出的第二子特征之间的损失值,得到多个第二损失值,其中,所述第二子特征与所述第一子特征处于相同的特征层;
确定多个所述第二损失值的第一和值;
确定所述第二损失值对应的目标权重;
确定所述第一和值与所述目标权重的第一乘积;
基于所述目标初始模型输出的特征与所述其他初始模型输出的特征确定第三损失值;
将所述第一乘积与所述第三损失值的第二和值确定为所述第一损失值。
4.根据权利要求3所述的方法,其特征在于,确定所述第二损失值对应的目标权重包括:
确定所述目标初始模型的评价指标与第一参数的第二乘积;
确定所述第二乘积与第二参数的第一差值;
将所述第一差值与第三参数的比值确定为所述目标权重。
5.根据权利要求1所述的方法,其特征在于,在基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标之后,所述方法还包括:
在所述第一评价指标以及所述第二评价指标均小于或等于预定阈值的情况下,将所述第一初始模型以及所述第二初始模型确定为所述目标初始模型;
利用训练数据迭代更新所述目标初始模型的网络参数,得到更新后的所述初始网络模型。
6.根据权利要求5所述的方法,其特征在于,利用训练数据迭代更新所述目标初始模型的网络参数,得到更新后的所述初始网络模型包括:
将所述训练数据输入至所述第一初始模型和所述第二初始模型中,基于所述第一初始模型输出的特征以及所述第二初始模型输出的特征确定所述初始网络模型的第四损失值;
基于所述第四损失值迭代更新所述第一初始模型以及所述第二初始模型的网络参数,得到更新后的所述初始网络模型。
7.根据权利要求1所述的方法,其特征在于,
基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标包括:在所述初始网络模型为分类网络模型的情况下,将所述第一训练精度确定为所述第一评价指标;在所述初始网络模型为检测网络模型的情况下,确定所述第一训练精度与第四参数的第三和值,将所述第三和值与第五参数的比值确定为所述第一评价指标;
基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标包括:在所述初始网络模型为分类网络模型的情况下,将所述第二训练精度确定为所述第二评价指标;在所述初始网络模型为检测网络模型的情况下,确定所述第二训练精度与第六参数的第四和值,将所述第四和值与第七参数的比值确定为所述第一评价指标。
8.一种模型的训练装置,其特征在于,包括:
第一确定模块,用于基于初始网络模型中包括的第一初始模型的第一训练精度确定所述第一初始模型的第一评价指标,以及基于所述初始网络模型中包括的第二初始模型的第二训练精度确定所述第二初始模型的第二评价指标,其中,所述初始网络模型为经过训练得到的网络模型;
第二确定模块,用于在所述第一评价指标以及所述第二评价指标中存在大于预定阈值的评价指标的情况下,将所述第一评价指标以及所述第二评价指标中包括的最小评价指标对应的初始模型确定为待优化的目标初始模型;
训练模块,用于利用训练数据以及其他初始模型训练所述目标初始模型,得到目标网络模型,其中,所述其他初始模型为所述第一评价指标以及所述第二评价指标中包括的最大评价指标对应的初始模型。
9.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有计算机程序,其中,所述计算机程序被处理器执行时实现所述权利要求1至7任一项中所述的方法的步骤。
10.一种电子装置,包括存储器和处理器,其特征在于,所述存储器中存储有计算机程序,所述处理器被设置为运行所述计算机程序以执行所述权利要求1至7任一项中所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210407353.8A CN114511083A (zh) | 2022-04-19 | 2022-04-19 | 一种模型的训练方法、装置、存储介质及电子装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210407353.8A CN114511083A (zh) | 2022-04-19 | 2022-04-19 | 一种模型的训练方法、装置、存储介质及电子装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114511083A true CN114511083A (zh) | 2022-05-17 |
Family
ID=81555025
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210407353.8A Pending CN114511083A (zh) | 2022-04-19 | 2022-04-19 | 一种模型的训练方法、装置、存储介质及电子装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114511083A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114821247A (zh) * | 2022-06-30 | 2022-07-29 | 杭州闪马智擎科技有限公司 | 一种模型的训练方法、装置、存储介质及电子装置 |
CN114998570A (zh) * | 2022-07-19 | 2022-09-02 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
-
2022
- 2022-04-19 CN CN202210407353.8A patent/CN114511083A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114821247A (zh) * | 2022-06-30 | 2022-07-29 | 杭州闪马智擎科技有限公司 | 一种模型的训练方法、装置、存储介质及电子装置 |
CN114998570A (zh) * | 2022-07-19 | 2022-09-02 | 上海闪马智能科技有限公司 | 一种对象检测框的确定方法、装置、存储介质及电子装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108804641B (zh) | 一种文本相似度的计算方法、装置、设备和存储介质 | |
CN110366734B (zh) | 优化神经网络架构 | |
CN108681750A (zh) | Gbdt模型的特征解释方法和装置 | |
CN114511083A (zh) | 一种模型的训练方法、装置、存储介质及电子装置 | |
CN111079780A (zh) | 空间图卷积网络的训练方法、电子设备及存储介质 | |
CN110874528B (zh) | 文本相似度的获取方法及装置 | |
CN110929532B (zh) | 数据处理方法、装置、设备及存储介质 | |
CN111259647A (zh) | 基于人工智能的问答文本匹配方法、装置、介质及电子设备 | |
CN113961765B (zh) | 基于神经网络模型的搜索方法、装置、设备和介质 | |
CN112307048B (zh) | 语义匹配模型训练方法、匹配方法、装置、设备及存储介质 | |
CN112182214A (zh) | 一种数据分类方法、装置、设备及介质 | |
CN106570197A (zh) | 基于迁移学习的搜索排序方法和装置 | |
US20230385317A1 (en) | Information Retrieval Method, Related System, and Storage Medium | |
CN114511042A (zh) | 一种模型的训练方法、装置、存储介质及电子装置 | |
CN115393633A (zh) | 数据处理方法、电子设备、存储介质及程序产品 | |
CN117973492A (zh) | 一种语言模型的微调方法、装置、电子设备及介质 | |
CN110262906B (zh) | 接口标签推荐方法、装置、存储介质和电子设备 | |
CN112070205A (zh) | 一种多损失模型获取方法以及装置 | |
CN104572820B (zh) | 模型的生成方法及装置、重要度获取方法及装置 | |
CN111126617A (zh) | 一种选择融合模型权重参数的方法、装置及设备 | |
CN115983362A (zh) | 一种量化方法、推荐方法以及装置 | |
CN113407806B (zh) | 网络结构搜索方法、装置、设备及计算机可读存储介质 | |
CN114357219A (zh) | 一种面向移动端实例级图像检索方法及装置 | |
CN111026661B (zh) | 一种软件易用性全面测试方法及系统 | |
CN109492046A (zh) | 集成模型的失效判断方法及装置和计算机可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |