CN109583594A - 深度学习训练方法、装置、设备及可读存储介质 - Google Patents
深度学习训练方法、装置、设备及可读存储介质 Download PDFInfo
- Publication number
- CN109583594A CN109583594A CN201811369102.5A CN201811369102A CN109583594A CN 109583594 A CN109583594 A CN 109583594A CN 201811369102 A CN201811369102 A CN 201811369102A CN 109583594 A CN109583594 A CN 109583594A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- source domain
- initial parameter
- parameter
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明实施例提供一种深度学习训练方法、装置、设备及可读存储介质。本发明实施例的方法通过将源域数训据集拆分成多个源域数据组,在对源域模型的每一轮模型训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到各小样本训练集的模型参数;根据各小样本训练集的模型参数更新源域模型的初始参数,根据更新后的初始参数能得到本轮训练后的新的模型;由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同,这样可以起到丰富训练数据的效果,即使在源域训练数据集中的样本数据较小的情况下,也可以实现训练出效果很好的模型。
Description
技术领域
本发明实施例涉及深度学习技术领域,尤其涉及一种深度学习训练方法、装置、设备及可读存储介质。
背景技术
深度学习(deep learning)已经广泛应用于各个领域,已经可以像人类一样识别与认知,甚至解决各类问题的能力在某些方面已超越了人类。
深度学习要求大体量训练数据,还需要有足够量包括标注数据的标签样本作为数据基础进行深度模型的训练。但在某些领域,由于样本采集困难、标签分析代价大等原因,通常标签样本很难获取,标签样本缺乏,小样本问题严重,导致训练出的深度模型效果差。
发明内容
本发明实施例提供一种深度学习训练方法、装置、设备及可读存储介质,用以解决在某些领域,由于样本采集困难、标签分析代价大等原因,通常标签样本很难获取,标签样本缺乏,小样本问题严重,导致训练出的深度模型效果差的问题。
本发明实施例的一个方面是提供一种深度学习训练方法,包括:
对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练,得到各小样本训练集的模型参数;
根据所述各小样本训练集的模型参数,更新源域模型的初始参数;
验证根据更新后的初始参数得到的模型是否符合预置条件;
若不符合,则跳转执行对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练的步骤;
若符合,则将更新后的初始参数确定为所述源域模型的最终参数得到源域模型。
本发明实施例的另一个方面是提供一种深度学习训练装置,包括:
训练模块,用于对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练,得到各小样本训练集的模型参数;
参数更新模块,用于根据所述各小样本训练集的模型参数,更新源域模型的初始参数;
验证模块,用于:
验证根据更新后的初始参数得到的模型是否符合预置条件;
若不符合,则跳转执行对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练的步骤;
若符合,则将更新后的初始参数确定为所述源域模型的最终参数得到源域模型。
本发明实施例的另一个方面是提供一种深度学习训练设备,其特征在于,包括:
存储器,处理器,以及存储在所述存储器上并可在所述处理器上运行的计算机程序,
所述处理器运行所述计算机程序时实现上述所述的方法。
本发明实施例的另一个方面是提供一种计算机可读存储介质,存储有计算机程序,
所述计算机程序被处理器执行时实现上述所述的方法。
本发明实施例提供的深度学习训练方法、装置、设备及可读存储介质,通过将源域数训据集拆分成多个源域数据组,在对源域模型的每一轮模型训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到本轮的各小样本训练集的模型参数;并根据各小样本训练集的模型参数,更新源域模型的初始参数,根据更新后的初始参数能得到本轮训练后的新的模型;由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同,这样可以起到丰富训练数据的效果,即使在源域训练数据集中的样本数据较小的情况下,也可以实现训练出效果很好的模型。
附图说明
图1为本发明实施例一提供的深度学习训练方法流程图;
图2为本发明实施例一提供的深度学习训练方法整体流程示意图;
图3为本发明实施例二提供的深度学习训练方法流程图;
图4为本发明实施例二提供的一种二层循环的流程示意图;
图5为本发明实施例三提供的深度学习训练装置的结构示意图;
图6为本发明实施例五提供的深度学习训练设备的结构示意图。
通过上述附图,已示出本发明明确的实施例,后文中将有更详细的描述。这些附图和文字描述并不是为了通过任何方式限制本发明实施例构思的范围,而是通过参考特定实施例为本领域技术人员说明本发明的概念。
具体实施方式
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本发明实施例相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本发明实施例的一些方面相一致的装置和方法的例子。
首先对本发明实施例所涉及的名词进行解释:
迁移学习:给的源域数据和源域任务,目标域数据和目标域任务,迁移学习就是研究如何利用源域数据和源域任务来帮助改善目标域数据的学习任务效果。一般源域数据与目标域数据不同,且源域任务和目标域任务不同。
此外,术语“第一”、“第二”等仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。在以下各实施例的描述中,“多个”的含义是两个以上,除非另有明确具体的限定。
下面这几个具体的实施例可以相互结合,对于相同或相似的概念或过程可能在某些实施例中不再赘述。下面将结合附图,对本发明的实施例进行描述。
实施例一
图1为本发明实施例一提供的深度学习训练方法流程图;图2为本发明实施例一提供的深度学习训练方法整体流程示意图。本发明实施例针对在某些领域,由于样本采集困难、标签分析代价大等原因,通常标签样本很难获取,标签样本缺乏,小样本问题严重,导致训练出的深度模型效果差的问题,提供了深度学习训练方法。
如图1和图2所示,该方法具体步骤如下:
步骤S101、对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练,得到各小样本训练集的模型参数。
首先获取源域训练数据集,并对源域训练数据集进行数据预处理,得到预处理后的源域训练数据集。为了适合带有学习策略的深度学习训练,对预处理以后的源域训练数据集进行分组处理,将源域数训据集拆分成多个源域数据组。
本实施例中,在对源域模型的每一轮迭代训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到本轮的各小样本训练集的模型参数。由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同。
步骤S102、根据各小样本训练集的模型参数,更新源域模型的初始参数。
在得到各小样本训练集的模型参数之后,根据预设的学习策略,生成本次迭代训练的最终模型参数。
其中,预设的学习策略具体包括如何根据本次更新前源域模型的初始参数,以及各小样本训练集的模型参数,更新源域模型的初始参数。
预设的学习策略采用了小样本多组联合训练方法,采用多个小样本训练集分别独立完成对源域模型的训练,得到多个训练后的源域模型,从而得到各小样本训练集的模型参数;综合得到的多组模型参数,更新源域模型的初始参数,使得训练后的源域模型具有更好的泛化能力。其中,模型参数是指源域模型中需要训练的一组参数。例如,模型参数可以模型中的权重参数等等。
步骤S103、验证根据更新后的初始参数得到的模型是否符合预置条件。
在更新更新源域模型的初始参数之后,通过验证根据更新后的初始参数得到的模型是否符合预置条件,来验证是否可以结束模型训练。
本实施例中,验证根据更新后的初始参数得到的模型是否符合预置条件至少包括:验证根据更新后的初始参数得到的模型是否收敛。只有在根据更新后的初始参数得到的模型收敛时,才有可能将更新后的初始参数作为最终参数。如果根据更新后的初始参数得到的模型不收敛,则不会将更新后的初始参数作为最终参数。
若该步骤中验证结果为不符合预置条件,则开启新一轮的模型训练,跳转执行步骤S101,对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练。
步骤S104、若符合,则将更新后的初始参数确定为源域模型的最终参数得到源域模型。
若步骤S103中验证结果为符合预置条件,则确定根据更新后的初始参数得到的模型能够满足需求,执行步骤S104,将更新后的初始参数确定为源域模型的最终参数得到源域模型,模型训练结束。
本发明实施例通过将源域数训据集拆分成多个源域数据组,在对源域模型的每一轮模型训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到本轮的各小样本训练集的模型参数;并根据各小样本训练集的模型参数,更新源域模型的初始参数,根据更新后的初始参数能得到本轮训练后的新的模型;由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同,这样可以起到丰富训练数据的效果,即使在源域训练数据集中的样本数据较小的情况下,也可以实现训练出效果很好的模型。
实施例二
图3为本发明实施例二提供的深度学习训练方法流程图。在上述实施例一的基础上,本实施例中,从训练数据组中随机抽取的多个小样本训练集,包括:对源域训练数据集进行分组处理,得到多个训练数据组;重复从多个训练数据组中分别抽取预设数量的训练数据的过程,得到多个小样本训练集。如图3所示,该方法具体步骤如下:
步骤S201、对源域训练数据集进行分组处理,得到多个训练数据组。
首先获取源域训练数据集,并对源域训练数据集进行数据预处理,得到预处理后的源域训练数据集。为了适合带有学习策略的深度学习训练,对预处理以后的源域训练数据集进行分组处理,将源域数训据集拆分成多个源域数据组。
通常,深度学习可以用于解决两类问题:一类为分类问题,另一类为回归分析问题。用于解决分类问题的深度学习任务的类型为分类学习任务,用于解决回归分析问题的深度学习任务的类型为回归学习任务。
可选的,对预处理以后的源域训练数据集进行分组处理的分组原则取决于源域学习任务。若源域模型对应的学习任务为分类学习任务,则将源域训练数据集分成的训练数据组的数量等于学习任务的类别数量。若源域模型对应的学习任务为回归学习任务,则将源域训练数据集分成的训练数据组的数量等于学习任务的可变参数的数量。。
可选的,在对预处理以后的源域训练数据集进行分组处理时,每个源域数据组中的样本数量的差值小于预设差值,以使每个源域数据组中的样本数量要尽量相同。其中,预设差值可以由技术人员根据实际需要进行设定,本实施例此处不做具体限定。
可选的,数据预处理的具体处理内容可以依据实际任务需要进行设定;或者可以采用现有技术中的深度学习方法中的数据预处理的方法实现,例如,图像亮度、饱和度、对比度变化等图像增量处理,归一化等标准化处理,等等,本实施例此处不做具体限定。
步骤S202、重复从多个训练数据组中分别抽取预设数量的训练数据的过程,得到多个小样本训练集。
本实施例中,用k表示每轮训练中得到的小样本训练集的数量。从每个训练数据组中随机抽取预设数量个训练样本,组合在一起,并随机排列,得到一个小样本训练集;重复操作k次,生成k个小样本训练集。也就是,从多个训练数据组中,各随机抽取等量的训练数据,组成一个小样本训练集;重复操作k次,生成k个小样本训练集。
可选的,k的取值可以为大于等于5且小于等于10,可以对训练数据组进行合理的抽样得到小样本训练集,以使对训练模型的训练效果更优。
其中,预设数量可以由技术人员根据训练数据组中样本总个数以及实际需要进行设定,本实施例此处不做具体限定。
可选的,可以设置预设数量小于训练数据组中样本总数的五分之一;也即是,小样本训练集中每类样本个数小于对应训练数据组中样本总数的五分之一。
步骤S203、对多个小样本训练集分别进行模型训练,得到各小样本训练集的模型参数。
具体的,分别采用每个小样本训练集,对源域模型进行预设循环次数的训练,得到该小样本训练集对应的模型参数;用k个小样本训练集对训练模型进行训练之后,得到k个小样本训练集的模型参数。
本实施例中,用每个小样本训练集对训练模型进行训练时,可以预先设定训练时的循环次数,当循环次数达到预设循环次数时,结束该小样本训练集对训练模型的训练。其中,预设循环次数可以由技术人员根据实际需要进行设定,本实施例此处不做具体限定。
可选的,训练循环的次数的取值可以为大于等于10,且小于等于20。
可选的,在采用每个小样本训练集对训练模型进行训练时,可以对训练模型进行批量训练,得到一组中间模型参数,以提高训练效率。具体的,可以预先设置训练批量的大小,每次从小样本训练集中抽取与训练批量等量的训练样本,对训练模型进行批量训练。通过多次批量训练,遍历小样本训练集中的所有训练样本,完成对训练模型的训练。其中,训练批量的大小可以由技术人员根据实际需要进行设定,本实施例此处不做局限定。
另外,采用小样本训练集对训练模型进行批量训练的方法可以采用现有技术中进行批量训练的方法实现,本实施例此处不再赘述。
步骤S204、根据各小样本训练集的模型参数,更新源域模型的初始参数。
在得到各小样本训练集的模型参数之后,根据预设的学习策略,更新源域模型的初始参数。
为了使训练模型具有更好的泛化能力,学习策略采取综合各小样本训练集的训练成果,根据各小样本训练集的模型参数,更新源域模型的初始参数,可以采用如下方式实现:
采用如下公式一,计算得到各小样本训练集的模型参数的平均值:
其中,θi表示第i个小样本训练集的模型参数。
进一步地,根据各小样本训练集的模型参数,采用如下公式二,更新源域模型的初始参数:
其中,0f表示本次更新后源域模型的初始参数,00表示本次更新前源域模型的初始参数,表示各小样本训练集的模型参数的平均值,a表示衰变系数。
衰变系数可以采用如下公式三计算得到:
其中,α0为衰变系数的预设初始值,N为预设的源域模型初始参数更新的总次数,j为当前源域模型的初始参数的更新次数,j为正整数。
在更新更新源域模型的初始参数之后,通过验证根据更新后的初始参数得到的模型是否符合预置条件,来验证是否可以结束模型训练。具体可以通过以下步骤S205-S208来验证根据更新后的初始参数得到的模型是否符合预置条件。
步骤S205、采用验证集对根据更新后的初始参数得到的模型进行模型预测,得到预测结果。
步骤S206、将预测结果与验证集对应的结果进行比较,确定根据更新后的初始参数得到的模型的准确率。
其中,验证集包括多个样本,以及每个样本对应的结果。通过将预测结果与验证集对应的结果进行比较,可以计算出预测结果相对于验证集对应的结果的正确率,得到根据更新后的初始参数得到的模型的准确率。
在得到根据更新后的初始参数得到的模型的准确率之后,比较根据更新后的初始参数得到的模型的准确率和预设准确率阈值的大小。
步骤S207、如果根据更新后的初始参数得到的模型的准确率小于准确率阈值,则确定根据更新后的初始参数得到的模型不符合预置条件。
如果根据更新后的初始参数得到的模型的准确率小于准确率阈值,则确定根据更新后的初始参数得到的模型不符合预置条件。这时,根据更新后的初始参数得到的模型不能满足需求,需继续执行步骤S202,开启新一轮的模型训练。
其中,准确率阈值可以由技术人员根据实际应用场景和经验进行设定,本实施例此处不做具体限定。
步骤S208、若准确率大于或者等于准确率阈值,则比较根据更新后的初始参数得到的模型的准确率与根据更新前的初始参数得到的模型的准确率的大小。
若准确率大于或者等于准确率阈值,则说明根据更新后的初始参数得到的模型的准确率能够满足需求,此时,可以通过比较根据更新后的初始参数得到的模型的准确率与根据更新前的初始参数得到的模型的准确率的大小,来确定本轮的模型训练使得源域模型的初始参数更优。
步骤S209、若根据更新后的初始参数得到的模型的准确率大于或者等于根据更新前的初始参数得到的模型的准确率,继续执行步骤S202。
若根据更新后的初始参数得到的模型的准确率大于或者等于根据更新前的初始参数得到的模型的准确率,说明本轮的模型训练使得源域模型的初始参数更优,那么继续执行步骤S202,启动下一轮的模型训练,以继续优化源域模型的初始参数。
其中,准确率阈值可以由技术人员根据实际应用场景和经验进行设定,本实施例此处不做具体限定。
步骤S210、若根据更新后的初始参数得到的模型的准确率小于根据更新前的初始参数得到的模型的准确率,则确定更新前的初始参数确定为源域模型的最终参数得到源域模型。
若根据更新后的初始参数得到的模型的准确率小于根据更新前的初始参数得到的模型的准确率,说明本轮的模型训练并没有使得源域模型的初始参数更差,那么将不再进行下一轮的模型训练。此时,将效果更优的更新前的初始参数确定为源域模型的最终参数得到源域模型。
本实施例的一种可行实施方式中,可以采用两层循环实现上述采用多个训练数据组,对训练模型进行带有学习策略的深度学习训练,得到训练后的模型的过程。
具体的,如图4所示,带有学习策略的模型训练过程包括内循环和外循环两部分:内循环的循环次数可以设置为k,内循环负责生成k个小样本训练集,并基于k个小样本训练集对训练模型进行批量的模型训练,得到k个小样本训练集型参数,将得到的k个小样本训练集的模型参数输出到外循环的生成学习策略模块。外循环的迭代次数可以设置为N,外循环的生成学习策略模块负责收集每次内循环训练后生成模型参数,然后基于预设的学习策略更新源域模型的初始参数,验证根据更新后的初始参数得到的模型是否符合预置条件;若不符合,则将更新后的源域模型的初始参数作为下一次内循环的初始参数,启动下一次内循环;直到外循环达到设置的迭代次数N,或者根据更新后的初始参数得到的模型符合预置条件为止,训练结束。
本实施例的另一实施方式中,在确定源域模型的最终参数得到源域模型之后,还可以将源域模型迁移到目标域,作为目标域模型。将源域模型的最终参数作为目标域模型的初始参数,使得目标域模型具有更优的初始参数。然后,获取预处理后的目标域训练数据集,并将目标域训练数据集进行分组处理,得到目标域的多个训练数据组。基于目标域的多个训练数据组,采用上述任一实施例提供的深度学习训练方法进行目标域模型的训练,确定目标域模型的最终参数得到目标域模型。
这种模型迁移适用于源域学习任务与目标域学习任务具有相同的类型的场景。例如,源域学习任务为3分类任务,目标域学习任务也为3分类任务,那么,可以实现整个模型的迁移,将源域的训练模型迁移到目标域。
本发明实施例通过将源域数训据集拆分成多个源域数据组,在对源域模型的每一轮模型训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到本轮的各小样本训练集的模型参数;并根据各小样本训练集的模型参数,更新源域模型的初始参数,根据更新后的初始参数能得到本轮训练后的新的模型;由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同,这样可以起到丰富训练数据的效果,即使在源域训练数据集中的样本数据较小的情况下,也可以实现训练出效果很好的模型。
实施例三
图5为本发明实施例三提供的深度学习训练装置的结构示意图。本发明实施例提供的深度学习训练装置可以执行深度学习训练方法实施例提供的处理流程。如图5所示,该深度学习训练装置30包括:训练模块301,参数更新模块302和验证模块303。
具体地,训练模块301用于对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练,得到各小样本训练集的模型参数。
参数更新模块302用于根据各小样本训练集的模型参数,更新源域模型的初始参数。
验证模块303用于:
验证根据更新后的初始参数得到的模型是否符合预置条件;若不符合,则跳转执行对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练的步骤;若符合,则将更新后的初始参数确定为源域模型的最终参数得到源域模型。
本发明实施例提供的装置可以具体用于执行上述实施例一所提供的方法实施例,具体功能此处不再赘述。
本发明实施例通过将源域数训据集拆分成多个源域数据组,在对源域模型的每一轮模型训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到本轮的各小样本训练集的模型参数;并根据各小样本训练集的模型参数,更新源域模型的初始参数,根据更新后的初始参数能得到本轮训练后的新的模型;由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同,这样可以起到丰富训练数据的效果,即使在源域训练数据集中的样本数据较小的情况下,也可以实现训练出效果很好的模型。
实施例四
在上述实施例三的基础上,本实施例中,训练模块还用于:
对源域训练数据集进行分组处理,得到多个训练数据组;重复从多个训练数据组中分别抽取预设数量的训练数据的过程,得到多个小样本训练集。
可选的,训练模块还用于:
若源域模型对应的学习任务为分类学习任务,则将源域训练数据集分成的训练数据组的数量等于学习任务的类别数量;若源域模型对应的学习任务为回归学习任务,则将源域训练数据集分成的训练数据组的数量等于学习任务的可变参数的数量。
可选的,参数更新模块还用于:
根据各小样本训练集的模型参数,采用如下公式,更新源域模型的初始参数:
其中,0f表示本次更新后源域模型的初始参数,00表示本次更新前源域模型的初始参数,表示各小样本训练集的模型参数的平均值,a表示衰变系数。
衰变系数为:
其中,α0为衰变系数的预设初始值,N为预设的源域模型初始参数更新的总次数,j为当前源域模型的初始参数的更新次数,j为正整数。
可选的,验证模块还用于:
采用验证集对根据更新后的初始参数得到的模型进行模型预测,得到预测结果;将预测结果与验证集对应的结果进行比较,确定根据更新后的初始参数得到的模型的准确率;若准确率小于准确率阈值,则确定根据更新后的初始参数得到的模型不符合预置条件;若准确率大于或者等于准确率阈值,则确定根据更新后的初始参数得到的模型符合预置条件。
可选的,验证模块还用于:
若准确率大于或者等于准确率阈值,则比较根据更新后的初始参数得到的模型的准确率与根据更新前的初始参数得到的模型的准确率的大小;若根据更新后的初始参数得到的模型的准确率大于或者等于根据更新前的初始参数得到的模型的准确率,则跳转执行对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练的步骤;若根据更新后的初始参数得到的模型的准确率小于根据更新前的初始参数得到的模型的准确率,则确定更新前的初始参数确定为源域模型的最终参数得到源域模型。
本发明实施例提供的装置可以具体用于执行上述实施例二所提供的方法实施例,具体功能此处不再赘述。
本发明实施例通过将源域数训据集拆分成多个源域数据组,在对源域模型的每一轮模型训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到本轮的各小样本训练集的模型参数;并根据各小样本训练集的模型参数,更新源域模型的初始参数,根据更新后的初始参数能得到本轮训练后的新的模型;由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同,这样可以起到丰富训练数据的效果,即使在源域训练数据集中的样本数据较小的情况下,也可以实现训练出效果很好的模型。
实施例五
图6为本发明实施例五提供的深度学习训练设备的结构示意图。如图6所示,该深度学习训练设备60包括:处理器601,存储器602,以及存储在存储器602上并可由处理器601执行的计算机程序。
处理器601在执行存储在存储器602上的计算机程序时实现上述任一方法实施例提供的深度学习训练方法。
本发明实施例通过将源域数训据集拆分成多个源域数据组,在对源域模型的每一轮模型训练中,均从训练数据组中随机抽取的多个小样本训练集,作为本轮的训练数据进行模型训练,得到本轮的各小样本训练集的模型参数;并根据各小样本训练集的模型参数,更新源域模型的初始参数,根据更新后的初始参数能得到本轮训练后的新的模型;由于每轮模型训练均重新从训练数据组中随机抽取的多个小样本训练集,作为新的训练数据,使得每轮模型训练所使用的训练数据均不相同,这样可以起到丰富训练数据的效果,即使在源域训练数据集中的样本数据较小的情况下,也可以实现训练出效果很好的模型。
另外,本发明实施例还提供一种计算机可读存储介质,存储有计算机程序,计算机程序被处理器执行时实现上述任一方法实施例提供的深度学习训练方法。
在本发明所提供的几个实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用硬件加软件功能单元的形式实现。
上述以软件功能单元的形式实现的集成的单元,可以存储在一个计算机可读取存储介质中。上述软件功能单元存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本发明各个实施例所述方法的部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
本领域技术人员可以清楚地了解到,为描述的方便和简洁,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将装置的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。上述描述的装置的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本发明的其它实施方案。本发明旨在涵盖本发明的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本发明的一般性原理并包括本发明未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本发明的真正范围和精神由下面的权利要求书指出。
应当理解的是,本发明并不局限于上面已经描述并在附图中示出的精确结构,并且可以在不脱离其范围进行各种修改和改变。本发明的范围仅由所附的权利要求书来限制。
Claims (10)
1.一种深度学习模型训练方法,其特征在于,包括:
对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练,得到各小样本训练集的模型参数;
根据所述各小样本训练集的模型参数,更新源域模型的初始参数;
验证根据更新后的初始参数得到的模型是否符合预置条件;
若不符合,则跳转执行对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练的步骤;
若符合,则将更新后的初始参数确定为所述源域模型的最终参数得到源域模型。
2.根据权利要求1所述的方法,其特征在于,所述对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练,包括:
对源域训练数据集进行分组处理,得到多个训练数据组;
重复从所述多个训练数据组中分别抽取预设数量的训练数据的过程,得到多个小样本训练集。
3.根据权利要求2所述的方法,其特征在于,所述对源域训练数据集进行分组处理,得到多个训练数据组,包括:
若所述源域模型对应的学习任务为分类学习任务,则将所述源域训练数据集分成的训练数据组的数量等于所述学习任务的类别数量;
若所述源域模型对应的学习任务为回归学习任务,则将所述源域训练数据集分成的训练数据组的数量等于所述学习任务的可变参数的数量。
4.根据权利要求1所述的方法,其特征在于,所述根据所述各小样本训练集的模型参数,更新源域模型的初始参数,包括:
根据所述各小样本训练集的模型参数,采用如下公式,更新源域模型的初始参数:
其中,θf表示本次更新后源域模型的初始参数,θ0表示本次更新前源域模型的初始参数,表示所述各小样本训练集的模型参数的平均值,α表示衰变系数。
5.根据权利要求4所述的方法,其特征在于,所述衰变系数为:
其中,α0为衰变系数的预设初始值,N为预设的源域模型初始参数更新的总次数,j为当前源域模型的初始参数的更新次数,j为正整数。
6.根据权利要求1所述的方法,其特征在于,所述验证根据更新后的初始参数得到的模型是否符合预置条件,包括:
采用验证集对根据更新后的初始参数得到的模型进行模型预测,得到预测结果;
将所述预测结果与所述验证集对应的结果进行比较,确定根据所述更新后的初始参数得到的模型的准确率;
若所述准确率小于准确率阈值,则确定根据更新后的初始参数得到的模型不符合预置条件;
若所述准确率大于或者等于所述准确率阈值,则确定根据更新后的初始参数得到的模型符合预置条件。
7.根据权利要求6所述的方法,其特征在于,所述确定根据更新后的初始参数得到的模型符合预置条件之前,包括:
若所述准确率大于或者等于所述准确率阈值,则比较根据更新后的初始参数得到的模型的准确率与根据更新前的初始参数得到的模型的准确率的大小;
若根据更新后的初始参数得到的模型的准确率大于或者等于根据更新前的初始参数得到的模型的准确率,则跳转执行对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练的步骤;
若根据更新后的初始参数得到的模型的准确率小于根据更新前的初始参数得到的模型的准确率,则确定所述更新前的初始参数确定为所述源域模型的最终参数得到源域模型。
8.一种深度学习训练装置,其特征在于,包括:
训练模块,用于对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练,得到各小样本训练集的模型参数;
参数更新模块,用于根据所述各小样本训练集的模型参数,更新源域模型的初始参数;
验证模块,用于:
验证根据更新后的初始参数得到的模型是否符合预置条件;
若不符合,则跳转执行对从训练数据组中随机抽取的多个小样本训练集分别进行模型训练的步骤;
若符合,则将更新后的初始参数确定为所述源域模型的最终参数得到源域模型。
9.一种深度学习训练设备,其特征在于,包括:
存储器,处理器,以及存储在所述存储器上并可在所述处理器上运行的计算机程序,
所述处理器运行所述计算机程序时实现如权利要求1-8中任一项所述的方法。
10.一种计算机可读存储介质,其特征在于,存储有计算机程序,
所述计算机程序被处理器执行时实现如权利要求1-8中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811369102.5A CN109583594B (zh) | 2018-11-16 | 2018-11-16 | 深度学习训练方法、装置、设备及可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201811369102.5A CN109583594B (zh) | 2018-11-16 | 2018-11-16 | 深度学习训练方法、装置、设备及可读存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN109583594A true CN109583594A (zh) | 2019-04-05 |
CN109583594B CN109583594B (zh) | 2021-03-30 |
Family
ID=65923032
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201811369102.5A Active CN109583594B (zh) | 2018-11-16 | 2018-11-16 | 深度学习训练方法、装置、设备及可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN109583594B (zh) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110188829A (zh) * | 2019-05-31 | 2019-08-30 | 北京市商汤科技开发有限公司 | 神经网络的训练方法、目标识别的方法及相关产品 |
CN110751183A (zh) * | 2019-09-24 | 2020-02-04 | 东软集团股份有限公司 | 影像数据分类模型的生成方法、影像数据分类方法及装置 |
CN111310905A (zh) * | 2020-05-11 | 2020-06-19 | 创新奇智(南京)科技有限公司 | 神经网络模型训练方法、装置及暖通系统能效优化方法 |
CN113254435A (zh) * | 2021-07-15 | 2021-08-13 | 北京电信易通信息技术股份有限公司 | 一种数据增强方法及系统 |
WO2022027806A1 (zh) * | 2020-08-07 | 2022-02-10 | 深圳先进技术研究院 | 深度学习模型的参数重用方法、装置、终端及存储介质 |
CN114127698A (zh) * | 2019-07-18 | 2022-03-01 | 日本电信电话株式会社 | 学习装置、检测系统、学习方法以及学习程序 |
CN114898178A (zh) * | 2022-05-10 | 2022-08-12 | 支付宝(杭州)信息技术有限公司 | 图像识别神经网络模型的训练方法及系统 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107368892A (zh) * | 2017-06-07 | 2017-11-21 | 无锡小天鹅股份有限公司 | 基于机器学习的模型训练方法和装置 |
CN107704926A (zh) * | 2017-11-23 | 2018-02-16 | 清华大学 | 一种大数据跨领域分析的深度迁移学习方法 |
CN107943911A (zh) * | 2017-11-20 | 2018-04-20 | 北京大学深圳研究院 | 数据抽取方法、装置、计算机设备及可读存储介质 |
US20180292220A1 (en) * | 2017-04-05 | 2018-10-11 | International Business Machines Corporation | Deep learning allergen mapping |
CN108764486A (zh) * | 2018-05-23 | 2018-11-06 | 哈尔滨工业大学 | 一种基于集成学习的特征选择方法及装置 |
-
2018
- 2018-11-16 CN CN201811369102.5A patent/CN109583594B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180292220A1 (en) * | 2017-04-05 | 2018-10-11 | International Business Machines Corporation | Deep learning allergen mapping |
CN107368892A (zh) * | 2017-06-07 | 2017-11-21 | 无锡小天鹅股份有限公司 | 基于机器学习的模型训练方法和装置 |
CN107943911A (zh) * | 2017-11-20 | 2018-04-20 | 北京大学深圳研究院 | 数据抽取方法、装置、计算机设备及可读存储介质 |
CN107704926A (zh) * | 2017-11-23 | 2018-02-16 | 清华大学 | 一种大数据跨领域分析的深度迁移学习方法 |
CN108764486A (zh) * | 2018-05-23 | 2018-11-06 | 哈尔滨工业大学 | 一种基于集成学习的特征选择方法及装置 |
Non-Patent Citations (1)
Title |
---|
张雁: "基于机器学习的遥感图像分类研究", 《中国博士学位论文全文数据库》 * |
Cited By (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110188829A (zh) * | 2019-05-31 | 2019-08-30 | 北京市商汤科技开发有限公司 | 神经网络的训练方法、目标识别的方法及相关产品 |
CN110188829B (zh) * | 2019-05-31 | 2022-01-28 | 北京市商汤科技开发有限公司 | 神经网络的训练方法、目标识别的方法及相关产品 |
CN114127698A (zh) * | 2019-07-18 | 2022-03-01 | 日本电信电话株式会社 | 学习装置、检测系统、学习方法以及学习程序 |
CN110751183A (zh) * | 2019-09-24 | 2020-02-04 | 东软集团股份有限公司 | 影像数据分类模型的生成方法、影像数据分类方法及装置 |
CN111310905A (zh) * | 2020-05-11 | 2020-06-19 | 创新奇智(南京)科技有限公司 | 神经网络模型训练方法、装置及暖通系统能效优化方法 |
CN111310905B (zh) * | 2020-05-11 | 2020-08-18 | 创新奇智(南京)科技有限公司 | 神经网络模型训练方法、装置及暖通系统能效优化方法 |
WO2022027806A1 (zh) * | 2020-08-07 | 2022-02-10 | 深圳先进技术研究院 | 深度学习模型的参数重用方法、装置、终端及存储介质 |
CN113254435A (zh) * | 2021-07-15 | 2021-08-13 | 北京电信易通信息技术股份有限公司 | 一种数据增强方法及系统 |
CN113254435B (zh) * | 2021-07-15 | 2021-10-29 | 北京电信易通信息技术股份有限公司 | 一种数据增强方法及系统 |
CN114898178A (zh) * | 2022-05-10 | 2022-08-12 | 支付宝(杭州)信息技术有限公司 | 图像识别神经网络模型的训练方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN109583594B (zh) | 2021-03-30 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109583594A (zh) | 深度学习训练方法、装置、设备及可读存储介质 | |
CN111553480B (zh) | 图像数据处理方法、装置、计算机可读介质及电子设备 | |
Gu et al. | A new deep learning method based on AlexNet model and SSD model for tennis ball recognition | |
CN109325516B (zh) | 一种面向图像分类的集成学习方法及装置 | |
CN106897714A (zh) | 一种基于卷积神经网络的视频动作检测方法 | |
CN104933428B (zh) | 一种基于张量描述的人脸识别方法及装置 | |
CN110135582B (zh) | 神经网络训练、图像处理方法及装置、存储介质 | |
CN111739115B (zh) | 基于循环一致性的无监督人体姿态迁移方法、系统及装置 | |
CN105787557A (zh) | 一种计算机智能识别的深层神经网络结构设计方法 | |
CN110110861A (zh) | 确定模型超参数及模型训练的方法和装置、存储介质 | |
CN105989849A (zh) | 一种语音增强方法、语音识别方法、聚类方法及装置 | |
US11907821B2 (en) | Population-based training of machine learning models | |
CN110363239A (zh) | 一种面向多模态数据的小样本机器学习方法、系统和介质 | |
CN108805149A (zh) | 一种视觉同步定位与地图构建的回环检测方法及装置 | |
CN108647571A (zh) | 视频动作分类模型训练方法、装置及视频动作分类方法 | |
CN109886343A (zh) | 图像分类方法及装置、设备、存储介质 | |
CN109514553A (zh) | 一种机器人移动控制的方法、系统及设备 | |
CN111598213A (zh) | 网络训练方法、数据识别方法、装置、设备和介质 | |
CN114881225A (zh) | 输变电巡检模型网络结构搜索方法、系统及存储介质 | |
CN108549857A (zh) | 事件检测模型训练方法、装置及事件检测方法 | |
CN113822434A (zh) | 用于知识蒸馏的模型选择学习 | |
CN110210419A (zh) | 高分辨率遥感图像的场景识别系统及模型生成方法 | |
CN109471951A (zh) | 基于神经网络的歌词生成方法、装置、设备和存储介质 | |
CN117744759A (zh) | 文本信息的识别方法、装置、存储介质及电子设备 | |
CN105117330B (zh) | Cnn代码测试方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |