CN113887610A - 基于交叉注意力蒸馏Transformer的花粉图像分类方法 - Google Patents

基于交叉注意力蒸馏Transformer的花粉图像分类方法 Download PDF

Info

Publication number
CN113887610A
CN113887610A CN202111147668.5A CN202111147668A CN113887610A CN 113887610 A CN113887610 A CN 113887610A CN 202111147668 A CN202111147668 A CN 202111147668A CN 113887610 A CN113887610 A CN 113887610A
Authority
CN
China
Prior art keywords
token
distillation
attention
network
module
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202111147668.5A
Other languages
English (en)
Other versions
CN113887610B (zh
Inventor
石宝
段凯博
杨传颖
马少瑛
黄林
李林
张心月
田宇
周昊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Inner Mongolia University of Technology
Original Assignee
Inner Mongolia University of Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Inner Mongolia University of Technology filed Critical Inner Mongolia University of Technology
Priority to CN202111147668.5A priority Critical patent/CN113887610B/zh
Publication of CN113887610A publication Critical patent/CN113887610A/zh
Application granted granted Critical
Publication of CN113887610B publication Critical patent/CN113887610B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

一种基于交叉注意力蒸馏Transformer的花粉图像分类方法,利用两个网络训练数据,两个网络互为对方的老师;网络一将图片编码为图片令牌,并加入Class令牌和蒸馏令牌;利用再注意力Transformer模块计算所有令牌的全局关联性;采用动态令牌稀疏化模块修剪掉冗余图片令牌,提高吞吐量;网络二将图片通过卷积运算编码为图片令牌,增加对图片令牌内部信息的建模,并加入Class令牌和蒸馏令牌;利用卷积投影以动态的卷积注意力机制来实现图片令牌的局部和全局像素信息的融合;本发明使两个网络通过各自的蒸馏令牌在蒸馏损失部分与老师网络的输出空间进行交互,学习老师网络的特征空间表达,最后输出分类结果。

Description

基于交叉注意力蒸馏Transformer的花粉图像分类方法
技术领域
本发明属于计算机视觉技术领域,特别涉及一种基于交叉注意力蒸馏Transformer的花粉图像分类方法。
背景技术
自从AlexNet网络在2012年的ImageNet图像分类比赛中获得冠军后,深度学习大热,随后相继出现许多优秀的CNN模型,如VGG-16,GoogleNet,ResNet。随着CNN网络的大放异彩,以卷积神经网络为主要模型的深度学习方法成为处理计算机视觉任务的主流。
Transformer是谷歌团队在2017年发表论文《Attention is All You Need》中提出的针对自然语言处理(NLP)的模型,它以可并行化计算和建立全局依赖关系等优点迅速成为NLP领域的首选模型。然而在计算机视觉中,由LeCun,Krizhevsky等人相继提出的卷积神经网络模型任然占据主导位置,受到Transformer在NLP领域中获得的巨大成功,研究者开始把目光投入计算机视觉领域,2020年,Facebook AI提出BETR模型,这是第一个将Transformer成功应用于目标检测和全景分割中的目标检测框架,在性能上,DETR达到了当时的SOTA效果,研究者发现在COCO目标检测数据集中,DETR在大型目标上的检测性能要优于Faster-R-CNN,但在小目标上不如后者。2020年,Dosovitskiy等人首次尝试将标准Transformer模型直接应用于图像分类,并尽可能少的修改,称之为视觉转换器(VisionTransformer)ViT。ViT将输入图像分割为大小相等的补丁Patch,并且将这些补丁的线性序列作为Transformer的输入,补丁的处理方式与NLP中的令牌Token相同,以监督学习的方式训练ViT图像分类模型。在大型私有数据集JFT-300M上进行预训练时,ViT在多个图像识别基准上接近或超过最新水平。2021年,Hugo Touvron等人在ViT的基础上加入了知识蒸馏策略,通过添加一个蒸馏令牌来与Teacher Model交互学习,最后通过蒸馏损失输出,称之为DeiT。DeiT通过一组优秀的超参数以及蒸馏操作,在不使用任何卷积操作的情况下,在86M参数的ImageNet数据集上就实现了83.1%的Top-1精度。
ViT模型在没有依赖CNN的情况下,将原版的Transformer迁移到图像分类任务上,并且在大规模数据集上也能得到很好的效果。同时ViT也存在缺点,比如在需要大型数据集来进行预训练才能达到很好的效果;在堆叠层数更深时,ViT模型的性能会迅速饱和,称之为注意力雷同问题;视觉Transformer中的最终预测结果仅基于信息量最大的令牌的子集,也就是存在冗余令牌;以及缺少卷积神经网络的如平移不变性和权重共享。
发明内容
为了克服上述现有技术的缺点,本发明的目的在于提供一种基于交叉注意力蒸馏Transformer的花粉图像分类方法,采用两个主干Transformer网络,网络一致力于解决冗余令牌和注意力雷同问题,网络二在ViT的基础上引入卷积操作,两个主干Transformer网络互为对方的老师和学生,利用知识蒸馏的方法,各自用一个蒸馏令牌与对方交互,以达到互相学习的目的。
为了实现上述目的,本发明采用的技术方案是:
基于交叉注意力蒸馏Transformer的花粉图像分类方法,采用网络一和网络二的架构实现,其中:
所述网络一将输入花粉图片进行分割,然后线性投影为图片令牌,并加入一个蒸馏令牌和一个Class令牌得到令牌序列一,利用再注意力transformer模块的再注意力机制消除令牌序列一中的注意力雷同问题,再利用动态令牌稀疏化模块的动态令牌稀疏化去除冗余令牌,多次经过再注意力Transformer模块和动态令牌稀疏化模块后,输出令牌序列二,将令牌序列二中的Class令牌和蒸馏令牌加权进行预测分类;
所述网络二将输入花粉图片进行分割,然后卷积编码为图片令牌,并加入一个蒸馏令牌和一个Class令牌得到令牌序列三,利用卷积Transformer模块的卷积注意力机制实现局部感受野,并共享卷积权重,利用卷积令牌编码模块减少令牌数量,同时增加令牌宽度,多次经过卷积Transformer模块和卷积令牌编码模块后,输出得到令牌序列四,将得到令牌序列四中的Class令牌和蒸馏令牌加权进行预测分类;
所述网络一和网络二中,蒸馏令牌和Class令牌均与图片令牌进行注意力运算并且输出概率值,取网络一、网络二的最大准确率作为最终预测分类结果;
所述网络一和网络二的损失函数组成交叉注意力蒸馏模块,交叉注意力蒸馏模块同时训练网络一和网络二,网络一和网络二互为对方的老师和学生,通过各自的蒸馏令牌以蒸馏损失的目标输出方式与老师网络进行交互。
在其中一个实施例中,所述网络一由Transformer编码器、再注意力transformer模块和动态令牌稀疏化模块组成,所述再注意力transformer模块和动态令牌稀疏化模块均有多个,依次交替设置,且再注意力transformer模块的数量较动态令牌稀疏化模块的数量多一个,所述再注意力transformer模块由再注意力模块和前馈网络组成。
在其中一个实施例中,所述Transformer编码器将输入花粉图片
Figure BDA0003285997210000031
重新划分为2D图像块序列
Figure BDA0003285997210000032
其中H,W是输入图像的长和宽,C是通道数,
Figure BDA0003285997210000033
是图像集合,N是产生的图像块个数,产生的图像块即所述补丁,补丁的维度大小是(P2·C),(P,P)是每个补丁的分辨率,C是每个补丁的维数,对每个补丁使用可训练的线性投影得到(N,D)的二维图片令牌向量,线性投影后的补丁即为图片令牌,然后初始化D维的Class令牌和D维的蒸馏令牌,加入图片令牌序列;
所述Class令牌和蒸馏令牌是初始化的可学习的嵌入向量,Class令牌和蒸馏令牌通过与图片令牌进行注意力运算,对令牌之间的全局关系进行建模,并且融合所有令牌的信息,最终与分类器相连进行类别预测。
在其中一个实施例中,使用位置编码加入令牌序列,所述位置编码是初始化的与输入令牌序列相同维度的可训练变量,通过位置编码对无序的令牌进行编码排序,以保留每个令牌的绝对或相对位置信息。
在其中一个实施例中,所述再注意力模块建立在多头注意力机制的基础上,注意力机制将每个输入令牌线性投影为可训练的查询Q、键K、值V三组值,通过所有令牌的K对Q进行点积生成注意力图,并且除以缩放因子
Figure BDA0003285997210000041
经过Softmax激活函数以获得V的权重输出到下一个再注意力模块,该注意力图表示每个再注意Transformer模块内所有令牌之间的全局相关性;多头注意力机制利用不同的权值矩阵将每个输入令牌投影到h个不同的子空间,每个子空间并行地执行注意力机制,将它们的输出值连接起来再次进行投影,得到再注意力模块的输出,再注意力机制通过定义一个端到端可训练的变换矩阵
Figure BDA0003285997210000042
使用变换矩阵动态聚合同一再注意力Transformer模块中不同头部之间的注意力映射图,重新映射出新的注意力图,解决了因层数加深而产生的注意力雷同问题,通过加深Transformer模块层数来增加注意力特征表达空间的多样性,更好的对令牌的全局关系进行建模,提高花粉图像分类的准确率。
在其中一个实施例中,所述前馈网络由MLP构成,其中包含两层线性层以及GELU激活函数,前馈网络用于对向量进行融合,对各个位置的信息进行变换,并且投影到所需的维度。
在其中一个实施例中,所述动态令牌稀疏化模块由预测模块和注意力屏蔽策略组成,对于每个输入的令牌实例,预测模块生成一个二进制决策掩码,来决定每个令牌的保留和修剪,预测模块被添加到再注意力Transformer模块中,从而逐步增加令牌的修剪数量,实现分层的稀疏化处理,一个令牌被修剪后,将不再参与后续的注意力运算;
所述预测模块以二进制决策掩码
Figure BDA0003285997210000043
和令牌
Figure BDA0003285997210000044
作为输入,使用MLP计算令牌的局部特征和全局特征,结合两者后通过一个线性变换,再利用Softmax激活函数得到一个概率值π,应用Gumbel-Softmax激活函数从π中取样得到当前的决策D,利用
Figure BDA0003285997210000051
更新
Figure BDA0003285997210000052
所述注意力屏蔽策略基于二进制决策掩码,在计算注意力映射图时加上一个注意力掩码矩阵
Figure BDA0003285997210000053
Figure BDA0003285997210000054
由二进制决策掩码生成,通过
Figure BDA0003285997210000055
显示切断的已修剪令牌和其他令牌之间的联系,只考虑当前阶段保留令牌之间的注意力矩阵运算,同时保持令牌数量不变。
在其中一个实施例中,所述网络二由多个卷积令牌编码模块和多个卷积Transformer模块组成,卷积令牌编码模块和卷积Transformer模块依次交替设置,所述卷积Transformer模块由多头注意力机制和前馈网络组成;
所述卷积令牌编码模块将分割得到的2D图像或上一阶段输出的2D重塑令牌图
Figure BDA0003285997210000056
作为本阶段的输入,通过卷积操作得到一组新的令牌图
Figure BDA0003285997210000057
f(·)表示卷积操作函数,然后将得到的令牌展平为HiWi×Ci的1D令牌序列,其中HiWi是本阶段令牌的个数,Ci是本阶段每个令牌的特征维数,Hi-1Wi-1是上一阶段令牌的个数,Ci-1是上一阶段每个令牌的特征维数,接着通过当前阶段的卷积Transformer模块执行注意力运算;
所述卷积Transformer模块由多头注意力机制(MHSA)和前馈网络(FFN)交替构成,将本阶段卷积令牌编码模块输出的1D令牌序列重塑为2D令牌图,使用深度可分离卷积生成图片令牌的查询Q、键K、值V,并且将Class令牌和蒸馏令牌线性投影到相同维度的Q、K、V,与图片令牌一起执行多头注意力运算。
所述网络一输出的Class令牌和蒸馏令牌组成网络一的损失函数,其中Class令牌与真实标签组成交叉熵损失,蒸馏令牌与网络二的Class令牌输出组成蒸馏损失;
所述网络二输出的Class令牌和蒸馏令牌组成网络二的损失函数,其中Class令牌与真实标签组成交叉熵损失,蒸馏令牌与网络一的Class令牌输出组成蒸馏损失;
所述网络一和网络二中,Class令牌与图片真实标签组成交叉熵损失,蒸馏令牌与老师网络输出组成蒸馏损失。
所述网络一和网络二中,除了自身Class令牌输出的交叉熵损失,还使用蒸馏令牌输出一个蒸馏损失,并且使用蒸馏策略来进行优化,所述蒸馏策略为软交叉蒸馏策略或硬交叉蒸馏策略;
所述软交叉蒸馏策略,计算蒸馏令牌与老师网络输出的KL散度,得到两个输出概率分布之间的差异性,通过减小KL散度值来使蒸馏令牌在自身注意力机制运算的基础上,逐渐向老师网络的输出方向靠拢;
所述硬交叉蒸馏策略,直接与老师网络的概率输出取交叉熵损失,使用带温度的Softmax激活函数输出的交叉熵损失,允许学生网络学习老师网络输出中的高概率负标签携带的有用信息。
与现有技术相比,本发明的有益效果是:
1、本发明从两个网络方向进行研究,每个网络结构解决视觉Transformer中存在的不同问题,通过交叉注意力蒸馏的损失函数来学习对方网络空间的输出表达特征,训练两个独立的网络通道结构,减少了模型的理论参数量和运算量,提高了图像分类水平。
2、本发明针对Transformer模块数量加深带来的注意力雷同问题,使用端到端的变换矩阵动态聚合以生产新的注意力特征图,通过堆叠更深层数来增加注意力特征图在不同层次上的多样性。
3、本发明针对视觉Transformer网络中存在的冗余令牌问题,使用动态令牌稀疏化模块分层的修剪掉多余的令牌,通过Gumbel-Softmax函数和注意力屏蔽策略实现了端到端的训练方式,减少了冗余的浮点数运算,保留了Transformer模型并行化计算的优点。
4、本发明针对视觉Transformer网络缺少对图片令牌内部的局部建模问题,使用卷积令牌编码模块来调整每个阶段令牌的个数和维数,通过卷积操作实现了对图片令牌内部信息的建模。使用卷积投影以一种动态的卷积注意力机制实现图片令牌局部和全局像素信息的融合。
附图说明
图1为基于交叉注意力蒸馏Transformer的花粉图像分类网络。
图2为再注意力结构。
图3为再注意力Transformer模块结构。
图4为预测模块结构。
图5为网络二体系结构。
图6为卷积Transformer模块结构。
图7为卷积投影结构。
图8为深度可分离卷积示意图。
具体实施方式
下面结合附图和实施例详细说明本发明的实施方式。
如图1所示,本发明为一种基于交叉注意力蒸馏Transformer的花粉图像分类方法,采用网络一和网络二的架构实现,从两个网络结构通道出发,两个网络结构致力于解决不同的问题,最后通过交叉注意力蒸馏策略相互学习,有效的减少了模型的理论参数量,提高了花粉图像的分类水平。其中网络一和网络二互为对方的老师,利用两个网络训练数据,分别预测分类,最终取网络一、网络二的最大准确率作为最终预测分类结果。
在网络一,将输入花粉图片进行分割,然后线性投影为图片令牌,并加入一个蒸馏令牌和一个Class令牌得到令牌序列一,利用再注意力transformer模块的再注意力机制计算所有令牌的全局关联性,消除令牌序列一中的注意力雷同问题,再利用动态令牌稀疏化模块的动态令牌稀疏化修剪掉冗余令牌,去除冗余令牌,提高吞吐量,多次经过再注意力Transformer模块和动态令牌稀疏化模块后,输出令牌序列二,将令牌序列二中的Class令牌和蒸馏令牌加权进行预测分类。
在网络二,将输入花粉图片进行分割,然后卷积编码为图片令牌,并加入一个蒸馏令牌和一个Class令牌得到令牌序列三,利用卷积Transformer模块的卷积注意力机制实现局部感受野,并共享卷积权重,利用卷积令牌编码模块减少令牌数量,同时增加令牌宽度,从而增加像素空间表示的丰富性和多样性,实现图片令牌的局部和全局像素信息的融合,多次经过卷积Transformer模块和卷积令牌编码模块后,输出得到令牌序列四,将得到令牌序列四中的Class令牌和蒸馏令牌加权进行预测分类。
在网络一和网络二中,加入的是初始化的D维的Class令牌和蒸馏令牌(此时两个向量都是初始化的随机值),由于Class令牌与真实标签组成交叉熵损失,蒸馏令牌与老师网络输出组成蒸馏损失,所以每个网络中,Class令牌和蒸馏令牌均趋于不同的方向,也就是两个向量趋于不同的值。而Class令牌和蒸馏令牌又均与图片令牌进行注意力运算并且输出概率值,所以最终输出的Class令牌和蒸馏令牌代表了所有图片令牌的信息。
在每个网络中,Class令牌各自初始化与自己网络的图片令牌相同的维度,由于两个网络的结构不同,所以Class令牌在各自网络中注意力运算后的值也不同。而网络一的蒸馏令牌与网络一的图片令牌注意力运算,与网络二输出组成蒸馏损失,网络二的蒸馏令牌与网络二的图片令牌注意力运算,与网络一输出组成蒸馏损失,所以两个蒸馏令牌也趋于不同的值。
本发明中,网络一可由Transformer编码器、再注意力transformer模块和动态令牌稀疏化模块组成,再注意力transformer模块和动态令牌稀疏化模块均有多个,依次交替设置,且再注意力transformer模块的数量较动态令牌稀疏化模块的数量多一个,再注意力transformer模块由再注意力模块和前馈网络组成。
网络二可由多个卷积令牌编码模块和多个卷积Transformer模块组成,卷积令牌编码模块和卷积Transformer模块依次交替设置,卷积Transformer模块由多头注意力机制和前馈网络组成。
网络一和网络二的损失函数组成交叉注意力蒸馏模块,交叉注意力蒸馏模块同时训练网络一和网络二,网络一和网络二互为对方的老师和学生,通过各自的蒸馏令牌以蒸馏损失的目标输出方式与老师网络进行交互,学习老师网络的特征空间表达。
下面更加详细地描述网络一和网络二所执行的过程。
一、在网络一
1.Transformer编码
标准Transformer模型的输入是一维的令牌序列,为了处理2D的花粉图像,Transformer编码器首先将输入花粉图片
Figure BDA0003285997210000091
重新划分为2D图像块序列
Figure BDA0003285997210000092
其中H,W是输入图像的长和宽,C是通道数,
Figure BDA0003285997210000093
是图像或令牌(补丁)序列的集合,N是产生的图像块个数,产生的图像块即所述补丁,每个补丁的维度大小是(P2·C),(P,P)是每个补丁的分辨率,C是每个补丁的维数,所以计算H×W×C→N×(P2·C),其中N=HW/P2。这里的N刚好满足Transformer模型的有效输入序列长度。为了使所有再注意力transformer模块中的补丁的恒定潜在变量为D维,所以将序列xp通过补丁编码的方式转换成(N,D)二维输入,方法是对每个补丁使用可训练的线性投影将(P2·C)映射到D维,得到(N,D)的二维图片令牌向量,线性投影后的补丁即图片令牌,此时N是图片令牌个数,D是图片令牌的维数。
通过重新划分和线性投影得到了N个D维的令牌序列作为后续输入,接着初始化两个可学习的xclass和xDistill,分别为Class令牌和蒸馏令牌,和每个图片令牌的维度一样为D维。Class令牌和蒸馏令牌是初始化的可学习的嵌入向量,它们的维度都是D维,Class令牌和蒸馏令牌通过与图片令牌序列进行注意力运算,对图像令牌之间的全局关系进行建模,并且融合所有令牌的信息,Class令牌与真实标签输出交叉熵损失部分,蒸馏令牌与老师网络输出组成蒸馏损失部分,不同的损失部分使它们收敛于不同的方向,最后输出Class令牌和蒸馏令牌的加权预测分类结果。示例地,Class令牌的最后输出连接一个分类头进行类别预测,分类头在预训练时由一个隐藏层的MLP实现,在微调时由一个线性层实现。蒸馏令牌用于在损失函数中与老师网络的输出交互学习,组成蒸馏损失部分。
随后用一个可学习的位置编码加入令牌序列,位置编码是初始化的与输入令牌序列相同维度的可训练变量,通过位置编码对无序的令牌进行编码排序,以保留令牌的绝对或相对位置信息,增强了对花粉图像分类语义的表述。位置编码(Positional encoding)的维度为
Figure BDA0003285997210000101
对训练好的位置编码,位置越接近的令牌越具有相似的位置编码。
Transformer编码完的向量如下:
Figure BDA0003285997210000102
其中xclass为Class令牌,xDistill为蒸馏令牌,它们的维度都是D维,
Figure BDA0003285997210000103
为补丁块,E为将补丁投影到D维图片令牌的线性变换,Epos为位置编码。
2.再注意力Transformer模块
A、再注意力模块
参考图2,再注意力模块以多头再注意力机制为基础。
在标准的注意力机制中,首先将输入的令牌序列通过一个层规范化LayerNorm(x),层规范化在训练和测试时执行同样的计算,在每个时间步分别计算规范化统计信息,可以显著降低训练时间。然后对每个令牌进行线性投影为可训练的查询Q、键K、值V三组值,通过以下公式计算:
Figure BDA0003285997210000104
其中,
Figure BDA0003285997210000105
Q,K,V分别表示查询、键、值,Softmax为激活函数,dk为输入维度,
Figure BDA0003285997210000106
为缩放因子,当dk为一个很大的值时,Q和K点乘得到的结果维度很大,从而导致结果位于softmax激活函数梯度很小的区域,因此除以一个缩放因子,使得维度可以缩小。通过所有令牌的K对Q进行点积生成注意力图,并且除以
Figure BDA0003285997210000107
经过Softmax函数以获得V的权重输出到下一个再注意力模块,该注意力图表示每个再注意Transformer模块内所有令牌之间的全局相关性。
以上是标准的注意力机制,为了解决注意力雷同问题,也就是随着Transformer模块的加深,注意力映射图逐渐变得相似的问题,本发明使用了再注意力机制。由于同一再注意Transformer模块不同头部的注意力图相似度很小,这表示来自同一层的不同头部关注令牌的不同方面,所以以头部的注意力映射图为基础,通过可学习的方式来交换不同头部的信息,动态聚合它们来重新生成各层的注意力映射图,来增加它们在不同层次上的多样性,可以显著增强对特征的表达,从而提高花粉图像分类的准确率。
具体来说,定义一个端到端可训练的变换矩阵
Figure BDA0003285997210000111
使用变换矩阵动态聚合同一再注意力Transformer模块中不同头部之间的注意力映射图,即使用变换矩阵沿着头部维度乘以注意力映射图,将多头注意力映射图混合到重新生成的新的注意力映射图中,然后进行标准化处理,最后与V相乘。具体公式如下:
Figure BDA0003285997210000112
其中softmax为激活函数,
Figure BDA0003285997210000113
用于调节Q,K点积完的维度,Norm为层标准化函数。
多头再注意力机制利用不同的权值矩阵将每个输入令牌线性投影到h个不同的子空间,每个子空间并行地执行再注意力机制,将它们的输出值连接起来再次进行投影,得到最终的值,即再注意力模块的输出,公式如下:
MultiHead(Q,K,V)=Concat(head1,…,headh)WO
where headi=Re-Attention(QWi Q,KWi K,VWi V)
其中参数为投影矩阵
Figure BDA0003285997210000114
Figure BDA0003285997210000115
Concat为向量拼接操作,h为多头再注意力的头数,WO是对各个头的注意力输出进行拼接后的向量线性映射函数。
再注意力模块解决了因层数加深而产生的注意力雷同问题,通过加深再注意力Transformer模块层数来增加注意力特征表达空间的多样性,更好的对令牌的全局关系进行建模,提高花粉图像分类的准确率。
B、前馈网络
前馈网络为神经网络结构,由MLP构成,其中包含两层线性层以及相应的GELU激活函数,主要作用是对向量进行融合,作用与卷积神经网络中的1×1卷积操作相似,负责对各个位置的信息进行变换,并且可以投影到所需的维度。
如图3所示,再注意力transformer模块由再注意力模块(MHRT)和前馈网络(FFN)交替构成,再注意力Transformer模块整体公式流程如下:
z'l=MHRT(LN(zl-1))+zl-1 l=1…L
zl=MLP(LN(z'l))+z'l l=1…L
Figure BDA0003285997210000121
其中L为网络中再注意力Transformer模块的个数(也叫层数),MHRT为再注意力模块,MLP为前馈网络,LN为层标准化函数。
3.动态令牌稀疏化模块
动态令牌稀疏化模块主要由预测模块以及注意力屏蔽策略组成。
A、预测模块
参考图4,预测模块动态地对输入的令牌进行选择性的修剪,对于输入的每个令牌,预测模块生成一个特别的二进制决策掩码来决定哪些令牌是需要裁剪掉的以及哪些令牌是需要保留的。在再注意力Transformer模块中加入预测模块,剪枝令牌的数量随着每个预测模块的增加而增加,从而逐步增加令牌的修剪数量,实现分层的稀疏化处理。一个令牌在某个层被删除后,它在后续层中不会再参与注意力运算。通过这种方法可以分层地修剪大量令牌,从而大大的减少每秒的浮点数运算,提高吞吐量,且能保证准确率的下降控制在0.5%以内,实现了速度和精度之间的完美权衡。
具体来说,决定每个令牌是保留还是丢弃主要通过一个二进制决策掩码
Figure BDA0003285997210000131
来决定,其中N是令牌的个数。首先将决策掩码中的所有元素初始化为1,1和0分别代表保留和丢弃令牌,随着正向传播动态的更新掩码。当前决策
Figure BDA0003285997210000132
和所有令牌
Figure BDA0003285997210000133
作为预测模块的输入,使用MLP计算令牌的局部特征和全局特征,计算公式如下:
Figure BDA0003285997210000134
Figure BDA0003285997210000135
Figure BDA0003285997210000136
其中zlocal是计算局部特征,zglobal计算全局特征,全局特征由一个Agg函数实现,它的功能是聚合当前阶段所有参与运算的令牌的信息,并且实现一个简单的平均池化。ui表示第i个令牌,
Figure BDA0003285997210000137
表示第i个令牌的二进制掩码。
这里局部特征是对特定令牌的信息进行编码,而全局特征则包含整个花粉图像的上下文关联信息,两者都是有用的。因此,结合两者后通过一个线性变换来获得局部-全局的信息建模,并将它们放入到另一个MLP中,利用Softmax激活函数得到一个概率值π,即预测丢弃/保留令牌的概率,如下式:
Figure BDA0003285997210000138
Figure BDA0003285997210000139
其中分别用πi,1和πi,0来表示保留和丢弃第i个令牌的概率,N是令牌个数,通过从π中取样生成当前的决策D,利用
Figure BDA00032859972100001310
更新
Figure BDA00032859972100001311
其中是Hadamard积,这表明一旦一个令牌被丢弃,它将永远不会被使用。
对于输出的概率π,应用Gumbel-Softmax激活函数从概率π中进行采样得到当前的决策D,并利用
Figure BDA00032859972100001312
更新
Figure BDA00032859972100001313
公式如下:
D=Gumbel-Softmax(π)∈{0,1}N
其中Gumbel-Softmax函数的输出是一个one-hot向量,其期望值正好等于π,N是令牌个数。由于Gumbel-Softmax函数具有可微性,使得二进制决策掩码D可以进行反向传播,从而实现了端到端的训练。
B、注意力屏蔽策略
注意力屏蔽策略是基于二进制决策掩码的。在网络训练过程中,为了使令牌的数量始终保持一致同时阻止修剪后的令牌与其他令牌进行交互,采用注意力屏蔽策略来显示切断的已修剪令牌和其他令牌之间的联系,让已经被稀疏化的令牌不参与注意力的运算,保证预测结果只与保留下来的令牌有关,从而使模型更加稳定,即减少了冗余的浮点数运算,又保留了Transformer模型并行化计算的优点。
具体来说,在计算注意力特征图的时候加上一个注意力掩码矩阵
Figure BDA0003285997210000141
通过
Figure BDA0003285997210000142
显示切断的已修剪令牌和其他令牌之间的联系,
Figure BDA0003285997210000143
由二进制决策掩码
Figure BDA0003285997210000144
转化而成,计算公式如下:
Figure BDA0003285997210000145
Figure BDA0003285997210000146
Figure BDA0003285997210000147
其中Q,K表示为查询,键,
Figure BDA0003285997210000148
是注意力掩码矩阵,Gij=1表示第j个令牌将有助于第i个令牌的更新,
Figure BDA0003285997210000149
表示第j个令牌不会对出自己以外的任何令牌做出贡献。所以
Figure BDA00032859972100001410
只考虑当前阶段保留的令牌的注意力矩阵运算,同时保持令牌数量不变,通过此方法可以完全消除已修剪令牌的影响,并且在训练过程中保持N×N的大小不变。
二、在网络二
网络二体系结构如图5所示,其由多个卷积令牌编码模块和多个卷积Transformer模块组成。网络二采用了卷积神经网络的多层空间结构设计,它基于两种卷积的操作,即卷积令牌编码和卷积投影,网络结构通道一共分为三个阶段,每个阶段均包括卷积令牌编码模块和卷积Transformer模块,其中卷积Transformer模块中包含卷积投影,每个阶段包含卷积令牌编码和卷积投影两种卷积操作。
1、卷积令牌编码模块
首先,对输入花粉图像(或2D重塑令牌图)进行卷积令牌编码,具体来说将输入图像(重塑令牌图)的重叠块使用卷积操作投影到二维空间网格作为输入,用步长来控制重叠程度,然后对令牌进行额外的层标准化。卷积操作的目的是逐步减少每个阶段的令牌数量(即特征分辨率),同时增加令牌的宽度(即特征位数)从而实现空间下采样和局部感受野来增加像素空间表示的丰富性和多样性。
具体来说,给定一个2D图像或者从第i-1阶段得到的2D的重塑令牌图
Figure BDA0003285997210000151
作为第i阶段的输入,通过一个常规的卷积操作f(·)得到一组新的令牌图
Figure BDA0003285997210000152
f(·)的二维卷积核大小为s×s,卷积核个数为Ci,步长为s-o,填充为p,新的令牌图
Figure BDA0003285997210000153
高度和宽度计算公式为:
Figure BDA0003285997210000154
然后将得到的
Figure BDA0003285997210000155
展平成HiWi×Ci的1D令牌序列,其中HiWi是第i阶段令牌的个数,Ci是第i阶段每个令牌的特征维数,也就是卷积核的个数,Hi-1Wi-1是上一阶段令牌的个数,Ci-1是上一阶段每个令牌的特征维数。接着通过一个层标准化的操作,作为下一个卷积Transformer模块的输入,执行注意力运算。卷积令牌编码模块通过改变卷积运算的参数来调整每个阶段的令牌特征维数和数量,类似于卷积神经网络的特征图运算,通过这种方式,在每个阶段逐步减少令牌序列的长度,同时增加令牌的特征维数,使令牌能够在越来越大的空间上表示越来越复杂的视觉模式。
2、卷积Transformer模块
卷积Transformer模块如图6所示,和原版Transformer模块流程一样,由多头注意力机制(MHSA)和前馈网络(FFN)交替构成,并且可在每个块后应用残差连接。不同的是,它使用深度可分离卷积的多头注意力机制(MHSA)代替原来的位置线性投影进而形成卷积投影层。带卷积投影的Transformer模块是对原有Transformer模块的推广,目标是以一种动态的卷积注意力机制来实现全局像素信息的融合,特别是实现局部像素空间上下文的额外建模。
具体来说,首先将1D令牌序列重塑为2D令牌图,然后使用核大小为S的深度可分离卷积层来实现卷积投影,生成图片令牌的查询Q、键K、值V,并将Class令牌和蒸馏令牌线性投影到相同维度的Q、K、V,投影后的Q、K、V令牌图被展平成1D序列与图片令牌一起参与后续多头注意力运算。计算式如下:
Figure BDA0003285997210000161
其中
Figure BDA0003285997210000162
是第i层q/k/v矩阵的输入,xi是未卷积投影之前的令牌,Conv2d是深度可分离卷积,卷积投影如图7所示。
深度可分离卷积将标准化卷积分解为逐通道卷积(depthwise convolution)和逐点1×1卷积(pointwise convolution)。逐通道卷积操作把来自上一层的多通道特征图全部拆分为单个通道的特征图,分别对其进行单通道卷积,然后重新堆叠到一起,它调整了特征图的尺寸而没有改变通道数。逐点卷积采用1×1的卷积核对前面得到的特征图进行第二次卷积,每个卷积核的维度与上一层的特征图通道数一样,通过选择卷积核的个数来控制输出特征图的维度。通过使用深度可分离卷积,在损失较小准确率的情况下,大大的降低了模型卷积运算部分的理论参数量,并且通过量化参数减少每个参数的占用内存,深度可分离卷积如图8所示:
在第一次进行卷积令牌编码时,得到了H1W1个维度为C1的1D令牌,特别的,初始化两个可训练的C1维Class令牌和蒸馏令牌加入1D令牌序列,将得到的H1W1+2个1D令牌输入到接下来的N1个卷积Transformer模块中,对H1W1个维度为C1的1D令牌进行卷积投影得到H1W1组查询Q、键K、值V,对Class令牌和蒸馏令牌进行线性投影得到2组查询Q、键K、值V,然后对得到的H1W1+2组查询Q、键K、值V进行多头注意力运算来对全局关系进行建模。
初始化的Class令牌和蒸馏令牌将不通过卷积令牌编码模块,为了与各个阶段卷积令牌编码模块完的维度保持一致,Class令牌和蒸馏令牌在2、3阶段线性投影得到的Q、K、V,保持与当前阶段的图像令牌卷积投影后的Q、K、V相同的维度,并且分别进行入到N2、N3个Transformer模块中与图像令牌进行多头注意力运算。
三、交叉注意力蒸馏模块
网络一在标准ViT图像分类器中加入了再注意力机制和动态令牌稀疏化模块。再注意力机制通过一个端到端可学习的矩阵Θ来关注同一注意力层中不同头部之间的信息,通过矩阵Θ的变换来重新生成新的注意力特征图,解决了因模型层数加深而产生的注意力特征图雷同的问题,使得模型可以训练更深的层数。通过加深Transformer模块的层数,可以学习到更大范围内的令牌之间的关联性,并且使网络更深层的注意力特征图保持了多样性,不再冗余,融合了全局信息,并且在同等参数量的情况下获得了更好的花粉图像分类准确率。通过可视化注意力特征图发现在视觉Transformer模型的推理过程中,注意力运算主要集中在一部分信息丰富的令牌上,因此训练一个端到端的可学习预测模块,预测模块通过维护一个二进制决策掩码来决定令牌是丢弃还是保留,分层的修剪掉重要性较低的令牌,逐步的自适应选择出重要的令牌,进而加速了推理的过程。分层的剪枝策略可以通过剪枝66%左右的令牌数量,以达到减少31%-37%左右的吞吐量,提高40%左右的模型运行速度,并且准确率下降控制在0.5%以内,实现了速度和精度之间的完美权衡。
网络二将卷积操作融入Transformer中,通过卷积令牌编码模块和卷积投影来逐步的减少令牌的数量(即特征尺寸),扩大令牌深度(即特征维数)。使得网络二既有CNN的局部感受野,共享卷积权重,空间下采样等优点,又具备Transformer的可并行化计算和融合全局信息的优点。且不需要大型数据集进行预训练也能达到很高的图像分类准确率。
网络一致力于加深Transformer模块的层数,并且修剪冗余令牌来达到速度和精度之间的平衡,但需要大型数据集进行预训练才能对全局信息进行良好的建模,以实现高效数据训练。并且缺少卷积神经网络的归纳假设和对局部像素的建模。网络二通过在Transformer上引入卷积操作加入CNN网络的诸多优点,并且不需要大规模数据集进行预训练。但即使每个阶段都在减少令牌个数,网络二依旧拥有这3k-0.2k左右的庞大令牌数量。
交叉注意力蒸馏策略允许同时训练这两个网络,并且两个网络互为对方的老师和学生,利用各自的蒸馏令牌通过蒸馏损失的方式与老师网络进行交互,以此学习老师网络身上的优点。特别的,两个网络分别初始化了一个Class令牌和一个蒸馏令牌,蒸馏令牌与Class令牌在Transformer模块中与其他图像令牌进行交互,执行注意力运算。区别在于Class令牌的目标是与真实的标签值一致,而蒸馏令牌的目标是要与老师网络预测的标签一致,两个令牌朝着不同的方向收敛,最终产生相似而不相同的目标。
具体的做法是通过在两个网络的损失函数中加上一个蒸馏损失部分。普通的视觉Transformer分类器的输出是各个类别的数值Zi,某个类别的Zi数值越大,模型认为输入花粉图片属于该类别的可能性越大,各个类别的Zi汇总值叫Logits,Logits通过Softmax函数得到各个类别的概率值作为最终分类结果概率,取概率最大值的类别做为模型预测结果,并且将输出的Softmax值与真实标签取交叉熵损失,以降低损失值来进行反向传播来更新参数,Softmax函数如下所示:
Figure BDA0003285997210000181
其中qi,zi代表第i个类别的概率和Logits值,N代表总类别数,当Softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,所以引入了带温度的Softmax函数,公式如下:
Figure BDA0003285997210000191
其中τ代表温度,当τ=1时,就是标准的Softmax函数,τ越高,Softmax函数的输出概率分布越趋于平滑,其分布的信息熵越大,负标签携带的信息会被相对放大,使得模型训练时加大关注那些概率显著高于平均值的负标签。
在网络一和网络二中,除了网络自身Class令牌输出的交叉熵损失,额外的使用蒸馏令牌输出一个蒸馏损失,网络一输出的Class令牌和蒸馏令牌组成网络一的损失函数,其中Class令牌与真实标签组成交叉熵损失,蒸馏令牌与网络二的Class令牌输出组成蒸馏损失。网络二输出的Class令牌和蒸馏令牌组成网络二的损失函数,其中Class令牌与真实标签组成交叉熵损失,蒸馏令牌与网络一的Class令牌输出组成蒸馏损失。网络一和网络二中,Class令牌与图片真实标签组成交叉熵损失,蒸馏令牌与老师网络输出组成蒸馏损失。损失函数不仅有原来的Class令牌与真实标签的交叉熵损失值,还有各自的蒸馏令牌与老师网络的输出值取KL散度或交叉熵损失,最终的损失函数由Class令牌部分和蒸馏令牌部分两者加权输出,使蒸馏令牌与老师网络的输出值交互学习,迫使两个令牌向不同的方向收敛,从而学习到老师网络身上的优点。
本发明使用两种蒸馏策略来进行优化,分别是软交叉蒸馏策略和硬交叉蒸馏策略。软交叉蒸馏策略计算蒸馏令牌与老师网络输出的KL散度,得到两个输出概率分布之间的差异性,通过减小KL散度值来使蒸馏令牌在自身注意力机制运算的基础上,逐渐向老师网络的输出方向靠拢;硬交叉蒸馏策略直接与老师网络的概率输出取交叉熵损失,使用带温度的Softmax激活函数输出的交叉熵损失,允许学生网络学习老师网络输出中的高概率负标签携带的有用信息。
网络一和网络二的软交叉损失函数如下所示:
Figure BDA0003285997210000201
Figure BDA0003285997210000202
其中Li,Zi
Figure BDA0003285997210000203
代表网络i的软交叉损失函数,Class令牌输出值,蒸馏令牌输出值,LCE代表交叉熵损失函数,
Figure BDA0003285997210000204
是温度为τ的Softmax函数,y代表真实标签,KL代表KL散度,用于度量两个概率分布间差异的非对称性,λ是超参数。
上式加入的蒸馏损失部分,使得两个网络各自通过计算自身蒸馏令牌输出值和老师网络输出值之间的KL散度,得到两个概率分布之间的差异性,通过反向传播更新参数,减小概率分布之间的距离,从而使蒸馏令牌在自身注意力机制运算的基础上,朝着老师网络的目标输出方向收敛,达到了与老师网络交互学习的目的。
网络一和网络二的硬交叉损失函数如下所示:
Figure BDA0003285997210000205
Figure BDA0003285997210000206
其中
Figure BDA0003285997210000207
Zi
Figure BDA0003285997210000208
代表网络i的硬交叉损失函数,Class令牌输出值,蒸馏令牌输出值,LCE代表交叉熵损失函数,
Figure BDA0003285997210000209
是温度为τ的Softmax函数,y代表真实标签,通过调节τ的大小来决定从老师模型中学习到的负标签比例。
硬交叉蒸馏策略直接与老师网络的概率输出取交叉熵损失,同上述软交叉蒸馏一样,通过减少蒸馏令牌输出值与老师网络输出值的交叉熵损失值,使得蒸馏令牌向老师网络的输出值方向聚拢,以达到学习老师网络的目的,对于给定的图像,硬交叉蒸馏策略可能会根据具体的数据增强而变化,这归功于蒸馏令牌与老师网络的输出取交叉熵损失。其中带温度的Softmax函数允许学生网络也可以学习到老师网络输出值中的高概率负标签所携带的信息,网络一和网络二互为对方的学生和老师,并且它们的输出空间向量趋于不同的方向,通过带温度的Softmax函数将老师网络的输出空间中学到的特征表达迁移到学生网络中,起到了互相监督和提高泛化能力的作用。
在测试时,将网络一和网络二各自输出的Class令牌和蒸馏令牌进行线性映射为概率向量,之后通过Softmax激活函数输出归一化的类别概率值,两个网络将自己的Class令牌和蒸馏令牌的Softmax值相加来得到预测结果,最后取网络一、网络二的Top-1精度最大值作为最终预测结果。公式如下:
Figure BDA0003285997210000211
Yfinal=max{y1,y2}
其中yi,zi
Figure BDA0003285997210000212
代表第i个网络的预测结果,Class令牌输出值,蒸馏令牌输出值,linear为线性函数,W和b分别表示线性映射的权值矩阵及偏置,Yfinal代表两个两个网络的最终预测结果。
本发明的整体流程可表述如下:
(1)网络一使用线性编码器将图片编码为图片令牌;
(2)网络一加入Class令牌和蒸馏令牌用于后续分类、与老师网络交互;
(3)网络一使用再注意力机制加深令牌交互的模块数量;
(4)网络一使用动态令牌稀疏化模块修剪冗余令牌;
(5)网络二使用卷积令牌编码模块对图片进行编码并提取特征;
(6)网络二加入Class令牌和蒸馏令牌用于后续分类、与老师网络交互;
(7)网络二使用卷积投影生成查询Q、键K、值V执行注意力运算;
(8)网络一和网络二同时进行训练;
(9)网络一和网络二利用两种交叉蒸馏策略,通过减少蒸馏损失,让蒸馏令牌趋向老师网络的输出空间,学习老师网络的空间特征表达。

Claims (10)

1.基于交叉注意力蒸馏Transformer的花粉图像分类方法,采用网络一和网络二的架构实现,其特征在于:
所述网络一将输入花粉图片进行分割,然后线性投影为图片令牌,并加入一个蒸馏令牌和一个Class令牌得到令牌序列一,利用再注意力transformer模块的再注意力机制消除令牌序列一中的注意力雷同问题,再利用动态令牌稀疏化模块的动态令牌稀疏化去除冗余令牌,多次经过再注意力Transformer模块和动态令牌稀疏化模块后,输出令牌序列二,将令牌序列二中的Class令牌和蒸馏令牌加权进行预测分类;
所述网络二将输入花粉图片进行分割,然后卷积编码为图片令牌,并加入一个蒸馏令牌和一个Class令牌得到令牌序列三,利用卷积Transformer模块的卷积注意力机制实现局部感受野,并共享卷积权重,利用卷积令牌编码模块减少令牌数量,同时增加令牌宽度,多次经过卷积Transformer模块和卷积令牌编码模块后,输出得到令牌序列四,将得到令牌序列四中的Class令牌和蒸馏令牌加权进行预测分类;
所述网络一和网络二中,蒸馏令牌和Class令牌均与图片令牌进行注意力运算并且输出概率值,取网络一、网络二的最大准确率作为最终预测分类结果;
所述网络一和网络二的损失函数组成交叉注意力蒸馏模块,交叉注意力蒸馏模块同时训练网络一和网络二,网络一和网络二互为对方的老师和学生,通过各自的蒸馏令牌以蒸馏损失的目标输出方式与老师网络进行交互。
2.根据权利要求1所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,所述网络一由Transformer编码器、再注意力transformer模块和动态令牌稀疏化模块组成,所述再注意力transformer模块和动态令牌稀疏化模块均有多个,依次交替设置,且再注意力transformer模块的数量较动态令牌稀疏化模块的数量多一个,所述再注意力transformer模块由再注意力模块和前馈网络组成。
3.根据权利要求2所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,所述Transformer编码器将输入花粉图片
Figure FDA0003285997200000021
重新划分为2D图像块序列
Figure FDA0003285997200000022
其中H,W是输入图像的长和宽,C是通道数,
Figure FDA0003285997200000023
是图像集合,N是产生的图像块个数,产生的图像块即所述补丁,补丁的维度大小是(P2·C),(P,P)是每个补丁的分辨率,C是每个补丁的维数,对每个补丁使用可训练的线性投影得到(N,D)的二维图片令牌向量,线性投影后的补丁即为图片令牌,然后初始化D维的Class令牌和D维的蒸馏令牌,加入图片令牌序列;
所述Class令牌和蒸馏令牌是初始化的可学习的嵌入向量,Class令牌和蒸馏令牌通过与图片令牌进行注意力运算,对图片令牌之间的全局关系进行建模,并且融合所有图片令牌的信息,最终与分类器相连进行类别预测。
4.根据权利要求3所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,使用位置编码加入令牌序列,所述位置编码是初始化的与输入令牌序列相同维度的可训练变量,通过位置编码对无序的令牌进行编码排序,以保留每个令牌的绝对或相对位置信息。
5.根据权利要求2或3所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,所述再注意力模块建立在多头注意力机制的基础上,注意力机制将每个输入令牌线性投影为可训练的查询Q、键K、值V三组值,通过所有令牌的K对Q进行点积生成注意力图,并且除以缩放因子
Figure FDA0003285997200000025
经过Softmax激活函数以获得V的权重输出到下一个再注意力模块,该注意力图表示每个再注意Transformer模块内所有令牌之间的全局相关性;多头注意力机制利用不同的权值矩阵将每个输入令牌投影到h个不同的子空间,每个子空间并行地执行注意力机制,将它们的输出值连接起来再次进行投影,得到再注意力模块的输出,再注意力机制通过定义一个端到端可训练的变换矩阵
Figure FDA0003285997200000024
使用变换矩阵动态聚合同一再注意力Transformer模块中不同头部之间的注意力映射图,重新映射出新的注意力图。
6.根据权利要求2或3所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,所述前馈网络由MLP构成,其中包含两层线性层以及GELU激活函数,前馈网络用于对向量进行融合,对各个位置的信息进行变换,并且投影到所需的维度。
7.根据权利要求2或3所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,所述动态令牌稀疏化模块由预测模块和注意力屏蔽策略组成,对于每个输入的令牌实例,预测模块生成一个二进制决策掩码,来决定每个令牌的保留和修剪,预测模块被添加到再注意力Transformer模块中,从而逐步增加令牌的修剪数量,实现分层的稀疏化处理,一个令牌被修剪后,将不再参与后续的注意力运算;
所述预测模块以二进制决策掩码
Figure FDA0003285997200000031
和令牌
Figure FDA0003285997200000032
作为输入,使用MLP计算令牌的局部特征和全局特征,结合两者后通过一个线性变换,再利用Softmax激活函数得到一个概率值π,应用Gumbel-Softmax激活函数从π中取样得到当前的决策D,利用
Figure FDA0003285997200000033
更新
Figure FDA0003285997200000034
所述注意力屏蔽策略基于二进制决策掩码,在计算注意力映射图时加上一个注意力掩码矩阵
Figure FDA0003285997200000035
Figure FDA0003285997200000036
由二进制决策掩码生成,通过
Figure FDA0003285997200000037
显示切断的已修剪令牌和其他令牌之间的联系,只考虑当前阶段保留令牌之间的注意力矩阵运算,同时保持令牌数量不变。
8.根据权利要求1所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,所述网络二由多个卷积令牌编码模块和多个卷积Transformer模块组成,卷积令牌编码模块和卷积Transformer模块依次交替设置,所述卷积Transformer模块由多头注意力机制和前馈网络组成;
所述卷积令牌编码模块将分割得到的2D图像或上一阶段输出的2D重塑令牌图
Figure FDA0003285997200000038
作为本阶段的输入,通过卷积操作得到一组新的令牌图
Figure FDA0003285997200000041
f(·)表示卷积操作函数,然后将得到的令牌展平为HiWi×Ci的1D令牌序列,其中HiWi是本阶段令牌的个数,Ci是本阶段每个令牌的特征维数,Hi-1Wi-1是上一阶段令牌的个数,Ci-1是上一阶段每个令牌的特征维数,接着通过当前阶段的卷积Transformer模块执行注意力运算;
所述卷积Transformer模块由多头注意力机制和前馈网络交替构成,将本阶段卷积令牌编码模块输出的1D令牌序列重塑为2D令牌图,使用深度可分离卷积生成图片令牌的查询Q、键K、值V,并且将Class令牌和蒸馏令牌线性投影到相同维度的Q、K、V,与图片令牌一起执行多头注意力运算。
9.根据权利要求1所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,
所述网络一输出的Class令牌和蒸馏令牌组成网络一的损失函数,其中Class令牌与真实标签组成交叉熵损失,蒸馏令牌与网络二的Class令牌输出组成蒸馏损失;
所述网络二输出的Class令牌和蒸馏令牌组成网络二的损失函数,其中Class令牌与真实标签组成交叉熵损失,蒸馏令牌与网络一的Class令牌输出组成蒸馏损失;
所述网络一和网络二中,Class令牌与图片真实标签组成交叉熵损失,蒸馏令牌与老师网络输出组成蒸馏损失。
10.根据权利要求1所述基于交叉注意力蒸馏Transformer的花粉图像分类方法,其特征在于,所述网络一和网络二中,除了自身Class令牌输出的交叉熵损失,还使用蒸馏令牌输出一个蒸馏损失,并且使用蒸馏策略来进行优化,所述蒸馏策略为软交叉蒸馏策略或硬交叉蒸馏策略;
所述软交叉蒸馏策略,计算蒸馏令牌与老师网络输出的KL散度,得到两个输出概率分布之间的差异性,通过减小KL散度值来使蒸馏令牌在自身注意力机制运算的基础上,逐渐向老师网络的输出方向靠拢;
所述硬交叉蒸馏策略,直接与老师网络的概率输出取交叉熵损失,使用带温度的Softmax激活函数输出的交叉熵损失,允许学生网络学习老师网络输出中的高概率负标签携带的有用信息。
CN202111147668.5A 2021-09-29 2021-09-29 基于交叉注意力蒸馏Transformer的花粉图像分类方法 Active CN113887610B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111147668.5A CN113887610B (zh) 2021-09-29 2021-09-29 基于交叉注意力蒸馏Transformer的花粉图像分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111147668.5A CN113887610B (zh) 2021-09-29 2021-09-29 基于交叉注意力蒸馏Transformer的花粉图像分类方法

Publications (2)

Publication Number Publication Date
CN113887610A true CN113887610A (zh) 2022-01-04
CN113887610B CN113887610B (zh) 2024-02-02

Family

ID=79007683

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111147668.5A Active CN113887610B (zh) 2021-09-29 2021-09-29 基于交叉注意力蒸馏Transformer的花粉图像分类方法

Country Status (1)

Country Link
CN (1) CN113887610B (zh)

Cited By (19)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114049527A (zh) * 2022-01-10 2022-02-15 湖南大学 基于在线协作与融合的自我知识蒸馏方法与系统
CN114463559A (zh) * 2022-01-29 2022-05-10 新疆爱华盈通信息技术有限公司 图像识别模型的训练方法、装置、网络和图像识别方法
CN114494791A (zh) * 2022-04-06 2022-05-13 之江实验室 一种基于注意力选择的transformer运算精简方法及装置
CN114648664A (zh) * 2022-03-23 2022-06-21 北京工业大学 一种基于多视角信息融合的图像分类方法
CN114663952A (zh) * 2022-03-28 2022-06-24 北京百度网讯科技有限公司 对象分类方法、深度学习模型的训练方法、装置和设备
CN114842253A (zh) * 2022-05-04 2022-08-02 哈尔滨理工大学 基于自适应光谱空间核结合ViT高光谱图像分类方法
CN114926460A (zh) * 2022-07-19 2022-08-19 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 眼底图像分类模型的训练方法、眼底图像分类方法及系统
CN115099294A (zh) * 2022-03-21 2022-09-23 昆明理工大学 一种基于特征增强和决策融合的花卉图像分类算法
CN115170638A (zh) * 2022-07-13 2022-10-11 东北林业大学 一种双目视觉立体匹配网络系统及其构建方法
CN115205986A (zh) * 2022-08-09 2022-10-18 山东省人工智能研究院 一种基于知识蒸馏与transformer的假视频检测方法
CN115272369A (zh) * 2022-07-29 2022-11-01 苏州大学 动态聚合变换器网络及视网膜血管分割方法
CN115471724A (zh) * 2022-11-02 2022-12-13 青岛杰瑞工控技术有限公司 一种基于自适应归一化的细粒度鱼类疫病识别融合算法
CN115761437A (zh) * 2022-11-09 2023-03-07 北京百度网讯科技有限公司 基于视觉转换器的图像处理方法、训练方法和电子设备
CN115797751A (zh) * 2023-01-18 2023-03-14 中国科学技术大学 基于对比掩码图像建模的图像分析方法与系统
CN115937567A (zh) * 2022-09-07 2023-04-07 北京交通大学 一种基于小波散射网络和ViT的图像分类方法
CN116091849A (zh) * 2023-04-11 2023-05-09 山东建筑大学 基于分组解码器的轮胎花纹分类方法、系统、介质及设备
CN116152240A (zh) * 2023-04-18 2023-05-23 厦门微图软件科技有限公司 一种基于知识蒸馏的工业缺陷检测模型压缩方法
CN116385839A (zh) * 2023-06-05 2023-07-04 深圳须弥云图空间科技有限公司 图像预训练模型的训练方法、装置、电子设备及存储介质
CN117422911A (zh) * 2023-10-20 2024-01-19 哈尔滨工业大学 一种协同学习驱动的多类别全切片数字病理图像分类系统

Citations (19)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103986890A (zh) * 2014-05-04 2014-08-13 苏州乐聚一堂电子科技有限公司 文字特效卡拉ok手机点歌系统
CN110506279A (zh) * 2017-04-14 2019-11-26 易享信息技术有限公司 采用隐树注意力的神经机器翻译
WO2019240964A1 (en) * 2018-06-12 2019-12-19 Siemens Aktiengesellschaft Teacher and student based deep neural network training
CN111144490A (zh) * 2019-12-26 2020-05-12 南京邮电大学 一种基于轮替知识蒸馏策略的细粒度识别方法
CN111444709A (zh) * 2020-03-09 2020-07-24 腾讯科技(深圳)有限公司 文本分类方法、装置、存储介质及设备
CN111611377A (zh) * 2020-04-22 2020-09-01 淮阴工学院 基于知识蒸馏的多层神经网络语言模型训练方法与装置
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
US20200342316A1 (en) * 2017-10-27 2020-10-29 Google Llc Attention-based decoder-only sequence transduction neural networks
CN112230990A (zh) * 2020-11-10 2021-01-15 北京邮电大学 一种基于层级注意力神经网络的程序代码查重方法
WO2021023202A1 (zh) * 2019-08-07 2021-02-11 交叉信息核心技术研究院(西安)有限公司 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法
CN112380379A (zh) * 2020-11-18 2021-02-19 北京字节跳动网络技术有限公司 歌词特效展示方法、装置、电子设备及计算机可读介质
CN112970024A (zh) * 2018-08-22 2021-06-15 Netapp股份有限公司 大型文档语料库中的令牌匹配
US20210183484A1 (en) * 2019-12-06 2021-06-17 Surgical Safety Technologies Inc. Hierarchical cnn-transformer based machine learning
US20210182662A1 (en) * 2019-12-17 2021-06-17 Adobe Inc. Training of neural network based natural language processing models using dense knowledge distillation
CN112990296A (zh) * 2021-03-10 2021-06-18 中科人工智能创新技术研究院(青岛)有限公司 基于正交相似度蒸馏的图文匹配模型压缩与加速方法及系统
US20210232773A1 (en) * 2020-01-23 2021-07-29 Salesforce.Com, Inc. Unified Vision and Dialogue Transformer with BERT
US20210232753A1 (en) * 2021-01-28 2021-07-29 Microsoft Technology Licensing Llc Ml using n-gram induced input representation
CN113255915A (zh) * 2021-05-20 2021-08-13 深圳思谋信息科技有限公司 基于结构化实例图的知识蒸馏方法、装置、设备和介质
CN113408343A (zh) * 2021-05-12 2021-09-17 杭州电子科技大学 基于双尺度时空分块互注意力的课堂动作识别方法

Patent Citations (19)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN103986890A (zh) * 2014-05-04 2014-08-13 苏州乐聚一堂电子科技有限公司 文字特效卡拉ok手机点歌系统
CN110506279A (zh) * 2017-04-14 2019-11-26 易享信息技术有限公司 采用隐树注意力的神经机器翻译
US20200342316A1 (en) * 2017-10-27 2020-10-29 Google Llc Attention-based decoder-only sequence transduction neural networks
WO2019240964A1 (en) * 2018-06-12 2019-12-19 Siemens Aktiengesellschaft Teacher and student based deep neural network training
CN112970024A (zh) * 2018-08-22 2021-06-15 Netapp股份有限公司 大型文档语料库中的令牌匹配
WO2021023202A1 (zh) * 2019-08-07 2021-02-11 交叉信息核心技术研究院(西安)有限公司 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法
US20210183484A1 (en) * 2019-12-06 2021-06-17 Surgical Safety Technologies Inc. Hierarchical cnn-transformer based machine learning
US20210182662A1 (en) * 2019-12-17 2021-06-17 Adobe Inc. Training of neural network based natural language processing models using dense knowledge distillation
CN111144490A (zh) * 2019-12-26 2020-05-12 南京邮电大学 一种基于轮替知识蒸馏策略的细粒度识别方法
US20210232773A1 (en) * 2020-01-23 2021-07-29 Salesforce.Com, Inc. Unified Vision and Dialogue Transformer with BERT
CN111444709A (zh) * 2020-03-09 2020-07-24 腾讯科技(深圳)有限公司 文本分类方法、装置、存储介质及设备
CN111611377A (zh) * 2020-04-22 2020-09-01 淮阴工学院 基于知识蒸馏的多层神经网络语言模型训练方法与装置
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
CN112230990A (zh) * 2020-11-10 2021-01-15 北京邮电大学 一种基于层级注意力神经网络的程序代码查重方法
CN112380379A (zh) * 2020-11-18 2021-02-19 北京字节跳动网络技术有限公司 歌词特效展示方法、装置、电子设备及计算机可读介质
US20210232753A1 (en) * 2021-01-28 2021-07-29 Microsoft Technology Licensing Llc Ml using n-gram induced input representation
CN112990296A (zh) * 2021-03-10 2021-06-18 中科人工智能创新技术研究院(青岛)有限公司 基于正交相似度蒸馏的图文匹配模型压缩与加速方法及系统
CN113408343A (zh) * 2021-05-12 2021-09-17 杭州电子科技大学 基于双尺度时空分块互注意力的课堂动作识别方法
CN113255915A (zh) * 2021-05-20 2021-08-13 深圳思谋信息科技有限公司 基于结构化实例图的知识蒸馏方法、装置、设备和介质

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
HUGO TOUVRON 等: "Training data-efficient image transformers & distillation through attention", 《ARXIV.ORG》, pages 1 - 22 *
王晓莉;叶东毅;: "基于字词特征自注意力学习的社交媒体文本分类方法", 模式识别与人工智能, no. 04, pages 4 - 11 *
胡滨: "基于知识蒸馏的高效生物医学命名实体识别模型", 《清华大学学报(自然科学版)》, vol. 61, no. 09, pages 936 - 942 *

Cited By (28)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114049527A (zh) * 2022-01-10 2022-02-15 湖南大学 基于在线协作与融合的自我知识蒸馏方法与系统
CN114049527B (zh) * 2022-01-10 2022-06-14 湖南大学 基于在线协作与融合的自我知识蒸馏方法与系统
CN114463559A (zh) * 2022-01-29 2022-05-10 新疆爱华盈通信息技术有限公司 图像识别模型的训练方法、装置、网络和图像识别方法
CN114463559B (zh) * 2022-01-29 2024-05-10 芯算一体(深圳)科技有限公司 图像识别模型的训练方法、装置、网络和图像识别方法
CN115099294A (zh) * 2022-03-21 2022-09-23 昆明理工大学 一种基于特征增强和决策融合的花卉图像分类算法
CN114648664A (zh) * 2022-03-23 2022-06-21 北京工业大学 一种基于多视角信息融合的图像分类方法
CN114663952A (zh) * 2022-03-28 2022-06-24 北京百度网讯科技有限公司 对象分类方法、深度学习模型的训练方法、装置和设备
CN114494791A (zh) * 2022-04-06 2022-05-13 之江实验室 一种基于注意力选择的transformer运算精简方法及装置
CN114494791B (zh) * 2022-04-06 2022-07-08 之江实验室 一种基于注意力选择的transformer运算精简方法及装置
CN114842253A (zh) * 2022-05-04 2022-08-02 哈尔滨理工大学 基于自适应光谱空间核结合ViT高光谱图像分类方法
CN115170638B (zh) * 2022-07-13 2023-04-18 东北林业大学 一种双目视觉立体匹配网络系统及其构建方法
CN115170638A (zh) * 2022-07-13 2022-10-11 东北林业大学 一种双目视觉立体匹配网络系统及其构建方法
CN114926460A (zh) * 2022-07-19 2022-08-19 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 眼底图像分类模型的训练方法、眼底图像分类方法及系统
CN114926460B (zh) * 2022-07-19 2022-10-25 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 眼底图像分类模型的训练方法、眼底图像分类方法及系统
CN115272369A (zh) * 2022-07-29 2022-11-01 苏州大学 动态聚合变换器网络及视网膜血管分割方法
CN115205986A (zh) * 2022-08-09 2022-10-18 山东省人工智能研究院 一种基于知识蒸馏与transformer的假视频检测方法
CN115205986B (zh) * 2022-08-09 2023-05-19 山东省人工智能研究院 一种基于知识蒸馏与transformer的假视频检测方法
CN115937567A (zh) * 2022-09-07 2023-04-07 北京交通大学 一种基于小波散射网络和ViT的图像分类方法
CN115471724A (zh) * 2022-11-02 2022-12-13 青岛杰瑞工控技术有限公司 一种基于自适应归一化的细粒度鱼类疫病识别融合算法
CN115761437B (zh) * 2022-11-09 2024-02-06 北京百度网讯科技有限公司 基于视觉转换器的图像处理方法、训练方法和电子设备
CN115761437A (zh) * 2022-11-09 2023-03-07 北京百度网讯科技有限公司 基于视觉转换器的图像处理方法、训练方法和电子设备
CN115797751A (zh) * 2023-01-18 2023-03-14 中国科学技术大学 基于对比掩码图像建模的图像分析方法与系统
CN116091849A (zh) * 2023-04-11 2023-05-09 山东建筑大学 基于分组解码器的轮胎花纹分类方法、系统、介质及设备
CN116152240A (zh) * 2023-04-18 2023-05-23 厦门微图软件科技有限公司 一种基于知识蒸馏的工业缺陷检测模型压缩方法
CN116385839A (zh) * 2023-06-05 2023-07-04 深圳须弥云图空间科技有限公司 图像预训练模型的训练方法、装置、电子设备及存储介质
CN116385839B (zh) * 2023-06-05 2023-08-11 深圳须弥云图空间科技有限公司 图像预训练模型的训练方法、装置、电子设备及存储介质
CN117422911A (zh) * 2023-10-20 2024-01-19 哈尔滨工业大学 一种协同学习驱动的多类别全切片数字病理图像分类系统
CN117422911B (zh) * 2023-10-20 2024-04-30 哈尔滨工业大学 一种协同学习驱动的多类别全切片数字病理图像分类系统

Also Published As

Publication number Publication date
CN113887610B (zh) 2024-02-02

Similar Documents

Publication Publication Date Title
CN113887610A (zh) 基于交叉注意力蒸馏Transformer的花粉图像分类方法
CN110263912B (zh) 一种基于多目标关联深度推理的图像问答方法
Wang et al. Development of convolutional neural network and its application in image classification: a survey
Chen et al. Learning student networks via feature embedding
CN111274869B (zh) 基于并行注意力机制残差网进行高光谱图像分类的方法
CN115294407B (zh) 基于预习机制知识蒸馏的模型压缩方法及系统
CN111523546A (zh) 图像语义分割方法、系统及计算机存储介质
CN110321967A (zh) 基于卷积神经网络的图像分类改进算法
CN113379655B (zh) 一种基于动态自注意力生成对抗网络的图像合成方法
Peng et al. CNN and transformer framework for insect pest classification
CN116580440B (zh) 基于视觉transformer的轻量级唇语识别方法
Han et al. Batch-normalized Mlpconv-wise supervised pre-training network in network
Liu et al. RB-Net: Training highly accurate and efficient binary neural networks with reshaped point-wise convolution and balanced activation
Chen et al. Coupled end-to-end transfer learning with generalized fisher information
CN116796810A (zh) 一种基于知识蒸馏的深度神经网络模型压缩方法及装置
Jeon et al. Leveraging angular distributions for improved knowledge distillation
Jiang et al. Cross-level reinforced attention network for person re-identification
CN114065834B (zh) 一种模型训练方法、终端设备及计算机存储介质
CN115273046A (zh) 一种用于智能视频分析的驾驶员行为识别方法
He et al. ACSL: Adaptive correlation-driven sparsity learning for deep neural network compression
Ding et al. A novel two-stage learning pipeline for deep neural networks
Li et al. Prototype-guided Cross-task Knowledge Distillation for Large-scale Models
Medvedev et al. Optimization of the local search in the training for SAMANN neural network
Qian et al. No-reference image quality assessment based on automatic machine learning
CN117036698B (zh) 一种基于双重特征知识蒸馏的语义分割方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant