CN117036901A - 一种基于视觉自注意力模型的小样本微调方法 - Google Patents
一种基于视觉自注意力模型的小样本微调方法 Download PDFInfo
- Publication number
- CN117036901A CN117036901A CN202310867841.1A CN202310867841A CN117036901A CN 117036901 A CN117036901 A CN 117036901A CN 202310867841 A CN202310867841 A CN 202310867841A CN 117036901 A CN117036901 A CN 117036901A
- Authority
- CN
- China
- Prior art keywords
- self
- attention model
- fine tuning
- norm
- visual self
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 32
- 230000000007 visual effect Effects 0.000 title claims abstract description 31
- 238000012549 training Methods 0.000 claims abstract description 24
- 238000010606 normalization Methods 0.000 claims abstract description 23
- 239000013598 vector Substances 0.000 claims abstract description 19
- 238000006243 chemical reaction Methods 0.000 claims abstract description 7
- 238000004422 calculation algorithm Methods 0.000 claims description 7
- HDAJUGGARUFROU-JSUDGWJLSA-L MoO2-molybdopterin cofactor Chemical compound O([C@H]1NC=2N=C(NC(=O)C=2N[C@H]11)N)[C@H](COP(O)(O)=O)C2=C1S[Mo](=O)(=O)S2 HDAJUGGARUFROU-JSUDGWJLSA-L 0.000 claims description 3
- 238000004364 calculation method Methods 0.000 abstract description 5
- 230000009286 beneficial effect Effects 0.000 abstract description 4
- 230000006870 function Effects 0.000 description 5
- 238000012360 testing method Methods 0.000 description 3
- 238000011156 evaluation Methods 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 238000012935 Averaging Methods 0.000 description 1
- 230000004913 activation Effects 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000012854 evaluation process Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000011835 investigation Methods 0.000 description 1
- 238000004091 panning Methods 0.000 description 1
- 238000003909 pattern recognition Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000009966 trimming Methods 0.000 description 1
- 238000010200 validation analysis Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/09—Supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于视觉自注意力模型的小样本微调方法,采用在大规模数据集上进行预训练和在小样本任务上进行微调的流程,视觉自注意力模型被用作主干网络,同时构建一个可学习的转换模块norm adapter,由两个向量组成,用于校正原始视觉自注意力模型归一化层的增益和偏置,norm adapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;在微调过程中,采用原型网络ProtoNet分类头。本发明计算简便,通过逐元素相乘和相加即可实现,因此占用的存储和计算资源比较少,有利于预训练模型投入实际应用场景。
Description
技术领域
本发明属于模式识别技术领域,具体涉及一种基于视觉自注意力模型的小样本微调方法。
背景技术
预训练模型被广泛用于自然语言处理(NLP)和计算机视觉(CV)领域,并大大改善了下游任务的性能。因此,预训练-微调范式已经被大家普遍接受,特别是在视觉注意力模型(ViT)兴起之后。由于预训练模型的规模很大,如何在有限的计算和存储开销下将预训练知识有效地迁移到下游任务中,仍在研究之中。已经有一些方法被提出来解决这个问题,称为参数高效微调(PEFT)方法,如:Adapter,bias-tuning,vis ual prompt tuning等等。
然而,关于参数高效微调方法在小样本图像分类中的研究却很少。小样本图像分类是小样本学习(few-shot learning)的一项基本任务。小样本学习可以通过模仿人类智能,凭借少量样本泛化到全新的概念,以此来扩大深度学习模型的应用范围。在小样本的设置中,测试数据会被划分成许多任务,每个任务由两部分组成:支持集和查询集,支持集包含N*K个带标注样本,即N个类别的数据,每类有K个样本,这样的小样本任务被称为“N-wayK-shot”形式的;查询集包含N*Q数量的样本,用于评估模型。
最近,Shell等人首次将预训练模型引入小样本分类领域。他们采用了预训练、元训练,最后微调的流程。首先模型在大规模数据集上(如ImageNet数据集)进行预训练,然后在目标域的基类数据上进行元训练,最后在微调过程中,使用少量样本对模型的所有参数进行更新(full-tuning)。预训练-元训练-微调的流程极大地提高了模型的性能。然而,用于元训练的目标域的基类数据并不容易获得,大多数情况下,只有极少量的带标注样本可以获取到。因此,在这种情况下无法进行元训练,而凭借少量样本对模型的所有参数进行更新(full-tuning)不能充分利用预训练知识。而且,更新全部参数带来的计算和存储开销很大,严重限制了其应用场景。因此,如何在小样本情况下进行高效微调仍然是一个开放的问题。
发明内容
为了克服现有技术的不足,本发明提供了一种基于视觉自注意力模型的小样本微调方法,采用在大规模数据集上进行预训练和在小样本任务上进行微调的流程,视觉自注意力模型被用作主干网络,同时构建一个可学习的转换模块norm adapter,由两个向量组成,用于校正原始视觉自注意力模型归一化层的增益和偏置,norm adapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;在微调过程中,采用原型网络ProtoNet分类头。本发明计算简便,通过逐元素相乘和相加即可实现,因此占用的存储和计算资源比较少,有利于预训练模型投入实际应用场景。
本发明解决其技术问题所采用的技术方案包括如下步骤:
步骤1:构建主干网络;
采用改进的视觉自注意力模型ViT作为主干网络;
原始视觉自注意力模型由一个补丁嵌入层和N个transformer层组成;经过补丁嵌入层,输入图像被编码成一定数量的token向量,再与位置编码相加后,输入的token向量连同CLS token被送入N个transformer层中;最终,经过N个transformer层和一个归一化层LayerNorm后,CLS token用于分类或其他目的;每个transformer层包含两个归一化层LayerNorm,一个MLP块和一个多头自注意力块MHSA;
构建一个可学习的转换模块,由两个向量组成,用于校正原始视觉自注意力模型归一化层LayerNorm的增益gain和偏置bias,将可学习的转换模块称为norm adapter;所述norm adapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;如公式(1)所示,Scale和Shift分别是norm adapter的两个可学习向量,y是归一化层的输出,⊙代表逐元素乘法:
h = Scale ⊙ y + Shift (1)
norm adapter的参数Scale和Shift的结构与归一化层的参数增益gain和偏置bias相同,分别被初始化为全1和全0的向量;在微调时,只有参数Scale和Shift被更新,其他参数在预训练后被冻结,不进行优化;
步骤2:在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;
步骤3:在微调过程中,采用原型网络ProtoNet分类头;该分类头根据查询图像和原型在嵌入空间中的距离,产生一个概率分布,如公式(2)所示:
其中,fφ是主干网络,将输入编码到特征空间;ck为类别k的原型,是属于类别k的特征的平均值;d是度量函数;具体来说,各个类别的原型是由支持集中每类样本求均值计算出来的,并将进行数据增强后的支持集作为伪查询集,然后由原型和伪查询集之间的余弦距离计算损失,更新参数;
损失函数选择交叉熵损失。
优选地,所述自监督方式,采用DINO和MOCO v3算法在ImageNet-1K数据集上训练主干网络;所述全监督方式,主干网络是在ImageNet-21K数据集上训练得到的。
优选地,所述度量函数使用余弦距离。
本发明的有益效果如下:
(1)本发明作为一种小样本微调方法,更新的参数量小,仅相当于全部微调(full-tuning)所需更新参数量的0.045%,计算简便,通过逐元素相乘和相加即可实现,因此占用的存储和计算资源比较少,有利于预训练模型投入实际应用场景。
(2)本发明在real、clipart、sketch、quickdraw四个数据集上的测试结果明显优于全部微调(full-tuning)、bias-tuning,visual prompt tuning等方法。
附图说明
图1为视觉自注意力模型ViT的transformer层示意图。
图2为加入norm adapter后的transformer层示意图。
具体实施方式
下面结合附图和实施例对本发明进一步说明。
本发明采用了在大规模数据集上进行预训练和在小样本任务上进行微调的流程,没有在目标域的基类数据上进行训练。视觉自注意力模型(ViT)被用作主干网络,一个普通的视觉自注意力模型由一个补丁嵌入层(patch embedding)和N个transformer层组成。经过补丁嵌入层,输入图像被编码成一定数量的token向量,在与位置编码相加后,输入的token向量连同CLS token被送入N个transformer层中。最终,经过N个transformer层和一个归一化层(LayerNorm)后,CLS token用于分类或其他目的。每个transformer层包含两个归一化层(LayerNorm),一个MLP块和一个多头自注意力块(MHSA)。图1为视觉自注意力模型(ViT)的transformer层,对应全部微调方法(Full-tuning),transformer层中的归一化层(LayerNorm),MLP块和多头自注意力块(MHSA)都是可学习的。
本发明提出使用一个可学习的转换模块,由两个向量组成,来校正归一化层(LayerNorm)的增益(gain)和偏置(bias),称为“norm adapter”。“norm adapter”位于视觉自注意力模型(ViT)的所有归一化层之后,以与增益和偏置相同的方式对激活值进行缩放和移位,具体来说,是通过逐元素的乘法和加法实现的,如公式(1)所示,Scale,Shift分别是“norm adapter”的两个可学习向量,y是归一化层的输出,⊙代表逐元素乘法。
h = Scale ⊙ y + Shift (1)
“norm adapter”的参数s1和s2的形状与归一化层的增益(gain)和偏置(bias)相同,分别被初始化为全一和全零的向量,因此,与微调前的原始预训练模型相比,带有“normadapter”的模型在计算结果上没有变化。在微调时,只有“norm adapter”的参数Scale和Shift被更新,其他参数在预训练后被冻结,不进行优化。图2为加入“norm adapter”后的transformer层,对应本发明提出的微调方法,transformer层中只有“norm adapter”的参数Scale、Shift是可学习的。
在预训练期间,使用在大规模数据集上以全监督或自监督方式训练的主干网络。对于自监督算法,采用DINO和MOCO v3算法在ImageNet-1K数据集上训练主干网络;对于全监督算法,主干网络是在ImageNet-21K数据集上训练得到的。
在微调过程中,采用了原型网络(ProtoNet)分类头。该分类头根据查询图像和原型在嵌入空间中的距离,产生一个概率分布,如公式(2)所示:
fφ是主干网络,将输入编码到特征空间。ck为类别k的原型,是属于类别k的特征的平均值。d是度量函数,这里使用的是余弦距离。具体来说,原型是由支持集计算出来的,并将数据增强后的支持集作为伪查询集。然后由原型和伪查询集之间的余弦距离计算损失,更新参数。损失函数选择交叉熵损失(Cross Entropy)。
本发明采用视觉自注意力模型(ViT)作为主干网络,包括ViT-Base/16和ViT-Small/16,对于ViT-Base/16,我们分别采用监督学习方法在ImageNet-21K数据集上训练,采用MOCO-v3算法在ImageNet-1K数据集上训练得到预训练主干网络;对于ViT-Small/16,采用DINO算法在ImageNet-1K数据集上训练。
在下游任务上微调和评估时采用了real、clipart、sketch、quickdraw四个数据集,它们是DomainNet的子数据集,包含相同的类别名。
在微调和评估过程中,采用30-way 5-shot的形式来构建小样本任务,每个任务包含5个类别的数据,每类数据有5张带标注样本和15张查询样本;所有图像均被调整成224*224分辨率大小;用于生成伪查询集的随机数据增强包括颜色抖动、水平翻转和平移;微调过程中有三个超参数比较关键:学习率、迭代次数和优化器,由于每个任务中的样本有限,最终性能对超参数的选择比较敏感,所以对于各种情况,根据验证集上50个任务的平均准确率来选择超参数,优化器从Adam或SGD中选择,学习率和迭代次数从经验范围内选择,分别为[1e-1,1e-2,1e-3,1e-4,1e-5,1e-6]和[20,50,80,100];最后,从测试集中随机选取600个任务进行评估,计算平均精度作为最终结果。所有的实验均采用固定的随机数种子。
Claims (3)
1.一种基于视觉自注意力模型的小样本微调方法,其特征在于,包括如下步骤:
步骤1:构建主干网络;
采用改进的视觉自注意力模型ViT作为主干网络;
原始视觉自注意力模型由一个补丁嵌入层和N个transformer层组成;经过补丁嵌入层,输入图像被编码成一定数量的token向量,再与位置编码相加后,输入的token向量连同CLS token被送入N个transformer层中;最终,经过N个transformer层和一个归一化层LayerNorm后,CLS token用于分类或其他目的;每个transformer层包含两个归一化层LayerNorm,一个MLP块和一个多头自注意力块MHSA;
构建一个可学习的转换模块,由两个向量组成,用于校正原始视觉自注意力模型归一化层LayerNorm的增益gain和偏置bias,将可学习的转换模块称为norm adapter;所述normadapter位于视觉自注意力模型ViT的所有归一化层之后,通过逐元素的乘法和加法实现;如公式(1)所示,Scale和Shift分别是norm adapter的两个可学习向量,y是归一化层的输出,⊙代表逐元素乘法:
h = Scale ⊙ y + Shift (1)
norm adapter的参数Scale和Shift的结构与归一化层的参数增益gain和偏置bias相同,分别被初始化为全1和全0的向量;在微调时,只有参数Scale和Shift被更新,其他参数在预训练后被冻结,不进行优化;
步骤2:在预训练期间,使用大规模数据集上以全监督或自监督方式训练的主干网络;
步骤3:在微调过程中,采用原型网络ProtoNet分类头;该分类头根据查询图像和原型在嵌入空间中的距离,产生一个概率分布,如公式(2)所示:
其中,fφ是主干网络,将输入编码到特征空间;ck为类别k的原型,是属于类别k的特征的平均值;d是度量函数;具体来说,各个类别的原型是由支持集中每类样本求均值计算出来的,并将进行数据增强后的支持集作为伪查询集,然后由原型和伪查询集之间的余弦距离计算损失,更新参数;
损失函数选择交叉熵损失。
2.根据权利要求1所述的一种基于视觉自注意力模型的小样本微调方法,其特征在于,所述自监督方式,采用DINO和MOCO v3算法在ImageNet-1K数据集上训练主干网络;所述全监督方式,主干网络是在ImageNet-21K数据集上训练得到的。
3.根据权利要求1所述的一种基于视觉自注意力模型的小样本微调方法,其特征在于,所述度量函数使用余弦距离。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310867841.1A CN117036901A (zh) | 2023-07-16 | 2023-07-16 | 一种基于视觉自注意力模型的小样本微调方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310867841.1A CN117036901A (zh) | 2023-07-16 | 2023-07-16 | 一种基于视觉自注意力模型的小样本微调方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117036901A true CN117036901A (zh) | 2023-11-10 |
Family
ID=88627066
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310867841.1A Pending CN117036901A (zh) | 2023-07-16 | 2023-07-16 | 一种基于视觉自注意力模型的小样本微调方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117036901A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117689044A (zh) * | 2024-02-01 | 2024-03-12 | 厦门大学 | 一种适用于视觉自注意力模型的量化方法 |
-
2023
- 2023-07-16 CN CN202310867841.1A patent/CN117036901A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117689044A (zh) * | 2024-02-01 | 2024-03-12 | 厦门大学 | 一种适用于视觉自注意力模型的量化方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11450066B2 (en) | 3D reconstruction method based on deep learning | |
US9129222B2 (en) | Method and apparatus for a local competitive learning rule that leads to sparse connectivity | |
CN106845529A (zh) | 基于多视野卷积神经网络的影像特征识别方法 | |
CN110728219A (zh) | 基于多列多尺度图卷积神经网络的3d人脸生成方法 | |
CN114332578A (zh) | 图像异常检测模型训练方法、图像异常检测方法和装置 | |
CN113989100B (zh) | 一种基于样式生成对抗网络的红外纹理样本扩充方法 | |
CN109146061A (zh) | 神经网络模型的处理方法和装置 | |
CN115222998B (zh) | 一种图像分类方法 | |
CN111210382A (zh) | 图像处理方法、装置、计算机设备和存储介质 | |
CN117036901A (zh) | 一种基于视觉自注意力模型的小样本微调方法 | |
CN113095254A (zh) | 一种人体部位关键点的定位方法及系统 | |
CN115311502A (zh) | 基于多尺度双流架构的遥感图像小样本场景分类方法 | |
CN113989612A (zh) | 基于注意力及生成对抗网络的遥感影像目标检测方法 | |
Wang et al. | Global aligned structured sparsity learning for efficient image super-resolution | |
CN115471016A (zh) | 一种基于cisso与daed的台风预测方法 | |
CN117974693B (zh) | 图像分割方法、装置、计算机设备和存储介质 | |
CN117333516A (zh) | 一种基于光流卷积神经网络的鲁棒性粒子图像测速方法 | |
CN117992919A (zh) | 基于机器学习和多气象模态融合的河流洪水预警方法 | |
CN117611838A (zh) | 一种基于自适应超图卷积网络的多标签图像分类方法 | |
CN115760670B (zh) | 基于网络隐式先验的无监督高光谱融合方法及装置 | |
CN117274664A (zh) | 一种视觉认知驱动的小样本图像分类方法、系统及介质 | |
CN116797681A (zh) | 渐进式多粒度语义信息融合的文本到图像生成方法及系统 | |
CN116109868A (zh) | 基于轻量化神经网络的图像分类模型构建和小样本图像分类方法 | |
Saenz et al. | Dimensionality-reduction of climate data using deep autoencoders | |
CN112991257B (zh) | 基于半监督孪生网络的异质遥感图像变化快速检测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |