CN108256561B - 一种基于对抗学习的多源域适应迁移方法及系统 - Google Patents

一种基于对抗学习的多源域适应迁移方法及系统 Download PDF

Info

Publication number
CN108256561B
CN108256561B CN201711468680.XA CN201711468680A CN108256561B CN 108256561 B CN108256561 B CN 108256561B CN 201711468680 A CN201711468680 A CN 201711468680A CN 108256561 B CN108256561 B CN 108256561B
Authority
CN
China
Prior art keywords
domain
target
source domain
path
source
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN201711468680.XA
Other languages
English (en)
Other versions
CN108256561A (zh
Inventor
林倞
陈子良
王可泽
许瑞嘉
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
National Sun Yat Sen University
Original Assignee
National Sun Yat Sen University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by National Sun Yat Sen University filed Critical National Sun Yat Sen University
Priority to CN201711468680.XA priority Critical patent/CN108256561B/zh
Publication of CN108256561A publication Critical patent/CN108256561A/zh
Application granted granted Critical
Publication of CN108256561B publication Critical patent/CN108256561B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2148Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the process organisation or structure, e.g. boosting cascade

Abstract

本发明公开了一种基于对抗学习的多源域适应迁移方法及系统,所述方法包括如下步骤:步骤一,使用各源域数据进行预训练并初始化目标模型的表示网络和分类器;步骤二,使用多源域数据与目标域数据进行多路对抗,更新目标模型的表示网络和多路判别器;步骤三,计算每个源域与目标域之间的对抗分数;步骤四,基于各源域的分类器和对抗分数对目标域进行分类;步骤五,选取高置信度的目标域伪样本微调目标模型的表示网络和分类器;步骤六,返回步骤二,进行步骤二‑五,直至模型收敛或达到最大迭代次数时停止训练,本发明可不再依赖单一源域标签集合与目标域一致的假设,并且可有效地避免多源域适应过程中存在的负迁移现象。

Description

一种基于对抗学习的多源域适应迁移方法及系统
技术领域
本发明涉及机器学习技术领域,特别是涉及一种基于对抗学习的多源域适应迁移方法及系统。
背景技术
随着大规模数据的不断产生和依靠人力进行信息标注的困难,域适应迁移方法逐渐成为机器学习领域中一项非常重要的研究课题。域适应学习旨在适配不同领域数据间的特征分布,提升不同领域间分类器迁移后的性能表现,解决目标域数据缺乏标注信息的难题。域适应迁移方法同时也是工业界的一项关键技术手段,在人脸识别、自动驾驶、医学影像等诸多领域均有重要应用。
目前,绝大部分的域适应学习方法主要关注在单一源域的迁移过程上,并依赖于单一源域标签集合与目标域一致的假设。Yaroslav Ganin等人在文献“Domain-Adversarial Training of Neural Networks”(Journal of Machine LearningResearch,2016,17(59):1-35)中公开了一种针对图像分类的单源域适应方法,其通过引入域间分类器对源域与目标域图像的特征分布进行对抗学习,得到一种领域无关的特征表示,提高目标域图像在迁移后的分类性能。然而,该类方法在现实场景中缺乏通用性,而且无法处理源域数据标签空间与目标域不一致的情形。
此外,Hongfu Liu等人在文献“Structure-Preserved Multi-source DomainAdaptation”(In IEEE 16th International Conference on Data Mining(ICDM),pages1059–1064.IEEE,2016)中提出一种保持多源域数据整体结构的方法进行目标任务的迁移,但该类方法往往忽略了不同领域数据间的差异性,无法避免多源域适应中存在的负迁移现象。
发明内容
为克服上述现有技术存在的不足,本发明之目的在于提供一种基于对抗学习的多源域适应迁移方法及系统,以将现有的一类基于对抗学习的单源域适应过程推广到多源域适应,不再依赖单一源域标签集合与目标域一致的假设,并且可有效地避免多源域适应过程中存在的负迁移现象。
为达上述及其它目的,本发明提出一种基于对抗学习的多源域适应迁移方法,包括如下步骤:
步骤一,使用各源域数据进行预训练并初始化目标模型的表示网络和分类器;
步骤二,使用多源域数据与目标域数据进行多路对抗,更新目标模型的表示网络和多路判别器;
步骤三,计算每个源域与目标域之间的对抗分数;
步骤四,基于各源域的分类器和对抗分数对目标域进行分类;
步骤五,选取高置信度的目标域伪样本微调目标模型的表示网络和分类器;
步骤六,返回步骤二,进行步骤二-五,直至模型收敛或达到最大迭代次数时停止训练。
进一步地,步骤一进一步包括:
输入带标记的N个源域数据集以及输入无标记的目标域数据集;
使用所有的源域数据集对领域无关的表示网络F和领域相关的多路分类器C进行目标模型的预训练。
进一步地,所述使用所有的源域数据集对领域无关的表示网络F和领域相关的多路分类器C进行目标模型的预训练的步骤具体为根据如下优化目标:
Figure BDA0001531599170000031
更新目标模型中表示网络和多路分类器的参数,其中
Figure BDA0001531599170000032
表示多路分类的损失函数,
Figure BDA0001531599170000033
表示具体选取的损失函数类型,
Figure BDA0001531599170000034
表示第sj路分类器,E表示所有样本损失值的期望,F(x)表示图像x经过表示网络F后的特征编码。
进一步地,步骤二进一步包括:
使用表示网络对多源域和目标域的图像进行特征提取;
将每一源域和目标域分别组成一对,输入多路判别器网络D进行判定训练,更新目标模型的表示网络和多路判别器。
进一步地,所述多路判别器网络D的更新策略为尽可能区分开输入特征是来自源域还是目标域,表示网络的更新策略是尽可能混淆特征,使得判别器网络无法区分输入特征是来自源域还是目标域。
进一步地,于步骤二中,更新多路判别器和表示网络的损失函数使用其最小二乘表示进行优化。
进一步地,于步骤三中,累加每一路判别器的损失值作为对应源域与目标域的对抗分数。
进一步地,于步骤四中,根据步骤三获得的对抗分数以及目标模型的表示网络F和多路分类器C对目标域的样本进行分类,并赋予伪标签。
进一步地,于步骤五中,在步骤四的基础上选取置信度大于设定阈值的样本组成目标域伪样本集合,并对目标模型的多路分类器进行微调,以获取在目标域上更加有效可分的特征编码。
为达到上述目的,本发明还提供一种基于对抗学习的多源域适应迁移系统,包括:
预训练单元,用于使用各源域数据进行预训练并初始化目标模型的表示网络和分类器;
多路对抗单元,用于使用多源域数据与目标域数据进行多路对抗,更新目标模型的表示网络和多路判别器;
对抗分数计算单元,用于计算每个源域与目标域之间的对抗分数;
分类单元,用于基于各源域的分类器和对抗分数对目标域进行分类;
微调单元,用于选取高置信度的目标域伪样本微调目标模型的表示网络和分类器,并返回所述多路对抗单元进行训练,直至模型收敛或达到最大迭代次数时停止训练。
与现有技术相比,本发明将现有的单源域适应过程推广到多源域适应,使之不再依赖单一源域标签集合与目标域一致的假设,在现实场景中具有更强的通用性。此外,由于本发明基于对抗学习对不同领域间的特征进行适配,有效避免了负迁移现象的产生,比较明显地提升了域适应后的分类性能。
附图说明
图1为本发明一种基于对抗学习的多源域适应迁移方法的步骤流程图。
图2为本发明具体实施例以两个源域为例的基于对抗学习的多源域适应迁移方法的流程图。
图3为本发明具体实施例以两个源域为例的网络框架示意图。
图4为本发明具体实施例中展示两个源域(A、D)迁移到目标域(W)在域适应前后的可视化效果图。
图5为本发明一种基于对抗学习的多源域适应迁移系统的系统架构图。
具体实施方式
以下通过特定的具体实例并结合附图说明本发明的实施方式,本领域技术人员可由本说明书所揭示的内容轻易地了解本发明的其它优点与功效。本发明亦可通过其它不同的具体实例加以施行或应用,本说明书中的各项细节亦可基于不同观点与应用,在不背离本发明的精神下进行各种修饰与变更。
图1为本发明一种基于对抗学习的多源域适应迁移方法的步骤流程图,图2为本发明具体实施例之基于对抗学习的多源域适应迁移方法的流程图。如图1及图2所示,本发明一种基于对抗学习的多源域适应迁移方法,包括如下步骤:
步骤101,使用各源域数据进行预训练并初始化目标模型的表示网络和分类器。
具体地,步骤101进一步包括:
步骤S100,输入带标记的N个源域数据集,其分布表示为
Figure BDA0001531599170000062
其中sj表示第j个源域,x和y分别表示样本图像和对应标签。假定各源域的数据集合
Figure BDA0001531599170000063
采样自不同的分布,其中
Figure BDA0001531599170000064
Figure BDA0001531599170000065
分别表示来自源域sj的图像和对应标签,同时,输入无标记的目标域数据集,其分布记为pt(x,y),对应图像集合记为
Figure BDA0001531599170000066
在本发明具体实施例中,以两个源域为例,即输入源域S1和S2的图像和对应标签,输入目标域T的图像;
步骤S101,使用所有的源域数据集对领域无关的表示网络F和领域相关的多路分类器C进行目标模型的预训练,即根据如下优化目标更新目标模型中表示网络F和多路分类器C的参数:
Figure BDA0001531599170000061
其中
Figure BDA0001531599170000067
表示多路分类的损失函数,而
Figure BDA0001531599170000068
表示具体选取的损失函数类型,
Figure BDA0001531599170000069
表示第sj路分类器,E表示所有样本损失值的期望,F(x)表示图像x经过表示网络F后的特征编码。
在本发明具体实施例中,所述多源域数据的标签集合的交集总和等于目标域的标签集合,即
Figure BDA0001531599170000073
步骤102,使用多源域数据与目标域数据进行多路对抗,更新目标模型的表示网络和多路判别器。具体地,固定当前多路分类器C的参数,引入目标域图像数据进行多路对抗,步骤102进一步包括:
步骤S200,使用表示网络F对多源域和目标域的图像进行特征提取,在本发明具体实施例中,得到源域S1、S2以及目标域T的特征表示;
步骤S201,将每一源域sj和目标域t分别组成一对,例如S1和T,S2和T,输入多路判别器网络D进行判定训练,更新目标模型的表示网络和多路判别器。在本发明具体实施例中,多路判别器网络D的更新策略是尽可能区分开输入特征是来自源域还是目标域;而表示网络的更新策略是尽可能混淆特征,使得判别器网络无法区分输入特征是来自源域还是目标域。这一对抗过程使用公式表示如下:
Figure BDA0001531599170000071
其中分类损失函数
Figure BDA0001531599170000074
如公式(1)中所示(但是分类器C的参数不更新),而对抗损失函数
Figure BDA0001531599170000075
表示为:
Figure BDA0001531599170000072
其中
Figure BDA0001531599170000076
表示第sj路判别器,E表示对应损失值的期望,F(x)表示图像x经过表示网络F后的特征编码。
优选地,于步骤S201中,多路对抗过程将回传困难样本的梯度用于更新目标模型的表示网络F。具体地,在所有源域
Figure BDA0001531599170000083
中选择
Figure BDA0001531599170000084
使得
Figure BDA0001531599170000081
并回传源域
Figure BDA0001531599170000085
与目标域的对抗损失更新表示网络,其中M是当前迭代中的样本数量。
优选地,为使对抗的训练过程稳定,上述步骤102中更新多路判别器和表示网络的损失函数使用其最小二乘表示进行优化,即使用如下函数:
Figure BDA0001531599170000086
优化多路判别器,使用
Figure BDA0001531599170000087
优化表示网络。
步骤103,计算每个源域与目标域之间的对抗分数。在本发明具体实施例中,累加每一路判别器的损失值作为对应源域与目标域的对抗分数(表征域间相似性)。
步骤104,基于各源域的分类器和对抗分数对目标域进行分类。
具体地说,根据步骤103获得的对抗分数以及目标模型的表示网络F和多路分类器C对目标域的样本进行分类,并赋予伪标签。特别地,对于目标域中的第i个样本
Figure BDA0001531599170000088
目标模型将其标记为第c类标签的置信度为
Figure BDA0001531599170000082
其中
Figure BDA0001531599170000089
表示第sj路分类器将样本
Figure BDA00015315991700000810
分类为第c类标签的概率,
Figure BDA0001531599170000091
表示目标域与源域sk通过步骤103计算得到的对抗分数,
Figure BDA0001531599170000092
表示第c类标签属于源域sj时对应的第sj路分类器才会参与计算该类标签的置信度。
直观上讲,目标模型通过表示网络F对图像进行特征提取,并利用多路分类器对特征进行分类,以对抗分数作为权重对分类结果进行加权平均,对抗分数越大,表明相应的源域与目标域越相近,则该路分类器的分类结果更可靠。
步骤105,选取高置信度的目标域伪样本微调目标模型的表示网络和分类器。
在本发明具体实施例中,在步骤104的基础上选取置信度大于设定阈值的样本组成目标域伪样本集合
Figure BDA0001531599170000093
并对目标模型的多路分类器进行微调,以获取在目标域上更加有效可分的特征编码。具体地,基于优化目标:
Figure BDA0001531599170000094
更新目标模型的表示网络F和多路分类器C,
Figure BDA0001531599170000095
表示源域
Figure BDA0001531599170000096
的标签集合包含伪标签
Figure BDA0001531599170000097
时,才对相应的第
Figure BDA0001531599170000098
路分类器进行更新。
步骤106,返回步骤102,进行步骤102-105,直至模型收敛或达到最大迭代次数时停止训练。
以下将配合图2通过具体实施例来说明本发明:在本发明具体实施例中,以两个源域为例,调用开源深度学习框架Pytorch,开源机器学习库Scikit-learn中的可视化工具t-SNE,具体过程如下:
(1)源域与目标域图像的特征提取(图3左虚线框)
输入带标记的N(这里取N=2进行示意)个源域数据集(分别对应图4的A和D),其分布表示为
Figure BDA0001531599170000101
其中sj表示第j个源域,x和y分别表示样本图像和对应标签。假定各源域的数据集合
Figure BDA0001531599170000102
采样自不同的分布,其中
Figure BDA0001531599170000103
Figure BDA0001531599170000104
分别表示来自源域sj的图像和对应标签。同时,输入无标记的目标域数据集(对应图4的W),其分布记为pt(x,y),对应图像集合记为
Figure BDA0001531599170000105
在每一次迭代中,各个源域与目标域均随机采样相同数量的训练样本,并通过参数共享的表示网络F进行特征表示。
(2)源域与目标域图像特征的多路对抗(图3中虚线框)
基于上述提取的图像特征,将每一源域sj和目标域t分别组成一对,输入多路判别器网络D进行判定。判别器网络的更新策略是尽可能区分开输入特征是来自源域还是目标域;而表示网络的更新策略是尽可能混淆特征,使得判别器网络无法区分输入特征是来自源域还是目标域。
由于在对抗学习的训练过程中容易产生梯度弥散的问题,为克服这一难题,更新多路判别器和表示网络的损失函数将采用其最小二乘表示进行优化,即使用
Figure BDA0001531599170000106
优化多路判别器,使用
Figure BDA0001531599170000107
优化表示网络,其中
Figure BDA0001531599170000108
表示第sj路判别器。
由于多源域适应学习中存在负迁移的不利现象,本发明在多路对抗过程中回传困难样本的梯度用于更新目标模型的表示网络。具体地,在所有源域
Figure BDA0001531599170000109
中选择
Figure BDA00015315991700001010
使得
Figure BDA00015315991700001011
并回传源域
Figure BDA0001531599170000111
与目标域的对抗损失更新表示网络,其中M是当前迭代中的样本数量。
与此同时,本发明累加每一路判别器的损失值作为对应源域与目标域的对抗分数,用以表征域间相似性。判别器的损失值越大,则表明对应源域的特征与目标域越混淆、越相近。
(3)目标域样本的多路分类(图3右虚线框)
根据(2)获得的对抗分数以及目标模型的表示网络F和多路分类器C对目标域的样本进行分类,并赋予伪标签。特别地,对于目标域中的第i个样本
Figure BDA0001531599170000112
目标模型将其标记为第c类标签的置信度为
Figure BDA0001531599170000113
其中
Figure BDA0001531599170000114
表示第sj路分类器将样本
Figure BDA0001531599170000115
分类为第c类标签的概率,
Figure BDA0001531599170000116
表示目标域与源域sk在多路对抗过程中计算得到的对抗分数,
Figure BDA0001531599170000117
表示第c类标签属于源域sj时对应的第sj路分类器才会参与计算该类标签的置信度。直观上讲,目标模型通过表示网络F对图像进行特征提取,并利用多路分类器对特征进行分类,以对抗分数作为权重对分类结果进行加权平均,对抗分数越大,表明相应的源域与目标域越相近,则该路分类器的分类结果更可靠。在此基础上,选取置信度大于设定阈值的样本组成目标域伪样本集合
Figure BDA0001531599170000118
并对目标模型的多路分类器进行微调,以获取在目标域上更加有效可分的特征编码。
图4展示了两个源域(A、D)迁移到目标域(W)在域适应前后的可视化效果,不同的图标形状表示不同的类别。为直观显示起见,我们将两个源域与目标域的特征进行逐对展示。通过图4(3)对比图4(1)、图4(4)对比图4(2)不难发现,使用了本发明的多源域适应迁移方法后,不同类别的类间距扩大,可分性更强,进而有助于提高目标域图像的分类精度。同时图4(4)对比图4(3)可以表明,D→W的域适应效果要好于A→W,而这也与对抗分数的高低相一致,表明本发明的方法能够区分不同领域间的差异性,避免在域间适应过程中发生负迁移的不利现象。
图5为本发明一种基于对抗学习的多源域适应迁移系统的系统架构图。如图5所示,本发明一种基于对抗学习的多源域适应迁移系统,包括:
预训练单元501,用于使用各源域数据进行预训练并初始化目标模型的表示网络和分类器。
具体地,预训练单元501进一步包括:
输入模块,用于输入带标记的N个源域数据集,其分布表示为
Figure BDA0001531599170000121
其中sj表示第j个源域,x和y分别表示样本图像和对应标签。假定各源域的数据集合
Figure BDA0001531599170000122
采样自不同的分布,其中
Figure BDA0001531599170000123
Figure BDA0001531599170000124
分别表示来自源域sj的图像和对应标签,同时,输入单元还输入无标记的目标域数据集,其分布记为pt(x,y),对应图像集合记为
Figure BDA0001531599170000125
在本发明具体实施例中,以两个源域为例,即输入源域S1和S2的图像和对应标签,输入目标域T的图像;
预训练模块,用于使用所有的源域数据集对领域无关的表示网络F和领域相关的多路分类器C进行目标模型的预训练,即根据优化目标
Figure BDA0001531599170000131
更新目标模型中表示网络和多路分类器的参数,其中
Figure BDA0001531599170000132
表示选取的损失函数类型,
Figure BDA0001531599170000133
表示第sj路分类器。
在本发明具体实施例中,所述多源域数据的标签集合的交集总和等于目标域的标签集合,即
Figure BDA0001531599170000134
多路对抗单元502,用于使用多源域数据与目标域数据进行多路对抗,更新目标模型的表示网络和多路判别器。具体地,多路对抗单元502固定当前多路分类器C的参数,引入目标域图像数据进行多路对抗,多路对抗单元502进一步包括:
特征提取模块,用于使用表示网络F对多源域和目标域的图像进行特征提取,在本发明具体实施例中,得到源域S1、S2以及目标域T的特征表示;
训练更新模块,用于将每一源域sj和目标域t分别组成一对,例如S1和T,S2和T,输入多路判别器网络D进行判定训练,更新目标模型的表示网络和多路判别器。在本发明具体实施例中,多路判别器网络D的更新策略是尽可能区分开输入特征是来自源域还是目标域;而表示网络的更新策略是尽可能混淆特征,使得判别器网络无法区分输入特征是来自源域还是目标域。
优选地,于多路对抗单元502中,多路对抗过程将回传困难样本的梯度用于更新目标模型的表示网络。
优选地,为使对抗的训练过程稳定,上述多路对抗单元502中更新多路判别器和表示网络的损失函数使用其最小二乘表示进行优化。
对抗分数计算单元503,用于计算每个源域与目标域之间的对抗分数。在本发明具体实施例中,对抗分数计算单元503累加每一路判别器的损失值作为对应源域与目标域的对抗分数(表征域间相似性)。
分类单元504,用于基于各源域的分类器和对抗分数对目标域进行分类。
具体地说,根据对抗分数计算单元503获得的对抗分数以及目标模型的表示网络F和多路分类器C对目标域的样本进行分类,并赋予伪标签。
直观上讲,目标模型通过表示网络F对图像进行特征提取,并利用多路分类器对特征进行分类,以对抗分数作为权重对分类结果进行加权平均,对抗分数越大,表明相应的源域与目标域越相近,则该路分类器的分类结果更可靠。
微调单元505,用于选取高置信度的目标域伪样本微调目标模型的表示网络和分类器,并返回多路对抗单元502进行训练,直至模型收敛或达到最大迭代次数时停止训练。
在本发明具体实施例中,微调单元505在分类单元504的基础上选取置信度大于设定阈值的样本组成目标域伪样本集合
Figure BDA0001531599170000141
并对目标模型的多路分类器进行微调,以获取在目标域上更加有效可分的特征编码。
可见,本发明将现有的单源域适应过程推广到多源域适应,使之不再依赖单一源域标签集合与目标域一致的假设,在现实场景中具有更强的通用性。此外,由于本发明基于对抗学习对不同领域间的特征进行适配,有效避免了负迁移现象的产生,比较明显地提升了域适应后的分类性能。
上述实施例仅例示性说明本发明的原理及其功效,而非用于限制本发明。任何本领域技术人员均可在不违背本发明的精神及范畴下,对上述实施例进行修饰与改变。因此,本发明的权利保护范围,应如权利要求书所列。

Claims (8)

1.一种基于对抗学习的多源域适应迁移方法,包括如下步骤:
步骤一,获取多个源域的带标记的源域数据以及无标记的目标域数据,并使用各源域的源域数据对目标模型的表示网络和多路分类器进行预训练及初始化,各源域数据包括图像数据和对应标签,所述目标域数据包括图像数据;所述步骤一还包括,
输入带标记的N个源域数据集以及输入无标记的目标域数据集,
使用所有的源域数据集对领域无关的表示网络F和领域相关的多路分类器C进行目标模型的预训练,所述预训练的步骤具体为根据如下优化目标
Figure FDA0002461077390000011
更新目标模型中表示网络F和多路分类器C的参数,其中
Figure FDA0002461077390000012
表示多路分类的损失函数,
Figure FDA0002461077390000013
表示具体选取的损失函数类型,
Figure FDA0002461077390000014
表示第sj路分类器,E表示所有样本损失值的期望,F(x)表示图像x经过表示网络F后的特征编码;
步骤二,固定当前多路分类器的参数,引入目标域数据,使用所述多个源域的源域数据与目标域数据进行多路对抗,更新所述目标模型的表示网络和多路判别器;
步骤三,基于每一路判别器的损失值计算对应源域与目标域之间的对抗分数;
步骤四,基于各源域的多路分类器和对抗分数对目标域的样本进行分类,赋予伪标签;
步骤五,选取高置信度的目标域伪样本微调所述目标模型的表示网络和多路分类器,获取在目标域上更加有效可分的特征编码;
步骤六,返回步骤二,进行步骤二-五,直至模型收敛或达到最大迭代次数时停止训练。
2.如权利要求1所述的一种基于对抗学习的多源域适应迁移方法,其特征在于,步骤二进一步包括:
使用表示网络对多源域和目标域的图像进行特征提取;
将每一源域和目标域分别组成一对,输入多路判别器网络D进行判定训练,更新目标模型的表示网络和多路判别器。
3.如权利要求2所述的一种基于对抗学习的多源域适应迁移方法,其特征在于:所述多路判别器网络D的更新策略为尽可能区分开输入特征是来自源域还是目标域,表示网络的更新策略是尽可能混淆特征,使得判别器网络无法区分输入特征是来自源域还是目标域。
4.如权利要求3所述的一种基于对抗学习的多源域适应迁移方法,其特征在于:于步骤二中,更新多路判别器和表示网络的损失函数使用其最小二乘表示进行优化。
5.如权利要求4所述的一种基于对抗学习的多源域适应迁移方法,其特征在于:于步骤三中,累加每一路判别器的损失值作为对应源域与目标域的对抗分数。
6.如权利要求1所述的一种基于对抗学习的多源域适应迁移方法,其特征在于:于步骤四中,根据步骤三获得的对抗分数以及目标模型的表示网络F和多路分类器C对目标域的样本进行分类,并赋予伪标签。
7.如权利要求1所述的一种基于对抗学习的多源域适应迁移方法,其特征在于:于步骤五中,在步骤四的基础上选取置信度大于设定阈值的样本组成目标域伪样本集合,并对目标模型的多路分类器进行微调,以获取在目标域上更加有效可分的特征编码。
8.一种基于对抗学习的多源域适应迁移系统,包括:
预训练单元,用于获取多个源域的带标记的源域数据以及无标记的目标域数据,并使用各源域的源域数据对化目标模型的表示网络和多路分类器进行预训练及初始化,各源域数据包括图像数据和对应标签,所述目标域数据包括图像数据,所述预训练的步骤具体为根据如下优化目标:
Figure FDA0002461077390000031
更新目标模型中表示网络F和多路分类器C的参数,其中
Figure FDA0002461077390000032
表示多路分类的损失函数,
Figure FDA0002461077390000033
表示具体选取的损失函数类型,
Figure FDA0002461077390000034
表示第sj路分类器,E表示所有样本损失值的期望,F(x)表示图像x经过表示网络F后的特征编码
多路对抗单元,用于通过固定当前多路分类器的参数,引入目标域数据,使用所述多个源域的源域数据与目标域数据进行多路对抗,更新所述目标模型的表示网络和多路判别器;
对抗分数计算单元,用于基于每一路判别器的损失值计算对应源域与目标域之间的对抗分数;
分类单元,用于基于各源域的多路分类器和对抗分数对目标域的样本进行分类,赋予伪标签;
微调单元,用于选取高置信度的目标域伪样本微调所述目标模型的表示网络和多路分类器,获取在目标域上更加有效可分的特征编码,并返回所述多路对抗单元进行训练,直至模型收敛或达到最大迭代次数时停止训练。
CN201711468680.XA 2017-12-29 2017-12-29 一种基于对抗学习的多源域适应迁移方法及系统 Active CN108256561B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201711468680.XA CN108256561B (zh) 2017-12-29 2017-12-29 一种基于对抗学习的多源域适应迁移方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201711468680.XA CN108256561B (zh) 2017-12-29 2017-12-29 一种基于对抗学习的多源域适应迁移方法及系统

Publications (2)

Publication Number Publication Date
CN108256561A CN108256561A (zh) 2018-07-06
CN108256561B true CN108256561B (zh) 2020-06-16

Family

ID=62724910

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201711468680.XA Active CN108256561B (zh) 2017-12-29 2017-12-29 一种基于对抗学习的多源域适应迁移方法及系统

Country Status (1)

Country Link
CN (1) CN108256561B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11875270B2 (en) 2020-12-08 2024-01-16 International Business Machines Corporation Adversarial semi-supervised one-shot learning

Families Citing this family (34)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109710636B (zh) * 2018-11-13 2022-10-21 广东工业大学 一种基于深度迁移学习的无监督工业系统异常检测方法
CN109523018B (zh) * 2019-01-08 2022-10-18 重庆邮电大学 一种基于深度迁移学习的图片分类方法
CN109948648B (zh) * 2019-01-31 2023-04-07 中山大学 一种基于元对抗学习的多目标域适应迁移方法及系统
CN110569985A (zh) * 2019-03-09 2019-12-13 华南理工大学 基于在线和离线决策集成学习的在线异构迁移学习的方法
CN110348579B (zh) * 2019-05-28 2023-08-29 北京理工大学 一种领域自适应迁移特征方法及系统
CN110188829B (zh) * 2019-05-31 2022-01-28 北京市商汤科技开发有限公司 神经网络的训练方法、目标识别的方法及相关产品
US11113829B2 (en) * 2019-08-20 2021-09-07 GM Global Technology Operations LLC Domain adaptation for analysis of images
CN110674849B (zh) * 2019-09-02 2021-06-18 昆明理工大学 基于多源域集成迁移的跨领域情感分类方法
CN110807194A (zh) * 2019-10-17 2020-02-18 新华三信息安全技术有限公司 一种webshell检测方法及装置
CN111523680B (zh) * 2019-12-23 2023-05-12 中山大学 一种基于Fredholm学习和对抗学习的域适应方法
CN111209935B (zh) * 2019-12-26 2022-03-25 武汉安视感知科技有限公司 基于自适应域转移的无监督目标检测方法及系统
CN111161239B (zh) * 2019-12-27 2024-02-27 上海联影智能医疗科技有限公司 医学图像分析方法、装置、存储介质及计算机设备
CN111275092B (zh) * 2020-01-17 2022-05-13 电子科技大学 一种基于无监督域适应的图像分类方法
CN111340819B (zh) * 2020-02-10 2023-09-12 腾讯科技(深圳)有限公司 图像分割方法、装置和存储介质
CN111310852B (zh) * 2020-03-08 2022-08-12 桂林电子科技大学 一种图像分类方法及系统
CN111444952B (zh) * 2020-03-24 2024-02-20 腾讯科技(深圳)有限公司 样本识别模型的生成方法、装置、计算机设备和存储介质
CN111444951B (zh) * 2020-03-24 2024-02-20 腾讯科技(深圳)有限公司 样本识别模型的生成方法、装置、计算机设备和存储介质
CN111382568B (zh) * 2020-05-29 2020-09-11 腾讯科技(深圳)有限公司 分词模型的训练方法和装置、存储介质和电子设备
CN111723691B (zh) * 2020-06-03 2023-10-17 合肥的卢深视科技有限公司 一种三维人脸识别方法、装置、电子设备及存储介质
CN111610768B (zh) * 2020-06-10 2021-03-19 中国矿业大学 基于相似度多源域迁移学习策略的间歇过程质量预测方法
CN111950608B (zh) * 2020-06-12 2021-05-04 中国科学院大学 一种基于对比损失的域自适应物体检测方法
CN111882055B (zh) * 2020-06-15 2022-08-05 电子科技大学 一种基于CycleGAN与伪标签的目标检测自适应模型的构建方法
CN111860677B (zh) * 2020-07-29 2023-11-21 湖南科技大学 一种基于部分域对抗的滚动轴承迁移学习故障诊断方法
CN112215405B (zh) * 2020-09-23 2024-04-16 国网甘肃省电力公司电力科学研究院 一种基于dann域适应学习的非侵入式居民用电负荷分解方法
CN112766334B (zh) * 2021-01-08 2022-06-21 厦门大学 一种基于伪标签域适应的跨域图像分类方法
CN112906857B (zh) * 2021-01-21 2024-03-19 商汤国际私人有限公司 一种网络训练方法及装置、电子设备和存储介质
CN112836795B (zh) * 2021-01-27 2023-08-18 西安理工大学 一种多源非均衡域自适应方法
CN112990387B (zh) * 2021-05-17 2021-07-20 腾讯科技(深圳)有限公司 模型优化方法、相关设备及存储介质
CN113468323B (zh) * 2021-06-01 2023-07-18 成都数之联科技股份有限公司 争议焦点类别及相似判断方法及系统及装置及推荐方法
CN113486827B (zh) * 2021-07-13 2023-12-08 上海中科辰新卫星技术有限公司 基于域对抗与自监督的多源遥感影像迁移学习方法
CN113762466B (zh) * 2021-08-02 2023-06-20 国网河南省电力公司信息通信公司 电力物联网流量分类方法及装置
CN114841137A (zh) * 2022-04-18 2022-08-02 北京百度网讯科技有限公司 模型获取方法、装置、电子设备及存储介质
CN114998602B (zh) * 2022-08-08 2022-12-30 中国科学技术大学 基于低置信度样本对比损失的域适应学习方法及系统
CN116580255B (zh) * 2023-07-13 2023-09-26 华南师范大学 多源域多目标域自适应方法、装置与电子设备

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103649294A (zh) * 2011-04-29 2014-03-19 贝克顿·迪金森公司 多路分类系统和方法
CN106056043A (zh) * 2016-05-19 2016-10-26 中国科学院自动化研究所 基于迁移学习的动物行为识别方法和装置
CN107103364A (zh) * 2017-03-28 2017-08-29 上海大学 一种基于多源域的任务拆分迁移学习预测方法

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103649294A (zh) * 2011-04-29 2014-03-19 贝克顿·迪金森公司 多路分类系统和方法
CN106056043A (zh) * 2016-05-19 2016-10-26 中国科学院自动化研究所 基于迁移学习的动物行为识别方法和装置
CN107103364A (zh) * 2017-03-28 2017-08-29 上海大学 一种基于多源域的任务拆分迁移学习预测方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
Pattern classification and clustering: A review of partially supervised learning approaches;Friedhelm Schwenker;《Elsevier》;20141231;第4-14页 *
平行学习—机器学习的一个新型理论框架;李力;《自动化学报》;20170131;第1-7页 *

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11875270B2 (en) 2020-12-08 2024-01-16 International Business Machines Corporation Adversarial semi-supervised one-shot learning

Also Published As

Publication number Publication date
CN108256561A (zh) 2018-07-06

Similar Documents

Publication Publication Date Title
CN108256561B (zh) 一种基于对抗学习的多源域适应迁移方法及系统
Shu et al. Transferable curriculum for weakly-supervised domain adaptation
Hao et al. An end-to-end architecture for class-incremental object detection with knowledge distillation
Sukhbaatar et al. Learning from noisy labels with deep neural networks
US10719780B2 (en) Efficient machine learning method
Grubb et al. Speedboost: Anytime prediction with uniform near-optimality
EP3767536A1 (en) Latent code for unsupervised domain adaptation
CN114492574A (zh) 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法
Bochinski et al. Deep active learning for in situ plankton classification
Wang et al. Towards realistic predictors
CN113469186B (zh) 一种基于少量点标注的跨域迁移图像分割方法
CN111832511A (zh) 一种增强样本数据的无监督行人重识别方法
CN113128478B (zh) 模型训练方法、行人分析方法、装置、设备及存储介质
De Rosa et al. Online open world recognition
CN114090780B (zh) 一种基于提示学习的快速图片分类方法
CN110705591A (zh) 一种基于最优子空间学习的异构迁移学习方法
CN113035311A (zh) 一种基于多模态注意力机制的医学图像报告自动生成方法
CN104376308B (zh) 一种基于多任务学习的人体动作识别方法
CN104680193A (zh) 基于快速相似性网络融合算法的在线目标分类方法与系统
WO2015146113A1 (ja) 識別辞書学習システム、識別辞書学習方法および記録媒体
CN110991500A (zh) 一种基于嵌套式集成深度支持向量机的小样本多分类方法
CN114863176A (zh) 基于目标域移动机制的多源域自适应方法
Nikpour et al. Deep reinforcement learning in human activity recognition: A survey
Zhang et al. Long-tailed classification with gradual balanced loss and adaptive feature generation
Mund et al. Active online confidence boosting for efficient object classification

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
CB03 Change of inventor or designer information

Inventor after: Lin Jing

Inventor after: Chen Ziliang

Inventor after: Wang Keze

Inventor after: Xu Ruijia

Inventor before: Lin Jing

Inventor before: Chen Ziliang

Inventor before: Wang Keze

Inventor before: Xu Ruijia

CB03 Change of inventor or designer information
GR01 Patent grant
GR01 Patent grant