CN116342942A - 基于多级域适应弱监督学习的跨域目标检测方法 - Google Patents
基于多级域适应弱监督学习的跨域目标检测方法 Download PDFInfo
- Publication number
- CN116342942A CN116342942A CN202310258566.3A CN202310258566A CN116342942A CN 116342942 A CN116342942 A CN 116342942A CN 202310258566 A CN202310258566 A CN 202310258566A CN 116342942 A CN116342942 A CN 116342942A
- Authority
- CN
- China
- Prior art keywords
- domain
- target
- network
- style
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 118
- 230000006978 adaptation Effects 0.000 title claims abstract description 16
- 238000012549 training Methods 0.000 claims abstract description 45
- 238000000034 method Methods 0.000 claims abstract description 26
- 230000003044 adaptive effect Effects 0.000 claims abstract description 20
- 230000005012 migration Effects 0.000 claims abstract description 19
- 238000013508 migration Methods 0.000 claims abstract description 19
- 230000008569 process Effects 0.000 claims abstract description 18
- 238000000605 extraction Methods 0.000 claims description 22
- 230000006870 function Effects 0.000 claims description 19
- 238000011176 pooling Methods 0.000 claims description 15
- 239000013598 vector Substances 0.000 claims description 12
- 208000037170 Delayed Emergence from Anesthesia Diseases 0.000 claims description 7
- 238000013507 mapping Methods 0.000 claims description 5
- 238000005070 sampling Methods 0.000 claims description 5
- 238000013459 approach Methods 0.000 claims description 4
- 230000004913 activation Effects 0.000 claims description 3
- 239000000284 extract Substances 0.000 claims description 3
- 210000002569 neuron Anatomy 0.000 claims description 3
- 238000010606 normalization Methods 0.000 claims description 3
- 230000009467 reduction Effects 0.000 claims description 3
- 230000001172 regenerating effect Effects 0.000 claims description 3
- 230000007480 spreading Effects 0.000 claims description 3
- 238000013527 convolutional neural network Methods 0.000 description 11
- 238000013528 artificial neural network Methods 0.000 description 4
- 238000010586 diagram Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000004807 localization Effects 0.000 description 2
- 238000013526 transfer learning Methods 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 210000000988 bone and bone Anatomy 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000005286 illumination Methods 0.000 description 1
- 238000003709 image segmentation Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000002093 peripheral effect Effects 0.000 description 1
- 238000011084 recovery Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
- 229940088594 vitamin Drugs 0.000 description 1
- 229930003231 vitamin Natural products 0.000 description 1
- 235000013343 vitamin Nutrition 0.000 description 1
- 239000011782 vitamin Substances 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了基于多级域适应弱监督学习的跨域目标检测方法,利用MUNIT风格迁移由源域DS生成接近目标域DT的中间域DG数据,用源域DS数据集预训练得到目标检测模型,用其为目标域DT和中间域DG数据打上伪标签,实现一种弱监督跨域迁移条件,有助于跨域检测;多层次使用域自适应分类器,在图像级既保证全局领域特征对齐,又保证局部领域特征的对齐,且不改变源域与目标域中数据之间的区别信息,增强了自适应模型的鲁棒性。在实例级也进行域特征对齐,针对目标检测这个特定任务做出改善;训练过程采取源域DS到中间域DG,再到目标域DT的顺序渐进地适应域差异,由目标检测损失和域迁移损失共同使网络收敛,提高检测模型的性能。
Description
技术领域
本发明属于机器学习中的迁移学习技术领域,具体涉及基于多级域适应弱监督学习的跨域目标检测方法。
背景技术
在计算机视觉领域中,目标检测是计算机视觉领域中的一项基本任务,它由图像分类任务发展而来,区别在于不再仅仅只对一张图像中的单一类型目标进行分类,而是要同时完成一张图像里可能存在的多个目标的分类和定位,其中分类是指给目标分配类别标签,定位是指确定目标的外围矩形框的顶点坐标。因此,目标检测任务更具有挑战性,也有着更广阔的应用前景,比如自动驾驶、人脸识别、行人检测、医疗检测等等。同时,目标检测也可以作为图像分割、图像描述、目标跟踪、动作识别等更复杂的计算机视觉任务的研究基础。现有的图像目标检测其通常可以被分为两类:一类是two-stage检测器,最具代表性的是Faster R-CNN。另一种是one-stage检测器,如YOLO、SSD。two-stage检测器具有较高的定位和目标识别精度,而one-stage检测器具有较高的推理速度。
近年来,使用有监督深度学习的对象检测已经显示出令人印象深刻的结果,但它在跨领域环境中仍然具有挑战性。对于许多实际任务来说,基于深度学习的目标检测器需要大量带边界框和类标注的样本,标注大规模数据集以训练卷积神经网络的成本高得令人望而却步且耗时,同时不同域中的光照、风格、尺度和外观等的变化也会严重影响检测器的性能。
目前域自适应学习已经成为解决数据标注和领域偏移问题的有效手段。域自适应学习利用已有标注的与目标数据相似的数据集,例如具有相同的类别,来作为源域,通过与未标注的目标域数据进行显式的数据特征对齐,利用源域和目标域同时进行迁移学习,进而获得在目标域上表现尚可的模型。现有的域自适应任务尚存在许多不足之处。具体的,第一,现有迁移学习大多应用在图片分类领域,目标检测领域应用较少;第二,现有的域自适现有技术在无监督领域自适应过程中,无法适应前后数据集差异过大的情况,对数据集要求很高,并且领域自适应后模型精度不高,模型泛化性不强。第三,应用于目标检测的域自适应的方法都使用对抗性训练来对齐跨域转换的全局特征,并实现图像信息传递。然而,这种方法不能有效地匹配局部特征的分布,导致跨域对象检测的改进有限。
发明内容
本发明的目的是提供基于多级域适应弱监督学习的跨域目标检测方法,解决了实际应用情况中当源域具实例级标签,而目标域中只有样本级标签时,训练后的目标检测模型跨域检测准确率低的问题。
本发明所采用的技术方案是,基于多级域适应弱监督学习的跨域目标检测方法,具体按照以下步骤实施:
步骤1、获取源域DS数据和目标域DT数据,构成目标检测数据集,构建MUNIT网络,利用源域DS数据和目标域DT的数据集对MUNIT网络进行训练并生成介于目标域和源域之间的中间域数据集DG;
步骤2、构建Faster RCNN网络作为目标网络,采用源域Ds的数据作为训练集对目标检测器进行训练并得到初步预训练的检测模型,将中间域DG和目标域DT数据送入检测网络中打上伪标签;
步骤3、构建图像级和实际级的域分类器,加入步骤2初步预训练的检测模型中,得到具有域自适应的目标检测器;
步骤4、按照源域DS、中间域DG、目标域DT的顺序渐进将数据集输入具有域自适应的目标检测器进行训练,渐进地适应域差异,得到训练好的目标检测模型;
步骤5、采用训练好的目标检测模型对目标域的数据集进行目标检测,得到检测结果。
本发明的特点还在于:
步骤1中MUNIT网络包括生成器和判别器,生成器包括风格编码器、内容编码器、特征交叉模块、解码器,生成器利用源域DS数据生成接近目标域DS的中间域DG数据,判别器用于判断输入的数据是真实的源域DS数据还是生成的中间域DG数据。
步骤1具体过程为:
步骤1.1、把源域Ds和目标域DT的数据集以不同的风格和内容的分类标准将其分为相应类数,从源域Ds中选定一个子集X1,并从目标域DT选定中也抽取一个子集X2;
步骤1.2、将子集X1和X2分别输入内风格编码器、内容编码器进行降维操作,得到两种图像特征向量;通过若干个卷积层对其中一种图像特征向量进行降采样,使用深度残差网络会用到的残差块生成低维的内容编码;先通过若干个卷积层对另一种图像特征向量进行降采样,然后经过一个全局池化层和一个全连接层,最后生成低维的风格编码;
步骤1.3、将子集X1的低维的内容编码与子集X2的低维的风格编码融合,产生风格迁移的图像编码特征,向风格迁移的图像编码特征中加入高斯噪声后进行交叉,得到新结合的编码特征,对新结合的编码特征用解码器升维生成结果图像;
步骤1.4、将生成结果图像根据风格编码器和内容编码器再次分解成两个编码特征,对于低维的内容编码、低维的风格编码计算误差反向传播,重新调整MUNIT网络参数;将生成的结果图像分别输入风格编码器和内容编码器,生成新的风格编码和新的内容编码,计算步骤1.2中低维的内容编码、低维的风格编码与新的风格编码和新的内容编码之间的差距损失,设置偏差阈值,当差距损失超过偏差阈值时,将这个损失反向传播,重新调整MUNIT网络参数;
步骤1.5、将子集X1、子集X2、步骤1.3中生成结果图像输入GAN网络,进行判别和对抗训练;利用梯度反转来更新对抗损失;训练后的MUNIT神经网络使用步骤1.1中分类好的源域DS和目标域DT的数据集来生成中间域DG数据集。
步骤1.3中将子集X1的低维的内容编码与子集X2的低维的风格编码融合具体过程为:
将子集X2的低维的风格编码先由多层感知器动态生成参数,再经过自适应实例规范化层,得到可进行融合的低维的风格编码,将可进行融合的低维的风格编码和子集X1的低维的内容编码在残差块中进行糅合,然后进行上采样得到风格迁移的图像编码特征。
GAN网络的损失函数表示为:
其中,图像的风格特征和内容特征x服从p(x)分布,G为图像的域内生成器,EC(x),Es(x)分别为图像的内容和风格特征;
风格重建损失和内容重建损失:
接着是对抗损失,就是用判别器判断真假,GAN网络的定义:
GAN网络用到两组生成器和判别器,MUNIT网络的目标函数就是上述几个损失的和:
其中E表示编码器,G表示生成器,D表示判别器。
Faster RCNN网络包括候选检测框生成网络和Fast R-CNN网络,Fast R-CNN网络由特征提取部分、Roi池化层、分类器3个部分构成,特征提取部分,用于提取整张图片的特征,得到特征图,候选检测框生成网络通过softmax函数判断锚框属于正例或者反例,再利用边框回归修正锚框获得候选区域,Roi池化层收集输入的特征图和候选区域,提取候选框特征图,送入分类器判定目标类别。
步骤2中具体过程为:
步骤2.1、构建Faster RCNN网络作为目标网络,采用源域Ds的数据作为训练集,特征提取部分使用训练集中数据训练Faster RCNN网络完成初始化权重,对源域Ds的样本抽样后的256个正负例anchor框训练候选检测框生成网络、Fast R-CNN网络中特征提取部分,其中,特征提取部分权重参与调整;
步骤2.2、使用步骤2.1训练好的候选检测框生成网络,生成正例预测框,供分类器网络进行训练;此时特征提取部分权值也使用源域Ds的样本进行训练更新网络参数,将得到的候选检测框生成网络和Fast R-CNN网络作为目标检测器;
步骤2.3、再次生成利用源域Ds的实例级标签输入目标检测器,得到候选区域的特征图,将候选区域的特征图送入分类器中进行类别判断与检测框回归,获得前向传播的损失,反传梯度,更新权重参数,减小损失,得到一个初步的预训练检测模型;
步骤2.4、利用初步的预训练检测模型对中间域DG样本和目标域DT的样本进行检测,对检测结果打上伪标签,即图像中物体的位置框坐标和类别,保存目标检测结果。
预训练检测模型的损失函数表示为:
其中为小批量中锚点的索引,pi是锚点/>作为目标的预测概率,/>为真值,当anchor为正时,/>为1,当anchor为负时,/>为0,ti是预测边界框的四个参数化坐标的向量,/>是与正锚框相关联的真实框的坐标,LC是两个类别的分类损失,Lr是边界框回归的损失,{pi},{ti}分别表示分类层和回归层的输出。
域分类器的结构由梯度反转层(GRL)、多层的全连接层后接Relu激活函数,最后接含一个神经元全连接层加Logistic损失函数构成,且所有层的权重因子设置为相等,域分类器用于判断输入的特征是来自源域DS还是来自目标域DT或中间域DG,来自源域DS则为1,为0。
步骤3具体过程为:
首先,图像级全局域判别器在图像阶段自适应使用特征提取部分最后一个卷积层之后的特征映射来对齐不同域的全局特征分布;
其次,建立多个图像级局部域判别器在卷积网络中提取多个中间层的输出特征映射来监督中间层的局部特征对齐;
最后,在目标检测模型Roi池化层提取后的特征进行特征对齐,放置实例级域分类器。
本发明有益效果是:
1.利用MUNIT风格迁移的方法由源域DS生成接近目标域DT的中间域DG数据,同时用源域DS数据集预训练得到的目标检测模型,用其为目标域DT和中间域DG数据打上伪标签,实现一种弱监督跨域迁移条件,有助于跨域检测。
2.多层次使用域自适应分类器,在图像级既保证全局领域特征对齐,又保证局部领域特征的对齐,同时且不改变源域与目标域中数据之间的区别信息,增强了自适应模型的鲁棒性。在实例级也进行域特征对齐,针对目标检测这个特定任务做出改善。
3.训练过程采取源域DS到中间域DG,再到目标域DT的顺序渐进地适应域差异,由目标检测损失和域迁移损失共同使网络收敛,逐步提高检测模型的性能。
附图说明
图1是本发明基于领域自适应的目标检测方法的网络结构图;
图2是本发明MUNIT网络的网络结构示意图;
图3是本发明中带有域自适应的Faster RCNN的网络结构示意图;
图4是本发明中网络训练过程的网络结构示意图。
图5是未加入域自适应检测效果图对比图;
图6是加入域自适应检测效果图对比图。
具体实施方式
下面结合附图及具体实施方式对本发明进行详细说明。
本发明基于多级域适应弱监督学习的跨域目标检测方法,使用的网络结构如图1所示,具体按照以下步骤实施:
步骤1、获取源域DS数据和目标域DT数据,构成目标检测数据集,构建MUNIT网络,利用源域DS数据和目标域DT的数据集对MUNIT网络进行训练并生成介于目标域和源域之间的中间域数据集DG;
如图2所示,MUNIT网络包括生成器和判别器,生成器包括风格编码器、内容编码器、特征交叉模块、解码器,生成器利用源域DS数据生成接近目标域DS的中间域DG数据,判别器用于判断输入的数据是真实的源域DS数据还是生成的中间域DG数据。
而判别器的作用是负责判断输入的数据是真实的源域DS数据还是生成的中间域DG数据。生成器要不断优化自己生成的数据让判别网络判断不出来,判别网络也要优化自己让自己判断得更准确,二者关系形成对抗(即对抗网络)。值得注意的是,MUNIT的生成器有自己的特色,它由特定神经网络组合而成的风格编码器、内容编码器和解码器三部分构成,其目的是输入一张图片,在保证图片主体内容不变的前提下,给图片换一种风格模式,反之亦行。
步骤1具体过程为:
步骤1.1、把源域DS和目标域DT的数据集以不同的风格和内容的分类标准将其分为相应类数,从源域Ds中选定一个子集X1,并从目标域DT选定中也抽取一个子集X2;两个子集的图像或内容风格具有一定相似性但相互之间又具有差异性。利用两类子集设计并预训练好一个具备编码内容和风格以及解码恢复功能的神经网络,过程中需要保证同一张图片T1可以被还原出相似的域中图片T1,即所谓的让T1的域内部重构的损失降到最小;
步骤1.2、将子集X1和X2分别输入内风格编码器、内容编码器进行降维操作,得到两种图像特征向量;通过若干个卷积层对其中一种图像特征向量进行降采样,使用深度残差网络会用到的残差块(Residual Blocks)生成低维的内容编码;先通过若干个卷积层对另一种图像特征向量进行降采样,然后经过一个全局池化层和一个全连接层,最后生成低维的风格编码;
步骤1.3、将子集X1的低维的内容编码与子集X2的低维的风格编码融合,产生风格迁移的图像编码特征,向风格迁移的图像编码特征中加入高斯噪声后进行交叉,目的是提升网络的鲁棒性,得到新结合的编码特征,对新结合的编码特征用解码器升维生成结果图像;
将子集X1的低维的内容编码与子集X2的低维的风格编码融合具体过程为:
将子集X2的低维的风格编码先由多层感知器动态生成参数,再经过自适应实例规范化层,得到可进行融合的低维的风格编码,将可进行融合的低维的风格编码和子集X1的低维的内容编码在残差块中进行糅合,然后进行上采样得到风格迁移的图像编码特征。
步骤1.4、将生成结果图像根据风格编码器和内容编码器再次分解成两个编码特征,对于低维的内容编码、低维的风格编码计算误差反向传播,重新调整MUNIT网络参数;将生成的结果图像分别输入风格编码器和内容编码器,生成新的风格编码和新的内容编码,计算步骤1.2中低维的内容编码、低维的风格编码与新的风格编码和新的内容编码之间的差距损失,设置偏差阈值,当差距损失超过偏差阈值时,将这个损失反向传播,重新调整MUNIT网络参数;
步骤1.5、将子集X1、子集X2、步骤1.3中生成结果图像输入GAN网络,进行判别和对抗训练;利用梯度反转来更新对抗损失;训练后的MUNIT神经网络使用步骤1.1中分类好的源域DS和目标域DT的数据集来生成中间域DG数据集。
GAN网络的损失函数表示为:
其中,图像的风格特征和内容特征x服从p(x)分布,G为图像的域内生成器,EC(x),Es(x)分别为图像的内容和风格特征;
接着是对抗损失,就是用判别器判断真假,GAN网络的定义:
GAN网络用到两组生成器和判别器,MUNIT网络的目标函数就是上述几个损失的和:
其中E表示编码器,G表示生成器,D表示判别器。
步骤2、构建Faster RCNN网络作为目标网络:Faster RCNN网络包括候选检测框生成网络(RPN)和Fast R-CNN网络,Fast R-CNN网络由特征提取部分(backbone)、Roi池化层(Roi Pooling)、分类器(Classification)3个部分构成,特征提取部分,用于提取整张图片的特征,得到特征图,例如VGG16,去除其中的全连接层,只留下卷基层,输出下采样后的特征图。用一串卷积层和池化层从原图中提取出特征图;候选检测框生成网络通过softmax函数判断锚框属于正例或者反例,再利用边框回归修正锚框获得候选区域,Roi池化层收集输入的特征图和候选区域,提取候选框特征图,送入分类器判定目标类别。利用候选框特征图计算候选区域的类别,同时再次边框回归获得检测框最终的精确位置。
步骤2具体过程为:
步骤2.1、构建Faster RCNN网络作为目标网络,采用源域Ds的数据作为训练集,特征提取部分使用训练集中数据训练Faster RCNN网络完成初始化权重,对源域Ds的样本抽样后的256个正负例anchor框训练候选检测框生成网络、Fast R-CNN网络中特征提取部分,其中,特征提取部分权重参与调整;
步骤2.2、使用步骤2.1训练好的候选检测框生成网络,生成正例预测框,供分类器网络进行训练;此时特征提取部分权值也使用源域Ds的样本进行训练更新网络参数,将得到的候选检测框生成网络和Fast R-CNN网络作为目标检测器;
步骤2.3、再次生成利用源域Ds的实例级标签输入目标检测器,得到候选区域的特征图,将候选区域的特征图送入分类器中进行类别判断与检测框回归,获得前向传播的损失,反传梯度,更新权重参数,减小损失,得到一个初步的预训练检测模型;
步骤2.4、利用初步的预训练检测模型对中间域DG样本和目标域DT的样本进行检测,对检测结果打上伪标签,即图像中物体的位置框坐标和类别,保存目标检测结果。
预训练检测模型的损失函数表示为:
其中为小批量中锚点的索引,pi是锚点/>作为目标的预测概率,/>为真值,当anchor为正时,/>为1,当anchor为负时,/>为0,ti是预测边界框的四个参数化坐标的向量,/>是与正锚框相关联的真实框的坐标,LC是两个类别的分类损失,Lr是边界框回归的损失,{pi},{ti}分别表示分类层和回归层的输出。
步骤3、构建图像级和实际级的域分类器,加入步骤2初步预训练的检测模型中,得到具有域自适应的目标检测器,结构如图3所示;
域分类器的结构由梯度反转层(GRL)、多层的全连接层后接Relu激活函数,最后接含一个神经元全连接层加Logistic损失函数构成,且所有层的权重因子设置为相等,域分类器用于判断输入的特征是来自源域DS还是来自目标域DT或中间域DG,来自源域DS则为1,为0。
步骤3中各类域判别器在Faster R-CNN中的位置是由其判别目的所决定的,首先图像级全局域判别器在图像阶段自适应使用特征提取部分最后一个卷积层之后的特征映射,为了来对齐不同域的全局特征分布。其次,采用了分层自适应的思想,建立多个图像级局部域判别器在卷积网络中提取多个中间层的输出特征映射,来监督中间层的局部特征对齐。因为仅有全局域判别器会忽略了局部特征的对齐,使得某些域敏感的局部特征削弱了自适应模型的泛化能力。最后,在目标检测模型Roi池化层提取后的特征进行特征对齐,放置实例级域分类器。因为目标检测的任务是找出图片中物体的位置并识别出类型,主要关注的是物体所在区域的特征,而目标检测模型Roi池化层提取后的特征就是物体所在预测框内的特征。
步骤4、按照源域DS、中间域DG、目标域DT的顺序渐进将数据集输入具有域自适应的目标检测器进行训练,如图4所示,渐进地适应域差异,利用梯度进行方向传播梯度更新,不断更新网络参数,减小网络损失,得到训练好的目标检测模型;
该过程中的总体损失是步骤2中的目标检测损失和域适应损失的总和。预适应损失又包括全局域适应损失、多级局部损失和实例级损失。具体表达式如下所示:
其中,j表示第i个图像中的第j个区域提议来自目标域的概率。
总体损失表达式为:
L=Ldet+λ(Lmulti+Lins+Limg) (13)
其中λ为域适应损失所占比例,Ldet为步骤2中的目标损失。
步骤5、采用训练好的目标检测模型对目标域的数据集进行目标检测,得到检测结果。
使用时,将同类型的其他目标域数据集的数据输入到训练好的目标检测器中,能够输出检测结果。
本发明中通过步骤1的设计利生成了介于目标域和源域之间的中间域数据集DG,拉近目标域和源域的特征差异,以此来解决域自适应中无法适应前后数据集差异过大的问题;通过步骤3的设计在目标检测模型当中加入多级域分类器,克服了需要同时实现跨域转换的全局特征和局部特征对齐的困难,最终达到能够在相似目标域可以检测目标对象的效果。对于同一个未使用过的样本图片,使用没有域自适应的目标检测模型和有域自适应的目标检测模型分别进行检测得到的检测结果如图5和图6所示,根据图5、图6对比可知,图6中的检测结果更好,即图片中的对象更多的被检测出来,定位框也更准确。
由此可知,有域自适应的目标检测模型检测结果更好,即图片中的对象更多的被检测出来,定位框也更准确。
通过上述方式,本发明基于多级域适应弱监督学习的跨域目标检测方法,利用MUNIT风格迁移的方法由源域DS生成接近目标域DT的中间域DG数据,同时用源域DS数据集预训练得到的目标检测模型,用其为目标域DT和中间域DG数据打上伪标签,实现一种弱监督跨域迁移条件,有助于跨域检测。多层次使用域自适应分类器,在图像级既保证全局领域特征对齐,又保证局部领域特征的对齐,同时且不改变源域与目标域中数据之间的区别信息,增强了自适应模型的鲁棒性。在实例级也进行域特征对齐,针对目标检测这个特定任务做出改善;训练过程采取源域DS到中间域DG,再到目标域DT的顺序渐进地适应域差异,由目标检测损失和域迁移损失共同使网络收敛,逐步提高检测模型的性能。
Claims (10)
1.基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,具体按照以下步骤实施:
步骤1、获取源域DS数据和目标域DT数据,构成目标检测数据集,构建MUNIT网络,利用源域DS数据和目标域DT的数据集对MUNIT网络进行训练并生成介于目标域和源域之间的中间域数据集DG;
步骤2、构建Faster RCNN网络作为目标网络,采用源域Ds的数据作为训练集对目标检测器进行训练并得到初步预训练的检测模型,将中间域DG和目标域DT数据送入检测网络中打上伪标签;
步骤3、构建图像级和实际级的域分类器,加入步骤2初步预训练的检测模型中,得到具有域自适应的目标检测器;
步骤4、按照源域DS、中间域DG、目标域DT的顺序渐进将数据集输入具有域自适应的目标检测器进行训练,渐进地适应域差异,得到训练好的目标检测模型;
步骤5、采用训练好的目标检测模型对目标域的数据集进行目标检测,得到检测结果。
2.根据权利要求1所述基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,步骤1中所述MUNIT网络包括生成器和判别器,所述生成器包括风格编码器、内容编码器、特征交叉模块、解码器,所述生成器利用源域DS数据生成接近目标域DS的中间域DG数据,所述判别器用于判断输入的数据是真实的源域DS数据还是生成的中间域DG数据。
3.根据权利要求2所述基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,步骤1具体过程为:
步骤1.1、把源域Ds和目标域DT的数据集以不同的风格和内容的分类标准将其分为相应类数,从源域Ds中选定一个子集X1,并从目标域DT选定中也抽取一个子集X2;
步骤1.2、将子集X1和X2分别输入内风格编码器、内容编码器进行降维操作,得到两种图像特征向量;通过若干个卷积层对其中一种图像特征向量进行降采样,使用深度残差网络会用到的残差块生成低维的内容编码;先通过若干个卷积层对另一种图像特征向量进行降采样,然后经过一个全局池化层和一个全连接层,最后生成低维的风格编码;
步骤1.3、将子集X1的低维的内容编码与子集X2的低维的风格编码融合,产生风格迁移的图像编码特征,向风格迁移的图像编码特征中加入高斯噪声后进行交叉,得到新结合的编码特征,对新结合的编码特征用解码器升维生成结果图像;
步骤1.4、将生成结果图像根据风格编码器和内容编码器再次分解成两个编码特征,对于低维的内容编码、低维的风格编码计算误差反向传播,重新调整MUNIT网络参数;将生成的结果图像分别输入风格编码器和内容编码器,生成新的风格编码和新的内容编码,计算步骤1.2中低维的内容编码、低维的风格编码与新的风格编码和新的内容编码之间的差距损失,设置偏差阈值,当差距损失超过偏差阈值时,将这个损失反向传播,重新调整MUNIT网络参数;
步骤1.5、将子集X1、子集X2、步骤1.3中生成结果图像输入GAN网络,进行判别和对抗训练;利用梯度反转来更新对抗损失;训练后的MUNIT神经网络使用步骤1.1中分类好的源域DS和目标域DT的数据集来生成中间域DG数据集。
4.根据权利要求2所述基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,步骤1.3中所述将子集X1的低维的内容编码与子集X2的低维的风格编码融合具体过程为:
将子集X2的低维的风格编码先由多层感知器动态生成参数,再经过自适应实例规范化层,得到可进行融合的低维的风格编码,将可进行融合的低维的风格编码和子集X1的低维的内容编码在残差块中进行糅合,然后进行上采样得到风格迁移的图像编码特征。
6.根据权利要求2所述基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,所述Faster RCNN网络包括候选检测框生成网络和Fast R-CNN网络,所述Fast R-CNN网络由特征提取部分、Roi池化层、分类器3个部分构成,所述特征提取部分,用于提取整张图片的特征,得到特征图,所述候选检测框生成网络通过softmax函数判断锚框属于正例或者反例,再利用边框回归修正锚框获得候选区域,所述Roi池化层收集输入的特征图和候选区域,提取候选框特征图,送入分类器判定目标类别。
7.根据权利要求5所述基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,步骤2中具体过程为:
步骤2.1、构建Faster RCNN网络作为目标网络,采用源域Ds的数据作为训练集,特征提取部分使用训练集中数据训练Faster RCNN网络完成初始化权重,对源域Ds的样本抽样后的256个正负例anchor框训练候选检测框生成网络、Fast R-CNN网络中特征提取部分,其中,特征提取部分权重参与调整;
步骤2.2、使用步骤2.1训练好的候选检测框生成网络,生成正例预测框,供分类器网络进行训练;此时特征提取部分权值也使用源域Ds的样本进行训练更新网络参数,将得到的候选检测框生成网络和Fast R-CNN网络作为目标检测器;
步骤2.3、再次生成利用源域Ds的实例级标签输入目标检测器,得到候选区域的特征图,将候选区域的特征图送入分类器中进行类别判断与检测框回归,获得前向传播的损失,反传梯度,更新权重参数,减小损失,得到一个初步的预训练检测模型;
步骤2.4、利用初步的预训练检测模型对中间域DG样本和目标域DT的样本进行检测,对检测结果打上伪标签,即图像中物体的位置框坐标和类别,保存目标检测结果。
9.根据权利要求1所述基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,所述域分类器的结构由梯度反转层(GRL)、多层的全连接层后接Relu激活函数,最后接含一个神经元全连接层加Logistic损失函数构成,且所有层的权重因子设置为相等,所述域分类器用于判断输入的特征是来自源域DS还是来自目标域DT或中间域DG,来自源域DS则为1,为0。
10.根据权利要求1所述基于多级域适应弱监督学习的跨域目标检测方法,其特征在于,步骤3具体过程为:
首先,图像级全局域判别器在图像阶段自适应使用特征提取部分最后一个卷积层之后的特征映射来对齐不同域的全局特征分布;
其次,建立多个图像级局部域判别器在卷积网络中提取多个中间层的输出特征映射来监督中间层的局部特征对齐;
最后,在目标检测模型Roi池化层提取后的特征进行特征对齐,放置实例级域分类器。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310258566.3A CN116342942A (zh) | 2023-03-16 | 2023-03-16 | 基于多级域适应弱监督学习的跨域目标检测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310258566.3A CN116342942A (zh) | 2023-03-16 | 2023-03-16 | 基于多级域适应弱监督学习的跨域目标检测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116342942A true CN116342942A (zh) | 2023-06-27 |
Family
ID=86888875
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310258566.3A Pending CN116342942A (zh) | 2023-03-16 | 2023-03-16 | 基于多级域适应弱监督学习的跨域目标检测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116342942A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116778277A (zh) * | 2023-07-20 | 2023-09-19 | 湖南大学无锡智能控制研究院 | 基于渐进式信息解耦的跨域模型训练方法 |
CN117456309A (zh) * | 2023-12-20 | 2024-01-26 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 基于中间域引导与度量学习约束的跨域目标识别方法 |
CN117576453A (zh) * | 2023-11-14 | 2024-02-20 | 中国人民解放军陆军装甲兵学院 | 一种跨域装甲目标检测方法、系统、电子设备及存储介质 |
-
2023
- 2023-03-16 CN CN202310258566.3A patent/CN116342942A/zh active Pending
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116778277A (zh) * | 2023-07-20 | 2023-09-19 | 湖南大学无锡智能控制研究院 | 基于渐进式信息解耦的跨域模型训练方法 |
CN116778277B (zh) * | 2023-07-20 | 2024-03-01 | 湖南大学无锡智能控制研究院 | 基于渐进式信息解耦的跨域模型训练方法 |
CN117576453A (zh) * | 2023-11-14 | 2024-02-20 | 中国人民解放军陆军装甲兵学院 | 一种跨域装甲目标检测方法、系统、电子设备及存储介质 |
CN117456309A (zh) * | 2023-12-20 | 2024-01-26 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 基于中间域引导与度量学习约束的跨域目标识别方法 |
CN117456309B (zh) * | 2023-12-20 | 2024-03-15 | 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) | 基于中间域引导与度量学习约束的跨域目标识别方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
He et al. | An end-to-end steel surface defect detection approach via fusing multiple hierarchical features | |
CN107563372B (zh) | 一种基于深度学习ssd框架的车牌定位方法 | |
Yuan et al. | Gated CNN: Integrating multi-scale feature layers for object detection | |
CN116342942A (zh) | 基于多级域适应弱监督学习的跨域目标检测方法 | |
EP3620980B1 (en) | Learning method, learning device for detecting lane by using cnn and testing method, testing device using the same | |
CN110909820A (zh) | 基于自监督学习的图像分类方法及系统 | |
CN113807420A (zh) | 一种考虑类别语义匹配的域自适应目标检测方法及系统 | |
CN112883931A (zh) | 基于长短期记忆网络的实时真假运动判断方法 | |
CN104778699A (zh) | 一种自适应对象特征的跟踪方法 | |
CN113657414B (zh) | 一种物体识别方法 | |
CN113808123B (zh) | 一种基于机器视觉的药液袋动态检测方法 | |
CN115019133A (zh) | 基于自训练和标签抗噪的图像中弱目标的检测方法及系统 | |
CN110751005B (zh) | 融合深度感知特征和核极限学习机的行人检测方法 | |
CN114549909A (zh) | 一种基于自适应阈值的伪标签遥感图像场景分类方法 | |
CN114626461A (zh) | 基于领域自适应的跨域目标检测方法 | |
CN113129336A (zh) | 一种端到端多车辆跟踪方法、系统及计算机可读介质 | |
Su et al. | Segmented handwritten text recognition with recurrent neural network classifiers | |
Wu et al. | DA-STD: deformable attention-based scene text detection in arbitrary shape | |
CN113344102A (zh) | 基于图像hog特征与elm模型的目标图像识别方法 | |
Ranjbar et al. | Scene novelty prediction from unsupervised discriminative feature learning | |
Li et al. | A fast detection method for polynomial fitting lane with self-attention module added | |
Xia et al. | Multi-RPN Fusion-Based Sparse PCA-CNN Approach to Object Detection and Recognition for Robot-Aided Visual System | |
Budiarsa et al. | Face recognition for occluded face with mask region convolutional neural network and fully convolutional network: a literature review | |
CN117132997B (zh) | 一种基于多头注意力机制和知识图谱的手写表格识别方法 | |
Ye et al. | Detection & tracking of multi-scenic lane based on segnet-LSTM semantic split network |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |