CN115063666A - 解码器的训练方法、目标检测方法、装置以及存储介质 - Google Patents
解码器的训练方法、目标检测方法、装置以及存储介质 Download PDFInfo
- Publication number
- CN115063666A CN115063666A CN202210788886.5A CN202210788886A CN115063666A CN 115063666 A CN115063666 A CN 115063666A CN 202210788886 A CN202210788886 A CN 202210788886A CN 115063666 A CN115063666 A CN 115063666A
- Authority
- CN
- China
- Prior art keywords
- segment
- query
- features
- feature
- predicted
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/74—Image or video pattern matching; Proximity measures in feature spaces
- G06V10/761—Proximity, similarity or dissimilarity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/40—Scenes; Scene-specific elements in video content
- G06V20/41—Higher-level, semantic clustering, classification or understanding of video scenes, e.g. detection, labelling or Markovian modelling of sport events or news items
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- Mathematical Physics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- General Engineering & Computer Science (AREA)
- Biophysics (AREA)
- Data Mining & Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种编码器的训练方法、目标检测方法、装置以及存储介质,其中的训练方法包括:使用关系注意力模块并基于查询特征,生成与查询特征相对应的显著查询特征集合,用以进行更新处理;使用跨越注意力模块并基于更新后的查询特征,获取与更新后的查询特征相对应的预测片段质量信息,并构建片段质量损失函数;获取与查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数;根据片段质量损失函数和片段关系损失函数进行调整处理。本公开能够减少无效查询特征对预测的干扰,可以抑制冗余的预测结果,提升检测结果的准确性。
Description
技术领域
本公开涉及人工智能技术领域,尤其涉及一种解码器的训练方法、目标检测方法、装置以及存储介质。
背景技术
随着视频数据量的日益增长,对于视频数据的分析和处理的需求日渐提升。例如,在直播内容安全性检测、短视频危险动作检测等场景,需要使用视频动作检测方法识别视频数据中的风险动作。目前,在进行动作检测时,通常使用DETR(Bidirectional EncoderRepresentations from Transformer,基于transformer结构的双向编码器表征)模型进行目标检测。DETR模型利用Transformer的结构,实现了基于查询的二维图像目标检测。Transformer结构是一种基于注意力(Attention)机制的网络结构,通过Transformer构建模型,能够有效地提升视频动作检测方法的性能。在实现本发明的过程中,发明人发现DETR模型通过编码器-解码器的方式预测出固定数量的检测目标,在解码器中通常采用密集的自注意力机制确定查询特征之间的相关关系,由于没有考虑到与每个查询特征对应的视频片段之间的语义关系,则无效的查询特征能够干扰查询特征预测的结果,并且,对于查询特征的预测存在预测结果不准确的情况。
发明内容
有鉴于此,本发明要解决的一个技术问题是提供一种解码器的训练方法、目标检测方法、装置以及存储介质。
根据本公开的第一方面,提供一种解码器的训练方法,其中,解码器包括:关系注意力模块和跨越注意力模块;所述训练方法包括:使用所述关系注意力模块并基于查询特征,生成与所述查询特征相对应的显著查询特征集合,用以使用所述关系注意力模块并基于所述显著查询特征集合对所述查询特征进行更新处理;使用所述跨越注意力模块并基于更新后的查询特征,获取与所述更新后的查询特征相对应的预测片段质量信息,并根据所述预测片段质量信息构建片段质量损失函数;获取与所述查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数;根据所述片段质量损失函数和所述片段关系损失函数,对所述关系注意力模块和所述跨越注意力模块进行调整处理。
可选地,所述生成与所述查询特征相对应的显著查询特征集合包括:使用所述关系注意力模块并基于查询特征,获取各个查询特征之间的相似度信息、与各个查询特征对应的视频片段之间的片段关系特征信息;根据所述相似度信息,生成与所述查询特征相对应的相似特征集合;根据所述片段关系特征信息,生成与所述查询特征相对应的关系特征集合;基于所述相似特征集合、所述关系特征集合以及所述查询特征自身,生成所述显著查询特征集合。
可选地,所述根据所述相似度信息,生成与所述查询特征相对应的相似特征集合包括:根据所述相似度信息获取所述查询特征的相似查询特征;其中,所述查询特征与所述相似查询特征之间的相似度大于预设的相似度阈值;基于所述相似查询特征生成所述相似特征集合。
可选地,所述片段关系特征信息包括:片段交并比;所述根据所述片段关系特征信息,生成与所述查询特征相对应的关系特征集合包括:根据所述片段交并比获取所述查询特征的关系查询特征;其中,所述查询特征与所述关系查询特征之间的片段交并比大于预设的交并比阈值;基于所述关系查询特征生成所述关系特征集合。
可选地,所述基于所述相似特征集合、所述关系特征集合以及所述查询特征自身,生成所述显著查询特征集合包括:获取所述相似特征集合关于所述关系特征集合的相对补集;将所述相对补集与所述查询特征自身的并集,作为所述显著查询特征集合。
可选地,所述预测片段质量信息包括:预测片段质量得分;所述使用所述跨越注意力模块并基于更新后的查询特征,获取与所述更新后的查询特征相对应的预测片段质量信息包括:确定与所述更新后的查询特征相对应的预测片段,并获取与所述预测片段相对应的视频片段;确定所述预测片段的中点与所述视频片段的中点之间的预测距离、所述预测片段与所述视频片段之间的预测交并比;基于所述预测距离和所述预测交并比,生成所述预测片段质量得分。
可选地,所述根据所述预测片段质量信息构建片段质量损失函数包括:确定所述预测片段中点与所述视频片段中点之间的片段距离、所述预测片段与所述视频片段之间的片段交并比;根据所述预测距离、所述预测交并比与对应的片段距离、片段交并比之间的偏差信息,构建所述片段质量损失函数。
可选地,所述片段关系特征包括:预测片段交并比;所述获取与所述查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数包括:确定与所述更新后的查询特征相对应的预测片段之间的预测片段交并比;根据所述预测片段交并比的累计信息,构建所述片段关系损失函数。
可选地,所述使用所述关系注意力模块并基于所述显著查询特征集合,对所述查询特征进行更新处理包括:使用所述关系注意力模块对所述显著查询特征集合内的特征进行自注意力计算处理,用以对所述查询特征进行更新处理。
可选地,所述解码器模块包括:基于Transformer结构的解码器。
根据本公开的第二方面,提供一种目标检测方法,包括:获取训练好的解码器;其中,所述解码器是通过如上所述的训练方法训练得到;使用所述解码器并基于查询特征,生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分;基于所述分类置信度和预测片段质量得分,确定预测得分。
根据本公开的第三方面,提供一种解码器的训练装置,其中,解码器包括:关系注意力模块和跨越注意力模块;所述训练装置包括:查询集合获取模块,用于使用所述关系注意力模块并基于查询特征,生成与所述查询特征相对应的显著查询特征集合;查询特征更新模块,用于使用所述关系注意力模块并基于所述显著查询特征集合,对所述查询特征进行更新处理;片段质量确定模块,用于使用所述跨越注意力模块并基于更新后的查询特征,获取与所述更新后的查询特征相对应的预测片段质量信息,并根据所述预测片段质量信息构建片段质量损失函数;预测损失确定模块,用于确定获取与所述查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数;模块调整模块,用于根据所述片段质量损失函数和所述片段关系损失函数,对所述关系注意力模块和所述跨越注意力模块进行调整处理。
可选地,所述查询集合获取模块,包括:特征信息获取单元,用于使用所述关系注意力模块并基于查询特征,获取各个查询特征之间的相似度信息、与各个查询特征对应的视频片段之间的片段关系特征信息;相似集合获取单元,用于根据所述相似度信息,生成与所述查询特征相对应的相似特征集合;关系集合获取单元,用于根据所述片段关系特征信息,生成与所述查询特征相对应的关系特征集合;显著集合获取单元,用于基于所述相似特征集合、所述关系特征集合以及所述查询特征自身,生成所述显著查询特征集合。
可选地,所述相似集合获取单元,具体用于根据所述相似度信息获取所述查询特征的相似查询特征;其中,所述查询特征与所述相似查询特征之间的相似度大于预设的相似度阈值;基于所述相似查询特征生成所述相似特征集合。
可选地,所述片段关系特征信息包括:片段交并比;所述关系集合获取单元,具体用于根据所述片段交并比获取所述查询特征的关系查询特征;其中,所述查询特征与所述关系查询特征之间的片段交并比大于预设的交并比阈值;基于所述关系查询特征生成所述关系特征集合。
可选地,所述显著集合获取单元,具体用于获取所述相似特征集合关于所述关系特征集合的相对补集;将所述相对补集与所述查询特征自身的并集,作为所述显著查询特征集合。
可选地,所述预测片段质量信息包括:预测片段质量得分;所述片段质量确定模块,包括:片段质量确定单元,用于确定与所述更新后的查询特征相对应的预测片段,并获取与所述预测片段相对应的视频片段;确定所述预测片段的中点与所述视频片段的中点之间的预测距离、所述预测片段与所述视频片段之间的预测交并比;基于所述预测距离和所述预测交并比,生成所述预测片段质量得分。
可选地,所述片段质量确定模块,包括:质量损失确定单元,用于确定所述预测片段中点与所述视频片段中点之间的片段距离、所述预测片段与所述视频片段之间的片段交并比;根据所述预测距离、所述预测交并比与对应的片段距离、片段交并比之间的偏差信息,构建所述片段质量损失函数。
可选地,所述片段关系特征包括:预测片段交并比;所述预测损失确定模块,具体用于确定与所述更新后的查询特征相对应的预测片段之间的预测片段交并比;根据所述预测片段交并比的累计信息,构建所述片段关系损失函数。
可选地,所述查询特征更新模块,具体用于使用所述关系注意力模块对所述显著查询特征集合内的特征进行自注意力计算处理,用以对所述查询特征进行更新处理。
可选地,所述解码器模块包括:基于Transformer结构的解码器。
根据本公开的第四方面,提供一种解码器的训练装置,包括:存储器;以及耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行如上所述的方法。
根据本公开的第五方面,提供一种目标检测装置,包括:模型获取模块,用于获取训练好的解码器;其中,所述解码器是通过上述的训练方法训练得到;检测处理模块,用于使用所述解码器并基于查询特征,生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分。预测得分模块,用于基于所述分类置信度和预测片段质量得分,确定预测得分。
根据本公开的第六方面,提供一种目标检测装置,包括:存储器;以及耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行如上所述的方法。
根据本公开的第七方面,提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述指令被处理器执行如上的方法。
本公开的解码器的训练方法、目标检测方法、装置以及存储介质,根据查询特征之间的关系构建显著查询特征集合,对显著查询特征集合内的查询特征进行自注意力处理,能够减少无效查询特征对预测的干扰;通过获取新增的预测片段质量信息并构建片段质量损失函数,可以抑制冗余的预测结果,提升检测结果的准确性;通过构建片段关系损失函数,能够抑制冗余预测,使得预测结果的更加精准。
附图说明
为了更清楚地说明本公开实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其它的附图。
图1为根据本公开的解码器的训练方法的一个实施例的流程示意图;
图2为本公开的解码器的一个实施例的网络框架结构示意图;
图3为根据本公开的解码器的训练方法的一个实施例中的生成显著查询特征集合的流程示意图;
图4为查询特征之间的关系示意图;
图5为根据本公开的解码器的训练方法的一个实施例中的生成预测片段质量得分的流程示意图;
图6为根据本公开的解码器的训练方法的一个实施例中的对查询特征进行处理的示意图;
图7为根据本公开的解码器的训练方法的一个实施例中的构建片段质量损失函数的流程示意图;
图8为根据本公开的解码器的训练方法的一个实施例中的构建片段关系损失函数的流程示意图;
图9为根据本公开的目标检测方法的一个实施例的流程示意图;
图10为根据本公开的解码器的训练装置的一个实施例的模块示意图;
图11为根据本公开的解码器的训练装置的一个实施例中的查询集合获取模块的模块示意图;
图12为根据本公开的解码器的训练装置的一个实施例中的片段质量确定模块的模块示意图;
图13为根据本公开的解码器的训练装置的另一个实施例的模块示意图;
图14为根据本公开的目标检测装置的一个实施例的模块示意图;
图15为根据本公开的目标检测装置的另一个实施例的模块示意图。
具体实施方式
下面参照附图对本公开进行更全面的描述,其中说明本公开的示例性实施例。下面将结合本公开实施例中的附图,对本公开实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本公开一部分实施例,而不是全部的实施例。基于本公开中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本公开保护的范围。下面结合各个图和实施例对本公开的技术方案进行多方面的描述。
在发明人所知晓的相关技术中,DETR模型包含基于Transformer结构的编码器和解码器,即Transformer编码器和Transformer解码器。原始视频序列经过骨干网络(例如为卷积神经网络等)提取时间和空间特征图并加上位置编码信息,合成嵌入向量,输入Transformer编码器。Transformer编码器通过自注意力机制提取图像编码特征,将图像编码特征与查询特征输入Transformer解码器。Transformer解码器输出目标查询向量,目标查询向量经过全连接层和多层感知机层构建的分类头和回归头,输出检测目标的位置和类别,检测目标可以为走路、跑步等动作。
Transformer结构在特征表示方面有着较好的性能,通过Transformer构建模型,能够有效地提升视频动作检测方法的性能。Transformer编码器包含多个编码器层,现有的编码器层由一个多头自注意力层、两个层归一化层及一个前馈神经网络层组成。现有的Transformer解码器包含多个解码器层,解码器层由两个多头自注意力层、三个归一化层及一个前馈神经网络层组成。
DETR方法以固定数量N个可学习的查询特征(Query Feature)为输入,每个查询特征通过网络自适应地从二维图像中采样像素点,并通过自注意力(Self-attetion)的方式进行查询特征之间的信息交互,最终,每个查询特征被用于单独预测一个检测框的位置和类别。在时间动作检测领域中,通过编码器-解码器的方式预测出固定数量的检测目标。在检测目标时,利用基于稀疏采样的Transformer结构提取时间片段特征。
对于解码器部分,K个可训练的查询特征作为输入。查询特征是一个可学习的向量,查询特征可以根据学习到的统计信息从特定时刻提取时间特征。使用自注意力的操作实现所有查询特征之间的信息交互,每个查询特征都可以通过一层全连接层预测N个时间维度上的采样k个点的归一化坐标,并根据采样点从视频特征中提取特征更新查询特征。例如,通过另一个全连接层,输入查询特征预测k个权重,对采样的k个特征加权求和。更新后的查询特征分别通过回归头和分类头预测出动作的位置以及类型。回归头和分类头分别为三层的全连接和一层的全连接层,回归头预测动作的开始和结束的归一化坐标,分类头预测该动作的分类以及置信度分数。
现有的DETR模型中的解码器通常采用密集的自注意力机制获取查询特征之间的相关关系,没有考虑每个查询特征对应的视频片段之间的语义关系,因此,无效的查询片段会干扰到每个查询特征预测的结果,并且,由于缺少查询特征之间的约束,容易导致冗余的预测结果,存在预测分数不准的情况。
图1为根据本公开的对话生成模型的训练方法的一个实施例的流程示意图,解码器包括关系注意力模块和跨越注意力模块,如图1所示:
步骤101,使用关系注意力模块并基于查询特征,生成与查询特征相对应的显著查询特征集合,用以使用关系注意力模块并基于显著查询特征集合对查询特征进行更新处理。
在一个实施例中,查询特征(Query Feature)可以为现有的Transformer编码器生成的查询向量等。解码器模块包括基于Transformer结构的解码器,即Transformer解码器。如图2所示,Transformer解码器包括关系注意力模块、跨越注意力模块、两个归一化层以及前馈网络。归一化层以及前馈网络可以使用现有的多种实现方式。Transformer解码器的输入为固定数量的可训练的查询特征。关系注意力模块为对现有的Transformer解码器中的自注意力(Self-Attention)模块进行优化处理后的模块,用于对查询特征进行非密集的注意力(Attention)处理。
步骤102,使用跨越注意力模块并基于更新后的查询特征,获取与更新后的查询特征相对应的预测片段质量信息,并根据预测片段质量信息构建片段质量损失函数。
在一个实施例中,使用跨越注意力模块并基于更新后的查询特征,通过前馈网络以及分类头、回归头和片段质量头生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分,其中,目标为视频中的动作等,分类置信度可以为分类置信度的得分等,回归信息可以为动作的起始和终止信息。
跨越注意力模块为对现有的Transformer解码器中的自注意力模块进行优化处理后的模块。通过增加片段质量头,用以获得预测片段质量得分,在进行预测时,将预测片段质量得分和分类置信度的得分相乘,得到查询特征的最终预测得分。
步骤103,获取与查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数。
步骤104,根据片段质量损失函数和片段关系损失函数,对关系注意力模块和跨越注意力模块进行调整处理。
在一个实施例中,可以采用现有的多种模型调整方法,根据片段质量损失函数和片段关系损失函数,对关系注意力模块和跨越注意力模块等模块的参数进行调整,以使片段质量损失函数的函数值和片段关系损失函数的函数值分别在允许的取值范围内。
在一个实施例中,生成与查询特征相对应的显著查询特征集合可以使用多种方法。图3为根据本公开的解码器的训练方法的一个实施例中的生成显著查询特征集合的流程示意图,如图3所示:
步骤301,使用关系注意力模块并基于查询特征,获取各个查询特征之间的相似度信息、与各个查询特征对应的视频片段之间的片段关系特征信息。
在一个实施例中,可以采用现有多种方法计算各个查询特征之间的相似度信息,相似度信息可以为余弦相似度等。可以采用现有多种方法计算与各个查询特征对应的视频片段之间的片段关系特征信息,片段关系特征信息包括片段交并比等。
步骤302,根据相似度信息,生成与查询特征相对应的相似特征集合。
在一个实施例中,根据相似度信息获取查询特征的相似查询特征,查询特征与相似查询特征之间的相似度大于预设的相似度阈值,相似度可以为余弦相似度等。基于相似查询特征生成相似特征集合。
例如,通过关系注意力模块对查询特征之间的关系进行建模。在图4中,查询特征包括真实标签311、参考查询片段321、显著相似片段331,332,333、显著不相似片段341,342、冗余片段351等。在进入关系注意力模块后,每个查询特征通过全连接层预测出对应的时间片段。对于参考查询片段321,对应的显著查询特征集合包括显著相似片段331,332,333等,相似特征集合中的查询特征具有在语义上相近、在时间维度上不冗余等特征。
根据各个查询特征之间的相似度信息,构建一个相似度矩阵A∈其中,Lq为查询特征的固定数量,A用于表征Lq个查询特征两两之间的相似度的相似度矩阵,相似度矩阵A中的每个元素是两个查询特征的余弦相似度。基于相似度阈值γ∈[-1,1]构建相似特征集合WEI:
Esim={(i,j)|A[i,j]-γ>0} (1-1);
其中,A[i,j]为第i个查询特征和第j个查询特征之间的相似度,γ是训练前预定义的相似度阈值;Esim为根据特征之间的相似度构建的相似特征集合,可以与查询特征相对应,Esim的数量为多个。
步骤303,根据片段关系特征信息,生成与查询特征相对应的关系特征集合。
在一个实施例中,片段关系特征信息为片段交并比(Intersection over Union,简称IoU)等。根据片段交并比获取查询特征的关系查询特征,查询特征与关系查询特征之间的片段交并比大于预设的交并比阈值。基于关系查询特征生成关系特征集合。
例如,片段交并比IoU用于表征两个片段的交集长度/两个片段并集的长度。基于片段交并比构建IoU矩阵:B矩阵中的每个元素为与两个查询特征对应的视频片段(可以为参考特征片段)之间的IoU值。根据交并比阈值v∈[0,1],构建关系特征集合:
EIoU={(i,j)|B[i,j]-τ>0} (1-2);
其中,EIoU为根据IoU关系构建的关系特征集合;B[i,j]是第i个查询特征和第j个查询特征之间的IoU关系,即B[i,j]为第i个查询特征和第j个查询特征对应的视频片段之间的IoU值;τ是训练前预定义的交并比阈值。
步骤304,基于相似特征集合、关系特征集合以及查询特征自身,生成显著查询特征集合。
在一个实施例中,获取相似特征集合关于关系特征集合的相对补集,将相对补集与查询特征自身的并集,作为显著查询特征集合。
例如,构建显著查询特征集合:
E=(EIoU\Esim)∪Eself (1-3);
其中,E为显著查询特征集合,Eself是自连接集合,表示第i个查询特征和自身的连接。
在一个实施例中,使用关系注意力模块对显著查询特征集合内的特征进行自注意力计算处理,用以对查询特征进行更新处理。可以使用现有的自注意力计算处理方法,对显著查询特征集合内的特征进行自注意力计算处理,通过自注意力计算处理,可以基于已有的查询特征获得更具表达性的特征。
例如,对显著查询特征集合内的查询特征计算attention(注意力)权重,计算方法为:
q′i=aiVi T (1-4);
其中,Q、K、V分别为每个查询特征的Query、Key、Value特征,Ki和Vi是第i个查询特征对应的显著查询特征集合内的key和value集合,q′i是第i个查询特征更新后的查询特征,ai是显著查询特征集合内元素的注意力权重,是一个行归一化的矩阵,为对Value集合里的每一个特征加权求和的值。
为了排除无效查询特征片段对预测的干扰,本公开的解码器的训练方法基于特征相似度和IoU两个指标,为每个查询特征动态构建显著查询特征集合,代替Self-attention的密集attention运算,该查询特征仅与显著查询特征集合内的其他查询特征计算Attention。
在一个实施例中,获取与更新后的查询特征相对应的预测片段质量信息可以采用多种方法。图5为根据本公开的解码器的训练方法的一个实施例中的生成预测片段质量得分的流程示意图,预测片段质量信息包括预测片段质量得分,如图5所示:
步骤501,确定与更新后的查询特征相对应的预测片段,并获取与预测片段相对应的视频片段。
步骤502,确定预测片段的中点与视频片段的中点之间的预测距离、预测片段与视频片段之间的预测交并比。
步骤503,基于预测距离和预测交并比,生成预测片段质量得分。
根据预测片段质量信息构建片段质量损失函数可以采用多种方法。图7为根据本公开的解码器的训练方法的一个实施例中的构建片段质量损失函数的流程示意图,如图7所示:
步骤701,确定预测片段中点与视频片段中点之间的片段距离、预测片段与视频片段之间的片段交并比。
步骤702,根据预测距离、预测交并比与对应的片段距离、片段交并比之间的偏差信息,构建片段质量损失函数。
例如,如图6所示,经过关系注意力模块更新的查询特征输入超越注意力模块,超越注意力模块预测时间维度上的采样点,并通过对采样特征加权求和的方式,得到视频片段的特征,将视频片段的特征通过前馈网络送入各个检测头当中。除了现有的回归头和分类头以外,添加了片段质量头(Segment Quality head),用以估计片段的质量。
确定与更新后的查询特征相对应的预测片段sq,sq对应的更新后的查询特征fq。定义(ζ1,ζ2)=φ(fq),表征为通过全连接层预测ζ1,ζ2这两个值,其中,φ()为单层的全连接层的函数,φ()可以为多种函数,ζ1为预测片段的中点与视频片段(动作片段)的中点之间的预测距离,ζ2为预测片段与视频片段(动作片段)之间的预测交并比。预测片段质量得分定义为ζ=ζ1·ζ2。在训练时,使用预测片段中点和与其对应的动作片段中点的偏移量,以及它们之间的交并比IoU值,构建片段质量损失函数:
其中,为预测的片段和最近的Ground truth之间的中点的距离,即为预测片段中点mq与对应的视频片段(与预测片段最近的片段)中点mgt之间的实际片段距离;IoU(sq,sgt)为预测片段和最近的Ground truth之间的IoU,即预测片段sq与对应的视频片段sgt之间的实际片段交并比。
在预测时,将分类头输出的分类执行度得分与ζ相乘,得到每个查询特征的预测片段的最终得分。通过增加片段质量头,将预测片段与真实动作的偏移程度和重合程度的乘积作为质量得分,用于预测时共同决定预测片段得分,提升检测结果的准确性。
构建片段关系损失函数可以使用多种方法。图8为根据本公开的解码器的训练方法的一个实施例中的构建片段关系损失函数的流程示意图,片段关系特征包括预测片段交并比,如图8所示:
步骤801,确定与更新后的查询特征相对应的预测片段之间的预测片段交并比。
步骤802,根据预测片段交并比的累计信息,构建片段关系损失函数。
在一个实施例中,在训练阶段,引入IoU约束项构建片段关系损失函数:
其中,Lq是查询特征的数量,si,sj是第i个和第j个查询特征对应的预测片段,来自上一层回归头预测数的输出;IoU是si,sj这两个片段的IoU(交并比)关系,计算方式为:
通过构建片段关系损失函数,能够抑制冗余的Query预测,从而增大获得更精准预测结果的概率。
图9为根据本公开的目标检测方法的一个实施例的流程示意图,如图9所示:
步骤901,获取训练好的解码器;其中,解码器是通过如上的训练方法训练得到。
步骤902,使用解码器并基于查询特征,生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分。
在一个实施例中,解码器模块包括Transformer解码器,Transformer解码器包括关系注意力模块、跨越注意力模块、两个归一化层以及前馈网络。Transformer解码器的输入为固定数量的可训练的查询特征。关系注意力模块对查询特征进行非密集的注意力处理,使用跨越注意力模块并基于更新后的查询特征,并通过前馈网络以及分类头、回归头和片段质量头生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分。
步骤903,基于分类置信度和预测片段质量得分,确定预测得分。
在一个实施例中,将预测片段质量得分和分类置信度的得分相乘,确定每个查询特征的最终预测得分。
在一个实施例中,如图10所示,本公开提供一种解码器的训练装置110,解码器包括关系注意力模块和跨越注意力模块等;解码器的训练装置110包括查询集合获取模块111、查询特征更新模块112、片段质量确定模块113、预测损失确定模块114和模块调整模块115。
查询集合获取模块111使用关系注意力模块并基于查询特征,生成与查询特征相对应的显著查询特征集合。查询特征更新模块112使用关系注意力模块并基于显著查询特征集合,对查询特征进行更新处理。例如,查询特征更新模块112使用关系注意力模块对显著查询特征集合内的特征进行自注意力计算处理,用以对查询特征进行更新处理。
片段质量确定模块113使用跨越注意力模块并基于更新后的查询特征,获取与更新后的查询特征相对应的预测片段质量信息,并根据预测片段质量信息构建片段质量损失函数。预测损失确定模块114确定获取与查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数。模块调整模块115根据片段质量损失函数和片段关系损失函数,对关系注意力模块和跨越注意力模块进行调整处理。
在一个实施例中,如图11所示,查询集合获取模块111包括特征信息获取单元1111、相似集合获取单元1112、关系集合获取单元1113和显著集合获取单元1114。特征信息获取单元1111使用关系注意力模块并基于查询特征,获取各个查询特征之间的相似度信息、与各个查询特征对应的视频片段之间的片段关系特征信息。
相似集合获取单元1112根据相似度信息,生成与查询特征相对应的相似特征集合。关系集合获取单元1113根据片段关系特征信息,生成与查询特征相对应的关系特征集合。显著集合获取单元1114基于相似特征集合、关系特征集合以及查询特征自身,生成显著查询特征集合。
在一个实施例中,相似集合获取单元1112根据相似度信息获取查询特征的相似查询特征;其中,查询特征与相似查询特征之间的相似度大于预设的相似度阈值。相似集合获取单元1112基于相似查询特征生成相似特征集合。
片段关系特征信息包括片段交并比等,关系集合获取单元1113根据片段交并比获取查询特征的关系查询特征;其中,查询特征与关系查询特征之间的片段交并比大于预设的交并比阈值。关系集合获取单元1113基于关系查询特征生成关系特征集合。
显著集合获取单元1114获取相似特征集合关于关系特征集合的相对补集。显著集合获取单元1114将相对补集与查询特征自身的并集,作为显著查询特征集合。
在一个实施例中,预测片段质量信息包括预测片段质量得分;如图12所示,片段质量确定模块113包括片段质量确定单元1131和质量损失确定单元1132。片段质量确定单元1131确定与更新后的查询特征相对应的预测片段,并获取与预测片段相对应的视频片段;片段质量确定单元1131确定预测片段的中点与视频片段的中点之间的预测距离、预测片段与视频片段之间的预测交并比;片段质量确定单元1131基于预测距离和预测交并比,生成预测片段质量得分。
质量损失确定单元1132确定预测片段中点与视频片段中点之间的片段距离、预测片段与视频片段之间的片段交并比。质量损失确定单元1132根据预测距离、预测交并比与对应的片段距离、片段交并比之间的偏差信息,构建片段质量损失函数。
在一个实施例中,片段关系特征包括预测片段交并比等,预测损失确定模块114用于确定与更新后的查询特征相对应的预测片段之间的预测片段交并比。预测损失确定模块114根据预测片段交并比的累计信息,构建片段关系损失函数。
在一个实施例中,如图13所示,本公开提供一种解码器的训练装置可包括存储器131、处理器132、通信接口133以及总线134。存储器131用于存储指令,处理器132耦合到存储器131,处理器132被配置为基于存储器131存储的指令执行实现上述的解码器的训练方法。
存储器131可以为高速RAM存储器、非易失性存储器(non-volatile memory)等,存储器131也可以是存储器阵列。存储器131还可能被分块,并且块可按一定的规则组合成虚拟卷。处理器132可以为中央处理器CPU,或专用集成电路ASIC(Application SpecificIntegrated Circuit),或者是被配置成实施本公开的解码器的训练方法的一个或多个集成电路。
在一个实施例中,本公开提供一种目标检测装置140,包括模型获取模块141、检测处理模块142和预测得分模块143。模型获取模块141获取训练好的解码器;其中,解码器是通过上述的训练方法训练得到。
检测处理模块142使用解码器并基于查询特征,生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分。预测得分模块143基于分类置信度和预测片段质量得分,确定预测得分。
在一个实施例中,如图15所示,本公开提供一种目标检测装置可包括存储器151、处理器152、通信接口153以及总线154。存储器151用于存储指令,处理器152耦合到存储器151,处理器152被配置为基于存储器151存储的指令执行实现上述的目标检测方法。
存储器151可以为高速RAM存储器、非易失性存储器(non-volatile memory)等,存储器151也可以是存储器阵列。存储器151还可能被分块,并且块可按一定的规则组合成虚拟卷。处理器152可以为中央处理器CPU,或专用集成电路ASIC(Application SpecificIntegrated Circuit),或者是被配置成实施本公开的目标检测方法的一个或多个集成电路。
在一个实施例中,本公开提供一种计算机可读存储介质,计算机可读存储介质存储有计算机指令,指令被处理器执行时实现如上任一个实施例中的方法。
上述实施例中的编码器的训练方法、目标检测方法、装置以及存储介质,根据查询特征之间的关系构建显著查询特征集合,对显著查询特征集合内的查询特征进行自注意力处理,能够减少无效查询特征对预测的干扰;通过获取新增的预测片段质量信息并构建片段质量损失函数,可以抑制冗余的预测结果,提升检测结果的准确性;通过构建片段关系损失函数,能够抑制冗余预测,使得预测结果的更加精准;提高了用户的使用感受度。
可以使用许多方式来实现本公开的方法和系统。例如,可通过软件、硬件、固件或者软件、硬件、固件的任何组合来实现本公开的方法和系统。用于方法的步骤的上述顺序仅是为了进行说明,本公开的方法的步骤不限于以上具体描述的顺序,除非以其它方式特别说明。此外,在一些实施例中,还可将本公开实施为记录在记录介质中的程序,这些程序包括用于实现根据本公开的方法的机器可读指令。因而,本公开还覆盖存储用于执行根据本公开的方法的程序的记录介质。
本公开的描述是为了示例和描述起见而给出的,而并不是无遗漏的或者将本公开限于所公开的形式。很多修改和变化对于本领域的普通技术人员而言是显然的。选择和描述实施例是为了更好说明本公开的原理和实际应用,并且使本领域的普通技术人员能够理解本公开从而设计适于特定用途的带有各种修改的各种实施例。
Claims (16)
1.一种解码器的训练方法,其中,解码器包括:关系注意力模块和跨越注意力模块;所述训练方法包括:
使用所述关系注意力模块并基于查询特征,生成与所述查询特征相对应的显著查询特征集合,用以使用所述关系注意力模块并基于所述显著查询特征集合对所述查询特征进行更新处理;
使用所述跨越注意力模块并基于更新后的查询特征,获取与所述更新后的查询特征相对应的预测片段质量信息,并根据所述预测片段质量信息构建片段质量损失函数;
获取与所述查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数;
根据所述片段质量损失函数和所述片段关系损失函数,对所述关系注意力模块和所述跨越注意力模块进行调整处理。
2.如权利要求1所述的方法,所述生成与所述查询特征相对应的显著查询特征集合包括:
使用所述关系注意力模块并基于查询特征,获取各个查询特征之间的相似度信息、与各个查询特征对应的视频片段之间的片段关系特征信息;
根据所述相似度信息,生成与所述查询特征相对应的相似特征集合;
根据所述片段关系特征信息,生成与所述查询特征相对应的关系特征集合;
基于所述相似特征集合、所述关系特征集合以及所述查询特征自身,生成所述显著查询特征集合。
3.如权利要求2所述的方法,所述根据所述相似度信息,生成与所述查询特征相对应的相似特征集合包括:
根据所述相似度信息获取所述查询特征的相似查询特征;其中,所述查询特征与所述相似查询特征之间的相似度大于预设的相似度阈值;
基于所述相似查询特征生成所述相似特征集合。
4.如权利要求2所述的方法,所述片段关系特征信息包括:片段交并比;所述根据所述片段关系特征信息,生成与所述查询特征相对应的关系特征集合包括:
根据所述片段交并比获取所述查询特征的关系查询特征;其中,所述查询特征与所述关系查询特征之间的片段交并比大于预设的交并比阈值;
基于所述关系查询特征生成所述关系特征集合。
5.如权利要求2所述的方法,所述基于所述相似特征集合、所述关系特征集合以及所述查询特征自身,生成所述显著查询特征集合包括:
获取所述相似特征集合关于所述关系特征集合的相对补集;
将所述相对补集与所述查询特征自身的并集,作为所述显著查询特征集合。
6.如权利要求1所述的方法,所述预测片段质量信息包括:预测片段质量得分;所述使用所述跨越注意力模块并基于更新后的查询特征,获取与所述更新后的查询特征相对应的预测片段质量信息包括:
确定与所述更新后的查询特征相对应的预测片段,并获取与所述预测片段相对应的视频片段;
确定所述预测片段的中点与所述视频片段的中点之间的预测距离、所述预测片段与所述视频片段之间的预测交并比;
基于所述预测距离和所述预测交并比,生成所述预测片段质量得分。
7.如权利要求6所述的方法,所述根据所述预测片段质量信息构建片段质量损失函数包括:
确定所述预测片段中点与所述视频片段中点之间的片段距离、所述预测片段与所述视频片段之间的片段交并比;
根据所述预测距离、所述预测交并比与对应的片段距离、片段交并比之间的偏差信息,构建所述片段质量损失函数。
8.如权利要求1所述的方法,所述片段关系特征包括:预测片段交并比;所述获取与所述查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数包括:
确定与所述更新后的查询特征相对应的预测片段之间的预测片段交并比;
根据所述预测片段交并比的累计信息,构建所述片段关系损失函数。
9.如权利要求1所述的方法,所述使用所述关系注意力模块并基于所述显著查询特征集合,对所述查询特征进行更新处理包括:
使用所述关系注意力模块对所述显著查询特征集合内的特征进行自注意力计算处理,用以对所述查询特征进行更新处理。
10.如权利要求1至9任一项所述的方法,其中,
所述解码器模块包括:基于Transformer结构的解码器。
11.一种目标检测方法,包括:
获取训练好的解码器;其中,所述解码器是通过权利要求1至10中任一项所述的训练方法训练得到;
使用所述解码器并基于查询特征,生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分;
基于所述分类置信度和预测片段质量得分,确定预测得分。
12.一种解码器的训练装置,其中,解码器包括:关系注意力模块和跨越注意力模块;所述训练装置包括:
查询集合获取模块,用于使用所述关系注意力模块并基于查询特征,生成与所述查询特征相对应的显著查询特征集合;
查询特征更新模块,用于使用所述关系注意力模块并基于所述显著查询特征集合,对所述查询特征进行更新处理;
片段质量确定模块,用于使用所述跨越注意力模块并基于更新后的查询特征,获取与所述更新后的查询特征相对应的预测片段质量信息,并根据所述预测片段质量信息构建片段质量损失函数;
预测损失确定模块,用于确定获取与所述查询特征相对应的预测视频片段之间的片段关系特征,构建片段关系损失函数;
模块调整模块,用于根据所述片段质量损失函数和所述片段关系损失函数,对所述关系注意力模块和所述跨越注意力模块进行调整处理。
13.一种解码器的训练装置,包括:
存储器;以及耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行如权利要求1至10中任一项所述的方法。
14.一种目标检测装置,包括:
模型获取模块,用于获取训练好的解码器;其中,所述解码器是通过权利要求1至10中任一项所述的训练方法训练得到;
检测处理模块,用于使用所述解码器并基于查询特征,生成分类置信度、用于表征目标位置的回归信息和预测片段质量得分。
预测得分模块,用于基于所述分类置信度和预测片段质量得分,确定预测得分。
15.一种目标检测装置,包括:
存储器;以及耦接至所述存储器的处理器,所述处理器被配置为基于存储在所述存储器中的指令,执行如权利要求11所述的方法。
16.一种计算机可读存储介质,所述计算机可读存储介质非暂时性地存储有计算机指令,所述指令被处理器执行如权利要求1至11中任一项所述的方法。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210788886.5A CN115063666A (zh) | 2022-07-06 | 2022-07-06 | 解码器的训练方法、目标检测方法、装置以及存储介质 |
PCT/CN2023/081879 WO2024007619A1 (zh) | 2022-07-06 | 2023-03-16 | 解码器的训练方法、目标检测方法、装置以及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210788886.5A CN115063666A (zh) | 2022-07-06 | 2022-07-06 | 解码器的训练方法、目标检测方法、装置以及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115063666A true CN115063666A (zh) | 2022-09-16 |
Family
ID=83203954
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210788886.5A Pending CN115063666A (zh) | 2022-07-06 | 2022-07-06 | 解码器的训练方法、目标检测方法、装置以及存储介质 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN115063666A (zh) |
WO (1) | WO2024007619A1 (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116128158A (zh) * | 2023-04-04 | 2023-05-16 | 西南石油大学 | 混合采样注意力机制的油井效率预测方法 |
WO2024007619A1 (zh) * | 2022-07-06 | 2024-01-11 | 京东科技信息技术有限公司 | 解码器的训练方法、目标检测方法、装置以及存储介质 |
Family Cites Families (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110929869B (zh) * | 2019-12-05 | 2021-09-07 | 同盾控股有限公司 | 序列数据处理方法、装置、设备及存储介质 |
US20210279576A1 (en) * | 2020-03-03 | 2021-09-09 | Google Llc | Attention neural networks with talking heads attention |
CN111639153A (zh) * | 2020-04-24 | 2020-09-08 | 平安国际智慧城市科技股份有限公司 | 基于法律知识图谱的查询方法、装置、电子设备及介质 |
CN113902926B (zh) * | 2021-12-06 | 2022-05-31 | 之江实验室 | 一种基于自注意力机制的通用图像目标检测方法和装置 |
CN114186568B (zh) * | 2021-12-16 | 2022-08-02 | 北京邮电大学 | 一种基于关系编码和层次注意力机制的图像段落描述方法 |
CN114612716A (zh) * | 2022-03-08 | 2022-06-10 | 南京大学 | 一种基于自适应解码器的目标检测方法及装置 |
CN115063666A (zh) * | 2022-07-06 | 2022-09-16 | 京东科技信息技术有限公司 | 解码器的训练方法、目标检测方法、装置以及存储介质 |
-
2022
- 2022-07-06 CN CN202210788886.5A patent/CN115063666A/zh active Pending
-
2023
- 2023-03-16 WO PCT/CN2023/081879 patent/WO2024007619A1/zh unknown
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2024007619A1 (zh) * | 2022-07-06 | 2024-01-11 | 京东科技信息技术有限公司 | 解码器的训练方法、目标检测方法、装置以及存储介质 |
CN116128158A (zh) * | 2023-04-04 | 2023-05-16 | 西南石油大学 | 混合采样注意力机制的油井效率预测方法 |
Also Published As
Publication number | Publication date |
---|---|
WO2024007619A1 (zh) | 2024-01-11 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11783199B2 (en) | Image description information generation method and apparatus, and electronic device | |
WO2019228317A1 (zh) | 人脸识别方法、装置及计算机可读介质 | |
CN115063666A (zh) | 解码器的训练方法、目标检测方法、装置以及存储介质 | |
CN110705406B (zh) | 基于对抗迁移学习的人脸美丽预测方法及装置 | |
CN110120064B (zh) | 一种基于互强化与多注意机制学习的深度相关目标跟踪算法 | |
CN111340105A (zh) | 一种图像分类模型训练方法、图像分类方法、装置及计算设备 | |
CN112597883A (zh) | 一种基于广义图卷积和强化学习的人体骨架动作识别方法 | |
CN110599443A (zh) | 一种使用双向长短期记忆网络的视觉显著性检测方法 | |
CN113111968A (zh) | 图像识别模型训练方法、装置、电子设备和可读存储介质 | |
CN112307883A (zh) | 训练方法、装置、电子设备以及计算机可读存储介质 | |
CN115051929B (zh) | 基于自监督目标感知神经网络的网络故障预测方法及装置 | |
CN113569758B (zh) | 基于动作三元组引导的时序动作定位方法、系统、设备及介质 | |
CN115546468A (zh) | 一种基于transformer的细长类物体目标检测方法 | |
US20230252271A1 (en) | Electronic device and method for processing data based on reversible generative networks, associated electronic detection system and associated computer program | |
CN114358250A (zh) | 数据处理方法、装置、计算机设备、介质及程序产品 | |
Li et al. | Active temporal action detection in untrimmed videos via deep reinforcement learning | |
TWI781000B (zh) | 機器學習裝置以及方法 | |
KR102432854B1 (ko) | 잠재 벡터를 이용하여 군집화를 수행하는 방법 및 장치 | |
JP7310927B2 (ja) | 物体追跡装置、物体追跡方法及び記録媒体 | |
CN113822291A (zh) | 一种图像处理方法、装置、设备及存储介质 | |
CN114565791A (zh) | 一种人物档案识别方法、装置、设备及介质 | |
Wang et al. | Efficient Crowd Counting via Dual Knowledge Distillation | |
Passalis et al. | Adaptive inference for face recognition leveraging deep metric learning-enabled early exits | |
CN118097510A (zh) | 一种时序动作定位方法、装置、计算机设备及存储介质 | |
US20230297823A1 (en) | Method and system for training a neural network for improving adversarial robustness |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |