CN117036901A - 一种基于视觉自注意力模型的小样本微调方法 - Google Patents

一种基于视觉自注意力模型的小样本微调方法 Download PDF

Info

Publication number
CN117036901A
CN117036901A CN202310867841.1A CN202310867841A CN117036901A CN 117036901 A CN117036901 A CN 117036901A CN 202310867841 A CN202310867841 A CN 202310867841A CN 117036901 A CN117036901 A CN 117036901A
Authority
CN
China
Prior art keywords
self
attention model
fine tuning
norm
visual self
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310867841.1A
Other languages
English (en)
Inventor
王鹏
付铭禹
李煜堃
索伟
张艳宁
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Northwestern Polytechnical University
Original Assignee
Northwestern Polytechnical University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Northwestern Polytechnical University filed Critical Northwestern Polytechnical University
Priority to CN202310867841.1A priority Critical patent/CN117036901A/zh
Publication of CN117036901A publication Critical patent/CN117036901A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/0895Weakly supervised learning, e.g. semi-supervised or self-supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于视觉自注意力模型的小样本微调方法,采用在大规模数据集上进行预训练和在小样本任务上进行微调的流程,视觉自注意力模型被用作主干网络,同时构建一个可学习的转换模块norm adapter,由两个向量组成,用于校正原始视觉自注意力模型归一化层的增益和偏置,norm adapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;在微调过程中,采用原型网络ProtoNet分类头。本发明计算简便,通过逐元素相乘和相加即可实现,因此占用的存储和计算资源比较少,有利于预训练模型投入实际应用场景。

Description

一种基于视觉自注意力模型的小样本微调方法
技术领域
本发明属于模式识别技术领域,具体涉及一种基于视觉自注意力模型的小样本微调方法。
背景技术
预训练模型被广泛用于自然语言处理(NLP)和计算机视觉(CV)领域,并大大改善了下游任务的性能。因此,预训练-微调范式已经被大家普遍接受,特别是在视觉注意力模型(ViT)兴起之后。由于预训练模型的规模很大,如何在有限的计算和存储开销下将预训练知识有效地迁移到下游任务中,仍在研究之中。已经有一些方法被提出来解决这个问题,称为参数高效微调(PEFT)方法,如:Adapter,bias-tuning,vis ual prompt tuning等等。
然而,关于参数高效微调方法在小样本图像分类中的研究却很少。小样本图像分类是小样本学习(few-shot learning)的一项基本任务。小样本学习可以通过模仿人类智能,凭借少量样本泛化到全新的概念,以此来扩大深度学习模型的应用范围。在小样本的设置中,测试数据会被划分成许多任务,每个任务由两部分组成:支持集和查询集,支持集包含N*K个带标注样本,即N个类别的数据,每类有K个样本,这样的小样本任务被称为“N-wayK-shot”形式的;查询集包含N*Q数量的样本,用于评估模型。
最近,Shell等人首次将预训练模型引入小样本分类领域。他们采用了预训练、元训练,最后微调的流程。首先模型在大规模数据集上(如ImageNet数据集)进行预训练,然后在目标域的基类数据上进行元训练,最后在微调过程中,使用少量样本对模型的所有参数进行更新(full-tuning)。预训练-元训练-微调的流程极大地提高了模型的性能。然而,用于元训练的目标域的基类数据并不容易获得,大多数情况下,只有极少量的带标注样本可以获取到。因此,在这种情况下无法进行元训练,而凭借少量样本对模型的所有参数进行更新(full-tuning)不能充分利用预训练知识。而且,更新全部参数带来的计算和存储开销很大,严重限制了其应用场景。因此,如何在小样本情况下进行高效微调仍然是一个开放的问题。
发明内容
为了克服现有技术的不足,本发明提供了一种基于视觉自注意力模型的小样本微调方法,采用在大规模数据集上进行预训练和在小样本任务上进行微调的流程,视觉自注意力模型被用作主干网络,同时构建一个可学习的转换模块norm adapter,由两个向量组成,用于校正原始视觉自注意力模型归一化层的增益和偏置,norm adapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;在微调过程中,采用原型网络ProtoNet分类头。本发明计算简便,通过逐元素相乘和相加即可实现,因此占用的存储和计算资源比较少,有利于预训练模型投入实际应用场景。
本发明解决其技术问题所采用的技术方案包括如下步骤:
步骤1:构建主干网络;
采用改进的视觉自注意力模型ViT作为主干网络;
原始视觉自注意力模型由一个补丁嵌入层和N个transformer层组成;经过补丁嵌入层,输入图像被编码成一定数量的token向量,再与位置编码相加后,输入的token向量连同CLS token被送入N个transformer层中;最终,经过N个transformer层和一个归一化层LayerNorm后,CLS token用于分类或其他目的;每个transformer层包含两个归一化层LayerNorm,一个MLP块和一个多头自注意力块MHSA;
构建一个可学习的转换模块,由两个向量组成,用于校正原始视觉自注意力模型归一化层LayerNorm的增益gain和偏置bias,将可学习的转换模块称为norm adapter;所述norm adapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;如公式(1)所示,Scale和Shift分别是norm adapter的两个可学习向量,y是归一化层的输出,⊙代表逐元素乘法:
h = Scale ⊙ y + Shift (1)
norm adapter的参数Scale和Shift的结构与归一化层的参数增益gain和偏置bias相同,分别被初始化为全1和全0的向量;在微调时,只有参数Scale和Shift被更新,其他参数在预训练后被冻结,不进行优化;
步骤2:在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;
步骤3:在微调过程中,采用原型网络ProtoNet分类头;该分类头根据查询图像和原型在嵌入空间中的距离,产生一个概率分布,如公式(2)所示:
其中,fφ是主干网络,将输入编码到特征空间;ck为类别k的原型,是属于类别k的特征的平均值;d是度量函数;具体来说,各个类别的原型是由支持集中每类样本求均值计算出来的,并将进行数据增强后的支持集作为伪查询集,然后由原型和伪查询集之间的余弦距离计算损失,更新参数;
损失函数选择交叉熵损失。
优选地,所述自监督方式,采用DINO和MOCO v3算法在ImageNet-1K数据集上训练主干网络;所述全监督方式,主干网络是在ImageNet-21K数据集上训练得到的。
优选地,所述度量函数使用余弦距离。
本发明的有益效果如下:
(1)本发明作为一种小样本微调方法,更新的参数量小,仅相当于全部微调(full-tuning)所需更新参数量的0.045%,计算简便,通过逐元素相乘和相加即可实现,因此占用的存储和计算资源比较少,有利于预训练模型投入实际应用场景。
(2)本发明在real、clipart、sketch、quickdraw四个数据集上的测试结果明显优于全部微调(full-tuning)、bias-tuning,visual prompt tuning等方法。
附图说明
图1为视觉自注意力模型ViT的transformer层示意图。
图2为加入norm adapter后的transformer层示意图。
具体实施方式
下面结合附图和实施例对本发明进一步说明。
本发明采用了在大规模数据集上进行预训练和在小样本任务上进行微调的流程,没有在目标域的基类数据上进行训练。视觉自注意力模型(ViT)被用作主干网络,一个普通的视觉自注意力模型由一个补丁嵌入层(patch embedding)和N个transformer层组成。经过补丁嵌入层,输入图像被编码成一定数量的token向量,在与位置编码相加后,输入的token向量连同CLS token被送入N个transformer层中。最终,经过N个transformer层和一个归一化层(LayerNorm)后,CLS token用于分类或其他目的。每个transformer层包含两个归一化层(LayerNorm),一个MLP块和一个多头自注意力块(MHSA)。图1为视觉自注意力模型(ViT)的transformer层,对应全部微调方法(Full-tuning),transformer层中的归一化层(LayerNorm),MLP块和多头自注意力块(MHSA)都是可学习的。
本发明提出使用一个可学习的转换模块,由两个向量组成,来校正归一化层(LayerNorm)的增益(gain)和偏置(bias),称为“norm adapter”。“norm adapter”位于视觉自注意力模型(ViT)的所有归一化层之后,以与增益和偏置相同的方式对激活值进行缩放和移位,具体来说,是通过逐元素的乘法和加法实现的,如公式(1)所示,Scale,Shift分别是“norm adapter”的两个可学习向量,y是归一化层的输出,⊙代表逐元素乘法。
h = Scale ⊙ y + Shift (1)
“norm adapter”的参数s1和s2的形状与归一化层的增益(gain)和偏置(bias)相同,分别被初始化为全一和全零的向量,因此,与微调前的原始预训练模型相比,带有“normadapter”的模型在计算结果上没有变化。在微调时,只有“norm adapter”的参数Scale和Shift被更新,其他参数在预训练后被冻结,不进行优化。图2为加入“norm adapter”后的transformer层,对应本发明提出的微调方法,transformer层中只有“norm adapter”的参数Scale、Shift是可学习的。
在预训练期间,使用在大规模数据集上以全监督或自监督方式训练的主干网络。对于自监督算法,采用DINO和MOCO v3算法在ImageNet-1K数据集上训练主干网络;对于全监督算法,主干网络是在ImageNet-21K数据集上训练得到的。
在微调过程中,采用了原型网络(ProtoNet)分类头。该分类头根据查询图像和原型在嵌入空间中的距离,产生一个概率分布,如公式(2)所示:
fφ是主干网络,将输入编码到特征空间。ck为类别k的原型,是属于类别k的特征的平均值。d是度量函数,这里使用的是余弦距离。具体来说,原型是由支持集计算出来的,并将数据增强后的支持集作为伪查询集。然后由原型和伪查询集之间的余弦距离计算损失,更新参数。损失函数选择交叉熵损失(Cross Entropy)。
本发明采用视觉自注意力模型(ViT)作为主干网络,包括ViT-Base/16和ViT-Small/16,对于ViT-Base/16,我们分别采用监督学习方法在ImageNet-21K数据集上训练,采用MOCO-v3算法在ImageNet-1K数据集上训练得到预训练主干网络;对于ViT-Small/16,采用DINO算法在ImageNet-1K数据集上训练。
在下游任务上微调和评估时采用了real、clipart、sketch、quickdraw四个数据集,它们是DomainNet的子数据集,包含相同的类别名。
在微调和评估过程中,采用30-way 5-shot的形式来构建小样本任务,每个任务包含5个类别的数据,每类数据有5张带标注样本和15张查询样本;所有图像均被调整成224*224分辨率大小;用于生成伪查询集的随机数据增强包括颜色抖动、水平翻转和平移;微调过程中有三个超参数比较关键:学习率、迭代次数和优化器,由于每个任务中的样本有限,最终性能对超参数的选择比较敏感,所以对于各种情况,根据验证集上50个任务的平均准确率来选择超参数,优化器从Adam或SGD中选择,学习率和迭代次数从经验范围内选择,分别为[1e-1,1e-2,1e-3,1e-4,1e-5,1e-6]和[20,50,80,100];最后,从测试集中随机选取600个任务进行评估,计算平均精度作为最终结果。所有的实验均采用固定的随机数种子。

Claims (3)

1.一种基于视觉自注意力模型的小样本微调方法,其特征在于,包括如下步骤:
步骤1:构建主干网络;
采用改进的视觉自注意力模型ViT作为主干网络;
原始视觉自注意力模型由一个补丁嵌入层和N个transformer层组成;经过补丁嵌入层,输入图像被编码成一定数量的token向量,再与位置编码相加后,输入的token向量连同CLS token被送入N个transformer层中;最终,经过N个transformer层和一个归一化层LayerNorm后,CLS token用于分类或其他目的;每个transformer层包含两个归一化层LayerNorm,一个MLP块和一个多头自注意力块MHSA;
构建一个可学习的转换模块,由两个向量组成,用于校正原始视觉自注意力模型归一化层LayerNorm的增益gain和偏置bias,将可学习的转换模块称为norm adapter;所述normadapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;如公式(1)所示,Scale和Shift分别是norm adapter的两个可学习向量,y是归一化层的输出,⊙代表逐元素乘法:
h = Scale ⊙ y + Shift (1)
norm adapter的参数Scale和Shift的结构与归一化层的参数增益gain和偏置bias相同,分别被初始化为全1和全0的向量;在微调时,只有参数Scale和Shift被更新,其他参数在预训练后被冻结,不进行优化;
步骤2:在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;
步骤3:在微调过程中,采用原型网络ProtoNet分类头;该分类头根据查询图像和原型在嵌入空间中的距离,产生一个概率分布,如公式(2)所示:
其中,fφ是主干网络,将输入编码到特征空间;ck为类别k的原型,是属于类别k的特征的平均值;d是度量函数;具体来说,各个类别的原型是由支持集中每类样本求均值计算出来的,并将进行数据增强后的支持集作为伪查询集,然后由原型和伪查询集之间的余弦距离计算损失,更新参数;
损失函数选择交叉熵损失。
2.根据权利要求1所述的一种基于视觉自注意力模型的小样本微调方法,其特征在于,所述自监督方式,采用DINO和MOCO v3算法在ImageNet-1K数据集上训练主干网络;所述全监督方式,主干网络是在ImageNet-21K数据集上训练得到的。
3.根据权利要求1所述的一种基于视觉自注意力模型的小样本微调方法,其特征在于,所述度量函数使用余弦距离。
CN202310867841.1A 2023-07-16 2023-07-16 一种基于视觉自注意力模型的小样本微调方法 Pending CN117036901A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310867841.1A CN117036901A (zh) 2023-07-16 2023-07-16 一种基于视觉自注意力模型的小样本微调方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310867841.1A CN117036901A (zh) 2023-07-16 2023-07-16 一种基于视觉自注意力模型的小样本微调方法

Publications (1)

Publication Number Publication Date
CN117036901A true CN117036901A (zh) 2023-11-10

Family

ID=88627066

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310867841.1A Pending CN117036901A (zh) 2023-07-16 2023-07-16 一种基于视觉自注意力模型的小样本微调方法

Country Status (1)

Country Link
CN (1) CN117036901A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117689044A (zh) * 2024-02-01 2024-03-12 厦门大学 一种适用于视觉自注意力模型的量化方法

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117689044A (zh) * 2024-02-01 2024-03-12 厦门大学 一种适用于视觉自注意力模型的量化方法

Similar Documents

Publication Publication Date Title
US11450066B2 (en) 3D reconstruction method based on deep learning
US9129222B2 (en) Method and apparatus for a local competitive learning rule that leads to sparse connectivity
CN106845529A (zh) 基于多视野卷积神经网络的影像特征识别方法
CN110728219A (zh) 基于多列多尺度图卷积神经网络的3d人脸生成方法
CN114332578A (zh) 图像异常检测模型训练方法、图像异常检测方法和装置
CN113989100B (zh) 一种基于样式生成对抗网络的红外纹理样本扩充方法
CN109146061A (zh) 神经网络模型的处理方法和装置
CN115222998B (zh) 一种图像分类方法
CN111210382A (zh) 图像处理方法、装置、计算机设备和存储介质
CN117036901A (zh) 一种基于视觉自注意力模型的小样本微调方法
CN113095254A (zh) 一种人体部位关键点的定位方法及系统
CN115311502A (zh) 基于多尺度双流架构的遥感图像小样本场景分类方法
CN113989612A (zh) 基于注意力及生成对抗网络的遥感影像目标检测方法
Wang et al. Global aligned structured sparsity learning for efficient image super-resolution
CN115471016A (zh) 一种基于cisso与daed的台风预测方法
CN117974693B (zh) 图像分割方法、装置、计算机设备和存储介质
CN117333516A (zh) 一种基于光流卷积神经网络的鲁棒性粒子图像测速方法
CN117992919A (zh) 基于机器学习和多气象模态融合的河流洪水预警方法
CN117611838A (zh) 一种基于自适应超图卷积网络的多标签图像分类方法
CN115760670B (zh) 基于网络隐式先验的无监督高光谱融合方法及装置
CN117274664A (zh) 一种视觉认知驱动的小样本图像分类方法、系统及介质
CN116797681A (zh) 渐进式多粒度语义信息融合的文本到图像生成方法及系统
CN116109868A (zh) 基于轻量化神经网络的图像分类模型构建和小样本图像分类方法
Saenz et al. Dimensionality-reduction of climate data using deep autoencoders
CN112991257B (zh) 基于半监督孪生网络的异质遥感图像变化快速检测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination