CN113902926A - 一种基于自注意力机制的通用图像目标检测方法和装置 - Google Patents

一种基于自注意力机制的通用图像目标检测方法和装置 Download PDF

Info

Publication number
CN113902926A
CN113902926A CN202111477045.4A CN202111477045A CN113902926A CN 113902926 A CN113902926 A CN 113902926A CN 202111477045 A CN202111477045 A CN 202111477045A CN 113902926 A CN113902926 A CN 113902926A
Authority
CN
China
Prior art keywords
layer
dimension
attention
ith
image
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202111477045.4A
Other languages
English (en)
Other versions
CN113902926B (zh
Inventor
李特
王世杰
朱世强
顾建军
王兴刚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Lab
Original Assignee
Zhejiang Lab
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Lab filed Critical Zhejiang Lab
Priority to CN202111477045.4A priority Critical patent/CN113902926B/zh
Publication of CN113902926A publication Critical patent/CN113902926A/zh
Application granted granted Critical
Publication of CN113902926B publication Critical patent/CN113902926B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开一种基于自注意力机制的通用图像目标检测方法,该方法是基于DETR模型的改进,其包括对将含边界框标注的训练集图像输入图像特征提取网络,获得图像特征;将图像特征依次通过多头十字交叉注意力模块和多方向交叉注意力模块,获得解码器输出增强目标查询向量;将增强目标查询向量分别通过模型的分类层和回归层得到目标图像物体边界框和物体类别概率;计算网络整体损失对模型进行训练,得到目标检测模型;利用上述模型对待检测图像进行目标检测。本发明相比于DETR模型,在保证目标检测准确的同时,加快模型训练速度,减小模型的计算复杂度,提高模型灵活性与实用性。

Description

一种基于自注意力机制的通用图像目标检测方法和装置
技术领域
本发明涉及计算机视觉技术领域,尤其涉及一种基于自注意力机制的通用图像目标检测方法和装置。
背景技术
在计算机视觉领域,目标检测是一项基本任务,其目的是将图像中待检测对象与背景区分开,并预测图像中待检测对象的位置和类别。现有的流行技术基于卷积神经网络。基于该技术目标检测算法可以分为两大类:一类是one-stage算法。其思路是直接产生待检测物体类别概率和坐标位置,不需要产生候选框。另一类算法是two-stage算法。其思路是将检测问题划分为两个部分:首先产生候选区域,然后基于候选区域进行分类和边框回归,得到待检测物体类别概率和坐标位置。
Transformer架构最初应用于自然语言处理领域。其关键的自注意力机制(Self-attention mechanism)使得Transformer架构在不同任务上取得了不错的效果。2020年,Facebook提出DETR模型首次将Transformer架构应用于目标检测领域。它针对基于卷积神经网络的目标检测模型需要不同人工设计步骤的问题。以更加直接的方法简化整个目标检测流程,实现真正地端到端解决目标检测问题。
然而,DETR模型也有其自身地问题:第一:模型收敛速度慢,训练时间长;第二:模型计算量大,这使得DETR模型现阶段难以用于实际应用。
发明内容
针对现有技术的不足,本发明基于Transformer架构,提供了一种基于自注意力机制的通用图像目标检测方法和装置。该方法基于Transformer架构得到图像目标检测模型,收敛速度更快,训练时间更短,计算量更少。
本发明的目的通过如下的技术方案来实现:
一种基于自注意力机制的通用图像目标检测方法,该方法包括如下步骤:
步骤一:将含边界框标注的训练集图像输入图像特征提取网络,获得维度为
Figure 340847DEST_PATH_IMAGE001
的图像特征;
步骤二:将所述图像特征,输入由L个十字交叉注意力层串联而成的多头十字交叉注意力模块,获得编码器输出特征图;
其中,所述十字交叉注意力层首先通过该层输入的图像特征获得该层多组编码器归一化后的注意力权重A和对应的编码器值向量V;然后通过A和V求出该层初步增强特征图
Figure 15542DEST_PATH_IMAGE002
;接着,将所述
Figure 991457DEST_PATH_IMAGE002
替换该层输入的图像特征,再次经过上述步骤,得到中间结果增强特征图;最后将中间结果增强特征图与输入的图像特征对应元素相加,经过该层编码器前馈网络和该层编码器层归一化后,得到该多头十字交叉注意力层输出的增强特征;
步骤三:将所述编码器输出特征图,通过由L个多方向交叉注意力层串联而成的多方向交叉注意力模块,获得解码器输出增强目标查询向量;
所述多方向交叉注意力层首先通过编码器输出特征图获得该层多组解码器归一化后的注意力权重
Figure 93405DEST_PATH_IMAGE003
和对应的解码器值向量
Figure 23446DEST_PATH_IMAGE004
,然后通过所述
Figure 767411DEST_PATH_IMAGE003
Figure 161615DEST_PATH_IMAGE004
求出该层融合后的增强目标查询特征
Figure 700043DEST_PATH_IMAGE005
;最后,改变所述
Figure 507593DEST_PATH_IMAGE005
维度为
Figure 55249DEST_PATH_IMAGE006
,依次经过该层解码器前馈网络和该层解码器层归一化后,得到该层多方向交叉注意力层输出的目标查询向量;
步骤四:将所述解码器输出增强目标查询向量分别通过分类层和回归层得到训练集图像的预测边界框类别概率和位置;
步骤五:将所述训练集图像的预测边界框类别概率和位置,与训练集图像的真实边界框类别和位置信息计算网络整体损失函数,通过反向传播方法对模型进行训练,得到目标检测模型;
步骤六:利用所述目标检测模型对待检测图像进行目标检测,以检测出所述待检测图像中待检测物体。
进一步地,每个多头十字交叉注意力层的操作具体如下:
S2.1:对于第一层多头十字交叉注意力层,将所述图像特征按照第1个维度等分为M组维度为
Figure 756489DEST_PATH_IMAGE007
子图像特征;对于第i层多头十字交叉注意力层,
Figure 715087DEST_PATH_IMAGE008
,将第i-1层多头十字交叉注意力层输出的维度为
Figure 777694DEST_PATH_IMAGE009
的增强特征
Figure 66724DEST_PATH_IMAGE010
按照第1个维度等分为M组子图像特征,其中第i层输入的第m组子图像特征
Figure 107624DEST_PATH_IMAGE011
的维度为
Figure 722276DEST_PATH_IMAGE012
Figure 612740DEST_PATH_IMAGE013
;将第L层多头十字交叉注意力层输出的增强特征
Figure 502199DEST_PATH_IMAGE014
作为编码器输出特征图;
S2.2:将所述
Figure 397605DEST_PATH_IMAGE011
分别经过第i层第m组
Figure 448738DEST_PATH_IMAGE015
编码器查询向量卷积
Figure 577231DEST_PATH_IMAGE016
和第i层第m组
Figure 254068DEST_PATH_IMAGE015
编码器匹配键值卷积
Figure 784407DEST_PATH_IMAGE017
,分别得到第i层第m组编码器查询向量
Figure 757173DEST_PATH_IMAGE018
、第i层第m组的编码器匹配键值
Figure 107383DEST_PATH_IMAGE019
;所述
Figure 322333DEST_PATH_IMAGE018
Figure 441598DEST_PATH_IMAGE019
维度为
Figure 100113DEST_PATH_IMAGE020
S2.3:采用下式计算第i层第m组第u个位置编码器未归一化注意力权重
Figure 676632DEST_PATH_IMAGE021
Figure 632956DEST_PATH_IMAGE022
其中,
Figure 295144DEST_PATH_IMAGE023
表示
Figure 858980DEST_PATH_IMAGE024
的第u个位置向量,其维度为
Figure 167471DEST_PATH_IMAGE025
Figure 740535DEST_PATH_IMAGE026
表示
Figure 319546DEST_PATH_IMAGE027
第u个位置同行同列向量,其维度为
Figure 319863DEST_PATH_IMAGE028
;u表示在分辨率维度上的一个位置,
Figure 850070DEST_PATH_IMAGE029
将所有位置
Figure 695666DEST_PATH_IMAGE030
拼接成为第i层第m组编码器未归一化注意力权重
Figure 660342DEST_PATH_IMAGE031
,其维度为
Figure 831561DEST_PATH_IMAGE032
Figure 130955DEST_PATH_IMAGE031
每一元素除以
Figure 763931DEST_PATH_IMAGE033
后,在第1个维度上进行softmax操作,得到第i层第m组编码器归一化后的注意力权重
Figure 363539DEST_PATH_IMAGE034
S2.4:将所述
Figure 190812DEST_PATH_IMAGE035
经过第i层第m组
Figure 711923DEST_PATH_IMAGE015
编码器值向量卷积
Figure 679748DEST_PATH_IMAGE036
,得到第i层第m组的编码器值向量
Figure 337125DEST_PATH_IMAGE037
,其维度为
Figure 600879DEST_PATH_IMAGE038
S2.5:根据下式计算第i层第m组第u个位置初步增强特征图
Figure 609286DEST_PATH_IMAGE039
Figure 865955DEST_PATH_IMAGE040
其中,
Figure 158265DEST_PATH_IMAGE041
表示
Figure 842187DEST_PATH_IMAGE042
第u个位置向量,其维度为
Figure 557465DEST_PATH_IMAGE043
Figure 617825DEST_PATH_IMAGE044
表示所述
Figure 967903DEST_PATH_IMAGE045
第u个位置同行同列向量,其维度为
Figure 822727DEST_PATH_IMAGE046
将所有位置
Figure 25300DEST_PATH_IMAGE039
拼接后经过第i层
Figure 623772DEST_PATH_IMAGE015
编码器融合卷积
Figure 890674DEST_PATH_IMAGE047
,从而得到第i层初步增强特征图
Figure 916399DEST_PATH_IMAGE048
,其维度为
Figure 386694DEST_PATH_IMAGE049
S2.6:将所述
Figure 551308DEST_PATH_IMAGE048
替换步骤S2.1中的
Figure 423449DEST_PATH_IMAGE050
,在所有卷积参数权值共享下,重复S2.1~S2.5后,将其输出的第i层第m组中间结果增强特征图的对应元素加上
Figure 869343DEST_PATH_IMAGE051
,最终获得第i层第m组再次增强特征图
Figure 826935DEST_PATH_IMAGE052
S2.7:将所述
Figure 32788DEST_PATH_IMAGE052
在第1个维度拼接,经过第i层编码器前馈网络和第i层编码器层归一化,得到第i层多头十字交叉注意力层输出的增强特征
Figure 510168DEST_PATH_IMAGE053
进一步地,所述步骤三中的每个多方向交叉注意力层进行如下操作:
S3.1:对于第一层多方向交叉注意力层,输入维度为
Figure 877695DEST_PATH_IMAGE054
的可学习的目标查询向量,并对所述目标查询向量进行标准正态分布的随机初始化;对于第i层多方向交叉注意力层,
Figure 322583DEST_PATH_IMAGE055
,将第i-1层多方向交叉注意力层输出的目标查询向量
Figure 581395DEST_PATH_IMAGE056
作为第i层多方向交叉注意力层输入的目标查询向量;将第L层多方向交叉注意力层输出的目标查询向量
Figure 162549DEST_PATH_IMAGE057
作为解码器输出增强目标查询向量;
S3.2:将所述
Figure 966557DEST_PATH_IMAGE056
输入到两层的多层感知机网络,生成维度为
Figure 118315DEST_PATH_IMAGE058
的第i层建议框;将所述第L层多头十字交叉注意力层输出的增强特征
Figure 931550DEST_PATH_IMAGE059
按照第1个维度等分为M组,第L层多头十字交叉注意力层输出的第m组的子图像特征
Figure 350899DEST_PATH_IMAGE060
的维度为
Figure 325808DEST_PATH_IMAGE061
S3.3:从N个第i层建议框中心出发,对所述
Figure 479709DEST_PATH_IMAGE062
均匀向外张开M个方向,在每个方向上使用双线性插值均匀采样K个点,得到维度为
Figure 316209DEST_PATH_IMAGE063
的第i层第m组采样视觉特征向量
Figure 871956DEST_PATH_IMAGE064
;通过改变维度的方式将所述
Figure 267034DEST_PATH_IMAGE056
变成维度为
Figure 908231DEST_PATH_IMAGE065
的第i层第m组目标查询特征
Figure 797689DEST_PATH_IMAGE066
S3.4:将所述
Figure 958674DEST_PATH_IMAGE066
经过第i层第m组
Figure 275386DEST_PATH_IMAGE015
解码器查询向量卷积
Figure 403879DEST_PATH_IMAGE067
,得到第i层第m组解码器查询向量
Figure 80717DEST_PATH_IMAGE068
,维度为
Figure 611056DEST_PATH_IMAGE065
;将所述
Figure 583822DEST_PATH_IMAGE064
经过第i层第m组
Figure 199611DEST_PATH_IMAGE015
解码器匹配键值卷积
Figure 165293DEST_PATH_IMAGE069
,得到第i层第m组解码器匹配键值
Figure 799406DEST_PATH_IMAGE070
,维度为
Figure 457920DEST_PATH_IMAGE063
S3.5:通过下式计算得到第i层第m组第j个解码器未归一化注意力权重
Figure 46158DEST_PATH_IMAGE071
Figure 815531DEST_PATH_IMAGE072
其中,
Figure 320462DEST_PATH_IMAGE073
为所述
Figure 399145DEST_PATH_IMAGE074
的第2个维度第j个矩阵,维度为
Figure 723947DEST_PATH_IMAGE075
Figure 297011DEST_PATH_IMAGE076
为所述
Figure 407181DEST_PATH_IMAGE070
的第2个维度第j个矩阵,维度为
Figure 141919DEST_PATH_IMAGE077
;其中,
Figure 203284DEST_PATH_IMAGE078
将所有维度
Figure 314460DEST_PATH_IMAGE079
在第2个维度进行拼接,成为第i层第m组解码器未归一化注意力权重
Figure 528403DEST_PATH_IMAGE080
,其维度为
Figure 173056DEST_PATH_IMAGE081
Figure 472451DEST_PATH_IMAGE082
每一元素除以
Figure 636585DEST_PATH_IMAGE033
后,在第1个维度上进行softmax操作,得到第i层第m组解码器归一化后的注意力权重
Figure 705035DEST_PATH_IMAGE083
S3.6:将所述
Figure 47154DEST_PATH_IMAGE084
经过第i层第m组
Figure 318998DEST_PATH_IMAGE015
解码器值向量卷积
Figure 771976DEST_PATH_IMAGE085
,得到第i层第m组的解码器值向量
Figure 944200DEST_PATH_IMAGE086
,其维度为
Figure 191642DEST_PATH_IMAGE087
S3.7:通过下式计算得到第i层第m组第j个增强目标查询特征
Figure 950782DEST_PATH_IMAGE088
Figure 207451DEST_PATH_IMAGE089
其中,
Figure 499761DEST_PATH_IMAGE090
表示所述
Figure 918104DEST_PATH_IMAGE091
第2个维度的第j个向量,其维度为
Figure 679386DEST_PATH_IMAGE092
Figure 959320DEST_PATH_IMAGE093
为所述
Figure 856869DEST_PATH_IMAGE086
第2个维度取出第j个向量,其维度为
Figure 226539DEST_PATH_IMAGE094
将所有
Figure 475118DEST_PATH_IMAGE088
在第2个维度拼接,成为第i层第m组增强目标查询特征
Figure 73590DEST_PATH_IMAGE095
,其维度为
Figure 576377DEST_PATH_IMAGE096
S3.8:将上述所有第i层每组增强目标查询特征在第3个维度拼接后,通过第i层
Figure 602102DEST_PATH_IMAGE015
解码器融合卷积
Figure 321665DEST_PATH_IMAGE097
,第i层得到融合后的增强目标查询特征
Figure 723828DEST_PATH_IMAGE098
,改变维度为
Figure 595969DEST_PATH_IMAGE099
,之后经过第i层解码器前馈网络和第i层解码器层归一化,得到第i层多方向交叉注意力层输出的目标查询向量
Figure 543327DEST_PATH_IMAGE100
,维度为
Figure 500919DEST_PATH_IMAGE101
进一步地,所述步骤四包括:将所述解码器输出增强目标查询向量分别输入到由两个不同全连接网络组成的分类层和回归层,输出训练集图像的预测边界框类别概率和位置。
进一步地,所述步骤五包括:将所述输出训练集图像的预测边界框类别概率和位置与真实边界框类别和位置信息通过匈牙利匹配算法获得最佳匹配,然后计算分类损失函数和位置回归损失函数之和作为网络整体损失;网络整体损失表达式如下所示:
Figure 424882DEST_PATH_IMAGE102
其中,
Figure 417108DEST_PATH_IMAGE103
表示分类损失函数,计算预测边界框类别概率和真实边界框类别的焦点损失;
Figure 784636DEST_PATH_IMAGE104
表示预测边界框位置和真实边界框位置的L1损失,
Figure 714677DEST_PATH_IMAGE105
表示预测边界框位置和真实边界框位置的广义的IoU损失,
Figure 193063DEST_PATH_IMAGE104
Figure 289064DEST_PATH_IMAGE105
之和表示位置回归损失函数;
Figure 827492DEST_PATH_IMAGE106
Figure 759676DEST_PATH_IMAGE107
Figure 792485DEST_PATH_IMAGE108
分别表示分类损失函数、L1损失和广义的IoU损失对应的权重系数;
最后使用反向传播方法对整个模型进行训练,当网络整体损失不再降低时,得到目标检测模型。
一种基于自注意力机制的通用图像目标检测装置,包括一个或多个处理器,用于实现上述的基于自注意力机制的通用图像目标检测方法。
一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述的基于自注意力机制的通用图像目标检测方法。
本发明具有如下的有益效果:
本发明提供的基于自注意力机制的通用图像目标检测方法,其中多头十字交叉注意力模块和多方向交叉注意力模块使得图像特征中空间信息更好的保留下来。特别地,多方向交叉注意力模块更好地聚集了第L层多头十字交叉注意力层输出的增强特征
Figure 228146DEST_PATH_IMAGE059
中局部信息,更加有利于检测框的定位。由于上述所设计的适合目标检测任务的操作,模型在一定程度上降低计算量;在保持精度一定的情况下,加快模型训练速度和收敛速度。
附图说明
图1为本发明提供的基于自注意力机制的通用图像目标检测方法的流程示意图;
图2为本发明提供的基于自注意力机制的通用图像目标检测方法的网络架构图;
图3为多方向交叉注意力模块采样示例图(图示张开16个方向,每个方向采样3个点)。
图4为本发明提供的基于自注意力机制的通用图像目标检测装置的结构框图。
具体实施方式
下面根据附图和优选实施例详细描述本发明,本发明的目的和效果将变得更加明白,应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
首先就本发明的技术术语进行解释说明:
ResNet50,ResNet50-DC5:ResNet是残差神经网络,由2015年提出并获得同年ILSVRC冠军。其主要贡献是通过快捷连接(Shortcut connection),消除了深度过大网络训练困难的问题;ResNet50表示50层的ResNet网络。ResNet50-DC5中DC5表示最后一个stage采用空洞率和stride相同的设置。目的是在不进行下采样基础上扩大感受野,输出特征图保持不变。
FPN:FPN是特征金字塔网络,于2017年提出,目的是采用特征金字塔做目标检测,通过自下而上网络提取不同网络层特征图,然后经过自上而下的网络融合特征图,最后在每一不同分辨率特征图上进行目标检测任务。这种特征融合和分而治之的思路被证明能够有效提高目标检测任务效果。
DETR:DEtection TRansformer模型的简称,于2020年提出。该模型将网络简化为图像特征提取模块,编码器和解码器模块,其中编码器模块使用多头自注意力机制获得编码后的序列特征;解码器模块通过输入可学习的目标查询向量和编码后的序列特征,使用多头自注意力机制获得增强的目标查询向量,最后通过两个不同的前馈网络获得待检测对象类别概率和边界框位置,从而实现端到端的目标检测流程。
请参阅图1,本发明实施例提供的基于自注意力机制的通用图像目标检测方法包括以下步骤:
S100、将含边界框标注的训练集图像输入图像特征提取网络,获得维度为
Figure 186744DEST_PATH_IMAGE001
的图像特征;
在本实例中,所述输入图像特征提取网络可以使用不同种类的卷积神经网络来提取图像特征。
S200、将所述图像特征,通过由L个十字交叉注意力层串联而成的多头十字交叉注意力模块,获得编码器输出特征图;
在本发明中,考虑到编码器设计应该关注全局信息的同时,针对DETR模型将二维图像特征转化为一维序列进行预测的设计会损失图像空间信息这一问题。本发明在该部分使用十字交叉注意力来进行改进。进一步地,考虑模型训练和其实际计算效率,本发明设计了多头十字交叉注意力模块来优化该部分。
所述多头十字交叉注意力模块请参阅图2,该模块由L个多头十字交叉注意力层组成,整体模块输入S100的图像特征,输出第L层多头十字交叉注意力层输出的增强特征,即编码器输出特征图。
所述十字交叉注意力层的操作用公式表达如下:
首先,通过下述公式的计算得到该层第m组分辨率维度上第u个空间位置编码器归一化后的注意力权重
Figure 340644DEST_PATH_IMAGE109
Figure 691991DEST_PATH_IMAGE110
上述公式中,
Figure 10189DEST_PATH_IMAGE111
是该层第m组
Figure 155999DEST_PATH_IMAGE015
编码器查询向量卷积,
Figure 797196DEST_PATH_IMAGE112
是该层第u个位置输入图像特征,
Figure 670343DEST_PATH_IMAGE113
是该层第u个位置同行同列输入图像特征,
Figure 80596DEST_PATH_IMAGE114
是该层第m组
Figure 148040DEST_PATH_IMAGE015
编码器匹配键值卷积,C为输出输入图像特征通道数,M是该层分组的总组数。
然后采用下列公式得到该层初步增强特征图
Figure 276533DEST_PATH_IMAGE115
Figure 704103DEST_PATH_IMAGE116
上述公式中,
Figure 218130DEST_PATH_IMAGE117
是该层第m组
Figure 705743DEST_PATH_IMAGE015
编码器融合卷积,
Figure 72265DEST_PATH_IMAGE118
是该层
Figure 303526DEST_PATH_IMAGE015
编码器值向量卷积。
接着,将所述该层初步增强特征图
Figure 688371DEST_PATH_IMAGE115
再重复上述步骤,得到中间结果增强特征图。将中间结果增强特征图与输入图像特征对应元素相加,得到该层再次增强特征图
Figure 330574DEST_PATH_IMAGE119
最后,将所述
Figure 433659DEST_PATH_IMAGE119
依次经过该层编码器前馈网络(feed-forward network)和该层编码器层归一化(Layer Normalization),得到该层多头十字交叉注意力层输出的增强特征。
具体地,每一个多头十字交叉注意力层具体步骤如下所述:
(1)对于第一层多头十字交叉注意力层,将所述图像特征按照第1个维度等分为M组维度为
Figure 203032DEST_PATH_IMAGE007
子图像特征;对于第i层多头十字交叉注意力层,
Figure 943848DEST_PATH_IMAGE008
,将第i-1层多头十字交叉注意力层输出的维度为
Figure 507685DEST_PATH_IMAGE009
的增强特征
Figure 347334DEST_PATH_IMAGE010
按照第1个维度等分为M组子图像特征,其中第i层输入第m组子图像特征
Figure 920398DEST_PATH_IMAGE011
的维度为
Figure 14255DEST_PATH_IMAGE012
Figure 765305DEST_PATH_IMAGE013
;将第L层多头十字交叉注意力层输出的增强特征
Figure 311824DEST_PATH_IMAGE014
作为编码器输出特征图。
(2)将所述
Figure 203425DEST_PATH_IMAGE011
分别经过第i层第m组
Figure 417369DEST_PATH_IMAGE015
编码器查询向量卷积
Figure 588587DEST_PATH_IMAGE016
和第i层第m组
Figure 638714DEST_PATH_IMAGE015
编码器匹配键值卷积
Figure 553581DEST_PATH_IMAGE017
,分别得到第i层第m组编码器查询向量
Figure 622031DEST_PATH_IMAGE018
、第i层第m组的编码器匹配键值
Figure 478997DEST_PATH_IMAGE019
;所述
Figure 108DEST_PATH_IMAGE018
Figure 718666DEST_PATH_IMAGE019
维度为
Figure 657934DEST_PATH_IMAGE020
(3)采用下式计算第i层第m组第u个位置编码器未归一化注意力权重
Figure 170955DEST_PATH_IMAGE021
Figure 694209DEST_PATH_IMAGE022
其中,
Figure 950878DEST_PATH_IMAGE023
表示
Figure 993920DEST_PATH_IMAGE024
的第u个位置向量,其维度为
Figure 416856DEST_PATH_IMAGE025
Figure 912560DEST_PATH_IMAGE026
表示
Figure 238499DEST_PATH_IMAGE027
第u个位置同行同列向量,其维度为
Figure 385315DEST_PATH_IMAGE028
;u表示在分辨率维度上的一个位置,
Figure 974559DEST_PATH_IMAGE029
将所有位置
Figure 973871DEST_PATH_IMAGE030
拼接成为第i层第m组编码器未归一化注意力权重
Figure 572342DEST_PATH_IMAGE031
,其维度为
Figure 324398DEST_PATH_IMAGE032
Figure 130548DEST_PATH_IMAGE031
每一元素除以
Figure 335265DEST_PATH_IMAGE033
后,在第1个维度上进行softmax操作,得到第i层第m组编码器归一化后的注意力权重
Figure 3006DEST_PATH_IMAGE034
(4)将所述
Figure 94721DEST_PATH_IMAGE035
经过第i层第m组
Figure 822506DEST_PATH_IMAGE015
编码器值向量卷积
Figure 780098DEST_PATH_IMAGE036
,得到第i层第m组的编码器值向量
Figure 969640DEST_PATH_IMAGE037
,其维度为
Figure 696287DEST_PATH_IMAGE038
(5)根据下式计算i层第m组第u个位置初步增强特征图
Figure 80126DEST_PATH_IMAGE039
Figure 259435DEST_PATH_IMAGE040
其中,
Figure 3400DEST_PATH_IMAGE041
表示
Figure 99401DEST_PATH_IMAGE042
第u个位置向量,其维度为
Figure 637829DEST_PATH_IMAGE120
Figure 570013DEST_PATH_IMAGE044
表示所述
Figure 602823DEST_PATH_IMAGE045
第u个位置同行同列向量,其维度为
Figure 569642DEST_PATH_IMAGE046
将所有位置
Figure 278972DEST_PATH_IMAGE039
拼接后经过第i层
Figure 947719DEST_PATH_IMAGE015
编码器融合卷积
Figure 767908DEST_PATH_IMAGE047
,从而得到第i层初步增强特征图
Figure 339966DEST_PATH_IMAGE048
,其维度为
Figure 485776DEST_PATH_IMAGE049
(6)将所述
Figure 126973DEST_PATH_IMAGE048
替换步骤(1)中的
Figure 120DEST_PATH_IMAGE050
,在所有卷积参数权值共享下,重复(1)~(5)后,将其输出的第i层第m组中间结果增强特征图的对应元素加上
Figure 675952DEST_PATH_IMAGE051
,最终获得第i层第m组再次增强特征图
Figure 727085DEST_PATH_IMAGE052
(7)将所述
Figure 606310DEST_PATH_IMAGE052
在第1个维度拼接,经过第i层编码器前馈网络和第i层编码器层归一化,得到第i层多头十字交叉注意力层输出的增强特征
Figure 33880DEST_PATH_IMAGE053
S300、将所述编码器输出特征图,通过由L个多方向交叉注意力层串联而成的多方向交叉注意力模块,获得解码器输出增强目标查询向量;
本发明该部分是针对DETR在解码器中将全局信息作为序列进行处理的操作进行优化。动机在于考虑到图像空间信息和上下文语义信息对于目标检测任务的重要性。从而本发明解码器的设计更关注于局部的上下文语义信息而不是全局信息。从而本发明提出了多方向交叉注意力模块。
所述多方向交叉注意力模块请参阅图2和3,该模块包括L个多方向交叉注意力层组成。整体模块输入编码器输出特征图,输出第L层多方向交叉注意力层输出的目标查询向量,即解码器输出增强目标查询向量。
所述多方向交叉注意力层的执行过程通过公式表达如下:
首先,由下述公式计算得到该层第m组第2维度第j个解码器归一化后的注意力权重
Figure 547907DEST_PATH_IMAGE121
Figure 301099DEST_PATH_IMAGE122
上述公式中,
Figure 916889DEST_PATH_IMAGE123
是该层第m组
Figure 633303DEST_PATH_IMAGE015
解码器查询向量卷积,
Figure 18148DEST_PATH_IMAGE124
是该层多方向交叉注意力层输入第2维度第j个的目标查询向量;
Figure 676662DEST_PATH_IMAGE125
是该层第m组
Figure 29015DEST_PATH_IMAGE015
解码器匹配键值卷积,
Figure 798388DEST_PATH_IMAGE126
是该层采样视觉特征向量,即编码器输出特征图;
然后采用下列公式得到该层融合后的增强目标查询特征
Figure 800191DEST_PATH_IMAGE127
Figure 629606DEST_PATH_IMAGE128
上述公式中,
Figure 219988DEST_PATH_IMAGE129
是该层第m组
Figure 776740DEST_PATH_IMAGE015
解码器值向量卷积,
Figure 401756DEST_PATH_IMAGE130
是该层第m组
Figure 136494DEST_PATH_IMAGE015
解码器融合卷积。
最后,改变所述
Figure 699325DEST_PATH_IMAGE131
维度为
Figure 341659DEST_PATH_IMAGE132
,之后依次经过该层解码器前馈网络和该层解码器层归一化,得到该层多方向交叉注意力层输出的目标查询向量。
进一步地,每个多方向交叉注意力层具体步骤如下所述:
(1)对于第一层多方向交叉注意力层,输入维度为
Figure 555602DEST_PATH_IMAGE133
的可学习的目标查询向量,并对所述目标查询向量进行标准正态分布的随机初始化;对于第i层多方向交叉注意力层,
Figure 976088DEST_PATH_IMAGE055
,将第i-1层多方向交叉注意力层输出的目标查询向量
Figure 275482DEST_PATH_IMAGE056
作为第i层多方向交叉注意力层输入的目标查询向量;将第L层多方向交叉注意力层输出的目标查询向量
Figure 675502DEST_PATH_IMAGE057
作为解码器输出增强目标查询向量。
(2)将所述
Figure 9531DEST_PATH_IMAGE056
输入到两层的多层感知机网络,生成维度为
Figure 351651DEST_PATH_IMAGE058
的第i层建议框;将所述第L层多头十字交叉注意力层输出的增强特征
Figure 387609DEST_PATH_IMAGE059
按照第1个维度等分为M组,第L层多头十字交叉注意力层输出的第m组的子图像特征
Figure 575008DEST_PATH_IMAGE134
的维度为
Figure 514276DEST_PATH_IMAGE061
(3)从N个第i层建议框中心出发,对所述
Figure 27297DEST_PATH_IMAGE062
均匀向外张开M个方向,在每个方向上使用双线性插值均匀采样K个点,得到维度为
Figure 301284DEST_PATH_IMAGE063
的第i层第m组采样视觉特征向量
Figure 807220DEST_PATH_IMAGE064
;通过改变维度的方式将所述
Figure 850263DEST_PATH_IMAGE056
变成维度为
Figure 534185DEST_PATH_IMAGE065
的第i层第m组目标查询特征
Figure 780621DEST_PATH_IMAGE066
(4)将所述
Figure 840981DEST_PATH_IMAGE066
经过第i层第m组
Figure 987797DEST_PATH_IMAGE015
解码器查询向量卷积
Figure 842620DEST_PATH_IMAGE067
,得到第i层第m组解码器查询向量
Figure 310773DEST_PATH_IMAGE074
,维度为
Figure 909245DEST_PATH_IMAGE065
;将所述
Figure 661300DEST_PATH_IMAGE064
经过第i层第m组
Figure 201872DEST_PATH_IMAGE015
解码器匹配键值卷积
Figure 672167DEST_PATH_IMAGE069
,得到第i层第m组解码器匹配键值
Figure 808751DEST_PATH_IMAGE070
,维度为
Figure 431624DEST_PATH_IMAGE063
(5)通过下式计算得到第i层第m组第j个解码器未归一化注意力权重
Figure 893829DEST_PATH_IMAGE071
Figure 835110DEST_PATH_IMAGE135
其中,
Figure 40963DEST_PATH_IMAGE073
为所述
Figure 502031DEST_PATH_IMAGE074
的第2个维度第j个矩阵,维度为
Figure 885870DEST_PATH_IMAGE075
Figure 330758DEST_PATH_IMAGE076
为所述
Figure 74723DEST_PATH_IMAGE070
的第2个维度第j个矩阵,维度为
Figure 905145DEST_PATH_IMAGE077
;其中,
Figure 709153DEST_PATH_IMAGE078
将所有维度
Figure 849192DEST_PATH_IMAGE071
在第2个维度进行拼接,成为第i层第m组解码器未归一化注意力权重
Figure 396848DEST_PATH_IMAGE082
,其维度为
Figure 98088DEST_PATH_IMAGE081
Figure 56685DEST_PATH_IMAGE082
每一元素除以
Figure 476165DEST_PATH_IMAGE033
后,在第1个维度上进行softmax操作,得到第i层第m组解码器归一化后的注意力权重
Figure 312665DEST_PATH_IMAGE083
(6)将所述
Figure 868412DEST_PATH_IMAGE084
经过第i层第m组
Figure 748643DEST_PATH_IMAGE015
解码器值向量卷积
Figure 639108DEST_PATH_IMAGE085
,得到第i层第m组的解码器值向量
Figure 528566DEST_PATH_IMAGE086
,其维度为
Figure 938819DEST_PATH_IMAGE087
(7)通过下式计算得到第i层第m组第j个增强目标查询特征
Figure 6263DEST_PATH_IMAGE088
Figure 134756DEST_PATH_IMAGE089
其中,
Figure 811594DEST_PATH_IMAGE090
表示所述
Figure 607512DEST_PATH_IMAGE091
第2个维度的第j个向量,其维度为
Figure 829546DEST_PATH_IMAGE092
Figure 930488DEST_PATH_IMAGE093
为所述
Figure 161749DEST_PATH_IMAGE086
第2个维度取出第j个向量,其维度为
Figure 546594DEST_PATH_IMAGE136
将所有
Figure 454376DEST_PATH_IMAGE088
在第2个维度拼接,成为第i层第m组增强目标查询特征
Figure 557461DEST_PATH_IMAGE095
,其维度为
Figure 592413DEST_PATH_IMAGE096
(8)将上述所有第i层每组增强目标查询特征在第3个维度拼接后,通过第i层
Figure 316918DEST_PATH_IMAGE015
解码器融合卷积
Figure 146334DEST_PATH_IMAGE097
,第i层得到融合后的增强目标查询特征
Figure 720403DEST_PATH_IMAGE098
,改变维度为
Figure 293467DEST_PATH_IMAGE099
,之后经过第i层解码器前馈网络和第i层解码器层归一化,得到第i层多方向交叉注意力层输出的目标查询向量
Figure 918484DEST_PATH_IMAGE100
,维度为
Figure 669533DEST_PATH_IMAGE101
S400、将所述解码器输出增强目标查询向量分别通过分类层和回归层得到训练集图像的预测边界框类别概率和位置。
在本实例中,请参阅图2,将解码器输出增强目标查询向量分别输出到由两个不同全连接网络组成的分类层和回归层输出训练集图像的预测边界框类别概率和位置。其中,解码器输出增强目标查询向量即所述第L层多方向交叉注意力层输出的目标查询向量
Figure 216052DEST_PATH_IMAGE137
S500、将所述训练集图像的预测边界框类别概率和位置和训练集图像的真实边界框类别和位置信息计算网络整体损失,通过反向传播方法对模型进行训练,得到目标检测模型;
在本实例中,将所述输出训练集图像的预测边界框类别概率和位置与真实边界框类别和位置信息通过匈牙利匹配算法获得最佳匹配,然后计算分类损失函数和位置回归损失函数之和作为网络整体损失。网络整体损失表达式如下所示:
Figure 576495DEST_PATH_IMAGE102
其中,
Figure 790439DEST_PATH_IMAGE103
表示分类损失函数,计算预测边界框类别概率和真实边界框类别的焦点损失。
Figure 961657DEST_PATH_IMAGE104
表示预测边界框位置和真实边界框位置的L1损失,
Figure 11784DEST_PATH_IMAGE105
表示预测边界框位置和真实边界框位置的广义的IoU损失,
Figure 926650DEST_PATH_IMAGE104
Figure 526259DEST_PATH_IMAGE105
之和表示位置回归损失函数。
Figure 117646DEST_PATH_IMAGE106
Figure 373178DEST_PATH_IMAGE107
Figure 357315DEST_PATH_IMAGE108
分别表示分类损失函数、L1损失和广义的IoU损失对应的权重系数。
最后使用反向传播方法对整个模型进行训练,当网络整体损失不再降低时,得到目标检测模型。
S600、利用所述目标检测模型对待检测图像进行目标检测,以检测出所述待检测图像中待检测物体。
与前述基于自注意力机制的通用图像目标检测方法的实施例相对应,本发明还提供了基于自注意力机制的通用图像目标检测装置的实施例。
参见图4,本发明实施例提供的一种基于自注意力机制的通用图像目标检测装置,包括一个或多个处理器,用于实现上述实施例中的基于自注意力机制的通用图像目标检测方法。
本发明基于自注意力制的通用图像目标检测装置的实施例可以应用在任意具备数据处理能力的设备上,该任意具备数据处理能力的设备可以为诸如计算机等设备或置。装置实施例可以通过软件实现,也可以通过硬件或者软硬件结合的方式实现。
Figure 296583DEST_PATH_IMAGE138
以软件实现为例,作为一个逻辑意义上的装置,是通过其所在任意具备数据处理能力的设备的处理器将非易失性存储器中对应的计算机程序指令读取到内存中运行形成的。从硬件层面而言,如图4所示,为本发明基于自注意力机制的通用图像目标检测装置所在任意具备数据处理能力的设备的一种硬件结构图,除了图4所示的处理器、内存、网络接口、以及非易失性存储器之外,实施例中装置所在的任意具备数据处理能力的设备通常根据该任意具备数据处理能力的设备的实际功能,还可以包括其他硬件,对此不再赘述。
上述装置中各个单元的功能和作用的实现过程具体详见上述方法中对应步骤的实现过程,在此不再赘述。
对于装置实施例而言,由于其基本对应于方法实施例,所以相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本发明方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
本发明实施例还提供一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时,实现上述实施例中的基于自注意力机制的通用图像目标检测方法。
所述计算机可读存储介质可以是前述任一实施例所述的任意具备数据处理能力的设备的内部存储单元,例如硬盘或内存。所述计算机可读存储介质也可以是外部存储设备,例如所述设备上配备的插接式硬盘、智能存储卡(SmartMedia card, SMC)、SD卡、闪存卡(Flash card)等。进一步的,所述计算机可读存储介质还可以既包括任意具备数据处理能力的设备的内部存储单元也包括外部存储设备。所述计算机可读存储介质用于存储所述计算仉程序以及所述任意具备数据处理能力的设备所需的其他程序和数据,还可以用于暂时地存储己经输出或者将要输出的数据。
为了验证本发明有效性,将本发明在COCO 2017训练集上进行训练,使用
Figure 809604DEST_PATH_IMAGE139
层多头十字交叉注意力层和多方向交叉注意力层。在所述多头十字交叉注意力模块和多方向交叉注意力模块中取
Figure 67279DEST_PATH_IMAGE140
。网络整体损失表达式中
Figure 323948DEST_PATH_IMAGE141
。为了更好对比,同时实现了已有的同规模DETR作为参考。将以上模型训练好在COCO 2017验证集上进行测试,结果如下表所示。
实验结果表明,本发明(CCTR)在保证检测准确率的前提下,模型训练迭代次数由原来的500轮减少到50轮,模型参数量与计算量有一定程度减少。对比表中CCTR-ResNet-5-FPN和DETR-DC5的结果,可以发现本发明在获取更好的物体检测精度(AP)的情况下,模型计算量(FLOPs)更小,且训练迭代次数减少了90%。
表1 本发明方案模型与DETR模型在COCO 2017验证集结果对比表
Figure 366990DEST_PATH_IMAGE142
本领域普通技术人员可以理解,以上所述仅为发明的优选实例而已,并不用于限制发明,尽管参照前述实例对发明进行了详细的说明,对于本领域的技术人员来说,其依然可以对前述各实例记载的技术方案进行修改,或者对其中部分技术特征进行等同替换。凡在发明的精神和原则之内,所做的修改、等同替换等均应包含在发明的保护范围之内。

Claims (7)

1.一种基于自注意力机制的通用图像目标检测方法,其特征在于,该方法包括如下步骤:
步骤一:将含边界框标注的训练集图像输入图像特征提取网络,获得维度为
Figure 445950DEST_PATH_IMAGE001
的图像特征;
步骤二:将所述图像特征,输入由L个十字交叉注意力层串联而成的多头十字交叉注意力模块,获得编码器输出特征图;
其中,所述十字交叉注意力层首先通过该层输入的图像特征获得该层多组编码器归一化后的注意力权重A和对应的编码器值向量V;然后通过A和V求出该层初步增强特征图
Figure 300773DEST_PATH_IMAGE002
;接着,将所述
Figure 1882DEST_PATH_IMAGE002
替换该层输入的图像特征,再次经过上述步骤,得到中间结果增强特征图;最后将中间结果增强特征图与输入的图像特征对应元素相加,经过该层编码器前馈网络和该层编码器层归一化后,得到该多头十字交叉注意力层输出的增强特征;
步骤三:将所述编码器输出特征图,通过由L个多方向交叉注意力层串联而成的多方向交叉注意力模块,获得解码器输出增强目标查询向量;
所述多方向交叉注意力层首先通过编码器输出特征图获得该层多组解码器归一化后的注意力权重
Figure 131512DEST_PATH_IMAGE003
和对应的解码器值向量
Figure 617988DEST_PATH_IMAGE004
,然后通过所述
Figure 128866DEST_PATH_IMAGE003
Figure 130320DEST_PATH_IMAGE004
求出该层融合后的增强目标查询特征
Figure 250592DEST_PATH_IMAGE005
;最后,改变所述
Figure 857154DEST_PATH_IMAGE005
维度为
Figure 584938DEST_PATH_IMAGE006
,依次经过该层解码器前馈网络和该层解码器层归一化后,得到该层多方向交叉注意力层输出的目标查询向量;
步骤四:将所述解码器输出增强目标查询向量分别通过分类层和回归层得到训练集图像的预测边界框类别概率和位置;
步骤五:将所述训练集图像的预测边界框类别概率和位置,与训练集图像的真实边界框类别和位置信息计算网络整体损失函数,通过反向传播方法对模型进行训练,得到目标检测模型;
步骤六:利用所述目标检测模型对待检测图像进行目标检测,以检测出所述待检测图像中待检测物体。
2.根据权利要求1所述的基于自注意力机制的通用图像目标检测方法,其特征在于,每个多头十字交叉注意力层的操作具体如下:
S2.1:对于第一层多头十字交叉注意力层,将所述图像特征按照第1个维度等分为M组维度为
Figure 762104DEST_PATH_IMAGE007
子图像特征;对于第i层多头十字交叉注意力层,
Figure 233536DEST_PATH_IMAGE008
,将第i-1层多头十字交叉注意力层输出的维度为
Figure 678293DEST_PATH_IMAGE009
的增强特征
Figure 842558DEST_PATH_IMAGE010
按照第1个维度等分为M组子图像特征,其中第i层输入的第m组子图像特征
Figure 756288DEST_PATH_IMAGE011
的维度为
Figure 708108DEST_PATH_IMAGE012
Figure 86000DEST_PATH_IMAGE013
;将第L层多头十字交叉注意力层输出的增强特征
Figure 342538DEST_PATH_IMAGE014
作为编码器输出特征图;
S2.2:将所述
Figure 805880DEST_PATH_IMAGE011
分别经过第i层第m组
Figure 556798DEST_PATH_IMAGE015
编码器查询向量卷积
Figure 539929DEST_PATH_IMAGE016
和第i层第m组
Figure 983680DEST_PATH_IMAGE015
编码器匹配键值卷积
Figure 668739DEST_PATH_IMAGE017
,分别得到第i层第m组编码器查询向量
Figure 738195DEST_PATH_IMAGE018
、第i层第m组的编码器匹配键值
Figure 762783DEST_PATH_IMAGE019
;所述
Figure 190484DEST_PATH_IMAGE018
Figure 300523DEST_PATH_IMAGE019
维度为
Figure 455560DEST_PATH_IMAGE020
S2.3:采用下式计算第i层第m组第u个位置编码器未归一化注意力权重
Figure 849502DEST_PATH_IMAGE021
Figure 697372DEST_PATH_IMAGE022
其中,
Figure 311018DEST_PATH_IMAGE023
表示
Figure 207430DEST_PATH_IMAGE024
的第u个位置向量,其维度为
Figure 3348DEST_PATH_IMAGE025
Figure 209070DEST_PATH_IMAGE026
表示
Figure 356017DEST_PATH_IMAGE027
第u个位置同行同列向量,其维度为
Figure 806853DEST_PATH_IMAGE028
;u表示在分辨率维度上的一个位置,
Figure 722856DEST_PATH_IMAGE029
将所有位置
Figure 850212DEST_PATH_IMAGE030
拼接成为第i层第m组编码器未归一化注意力权重
Figure 671406DEST_PATH_IMAGE031
,其维度为
Figure 971938DEST_PATH_IMAGE032
Figure 696442DEST_PATH_IMAGE031
每一元素除以
Figure 791437DEST_PATH_IMAGE033
后,在第1个维度上进行softmax操作,得到第i层第m组编码器归一化后的注意力权重
Figure 116239DEST_PATH_IMAGE034
S2.4:将所述
Figure 407412DEST_PATH_IMAGE035
经过第i层第m组
Figure 766849DEST_PATH_IMAGE015
编码器值向量卷积
Figure 298325DEST_PATH_IMAGE036
,得到第i层第m组的编码器值向量
Figure 595576DEST_PATH_IMAGE037
,其维度为
Figure 441173DEST_PATH_IMAGE038
S2.5:根据下式计算第i层第m组第u个位置初步增强特征图
Figure 435542DEST_PATH_IMAGE039
Figure 75602DEST_PATH_IMAGE040
其中,
Figure 906155DEST_PATH_IMAGE041
表示
Figure 40595DEST_PATH_IMAGE042
第u个位置向量,其维度为
Figure 640204DEST_PATH_IMAGE043
Figure 700433DEST_PATH_IMAGE044
表示所述
Figure 18282DEST_PATH_IMAGE045
第u个位置同行同列向量,其维度为
Figure 205680DEST_PATH_IMAGE046
将所有位置
Figure 625509DEST_PATH_IMAGE047
拼接后经过第i层
Figure 669688DEST_PATH_IMAGE015
编码器融合卷积
Figure 661784DEST_PATH_IMAGE048
,从而得到第i层初步增强特征图
Figure 449611DEST_PATH_IMAGE049
,其维度为
Figure 227075DEST_PATH_IMAGE050
S2.6:将所述
Figure 130571DEST_PATH_IMAGE051
替换步骤S2.1中的
Figure 157433DEST_PATH_IMAGE052
,在所有卷积参数权值共享下,重复S2.1~S2.5后,将其输出的第i层第m组中间结果增强特征图的对应元素加上
Figure 935902DEST_PATH_IMAGE053
,最终获得第i层第m组再次增强特征图
Figure 364609DEST_PATH_IMAGE054
S2.7:将所述
Figure 953853DEST_PATH_IMAGE054
在第1个维度拼接,经过第i层编码器前馈网络和第i层编码器层归一化,得到第i层多头十字交叉注意力层输出的增强特征
Figure 422006DEST_PATH_IMAGE055
3.根据权利要求1所述的基于自注意力机制的通用图像目标检测方法,其特征在于,所述步骤三中的每个多方向交叉注意力层进行如下操作:
S3.1:对于第一层多方向交叉注意力层,输入维度为
Figure 551636DEST_PATH_IMAGE056
的可学习的目标查询向量,并对所述目标查询向量进行标准正态分布的随机初始化;对于第i层多方向交叉注意力层,
Figure 569271DEST_PATH_IMAGE057
,将第i-1层多方向交叉注意力层输出的目标查询向量
Figure 844263DEST_PATH_IMAGE058
作为第i层多方向交叉注意力层输入的目标查询向量;将第L层多方向交叉注意力层输出的目标查询向量
Figure 783400DEST_PATH_IMAGE059
作为解码器输出增强目标查询向量;
S3.2:将所述
Figure 670716DEST_PATH_IMAGE058
输入到两层的多层感知机网络,生成维度为
Figure 808436DEST_PATH_IMAGE060
的第i层建议框;将所述第L层多头十字交叉注意力层输出的增强特征
Figure 988751DEST_PATH_IMAGE061
按照第1个维度等分为M组,第L层多头十字交叉注意力层输出的第m组的子图像特征
Figure 477501DEST_PATH_IMAGE062
的维度为
Figure 886616DEST_PATH_IMAGE063
S3.3:从N个第i层建议框中心出发,对所述
Figure 895155DEST_PATH_IMAGE064
均匀向外张开M个方向,在每个方向上使用双线性插值均匀采样K个点,得到维度为
Figure 997103DEST_PATH_IMAGE065
的第i层第m组采样视觉特征向量
Figure 160100DEST_PATH_IMAGE066
;通过改变维度的方式将所述
Figure 435224DEST_PATH_IMAGE058
变成维度为
Figure 750798DEST_PATH_IMAGE067
的第i层第m组目标查询特征
Figure 774380DEST_PATH_IMAGE068
S3.4:将所述
Figure 237723DEST_PATH_IMAGE068
经过第i层第m组
Figure 237909DEST_PATH_IMAGE015
解码器查询向量卷积
Figure 470307DEST_PATH_IMAGE069
,得到第i层第m组解码器查询向量
Figure 648478DEST_PATH_IMAGE070
,维度为
Figure 349849DEST_PATH_IMAGE067
;将所述
Figure 904459DEST_PATH_IMAGE066
经过第i层第m组
Figure 443893DEST_PATH_IMAGE015
解码器匹配键值卷积
Figure 120862DEST_PATH_IMAGE071
,得到第i层第m组解码器匹配键值
Figure 230901DEST_PATH_IMAGE072
,维度为
Figure 136671DEST_PATH_IMAGE065
S3.5:通过下式计算得到第i层第m组第j个解码器未归一化注意力权重
Figure 281344DEST_PATH_IMAGE073
Figure 863635DEST_PATH_IMAGE074
其中,
Figure 975817DEST_PATH_IMAGE075
为所述
Figure 872229DEST_PATH_IMAGE070
的第2个维度第j个矩阵,维度为
Figure 672739DEST_PATH_IMAGE076
Figure 629194DEST_PATH_IMAGE077
为所述
Figure 776141DEST_PATH_IMAGE072
的第2个维度第j个矩阵,维度为
Figure 725512DEST_PATH_IMAGE078
;其中,
Figure 579198DEST_PATH_IMAGE079
将所有维度
Figure 519604DEST_PATH_IMAGE073
在第2个维度进行拼接,成为第i层第m组解码器未归一化注意力权重
Figure 91530DEST_PATH_IMAGE080
,其维度为
Figure 844592DEST_PATH_IMAGE081
Figure 615102DEST_PATH_IMAGE080
每一元素除以
Figure 913359DEST_PATH_IMAGE033
后,在第1个维度上进行softmax操作,得到第i层第m组解码器归一化后的注意力权重
Figure 785631DEST_PATH_IMAGE082
S3.6:将所述
Figure 561957DEST_PATH_IMAGE083
经过第i层第m组
Figure 452553DEST_PATH_IMAGE015
解码器值向量卷积
Figure 170979DEST_PATH_IMAGE084
,得到第i层第m组的解码器值向量
Figure 717498DEST_PATH_IMAGE085
,其维度为
Figure 376143DEST_PATH_IMAGE086
S3.7:通过下式计算得到第i层第m组第j个增强目标查询特征
Figure 58929DEST_PATH_IMAGE087
Figure 761305DEST_PATH_IMAGE088
其中,
Figure 778809DEST_PATH_IMAGE089
表示所述
Figure 162517DEST_PATH_IMAGE090
第2个维度的第j个向量,其维度为
Figure 778437DEST_PATH_IMAGE091
Figure 589398DEST_PATH_IMAGE092
为所述
Figure 907247DEST_PATH_IMAGE085
第2个维度取出第j个向量,其维度为
Figure 78334DEST_PATH_IMAGE093
将所有
Figure 1291DEST_PATH_IMAGE087
在第2个维度拼接,成为第i层第m组增强目标查询特征
Figure 530624DEST_PATH_IMAGE094
,其维度为
Figure 539031DEST_PATH_IMAGE095
S3.8:将上述所有第i层每组增强目标查询特征在第3个维度拼接后,通过第i层
Figure 779388DEST_PATH_IMAGE015
解码器融合卷积
Figure 353589DEST_PATH_IMAGE096
,第i层得到融合后的增强目标查询特征
Figure 506353DEST_PATH_IMAGE097
,改变维度为
Figure 283947DEST_PATH_IMAGE098
,之后经过第i层解码器前馈网络和第i层解码器层归一化,得到第i层多方向交叉注意力层输出的目标查询向量
Figure 813148DEST_PATH_IMAGE099
,维度为
Figure 241856DEST_PATH_IMAGE100
4.根据权利要求1所述的基于自注意力机制的通用图像目标检测方法,其特征在于,所述步骤四包括:将所述解码器输出增强目标查询向量分别输入到由两个不同全连接网络组成的分类层和回归层,输出训练集图像的预测边界框类别概率和位置。
5.根据权利要求1所述的基于自注意力机制的通用图像目标检测方法,其特征在于,所述步骤五包括:将所述输出训练集图像的预测边界框类别概率和位置与真实边界框类别和位置信息通过匈牙利匹配算法获得最佳匹配,然后计算分类损失函数和位置回归损失函数之和作为网络整体损失;网络整体损失表达式如下所示:
Figure 814788DEST_PATH_IMAGE101
其中,
Figure 17362DEST_PATH_IMAGE102
表示分类损失函数,计算预测边界框类别概率和真实边界框类别的焦点损失;
Figure 146992DEST_PATH_IMAGE103
表示预测边界框位置和真实边界框位置的L1损失,
Figure 367889DEST_PATH_IMAGE104
表示预测边界框位置和真实边界框位置的广义的IoU损失,
Figure 439619DEST_PATH_IMAGE103
Figure 378756DEST_PATH_IMAGE104
之和表示位置回归损失函数;
Figure 12211DEST_PATH_IMAGE105
Figure 415511DEST_PATH_IMAGE106
Figure 346558DEST_PATH_IMAGE107
分别表示分类损失函数、L1损失和广义的IoU损失对应的权重系数;
最后使用反向传播方法对整个模型进行训练,当网络整体损失不再降低时,得到目标检测模型。
6.一种基于自注意力机制的通用图像目标检测装置,其特征在于,包括一个或多个处理器,用于实现权利要求1-5中任一项所述的基于自注意力机制的通用图像目标检测方法。
7.一种计算机可读存储介质,其特征在于,其上存储有程序,该程序被处理器执行时,实现权利要求1-5中任一项所述的基于自注意力机制的通用图像目标检测方法。
CN202111477045.4A 2021-12-06 2021-12-06 一种基于自注意力机制的通用图像目标检测方法和装置 Active CN113902926B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111477045.4A CN113902926B (zh) 2021-12-06 2021-12-06 一种基于自注意力机制的通用图像目标检测方法和装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111477045.4A CN113902926B (zh) 2021-12-06 2021-12-06 一种基于自注意力机制的通用图像目标检测方法和装置

Publications (2)

Publication Number Publication Date
CN113902926A true CN113902926A (zh) 2022-01-07
CN113902926B CN113902926B (zh) 2022-05-31

Family

ID=79195365

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111477045.4A Active CN113902926B (zh) 2021-12-06 2021-12-06 一种基于自注意力机制的通用图像目标检测方法和装置

Country Status (1)

Country Link
CN (1) CN113902926B (zh)

Cited By (16)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114283347A (zh) * 2022-03-03 2022-04-05 粤港澳大湾区数字经济研究院(福田) 目标检测方法、系统、智能终端及计算机可读存储介质
CN114359283A (zh) * 2022-03-18 2022-04-15 华东交通大学 基于Transformer的缺陷检测方法和电子设备
CN114596273A (zh) * 2022-03-02 2022-06-07 江南大学 利用yolov4网络的陶瓷基板多种瑕疵智能检测方法
CN114612378A (zh) * 2022-01-21 2022-06-10 华东师范大学 一种目标检测中使用IoU加强自注意力机制的方法
CN114758032A (zh) * 2022-06-15 2022-07-15 之江实验室 基于时空注意力模型的多相期ct图像分类系统及构建方法
CN114972976A (zh) * 2022-07-29 2022-08-30 之江实验室 基于频域自注意力机制的夜间目标检测、训练方法及装置
CN114998748A (zh) * 2022-07-28 2022-09-02 北京卫星信息工程研究所 遥感图像目标精细识别方法、电子设备及存储介质
CN115170828A (zh) * 2022-07-15 2022-10-11 哈尔滨市科佳通用机电股份有限公司 基于深度学习的折角塞门卡子丢失故障检测方法
CN115953665A (zh) * 2023-03-09 2023-04-11 武汉人工智能研究院 一种目标检测方法、装置、设备及存储介质
CN116129228A (zh) * 2023-04-19 2023-05-16 中国科学技术大学 图像匹配模型的训练方法、图像匹配方法及其装置
CN116258931A (zh) * 2022-12-14 2023-06-13 之江实验室 基于ViT和滑窗注意力融合的视觉指代表达理解方法和系统
CN116384593A (zh) * 2023-06-01 2023-07-04 深圳市国电科技通信有限公司 分布式光伏出力预测方法、装置、电子设备和介质
CN116993996A (zh) * 2023-09-08 2023-11-03 腾讯科技(深圳)有限公司 对图像中的对象进行检测的方法及装置
CN117152142A (zh) * 2023-10-30 2023-12-01 菲特(天津)检测技术有限公司 一种轴承缺陷检测模型构建方法及系统
WO2024007619A1 (zh) * 2022-07-06 2024-01-11 京东科技信息技术有限公司 解码器的训练方法、目标检测方法、装置以及存储介质
CN117542045A (zh) * 2024-01-10 2024-02-09 济南大学 一种基于空间引导自注意力的食品识别方法及系统

Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112819037A (zh) * 2021-01-12 2021-05-18 广东石油化工学院 基于交叉注意力和自注意力的分类参数分布的故障诊断方法

Patent Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112819037A (zh) * 2021-01-12 2021-05-18 广东石油化工学院 基于交叉注意力和自注意力的分类参数分布的故障诊断方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
YUAN LI 等: "《Deep attention network for joint hand gesture localization and recognition using static RGB-D images》", 《INFORMATION SCIENCES》 *
费文曲: "《基于位置感知交叉注意力网络的方面情感分析》", 《信息通信》 *
马浩男 等: "《十字交叉与时空注意力在动作识别系统的实践》", 《福建电脑》 *

Cited By (29)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114612378A (zh) * 2022-01-21 2022-06-10 华东师范大学 一种目标检测中使用IoU加强自注意力机制的方法
CN114612378B (zh) * 2022-01-21 2024-04-26 华东师范大学 一种目标检测中使用IoU加强自注意力机制的方法
CN114596273B (zh) * 2022-03-02 2022-11-25 江南大学 利用yolov4网络的陶瓷基板多种瑕疵智能检测方法
CN114596273A (zh) * 2022-03-02 2022-06-07 江南大学 利用yolov4网络的陶瓷基板多种瑕疵智能检测方法
CN114283347B (zh) * 2022-03-03 2022-07-15 粤港澳大湾区数字经济研究院(福田) 目标检测方法、系统、智能终端及计算机可读存储介质
CN114283347A (zh) * 2022-03-03 2022-04-05 粤港澳大湾区数字经济研究院(福田) 目标检测方法、系统、智能终端及计算机可读存储介质
CN114359283A (zh) * 2022-03-18 2022-04-15 华东交通大学 基于Transformer的缺陷检测方法和电子设备
CN114359283B (zh) * 2022-03-18 2022-07-05 华东交通大学 基于Transformer的缺陷检测方法和电子设备
CN114758032A (zh) * 2022-06-15 2022-07-15 之江实验室 基于时空注意力模型的多相期ct图像分类系统及构建方法
JP7411126B2 (ja) 2022-06-15 2024-01-10 之江実験室 時空間的アテンションモデルに基づく多時相ct画像分類システム及び構築方法
WO2024007619A1 (zh) * 2022-07-06 2024-01-11 京东科技信息技术有限公司 解码器的训练方法、目标检测方法、装置以及存储介质
CN115170828A (zh) * 2022-07-15 2022-10-11 哈尔滨市科佳通用机电股份有限公司 基于深度学习的折角塞门卡子丢失故障检测方法
CN115170828B (zh) * 2022-07-15 2023-03-14 哈尔滨市科佳通用机电股份有限公司 基于深度学习的折角塞门卡子丢失故障检测方法
CN114998748A (zh) * 2022-07-28 2022-09-02 北京卫星信息工程研究所 遥感图像目标精细识别方法、电子设备及存储介质
CN114998748B (zh) * 2022-07-28 2023-02-03 北京卫星信息工程研究所 遥感图像目标精细识别方法、电子设备及存储介质
CN114972976B (zh) * 2022-07-29 2022-12-20 之江实验室 基于频域自注意力机制的夜间目标检测、训练方法及装置
CN114972976A (zh) * 2022-07-29 2022-08-30 之江实验室 基于频域自注意力机制的夜间目标检测、训练方法及装置
CN116258931A (zh) * 2022-12-14 2023-06-13 之江实验室 基于ViT和滑窗注意力融合的视觉指代表达理解方法和系统
CN116258931B (zh) * 2022-12-14 2023-09-15 之江实验室 基于ViT和滑窗注意力融合的视觉指代表达理解方法和系统
CN115953665A (zh) * 2023-03-09 2023-04-11 武汉人工智能研究院 一种目标检测方法、装置、设备及存储介质
CN116129228A (zh) * 2023-04-19 2023-05-16 中国科学技术大学 图像匹配模型的训练方法、图像匹配方法及其装置
CN116384593B (zh) * 2023-06-01 2023-08-18 深圳市国电科技通信有限公司 分布式光伏出力预测方法、装置、电子设备和介质
CN116384593A (zh) * 2023-06-01 2023-07-04 深圳市国电科技通信有限公司 分布式光伏出力预测方法、装置、电子设备和介质
CN116993996A (zh) * 2023-09-08 2023-11-03 腾讯科技(深圳)有限公司 对图像中的对象进行检测的方法及装置
CN116993996B (zh) * 2023-09-08 2024-01-12 腾讯科技(深圳)有限公司 对图像中的对象进行检测的方法及装置
CN117152142A (zh) * 2023-10-30 2023-12-01 菲特(天津)检测技术有限公司 一种轴承缺陷检测模型构建方法及系统
CN117152142B (zh) * 2023-10-30 2024-02-02 菲特(天津)检测技术有限公司 一种轴承缺陷检测模型构建方法及系统
CN117542045A (zh) * 2024-01-10 2024-02-09 济南大学 一种基于空间引导自注意力的食品识别方法及系统
CN117542045B (zh) * 2024-01-10 2024-05-10 山东记食信息科技有限公司 一种基于空间引导自注意力的食品识别方法及系统

Also Published As

Publication number Publication date
CN113902926B (zh) 2022-05-31

Similar Documents

Publication Publication Date Title
CN113902926B (zh) 一种基于自注意力机制的通用图像目标检测方法和装置
CN109522942B (zh) 一种图像分类方法、装置、终端设备和存储介质
US20190370647A1 (en) Artificial intelligence analysis and explanation utilizing hardware measures of attention
WO2016119076A1 (en) A method and a system for face recognition
CN110929080B (zh) 基于注意力和生成对抗网络的光学遥感图像检索方法
CN111160375A (zh) 三维关键点预测及深度学习模型训练方法、装置及设备
CN111027576B (zh) 基于协同显著性生成式对抗网络的协同显著性检测方法
CN111738270B (zh) 模型生成方法、装置、设备和可读存储介质
CN115147598B (zh) 目标检测分割方法、装置、智能终端及存储介质
CN111179419A (zh) 三维关键点预测及深度学习模型训练方法、装置及设备
CN111598087B (zh) 不规则文字的识别方法、装置、计算机设备及存储介质
CN112116064A (zh) 光谱超分辨自适应加权注意力机制深层网络数据处理方法
US11948078B2 (en) Joint representation learning from images and text
CN115409896A (zh) 位姿预测方法、装置、电子设备和介质
CN114067371B (zh) 一种跨模态行人轨迹生成式预测框架、方法和装置
CN114820755A (zh) 一种深度图估计方法及系统
CN113487027B (zh) 基于时序对齐预测的序列距离度量方法、存储介质及芯片
CN115346125A (zh) 一种基于深度学习的目标检测方法
CN112257686B (zh) 人体姿态识别模型的训练方法、装置及存储介质
CN114819140A (zh) 模型剪枝方法、装置和计算机设备
CN114998630A (zh) 一种从粗到精的地对空图像配准方法
CN114743187A (zh) 银行安全控件自动登录方法、系统、设备及存储介质
CN113095328A (zh) 一种基尼指数引导的基于自训练的语义分割方法
CN112232360A (zh) 图像检索模型优化方法、图像检索方法、装置及存储介质
US20230298326A1 (en) Image augmentation method, electronic device and readable storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant