CN112633385A - 一种模型训练的方法、数据生成的方法以及装置 - Google Patents
一种模型训练的方法、数据生成的方法以及装置 Download PDFInfo
- Publication number
- CN112633385A CN112633385A CN202011567739.2A CN202011567739A CN112633385A CN 112633385 A CN112633385 A CN 112633385A CN 202011567739 A CN202011567739 A CN 202011567739A CN 112633385 A CN112633385 A CN 112633385A
- Authority
- CN
- China
- Prior art keywords
- model
- data
- loss value
- target
- generation
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 295
- 238000000034 method Methods 0.000 title claims abstract description 147
- 238000013145 classification model Methods 0.000 claims description 144
- 230000015654 memory Effects 0.000 claims description 80
- 238000004590 computer program Methods 0.000 claims description 8
- 238000013473 artificial intelligence Methods 0.000 abstract description 6
- 230000008569 process Effects 0.000 description 36
- 238000004891 communication Methods 0.000 description 33
- 238000012545 processing Methods 0.000 description 26
- 238000010586 diagram Methods 0.000 description 24
- 241000282326 Felis catus Species 0.000 description 13
- 238000013528 artificial neural network Methods 0.000 description 12
- 239000011159 matrix material Substances 0.000 description 12
- 230000006870 function Effects 0.000 description 11
- 230000006978 adaptation Effects 0.000 description 7
- 238000004364 calculation method Methods 0.000 description 6
- 230000003287 optical effect Effects 0.000 description 6
- 230000003044 adaptive effect Effects 0.000 description 5
- 238000013527 convolutional neural network Methods 0.000 description 5
- 230000009466 transformation Effects 0.000 description 5
- 230000003068 static effect Effects 0.000 description 4
- 241000287828 Gallus gallus Species 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000013461 design Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 238000012546 transfer Methods 0.000 description 3
- 230000004913 activation Effects 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 2
- 239000000872 buffer Substances 0.000 description 2
- 238000006243 chemical reaction Methods 0.000 description 2
- 238000001514 detection method Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 238000012886 linear function Methods 0.000 description 2
- 238000012423 maintenance Methods 0.000 description 2
- 238000007726 management method Methods 0.000 description 2
- 238000010606 normalization Methods 0.000 description 2
- 238000010428 oil painting Methods 0.000 description 2
- 230000002093 peripheral effect Effects 0.000 description 2
- 230000001172 regenerating effect Effects 0.000 description 2
- 239000000654 additive Substances 0.000 description 1
- 230000000996 additive effect Effects 0.000 description 1
- 230000003042 antagnostic effect Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000010422 painting Methods 0.000 description 1
- 230000008447 perception Effects 0.000 description 1
- 238000003672 processing method Methods 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本申请公开了一种模型训练的方法,涉及人工智能领域,包括:根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,其中,每次迭代训练,包括:将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。根据第一判别结果获取第一损失值。固定第一生成模型的参数,根据第一损失值更新第一判别模型,或者固定第一判别模型的参数,根据第一损失值更新第一生成模型的参数。通过本申请提供的方案,可以根据源域的数据获取大量优质的目标域数据。
Description
技术领域
本申请涉及人工智能技术领域,具体涉及一种模型训练的方法、数据生成的方法以及装置。
背景技术
人工智能(artificial intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
近些年来,在人工智能领域,使用领域自适应策略来解决图像分类、检测等任务成为热点。领域自适应是指利用数据、任务或模型之间的相似性,将在源领域学习过的模型,应用于目标领域的一种学习过程。通过源领域的数据生成大量优质的目标域数据,有利于根据获取到的大量优质的目标域数据训练获取性能更好的模型,因此如何根据源领域的数据生成大量优质的目标域数据亟待解决。
发明内容
本申请实施例提供一种模型训练的方法、数据生成的方法以及装置,以通过源领域的数据生成大量优质的目标域数据。
为达到上述目的,本申请实施例提供如下技术方案:
本申请第一方面提供一种模型训练的方法,可以包括:根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,目标模型可以包括第一生成模型和第一判别模型。第一生成模型和第一判别模型可以是神经网络,例如:残差神经网络(residualneural network,Resnet),比如具体可以是Resnet50。第一生成模型和第一判别模型还可以是其他类型的神经网络,或者第一生成模型和第一判别模型还可以是VGG、AlexNet,或者第一生成模型和第一判别模型还可以是例如全连接网络,全连接网络的层数可以选择(例如但不限于:九层的全连接网络),或者第一生成模型和第一判别模型还可以是卷积神经网络。需要说明的是,第一生成模型和第一判别模型可以是相同类型的神经网络,比如二者都采用Resnet50;第一生成模型和第一判别模型也可以是不同类型的神经网络,比如第一生成模型采用Resnet50,第一判别模型采用AlexNet。本申请中的源域数据可以是图片数据、文本数据或者语音数据,本申请实施例对源域数据的类型并不进行限定。本申请中的目标域数据与源语数据的数据类型相同。其中,每次迭代训练,可以包括:将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。以图片数据为例进行说明,假设源域数据包括图片1(标签为“猫”)。将图片1(不携带标签“猫”)输入至第一生成模型。针对图片1,第一生成模型生成1个生成图片,假设为生成图片1。将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据,可以理解为,将生成图片1和图片1未携带的标签“猫”进行组合,以获取一组组合数据。将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。目标域的数据包括有标签的数据和无标签的数据。需要说明的是,未携带标签的源域数据代表源域数据是有标签的,只是输入至第一生成模型中时没有携带,而目标域无标签的数据代表该目标域的数据是没有标签的。根据第一判别结果获取第一损失值。固定第一生成模型的参数,根据第一损失值更新第一判别模型,或者固定第一判别模型的参数,根据第一损失值更新第一生成模型。第一判别模型用于判别输入至第一判别模型的数据是真实的目标域数据还是第一生成器生成的数据。第一判别模型的训练目标为判别第一组合数据为负样本,判别目标域有标签数据为正样本。训练第一判别模型的目的在于,使第一判别模型可以更准确的判别获取到的数据哪些是生成器生成的数据,哪些是目标域有标签的数据。迭代训练的过程可以看做根据第一损失值对第一生成模型和第一判别模型进行对抗训练的训练过程。对抗训练是指第一生成模型和第一判别模型形成一种动态的“博弈过程”,相互对抗,相互促进。可以理解为,第一生成模型的训练是了使第一判别模型无法正确判断接收的数据是来自第一生成模型还是来自真实的目标域数据,第一判别模型的训练是为了准确分辨接收的数据是来自第一生成模型还是来自真实的目标域数据。重复执行上述迭代训练的过程,直至满足预设的停止条件,其中,停止条件包括但不限于满足预设的训练次数、第一生成模型满足收敛条件、第一判别模型无法正确判断接收的数据是来自第一生成模型还是来自真实的目标域数据。本申请提供的一种模型训练的方法,根据未携带标签的源域数据对目标模型进行迭代训练。使目标模型中的第一生成模型只能根据源域数据本身的特征生成第一生成结果。避免在生成第一生成的结果中,产生“标签主导”的问题,忽略了源域数据本身的特征,比如忽略了图片的语义信息。因此通过本申请提供的训练方法获取的模型,可以根据源域的数据获取大量优质的目标域数据。
在一种可能的实施方式中,目标模型还可以包括第二生成模型,该方法还可以包括:将第一生成结果输入至第二生成模型,以输出第二生成结果。根据第二生成结果和未携带标签的源域数据之间的差异获取第二损失值。根据第一损失值更新第一生成模型,可以包括:根据第一损失值和第二损失值更新第一生成模型,第二损失值还用于更新第二生成模型。在这种实施方式中,给出了一种具体的,在进行对抗训练时,如何根据第一损失值对第一生成模型进行更新的方案,具体的根据第一损失值和二损失值更新第一生成模型,使输出的目标数据与输入的源域数据是对应的、有关联的。
在一种可能的实施方式中,根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,可以包括:根据未携带标签的源域数据、目标域有标签数据以及目标域无标签数据对目标模型进行迭代训练。在本申请提供的方案的一个典型的适用场景中,训练一个模型,比如训练一个用于执行分类任务的模型,一般需要大量有标注的样本。而对于某一些领域,比如医疗领域、工业视觉领域,获取有标注的样本并不容易,具体可能表现在成本高昂,费时费力。通过本申请提供的方案,可以根据目标域少量有标注的样本,目标域大量无标注的样本、以及源域大量有标注的样本,获取目标域大量有标注的样本,降低成本。在本申请提供的方案的另一个典型的适用场景中,对于要从其他企业采购训练数据的企业,为了保证训练数据供应的稳定性,该企业可能会从多个其他企业采购训练数据。来自不同企业的训练数据可以分别看做一个数据域,来自不同企业的训练数据可能并不满足独立同分布,即使从同一个企业采购的训练数据,该训练数据也可能是该企业不同代设备采集的数据,比如有些数据是第一代产品采集的数据,有些数据是第二代产品采集的数据,进而导致采集的数据不满足独立同分布。而训练一个模型的训练数据要求满足独立同分布,此外,训练数据如果和测试数据的分布差异较大,模型的性能往往也会显著下降。本申请提供的可以很好的解决这一问题。将从其他企业采购的训练数据看做有标注的目标域数据,通过本申请提供的方案,无需从其他企业采购大量的有标注的目标域数据,只需要采购少量的有标注的目标域数据,根据本申请提供的方案可以根据该少量的有标注的目标域数据,以及一些无标注的目标域数据,以及源域数据生成大量的有标注的目标域数据,进而节省了成本,还使得生成的大量的有标注的目标域数据均满足独立同分布,提升模型训练的效率。
在一种可能的实施方式中,目标模型还可以包括第二判别模型,每次迭代训练还可以包括:将第一生成结果和预设标签进行组合,以获取第二组合数据。将目标域无标签数据和预设标签进行组合,以获取第三组合数据。将第二组合数据以及第三组合数据输入至第二判别模型中,以输出第二判别结果。根据第二判别结果获取第三损失值。根据第一损失值更新第一判别模型,可以包括:根据第一损失值和第三损失值更新第一判别模型。根据第一损失值更新第一生成模型,可以包括:根据第一损失值和第三损失值更新第一生成模型。由于目标域有标签的数据的数目过少,在训练过程中,能够提供给判别模型的信息有限,不利于模型的训练。在这种实施方式中,在训练过程中利用了目标域无标签数据,将目标域无标签数据和预设标签进行组合,可以获取大量的目标域有标签的数据,进而可以获取大量真实的目标域有标签数据。
在一种可能的实施方式中,目标模型还可以包括第一分类模型,每次迭代训练还可以包括:将目标域无标签数据输入至第一分类模型中,以输出第一预测标签。将目标域无标签数据和第一预测标签进行组合,以获取第四组合数据。将第四组合数据输入至第一判别模型中,以输出第三判别结果。根据第三判别结果获取第四损失值。根据第一损失值更新第一判别模型,可以包括:固定第一生成模型的参数和第一分类模型的参数,根据第一损失值和第四损失值更新第一判别模型。每次迭代训练还可以包括:固定第一判别模型的参数,根据第四损失值更新第一分类模型。在这种实施方式中,还可以在目标网络中引入分类器,以根据目标域无标签数据获取更多的训练数据。
在一种可能的实施方式中,目标模型还可以包括第二分类模型,每次迭代训练还可以包括:将第一生成结果和目标域有标签数据输入至第二分类模型中,以输出第二预测标签。根据第二预测标签获取第五损失值。根据第四损失值更新第一分类模型,可以包括:根据第四损失值和第五损失值更新第一分类模型,第四损失值和第五损失值还用于更新第二分类模型。在这种实施方式中,通过增加用于训练第一分类模型的训练样本,提升第一分类模型的训练效率,进而提升训练后的目标模型的性能。
本申请第二方面提供一种数据生成的方法,可以包括:获取未携带标签的源域数据。将未携带标签的源域数据输入至目标生成模型中,以获取目标域数据。其中,目标生成模型是通过未携带标签的源域训练数据、目标域有标签训练数据对目标模型进行迭代训练后获取的,目标模型可以包括第一生成模型和第一判别模型,目标生成模型是训练后的第一生成模型,目标生成模型的参数是通过固定第一判别模型的参数,通过第一损失值更新第一生成模型的参数获取的,第一损失值还用于固定第一生成模型的参数时,更新第一判别模型,第一损失值是根据第一判别结果获取的,第一判别结果是根据第一组合训练数据以及目标域有标签训练数据输入至第一判别模型中获取的,第一组合数据是将源域训练数据未携带的标签和第一生成结果进行组合后获取的,第一生成结果是将未携带标签的源域训练数据输入至第一生成模型后获取的。根据本申请提供的一种数据生成方法,可以根据源域数据获取大量的、优质的目标领域的数据。此外,还可以根据获取到的目标领域的数据训练新的模型,比如通过获取到的目标领域的数据训练新的分类模型。
在一种可能的实施方式中,目标模型还可以包括第二生成模型,目标生成模型的参数具体是通过固定第一判别模型的参数,通过第一损失值和第二损失值更新第一生成模型的参数获取的,第二损失值是根据第二生成结果和未携带标签的源域训练数据之间的差异获取的,第二生成结果是将第一生成结果输入至第二生成模型获取的,第二损失值还用于更新第二生成模型。
在一种可能的实施方式中,目标生成模型具体是通过未携带标签的源域训练数据、目标域有标签训练数据以及目标域无标签训练数据对目标模型进行迭代训练后获取的。
在一种可能的实施方式中,目标模型还可以包括第二判别模型,目标生成模型的参数具体是通过固定第一判别模型的参数,第一损失值和第三损失值更新第一生成模型的参数获取的,第三损失值是通过第二判别结果获取的,第二判别结果是将第二组合数据以及第三组合数据输入至第二判别模型中获取的,第二组合数据是将第一生成结果和预设标签进行组合后获取的,第三组合数据是将标域无标签训练数据和预设标签进行组合后获取的,第一损失值和第三损失值还用于更新第一判别模型。
在一种可能的实施方式中,还可以包括:将目标域训练数据输入至目标分类模型中,以获取预测结果,其中,目标分类模型是固定第一判别模型的参数,通过第四损失值更新第一分类模型获取的,第四损失值是通过第三判别结果获取的,第三判别结果是通过将第四组合数据输入至第一判别模型中获取的,第四组合数据是将目标域无标签训练数据和第一预测标签进行组合后获取的,第一预测标签是将目标域无标签训练数据输入至第一分类模型中获取的,第四损失值还用于固定第一生成模型的参数和第一分类模型的参数时,更新第一判别模型。
在一种可能的实施方式中,目标模型还可以包括第二分类模型,目标分类模型具体是固定第一判别模型的参数,通过第四损失值和第五损失值更新第一分类模型获取的,第五损失值是通过第二预测标签获取的,第二预测标签是通过将第一生成结果和目标域有标签训练数据输入至第二分类模型中获取的,第四损失值和第五损失值还用于更新第二分类模型。
本申请第三方面提供一种模型训练的装置,可以包括:训练模块,用于根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,目标模型可以包括第一生成模型和第一判别模型,训练模块可以包括输入模块、组合模块、获取模块以及更新模块,其中,每次迭代训练时,输入模块,用于将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。组合模块,用于将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。输入模块,还用于将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。获取模块,用于根据第一判别结果获取第一损失值。更新模块,用于固定第一生成模型的参数,根据第一损失值更新第一判别模型,或者固定第一判别模型的参数,根据第一损失值更新第一生成模型。
在一种可能的实施方式中,目标模型还可以包括第二生成模型,输入模块,还用于将第一生成结果输入至第二生成模型,以输出第二生成结果。获取模块,还用于根据第二生成结果和未携带标签的源域数据之间的差异获取第二损失值。更新模块,具体用于根据第一损失值和第二损失值更新第一生成模型,第二损失值还用于更新第二生成模型。
在一种可能的实施方式中,训练模块,具体用于:根据未携带标签的源域数据、目标域有标签数据以及目标域无标签数据对目标模型进行迭代训练。
在一种可能的实施方式中,目标模型还可以包括第二判别模型,组合模块,还用于将第一生成结果和预设标签进行组合,以获取第二组合数据,将目标域无标签数据和预设标签进行组合,以获取第三组合数据。输入模块,还用于将第二组合数据以及第三组合数据输入至第二判别模型中,以输出第二判别结果。获取模块,还用于根据第二判别结果获取第三损失值。更新模块,具体用于根据第一损失值和第三损失值更新第一判别模型。更新模块,具体用于根据第一损失值和第三损失值更新第一生成模型。
在一种可能的实施方式中,目标模型还可以包括第一分类模型,输入模块,还用于将目标域无标签数据输入至第一分类模型中,以输出第一预测标签。组合模块,还用于将目标域无标签数据和第一预测标签进行组合,以获取第四组合数据。输入模块,还用于将第四组合数据输入至第一判别模型中,以输出第三判别结果。获取模块,还用于根据第三判别结果获取第四损失值。更新模块,具体用于固定第一生成模型的参数和第一分类模型的参数,根据第一损失值和第四损失值更新第一判别模型。更新模块,还用于固定第一判别模型的参数,根据第四损失值更新第一分类模型。
在一种可能的实施方式中,目标模型还可以包括第二分类模型,输入模块,还用于将第一生成结果和目标域有标签数据输入至第二分类模型中,以输出第二预测标签。获取模块,还用于根据第二预测标签获取第五损失值。更新模块,具体用于根据第四损失值和第五损失值更新第一分类模型,第四损失值和第五损失值还用于更新第二分类模型。
本申请第四方面提供一种数据生成的装置,可以包括:获取模块,用于获取未携带标签的源域数据。生成模块,用于将未携带标签的源域数据输入至目标生成模型中,以获取目标域数据。其中,目标生成模型是通过未携带标签的源域训练数据、目标域有标签训练数据对目标模型进行迭代训练后获取的,目标模型可以包括第一生成模型和第一判别模型,目标生成模型是训练后的第一生成模型,目标生成模型的参数是通过固定第一判别模型的参数,通过第一损失值更新第一生成模型的参数获取的,第一损失值还用于固定第一生成模型的参数时,更新第一判别模型,第一损失值是根据第一判别结果获取的,第一判别结果是根据第一组合训练数据以及目标域有标签训练数据输入至第一判别模型中获取的,第一组合数据是将源域训练数据未携带的标签和第一生成结果进行组合后获取的,第一生成结果是将未携带标签的源域训练数据输入至第一生成模型后获取的。
在一种可能的实施方式中,目标模型还可以包括第二生成模型,目标生成模型的参数具体是通过固定第一判别模型的参数,通过第一损失值和第二损失值更新第一生成模型的参数获取的,第二损失值是根据第二生成结果和未携带标签的源域训练数据之间的差异获取的,第二生成结果是将第一生成结果输入至第二生成模型获取的,第二损失值还用于更新第二生成模型。
在一种可能的实施方式中,目标生成模型具体是通过未携带标签的源域训练数据、目标域有标签训练数据以及目标域无标签训练数据对目标模型进行迭代训练后获取的。
在一种可能的实施方式中,目标模型还可以包括第二判别模型,目标生成模型的参数具体是通过固定第一判别模型的参数,第一损失值和第三损失值更新第一生成模型的参数获取的,第三损失值是通过第二判别结果获取的,第二判别结果是将第二组合数据以及第三组合数据输入至第二判别模型中获取的,第二组合数据是将第一生成结果和预设标签进行组合后获取的,第三组合数据是将标域无标签训练数据和预设标签进行组合后获取的,第一损失值和第三损失值还用于更新第一判别模型。
在一种可能的实施方式中,还可以包括分类模块,分类模块,用于将目标域训练数据输入至目标分类模型中,以获取预测结果,其中,目标分类模型是固定第一判别模型的参数,通过第四损失值更新第一分类模型获取的,第四损失值是通过第三判别结果获取的,第三判别结果是通过将第四组合数据输入至第一判别模型中获取的,第四组合数据是将目标域无标签训练数据和第一预测标签进行组合后获取的,第一预测标签是将目标域无标签训练数据输入至第一分类模型中获取的,第四损失值还用于固定第一生成模型的参数和第一分类模型的参数时,更新第一判别模型。
在一种可能的实施方式中,目标模型还可以包括第二分类模型,目标分类模型具体是固定第一判别模型的参数,通过第四损失值和第五损失值更新第一分类模型获取的,第五损失值是通过第二预测标签获取的,第二预测标签是通过将第一生成结果和目标域有标签训练数据输入至第二分类模型中获取的,第四损失值和第五损失值还用于更新第二分类模型。
本申请第五方面提供一种模型训练的装置,可以包括:存储器,用于存储计算机可读指令。还可以包括,与存储器耦合的处理器,用于执行存储器中的计算机可读指令从而执行如权利要求1至6任一项所描述的方法。
本申请第六方面提供一种数据生成的装置,可以包括:存储器,用于存储计算机可读指令。还可以包括,与存储器耦合的处理器,用于执行存储器中的计算机可读指令从而执行如权利要求7至12任一项所描述的方法。
本申请第七方面提供一种计算机可读存储介质,当指令在计算机装置上运行时,使得计算机装置执行如第一方面所描述的方法。
本申请第八方面提供一种计算机可读存储介质,当指令在计算机装置上运行时,使得计算机装置执行如第二方面所描述的方法。
本申请第九方面提供一种计算机程序产品,当在计算机上运行时,使得计算机可以执行如权第一方面所描述的方法。
本申请第十方面提供一种计算机程序产品,当在计算机上运行时,使得计算机可以执行如第二方面所描述的方法。
本申请第十一方面提供一种芯片,芯片与存储器耦合,用于执行存储器中存储的程序,以执行如第一方面所描述的方法。
本申请第十二方面提供一种芯片,芯片与存储器耦合,用于执行存储器中存储的程序,以执行如第二方面所描述的方法。
本申请提供的方案,使目标模型中的第一生成模型只能根据源域数据本身的特征生成第一生成结果。避免在生成第一生成的结果中,产生“标签主导”的问题,忽略了源域数据本身的特征,比如忽略了图片的语义信息。此外,本申请提供了多种训练第一生成模型的方案,包括根据源域数据、目标域有标签数据以及目标域无标签数据训练模型,本申请提供的方案提供更多的信息用于模型训练,提升训练后的模型的性能,解决了现有的方案,在训练过程中,只考虑了源域有标签的数据和目标域有标签的数据,对目标域有标签的数据存在很深的依赖的问题。因此通过本申请提供的训练方法获取的模型,可以根据源域的数据获取大量优质的目标域数据。
附图说明
图1为一种生成式对抗网络的架构示意图;
图2为一种环形生成对抗网络的结构示意图;
图3一种低资源域自适应的增强环形对抗网络的架构示意图;
图4为本申请实施例提供的一种模型训练的方法的流程示意图;
图5为本申请实施例提供的一种目标模型的架构示意图;
图6为本申请实施例提供的一种模型训练的方法的流程示意图;
图7为本申请实施例提供的另一种目标模型的架构示意图;
图8为本申请实施例提供的一种模型训练的方法的流程示意图;
图9为本申请实施例提供的另一种目标模型的架构示意图;
图10为本申请实施例提供的一种模型训练的方法的流程示意图;
图11为本申请实施例提供的另一种目标模型的架构示意图;
图12为本申请实施例提供的一种模型训练的方法的流程示意图;
图13为本申请实施例提供的另一种目标模型的架构示意图;
图14为本申请实施例提供的一种数据生成方法的流程示意图;
图15为本申请实施例提供的一种数据生成模型的架构示意图;
图16为本申请实施例提供的一种计算机设备的结构示意图;
图17为本申请实施例提供的另一种计算机设备的结构示意图;
图18为本申请实施例提供的芯片的一种结构示意图。
具体实施方式
下面结合附图,对本申请的实施例进行描述,显然,所描述的实施例仅仅是本申请一部分的实施例,而不是全部的实施例。本领域普通技术人员可知,随着技术的发展和新场景的出现,本申请实施例提供的技术方案对于类似的技术问题,同样适用。
本申请提供一种模型训练的方法、数据生成的方法以及设备。通过本申请提供的一种模型训练的方法得到的模型,可以根据源领域的数据获取大量优质的目标域数据。
由于本申请涉及大量领域自适应相关的知识,为了更好的理解本申请提供的方案,下面对本申请涉及的与领域自适应相关的术语进行介绍。
1、领域自适应
领域自适应是指利用数据、任务或模型之间的相似性,将在源领域(以下简称为源域)学习过的模型,应用于目标领域(以下简称为目标域)的一种学习过程。在本申请中,可以将领域视为数据域,可以将源域看作是能够提供大量有标签数据的域,将目标域看作是不能够提供大量有标签数据的域,即目标域提供的少量有标签的数据不足以训练好一个模型。需要说明的是,源域和目标域需要有一定相似性,但同时存在差异,导致运用源域数据训练的模型不能直接用于目标域预测。其中,相似性可以表现在源域和目标域的标签集合是相同的,差异可以表现在数据的风格不同,比如源域是真实拍摄的图片数据,目标域是油画风格的图片数据;再比如,差异可以表现在数据的来源不同,比如源域是来自A设备采集的数据,目标域是来自B设备采集的设备,A设备和B设备可能是不同厂家的设备,或者是同一个厂家不同型号的设备;再比如在电信运维网络中,电信运维过程中会产生故障检测数据,则来自于不同地区的数据可以分别看做一个域。
2、生成式对抗网络(generative adversarial networks,GAN)
通过对抗训练的方式来解决领域自适应问题是目前最常用的方法。参阅图1,为一种生成式对抗网络的架构示意图。如图1所示,生成式对抗网络通常包括两个部分,一个是生成模型(本申请有时也称之为生成器),一个是判别模型(本申请有时也称之为判别器)。通过这两个模型互相对抗训练,从而产生更好的输出。其中,生成模型捕捉真实训练样本的潜在分布,并生成新的样本。判别模型是一个二分类器,用于判别输入样本是真实样本还是生成样本。通过迭代优化生成模型和判别模型,当判别模型无法正确判别输入样本的数据来源时,可以认为这个生成模型已经学到了真实训练数据的分布。
3、环形生成对抗网络(cycle generative adversarial networks,CycleGAN)
对于生成式对抗网络GAN,训练后的生成器输出什么样的数据是无法预知的,生成的数据可能和输入的数据是毫无关联的,为了能够使输出的数据与输入的数据是对应的、有关联的,环形生成对抗网络CycleGAN应运而生。参阅图2,为一种环形生成对抗网络的结构示意图。如图2所示,环形生成对抗网络CycleGAN通常包括两个生成器和两个判别器。如图2所示,将源域数据作为训练数据输入至生成模型1中,生成模型1针对获取的源域数据,输出生成数据1,这一过程可以理解为源域数据通过生成模型1生成目标域数据。将生成数据1输入至生成模型2中,以输出生成数据2,这一过程可以理解为将生成的目标域数据重新生成源域数据。根据生成数据2(生成的源域数据)和作为训练数据输入至生成模型1的源域数据之间的差异确定损失值,根据该损失值更新生成器1和生成器2,以使输出的数据与输入的数据是对应的、有关联的。此外,判别器1用于判断是生成的目标域数据还是真实的目标域数据,判别器2用于判断是生成的源域数据还是真实的源域数据。在训练过程中,初始化生成器和判别器的参数后,每一次迭代过程中,首先固定生成器(生成器1和生成器2)的参数,只更新判别器(判别器1和判别器2的参数)的参数。判别器的训练目标是,如果输入是来自于真实域数据(真实源域数据、真实的目标域数据),则给高分;如果是生成器产生的数据(生成数据1和生成数据2),则给低分。接下来,固定住判别器的参数,更新生成器,这一阶段判别器的参数已经固定住了,生成器(生成器1和生成器2)需要调整自己的参数,使得生成器输出的生成数据,被判别器判别的分数越大越好。
4、低资源域自适应的增强环形对抗网络(augmented cyclic adversariallearning for low resource domain adaptation,ACAL)
虽然CycleGAN可以在一定程度上解决输出的数据与输入的数据是对应的,有关联的,然而,CycleGAN无法保证生成器生成的数据和输入至生成器的数据的标签是一致的。比如当输入至生成器1的源域数据是真实拍摄风格的猫,生成器1生成的目标域数据可能是油画风格的狗,经过生成器2之后,又重新生成真实拍摄风格的猫。在这种情况下,生成器的输出结果已经可以迷惑判别器,即判别器1认为生成器1输出的数据是真实的目标域数据,判别器2会认为生成器2输出的数据是真实的源域数据。这是因为,对于CycleGAN,在训练过程中没有考虑标签的影响。ACAL是基于CycleGAN的一种方案,可以解决CycleGAN无法保证生成器生成的数据和输入至生成器的数据的标签不一致的这一问题。参照图3,为一种低资源域自适应的增强环形对抗网络的架构示意图。如图3所示,低资源域自适应的增强环形对抗网络ACAL在CycleGAN的基础上,加入了两个分类器。其中,分类器1用于获取生成器1生成的目标域数据的类别,分类器2用于获取生成器2生成的源域数据的类别,通过生成的目标域数据的类别和生成的源域数据的类别之间的差异获取损失值,通过损失值更新分类器和生成器。
申请人发现,目前已有的解决领域自适应问题的方案至少存在以下两个方面的问题:
1)现有的方案存在“标签主导”的问题。以ACAL为例进行说明,输入至生成器1的源域数据要包括标签,比如源域数据是图片数据,则输入至生成器1的数据为“图片+标签”的组合形式。这会导致生成器1在生成目标域数据的时候,会以标签为主导,仅生成对应标签的目标域数据,而忽略了源域数据的语义信息。具体的,对于具有同一标签的大量源域数据,生成器1只需要生成少量与该大量源域数据有相同标签的目标域数据,就可以迷惑判别器。比如,源域数据的风格为真实拍摄风格,目标域数据的风格为油画风格,输入至生成器1的数据包括5种不同种类的真实拍摄风格的猫,5种不同种类的真实拍摄风格的猫的标签都为“猫”。则针对这5种不同种类的真实拍摄风格的猫,生成器1可能只会生成一种油画风格的猫。申请人发现了这一问题:现有的解决领域自适应的问题的方案,无法根据源域数据生成大量的目标域数据。
2)现有的方案,在训练过程中,只考虑了源域有标签的数据和目标域有标签的数据,对目标域有标签的数据存在很深的依赖。然而在实际情况中,目标域有标签的数据是少量的,目标域无标签的数据才是大量存在的,如何在训练过程中,利用目标域无标签的数据,提供更多的信息用于模型训练,提升训练后的模型的性能,尚未解决。
为了解决上述问题,本申请实施例提供一种模型训练的方法以及一种数据生成的方法,通过本申请提供的方案,解决现有技术中存在的“标签主导”的问题,根据源域数据、目标域有标签数据以及目标域无标签数据训练模型,通过本申请提供的训练方法获取的模型,可以根据源域的数据获取大量优质的目标域数据。
一、训练模型
参阅图4,为本申请实施例提供的一种模型训练的方法的流程示意图。
如图4所示,本申请提供的一种模型训练的方法,可以包括以下步骤:
401、将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。
图4所示的方法可以适用在图5所示的目标模型中。参阅图5,为本申请实施例提供的一种目标模型的架构示意图。如图5所示,本申请实施例提供的一种目标模型包括一个生成模型(第一生成模型)和一个判别模型(第一判别模型)。其中,生成模型包括一个输入,一个输出,判别模型包括两个输入和一个输出。在图4所描述的方法中,生成模型的输入为未携带标签的源域数据。生成模型的输出为第一生成结果。第一判别模型的一个输入是第一组合数据,另一个输入是目标域有标签数据,参照步骤403进行理解。本申请中的源域数据可以是图片数据、文本数据或者语音数据,本申请实施例对源域数据的类型并不进行限定。本申请中的目标域数据与源语数据的数据类型相同。上文介绍到可以将源域看作是能够提供大量有标签数据的域,以源域数据是图片数据为例进行说明,假设源域数据包括第一图片数据和第二图片数据,其中第一图片数据包括第一标签,第二图片数据包括第二标签。将未携带标签的源域数据输入至第一生成模型,可以理解为将第一图片数据(不携带第一标签)和第二图片数据(不携带第二标签)输入至第一生成模型。由于输入至第一生成模型的源域数据未携带标签,所以第一生成模型无法根据源域数据的标签生成第一生成结果,只能根据源域数据的特征,比如根据图片的语义信息生成第一生成结果。通过这样的输入设计,使第一生成模型只能根据源域数据本身的特征生成第一生成结果,避免在生成第一生成的结果中,产生“标签主导”的问题,忽略了源域数据本身的特征,比如忽略了图片的语义信息。此外,可以将未携带标签的源域数据分批次输入至第一生成模型中。比如,每次输入预设数目的未携带标签的源域数据至第一生成模型中。
在一个可能的实施方式中,第一生成模型可采用GAN、CycleGAN等已有的对抗网络中的生成器。
在一个可能的实施方式中,第一生成模型可以是深度神经网络(deep neuralnetworks,DNN)。
此外,本申请有时也将标签称为标注,二者表示相同的意思。
402、将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。
以图片数据为例进行说明,假设源域数据包括图片1(标签为“猫”)、图片2(标签为“狗”)、图片3(标签为“鸡”)。通过执行步骤401,将图片1(不携带标签“猫”)、图片2(不携带标签“狗”)、图片3(不携带标签“鸡”)输入至第一生成模型。针对图片1、图片2以及图片3,第一生成模型分别生成3个生成图片,假设分别为生成图片1、生成图片2以及生成图片3。将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据,可以理解为,将生成图片1和图片1未携带的标签“猫”进行组合,以获取一组组合数据,将生成图片2和图片2未携带的标签“狗”进行组合,以获取另一组组合数据,将生成图片3和图片3未携带的标签“鸡”进行组合,以获取另一组组合数据。以上,一共获取了三组组合数据。
403、将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。
目标域的数据包括有标签的数据和无标签的数据。需要说明的是,未携带标签的源域数据代表源域数据是有标签的,只是输入至第一生成模型中时没有携带,而目标域无标签的数据代表该目标域的数据是没有标签的。第一判别模型用于判别输入至第一判别模型的数据是真实的目标域数据还是第一生成器生成的数据。第一判别模型的训练目标为判别第一组合数据为负样本,判别目标域有标签数据为正样本。其中,第一判别模型判别为负样本,则第一判别模型可以输出“0”或者“false”,第一判别模型判别为正样本,则第一判别模型可以输出“1”或者“ture”。训练第一判别模型的目的在于,使第一判别模型可以更准确的判别获取到的数据哪些是生成器生成的数据,哪些是目标域有标签的数据。
在一个可能的实施方式中,第一判别模型可采用GAN、CycleGAN等已有的对抗网络中的判别器。
在一个可能的实施方式中,第一判别模型可以是深度神经网络(deep neuralnetworks,DNN)。
404、根据第一判别结果获取第一损失值。
根据第一判别结果和第一判别模型的训练目标之间的差异获取第一损失值。
在一个可能的实施方式中,将第一组合数据输入至第一判别模型中,以输出第一判别结果时,第一损失值可以参照公式1-1进行理解:
其中,E代表期望,x代表未携带标签的数据,y代表x未携带的标签,ps(x,y)表示源域未携带标签的数据和未携带的标签之间的概率联合分布,Dt()表示第一判别器的输入,GS→T()表示第一生成模型的输入。
在一个可能的实施方式中,将目标域有标签数据输入至第一判别模型中,以输出第一判别结果时,第一损失值可以参照公式1-2进行理解:
405、根据第一损失值对第一生成模型和第一判别模型进行对抗训练。
对抗训练是指第一生成模型和第一判别模型形成一种动态的“博弈过程”,相互对抗,相互促进。可以理解为,第一生成模型的训练是了使第一判别模型无法正确判断接收的数据是来自第一生成模型还是来自真实的目标域数据,第一判别模型的训练是为了准确分辨接收的数据是来自第一生成模型还是来自真实的目标域数据。
重复执行上述步骤401至步骤405,直至满足预设的停止条件,其中,停止条件包括但不限于满足预设的训练次数、第一生成模型满足收敛条件、第一判别模型无法正确判断接收的数据是来自第一生成模型还是来自真实的目标域数据。
在一次迭代训练的过程中,可以先固定第一生成模型的参数,即先使第一生成模型的参数保持不变,根据第一损失值更新第一判别模型,以得到第一更新判别模型(更新后的第一判别模型)。或者在一次迭代训练中,可以先固定第一判别模型的参数,根据第一损失值更新第一生成模型,以得到第一更新生成模型(更新后的第一生成模型)。交替执行更新第一判别模型和更新第一生成模型的步骤。由于第一生成模型要对抗第一判别模型训练,具体的根据第一变换损失值L1’更新第一生成模型,可以参照公式1-3进行理解,第一变换损失值L1’是第一损失值的相反数:
对上面的迭代训练举例说明,将未携带标签的源域数据(比如第一源域数据)输入至第一生成模型中,通过执行步骤401至步骤404获取了第一损失值(以下称为第一损失值a)。固定第一生成模型的参数,根据第一损失值a更新第一判别模型(以下将此处更新后得到的第一判别模型称为第一判别模型a)。固定第一判别模型a,将未携带标签的源域数据(比如第二源域数据)输入至第一生成模型中,通过执行步骤401至步骤404获取了另一个第一损失值(以下称为第一损失值b),根据第一损失值获取了第一变换损失值b’。根据第一变换损失值b’更新第一生成模型(以下将此处更新后得到的第一生成模型称为第一生成模型b)。将未携带标签的源域数据(比如第三源域数据)输入至第一生成模型中,通过执行步骤401至步骤404获取了另一个第一损失值(以下称为第一损失值c)。固定第一生成模型b的参数,根据第一损失值c更新第一判别模型a(以下将此处更新后得到的第一判别模型称为第一判别模型b)。固定第一判别模型b,将未携带标签的源域数据(比如第四源域数据)输入至第一生成模型中,通过执行步骤401至步骤404获取了另一个第一损失值(以下称为第一损失值d),根据第一损失值d获取了第一变换损失值d’,根据第一变换损失值d’更新第一生成模型b,依次类推,每次输入新的未携带标签的源域数据至第一生成模型中后,根据获取到的第一损失值交替更新第一判别模型和第一生成模型,直至满足预设的停止条件。需要说明的是,上述例子中,每次输入新的未携带标签的源域数据至第一生成模型中后,根据获取到的第一损失值交替更新第一判别模型和第一生成模型。在一些可能的实施方式中,可以每多次输入新的未携带标签的源域数据至第一生成模型中后,根据获取到的第一损失值交替更新第一判别模型和第一生成模型。比如,第一生成模型接收到第一源域数据至第三源域数据时,都是固定第一生成模型的参数,更新第一判别模型,接收到第四源域数据时,固定更新后的第一判别模型,更新第一生成模型。
由图4对应的实施例可知,本申请提供的一种模型训练的方法,根据未携带标签的源域数据对目标模型进行迭代训练。图4对应的实施例使目标模型中的第一生成模型只能根据源域数据本身的特征生成第一生成结果。避免在生成第一生成的结果中,产生“标签主导”的问题,忽略了源域数据本身的特征,比如忽略了图片的语义信息。因此通过本申请提供的训练方法获取的模型,可以根据源域的数据获取大量优质的目标域数据。
参阅图6,为本申请实施例提供的一种模型训练的方法的流程示意图。
如图6所示,本申请提供的一种模型训练的方法,可以包括以下步骤:
601、将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。
602、将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。
603、将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。
604、根据第一判别结果获取第一损失值。
步骤601至步骤604可以参照图4对应的实施例中的步骤401至步骤404进行理解,这里不再重复赘述。
605、将第一生成结果输入至第二生成模型,以输出第二生成结果。
图6对应的实施例可以适用在图7对应的目标模型中,参阅图7,为本申请实施例提供的另一种目标模型的架构示意图。如图7所示,本申请实施例提供的一种目标模型包括两个生成模型和一个判别模型。将未携带标签的源域数据作为训练数据输入至第一生成模型中,第一生成模型针对获取的未携带标签的源域数据,输出生成数据1,这一过程可以理解为将未携带标签的源域数据通过第一生成模型生成目标域数据。将生成数据1输入至第二生成模型中,以输出生成数据2,这一过程可以理解为将生成的目标域数据重新生成未携带标签的源域数据。
606、根据第二生成结果和未携带标签的源域数据之间的差异获取第二损失值。
根据生成数据2(生成的未携带标签的源域数据)和作为训练数据输入至第一生成模型的未携带标签的源域数据之间的差异确定损失值,根据该损失值更新第一生成模型和第二生成模型。示例性的,在一个可能的实施方式中,可以参照公式1-4进行理解:
其中,E代表期望,x代表未携带标签的数据,y代表x未携带的标签,ps(x,y)表示源域的数据和标签之间的概率联合分布,GT→S()()表示第二生成模型的输入,pt(x,y)表示目标域的数据和标签之间的概率联合分布,GS→T()表示第一生成模型的输入。
607、固定第一判别模型的参数,根据第一损失值更新第一生成模型的参数。
固定第一判别模型的参数,根据第一损失值和第二损失值更新第一生成模型的参数。具体的,根据第一变换损失值和第二损失值更新第一生成模型的参数。
在对目标模型进行迭代训练的过程中,如果当前迭代训练是固定第一判别模型的参数,对第一生成模型的参数进行更新,则根据第一损失值和第二损失值更新第一生成模型的参数。其中,第二损失值还用于更新第二生成模型。在一个可能的实施方式中,根据第一损失值和第二损失值更新第一生成模型的参数可以理解为对第一损失值的相反数和第二损失值进行加权处理,根据加权处理后获取的损失值更新第一生成模型的参数和第二生成模型的参数。
图6对应的实施例,给出了一种具体的,在进行对抗训练时,如何根据第一损失值对第一生成模型进行更新的方案,下面对进行对抗训练时,如何根据第一损失值对第一判别模型进行更新进行说明。
参阅图8,为本申请实施例提供的一种模型训练的方法的流程示意图。
如图8所示,本申请提供的一种模型训练的方法,可以包括以下步骤:
801、将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。
802、将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。
803、将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。
804、根据第一判别结果获取第一损失值。
步骤801至步骤804可以参照图4对应的实施例中的步骤401至步骤404进行理解,这里不再重复赘述。
805、将第一生成结果和预设标签进行组合,以获取第二组合数据。
为了增加训练样本的数量,获取更多的训练样本,训练样本数目的提升有助于提升训练获取的模型的性能。假设预设标签可以为“unknown label”,第一生成结果包括生成图片1,生成图片2,生成图片3。则将生成图片1和“unknown label”进行组合,获取一组第二组合数据,则将生成图片2和“unknown label”进行组合,获取另一组第二组合数据,则将生成图片3和“unknown label”进行组合,获取另一组第二组合数据。
806、将目标域无标签数据和预设标签进行组合,以获取第三组合数据。
由于目标域有标签的数据的数目过少,在训练过程中,能够提供给判别模型的信息有限,不利于模型的训练。本申请提供的方案在训练过程中利用了目标域无标签数据,将目标域无标签数据和预设标签进行组合,可以获取大量的目标域有标签的数据,进而可以获取大量真实的目标域有标签数据。通过在训练过程中,增加大量真实的目标域有标签数据作为训练样本,提升训练模型的效率,使训练后的模型性能更优,具体的,使训练后的判别模型和生成模型的性能更优。将目标域无标签数据和预设标签进行组合,以获取第三组合数据与获取第二组合数据的过程相似,比如,目标域无标签数据包括数据1、数据2、数据3,则将数据1和“unknown label”进行组合,获取一组第三组合数据,则将数据2和“unknownlabel”进行组合,获取另一组第三组合数据,则将数据3和“unknown label”进行组合,获取另一组第三组合数据。
807、将第二组合数据以及第三组合数据输入至第二判别模型中,以输出第二判别结果。
图8对应的实施例可以适用在图9所示的目标模型中,参阅图9,为本申请实施例提供的另一种目标模型的架构示意图。如图9所示,本申请实施例提供的一种目标模型包括至少一个生成模型和两个判别模型。其中,包括一个生成模型时,可以参照图5所示的一种目标模型中的第一生成模型进行理解,包括两个生成模型时,可以参照图7所示的一种目标模型中的第一生成模型和第二生成模型进行理解,这里不再重复赘述。第二判别模型用于判别输入至第二判别模型的数据是真实的目标域数据还是生成器生成的数据,其中真实的目标域数据包括目标域无标签数据和预设标签进行组合后获取的组合数据。第二判别模型的训练目标为判别第二组合数据为负样本,判别第三组合数据为正样本。其中,第二判别模型判别为负样本,则第二判别模型可以输出“0”或者“false”,第二判别模型判别为正样本,则第二判别模型可以输出“1”或者“ture”。第二判别模型的训练目标为判别第二组合数据为负样本可以理解为第二判别模型的训练目标是判别第二组合数据是生成器生成的数据,第二判别模型的训练目标为判别第三组合数据可以理解为判别第三组合数据是真实的目标域数据。训练第二判别模型的目的在于,使第二判别模型可以更准确的判别获取到的数据哪些是生成器生成的数据,哪些是目标域有标签的数据。
808、根据第二判别结果获取第三损失值。
根据第二判别结果和第二判别模型的训练目标之间的差异获取第三损失值。
在一个可能的实施方式中,将第二组合数据输入至第二判别模型时,第二判别模型的训练目标为判别第二组合数据为负样本,根据第二判别结果获取第三损失值可以参照公式1-5进行理解:
在一个可能的实施方式中,将第三组合数据输入至第二判别模型时,第二判别模型的训练目标为判别第三组合数据为正样本,根据第二判别结果获取第三损失值可以参照公式1-6进行理解:
其中,E代表期望,x代表未携带标签的数据,pu(x)表示目标域中无标注数据的边缘概率分布,Dt()表示第一判别器的输入,K表示源域/目标域中标签集合中包括的标签数量。pu(x)表示源域中未携带标签的数据的边缘概率分布。
809、固定第一生成模型的参数,根据第一损失值和第三损失值更新第一判别模型。
在对目标模型进行迭代训练的过程中,如果当前迭代训练是固定第一生成模型的参数,对第一判别模型进行更新,则可以根据第一损失值和第三损失值更新第一判别模型。第一损失值和第三损失值还用于更新第二判别模型,既第一判别模型和第二判别模型共享模型参数。在一个可能的实施方式中,根据第一损失值和第三损失值更新第一判别模型的参数可以理解为对第一损失值和第三损失值进行加权处理,根据加权处理后获取的损失值更新第一判别模型的参数。
810、固定第一判别模型的参数,根据第一损失值和第三损失值更新第一生成模型的参数。
在对目标模型进行迭代训练的过程中,如果当前迭代训练是固定第一判别模型的参数,对第一生成模型进行更新,则可以根据第一损失值和第三损失值更新第一生成模型,比如根据第一损失值的相反数、第三损失值的相反数更新第一判别第一生成模型。在一个可能的实施方式中,可以根据第一损失值的相反数、第二损失值以及第三损失值的相反数更新第一生成模型。
图8对应的实施例,给出了一种具体的,在进行对抗训练时,如何根据第一损失值对判别模型和生成模型进行更新的方案。在一些可能的实施方式中,还可以在目标网络中引入分类器,以根据目标域无标签数据获取更多的训练数据。
参阅图10,为本申请实施例提供的一种模型训练的方法的流程示意图。
如图10所示,本申请提供的一种模型训练的方法,可以包括以下步骤:
1001、将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。
1002、将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。
1003、将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。
1004、根据第一判别结果获取第一损失值。
步骤1001至步骤1004可以参照图4对应的实施例中的步骤401至步骤404进行理解,这里不再重复赘述。
1005、将目标域无标签数据输入至第一分类模型中,以输出第一预测标签。
图10对应的实施例可以适用在图11所示的目标模型中,参阅图11,为本申请实施例提供的另一种目标模型的架构示意图。如图11所示,本申请实施例提供的一种目标模型包括至少一个生成模型、至少一个判别模型以及一个分类模型。其中,包括一个生成模型时,可以参照图5所示的一种目标模型中的第一生成模型进行理解,包括两个生成模型时,可以参照图7所示的一种目标模型中的第一生成模型和第二生成模型进行理解,这里不再重复赘述。包括两个判别模型时,其中一个判别模型可以参照图9对应的实施例中的第二判别模型进行理解,这里不再重复赘述。在一个可能的实施方式中,第一分类模型可以是卷积神经网络(convolutional neural networks,CNN)。
1006、将目标域无标签数据和第一预测标签进行组合,以获取第四组合数据。
假设目标域无标签数据a输入至第一分类模型中,第一分类模型针对目标域无标签数据a,输出第一预测标签a,则将目标域无标签数据和第一预测标签a进行组合,以获取一组第四组合数据。再比如,目标域无标签数据b输入至第一分类模型中,第一分类模型针对目标域无标签数据b,输出第一预测标签b,则将目标域无标签数据和第一预测标签b进行组合,以获取一组第四组合数据。
1007、将第四组合数据输入至第一判别模型中,以输出第三判别结果。
上文几个实施例已经介绍了第一判别模型用于判别输入至第一判别模型的数据是真实的目标域数据还是生成器生成的数据,在图10对应的实施例中,训练目标为第四组合数据有预设概率被判别为是真实的目标域数据,还有预设概率被判别是生成器生成的数据。比如,在一种可能的实施方式中,第一判别模型的训练目标还包括有50%的概率判别第四组合数据是真实的目标域数据,有50%的概率判别第四组合数据是生成器生成的数据。再比如,在一个可能的实施方式中,第一判别模型的训练目标与第一分类模型的准确度相关,比如第一分类模型的准确度未达到60%,则第一判别模型以判别第四组合数据是生成器生成的数据为训练目标,若第一分类模型的准确度达到60%,则第一判别模型以判别第四组合数据是真实的目标域数据为训练目标。
1008、根据第三判别结果获取第四损失值。
根据第三判别结果和第一判别模型的训练目标之间的差异获取第四损失值,步骤1108中的第一判别模型的训练目标参照步骤1107中描述的训练目标进行理解。示例性的,可以参照公式1-7进行理解:
其中,μ表示0至1之间的一个自然数,E代表期望,x代表未携带标签的数据,y代表x未携带的标签,ps(x,y)表示源域未携带标签的数据和未携带的标签之间的概率联合分布,Dt()表示第一判别器的输入,C1(x)代表第一分类模型的输入。
1009、根据第四损失值更新第一分类模型。
在一个可能的实施方式中,分类模型和生成模型同时更新,则固定第一判别模型的参数,根据第四损失值的相反数更新第一分类模型。若包括两个判别模型,则固定第一判别模型的参数和第二判别模型的参数,根据第四损失值的相反数更新第一分类模型。
1010、固定第一生成模型的参数、第一分类模型的参数根据第一损失值和第四损失值更新第一判别模型。
在一个可能的实施方式中,固定第一生成模型的参数和第一分类模型的参数,根据第一损失值和第四损失值更新第一判别模型。比如对第一损失值和第四损失值进行加权处理,根据加权处理后获取的损失值更新第一判别模型。
在一个可能的实施方式中,固定第一生成模型的参数和第一分类模型的参数,根据第一损失值、第三损失值和第四损失值更新第一判别模型。比如对第一损失值、第三损失值以及第四损失值进行加权处理,根据加权处理后获取的损失值更新第一判别模型。其中,第三损失值的获取过程可以参照图8对应的实施例中的步骤805至808进行理解,这里不再重复赘述。
在一些可能的实施方式中,提升第一分类模型的训练效率,提升训练后的目标模型的性能,还可以增加用于训练第一分类模型的训练样本。
参阅图12,为本申请实施例提供的一种模型训练的方法的流程示意图。
如图12所示,本申请提供的一种模型训练的方法,可以包括以下步骤:
1201、将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。
1202、将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。
1203、将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。
1204、根据第一判别结果获取第一损失值。
步骤801至步骤804可以参照图4对应的实施例中的步骤401至步骤404进行理解,这里不再重复赘述。
1205、将目标域无标签数据输入至第一分类模型中,以输出第一预测标签。
1206、将目标域无标签数据和第一预测标签进行组合,以获取第四组合数据。
1207、将第四组合数据输入至第一判别模型中,以输出第三判别结果。
1208、根据第三判别结果获取第四损失值。
步骤1205至步骤1208可以参照图10对应的实施例中的步骤1005至步骤1008进行理解,这里不再重复赘述。
1209、将第一生成结果和目标域有标签数据输入至第二分类模型中,以输出第二预测标签。
图12对应的实施例可以适用在图13所示的目标模型中,参阅图13,为本申请实施例提供的另一种目标模型的架构示意图。如图13所示,本申请实施例提供的一种目标模型包括至少一个生成模型、至少一个判别模型以及两个分类模型。其中,包括一个生成模型时,可以参照图5所示的一种目标模型中的第一生成模型进行理解,包括两个生成模型时,可以参照图7所示的一种目标模型中的第一生成模型和第二生成模型进行理解,这里不再重复赘述。包括一个判别模型时,可以参照图5、图9以及图11中的第一判别模型进行理解,包括两个判别模型时,可以参照图9以及图11中的第一判别模型和第二判别模型进行理解,这里不再重复赘述。在一个可能的实施方式中,第二分类模型可以是卷积神经网络(convolutional neural networks,CNN)。
1210、根据第二预测标签获取第五损失值。
根据第二预测结果和源域数据未携带的标签之间的差异获取第五损失值,此外,根据第二预测结果和目标域真实的标签之间的差异获取第五损失值。当第二分类模型的输入是第一生成结果时,根据第二预测结果和目源域数据未携带的标签之间的差异获取第五损失值。比如,第一生成模型的输入包括未携带标签的源域数据为源域数据a,源域数据a未携带的标签为标签a,生成器根据源域数据a生成第一生成结果,假设是第一生成结果a,将第一生成结果a输入至第二分类模型,以获取第二预测标签,比如是预测标签a,则根据预测标签a和标签a之间的差异获取第五损失值。示例性的,可以参照公式1-8进行理解:
其中,K表示源域/目标域中标签集合中包括的标签数量,E代表期望,x代表未携带标签的数据,pu(x)表示目标域中无标注数据的边缘概率分布,C1()表示第一分类器的输入,其中C1(i|x)表示输入x,输出第i类的概率。
当第二分类模型的输入是目标域有标签数据,根据第二预测结果和目标域真实的标签之间的差异获取第五损失值。比如,第二分类模型的输入包括目标域数据b(该数据b的标签为标签b),假设第二分类模型针对目标域数据b获取的第二预测结果为预测标签b,则根据预测标签b和标签b之间的差异获取第五损失值。示例性的,可以参照公式1-9进行理解:
其中,μ表示0至1之间的一个自然数,E代表期望,x代表未携带标签的数据,y代表x未携带的标签,ps(x,y)表示源域未携带标签的数据和未携带的标签之间的概率联合分布,C2(x)代表第二分类模型的输入。
1211、根据第五损失值更新第二分类模型。
在一个可能的实施方式中,分类模型和生成模型同时更新,则固定第一判别模型的参数,根据第五损失值更新第二分类模型。若包括两个判别模型,则固定第一判别模型的参数和第二判别模型的参数,根据第五损失值更新第二分类模型。
在一个可能的实施方式中,分类模型可以延后于生成模型更新,比如生成模型(第一生成模型,或者第一生成模型和第二生成模型)已经迭代训练了预设次数后,分类模型开始与生成模型同步更新。具体的,对第二分类模型进行迭代训练,包括固定第一判别模型的参数,根据第五损失值更新第二分类模型。若包括两个判别模型,固定第一判别模型的参数和第二判别模型的参数,根据第五损失值更新第二分类模型。
1212、根据第四损失值和第五损失值更新第一分类模型。
在一个可能的实施方式中,对第四损失值和第五损失值进行加权处理,根据加权处理后获取的损失值更新第一分类模型。在这种实施方式中,第一分类模型和第二分类模型共享模型参数。
以上对本申请实施例提供的一种模型训练的方法进行了介绍,下面对如何应用上述训练好的模型执行数据生成任务进行说明。
二、通过训练好的模型执行数据生成任务。
参阅图14,为本申请实施例提供的一种数据生成方法的流程示意图。
如图14所示,本申请提供的一种数据生成方法,可以包括以下步骤:
1401、获取未携带标签的源域数据。
1402、将该未携带标签的源域数据输入至第一生成模型中,以获取目标域数据。
其中,第一生成模型是通过图4、图6、图8、图10以及图12中对应的任意一个实施例中所描述的训练方法得到的训练后的第一生成模型。
需要说明的是,根据本申请提供的一种数据生成方法,可以根据源域数据获取大量的、优质的目标领域的数据。此外,还可以根据获取到的目标领域的数据训练新的模型,比如通过获取到的目标领域的数据训练新的分类模型。在一个可能的实施方式中,图14对应的实施例中描述的训练后的第二分类模型可以用于对目标领域的数据执行分类任务。
参阅图15,为本申请实施例提供的一种数据生成模型的架构示意图,图14对应的实施例所描述的方法可以适用在图15所示的数据生成模型中。该数据生成模型包括第一生成模型,用于生成目标域的数据。在一个可能的方式中还可以包括第二分类模型,根据生成的目标域的数据训练第二分类模型,使训练后的第二分类模型可以针对目标域的数据,执行分类任务。该第一生成模型的输入数据是未携带标签的源域数据,输出是包括标签的目标域数据。其中,第一生成模型是通过图4、图6、图8、图10以及图12中对应的任意一个实施例中所描述的训练方法得到的训练后的第一生成模型。第二分类模型是通过图14对应的实施例所描述的训练方法得到的训练后的第二分类模型。
在本申请提供的方案的一个典型的适用场景中,训练一个模型,比如训练一个用于执行分类任务的模型,一般需要大量有标注的样本。而对于某一些领域,比如医疗领域、工业视觉领域,获取有标注的样本并不容易,具体可能表现在成本高昂,费时费力。通过本申请提供的方案,可以根据目标域少量有标注的样本,目标域大量无标注的样本、以及源域大量有标注的样本,获取目标域大量有标注的样本,降低成本。
在本申请提供的方案的另一个典型的适用场景中,对于要从其他企业采购训练数据的企业,为了保证训练数据供应的稳定性,该企业可能会从多个其他企业采购训练数据。来自不同企业的训练数据可以分别看做一个数据域,来自不同企业的训练数据可能并不满足独立同分布,即使从同一个企业采购的训练数据,该训练数据也可能是该企业不同代设备采集的数据,比如有些数据是第一代产品采集的数据,有些数据是第二代产品采集的数据,进而导致采集的数据不满足独立同分布。而训练一个模型的训练数据要求满足独立同分布,此外,训练数据如果和测试数据的分布差异较大,模型的性能往往也会显著下降。本申请提供的可以很好的解决这一问题。将从其他企业采购的训练数据看做有标注的目标域数据,通过本申请提供的方案,无需从其他企业采购大量的有标注的目标域数据,只需要采购少量的有标注的目标域数据,根据本申请提供的方案可以根据该少量的有标注的目标域数据,以及一些无标注的目标域数据,以及源域数据生成大量的有标注的目标域数据,进而节省了成本,还使得生成的大量的有标注的目标域数据均满足独立同分布,提升模型训练的效率。
上文结合附图对本申请实施例的数据生成方法进行了详细的介绍,下面结合具体的实验数据展示本申请提供的方案的优势。本申请在实验过程中采用了常用数据集SVHN(以下简称为S数据集)、MNIST(以下简称为M数据集)以及USPS(以下简称为U数据集)。这三种数据集中包括的数据均为图片数据。第一种方案和第二种方案,为现有的根据源域数据生成目标域数据的方案,根据不同方案获取的目标域数据分别训练分类模型,比较分类模型的准确率,进而展示不同方案的优势。如表1所示,“S-M”表示将S数据集作为源域数据,将M数据集作为目标域数据。“每个类别的有标签数据”为q,表示使目标域数据集中的每一类别上有q张有标签的图片,该类别上的其他图片没有标签,将该q张有标签的图片作为目标域有标签数据,该类包括的其他图片作为目标域无标签数据。表格中的其他数值表示准确率,准确率越高,表示分类器的性能越高。如表1中所示,通过本申请提供的方案生成的数据训练得到的分类器(可以理解为上述实施例描述的第二分类模型),对于目标领域的数据的分类效果要远高于其他几种现有的方案。表1中所示的测试结果为低资源的情况,表2为高资源时的测试结果,其中高资源和低资源是根据目标领域中包括的无标注数据的数目确定的。高资源是指目标域中包括大量无标注的数据,低资源是指目标域中包括的无标注数据是少量的,或者有限制的。参阅表2,“每个类别的无标签数据”为p,表示使目标域数据集中的每一类别上有p张无标签的图片。由表2可见,本申请提供的方案,针对目标域数据有大量无标注数据时,通过本申请提供的方案可以得到性能更好的分类模型。参阅表3,当目标域无标注数据越来越多时,通过本申请提供的方案训练得到的分类模型表现会持续变好,而目前最好的基于ACAL方案生成目标域数据的方案,并不能保证这样的趋势。
表1:
表2:
表3
以上对本申请提供的一种模型训练的方法以及一种数据生成的方法进行了介绍,通过本申请实施例提供的方案。可以理解的是,上述训练装置以及翻译设备为了实现上述功能,其包含了执行各个功能相应的硬件结构和/或软件模块。本领域技术人员应该很容易意识到,结合本文中所公开的实施例描述的各示例的模块及算法步骤,本申请能够以硬件或硬件和计算机软件的结合形式来实现。某个功能究竟以硬件还是计算机软件驱动硬件的方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
从硬件结构上来描述,图4、图6、图8、图10、图12中的执行主体可以由一个实体设备实现,也可以由多个实体设备共同实现,还可以是一个实体设备内的一个逻辑功能模块,本申请实施例对此不作具体限定。图14中的执行主体可以由一个实体设备实现,也可以由多个实体设备共同实现,还可以是一个实体设备内的一个逻辑功能模块,本申请实施例对此不作具体限定。
下面基于前述的一种模型训练的方法以及一种数据生成的方法,对本申请提供的模型训练的装置、数据生成装置进行阐述,模型训练的装置用于执行前述图4、图6、图8、图10、图12对应的方法的步骤。数据生成装置用于执行前述图14对应的方法的步骤。
例如,模型训练的装置或者数据生成装置可以通过图16中的计算机设备来实现,图16所示为本申请实施例提供的计算机设备的硬件结构示意图。包括:通信接口1601和处理器1602,还可以包括存储器1603。
通信接口1601可以使用任何收发器一类的装置,用于与其他设备或通信网络通信,在本方案中,端侧设备可以利用通信接口1601与服务器进行通信,比如上传模型或者下载模型。在一个可能的实施方式中,通信接口1601可以采用以太网,无线接入网(radioaccess network,RAN),无线局域网(wireless local area networks,WLAN)等技术与服务器进行通信。
处理器1602包括但不限于中央处理器(central processing unit,CPU),网络处理器(network processor,NP),专用集成电路(application-specific integratedcircuit,ASIC)或者可编程逻辑器件(programmable logic device,PLD)中的一个或多个。上述PLD可以是复杂可编程逻辑器件(complex programmable logic device,CPLD),现场可编程逻辑门阵列(field-programmable gate array,FPGA),通用阵列逻辑(genericarray logic,GAL)或其任意组合。处理器1602负责通信线路1604和通常的处理,还可以提供各种功能,包括定时,外围接口,电压调节,电源管理以及其他控制功能。
存储器1603可以是只读存储器(read-only memory,ROM)或可存储静态信息和指令的其他类型的静态存储设备,随机存取存储器(random access memory,RAM)或者可存储信息和指令的其他类型的动态存储设备,也可以是电可擦可编程只读存储器(electrically er服务器able programmable read-only memory,EEPROM)、只读光盘(compact disc read-only memory,CD-ROM)或其他光盘存储、光碟存储(包括压缩光碟、激光碟、光碟、数字通用光碟、蓝光光碟等)、磁盘存储介质或者其他磁存储设备、或者能够用于携带或存储具有指令或数据结构形式的期望的程序代码并能够由计算机存取的任何其他介质,但不限于此。存储器可以是独立存在,通过通信线路1604与处理器1602相连接。存储器1603也可以和处理器1602集成在一起。如果存储器1603和处理器1602是相互独立的器件,存储器1603和处理器1602相连,例如存储器1603和处理器1602可以通过通信线路通信。通信接口1601和处理器1602可以通过通信线路通信,通信接口1601也可以与处理器1602直连。
通信线路1604可以包括任意数量的互联的总线和桥,通信线路1604将包括由处理器1602代表的一个或多个处理器1602和存储器1603代表的存储器的各种电路链接在一起。通信线路1604还可以将诸如外围设备、稳压器和功率管理电路等之类的各种其他电路链接在一起,这些都是本领域所公知的,因此,本申请不再对其进行进一步描述。
在一个可能的实施方式中,该计算机设备是模型训练的装置,该模型训练的装置包括存储器,用于存储计算机可读指令。还可以包括和存储器耦合的通信接口和处理器。通信接口用于获取训练数据,训练数据可以包括未携带标签的源域数据和目标域有标签数据。处理器用于执行所述存储器中的计算机可读指令从而执行图4对应的实施例中的步骤401至步骤404。
在一个可能的实施方式中,该计算机设备是模型训练的装置,该模型训练的装置包括存储器,用于存储计算机可读指令。还可以包括和存储器耦合的通信接口和处理器。通信接口用于获取训练数据,训练数据可以包括未携带标签的源域数据和目标域有标签数据。处理器用于执行所述存储器中的计算机可读指令从而执行图6对应的实施例中的步骤601至步骤607。
在一个可能的实施方式中,该计算机设备是模型训练的装置,该模型训练的装置包括存储器,用于存储计算机可读指令。还可以包括和存储器耦合的通信接口和处理器。通信接口用于获取训练数据,其中训练数据可以包括未携带标签的源域数据、目标域有标签数据以及目标域无标签数据。处理器用于执行所述存储器中的计算机可读指令从而执行图8对应的实施例中的步骤801至步骤810。
在一个可能的实施方式中,该计算机设备是模型训练的装置,该模型训练的装置包括存储器,用于存储计算机可读指令。还可以包括和存储器耦合的通信接口和处理器。通信接口用于获取训练数据,其中训练数据可以包括未携带标签的源域数据、目标域有标签数据以及目标域无标签数据。处理器用于执行所述存储器中的计算机可读指令从而执行图10对应的实施例中的步骤1001至步骤1010。
在一个可能的实施方式中,该计算机设备是模型训练的装置,该模型训练的装置包括存储器,用于存储计算机可读指令。还可以包括和存储器耦合的通信接口和处理器。通信接口用于获取训练数据,其中训练数据可以包括未携带标签的源域数据、目标域有标签数据以及目标域无标签数据。处理器用于执行所述存储器中的计算机可读指令从而执行图12对应的实施例中的步骤1201至步骤1212。
在一个可能的实施方式中,该计算机设备是数据生成装置,该数据生成装置包括存储器,用于存储计算机可读指令。还可以包括和存储器耦合的通信接口和处理器。通信接口用于获取源域数据,用于执行图14对应的实施例中的步骤1401。处理器用于执行所述存储器中的计算机可读指令从而执行图14对应的实施例中的步骤1202。
在本申请实施例中,可以将通信接口视为计算机设备的收发模块1701,将具有处理功能的处理器视为计算机设备的处理模块1702,将存储器视为计算机设备的存储模块(图中未示出)。参阅图17,为本申请实施例提供的一种计算机设备的结构示意图。
本申请实施例对模型的名称并不进行限定,比如,当计算机设备是模型训练的装置时,该处理模块1702可以看做训练模块,或者可以将处理模块看做输入模块、组合模块、获取模块以及更新模块。其中,在一种可能的实施方式中,模型训练的装置可以包括:训练模块,用于根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,目标模型可以包括第一生成模型和第一判别模型,训练模块可以包括输入模块、组合模块、获取模块以及更新模块,其中,每次迭代训练时,输入模块,用于将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果。组合模块,用于将源域数据未携带的标签和第一生成结果进行组合,以获取第一组合数据。输入模块,还用于将第一组合数据以及目标域有标签数据输入至第一判别模型中,以输出第一判别结果。获取模块,用于根据第一判别结果获取第一损失值。更新模块,用于固定第一生成模型的参数,根据第一损失值更新第一判别模型,或者固定第一判别模型的参数,根据第一损失值更新第一生成模型。
在一种可能的实施方式中,目标模型还可以包括第二生成模型,输入模块,还用于将第一生成结果输入至第二生成模型,以输出第二生成结果。获取模块,还用于根据第二生成结果和未携带标签的源域数据之间的差异获取第二损失值。更新模块,具体用于根据第一损失值和第二损失值更新第一生成模型,第二损失值还用于更新第二生成模型。
在一种可能的实施方式中,训练模块,具体用于:根据未携带标签的源域数据、目标域有标签数据以及目标域无标签数据对目标模型进行迭代训练。
在一种可能的实施方式中,目标模型还可以包括第二判别模型,组合模块,还用于将第一生成结果和预设标签进行组合,以获取第二组合数据,将目标域无标签数据和预设标签进行组合,以获取第三组合数据。输入模块,还用于将第二组合数据以及第三组合数据输入至第二判别模型中,以输出第二判别结果。获取模块,还用于根据第二判别结果获取第三损失值。更新模块,具体用于根据第一损失值和第三损失值更新第一判别模型。更新模块,具体用于根据第一损失值和第三损失值更新第一生成模型。
在一种可能的实施方式中,目标模型还可以包括第一分类模型,输入模块,还用于将目标域无标签数据输入至第一分类模型中,以输出第一预测标签。组合模块,还用于将目标域无标签数据和第一预测标签进行组合,以获取第四组合数据。输入模块,还用于将第四组合数据输入至第一判别模型中,以输出第三判别结果。获取模块,还用于根据第三判别结果获取第四损失值。更新模块,具体用于固定第一生成模型的参数和第一分类模型的参数,根据第一损失值和第四损失值更新第一判别模型。更新模块,还用于固定第一判别模型的参数,根据第四损失值更新第一分类模型。
在一种可能的实施方式中,目标模型还可以包括第二分类模型,输入模块,还用于将第一生成结果和目标域有标签数据输入至第二分类模型中,以输出第二预测标签。获取模块,还用于根据第二预测标签获取第五损失值。更新模块,具体用于根据第四损失值和第五损失值更新第一分类模型,第四损失值和第五损失值还用于更新第二分类模型。
当计算机设备是数据生成的装置时,可以将处理模块1701看做获取模块,将该处理模块1702可以看做生成模块,在一个可能的实施方式中,该数据生成的装置可以包括:获取模块,用于获取未携带标签的源域数据。生成模块,用于将未携带标签的源域数据输入至目标生成模型中,以获取目标域数据。其中,目标生成模型是通过未携带标签的源域训练数据、目标域有标签训练数据对目标模型进行迭代训练后获取的,目标模型可以包括第一生成模型和第一判别模型,目标生成模型是训练后的第一生成模型,目标生成模型的参数是通过固定第一判别模型的参数,通过第一损失值更新第一生成模型的参数获取的,第一损失值还用于固定第一生成模型的参数时,更新第一判别模型,第一损失值是根据第一判别结果获取的,第一判别结果是根据第一组合训练数据以及目标域有标签训练数据输入至第一判别模型中获取的,第一组合数据是将源域训练数据未携带的标签和第一生成结果进行组合后获取的,第一生成结果是将未携带标签的源域训练数据输入至第一生成模型后获取的。
应当理解,上述仅为本申请实施例提供的一个例子,并且,模型训练的装置/数据生成装置可具有比示出的部件更多或更少的部件,可以组合两个或更多个部件,或者可具有部件的不同配置实现。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。
本申请实施例提供的模型训练的装置/数据生成装置可以为芯片,芯片包括:处理单元和通信单元,所述处理单元例如可以是处理器,所述通信单元例如可以是输入/输出接口、管脚或电路等。该模型训练的装置为芯片时,该处理单元可执行存储单元存储的计算机执行指令,以使芯片执行上述图4、图6、图8、图10、图12所示实施例描述的训练模型的方法。在另一个可能的实施方式中,该数据生成装置为芯片时,以使芯片执行上述图14所示实施例描述的数据生成的方法。可选地,所述存储单元为所述芯片内的存储单元,如寄存器、缓存等,所述存储单元还可以是所述无线接入设备端内的位于所述芯片外部的存储单元,如只读存储器(read-only memory,ROM)或可存储静态信息和指令的其他类型的静态存储设备,随机存取存储器(random access memory,RAM)等。
具体地,前述的处理单元或者处理器可以是中央处理器(central processingunit,CPU)、神经网络处理器(neural-network processing unit,NPU)、图形处理器(graphics processing unit,GPU)、数字信号处理器(digital signal processor,DSP)、专用集成电路(application specific integrated circuit,ASIC)或现场可编程逻辑门阵列(field programmable gate array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者也可以是任何常规的处理器等。
具体的,请参阅图18,图18为本申请实施例提供的芯片的一种结构示意图,所述芯片可以表现为神经网络处理器NPU180,NPU180作为协处理器挂载到主CPU(Host CPU)上,由Host CPU分配任务。NPU的核心部分为运算电路1803,通过控制器1804控制运算电路1803提取存储器中的矩阵数据并进行乘法运算。
在一些实现中,运算电路1803内部包括多个处理单元(process engine,PE)。在一些实现中,运算电路1803是二维脉动阵列。运算电路1803还可以是一维脉动阵列或者能够执行例如乘法和加法这样的数学运算的其它电子线路。在一些实现中,运算电路1803是通用的矩阵处理器。
举例来说,假设有输入矩阵A,权重矩阵B,输出矩阵C。运算电路从权重存储器1802中取矩阵B相应的数据,并缓存在运算电路中每一个PE上。运算电路从输入存储器1801中取矩阵A数据与矩阵B进行矩阵运算,得到的矩阵的部分结果或最终结果,保存在累加器(accumulator)1808中。
统一存储器1806用于存放输入数据以及输出数据。权重数据直接通过存储单元访问控制器(direct memory access controller,DMAC)1805,DMAC被搬运到权重存储器1802中。输入数据也通过DMAC被搬运到统一存储器1806中。
总线接口单元(bus interface unit,BIU)1810,用于AXI总线与DMAC和取指存储器(Instruction Fetch Buffer,IFB)1809的交互。
总线接口单元1810(bus interface unit,BIU),用于取指存储器1809从外部存储器获取指令,还用于存储单元访问控制器1805从外部存储器获取输入矩阵A或者权重矩阵B的原数据。
DMAC主要用于将外部存储器DDR中的输入数据搬运到统一存储器1806或将权重数据搬运到权重存储器1802中或将输入数据数据搬运到输入存储器1801中。
向量计算单元1807包括多个运算处理单元,在需要的情况下,对运算电路的输出做进一步处理,如向量乘,向量加,指数运算,对数运算,大小比较等等。主要用于神经网络中非卷积/全连接层网络计算,如批归一化(batch normalization),像素级求和,对特征平面进行上采样等。
在一些实现中,向量计算单元1807能将经处理的输出的向量存储到统一存储器1806。例如,向量计算单元1807可以将线性函数和/或非线性函数应用到运算电路1803的输出,例如对卷积层提取的特征平面进行线性插值,再例如累加值的向量,用以生成激活值。在一些实现中,向量计算单元1807生成归一化的值、像素级求和的值,或二者均有。在一些实现中,处理过的输出的向量能够用作到运算电路1803的激活输入,例如用于在神经网络中的后续层中的使用。
控制器1804连接的取指存储器(instruction fetch buffer)1809,用于存储控制器1804使用的指令。
统一存储器1806,输入存储器1801,权重存储器1802以及取指存储器1809均为On-Chip存储器。外部存储器私有于该NPU硬件架构。
其中,循环神经网络中各层的运算可以由运算电路1803或向量计算单元1807执行。
其中,上述任一处提到的处理器,可以是一个通用中央处理器,微处理器,ASIC,或一个或多个用于控制上述图4、图6、图8、图10、图12的方法的程序执行的集成电路,或者在另一个可能的实施方式中,控制上述图14的方法的程序执行的集成电路。
另外需说明的是,以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。另外,本申请提供的装置实施例附图中,模块之间的连接关系表示它们之间具有通信连接,具体可以实现为一条或多条通信总线或信号线。
通过以上的实施方式的描述,所属领域的技术人员可以清楚地了解到本申请可借助软件加必需的通用硬件的方式来实现,当然也可以通过专用硬件包括专用集成电路、专用CPU、专用存储器、专用元器件等来实现。一般情况下,凡由计算机程序完成的功能都可以很容易地用相应的硬件来实现,而且,用来实现同一功能的具体硬件结构也可以是多种多样的,例如模拟电路、数字电路或专用电路等。但是,对本申请而言更多情况下软件程序实现是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在可读取的存储介质中,如计算机的软盘、U盘、移动硬盘、只读存储器(read only memory,ROM)、随机存取存储器(random access memory,RAM)、磁碟或者光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述的方法。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。
本申请实施例中还提供一种计算机可读存储介质,该计算机可读存储介质中存储有用于训练模型的程序,当其在计算机上运行时,使得计算机执行如前述图9或图10所示实施例描述的方法中的步骤。
本申请实施例中还提供一种计算机可读存储介质,该计算机可读存储介质中存储有用于数据处理的程序,当其在计算机上运行时,使得计算机执行如前述图4、图6、图8、图10、图12所示实施例描述的方法中的步骤。或者使得计算机执行如前述图14所示实施例描述的方法中的步骤。
本申请实施例还提供一种数字处理芯片。该数字处理芯片中集成了用于实现上述处理器,或者处理器的功能的电路和一个或者多个接口。当该数字处理芯片中集成了存储器时,该数字处理芯片可以完成前述实施例中的任一个或多个实施例的方法步骤。当该数字处理芯片中未集成存储器时,可以通过通信接口与外置的存储器连接。该数字处理芯片根据外置的存储器中存储的程序代码来实现上述实施例中模型训练的装置/数据生成装置执行的动作。
本申请实施例中还提供一种计算机程序产品,所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存储的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘Solid State Disk(SSD))等。
本领域普通技术人员可以理解上述实施例的各种方法中的全部或部分步骤是可以通过程序来指令相关的硬件来完成,该程序可以存储于一计算机可读存储介质中,存储介质可以包括:ROM、RAM、磁盘或光盘等。
以上对本申请实施例所提供的模型的训练方法、数据处理方法以及相关设备进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。
本申请的说明书和权利要求书及上述附图中的术语“第一”,“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的实施例能够以除了在这里图示或描述的内容以外的顺序实施。本申请中术语“和/或”,仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况,另外,本文中字符“/”,一般表示前后关联对象是一种“或”的关系。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或模块的过程,方法,系统,产品或设备不必限于清楚地列出的那些步骤或模块,而是可包括没有清楚地列出的或对于这些过程,方法,产品或设备固有的其它步骤或模块。在本申请中出现的对步骤进行的命名或者编号,并不意味着必须按照命名或者编号所指示的时间/逻辑先后顺序执行方法流程中的步骤,已经命名或者编号的流程步骤可以根据要实现的技术目的变更执行次序,只要能达到相同或者相类似的技术效果即可。本申请中所出现的模块的划分,是一种逻辑上的划分,实际应用中实现时可以有另外的划分方式,例如多个模块可以结合成或集成在另一个系统中,或一些特征可以忽略,或不执行,另外,所显示的或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些端口,模块之间的间接耦合或通信连接可以是电性或其他类似的形式,本申请中均不作限定。并且,作为分离部件说明的模块或子模块可以是也可以不是物理上的分离,可以是也可以不是物理模块,或者可以分布到多个电路模块中,可以根据实际的需要选择其中的部分或全部模块来实现本申请方案的目的。
Claims (32)
1.一种模型训练的方法,其特征在于,包括:
根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,所述目标模型包括第一生成模型和第一判别模型,其中,每次迭代训练,包括:
将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果;
将所述源域数据未携带的标签和所述第一生成结果进行组合,以获取第一组合数据;
将所述第一组合数据以及所述目标域有标签数据输入至所述第一判别模型中,以输出第一判别结果;
根据所述第一判别结果获取第一损失值;
固定所述第一生成模型的参数,根据所述第一损失值更新所述第一判别模型,或者固定所述第一判别模型的参数,根据所述第一损失值更新所述第一生成模型。
2.根据权利要求1所述的方法,其特征在于,所述目标模型还包括第二生成模型,所述方法还包括:
将所述第一生成结果输入至第二生成模型,以输出第二生成结果;
根据所述第二生成结果和所述未携带标签的源域数据之间的差异获取第二损失值;
所述根据所述第一损失值更新所述第一生成模型,包括:
根据所述第一损失值和所述第二损失值更新所述第一生成模型,所述第二损失值还用于更新所述第二生成模型。
3.根据权利要求1或2所述的方法,其特征在于,所述根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,包括:
根据所述未携带标签的源域数据、目标域有标签数据以及目标域无标签数据对目标模型进行迭代训练。
4.根据权利要求1至3任一项所述的方法,其特征在于,所述目标模型还包括第二判别模型,所述每次迭代训练还包括:
将所述第一生成结果和预设标签进行组合,以获取第二组合数据;
将目标域无标签数据和所述预设标签进行组合,以获取第三组合数据;
将所述第二组合数据以及所述第三组合数据输入至第二判别模型中,以输出第二判别结果;
根据所述第二判别结果获取第三损失值;
所述根据所述第一损失值更新所述第一判别模型,包括:
根据所述第一损失值和所述第三损失值更新所述第一判别模型;
所述根据所述第一损失值更新所述第一生成模型,包括:
根据所述第一损失值和所述第三损失值更新所述第一生成模型。
5.根据权利要求2至4任一项所述的方法,其特征在于,所述目标模型还包括第一分类模型,所述每次迭代训练还包括:
将所述目标域无标签数据输入至第一分类模型中,以输出第一预测标签;
将所述目标域无标签数据和所述第一预测标签进行组合,以获取第四组合数据;
将所述第四组合数据输入至所述第一判别模型中,以输出第三判别结果;
根据所述第三判别结果获取第四损失值;
所述根据所述第一损失值更新所述第一判别模型,包括:
固定所述第一生成模型的参数和所述第一分类模型的参数,根据所述第一损失值和所述第四损失值更新所述第一判别模型;
所述每次迭代训练还包括:
固定所述第一判别模型的参数,根据所述第四损失值更新所述第一分类模型。
6.根据权利要求5所述的方法,其特征在于,所述目标模型还包括第二分类模型,所述每次迭代训练还包括:
将所述第一生成结果和所述目标域有标签数据输入至第二分类模型中,以输出第二预测标签;
根据所述第二预测标签获取第五损失值;
所述根据所述第四损失值更新所述第一分类模型,包括:
根据所述第四损失值和所述第五损失值更新所述第一分类模型,所述第四损失值和所述第五损失值还用于更新所述第二分类模型。
7.一种数据生成的方法,其特征在于,包括:
获取未携带标签的源域数据;
将所述未携带标签的源域数据输入至目标生成模型中,以获取目标域数据;
其中,所述目标生成模型是通过未携带标签的源域训练数据、目标域有标签训练数据对目标模型进行迭代训练后获取的,所述目标模型包括第一生成模型和第一判别模型,所述目标生成模型是训练后的所述第一生成模型,所述目标生成模型的参数是通过固定所述第一判别模型的参数,通过第一损失值更新所述第一生成模型的参数获取的,所述第一损失值还用于固定所述第一生成模型的参数时,更新所述第一判别模型,所述第一损失值是根据第一判别结果获取的,所述第一判别结果是根据第一组合训练数据以及所述目标域有标签训练数据输入至所述第一判别模型中获取的,所述第一组合数据是将所述源域训练数据未携带的标签和所述第一生成结果进行组合后获取的,所述第一生成结果是将所述未携带标签的源域训练数据输入至所述第一生成模型后获取的。
8.根据权利要求7所述的方法,其特征在于,所述目标模型还包括第二生成模型,所述目标生成模型的参数具体是通过固定所述第一判别模型的参数,通过第一损失值和第二损失值更新所述第一生成模型的参数获取的,所述第二损失值是根据第二生成结果和所述未携带标签的源域训练数据之间的差异获取的,所述第二生成结果是将所述第一生成结果输入至所述第二生成模型获取的,所述第二损失值还用于更新所述第二生成模型。
9.根据权利要求7或8所述的方法,其特征在于,所述目标生成模型具体是通过未携带标签的源域训练数据、目标域有标签训练数据以及目标域无标签训练数据对目标模型进行迭代训练后获取的。
10.根据权利要求7至9任一项所述的方法,其特征在于,所述目标模型还包括第二判别模型,所述目标生成模型的参数具体是通过固定所述第一判别模型的参数,所述第一损失值和第三损失值更新所述第一生成模型的参数获取的,所述第三损失值是通过第二判别结果获取的,所述第二判别结果是将第二组合数据以及第三组合数据输入至所述第二判别模型中获取的,所述第二组合数据是将所述第一生成结果和预设标签进行组合后获取的,所述第三组合数据是将所述标域无标签训练数据和所述预设标签进行组合后获取的,所述第一损失值和所述第三损失值还用于更新所述第一判别模型。
11.根据权利要求8至10任一项所述的方法,其特征在于,还包括:
将所述目标域训练数据输入至目标分类模型中,以获取预测结果,其中,所述目标分类模型是固定所述第一判别模型的参数,通过第四损失值更新第一分类模型获取的,所述第四损失值是通过第三判别结果获取的,所述第三判别结果是通过将第四组合数据输入至所述第一判别模型中获取的,所述第四组合数据是将所述目标域无标签训练数据和所述第一预测标签进行组合后获取的,所述第一预测标签是将所述目标域无标签训练数据输入至所述第一分类模型中获取的,所述第四损失值还用于固定所述第一生成模型的参数和所述第一分类模型的参数时,更新所述第一判别模型。
12.根据权利要求11所述的方法,其特征在于,所述目标模型还包括第二分类模型,所述目标分类模型具体是固定所述第一判别模型的参数,通过第四损失值和第五损失值更新所述第一分类模型获取的,所述第五损失值是通过第二预测标签获取的,所述第二预测标签是通过将所述第一生成结果和所述目标域有标签训练数据输入至第二分类模型中获取的,所述第四损失值和所述第五损失值还用于更新所述第二分类模型。
13.一种模型训练的装置,其特征在于,包括:
训练模块,用于根据未携带标签的源域数据、目标域有标签数据对目标模型进行迭代训练,所述目标模型包括第一生成模型和第一判别模型,所述训练模块包括输入模块、组合模块、获取模块以及更新模块,其中,每次迭代训练时,
所述输入模块,用于将未携带标签的源域数据输入至第一生成模型,以输出第一生成结果;
所述组合模块,用于将所述源域数据未携带的标签和所述第一生成结果进行组合,以获取第一组合数据;
所述输入模块,还用于将所述第一组合数据以及所述目标域有标签数据输入至所述第一判别模型中,以输出第一判别结果;
所述获取模块,用于根据所述第一判别结果获取第一损失值;
所述更新模块,用于固定所述第一生成模型的参数,根据所述第一损失值更新所述第一判别模型,或者固定所述第一判别模型的参数,根据所述第一损失值更新所述第一生成模型。
14.根据权利要求13所述的装置,其特征在于,所述目标模型还包括第二生成模型,
所述输入模块,还用于将所述第一生成结果输入至第二生成模型,以输出第二生成结果;
所述获取模块,还用于根据所述第二生成结果和所述未携带标签的源域数据之间的差异获取第二损失值;
所述更新模块,具体用于根据所述第一损失值和所述第二损失值更新所述第一生成模型,所述第二损失值还用于更新所述第二生成模型。
15.根据权利要求13或14所述的装置,其特征在于,所述训练模块,具体用于:
根据所述未携带标签的源域数据、目标域有标签数据以及目标域无标签数据对目标模型进行迭代训练。
16.根据权利要求13至15任一项所述的装置,其特征在于,所述目标模型还包括第二判别模型,
所述组合模块,还用于将所述第一生成结果和预设标签进行组合,以获取第二组合数据,将目标域无标签数据和所述预设标签进行组合,以获取第三组合数据;
所述输入模块,还用于将所述第二组合数据以及所述第三组合数据输入至第二判别模型中,以输出第二判别结果;
所述获取模块,还用于根据所述第二判别结果获取第三损失值;
所述更新模块,具体用于根据所述第一损失值和所述第三损失值更新所述第一判别模型;
所述更新模块,具体用于根据所述第一损失值和所述第三损失值更新所述第一生成模型。
17.根据权利要求14至16任一项所述的装置,其特征在于,所述目标模型还包括第一分类模型,
所述输入模块,还用于将所述目标域无标签数据输入至第一分类模型中,以输出第一预测标签;
所述组合模块,还用于将所述目标域无标签数据和所述第一预测标签进行组合,以获取第四组合数据;
所述输入模块,还用于将所述第四组合数据输入至所述第一判别模型中,以输出第三判别结果;
所述获取模块,还用于根据所述第三判别结果获取第四损失值;
所述更新模块,具体用于固定所述第一生成模型的参数和所述第一分类模型的参数,根据所述第一损失值和所述第四损失值更新所述第一判别模型;
所述更新模块,还用于固定所述第一判别模型的参数,根据所述第四损失值更新所述第一分类模型。
18.根据权利要求17所述的装置,其特征在于,所述目标模型还包括第二分类模型,
所述输入模块,还用于将所述第一生成结果和所述目标域有标签数据输入至第二分类模型中,以输出第二预测标签;
所述获取模块,还用于根据所述第二预测标签获取第五损失值;
所述更新模块,具体用于根据所述第四损失值和所述第五损失值更新所述第一分类模型,所述第四损失值和所述第五损失值还用于更新所述第二分类模型。
19.一种数据生成的装置,其特征在于,包括:
获取模块,用于获取未携带标签的源域数据;
生成模块,用于将所述未携带标签的源域数据输入至目标生成模型中,以获取目标域数据;
其中,所述目标生成模型是通过未携带标签的源域训练数据、目标域有标签训练数据对目标模型进行迭代训练后获取的,所述目标模型包括第一生成模型和第一判别模型,所述目标生成模型是训练后的所述第一生成模型,所述目标生成模型的参数是通过固定所述第一判别模型的参数,通过第一损失值更新所述第一生成模型的参数获取的,所述第一损失值还用于固定所述第一生成模型的参数时,更新所述第一判别模型,所述第一损失值是根据第一判别结果获取的,所述第一判别结果是根据第一组合训练数据以及所述目标域有标签训练数据输入至所述第一判别模型中获取的,所述第一组合数据是将所述源域训练数据未携带的标签和所述第一生成结果进行组合后获取的,所述第一生成结果是将所述未携带标签的源域训练数据输入至所述第一生成模型后获取的。
20.根据权利要求19所述的装置,其特征在于,所述目标模型还包括第二生成模型,所述目标生成模型的参数具体是通过固定所述第一判别模型的参数,通过第一损失值和第二损失值更新所述第一生成模型的参数获取的,所述第二损失值是根据第二生成结果和所述未携带标签的源域训练数据之间的差异获取的,所述第二生成结果是将所述第一生成结果输入至所述第二生成模型获取的,所述第二损失值还用于更新所述第二生成模型。
21.根据权利要求19或20所述的装置,其特征在于,所述目标生成模型具体是通过未携带标签的源域训练数据、目标域有标签训练数据以及目标域无标签训练数据对目标模型进行迭代训练后获取的。
22.根据权利要求19至21任一项所述的装置,其特征在于,所述目标模型还包括第二判别模型,所述目标生成模型的参数具体是通过固定所述第一判别模型的参数,所述第一损失值和第三损失值更新所述第一生成模型的参数获取的,所述第三损失值是通过第二判别结果获取的,所述第二判别结果是将第二组合数据以及第三组合数据输入至所述第二判别模型中获取的,所述第二组合数据是将所述第一生成结果和预设标签进行组合后获取的,所述第三组合数据是将所述标域无标签训练数据和所述预设标签进行组合后获取的,所述第一损失值和所述第三损失值还用于更新所述第一判别模型。
23.根据权利要求20至22任一项所述的装置,其特征在于,还包括分类模块,
所述分类模块,用于将所述目标域训练数据输入至目标分类模型中,以获取预测结果,其中,所述目标分类模型是固定所述第一判别模型的参数,通过第四损失值更新第一分类模型获取的,所述第四损失值是通过第三判别结果获取的,所述第三判别结果是通过将第四组合数据输入至所述第一判别模型中获取的,所述第四组合数据是将所述目标域无标签训练数据和所述第一预测标签进行组合后获取的,所述第一预测标签是将所述目标域无标签训练数据输入至所述第一分类模型中获取的,所述第四损失值还用于固定所述第一生成模型的参数和所述第一分类模型的参数时,更新所述第一判别模型。
24.根据权利要求23所述的装置,其特征在于,所述目标模型还包括第二分类模型,所述目标分类模型具体是固定所述第一判别模型的参数,通过第四损失值和第五损失值更新所述第一分类模型获取的,所述第五损失值是通过第二预测标签获取的,所述第二预测标签是通过将所述第一生成结果和所述目标域有标签训练数据输入至第二分类模型中获取的,所述第四损失值和所述第五损失值还用于更新所述第二分类模型。
25.一种模型训练的装置,其特征在于,包括:
存储器,用于存储计算机可读指令;
还包括,与所述存储器耦合的处理器,用于执行所述存储器中的计算机可读指令从而执行如权利要求1至6任一项所描述的方法。
26.一种数据生成的装置,其特征在于,包括:
存储器,用于存储计算机可读指令;
还包括,与所述存储器耦合的处理器,用于执行所述存储器中的计算机可读指令从而执行如权利要求7至12任一项所描述的方法。
27.一种计算机可读存储介质,其特征在于,当指令在计算机装置上运行时,使得所述计算机装置执行如权利要求1至6任一项所描述的方法。
28.一种计算机可读存储介质,其特征在于,当指令在计算机装置上运行时,使得所述计算机装置执行如权利要求7至12任一项所描述的方法。
29.一种计算机程序产品,当在计算机上运行时,使得计算机可以执行如权利要求1至6任一所描述的方法。
30.一种计算机程序产品,当在计算机上运行时,使得计算机可以执行如权利要求7至12任一所描述的方法。
31.一种芯片,其特征在于,所述芯片与存储器耦合,用于执行所述存储器中存储的程序,以执行如权利要求1至6任一项所述的方法。
32.一种芯片,其特征在于,所述芯片与存储器耦合,用于执行所述存储器中存储的程序,以执行如权利要求7至12任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011567739.2A CN112633385A (zh) | 2020-12-25 | 2020-12-25 | 一种模型训练的方法、数据生成的方法以及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011567739.2A CN112633385A (zh) | 2020-12-25 | 2020-12-25 | 一种模型训练的方法、数据生成的方法以及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112633385A true CN112633385A (zh) | 2021-04-09 |
Family
ID=75325446
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011567739.2A Pending CN112633385A (zh) | 2020-12-25 | 2020-12-25 | 一种模型训练的方法、数据生成的方法以及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112633385A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115761239A (zh) * | 2023-01-09 | 2023-03-07 | 深圳思谋信息科技有限公司 | 一种语义分割方法及相关装置 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108460415A (zh) * | 2018-02-28 | 2018-08-28 | 国信优易数据有限公司 | 伪标签生成模型训练方法及伪标签生成方法 |
CN108898218A (zh) * | 2018-05-24 | 2018-11-27 | 阿里巴巴集团控股有限公司 | 一种神经网络模型的训练方法、装置、及计算机设备 |
CN110148142A (zh) * | 2019-05-27 | 2019-08-20 | 腾讯科技(深圳)有限公司 | 图像分割模型的训练方法、装置、设备和存储介质 |
CN111477212A (zh) * | 2019-01-04 | 2020-07-31 | 阿里巴巴集团控股有限公司 | 内容识别、模型训练、数据处理方法、系统及设备 |
-
2020
- 2020-12-25 CN CN202011567739.2A patent/CN112633385A/zh active Pending
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108460415A (zh) * | 2018-02-28 | 2018-08-28 | 国信优易数据有限公司 | 伪标签生成模型训练方法及伪标签生成方法 |
CN108898218A (zh) * | 2018-05-24 | 2018-11-27 | 阿里巴巴集团控股有限公司 | 一种神经网络模型的训练方法、装置、及计算机设备 |
CN111477212A (zh) * | 2019-01-04 | 2020-07-31 | 阿里巴巴集团控股有限公司 | 内容识别、模型训练、数据处理方法、系统及设备 |
CN110148142A (zh) * | 2019-05-27 | 2019-08-20 | 腾讯科技(深圳)有限公司 | 图像分割模型的训练方法、装置、设备和存储介质 |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115761239A (zh) * | 2023-01-09 | 2023-03-07 | 深圳思谋信息科技有限公司 | 一种语义分割方法及相关装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Oyedotun et al. | Deep learning in vision-based static hand gesture recognition | |
CN111797893B (zh) | 一种神经网络的训练方法、图像分类系统及相关设备 | |
WO2022012407A1 (zh) | 一种用于神经网络的训练方法以及相关设备 | |
WO2022068623A1 (zh) | 一种模型训练方法及相关设备 | |
JP7178513B2 (ja) | ディープラーニングに基づく中国語単語分割方法、装置、記憶媒体及びコンピュータ機器 | |
Imani et al. | Fach: Fpga-based acceleration of hyperdimensional computing by reducing computational complexity | |
CN113449858A (zh) | 一种神经网络模型的处理方法以及相关设备 | |
CN113139664B (zh) | 一种跨模态的迁移学习方法 | |
CN108171328B (zh) | 一种神经网络处理器和采用其执行的卷积运算方法 | |
WO2022012668A1 (zh) | 一种训练集处理方法和装置 | |
CN110968235B (zh) | 信号处理装置及相关产品 | |
CN114925320B (zh) | 一种数据处理方法及相关装置 | |
US11847555B2 (en) | Constraining neural networks for robustness through alternative encoding | |
CN112633385A (zh) | 一种模型训练的方法、数据生成的方法以及装置 | |
CN116401552A (zh) | 一种分类模型的训练方法及相关装置 | |
Wu et al. | A novel method of data and feature enhancement for few-shot image classification | |
Vishnu et al. | Mobile application-based virtual assistant using deep learning | |
Kriete et al. | Models of Cognition: Neurological Possiblity Does Not Indicate Neurological Plausibility | |
Luenemann et al. | Capturing neural-networks as synchronous dataflow graphs | |
CN112817560B (zh) | 一种基于表函数的计算任务处理方法、系统及计算机可读存储介质 | |
CN116560731A (zh) | 一种数据处理方法及其相关装置 | |
CN114462526A (zh) | 一种分类模型训练方法、装置、计算机设备及存储介质 | |
CN110069770B (zh) | 一种数据处理系统、方法及计算机设备 | |
JP7103987B2 (ja) | 情報処理装置、情報処理方法、及びプログラム | |
Zhou et al. | Efficient image evidence analysis of cnn classification results |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |