CN113920302A - 基于交叉注意力机制的多头弱监督目标检测方法 - Google Patents

基于交叉注意力机制的多头弱监督目标检测方法 Download PDF

Info

Publication number
CN113920302A
CN113920302A CN202111037829.5A CN202111037829A CN113920302A CN 113920302 A CN113920302 A CN 113920302A CN 202111037829 A CN202111037829 A CN 202111037829A CN 113920302 A CN113920302 A CN 113920302A
Authority
CN
China
Prior art keywords
candidate region
feature
head
image
class
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111037829.5A
Other languages
English (en)
Inventor
李浥东
李慧芳
董海荣
韩瑜珊
金�一
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Jiaotong University
Original Assignee
Beijing Jiaotong University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Jiaotong University filed Critical Beijing Jiaotong University
Priority to CN202111037829.5A priority Critical patent/CN113920302A/zh
Publication of CN113920302A publication Critical patent/CN113920302A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • G06F18/253Fusion techniques of extracted features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明提供了一种基于交叉注意力机制的多头弱监督目标检测方法,包括:获取图像,对图像进行处理,将处理后的图像分成训练集和测试集;定义类原型特征,构建包含类原型特征的基于交叉注意力机制的弱监督目标检测网络WCAN模型;采用训练数据集对WCAN模型进行训练;基于训练好的WCAN模型,对测试集图像进行目标检测。本方法在只有类别标签的条件下,能够更全面地、充分地感知位置和类别信息,实现更准确高效的目标检测。

Description

基于交叉注意力机制的多头弱监督目标检测方法
技术领域
本发明涉及计算机视觉领域,尤其涉及一种基于交叉注意力机制的多头弱监督目标检测方法。
背景技术
目标检测是找出图像或者视频中所有感兴趣的目标,并获得这一目标的类别信息和位置信息,目标检测是视觉识别和分析的基础,是计算机视觉领域的核心问题之一。近年来,随着深度学习的发展,目标检测模型的性能得到了极大的提高。但是,这些研究大多是针对全监督目标检测而提出的,全监督模型的学习需要大规模的且具有精确的目标框标注的数据。目前,精确的标注多是由算法工作人员手工标注完成,标注工作耗费人力、物力和财力,尤其是当一副图像中存在多个目标的时候,需要逐个对目标进行手工标注。为了解决标注的问题,相关研究人员提出了弱监督目标检测,该方法只需要标定这副图像中包含的目标的类别信息,提供图像级别的标注信息。这种标注方式,极大地减轻了标注工作量,将标注的速度提升10倍以上。因此,如何提高弱监督目标检测的性能成为当前的研究热点。
弱监督目标检测的任务是在给定图像中存在目标的类别的条件下,通过算法找出所有的目标并定位。多示例学习(MIL,Multiple Instance Learning)算法是弱监督目标检测最常用的方法,通过把整张图片作为一个包,并通过选择性搜索算法(SS,SelectiveSearch)或多尺度组合分组方法(MCG,Multiscale Combinatorial Grouping)生成大量的候选框,把这些候选框作为对象,通过多示例学习算法,训练一个分类器为这些候选框赋予类别标签。在此基础上,现有技术中使用的方法迭代地进行候选框选择和外观模型估计,并逐步检测出目标。但是该迭代方法很容易使得模型陷入局部最优。针对上述问题,提出了一个简洁的端到端的弱监督目标检测框架WSDDN,其独立地执行候选框区域分类和候选区域定位,然后将两个分支的输出集成到图像级监督学习中,隐式地学习目标检测器。但是,该方法中的两个分支是独立学习的,定位分支没有考虑整体语义信息,可能输出错误的候选区域定位得分,分类分支中的分类器只关注目标局部判别性部位,不能取得最优的候选区域得分。两个独立的分支不能全面地感知类别和位置信息,这不符合目标的要求。为了解决这个问题,现有的一些方法通过利用上下文信息,引入额外的分割模块或者训练在线对象优化模块来提高弱监督面部检测的性能。虽然这些方法对现有的方法进行了改进,但是它们的基分类器都是基于WSDDN框架,这个基础模型阻碍了已有方法性能的进一步提升。
因此,亟需一种可以防止陷入局部最优,能实现更准确的目标检测的方法。
发明内容
本发明提供了一种基于交叉注意力机制的多头弱监督目标检测方法,以突破现有的技术瓶颈,解决现有技术问题中存在的缺陷。
为了实现上述目的,本发明采取了如下技术方案。
一种基于交叉注意力机制的多头弱监督目标检测方法,包括:
S1获取图像,对所述图像进行处理,将处理后的图像分成训练集和测试集;
S2定义类原型特征,构建包含类原型特征的基于交叉注意力机制的弱监督目标检测网络WCAN模型;
S3采用训练数据集对WCAN模型进行训练;
S4基于训练好的WCAN模型,对测试集图像进行目标检测。
优选地,对图像进行处理,包括:对每个图像,标注图像中包含的类别标签,采用选择性搜索方法或多尺度组合分组方法获取数据集中每个图像中的候选区域。
优选地,定义类原型特征,包括:定义类原型特征
Figure BDA0003247948740000031
每个类原型特征Ei对应一个特定的语义类别,C是类别数量,D是类原型特征的维度,
Figure BDA0003247948740000032
表示实数。
优选地,WCAN模型包括候选区域特征提取网络、类原型特征和多头弱监督目标检测网络;
所述多头弱监督目标检测网络包括并列的类别感知的交叉注意力模块和定位感知的交叉注意力模块。
优选地,WCAN模型还包括多头弱监督位置定位网络,多头弱监督位置定位网络包括位置感知的交叉注意力模块。
优选地,采用训练数据集对WCAN模型进行训练,包括:
S31对WCAN模型进行初始化;
S32将训练集中的图像输入至WCAN模型的候选区域特征提取网络,通过主干网络进行特征图提取;
S33将提取后的图像特征图与对应的候选区域同时输入至区域池化层提取候选区域特征Fp;候选区域特征Fp和类原型特征E同时输入至多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块,分别得到多头候选区域类别相似性得分和多头候选区域目标相似性得分,对同一区域的多头候选区域类别相似性得分和多头候选区域目标相似性得分融合得到候选区域得分,并对多头候选区域得分聚合得到图像得分Si
S34采用交叉熵损失函数计算对每个图像得分Si和图像标签之间的误差,并对多个误差值求和,使用随机梯度下降算法更新WCAN模型中主干网络、类原型特征和多头弱监督目标检测网络模型的参数;
S35重复步骤S32-S34,直至交叉熵损失函数的结果不再降低,得到训练好的WCAN模型。
优选地,S33和S34之间还包括:
将主干网络特征提取后的图像特征图在空间维度进行形状变换得到位置特征,采用线性映射层将位置特征维度对齐到类原型特征维度;
将位置特征Fpi和类原型特征E输入至多头弱监督位置定位网络中的位置感知的交叉注意力模块,得到多头位置得分,分别对多头位置得分进行聚合,得到多头图像得分
Figure BDA0003247948740000041
采用交叉熵损失函数计算对每个图像得分
Figure BDA0003247948740000042
和图像标签之间的误差,并对多个误差值求和,使用随机梯度下降算法更新WCAN模型中主干网络、类原型特征和多头弱监督目标检测网络模型的参数。
优选地,候选区域特征Fp和类原型特征E同时输入至多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块,分别得到多头候选区域类别相似性得分和多头候选区域目标相似性得分,具体包括:
对于多头弱监督目标检测网络中类感知的交叉注意力模块:采用线性映射层FCck对候选区域特征进行映射,采用层归一化层,线性映射层FCcq对类原型特征进行映射,然后在特征维度对映射后的类原型特征和映射后的候选区域特征进行切分;接下来用dot-product操作计算第i个类原型特征切片和第i个候选区域特征切片的相似性得分,再在类别维度对相似性得分进行softmaxc归一化得到第i个候选区域类别相似性得分Pi,c,共可得到与切分个数相同个数的候选区域类别相似性得分;
对于多头弱监督目标检测网络中定位感知的交叉注意力模块:采用线性映射层FCdk对输入的候选区域特征进行映射,采用层归一化层,线性映射层FCdq对类原型特征进行映射,然后对映射后的类原型特征和映射后的候选区域特征进行切片,得到类原型特征切片和候选区域特征切片;接下来用dot-product操作计算第i个类原型特征切片和第i个候选区域特征切片的相似性得分,再在候选区域维度对相似性得分进行softmaxd归一化得到第i个候选区域目标相似性得分Pi,d,共可得到与切片个数相同个数的候选区域目标相似性得分。
优选地,将位置特征Fpi和类原型特征E输入至多头弱监督位置定位网络中的位置感知的交叉注意力模块,得到多头位置得分,具体包括:
采用线性映射层FCpi_k对位置特征进行映射,用层归一化层,线性映射层FCpi_q对类原型进行映射;然后在特征维度对映射后的类原型特征和映射后的位置特征进行切分;接下来采用dot-product操作计算第i个类原型特征切片和第i个位置特征切片的相似性得分,再在位置维度对相似性得分进行softmaxl归一化得到第i个位置得分PIi,共可得到与切分个数相同个数的位置得分。
优选地,S4包括:
S41用主干网络提取测试集中的图像特征图;
S42用区域池化层提取图像的候选区域特征;
S43将候选区域特征和类原型特征输入至基于多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块得到多头候选区域得分;
S44对所有多头候选区域得分求平均值得到最终的候选区域得分;
S45采用NMS算法去除重叠的和低分的候选区域,得到最终的目标检测结果。
由上述本发明的基于交叉注意力机制的多头弱监督目标检测方法提供的技术方案可以看出,本发明考虑到弱监督条件下只有类别标签,为了使模型综合考虑分类和定位任务,定义类原型特征,从类别感知和定位维度与候选区域特征同时执行交叉注意力学习,从而鼓励模型学习对分类和定位同时敏感的知识,更符合目标检测的任务要求;提出了一个多检测头来实现弱监督目标检测,引导网络同时关注不同的目标区域,从而防止陷入局部最优,获得更完整的检测结果,缓解了弱监督目标检测模型只关注局部最显眼的部位的问题;提出在细粒度的图像特征上进行原型特征的优化,防止引入背景特征,学习更准确的类原型特征。因此,本发明在只有类别标签的条件下,模型能够更全面地、充分地感知位置和类别信息,目标检测更准确高效。
本发明附加的方面和优点将在下面的描述中部分给出,这些将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本实施例的基于交叉注意力机制的多头弱监督目标检测方法流程示意图;
图2为基于交叉注意力机制的多头弱监督检测网络模型结构示意图;
图3为训练流程示意图;
图4为测试流程示意图;
图5为多头弱监督目标检测网络中类感知的交叉注意力模块处理示意图;
图6为多头弱监督目标检测网络中定位感知的交叉注意力模块处理示意图;
图7为多头弱监督位置定位网络中的位置感知的交叉注意力模块处理示意图;
图8为检测结果示意图。
具体实施方式
下面详细描述本发明的实施方式,所述实施方式的示例在附图中示出,其中自始至终相同或类似的标号表示相同或类似的元件或具有相同或类似功能的元件。下面通过参考附图描述的实施方式是示例性的,仅用于解释本发明,而不能解释为对本发明的限制。
本技术领域技术人员可以理解,除非特意声明,这里使用的单数形式“一”、“一个”、“所述”和“该”也可包括复数形式。应该进一步理解的是,本发明的说明书中使用的措辞“包括”是指存在所述特征、整数、步骤、操作、元件和/或组件,但是并不排除存在或添加一个或多个其他特征、整数、步骤、操作、元件、组件和/或它们的组。应该理解,当我们称元件被“连接”或“耦接”到另一元件时,它可以直接连接或耦接到其他元件,或者也可以存在中间元件。此外,这里使用的“连接”或“耦接”可以包括无线连接或耦接。这里使用的措辞“和/或”包括一个或更多个相关联的列出项的任一单元和全部组合。
本技术领域技术人员可以理解,除非另外定义,这里使用的所有术语(包括技术术语和科学术语)具有与本发明所属领域中的普通技术人员的一般理解相同的意义。还应该理解的是,诸如通用字典中定义的那些术语应该被理解为具有与现有技术的上下文中的意义一致的意义,并且除非像这里一样定义,不会用理想化或过于正式的含义来解释。
为便于对本发明实施例的理解,下面将结合附图以具体实施例为例做进一步的解释说明,且并不构成对本发明实施例的限定。
实施例
图1为本实施例的基于交叉注意力机制的多头弱监督目标检测方法流程示意图,参照图1,该方法包括:
S1获取图像,对图像进行处理,将处理后的图像分成训练集和测试集。数据集图像可从网站上抓取,也可用公开的目标检测数据集,如PASCAL VOC或者MS-COCO等。对获取的每个图像,标注图像中包含的类别标签Y={Y1,Y2,...,YC},Yc∈{0,1},采用选择性搜索方法(SS,Selective Search)或多尺度组合分组方法(MCG,Multiscale CombinatorialGrouping)获取数据集中每个图像中的候选区域。其中Y是每个图像中提取的候选区域(也叫候选框)的总数,每个候选区域用4维的坐标向量表示,指示该候选区域在图像中的位置。然后将标注和划分候选区域的图像数据集分成训练集和测试集。将数据集按照一定的比例,如8:2或者7:3,将数据集划分成训练集和测试集。公开的目标检测数据集已经划分好了训练集和测试集。
S2定义类原型特征,构建包含类原型特征的基于交叉注意力机制的弱监督目标检测网络(WCAN,Weakly supervised object detection with Cross-Attention Network)模型。
定义类原型特征
Figure BDA0003247948740000081
每个类原型特征Ei对应一个特定的语义类别,C是类别数量,D是类原型特征的维度,
Figure BDA0003247948740000082
表示实数。
图2为基于交叉注意力机制的多头弱监督检测网络模型结构示意图,参照图2,WCAN模型包括候选区域特征提取网络、类原型特征和多头弱监督目标检测网络;多头弱监督目标检测网络包括并列的类别感知的交叉注意力模块和定位感知的交叉注意力模块。需要说明的是,WCAN模型还包括多头弱监督位置定位网络,多头弱监督位置定位网络包括位置感知的交叉注意力模块。
由于本实施例构建的类别感知的交叉注意力模块、定位感知的交叉注意力模块和位置感知的交叉注意力模块都可以同时输出多头预测值。因此,构建的网络模型是一个基于交叉注意力机制的多头弱监督目标检测网络模型。
S3采用训练数据集对WCAN模型进行训练,图3为训练流程示意图。
对WCAN模型进行初始化:
对于候选区域特征提取网络中的主干网络,采用已有的在ImageNet数据集上预训练的网络模型参数进行初始化,其他网络模型参数均采用随机参数初始化,类原型特征E也采用随机参数进行初始化。
初始化后,从训练集中采样一张训练图像,同时提取该图像的类别标签Y={Y1,Y2,...,YC},Yc∈{0,1}和候选区域
Figure BDA0003247948740000091
将训练集中的图像输入至WCAN模型的候选区域特征提取网络,通过主干网络进行图像特征图提取,提取图像特征图
Figure BDA0003247948740000092
其中,H、W和Ch分别表示图像特征图的宽、高和特征维度。
将提取后的图像特征图F与对应的候选区域P同时输入至区域池化层(RP,ROIPooling)提取候选区域特征Fp。需要说明的是,候选区域特征图同时要在空间维度进行形状变换,再经过两个非线性映射层FC1,FC2获取候选区域特征
Figure BDA0003247948740000093
D1是候选区域特征的维度。每个非线性映射层由全连接层FC,非线性激活函数ReLU和Dropout层组成。
候选区域特征Fp和类原型特征E同时输入至多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块,分别得到多头候选区域类别相似性得分Pi,c,i={1,...,H1}和多头候选区域目标相似性得分Pi,d,i={1,...,H1}。
多头弱监督目标检测网络中类感知的交叉注意力模块处理示意图如图5所示:采用线性映射层FCck对候选区域特征Fp进行映射,得到映射后的候选区域特征Fck。采用层归一化层,线性映射层FCcq对类原型特征进行映射,得到映射后的类原型特征Ecq。然后在特征维度对映射后的类原型特征和映射后的候选区域特征进行切分,例如,把D维的特征切分成H1份,得到H1个特征切片,每个切片的特征维度是D/H1,得到H1个类原型特征切片和H1个候选区域特征切片,具体地,将Fck在特征维度切割成H1份,得到切片候选区域特征
Figure BDA0003247948740000101
将Ecq在特征维度切割成H1份,得到类原型特征切片
Figure BDA0003247948740000102
接下来用dot-product操作计算第i个类原型特征切片和第i个候选区域特征切片的相似性得分,对每对候选区域特征切片
Figure BDA0003247948740000103
和类原型特征切片
Figure BDA0003247948740000104
进行向量内积,输出类原型特征和候选区域特征的相似性矩阵
Figure BDA0003247948740000105
⊙是dot-product操作。其他切片对也进行相同的操作,共输出H1个相似性矩阵
Figure BDA0003247948740000106
再在类别维度对相似性得分进行softmaxc归一化得到第i个候选区域类别相似性得分Pi,c,共可得到与切分个数相同个数的候选区域类别相似性得分Pi,c,i={1,...,H1},输出候选区域类别得分矩阵
Figure BDA0003247948740000107
表示候选区域Pn与类原型Em的类别相似性得分。
多头弱监督目标检测网络中定位感知的交叉注意力模块处理示意图如图6所示:采用线性映射层FCdk对输入的候选区域特征Fp进行映射,得到映射后的候选区域特征Fdk。采用层归一化层,线性映射层FCdq对类原型特征进行映射,得到映射后的类原型特征Edq。然后对映射后的类原型特征和映射后的候选区域特征进行切片,得到类原型特征切片
Figure BDA0003247948740000108
和候选区域特征切片
Figure BDA0003247948740000109
具体地,进行H1次切片,得到H1个类原型特征切片和H1个候选区域特征切片;接下来用dot-product操作计算第i个类原型特征切片和第i个候选区域特征切片的相似性得分,具体地,对每个候选区域特征切片
Figure BDA00032479487400001010
和类原型特征切片
Figure BDA00032479487400001011
进行向量内积,输出类原型特征和候选区域特征的相似性矩阵
Figure BDA00032479487400001012
⊙是dot-product操作,其他切片对也进行相同对操作,共输出H1个相似性矩阵
Figure BDA00032479487400001013
再在候选区域维度对相似性得分进行softmaxd归一化得到第i个候选区域目标相似性得分Pi,d,共可得到与切片个数相同个数的候选区域目标相似性得分Pi,d,i={1,...,H1},输出候选区域目标相似性矩阵
Figure BDA0003247948740000111
Figure BDA0003247948740000112
表示候选区域Pn与类原型Em的目标相似性得分。
具体地,多头弱监督目标检测网络中类感知的交叉注意力模块输出H1个候选区域类别相似性得分,定位感知的交叉注意力模块输出H1个候选区域目标性相似性得分,优选地,H1=4。
通过类感知的交叉注意模块和定位感知的交叉注意力模块从不同维度衡量类原型特征和候选区域特征的相似度,可以输出更准确的候选区域得分。
对同一区域的多头候选区域类别相似性得分和多头候选区域目标相似性得分融合得到候选区域得分。具体地,对第i个候选区域类别相似性得分和第i个候选区域目标相似性得分进行融合,用逐元素相乘操作(element-wise product)进行融合,得到第i个候选区域得分
Figure BDA0003247948740000113
是逐元素相乘操作。融合两个模块的输出得到多头候选区域得分Pi,i={1,...,H1}。
对多头候选区域得分聚合得到图像得分Si,采用求和池化层聚合候选区域得分得到图像得分。具体来说,对每个类别c,计算候选区域得分的累加和得到图像在类别c下的得分/预测值
Figure BDA0003247948740000114
因此,得到图像得分
Figure BDA0003247948740000115
Figure BDA0003247948740000116
对分别对多头候选区域得分进行聚合,可得到多头图像得分Si,i={1,...,H1}。
S34采用交叉熵损失函数计算对每个图像得分Si和图像标签之间的误差
Figure BDA0003247948740000117
并对多个误差值求和
Figure BDA0003247948740000118
使用随机梯度下降算法更新WCAN模型中主干网络、类原型特征和多头弱监督目标检测网络模型的参数;
S35重复步骤S32-S34,直至交叉熵损失函数的结果不再降低,得到训练好的WCAN模型。
上述步骤实现了基于交叉注意力机制的弱监督目标检测。构建的类感知的交叉注意力模块和定位感知的交叉注意力模块从不同维度约束了候选区域和类原型之间的相似度,正确性更高。提出的类原型特征又增加了两个模块之间的交互,其既可以学习类判别性知识又可以学习候选区域目标性知识。在综合性的类原型特征的帮助下,网络模型可以输出准确的候选区域得分,从而正确地选择目标候选区域。与此同时,本实施例构建的两个模块可以同时输出多头候选区域得分,是一个多头弱监督目标检测网络,可以防止网络陷入局部最小值。
进一步地,本实施例还提出多头弱监督位置定位网络,其主要由位置感知的交叉注意力模块组成。这个模块可以进一步优化类原型特征,用最相似的位置特征优化类原型特征,可以抑制非目标区域背景位置特征的干扰。具体内容为:
在S33和S34之间还可以包括:
将主干网络特征提取后的图像特征图
Figure BDA0003247948740000121
在空间维度进行形状变换得到位置特征,采用线性映射层FCalign2将位置特征维度对齐到类原型特征维度。
将位置特征Fpi和类原型特征E输入至多头弱监督位置定位网络中的位置感知的交叉注意力模块,得到多头位置得分PIi,i={1,...,H2}。
多头弱监督位置定位网络中的位置感知的交叉注意力模块处理示意图如图7所示,具体地,采用线性映射层FCpi_k对位置特征进行映射,得到映射后的位置特征Fpi_k。用层归一化层,线性映射层FCpi_q对类原型进行映射,得到映射后的类原型特征Epi_q。然后在特征维度对映射后的类原型特征和映射后的位置特征进行切分,例如,把D维的特征切分成H2份,得到H2个特征切片,每个切片的特征维度是D/H2,可得到H2个类原型特征切片
Figure BDA0003247948740000122
和H2个位置特征切片
Figure BDA0003247948740000123
Figure BDA0003247948740000124
接下来采用dot-product操作计算第i个类原型特征切片和第i个位置特征切片的相似性得分,具体地,对输出的
Figure BDA0003247948740000125
将第i个类原型特征切片
Figure BDA0003247948740000131
和第i个位置特征切片
Figure BDA0003247948740000132
进行向量内积,输出类原型特征和候选区域特征的相似性矩阵
Figure BDA0003247948740000133
⊙是dot-product操作。其他切片对也进行同样的操作,得到H2个相似性矩阵
Figure BDA0003247948740000134
再在位置维度对相似性得分进行softmaxl归一化得到第i个位置得分PIi,共可得到与切分个数相同个数的位置得分PIi,i={1,...,H2},输出位置得分矩阵
Figure BDA0003247948740000135
Figure BDA0003247948740000136
表示位置PIn与类原型Em的匹配得分。
分别对多头位置得分进行聚合,得到多头图像得分
Figure BDA0003247948740000137
用最大值池化层对位置得分进行聚合。具体来说,对每个类别c,选择位置得分的最大值得到图像在类别c下的预测值
Figure BDA0003247948740000138
因此,得到图像得分
Figure BDA0003247948740000139
Figure BDA00032479487400001310
分别对多头位置得分进行聚合,可得到多头图像得分
Figure BDA00032479487400001311
Figure BDA00032479487400001312
在位置维度对位置得分求最大值,得到图像得分。这相当于用最相似的位置特征来优化类原型特征,可以抑制非目标区域背景位置特征的干扰。
采用交叉熵损失函数计算对每个图像得分
Figure BDA00032479487400001313
和图像标签之间的误差
Figure BDA00032479487400001314
Figure BDA00032479487400001315
并对多个误差值求和,使用随机梯度下降算法更新WCAN模型中主干网络、类原型特征和多头弱监督目标检测网络模型的参数。此处需要说明的是,在步骤执行过程中,这里的交叉熵损失函数也要满足步骤S35中的条件。
S4基于训练好的WCAN模型,对测试集图像进行目标检测。图4为本实施例的测试流程示意图。
S41用主干网络提取测试集中的图像特征图。
S42用区域池化层提取图像的候选区域特征。
S43将候选区域特征和类原型特征输入至基于多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块得到多头候选区域得分。
S44对所有多头候选区域得分求平均值得到最终的候选区域得分;
S45采用非极大值抑制(NMS,Non-Maximum Suppression)算法去除重叠的和低分的候选区域,得到最终的目标检测结果。通过PASCAL VOC 2007数据集验证本实施例方法,具体地,用PASCAL VOC 2007中的训练集和验证集训练模型,用PASCAL VOC 2007中的测试集测试模型。在PASCAL VOC 2007测试集上,本实施例的方法的mAP(mean AveragePrecision)指标是0.401。相比于WSDDN方法,取得了很大的提升。如图8所示,本实施例方法得到的检测结果示意图。
由于在测试阶段,所有网络模型参数已经收敛,可以省去基于多头弱监督位置定位网络(在训练阶段,其作用主要是辅助优化主干网络和类原型特征的网络参数,抑制非目标区域位置特征的干扰),以提高目标检测的效率。
本领域技术人员应能理解,图1仅为简明起见而示出的各类网络元素的数量可能小于一个实际网络中的数量,但这种省略无疑是以不会影响对发明实施例进行清楚、充分的公开为前提的。
通过以上的实施方式的描述可知,本领域的技术人员可以清楚地了解到本发明可借助软件加必需的通用硬件平台的方式来实现。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在存储介质中,如ROM/RAM、磁碟、光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例或者实施例的某些部分所述的方法。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求的保护范围为准。

Claims (10)

1.一种基于交叉注意力机制的多头弱监督目标检测方法,其特征在于,包括:
S1获取图像,对所述图像进行处理,将处理后的图像分成训练集和测试集;
S2定义类原型特征,构建包含类原型特征的基于交叉注意力机制的弱监督目标检测网络WCAN模型;
S3采用训练数据集对WCAN模型进行训练;
S4基于训练好的WCAN模型,对测试集图像进行目标检测。
2.根据权利要求1所述的方法,其特征在于,所述对图像进行处理,包括:对每个图像,标注图像中包含的类别标签,采用选择性搜索方法或多尺度组合分组方法获取数据集中每个图像中的候选区域。
3.根据权利要求1所述的方法,其特征在于,所述定义类原型特征,包括:定义类原型特征
Figure FDA0003247948730000011
每个类原型特征Ei对应一个特定的语义类别,C是类别数量,D是类原型特征的维度,
Figure FDA0003247948730000012
表示实数。
4.根据权利要求1所述的方法,其特征在于,所述WCAN模型包括候选区域特征提取网络、类原型特征和多头弱监督目标检测网络;
所述多头弱监督目标检测网络包括并列的类别感知的交叉注意力模块和定位感知的交叉注意力模块。
5.根据权利要求4所述的方法,其特征在于,所述WCAN模型还包括多头弱监督位置定位网络,多头弱监督位置定位网络包括位置感知的交叉注意力模块。
6.根据权利要求1所述的方法,其特征在于,所述采用训练数据集对WCAN模型进行训练,包括:
S31对WCAN模型进行初始化;
S32将训练集中的图像输入至WCAN模型的候选区域特征提取网络,通过主干网络进行特征图提取;
S33将提取后的图像特征图与对应的候选区域同时输入至区域池化层提取候选区域特征Fp;候选区域特征Fp和类原型特征E同时输入至多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块,分别得到多头候选区域类别相似性得分和多头候选区域目标相似性得分,对同一区域的多头候选区域类别相似性得分和多头候选区域目标相似性得分融合得到候选区域得分,并对多头候选区域得分聚合得到图像得分Si
S34采用交叉熵损失函数计算对每个图像得分Si和图像标签之间的误差,并对多个误差值求和,使用随机梯度下降算法更新WCAN模型中主干网络、类原型特征和多头弱监督目标检测网络模型的参数;
S35重复步骤S32-S34,直至交叉熵损失函数的结果不再降低,得到训练好的WCAN模型。
7.根据权利要求6所述的方法,其特征在于,所述S33和S34之间还包括:
将主干网络特征提取后的图像特征图在空间维度进行形状变换得到位置特征,采用线性映射层将位置特征维度对齐到类原型特征维度;
将位置特征Fpi和类原型特征E输入至多头弱监督位置定位网络中的位置感知的交叉注意力模块,得到多头位置得分,分别对多头位置得分进行聚合,得到多头图像得分
Figure FDA0003247948730000021
采用交叉熵损失函数计算对每个图像得分
Figure FDA0003247948730000022
和图像标签之间的误差,并对多个误差值求和,使用随机梯度下降算法更新WCAN模型中主干网络、类原型特征和多头弱监督目标检测网络模型的参数。
8.根据权利要求6所述的方法,其特征在于,所述候选区域特征Fp和类原型特征E同时输入至多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块,分别得到多头候选区域类别相似性得分和多头候选区域目标相似性得分,具体包括:
对于多头弱监督目标检测网络中类感知的交叉注意力模块:采用线性映射层FCck对候选区域特征进行映射,采用层归一化层,线性映射层FCcq对类原型特征进行映射,然后在特征维度对映射后的类原型特征和映射后的候选区域特征进行切分;接下来用dot-product操作计算第i个类原型特征切片和第i个候选区域特征切片的相似性得分,再在类别维度对相似性得分进行softmaxc归一化得到第i个候选区域类别相似性得分Pi,c,共可得到与切分个数相同个数的候选区域类别相似性得分;
对于多头弱监督目标检测网络中定位感知的交叉注意力模块:采用线性映射层FCdk对输入的候选区域特征进行映射,采用层归一化层,线性映射层FCdq对类原型特征进行映射,然后对映射后的类原型特征和映射后的候选区域特征进行切片,得到类原型特征切片和候选区域特征切片;接下来用dot-product操作计算第i个类原型特征切片和第i个候选区域特征切片的相似性得分,再在候选区域维度对相似性得分进行softmaxd归一化得到第i个候选区域目标相似性得分Pi,d,共可得到与切片个数相同个数的候选区域目标相似性得分。
9.根据权利要求7所述的方法,其特征在于,所述将位置特征Fpi和类原型特征E输入至多头弱监督位置定位网络中的位置感知的交叉注意力模块,得到多头位置得分,具体包括:
采用线性映射层FCpi_k对位置特征进行映射,用层归一化层,线性映射层FCpi_q对类原型进行映射;然后在特征维度对映射后的类原型特征和映射后的位置特征进行切分;接下来采用dot-product操作计算第i个类原型特征切片和第i个位置特征切片的相似性得分,再在位置维度对相似性得分进行softmaxl归一化得到第i个位置得分PIi,共可得到与切分个数相同个数的位置得分。
10.根据权利要求1所述的方法,其特征在于,所述S4包括:
S41用主干网络提取测试集中的图像特征图;
S42用区域池化层提取图像的候选区域特征;
S43将候选区域特征和类原型特征输入至基于多头弱监督目标检测网络中的类感知的交叉注意力模块和定位感知的交叉注意力模块得到多头候选区域得分;
S44对所有多头候选区域得分求平均值得到最终的候选区域得分;
S45采用NMS算法去除重叠的和低分的候选区域,得到最终的目标检测结果。
CN202111037829.5A 2021-09-06 2021-09-06 基于交叉注意力机制的多头弱监督目标检测方法 Pending CN113920302A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111037829.5A CN113920302A (zh) 2021-09-06 2021-09-06 基于交叉注意力机制的多头弱监督目标检测方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111037829.5A CN113920302A (zh) 2021-09-06 2021-09-06 基于交叉注意力机制的多头弱监督目标检测方法

Publications (1)

Publication Number Publication Date
CN113920302A true CN113920302A (zh) 2022-01-11

Family

ID=79234069

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111037829.5A Pending CN113920302A (zh) 2021-09-06 2021-09-06 基于交叉注意力机制的多头弱监督目标检测方法

Country Status (1)

Country Link
CN (1) CN113920302A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114972737A (zh) * 2022-06-08 2022-08-30 湖南大学 基于原型对比学习的遥感图像目标检测系统及方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114972737A (zh) * 2022-06-08 2022-08-30 湖南大学 基于原型对比学习的遥感图像目标检测系统及方法
CN114972737B (zh) * 2022-06-08 2024-03-15 湖南大学 基于原型对比学习的遥感图像目标检测系统及方法

Similar Documents

Publication Publication Date Title
Lin et al. Automated defect inspection of LED chip using deep convolutional neural network
CN110097568B (zh) 一种基于时空双分支网络的视频对象检测与分割方法
Wang et al. Shape2motion: Joint analysis of motion parts and attributes from 3d shapes
Li et al. Localizing and quantifying damage in social media images
CN110765921B (zh) 一种基于弱监督学习和视频时空特征的视频物体定位方法
CN109740676B (zh) 基于相似目标的物体检测迁移方法
CA3066029A1 (en) Image feature acquisition
WO2018005413A1 (en) Method and system for cell annotation with adaptive incremental learning
CN110728694B (zh) 一种基于持续学习的长时视觉目标跟踪方法
CN110287879B (zh) 一种基于注意力机制的视频行为识别方法
Gómez et al. Cutting Sayre's Knot: reading scene text without segmentation. application to utility meters
CN111368690A (zh) 基于深度学习的海浪影响下视频图像船只检测方法及系统
CN108038515A (zh) 无监督多目标检测跟踪方法及其存储装置与摄像装置
CN115131613B (zh) 一种基于多向知识迁移的小样本图像分类方法
CN105930792A (zh) 一种基于视频局部特征字典的人体动作分类方法
US20230095533A1 (en) Enriched and discriminative convolutional neural network features for pedestrian re-identification and trajectory modeling
Saqib et al. Intelligent dynamic gesture recognition using CNN empowered by edit distance
Akhlaghi et al. Farsi handwritten phone number recognition using deep learning
CN110705384A (zh) 一种基于跨域迁移增强表示的车辆再识别方法
CN113920302A (zh) 基于交叉注意力机制的多头弱监督目标检测方法
CN117829243A (zh) 模型训练方法、目标检测方法、装置、电子设备及介质
CN110287970B (zh) 一种基于cam与掩盖的弱监督物体定位方法
Tu et al. Toward automatic plant phenotyping: starting from leaf counting
CN113158878B (zh) 一种基于子空间的异构迁移故障诊断方法、系统和模型
CN110135306B (zh) 基于角度损失函数的行为识别方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination