CN112801029B - 基于注意力机制的多任务学习方法 - Google Patents

基于注意力机制的多任务学习方法 Download PDF

Info

Publication number
CN112801029B
CN112801029B CN202110182158.5A CN202110182158A CN112801029B CN 112801029 B CN112801029 B CN 112801029B CN 202110182158 A CN202110182158 A CN 202110182158A CN 112801029 B CN112801029 B CN 112801029B
Authority
CN
China
Prior art keywords
frame
network
task
convolution
attention
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110182158.5A
Other languages
English (en)
Other versions
CN112801029A (zh
Inventor
邢德旺
刘兆英
张婷
李玉鑑
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Technology
Original Assignee
Beijing University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Technology filed Critical Beijing University of Technology
Priority to CN202110182158.5A priority Critical patent/CN112801029B/zh
Publication of CN112801029A publication Critical patent/CN112801029A/zh
Application granted granted Critical
Publication of CN112801029B publication Critical patent/CN112801029B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/20Image preprocessing
    • G06V10/26Segmentation of patterns in the image field; Cutting or merging of image elements to establish the pattern region, e.g. clustering-based techniques; Detection of occlusion
    • G06V10/267Segmentation of patterns in the image field; Cutting or merging of image elements to establish the pattern region, e.g. clustering-based techniques; Detection of occlusion by performing operations on regions, e.g. growing, shrinking or watersheds
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V2201/00Indexing scheme relating to image or video recognition or understanding
    • G06V2201/07Target detection

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Multimedia (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了基于注意力机制的多任务学习方法,具体步骤包括:(1)使用全局共享特征池提取图像特征;(2)使用注意力机制提取特定任务的特征;(3)将注意力机制提取的特征进行解码,使其适应于该任务;(4)对模型进行训练;(5)利用训练模型生成多任务学习结果。本发明方法大多为在共享特征池后进行分流,这样会使得共享特征利用不充分,无法利用低层次特征,最终使得效果偏差。本发明利用vgg16特征提取网络作为共享特征池,并在共享特征池中多次利用注意力机制提取低、中、高层次特征,充分利用了特征池中的各个特征。大大解决了特征利用不充分的问题,为机器视觉的发展奠定了基础。

Description

基于注意力机制的多任务学习方法
技术领域
本发明属于多任务学习和计算机视觉领域,涉及图像识别、语义分割、目标检测等任务,尤其涉及一种基于注意力机制的多任务学习方法。
背景技术
近年来,卷积神经网络在许多计算机视觉方面的任务上取得了巨大的成功,包括图像分类、语义分割、风格转换等。例如在2012年,Alex等人提出的AlexNet网络在ImageNet大赛上以远超第二名的成绩夺冠,在2014年的ILSVRC比赛中,VGG在Top-5中取得了92.3%的正确率。同年的冠军是googlenet。与此同时,目标检测和语义分割领域的算法层出不穷,推动着图像识别技术快速发展。FCN和SSD网络在语音分割和目标检测任务上分别取得了令人满意的成绩。然而,这些网络是典型的单任务网络,只能实现特定的任务。对于在真实场景下应用的大多数计算机视觉系统,按照传统的方法是针对每个任务建立适合该任务的网络,每个网络有不同的输入和输出,没有相互影响。这就带来了很多问题,首先参数量是成倍增长,虽说可以对每个网络进行优化,但是参数量还是很大。其次如果同时进行多个任务的运行需要同时运行多个网络,这对内存的消耗非常大。最后从数据集的角度说,每个任务需要不同的数据集,由于任务间没有共享,训练一组网络往往需要更大的数据集。因此,建立可以同时执行多个任务的网络比建立一组独立的网络要更可取。这不仅对内存和运算速度来说是一个有效的提升,在能利用更少的数据得到更好的准确率,因为相关的任务可能共享更有用的视觉特征。多任务学习用于同时学习多个相关任务,通过联合学习,它既保持了任务间的差异性又充分利用其相关性,从而从整体上提高所有任务的学习性能。
因此,本发明针对舰船目标的识别问题,研究多任务深度卷积神经网络的目标识别方法。通过多任务学习和共享卷积特征以提高舰船目标的识别性能。在此基础上引入注意力机制,有效地在参数共享的基础上筛选出对特定任务更为有效的特征,同时抑制对该任务没有帮助的特征,最终实现舰船目标的识别。本文的成果可以为舰船目标的识别问题提供重要的技术参考,具有重要的军事意义和应用价值。
发明内容
1、基于注意力机制的多任务学习方法,其特征在于:该方法包括如下步骤,
步骤1:构建舰船数据集。舰船数据集来源于CNSS海事服务网,称为数据集D,舰船数据集舰船图像为Im,类别标签为xm,语义标签为ym,目标检测框为zm;分割标签ym为使用Labelme进行精准标注的灰度图像,目标检测框zm为使用LabelImg进行精准标注的xml文件;数据集D共包含M类舰船的N幅图像,将数据集D划分为训练集Dt和测试集Ds;训练集Dt包含Nt幅图像,测试集Ds包含Ns幅图像;m为图像的序号数;
步骤2:构建主干网络。选择VGG-16的前13层作为主干网络,该主干网络由5个卷积块组成;前两个卷积块中各包含两个卷积层,后三个卷积块中各包含三个卷积层,卷积层表示为Ci-j,其中i表示当前是第几个卷积块,j表示当前卷积层是该卷积块中的第几个卷积层;每个卷积块后有一个池化层;主干网络的输入为彩色的舰船图像Ii∈Rh×w×3(1<i<Nt),其中h和w分别表示图像的高度和宽度,3表示图像的通道个数;主干网络不做任何输出,由各个任务对应的注意力机制来做输出。
步骤3:在主干网络中添加注意力机制。主干网络中有5个卷积块,因此本发明中每个任务的注意力机制包含5个注意力掩膜。注意力掩膜实现将全局特征池中的前后特征进行连接和融合。
步骤4:建立损失函数。为每个任务建立损失函数:交叉熵是建立在熵的基础上表示两种概率分布之间的差异的一种度量方法。交叉熵损失函数常用于分类任务中,尤其是在神经网络分类问题中使用更为普遍。分类任务的损失函数Lcla为网络实际输出概率Pc与标签值yc的交叉熵,表示为:
其中K为类别数量;Pc为网络预测的该图像属于类别c的概率。yc是ont-hot格式的标签,也就是如果类别是c,则yc=1,否则等于0;语义分割任务同分类任务类似,其损失函数Lseg为实际输出特征图Oseg与真实标签值yseg的逐像素交叉熵损失,其中Oseg和yseg为h×w的二维向量,表示为:
其中1<p<h,1<q<w;yseg中的所有值为0或1。
在计算目标检测的损失函数时,首先需要计算网络预测层输出的每个点的预测框和真实框的交并比,若交并比大于设定的阈值,就可以认为这个预测框与真实框标记的类别相同,认为这是一个正例,否则就认为这个框是负例,指向背景。所有正例组成的集合叫Pos,所有负例组成的集合为Neg。因此目标检测损失函数由两部分组成,一方面来自于预测框与真实框位置的损失Lloc,另一方面来自于该框预测的类别置信度的损失Lconf,总的损失为两个损失加权和,表示为:
其中N是匹配的先验框的数量(就是正负样本的数量之和),位置损失Lloc是预测框(l)和真实标签值框(g)参数之间的smoothL1损失
其中中a是预选框序号,b是真实框序号,p是类别序号,当p=0时为背景,则/>表示第a个预测框与第b个真实框关于类别k是否匹配,若匹配则该值为1,否则为0。
smoothL1是做光滑处理之后的L1范数损失函数,其计算公式如下:
表示预测框相对于第b个真实框在m方向上的偏移量,其中m∈{cx,cy,w,h},计算公式分别如下:
其中ga代表真实框,即图片中目标的位置,分别为cx、cy、h、w,da表示预测框。
由分类任务中的交叉熵公式可得置信度计算中的交叉熵损失函数:
其中p(xc)为真实框属于第c类的概率,q(xc)表示预测框属于第c类的概
率,计算置信损失Lconf由下面公式计算得出:
其中,表示第a个预测框对应类别c的预测概率,/>表示第a个预测框属于背景的概率。
计算完三个任务的损失之后,网络的总损失为三个损失函数的加和:
Lmtl=Lcla+Lseg+Lobj
步骤5:网络训练。设置迭代次数、学习率超参数,将训练集Dt输入网络,使用随机梯度下降算法对网络参数进行迭代更新,直到损失收敛,保存最终的模型。
步骤6:网络测试。加载保存的模型,利用测试集Ds测试各个任务,输入单幅图像Ii∈Rh×w×3(1<i<Nt),获得3个输出out1、out2、out3;out1为维度为K的向量,其中K为类别数;out2为维度为(h,w)的向量;out3为维度为(K,nm,5)的向量,其中nm表示输出前n×m个置信度最大的框,5表示[cx,cy,h,w,conf],其中conf为这个预测框是该类别的置信度。根据以上三个输出,对out1取最大值索引即可获得预测类别,使用opencv将out2显示为灰度图即可获得预测的分割图,对out3中取出所有类别中最大置信度的框合并到输入图像中即可获得目标检测检测框和类别。
2、根据权利要求1所述的基于注意力机制的多任务学习方法,其特征在于:步骤3中,包括如下步骤,步骤3.1:每个卷积块中的第一个和最后一个卷积层的输出将做为该任务注意力掩膜的输入,并且该注意力掩膜将这两个输入进行连接。任务1注意力机制A1的第一个注意力掩膜接收主干网络第一个卷积块的两个输出out1-1和out1-2,假设该掩膜中的两个函数为f1-1和f1-2,f1-1中包含的层顺序为卷积层、batchnorm层、relu激活层、卷积层、batchnorm层、sigmoid层,即:
f1-1(x)=sigmoid(bn(conv(relu(bn(conv(x))))))
其中conv表示卷积计算,bn表示batchnorm计算,sigmoid和relu分别表示不同的激活函数;f1-2将f1-1的结果和out1-2进行逐元素相乘后进行卷积和池化运算,即
f1-2(x)=maxpool(conv(f1-1(x)*out1-2)))
其中maxpool表示最大池化计算,*表示逐元素相乘。
步骤3.2:下一个注意力掩膜通过结合主干网络中的部分输出和上一个掩膜的输出来达到特征融合的目的。结合下一个掩膜输入结果
Rcat(x)=concat(f1-2(x),out2-1,dim=1)
其中concat为拼接函数,接收一个dim参数来指定拼接维度;则第二个掩膜的输入为:Rcat(x)、out2-2,此时计算方法同第一个掩膜。后续的3个掩膜同第二个掩膜方法类似,且后续的3个掩膜的输入分别为:Rcat(x)和out3-3、Rcat(x)和out4-3、Rcat(x)和out5-3。其中Rcat(x)为上一个掩膜的输出;
步骤3.3:由于实现目标检测、语义分割、分类3个任务,在注意力机制结构设计上构建3个上述注意力机制。假设注意力机制A1对应分类任务,将A1的输出A1-out输入到分类器classifier中,分类器为全连接层结构,将所有神经元全连接到c个神经元中产生每个类别的概率;注意力机制A2对应语义分割任务,将A2的输出A2-out输入到分割解码网络seg中,解码网络主要使用上采样方法,将特征图还原成原始输入大小,在应用softmax就能产生和原始图像一样大小的二维概率矩阵;注意力机制A3对应目标检测任务,将A3的输出A3-out输入到检测网络obj中,利用辅助卷积层分别提取不同大小的目标特征进行检测和识别。
附图说明
图1为原始图像。
图2为骨干网络结构。
图3为网络整体结构。
图4为测试结果图。
具体实施方式
1、基于注意力机制的多任务学习方法,其特征在于:该方法包括如下步骤,
步骤1:构建舰船数据集。舰船数据集来源于CNSS海事服务网,称为数据集D,舰船数据集舰船图像为Im,类别标签为xm,语义标签为ym,目标检测框为zm;分割标签ym为使用Labelme进行精准标注的灰度图像,目标检测框zm为使用LabelImg进行精准标注的xml文件;数据集D共包含M类舰船的N幅图像,将数据集D划分为训练集Dt和测试集Ds;训练集Dt包含Nt幅图像,测试集Ds包含Ns幅图像;m为图像的序号数;
步骤2:构建主干网络。选择VGG-16的前13层作为主干网络,该主干网络由5个卷积块组成;前两个卷积块中各包含两个卷积层,后三个卷积块中各包含三个卷积层,卷积层表示为Ci-j,其中i表示当前是第几个卷积块,j表示当前卷积层是该卷积块中的第几个卷积层;每个卷积块后有一个池化层;主干网络的输入为彩色的舰船图像Ii∈Rh×w×3(1<i<Nt),其中h和w分别表示图像的高度和宽度,3表示图像的通道个数;主干网络不做任何输出,由各个任务对应的注意力机制来做输出。
步骤3:在主干网络中添加注意力机制。主干网络中有5个卷积块,因此本发明中每个任务的注意力机制包含5个注意力掩膜。注意力掩膜实现将全局特征池中的前后特征进行连接和融合。
步骤4:建立损失函数。为每个任务建立损失函数:交叉熵是建立在熵的基础上表示两种概率分布之间的差异的一种度量方法。交叉熵损失函数常用于分类任务中,尤其是在神经网络分类问题中使用更为普遍。分类任务的损失函数Lcla为网络实际输出概率Pc与标签值yc的交叉熵,表示为:
其中K为类别数量;Pc为网络预测的该图像属于类别c的概率。yc是ont-hot格式的标签,也就是如果类别是c,则yc=1,否则等于0;语义分割任务同分类任务类似,其损失函数Lseg为实际输出特征图Oseg与真实标签值yseg的逐像素交叉熵损失,其中Oseg和yseg为h×w的二维向量,表示为:
其中1<p<h,1<q<w;yseg中的所有值为0或1。
在计算目标检测的损失函数时,首先需要计算网络预测层输出的每个点的预测框和真实框的交并比,若交并比大于设定的阈值,就可以认为这个预测框与真实框标记的类别相同,认为这是一个正例,否则就认为这个框是负例,指向背景。所有正例组成的集合叫Pos,所有负例组成的集合为Neg。因此目标检测损失函数由两部分组成,一方面来自于预测框与真实框位置的损失Lloc,另一方面来自于该框预测的类别置信度的损失Lconf,总的损失为两个损失加权和,表示为:
其中N是匹配的先验框的数量(就是正负样本的数量之和),位置损失Lloc是预测框(l)和真实标签值框(g)参数之间的smoothL1损失
其中中a是预选框序号,b是真实框序号,p是类别序号,当p=0时为背景,则/>表示第a个预测框与第b个真实框关于类别k是否匹配,若匹配则该值为1,否则为0。
smoothL1是做光滑处理之后的L1范数损失函数,其计算公式如下:
表示预测框相对于第b个真实框在m方向上的偏移量,其中m∈{cx,cy,w,h},计算公式分别如下:
其中ga代表真实框,即图片中目标的位置,分别为cx、cy、h、w,da表示预测框。
由分类任务中的交叉熵公式可得置信度计算中的交叉熵损失函数:
其中p(xc)为真实框属于第c类的概率,q(xc)表示预测框属于第c类的概
率,计算置信损失Lconf由下面公式计算得出:
其中,表示第a个预测框对应类别c的预测概率,/>表示第a个预测框属于背景的概率。
计算完三个任务的损失之后,网络的总损失为三个损失函数的加和:
Lmtl=Lcla+Lseg+Lobj
步骤5:网络训练。设置迭代次数、学习率超参数,将训练集Dt输入网络,使用随机梯度下降算法对网络参数进行迭代更新,直到损失收敛,保存最终的模型。
步骤6:网络测试。加载保存的模型,利用测试集Ds测试各个任务,输入单幅图像Ii∈Rh×w×3(1<i<Nt),获得3个输出out1、out2、out3;out1为维度为K的向量,其中K为类别数;out2为维度为(h,w)的向量;out3为维度为(K,nm,5)的向量,其中nm表示输出前n×m个置信度最大的框,5表示[cx,cy,h,w,conf],其中conf为这个预测框是该类别的置信度。根据以上三个输出,对out1取最大值索引即可获得预测类别,使用opencv将out2显示为灰度图即可获得预测的分割图,如附图4(a)所示。对out3中取出所有类别中最大置信度的框合并到输入图像中即可获得目标检测检测框和类别,如附图4(b)所示。
2、根据权利要求1所述的基于注意力机制的多任务学习方法,其特征在于:步骤3中,包括如下步骤,步骤3.1:每个卷积块中的第一个和最后一个卷积层的输出将做为该任务注意力掩膜的输入,并且该注意力掩膜将这两个输入进行连接。任务1注意力机制A1的第一个注意力掩膜接收主干网络第一个卷积块的两个输出out1-1和out1-2,假设该掩膜中的两个函数为f1-1和f1-2,f1-1中包含的层顺序为卷积层、batchnorm层、relu激活层、卷积层、batchnorm层、sigmoid层,即:
f1-1(x)=sigmoid(bn(conv(relu(bn(conv(x))))))
其中conv表示卷积计算,bn表示batchnorm计算,sigmoid和relu分别表示不同的激活函数;f1-2将f1-1的结果和out1-2进行逐元素相乘后进行卷积和池化运算,即
f1-2(x)=maxpool(conv(f1-1(x)*out1-2)))
其中maxpool表示最大池化计算,*表示逐元素相乘。
步骤3.2:下一个注意力掩膜通过结合主干网络中的部分输出和上一个掩膜的输出来达到特征融合的目的。结合下一个掩膜输入结果
Rcat(x)=concat(f1-2(x),out2-1,dim=1)
其中concat为拼接函数,接收一个dim参数来指定拼接维度;则第二个掩膜的输入为:Rcat(x)、out2-2,此时计算方法同第一个掩膜。后续的3个掩膜同第二个掩膜方法类似,且后续的3个掩膜的输入分别为:Rcat(x)和out3-3、Rcat(x)和out4-3、Rcat(x)和out5-3。其中Rcat(x)为上一个掩膜的输出;
步骤3.3:由于实现目标检测、语义分割、分类3个任务,在注意力机制结构设计上构建3个上述注意力机制。假设注意力机制A1对应分类任务,将A1的输出A1-out输入到分类器classifier中,分类器为全连接层结构,将所有神经元全连接到c个神经元中产生每个类别的概率;注意力机制A2对应语义分割任务,将A2的输出A2-out输入到分割解码网络seg中,解码网络主要使用上采样方法,将特征图还原成原始输入大小,在应用softmax就能产生和原始图像一样大小的二维概率矩阵;注意力机制A3对应目标检测任务,将A3的输出A3-out输入到检测网络obj中,利用辅助卷积层分别提取不同大小的目标特征进行检测和识别。
以上实例仅用于描述本发明,而非限制本发明所描述的技术方案。因此,一切不脱离本发明精神和范围的技术方案及其改进,均应涵盖在本发明的权利要求范围中。

Claims (2)

1.基于注意力机制的多任务学习方法,其特征在于:该方法包括如下步骤,
步骤1:构建舰船数据集;舰船数据集来源于CNSS海事服务网,称为数据集D,舰船数据集舰船图像为Im,类别标签为xm,语义标签为ym,目标检测框为zm;分割标签ym为使用Labelme进行精准标注的灰度图像,目标检测框zm为使用LabelImg进行精准标注的xml文件;数据集D共包含M类舰船的N幅图像,将数据集D划分为训练集Dt和测试集Ds;训练集Dt包含Nt幅图像,测试集Ds包含Ns幅图像;m为图像的序号数;
步骤2:构建主干网络;选择VGG-16的前13层作为主干网络,该主干网络由5个卷积块组成;前两个卷积块中各包含两个卷积层,后三个卷积块中各包含三个卷积层,卷积层表示为Ci-j,其中i表示当前是第几个卷积块,j表示当前卷积层是该卷积块中的第几个卷积层;每个卷积块后有一个池化层;主干网络的输入为彩色的舰船图像Ii∈Rh×w×3,1<i<Nt,其中h和w分别表示图像的高度和宽度,3表示图像的通道个数;主干网络不做任何输出,由各个任务对应的注意力机制来做输出;
步骤3:在主干网络中添加注意力机制;主干网络中有5个卷积块,因此本发明中每个任务的注意力机制包含5个注意力掩膜;注意力掩膜实现将全局特征池中的前后特征进行连接和融合;
步骤4:建立损失函数;为每个任务建立损失函数:交叉熵是建立在熵的基础上表示两种概率分布之间的差异的一种度量方法;分类任务的损失函数Lcla为网络实际输出概率Pc与标签值yc的交叉熵,表示为:
其中K为类别数量;Pc为网络预测的图像属于类别c的概率;yc是ont-hot格式的标签,也就是如果类别是c,则yc=1,否则等于0;语义分割任务同分类任务类似,损失函数Lseg为实际输出特征图Oseg与真实标签值yseg的逐像素交叉熵损失,其中Oseg和yseg为h×w的二维向量,表示为:
其中1<p<h,1<q<w;yseg中的所有值为0或1;
在计算目标检测的损失函数时,首先需要计算网络预测层输出的每个点的预测框和真实框的交并比,若交并比大于设定的阈值,认为这个预测框与真实框标记的类别相同,认为这是一个正例,否则就认为这个框是负例,指向背景;所有正例组成的集合叫Pos,所有负例组成的集合为Neg;因此目标检测损失函数由两部分组成,一方面来自于预测框与真实框位置的损失Lloc,另一方面来自于该框预测的类别置信度的损失Leonf,总的损失为两个损失加权和,表示为:
其中N是匹配的先验框的数量,位置损失Lloc是预测框l和真实标签值框g参数之间的smoothL1损失
其中中a是预选框序号,b是真实框序号,p是类别序号,当p=0时为背景,则/>表示第a个预测框与第b个真实框关于类别k是否匹配,若匹配则该值为1,否则为0;
smoothL1是做光滑处理之后的L1范数损失函数,其计算公式如下:
表示预测框相对于第b个真实框在m方向上的偏移量,其中m∈{cx,cy,w,h},计算公式分别如下:
其中ga代表真实框,即图片中目标的位置,分别为cx、cy、h、w,da表示预测框;
由分类任务中的交叉熵公式可得置信度计算中的交叉熵损失函数:
其中p(xc)为真实框属于第c类的概率,q(xc)表示预测框属于第c类的概率,计算置信损失Lconf由下面公式计算得出:
其中,表示第a个预测框对应类别c的预测概率,/>表示第a个预测框属于背景的概率;
计算完三个任务的损失之后,网络的总损失为三个损失函数的加和:
Lmtl=Lcla+Lseg+Lobj
步骤5:网络训练;设置迭代次数、学习率超参数,将训练集Dt输入网络,使用随机梯度下降算法对网络参数进行迭代更新,直到损失收敛,保存最终的模型;
步骤6:网络测试;加载保存的模型,利用测试集Ds测试各个任务,输入单幅图像Ii∈Rh ×w×3,1<i<Nt,获得3个输出out1、out2、out3;out1为维度为K的向量,其中K为类别数;out2为维度为(h,w)的向量;out3为维度为(K,nm,5)的向量,其中nm表示输出前n×m个置信度最大的框,5表示[cx,cy,h,w,conf],其中conf为这个预测框是该类别的置信度;根据以上三个输出,对outl取最大值索引即可获得预测类别,使用opencv将out2显示为灰度图即可获得预测的分割图,对out3中取出所有类别中最大置信度的框合并到输入图像中即可获得目标检测检测框和类别。
2.根据权利要求1所述的基于注意力机制的多任务学习方法,其特征在于:步骤3中,包括如下步骤,步骤3.1:每个卷积块中的第一个和最后一个卷积层的输出将做为该任务注意力掩膜的输入,并且该注意力掩膜将这两个输入进行连接;任务1注意力机制A1的第一个注意力掩膜接收主干网络第一个卷积块的两个输出out1-1和out1-2,假设该掩膜中的两个函数为f1-1和f1-2,f1-1中包含的层顺序为卷积层、batchnorm层、relu激活层、卷积层、batchnorm层、sigmoid层,即:
f1-1(x)=sigmoid(bn(conv(relu(bn(conv(x))))))
其中conv表示卷积计算,bn表示batchnorm计算,sigmoid和relu分别表示不同的激活函数;f1-2将f1-1的结果和out1-2进行逐元素相乘后进行卷积和池化运算,即
f1-2(x)=maxpool(conv(f1-1(x)*out1-2)))
其中maxpool表示最大池化计算,*表示逐元素相乘;
步骤3.2:下一个注意力掩膜通过结合主干网络中的部分输出和上一个掩膜的输出来达到特征融合的目的;结合下一个掩膜输入结果
Rcat(x)=concat(f1-2(x),out2-1,dim=1)
其中concat为拼接函数,接收一个dim参数来指定拼接维度;则第二个掩膜的输入为:Rcat(x)、out2-2,此时计算方法同第一个掩膜;后续的3个掩膜同第二个掩膜方法类似,且后续的3个掩膜的输入分别为:Rcat(x)和out3-3、Rcat(x)和out4-3、Rcat(x)和out5-3;其中Rcat(x)为上一个掩膜的输出;
步骤3.3:由于实现目标检测、语义分割、分类3个任务,在注意力机制结构设计上构建3个上述注意力机制;假设注意力机制A1对应分类任务,将A1的输出A1-out输入到分类器classifier中,分类器为全连接层结构,将所有神经元全连接到c个神经元中产生每个类别的概率;注意力机制A2对应语义分割任务,将A2的输出A2-out输入到分割解码网络seg中,解码网络主要使用上采样方法,将特征图还原成原始输入大小,在应用softmax就能产生和原始图像一样大小的二维概率矩阵;注意力机制A3对应目标检测任务,将A3的输出A3-out输入到检测网络obj中,利用辅助卷积层分别提取不同大小的目标特征进行检测和识别。
CN202110182158.5A 2021-02-09 2021-02-09 基于注意力机制的多任务学习方法 Active CN112801029B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110182158.5A CN112801029B (zh) 2021-02-09 2021-02-09 基于注意力机制的多任务学习方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110182158.5A CN112801029B (zh) 2021-02-09 2021-02-09 基于注意力机制的多任务学习方法

Publications (2)

Publication Number Publication Date
CN112801029A CN112801029A (zh) 2021-05-14
CN112801029B true CN112801029B (zh) 2024-05-28

Family

ID=75815038

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110182158.5A Active CN112801029B (zh) 2021-02-09 2021-02-09 基于注意力机制的多任务学习方法

Country Status (1)

Country Link
CN (1) CN112801029B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113392724B (zh) * 2021-05-25 2022-12-27 中国科学院西安光学精密机械研究所 基于多任务学习的遥感场景分类方法
CN113554156B (zh) * 2021-09-22 2022-01-11 中国海洋大学 基于注意力机制与可变形卷积的多任务图像处理方法

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111209975A (zh) * 2020-01-13 2020-05-29 北京工业大学 一种基于多任务学习的舰船目标识别方法
CN111275688A (zh) * 2020-01-19 2020-06-12 合肥工业大学 基于注意力机制的上下文特征融合筛选的小目标检测方法
CN111340096A (zh) * 2020-02-24 2020-06-26 北京工业大学 一种基于对抗互补学习的弱监督蝴蝶目标检测方法
CN111353505A (zh) * 2020-05-25 2020-06-30 南京邮电大学 可联合实现语义分割和景深估计的网络模型及训练方法
CN111539469A (zh) * 2020-04-20 2020-08-14 东南大学 一种基于视觉自注意力机制的弱监督细粒度图像识别方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10956817B2 (en) * 2018-04-18 2021-03-23 Element Ai Inc. Unsupervised domain adaptation with similarity learning for images

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111209975A (zh) * 2020-01-13 2020-05-29 北京工业大学 一种基于多任务学习的舰船目标识别方法
CN111275688A (zh) * 2020-01-19 2020-06-12 合肥工业大学 基于注意力机制的上下文特征融合筛选的小目标检测方法
CN111340096A (zh) * 2020-02-24 2020-06-26 北京工业大学 一种基于对抗互补学习的弱监督蝴蝶目标检测方法
CN111539469A (zh) * 2020-04-20 2020-08-14 东南大学 一种基于视觉自注意力机制的弱监督细粒度图像识别方法
CN111353505A (zh) * 2020-05-25 2020-06-30 南京邮电大学 可联合实现语义分割和景深估计的网络模型及训练方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
基于双注意力机制的遥感图像目标检测;周幸;陈立福;;计算机与现代化;20200815(第08期);全文 *
基于注意力LSTM和多任务学习的远场语音识别;张宇;张鹏远;颜永红;;清华大学学报(自然科学版);20180315(第03期);全文 *

Also Published As

Publication number Publication date
CN112801029A (zh) 2021-05-14

Similar Documents

Publication Publication Date Title
CN111639692B (zh) 一种基于注意力机制的阴影检测方法
Turhan et al. Recent trends in deep generative models: a review
CN112749626B (zh) 一种面向dsp平台的快速人脸检测与识别方法
Jiang et al. Cascaded subpatch networks for effective CNNs
CN112906720A (zh) 基于图注意力网络的多标签图像识别方法
CN112801029B (zh) 基于注意力机制的多任务学习方法
CN112288011A (zh) 一种基于自注意力深度神经网络的图像匹配方法
CN111310766A (zh) 基于编解码和二维注意力机制的车牌识别方法
Amirian et al. Dissection of deep learning with applications in image recognition
CN114283352A (zh) 一种视频语义分割装置、训练方法以及视频语义分割方法
CN110598746A (zh) 一种基于ode求解器自适应的场景分类方法
CN115482518A (zh) 一种面向交通场景的可扩展多任务视觉感知方法
CN116596966A (zh) 一种基于注意力和特征融合的分割与跟踪方法
CN111899203A (zh) 基于标注图在无监督训练下的真实图像生成方法及存储介质
CN114780767A (zh) 一种基于深度卷积神经网络的大规模图像检索方法及系统
CN114550014A (zh) 道路分割方法及计算机装置
CN111967408B (zh) 基于“预测-恢复-识别”的低分辨率行人重识别方法及系统
CN110942463B (zh) 一种基于生成对抗网络的视频目标分割方法
CN117011515A (zh) 基于注意力机制的交互式图像分割模型及其分割方法
CN115424275B (zh) 一种基于深度学习技术的渔船船牌号识别方法及系统
CN116665114A (zh) 基于多模态的遥感场景识别方法、系统及介质
Malekijoo et al. Convolution-deconvolution architecture with the pyramid pooling module for semantic segmentation
CN113688946B (zh) 基于空间关联的多标签图像识别方法
CN115294353A (zh) 基于多层属性引导的人群场景图像字幕描述方法
CN111639563B (zh) 一种基于多任务的篮球视频事件与目标在线检测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant