CN116805162A - 基于自监督学习的Transformer模型训练方法 - Google Patents
基于自监督学习的Transformer模型训练方法 Download PDFInfo
- Publication number
- CN116805162A CN116805162A CN202310475772.XA CN202310475772A CN116805162A CN 116805162 A CN116805162 A CN 116805162A CN 202310475772 A CN202310475772 A CN 202310475772A CN 116805162 A CN116805162 A CN 116805162A
- Authority
- CN
- China
- Prior art keywords
- model
- student
- parameters
- teacher
- self
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 79
- 238000000034 method Methods 0.000 title claims abstract description 73
- 238000009826 distribution Methods 0.000 claims description 52
- 230000006870 function Effects 0.000 claims description 20
- 238000011176 pooling Methods 0.000 claims description 8
- 238000013140 knowledge distillation Methods 0.000 abstract description 31
- 238000012546 transfer Methods 0.000 abstract description 5
- 230000008569 process Effects 0.000 description 19
- 238000004422 calculation algorithm Methods 0.000 description 11
- 238000007781 pre-processing Methods 0.000 description 11
- 238000004590 computer program Methods 0.000 description 10
- 238000010586 diagram Methods 0.000 description 10
- 238000012545 processing Methods 0.000 description 9
- 238000004891 communication Methods 0.000 description 8
- 230000000694 effects Effects 0.000 description 7
- 238000004821 distillation Methods 0.000 description 6
- 238000005516 engineering process Methods 0.000 description 4
- 230000006872 improvement Effects 0.000 description 4
- 239000013598 vector Substances 0.000 description 4
- 230000003287 optical effect Effects 0.000 description 3
- 238000013459 approach Methods 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 230000006835 compression Effects 0.000 description 2
- 238000007906 compression Methods 0.000 description 2
- 238000010276 construction Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 230000014509 gene expression Effects 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 239000004973 liquid crystal related substance Substances 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 238000009827 uniform distribution Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000009499 grossing Methods 0.000 description 1
- 238000011423 initialization method Methods 0.000 description 1
- 238000007689 inspection Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本公开的实施例提供了一种基于自监督学习的Transformer模型训练方法。应用于深度学习领域,所述方法包括将训练数据集输入一个包含教师模型和学生模型的预设的Transformer模型,每个模型的输出有两部分,一个对应目标类任务,一个对应非目标类任务,将两部分输入到一个带有超参数的解耦Kullback‑Leibler散度公式,获取损失值,利用梯度反向传播,更新学生模型的自身参数并不断调整超参数,同时基于学生模型的自身参数更新教师模型的自身参数,直到学生模型达到收敛。以此方式,本公开将传统的知识蒸馏思想解耦成目标知识蒸馏和非目标知识蒸馏,使目标知识蒸馏和非目标知识蒸馏两部分各自的权重更易调整,最终提高了教师模型和学生模型间的知识转移有效性。
Description
技术领域
本发明涉及深度学习技术领域,尤其涉及一种基于自监督学习的Transformer模型训练方法。
背景技术
自监督学习是一种模型训练方式,主要是利用辅助任务从大规模的无监督数据中挖掘自身的监督信息,通过这种构造的监督信息对网络进行训练,从而可以学习到对下游任务有价值的表征。自监督学习技术在一定程度上缓解了下游任务训练数据少,训练数据无标签等情况带来的模型性能欠佳问题。
知识蒸馏是一种模型压缩方式,采用一种采用“教师-学生模型架构”和Kullback-Leibler散度模型的训练方法。其中教师模型是任一种已完成训练且效果较好的模型,学生模型是一个随机初始化的模型,学生模型需通过知识蒸馏技术学习到教师模型的暗知识,以达到教师模型的效果。知识蒸馏领域出现了新的方法,例如自蒸馏。自蒸馏是采用自监督学习进行知识蒸馏。它也是采用“教师-学生网模型”,与知识蒸馏不同的是,教师模型和学生模型是同结构,且教师模型和学生模型是共同训练,而非使用已经完成训练且效果较好的模型作为教师模型。
最近,基于自监督学习的预训练模式以其良好的数据利用效率和泛化能力,在计算机视觉中获得了非常巨大的进展。随着自监督学习的不断发展,各种算法层出不穷。其中,一种名为DINO模型的采用了传统的自蒸馏方式的自监督学习算法取得了良好的成绩。DINO是一种基于自监督学习的Transformer模型训练方法。通过在相同图像的不同视图上匹配模型输出,DINO模型算法使用教师模型和学生模型的logit分布进行自蒸馏,能够有效地发现目标对象和跨图像的共享特征。DINO模型算法的研发,使得视觉技术又向前迈进了一大步,省去以往传统的数据标注才能识别,减少计算的过程,降低训练时间,更加高效便捷地识别人或物,软硬件协作更为紧密。
然而,DINO模型算法在多项实验中表现不及基于深度特征的蒸馏方式。因为目标与非目标logit高度耦合,DINO模型算法无法防止非目标知识蒸馏部分中暗知识的转移。也无法控制目标知识蒸馏和非目标知识蒸馏两部分各自的知识贡献。
发明内容
本公开提供了一种基于自监督学习的Transformer模型训练方法、装置、设备以及存储介质。
根据本公开的第一方面,提供了一种基于自监督学习的Transformer模型训练方法。该方法包括:获取训练数据集;将所述训练数据集输入预设的Transformer模型,所述Transformer模型包括教师模型和学生模型;将所述教师模型和所述学生模型的输出输入到带有超参数的解耦Kullback-Leibler散度公式,获取损失值;基于所述损失值,利用梯度反向传播,更新所述学生模型的自身参数并不断调整所述超参数,同时基于所述学生模型的自身参数更新教师模型的自身参数,直到所述学生模型达到收敛,保存所述学生模型。
如上所述的方面和任一可能的实现方式,进一步提供一种实现方式,所述训练数据集包括:多组图像数据;每组图像数据包含对应同一原始图像的选定数量的全局视图和选定数量的局部视图;将每组图像数据输入学生模型;将每组图像数据中的全局视图输入教师模型。
如上所述的方面和任一可能的实现方式,进一步提供一种实现方式,所述学生模型包括:骨干模型,全局平均池化层和两个Softmax层,所述两个softmax层具有不同输出维度,一个Softmax层对应的是目标类任务,另一个Softmax层对应的是非目标类任务。
如上所述的方面和任一可能的实现方式,进一步提供一种实现方式,所述教师模型包括:骨干模型,center层,全局平均池化层和两个Softmax层,所述两个softmax层具有不同输出维度,一个Softmax层对应的是目标类任务,另一个Softmax层对应的是非目标类任务。
如上所述的方面和任一可能的实现方式,进一步提供一种实现方式,所述解耦Kullback-Leibler散度公式为:
其中,DKL(pT||pS)是解耦Kullback-Leibler散度;pT和pS分别为教师目标logit分布和学生目标logit分布;是目标类概率分布的Kullback-Leibler散度;/>和/>是pT和pS的二进制表示;/>是非目标类概率分布的Kullback-Leibler散度;/>和/>分别为教师非目标类别概率分布和学生非目标类别概率分布;γ和δ是超参数。
如上所述的方面和任一可能的实现方式,进一步提供一种实现方式,所述更新教师模型的自身参数包括:利用所述学生模型的指数滑动平均参数来更新所述教师模型的自身参数;所述学生模型的指数滑动平均参数是学生模型自身参数的指数移动平均值。
根据本公开的第二方面,提供了一种基于自监督学习的Transformer模型训练装置。该装置包括:数据单元,用于获取训练数据集;Transformer模型单元,用于将所述训练数据集输入预设的Transformer模型,所述Transformer模型包括教师模型和学生模型;损失函数单元,用于将所述教师模型和所述学生模型的输出输入到带有超参数的解耦Kullback-Leibler散度公式,获取损失值;训练单元,用于基于所述损失值,利用梯度反向传播,更新所述学生模型的自身参数并不断调整所述超参数,同时基于所述学生模型的自身参数更新教师模型的自身参数,直到所述学生模型达到收敛,保存所述Transformer模型的学生模型。
根据本公开的第三方面,提供了一种电子设备。该电子设备包括:存储器和处理器,所述存储器上存储有计算机程序,所述处理器执行所述程序时实现如以上所述的方法。
根据本公开的第四方面,提供了一种计算机可读存储介质,其上存储有计算机程序,所述程序被处理器执行时实现如根据本公开的第一方面和/或第二发面的方法。
本公开提出一种基于自监督学习的Transformer模型训练方法,引入了解耦知识蒸馏思想和重新设计了整体模型架构,改进了DINO模型,解耦了Kullback-Leibler散度公式,将传统的知识蒸馏思想解耦成目标知识蒸馏和非目标知识蒸馏,解除非目标知识蒸馏和教师模型logit分布的负相关关系,引入两个可控超参数,使目标知识蒸馏和非目标知识蒸馏两部分各自的权重更易调整,最终提高了教师模型和学生模型间的知识转移有效性,进而提高整体训练效果和效率。
应当理解,发明内容部分中所描述的内容并非旨在限定本公开的实施例的关键或重要特征,亦非用于限制本公开的范围。本公开的其它特征将通过以下的描述变得容易理解。
附图说明
结合附图并参考以下详细说明,本公开各实施例的上述和其他特征、优点及方面将变得更加明显。附图用于更好地理解本方案,不构成对本公开的限定在附图中,相同或相似的附图标记表示相同或相似的元素,其中:
图1示出了根据示例性实施例的基于自监督学习的Transformer模型训练系统的架构框图;
图2示出了根据本公开的实施例的基于自监督学习的Transformer模型训练的流程图;
图3示出了根据本公开的实施例的数据增强流程图;
图4示出了根据本公开的实施例的Transformer模型结构图;
图5示出了根据本公开的实施例的解耦Kullback-Leibler散度公式的流程图;
图6示出了根据本公开的实施例的利用指数滑动平均参数更新教师模型的流程图;
图7示出了根据本公开的实施例的一种基于自监督学习的Transformer模型训练装置700的框图;
图8示出了能够实施本公开的实施例的示例性电子设备的方框图。
具体实施方式
为使本公开实施例的目的、技术方案和优点更加清楚,下面将结合本公开实施例中的附图,对本公开实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本公开一部分实施例,而不是全部的实施例。基于本公开中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的全部其他实施例,都属于本公开保护的范围。
另外,本文中术语“和/或”,仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。另外,本文中字符“/”,一般表示前后关联对象是一种“或”的关系。
本公开中,通过改进了DINO模型和解耦Kullback-Leibler散度公式,将传统的知识蒸馏思想解耦成目标知识蒸馏和非目标知识蒸馏,解除非目标知识蒸馏和教师模型logit分布的负相关关系,引入两个可控超参数,最终提高了教师模型和学生模型间的知识转移有效性,进而提高整体训练效果和效率。
图1示出了根据示例性实施例的基于自监督学习的Transformer模型训练系统的架构框图。
在自监督学习的Transformer模型训练系统100中包括数据预处理模块110、模型训练模块120和预测模块130。
110是数据预处理模块,
在一些实施例中,数据预处理模块用于获取的数据,包括图像数据等数据进行数据预处理。数据预处理模块的作用是将原始数据处理成为网络模型可以接收的数据格式。数据预处理模块包括筛选单元111和图像增强单元112。预处理模块首先是从图像数据集中选出用于模型训练和检验的数据集。然后再将图像处理成符合模型训练要求的数据。如果图像的丰富度不满足要求,可以通过图像缩放调整图像的大小,也可以对数据进行图像剪裁增加数据的丰富度。图像缩放是通过将图像的高度和宽度乘以缩放因子来更加图像大小和像素的空间范围和纵横比。图像裁剪是提取图像的选定子区域并保留该区域的每个像素的空间范围。预处理模块也通过对图像的旋转,翻转和平移变换对图像进行增强。当使用图像增强的时候,虽然参见模型训练的图像的实际数量不变,但是训练过程中,每轮使用的数据集会不同。
120是模型训练模块,
在一些实施例中,模型训练模块包括数据单元121,Transformer模型单元122,损失函数单元123,训练单元124等单元。模型训练模块通常可以调整Transformer模型初始参数,可以调整训练选项。也可以选择训练数据和评估数据。同时调整损失函数或是选择不同的损失函数,选择模型准确度评估方法。另外可以根据损失函数输出监控训练过程。并不断调整数据集,调整超参数,观察损失函数输出来控制模型训练过程。
130是预测模块,
在一些实施例中,预测模块包括数据单元131,模型存储单元132,预测单元133和输出管理单元134。数据单元是用来管理待预测数据。模型存储单元是用来保存和调用训练好的模型。预测单元是加载训练好的模型,将待预测数据按照一定的顺序输入到训练好的模型。输出管理单元是保存模型的输出,并管理输出图像和待预测数据的关系。在一些实施例,输出管理单元需要对输出数据进行数据增强操作,例如色彩抖动、对比度调整和分辨率增强。
图2示出了根据本公开的实施例的一种基于自监督学习的Transformer模型训练方法的流程图。
参考图2,
在框210,
在一些实施例中,获取训练数据集。
需要说明的是,本公开所描述的模型训练的过程中,不需要对数据进行标注,因为本公开是一种基于自监督学习的Transformer模型训练方法,通过无监督的方式来训练模型。这种方法不需要使用人工标注的标签数据,而是使用自然的数据本身作为监督信号,从而可以有效地利用大量未标记的数据来训练模型。
在一些实施例中,通常采用公开数据集对模型进行训练。可以用常用的已经标记的数据集包括ImageNet、CIFAR-10、CIFAR-100等,也可以用没有标记的数据集,如ImageNet-21K等。同时也可以根据模型的具体应用领域单独采集训练数据。
在一些实施例中,根据本公开所涉及的模型的特点,当获取图像数据集后,尤其是公开数据集,通常需要利用数据增强方法进行预处理,得到增强后的无标签数据。在本公开所涉及的模型架构下,对于教师模型和学生模型的输入是完全独立的数据增强过程。将增强后的无标签数据分别输入到教师模型和学生模型。
在框220,
在一些实施例中,将所述训练数据集输入一个预设的Transformer模型,所述Transformer模型包括教师模型和学生模型。
需要说明的是,Transformer模型是一种基于自注意力机制的神经网络模型。在图像处理中,Transformer模型的输入是一张图像,输出可以是图片中某些目标的位置信息或者图像的标签。Transformer模型将图像拆分成若干个局部区域,然后对每个局部区域进行处理。在这个过程中,每个局部区域都会被转化为一个向量,这些向量将被不断地进行自注意力计算,从而获得局部区域之间的关系信息。
还需要说明的是,本公开是基于DINO(self-DIstillation with NO labels)模型进行改进。DINO是一种基于自监督学习的Transformer模型训练方法。DINO模型也是一种知识蒸馏方法。知识蒸馏是一种模型压缩方式,采用“教师-学生模型架构”和Kullback-Leibler散度的训练方法。其中教师模型和学生模型都是一个随机初始化的模型,学生模型需通过知识蒸馏技术学习到教师模型的暗知识,以达到教师模型的效果。DINO包含了一个骨架,Transformer解码器和多个预测头。DINO使用了两个Transformer编码器,一个作为教师模型,另一个作为学生模型。在训练过程中,教师模型将图像编码为特征向量,然后学生模型将这些特征向量映射到不同的视图。然后,通过比较教师模型和学生模型的输出,可以计算出损失函数,并使用反向传播算法来更新模型参数。
还需要说明的是,本公开对模型进行改进的部分是在所述教师模型和所述学生模型的输出分别添加了两个有不同输出维度的Softmax层,其中一个对应的是目标类任务,另一个对应的是非目标类任务。
在一些实施例中,需要对模型进行初始化。初始化参数指的是在网络模型训练之前,对各个节点的权重和偏置进行初始化赋值的过程。一般情况下,随机参数服从高斯分布/正态分布(Gaussian distribution/normaldistribution)和均匀分布(uniformdistribution)都是有效的初始化方法。
在框230,
在一些实施例中,将所述教师模型和所述学生模型的输出输入到带有超参数的解耦Kullback-Leibler散度公式,获取损失值。
在一些实施例中,教师模型和学生模型通过数据增强得到的图像在通过各自的主干网络和编码器后,计算出图像的匹配度也就是logits。该logits会通过使用softmax函数进行归一化处理。
在一些实施例中,基于目标类问题方法,将教师模型和学生模型输出的logits分别分成两个部分,其中一个对应的是目标类任务,另一个对应的是非目标类任务。也可以理解为一部分是目标类别的分数和非目标类别的分数,得到四个部分。
需要说明的是,和DINO模型不一样的地方是,本公开没有采用原始的Kullback-Leibler散度模型作为损失函数。Kullback-Leibler散度是一种用于评估预测分布和真实分布的交叉熵与真实分布的熵的差的方法。换句话说就是计算学生模型的logit分布和教师模型的logit分布的交叉熵。本公开是对Kullback-Leibler散度模型进行了解耦,并加入超参数(在图5会详细介绍解耦过程)。然后将解耦Kullback-Leibler散度公式作为训练的损失函数进行模型训练。
在一些实施例中,将这个四个部分输入到所述的解耦Kullback-Leibler散度公式,得到损失值。
在框240,
在一些实施例中,基于所述损失值,利用梯度反向传播,更新所述学生模型的自身参数并不断调整所述超参数,同时基于所述学生模型的自身参数更新教师模型的自身参数,直到所述学生模型达到收敛,保存所述学生模型。
在一些实施例中,根据损失函数所得损失值调整训练过程的误差,直到学生模型达到收敛。保存学生模型用于后续的预测或识别任务。
本公开提出的基于自监督学习的Transformer模型训练方法,首先不需要大量标记数据,而是通过知识蒸馏发现数据本事的结构和特征。另外它将传统的知识蒸馏思想解耦成目标知识蒸馏和非目标知识蒸馏,最终提高了教师模型和学生模型间的知识转移有效性,进而提高整体训练效果和效率。在降低模型的复杂度和计算成本的同时,在处理未标注数据时具有更好的泛化能力。
图3示出了根据本公开的实施例的数据增强流程图。
参考图3,
在框310,
在一些实施例中,从一个数据集中选择一部分用于训练模型的图像子集,
需要说明的是,用于模型训练的图像可以是无标签的图像。如果采用公开数据集对模型进行训练,例如ImageNet、CIFAR-10、CIFAR-100、ImageNet-21K、Google Landmarksv2等,而且需要对模型进行评估则需要对数据进行标注。同时也可以根据模型的具体应用领域单独采集训练数据,并为了支持模型检验对数据进行标注。
在一些实施例中,本公开是对DINO模型进行改进。因此采集数据,需要满足DINO模型的对于训练数据的要求。和其他神经网络模型一样,DINO模型训练,要求数据应该充足,多样化的,训练数据的图像质量和分辨率也应该足够高,以便模型学习到有意义的特征表示。
在框320,
在一些实施例中,利用数据增强方法进行预处理,得到增强后的无标签图像数据;
在一些实施例中,获取无标签数据后,会利用数据增强方法进行预处理,得到增强后的无标签图像数据。图像数据增强方法包括但不限于随机缩放裁剪、随机旋转、随机翻转等等。在本公开,最主要的数据增强方法是随机缩放。
在一些实施例中,输入图像需要基于随机缩放裁剪处理成局部视图。当提取局部视图的时候,对于每张输入图像,可以随机选择一部分像素来构成局部视图。例如,可以使用一个滑动窗口来遍历整张图像,并在每个位置上选择一个固定大小的窗口作为局部视图。通常情况下,小于原始图像50%覆盖区域的会被认为是局部视图。
在一些实施例中,输入图像需要基于随机缩放裁剪处理成全局视图。当提取局部视图的时候,对于全局视图,可以使用整张图像作为输入或是截取整张图像的较大面积的部分。如果对全局视图的数量要求较高,可以对输入图像进行随机的数据增强操作,如旋转、裁剪、缩放等,以获得更多的全局视角。通常情况下,大于原始图像50%覆盖区域的会被认为是全局视图。
在一些实施例中,利用随机缩放裁剪,可以生成一组不同视图的集合。该集合包几个全局视图以及几个较小分辨率的局部视图。所有视图都通过学生模型,而只有全局视图通过老师模型,因此鼓励“局部到全局”的对应。例如可以使用2个分辨率较大的全局视图,即为覆盖原始图像大区域的,以及几个分辨率较小的局部视图,即为仅覆盖原始图像的小区域的(例如小于50%)。
图4示出了根据本公开的实施例的Transformer模型结构图。
参考图4,
在一些实施例中,本公开所用模型是在DINO算法中的模型结构基础上进行修改。学生模型由一个骨干模型1,一个全局平均池化层1,和两个Softmax层(Softmax层1和Softmax层2)组成,而教师模型由一个骨干模型2,一个center层(和DINO算法中的对应层的结构及作用一样),一个全局平均池化层2,和两个Softmax层(Softmax层3和Softmax层4)组成。其中,学生模型和教师模型的骨干模型同为任一种Transformer模型。4层Softmax层的输入维度均一致,具体输入维度需视Transformer骨干模型输出维度而定。学生模型和教师模型的输出总维度参考了DINO算法中的设置,即为65536。由于本公开中已经将输出解耦成两部分,因此设置Softmax层1和Softmax层3输出维度为2,Softmax层2和Softmax层4输出维度为65534(65536-2),以分别学习目标概率分布和非目标概率分布。维度为2对应的就是目标类任务(目标概率分布),维度为65534对应的就是非目标类任务(非目标概率分布)。
给定输入,DINO用ResNet或ViT Transformer等主干网络提取多尺度特征,然后将它们与相应的位置嵌入一起输入Transformer编码器。
假设给定一组有M个样本的无标签预训练数据集U用于自监督训练,那么上述4层Softmax层对数据集U的logit输出可定义为:
其中,si表示不同Softmax层的序号,取值为1至4,即si={1,2,3,4}。Csi为各个Softmax层的相应输出维度,Csi={C1,C2,C3,C4}={2,65534,2,65534}。
在此发明中,学生模型Transformer骨干模型所输出需要输入至Softmax层1,将其余输出需要输入至全局平均池化层1,然后再输入至Softmax层2;教师模型Transformer骨干模型的输出需要输入至Center层,然后再输入到Softmax层3,而其余输出需要输入至全局平均池化层2,最后再输入到Softmax层4。
Softmax层1和Softmax层3的输出分别作为学生目标logit分布pS=p1和教师目标logit分布pT=p3,Softmax层2和Softmax层4的输出分别作为学生非目标logit分布和教师非目标logit分布/>
图5示出了根据本公开的实施例的解耦Kullback-Leibler散度公式的流程图。
参考图5,
在框510,
在一些实施例中,定义随机变量的预测概率。
在一些实施例中,以离散随机变量为例,有一个离散随机变量集X={x1,x2,x3,…,xn},其中某随机变量xi在某分布p中的对应概率为pi=p(X=xi),则随机变量集X的熵可定义为:
除了分布p外,还有另一个分布为q,则某随机变量xi在分布q的对应概率为qi=q(X=xi)。此时,可以使用以下公式表示两个分布的交叉熵,其代表了两个分布间的差异性信息:
在以上两个公式的基础熵,即可得到Kullback-Leibler散度(亦称为相对熵):
DKL(p||q)=H(p,q)-H(p),
在上述离散随机变量集X基础上,假设有其对应的类别标签集Y∈R1×K,其中K表示类别数量,则离散随机变量集X中某随机变量x的预测概率可写为其中第c类预测概率的基本表达式可表示为:
换个角度,若当前第t类为目标类别,然后将非目标类问题拆解出二分类问题,即相当于把第t类预测概率看成两部分,指第t类预测概率,/>非第t类预测概率,以下是两者数学公式:
其中,G指任一函数,exp为指数函数。
以上将原非目标类别概率转换成了一种新的目标类别概率分布,换言之,把目标类别概率从原非目标类概率分布中独立出来,形成新的目标类概率分布。在此基础上,进一步将原非目标类概率分布进行分解,即将非目标类别概率分布也中独立出来,可通过以下公式表达:
最后将原非目标类概率分布分解成一个新的目标类概率分布和非目标类别概率分布/>结合上述表达式,即可得到以下公式:
以上则为随机变量的预测概率的定义。
在框520,
在一些实施例中,分解Kullback-Leibler散度公式。
基于上述数学支持,即可将和另一任意概率分布/>代入Kullback-Leibler散度公式,并进行分解:
由最终分解结果可看出,Kullback-Leibler散度公式也被分解成了两部分,第一部分是两个分布间(目标类分布和非目标类分布)的目标类概率分布的Kullback-Leibler散度,第二部分两个分布间的非目标类别概率分布的Kullback-Leibler散度。为了实现完全解耦,使用两个可调超参数γ和δ,对这两部分进行控制,因此解耦Kullback Leibler散度公式如下所示:
将增强后的无标签数据分别输入到教师模型和学生模型,得到pS,pT,将pS,pT,/>代入上述解耦Kullback Leibler散度公式,即得到最终的解耦KullbackLeibler散度公式:
其中,DKL(pT||pS)是解耦Kullback-Leibler散度公式,即为在两个分布之间的Kullback-Leibler散度;pT和pS分别为教师目标logit分布和学生目标logit分布;公式右侧将其分解成两个Kullback-Leibler散度的和,其中和/>是pT和pS的二进制表示,对应所述目标类任务,γ和δ是超参数,/>和/>分别为教师非目标类别概率分布和学生非目标类别概率分布,对应所述非目标类任务。
图6示出了根据本公开的实施例的利用指数滑动平均参数更新教师模型的流程图。
参考图6,
在框610,
在一些实施例中,将全部softmax输出都传递到损失函数中。
需要说明的是,四个softmax输出都传递到损失函数中,使用随机梯度下降(SGD)执行反向传播。在这里的反向传播是通过学生模型执行的。
在框620,
在一些实施例中,对学生模型参数使用指数移动平均指数。
需要说明的是,为了更新教师模型,DINO对学生模型参数使用指数移动平均(EMA),将学生模型的模型参数传输到教师模型。EMA(Exponential Moving Average)是一种平滑技术,用于计算一个序列的指数移动平均值。在DINO模型中,EMA用于防止学生模型的输出特征向量的方差过大,防止过度拟合。
具体来说,对于每个学生模型参数,EMA可以根据以下公式对教师模型的参数进行更新:
其中,表示学生模型对样本xi的输出参数,EMA表示指数移动平均,α是一个滑动平均系数,在训练过程中遵循余弦计划从0.996到1。
需要说明的是,对于前述的各方法实施例,为了简单描述,故将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本公开并不受所描述的动作顺序的限制,因为依据本公开,某些步骤可以采用其他顺序或者同时进行。其次,本领域技术人员也应该知悉,说明书中所描述的实施例均属于可选实施例,所涉及的动作和模块并不一定是本公开所必须的。
以上是关于方法实施例的介绍,以下通过装置实施例,对本公开所述方案进行进一步说明。
图7示出了根据本公开的实施例的一种基于自监督学习的Transformer模型训练装置700的框图。装置700可以被包括在图1的120中。如图7所示,装置700包括:
数据单元701,用于获取训练数据集;
Transformer模型单元702,用于将所述训练数据集输入预设的Transformer模型,所述Transformer模型包括教师模型和学生模型;
损失函数单元703,用于将所述教师模型和所述学生模型的输出输入到带有超参数的解耦Kullback-Leibler散度公式,获取损失值;
训练单元704,基于所述损失值,利用梯度反向传播,更新所述学生模型的自身参数并不断调整所述超参数,同时基于所述学生模型的自身参数更新教师模型的自身参数,直到所述学生模型达到收敛,保存所述学生模型。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,所述描述的模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
本公开的技术方案中,所涉及的用户个人信息的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图8示出了可以用来实施本公开的实施例的电子设备800的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
电子设备800包括计算单元801,其可以根据存储在ROM802中的计算机程序或者从存储单元808加载到RAM803中的计算机程序,来执行各种适当的动作和处理。在RAM803中,还可存储电子设备800操作所需的各种程序和数据。计算单元801、ROM802以及RAM803通过总线804彼此相连。I/O接口805也连接至总线804。
电子设备800中的多个部件连接至I/O接口805,包括:输入单元806,例如键盘、鼠标等;输出单元807,例如各种类型的显示器、扬声器等;存储单元808,例如磁盘、光盘等;以及通信单元809,例如网卡、调制解调器、无线通信收发机等。通信单元809允许电子设备800通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元801可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元801的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元801执行上文所描述的各个方法和处理,例如基于自监督学习的Transformer模型训练方法。例如,在一些实施例中,基于自监督学习的Transformer模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元808。在一些实施例中,计算机程序的部分或者全部可以经由ROM802和/或通信单元809而被载入和/或安装到电子设备800上。当计算机程序加载到RAM803并由计算单元801执行时,可以执行上文描述的基于自监督学习的Transformer模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元801可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行基于自监督学习的Transformer模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、现场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置;以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (9)
1.一种基于自监督学习的Transformer模型训练方法,其特征在于,包括:
获取训练数据集;
将所述训练数据集输入预设的Transformer模型,所述Transformer模型包括教师模型和学生模型;
将所述教师模型和所述学生模型的输出输入到带有超参数的解耦Kullback-Leibler散度公式,获取损失值;
基于所述损失值,利用梯度反向传播,更新所述学生模型的自身参数并不断调整所述超参数,同时基于所述学生模型的自身参数更新教师模型的自身参数,直到所述学生模型达到收敛,保存所述学生模型。
2.根据权利要求1所述的方法,其特征在于,所述训练数据集包括:多组图像数据;每组图像数据包含对应同一原始图像的选定数量的全局视图和选定数量的局部视图;将每组图像数据输入学生模型;将每组图像数据中的全局视图输入教师模型。
3.根据权利要求1所述的方法,其特征在于,所述学生模型包括:骨干模型,全局平均池化层和两个Softmax层,所述两个softmax层具有不同输出维度,一个Softmax层对应的是目标类任务,另一个Softmax层对应的是非目标类任务。
4.根据权利要求3所述的方法,其特征在于,所述教师模型包括:骨干模型,center层,全局平均池化层和两个Softmax层,所述两个softmax层具有不同输出维度,一个Softmax层对应的是目标类任务,另一个Softmax层对应的是非目标类任务。
5.根据权利要求4所述的方法,其特征在于,所述解耦Kullback-Leibler散度公式为:
其中,DKL(pT||pS)是解耦Kullback-Leibler散度;pT和pS分别为教师目标logit分布和学生目标logit分布;和/>是pT和pS的二进制表示;/>是非目标类概率分布的Kullback-Leibler散度;/>和/>分别为教师非目标类别概率分布和学生非目标类别概率分布;γ和δ是超参数。
6.根据权利要求1所述的方法,其特征在于,所述更新教师模型的自身参数包括:利用所述学生模型的指数滑动平均参数来更新所述教师模型的自身参数;所述学生模型的指数滑动平均参数是学生模型自身参数的指数移动平均值。
7.一种基于自监督学习的Transformer模型训练装置,其特征在于,包括:
数据单元,用于获取训练数据集;
Transformer模型单元,用于将所述训练数据集输入预设的Transformer模型,所述Transformer模型包括教师模型和学生模型;损失函数单元,用于将所述教师模型和所述学生模型的输出输入到带有超参数的解耦Kullback-Leibler散度公式,获取损失值;训练单元,用于基于所述损失值,利用梯度反向传播,更新所述学生模型的自身参数并不断调整所述超参数,同时基于所述学生模型的自身参数更新教师模型的自身参数,直到所述学生模型达到收敛,保存所述Transformer模型的学生模型。
8.一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-6中任一权利要求所述的方法。
9.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1至6中任一权利要求所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310475772.XA CN116805162A (zh) | 2023-04-27 | 2023-04-27 | 基于自监督学习的Transformer模型训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310475772.XA CN116805162A (zh) | 2023-04-27 | 2023-04-27 | 基于自监督学习的Transformer模型训练方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116805162A true CN116805162A (zh) | 2023-09-26 |
Family
ID=88080097
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310475772.XA Pending CN116805162A (zh) | 2023-04-27 | 2023-04-27 | 基于自监督学习的Transformer模型训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116805162A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117787922A (zh) * | 2024-02-27 | 2024-03-29 | 东亚银行(中国)有限公司 | 基于蒸馏学习和自动学习的反洗钱业务处理方法、系统、设备和介质 |
-
2023
- 2023-04-27 CN CN202310475772.XA patent/CN116805162A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117787922A (zh) * | 2024-02-27 | 2024-03-29 | 东亚银行(中国)有限公司 | 基于蒸馏学习和自动学习的反洗钱业务处理方法、系统、设备和介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109949255B (zh) | 图像重建方法及设备 | |
JP6504590B2 (ja) | 画像のセマンティックセグメンテーションのためのシステム及びコンピューター実施方法、並びに非一時的コンピューター可読媒体 | |
CN109345508B (zh) | 一种基于两阶段神经网络的骨龄评价方法 | |
CN107292352B (zh) | 基于卷积神经网络的图像分类方法和装置 | |
CN111507993A (zh) | 一种基于生成对抗网络的图像分割方法、装置及存储介质 | |
US20230162477A1 (en) | Method for training model based on knowledge distillation, and electronic device | |
CN113128478A (zh) | 模型训练方法、行人分析方法、装置、设备及存储介质 | |
CN111127360A (zh) | 一种基于自动编码器的灰度图像迁移学习方法 | |
WO2022111387A1 (zh) | 一种数据处理方法及相关装置 | |
CN111667483A (zh) | 多模态图像的分割模型的训练方法、图像处理方法和装置 | |
CN116805162A (zh) | 基于自监督学习的Transformer模型训练方法 | |
Peng et al. | An industrial-grade solution for agricultural image classification tasks | |
Li et al. | An end-to-end framework for joint denoising and classification of hyperspectral images | |
CN114495101A (zh) | 文本检测方法、文本检测网络的训练方法及装置 | |
WO2024060839A1 (zh) | 对象操作方法、装置、计算机设备以及计算机存储介质 | |
CN115909336A (zh) | 文本识别方法、装置、计算机设备和计算机可读存储介质 | |
CN111753995A (zh) | 一种基于梯度提升树的局部可解释方法 | |
CN114913339B (zh) | 特征图提取模型的训练方法和装置 | |
JP2022075620A (ja) | 畳み込みニューラルネットワークをトレーニングする方法およびシステム | |
CN114330576A (zh) | 模型处理方法、装置、图像识别方法及装置 | |
Jin et al. | Blind image quality assessment for multiple distortion image | |
Wirayasa et al. | Comparison of Convolutional Neural Networks Model Using Different Optimizers for Image Classification | |
CN114708471B (zh) | 跨模态图像生成方法、装置、电子设备与存储介质 | |
JP2020030702A (ja) | 学習装置、学習方法及び学習プログラム | |
CN115578613B (zh) | 目标再识别模型的训练方法和目标再识别方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |