CN111160553B - 一种新的领域自适应学习方法 - Google Patents
一种新的领域自适应学习方法 Download PDFInfo
- Publication number
- CN111160553B CN111160553B CN201911342565.7A CN201911342565A CN111160553B CN 111160553 B CN111160553 B CN 111160553B CN 201911342565 A CN201911342565 A CN 201911342565A CN 111160553 B CN111160553 B CN 111160553B
- Authority
- CN
- China
- Prior art keywords
- classification
- task
- target field
- model
- loss
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Image Analysis (AREA)
Abstract
提本发明属于新一代信息技术领域自适应学习技术领域,提出了一种新的领域自适应学习方法,该方法面向图像分类任务,不需要复杂的对抗学习,而是通过一个目标领域图像旋转预测的辅助分类任务和目标领域无标记样本插值后预测结果一致性约束,构建多任务学习模型,最终学习得到适用于目标领域的特征和适用于目标领域数据分布的分类模型。本发明不依赖目标领域样本标注的情况下学习得到适用于目标领域的分类器,大大降低了在测试数据分布发生变化时手工标注样本的压力。本发明结合目标领域样本的插值一致性先验和目标领域无标记样本的旋转角度预测这一辅助任务进行多任务学习,既能学习到适用于目标领域的特征,又能确保分类边界在目标领域数据分布中处在合适位置,能够有效提高目标领域的分类性能。
Description
技术领域
本发明涉及领域自适应深度学习领域,特别是面向图像分类的领域自适应深度学习方法。
背景技术
目前大多数深度学习方法采用监督学习,通过手工标记大量的样本进行模型训练。但是,手工标注样本十分耗费体力,成本高昂。此外,标记的训练样本和真实的测试样本之间很有可能存在分布不同的问题,这种情况下,训练的模型在测试数据上的性能往往会急剧下降。
领域自适应学习就是一种为了解决由于训练数据和测试数据的分布不同导致机器学习性能下降而提出了一种迁移学习方法。领域自适应学习利用源领域的标注数据学习得到目标领域依然适用的模型。根据目标领域数据是否有标注信息,领域自适应学习可以分为有监督领域自适应学习、半监督领域自适应学习和无监督领域自适应学习。无监督领域自适应学习由于完全不依赖目标领域数据标注信息而应用更加广泛。近年来,深度学习快速发展并在计算机视觉领域取得空前成功。最近提出的领域自适应学习方法也大都采用深度神经网络模型,这些深度领域自适应学习方法可以分为两类,一类是基于最小化差异(discrepancy)的方法,这些方法通过最小化源领域和目标领域的特征之间的差异实现领域不变特征学习。另一类方法是基于对抗学习(adversarial learning)的方法,这类方法通过最小最大化博弈,学习一个领域判别器实现对源领域和目标领域的鉴别,同时学习一个特征提取器(生成器)迷惑之前的领域判别器,当最小最大化优化达到均衡时可以实现领域特征的对齐。这两类方法存在的问题是优化目标和训练过程较为复杂。
自监督学习是近年来发展迅速的一类机器学习方法,它通过设置不依赖手工标注的辅助任务,学习得到适用于下游任务的特征。文献(Revisiting Self-SupervisedVisual Representation Learning,Alexander Kolesnikov,CVPR2019)证明,自监督学习是一种有效的特征学习方法。基于自监督学习的思想,文献(Self-Supervised DomainAdaptation for Computer Vision Tasks,Jiaolong Xu,IEEE Access 2019(7):156694-156706)和专利(201910139916.8)提出一种自监督领域自适应学习方法,利用目标领域的图像旋转预测这一辅助任务学习适用于目标领域的特征,能够有效提升模型在目标领域数据上的性能。
文献(Self-Ensembling For Visual Domain Adaptation,Geoff French,ICLR18)提出了一种基于自集成的领域自适应方法,这种方法利用训练过程中不同迭代获得的模型的参数均值(自集成)作为教师模型,同时对目标领域的无标记样本进行随机增广,利用增广后样本在教师模型和学生模型上的预测的一致性作为监督信号,以学习适用于目标领域的模型。
上述文献记载的方法挖掘了目标领域无标记样本的自监督信息,通过辅助任务构建多任务学习系统,能够学习到适用于目标领域的特征,但是这两种方法没有显式地考虑目标领域的聚类假设,也就是在目标领域的数据分布中,相近的样本很可能具有相同的类别,导致学习得到的分类边界可能存在不合理现象。
发明内容
本发明的目的是解决自监督领域自适应学习方法中欠缺对目标领域的聚类假设,从而造成学习得到的分类边界不合理的技术问题。
为达到上述目的解决上述技术问题,本发明提出一种新的领域自适应学习方法,该方法的技术方案包括如下步骤:
S1.准备源领域有标记样本集Ds(x,y)和目标领域无标记样本集Dt(x);
S4.构建由源领域有监督分类任务、目标领域无标记样本插值一致性任务和目标领域样本旋转预测任务组成的多任务学习模型并在Ds(x,y)和Dt(x)上进行训练,以获取主任务分类模型的最优参数θ*;
S4.1.确定训练的迭代次数T、移动平均系数α、[0,1]之间的随机分布Q;
S4.2.初始化网络参数Θ={θv,θc,θa},初始化主任务分类模型θ={θv,θc}的移动均值:θ′∶=θ;
S4.3.利用随机梯度下降法进行迭代,更新模型参数;
S4.3.2.计算源领域小批量样本的主任务分类损失
其中损失函数可采用交叉熵损失
进行计算。
S4.3.4.利用主分类网络的均值教师模型计算目标领域样本的伪标记
S4.3.5.从随机分布Q中采样插值系数λ;
S4.3.6.计算样本和预测的插值,样本插值的计算方法为:
预测的伪标记插值结果为:
S4.3.7.计算插值一致性损失
具体可采用均值平方误差
进行计算。
S4.3.9.计算目标领域样本的辅助分类任务损失
其中损失函数可采用交叉熵损失
进行计算。
S4.3.10.根据主任务分类损失、插值一致性损失和辅助任务分类损失计算总损失:
S4.3.11计算总损失L相对模型参数Θ的梯度;
S4.3.12.更新主任务分类模型参数的移动平均值
θ′∶=αθ′+(1-α)θ; (12)
S4.3.13.利用随机梯度下降法更新模型参数Θ;
与现有技术相比,本发明有效收益在于:
(1)本发明在不依赖目标领域样本标注的情况下学习得到适用于目标领域的分类器,大大降低了在测试数据分布发生变化时手工标注样本的压力。
(2)本发明结合目标领域样本的插值一致性先验和目标领域无标记样本的旋转角度预测这一辅助任务进行多任务学习,既能学习到适用于目标领域的特征,又能确保分类边界在目标领域数据分布中处在合适位置,能够有效提高目标领域的分类性能。
附图说明
图1是本发明的流程示意图;
图2是本发明中主任务分类模型和辅助任务分类模型示意图;
图3是本发明中多任务学习的损失函数示意图。
具体实施方式
本发明提出一种新的领域自适应学习方法,该方法面向图像分类任务,不需要复杂的对抗学习,而是通过一个目标领域图像旋转预测的辅助分类任务和目标领域无标记样本插值后预测结果一致性约束,构建多任务学习模型,最终学习得到适用于目标领域的特征和适用于目标领域数据分布的分类模型。
下面结合附图和实施例对本发明作进一步的详细描述,本实例采用MNIST数据集作为源领域数据集,MNIST数据集为来自NIST(National Institute of Standards andTechnology,美国国家标准与技术研究所)的手写数据集,包括0~9一共10个类别,图像为分辨率为28*28;采用USPS(US Postal Servers,美国邮政服务)手写数据集作为目标领域数据集,该数据集同样包含0~9共10个类别,图像分辨率为16*16。本发明的实现流程如附图1所示。
第一步,收集源领域有标记训练样本集Ds(x,y)和目标领域无标记训练样本集Dt(x),其中标记y采用one-hot向量表示,为了使神经网络分类模型能有同时应用于源领域和目标领域,需要对图像进行预处理,因此将USPS数据集的图像尺寸利用双线性插值上采样至与MNIST数据集一致;
第二步,如图2所示,根据主分类任务中输入图像的尺寸和类别数等特点,构造深度卷积神经网络分类模型,本实施例中,输入图像分辨率较小,类别为0~9一共10类手写数字,任务相对较为简单,所以可以采取层数较少的网络模型,例如可以采用7层的LeNet5网络,前6层为特征提取网络(参数为θv),最后一层为10类softmax输出层(参数为θc),并将该级联网络模型记为fθ(x),其中θ={θv,θc};
第四步,构建由源领域有监督分类任务、目标领域无标记样本插值一致性任务和目标领域样本旋转预测任务组成的多任务学习模型并在Ds(x,y)和Dt(x)上进行训练,获取主任务分类模型最优参数该步骤通过以下分步骤实现。
1、确定训练的迭代次数T,移动平均系数α,[0,1]之间的随机分布Q;
2、随机初始化模型参数Θ={θv,θc,θa},初始化θ的移动均值:θ′∶=θ;
3、利用随机梯度下降法进行T次迭代,更新模型参数Θ;
上述第3小步中,需要通过T次迭代对模型参数进行训练,训练的损失函数计算如图3所示,每次迭代t包括如下步骤:
2)利用前向传播计算源领域小批量样本的主任务分类损失
可选地,主任务分类损失可以采用交叉熵损失函数
其中M为主分类任务的总类别数,对于MNIST手写数字识别,M=10;
4)利用主分类模型的均值教师模型计算目标领域样本的伪标记
此处采用均值教师模型的作用是相比当前迭代的模型,均值教师模型可以获得更加温和的正则;
5)从随机分布Q采样插值系数λ;
6)计算样本和伪标记的插值,样本插值的计算方法为:
伪标记的插值结果为:
7)根据样本插值和伪标记的插值结果计算插值一致性损失:
可选地,一致性损失采用均值平方误差
9)利用旋转后的图像在辅助分类网络上前向传播,计算辅助分类任务损失
可选地,辅助分类任务损失可以采用交叉熵损失函数
10)根据主任务分类损失、插值一致性损失和辅助任务分类损失计算总损失
11)利用反向传播算法计算总损失相对模型参数Θ的梯度
12)更新主任务模型参数的移动平均值
θ′∶=αθ′+(1-α)θ (12)
13)利用随机梯度下降法更新模型参数Θ,参数更新方式为:
其中γt为当前迭代t的学习率。
第五步,选取最优模型应用于目标领域的图像并计算分类结果。
虽然本发明已通过实施例进行了描述,然而本发明并非局限于这里所描述的实施例,在不脱离本发明所做出的各种改变以及变化仍属于本发明的范围。
Claims (5)
1.一种新的领域自适应学习方法,其特征在于,包含如下步骤:
S1.准备源领域有标记样本集Ds(x,y)和目标领域无标记样本集Dt(x);
S4.构建由源领域有监督分类任务、目标领域无标记样本插值一致性任务和目标领域样本旋转预测任务组成的多任务学习模型并在Ds(x,y)和Dt(x)上进行训练,以获取主任务分类模型的最优参数θ*;
S4.1.确定训练的迭代次数T,移动平均系数α,[0,1]之间的随机分布Q;
S4.2.初始化网络参数Θ={θv,θc,θa},初始化主任务分类模型θ={θv,θc}的移动均值:θ'∶=θ;
S4.3.利用随机梯度下降法进行迭代,更新模型参数;
S4.3.2.计算源领域小批量样本的主任务分类损失:
S4.3.4.利用主分类网络的均值教师模型计算目标领域样本的伪标记
S4.3.5.从随机分布Q采样插值系数λ;
S4.3.6.计算样本和预测的插值,样本插值的计算方法为:
预测的伪标记插值结果为:
S4.3.7.计算插值一致性损失
S4.3.9.计算目标领域样本的辅助分类任务损失
S4.3.10.根据主任务分类损失、插值一致性损失和辅助任务分类损失计算总损失:
S4.3.11计算总损失L相对模型参数Θ的梯度;
S4.3.12.更新主任务分类模型参数的移动平均值
θ′:=αθ′+(1-α)θ; (12)
S4.3.13.利用随机梯度下降法更新模型参数Θ;
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911342565.7A CN111160553B (zh) | 2019-12-23 | 2019-12-23 | 一种新的领域自适应学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911342565.7A CN111160553B (zh) | 2019-12-23 | 2019-12-23 | 一种新的领域自适应学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111160553A CN111160553A (zh) | 2020-05-15 |
CN111160553B true CN111160553B (zh) | 2022-10-25 |
Family
ID=70558212
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911342565.7A Active CN111160553B (zh) | 2019-12-23 | 2019-12-23 | 一种新的领域自适应学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111160553B (zh) |
Families Citing this family (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111797935B (zh) * | 2020-07-13 | 2023-10-31 | 扬州大学 | 基于群体智能的半监督深度网络图片分类方法 |
CN112116441B (zh) * | 2020-10-13 | 2024-03-12 | 腾讯科技(深圳)有限公司 | 金融风险分类模型的训练方法、分类方法、装置及设备 |
CN112288004A (zh) * | 2020-10-28 | 2021-01-29 | 香港中文大学(深圳) | 一种无需一致性约束的半监督方法及移动终端 |
GB2608344A (en) | 2021-01-12 | 2022-12-28 | Zhejiang Lab | Domain-invariant feature-based meta-knowledge fine-tuning method and platform |
CN112364945B (zh) * | 2021-01-12 | 2021-04-16 | 之江实验室 | 一种基于域-不变特征的元-知识微调方法及平台 |
CN112949786B (zh) * | 2021-05-17 | 2021-08-06 | 腾讯科技(深圳)有限公司 | 数据分类识别方法、装置、设备及可读存储介质 |
CN114220016B (zh) * | 2022-02-22 | 2022-06-03 | 山东融瓴科技集团有限公司 | 面向开放场景下的无人机航拍图像的域自适应识别方法 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109919209B (zh) * | 2019-02-26 | 2020-06-19 | 中国人民解放军军事科学院国防科技创新研究院 | 一种领域自适应深度学习方法及可读存储介质 |
CN110175982B (zh) * | 2019-04-16 | 2021-11-02 | 浙江大学城市学院 | 一种基于目标检测的缺陷检测方法 |
CN110580496A (zh) * | 2019-07-11 | 2019-12-17 | 南京邮电大学 | 一种基于熵最小化的深度迁移学习系统及方法 |
-
2019
- 2019-12-23 CN CN201911342565.7A patent/CN111160553B/zh active Active
Also Published As
Publication number | Publication date |
---|---|
CN111160553A (zh) | 2020-05-15 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111160553B (zh) | 一种新的领域自适应学习方法 | |
CN109919108B (zh) | 基于深度哈希辅助网络的遥感图像快速目标检测方法 | |
CN108596053B (zh) | 一种基于ssd和车辆姿态分类的车辆检测方法和系统 | |
CN110414377B (zh) | 一种基于尺度注意力网络的遥感图像场景分类方法 | |
CN107180426B (zh) | 基于可迁移的多模型集成的计算机辅助肺结节分类装置 | |
CN109523013B (zh) | 基于浅层卷积神经网络的空气颗粒物污染程度估计方法 | |
CN111126488B (zh) | 一种基于双重注意力的图像识别方法 | |
CN109671070B (zh) | 一种基于特征加权和特征相关性融合的目标检测方法 | |
CN108876796A (zh) | 一种基于全卷积神经网络和条件随机场的道路分割系统及方法 | |
CN110633708A (zh) | 一种基于全局模型和局部优化的深度网络显著性检测方法 | |
CN114283287B (zh) | 基于自训练噪声标签纠正的鲁棒领域自适应图像学习方法 | |
CN109743642B (zh) | 基于分层循环神经网络的视频摘要生成方法 | |
Lin et al. | Ru-net: Regularized unrolling network for scene graph generation | |
CN112347970B (zh) | 一种基于图卷积神经网络的遥感影像地物识别方法 | |
CN109446894B (zh) | 基于概率分割及高斯混合聚类的多光谱图像变化检测方法 | |
CN111985581A (zh) | 一种基于样本级注意力网络的少样本学习方法 | |
Isobe et al. | Deep convolutional encoder-decoder network with model uncertainty for semantic segmentation | |
CN112364791B (zh) | 一种基于生成对抗网络的行人重识别方法和系统 | |
CN110443257B (zh) | 一种基于主动学习的显著性检测方法 | |
CN104091038A (zh) | 基于大间隔分类准则的多示例学习特征加权方法 | |
CN106056165A (zh) | 一种基于超像素关联性增强Adaboost分类学习的显著性检测方法 | |
CN116110022A (zh) | 基于响应知识蒸馏的轻量化交通标志检测方法及系统 | |
CN112084897A (zh) | 一种gs-ssd的交通大场景车辆目标快速检测方法 | |
CN115995040A (zh) | 一种基于多尺度网络的sar图像小样本目标识别方法 | |
CN113989256A (zh) | 遥感图像建筑物的检测模型优化方法及检测方法、装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |