CN117010480A - 模型训练方法、装置、设备、存储介质及程序产品 - Google Patents
模型训练方法、装置、设备、存储介质及程序产品 Download PDFInfo
- Publication number
- CN117010480A CN117010480A CN202211017332.1A CN202211017332A CN117010480A CN 117010480 A CN117010480 A CN 117010480A CN 202211017332 A CN202211017332 A CN 202211017332A CN 117010480 A CN117010480 A CN 117010480A
- Authority
- CN
- China
- Prior art keywords
- model
- sample data
- target
- training
- loss value
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 208
- 238000000034 method Methods 0.000 title claims abstract description 109
- 238000003860 storage Methods 0.000 title claims abstract description 34
- 238000012545 processing Methods 0.000 claims description 17
- 230000004927 fusion Effects 0.000 claims description 11
- 238000004458 analytical method Methods 0.000 claims description 10
- 238000009826 distribution Methods 0.000 claims description 10
- 239000006185 dispersion Substances 0.000 claims description 8
- 230000004044 response Effects 0.000 claims description 4
- 238000010801 machine learning Methods 0.000 abstract description 8
- 230000007704 transition Effects 0.000 abstract description 6
- 230000008569 process Effects 0.000 description 52
- 238000005516 engineering process Methods 0.000 description 24
- 239000013598 vector Substances 0.000 description 15
- 230000000750 progressive effect Effects 0.000 description 14
- 230000006870 function Effects 0.000 description 12
- 238000013507 mapping Methods 0.000 description 12
- 238000013473 artificial intelligence Methods 0.000 description 10
- 238000010586 diagram Methods 0.000 description 9
- 230000006978 adaptation Effects 0.000 description 8
- 238000004364 calculation method Methods 0.000 description 5
- 238000007405 data analysis Methods 0.000 description 5
- 238000013508 migration Methods 0.000 description 4
- 230000005012 migration Effects 0.000 description 4
- 238000012546 transfer Methods 0.000 description 4
- 230000009466 transformation Effects 0.000 description 4
- 230000003044 adaptive effect Effects 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 238000004590 computer program Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 230000011218 segmentation Effects 0.000 description 3
- 239000007787 solid Substances 0.000 description 3
- 241001465754 Metazoa Species 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000000605 extraction Methods 0.000 description 2
- 230000010354 integration Effects 0.000 description 2
- 210000001525 retina Anatomy 0.000 description 2
- 238000013459 approach Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000033228 biological regulation Effects 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000005520 cutting process Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 230000006698 induction Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 230000003902 lesion Effects 0.000 description 1
- 238000007726 management method Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000006855 networking Effects 0.000 description 1
- 208000014081 polyp of colon Diseases 0.000 description 1
- 238000011084 recovery Methods 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000002207 retinal effect Effects 0.000 description 1
- 230000004256 retinal image Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/80—Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Molecular Biology (AREA)
- Data Mining & Analysis (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了一种模型训练方法、装置、设备、存储介质及程序产品,涉及机器学习领域。该方法包括:获取第一领域的样本数据;通过源模型对样本数据进行预测,输出得到第一伪标签,以及,通过动力模型对变形样本数据进行预测,输出得到第二伪标签;通过目标候选模型对样本数据进行预测,得到样本预测结果;基于第一伪标签、第二伪标签与样本预测结果之间差异确定损失值;基于损失值对目标候选模型进行迭代训练,得到目标模型。通过增加动力模型的过渡,使样本数据从源域通过动力模型过渡至目标模型对应的目标域,提高了模型训练效率和准确率。
Description
技术领域
本申请实施例涉及机器学习领域,特别涉及一种模型训练方法、装置、设备、存储介质及程序产品。
背景技术
在机器学习领域,可以将已训练的模型进行领域迁移,从而将训练的效果从源领域模型迁移至目标领域模型。示意性的,源领域模型为视网膜眼底分割模型,通过对源领域模型的领域迁移,将其迁移至结肠息肉分割模型中。
相关技术中,在模型领域迁移的过程中,利用源模型产生伪标签,对后续目标模型的预测结果进行约束,从而完成目标模型在源模型基础上的训练过程。
然而,由于伪标签是在目标模型训练之前产生的,在训练过程中无法在伪标签以外继续应用源模型的知识,从而导致源模型知识的应用效率较低,目标模型的训练效率较低。
发明内容
本申请实施例提供了一种模型训练方法、装置、设备、存储介质及程序产品,能够提高模型迁移的效率和准确率。所述技术方案如下。
一方面,提供了一种模型训练方法,所述方法包括:
获取样本数据,所述样本数据是第一领域中采集的用于对目标模型进行训练的数据;
通过源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过动力模型对变形样本数据进行预测,输出得到第二伪标签,所述源模型为预先训练得到的针对第二领域进行数据预测的模型,所述动力模型是从待训练的目标候选模型变形得到的模型,所述变形样本数据是对所述样本数据变形得到的数据;
通过所述目标候选模型对所述样本数据进行预测,得到样本预测结果;
基于所述第一伪标签与所述样本预测结果之间的第一差异,和所述第二伪标签与所述样本预测结果之间的第二差异确定损失值;
基于所述损失值对所述目标候选模型进行迭代训练,得到所述目标模型,所述目标模型用于对所述第一领域的数据进行预测。
在一个可选的实施例中,所述基于所述第一伪标签与所述样本预测结果之间的第一差异,和所述第二伪标签与所述样本预测结果之间的第二差异确定损失值,包括:
确定所述第一伪标签与所述样本预测结果之间的第一损失值;
确定所述第二伪标签与所述样本预测结果之间的第二损失值;
对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值。
在一个可选的实施例中,所述获取样本数据之后,还包括:
对所述样本数据在所述源模型和所述目标候选模型中的输出结果之间进行离散度分析,得到所述样本数据对应的学习复杂度;
基于所述学习复杂度对所述样本数据进行权重分配,得到权重参数;
所述对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值,包括:
基于所述权重参数对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值。
在一个可选的实施例中,所述基于所述学习复杂度对所述样本数据进行权重分配,得到权重参数,包括:
获取当前迭代循环次序与预设迭代循环次数之间的比值;
基于所述比值得到所述候选权重参数;
基于所述学习复杂度和所述候选权重参数对所述样本数据进行权重分配,得到所述权重参数。
在一个可选的实施例中,所述基于所述权重参数对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值,包括:
获取所述候选权重参数和所述权重参数与所述第一损失值的第一乘积;
获取预设阈值与候选权重参数的第一差值,以及所述第一差值与所述第二损失值的第二乘积;
将所述第一乘积和所述第二乘积之和作为所述损失值。
在一个可选的实施例中,所述通过源模型和动力模型对所述样本数据进行预测之前,还包括:
通过预设变形参数对所述目标候选模型进行变形处理,得到所述动力模型。
在一个可选的实施例中,所述通过预设变形参数对所述目标候选模型进行变形处理,得到所述动力模型,包括:
将第i次迭代中的动力模型的模型参数与所述预设变形参数相乘,得到第一乘积参数;
获取预设参数与所述预设变形参数的第二差值;
将第i次迭代后训练得到的目标模型的模型参数与所述第二差值相乘,得到第二乘积参数;
将所述第一乘积参数和所述第二乘积参数之和作为第i+1次迭代中的动力模型的模型参数。
另一方面,提供了一种模型训练装置,所述装置包括:
获取模块,用于获取样本数据,所述样本数据是第一领域中采集的用于对目标模型进行训练的数据;
预测模块,用于通过源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过动力模型对变形样本数据进行预测,输出得到第二伪标签,所述源模型为预先训练得到的针对第二领域进行数据预测的模型,所述动力模型是从待训练的目标候选模型变形得到的模型,所述变形样本数据是对所述样本数据变形得到的数据;
所述预测模块,还用于通过所述目标候选模型对所述样本数据进行预测,得到样本预测结果;
确定模块,用于基于所述第一伪标签与所述样本预测结果之间的第一差异,和所述第二伪标签与所述样本预测结果之间的第二差异确定损失值;
训练模块,用于基于所述损失值对所述目标候选模型进行迭代训练,得到所述目标模型,所述目标模型用于对所述第一领域的数据进行预测。
另一方面,提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如上述本申请实施例中任一所述模型训练方法。
另一方面,提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如上述本申请实施例中任一所述的模型训练方法。
另一方面,提供了一种计算机程序产品,该计算机程序产品包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中任一所述的模型训练方法。
本申请实施例提供的技术方案带来的有益效果至少包括:
通过将目标候选模型进行变形得到动力模型,从而提供了从源模型到目标模型,以及从源模型迁移到动力模型的渐进式训练过程,辅助完成从源域到目标域的训练过程,训练过程稳定适应从源模型到目标模型的渐进,从而控制模型训练平滑地从源域转移至目标域,也即,通过增加动力模型的过渡,使样本数据从源域通过动力模型过渡至目标模型对应的目标域,提高了模型训练效率和准确率。另外,由于动力模型是由目标模型变形得到的,用于对变形样本数据进行预测,并得到对应的第二伪标签,由于变形样本数据存在变形情况,通过动力模型辅助目标模型在训练过程中对数据的变形复原预测能力,进一步提高了目标模型的训练准确率。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了本申请一个示例性实施例提供的无源领域自适应过程的示意图;
图2是本申请一个示例性实施例提供的实施环境示意图;
图3是本申请一个示例性实施例提供的模型训练方法的流程图;
图4是本申请另一个示例性实施例提供的模型训练方法的流程图;
图5是基于图4示出的实施例提供的基础模型的结构示意图;
图6是基于图4示出的实施例提供的空洞卷积对应的示意图;
图7是本申请另一个示例性实施例提供的模型训练方法的流程图;
图8是本申请一个示例性实施例提供的模型训练装置的结构框图;
图9是本申请另一个示例性实施例提供的模型训练装置的结构框图;
图10是本申请一个示例性实施例提供的服务器的结构框图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
首先,针对本申请实施例中涉及的名词进行简单介绍。
人工智能(Artificial Intelligence,AI):是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大特征表示的提取技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
机器学习(Machine Learning,ML):是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
领域自适应:是指将分布不同的源域数据和目标域数据映射到同一个特征空间中,使源域数据和目标域数据在特征空间中的距离符合距离要求,从而,在特征空间中对源领域数据训练得到的函数就能够迁移至目标领域数据上进行使用,提高了目标领域上数据的预测准确率和效率。
本申请实施例中所涉及的领域自适应为无源领域自适应,也即在领域自适应的基础上,设定源数据无法直接获取,而进行领域自适应的过程。
然而,相关技术中无源领域自适应框架主要集中在对目标数据进行伪标签的修正,而未考虑学习过程。
如图1所示,以上述数据实现为图像数据为例,本申请实施例中涉及的无源领域自适应过程中主要包括两个部分中的至少一个部分。
一、从易到难110
请参考图1,其示出了本申请实施例中从易到难110过程的示意图。如图1所示,将样本图像输入源模型111和目标模型112后,根据源模型111输出的结果和目标模型112输出的结果进行离散度分析,最终得到每个样本图像对应的权重,该权重代表了样本图像对应的识别复杂度。其中,复杂度越高的图像对应的权重越低,复杂度越低的图像对应的权重越高。
二、从源到目标120
在获取标注有权重的样本图像后,根据样本图像对应的权重情况对样本图像按权重从高到低依次输入进行训练。其中,目标模型112是基于源模型111初始化得到的模型,对目标模型112进行变形处理,得到动力模型113。将样本图像输入源模型111、目标模型112,以及将样本图像变形得到的图像输入动力模型113,从而根据源模型111输出的第一伪标签、目标模型112输出的样本预测结果以及动力模型113输出的第二伪标签计算损失值,对目标模型112进行迭代训练。
需要说明的是,本申请所涉及的信息(包括但不限于用户设备信息、用户个人信息等)、数据(包括但不限于用于分析的数据、存储的数据、展示的数据等)以及信号,均为经用户单独授权或者经过各方充分授权的,且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。例如,本申请中涉及到的源数据、目标数据都是在充分授权的情况下获取的。
其次,对本申请实施例中涉及的实施环境进行说明,示意性的,请参考图2,该实施环境中涉及终端210、服务器220,终端210和服务器220之间通过通信网络230连接。
在一些实施例中,终端210用于向服务器220发送数据。在一些实施例中,终端210中安装有具有数据分析功能(如:数据类别预测、数据识别等)的应用程序,示意性的,终端210中安装有图像识别应用程序。可选地,终端210中安装有搜索引擎程序、旅游应用程序、生活辅助应用程序、即时通讯应用程序、视频类程序、游戏类程序等,本申请实施例对此不加以限定。
服务器220用于基于源模型对目标模型进行训练,从而通过目标模型向终端210提供数据分析功能。其中,目标模型是针对第一领域进行数据分析的模型,源模型是针对第二领域进行数据分析的模型,如:目标模型用于针对视网膜眼底分割图像中由设备A采集的图像进行分割,源模型用于针对视网膜眼底分割图像中由设备B采集的图像进行分割;或者,目标模型用于针对森林中采集的动物图像进行识别,源模型用于针对草原上采集的动物图像进行识别。可选地,服务器220首先获取样本数据。根据源模型和待训练的目标模型对样本数据的识别结果进行离散度分析,得到各样本数据分别对应的权重。其中,目标模型是对源模型进行初始化得到的模型。第一领域和第二领域属于相同类型的领域,可选地,源模型和目标模型的预测结果范围相同,示意性的,源模型用于对图像在类型库A中进行图像类型识别,则目标模型也用于对图像在类型库A中进行图像类型识别,但源模型针对的源域图像和目标模型针对的目标域图像对应的采集方式不同,如:源域图像和目标域图像对应的采集设备不同,或者,对应的采集场景不同,本实施例对此不加以限定。
另外,对目标模型进行变形处理后,得到动力模型,根据源模型、目标模型以及动力模型对样本数据的识别结果,计算损失值,其中,将源模型和动力模型的识别结果作为伪标签,对目标模型的识别结果进行约束。
根据损失值对目标模型进行迭代训练,得到最终应用于数据分析的目标模型。
上述终端可以是手机、平板电脑、台式电脑、便携式笔记本电脑、智能电视、车载终端、智能家居设备等多种形式的终端设备,本申请实施例对此不加以限定。
值得注意的是,上述服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。
其中,云技术(Cloud technology)是指在广域网或局域网内将硬件、软件、网络等系列资源统一起来,实现数据的计算、储存、处理和共享的一种托管技术。云技术基于云计算商业模式应用的网络技术、信息技术、整合技术、管理平台技术、应用技术等的总称,可以组成资源池,按需所用,灵活便利。云计算技术将变成重要支撑。技术网络系统的后台服务需要大量的计算、存储资源,如视频网站、图片类网站和更多的门户网站。伴随着互联网行业的高度发展和应用,将来每个物品都有可能存在自己的识别标志,都需要传输到后台系统进行逻辑处理,不同程度级别的数据将会分开处理,各类行业数据皆需要强大的系统后盾支撑,只能通过云计算来实现。
在一些实施例中,上述服务器还可以实现为区块链系统中的节点。
结合上述名词简介和应用场景,对本申请提供的模型训练方法进行说明,该方法可以由服务器或者终端执行,也可以由服务器和终端共同执行,本申请实施例中,以该方法由服务器执行为例进行说明,如图3所示,该方法包括如下步骤。
步骤301,获取样本数据,样本数据是第一领域中采集的用于对目标模型进行训练的数据。
可选地,该样本数据是针对第一领域获取的公开数据集中的数据,用于对目标模型进行无源领域自适应训练。
可选地,服务器可以接收终端上传的样本数据;或者,服务器可以从其他服务器获取公开数据集中的公开数据。
示意性的,样本数据是针对视网膜图像眼底分割领域获取的公开数据集中的数据。
步骤302,通过源模型对样本数据进行预测,输出得到第一伪标签,以及,通过动力模型对变形样本数据进行预测,输出得到第二伪标签。
其中,源模型为预先训练得到的针对第二领域进行数据预测的模型,动力模型是从待训练的目标候选模型变形得到的模型。变形样本数据是对样本数据变形得到的数据。
可选地,动力模型用于训练目标候选模型对变形数据的还原能力。
第一领域和第二领域属于相同类型的领域,可选地,源模型和目标模型的预测结果范围相同,示意性的,源模型用于对图像在类型库A中进行图像类型识别,则目标模型也用于对图像在类型库A中进行图像类型识别,但源模型针对的源域图像和目标模型针对的目标域图像对应的采集方式不同,如:源域图像和目标域图像对应的采集设备不同,或者,对应的采集场景不同,本实施例对此不加以限定。
源模型是在第二领域中预先经过训练的模型,可选地,源模型为训练结果符合训练要求的模型,如:源模型是训练迭代次数达到预设次数的模型;或者,源模型是训练过程中,损失值收敛的模型。
为了让模型更平滑的从源领域迁移到目标领域,本申请实施例中提出“从源到目标”的课程学习算法,在不同的迁移阶段使用来自不同的模型生成的伪标签。本申请实施例中,将样本数据输入三分支的网络结构,这个结构包括源模型、目标模型和动力模型,这三个模型有相同的网络结构。源模型的参数是冻结的,目标模型是对源模型经过初始化得到的模型,其中,初始化方式包括AdaBN初始化方式或者其他初始化方式,动力模型是将待训练的目标候选模型根据预设变形方式进行变形得到的模型,可选地,动力模型由目标模型进行参数变形得到。
可选地,通过预设变形参数对目标候选模型进行变形处理,得到动力模型。示意性的,动力模型的变形过程如下公式一所示:
公式一:fm←τfm’+(1-τ)ft
其中,τ为预设变形参数,fm是当前更新得到的动力模型,fm’为最近一次迭代得到的动力模型,如:根据第i次迭代更新的动力模型,变形得到第i+1轮迭代更新中的动力模型,ft为目标候选模型。
通过源模型、目标候选模型和动力模型的三分支结构,从源模型递进到目标模型的课程学习和自监督学习任务完成更新。其中,通过动力模型和目标候选模型之间的预测情况完整自监督学习任务,通过源模型、动力模型和目标候选模型的预测情况完成课程学习。本申请实施例中所涉及的课程学习是指目标模型对源模型特征空间的学习过程。
可选地,第一伪标签可以是软标签或者硬标签;第二伪标签可以是软标签或者硬标签。其中,第一伪标签和第二伪标签的形式相同或者不相同。其中,软标签是指针对预测的各个类型赋概率值的标签,如:类型A对应概率为0.95,类型B对应概率为0.6;硬标签是指通过二值化表达数据是否数据类型的标签,如:类型A对应的硬标签为1,表示数据属于类型A,类型B对应的硬标签为0,表示数据不属于类型B。
步骤303,通过目标候选模型对样本数据进行预测,得到样本预测结果。
目标候选模型为当前待训练的模型,且目标候选模型是针对第一领域的数据进行训练的模型,当目标候选模型训练完成后,即得到目标模型。也即,目标候选模型为模型参数待调整的模型,当目标候选模型的模型参数调整完毕后,冻结模型参数,将冻结模型参数的目标候选模型作为目标模型。
将样本数据输入目标候选模型后,通过目标候选模型对样本数据进行预测,输出得到样本预测结果。其中,目标候选模型对样本数据的预测包括:分类预测、识别预测、图像处理结果预测等,本申请实施例对此不加以限定。
步骤304,基于第一伪标签与样本预测结果之间的第一差异,和第二伪标签与样本预测结果之间的第二差异确定损失值。
样本预测结果包括硬标签结果和软标签结果中的至少一种。
基于第一伪标签和第二伪标签的形式,与对应形式的样本预测结果进行差异比对。以第一伪标签为例,若第一伪标签实现为硬标签形式,则将样本预测结果的硬标签结果与第一伪标签进行差异比对;若第一伪标签实现为软标签形式,则将样本预测结果的软标签结果与第一伪标签进行差异比对。
在一些实施例中,通过第一预设损失函数对第一伪标签与样本预测结果之间的第一差异进行确定;以及,通过第二预设损失函数对第二伪标签与样本预测结果之间的第二差异进行确定。
根据第一差异和第二差异确定损失值,并基于损失值对目标候选模型进行训练。
步骤305,基于损失值对目标候选模型进行迭代训练,得到目标模型。
目标模型用于对第一领域的数据进行预测,也即,目标模型用于将数据的特征表示映射至与源模型相同的特征空间中,从而根据特征表示在特征空间中的映射情况得到数据的预测结果。
可选地,每轮迭代训练中,通过计算得到的损失值对该轮迭代中的目标候选模型进行训练,得到下一轮迭代中的目标候选模型,并继续进行后续训练,直至得到目标模型。
示意性的,在第i轮迭代训练中,使用第i轮迭代训练得到的目标候选模型,将第i轮迭代训练得到的目标候选模型变形后,得到第i+1轮迭代训练中的动力模型,根据源模型、第i轮迭代训练中的动力模型和第i轮迭代训练得到的目标候选模型确定损失值,并基于损失值对第i轮迭代训练得到的目标候选模型进行训练,得到第i轮迭代训练得到的目标候选模型。重复迭代对目标候选模型进行训练,直至训练符合训练要求后,得到目标模型。
综上所述,本实施例提供的方法,通过将目标候选模型进行变形得到动力模型,从而提供了从源模型到目标模型,以及从源模型迁移到动力模型的渐进式训练过程,辅助完成从源域到目标域的训练过程,使训练过程稳定适应从源模型到目标模型,从而控制模型训练平滑地从源域转移至目标域,也即,通过增加动力模型的过渡,使样本数据从源域通过动力模型过渡至目标模型对应的目标域,提高了模型训练效率和准确率。
在一个可选的实施例中,迭代循环训练过程包括目标候选模型的训练,和目标候选模型训练后生成的动力模型。图4是本申请另一个示例性实施例提供的模型训练方法的流程图,该方法可以由服务器或者终端执行,也可以由服务器和终端共同执行,本申请实施例中,以该方法由服务器执行为例进行说明,如图4所示,该方法包括如下步骤。
步骤401,获取样本数据,样本数据是第一领域中采集的用于对目标模型进行训练的数据。
可选地,该样本数据是针对第一领域获取的公开数据集中的数据,用于对目标模型进行无源领域自适应训练。
步骤402,通过源模型对样本数据进行预测,输出得到第一伪标签,以及,通过第i次迭代得到的动力模型对变形样本数据进行预测,输出得到第二伪标签,i为正整数。
其中,源模型为预先训练得到的针对第二领域进行数据预测的模型,动力模型是从待训练的目标候选模型变形得到的模型。
可选地,动力模型是第i次迭代得到的目标候选模型变形得到的模型。在一些实施例中,第i+1次迭代得到的动力模型是基于第i次迭代中的动力模型和第i次迭代后得到的目标候选模型变形得到的。如上公式一所示,fm是当前第i+1次迭代中更新得到的动力模型,fm’为第i次迭代中更新得到的动力模型,ft为第i次迭代训练得到的目标候选模型。
在目标候选模型的迭代循环训练过程中,通过源模型对样本数据进行预测,以及通过第i+1次循环迭代中的动力模型对变形样本数据进行预测,根据预测的结果训练第i次循环迭代训练得到的目标候选模型,得到第i+1次循环迭代训练得到的目标候选模型,并根据第i+1次循环迭代训练得到的目标候选模型对第i+1次循环迭代中的动力模型进行变形,得到第i+2次循环迭代中的动力模型,由此往复,直至训练得到目标模型。
其中,变形样本数据是对源模型输入的样本数据进行变形得到的数据,以样本数据为图像为例,也即,对样本数据进行图像变换,得到变形样本数据。其中图像变换方式包括水平翻转、垂直翻转、剪裁等变换方式中的至少一种。
将图像变换后的变形样本数据输入动力模型,由动力模型输出第二伪标签。
源模型是预先训练得到的将样本数据的特征表示映射至预设特征空间的模型,该预设特征空间是源模型在针对第二领域数据进行训练时生成的特征空间,该预设特征空间中包括各个预测分类对应的特征中心向量,各个预测分类对应的特征中心向量用于与样本数据的特征表示在特征空间中的映射结果进行匹配。
源模型提取样本数据的特征表示后,将特征表示映射至预设特征空间中,得到样本数据在预设特征空间中的特征向量,将该特征向量与各预测分类对应的特征中心向量进行匹配,从而确定与该样本数据对应的第一伪标签。当第一伪标签为硬标签时,将与特征向量距离最近的特征中心向量对应的预测分类作为硬标签;当第一伪标签为软便签时,根据各个特征中心向量与特征向量的距离确定预测分类对应样本数据的概率,作为软标签。
而动力模型是在第i次迭代得到的目标候选模型基础上变形得到的模型,对于特征的映射、特征空间的参数与源模型存在不同。动力模型得到第二伪标签的过程与源模型得到第一伪标签的过程类似,提取变形样本数据的特征表示后,将变形样本数据的特征表示映射至动力模型对应的特征空间中,与动力模型的特征空间中各预测分类的特征中心向量进行匹配,从而得到第二伪标签。当第二伪标签为硬标签时,将与变形样本数据的特征表示距离最近的特征中心向量对应的预测分类作为硬标签;当第二伪标签为软便签时,根据各个特征中心向量与变形样本数据的特征表示的距离确定预测分类对应样本数据的概率,作为软标签。
可选地,通过动力模型对变形样本数据进行预测,输出得到第二候选伪标签,对第二候选伪标签进行逆变形处理,得到第二伪标签,逆变形处理的变形方式与变形样本数据的变形方式相反。如:变形样本数据时样本数据通过水平翻转得到的,则逆变形是指将候选伪标签重新进行水平翻转得到第二伪标签的变形方式。
步骤403,通过第i次迭代训练得到的目标候选模型对样本数据进行预测,得到样本预测结果。
目标候选模型为当前待训练的模型,且目标候选模型是针对第一领域的数据进行训练的模型,当目标候选模型训练完成后,即得到目标模型。也即,目标候选模型为模型参数待调整的模型,当目标候选模型的模型参数调整完毕后,冻结模型参数,将冻结模型参数的目标候选模型作为目标模型。
将样本数据输入第i次迭代训练得到的目标候选模型后,通过第i次迭代训练得到的目标候选模型对样本数据进行预测,输出得到样本预测结果。其中,目标候选模型对样本数据的预测包括:分类预测、识别预测、图像处理结果预测等,本申请实施例对此不加以限定。
步骤404,基于第一伪标签与样本预测结果之间的第一差异,和第二伪标签与样本预测结果之间的第二差异确定损失值。
样本预测结果包括硬标签结果和软标签结果中的至少一种。该损失值为第i+1次得带中的损失值,也即,通过第i次迭代训练得到的目标候选模型进行预测后,用于对第i次迭代训练得到的目标候选模型进行评估的损失值。
在一些实施例中,通过第一预设损失函数对第一伪标签与样本预测结果之间的第一差异进行确定;以及,通过第二预设损失函数对第二伪标签与样本预测结果之间的第二差异进行确定。
根据第一差异和第二差异确定损失值,并基于损失值对目标候选模型进行训练。
首先,对动力模型输入的第二伪标签与样本预测结果之间的第二差异进行说明。根据第二伪标签与样本预测结果之间的第二差异确定第二损失值,计算该第二损失值对应的损失函数如下公式二和公式三所示:
公式二:ypsd=softmax(T-1(fm(T(xt))))
公式三:
其中,ypsd是指动力模型预测得到的第二伪标签,T是指对样本图像的变形操作,fm是指动力模型,T-1是指对预测结果的逆变形。是第二损失值,W和H是指样本图像的宽和高,(u,v)是样本图像中像素点的坐标,pt是指目标候选模型输出的预测结果。
在从源到目标课程学习中,从源模型生成的第一伪标签ysrc开始学习,逐渐将学习对象迁移到动力模型生成的第二伪标签ypsd,通过该损失函数如下公式四所示:
公式四:
其中,为总损失值,/>为源模型对应的第一损失值,ω和α为权重参数。
步骤405,基于第i+1次迭代中的损失值对第i次迭代得到的目标候选模型进行训练,得到第i+1次迭代后的目标候选模型。
可选地,基于第i+1次迭代中的损失值对第i次迭代得到的目标候选模型进行模型参数的调整,得到第i+1次迭代后的目标候选模型。
目标模型用于对第一领域的数据进行预测,也即,目标模型用于将数据的特征表示映射至与源模型相同的特征空间中,从而根据特征表示在特征空间中的映射情况得到数据的预测结果。
步骤406,响应于目标候选模型符合训练要求,将目标候选模型确定为目标模型。
可选地,每轮迭代训练中,通过计算得到的损失值对该轮迭代中的目标候选模型进行训练,得到下一轮迭代中的目标候选模型,并继续进行后续训练,直至得到目标模型。
示意性的,在第i轮迭代训练中,使用第i轮迭代训练得到的目标候选模型,将第i轮迭代训练得到的目标候选模型变形后,得到第i+1轮迭代训练中的动力模型,根据源模型、第i轮迭代训练中的动力模型和第i轮迭代训练得到的目标候选模型确定损失值,并基于损失值对第i轮迭代训练得到的目标候选模型进行训练,得到第i轮迭代训练得到的目标候选模型。重复迭代对目标候选模型进行训练,直至训练符合训练要求后,得到目标模型。
可选地,训练要求包括迭代次数、损失值收敛情况等要求中的至少一种。
可选地,本申请实施例中的源模型、目标模型和动力模型都采用相同的基础模型,示意性的,使用DeepLab-V3作为基础模型,模型结构如图5所示。该模型500包含多个不同尺度的空洞卷积510,可以增加模型多尺度特征的提取能力,帮助其更好地识别图片中的病灶。
空洞卷积510的结构如图6所示。相比正常卷积600(图6中最左),空洞卷积510在各自卷积参数间留有空隙,可以达到增加卷积感受野的目的。
综上所述,本实施例提供的方法,通过将目标候选模型进行变形得到动力模型,从而提供了从源模型到目标模型,以及从源模型迁移到动力模型的渐进式训练过程,辅助完成从源域到目标域的训练过程,使训练过程稳定适应从源模型到目标模型,从而控制模型训练平滑地从源域转移至目标域,也即,通过增加动力模型的过渡,使样本数据从源域通过动力模型过渡至目标模型对应的目标域,提高了模型训练效率和准确率。
本实施例提供的方法,通过动力模型对变形样本数据进行预测,以及通过源模型对样本数据进行预测,由动力模型辅助目标模型训练数据复原能力,以及通过源模型提高目标模型在特征空间的映射准确率,提高了目标模型的训练效率和准确率。
本实施例提供的方法,通过对样本数据进行变形处理,并在动力模型对变形样本数据预测得到第二候选伪标签时,通过逆变形对第二候选伪标签进行再次变形,从而得到第二伪标签与样本预测结果进行比对,提高了数据比对精确度和效率。
本实施例提供的方法,通过在每一轮迭代中训练得到的目标候选模型继续变形得到下一轮的动力模型,从而在动力模型和目标候选模型的轮替迭代更新提高了目标模型的训练效率和准确率。
在一个可选的实施例中,各个样本数据还包括复杂度,根据复杂度对目标模型进行由易到难的训练。图7是本申请一个示例性实施例提供的模型训练方法的流程图,该方法可以由服务器或者终端执行,也可以由服务器和终端共同执行,本申请实施例中,以该方法由服务器执行为例进行说明,如图7所示,该方法包括如下步骤。
步骤701,获取样本数据,样本数据是第一领域中采集的用于对目标模型进行训练的数据。
可选地,该样本数据是针对第一领域获取的公开数据集中的数据,用于对目标模型进行无源领域自适应训练。
步骤702,对样本数据在源模型和目标候选模型中的输出结果之间进行离散度分析,得到样本数据对应的学习复杂度。
也即,将样本数据输入源模型,输出第一预测结果,以及将样本数据输入目标候选模型,输出样本预测结果。基于第一预测结果和样本预测结果对样本数据进行离散度的分析,得到学习复杂度。
可选地,以样本数据实现为样本图像为例,该学习复杂度的预测方式如下公式五所示:
公式五:
其中,d表示学习复杂度,KL表示离散度算法,ps为源模型输出的预测概率,pt表示目标模型输出的概率。该概率表示像素点(u,v)属于某个分类的概率。
步骤703,基于学习复杂度对样本数据进行权重分配,得到权重参数。
可选地,在每个批次(batch)的训练中,为了实现目标模型从简单样本到困难样本的学习,根据学习复杂度d对样本进行权重分配,权重分配方式如下公式六所示:
公式六:
其中,ωb为权重参数,δ为预设常数,B为batch中样本数据总数量,b为batch第b个样本数据,α为候选权重参数。
可选地,权重参数和候选权重参数的计算过程中,获取当前迭代循环次序与预设迭代循环次数之间的比值;基于比值得到候选权重参数;基于学习复杂度和候选权重参数对样本数据进行权重分配,得到权重参数。
其中,候选权重参数α的计算方式如下公式七所示:
公式七:
其中,R是指迭代训练的第R次,Rmax是指迭代训练的最大次数。
可选地,学习复杂度越高的样本数据,对应的权重参数越小;学习复杂度越低的样本数据,对应的权重参数越大。
可选地,在获取样本数据对目标候选模型进行训练时,首先选择权重参数大的样本数据,再选择权重参数小的样本数据,对目标候选模型进行训练。
步骤704,通过源模型对样本数据进行预测,输出得到第一伪标签,以及,通过第i次迭代得到的动力模型对变形样本数据进行预测,输出得到第二伪标签,i为正整数。
可选地,由于上述步骤702中执行了输出第一预测结果的过程,该第一预测结果和第一伪标签相同,或者第一伪标签是基于第一预测结果得到的伪标签,故步骤704中源模型的预测过程可以省略。
其中,源模型为预先训练得到的针对第二领域进行数据预测的模型,动力模型是从待训练的目标候选模型变形得到的模型。
可选地,动力模型是第i次迭代得到的目标候选模型变形得到的模型。在一些实施例中,第i+1次迭代得到的动力模型是基于第i次迭代中的动力模型和第i次迭代后得到的目标候选模型变形得到的。
源模型是预先训练得到的将样本数据的特征表示映射至预设特征空间的模型,该预设特征空间是源模型在针对第二领域数据进行训练时生成的特征空间,该预设特征空间中包括各个预测分类对应的特征中心向量,各个预测分类对应的特征中心向量用于与样本数据的特征表示在特征空间中的映射结果进行匹配。
而动力模型是在第i次迭代得到的目标候选模型基础上变形得到的模型,对于特征的映射、特征空间的参数与源模型存在不同。
步骤705,通过第i次迭代训练得到的目标候选模型对样本数据进行预测,得到样本预测结果。
可选地,由于上述步骤702中执行了输出样本预测结果的过程,步骤705的预测过程可以省略。
目标候选模型为当前待训练的模型,且目标候选模型是针对第一领域的数据进行训练的模型,当目标候选模型训练完成后,即得到目标模型。也即,目标候选模型为模型参数待调整的模型,当目标候选模型的模型参数调整完毕后,冻结模型参数,将冻结模型参数的目标候选模型作为目标模型。
将样本数据输入第i次迭代训练得到的目标候选模型后,通过第i次迭代训练得到的目标候选模型对样本数据进行预测,输出得到样本预测结果。其中,目标候选模型对样本数据的预测包括:分类预测、识别预测、图像处理结果预测等,本申请实施例对此不加以限定。
步骤706,确定第一伪标签与样本预测结果之间的第一损失值。
步骤707,确定第二伪标签与样本预测结果之间的第二损失值。
可选地,上述第一损失值和第二损失值的计算过程在上述步骤404中已进行了详细说明,此处不再赘述。
步骤708,基于权重参数对第一损失值和第二损失值进行加权融合,得到损失值。
也即,上述公式四中的权重参数ω和候选权重参数α即为上述公式六和公式七中基于迭代过程和样本数据计算得到的参数。
可选地,获取候选权重参数和权重参数与第一损失值的第一乘积;获取预设阈值与候选权重参数的第一差值,以及第一差值与第二损失值的第二乘积;将第一乘积和第二乘积之和作为损失值。如上公式四所示,则预设阈值为1。可选地,候选权重参数的取值在0-1之间。
步骤709,基于第i+1次迭代中的损失值对第i次迭代得到的目标候选模型进行训练,得到第i+1次迭代后的目标候选模型。
可选地,基于第i+1次迭代中的损失值对第i次迭代得到的目标候选模型进行模型参数的调整,得到第i+1次迭代后的目标候选模型。
目标模型用于对第一领域的数据进行预测,也即,目标模型用于将数据的特征表示映射至与源模型相同的特征空间中,从而根据特征表示在特征空间中的映射情况得到数据的预测结果。
步骤710,响应于目标候选模型符合训练要求,将目标候选模型确定为目标模型。
可选地,每轮迭代训练中,通过计算得到的损失值对该轮迭代中的目标候选模型进行训练,得到下一轮迭代中的目标候选模型,并继续进行后续训练,直至得到目标模型。
示意性的,在第i轮迭代训练中,使用第i轮迭代训练得到的目标候选模型,将第i轮迭代训练得到的目标候选模型变形后,得到第i+1轮迭代训练中的动力模型,根据源模型、第i轮迭代训练中的动力模型和第i轮迭代训练得到的目标候选模型确定损失值,并基于损失值对第i轮迭代训练得到的目标候选模型进行训练,得到第i轮迭代训练得到的目标候选模型。重复迭代对目标候选模型进行训练,直至训练符合训练要求后,得到目标模型。
综上所述,本实施例提供的方法,通过将目标候选模型进行变形得到动力模型,从而提供了从源模型到目标模型,以及从源模型迁移到动力模型的渐进式训练过程,辅助完成从源域到目标域的训练过程,使训练过程稳定适应从源模型到目标模型,从而控制模型训练平滑地从源域转移至目标域,也即,通过增加动力模型的过渡,使样本数据从源域通过动力模型过渡至目标模型对应的目标域,提高了模型训练效率和准确率。
本实施例提供的方法,通过对样本数据计算学习复杂度,从而确定样本数据对应的权重参数,在选择样本数据进行训练时,首先选择学习复杂度低的数据进行训练,再选择学习复杂度高的数据进行训练,从而实现从源领域到目标领域的渐进式学习过程。
本实施例提供的方法,通过对样本数据计算学习复杂度,从而确定样本数据对应的权重参数,并将权重参数应用于损失值的计算过程中,避免不同复杂度的样本数据对模型进行同样步长的训练,提高了模型的训练准确率。
图8是本申请一个示例性实施例提供的模型训练装置的结构框图,如图8所示,该装置包括:
获取模块810,用于获取样本数据,所述样本数据是第一领域中采集的用于对目标模型进行训练的数据;
预测模块820,用于通过源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过动力模型对变形样本数据进行预测,输出得到第二伪标签,所述源模型为预先训练得到的针对第二领域进行数据预测的模型,所述动力模型是从待训练的目标候选模型变形得到的模型,所述变形样本数据是对所述样本数据变形得到的数据;
所述预测模块820,还用于通过所述目标候选模型对所述样本数据进行预测,得到样本预测结果;
确定模块830,用于基于所述第一伪标签与所述样本预测结果之间的第一差异,和所述第二伪标签与所述样本预测结果之间的第二差异确定损失值;
训练模块840,用于基于所述损失值对所述目标候选模型进行迭代训练,得到所述目标模型,所述目标模型用于对所述第一领域的数据进行预测。
在一个可选的实施例中,所述动力模型是第i次迭代得到的目标候选模型变形得到的模型;
所述训练模块840,还用于基于第i+1次迭代中的损失值对第i次迭代得到的目标候选模型进行训练,得到第i+1次迭代后的目标候选模型,i为正整数;
所述确定模块830,还用于响应于所述目标候选模型符合训练要求,将所述目标候选模型确定为所述目标模型。
在一个可选的实施例中,所述预测模块820,还用于通过所述源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过第i次迭代得到的动力模型对所述变形样本数据进行预测,输出得到所述第二伪标签;
所述预测模块820,还用于通过第i次迭代得到的所述目标候选模型对所述样本数据进行预测,得到第i次迭代中的所述样本预测结果;
在一个可选的实施例中,所述预测模块820,还用于通过所述动力模型对所述变形样本数据进行预测,输出得到第二候选伪标签;
如图9所示,该装置还包括:
变形模块850,用于对所述第二候选伪标签进行逆变形处理,得到所述第二伪标签,所述逆变形处理的变形方式与所述变形样本数据的变形方式相反。
在一个可选的实施例中,所述确定模块830,还用于确定所述第一伪标签与所述样本预测结果之间的第一损失值;
所述确定模块830,还用于确定所述第二伪标签与所述样本预测结果之间的第二损失值;
所述装置还包括:
融合模块860,用于对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值。
在一个可选的实施例中,所述装置还包括:
分析模块870,用于对所述样本数据在所述源模型和所述目标候选模型中的输出结果之间进行离散度分析,得到所述样本数据对应的学习复杂度;
分配模块880,用于基于所述学习复杂度对所述样本数据进行权重分配,得到权重参数;
所述融合模块860,还用于基于所述权重参数对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值。
在一个可选的实施例中,所述分配模块880,还用于获取当前迭代循环次序与预设迭代循环次数之间的比值;基于所述比值得到所述候选权重参数;基于所述学习复杂度和所述候选权重参数对所述样本数据进行权重分配,得到所述权重参数。
在一个可选的实施例中,所述融合模块860,还用于获取所述候选权重参数和所述权重参数与所述第一损失值的第一乘积;获取预设阈值与候选权重参数的第一差值,以及所述第一差值与所述第二损失值的第二乘积;将所述第一乘积和所述第二乘积之和作为所述损失值。
在一个可选的实施例中,所述装置还包括:
变形模块850,用于通过预设变形参数对所述目标候选模型进行变形处理,得到所述动力模型。
在一个可选的实施例中,所述变形模块850,还用于将第i次迭代中的动力模型的模型参数与所述预设变形参数相乘,得到第一乘积参数;获取预设参数与所述预设变形参数的第二差值;将第i次迭代后训练得到的目标模型的模型参数与所述第二差值相乘,得到第二乘积参数;将所述第一乘积参数和所述第二乘积参数之和作为第i+1次迭代中的动力模型的模型参数。
综上所述,本实施例提供的装置,通过将目标候选模型进行变形得到动力模型,从而提供了从源模型到目标模型,以及从源模型迁移到动力模型的渐进式训练过程,辅助完成从源域到目标域的训练过程,训练过程稳定适应从源模型到目标模型的渐进,从而控制模型训练平滑地从源域转移至目标域,也即,通过增加动力模型的过渡,使样本数据从源域通过动力模型过渡至目标模型对应的目标域,提高了模型训练效率和准确率。
需要说明的是:上述实施例提供的模型训练装置,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的模型训练装置与模型训练方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
图10示出了本申请一个示例性实施例提供的服务器的结构示意图。该服务器可以是如图2所示的服务器。
具体来讲:服务器1000包括中央处理单元(Central Processing Unit,CPU)1001、包括随机存取存储器(Random Access Memory,RAM)1002和只读存储器(Read OnlyMemory,ROM)1003的系统存储器1004,以及连接系统存储器1004和中央处理单元1001的系统总线1005。服务器1000还包括用于存储操作系统1013、应用程序1014和其他程序模块1015的大容量存储设备1006。
大容量存储设备1006通过连接到系统总线1005的大容量存储控制器(未示出)连接到中央处理单元1001。大容量存储设备1006及其相关联的计算机可读介质为服务器1000提供非易失性存储。也就是说,大容量存储设备1006可以包括诸如硬盘或者紧凑型光盘只读存储器(Compact Disc Read Only Memory,CD-ROM)驱动器之类的计算机可读介质(未示出)。
不失一般性,计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、可擦除可编程只读存储器(Erasable Programmable Read Only Memory,EPROM)、带电可擦可编程只读存储器(Electrically Erasable Programmable Read Only Memory,EEPROM)、闪存或其他固态存储技术,CD-ROM、数字通用光盘(Digital Versatile Disc,DVD)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知计算机存储介质不局限于上述几种。上述的系统存储器1004和大容量存储设备1006可以统称为存储器。
根据本申请的各种实施例,服务器1000还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即服务器1000可以通过连接在系统总线1005上的网络接口单元1011连接到网络1012,或者说,也可以使用网络接口单元1011来连接到其他类型的网络或远程计算机系统(未示出)。
上述存储器还包括一个或者一个以上的程序,一个或者一个以上程序存储于存储器中,被配置由CPU执行。
本申请的实施例还提供了一种计算机设备,该计算机设备可以实现为如图2所示的终端或者服务器。该计算机设备包括处理器和存储器,该存储器中存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行以实现上述各方法实施例提供的模型训练方法。
本申请的实施例还提供了一种计算机可读存储介质,该计算机可读存储介质上存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行,以实现上述各方法实施例提供的模型训练方法。
本申请的实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中任一所述的模型训练方法。
可选地,该计算机可读存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、固态硬盘(SSD,Solid State Drives)或光盘等。其中,随机存取记忆体可以包括电阻式随机存取记忆体(ReRAM,Resistance RandomAccess Memory)和动态随机存取存储器(DRAM,Dynamic Random Access Memory)。上述本申请实施例序号仅仅为了描述,不代表实施例的优劣。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
以上所述仅为本申请的可选实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (10)
1.一种模型训练方法,其特征在于,所述方法包括:
获取样本数据,所述样本数据是第一领域中采集的用于对目标模型进行训练的数据;
通过源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过动力模型对变形样本数据进行预测,输出得到第二伪标签,所述源模型为预先训练得到的针对第二领域进行数据预测的模型,所述动力模型是从待训练的目标候选模型变形得到的模型,所述变形样本数据是对所述样本数据变形得到的数据;
通过所述目标候选模型对所述样本数据进行预测,得到样本预测结果;
基于所述第一伪标签与所述样本预测结果之间的第一差异,和所述第二伪标签与所述样本预测结果之间的第二差异确定损失值;
基于所述损失值对所述目标候选模型进行迭代训练,得到所述目标模型,所述目标模型用于对所述第一领域的数据进行预测。
2.根据权利要求1所述的方法,其特征在于,所述动力模型是第i次迭代得到的目标候选模型变形得到的模型;
所述基于所述损失值对所述目标候选模型进行迭代训练,得到所述目标模型,包括:
基于第i+1次迭代中的损失值对第i次迭代得到的目标候选模型进行训练,得到第i+1次迭代后的目标候选模型,i为正整数;
响应于所述目标候选模型符合训练要求,将所述目标候选模型确定为所述目标模型。
3.根据权利要求2所述的方法,其特征在于,所述通过源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过动力模型变形样本数据进行预测,输出得到第二伪标签,包括:
通过所述源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过第i次迭代得到的动力模型对所述变形样本数据进行预测,输出得到所述第二伪标签;
所述通过所述目标候选模型对所述样本数据进行预测,得到样本预测结果,包括:
通过第i次迭代得到的所述目标候选模型对所述样本数据进行预测,得到第i次迭代中的所述样本预测结果。
4.根据权利要求1至3任一所述的方法,其特征在于,所述通过动力模型对变形样本数据进行预测,输出得到第二伪标签,包括:
通过所述动力模型对所述变形样本数据进行预测,输出得到第二候选伪标签;
对所述第二候选伪标签进行逆变形处理,得到所述第二伪标签,所述逆变形处理的变形方式与所述变形样本数据的变形方式相反。
5.根据权利要求1至3任一所述的方法,其特征在于,所述基于所述第一伪标签与所述样本预测结果之间的第一差异,和所述第二伪标签与所述样本预测结果之间的第二差异确定损失值,包括:
确定所述第一伪标签与所述样本预测结果之间的第一损失值;
确定所述第二伪标签与所述样本预测结果之间的第二损失值;
对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值。
6.根据权利要求5所述的方法,其特征在于,所述获取样本数据之后,还包括:
对所述样本数据在所述源模型和所述目标候选模型中的输出结果之间进行离散度分析,得到所述样本数据对应的学习复杂度;
基于所述学习复杂度对所述样本数据进行权重分配,得到权重参数;
所述对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值,包括:
基于所述权重参数对所述第一损失值和所述第二损失值进行加权融合,得到所述损失值。
7.根据权利要求6所述的方法,其特征在于,所述基于所述学习复杂度对所述样本数据进行权重分配,得到权重参数,包括:
获取当前迭代循环次序与预设迭代循环次数之间的比值;
基于所述比值得到所述候选权重参数;
基于所述学习复杂度和所述候选权重参数对所述样本数据进行权重分配,得到所述权重参数。
8.一种模型训练装置,其特征在于,所述装置包括:
获取模块,用于获取样本数据,所述样本数据是第一领域中采集的用于对目标模型进行训练的数据;
预测模块,用于通过源模型对所述样本数据进行预测,输出得到第一伪标签,以及,通过动力模型对变形样本数据进行预测,输出得到第二伪标签,所述源模型为预先训练得到的针对第二领域进行数据预测的模型,所述动力模型是从待训练的目标候选模型变形得到的模型,所述变形样本数据是对所述样本数据变形得到的数据;
所述预测模块,还用于通过所述目标候选模型对所述样本数据进行预测,得到样本预测结果;
确定模块,用于基于所述第一伪标签与所述样本预测结果之间的第一差异,和所述第二伪标签与所述样本预测结果之间的第二差异确定损失值;
训练模块,用于基于所述损失值对所述目标候选模型进行迭代训练,得到所述目标模型,所述目标模型用于对所述第一领域的数据进行预测。
9.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如权利要求1至7任一所述的模型训练方法。
10.一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如权利要求1至7任一所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211017332.1A CN117010480A (zh) | 2022-08-23 | 2022-08-23 | 模型训练方法、装置、设备、存储介质及程序产品 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211017332.1A CN117010480A (zh) | 2022-08-23 | 2022-08-23 | 模型训练方法、装置、设备、存储介质及程序产品 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117010480A true CN117010480A (zh) | 2023-11-07 |
Family
ID=88566086
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211017332.1A Pending CN117010480A (zh) | 2022-08-23 | 2022-08-23 | 模型训练方法、装置、设备、存储介质及程序产品 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117010480A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118072127A (zh) * | 2024-04-18 | 2024-05-24 | 海马云(天津)信息技术有限公司 | 一种图像生成模型的训练方法及相关装置 |
-
2022
- 2022-08-23 CN CN202211017332.1A patent/CN117010480A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN118072127A (zh) * | 2024-04-18 | 2024-05-24 | 海马云(天津)信息技术有限公司 | 一种图像生成模型的训练方法及相关装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2022042002A1 (zh) | 一种半监督学习模型的训练方法、图像处理方法及设备 | |
CN111382868B (zh) | 神经网络结构搜索方法和神经网络结构搜索装置 | |
CN109993102B (zh) | 相似人脸检索方法、装置及存储介质 | |
CN110659723B (zh) | 基于人工智能的数据处理方法、装置、介质及电子设备 | |
CN112418292B (zh) | 一种图像质量评价的方法、装置、计算机设备及存储介质 | |
CN111651671B (zh) | 用户对象推荐方法、装置、计算机设备和存储介质 | |
CN111950596A (zh) | 一种用于神经网络的训练方法以及相关设备 | |
CN110889450B (zh) | 超参数调优、模型构建方法和装置 | |
CN114298122B (zh) | 数据分类方法、装置、设备、存储介质及计算机程序产品 | |
CN113158554B (zh) | 模型优化方法、装置、计算机设备及存储介质 | |
CN113807399A (zh) | 一种神经网络训练方法、检测方法以及装置 | |
CN113806582B (zh) | 图像检索方法、装置、电子设备和存储介质 | |
CN112199600A (zh) | 目标对象识别方法和装置 | |
CN112819024B (zh) | 模型处理方法、用户数据处理方法及装置、计算机设备 | |
CN113609337A (zh) | 图神经网络的预训练方法、训练方法、装置、设备及介质 | |
CN115879508A (zh) | 一种数据处理方法及相关装置 | |
EP3732632A1 (en) | Neural network training using the soft nearest neighbor loss | |
CN113128526B (zh) | 图像识别方法、装置、电子设备和计算机可读存储介质 | |
CN113762331A (zh) | 关系型自蒸馏方法、装置和系统及存储介质 | |
CN112070205A (zh) | 一种多损失模型获取方法以及装置 | |
CN115631008B (zh) | 商品推荐方法、装置、设备及介质 | |
CN112507912B (zh) | 一种识别违规图片的方法及装置 | |
CN113822293A (zh) | 用于图数据的模型处理方法、装置、设备及存储介质 | |
CN117010480A (zh) | 模型训练方法、装置、设备、存储介质及程序产品 | |
CN114528491A (zh) | 信息处理方法、装置、计算机设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |