CN115937567A - 一种基于小波散射网络和ViT的图像分类方法 - Google Patents
一种基于小波散射网络和ViT的图像分类方法 Download PDFInfo
- Publication number
- CN115937567A CN115937567A CN202211089518.8A CN202211089518A CN115937567A CN 115937567 A CN115937567 A CN 115937567A CN 202211089518 A CN202211089518 A CN 202211089518A CN 115937567 A CN115937567 A CN 115937567A
- Authority
- CN
- China
- Prior art keywords
- scatvit
- image
- network
- model
- wavelet scattering
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 27
- 238000013145 classification model Methods 0.000 claims abstract description 37
- 238000012549 training Methods 0.000 claims abstract description 30
- 238000007781 pre-processing Methods 0.000 claims abstract description 6
- 239000011159 matrix material Substances 0.000 claims description 39
- 239000013598 vector Substances 0.000 claims description 22
- 230000006870 function Effects 0.000 claims description 18
- 238000010606 normalization Methods 0.000 claims description 13
- 230000009466 transformation Effects 0.000 claims description 13
- 238000012795 verification Methods 0.000 claims description 12
- 238000004364 calculation method Methods 0.000 claims description 11
- 241001274197 Scatophagus argus Species 0.000 claims description 10
- 230000004913 activation Effects 0.000 claims description 6
- 238000011478 gradient descent method Methods 0.000 claims description 3
- 238000013507 mapping Methods 0.000 claims description 3
- 230000007547 defect Effects 0.000 abstract description 7
- 230000008439 repair process Effects 0.000 abstract description 3
- 230000036961 partial effect Effects 0.000 abstract description 2
- 230000000903 blocking effect Effects 0.000 abstract 2
- 238000013519 translation Methods 0.000 description 10
- 238000013527 convolutional neural network Methods 0.000 description 9
- 238000013528 artificial neural network Methods 0.000 description 7
- 238000012545 processing Methods 0.000 description 6
- 238000005520 cutting process Methods 0.000 description 5
- 238000010586 diagram Methods 0.000 description 5
- 238000005516 engineering process Methods 0.000 description 5
- 230000008569 process Effects 0.000 description 5
- 238000011176 pooling Methods 0.000 description 4
- 238000012360 testing method Methods 0.000 description 4
- 230000000007 visual effect Effects 0.000 description 4
- 238000006243 chemical reaction Methods 0.000 description 3
- 238000000605 extraction Methods 0.000 description 3
- 238000003709 image segmentation Methods 0.000 description 3
- 230000011218 segmentation Effects 0.000 description 3
- 230000008901 benefit Effects 0.000 description 2
- 230000015556 catabolic process Effects 0.000 description 2
- 230000006378 damage Effects 0.000 description 2
- 238000006731 degradation reaction Methods 0.000 description 2
- 238000009795 derivation Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000000670 limiting effect Effects 0.000 description 2
- 230000007246 mechanism Effects 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 230000002829 reductive effect Effects 0.000 description 2
- 208000037170 Delayed Emergence from Anesthesia Diseases 0.000 description 1
- 230000003213 activating effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 230000008034 disappearance Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000004880 explosion Methods 0.000 description 1
- 239000002360 explosive Substances 0.000 description 1
- 230000004927 fusion Effects 0.000 description 1
- 230000006698 induction Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 230000000087 stabilizing effect Effects 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 238000010200 validation analysis Methods 0.000 description 1
Images
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Image Analysis (AREA)
Abstract
本发明提供了一种基于小波散射网络和ViT的图像分类方法。该方法包括:对图像数据进行预处理,获取带标签的预处理后的图像数据;构建基于小波散射网络和ViT的分类模型ScatViT,设定模型参数;设定训练参数,利用预处理后的图像数据训练分类模型ScatViT;利用训练好的分类模型ScatViT对待分类图像进行分类处理。本发明结合小波散射网络和ViT两个模型,提出了将图像切块操作改为使用小波散射网络提取图像特征的ScatViT模型,该模型改进了小波散射网络由于滤波器权重固定导致的无法从数据中学习的缺陷,修复了由于切块操作所丢失的部分信息,并排除了与图像分类无关信息的干扰,能更准确地表达图像的特征信息。
Description
技术领域
本发明涉及计算机视觉领域的图像分类技术,是一种基于小波散射网络和ViT的图像分类方法。
背景技术
图像分类是指根据图像所包含的信息对不同类别的图像进行区分,为每个图像分配预设范围内的类别标签,从而达到理解图像信息的目的。作为图像识别中最为基础的技术,图像分类在计算机视觉领域起着至关重要的作用。已有的图像分类方法包括卷积神经网络、小波散射网络以及基于Transformer的网络等。
卷积神经网络从人类视觉系统演变而来,是一类包含卷积计算的神经网络。1998年,Yann LeCun等在图像分类任务上首次使用卷积神经网络,提出LeNet,通过接连使用卷积和池化层的组合提取图像特征,采用了局部连接、权重共享、池化等操作,在手写数字识别任务上取得了巨大成功。但是该网络结构复杂度低且网络深度较浅,因而图像特征提取能力一般。2012年,Alex Krizhevsky等在大规模图片数据集ImageNet上应用卷积神经网络,提出AlexNet,获得了当年大规模视觉识别挑战赛的冠军,将错误率降低了10个百分点,引起图像领域的极大震撼。与LeNet相比,AlexNet具有更深的网络结构,计算量增大,具有更多的参数且可以有效避免过拟合现象。2014年,Simonyan和Zisserman提出了VGG网络结构,VGGNet采用小卷积核,层数更深、特征图更宽。VGGNet结构简单、性能优秀,其网络结构的独特设计,为构建深度神经网络提供了一般化方法。同年,Christian Szegedy等提出了GoogLeNet,并取得了当年大规模视觉识别挑战赛的冠军。相比于卷积神经网络中单纯的“卷积+池化+全连接”的操作技术,GoogLeNet引入Inception结构,并用全局平均池化替换了原始结构中的全连接层。2015年,Kaiming He和Jian Sun等提出了ResNet来解决深度神经网络的退化问题,其核心思想是使用Residual Connection和残差块,在大规模数据集ImagNet上将错误率降低至3.57%,超过了人眼识别的能力,后续深度神经网络的设计也在不断借鉴Residual Connection的操作。近年来深度学习在计算机视觉领域中的图像分类、检测等任务上不断获得成功,很大程度上是因为卷积神经网络的不断进步。但是,卷积操作缺乏对图像数据的全局理解,受到局部相互作用的限制,无法充分利用图像数据的全局信息。此外,卷积神经网络在训练过程中对数据需求量大,网络参数多,同时存在梯度消失、梯度爆炸、网络退化、可解释性差等问题。
小波散射网络是一种基于小波变换的非反馈式神经网络,该网络作为特征提取器具有如非扩张性、微小形变稳定性、平移不变性的良好性质,经过了严格的数学推导和理论证明。然而,在实际应用中,图像往往还要包含诸如遮挡、杂乱背景等更加复杂的变化。在这些情况中,仅仅使用小波散射网络是无法捕捉到有效特征表达的。小波散射网络是非反馈式结构,采用预先定义的权重固定的滤波器,权重固定的滤波器无需通过学习得到,能够降低计算复杂度,是小波散射网络的一大优点。但事实上,权重固定的滤波器意味着小波散射网络只能捕捉到如平移、旋转等刚性变换,而对更复杂的变化无能为力。
Transformer是一种完全基于自注意力机制、能够并行化处理数据的深度神经网络。由于其对于大规模数据表现出来的巨大潜力,该模型一直受到研究者们的关注。相比卷积神经网络,Transformer的自注意力机制利用全局信息,能挖掘长距离的依赖关系,根据不同的任务目标学习最合适的归纳偏置。近年来,基于Transformer的模型不断涌现,给计算机视觉领域注入了新的活力,引领了新的变革。诞生于自然语言处理领域的Transformer应用于计算机视觉领域的主要困难在于图像数据转化为序列数据所带来的爆炸式的计算量增长。事实上,如果直接将大小为224×224的图像按像素点转化为序列数据,将会得到长度50176的序列数据。
ViT通过对图像做切块展平处理来解决伴随数据转化而来的计算量陡增问题,是最早将Transformer应用于图像分类任务的模型,其结构不依赖卷积神经网络,在许多大规模数据集上面实现了非常好的分类效果,但缺陷也十分明显,主要表现在:将原本应用于自然语言处理领域的Transformer引入计算机视觉领域,自然需要将图像数据转换成序列数据,而这种先切块后展平的转换必然伴随着图片内部结构的破坏,从而导致分类性能的下降。
现有的图像分类方法包括小波散射网络和ViT方案。技术方案如下:设小波函数为ψ,那么对其进行2j的尺度缩放与r的旋转,可得到小波如下:
ψλ(u)=2-2jψ(2-jr-1u)
其中ψλ(u)是经过特定的尺度缩放和旋转后得到的小波函数,j和r分别是尺度参数和角度参数,j∈Z确定尺度,r∈G确定方向,G是平面旋转群,λ=2-jr∈2-Z×G=Λ。对于图像x(u)和有序路径p=(λ1,λ2,...,λm),其对应的小波散射变换为:
遍历所有可能的路径然后拼接即可得到最终的输出结果,记为X:
定义1:Γ是平移不变算子,若Γ满足对任意x(u)∈L2(R2),任意c∈R2:
Γ(x(u-c))=Γ(x(u))
定义2:Γ是非扩张算子,若Γ满足对任意x(u)、y(u)∈L2(R2),存在C>0:
||Γ(x(u))-Γ(y(u))||≤C||x(u)-y(u)||
定义3:Γ是形变稳定算子,若Γ满足对任意x(u)∈L2(R2),任意不恒为常数的形变算子τ:R2→R2,存在C>0:
ViT的结构包括PatchEmbedding层、Encoder层和MLPHead层。在实际应用中,经常要在大规模数据集上进行预训练,之后根据迁移学习的原理进行微调训练。PatchEmbedding层通过切块展平将二维图像数据转化为序列数据,之后类似于机器翻译任务中的词嵌入算法,将其映射到高维空间。在传入编码器中之前,需要像处理机器翻译任务的原始Transformer模型一样附加位置信息,即加上一个位置向量。除此之外,还需要添加一个分类标志位,便于最终输出概率分布。Encoder层由Multi-Head Attention与MLP构成,包含Residual Connection和Layer Normalization。为防止过拟合,引入Dropout,在使用数据训练时随机去掉一些神经元。MLPHead层将提取的分类标志位通过线性变换与激活函数的组合,得到待输出的类别概率分布。
在Encoder层,Multi-Head Attention将参数映射到不同的子空间,分别进行注意力计算,最终将各个结果进行拼接,它能够使各个独立的头分别关注不同的信息,例如全局信息、局部信息,所以可以寻找数据之间不同角度的关联,其计算公式为:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO
headi=Attention(QWi Q,KWi K,VWi V)
其中Q是查询矩阵,K是关键字矩阵,V是值矩阵,Wi Q是查询矩阵第i头的参数矩阵,Wi K是关键字矩阵第i头的参数矩阵,Wi V是值矩阵第i头的参数矩阵,WO是参数矩阵,Concat是连接操作,dk是查询矩阵的列数。
Encoder包含多个相同的EncoderBlock,具体结构如下图:
假设输入图像形状是224×224×3,那么计算过程如下:
步骤1:将图像进行16×16的分块处理,得到的每个图像块大小为14×14。使各个图像块展平成一维向量,再进行线性投影变换,整个过程可以通过使用卷积与展平实现,得到矩阵形状为196×768;
步骤2:添加分类标志位以及位置编码,得到矩阵形状为197×768;.
步骤3:输入到Encoder,依次通过Multi-Head Attention与MLP,每个子层包含Layer Normalization与Residual Connection,得到矩阵形状为197×768;
步骤4:Layer Normalization后提取分类标志位,经过MLPHead获取给定类别范围的概率分布,这里面最大的概率对应的类别即为预测的类别。
上述现有技术中的小波散射网络和ViT方案的缺点包括:小波散射网络是一种基于小波变换的非反馈式神经网络,该网络作为特征提取器具有如非扩张性、微小形变稳定性、平移不变性的良好性质,经过了严格的数学推导和理论证明。但是小波散射网络的滤波器权重固定,无法从数据中学习,因此只能捕捉到如平移、旋转等刚性变换,而无法处理更复杂的变换。
ViT是一种基于Transformer的用于处理图像分类任务的网络模型,具有学习长距离依赖能力强、多模态融合能力强、更具可解释性等优点,但是ViT在将图像数据转换成序列数据时,对图像的先切块后展平的操作必然伴随着图像内部结构的破坏。
发明内容
本发明的实施例提供了一种基于小波散射网络和ViT的图像分类方法,以实现有效地对图像进行分类。
为了实现上述目的,本发明采取了如下技术方案。
一种基于小波散射网络和ViT的图像分类方法,包括:
对图像数据进行预处理,获取带标签的预处理后的图像数据;
构建基于小波散射网络和ViT的分类模型ScatViT,设定模型参数;
设定训练参数,利用预处理后的图像数据训练分类模型ScatViT,得到训练好的分类模型ScatViT;
利用训练好的分类模型ScatViT对待分类图像进行分类处理。
优选地,所述的对图像数据进行预处理,获取带标签的预处理后的图像数据,包括:
对图像数据集进行划分,将图像数据集按19:1的比例均匀分为训练集和验证集,验证集中的每个类别的图片数量相同,将每一张图片按通道维度进行归一化处理,所述数据集包括cifar-10数据集和cifar-100数据集。
优选地,所述的构建基于小波散射网络和ViT的分类模型ScatViT,设定模型参数,包括:
将ViT的Patch Embedding模块替换为小波散射网络ScatNet,使用小波散射网络提取图像特征,利用改进后的小波散射网络和ViT构建分类模型ScatViT,其由ScatEmbedding、Encoder和MLP Head三部分组成,设定模型所涉及的参数包括:小波散射角度参数L=6,尺度参数J=2,最大路径长度M=2,嵌入层维度大小D=768,Encoder的深度S=12,Multi-Head Attention中的head数量H=12。
优选地,所述的Scat Embedding通过小波散射网络将待分类的二维图像数据转化为特征图序列,通过线性映射将特征图序列投影到高维空间,添加一个分类标志向量,以用于最终输出概率分布,添加一个可学习的位置编码矩阵,以用于附加位置信息;
Encoder由Multi-Head Attention与MLP Block构成,每个子层内部均使用Residual Connection,同时每个子层末端使用Layer Normalization,将Multi-HeadAttention的操作记为MSA,MLP Block的操作记为MLP。Multi-Head Attention是指将参数映射到不同子空间,分别进行注意力计算,最终将各个结果进行拼接;
MLP Head将提取的分类标志向量通过线性变换与激活函数的组合,得到待输出的类别概率分布。
优选地,所述的设定训练参数,利用预处理后的图像数据训练分类模型ScatViT,得到训练好的分类模型ScatViT,包括:
步骤3.1,对已构建好的分类模型ScatViT中的网络参数进行初始化,输入训练数据集;
步骤3.2,使用小波散射网络作为图像特征提取器,通过Scat Embedding中的小波散射网络提取多尺度、多方向的图像特征;
步骤3.3,在小波散射网络所提取的图像特征的基础上,将图像特征展平并投影到更高的维度,之后在图像特征中添加类别标记向量与可学习的位置编码矩阵,将改进后的图像特征输入到Encoder中来学习距离依赖关系;
步骤3.4,将Encoder的输出进行层标准化Layer Normalization后提取所添加的类别标记向量,将类别标记向量通过多层感知机MLP得到类别概率分布;
步骤3.5,根据得到的类别概率分布和真实标签计算交叉熵损失,使用梯度下降法更新网络参数,相关公式为:
其中num为计算样本数量,num_classes为类别数量,yic为符号函数,类别与真实标签相等时取值为1否则为0;pic是样本i属于c类的预测概率,θ是待更新参数,η是学习率,是Loss关于θ的梯度;
步骤3.6,所有训练集数据都处理完成后,输入验证集数据,计算分类准确率,回到步骤3.2迭代进行,直到到达设定的最大迭代轮数;
步骤3.7,选取验证集准确率最高的模型作为训练好的分类模型ScatViT。
由上述本发明的实施例提供的技术方案可以看出,本发明实施例提出了将图像切块操作改为使用小波散射网络提取图像特征的ScatViT模型,该模型改进了小波散射网络的由滤波器权重固定导致的无法从数据中学习的缺陷,修复由于切块操作所丢失的部分信息,并排除与图像分类无关信息的干扰,能更准确地表达图像的特征信息。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的一种基于小波散射网络和ViT的图像分类方法的处理流程图。
图2为本发明实施例提供的一种分类模型ScatViT的结构对比图。
图3为本发明实施例提供的一种cifar-10上的实验结果。
图4为本发明实施例提供的一种cifar-100上的实验结果。
具体实施方式
下面详细描述本发明的实施方式,所述实施方式的示例在附图中示出,其中自始至终相同或类似的标号表示相同或类似的元件或具有相同或类似功能的元件。下面通过参考附图描述的实施方式是示例性的,仅用于解释本发明,而不能解释为对本发明的限制。
本技术领域技术人员可以理解,除非特意声明,这里使用的单数形式“一”、“一个”、“所述”和“该”也可包括复数形式。应该进一步理解的是,本发明的说明书中使用的措辞“包括”是指存在所述特征、整数、步骤、操作、元件和/或组件,但是并不排除存在或添加一个或多个其他特征、整数、步骤、操作、元件、组件和/或它们的组。应该理解,当我们称元件被“连接”或“耦接”到另一元件时,它可以直接连接或耦接到其他元件,或者也可以存在中间元件。此外,这里使用的“连接”或“耦接”可以包括无线连接或耦接。这里使用的措辞“和/或”包括一个或更多个相关联的列出项的任一单元和全部组合。
本技术领域技术人员可以理解,除非另外定义,这里使用的所有术语(包括技术术语和科学术语)具有与本发明所属领域中的普通技术人员的一般理解相同的意义。还应该理解的是,诸如通用字典中定义的那些术语应该被理解为具有与现有技术的上下文中的意义一致的意义,并且除非像这里一样定义,不会用理想化或过于正式的含义来解释。
为便于对本发明实施例的理解,下面将结合附图以几个具体实施例为例做进一步的解释说明,且各个实施例并不构成对本发明实施例的限定。
为了改进小波散射网络的由滤波器权重固定导致的无法从数据中学习的缺陷,以及ViT的由对图像简单切块的操作导致的破坏图像内部结构的缺陷,本发明提出了一种基于小波散射网络和ViT的图像分类方法ScatViT。该方法首先利用小波散射网络提取多尺度、多方向的图像特征,获得具有局部平移不变性、微小形变稳定性、非扩张性的特征图,之后在获得的特征图的基础上再利用ViT进行图像分类。
为本发明实施例提供的一种基于小波散射网络和ViT的图像分类方法的处理流程如图1所示,包括以下步骤:
步骤1:对图像数据进行预处理,获取带标签的预处理后的图像数据;
步骤2:构建基于小波散射网络和ViT的分类模型ScatViT,设定分类模型ScatViT的参数;
步骤3:设定训练参数,利用预处理后的图像数据训练分类模型ScatViT;利用验证数据集对分类模型ScatViT的图像分类性能进行评估,得到评估合格的训练好的分类模型ScatViT;
步骤4:将待分类图像输入到训练好的分类模型ScatViT,分类模型ScatViT输出所述待分类图像的分类结果。
进一步地,上述步骤1具体包括:
本发明使用cifar-10数据集和cifar-100数据集进行训练和测试。以cifar-10数据集为例,该数据集共有10个不同类别,包含了训练集数据与验证集数据两部分,分别有50000张和10000张图片,图片形状为32×32×3。首先对数据集进行划分,将原始的图像数据集按19:1的比例均匀分为训练集和验证集,即划分后的训练集包含47500张图片,验证集包含2500张图片,且验证集中的每个类别的图片数量相同。之后将每一张图片按通道维度进行归一化处理。
进一步地,上述步骤2具体包括:
分类模型ScatViT由小波散射网络和ViT改进而来,将ViT的Patch Embedding模块替换为小波散射网络ScatNet,即由原来的对图像的简单切块操作改为使用小波散射网络提取图像特征。分类模型ScatViT由Scat Embedding、Encoder和MLP Head三部分组成,其结构图如图2所示。设定分类模型ScatViT中所涉及的参数包括:小波散射角度参数L=6,尺度参数J=2,最大路径长度M=2,嵌入层维度大小D=768,Encoder的深度S=12,Multi-HeadAttention中的head数量H=12。
Scat Embedding通过小波散射网络将待分类的二维图像数据转化为特征图序列,然后通过线性映射将特征图序列投影到高维空间,之后添加一个分类标志向量,以用于最终输出概率分布,添加一个可学习的位置编码矩阵,以用于附加位置信息。Scat Embedding中使用的小波函数表达式如下:
ψ(u)=(eiuξ-β)φ(u)
其中φ(u)是高斯均值滤波,σ2是二维高斯分布的参数,ψ(u)是Morlet小波,由φ(u)变换而来,i是虚数单位,ξ是变换的参数,β<<1是可调整的以满足∫ψ(u)=0,在本发明所有的数值实验中,σ=0.8,对ψ进行2j的尺度缩放与r的旋转,可得到小波如下:
ψλ(u)=2-2jψ(2-jr-1u)
其中j和r分别是尺度参数和角度参数,j∈Z确定尺度,r∈G确定方向,G是平面旋转群,λ=2-jr∈2-Z×G=ΛJ。对于图像x(u)和有序路径p=(λ1,λ2,...,λm),其对应的小波散射变换为:
其中,SI[p]是沿特定路径小波散射变换算子,它将图像沿路径p的变换得到图像特征,下标J代表最大尺度,即尺度参数j的最大值。 是将高斯均值滤波经过尺度参数为J的缩放得到的函数,*是卷积操作,将长度为m的有序路径p组成的集合记为所有长度的路径组成的集合记为即:
遍历所有可能的路径然后拼接即可得到最终的输出结果:
将小波散射网络的输出结果作为ViT模型的输入是出于以下几点考虑:
第一,小波散射网络作为图像特征提取器具有背景技术中提到的局部平移不变性、非扩张性、微小形变稳定性,当小波散射网络的输入发生微小扰动时,小波散射网络的输出不会发生大的变化,也即具有较强的稳健性;
第二,使用小波散射网络作为图像特征提取器,可以使ViT模型的输入从直接的图像变为处理过后的图像特征,对ViT模型的图像分类起到辅助的作用,相比直接将原始图像作为输入的端到端模型能取得更好的分类效果;
第三,通过对小波散射网络的尺度参数和ViT模型的切块大小进行调整,可以使得特征图的大小保持一致,这样在对两个模型进行最终的比较时,能确保模型的优越性来源于模型结构而非模型规模。
是将图像映射成最终结果的算子,由遍历所有长度得到,也即由SI[p]遍历所有可能的路径得到。在此结果的基础上,展平并投影到高维空间,添加一个分类标志向量用于最终输出概率分布以及一个可学习的位置编码矩阵用于附加位置信息,公式为:
z0=[Xclass;X1E;…;XNE]+Epos
其中,Xclass是分类标志向量,Xi是之前利用小波散射网络得到的特征图的第i个位置序列,E是投影矩阵,Epos是位置信息矩阵。需要注意的是,Xclass、Epos、E都是可学习的,这意味着它们是作为网络参数而存在的,随着不断训练它们的值不断更新。
Encoder由Multi-Head Attention与MLP Block构成,每个子层内部均使用Residual Connection,同时每个子层末端使用Layer Normalization。将Multi-HeadAttention的操作记为MSA,MLP Block的操作记为MLP。Multi-Head Attention是指将参数映射到不同子空间,分别进行注意力计算,最终将各个结果进行拼接,它能够使各个独立的头分别关注不同的信息,例如全局信息、局部信息,所以可以寻找数据之间不同角度的关联,其计算公式为:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO
headi=Attention(QWi Q,KWi K,VWi V)
其中Q是查询矩阵,K是关键字矩阵,V是值矩阵,Wi Q是查询矩阵第i头的参数矩阵,Wi K是关键字矩阵第i头的参数矩阵,Wi V是值矩阵第i头的参数矩阵,Wo是参数矩阵,Concat是连接操作,dk是查询矩阵的列数。在本模型中,Q、K、V三者相等,都是经过层标准化的上述提取的小波散射特征或者迭代中间值,即LN(zl-1),l=1,...,S。
MLP Block将结果投影到高维,之后通过激活函数,最后再降低到原始维度,公式为:
MLP(X)=T·GELU(KX)其中,T和K是可学习的投影矩阵,它们作为网络参数而存在,随着不断训练值不断变化,GELU是用于引入非线性的激活函数,其表达式为
整个计算过程为:
z′l=MSA(LN(zl-1))+zl-1,l=1,...,S
zl=MLP(LN(z′l))+z′l
其中MSA指的是Multi-Head Attention的操作,MLP指的是MLP Block的操作。由于Encoder的深度是S,所以需要将输入z0先经过Layer Normalization再经过MSA,之后再经过Layer Normalization后经过MLP,如此循环S次,最后提取分类标志向量进行LayerNormalization得到最终结果。LN是Layer Normalization,是zs的第0维度即提取的分类标志向量。
MLP Head将提取的分类标志向量通过线性变换与激活函数的组合,得到待输出的类别概率分布,公式为:
MLPHead(y)=tanh(yW)R
进一步地,上述步骤3具体包括:
输入图像大小img_size=32,类别总数numclasses=10,最大迭代轮次epochs=200,初始学习率base_lr=0.002,每个Batch包含的图像数量batch_size=32,Dropout层的几率dropout_ratio=0.3。
数据在ScatViT模型中的计算过程为:使用小波散射网络作为图像特征提取器,将图像输入已设置好参数的小波散射网络,得到的特征图大小为8×8×48,使各个图像块展平成一维向量,再进行线性投影变换,得到矩阵形状为64×768;添加分类标志位以及位置编码,得到矩阵形状为65×768;输入到编码器,依次通过Multi-Head Attention与MLPBlock,得到矩阵形状为65×768;标准化后提取形状为1×768的分类标志向量,输入到最后的MLP Mead,得到形状为1×10的类别概率分布。
步骤3.1,对已构建好的ScatViT中的网络参数进行初始化,输入训练数据集。
步骤3.2,通过Scat Embedding中的小波散射网络提取图像多尺度、多方向的特征。
步骤3.3,在所提取的图像特征的基础上,展平并投影到更高的维度,之后添加类别标记向量与可学习的位置编码矩阵,输入到Encoder中来学习距离依赖关系。
步骤3.4,将Encoder的输出进行Layer Normalization后提取所添加的类别标记向量,将其通过MLP得到类别概率分布。
步骤3.5,根据得到的类别概率分布和真实标签计算交叉熵损失,使用梯度下降法更新网络参数。相关公式为:
其中num为计算样本数量,num_classes为类别数量,yic为符号函数,类别与真实标签相等时取值为1否则为0;pic是样本i属于c类的预测概率,θ是待更新参数,η是学习率,是Loss关于θ的梯度。
步骤3.6,所有训练集数据都处理完成后,输入验证集并计算分类准确率,回到步骤3.2迭代进行,直到到达设定的最大迭代轮数200。
步骤3.7,选取验证集准确率最高的模型作为最优的分类模型ScatViT。
输出:训练集准确率曲线、验证集准确率曲线、最优模型。
cifar-10和cifar-100上的实验结果,包括训练集和验证集的准确率曲线、损失函数曲线,分别如图3和图4所示。可在损失函数曲线中看到ScatViT模型的收敛速度更快,在两个小规模数据集上ScatViT模型收敛均要早于ViT模型约50个epoch。
下表列出了ScatViT和ViT在cifar-10数据集、cifar-100数据集上的测试集准确率和模型参数量。
从上表可以看到,首先在参数量上,两个模型相差不大,这是由于模型超参数基本相同,只是网络结构有略微的差异。除此之外,无论在cifar-10数据集还是cifar-100数据集上,ScatViT模型的测试集准确率高于ViT模型,在cifar-10数据集上高5.4%,在cifar-100数据集上高3.6%。
综上所述,本发明实施例结合小波散射网络和ViT两个模型,提出了将图像切块操作改为使用小波散射网络提取图像特征的ScatViT模型,该模型改进了小波散射网络的由滤波器权重固定导致的无法从数据中学习的缺陷以及ViT的由对图像的简单切块的操作导致的破坏图像内部结构的缺陷,修复由于切块操作所丢失的部分信息,并排除与图像分类无关信息的干扰,使用小波散射网络相比简单的切块操作更能准确表达图像的特征信息。
本发明提出的ScatViT模型在小规模数据集、有限计算资源的情形下的图像分类性能相比ViT的图像分类性能更好,这体现在两方面,其一是ScatViT模型的测试集准确率更高,在cifar-10数据集上高5.4%,在cifar-100数据集上高3.6%;其二是ScatViT模型的收敛速度更快,可在损失函数曲线中看到在两个小规模数据集上ScatViT模型收敛均要早于ViT模型约50个epoch。本发明涉及的ScatViT模型相比同规模的ViT模型,在小规模数据集上的收敛速度更快、准确率更高。
本发明将ViT模型中的图像切块操作改为使用小波散射网络提取图像特征,修复由于切块操作所丢失的部分信息,并排除与图像分类无关信息的干扰,获得对于分类更为有效的特征表达。同时,ScatViT模型不仅可应用在图像分类领域,也可应用于计算机视觉中的其他领域,其应用领域非常广泛。
本领域普通技术人员可以理解:附图只是一个实施例的示意图,附图中的模块或流程并不一定是实施本发明所必须的。
通过以上的实施方式的描述可知,本领域的技术人员可以清楚地了解到本发明可借助软件加必需的通用硬件平台的方式来实现。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在存储介质中,如ROM/RAM、磁碟、光盘等,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例或者实施例的某些部分所述的方法。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置或系统实施例而言,由于其基本相似于方法实施例,所以描述得比较简单,相关之处参见方法实施例的部分说明即可。以上所描述的装置及系统实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
以上所述,仅为本发明较佳的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应该以权利要求的保护范围为准。
Claims (5)
1.一种基于小波散射网络和ViT的图像分类方法,其特征在于,包括:
对图像数据进行预处理,获取带标签的预处理后的图像数据;
构建基于小波散射网络和ViT的分类模型ScatViT,设定模型参数;
设定训练参数,利用预处理后的图像数据训练分类模型ScatViT,得到训练好的分类模型ScatViT;
利用训练好的分类模型ScatViT对待分类图像进行分类处理。
2.根据权利要求1所述的方法,其特征在于,所述的对图像数据进行预处理,获取带标签的预处理后的图像数据,包括:
对图像数据集进行划分,将图像数据集按19:1的比例均匀分为训练集和验证集,验证集中的每个类别的图片数量相同,将每一张图片按通道维度进行归一化处理,所述数据集包括cifar-10数据集和cifar-100数据集。
3.根据权利要求1所述的方法,其特征在于,所述的构建基于小波散射网络和ViT的分类模型ScatViT,设定模型参数,包括:
将ViT的Patch Embedding模块替换为小波散射网络ScatNet,使用小波散射网络提取图像特征,利用改进后的小波散射网络和ViT构建分类模型ScatViT,其由Scat Embedding、Encoder和MLP Head三部分组成,设定模型所涉及的参数包括:小波散射角度参数L=6,尺度参数J=2,最大路径长度M=2,嵌入层维度大小D=768,Encoder的深度S=12,Multi-Head Attention中的head数量H=12。
4.根据权利要求3所述的方法,其特征在于,所述的Scat Embedding通过小波散射网络将待分类的二维图像数据转化为特征图序列,通过线性映射将特征图序列投影到高维空间,添加一个分类标志向量,以用于最终输出概率分布,添加一个可学习的位置编码矩阵,以用于附加位置信息;
Encoder由Multi-Head Attention与MLP Block构成,每个子层内部均使用ResidualConnection,同时每个子层末端使用Layer Normalization,将Multi-Head Attention的操作记为MSA,MLP Block的操作记为MLP。Multi-Head Attention是指将参数映射到不同子空间,分别进行注意力计算,最终将各个结果进行拼接;
MLP Head将提取的分类标志向量通过线性变换与激活函数的组合,得到待输出的类别概率分布。
5.根据权利要求3或者4所述的方法,其特征在于,所述的设定训练参数,利用预处理后的图像数据训练分类模型ScatViT,得到训练好的分类模型ScatViT,包括:
步骤3.1,对已构建好的分类模型ScatViT中的网络参数进行初始化,输入训练数据集;
步骤3.2,使用小波散射网络作为图像特征提取器,通过Scat Embedding中的小波散射网络提取多尺度、多方向的图像特征;
步骤3.3,在小波散射网络所提取的图像特征的基础上,将图像特征展平并投影到更高的维度,之后在图像特征中添加类别标记向量与可学习的位置编码矩阵,将改进后的图像特征输入到Encoder中来学习距离依赖关系;
步骤3.4,将Encoder的输出进行层标准化Layer Normalization后提取所添加的类别标记向量,将类别标记向量通过多层感知机MLP得到类别概率分布;
步骤3.5,根据得到的类别概率分布和真实标签计算交叉熵损失,使用梯度下降法更新网络参数,相关公式为:
其中num为计算样本数量,num_classes为类别数量,yic为符号函数,类别与真实标签相等时取值为1否则为0;pic是样本i属于c类的预测概率,θ是待更新参数,η是学习率,是Loss关于θ的梯度;
步骤3.6,所有训练集数据都处理完成后,输入验证集数据,计算分类准确率,回到步骤3.2迭代进行,直到到达设定的最大迭代轮数;
步骤3.7,选取验证集准确率最高的模型作为训练好的分类模型ScatViT。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211089518.8A CN115937567B (zh) | 2022-09-07 | 2022-09-07 | 一种基于小波散射网络和ViT的图像分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211089518.8A CN115937567B (zh) | 2022-09-07 | 2022-09-07 | 一种基于小波散射网络和ViT的图像分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115937567A true CN115937567A (zh) | 2023-04-07 |
CN115937567B CN115937567B (zh) | 2023-07-07 |
Family
ID=86654621
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211089518.8A Active CN115937567B (zh) | 2022-09-07 | 2022-09-07 | 一种基于小波散射网络和ViT的图像分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115937567B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117974670A (zh) * | 2024-04-02 | 2024-05-03 | 齐鲁工业大学(山东省科学院) | 一种融合散射网络的图像分析方法、装置、设备及介质 |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200226427A1 (en) * | 2015-06-05 | 2020-07-16 | Kepler Vision Technologies Bv | Deep receptive field networks |
WO2021132633A1 (ja) * | 2019-12-26 | 2021-07-01 | 公益財団法人がん研究会 | Aiを用いた病理診断支援方法、及び支援装置 |
US20210241041A1 (en) * | 2020-01-31 | 2021-08-05 | Element Ai Inc. | Method of and system for joint data augmentation and classification learning |
CN113887610A (zh) * | 2021-09-29 | 2022-01-04 | 内蒙古工业大学 | 基于交叉注意力蒸馏Transformer的花粉图像分类方法 |
US20220036564A1 (en) * | 2020-08-03 | 2022-02-03 | Korea Advanced Institute Of Science And Technology | Method of classifying lesion of chest x-ray radiograph based on data normalization and local patch and apparatus thereof |
CN114332039A (zh) * | 2021-12-30 | 2022-04-12 | 东北电力大学 | 一种光伏板积灰浓度识别网络、系统及方法 |
CN114445366A (zh) * | 2022-01-26 | 2022-05-06 | 沈阳派得林科技有限责任公司 | 基于自注意力网络的长输管道射线影像缺陷智能识别方法 |
CN114758360A (zh) * | 2022-04-24 | 2022-07-15 | 北京医准智能科技有限公司 | 一种多模态图像分类模型训练方法、装置及电子设备 |
CN114966696A (zh) * | 2021-12-23 | 2022-08-30 | 昆明理工大学 | 一种基于Transformer的跨模态融合目标检测方法 |
-
2022
- 2022-09-07 CN CN202211089518.8A patent/CN115937567B/zh active Active
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20200226427A1 (en) * | 2015-06-05 | 2020-07-16 | Kepler Vision Technologies Bv | Deep receptive field networks |
WO2021132633A1 (ja) * | 2019-12-26 | 2021-07-01 | 公益財団法人がん研究会 | Aiを用いた病理診断支援方法、及び支援装置 |
US20210241041A1 (en) * | 2020-01-31 | 2021-08-05 | Element Ai Inc. | Method of and system for joint data augmentation and classification learning |
US20220036564A1 (en) * | 2020-08-03 | 2022-02-03 | Korea Advanced Institute Of Science And Technology | Method of classifying lesion of chest x-ray radiograph based on data normalization and local patch and apparatus thereof |
CN113887610A (zh) * | 2021-09-29 | 2022-01-04 | 内蒙古工业大学 | 基于交叉注意力蒸馏Transformer的花粉图像分类方法 |
CN114966696A (zh) * | 2021-12-23 | 2022-08-30 | 昆明理工大学 | 一种基于Transformer的跨模态融合目标检测方法 |
CN114332039A (zh) * | 2021-12-30 | 2022-04-12 | 东北电力大学 | 一种光伏板积灰浓度识别网络、系统及方法 |
CN114445366A (zh) * | 2022-01-26 | 2022-05-06 | 沈阳派得林科技有限责任公司 | 基于自注意力网络的长输管道射线影像缺陷智能识别方法 |
CN114758360A (zh) * | 2022-04-24 | 2022-07-15 | 北京医准智能科技有限公司 | 一种多模态图像分类模型训练方法、装置及电子设备 |
Non-Patent Citations (3)
Title |
---|
ALEXEY DOSOVITSKIY ET AL.: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", 《ARXIV[CS.CV]》, pages 1 - 22 * |
JOAN BRUNA ET AL.: "Invariant Scattering Convolution Networks", 《IEEE TRANSACTIONS ON PATTERN ANALYSIS AND MACHINE INTELLIGENCE》, vol. 35, no. 8, pages 1872 - 1886, XP011515339, DOI: 10.1109/TPAMI.2012.230 * |
曹琨: "基于Transformer框架的雷达遥感图像序列特征提取及分类研究", 《中国优秀硕士学位论文全文数据库工程科技Ⅱ辑》, no. 1, pages 028 - 137 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117974670A (zh) * | 2024-04-02 | 2024-05-03 | 齐鲁工业大学(山东省科学院) | 一种融合散射网络的图像分析方法、装置、设备及介质 |
CN117974670B (zh) * | 2024-04-02 | 2024-06-04 | 齐鲁工业大学(山东省科学院) | 一种融合散射网络的图像分析方法、装置、设备及介质 |
Also Published As
Publication number | Publication date |
---|---|
CN115937567B (zh) | 2023-07-07 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN106650813B (zh) | 一种基于深度残差网络和lstm的图像理解方法 | |
Gholamalinezhad et al. | Pooling methods in deep neural networks, a review | |
CN110490946B (zh) | 基于跨模态相似度和生成对抗网络的文本生成图像方法 | |
JP7193252B2 (ja) | 画像の領域のキャプション付加 | |
Passalis et al. | Training lightweight deep convolutional neural networks using bag-of-features pooling | |
CN105138973B (zh) | 人脸认证的方法和装置 | |
CN111160343B (zh) | 一种基于Self-Attention的离线数学公式符号识别方法 | |
EP3029606A2 (en) | Method and apparatus for image classification with joint feature adaptation and classifier learning | |
CN111723220A (zh) | 基于注意力机制和哈希的图像检索方法、装置及存储介质 | |
CN106339753A (zh) | 一种有效提升卷积神经网络稳健性的方法 | |
CN110188827A (zh) | 一种基于卷积神经网络和递归自动编码器模型的场景识别方法 | |
JP7252009B2 (ja) | 人工ニューラルネットワークを用いたocrシステムのための、線認識最大-最小プーリングを用いたテキスト画像の処理 | |
CN109492610B (zh) | 一种行人重识别方法、装置及可读存储介质 | |
CN112163114B (zh) | 一种基于特征融合的图像检索方法 | |
CN116075820A (zh) | 用于搜索图像数据库的方法、非暂时性计算机可读存储介质和设备 | |
CN115937567B (zh) | 一种基于小波散射网络和ViT的图像分类方法 | |
CN115131607A (zh) | 图像分类方法及装置 | |
Davoudi et al. | Ancient document layout analysis: Autoencoders meet sparse coding | |
CN110135363B (zh) | 基于判别词典嵌入行人图像检索方法、系统、设备及介质 | |
CN115640418B (zh) | 基于残差语义一致性跨域多视角目标网站检索方法及装置 | |
Liu et al. | Multi-digit recognition with convolutional neural network and long short-term memory | |
CN114944002B (zh) | 文本描述辅助的姿势感知的人脸表情识别方法 | |
CN116089646A (zh) | 一种基于显著性捕获机制的无人机图像哈希检索方法 | |
CN113449751A (zh) | 基于对称性和群论的物体-属性组合图像识别方法 | |
CN117994861B (zh) | 一种基于多模态大模型clip的视频动作识别方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |