CN110309922A - 一种网络模型训练方法和装置 - Google Patents

一种网络模型训练方法和装置 Download PDF

Info

Publication number
CN110309922A
CN110309922A CN201910527781.2A CN201910527781A CN110309922A CN 110309922 A CN110309922 A CN 110309922A CN 201910527781 A CN201910527781 A CN 201910527781A CN 110309922 A CN110309922 A CN 110309922A
Authority
CN
China
Prior art keywords
network model
loss function
sub
numerical value
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201910527781.2A
Other languages
English (en)
Inventor
张文迪
崔正文
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing QIYI Century Science and Technology Co Ltd
Original Assignee
Beijing QIYI Century Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing QIYI Century Science and Technology Co Ltd filed Critical Beijing QIYI Century Science and Technology Co Ltd
Priority to CN201910527781.2A priority Critical patent/CN110309922A/zh
Publication of CN110309922A publication Critical patent/CN110309922A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Data Exchanges In Wide-Area Networks (AREA)

Abstract

本发明实施例提供了一种网络模型训练方法和装置,方法包括:将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型,获取各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,原始损失函数用于表示目标网络模型的实际输出结果与期望输出结果之间的差值,根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。基于上述处理,能够提高训练好的网络模型的有效性。

Description

一种网络模型训练方法和装置
技术领域
本发明涉及人工智能技术领域,特别是涉及一种网络模型训练方法和装置。
背景技术
随着人工智能技术的快速发展,神经网络模型(可以简称为网络模型)在系统辨识、模式识别、智能控制等领域有着广泛的应用前景。通常可以基于训练样本集,对预设的网络模型进行训练,得到训练好的网络模型,进而,可以将待检测样本输入至训练好的网络模型,得到网络模型的实际输出结果,实际输出结果为对待检测样本进行预测的预测结果。
一种实现方式中,网络模型具有初始的模型参数,在对网络模型进行训练的过程中,可以根据损失函数对网络模型的模型参数进行调整,损失函数可以用于表示网络模型的实际输出结果与期望输出结果之间的差值,对模型参数进行调整的目的是为了使损失函数的数值不断减小。当达到预设停止训练条件,得到训练好的网络模型。
然而,发明人在实现本发明的过程中发现,现有技术至少存在如下问题:
针对包含多个子网络模型的目标网络模型,由于各子网络模型的结构、特性之间存在差异,在根据目标网络模型的损失函数对目标网络模型进行训练的情况下,当停止训练目标网络模型时,一些子网络模型达到较好的收敛状态,另一些子网络模型可能并未达到收敛状态,进而会导致训练好的目标网络模型的有效性较低。
发明内容
本发明实施例的目的在于提供一种网络模型训练方法和装置,能够提高训练好的网络模型的有效性。具体技术方案如下:
第一方面,为了达到上述目的,本发明实施例公开了一种网络模型训练方法,所述方法包括:
将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
可选的,所述根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整,包括:
根据所述各子网络模型各自的损失函数的数值与所述原始损失函数的数值的总和值,对所述目标网络模型的模型参数进行调整。
可选的,所述原始损失函数为所述目标网络模型的实际输出结果与期望输出结果的交叉熵,一个子网络模型的损失函数为该子网络模型的实际输出结果与期望输出结果的交叉熵。
可选的,所述目标网络模型为宽度和深度Wide&Deep网络模型。
可选的,所述预设停止训练条件为:
根据所述预设训练样本集对所述目标网络模型进行模型训练的次数,达到预设次数;
或者,
将测试样本输入至所述目标网络模型中,得到的所述目标损失函数的数值小于预设阈值。
第二方面,为了达到上述目的,本发明实施例公开了一种网络模型训练装置,所述装置包括:
第一处理模块,用于将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取模块,用于获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
调整模块,用于根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
第二处理模块,用于当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
可选的,所述调整模块,具体用于根据所述各子网络模型各自的损失函数的数值与所述原始损失函数的数值的总和值,对所述目标网络模型的模型参数进行调整。
可选的,所述原始损失函数为所述目标网络模型的实际输出结果与期望输出结果的交叉熵,一个子网络模型的损失函数为该子网络模型的实际输出结果与期望输出结果的交叉熵。
可选的,所述目标网络模型为宽度和深度Wide&Deep网络模型。
可选的,所述预设停止训练条件为:
根据所述预设训练样本集对所述目标网络模型进行模型训练的次数,达到预设次数;
或者,
将测试样本输入至所述目标网络模型中,得到的所述目标损失函数的数值小于预设阈值。
在本发明实施的又一方面,还提供了一种电子设备,所述电子设备包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现上述任一所述的网络模型训练方法。
在本发明实施的又一方面,还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述任一所述的网络模型训练方法。
在本发明实施的又一方面,本发明实施例还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述任一所述的网络模型训练方法。
本发明实施例提供了一种网络模型训练方法,可以将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型,获取各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,原始损失函数用于表示目标网络模型的实际输出结果与期望输出结果之间的差值,根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。由于根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,因此,当达到预设停止训练条件时,各子网络模型和目标网络模型都达到较好的收敛状态,能够提高训练好的目标网络模型的有效性。
当然,实施本发明的任一产品或方法并不一定需要同时达到以上所述的所有优点。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍。
图1为本发明实施例提供的一种网络模型训练方法的流程图;
图2为本发明实施例提供的一种网络模型训练方法示例的流程图;
图3为本发明实施例提供的一种计算目标损失函数过程的示意图;
图4(a)为采用原始损失函数进行模型训练时,宽度和深度网络模型的训练曲线图;
图4(b)为采用原始损失函数进行模型训练时,深度子网络模型的训练曲线图;
图4(c)为采用原始损失函数进行模型训练时,宽度子网络模型的训练曲线图;
图5(a)为采用目标损失函数进行模型训练时,宽度和深度网络模型的训练曲线图;
图5(b)为采用目标损失函数进行模型训练时,深度子网络模型的训练曲线图;
图5(c)为采用目标损失函数进行模型训练时,宽度子网络模型的训练曲线图;
图6为本发明实施例提供的一种网络模型训练装置的结构图;
图7为本发明实施例提供的一种电子设备的结构图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行描述。
现有技术中,针对包含多个子网络模型的目标网络模型,由于各子网络模型的结构、特性之间存在差异,当根据目标网络模型的实际输出结果和期望输出结果,确定达到预设停止训练条件时,目标网络模型中的一些子网络模型可能并未达到收敛状态,进而会导致训练好的目标网络模型的有效性较低。
为了解决上述问题,本发明提供一种网络模型训练方法,该方法可以应用于电子设备,该电子设备可以是终端,也可以是服务器,该电子设备用于对网络模型进行训练。
电子设备可以将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型,获取各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,原始损失函数用于表示目标网络模型的实际输出结果与期望输出结果之间的差值。
然后,电子设备可以根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整。
当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
由于电子设备根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,因此,当达到预设停止训练条件时,各子网络模型和目标网络模型都达到较好的收敛状态,进而,能够提高训练好的目标网络模型的有效性。
下面以具体实施例对本发明进行详细介绍。
参见图1,图1为本发明实施例提供的一种网络模型训练方法的流程图,该方法可以包括以下步骤:
S101:将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型。
在发明实施中,电子设备可以确定当前待训练的目标网络模型,并获取预设训练样本集,进而,电子设备可以将训练样本集中的训练样本包含的输入参数分别输入至目标网络模型所包含的各子网络模型,设置目标网络模型的输出为训练样本包含的对应的输出参数,以对目标网络模型进行训练。
本步骤中,电子设备可以依次将每一训练样本分别输入至各子网络模型,对目标网络模型进行模型训练,也可以依次将预设数目个训练样本分别输入各子网络模型,对目标网络模型进行模型训练,电子设备根据训练样本对目标网络模型进行模型训练的方式并不限于此。
S102:获取各子网络模型各自的损失函数的数值,以及原始损失函数的数值。
其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,原始损失函数用于表示目标网络模型的实际输出结果与期望输出结果之间的差值。
目标网络模型的实际输出结果可以为各子网络模型的实际输出结果的加权和,各子网络模型的实际输出结果各自的权重可以是在对目标网络模型进行模型训练的过程中确定出的。
在将训练样本输入至目标网络模型后,电子设备可以获取每一子网络模型的实际输出结果,并根据该子网络模型的期望输出结果,得到该子网络模型的损失函数的数值。
另外,电子设备还可以获取目标网络模型总的实际输出结果,并根据目标网络模型总的期望输出结果,得到目标网络模型的损失函数(即原始损失函数)的数值。
S103:根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整。
在发明实施中,在对目标网络模型进行模型训练的过程中,电子设备可以根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,直至达到预设停止训练条件。
可选的,S103可以包括以下步骤:
根据各子网络模型各自的损失函数的数值与原始损失函数的数值的总和值,对目标网络模型的模型参数进行调整。
在发明实施中,电子设备可以将各子网络模型的损失函数与原始损失函数的总和,作为目标损失函数。
进而,在对目标网络模型进行模型训练的过程中,可以根据目标损失函数对目标网络模型的模型参数进行调整,以使在训练目标网络模型的过程中,目标损失函数的数值不断减少,直至达到预设停止训练条件。
S104:当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
其中,预设停止训练条件可以由技术人员根据经验进行设置。
可选的,预设停止训练条件可以为:根据预设训练样本集对目标网络模型进行模型训练的次数,达到预设次数。
预设次数可以由技术人员根据经验进行设置,例如,预设次数可以为10000次,但并不限于此。
一种实现方式中,如果预设次数为10000次,预设训练样本集包含1000个样本,则电子设备可以根据每一样本对目标网络模型重复训练10次,直至根据所有的样本对目标网络模型训练结束,此时,对目标网络模型进行模型训练的次数达到预设次数(即10000次),电子设备可以确定达到预设停止训练条件,进而,完成模型训练,得到训练好的目标网络模型。
或者,预设停止训练条件也可以为:将测试样本输入至目标网络模型中,得到的目标损失函数的数值小于预设阈值。
其中,预设阈值可以由技术人员根据经验进行设置,例如,预设阈值可以为0.01,但并不限于此。
一种实现方式中,如果预设阈值为0.01,在根据训练样本集对目标网络模型进行模型训练的过程中,电子设备可以将测试样本输入至已训练的目标网络模型中,并判断此时目标损失函数的数值是否小于0.01。当电子设备判定目标损失函数的数值小于0.01时,电子设备可以确定当前达到预设停止训练条件,进而,完成模型训练,得到训练好的目标网络模型。
可见,基于本实施例提供的网络模型训练方法,由于电子设备根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,因此,当达到预设停止训练条件时,各子网络模型和目标网络模型都达到较好的收敛状态,能够提高训练好的目标网络模型的有效性。
上述实施例中的各损失函数可以为相同的损失函数,也可以为不同的损失函数,例如,铰链损失函数(Hinge Loss Function)、交叉熵损失函数(Cross-entropy LossFunction)或其他损失函数。
可选的,为了提高对目标网络模型进行模型训练的效率,原始损失函数可以为目标网络模型的实际输出结果与期望输出结果的交叉熵,一个子网络模型的损失函数可以为该子网络模型的实际输出结果与期望输出结果的交叉熵。
一种实现方式中,电子设备可以计算各子网络模型的实际输出结果与预期输出结果之间的交叉熵。然后,电子设备可以计算目标网络模型的实际输出结果和预期输出结果的交叉熵。
进而,在对目标网络模型进行模型训练的过程中,电子设备可以根据得到的各交叉熵的总和值,对目标网络模型的模型参数进行调整。
在电子设备将测试样本输入至已训练的目标网络模型中,得到对应的各交叉熵的总和值后,电子设备可以判断得到的各交叉熵的总和值是否小于预设阈值。当电子设备判定得到的各交叉熵的总和值小于预设阈值时,电子设备可以确定当前达到预设停止训练条件,进而,完成模型训练,得到训练好的目标网络模型。
综上,基于本实施例的网络模型训练方法,各损失函数可以均为交叉熵损失函数,能够提高每一子网络模型的训练效率,进而提高目标网络模型的训练效率。
可选的,目标网络模型可以为Wide&Deep(宽度和深度)网络模型(一类用于分类和回归的网络模型),本申请中的宽度和深度网络模型均指Wide&Deep网络模型。子网络模型可以为宽度子网络模型,或者深度子网络模型。目标网络模型中的宽度子网络模型可以为一个,也可以为多个;目标网络模型中的深度子网络模型可以为一个,也可以为多个。
一种实现方式中,如果目标网络模型为宽度和深度网络模型,宽度和深度网络模型包含一个深度子网络模型和一个宽度子网络模型。
可以用Output-wide表示宽度子网络模型的实际输出结果,Label-wide表示宽度子网络模型的期望输出结果,则Loss-wide=H(Label-wide,Output-wide),Loss-wide表示宽度子网络模型的损失函数,H()可以表示交叉熵损失函数。
Output-deep表示深度子网络模型的实际输出结果,Label-deep表示深度子网络模型的期望输出结果,则Loss-deep=H(Label-deep,Output-deep),Loss-deep表示深度子网络模型的损失函数。
Output-all=Output-wide×W-wide+Output-deep×W-deep,Output-all表示宽度和深度网络模型的实际输出结果,W-wide表示宽度子网络模型的权重,W-deep表示深度子网络模型的权重,可以通过对宽度和深度网络模型的模型训练,确定W-wi de和W-deep。Loss-all=H(Label-all,Output-all),Loss-all表示原始损失函数,Label-al l表示宽度和深度网络模型的期望输出结果。
则可以得到Loss=Loss-all+Loss-deep+Loss-wide,Loss表示目标损失函数。电子设备可以在根据测试样本得到的Loss的数值小于预设阈值时,确定达到预设停止训练条件,进而,完成模型训练,得到训练好的宽度和深度网络模型。
参见图2,图2为本发明实施例提供的一种网络模型训练方法示例的流程图,该方法可以包括以下步骤:
S201:将预设训练样本集中的训练样本,分别输入至宽度和深度网络模型中的宽度子网络模型和深度子网络模型。
S202:获取宽度子网络模型的损失函数的数值、深度子网络模型的损失函数的数值,以及原始损失函数的数值。
其中,原始损失函数为宽度和深度网络模型的实际输出结果与期望输出结果的交叉熵,宽度子网络模型的损失函数为宽度子网络模型的实际输出结果与期望输出结果的交叉熵,深度子网络模型的损失函数为深度子网络模型的实际输出结果与期望输出结果的交叉熵。
S203:根据宽度子网络模型的损失函数的数值、深度子网络模型的损失函数的数值与原始损失函数的数值的总和值,对宽度和深度网络模型的模型参数进行调整。
S204:当达到预设停止训练条件时,停止模型训练,得到训练好的宽度和深度网络模型。
参见图3,图3为本发明实施例提供的一种计算目标损失函数过程的示意图,图3与图2的方法相对应。
图3中,宽度和深度网络模型可以包括宽度子网络模型和深度子网络模型,计算宽度子网络模型的实际输出结果与深度网络子模型的实际输出结果的加权和,得到宽度和深度网络模型的实际输出结果,并根据宽度和深度网络模型的预期输出结果,得到原始损失函数。根据宽度子网络模型的实际输出结果,得到宽度子网络模型的损失函数,根据深度子网络模型的实际输出结果,得到深度子网络模型的损失函数,进而,将宽度子网络模型的损失函数、深度子网络模型的损失函数和原始损失函数的总和,作为目标损失函数。
由于目标损失函数为宽度子网络模型的损失函数、深度子网络模型的损失函数和原始损失函数的总和,因此,在根据目标损失函数对宽度和深度网络模型的模型参数进行调整的过程中,达到预设停止训练条件时,宽度子网络模型、深度子网络模型,以及宽度和深度网络模型都达到较好的收敛状态,能够提高训练好的宽度和深度网络模型的有效性。
参见图4(a)为采用现有的损失函数(即原始损失函数)进行模型训练时,目标网络模型(即宽度和深度网络模型)的训练曲线图。
图4(b)为采用原始损失函数进行模型训练时,深度子网络模型的训练曲线图。
图4(c)为采用原始损失函数进行模型训练时,宽度子网络模型的训练曲线图。
图5(a)为采用目标损失函数进行模型训练时,宽度和深度网络模型的训练曲线图。
图5(b)为采用目标损失函数进行模型训练时,深度子网络模型的训练曲线图。
图5(c)为采用目标损失函数进行模型训练时,宽度子网络模型的训练曲线图。
上图中,带圆点的线为测试样本对应的曲线,不带圆点的线为训练样本对应的曲线,横坐标表示训练次数,纵坐标表示网络模型的精准度。
对比图4(a)和图5(a),采用目标损失函数进行模型训练时,宽度和深度网络模型整体的精准度有提高。
对比图4(b)和图5(b),采用原始损失函数进行模型训练时,深度子网络模型的精准度逐渐下降,而采用目标损失函数进行模型训练时,深度网络子模型的精准度则逐渐提升。
对比图4(c)和图5(c),采用目标损失函数进行模型训练时,能够在深度子网络模型具有较高精准度的前提下,保证宽度子网络模型也具有较高的精准度。
与图1的方法实施例相对应,参见图6,图6为本发明实施例提供的一种网络模型训练装置的结构图,所述装置可以包括:
第一处理模块601,用于将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取模块602,用于获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
调整模块603,用于根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
第二处理模块604,用于当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
可选的,所述调整模块603,具体用于根据所述各子网络模型各自的损失函数的数值与所述原始损失函数的数值的总和值,对所述目标网络模型的模型参数进行调整。
可选的,所述原始损失函数为所述目标网络模型的实际输出结果与期望输出结果的交叉熵,一个子网络模型的损失函数为该子网络模型的实际输出结果与期望输出结果的交叉熵。
可选的,所述目标网络模型为宽度和深度Wide&Deep网络模型。
可选的,所述预设停止训练条件为:
根据所述预设训练样本集对所述目标网络模型进行模型训练的次数,达到预设次数;
或者,
将测试样本输入至所述目标网络模型中,得到的所述目标损失函数的数值小于预设阈值。
可见,基于本发明实施例提供的网络模型训练装置,将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型,获取各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,原始损失函数用于表示目标网络模型的实际输出结果与期望输出结果之间的差值,根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。基于上述处理,能够提高训练好的网络模型的有效性。
本发明实施例还提供了一种电子设备,如图7所示,包括处理器701、通信接口702、存储器703和通信总线704,其中,处理器701,通信接口702,存储器703通过通信总线704完成相互间的通信,
存储器703,用于存放计算机程序;
处理器701,用于执行存储器703上所存放的程序时,实现本发明实施例提供的网络模型训练方法。
具体的,上述网络模型训练方法,包括:
将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,简称PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,简称EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,简称RAM),也可以包括非易失性存储器(non-volatile memory),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,简称CPU)、网络处理器(Network Processor,简称NP)等;还可以是数字信号处理器(Digital Signal Processing,简称DSP)、专用集成电路(Application SpecificIntegrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
本发明实施例提供的电子设备在对目标网络模型进行模型训练时,根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,因此,当达到预设停止训练条件时,各子网络模型和目标网络模型都达到较好的收敛状态,能够提高训练好的目标网络模型的有效性。
本发明实施例还提供了一种计算机可读存储介质,该计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行本发明实施例提供的网络模型训练方法。
具体的,上述网络模型训练方法,包括:
将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
需要说明的是,上述网络模型训练方法的其他实现方式与前述方法实施例部分相同,这里不再赘述。
通过运行本发明实施例提供的计算机可读存储介质中存储的指令,根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,因此,当达到预设停止训练条件时,各子网络模型和目标网络模型都达到较好的收敛状态,能够提高训练好的目标网络模型的有效性。
本发明实施例还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行本发明实施例提供的网络模型训练方法。
具体的,上述网络模型训练方法,包括:
将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
需要说明的是,上述网络模型训练方法的其他实现方式与前述方法实施例部分相同,这里不再赘述。
通过运行本发明实施例提供的计算机程序产品,根据各子网络模型各自的损失函数的数值,以及原始损失函数的数值,对目标网络模型的模型参数进行调整,因此,当达到预设停止训练条件时,各子网络模型和目标网络模型都达到较好的收敛状态,能够提高训练好的目标网络模型的有效性。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本发明实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘Solid State Disk(SSD))等。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置、电子设备、计算机可读存储介质、计算机程序产品实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本发明的较佳实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本发明的保护范围内。

Claims (11)

1.一种网络模型训练方法,其特征在于,所述方法包括:
将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
2.根据权利要求1所述的方法,其特征在于,所述根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整,包括:
根据所述各子网络模型各自的损失函数的数值与所述原始损失函数的数值的总和值,对所述目标网络模型的模型参数进行调整。
3.根据权利要求1所述的方法,其特征在于,所述原始损失函数为所述目标网络模型的实际输出结果与期望输出结果的交叉熵,一个子网络模型的损失函数为该子网络模型的实际输出结果与期望输出结果的交叉熵。
4.根据权利要求1所述的方法,其特征在于,所述目标网络模型为宽度和深度Wide&Deep网络模型。
5.根据权利要求1所述的方法,其特征在于,所述预设停止训练条件为:
根据所述预设训练样本集对所述目标网络模型进行模型训练的次数,达到预设次数;
或者,
将测试样本输入至所述目标网络模型中,得到的所述目标损失函数的数值小于预设阈值。
6.一种网络模型训练装置,其特征在于,所述装置包括:
第一处理模块,用于将预设训练样本集中的训练样本,分别输入至目标网络模型所包含的各子网络模型;
获取模块,用于获取所述各子网络模型各自的损失函数的数值,以及原始损失函数的数值,其中,一个子网络模型的损失函数用于表示该子网络模型的实际输出结果与期望输出结果之间的差值,所述原始损失函数用于表示所述目标网络模型的实际输出结果与期望输出结果之间的差值;
调整模块,用于根据所述各子网络模型各自的损失函数的数值,以及所述原始损失函数的数值,对所述目标网络模型的模型参数进行调整;
第二处理模块,用于当达到预设停止训练条件时,停止模型训练,得到训练好的目标网络模型。
7.根据权利要求6所述的装置,其特征在于,所述调整模块,具体用于根据所述各子网络模型各自的损失函数的数值与所述原始损失函数的数值的总和值,对所述目标网络模型的模型参数进行调整。
8.根据权利要求6所述的装置,其特征在于,所述原始损失函数为所述目标网络模型的实际输出结果与期望输出结果的交叉熵,一个子网络模型的损失函数为该子网络模型的实际输出结果与期望输出结果的交叉熵。
9.根据权利要求6所述的装置,其特征在于,所述目标网络模型为宽度和深度Wide&Deep网络模型。
10.根据权利要求6所述的装置,其特征在于,所述预设停止训练条件为:
根据所述预设训练样本集对所述目标网络模型进行模型训练的次数,达到预设次数;
或者,
将测试样本输入至所述目标网络模型中,得到的所述目标损失函数的数值小于预设阈值。
11.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,所述处理器,所述通信接口,所述存储器通过所述通信总线完成相互间的通信;
所述存储器,用于存放计算机程序;
所述处理器,用于执行所述存储器上所存放的程序时,实现权利要求1-5任一所述的方法步骤。
CN201910527781.2A 2019-06-18 2019-06-18 一种网络模型训练方法和装置 Pending CN110309922A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201910527781.2A CN110309922A (zh) 2019-06-18 2019-06-18 一种网络模型训练方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201910527781.2A CN110309922A (zh) 2019-06-18 2019-06-18 一种网络模型训练方法和装置

Publications (1)

Publication Number Publication Date
CN110309922A true CN110309922A (zh) 2019-10-08

Family

ID=68077418

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201910527781.2A Pending CN110309922A (zh) 2019-06-18 2019-06-18 一种网络模型训练方法和装置

Country Status (1)

Country Link
CN (1) CN110309922A (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111046027A (zh) * 2019-11-25 2020-04-21 北京百度网讯科技有限公司 时间序列数据的缺失值填充方法和装置
CN111091116A (zh) * 2019-12-31 2020-05-01 华南师范大学 一种用于判断心律失常的信号处理方法及系统
CN111310823A (zh) * 2020-02-12 2020-06-19 北京迈格威科技有限公司 目标分类方法、装置和电子系统
CN111626098A (zh) * 2020-04-09 2020-09-04 北京迈格威科技有限公司 模型的参数值更新方法、装置、设备及介质

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111046027A (zh) * 2019-11-25 2020-04-21 北京百度网讯科技有限公司 时间序列数据的缺失值填充方法和装置
CN111091116A (zh) * 2019-12-31 2020-05-01 华南师范大学 一种用于判断心律失常的信号处理方法及系统
CN111091116B (zh) * 2019-12-31 2021-05-18 华南师范大学 一种用于判断心律失常的信号处理方法及系统
CN111310823A (zh) * 2020-02-12 2020-06-19 北京迈格威科技有限公司 目标分类方法、装置和电子系统
CN111310823B (zh) * 2020-02-12 2024-03-29 北京迈格威科技有限公司 目标分类方法、装置和电子系统
CN111626098A (zh) * 2020-04-09 2020-09-04 北京迈格威科技有限公司 模型的参数值更新方法、装置、设备及介质

Similar Documents

Publication Publication Date Title
CN110309922A (zh) 一种网络模型训练方法和装置
CN109754105B (zh) 一种预测方法及终端、服务器
CN110766080B (zh) 一种标注样本确定方法、装置、设备及存储介质
CN113435247B (zh) 一种通信干扰智能识别方法、系统及终端
EP3729857A1 (en) Radio coverage map generation
US11721229B2 (en) Question correction method, device, electronic equipment and storage medium for oral calculation questions
CN109188410A (zh) 一种非视距场景下的距离校准方法、装置及设备
CN109977415A (zh) 一种文本纠错方法及装置
CN109981195B (zh) 无线信号强度的处理方法及装置
CN114936323B (zh) 图表示模型的训练方法、装置及电子设备
CN114520736A (zh) 一种物联网安全检测方法、装置、设备及存储介质
Arjona et al. Fast fuzzy anti‐collision protocol for the RFID standard EPC Gen‐2
CN111626360A (zh) 用于检测锅炉故障类型的方法、装置、设备和存储介质
CN108495265B (zh) 一种室内定位方法、装置及计算设备
CN111565065B (zh) 一种无人机基站部署方法、装置及电子设备
KR101846970B1 (ko) 전자전 위협신호의 분류를 위한 딥 신경망 학습장치 및 방법
CN109814067A (zh) 一种三维节点定位方法及装置
CN112329692B (zh) 一种有限样本条件下跨场景人体行为无线感知方法及装置
CN111368792B (zh) 特征点标注模型训练方法、装置、电子设备及存储介质
CN111310823B (zh) 目标分类方法、装置和电子系统
CN111585739B (zh) 一种相位调整方法及装置
CN114970495A (zh) 人名消歧方法、装置、电子设备及存储介质
CN107436788B (zh) 一种应用程序的卸载方法、装置及终端设备
CN111400677A (zh) 一种用户检测方法及装置
CN112926608A (zh) 一种图像分类方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20191008