CN114821282B - 一种基于域对抗神经网络的图像检测装置及方法 - Google Patents

一种基于域对抗神经网络的图像检测装置及方法 Download PDF

Info

Publication number
CN114821282B
CN114821282B CN202210738094.7A CN202210738094A CN114821282B CN 114821282 B CN114821282 B CN 114821282B CN 202210738094 A CN202210738094 A CN 202210738094A CN 114821282 B CN114821282 B CN 114821282B
Authority
CN
China
Prior art keywords
domain
target
training
image
image detection
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210738094.7A
Other languages
English (en)
Other versions
CN114821282A (zh
Inventor
李骏
杨苏
周方明
黄伟国
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Suzhou Lichuang Zhiheng Electronic Technology Co ltd
Original Assignee
Suzhou Lichuang Zhiheng Electronic Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Suzhou Lichuang Zhiheng Electronic Technology Co ltd filed Critical Suzhou Lichuang Zhiheng Electronic Technology Co ltd
Priority to CN202210738094.7A priority Critical patent/CN114821282B/zh
Publication of CN114821282A publication Critical patent/CN114821282A/zh
Application granted granted Critical
Publication of CN114821282B publication Critical patent/CN114821282B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V2201/00Indexing scheme relating to image or video recognition or understanding
    • G06V2201/07Target detection

Abstract

本申请提供一种基于域对抗神经网络的图像检测装置及方法。所述图像检测装置包括通过图像检测训练模型按照预设训练方法得到的第一特征提取器以及第一标签分类器。所述图像检测训练模型包括第二特征提取器、第二标签分类器、梯度翻转层、全局域判别器和多个局部域判别器。在训练过程中,使用全局域判别器对齐源域和目标域的边缘分布,使用局部域判别器对齐源域和目标域的条件分布,并通过设置目标域各类样本在局部域判别器损失函数中的权重平衡因子,来解决目标域训练数据集类不平衡带来的图像检测性能下降问题。因此,训练得到的图像检测装置在面对实际工业视觉检测场景中图像数据类别不平衡时,检测准确率高。

Description

一种基于域对抗神经网络的图像检测装置及方法
技术领域
本申请涉及工业视觉检测技术领域,尤其涉及一种基于域对抗神经网络的图像检测装置及方法。
背景技术
基于域对抗神经网络的图像检测模型被广泛应用在工业视觉检测技术领域,例如,可以应用在列车关键部位的故障检测中。首先,通过采集列车关键部位的图像,将采集到的图像输入到训练好的图像检测模型中,由训练好的图像检测模型进行分类,然后根据分类结果判断列车关键部位是否存在故障、以及是什么样的故障。
目前,基于域对抗神经网络的图像检测模型一般由由三个部分组成:特征提取器、域判别器和标签分类器。其中,特征提取器将源域和目标域数据作为输入,输出高层隐含特征。域判别器则是以高层隐含特征作为输入,并区分高层隐含特征是来自哪个域。而标签分类器则是对提取自源域和目标域的高层隐含特征进行分类,尽可能地识别出高层隐含特征的种类。特征提取器和域判别器之间是一种对抗性学习。在模型训练中,域判别器不断被训练优化识别数据的来源,特征提取器则不断被训练用来混淆域判别器,直到达到纳什平衡。在不断迭代性的训练中,特征提取器可以提取可迁移的高层隐含特征。在训练完成之后将目标域测试数据输入到训练好的特征提取器和标签分类器中进行图像检测,得到其故障类别标签。
但是,目前的基于域对抗神经网络的图像检测模型在检测过程中要求源域和目标域各类别样本数量基本平衡,并且仅对源域和目标域的边缘分布进行对齐,并没有考虑到域间条件分布的差异及其与边缘分布对齐的相对重要性。而在实际工业视觉检测应用中,待检测部分一般在不同角度和明暗程度下图像的数据分布不一致;并且同一个部位在不同故障类型下也具有不同的图像数据量,导致图像数据目标集中类别不平衡。
因此,目前的基于域对抗神经网路的图像检测模型,在面对实际工业场景中图像数据集中类别不平衡时,检测准确率不高。
发明内容
为了解决目前的基于域对抗神经网络的图像检测模型在面对实际工业场景中图像数据集中类别不平衡时,检测准确率不高的问题,本申请通过以下方面公开了一种基于域对抗神经网络的图像检测装置及方法。
本申请的第一方面公开一种基于域对抗神经网络的图像检测装置,包括依次相连的第一特征提取器以及第一标签分类器;
第一特征提取器用于提取待检测图像的目标特征向量,将目标特征向量输出至第一标签分类器;
第一标签分类器用于根据目标特征向量,输出待检测图像的目标类别标签;
其中,第一特征提取器和第一标签分类器通过图像检测训练模型按照预设训练方法训练得到,其中,图像检测训练模型包括:第二特征提取器、第二标签分类器、梯度翻转层、全局域判别器和多个局部域判别器;其中,局部域判别器的数量与训练数据集中的故障类别的数量一致;
预设训练方法包括:
获取源域训练数据集、目标域训练数据集,其中,源域训练数据集包括预设数量类的源域训练图像和对应的类别标签,目标域训练数据集包括预设数量类的目标域训练图像,目标域训练数据集中各类别图像的数量不平衡;
第二特征提取器提取源域训练图像或者目标域训练图像的训练特征向量;
第二标签分类器根据训练特征向量,输出对应的类别预测标签,第二标签分类器的损失函数为源域训练图像的标签的交叉熵和目标域训练图像的预测标签的熵;
梯度翻转层翻转训练特征向量的梯度,得到中间特征向量;
全局域判别器根据中间特征向量,输出对应的域类别,域类别为源域或者目标域;其中,全局域判别器的损失函数为源域和目标域边缘分布的Wasserstein距离;
目标局部域判别器用于根据中间特征向量和目标预测概率,输出对应的域类别;其中,目标局部域判别器的损失函数为源域和目标域条件分布的Wasserstein距离,目标预测概率为对应的高层特征向量被第二标签分类器分为目标类的概率,并根据目标预测概率设置目标局部域判别器的损失函数中目标域训练数据集中各类别的权重平衡因子,其中目标局部域判别器为多个局部域判别器中的任意一个;
根据第二标签分类器的损失函数、全局域判别器的损失函数、局部域判别器的损失函数以及预设优化算法训练图像检测训练模型,得到训练后的第二特征提取器和训练后的第二标签分类器,其中,第一特征提取器为训练后的第二特征提取器,第一标签分类器为训练后的第二标签分类器。
在一些可能的实现方式中,目标局部域判别器的损失函数中目标域训练数据集中 各类别的权重平衡因子为
Figure 821804DEST_PATH_IMAGE001
其中,
Figure 676628DEST_PATH_IMAGE001
按照以下公式计算:
Figure 456365DEST_PATH_IMAGE002
其中,m为目标域训练图像的数量,
Figure 851574DEST_PATH_IMAGE003
为目标域训练数据集,
Figure 855827DEST_PATH_IMAGE004
为第二标签分类 器对目标域训练图像
Figure 678289DEST_PATH_IMAGE005
的第c类预测概率。
在一些可能的实现方式中,全局域判别器的损失函数的权重参数为μ,局部域判别器的损失函数的权重参数为1-μ
其中,μ按照以下公式计算:
Figure 883006DEST_PATH_IMAGE006
其中,s表示源域,t表示目标域,
Figure 81906DEST_PATH_IMAGE007
Figure 937735DEST_PATH_IMAGE008
分别指源域和目标域数据的边缘分布,
Figure 931099DEST_PATH_IMAGE009
Figure 623112DEST_PATH_IMAGE010
分别指源域和目标域数据的条件分布,
Figure 360123DEST_PATH_IMAGE011
Figure 70459DEST_PATH_IMAGE012
分别为域 间边缘分布和条件分布的Wasserstein距离。
在一些可能的实现方式中,第二特征提取器包括第一全连接网络或者深度卷积神经网络或者深度置信神经网络或者深度残差神经网络中的一种。
在一些可能的实现方式中,第二标签分类器包括第二全连接网络。
在一些可能的实现方式中,全局域判别器和预设数量个局部域判别器均为第三全连接网络。
在一些可能的实现方式中,预设优化算法为自适应矩估计算法或者随机梯度下降法或者均方根传递算法。
本申请第二方面提供一种基于域对抗神经网络的图像检测方法,包括:
获取待检测图像;
将待检测图像输入至本申请第一方面提供的基于域对抗神经网络的图像检测装置中,得到待检测图像的目标类别标签。
在一些可能的实现方式中,获取目标工业设备的待检测图像,包括:
获取目标物的初始图像;
对初始图像按照预设像素压缩,得到压缩后的图像;
将压缩后的图像按照预设尺寸裁剪,得到待检测图像。
本申请第三方面提供一种终端装置,包括:
至少一个处理器和存储器;
存储器,用于存储程序指令;
处理器,用于调用并执行存储器中存储的程序指令,以使终端装置执行如本申请第二方面提供的基于域对抗神经网络的图像检测方法。
本申请提供一种基于域对抗神经网络的图像检测装置及方法。所述图像检测装置包括通过图像检测训练模型按照预设训练方法得到的第一特征提取器以及第一标签分类器。所述图像检测训练模型包括第二特征提取器、第二标签分类器、梯度翻转层、全局域判别器和多个局部域判别器。在训练过程中,使用全局域判别器对齐源域和目标域的边缘分布,使用局部域判别器对齐源域和目标域的条件分布,并通过设置目标域各类样本在局部域判别器损失函数中的权重平衡因子,来解决目标域训练数据集类不平衡带来的图像检测性能下降问题。因此,训练得到的图像检测装置在面对实际工业视觉检测场景中图像数据类别不平衡时,检测准确率高。
附图说明
为了更清楚地说明本申请的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员而言,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例提供的一种基于域对抗神经网络的图像检测装置的结构示意图;
图2为本申请实施例提供的一种基于域对抗神经网络的图像检测装置中用于训练使用的图像检测训练模型的结构示意图;
图3为不同状态下的齿轮箱螺栓防松动铁丝图像示例;
图4为本申请实施例提供的所述图像检测装置和对比模型的可视化聚类结果示意图;
图5为本申请提供的所述图像检测装置和对比模型输出结果的混淆矩阵示意图。
具体实施方式
为了解决目前的基于域对抗神经网络的图像检测模型,在面对实际工业场景中图像数据集中类别不平衡时,检测准确率不高的问题,本申请通过以下实施例公开了一种基于域对抗神经网络的图像检测装置及方法。参见图1,本申请第一实施例公开的一种基于域对抗神经网络的图像检测装置包括:依次相连的第一特征提取器以及第一标签分类器。
所述第一特征提取器用于提取待检测图像的目标特征向量,输出至所述第一标签分类器。
所述第一标签分类器用于根据所述目标特征向量,输出待检测图像的目标类别标签。
其中,所述第一特征提取器和所述第一标签分类器通过图像检测训练模型按照预设训练方法训练得到。在本实施例中,图像检测训练模型是一种经过改进的域对抗神经网络模型。参见图2,所述图像检测训练模型包括:第二特征提取器、第二标签分类器、梯度翻转层、全局域判别器和多个局部域判别器;其中,所述局部域判别器的数量与训练数据集中的故障类别的数量一致。
所述预设训练方法包括:
步骤301,获取源域训练数据集、目标域训练数据集,其中,所述源域训练数据集包括预设数量类的源域训练图像和对应的类别标签,所述目标域训练数据集包括预设数量类的目标域训练图像,所述目标域训练数据集中各类别图像的数量不平衡。在实际应用中,所述目标域训练数据集中也包括对应的类别标签,但不参与训练过程,仅用于评价模型预测结果的准确性。
在一种实现方式中,要对提供的训练数据进行预处理,统一格式及尺寸,才能作为源域训练数据集、目标域训练数据集。对应的,输入到所述第一特征提取器中的待检测图像,也是经过预先的格式及尺寸处理。
步骤302,第二特征提取器提取所述源域训练图像或者所述目标域训练图像的训练特征向量。第二特征提取器也就是训练前的第一特征提取器,用于将输入的图像数据映射到高层特征空间,以训练图像作为输入,输出训练图像的高层隐含特征(即训练特征向量)。
在本实施例中,第二特征提取器由第一全连接网络、深度卷积神经网络、深度置信神经网络、深度残差神经网络中的一种神经网络构建,但不限于以上提到的神经网络。
步骤303,第二标签分类器根据所述训练特征向量,输出对应的目标类别预测标签,所述第二标签分类器的损失函数为所述源域训练图像的标签的交叉熵和所述目标域训练图像的预测标签的熵。所述第二标签分类器也就是训练前的第一标签分类器,用于通过预测训练图像的标签实现图像分类,得到对应的类别预测标签。
在本实施例中,第二标签分类器包括第二全连接网络。示例性的,第二全连接网络共设计三层,其中隐含层维数分别为256、256、4,三层全连接层后分别连接ReLU、ReLU和Softmax激活函数,第二全连接网络最后输出一个四维向量来表示输入数据的类别。其中最后一层的隐含层维数可以理解为故障类别数量。
步骤304,梯度翻转层翻转所述训练特征向量的梯度,得到中间特征向量。
步骤305,全局域判别器根据所述中间特征向量,输出对应的域类别,所述域类别为源域或者目标域;其中,所述全局域判别器的损失函数为源域和目标域边缘分布的Wasserstein距离。
步骤306,目标局部域判别器用于根据所述中间特征向量和目标预测概率,输出对应的域类别;其中,所述目标局部域判别器的损失函数为源域和目标域条件分布的Wasserstein距离,所述目标预测概率为对应的高层特征向量被所述第二标签分类器分为目标类的概率,并根据所述目标预测概率设置所述目标局部域判别器的损失函数中目标域训练数据集中各类别的权重平衡因子,其中目标局部域判别器为所述多个局部域判别器中的任意一个。
权重平衡因子用于目标域各类样本在局部域判别器损失函数中的权重平衡。目标 局部域判别器的权重平衡因子为
Figure 500304DEST_PATH_IMAGE001
其中,
Figure 679612DEST_PATH_IMAGE001
按照以下公式计算:
Figure 954736DEST_PATH_IMAGE013
其中,所述m为目标域训练图像的数量,
Figure 785157DEST_PATH_IMAGE014
为所述目标域训练数据集,
Figure 120324DEST_PATH_IMAGE004
为所述 第二标签分类器对目标域训练图像
Figure 786928DEST_PATH_IMAGE005
的第c类预测概率。
为了进一步提高图像检测装置的检测精度和泛化能力,在一些实施例中,可以通过设置动态平衡因子μ来评价便于分布对齐和条件分布对齐的相对重要性。将μ作为全局域判别器损失函数的权重参数,1-μ作为局部域判别器损失函数的权重参数。动态平衡因子μ按照以下公式计算:
Figure 865743DEST_PATH_IMAGE015
其中,s表示源域,t表示目标域,
Figure 547741DEST_PATH_IMAGE007
Figure 319388DEST_PATH_IMAGE008
分别指源域和目标域数据的边缘分布,
Figure 207710DEST_PATH_IMAGE009
Figure 90215DEST_PATH_IMAGE010
分别指源域和目标域数据的条件分布,
Figure 629650DEST_PATH_IMAGE011
Figure 572198DEST_PATH_IMAGE012
分别为域 间边缘分布和条件分布的Wasserstein距离。
在本实施例中,全局域判别器用于来对齐数据边缘分布,局部域判别器用于对齐条件分布。所述全局域判别器和预设数量个所述局部域判别器均为第三全连接网络。示例性的,第三全连接网络共设三层,其中隐含层维数分别为256、256、1,每个全连接层后分别连接ReLU、ReLU和Sigmoid激活函数,第三全连接网络最后输出一个一维向量来表示输入数据的域类别。通过全局域判别器和局部域判别器的输出来计算目标域中各类别样本的权重平衡因子、全局域判别损失、局部域判别损失。
需要说明的是,第二全连接网络或者第三全连接网络的最后一个激活函数层中的激活函数不限于使用上述示例中的Softmax或者Sigmoid函数。
步骤307,根据所述第二标签分类器的损失函数、所述全局域判别器的损失函数、所述局部域判别器的损失函数以及预设优化算法训练所述图像检测训练模型,得到训练后的第二特征提取器和训练后的第二标签分类器,其中,所述第一特征提取器为所述训练后的第二特征提取器,所述第一标签分类器为所述训练后的第二标签分类器。
在本实施例中,所述预设优化算法为自适应矩估计算法、随机梯度下降法、均方根传递算法中的一种,但不限于以上列举的迭代优化算法。
所述图像检测训练模型的训练过程为:将源域和目标域训练图像经特征提取器提取的高层隐含特征输入标签分类器、全局域判别器和局部域判别器。对来自源域的带标签数据,模型不断最小化所述第二标签分类器的损失(包含源域样本预测标签的交叉熵损失和目标域样本预测标签的熵损失),对来自源域和目标域的全部数据,网络不断最小化域判别器损失(包含全局域判别器损失和局部域判别器损失)。全局域判别器和局部域判别器的训练目标是尽量将输入的特征分到正确的域类别,而第二特征提取器所提取的特征的目标是使全局域判别器和局部域判别器不能正确的判断出特征来自哪一个域,因此形成一种对抗关系。当对抗训练使模型达到纳什平衡时停止训练。
本实施例提供一种基于域对抗神经网络的图像检测装置。所述图像检测装置包括通过图像检测训练模型按照预设训练方法得到的第一特征提取器以及第一标签分类器。所述图像检测训练模型包括第二特征提取器、第二标签分类器、梯度翻转层、全局域判别器和多个局部域判别器。在训练过程中,使用全局域判别器对齐源域和目标域的边缘分布,使用局部域判别器对齐源域和目标域的条件分布,并通过设置目标域各类样本在局部域判别器损失函数中的权重平衡因子,来解决目标域训练数据集类不平衡带来的图像检测性能下降问题。因此,训练得到的图像检测装置在面对实际工业视觉检测常见中图像数据类别不平衡时,检测准确率高。
进一步地,通过动态平衡因子评价边缘分布对齐和条件分布对齐的相对重要性,来提高装置的检测精度和泛化能力。
本申请第二实施例提供一种基于域对抗神经网络的图像检测方法,所述方法包括:
步骤401,获取待检测图像;
步骤402,将所述待检测图像输入至本申请第一实施例提供的基于域对抗神经网络的图像检测装置中,得到待检测图像的目标类别标签。
也就是说,本申请第二实施例提供的图像检测方法,是使用本申请第一实施例提供的所述图像检测装置来进行检测。
在一种实现方式中,所述获取列车关键部位的待检测图像,包括:
步骤4011,获取目标物的初始图像;
步骤4012,对所述初始图像按照预设像素压缩,得到压缩后的图像;
步骤4013,将所述压缩后的图像按照预设尺寸裁剪,得到待检测图像。
本实施例提供的所述方法在应用在工业视觉检测领域时的作用效果可参见本申请第一实施例中的说明,在此不再赘述。
为了更加清楚地了解本申请的技术方案及其效果,下面结合一个具体的示例进行详细说明。
以列车车底齿轮箱螺栓防松动铁丝故障的故障检测为例,由车底检测机器人拍摄齿轮箱螺栓防松动铁丝图像。如图3所示,防松动铁丝的状态类别包含正常(N)、绷断_位置1(B1)、绷断_位置2(B2)和缺失(L)等4种,类别标签分别用0、1、2和3来表示。实验中将图像数据集划分为源域训练数据集、目标域训练数据集和目标域测试数据集,源域和目标域包含多种不同角度和明暗程度下的图像。源域和目标域中均包含N、B1、B2和L等4种状态的图像样本。
步骤501,对车底检测机器人拍摄的齿轮箱螺丝防松动铁丝图像进行压缩裁剪,统一图像尺寸,把图像划分为源域训练数据集、目标域训练数据集和目标域测试数据集。源域训练数据集中各类样本数量均为100,目标域训练数据集中4类样本数量分别为100、30、10和5,目标域测试数据集中各类别样本数量均为50。
齿轮箱铁丝图像的关键部分位于图像中心部位,先对图像进行压缩像素为512*512,然后从图像中心裁剪出像素为400*400的图像作为图像检测训练模型的输入。
步骤502,建立图像检测训练模型。其中,图像检测训练模型包括第二特征提取器、第二标签分类器、梯度翻转层、全局域判别器和4个局部域判别器。
第二特征提取器采用经典的深度残差网络ResNet-18,以预处理后的RGB三通道图片作为输入,输出长度为512的高层隐含特征向量。
第二标签分类器采用全连接网络,共设计三层,其中隐含层维数分别为256、256、4,三层全连接层后分别连接ReLU、ReLU和Softmax激活函数,第二标签分类器最后输出一个四维向量来表示输入数据的类别。通过第二标签分类器的标签输出来计算源域样本预测标签的交叉熵和目标域样本预测标签的熵。
全局域判别器和4个局部域判别器均采用相同的全连接网络,共设计三层,其中隐含层维数分别为256、256、1,每个全连接层后分别连接ReLU、ReLU和Sigmoid激活函数,全局域判别器和局部域判别器最后输出一个一维向量来表示输入数据的域类别。通过全局域判别器和局部域判别器的输出来计算目标域中各类别样本的权重平衡因子、全局域判别损失、局部域判别损失、边缘分布和条件分布的动态平衡因子。
步骤503,训练图像检测训练模型。将带有标签的源域训练数据集和无标签的目标域训练数据集输入到构建的图像检测训练模型中,根据损失函数和优化算法进行模型训练。
图像检测训练模型的目标损失函数包含第二标签分类器损失和全局域判别器损失和局部域判别器损失。
优化算法采用随机梯度下降算法(Stochastic Gradient Descent,SGD),学习率为0.01,动量为0.9,迭代150次后模型目标函数损失趋于平衡,结束模型训练。
步骤504,使用训练后的第二特征提取器和训练后的第二标签分类器组成图像检测装置。将目标域测试数据集中的测试图像依次输入到所述图像检测装置中,在线输出故障类别。
为了验证本申请的有效性,分别使用本申请的图像检测装置和对比模型的检测结果进行对比。其中,对比模型为无条件分布对齐和权重平衡的域对抗神经网络模型训练得到的检测装置。本申请和对比模型利用t-SNE进行各健康状态图片样本的特征聚类结果可视化,分别如图4中的(a)、图4中的(b)所示,其中S和T分别为源域训练样本和目标域测试样本。由图4看出,相比对比模型,本申请提供的图像检测装置可以有效地减小源域和目标域中相同类别样本特征的数据分布距离,并使不同类别样本特征之间的距离变大,且只有少部分不同类别样本产生了混淆。本申请提供的图像检测装置和对比模型输出结果的混淆矩阵分别如图5中的(a)、图5中的(b)所示。从图5中可以看出本申请提供的图像障检测装置的诊断准确率很高,达到了98.5%。本申请提供的图像检测装置的输出结果中只有三个样本被错误分类,且具有最少样本数量的L类别全部分类准确,而对比模型将目标域少数类训练样本L类别错误分类到了多数类样本中,诊断准确率仅有86%。
综上所述,对源域和目标域样本同时进行数据边缘分布和条件分布动态对齐,并通过权重平衡因子对目标域不平衡图像样本进行条件分布加权对齐,以及利用动态平衡因子评价边缘分布对齐和条件分布对齐的相对重要性,可以提高图像可迁移特征的提取能力,突破源域和目标域数据分布不一致的限制,解决目标域中类不平衡问题,实现列车关键部位故障的精确故障检测。
本申请第三实施例提供一种终端装置,包括:
至少一个处理器和存储器;
存储器,用于存储程序指令;
处理器,用于调用并执行存储器中存储的程序指令,以使终端装置执行如本申请第二实施例提供的基于域对抗神经网络的图像检测方法。
以上结合具体实施方式和范例性实例对本申请进行了详细说明,不过这些说明并不能理解为对本申请的限制。本领域技术人员理解,在不偏离本申请精神和范围的情况下,可以对本申请技术方案及其实施方式进行多种等价替换、修饰或改进,这些均落入本申请的范围内。本申请的保护范围以所附权利要求为准。
本说明书中各个实施例之间相似相同部分互相参见即可。

Claims (10)

1.一种基于域对抗神经网络的图像检测装置,其特征在于,包括依次相连的第一特征提取器以及第一标签分类器;
所述第一特征提取器用于提取待检测图像的目标特征向量,将所述目标特征向量输出至所述第一标签分类器;
所述第一标签分类器用于根据所述目标特征向量,输出待检测图像的目标类别标签;
其中,所述第一特征提取器和所述第一标签分类器通过图像检测训练模型按照预设训练方法训练得到,其中,所述图像检测训练模型包括:第二特征提取器、第二标签分类器、梯度翻转层、全局域判别器和多个局部域判别器;其中,所述局部域判别器的数量与训练数据集中的故障类别的数量一致;
所述预设训练方法包括:
获取源域训练数据集、目标域训练数据集,其中,所述源域训练数据集包括预设数量类的源域训练图像和对应的类别标签,所述目标域训练数据集包括预设数量类的目标域训练图像,所述目标域训练数据集中各类别图像的数量不平衡;
第二特征提取器提取所述源域训练图像或者所述目标域训练图像的训练特征向量;
第二标签分类器根据所述训练特征向量,输出对应的类别预测标签,所述第二标签分类器的损失函数为所述源域训练图像的标签的交叉熵和所述目标域训练图像的预测标签的熵;
梯度翻转层翻转所述训练特征向量的梯度,得到中间特征向量;
全局域判别器根据所述中间特征向量,输出对应的域类别,所述域类别为源域或者目标域;其中,所述全局域判别器的损失函数为源域和目标域边缘分布的Wasserstein距离;
目标局部域判别器用于根据所述中间特征向量和目标预测概率,输出对应的域类别;其中,所述目标局部域判别器的损失函数为源域和目标域条件分布的Wasserstein距离,所述目标预测概率为对应的高层特征向量被所述第二标签分类器分为目标类的概率,并根据所述目标预测概率设置所述目标局部域判别器的损失函数中目标域训练数据集中各类别的权重平衡因子,其中目标局部域判别器为所述多个局部域判别器中的任意一个;
根据所述第二标签分类器的损失函数、所述全局域判别器的损失函数、所述局部域判别器的损失函数以及预设优化算法训练所述图像检测训练模型,得到训练后的第二特征提取器和训练后的第二标签分类器,其中,所述第一特征提取器为所述训练后的第二特征提取器,所述第一标签分类器为所述训练后的第二标签分类器。
2.根据权利要求1所述的基于域对抗神经网络的图像检测装置,其特征在于,所述目标 局部域判别器的损失函数中目标域训练数据集中各类别的权重平衡因子为
Figure 889757DEST_PATH_IMAGE001
其中,
Figure 139472DEST_PATH_IMAGE001
按照以下公式计算:
Figure 813162DEST_PATH_IMAGE002
其中,所述m为目标域训练图像的数量,
Figure 60603DEST_PATH_IMAGE003
为所述目标域训练数据集,
Figure 131327DEST_PATH_IMAGE004
为所述第二 标签分类器对目标域训练图像
Figure 371685DEST_PATH_IMAGE005
的第c类预测概率。
3.根据权利要求1所述的基于域对抗神经网络的图像检测装置,其特征在于,所述全局域判别器的损失函数的权重参数为μ,所述局部域判别器的损失函数的权重参数为1-μ
其中,μ按照以下公式计算:
Figure 211465DEST_PATH_IMAGE006
其中,s表示源域,t表示目标域,
Figure 629808DEST_PATH_IMAGE007
Figure 922249DEST_PATH_IMAGE008
分别指源域和目标域数据的边缘分布,
Figure 464832DEST_PATH_IMAGE009
Figure 159119DEST_PATH_IMAGE010
分别指源域和目标域数据的条件分布,
Figure 748363DEST_PATH_IMAGE011
Figure 793680DEST_PATH_IMAGE012
分别为域间边 缘分布和条件分布的Wasserstein距离。
4.根据权利要求1所述的基于域对抗神经网络的图像检测装置,其特征在于,所述第二特征提取器包括第一全连接网络或者深度卷积神经网络或者深度置信神经网络或者深度残差神经网络中的一种。
5.根据权利要求1所述的基于域对抗神经网络的图像检测装置,其特征在于,所述第二标签分类器包括第二全连接网络。
6.根据权利要求1所述的基于域对抗神经网络的图像检测装置,其特征在于,所述全局域判别器和预设数量个所述局部域判别器均为第三全连接网络。
7.根据权利要求1所述的基于域对抗神经网络的图像检测装置,其特征在于,所述预设优化算法为自适应矩估计算法或者随机梯度下降法或者均方根传递算法。
8.一种基于域对抗神经网络的图像检测方法,其特征在于,包括:
获取待检测图像;
将所述待检测图像输入至如权利要求1-7任一项所述的基于域对抗神经网络的图像检测装置中,得到待检测图像的目标类别标签。
9.根据权利要求8所述的基于域对抗神经网络的图像检测方法,其特征在于,所述获取待检测图像,包括:
获取目标物的初始图像;
对所述初始图像按照预设像素压缩,得到压缩后的图像;
将所述压缩后的图像按照预设尺寸裁剪,得到待检测图像。
10.一种终端装置,其特征在于,包括:
至少一个处理器和存储器;
所述存储器,用于存储程序指令;
所述处理器,用于调用并执行所述存储器中存储的程序指令,以使所述终端装置执行如权利要求8-9任一项所述的基于域对抗神经网络的图像检测方法。
CN202210738094.7A 2022-06-28 2022-06-28 一种基于域对抗神经网络的图像检测装置及方法 Active CN114821282B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210738094.7A CN114821282B (zh) 2022-06-28 2022-06-28 一种基于域对抗神经网络的图像检测装置及方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210738094.7A CN114821282B (zh) 2022-06-28 2022-06-28 一种基于域对抗神经网络的图像检测装置及方法

Publications (2)

Publication Number Publication Date
CN114821282A CN114821282A (zh) 2022-07-29
CN114821282B true CN114821282B (zh) 2022-11-04

Family

ID=82523147

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210738094.7A Active CN114821282B (zh) 2022-06-28 2022-06-28 一种基于域对抗神经网络的图像检测装置及方法

Country Status (1)

Country Link
CN (1) CN114821282B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115063459B (zh) * 2022-08-09 2022-11-04 苏州立创致恒电子科技有限公司 点云配准方法及装置、全景点云融合方法及系统
CN117011571A (zh) * 2022-09-30 2023-11-07 腾讯科技(深圳)有限公司 图像分类模型的训练方法、装置及设备
CN115880538A (zh) * 2023-02-17 2023-03-31 阿里巴巴达摩院(杭州)科技有限公司 图像处理模型的域泛化、图像处理的方法及设备
CN116129198B (zh) * 2023-04-12 2023-07-18 山东建筑大学 一种多域轮胎花纹图像分类方法、系统、介质及设备

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113392967A (zh) * 2020-03-11 2021-09-14 富士通株式会社 领域对抗神经网络的训练方法
CN114358124B (zh) * 2021-12-03 2024-03-15 华南理工大学 基于深度对抗卷积神经网络的旋转机械新故障诊断方法
CN114492574A (zh) * 2021-12-22 2022-05-13 中国矿业大学 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法

Also Published As

Publication number Publication date
CN114821282A (zh) 2022-07-29

Similar Documents

Publication Publication Date Title
CN114821282B (zh) 一种基于域对抗神经网络的图像检测装置及方法
CN112990432B (zh) 目标识别模型训练方法、装置及电子设备
CN111507370A (zh) 获得自动标注图像中检查标签的样本图像的方法和装置
CN108846835A (zh) 基于深度可分离卷积网络的图像变化检测方法
CN105574550A (zh) 一种车辆识别方法及装置
CN111695466B (zh) 一种基于特征mixup的半监督极化SAR地物分类方法
CN110135505B (zh) 图像分类方法、装置、计算机设备及计算机可读存储介质
CN115731164A (zh) 基于改进YOLOv7的绝缘子缺陷检测方法
CN111368690A (zh) 基于深度学习的海浪影响下视频图像船只检测方法及系统
JP2020126613A (ja) イメージを分析するために、ディープラーニングネットワークに利用するためのトレーニングイメージに対するラベルリング信頼度を自動的に評価するための方法、及びこれを利用した信頼度評価装置
CN109543760A (zh) 基于图像滤镜算法的对抗样本检测方法
CN110222604A (zh) 基于共享卷积神经网络的目标识别方法和装置
CN116503399B (zh) 基于yolo-afps的绝缘子污闪检测方法
CN111523558A (zh) 一种基于电子围网的船只遮挡检测方法、装置及电子设备
CN113569981A (zh) 一种基于单阶段目标检测网络的电力巡检鸟窝检测方法
CN110321867B (zh) 基于部件约束网络的遮挡目标检测方法
CN116912796A (zh) 一种基于新型动态级联YOLOv8的自动驾驶目标识别方法及装置
CN111598854A (zh) 基于丰富鲁棒卷积特征模型的复杂纹理小缺陷的分割方法
CN111539456A (zh) 一种目标识别方法及设备
CN111553184A (zh) 一种基于电子围网的小目标检测方法、装置及电子设备
CN113723553A (zh) 一种基于选择性密集注意力的违禁物品检测方法
CN116485796B (zh) 害虫检测方法、装置、电子设备及存储介质
CN114821200B (zh) 一种应用于工业视觉检测领域的图像检测模型及方法
CN113379685A (zh) 一种基于双通道特征比对模型的pcb板缺陷检测方法及装置
CN109086737A (zh) 基于卷积神经网络的航运货物监控视频识别方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant