CN114494814A - 基于注意力的模型训练方法、装置及电子设备 - Google Patents
基于注意力的模型训练方法、装置及电子设备 Download PDFInfo
- Publication number
- CN114494814A CN114494814A CN202210102176.2A CN202210102176A CN114494814A CN 114494814 A CN114494814 A CN 114494814A CN 202210102176 A CN202210102176 A CN 202210102176A CN 114494814 A CN114494814 A CN 114494814A
- Authority
- CN
- China
- Prior art keywords
- output matrix
- output
- attention
- updated
- neural network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 47
- 238000000034 method Methods 0.000 title claims abstract description 43
- 239000011159 matrix material Substances 0.000 claims abstract description 264
- 238000003062 neural network model Methods 0.000 claims abstract description 103
- 238000012545 processing Methods 0.000 claims abstract description 35
- 238000011176 pooling Methods 0.000 claims abstract description 27
- 238000004364 calculation method Methods 0.000 claims abstract description 16
- 230000009467 reduction Effects 0.000 claims abstract description 16
- 238000010606 normalization Methods 0.000 claims abstract description 13
- 230000002452 interceptive effect Effects 0.000 claims description 16
- 238000004590 computer program Methods 0.000 claims description 11
- 239000000126 substance Substances 0.000 claims description 2
- 238000013473 artificial intelligence Methods 0.000 abstract description 4
- 238000013135 deep learning Methods 0.000 abstract description 3
- 238000001514 detection method Methods 0.000 abstract description 3
- 238000004891 communication Methods 0.000 description 8
- 238000013528 artificial neural network Methods 0.000 description 7
- 230000007246 mechanism Effects 0.000 description 7
- 230000008569 process Effects 0.000 description 7
- 230000006870 function Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- 230000008859 change Effects 0.000 description 3
- 230000000875 corresponding effect Effects 0.000 description 3
- 238000013461 design Methods 0.000 description 3
- 230000003993 interaction Effects 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 2
- 230000002596 correlated effect Effects 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 238000000605 extraction Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 238000003491 array Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000011946 reduction process Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种基于注意力的模型训练方法、装置及电子设备,涉及人工智能技术领域,具体为深度学习、计算机视觉技术领域,可应用于图像处理、图像检测等场景。具体实现方案为:获取神经网络模型中注意力模块的注意力输出矩阵,基于神经网络模型的池化层对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵;基于神经网络模型的卷积层对第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,对第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵;基于更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
Description
技术领域
本公开涉及人工智能技术领域,具体为深度学习、计算机视觉技术领域,可应用于图像处理、图像检测等场景,具体涉及一种基于注意力的模型训练方法、装置及电子设备。
背景技术
随着计算机技术的不断发展,各种神经网络模型在诸如图像、文本、语音等领域得到了广泛应用,例如卷积神经网络(Convolutional Neural Network,CNN)作为一种具有深度结构的前馈神经网络,其通过卷积计算实现特征的提取,通过网络结构的加深实现特征从局部到全局的捕获,通过增加通道的方式实现多个维度特征的叠加。目前,技术人员需要具备大量的神经网络结构设计及参数调整经验,耗费大量的硬件资源经多次更换、实验不同结构的神经网络来获得神经网络结构。
发明内容
本公开提供了一种基于注意力的模型训练方法、装置及电子设备。
根据本公开的第一方面,提供了一种基于注意力的模型训练方法,包括:
获取神经网络模型中注意力模块的注意力输出矩阵,所述注意力输出矩阵包括头维度、样本维度和数据块维度;
基于所述神经网络模型的池化层对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵;
基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵;
基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
根据本公开的第二方面,提供了一种基于注意力的模型训练装置,包括:
获取模块,用于获取神经网络模型中注意力模块的注意力输出矩阵,所述注意力输出矩阵包括头维度、样本维度和数据块维度;
池化模块,用于基于所述神经网络模型的池化层对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵;
卷积模块,用于基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
更新模块,用于对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵;
训练模块,用于基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
根据本公开的第三方面,提供了一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行第一方面中所述的方法。
根据本公开的第四方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据第一方面所述的方法。
根据本公开的第五方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据第一方面所述的方法。
本公开实施例中,基于注意力机制,以及对神经网络模型中注意力模块输出的注意力输出矩阵进行池化和卷积操作,使得所述神经网络模型中各个head之间能够进行交互学习,进而让所述神经网络模型更关注重要的head的学习而弱化不重要的head的学习,从而优化所述神经网络模型的网络结构和网络参数,基于优化后的神经网络模型来进行模型训练,基于此训练得到的神经网络模型的最终输出的精确度更高。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是本公开实施例提供的一种基于注意力的模型训练方法的流程图之一;
图2是本公开实施例提供的一种基于注意力的模型训练方法的流程图之二;
图3是本公开实施例提供的一种基于注意力的模型训练装置的结构图;
图4是用来实现本公开实施例的基于注意力的模型训练方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
本公开涉及人工智能技术领域,具体为深度学习、计算机视觉技术领域,可应用于图像处理、图像检测等场景,例如应用于人脸识别、图像识别、行为比对等场景中。以下结合具体实施例,对本公开提供的方案进行解释说明。
请参照图1,图1是本公开实施例提供的一种基于注意力的模型训练方法的流程图之一,如图1所示,所述方法包括以下步骤:
步骤S101、获取神经网络模型中注意力模块的注意力输出矩阵,所述注意力输出矩阵包括头维度、样本维度和数据块维度。
需要说明地,本公开实施例提供的方法可以是应用于计算机、手机、平板电脑等电子设备。
本公开实施例中,所述神经网络模型可以是应用于图像识别或图像分类等神经网络模型,例如视觉转换器(vision transformer)。
可以理解地,vision transformer中主要通过缩放点积注意力(Scaled Dot-Product Attention)的方式来获得不同数据块(patch,或者也称补丁)之间的注意力(attention)关系。本公开实施例中,可以是基于多头注意力机制来对神经网络模型(例如vision transformer)进行优化和训练,能够让学习到的注意力关系更加丰富。具体地,神经网络模型引入多头注意力机制,通过获得不同的头(head)输出,将这些head分别输入注意力模块,例如缩放点积注意力模块,获取缩放点积注意力模块的attention输出,所述输出为矩阵,也即获取缩放点积注意力模块输出的注意力输出矩阵。
需要说明地,所述注意力输出矩阵包括三个不同的维度:头维度(head number)、样本维度(sample number)、数据块维度(patch number)。其中,所述神经网络模型引入了多头注意力机制,所述头维度也就大于1;而神经网络模型通常需要大量的样本进行训练,所述样本维度也大于1;在神经网络模型的训练过程中,针对每一个样本,会划分成多个数据块,进行特征提取学习,所述数据块维度也大于1。
步骤S102、基于所述神经网络模型的池化层对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵。
可以理解地,神经网络模型的池化层用于对输入的矩阵进行池化操作。可选地,所述池化操作为全局平均池化(Global Average Pooling,GAP)操作。本公开实施例中,在多头注意力机制中,对神经网络模型中经过缩放点积注意力(Scaled Dot-ProductAttention)模块输出的注意力输出矩阵进行GAP操作,以对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵。
需要说明地,所述池化操作为一个信息压缩的过程,或者也可称为降维过程。本公开实施例中,基于所述池化操作,针对性地对所述注意力输出矩阵的样本维度(samplenumber)和数据块维度(patch number)进行降维处理,而所述注意力输出矩阵的头维度(head number)保持不变。也就是说,池化后的样本维度和数据块维度均要小于池化前的维度,例如,将样本维度从10降维到1,将数据块维度从10降维到2。进而,通过池化操作,也就能够对所述注意力输出矩阵进行降维,有效减少网络参数。
可选地,所述池化后的第一输出矩阵中的样本维度和数据块维度均为1。也就是说,通过池化操作,将所述注意力输出矩阵的样本维度和数据块维度降维处理为1,这样也就能够有效降低数据量,减少网络参数,更有利于后续流程中各个head之间的学习交互。
步骤S103、基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值。
本公开实施例中,在对所述注意力输出矩阵的样本维度和数据块维度进行降维处理,得到第一输出矩阵后,通过卷积层对所述第一输出矩阵进行卷积(conv)操作,获得卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头(head)的输出值。
需要说明地,所述卷积操作中的卷积核为1×1的卷积核。通过所述卷积操作,也就使得所述第一输出矩阵中的各个head能够实现交互学习,进而使得第一输出矩阵中各个head的特征值得以调整,得到卷积后的第二输出矩阵,所述第二输出矩阵中各个head的输出值也即第一输出矩阵中各个head调整了的特征值。
步骤S104、对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵。
本公开实施例中,获得卷积后的第二输出矩阵后,对所述第二输出矩阵中各个head的输出值进行归一化处理和加权处理,进而以对所述第二输出矩阵中各个head的输出值进行更新,得到更新后的第二输出矩阵。其中,所述归一化处理可以是基于归一化指数函数,例如softmax函数来实现,通过归一化处理来计算获得各个head的权重值,然后再对不同的head进行加权处理,以实现对第二输出矩阵的更新。
步骤S105、基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
需要说明地,在对第二输出矩阵进行归一化处理和加权处理,得到更新后的第二输出矩阵后,基于所述更新后的第二输出矩阵与所述注意力输出矩阵进行加权乘计算,以得到更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
其中,所述加权乘计算可以是将所述更新后的第二输出矩阵与步骤S101中的所述注意力输出矩阵相乘。可以理解地,所述更新后的第二输出矩阵中的样本维度和数据块维度均为1,而所述注意力输出矩阵中的样本维度和数据块维度为降维处理前的维度,经过所述加权乘计算得到的所述更新后的注意力输出矩阵,其样本维度和数据块维度与所述注意力输出矩阵中的样本维度和数据块维度一致。例如,所述更新后的第二输出矩阵中的头维度×样本维度×数据块维度为10×1×1,所述注意力输出矩阵中的头维度×样本维度×数据块维度为10×10×10,经过所述加权乘计算得到的所述更新后的注意力输出矩阵中的头维度×样本维度×数据块维度为10×10×10。这样,也就使得所述更新后的注意力输出矩阵中的各维度与原始的所述注意力输出矩阵中的各维度保持一致,避免因池化操作导致的降维而对影响神经网络模型的训练造成影响。
可选地,在获得更新后的注意力输出矩阵后,基于所述更新后的注意力输出矩阵对所述神经网络模型进行训练或者说优化。可以理解地,神经网络模型是由存在输入输出关系的大量网络结构和网络参数构成的,其中一个网络结构和/或网络参数的改变,会影响与其存在直接或间接连接关系的网络结构的输出,进而调整神经网络模型的最终输出。本公开实施例中,更新后的注意力输出矩阵中的各参数相比于原始的所述注意力输出矩阵的各参数发生了变化,参数的变化也就会影响神经网络模型中其他网络结构的输出,这样也就能够对神经网络模型的网络结构和/或网络参数进行优化,进而改变神经网络模型的最终输出,进而以实现对神经网络模型的训练,提升神经网络模型的输出精度。需要说明地,所述神经网络模型的训练可以是参照相关技术,例如电子设备在获取输入神经网络模型的样本后,将该样本输入注意力模块以获取输出的注意力输出矩阵,然后基于上述步骤对所述注意力输出矩阵进行池化、卷积、归一化等操作,以改变神经网络模型中后续层级网络结构的输出,进而以对神经网络模型进行优化和训练,以提升神经网络模型的最终输出的精确度。
本公开实施例中,基于多头注意力机制,以及对神经网络模型中缩放点积注意力模块输出的注意力输出矩阵进行池化和卷积操作,使得所述神经网络模型中各个head之间能够进行交互学习,以使得不同head之间的注意力输出进行互相关,得到不同head的权重,进而让所述神经网络模型更关注重要的head的学习而弱化不重要的head的学习,以对所述神经网络模型进行训练,从而使得所述神经网络模型的最终输出的精确度更高。
另外,相比于现有技术中如何获得性能较好的神经网络结构,需要技术人员具备大量的神经网络结构设计及参数调整经验,多次更换、实验不同结构的神经网络需要耗费大量的硬件资源,而本公开实施例中,通过电子设备针对性地对神经网络模型中注意力模块输出的注意力输出矩阵进行池化和卷积操作,能够对神经网络模型的网络参数实现降维处理,进而能得到最佳的神经网络模型,使得神经网络的性能得到提高,也有效节省了计算资源和存储资源。
可选地,所述步骤S103、基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值,包括:
基于所述神经网络模型的卷积层对所述第一输出矩阵中各个头的输出值进行交互学习以实现卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值。
其中,所述第二输出矩阵中目标头的输出值为所述第一输出矩阵中对应的所述目标头的输出值经交互学习后得到的输出值,所述目标头为所述第二输出矩阵中任一个头。
本公开实施例中,所述神经网络模型引入了多头注意力机制,在基于池化层对注意力输出矩阵的样本维度和数据块维度进行降维处理后,获得池化后的第一输出矩阵,也即第一输出矩阵的头维度是和所述注意力输出矩阵的头维度保持一致的。可选地,所述池化后的第一输出矩阵的样本维度和数据块维度均为1,基于卷积层的1×1的卷积核对所述第一输出矩阵进行卷积操作,进而也就使得所述第一输出矩阵中的各个头进行交互学习,从而影响各个头的权重和输出值,获得卷积后的第二输出矩阵。其中,所述第二输出矩阵中某个头的输出值也即第一输出矩阵中该头经交互学习后得到的输出值。
本公开实施例中,通过卷积层的卷积操作来实现第一输出矩阵中各个头之间的交互学习,以使得不同head之间的注意力输出进行互相关,让所述神经网络模型更关注重要的head的学习而弱化不重要的head的学习,进而以改变各个头的权重和输出值,以优化所述神经网络模型的网络结构和网络参数,依此对神经网络模型进行训练,从而提升所述神经网络模型的最终输出的精确度。
可选地,所述步骤S104、对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵,包括:
对所述第二输出矩阵中各个头的输出值进行归一化处理,获得所述第二输出矩阵中各个头的权重值;
基于所述第二输出矩阵中目标头的权重值对所述目标头的输出值进行加权处理,所述目标头为所述第二输出矩阵中任一个头;
基于所述加权处理,获得更新后的第二输出矩阵。
本公开实施例中,在基于卷积层的卷积操作获得第二输出矩阵后,对所述第二输出矩阵中各个头的输出值可以是进行归一化处理,例如基于归一化指数函数(softmax)来获得所述第二输出矩阵中各个头的权重值。
进一步地,对所述第二输出矩阵中各个头的输出值进行加权处理。例如对于所述第二输出矩阵中的某个头,将该头的输出值与该头对应的权重值相乘,以获得该头更新后的输出值。基于这样的加权处理,从而也就能够对所述第二输出矩阵的各个头的输出值进行更新,进而得到更新后的第二输出矩阵。这样,也就能够让所述神经网络模型更关注重要的头的学习,从而对所述神经网络模型进行优化,并依此优化后的神经网络模型进行训练,提升神经网络模型的输出精度。
可选地,所述步骤S105、基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型,包括:
对所述注意力输出矩阵及所述更新后的第二输出矩阵进行矩阵乘计算,获取更新后的注意力输出矩阵;
基于所述更新后的注意力输出矩阵训练所述神经网络模型。
需要说明地,在基于上述方式获得更新后的第二输出矩阵后,可以理解地,所述更新后的第二输出矩阵的样本维度和数据块维度是步骤S101中的所述注意力输出矩阵经过降维处理后的维度,例如这两个维度都为1,如果继续基于这样的维度来进行所述神经网络模型的网络参数的传递,会导致所述神经网络模型的最终输出出现偏差。
本公开实施例中,将所述更新后的第二输出矩阵与所述注意力输出矩阵进行矩阵乘计算,以得到更新后的注意力输出矩阵,进而所述更新后的注意力输出矩阵的各个参数维度和所述注意力输出矩阵的各个参数维度保持一致,避免神经网络模型的最终输出出现偏差。而所述更新后的注意力输出矩阵同时也是基于所述更新后的第二输出矩阵得到,进而基于所述更新后的注意力输出矩阵来继续神经网络模型中网络结构和网络参数的传递,也就使得所述神经网络模型的网络参数得到了优化和训练,让所述神经网络模型更加关注重要的head的学习,以提升所述神经网络模型的最终输出的精确度。
请参照图2,图2是本公开实施例提供的一种基于注意力的模型训练方法的流程图之二,如图2所示,获取神经网络模型中的注意力输出矩阵A,该注意力输出矩阵A包括头维度(head num)、样本维度(sample num)、数据块维度(patch num),对该注意力输出矩阵A进行全局平均池化(GAP)操作,得到第一输出矩阵A_p,该第一输出矩阵A_p的样本维度和数据块维度均为1,头维度(head num)不变,第一输出矩阵A_p的参数维度表示为head num×1×1;对该第一输出矩阵A_p进行卷积核为1×1的卷积(conv)操作和归一化(softmax)操作,得到第二输出矩阵A_w,该第二输出矩阵A_w的参数维度仍然为head num×1×1;将该第二输出矩阵A_w与注意力输出矩阵A进行矩阵乘(scale)操作,得到更新后的注意力输出矩阵A_new,该更新后的注意力输出矩阵A_new的头维度(head num)、样本维度(sample num)和数据块维度(patch num)与注意力输出矩阵A的头维度、样本维度和数据块维度一致,以基于更新后的注意力输出矩阵A_new来实现对神经网络模型的训练。
本公开实施例中,所述神经网络模型可以为视觉转换器(vision transformer),在神经网络模型的head输出上添加基于头(head-based)的压缩和激发(Squeeze-and-Excitation,SE)模块,也即通过对注意力输出矩阵进行池化和卷积操作,让神经网络模型不同head之间的注意力输出进行互相关,也即让神经网络模型的各个head之间能够进行交互学习,进而让所述神经网络模型更关注重要的head的学习而弱化不重要的head的学习,从而使得所述神经网络模型的最终输出的精确度更高。
请参照图3,图3是本公开实施例提供的一种基于注意力的模型训练装置的结构图,如图3所示,基于注意力的模型训练装置300包括:
获取模块301,用于获取神经网络模型中注意力模块的注意力输出矩阵,所述注意力输出矩阵包括头维度、样本维度和数据块维度;
池化模块302,用于基于所述神经网络模型的池化层对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵;
卷积模块303,用于基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
更新模块304,用于对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵;
训练模块305,用于基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
可选地,所述池化后的第一输出矩阵中的样本维度和数据块维度均为1。
可选地,所述卷积模块303还用于:
基于所述神经网络模型的卷积层对所述第一输出矩阵中各个头的输出值进行交互学习以实现卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
其中,所述第二输出矩阵中目标头的输出值为所述第一输出矩阵中对应的所述目标头的输出值经交互学习后得到的输出值,所述目标头为所述第二输出矩阵中任一个头。
可选地,所述更新模块304还用于:
对所述第二输出矩阵中各个头的输出值进行归一化处理,获得所述第二输出矩阵中各个头的权重值;
基于所述第二输出矩阵中目标头的权重值对所述目标头的输出值进行加权处理,所述目标头为所述第二输出矩阵中任一个头;
基于所述加权处理,获得更新后的第二输出矩阵。
可选地,所述训练模块305还用于:
对所述注意力输出矩阵及所述更新后的第二输出矩阵进行矩阵乘计算,获取更新后的注意力输出矩阵;
基于所述更新后的注意力输出矩阵训练所述神经网络模型。
本公开实施例中,通过对注意力输出矩阵进行池化和卷积操作,使得神经网络模型中各个head之间能够进行交互学习,进而让所述神经网络模型更关注重要的head的学习而弱化不重要的head的学习,从而使得所述神经网络模型的最终输出的精确度更高。
需要说明地,本公开实施例提供的基于注意力的模型训练装置300能够实现上述图1和图2所述基于注意力的模型训练方法实施例中的全部技术方案,因此至少能够实现上述图1和图2所述方法实施例的全部技术效果,此处不再赘述。
本公开的技术方案中,所涉及的用户个人信息的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图4示出了可以用来实施本公开的实施例的示例电子设备400的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图4所示,电子设备400包括计算单元401,其可以根据存储在只读存储器(ROM)402中的计算机程序或者从存储单元408加载到随机访问存储器(RAM)403中的计算机程序,来执行各种适当的动作和处理。在RAM 403中,还可存储电子设备400操作所需的各种程序和数据。计算单元401、ROM 402以及RAM 403通过总线404彼此相连。输入/输出(I/O)接口405也连接至总线404。
电子设备400中的多个部件连接至I/O接口405,包括:输入单元406,例如键盘、鼠标等;输出单元407,例如各种类型的显示器、扬声器等;存储单元408,例如磁盘、光盘等;以及通信单元409,例如网卡、调制解调器、无线通信收发机等。通信单元409允许电子设备400通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元401可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元401的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元401执行上文所描述的各个方法和处理,例如基于注意力的模型训练方法。例如,在一些实施例中,基于注意力的模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元408。在一些实施例中,计算机程序的部分或者全部可以经由ROM 402和/或通信单元409而被载入和/或安装到电子设备400上。当计算机程序加载到RAM 403并由计算单元401执行时,可以执行上文描述的基于注意力的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元401可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行基于注意力的模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (13)
1.一种基于注意力的模型训练方法,包括:
获取神经网络模型中注意力模块的注意力输出矩阵,所述注意力输出矩阵包括头维度、样本维度和数据块维度;
基于所述神经网络模型的池化层对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵;
基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵;
基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
2.根据权利要求1所述的方法,其中,所述池化后的第一输出矩阵中的样本维度和数据块维度均为1。
3.根据权利要求1所述的方法,其中,所述基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值,包括:
基于所述神经网络模型的卷积层对所述第一输出矩阵中各个头的输出值进行交互学习以实现卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
其中,所述第二输出矩阵中目标头的输出值为所述第一输出矩阵中对应的所述目标头的输出值经交互学习后得到的输出值,所述目标头为所述第二输出矩阵中任一个头。
4.根据权利要求1所述的方法,其中,所述对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵,包括:
对所述第二输出矩阵中各个头的输出值进行归一化处理,获得所述第二输出矩阵中各个头的权重值;
基于所述第二输出矩阵中目标头的权重值对所述目标头的输出值进行加权处理,所述目标头为所述第二输出矩阵中任一个头;
基于所述加权处理,获得更新后的第二输出矩阵。
5.根据权利要求1所述的方法,其中,所述基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型,包括:
对所述注意力输出矩阵及所述更新后的第二输出矩阵进行矩阵乘计算,获取更新后的注意力输出矩阵;
基于所述更新后的注意力输出矩阵训练所述神经网络模型。
6.一种基于注意力的模型训练装置,包括:
获取模块,用于获取神经网络模型中注意力模块的注意力输出矩阵,所述注意力输出矩阵包括头维度、样本维度和数据块维度;
池化模块,用于基于所述神经网络模型的池化层对所述注意力输出矩阵的样本维度和数据块维度进行降维计算,确定池化后的第一输出矩阵;
卷积模块,用于基于所述神经网络模型的卷积层对所述第一输出矩阵进行卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
更新模块,用于对所述第二输出矩阵中各个头的输出值进行归一化处理和加权处理,获得更新后的第二输出矩阵;
训练模块,用于基于所述更新后的第二输出矩阵获取更新后的注意力输出矩阵,并基于所述更新后的注意力输出矩阵训练所述神经网络模型。
7.根据权利要求6所述的装置,其中,所述池化后的第一输出矩阵中的样本维度和数据块维度均为1。
8.根据权利要求6所述的装置,其中,所述卷积模块还用于:
基于所述神经网络模型的卷积层对所述第一输出矩阵中各个头的输出值进行交互学习以实现卷积操作,确定卷积后的第二输出矩阵,并获取所述第二输出矩阵中各个头的输出值;
其中,所述第二输出矩阵中目标头的输出值为所述第一输出矩阵中对应的所述目标头的输出值经交互学习后得到的输出值,所述目标头为所述第二输出矩阵中任一个头。
9.根据权利要求6所述的装置,其中,所述更新模块还用于:
对所述第二输出矩阵中各个头的输出值进行归一化处理,获得所述第二输出矩阵中各个头的权重值;
基于所述第二输出矩阵中目标头的权重值对所述目标头的输出值进行加权处理,所述目标头为所述第二输出矩阵中任一个头;
基于所述加权处理,获得更新后的第二输出矩阵。
10.根据权利要求6所述的装置,其中,所述训练模块还用于:
对所述注意力输出矩阵及所述更新后的第二输出矩阵进行矩阵乘计算,获取更新后的注意力输出矩阵;
基于所述更新后的注意力输出矩阵训练所述神经网络模型。
11.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-5中任一项所述的方法。
12.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-5中任一项所述的方法。
13.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-5中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210102176.2A CN114494814A (zh) | 2022-01-27 | 2022-01-27 | 基于注意力的模型训练方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210102176.2A CN114494814A (zh) | 2022-01-27 | 2022-01-27 | 基于注意力的模型训练方法、装置及电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114494814A true CN114494814A (zh) | 2022-05-13 |
Family
ID=81476623
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210102176.2A Pending CN114494814A (zh) | 2022-01-27 | 2022-01-27 | 基于注意力的模型训练方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114494814A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114819149A (zh) * | 2022-06-28 | 2022-07-29 | 深圳比特微电子科技有限公司 | 基于变换神经网络的数据处理方法、装置和介质 |
CN114999637A (zh) * | 2022-07-18 | 2022-09-02 | 华东交通大学 | 多角度编码与嵌入式互学习的病理图像诊断方法与系统 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111626119A (zh) * | 2020-04-23 | 2020-09-04 | 北京百度网讯科技有限公司 | 目标识别模型训练方法、装置、设备以及存储介质 |
CN113379655A (zh) * | 2021-05-18 | 2021-09-10 | 电子科技大学 | 一种基于动态自注意力生成对抗网络的图像合成方法 |
-
2022
- 2022-01-27 CN CN202210102176.2A patent/CN114494814A/zh active Pending
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111626119A (zh) * | 2020-04-23 | 2020-09-04 | 北京百度网讯科技有限公司 | 目标识别模型训练方法、装置、设备以及存储介质 |
CN113379655A (zh) * | 2021-05-18 | 2021-09-10 | 电子科技大学 | 一种基于动态自注意力生成对抗网络的图像合成方法 |
Non-Patent Citations (1)
Title |
---|
QILONG WANG等: "ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks", 2020 IEEE/CVF CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION (CVPR), 5 August 2020 (2020-08-05), pages 11531 - 11539 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114819149A (zh) * | 2022-06-28 | 2022-07-29 | 深圳比特微电子科技有限公司 | 基于变换神经网络的数据处理方法、装置和介质 |
CN114999637A (zh) * | 2022-07-18 | 2022-09-02 | 华东交通大学 | 多角度编码与嵌入式互学习的病理图像诊断方法与系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113239705A (zh) | 语义表示模型的预训练方法、装置、电子设备和存储介质 | |
CN114494814A (zh) | 基于注意力的模型训练方法、装置及电子设备 | |
CN112580732B (zh) | 模型训练方法、装置、设备、存储介质和程序产品 | |
CN113705628B (zh) | 预训练模型的确定方法、装置、电子设备以及存储介质 | |
CN115456167B (zh) | 轻量级模型训练方法、图像处理方法、装置及电子设备 | |
CN114821063A (zh) | 语义分割模型的生成方法及装置、图像的处理方法 | |
CN113516185B (zh) | 模型训练的方法、装置、电子设备及存储介质 | |
CN113642710A (zh) | 一种网络模型的量化方法、装置、设备和存储介质 | |
CN113052063A (zh) | 置信度阈值选择方法、装置、设备以及存储介质 | |
CN117351299A (zh) | 图像生成及模型训练方法、装置、设备和存储介质 | |
CN113361621B (zh) | 用于训练模型的方法和装置 | |
CN113642654B (zh) | 图像特征的融合方法、装置、电子设备和存储介质 | |
CN116363444A (zh) | 模糊分类模型训练方法、识别模糊图像的方法及装置 | |
CN114707638A (zh) | 模型训练、对象识别方法及装置、设备、介质和产品 | |
CN114898742A (zh) | 流式语音识别模型的训练方法、装置、设备和存储介质 | |
CN115482443A (zh) | 图像特征融合及模型训练方法、装置、设备以及存储介质 | |
CN112784967B (zh) | 信息处理方法、装置以及电子设备 | |
CN113361575A (zh) | 模型训练方法、装置和电子设备 | |
CN114120416A (zh) | 模型训练方法、装置、电子设备及介质 | |
CN114254028A (zh) | 事件属性抽取方法、装置、电子设备和存储介质 | |
CN114021642A (zh) | 数据处理方法、装置、电子设备和存储介质 | |
CN116151215B (zh) | 文本处理方法、深度学习模型训练方法、装置以及设备 | |
CN114549948B (zh) | 深度学习模型的训练方法、图像识别方法、装置和设备 | |
CN114186097A (zh) | 用于训练模型的方法和装置 | |
CN115758142A (zh) | 深度学习模型的训练方法、数据处理方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |