CN111144565B - 基于一致性训练的自监督领域自适应深度学习方法 - Google Patents
基于一致性训练的自监督领域自适应深度学习方法 Download PDFInfo
- Publication number
- CN111144565B CN111144565B CN201911372719.7A CN201911372719A CN111144565B CN 111144565 B CN111144565 B CN 111144565B CN 201911372719 A CN201911372719 A CN 201911372719A CN 111144565 B CN111144565 B CN 111144565B
- Authority
- CN
- China
- Prior art keywords
- self
- training
- training set
- task
- consistency
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Abstract
本发明公开了一种基于一致性训练的自监督领域自适应深度学习方法。该方法首先构建一个数据增强变换集合,对每一个变换定义一个标签。针对源域样本和其对应的类别标签,构建分类任务;对源域和目标域样本应用所述数据增强变换,通过最小化预测该变换类别的误差,构建自监督学习任务;针对源域和目标域样本,通过最小化变换后的样本和原始样本在分类任务上的输出的KL散度(Kullback‑Leibler Divergence),构建一致性训练任务;构建一个多任务学习网络,将所述的分类、自监督学习和一致性训练任务进行联合训练。该方法无需对目标域样本进行标注,能有效地学习目标域特征表示,提升目标域上样本分类和识别的性能。本申请还公开了一种领域自适应深度学习可读存储介质,同样具有上述有益效果。
Description
技术领域
本发明属于新一代信息技术领域,具体涉及领域自适应深度学习图像分类方法及可读存储介质。
背景技术
机器学习特别是深度学习模型通常需要大量的标注样本来进行监督学习,比如图像、文本等的分类和识别需要收集大量的样本,同时还需要标注每一个样本的对应的类别。当模型在标注数据上训练完成之后,将其应用到测试数据上。当测试数据与训练数据具有相同的分布时,监督学习是一种非常有效的方法。然而实际应用中通常会出现测试数据与训练数据分布不同的情况,从而使得模型在测试数据集上的性能急剧下降。
领域自适应(domain adaptation)是解决上述由于训练和测试数据分布差异引起模型性能下降问题的一类技术方法。通常将训练数据集称为源领域,测试数据集称为目标领域。源领域的数据是带有标注信息的,而目标领域的数据通常是没有标注信息的。领域自适应技术旨在将源领域的监督信息迁移到目标领域,提升目标领域上任务的性能。目前基于深度神经网络的领域自适应学习大多数是通过领域对抗训练来学习跨领域不变的特征表示,从而提升目标领域上的任务的性能的。然而领域对抗训练需要优化一对相互对抗的目标函数,训练过程的收敛比较困难,很难得到最优的模型。
发明内容
本发明要解决的技术问题是领域对抗训练时优化一对相互对抗的目标函数,训练过程的收敛困难,难以获取适合的模型。
本发明为解决上述技术问题,提供基于一致性训练的自监督领域自适应深度学习图像分类方法,该方法提供一种非对抗式的训练方法,以提高目标领域上任务的性能,具体的技术方案如下:
S1:构建一个多任务学习深度神经网络,包含一个参数为θe的特征提取网E,参数为θm主分类网M,以及参数为θp的图像增强变换预测网P;
S2:将源域图像xs和其类别标签y组成分类任务训练集Ds={(xs,y)|y∈[0,C]},其中C是类别数;
S3:构建一组图像增强变换集合G={g(x,r)|r∈[0,R)},每一个图像增强变换g(x,r)对应一个变换类别标签r;
S6:针对步骤S2中分类任务训练集Ds以及步骤S1中的特征提取网E和主分类网M,构建有监督学习任务,其训练损失函数为:
S7:针对步骤S4中自监督学习训练集D*以及步骤S1中的特征提取网E和数据据增强变换预测网P,构建自监督学习任务,其训练损失函数为:
S8:针对步骤S5中一致性训练集Dc以及步骤S1中的特征提取网E和主分类网M,构建一致性学习任务,通过KL散度(Kullback-Leibler Divergence)距离构建其训练损失函数:
其中DKL为KL散度距离;
S9:将步骤S6、S7以及S8中的损失函数加权求和,得到总的训练损失函数:
Ltotal=LM+λ1LP+λ2LC (4)
其中λ1和λ2为加权系数,可通过交叉验证选取合适的值;
S10:通过最小化步骤S9中的损失函数Ltotal,得到训练后优化的参数θe、θp以及θm;
S11:对目标域测试样本,使用步骤S10中优化后的参数,通过公式
y~t=argmax[M(E(xt))] (5)
得到其预测的样本类别,实现深度学习模型在目标域上的领域自适应。
本发明还提供一种可读存储介质,该可读存储介质上存储有程序,当该程序被处理器执行时能够实现步骤S1-S11的基于一致性训练的自监督领域自适应深度学习方法。
相对于现有技术,本发明的有效收益如下:
1、本发明提供的领域自适应深度学习方法,通过图像增强来构建一致性训练和自监督训练,通过多任务学习框架联合源领域标注样本的监督学习来学习适应目标领域的特征表示,从而实现领域自适应。
2、本发明该不依赖人工标注来构建目标领域训练集,通过目标域样本的一致性训练和自监督学习,建立适应目标领域任务的特征表示,从而提高目标领域上任务的性能。
3、本发明还提供一种领域自适应深度学习可读存储介质,该可读存储介质上存储有程序,当该程序被处理器执行时同样具有上述有益效果。
附图说明
图1是本发明实施例的基于一致性训练的自监督领域自适应深度学习训练过程的流程示意图。
具体实施方式
以下结合说明书附图和图像分类领域自适应学习实例对本发明作进一步的详细描述,但并不因此而限制本发明的保护范围。
图1给出了本发明实施例的基于一致性训练的自监督领域自适应深度学习训练流程示意图。以图像分类领域自适应学习主要包括以下步骤:
S1:构建一个多任务学习深度神经网络,包含一个参数为θe的特征提取网E,参数为θm图像分类网M,以及参数为θp的图像增强变换预测网P;
本实施例中S1中的图像增强变换采用图像旋转操作。
S2:将源域图像xs和其类别标签y组成分类任务训练集Ds={(xs,y)|y∈[0,C]},其中C是图像类别数目;
S3:构建一组基于图像旋转的图像增强变换集合G={g(x,r)|r∈[0,R)},每一个图像增强变换g(x,r)对应一个变换类别标签r,本实例采用三种不同角度旋转(即R=3),分别为90°、180°和270°旋转,对应的变换标签为0,1和2;
S6:针对步骤S2中分类任务训练集Ds以及步骤S1中的特征提取网E和主分类网M,构建有监督学习任务,其训练损失函数为:
S7:针对步骤S4中自监督学习训练集D*以及步骤S1中的特征提取网E和基于图像旋转的数据据增强变换预测网P,构建自监督学习任务,其训练损失函数为:
S8:针对步骤S5中一致性训练集Dc以及步骤S1中的特征提取网E和主分类网M,构建一致性学习任务,通过KL散度(Kullback-Leibler Divergence)距离构建其训练损失函数:
其中DKL为KL散度距离;
S9:将步骤S6、S7以及S8中的损失函数加权求和,得到总的训练损失函数:
Ltotal=LM+λ1LP+λ2LC (4)
其中λ1和λ2为加权系数,可通过交叉验证选取合适的值,本实例种λ1和λ2分别为0.7和1.0;
S10:通过SGD或Adam等优化算法最小化步骤S9中的损失函数Ltotal,得到训练后优化的参数θe、θp以及θm;
S11:对目标域测试图像,使用步骤S10中优化后的参数,通过公式
得到其预测的样本类别,从而实现深度学习模型在目标域上的领域自适应。
下面对本申请实施例提供的可读存储介质进行介绍,下文描述的可读存储介质与上文描述的领域自适应深度学习方法可相互对应参照。
本申请公开的一种可读存储介质,其上存储有程序,程序被处理器执行时实现领域自适应深度学习方法的步骤。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的可读存储介质中程序的流程,可以参考前述方法实施例中的对应过程,在此不再赘述。
为了更好地说明本发明的技术效果,以图像分类任务为例,发明人在PACS数据集上进行了图像分类领域自适应学习的实验。PACS是一个公开数据集,包含来自4个域(ArtPaintings、Cartoon、Sketches以及Photo)的图像,每个域包含7个类别。测试结果如下所示:
表1:
Art paint. | cartoon | sketches | photo | Avg. | |
SRC | 79.3 | 76.8 | 64.4 | 96.4 | 79.2 |
Jigsaw | 84.9 | 83.9 | 69.0 | 93.9 | 82.9 |
Rot | 88.7 | 86.4 | 74.9 | 98.0 | 87.0 |
Ours | 89.9 | 87.7 | 75.1 | 97.9 | 87.7 |
表1中从第二列到最第五列,每一列分别表示以该列列名对应的域作为目标域,其他三个域合并作为源域,最后一列Avg.表示平均分类准确率。表1对比了三种方法,包括只采用源领域样本训练的方法SRC,以及两种基于自监督的领域自适应学习方法Jigsaw和Rot。从表1中可以看出,只采用源领域样本训练的方法,由于没有做领域自适应学习,在目标域上性能最差。采用基于自监督学习的领域自适应学习,可以得到较好的自适应学习效果。本发明通过一致性训练的自监督学习,进一步提高了领域自适应性能,得到了更适应目标领域的特征表示,因此在目标领域上的分类准确率达到了最高。
虽然本发明已通过优选实施例进行了描述,然而本发明并非局限于这里描述的实施例,在不脱离本发明范围的情况下还包括所做出的各种改变以及变化。
Claims (2)
1.基于一致性训练的自监督领域自适应深度学习图像分类方法,其特征在于,该方法包括:
S1:构建一个多任务学习深度神经网络,包含一个参数为θe的图像特征提取网E,参数为θm主分类网M,以及参数为θp的图像增强变换预测网P;
S2:将源域图像xs和其类别标签y组成分类任务训练集Ds={(xs,y)|y∈[0,C]},其中C是类别数;
S3:构建一组基于图像旋转的图像增强变换集合G={g(x,r)|r∈[0,R)},每一个图像增强变换g(x,r)对应一个变换类别标签r,采用三种不同角度旋转,即R=3,分别进行90°、180°和270°旋转,对应的变换标签为0,1和2;
S6:针对步骤S2中分类任务训练集Ds以及步骤S1中的特征提取网E和主分类网M,构建有监督学习任务,其训练损失函数为:
S7:针对步骤S4中自监督学习训练集D*以及步骤S1中的特征提取网E和数据据增强变换预测网P,构建自监督学习任务,其训练损失函数为:
S8:针对步骤S5中一致性训练集Dc以及步骤S1中的特征提取网E和主分类网M,构建一致性学习任务,通过KL散度(Kullback-Leibler Divergence)距离构建其训练损失函数:
其中DKL为KL散度距离;
S9:将步骤S6、S7以及S8中的损失函数加权求和,得到总的训练损失函数:
Ltotal=LM+λ1LP+λ2LC (4)
其中λ1和λ2为加权系数,可通过交叉验证选取合适的值;
S10:通过最小化步骤S9中的损失函数Ltotal,得到训练后优化的参数θe、θp以及θm;
S11:对目标域测试图像,使用步骤S10中优化后的参数,通过公式
得到其预测的图像类别,实现深度学习模型在目标域上的领域自适应。
2.一种可读存储介质,该可读存储介质上存储有程序,其特征在于,当该程序被处理器执行时能够实现如权利要求1所述的基于一致性训练的自监督领域自适应深度学习图像分类方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911372719.7A CN111144565B (zh) | 2019-12-27 | 2019-12-27 | 基于一致性训练的自监督领域自适应深度学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911372719.7A CN111144565B (zh) | 2019-12-27 | 2019-12-27 | 基于一致性训练的自监督领域自适应深度学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111144565A CN111144565A (zh) | 2020-05-12 |
CN111144565B true CN111144565B (zh) | 2020-10-27 |
Family
ID=70520810
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911372719.7A Active CN111144565B (zh) | 2019-12-27 | 2019-12-27 | 基于一致性训练的自监督领域自适应深度学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111144565B (zh) |
Families Citing this family (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11586982B2 (en) * | 2019-09-18 | 2023-02-21 | Samsung Electronics Co., Ltd. | Electronic and atomic structure computation utilizing machine learning |
US11537898B2 (en) | 2019-10-02 | 2022-12-27 | Samsung Electronics Co., Ltd. | Generative structure-property inverse computational co-design of materials |
CN112529913A (zh) * | 2020-12-14 | 2021-03-19 | 北京达佳互联信息技术有限公司 | 图像分割模型训练方法、图像处理方法及装置 |
CN112712003B (zh) * | 2020-12-25 | 2022-07-26 | 华南理工大学 | 一种用于骨骼动作序列识别的联合标签数据增强方法 |
CN113313166B (zh) * | 2021-05-28 | 2022-07-26 | 华南理工大学 | 基于特征一致性学习的船舶目标自动标注方法 |
CN113792758B (zh) * | 2021-08-18 | 2023-11-07 | 中国矿业大学 | 一种基于自监督学习和聚类的滚动轴承故障诊断方法 |
CN114490950B (zh) * | 2022-04-07 | 2022-07-12 | 联通(广东)产业互联网有限公司 | 编码器模型的训练方法及存储介质、相似度预测方法及系统 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN103729648A (zh) * | 2014-01-07 | 2014-04-16 | 中国科学院计算技术研究所 | 领域自适应模式识别方法及系统 |
CN107103364A (zh) * | 2017-03-28 | 2017-08-29 | 上海大学 | 一种基于多源域的任务拆分迁移学习预测方法 |
CN107451616A (zh) * | 2017-08-01 | 2017-12-08 | 西安电子科技大学 | 基于深度半监督迁移学习的多光谱遥感图像地物分类方法 |
CN108962224A (zh) * | 2018-07-19 | 2018-12-07 | 苏州思必驰信息科技有限公司 | 口语理解和语言模型联合建模方法、对话方法及系统 |
CN109919209A (zh) * | 2019-02-26 | 2019-06-21 | 中国人民解放军军事科学院国防科技创新研究院 | 一种领域自适应深度学习方法及可读存储介质 |
CN110163286A (zh) * | 2019-05-24 | 2019-08-23 | 常熟理工学院 | 一种基于混合池化的领域自适应图像分类方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20180330205A1 (en) * | 2017-05-15 | 2018-11-15 | Siemens Aktiengesellschaft | Domain adaptation and fusion using weakly supervised target-irrelevant data |
-
2019
- 2019-12-27 CN CN201911372719.7A patent/CN111144565B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN103729648A (zh) * | 2014-01-07 | 2014-04-16 | 中国科学院计算技术研究所 | 领域自适应模式识别方法及系统 |
CN107103364A (zh) * | 2017-03-28 | 2017-08-29 | 上海大学 | 一种基于多源域的任务拆分迁移学习预测方法 |
CN107451616A (zh) * | 2017-08-01 | 2017-12-08 | 西安电子科技大学 | 基于深度半监督迁移学习的多光谱遥感图像地物分类方法 |
CN108962224A (zh) * | 2018-07-19 | 2018-12-07 | 苏州思必驰信息科技有限公司 | 口语理解和语言模型联合建模方法、对话方法及系统 |
CN109919209A (zh) * | 2019-02-26 | 2019-06-21 | 中国人民解放军军事科学院国防科技创新研究院 | 一种领域自适应深度学习方法及可读存储介质 |
CN110163286A (zh) * | 2019-05-24 | 2019-08-23 | 常熟理工学院 | 一种基于混合池化的领域自适应图像分类方法 |
Non-Patent Citations (3)
Title |
---|
Domain Adaptation of Deformable Part-Based Models;Jiaolong Xu 等;《IEEE TRANSACTIONS ON PATTERN ANALYSIS AND MACHINE INTELLIGENCE》;20140603;第36卷(第12期);2367-2380 * |
Jes'us Andr'es-Ferrer 等.Efficient language model adaptation with Noise Contrastive Estimation and Kullback-Leibler regularization.《Interspeech 2018》.2018,3368-3372. * |
基于最小化最大平均差异损失的无监督领域自适应;桂存斌;《中国优秀硕士学位论文全文数据库 信息科技辑》;20190815(第8期);I138-1139 * |
Also Published As
Publication number | Publication date |
---|---|
CN111144565A (zh) | 2020-05-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111144565B (zh) | 基于一致性训练的自监督领域自适应深度学习方法 | |
CN108399428B (zh) | 一种基于迹比准则的三元组损失函数设计方法 | |
CN108776975B (zh) | 一种基于半监督特征和滤波器联合学习的视觉跟踪方法 | |
CN111127364B (zh) | 图像数据增强策略选择方法及人脸识别图像数据增强方法 | |
CN113837205B (zh) | 用于图像特征表示生成的方法、设备、装置和介质 | |
Siivola et al. | Good practices for Bayesian optimization of high dimensional structured spaces | |
CN113010683B (zh) | 基于改进图注意力网络的实体关系识别方法及系统 | |
CN114341886A (zh) | 用于识别无线电技术的神经网络 | |
CN107292323B (zh) | 用于训练混合模型的方法和设备 | |
Khakzar et al. | Learning interpretable features via adversarially robust optimization | |
KR20220024990A (ko) | L2TL(Learning to Transfer Learn)을 위한 프레임워크 | |
Suzuki et al. | Adversarial transformations for semi-supervised learning | |
Zhu et al. | Structured sparse low-rank regression model for brain-wide and genome-wide associations | |
WO2017188048A1 (ja) | 作成装置、作成プログラム、および作成方法 | |
Lonij et al. | Open-world visual recognition using knowledge graphs | |
KR101700030B1 (ko) | 사전 정보를 이용한 영상 물체 탐색 방법 및 이를 수행하는 장치 | |
Zhou et al. | Multi-kernel graph fusion for spectral clustering | |
de Boer et al. | Non-Gaussian Normative Modelling With Hierarchical Bayesian Regression | |
CN115797642B (zh) | 基于一致性正则化与半监督领域自适应图像语义分割算法 | |
CN108009586B (zh) | 封顶概念分解方法及图像聚类方法 | |
CN113537389B (zh) | 基于模型嵌入的鲁棒图像分类方法和装置 | |
Öfverstedt et al. | INSPIRE: Intensity and spatial information-based deformable image registration | |
CN111967499B (zh) | 基于自步学习的数据降维方法 | |
US11915120B2 (en) | Flexible parameter sharing for multi-task learning | |
CN114399025A (zh) | 一种图神经网络解释方法、系统、终端以及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |