CN116524307A - 一种基于扩散模型的自监督预训练方法 - Google Patents
一种基于扩散模型的自监督预训练方法 Download PDFInfo
- Publication number
- CN116524307A CN116524307A CN202310350662.0A CN202310350662A CN116524307A CN 116524307 A CN116524307 A CN 116524307A CN 202310350662 A CN202310350662 A CN 202310350662A CN 116524307 A CN116524307 A CN 116524307A
- Authority
- CN
- China
- Prior art keywords
- network
- diffusion model
- image
- representing
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 81
- 238000009792 diffusion process Methods 0.000 title claims abstract description 66
- 238000012549 training Methods 0.000 title claims abstract description 62
- 238000005070 sampling Methods 0.000 claims abstract description 41
- 238000001514 detection method Methods 0.000 claims abstract description 12
- 230000011218 segmentation Effects 0.000 claims abstract description 12
- 239000013598 vector Substances 0.000 claims description 57
- 239000004973 liquid crystal related substance Substances 0.000 claims description 12
- 238000013528 artificial neural network Methods 0.000 claims description 10
- 230000000873 masking effect Effects 0.000 claims description 10
- 239000011159 matrix material Substances 0.000 claims description 10
- 238000010586 diagram Methods 0.000 claims description 9
- 238000000605 extraction Methods 0.000 claims description 4
- 238000012544 monitoring process Methods 0.000 claims 3
- 230000006870 function Effects 0.000 description 15
- 230000008569 process Effects 0.000 description 6
- 230000000007 visual effect Effects 0.000 description 5
- 238000004590 computer program Methods 0.000 description 4
- 230000002708 enhancing effect Effects 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 238000010191 image analysis Methods 0.000 description 2
- 238000003709 image segmentation Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000000265 homogenisation Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/778—Active pattern-learning, e.g. online learning of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/20—Image preprocessing
- G06V10/26—Segmentation of patterns in the image field; Cutting or merging of image elements to establish the pattern region, e.g. clustering-based techniques; Detection of occlusion
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/42—Global feature extraction by analysis of the whole pattern, e.g. using frequency domain transformations or autocorrelation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V20/00—Scenes; Scene-specific elements
- G06V20/70—Labelling scene content, e.g. deriving syntactic or semantic representations
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Abstract
本发明提供了一种基于扩散模型的自监督预训练方法,包括以下步骤:步骤1,在预训练的数据集上基于噪声预测的方式训练扩散模型,并将其作为教师网络;步骤2,将步骤1中训练完的扩散模型中U‑Net网络中的上采样部分的特征图提取出来,并进行拼接;步骤3,将学生网络输出的特征图与步骤2中提取的特征图进行对齐,从而对学生网络进行训练;步骤4,通过步骤3训练完的学生网络得到图像的全局特征图。本发明方法对图像中不同区域的语义相关性进行了显式地建模,并且添加了对图像的全局特征的正则化约束,从而大幅提升了通过本发明方法预训练得到的模型在图像分类、目标检测和语义分割等下游任务的性能。
Description
技术领域
本发明涉及一种计算机视觉领域自监督预训练方法,特别是一种基于扩散模型的自监督预训练方法。
背景技术
计算机视觉(Computer Vision)是指让计算机能够理解和处理图像或视频数据的技术。计算机视觉在各个领域都有广泛的应用,例如人脸识别、自动驾驶、医学影像分析等。为了提高计算机视觉任务的性能,需要从大量的数据中学习有效的视觉特征,即能够表征图像内容和语义信息的向量。
近年来,随着深度神经网络(Deep Neural Network)在计算机视觉领域的发展,通过监督学习(Supervised Learning)提取图像中的视觉特征逐渐成为了主流。然而,在监督学习的框架下通常需要大量的人工标注的数据来训练模型。在实际应用中,并不总是有足够多且高质量的标注数据可用,因为收集和标注数据既耗时又昂贵,而且可能存在噪声、偏差和不一致等问题。此外,在某些领域或任务中,获取标注数据本身就非常困难或不可行,例如医学影像分析、无人驾驶等。为了解决这个问题,自监督学习(Self-supervisedLearning)利用海量的未标注数据来学习图像的视觉特征。自监督学习方法往往通过设计一些前置任务(Pretext Task),从数据本身生成伪标签,并以此训练神经网络,从而学习到通用的视觉特征。而通过自监督学习训练神经网络并从海量未标注数据中提取视觉特征的过程被称为预训练(Pre-training)。
当前针对计算机视觉领域的自监督预训练方法依然存在以下问题:1)现有的预训练方法缺乏对图像中各局部区域之间的语义相关性的建模;2)现有的预训练方法要么过于关注图像中的像素级别的特征,要么过于关注图像的全局特征而忽略图像中的语义结构,这些方法都无法在学习图像的全局特征和局部细节的能力上达到平衡。上述问题会导致通过这些方法预训练得到的模型在线性分类、目标检测等下游任务上的性能不佳。
发明内容
发明目的:本发明所要解决的技术问题是针对现有技术的不足,提供一种基于扩散模型的自监督预训练方法。
为了解决上述技术问题,本发明公开了一种基于扩散模型的自监督预训练方法,包括以下步骤:
步骤1,在图片数据集上采用噪声预测的方法训练扩散模型,并将训练好的扩散模型作为教师网络;
所述采用噪声预测的方法训练扩散模型的方法包括:
步骤1-1,选定图片数据集并将其中的所有图片的像素值缩放到区间[0,1];
步骤1-2,对于数据集的图片x0,向图片x0中注入噪声并得到t时刻的图片xt,具体方法如下:
其中,表示各个时刻缩放因子的连乘,/>αs表示s时刻的缩放因子,αt表示t时刻的缩放因子,αt:=1-βt,βt表示t时刻的噪声方差,∈表示注入的随机采样的高斯噪声,/>表示标准正态分布。
步骤1-3,采用第一损失函数L(θ)对扩散模型进行训练,所述第一损失函数如下:
其中,表示求方括号中的函数的期望,∈θ表示噪声预测网络。
步骤2,使用步骤1中得到的训练好的扩散模型提取特征向量,即将输入图片经过扩散模型中的前向扩散过程得到对应时间步的图片,然后将得到的图片输入扩散模型中的噪声预测网络,并将噪声预测网络中上采样区块中的特征图提取出来,最后将特征图进行拼接,并进行采样得到输入图片的特征向量;
所述提取特征向量的方法包括:
步骤2-1,将步骤1-2中得到的t时刻的图片xt输入噪声预测网络∈θ,提取出噪声预测网络中上采样部分的中间层特征图,并将这些特征图拼接起来;
步骤2-2,从步骤2-1得到的特征图中均匀采样K个点,采用双线性插值得到采样点对应的特征向量表示第i个图像的特征向量即教师网络生成的特征向量,表示K×D大小的实数矩阵,D表示特征向量的维度。
步骤3,将待训练的神经网络作为学生网络,并对其进行训练,即对待训练的学生网络输出的特征图进行采样,然后与步骤2中采样得到的特征向量进行对齐,最后采用反向传播对学生网络进行训练,得到训练完成的学生网络;
所述对学生网络进行训练,具体方法包括:
步骤3-1,对输入图像x0进行数据增强,并使用二进制掩码遮去图像中50%的部分;
所述的数据增强的方法包括:
采用随机裁剪、水平翻转、颜色抖动和高斯模糊对图像进行增强。
所述的使用二进制掩码遮去图像中50%的部分,具体方法包括:
步骤3-1-1,计算图像中各个区块的重要性程度s,具体方法如下:
其中,q0表示所述图片全局特征的查询向量,K1:表示该图片除了全局特征向量之外的索引矩阵,D表示特征向量的维度,表示数学期望,Softmax(·)表示Softmax函数。
步骤3-1-2,根据各个区块的重要性程度s进行重要性采样,即采样出重要性程度s大于阈值的区块,使用掩码遮去。
步骤3-2,将图像剩余的50%的部分输入学生网络中的编码器,得到剩余部分图像的特征向量;
步骤3-3,将步骤3-2中得到的特征向量和掩码token拼接起来,并输入到学生网络中的解码器中,从而得到剩余图像的特征图和全局特征;
步骤3-4,将步骤3-3得到的特征图通过与步骤2-2相同的采样方法得到剩余图像的特征向量
步骤3-5,计算学生网络生成的特征向量和教师网络生成的特征向量/>之间的距离相关系数/>然后采用第二损失函数/>进行特征图对齐,方法如下:
其中,N表示数据集中图片的数量,α表示缩放因子;
步骤3-6,对输入图像的全局特征进行归一化,然后引入均匀正则化项对图片的全局特征进行约束,具体方法如下:
其中,表示第三损失函数,/>表示归一化后的图片的全局特征向量,I表示单位矩阵,/>表示矩阵的弗罗贝尼乌斯范数;
步骤3-7,结合步骤3-5和步骤3-6中的第二损失函数和第三损失函数,得到总损失函数来训练学生网络。
所述的总损失函数如下:
其中,λ表示平衡因子。
步骤4,将步骤3中得到的训练完成的学生网络作为特征提取器,进行特征提取,将图像输入训练完成的学生网络得到图像的全局特征图,完成基于扩散模型的自监督预训练;
所述进行特征提取的方法包括:
将通过步骤3中训练得到的学生网络中的编码器作为特征提取的网络,输入图片后将编码器最后一层的特征图保留作为提取的特征。
步骤5,将完成基于扩散模型的自监督预训练的学生网络应用于目标检测或语义分割的任务中,对图像中的物体进行识别和定位。
有益效果:
本发明提出的一种基于扩散模型的自监督预训练方法,将待训练的学生网络的输出与扩散模型中提取的特征图进行对齐,并且对图像的全局特征进行均匀化约束。与其它自监督预训练方法相比,本发明对图像中各区域的语义相关性进行了显式地建模,并且加强了模型对图像全局特征的提取能力,从而大幅提升了预训练得到的模型在各种下游任务上的性能。
附图说明
下面结合附图和具体实施方式对本发明做更进一步的具体说明,本发明的上述和/或其他方面的优点将会变得更加清楚。
图1为本发明提供的一种基于扩散模型的自监督预训练方法的流程图。
图2为本发明提供的基于扩散模型的特征图提取方法的流程图。
图3为本发明提供的学生网络训练策略的流程图。
图4为本发明在目标检测和语义分割任务上应用的示意图。
具体实施方式
本发明提供一种基于扩散模型的自监督预训练方法,本方法充分利用了扩散模型中具有丰富语义信息的特征图,对图像中各区域的语义相关性进行了建模,并且通过正则化项加强了模型提取全局特征的能力,提高了模型在各种下游任务上的迁移性能。通过本发明提出的方法进行预训练的模型在目标检测和语义分割等需要图像中细节特征的任务上达到了大幅超越之前方法的性能,并且在线性分类等需要图像全局特征的任务上也取得了较好的效果。
本发明公开了一种基于扩散模型的自监督预训练方法,包括以下步骤:
步骤1,在图片数据集上采用噪声预测的方法训练扩散模型,并将其作为教师网络;
扩散模型是一种生成模型,扩散模型在前向过程中对图像逐步施加噪声,直至图像被破坏变成完全的高斯噪声,然后在逆向过程中学习从高斯噪声还原为原始图像的过程。在训练阶段,扩散模型通过训练噪声预测网络来预测图像中的噪声,在生成阶段,扩散模型将带噪声的图像输入噪声预测网络,并通过其预测的噪声来去除图像中的噪声,最终逐步地得到没有噪声的图像。噪声预测网络采用的是U-Net架构,U-Net是一类被广泛应用于图像生成、图像分割等领域的神经网络架构。U-Net中包含两个部分:下采样部分和上采样部分。下采样部分逐步地将输入图像降低分辨率并同时增加通道数,而上采样部分则相反,通过上采样将低分辨率的特征图逐步还原成原始图像的维度。上采样部分由多层的神经网络组成,中间层则具体指上采样部分的第6、7、8层。
步骤2,对步骤1中得到的训练完成的扩散模型进行特征图提取,即将输入图片经过前向扩散过程得到对应时间步的图片,然后将对应的U-Net中上采样区块中的特征图提取出来并进行拼接;
步骤3,对步骤2中提取出的特征图进行采样,然后将待训练的学生网络输出的特征图与采样后的特征图进行对齐,最后采用反向传播对学生网络进行训练;
学生网络,通常采用ViT架构(参考:[2010.11929]An Image is Worth 16x16Words:Transformers for Image Recognition at Scale(arxiv.org)),学生网络和教师网络是解耦的,教师网络使用扩散模型,学生网络可以采用多种架构,一般会采用ViT架构。
步骤4,将步骤3中得到的训练完成的学生网络作为特征提取器,将图像输入学生网络得到图像的全局特征图。
本发明步骤1中所述基于噪声预测的扩散模型训练方法包括:
步骤1-1,选定图片数据集并将其中的所有图片的像素值缩放到区间[0,1];
步骤1-2,对于数据集的图片x0,通过如下公式向图片中注入噪声并得到t时刻的图片xt:
其中βt表示t时刻的噪声方差;
步骤1-3,采用如下的损失函数对扩散模型(参考:[2006.11239]DenoisingDiffusion Probabilistic Models(arxiv.org))进行训练:
其中∈θ表示采用U-Net作为基础架构的噪声预测网络(参考:[1505.04597]U-Net:Convolutional Networks for Biomedical Image Segmentation(arxiv.org)),∈表示采样的高斯噪声。训练完成后得到的模型称之为教师网络。
本发明步骤2中所述特征图拼接方法包括:
步骤2-1,给定训练集中的图片x0,向其中注入随机采样的高斯噪声,并按照步骤1-2中的公式得到t时刻的图片xt;
步骤2-2,将图片xt输入教师网络∈θ,然后提取出U-Net上采样部分的中间层特征图,并将这些特征图拼接起来;
步骤2-3,从步骤2-2得到的特征图中均匀采样K个点,然后采样双线性插值得到这些采样点对应的特征向量
本发明步骤3中所述学生网络的训练策略包括:
步骤3-1,对输入图像x0进行数据增强,并使用二进制掩码遮去图像中50%的部分;
步骤3-2,将图像剩余的可见部分输入学生网络中的编码器,得到这部分图像的特征向量;
步骤3-3,将步骤3-2中得到的特征向量和可学习的掩码token(用来表示图像中掩码区域的token,可以看作是一组可学习的参数)拼接起来,并输入到学生网络中的解码器中,从而得到图像的特征图和全局特征;
步骤3-4,将步骤3-3得到的特征图通过与步骤2-3相同的采样策略得到图像的特征向量
步骤3-5,计算学生网络生成的特征向量和教师网络生成的特征向量/>之间的距离相关系数/>然后采用如下损失函数进行特征图对齐:
其中N表示图片的数量,α表示缩放因子。
步骤3-6,对图片的全局特征进行归一化,然后引入如下的均匀正则化项对图片的全局特征进行约束:
其中表示归一化后的图片的全局特征向量。
步骤3-7,结合步骤3-5和步骤3-6中的损失函数,得到如下的损失函数来训练学生网络:
其中λ表示用来权衡第一项和第二项重要性的平衡因子。
本发明步骤3-1中所述数据增强和掩码的方法包括:
步骤3-1-1,采用随机裁剪、水平翻转、颜色抖动和高斯模糊等策略对图片进行增强。
步骤3-1-2,采用如下公式计算图片中各个patch的重要性程度:
其中,q0表示图片全局token的query向量,K1:表示除了全局特征向量之外的索引矩阵。
步骤3-1-3,根据各个patch的重要性程度进行重要性采样,将采样出的patch使用掩码遮去。
本发明步骤4中所述特征提取的方法包括:
将通过步骤3-5训练得到的学生网络中的编码器作为特征提取的网络,输入图片后将编码器最后一层的特征图保留作为提取的特征。
实施例:
本发明公开了一种基于扩散模型的自监督预训练方法,如图1所示,所述方法包括:
步骤1,在图片数据集上采用噪声预测的方法训练扩散模型,并将其作为教师网络;
步骤2,对步骤1中得到的训练完成的扩散模型进行特征图提取,即将输入图片经过前向扩散过程得到对应时间步的图片,然后将对应的U-Net中上采样区块中的特征图提取出来并进行拼接;
步骤3,对步骤2中提取出的特征图进行采样,然后将待训练的学生网络输出的特征图与采样后的特征图进行对齐,最后采用反向传播对学生网络进行训练;
步骤4,将步骤3中得到的训练完成的学生网络作为特征提取器,将图像输入学生网络得到图像的全局特征图。
步骤1包含如下步骤:
步骤1-1,选定图片数据集例如ImageNet,并将其中的所有图片的像素值缩放到区间[0,1];
步骤1-2,对于数据集的图片x0,通过如下公式向图片中注入噪声并得到t时刻的图片xt:
其中βt表示t时刻的噪声方差;
步骤1-3,采用如下的损失函数对扩散模型进行训练:
其中∈θ表示采用U-Net作为基础架构的噪声预测网络,∈表示采样的高斯噪声。训练完成后得到的模型称之为教师网络。
训练完成后的扩散模型的输入为t时刻的带有噪声的图片xt和时间t,输出为图片xt中含有的噪声,然后便可以预测出t-1时刻的图片xt-1。
如图2所示,步骤2包含如下步骤:
步骤2-1,给定训练集中的图片x0,向其中注入随机采样的高斯噪声,并按照步骤1-2中的公式得到t时刻的图片xt;
步骤2-2,将图片xt输入教师网络∈θ,然后提取出U-Net上采样部分的中间层特征图,并将这些特征图拼接起来;
步骤2-3,从步骤2-2得到的特征图中均匀采样K个点,然后采样双线性插值得到这些采样点对应的特征向量
由于扩散模型中提取出的特征图包含丰富的语义信息,因此这些采样点对应的特征向量包含了这些点周围区域的语义信息。
如图3所示,步骤3中包含如下步骤:
步骤3-1,对输入图像x0进行数据增强,并使用二进制掩码遮去图像中50%的部分;
步骤3-2,将图像剩余的可见部分输入学生网络中的编码器,得到这部分图像的特征向量。编码器的网络结构与传统的ViT保持一致,即将输入的图片划分为若干个大小相等的patch,并采用线性变换将原始的像素块转换为向量,然后输入L层标准的TransformerBlock,每个Block都由自注意力模块和多层感知机组成。;
步骤3-3,将步骤3-2中得到的特征向量和掩码token拼接起来,并输入到学生网络中的解码器中,从而得到图像的特征图和全局特征。与编码器不同的是,解码器由一系列的互注意力模块组成。该模块的输入hm由M个掩码token加上相应的位置编码组成,每个token都是相同的,并且是可学习的参数。互注意力模块中的K和V则是由图像中可见部分的特征zv组成。互注意力模块的计算公式如下:
其中,Qm=hmWQ,K=zvWK,V=zvWV,WQ、WK、WV为可学习的参数。
步骤3-4,将步骤3-3得到的特征图通过与步骤2-3相同的采样策略得到图像的特征向量
步骤3-5,计算学生网络生成的特征向量和教师网络生成的特征向量/>之间的距离相关系数/>然后采用如下损失函数进行特征图对齐:
其中N表示图片的数量,α表示缩放因子。
步骤3-6,对图片的全局特征进行归一化,然后引入如下的均匀正则化项对图片的全局特征进行约束:
其中表示归一化后的图片的全局特征向量。训练时,对于一个批次内的N张图片,这些图片的全局特征/>由解码器输出的[cls]token组成。然后进一步对这些特征向量进行批次归一化,使得每个特征维度都具有0均值和/>的标准差:
由此可以得到归一化后的图片的全局特征向量
步骤3-7,结合步骤3-5和步骤3-6中的损失函数,得到如下的损失函数来训练学生网络:
其中λ表示用来权衡第一项和第二项重要性的平衡因子。
步骤3-1包含如下步骤:
步骤3-1-1,采用随机裁剪、水平翻转、颜色抖动和高斯模糊等策略对图片进行增强。
步骤3-1-2,采用如下公式计算图片中各个patch的重要性程度:
其中,q0表示图片全局特征向量,K1:表示除了全局特征向量之外的索引矩阵。
步骤3-1-3,根据各个patch的重要性程度进行重要性采样,将采样出的patch使用掩码遮去。在实际采样过程中,首先根据各个patch的重要性得到分布π(x),然后通过随机数生成器得到随机分布p(x),最后根据的值对各个patch由大到小进行排序,将前50%的patch遮去,并将剩余的patch作为学生网络中编码器的输入。
步骤4包括:将通过步骤3-5训练得到的学生网络中的编码器作为特征提取的网络,输入图片后将编码器最后一层的特征图保留作为提取的特征。
为了验证本发明的有效性,本发明采用ImageNet作为预训练数据集,与其它优秀模型在目标检测和语义分割任务上的性能进行了对比。对于目标检测任务,本发明采用了常用的Mask R-CNN结构,将预训练得到的模型作为骨干网络,然后在MS-COCO数据上进行微调,最终的实验结果如表1所示:
表1本发明提供的目标检测任务上的性能对比表
由表1可以看出:本发明提出的预训练方法在目标检测任务上达到了52.8%的平均精确率(Average Precision,AP),在实例分割任务上达到了46.7%的AP,比之前最好的方法CAE分别提升了3.0%和2.8%。
对于语义分割任务,本发明采用了UperNet结构,然后在ADE20K数据集上进行微调,最终的实验结果如表2所示:
表2本发明提供的语义分割任务上的性能对比表
由表2可以看出:本发明提出的预训练方法在语义分割任务上达到了52.4%的平均交并比(mean Intersection Over Union,mIOU),62.8%的全局像素准确率(all pixelAccuracy,aAcc)和86.1%的平均类别准确率(mean class Accuracy,mAcc),比之前最好的方法CAE分别提升了2.1%、2.5%和1.3%。
本发明在目标检测(图4上)和语义分割任务(图4下)上应用如图4所示。
综上所述,由实验结果可知,本发明在目标检测和语义分割任务上的多个测试指标均大幅超越了之前的模型。
具体实现中,本申请提供计算机存储介质以及对应的数据处理单元,其中,该计算机存储介质能够存储计算机程序,所述计算机程序通过数据处理单元执行时可运行本发明提供的一种基于扩散模型的自监督预训练方法的发明内容以及各实施例中的部分或全部步骤。所述的存储介质可为磁碟、光盘、只读存储记忆体(read-only memory,ROM)或随机存储记忆体(random access memory,RAM)等。
本领域的技术人员可以清楚地了解到本发明实施例中的技术方案可借助计算机程序以及其对应的通用硬件平台的方式来实现。基于这样的理解,本发明实施例中的技术方案本质上或者说对现有技术做出贡献的部分可以以计算机程序即软件产品的形式体现出来,该计算机程序软件产品可以存储在存储介质中,包括若干指令用以使得一台包含数据处理单元的设备(可以是个人计算机,服务器,单片机,MUU或者网络设备等)执行本发明各个实施例或者实施例的某些部分所述的方法。
本发明提供了一种基于扩散模型的自监督预训练方法的思路及方法,具体实现该技术方案的方法和途径很多,以上所述仅是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明原理的前提下,还可以做出若干改进和润饰,这些改进和润饰也应视为本发明的保护范围。本实施例中未明确的各组成部分均可用现有技术加以实现。
Claims (10)
1.一种基于扩散模型的自监督预训练方法,其特征在于,包括以下步骤:
步骤1,在图片数据集上采用噪声预测的方法训练扩散模型,并将训练好的扩散模型作为教师网络;
步骤2,使用步骤1中得到的训练好的扩散模型提取特征向量,即将输入图片经过扩散模型中的前向扩散过程得到对应时间步的图片,然后将得到的图片输入扩散模型中的噪声预测网络,并将噪声预测网络中上采样区块中的特征图提取出来,最后将特征图进行拼接,并进行采样得到输入图片的特征向量;
步骤3,将待训练的神经网络作为学生网络,并对其进行训练,即对待训练的学生网络输出的特征图进行采样,然后与步骤2中采样得到的特征向量进行对齐,最后采用反向传播对学生网络进行训练,得到训练完成的学生网络;
步骤4,将步骤3中得到的训练完成的学生网络作为特征提取器,进行特征提取,将图像输入训练完成的学生网络得到图像的全局特征图,完成基于扩散模型的自监督预训练;
步骤5,将完成基于扩散模型的自监督预训练的学生网络应用于目标检测或语义分割的任务中,对图像中的物体进行识别和定位。
2.根据权利要求1所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤1中所述采用噪声预测的方法训练扩散模型的方法包括:
步骤1-1,选定图片数据集并将其中的所有图片的像素值缩放到区间[0,1];
步骤1-2,对于数据集的图片x0,向图片x0中注入噪声并得到t时刻的图片xt;
步骤1-3,采用第一损失函数L(θ)对扩散模型进行训练,所述第一损失函数如下:
其中,表示求方括号中的函数的期望,∈θ表示噪声预测网络。
3.根据权利要求2所述的一种基于扩散模型的自监督预训练方法,其特征在于,,步骤1-2中所述的向图片x0中注入噪声并得到t时刻的图片xt,具体方法如下:
其中,表示各个时刻缩放因子的连乘,/>αs表示s时刻的缩放因子,αt表示t时刻的缩放因子,αt:=1-βt,βt表示t时刻的噪声方差,∈表示注入的随机采样的高斯噪声,/>表示标准正态分布。
4.根据权利要求3所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤2中所述提取特征向量的方法包括:
步骤2-1,将步骤1-2中得到的t时刻的图片xt输入噪声预测网络∈θ,提取出噪声预测网络中上采样部分的中间层特征图,并将这些特征图拼接起来;
步骤2-2,从步骤2-1得到的特征图中均匀采样K个点,采用双线性插值得到采样点对应的特征向量 表示第i个图像的特征向量即教师网络生成的特征向量,/>表示K×D大小的实数矩阵,D表示特征向量的维度。
5.根据权利要求4所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤3中所述对学生网络进行训练,具体方法包括:
步骤3-1,对输入图像x0进行数据增强,并使用二进制掩码遮去图像中50%的部分;
步骤3-2,将图像剩余的50%的部分输入学生网络中的编码器,得到剩余部分图像的特征向量;
步骤3-3,将步骤3-2中得到的特征向量和可学习的掩码token拼接起来,并输入到学生网络中的解码器中,从而得到剩余图像的特征图和全局特征;
步骤3-4,将步骤3-3得到的特征图通过与步骤2-2相同的采样方法得到剩余图像的特征向量
步骤3-5,计算学生网络生成的特征向量和教师网络生成的特征向量/>之间的距离相关系数/>然后采用第二损失函数/>进行特征图对齐,方法如下:
其中,N表示数据集中图片的数量,α表示缩放因子;
步骤3-6,对输入图像的全局特征进行归一化,然后引入均匀正则化项对图片的全局特征进行约束,具体方法如下:
其中,表示第三损失函数,/>表示归一化后的图片的全局特征向量,I表示单位矩阵,/>表示矩阵的弗罗贝尼乌斯范数;
步骤3-7,结合步骤3-5和步骤3-6中的第二损失函数和第三损失函数,得到总损失函数来训练学生网络。
6.根据权利要求5所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤3-7中所述的总损失函数如下:
其中,λ表示平衡因子。
7.根据权利要求6所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤4中所述进行特征提取的方法包括:
将通过步骤3中训练得到的学生网络中的编码器作为特征提取的网络,输入图片后将编码器最后一层的特征图保留作为提取的特征。
8.根据权利要求7所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤3-1中所述的数据增强的方法包括:
采用随机裁剪、水平翻转、颜色抖动和高斯模糊对图像进行增强。
9.根据权利要求8所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤3-1中所述的使用二进制掩码遮去图像中50%的部分,具体方法包括:
步骤3-1-1,计算图像中各个区块的重要性程度s;
步骤3-1-2,根据各个区块的重要性程度s进行重要性采样,即采样出重要性程度s大于阈值的区块,使用掩码遮去。
10.根据权利要求9所述的一种基于扩散模型的自监督预训练方法,其特征在于,步骤3-1-1中所述的计算图像中各个区块的重要性程度s,具体方法如下:
其中,q0表示所述图片全局特征的查询向量,K1:表示该图片除了全局特征向量之外的索引矩阵,D表示特征向量的维度,表示数学期望,Softmax(·)表示Softmax函数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310350662.0A CN116524307A (zh) | 2023-04-04 | 2023-04-04 | 一种基于扩散模型的自监督预训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310350662.0A CN116524307A (zh) | 2023-04-04 | 2023-04-04 | 一种基于扩散模型的自监督预训练方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116524307A true CN116524307A (zh) | 2023-08-01 |
Family
ID=87393152
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310350662.0A Pending CN116524307A (zh) | 2023-04-04 | 2023-04-04 | 一种基于扩散模型的自监督预训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116524307A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116825130A (zh) * | 2023-08-24 | 2023-09-29 | 硕橙(厦门)科技有限公司 | 一种深度学习模型蒸馏方法、装置、设备及介质 |
-
2023
- 2023-04-04 CN CN202310350662.0A patent/CN116524307A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116825130A (zh) * | 2023-08-24 | 2023-09-29 | 硕橙(厦门)科技有限公司 | 一种深度学习模型蒸馏方法、装置、设备及介质 |
CN116825130B (zh) * | 2023-08-24 | 2023-11-21 | 硕橙(厦门)科技有限公司 | 一种深度学习模型蒸馏方法、装置、设备及介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Lin et al. | FPGAN: Face de-identification method with generative adversarial networks for social robots | |
Liu et al. | Connecting image denoising and high-level vision tasks via deep learning | |
CN113240580B (zh) | 一种基于多维度知识蒸馏的轻量级图像超分辨率重建方法 | |
CN109299274B (zh) | 一种基于全卷积神经网络的自然场景文本检测方法 | |
CN109886121B (zh) | 一种遮挡鲁棒的人脸关键点定位方法 | |
CN111681252A (zh) | 一种基于多路径注意力融合的医学图像自动分割方法 | |
CN109636721B (zh) | 基于对抗学习和注意力机制的视频超分辨率方法 | |
CN116309648A (zh) | 一种基于多注意力融合的医学图像分割模型构建方法 | |
CN111445496B (zh) | 一种水下图像识别跟踪系统及方法 | |
CN113706544A (zh) | 一种基于完备注意力卷积神经网络的医学图像分割方法 | |
CN116524307A (zh) | 一种基于扩散模型的自监督预训练方法 | |
Zheng et al. | T-net: Deep stacked scale-iteration network for image dehazing | |
CN116310394A (zh) | 显著性目标检测方法及装置 | |
Uddin et al. | A perceptually inspired new blind image denoising method using $ L_ {1} $ and perceptual loss | |
Chen et al. | MICU: Image super-resolution via multi-level information compensation and U-net | |
Yang et al. | Unsupervised learning polarimetric underwater image recovery under nonuniform optical fields | |
Zhang et al. | Dense haze removal based on dynamic collaborative inference learning for remote sensing images | |
Kim et al. | Infrared and visible image fusion using a guiding network to leverage perceptual similarity | |
CN114202473A (zh) | 一种基于多尺度特征和注意力机制的图像复原方法及装置 | |
CN116258652B (zh) | 基于结构注意和文本感知的文本图像修复模型及方法 | |
CN116778164A (zh) | 一种基于多尺度结构改进DeeplabV3+网络的语义分割方法 | |
CN116452472A (zh) | 基于语义知识引导的低照度图像增强方法 | |
CN114627293A (zh) | 基于多任务学习的人像抠图方法 | |
CN112396598A (zh) | 一种基于单阶段多任务协同学习的人像抠图方法及系统 | |
Di et al. | FDNet: An end-to-end fusion decomposition network for infrared and visible images |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |