CN117934995A - 基于多注意力的模型训练方法、系统、设备及存储介质 - Google Patents

基于多注意力的模型训练方法、系统、设备及存储介质 Download PDF

Info

Publication number
CN117934995A
CN117934995A CN202410102137.1A CN202410102137A CN117934995A CN 117934995 A CN117934995 A CN 117934995A CN 202410102137 A CN202410102137 A CN 202410102137A CN 117934995 A CN117934995 A CN 117934995A
Authority
CN
China
Prior art keywords
attention
visual
features
model
feature
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202410102137.1A
Other languages
English (en)
Inventor
李为
钟翔
董全超
李远钱
易长渝
陈佳鹏
焦宾
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Wuhu Yuncong Technology Co ltd
Original Assignee
Wuhu Yuncong Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Wuhu Yuncong Technology Co ltd filed Critical Wuhu Yuncong Technology Co ltd
Priority to CN202410102137.1A priority Critical patent/CN117934995A/zh
Publication of CN117934995A publication Critical patent/CN117934995A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/42Global feature extraction by analysis of the whole pattern, e.g. using frequency domain transformations or autocorrelation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/80Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level
    • G06V10/806Fusion, i.e. combining data from various sources at the sensor level, preprocessing level, feature extraction level or classification level of extracted features

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • General Physics & Mathematics (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Software Systems (AREA)
  • Databases & Information Systems (AREA)
  • Computing Systems (AREA)
  • Artificial Intelligence (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及计算机视觉领域,具体提供一种基于多注意力的模型训练方法、系统、设备及存储介质,旨在解决预训练的视觉模型不能更好地兼容灵活使用多注意力机制的下游任务,降低下游任务处理精度。为此目的,本发明方法包括:获取图像特征并掩码处理,获取掩码处理后的图像特征;选择一种目标注意力并通过视觉模型处理该掩码处理后的图像特征,获取输出特征;基于输出特征、输入图像中与掩码位置对应的特征,进行损失计算并回传,训练视觉模型;上述模型训练中采用了多种注意力机制,兼容灵活使用多注意力机制的下游任务,提高下游任务处理精度。

Description

基于多注意力的模型训练方法、系统、设备及存储介质
技术领域
本发明涉及视觉模型训练领域,具体提供一种基于多注意力的模型训练方法、系统、设备及存储介质。
背景技术
随着计算机视觉基础模型的发展,视觉Transformer结构由于其强大的视觉建模能力和良好的网络规模扩展性,逐渐成为各个视觉子领域的主流模型。虽然视觉transformer结构有着十分卓越的表现,但是其在应用前需要经过大量的数据进行预训练,由于经过人工标注的数据十分有限,因此一系列基于对比学习或者掩码重建的自监督预训练方法可以应用于视觉transformer结构。
然而由于预训练和下游任务的图片类型和使用分辨率的不同,以及显存资源的限制,往往下游任务会调整注意力机制,使得和预训练任务使用的注意力机制存在差异,从而影响预训练模型在下游任务上的效果。
相应地,本领域需要一种基于多注意力的模型训练方案来解决上述技术问题。
发明内容
为了克服上述缺陷,提出了本发明,以提供解决或至少部分地解决如何使得预训练的视觉模型能够更好地兼容灵活使用多注意力机制的下游任务,提高下游任务处理精度。
在第一方面,本发明提供一种基于多注意力的模型训练方法,包括:
通过视觉模型中的特征提取模块,对训练数据集中每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;
针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;
针对每个输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型。
在一个实施例中,所述视觉模型包括以下至少之一:视觉变换网络结构即视觉Transformer结构、层次视觉变换网络结构即层次视觉Transformer结构。
在一个实施例中,所述视觉模型为所述层次视觉Transformer结构时,所述特征提取模块包括以下至少之一:进入堆叠的处理模块前的网络层、预设数量的处理模块。
在一个实施例中,对输入图像进行特征提取并进行掩码处理后,对掩码掉的特征,采用可学习的特征向量进行填充,获取掩码处理后的图像特征。
在一个实施例中,所述视觉模型为所述视觉Transformer结构时,所述特征提取模块为进入堆叠的处理模块前的网络层。
在一个实施例中,对输入图像进行特征提取并进行掩码处理后,删除掩码掉的特征,保留未被掩码掉的图像特征。
在一个实施例中,基于所述输出特征,确定与掩码位置对应的第一目标特征,包括:对所述输出特征进行解码,基于解码后的输出特征,确定与掩码位置对应的第一目标特征。
在一个实施例中,所述目标注意力种类包括以下至少之一:全局注意力、局部窗口注意力、自注意力、空间域注意力、通道域注意力、时间域注意力、混合域注意力。
在一个实施例中,所述第一目标特征类型包括以下至少之一:原始像素特征、patch特性即区域特征、HoG特征即方向梯度直方图特征、视觉图像特征;所述第二目标特征类型包括以下至少之一:原始像素特征、patch特性即区域特征、HoG特征即方向梯度直方图特征、视觉图像特征。
在一个实施例中,基于所述第一目标特征、所述第二目标特征,采用以下损失函数之一,进行损失计算:平均绝对误差损失函数L1、均方误差损失函数L2、光滑绝对误差损失函数Smooth-L1、负余弦距离。
在一个实施例中,预先设定训练数据集中输入单元大小即输入单元中的输入图像数量,其中,所述输入图像为样本图像,所述输入单元为Batch即每一次模型训练时所处理的样本数量。
在一个实施例中,将训练数据集划分为多个Batch进行训练。
在第二方面,本发明提供一种基于多注意力的模型训练系统,包括:
特征提取模块,用于通过视觉模型中的特征提取模块,对训练数据集中每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;
处理模块,用于针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;
模型训练模块,用于针对每个输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型。
在一个实施例中,所述视觉模型包括以下至少之一:视觉变换网络结构即视觉Transformer结构、层次视觉变换网络结构即层次视觉Transformer结构。
在第三方面,提供一种计算机设备,包括处理器和存储装置,其中所述存储器中存储有程序,所述处理器执行所述程序时实现上述方法的技术方案中任一项技术方案所述的基于多注意力的模型训练方法。
在第四方面,提供一种计算机可读存储介质,存储有程序,所述程序被执行时实现上述方法的技术方案中任一项技术方案所述的基于多注意力的模型训练方法。
本发明上述一个或多个技术方案,至少具有如下一种或多种有益效果:
在实施本发明的技术方案中:通过视觉模型中的特征提取模块,对每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;针对每个输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型;在训练数据集的一次训练中有效利用了多种注意力机制训练视觉模型,使得训练的视觉模型能够在下游任务中具有较好的兼容性,即无论下游任务使用多种注意力机制中的一个或多个,任务均能有较好的精度。
进一步地,视觉模型中的处理模块,无论采用哪种注意力,处理模块的参数不增加,即视觉模型训练时,不需要增加模型参数和训练成本的情况下,使得模型在下游任务就可以灵活使用多种注意力机制中的一个或多个,任务均能有较好的精度。
附图说明
参照附图,本发明的公开内容将变得更易理解。本领域技术人员容易理解的是:这些附图仅仅用于说明的目的,而并非意在对本发明的保护范围组成限制。此外,图中类似的数字用以表示类似的部件,其中:
图1是根据本发明的一个实施例的基于多注意力的模型训练方法的主要步骤流程示意图;
图2是根据本发明的一个实施例的视觉模型训练方法的主要步骤流程示意图;
图3是根据本发明的一个实施例的基于多注意力的模型训练系统的主要结构框图示意图。
具体实施方式
下面参照附图来描述本发明的一些实施方式。本领域技术人员应当理解的是,这些实施方式仅仅用于解释本发明的技术原理,并非旨在限制本发明的保护范围。
在本发明的描述中,“模块”、“处理器”可以包括硬件、软件或者两者的组合。一个模块可以包括硬件电路,各种合适的感应器,通信端口,存储器,也可以包括软件部分,比如程序代码,也可以是软件和硬件的组合。处理器可以是中央处理器、微处理器、图像处理器、数字信号处理器或者其他任何合适的处理器。处理器具有数据和/或信号处理功能。处理器可以以软件方式实现、硬件方式实现或者二者结合方式实现。非暂时性的计算机可读存储介质包括任何合适的可存储程序代码的介质,比如磁碟、硬盘、光碟、闪存、只读存储器、随机存取存储器等等。术语“A和/或B”表示所有可能的A与B的组合,比如只是A、只是B或者A和B。术语“至少一个A或B”或者“A和B中的至少一个”含义与“A和/或B”类似,可以包括只是A、只是B或者A和B。单数形式的术语“一个”、“这个”也可以包含复数形式。
名词解释:
自监督:属于无监督学习范畴,主要是通过辅助任务从大规模的无标签数据中挖掘到自身监督信息,然后利用此种监督信息对模型进行训练,从而学习到有效的特征表达能力。
全局注意力:在输入的图片变换为特征序列后,基于所有的特征建模计算互相之间的相关性,然后汇聚变换得到新的特征;
局部窗口注意力:在输入的图片变为特征序列后,需要整理为二维的形式,然后将特征划为特定窗口大小的集合,每个窗口内的特征各自进行相关性计算以及特征变换等。通常在使用局部窗口注意力时,除了按照上述最基本的使用方式外,swin transformer提出的shift-window机制也是最受欢迎的变体之一;
HOG特征:方向梯度直方图,通过计算图像部分区域的梯度信息,并进行统计梯度信息的直方图来构成特征向量;
CLIP模型:通过利用互联网上海量的图片文本对,以多模态对齐的自监督方式训练得到的神经网络预训练模型。
参阅附图1,图1是根据本发明的一个实施例的基于多注意力的模型训练方法的主要步骤流程示意图。如图1所示,主要包括下列步骤S10-S30:
S10,通过视觉模型中的特征提取模块,对训练数据集中每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;
在本实施例中,预先设定训练数据集中输入单元大小即输入单元中的输入图像数量,其中,所述输入图像为样本图像,所述输入单元为Batch即每一次模型训练时所处理的样本数量;一般将训练数据集划分为多个Batch进行训练。
具体而言,Batch(批次)是指每一次模型训练时所处理的样本数量,在深度学习中,通常将训练数据集划分为多个Batch进行训练,这样可以减少内存占用,加快模型训练速度,同时也可以增加模型的泛化能力。通常,选择的Batch大小要根据训练数据集的大小和硬件性能来确定。
在深度机器学习中,Batch的大小是一个重要的超参数,其对模型训练的效果有重要影响。Batch size指的是每次模型训练时,从训练数据集中抽取的样本个数。其具体影响主要体现在以下几个方面:
1、计算效率:Batch size较大时,每次迭代需要处理的数据更多,可以更充分地利用硬件资源,如GPU的并行计算能力,从而提高计算效率;这一点对于处理大规模数据集特别重要,然而,当batch size过大时,容易导致内存溢出。
2、模型性能:Batch size可以影响模型收敛速度和模型最终性能。当batch size较小时,每次迭代使用的样本少,训练过程中的梯度更新方向可能会比较嘈杂,这样可以帮助模型跳出局部最优,有助于提高模型的泛化性能。而当batch size较大时,梯度的方向通常会比较准确,训练收敛速度快,但可能会陷入局部最优。
3、梯度更新:Batch size也影响梯度的更新。当batch size较小时,由于每批次的样本差异大,可能会导致模型在训练时出现回缩、震荡等现象,影响模型的收敛。而当batchsize较大时,每批次的样本差异小,梯度的方向更稳定,可以保证模型更稳定地收敛。
4、泛化能力:理论上,较小的batch size能够带来更好的模型泛化能力,这是因为小batch size在模型训练中引入了一种随机性,可以使模型有机会跳出局部最优,从而有可能找到更好的全局最优。
在本实施例中,所述视觉模型包括以下至少之一:视觉变换网络结构即视觉Transformer结构、层次视觉变换网络结构即层次视觉Transformer结构。
在本实施例中,在所述视觉模型为所述视觉Transformer结构时,所述特征提取模块为进入堆叠的处理模块前的网络层,具体而言:所述特征提取模块是进入堆叠的transformer block即处理模块前的网络层,如patch embedding层,在对输入图像进行特征提取并进行掩码处理后,删除掩码掉的特征,保留未被掩码掉的图像特征。
在本实施例中,在所述视觉模型为所述层次视觉Transformer结构时,所述特征提取模块包括以下至少之一:进入堆叠的处理模块前的网络层、预设数量的处理模块,具体而言:对于层次视觉transformer结构(如swin transformer)而言,所述特征提取模块是进入堆叠的处理模块前的网络层,如patch embedding层、预设数量的处理模块,如:前两个阶段的处理模块:swin transformer block、patch merge。
在本实施例中,将swin transformer的前两个阶段也归为特征提取模块,第三个阶段作为主阶段(堆叠的多个transformer block),去掉了第四个阶段,由于阶段数的减少,在第三阶段(主阶段)适当增加了swin transformer block的数量,以保证网络的建模能力,主阶段处理的特征图大小是原图的1/16,当然也可以是其他比例。
在本实施例中,为了更好地获取层次化特征来加强视觉模型在下游任务的表现,使用层次视觉Transformer结构,在经过patch embedding层后,需要对提取的特征按照一定的比例进行掩码处理,例如:按照75%的比例进行掩码处理;由于后续会使用到注意力机制(例如:局部窗口注意力机制),所以后续需要保留所有的特征序列进行输入,其中,对掩码掉的特征,采用可学习的特征向量进行填充,获取掩码处理后的图像特征。
S20,针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;
本实施例中,所述目标注意力种类包括以下至少之一:全局注意力、局部窗口注意力、自注意力、空间域注意力、通道域注意力、时间域注意力、混合域注意力。
在本实施例中,所述视觉模型包括多个处理模块,例如:多个堆叠的transformerblock,处理所述掩码处理后的图像特征;当选择全局注意力时,所述视觉模型中多个堆叠的transformer block(例如:12层transformer block),将以全局注意力的方式处理各个输入单元对应的掩码处理后的图像特征,当选择局部窗口注意力时,所述视觉模型中多个堆叠的transformer block,将以局部窗口注意力的方式处理各个输入单元对应的掩码处理后的图像特征。
由于针对所有输入单元(也即训练数据集的全部样本数量),选择的目标注意力种类包括两种,或两种以上,也即在训练数据集的一次完整训练时,采用混合注意力机制(例如:全局注意力、局部窗口注意力),对模型进行训练,可以同时汲取两种注意力机制的优点,使得其在下游任务上的表现可以超过单一使用一种注意力机制的方法。
视觉模型中的处理模块,无论采用哪种注意力,处理模块的参数不增加,即视觉模型训练时,不需要增加模型参数和训练成本的情况下,使得模型在下游任务就可以灵活使用多种注意力机制中的一个或多个,任务均能有较好的精度。
S30,基于所述输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型。
在本实施例中,在所述视觉模型为所述视觉Transformer结构时,特征提取模块(进入堆叠的transformer block即处理模块前的网络层)在对输入图像进行特征提取并进行掩码处理后,删除掩码掉的特征,保留未被掩码掉的图像特征;此后,获取的所述输出特征需要进行解码处理,得到解码后的输出特征,基于解码后的输出特征,确定与掩码位置对应的第一目标特征;进行解码时候,需要经过多层transformer block的变换进行特征解码。
在本实施例中,在所述视觉模型为层次视觉Transformer结构时,在经过patchembedding层后,需要对提取的特征按照一定的比例进行掩码处理,例如:按照75%的比例进行掩码处理;由于后续会使用到注意力机制(例如:局部窗口注意力机制),所以后续需要保留所有的特征序列进行输入,其中,对掩码掉的特征,采用可学习的特征向量进行填充,获取掩码处理后的图像特征;此后,获取的所述输出特征可以不用进行解码处理,直接确定与掩码位置对应的第一目标特征。
本实施例中,所述第一目标特征类型包括以下至少之一:原始像素特征、patch特性即区域特征、HoG特征即方向梯度直方图特征、视觉图像特征;所述第二目标特征类型包括以下至少之一:原始像素特征、patch特性即区域特征、HoG特征即方向梯度直方图特征、视觉图像特征。
本实施例中,基于所述第一目标特征、所述第二目标特征,采用以下损失函数之一,进行损失计算:平均绝对误差损失函数L1、均方误差损失函数L2、光滑绝对误差损失函数Smooth-L1、负余弦距离。
在实施本发明的技术方案中:通过视觉模型中的特征提取模块,对训练数据集中每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;针对每个输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型;在训练数据集的一次训练中有效利用了多种注意力机制训练视觉模型,使得训练的视觉模型能够在下游任务中具有较好的兼容性,即无论下游任务使用多种注意力机制中的一个或多个,任务均能有较好的精度。
进一步地,视觉模型中的处理模块,无论采用哪种注意力,处理模块的参数不增加,即视觉模型训练时,不需要增加模型参数和训练成本的情况下,使得模型在下游任务就可以灵活使用多种注意力机制中的一个或多个,任务均能有较好的精度。
图2是根据本发明的一个实施例的视觉模型训练方法的主要步骤流程示意图;
S201:样本图像处理
对训练数据集中每个输入单元中的输入图像(样本图像)进行特征提取并进行掩码处理,获取掩码处理后的图像特征。
将样本图片输入,首先经过特征提取模块,当视觉模型是视觉transformer结构,对于普通ViT而言,特征提取模块是进入堆叠的transformer block前的网络层,如patchembedding层;当视觉模型是层次视觉transformer结构而言(如swin transformer),特征提取模块是patch embedding、前两个阶段的swin transformer block以及patch merge。
在本实施例中,为了更好地获取层次化特征来加强视觉模型在下游任务的表现,使用的是后者层次transformer(swin transformer)结构;特别地,经过patch embedding层后,需要对样本图像特征按照一定比例进行掩码处理,例如,可以选择按照75%的比例进行掩码,由于会使用到局部窗口注意力机制,所以后续需要保留所有的特征序列进行输入,其中掩码掉的部分用可学习的特征向量进行填充。
此外,由于将swin transformer的前两个阶段也归为特征提取模块,第三个阶段作为主阶段,去掉了第四个阶段,由于阶段数的减少,在第三阶段(主阶段)适当增加了swintransformer block的数量,以保证网络的建模能力。
S202:混合注意力机制处理
获取到S201输出的掩码处理后的图像特征后,此时将图像特征输入到主阶段中,对于每个输入单元而言,选择一种目标注意力(例如:全局注意力或局部窗口注意力)并通过视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上。
对于一次完整的数据集训练,模型的主阶段采用的是混合注意力机制处理,且采用多种注意力训练时,共享模型参数,进行混合训练。
S203:特征解码
得到步骤S202的输出特征,将经过多层transformer block的变换进行特征解码。
在本实施例中,在所述视觉模型为所述视觉Transformer结构时,特征提取模块(进入堆叠的transformer block即处理模块前的网络层)在对输入图像进行特征提取并进行掩码处理后,删除掩码掉的特征,保留未被掩码掉的图像特征;此后,获取的所述输出特征需要进行解码处理,基于解码后的输出特征,确定与掩码位置对应的第一目标特征;进行解码时候,需要经过多层transformer block的变换进行特征解码。
S204:监督对象变换
根据第一目标特征类型(例如:原始像素特征、patch特性即区域特征、HoG特征即方向梯度直方图特征、视觉图像特征等),将输入图像变换整理成特定的监督对象,如果是重建原始像素,则将输入图像整理成对应的patch序列即可,如果是重建其他更高级的特征,例如:HoG特征或者经过CLIP提取的特征,那么需要将输入图像根据对应的特征提取方式得到相应的特征。
S205:损失计算
根据掩码位置,获取输入图像中与所述掩码位置对应的特征,以及,输出特征中与所述掩码位置对应的特征,然后进行相应的损失计算,通常损失度量方式可以是L1、Smooth-L1、L2或负余弦距离等;将得到的损失进行回传,通过计算梯度以及更新模型参数,以训练整个视觉模型。
S206:输出训练好的神经网络模型
通过损失计算,对神经网络模型进行训练,多次迭代,不断调整神经网络模型权重,直至完成训练并保存训练的权重文件,输出最终的神经网络模型。
本实施例中,在视觉模型训练时,在完整的一次数据集训练中,采用了混合注意力机制(例如:全局注意力和窗口注意力两种机制)处理的方法,使得经过此种方法得到的视觉模型在下游能够具有较好的兼容性,即无论下游任务使用任何注意力机制(例如:全局注意力机制或局部窗口注意力机制)均能有较好的精度。
本实施中,基于多种注意力共享已有的模型参数,因此在不增加模型参数和训练成本的情况下,能够使得经过训练的视觉模型在下游任务兼容多种注意力机制,提高下游任务处理精度。
本实施中,提出的基于混合注意力机制的方法,得到的自监督预训练视觉模型,在下游任务应用时会具有较好的扩展性,无论下游任务选择全局注意力机制以更好地捕捉全局空间信息,还是选择局部窗口注意力机制来对显存资源的使用更加友好(节省显存资源),所得到的自监督预训练视觉模型均能够很好地适配,并且在下游任务有着较好的精度表现。
进一步,本发明还提供了一种基于多注意力的模型训练系统。如图3所示,本发明实施例中的一种基于多注意力的模型训练系统主要包括特征提取模块31、处理模块32、模型训练模块33;
其中,
特征提取模块31,用于通过视觉模型中的特征提取模块,对训练数据集中每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;
处理模块32,用于针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;
模型训练模块33,用于针对每个输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型。
上述基于多注意力的模型训练系统以用于执行图1所示的基于多注意力的模型训练方法实施例,两者的技术原理、所解决的技术问题及产生的技术效果相似,本技术领域技术人员可以清楚地了解到,为了描述的方便和简洁,基于多注意力的模型训练系统的具体工作过程及有关说明,可以参考基于多注意力的模型训练方法的实施例所描述的内容,此处不再赘述。
本领域技术人员能够理解的是,本发明实现上述一实施例的方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读存储介质可以包括:能够携带所述计算机程序代码的任何实体或装置、介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器、随机存取存储器、电载波信号、电信信号以及软件分发介质等。需要说明的是,所述计算机可读存储介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读存储介质不包括电载波信号和电信信号。
进一步,本发明还提供了一种计算机设备,计算机设备包括处理器和存储装置,存储装置可以被配置成存储执行上述方法实施例的基于多注意力的模型训练方法的程序,处理器可以被配置成用于执行存储装置中的程序,该程序包括但不限于执行上述方法实施例的基于多注意力的模型训练方法的程序。为了便于说明,仅示出了与本发明实施例相关的部分,具体技术细节未揭示的,请参照本发明实施例方法部分。该计算机设备可以是包括各种电子设备形成的计算机设备。
进一步,本发明还提供了一种计算机可读存储介质。在根据本发明的一个计算机可读存储介质实施例中,计算机可读存储介质可以被配置成存储执行上述方法实施例的基于多注意力的模型训练方法的程序,该程序可以由处理器加载并运行以实现上述基于多注意力的模型训练方法。为了便于说明,仅示出了与本发明实施例相关的部分,具体技术细节未揭示的,请参照本发明实施例方法部分。该计算机可读存储介质可以是包括各种电子设备形成的存储装置设备,可选的,本发明实施例中计算机可读存储介质是非暂时性的计算机可读存储介质。
进一步,应该理解的是,由于各个模块的设定仅仅是为了说明本发明的系统的功能单元,这些模块对应的物理器件可以是处理器本身,或者处理器中软件的一部分,硬件的一部分,或者软件和硬件结合的一部分。因此,图中的各个模块的数量仅仅是示意性的。
本领域技术人员能够理解的是,可以对系统中的各个模块进行适应性地拆分或合并。对具体模块的这种拆分或合并并不会导致技术方案偏离本发明的原理,因此,拆分或合并之后的技术方案都将落入本发明的保护范围内。
至此,已经结合附图所示的优选实施方式描述了本发明的技术方案,但是,本领域技术人员容易理解的是,本发明的保护范围显然不局限于这些具体实施方式。在不偏离本发明的原理的前提下,本领域技术人员可以对相关技术特征作出等同的更改或替换,这些更改或替换之后的技术方案都将落入本发明的保护范围之内。

Claims (12)

1.一种基于多注意力的模型训练方法,其特征在于,包括:
通过视觉模型中的特征提取模块,对训练数据集中每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;
针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;
针对每个输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型。
2.根据权利要求1所述的方法,其特征在于,所述视觉模型包括以下至少之一:视觉变换网络结构即视觉Transformer结构、层次视觉变换网络结构即层次视觉Transformer结构。
3.根据权利要求2所述的方法,其特征在于,所述视觉模型为所述层次视觉Transformer结构时,所述特征提取模块包括以下至少之一:进入堆叠的处理模块前的网络层、预设数量的处理模块。
4.根据权利要求3所述的任一方法,其特征在于,对输入图像进行特征提取并进行掩码处理后,对掩码掉的特征,采用可学习的特征向量进行填充,获取掩码处理后的图像特征。
5.根据权利要求2所述的方法,其特征在于,所述视觉模型为所述视觉Transformer结构时,所述特征提取模块为进入堆叠的处理模块前的网络层。
6.根据权利要求5所述的任一方法,其特征在于,对输入图像进行特征提取并进行掩码处理后,删除掩码掉的特征,保留未被掩码掉的图像特征。
7.根据权利要求6所述的方法,其特征在于,基于所述输出特征,确定与掩码位置对应的第一目标特征,包括:
对所述输出特征进行解码,基于解码后的输出特征,确定与掩码位置对应的第一目标特征。
8.根据权利要求1所述的方法,其特征在于,所述目标注意力种类包括以下至少之一:全局注意力、局部窗口注意力、自注意力、空间域注意力、通道域注意力、时间域注意力、混合域注意力。
9.根据权利要求1所述的方法,其特征在于,所述第一目标特征类型包括以下至少之一:原始像素特征、patch特性即区域特征、HoG特征即方向梯度直方图特征、视觉图像特征;所述第二目标特征类型包括以下至少之一:原始像素特征、patch特性即区域特征、HoG特征即方向梯度直方图特征、视觉图像特征。
10.一种基于多注意力的模型训练系统,其特征在于,包括:
特征提取模块,用于通过视觉模型中的特征提取模块,对训练数据集中每个输入单元中的输入图像进行特征提取并进行掩码处理,获取掩码处理后的图像特征;
处理模块,用于针对每个输入单元,选择一种目标注意力并通过所述视觉模型中的处理模块,处理各个输入单元对应的掩码处理后的图像特征,获取对应的各个输出特征;其中,针对所有输入单元,选择的目标注意力种类包括两种,或两种以上;
模型训练模块,用于针对每个输出特征,确定与掩码位置对应的第一目标特征,并基于所述第一目标特征类型,获取所述输入图像中与掩码位置对应的第二目标特征,并基于所述第一目标特征、所述第二目标特征,进行损失计算并回传,训练所述视觉模型。
11.一种计算机设备,包括处理器和存储装置,其中所述存储器中存储有程序,其特征在于,所述处理器执行所述程序时实现权利要求1至9中任一项所述的方法。
12.一种计算机可读存储介质,存储有程序,其特征在于,所述程序被执行时实现权利要求1至9中任一项所述的方法。
CN202410102137.1A 2024-01-24 2024-01-24 基于多注意力的模型训练方法、系统、设备及存储介质 Pending CN117934995A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202410102137.1A CN117934995A (zh) 2024-01-24 2024-01-24 基于多注意力的模型训练方法、系统、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202410102137.1A CN117934995A (zh) 2024-01-24 2024-01-24 基于多注意力的模型训练方法、系统、设备及存储介质

Publications (1)

Publication Number Publication Date
CN117934995A true CN117934995A (zh) 2024-04-26

Family

ID=90757064

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202410102137.1A Pending CN117934995A (zh) 2024-01-24 2024-01-24 基于多注意力的模型训练方法、系统、设备及存储介质

Country Status (1)

Country Link
CN (1) CN117934995A (zh)

Similar Documents

Publication Publication Date Title
CN111898696B (zh) 伪标签及标签预测模型的生成方法、装置、介质及设备
WO2020238560A1 (zh) 视频目标跟踪方法、装置、计算机设备及存储介质
JP7059318B2 (ja) 地域的特徴を有する分類器学習のための学習データ生成方法およびそのシステム
CN111860398B (zh) 遥感图像目标检测方法、系统及终端设备
CN111275107A (zh) 一种基于迁移学习的多标签场景图像分类方法及装置
CN109118504B (zh) 一种基于神经网络的图像边缘检测方法、装置及其设备
KR102140805B1 (ko) 위성 영상의 물체 식별을 위한 뉴럴 네트워크 학습 방법 및 장치
CN111444807A (zh) 目标检测方法、装置、电子设备和计算机可读介质
CN114821058A (zh) 一种图像语义分割方法、装置、电子设备及存储介质
CN116740362A (zh) 一种基于注意力的轻量化非对称场景语义分割方法及系统
CN115344805A (zh) 素材审核方法、计算设备及存储介质
Cong et al. CAN: Contextual aggregating network for semantic segmentation
CN116912923B (zh) 一种图像识别模型训练方法和装置
CN112132867B (zh) 一种遥感影像变化检测方法及装置
CN113742525A (zh) 自监督视频哈希学习方法、系统、电子设备及存储介质
CN110852102B (zh) 一种中文的词性标注方法、装置、存储介质及电子设备
CN111914920A (zh) 一种基于稀疏编码的相似性图像检索方法及系统
CN115393868B (zh) 文本检测方法、装置、电子设备和存储介质
CN116797830A (zh) 一种基于YOLOv7的图像风险分类方法及装置
CN113496228B (zh) 一种基于Res2Net、TransUNet和协同注意力的人体语义分割方法
CN113343979B (zh) 用于训练模型的方法、装置、设备、介质和程序产品
CN117934995A (zh) 基于多注意力的模型训练方法、系统、设备及存储介质
CN112559582A (zh) 一种基于样本对关系传播的小样本学习方法和装置
CN114510592A (zh) 图像分类方法、装置、电子设备及存储介质
US12112524B2 (en) Image augmentation method, electronic device and readable storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination