CN110210486B - 一种基于素描标注信息的生成对抗迁移学习方法 - Google Patents
一种基于素描标注信息的生成对抗迁移学习方法 Download PDFInfo
- Publication number
- CN110210486B CN110210486B CN201910401740.9A CN201910401740A CN110210486B CN 110210486 B CN110210486 B CN 110210486B CN 201910401740 A CN201910401740 A CN 201910401740A CN 110210486 B CN110210486 B CN 110210486B
- Authority
- CN
- China
- Prior art keywords
- network
- depth
- image
- edge
- sketch
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 40
- 238000013526 transfer learning Methods 0.000 title claims description 7
- 230000011218 segmentation Effects 0.000 claims abstract description 106
- 238000002372 labelling Methods 0.000 claims abstract description 62
- 239000011159 matrix material Substances 0.000 claims abstract description 56
- 238000012549 training Methods 0.000 claims abstract description 48
- 238000013508 migration Methods 0.000 claims abstract description 37
- 230000005012 migration Effects 0.000 claims abstract description 11
- 230000006870 function Effects 0.000 claims description 53
- 238000011478 gradient descent method Methods 0.000 claims description 10
- 238000010276 construction Methods 0.000 claims description 6
- 238000011176 pooling Methods 0.000 claims description 6
- 238000013461 design Methods 0.000 claims description 4
- 238000005457 optimization Methods 0.000 claims description 4
- 238000012216 screening Methods 0.000 claims description 3
- 239000000463 material Substances 0.000 claims description 2
- OAICVXFJPJFONN-UHFFFAOYSA-N Phosphorus Chemical compound [P] OAICVXFJPJFONN-UHFFFAOYSA-N 0.000 claims 1
- 238000009826 distribution Methods 0.000 abstract description 10
- 238000004088 simulation Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 238000013507 mapping Methods 0.000 description 3
- 238000005070 sampling Methods 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 230000007547 defect Effects 0.000 description 2
- 101100460704 Aspergillus sp. (strain MF297-2) notI gene Proteins 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 239000002131 composite material Substances 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/20—Image preprocessing
- G06V10/26—Segmentation of patterns in the image field; Cutting or merging of image elements to establish the pattern region, e.g. clustering-based techniques; Detection of occlusion
- G06V10/267—Segmentation of patterns in the image field; Cutting or merging of image elements to establish the pattern region, e.g. clustering-based techniques; Detection of occlusion by performing operations on regions, e.g. growing, shrinking or watersheds
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于素描标注信息的生成对抗迁移学习方法,获取初始素描图,构造形式为“源域图像‑源域图像边缘标注图”的成对数据集;构造基于素描标注信息的边缘分割深度网络并训练;基于矩阵范数选择目标域样本;构造并训练基于素描标注信息的生成对抗迁移学习网络,该网络包括深度生成器网络、深度判别器网络、基于素描标注信息的边缘分割深度网络和深度分类器网络;输入目标域图像,得到目标域图像的分类结果;本发明利用源域数据及目标域数据结构的相似性,通过结构约束,生成确定标签的符合目标域分布的样本,从而进行标签的传递,实现跨域分类。提高了分类准确率,实现了跨域分类任务。
Description
技术领域
本发明属于图像分类技术领域,具体涉及一种基于素描标注信息的生成对抗迁移学习方法,可用于跨域图像分类。
背景技术
深度学习在图像分类问题上已经取得了显著的成果,在传统的深度学习框架下,学习的任务就是在给定有标签的训练数据集上学习到一个分类网络,参数越多的模型复杂度越高,也越能解决更为复杂的分类问题,但实际上,随着大数据时代的来临,获取数据的成本越来越低,标定数据标签的成本却没有降低,使得深度学习网络在处理这部分数据时遇到阻碍。深度迁移学习打破了传统的框架,利用已标记好类别的数据域,通过寻找两个域间的相似性,进行知识的迁移,从而完成标签的传递。
目前的深度迁移学习方法主要分为基于微调和领域自适应两种。基于微调的深度迁移学习主要是利用已训练好的成熟网络,针对特定的任务,固定网络相关层,修改输出层以满足任务的需要。由于预训练好的模型都是在较大数据集上进行的,无形中扩充了训练数据,使得训练的模型更具泛化力,但其无法解决训练数据与测试数据分布不相同的问题。领域自适应则通过提取域不变特征来达成目的,即假设数据分布不一致,但在特征空间中分布一致来完成迁移任务。该类方法存在的问题是,由于特征空间的抽象性,其实无法判断提取的特征是否为域不变特征。
发明内容
本发明所要解决的技术问题在于针对上述现有技术中的不足,提供一种基于素描标注信息的生成对抗迁移学习方法,通过约束源域和目标域间的结构信息,利用生成对抗网络,生成确定标签的符合目标域分布并与源域边缘结构相似的样本,从而实现目标域的图像分类。
本发明采用以下技术方案:
一种基于素描标注信息的生成对抗迁移学习方法,包括如下步骤:
S4、构造基于素描标注信息的生成对抗迁移学习网络,其中,对抗迁移学习网络包括深度生成器网络G、深度判别器网络D、基于素描标注信息的边缘分割深度网络T和深度分类器网络C,将源域图像及对应的标签和步骤S3中得到的目标域图像及对应的伪标签分批次输入对抗迁移学习网络中进行训练,每批次的大小为K;
具体的,步骤S1中,利用Primal Sketch算法得到源域图像对应的初始素描图,初始素描图的大小和源域图像大小相同,构造与第k张源域图像初始素描图大小一致的矩阵矩阵中对应于素描线段上素描点位置处的元素置为1,矩阵中其它位置处的元素置为0,将矩阵作为边缘标注图,从而获得形式为“源域图像-源域图像边缘标注图”的成对数据集其中,
具体的,步骤S2具体为:
首先,构造基于素描标注信息的边缘分割深度网络T,边缘分割深度网络T包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第一反卷积层、第二反卷积层、第五卷积层、第六卷积层和输出层;其中,边缘分割深度网络T的输入为每次从步骤S1中构造的形式为“源域图像-源域图像边缘标注图”的成对数据集中随机抽取的K对数据,K为每批次的大小,输出为得到的边缘分割图其中,
其次,对基于素描标注信息的边缘分割深度网络T进行训练,训练具体为:
构造边缘分割损失函数LT,其优化目标为:
其中,θT表示该基于素描标注信息的边缘分割深度网络T的参数,为第k张源域图像经过基于素描标注信息的边缘分割深度网络T的输出,表示该输出图像第(i,j)位置实际是边缘像素点的概率,表示和第k张源域图像边缘标注图对应的(i,j)位置希望是边缘像素点的概率,ω为权重参数,K为每批次的大小,N为图像大小;
具体的,步骤S3具体为:
具体的,步骤S4具体为:
构造基于素描标注信息的生成对抗迁移学习网络并进行训练,对抗迁移学习网络包括深度生成器网络G、深度判别器网络D、基于素描标注信息的边缘分割深度网络T和深度分类器网络C;
深度生成器网络G包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第一反卷积层、第二反卷积层、第五卷积层、第六卷积层和输出层;深度生成器网络G的输入为每次从源域图像及对应的标签中随机抽取的K个数据,K为每批次的大小,输出为图像及对应标签
深度判别器网络D包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第五卷积层、第一全连接层和二分类器;深度判别器网络D的输入为每次分别从目标域图像和深度生成器网络G的输出图像中随机抽取的K个数据,K为每批次的大小,用来判别深度生成器网络G的输出是来源于目标域还是源域;
深度分类器网络C包括依次相连接的输入层、第一卷积层、第一最大池化层、第二卷积层、第三卷积层、第二最大池化层、第一全连接层和第二全连接层;其中,深度分类器网络C的输入为每次从深度生成器网络G的输出图像及对应标签中随机抽取的K个数据,K为每批次的大小,输出为图像被判别为对应标签的概率。
进一步的,对构造的基于素描标注信息的生成对抗迁移学习网络进行训练的具体方法是:
首先,构造生成对抗损失函数LGAN、针对源域图像和深度生成器网络G的输出图像的边缘结构损失函数L1、针对目标域图像和深度生成器网络G的输出图像对应类别的边缘结构损失函数L2、分类器损失LC和整体损失函数L;
最终得到上述三个网络训练后权值。
更进一步的,整体损失函数L由生成对抗损失函数LGAN、针对源域图像和深度生成器网络G的输出图像的边缘结构损失函数L1、针对目标域图像和深度生成器网络G的输出图像对应类别的边缘结构损失函数L2和分类器损失函数LC加权求和得到:
其中,θG表示深度生成器网络G的参数,θD表示深度判别器网络D的参数,θC表示深度分类器C的参数,λ1、λ2、λ3和λ4均为超参数,具体地:
生成对抗损失函数LGAN表示为:
其中,θG表示深度生成器网络G的参数,θD表示深度判别器网络D的参数,D(·)表示判别器的输出,G(·)表示生成器的输出,K为每批次的大小;
其中,θG为深度生成器网络G的参数,表示第k张源域图像经过深度生成器网络G得到的输出图像,表示该输出图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示第k张源域图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示该边缘分割图中第n个边缘像素点对应的掩模矩阵,Nk为该边缘分割图中边缘像素点的个数,*为Hadamard积,表示矩阵对应位置元素相乘,||·||F表示矩阵的F范数。
其中,θG为深度生成器网络G的参数,表示第k张源域图像对应标签,表示第k张目标域图像对应伪标签,表示第k张源域图像经过深度生成器网络G得到的输出图像,表示该输出图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示第k张目标域图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示该边缘分割图中第n个边缘像素点对应的掩模矩阵,Nk为该边缘分割图中边缘像素点的个数,*为Hadamard积,表示矩阵对应位置元素相乘,K为每批次的大小,||·||F表示矩阵的F范数,其中,构造方式如下:
进一步的,训练深度分类器网络C时,分类器损失函数LC的表达式为:
与现有技术相比,本发明至少具有以下有益效果:
本发明一种基于素描标注信息的生成对抗迁移学习方法,构建基于素描标注信息的边缘分割深度网络,通过寻找到的源域与目标域样本结构的相似性,生成确定标签的符合目标域分布并与源域边缘结构相似的样本,生成的图像作为训练集,目标域图像作为测试集,此时训练数据与测试数据满足深度网络所要求的分布一致性,训练好的分类器可以对目标域进行良好的分类,从而获取没有标签信息的目标域图像的分类结果,完成跨域分类任务。
进一步的,利用Primal Sketch算法得到源域图像对应的初始素描图去标注图像的边缘信息,构造的数据集省去了人工标注边缘的成本。
进一步的,构造的基于素描标注信息的边缘分割网络利用S1得到的数据集训练,使用训练好的该网络可以直接得到输入图像的边缘结构信息,克服了现有模型不可微的缺点,使得本发明可以将结构约束应用到梯度下降优化算法中,使得整体网络得到更好的训练。
进一步的,采用F范数对目标域数据进行筛选及赋予伪标签操作,可以对生成器网络输出和目标域对应类别间的数据做结构约束,从而使生成器网络输出结构更相似于目标域数据的样本,有利于分类器网络对目标域数据的分类。
进一步的,本发明同时使用生成对抗损失函数、针对源域图像和深度生成器网络的输出图像的边缘结构损失函数、针对目标域图像和深度生成器网络的输出图像对应类别的边缘结构损失函数和分类器损失对生成器网络进行约束,克服了现有生成模型仅使用生成对抗损失,却没有关注到图像本身的结构信息,使得本发明可以生成确定类别的,符合目标域数据分布且结构与目标域数据相似的样本,提高了跨域分类的准确率。
综上所述,本发明通过约束生成器网络输出结果和源域数据及目标域数据结构的相似性,利用源域数据的标签信息,有效地对属于不同分布的无标签的目标域数据进行分类,提高了分类准确率,实现了跨域分类任务。
下面通过附图和实施例,对本发明的技术方案做进一步的详细描述。
附图说明
图1为本发明中基于素描标注信息的生成对抗迁移学习方法的网络框图;
图2为本发明中边缘标注示意图;
图3为本发明中基于素描标注信息的边缘分割深度网络T架构示意图;
图4为本发明中基于素描标注信息的生成对抗迁移学习网络架构示意图;其中,其中,(a)是深度生成器网络G的网络架构;(b)是深度判别器网络D的网络架构;(c)是深度分类器网络C的网络架构;
图5为基于素描标注信息的边缘分割深度网络源域图像仿真结果图,按列查看,从左往右第一列是源域图像,第二列是源域图像初始素描图,第三列是源域图像边缘分割概率图,第四列是源域图像边缘分割结果图;
图6为边缘分割网络目标域图像仿真结果图,按列查看,从左往右第一列是目标域图像,第二列是目标域图像边缘分割概率图,第三列是目标域图像边缘分割结果图;
图7为生成器网络输出结果图,按列查看,从左往右第一列为源域图像,第二列为深度生成器网络G的输出图像。
具体实施方式
本发明提供了一种基于素描标注信息的生成对抗迁移学习方法,获取初始素描图,构造形式为“源域图像-源域图像边缘标注图”的成对数据集;构造基于素描标注信息的边缘分割深度网络并训练;基于矩阵范数选择目标域样本;构造并训练基于素描标注信息的生成对抗迁移学习网络,该网络包括深度生成器网络、深度判别器网络、基于素描标注信息的边缘分割深度网络和深度分类器网络;输入目标域图像,得到目标域图像的分类结果;本发明利用源域数据及目标域数据结构的相似性,通过结构约束,生成确定标签的符合目标域分布的样本,从而进行标签的传递,实现跨域分类。
请参阅图1,本发明一种基于素描标注信息的生成对抗迁移学习方法,包括以下步骤:
请参阅图2,利用Primal Sketch算法得到源域图像对应的初始素描图,初始素描图的大小和源域图像大小相同,初始素描图中的素描线段描述了原图像中亮度发生突变的地方,而图像中的边界部分正是亮度发生突变的地方,因此,在本发明中,用素描线段上素描点的位置来表示对应图像上属于边缘的像素点的位置。
具体地,构造与第k张源域图像初始素描图大小一致的矩阵只让矩阵中对应于素描线段上素描点位置处的元素置为1,矩阵中其它位置处的元素置为0,称这样的矩阵为边缘标注图,从而获得形式为“源域图像-源域图像边缘标注图”的成对数据集其中,
S2、构造基于素描标注信息的边缘分割深度网络T,将步骤S1中的成对数据集分批次输入该网络中进行训练,每批次的大小为K;输出为得到的边缘分割图其中,且随机初始化该网络中各个卷积核的参数,得到初始化后的边缘分割网络;
请参阅图3,构造基于素描标注信息的边缘分割深度网络T,该边缘分割深度网络T包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第一反卷积层、第二反卷积层、第五卷积层、第六卷积层和输出层,输入层为28×28×1大小的原图及边缘图,输入层和输出层间的8层的滤波器尺寸分别为3,3,3,3,3,3,3,3,步长分别为1,2,1,2,2,2,1,1,特征映射图数目分别为64,128,256,256,256,128,64,2,1,输出层大小为28×28×1的二值图。
构造边缘分割损失函数LT,其优化目标为:
其中,θT表示该基于素描标注信息的边缘分割深度网络T的参数,为第k张源域图像经过基于素描标注信息的边缘分割深度网络T的输出,表示该输出图像第(i,j)位置实际的是边缘像素点的概率,表示和第k张源域图像边缘标注图对应的(i,j)位置希望是边缘像素点的概率,ω为权重参数(取5),K为每批次的大小,N为图像大小。
通过边缘分割损失函数LT,采用批随机梯度下降方法,对边缘分割深度网络T进行训练,得到训练后的网络权值,训练的具体方法是:
S201、设置训练批次大小n=64和迭代次数t=30,以及损失函数中包含的权重参数ω=5;
S203、通过批随机梯度下降的方法更新边缘分割网络T:
S204、重复S202至S203,直到达到迭代次数t;
S205、输出训练完成的边缘分割网络T的权值θT。
S3、样本选择;
S4、构造并训练基于素描标注信息的生成对抗迁移学习网络;
对抗迁移学习网络包括深度生成器网络G、深度判别器网络D、基于素描标注信息的边缘分割深度网络T和深度分类器网络C,如图4所示,将源域图像及对应的标签和步骤S3中得到的目标域图像及对应的伪标签分批次输入该网络中进行训练,每批次的大小为K;
深度生成器网络G包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第一反卷积层、第二反卷积层、第五卷积层、第六卷积层和输出层,输入层为28×28×1大小的源域图像,输入层和输出层间的8层的滤波器尺寸分别为3,3,3,3,3,3,3,3,步长分别为1,2,1,2,2,2,1,1,特征映射图数目分别为64,128,256,256,256,128,64,1,1,输出层大小为28×28×1的灰度图像;
深度判别器网络D包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第五卷积层、第一全连接层和二分类器,输入层的输入为28×28×1大小的图像,输入层和输出层间的5层的滤波器尺寸分别为5,5,2,2,2,步长分别为2,2,2,2,2,特征映射图数目分别为64,128,256,512,1024,全连接层的节点为100,1,输出一个标量;
深度判别器网络D的输入为每次分别从目标域图像和深度生成器网络G的输出图像中随机抽取的K个数据,K为每批次的大小,用来判别深度生成器网络G的输出是来源于目标域还是源域,随机初始化该网络中各个卷积核的参数,得到初始化后的网络;
深度分类器网络C包括依次相连接的输入层、第一卷积层、第一最大池化层、第二卷积层、第三卷积层、第二最大池化层、第一全连接层和第二全连接层,输入层的输入为28×28×1大小的图像,输入层和输出层间的5层的滤波器尺寸分别为5,2,3,3,2,步长分别为1,2,1,1,2,特征映射图数目分别为32,32,64,64,64,全连接层的节点为256,10;
对构造的基于素描标注信息的生成对抗迁移学习网络进行训练的具体方法是:
构造生成对抗损失函数LGAN、针对源域图像和深度生成器网络G的输出图像的边缘结构损失函数L1、针对目标域图像和深度生成器网络G的输出图像对应类别的边缘结构损失函数L2、分类器损失LC和整体损失函数L;通过上面构造的损失函数、源域图像及对应的标签和步骤S3挑选出的目标域图像及对应的伪标签结合批随机梯度下降方法,对深度生成器网络G、深度判别器网络D和深度分类器网络C依次进行交替训练,最终得到上述三个网络训练后权值。
整体损失函数L由生成对抗损失函数LGAN、针对源域图像和深度生成器网络G的输出图像的边缘结构损失函数L1、针对目标域图像和深度生成器网络G的输出图像对应类别的边缘结构损失函数L2和分类器损失函数LC加权求和得到:
其中,θG表示深度生成器网络G的参数,θD表示深度判别器网络D的参数,θC表示深度分类器C的参数,λ1、λ2、λ3和λ4均为超参数,具体地:
其中,θG表示深度生成器网络G的参数,θD表示深度判别器网络D的参数,D(·)表示判别器的输出,G(·)表示生成器的输出,K为每批次的大小。
其中,θG为深度生成器网络G的参数,表示第k张源域图像经过深度生成器网络G得到的输出图像,表示该输出图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示第k张源域图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示该边缘分割图中第n个边缘像素点对应的掩模矩阵,Nk为该边缘分割图中边缘像素点的个数,*为Hadamard积,表示矩阵对应位置元素相乘,||·||F表示矩阵的F范数。
其中,θG为深度生成器网络G的参数,表示第k张源域图像对应标签,表示第k张目标域图像对应伪标签,表示第k张源域图像经过深度生成器网络G得到的输出图像,表示该输出图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示第k张目标域图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示该边缘分割图中第n个边缘像素点对应的掩模矩阵,Nk为该边缘分割图中边缘像素点的个数,*为Hadamard积,表示矩阵对应位置元素相乘,K为每批次的大小,||·||F表示矩阵的F范数。
分类器损失函数LC的表达式为:
借助批随机梯度下降方法,对深度判别器网络D、深度生成器网络G和深度分类器网络C依次进行交替训练的具体步骤为:
S401、设置训练批次大小n=64和迭代次数t=1000,以及损失函数中包含的四种权重参数λ1=0.25,λ2=0.2,λ3=0.2,λ4=0.35;
S404、通过批随机梯度下降的方法更新判别器网络D:
S405、通过批随机梯度下降的方法更新生成器网络G:
S406、通过批随机梯度下降的方法更新分类器网络C:
S407、重复S402至S406,直到达到迭代次数t;
S408、输出训练完成的深度生成器网络G的权值θG、深度判别器网络D的权值θD和深度分类器网络C的权值θC。
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。通常在此处附图中的描述和所示的本发明实施例的组件可以通过各种不同的配置来布置和设计。因此,以下对在附图中提供的本发明的实施例的详细描述并非旨在限制要求保护的本发明的范围,而是仅仅表示本发明的选定实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
下面结合仿真图对本发明的效果做进一步的描述。
1.仿真条件:
本发明仿真的硬件平台为:HP Z840;软件平台为:Tensorflow;本发明使用的源域图像为MNIST手写体数据集,包含训练集样本60000张,目标域图像为USPS,包含训练集样本7291张。
2.仿真内容与结果:
用本发明方法在上述仿真条件下进行实验,首先从“源域图像-源域图像边缘标注图”的成对数据集中选取成对样本对基于素描标注信息的边缘分割深度网络进行训练。得到如图5的结果,从左往右第一列是源域图像,第二列是源域图像初始素描图,第三列是源域图像边缘分割概率图,第四列是源域图像边缘分割结果图。以及图6的结果,从左往右第一列是目标域图像,第二列是目标域图像边缘分割概率图,第三列是目标域图像边缘分割结果图。再通过选取的目标域样本,每类500张及源域图像对基于素描标注信息的生成对抗迁移学习网络进行训练,如图7为对应的生成器结果,按列查看,从左往右第一列为源域图像,第二列为深度生成器网络的输出图像。
从图7可以看出生成器生成的图像和目标域的分布具有很大的相似性,同时,保持了结构的一致性。最终的分类结果如下:
表1
方法 | 分类准确率 |
只用源域图像训练分类器 | 75.1% |
本发明方法(去掉边缘分割网络) | 69.6% |
本发明方法 | 93.2% |
从表1的结果看本发明的方法取得了较好的分类结果。
综上所述,本发明使用的基于素描标注信息的生成对抗迁移学习方法,能够利用源域的标签信息及源域和目标域样本结构的相似性,有效的对无标签信息的目标域图像进行分类,通过对图像通过生成器前后的结构及源域和目标域图像结构的约束,达到了生成确定标签的符合目标域分布且结构一致样本的目的。
以上内容仅为说明本发明的技术思想,不能以此限定本发明的保护范围,凡是按照本发明提出的技术思想,在技术方案基础上所做的任何改动,均落入本发明权利要求书的保护范围之内。
Claims (10)
1.一种基于素描标注信息的生成对抗迁移学习方法,其特征在于,包括如下步骤:
S4、构造基于素描标注信息的生成对抗迁移学习网络,其中,对抗迁移学习网络包括深度生成器网络G、深度判别器网络D、基于素描标注信息的边缘分割深度网络T和深度分类器网络C,将源域图像及对应的标签和步骤S3中得到的目标域图像及对应的伪标签分批次输入对抗迁移学习网络中进行训练,每批次的大小为K;
3.根据权利要求1所述的基于素描标注信息的生成对抗迁移学习方法,其特征在于,步骤S2具体为:
首先,构造基于素描标注信息的边缘分割深度网络T,边缘分割深度网络T包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第一反卷积层、第二反卷积层、第五卷积层、第六卷积层和输出层;其中,边缘分割深度网络T的输入为每次从步骤S1中构造的形式为“源域图像-源域图像边缘标注图”的成对数据集中随机抽取的K对数据,K为每批次的大小,输出为得到的边缘分割图其中,
其次,对基于素描标注信息的边缘分割深度网络T进行训练,训练具体为:
构造边缘分割损失函数LT,其优化目标为:
其中,θT表示该基于素描标注信息的边缘分割深度网络T的参数,为第k张源域图像经过基于素描标注信息的边缘分割深度网络T的输出,表示该输出图像第(i,j)位置实际是边缘像素点的概率,表示和第k张源域图像边缘标注图对应的(i,j)位置希望是边缘像素点的概率,ω为权重参数,K为每批次的大小,N为图像大小;
5.根据权利要求1所述的基于素描标注信息的生成对抗迁移学习方法,其特征在于,步骤S4具体为:
构造基于素描标注信息的生成对抗迁移学习网络并进行训练,对抗迁移学习网络包括深度生成器网络G、深度判别器网络D、基于素描标注信息的边缘分割深度网络T和深度分类器网络C;
深度生成器网络G包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第一反卷积层、第二反卷积层、第五卷积层、第六卷积层和输出层;深度生成器网络G的输入为每次从源域图像及对应的标签中随机抽取的K个数据,K为每批次的大小,输出为图像及对应标签
深度判别器网络D包括依次相连接的输入层、第一卷积层、第二卷积层、第三卷积层,第四卷积层、第五卷积层、第一全连接层和二分类器;深度判别器网络D的输入为每次分别从目标域图像和深度生成器网络G的输出图像中随机抽取的K个数据,K为每批次的大小,用来判别深度生成器网络G的输出是来源于目标域还是源域;
7.根据权利要求6所述的基于素描标注信息的生成对抗迁移学习方法,其特征在于,整体损失函数L由生成对抗损失函数LGAN、针对源域图像和深度生成器网络G的输出图像的边缘结构损失函数L1、针对目标域图像和深度生成器网络G的输出图像对应类别的边缘结构损失函数L2和分类器损失函数LC加权求和得到:
其中,θG表示深度生成器网络G的参数,θD表示深度判别器网络D的参数,θC表示深度分类器C的参数,λ1、λ2、λ3和λ4均为超参数,具体地:
生成对抗损失函数LGAN表示为:
其中,θG表示深度生成器网络G的参数,θD表示深度判别器网络D的参数,D(·)表示判别器的输出,G(·)表示生成器的输出,K为每批次的大小;
其中,θG为深度生成器网络G的参数,表示第k张源域图像对应标签,表示第k张目标域图像对应伪标签,表示第k张源域图像经过深度生成器网络G得到的输出图像,表示该输出图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示第k张目标域图像经过基于素描标注信息的边缘分割深度网络T得到的边缘分割图,表示该边缘分割图中第n个边缘像素点对应的掩模矩阵,Nk为该边缘分割图中边缘像素点的个数,*为Hadamard积,表示矩阵对应位置元素相乘,K为每批次的大小,||·||F表示矩阵的F范数,其中,构造方式如下:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910401740.9A CN110210486B (zh) | 2019-05-15 | 2019-05-15 | 一种基于素描标注信息的生成对抗迁移学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910401740.9A CN110210486B (zh) | 2019-05-15 | 2019-05-15 | 一种基于素描标注信息的生成对抗迁移学习方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN110210486A CN110210486A (zh) | 2019-09-06 |
CN110210486B true CN110210486B (zh) | 2021-01-01 |
Family
ID=67787245
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910401740.9A Active CN110210486B (zh) | 2019-05-15 | 2019-05-15 | 一种基于素描标注信息的生成对抗迁移学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110210486B (zh) |
Families Citing this family (20)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110555060B (zh) * | 2019-09-09 | 2023-05-02 | 山东省计算中心(国家超级计算济南中心) | 基于成对样本匹配的迁移学习方法 |
CN110610207B (zh) * | 2019-09-10 | 2022-11-25 | 重庆邮电大学 | 一种基于迁移学习的小样本sar图像舰船分类方法 |
CN112733864B (zh) * | 2019-09-16 | 2023-10-31 | 北京迈格威科技有限公司 | 模型训练方法、目标检测方法、装置、设备及存储介质 |
CN110705406B (zh) * | 2019-09-20 | 2022-11-15 | 五邑大学 | 基于对抗迁移学习的人脸美丽预测方法及装置 |
CN110837850B (zh) * | 2019-10-23 | 2022-06-21 | 浙江大学 | 一种基于对抗学习损失函数的无监督域适应方法 |
CN111242157A (zh) * | 2019-11-22 | 2020-06-05 | 北京理工大学 | 联合深度注意力特征和条件对抗的无监督域自适应方法 |
CN111292384B (zh) * | 2020-01-16 | 2022-05-20 | 西安交通大学 | 基于生成式对抗网络的跨域多样性图像生成方法及系统 |
CN111275713B (zh) * | 2020-02-03 | 2022-04-12 | 武汉大学 | 一种基于对抗自集成网络的跨域语义分割方法 |
CN111414888A (zh) * | 2020-03-31 | 2020-07-14 | 杭州博雅鸿图视频技术有限公司 | 低分辨率人脸识别方法、系统、装置及存储介质 |
CN111598914B (zh) * | 2020-05-12 | 2022-05-06 | 湖南大学 | 一种基于不确定性引导的自适应图像分割方法 |
CN111598160B (zh) * | 2020-05-14 | 2023-04-07 | 腾讯科技(深圳)有限公司 | 图像分类模型的训练方法、装置、计算机设备及存储介质 |
CN111654368B (zh) * | 2020-06-03 | 2021-10-08 | 电子科技大学 | 一种基于深度学习生成对抗网络的密钥生成方法 |
CN111783980B (zh) * | 2020-06-28 | 2023-04-07 | 大连理工大学 | 基于双重协作生成式对抗网络的排序学习方法 |
CN113807529A (zh) * | 2020-07-31 | 2021-12-17 | 北京沃东天骏信息技术有限公司 | 机器学习模型的训练方法和装置、图像的分类方法和装置 |
CN112633579B (zh) * | 2020-12-24 | 2024-01-12 | 中国科学技术大学 | 一种基于域对抗的交通流迁移预测方法 |
CN112699809B (zh) * | 2020-12-31 | 2023-08-01 | 深圳数联天下智能科技有限公司 | 痘痘类别识别方法、装置、计算机设备及存储介质 |
CN112861977B (zh) * | 2021-02-19 | 2024-01-26 | 中国人民武装警察部队工程大学 | 迁移学习数据处理方法、系统、介质、设备、终端及应用 |
CN113298855B (zh) * | 2021-05-27 | 2021-12-28 | 广州柏视医疗科技有限公司 | 基于自动勾画的图像配准方法 |
CN114612961B (zh) * | 2022-02-15 | 2023-04-07 | 哈尔滨工业大学(深圳) | 一种多源跨域表情识别方法、装置及存储介质 |
CN115392484B (zh) * | 2022-08-25 | 2024-07-02 | 上海人工智能创新中心 | 用于计算机视觉任务中深度学习算法的数据传递方法 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109190684A (zh) * | 2018-08-15 | 2019-01-11 | 西安电子科技大学 | 基于素描及结构生成对抗网络的sar图像样本生成方法 |
CN109359541A (zh) * | 2018-09-17 | 2019-02-19 | 南京邮电大学 | 一种基于深度迁移学习的素描人脸识别方法 |
CN109711426A (zh) * | 2018-11-16 | 2019-05-03 | 中山大学 | 一种基于gan和迁移学习的病理图片分类装置及方法 |
WO2019090023A1 (en) * | 2017-11-03 | 2019-05-09 | General Electric Company | System and method for interactive representation learning transfer through deep learning of feature ontologies |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109190750B (zh) * | 2018-07-06 | 2021-06-08 | 国家计算机网络与信息安全管理中心 | 基于对抗生成网络的小样本生成方法及装置 |
-
2019
- 2019-05-15 CN CN201910401740.9A patent/CN110210486B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019090023A1 (en) * | 2017-11-03 | 2019-05-09 | General Electric Company | System and method for interactive representation learning transfer through deep learning of feature ontologies |
CN109190684A (zh) * | 2018-08-15 | 2019-01-11 | 西安电子科技大学 | 基于素描及结构生成对抗网络的sar图像样本生成方法 |
CN109359541A (zh) * | 2018-09-17 | 2019-02-19 | 南京邮电大学 | 一种基于深度迁移学习的素描人脸识别方法 |
CN109711426A (zh) * | 2018-11-16 | 2019-05-03 | 中山大学 | 一种基于gan和迁移学习的病理图片分类装置及方法 |
Non-Patent Citations (2)
Title |
---|
Transfer Feature Learning with Joint Distribution Adaptation;Mingsheng L.等;《2013 IEEE International Conference on Computer Vision》;20140303;第2202-2207页 * |
基于生成对抗网络的迁移学习算法研究;臧文华;《中国优秀硕士学位论文全文数据库 信息科技辑》;20180915;第I140-10页 * |
Also Published As
Publication number | Publication date |
---|---|
CN110210486A (zh) | 2019-09-06 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110210486B (zh) | 一种基于素描标注信息的生成对抗迁移学习方法 | |
CN111368896B (zh) | 基于密集残差三维卷积神经网络的高光谱遥感图像分类方法 | |
CN110516596B (zh) | 基于Octave卷积的空谱注意力高光谱图像分类方法 | |
CN111079795B (zh) | 基于cnn的分片多尺度特征融合的图像分类方法 | |
CN107563422A (zh) | 一种基于半监督卷积神经网络的极化sar分类方法 | |
CN112347970B (zh) | 一种基于图卷积神经网络的遥感影像地物识别方法 | |
CN106203523A (zh) | 基于梯度提升决策树半监督算法融合的高光谱图像分类 | |
CN108846426A (zh) | 基于深度双向lstm孪生网络的极化sar分类方法 | |
CN111882040A (zh) | 基于通道数量搜索的卷积神经网络压缩方法 | |
CN110619059B (zh) | 一种基于迁移学习的建筑物标定方法 | |
CN110728295B (zh) | 半监督式的地貌分类模型训练和地貌图构建方法 | |
CN107292336A (zh) | 一种基于dcgan的极化sar图像分类方法 | |
CN106203625A (zh) | 一种基于多重预训练的深层神经网络训练方法 | |
CN106651887A (zh) | 一种基于卷积神经网络的图像像素分类方法 | |
CN105989336A (zh) | 基于带权重的解卷积深度网络学习的场景识别方法 | |
CN104992183A (zh) | 自然场景中的显著目标的自动检测方法 | |
CN109816030A (zh) | 一种基于受限玻尔兹曼机的图像分类方法及装置 | |
CN112416293A (zh) | 一种神经网络增强方法、系统及其应用 | |
Tang et al. | Target Category Agnostic Knowledge Distillation With Frequency-Domain Supervision | |
CN109118483A (zh) | 一种标签质量检测方法及装置 | |
CN114898417B (zh) | 一种基于协调注意力深度神经网络的菊头蝠识别方法 | |
CN112508958B (zh) | 一种轻量多尺度的生物医学图像分割方法 | |
CN111639659A (zh) | 一种水下沉底小目标融合分类方法 | |
CN109345537A (zh) | 基于高阶多尺度crf半监督的sar图像分割方法 | |
CN114627005A (zh) | 一种雨密度分类引导的双阶段单幅图像去雨方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |