CN116128047A - 一种基于对抗网络的迁移学习方法 - Google Patents
一种基于对抗网络的迁移学习方法 Download PDFInfo
- Publication number
- CN116128047A CN116128047A CN202211579284.5A CN202211579284A CN116128047A CN 116128047 A CN116128047 A CN 116128047A CN 202211579284 A CN202211579284 A CN 202211579284A CN 116128047 A CN116128047 A CN 116128047A
- Authority
- CN
- China
- Prior art keywords
- target
- source
- representing
- risk value
- task
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000013508 migration Methods 0.000 title claims abstract description 31
- 238000000034 method Methods 0.000 title claims abstract description 20
- 230000005012 migration Effects 0.000 title claims abstract description 16
- 238000013526 transfer learning Methods 0.000 claims abstract description 16
- 230000014509 gene expression Effects 0.000 claims description 16
- 238000004364 calculation method Methods 0.000 claims description 12
- 230000006870 function Effects 0.000 claims description 6
- 238000012549 training Methods 0.000 claims description 6
- 238000010276 construction Methods 0.000 claims description 4
- 230000008485 antagonism Effects 0.000 abstract description 3
- 238000010801 machine learning Methods 0.000 abstract description 3
- 230000005764 inhibitory process Effects 0.000 abstract description 2
- 230000009286 beneficial effect Effects 0.000 description 4
- 230000008901 benefit Effects 0.000 description 3
- 238000011160 research Methods 0.000 description 2
- 238000013459 approach Methods 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000003745 diagnosis Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000008030 elimination Effects 0.000 description 1
- 238000003379 elimination reaction Methods 0.000 description 1
- 235000003642 hunger Nutrition 0.000 description 1
- 206010027175 memory impairment Diseases 0.000 description 1
- 230000037351 starvation Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Multimedia (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Testing, Inspecting, Measuring Of Stereoscopic Televisions And Televisions (AREA)
Abstract
本发明公开了一种基于对抗网络的迁移学习方法,属于机器学习技术领域,包括获取源域数据集和目标域数据集;分别利用标签1和标签0对应预标记源域数据集和目标域数据集中的样本图像;基于预标记后的样本图像,构建对抗迁移学习框架;基于对抗迁移学习框架,完成迁移学习;本发明通过引入对抗性学习从而实现对过拟合的抑制,解决了免由于目标任务的数据集不足而导致的过拟合问题的问题。
Description
技术领域
本发明属于机器学习技术领域,尤其涉及一种基于对抗网络的迁移学习方法。
背景技术
迁移学习是机器学习中的一个重要研究问题。它的目标是将源任务中学到的知识或模式应用到不同但相关的目标任务中。迁移学习的好处在于它可以将在大规模数据上训练的模型中学到的知识迁移到不同的任务场景中,以适应不同情况的需要。更重要的是,传统的模型训练依赖于大量的数据,但在某些情况下,数据不足是不可避免的,而迁移学习可以解决这个问题。鉴于迁移学习的诸多优点,它已广泛应用于文本分类、图像分类、语音识别和故障诊断等多个领域。
然而,迁移学习也面临一些问题,如灾难性遗忘和负迁移等等。针对灾难性遗忘,现有技术提出了DELTA和L2-SP等正则化方法,也有现有技术对负迁移进行了处理和解决,但当目标任务数据集极小时还面临严重的过拟合问题,且当前还缺少专门的研究。
发明内容
针对现有技术中的上述不足,本发明提供的一种基于对抗网络的迁移学习方法,通过引入对抗性学习从而实现对过拟合的抑制,解决了由于目标任务的数据集不足而导致的过拟合问题。
为了达到上述发明目的,本发明采用的技术方案为:
本发明提供一种基于对抗网络的迁移学习方法,包括如下步骤:
S1、获取源域数据集和目标域数据集;
S2、分别利用标签1和标签0对应预标记源域数据集和目标域数据集中的样本图像;
S3、基于预标记后的样本图像,构建对抗迁移学习框架;
S4、基于对抗迁移学习框架,完成迁移学习。
本发明的有益效果为:本发明提供的一种基于对抗网络的迁移学习方法,提出了一种在目标任务数据集不足情况下的对抗迁移学习框架,由一个特征提取器、一个鉴别器和两个分类器组成,通过提取器和鉴别器之间的对抗性训练,取出更多的域不变特征,从而抑制了过拟合问题。
进一步地,所述步骤S3包括如下步骤:
S31、利用特征提取器提取源域数据集和目标域数据集中预标记后的各样本图像的图像特征;
S32、将各图像特征均分别传输至标签0分类器、标签1分类器和鉴别器;
S33、利用标签1分类器根据图像特征,得到源域分类损失;
S34、利用标签0分类器根据图像特征,得到目标域分类损失;
S35、利用鉴别器根据图像特征,得到判别损失,完成构建对抗迁移学习框架。
采用上述进一步方案的有益效果为:通过据样本的预标记标签,将提取出的图像特征发送到相应的分类器进行分类,得到最终的图像识别结果,实现对抗迁移学习框架的构建。
进一步地,所述鉴别器的计算表达式如下:
采用上述进一步方案的有益效果为:提供鉴别器的计算表达式,相较于传统生成对抗网络更能精确计算距离,避免梯度消失和模式坍塌的问题。
进一步地,所述步骤S4包括如下步骤:
S41、将学习问题进行风险值最小化,并基于对抗迁移学习框架,得到源任务和目标任务的风险值;
S42、分别定义源任务和目标任务的损失为源任务经验风险值和目标任务经验风险值;
S43、基于源任务和目标任务的风险值,得到组合风险值;
S44、基于源任务经验风险值和目标任务经验风险值,得到组合风险的经验风险值;
S45、基于组合风险值和组合风险的经验风险值,完成迁移学习。
采用上述进一步方案的有益效果为:基于对抗学习框架通过加权分类代价值,计算得到组合风险值和组合风险的经验风险值,实现迁移学习。
进一步地,所述步骤S41中源任务和目标任务的风险值的计算表达式分别如下:
其中,x′表示特征空间X的样本图像,y′表示样本空间Y的样本图像,和分别表示源任务fs和目标任务ft的风险值,fs和ft分别表示源任务和目标任务的目标预测函数,表示源任务在联合分布Pr(x′,y′)上的预测损失,表示目标任务在联合分布Pr(x′,y′)上的预测损失,L(·)表示负对数似然损失函数。
进一步地,所述步骤S42中源任务经验风险值和目标任务经验风险值的计算表达式分别如下:
其中,losss和losst分别表示源任务经验风险值和目标任务经验风险值,表示定义,和分别表示源任务和目标任务的损失,和分别表示源域数据集和目标域数据集的训练样本树,x′i表示第i个特征空间X的样本图像输入,和分别表示源域数据集的第i个样本图像和目标域数据集的第i个样本图像。
进一步地,所述步骤S43中组合风险值的计算表达式如下:
其中,Rcombined表示组合风险值,p((x′,y′)∈(X×Ys))表示样本图像组(x′,y′)来自于特征空间X和源域数据集Ys的概率,p((x′,y′)∈(X×Yt))表示样本图像组(x′,y′)来自于特征空间X和目标域数据集Ys的概率,ws表示样本来自源域数据集的概率系数,wt表示样本来自目标域数据集的概率系数。
进一步地,所述步骤S44中组合风险的经验风险值的计算表达式如下:
附图说明
图1为本发明实施例中基于对抗网络的迁移学习方法的步骤流程图。
图2为本发明实施例中对抗迁移学习框架的示意图。
具体实施方式
下面对本发明的具体实施方式进行描述,以便于本技术领域的技术人员理解本发明,但应该清楚,本发明不限于具体实施方式的范围,对本技术领域的普通技术人员来讲,只要各种变化在所附的权利要求限定和确定的本发明的精神和范围内,这些变化是显而易见的,一切利用本发明构思的发明创造均在保护之列。
如图1所示,在本发明的一个实施例中,本发明提供一种基于对抗网络的迁移学习方法,包括如下步骤:
S1、获取源域数据集和目标域数据集;
S2、分别利用标签1和标签0对应预标记源域数据集和目标域数据集中的样本图像;
所述源域数据集中的样本图像采用标签1预标记,目标域数据集中的样本图像采用标签0预标记;
S3、基于预标记后的样本图像,构建对抗迁移学习框架;
如图2所示,所述步骤S3包括如下步骤:
S31、利用特征提取器提取源域数据集和目标域数据集中预标记后的各样本图像的图像特征;
S32、将各图像特征均分别传输至标签0分类器、标签1分类器和鉴别器;
S33、利用标签1分类器根据图像特征,得到源域分类损失;
S34、利用标签0分类器根据图像特征,得到目标域分类损失;
S35、利用鉴别器根据图像特征,得到判别损失,完成构建对抗迁移学习框架;
所述鉴别器的计算表达式如下:
S4、基于对抗迁移学习框架,完成迁移学习;
所述步骤S4包括如下步骤:
S41、将学习问题进行风险值最小化,并基于对抗迁移学习框架,得到源任务和目标任务的风险值;
所述步骤S41中源任务和目标任务的风险值的计算表达式分别如下:
其中,x′表示特征空间X的样本图像,y′表示样本空间Y的样本图像,和分别表示源任务fs和目标任务ft的风险值,fs和ft分别表示源任务和目标任务的目标预测函数,表示源任务在联合分布Pr(x′,y′)上的预测损失,表示目标任务在联合分布Pr(x′,y′)上的预测损失,L(·)表示负对数似然损失函数;
S42、分别定义源任务和目标任务的损失为源任务经验风险值和目标任务经验风险值;
所述步骤S42中源任务经验风险值和目标任务经验风险值的计算表达式分别如下:
其中,losss和losst分别表示源任务经验风险值和目标任务经验风险值,表示定义,和分别表示源任务和目标任务的损失,和分别表示源域数据集和目标域数据集的训练样本树,xi′表示第i个特征空间X的样本图像输入,yi′s和yi′t分别表示源域数据集的第i个样本图像和目标域数据集的第i个样本图像;
S43、基于源任务和目标任务的风险值,得到组合风险值;
所述步骤S43中组合风险值的计算表达式如下:
其中,Rcombined表示组合风险值,p((x′,y′)∈(X×Ys))表示样本图像组(x′,y′)来自于特征空间X和源域数据集Ys的概率,p((x′,y′)∈(X×Yt))表示样本图像组(x′,y′)来自于特征空间X和目标域数据集Ys的概率,ws表示样本来自源域数据集的概率系数,wt表示样本来自目标域数据集的概率系数;
S44、基于源任务经验风险值和目标任务经验风险值,得到组合风险的经验风险值;
所述步骤S44中组合风险的经验风险值的计算表达式如下:
S45、基于组合风险值和组合风险的经验风险值, 完成迁移学习。
本实施例中使用EMNIST数字数据集作为源域,并使用EMNIST字母数据集作为目标域,数字数据集中每个类的样本数固定在1000,在字母数据集中,每个类的样本数量取值有5种情况,分别为20、50、100、150和200;
本实施例中对抗迁移学习框架的目标是通过大量的手写数字样本来提高少样本情况下的手写字母识别性能;为了评估本方法的性能,本实施例中将试验结果与几种流行的算法进行了比较,包括基线、微调、普通的对抗网络以及BSS;本实施例中采用的基线模型是直接在手写字母数据集上进行训练的;
不同模型在字母数据集样本量不同情况下的训练结果如表1所示:
表1
根据表1能够得到,在不同的样本量下,本发明所提出的对抗迁移学习框架的性能都优于其他四种比较模型,这表明,当目标任务的数据集非常小时,本发明所提出的对抗迁移学习框架能够抑制过拟合问题;此外,随着字母数据集样本量的减小,本方案的优势变得越来越明显。特别是当每个类的样本数为20时,本方案提出的迁移学习框架与其他比较模型的准确率差异最大。
Claims (8)
1.一种基于对抗网络的迁移学习方法,其特征在于,包括如下步骤:
S1、获取源域数据集和目标域数据集;
S2、分别利用标签1和标签0对应预标记源域数据集和目标域数据集中的样本图像;
S3、基于预标记后的样本图像,构建对抗迁移学习框架;
S4、基于对抗迁移学习框架,完成迁移学习。
2.根据权利要求1所述的一种基于对抗网络的迁移学习方法,其特征在于,所述步骤S3包括如下步骤:
S31、利用特征提取器提取源域数据集和目标域数据集中预标记后的各样本图像的图像特征;
S32、将各图像特征均分别传输至标签0分类器、标签1分类器和鉴别器;
S33、利用标签1分类器根据图像特征,得到源域分类损失;
S34、利用标签0分类器根据图像特征,得到目标域分类损失;
S35、利用鉴别器根据图像特征,得到判别损失,完成构建对抗迁移学习框架。
4.根据权利要求3所述的一种基于对抗网络的迁移学习方法,其特征在于,所述步骤S4包括如下步骤:
S41、将学习问题进行风险值最小化,并基于对抗迁移学习框架,得到源任务和目标任务的风险值;
S42、分别定义源任务和目标任务的损失为源任务经验风险值和目标任务经验风险值;
S43、基于源任务和目标任务的风险值,得到组合风险值;
S44、基于源任务经验风险值和目标任务经验风险值,得到组合风险的经验风险值;
S45、基于组合风险值和组合风险的经验风险值,完成迁移学习。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211579284.5A CN116128047B (zh) | 2022-12-08 | 2022-12-08 | 一种基于对抗网络的迁移学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211579284.5A CN116128047B (zh) | 2022-12-08 | 2022-12-08 | 一种基于对抗网络的迁移学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116128047A true CN116128047A (zh) | 2023-05-16 |
CN116128047B CN116128047B (zh) | 2023-11-14 |
Family
ID=86296454
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211579284.5A Active CN116128047B (zh) | 2022-12-08 | 2022-12-08 | 一种基于对抗网络的迁移学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116128047B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117275220A (zh) * | 2023-08-31 | 2023-12-22 | 云南云岭高速公路交通科技有限公司 | 基于非完备数据的山区高速公路实时事故风险预测方法 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20130132315A1 (en) * | 2010-07-22 | 2013-05-23 | Jose Carlos Principe | Classification using correntropy |
CN107895177A (zh) * | 2017-11-17 | 2018-04-10 | 南京邮电大学 | 一种保持图像分类稀疏结构的迁移分类学习方法 |
CN112131967A (zh) * | 2020-09-01 | 2020-12-25 | 河海大学 | 基于多分类器对抗迁移学习的遥感场景分类方法 |
AU2020103905A4 (en) * | 2020-12-04 | 2021-02-11 | Chongqing Normal University | Unsupervised cross-domain self-adaptive medical image segmentation method based on deep adversarial learning |
CN113361566A (zh) * | 2021-05-17 | 2021-09-07 | 长春工业大学 | 用对抗性学习和判别性学习来迁移生成式对抗网络的方法 |
US20210383265A1 (en) * | 2018-09-28 | 2021-12-09 | Nec Corporation | Empirical risk estimation system, empirical risk estimation method, and empirical risk estimation program |
US20210390355A1 (en) * | 2020-06-13 | 2021-12-16 | Zhejiang University | Image classification method based on reliable weighted optimal transport (rwot) |
US20220058839A1 (en) * | 2018-12-31 | 2022-02-24 | Oregon Health & Science University | Translation of images of stained biological material |
CN114492574A (zh) * | 2021-12-22 | 2022-05-13 | 中国矿业大学 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
-
2022
- 2022-12-08 CN CN202211579284.5A patent/CN116128047B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20130132315A1 (en) * | 2010-07-22 | 2013-05-23 | Jose Carlos Principe | Classification using correntropy |
CN107895177A (zh) * | 2017-11-17 | 2018-04-10 | 南京邮电大学 | 一种保持图像分类稀疏结构的迁移分类学习方法 |
US20210383265A1 (en) * | 2018-09-28 | 2021-12-09 | Nec Corporation | Empirical risk estimation system, empirical risk estimation method, and empirical risk estimation program |
US20220058839A1 (en) * | 2018-12-31 | 2022-02-24 | Oregon Health & Science University | Translation of images of stained biological material |
US20210390355A1 (en) * | 2020-06-13 | 2021-12-16 | Zhejiang University | Image classification method based on reliable weighted optimal transport (rwot) |
CN112131967A (zh) * | 2020-09-01 | 2020-12-25 | 河海大学 | 基于多分类器对抗迁移学习的遥感场景分类方法 |
AU2020103905A4 (en) * | 2020-12-04 | 2021-02-11 | Chongqing Normal University | Unsupervised cross-domain self-adaptive medical image segmentation method based on deep adversarial learning |
CN113361566A (zh) * | 2021-05-17 | 2021-09-07 | 长春工业大学 | 用对抗性学习和判别性学习来迁移生成式对抗网络的方法 |
CN114492574A (zh) * | 2021-12-22 | 2022-05-13 | 中国矿业大学 | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 |
Non-Patent Citations (2)
Title |
---|
ZHUN DENG 等: "Adversarial Training Helps Transfer Learning via Better Representations", 35TH CONFERENCE ON NEURAL INFORMATION PROCESSING SYSTEMS, pages 1 - 13 * |
袁珑 等: "面向目标检测的对抗样本综述", 中国图象图形学报, pages 2873 - 2896 * |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117275220A (zh) * | 2023-08-31 | 2023-12-22 | 云南云岭高速公路交通科技有限公司 | 基于非完备数据的山区高速公路实时事故风险预测方法 |
Also Published As
Publication number | Publication date |
---|---|
CN116128047B (zh) | 2023-11-14 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109886121A (zh) | 一种遮挡鲁棒的人脸关键点定位方法 | |
CN109993201B (zh) | 一种图像处理方法、装置和可读存储介质 | |
CN109993100B (zh) | 基于深层特征聚类的人脸表情识别的实现方法 | |
CN113963165B (zh) | 一种基于自监督学习的小样本图像分类方法及系统 | |
CN105184772A (zh) | 一种基于超像素的自适应彩色图像分割方法 | |
CN116701725B (zh) | 基于深度学习的工程师人员数据画像处理方法 | |
CN116128047B (zh) | 一种基于对抗网络的迁移学习方法 | |
CN117197904B (zh) | 人脸活体检测模型的训练方法、人脸活体检测方法及装置 | |
Chen et al. | Offline handwritten digits recognition using machine learning | |
CN111191033A (zh) | 一种基于分类效用的开集分类方法 | |
Sahoo et al. | Indian sign language recognition using skin color detection | |
CN117611838A (zh) | 一种基于自适应超图卷积网络的多标签图像分类方法 | |
CN117435982A (zh) | 一种多维度快速识别网络水军的方法 | |
CN105844299B (zh) | 一种基于词袋模型的图像分类方法 | |
CN109145749B (zh) | 一种跨数据集的面部表情识别模型构建及识别方法 | |
CN111401485A (zh) | 实用的纹理分类方法 | |
CN116051924A (zh) | 一种图像对抗样本的分治防御方法 | |
CN116091828A (zh) | 一种染色体图像可解释分析方法、装置、设备及存储介质 | |
CN111931767B (zh) | 一种基于图片信息度的多模型目标检测方法、装置、系统及存储介质 | |
Tomar et al. | A Comparative Analysis of Activation Function, Evaluating their Accuracy and Efficiency when Applied to Miscellaneous Datasets | |
CN113297376A (zh) | 基于元学习的法律案件风险点识别方法及系统 | |
CN113255768A (zh) | 一种提升卷积神经网络鲁棒性能的方法 | |
CN112200216B (zh) | 汉字识别方法、装置、计算机设备和存储介质 | |
CN113610121B (zh) | 一种跨域任务深度学习识别方法 | |
CN110211149B (zh) | 一种基于背景感知的尺度自适应核相关滤波跟踪方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |