CN116883545A - 基于扩散模型的图片数据集扩充方法、介质及设备 - Google Patents
基于扩散模型的图片数据集扩充方法、介质及设备 Download PDFInfo
- Publication number
- CN116883545A CN116883545A CN202310827912.5A CN202310827912A CN116883545A CN 116883545 A CN116883545 A CN 116883545A CN 202310827912 A CN202310827912 A CN 202310827912A CN 116883545 A CN116883545 A CN 116883545A
- Authority
- CN
- China
- Prior art keywords
- image
- diffusion model
- style
- word
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000009792 diffusion process Methods 0.000 title claims abstract description 155
- 238000000034 method Methods 0.000 title claims abstract description 91
- 238000012549 training Methods 0.000 claims abstract description 71
- 230000008569 process Effects 0.000 claims description 48
- 239000011159 matrix material Substances 0.000 claims description 12
- 238000004590 computer program Methods 0.000 claims description 11
- 230000004927 fusion Effects 0.000 claims description 9
- 238000013507 mapping Methods 0.000 claims description 9
- 238000004364 calculation method Methods 0.000 claims description 8
- 238000005457 optimization Methods 0.000 claims description 5
- 230000003190 augmentative effect Effects 0.000 claims description 3
- 238000005070 sampling Methods 0.000 claims description 3
- 238000005215 recombination Methods 0.000 abstract 1
- 230000006798 recombination Effects 0.000 abstract 1
- 230000000875 corresponding effect Effects 0.000 description 34
- 238000009826 distribution Methods 0.000 description 11
- 238000012360 testing method Methods 0.000 description 5
- 230000001276 controlling effect Effects 0.000 description 4
- 230000000694 effects Effects 0.000 description 4
- 238000012545 processing Methods 0.000 description 4
- 238000013528 artificial neural network Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000011156 evaluation Methods 0.000 description 2
- 230000006870 function Effects 0.000 description 2
- 238000013526 transfer learning Methods 0.000 description 2
- 230000009466 transformation Effects 0.000 description 2
- 238000000844 transformation Methods 0.000 description 2
- 238000003491 array Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000001364 causal effect Effects 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 230000002596 correlated effect Effects 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000001747 exhibiting effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 239000012633 leachable Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000011084 recovery Methods 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T11/00—2D [Two Dimensional] image generation
- G06T11/60—Editing figures and text; Combining figures or text
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本发明提供了一种基于扩散模型的图片数据集扩充方法、介质及设备。该方法包括:S1:针对原始图片数据集中的图像,设计对应的文本描述;S2:构建单词‑图像重映射模块嵌入预训练过的扩散模型,使用原始图片和对应文本描述作为输入数据对,训练单词‑图像重映射模块,构建图片到单词的重新映射;S3:固定扩散模型与其中的单词‑图像重映射模块,通过文本描述中单词的重新组合和拼接,构建新的文本描述。通过使用该文本描述与不同的随机噪声,通过扩散模型生成不同于数据集的图片,从而完成对原始图片数据集的扩充。
Description
技术领域
本发明涉及计算机视觉处理领域,尤其涉及一种基于扩散模型的图片数据集扩充方法。
背景技术
扩散模型(Diffusion Models)是一种基于概率的生成模型,其目的是通过学习数据的潜在分布来生成新的数据样本。在深度学习领域,扩散模型通常被用于生成图像、文本和其他类型的数据。
扩散模型的核心思想是将数据的生成过程建模为一个随机扩散过程。在这个过程中,模型从一个简单的先验分布(如高斯分布)开始,然后通过一系列的随机变换,逐渐将这个分布扩散到目标数据的潜在分布。这些随机变换通常由神经网络参数化,以便可以使用梯度下降等优化方法来学习它们。扩散模型的一个关键概念是噪声扩散过程(NoiseDiffusion Process)。在这个过程中,模型首先将原始数据样本加入噪声,然后逐渐减小噪声的强度,直到数据恢复到其原始形态。这个过程可以通过以下公式表示:
xt=sqrt(1-αt)*x0+sqrt(αt)*∈t,
其中αt表示在时间步t的数据样本,x_0表示原始数据样本,αt是噪声强度参数,∈t是从标准正态分布采样的噪声。在扩散模型中,需要学习一个神经网络来预测每个时间步的噪声强度参数∈t。扩散模型的另一个关键概念是反向扩散过程(Reverse DiffusionProcess)。在这个过程中,模型的目标是从加噪后的数据样本xt逆向恢复原始数据样本x0。此过程可以通过公式表示:x0=(xt-sqrt(αt)*∈t)/sqrt(1-αt)
在训练扩散模型时,模型需要最小化原始数据样本x0和通过反向扩散过程恢复的数据样本之间的差异。这可以通过最小化均方误差(MSE)损失或其他类似的目标函数来实现。为了生成新的数据样本,模型可以从简单的先验分布(如高斯分布)中采样一个随机样本,然后通过扩散模型的反向扩散过程将其转换为目标数据的潜在分布。在实践中,模型通常需要通过多个时间步来生成新的数据样本,以确保生成的样本具有足够的多样性。
扩散模型在生成图像方面取得了显著的成功。例如,OpenAI的DALL-E是一个基于扩散模型的图像生成系统,可以根据文本描述生成高质量的图像。扩散模型还被用于生成文本、音频和其他类型的数据,表现出强大的生成能力和多样性。然而,扩散模型也存在一些局限性。对于图像而言,单调的文本会导致生成图片的多样性变差。当遇到一词多义的情况下,扩散模型也会受到迷惑从而生成错误的图片。
图像数据的规模和泛化能力之间存在着密切的关系。泛化能力是指机器学习模型在未见过的新数据上的表现。在深度学习和计算机视觉领域,图像数据的规模通常对模型的泛化能力产生显著影响。以下是图像数据规模与泛化能力之间关系的一些关键点:
1.更大的数据集通常带来更好的泛化能力:在一般情况下,具有更多样本的数据集能够提供更丰富的信息,有助于训练出具有更强泛化能力的模型。这是因为更大的数据集能够更好地捕捉到数据的底层分布,从而使模型能够学习到更多的特征和模式。
2.数据的多样性对泛化能力至关重要:仅仅增加数据的数量并不一定能提高泛化能力,数据的多样性同样重要。一个理想的图像数据集应该包含各种各样的场景、物体和视角,以便模型能够学习到更具代表性的特征。当数据集包含多样性丰富的样本时,模型在面对新数据时更有可能做出正确的预测。
3.数据规模对过拟合的影响:过拟合是指模型在训练数据上表现良好,但在测试数据上表现较差的现象。通常情况下,当数据集规模较小时,模型更容易发生过拟合。这是因为模型可能会过度拟合训练数据中的噪声和特定样本,而忽略了数据的底层分布。随着数据规模的增加,过拟合的风险通常会降低,从而提高模型的泛化能力。
总之,图像数据的规模和泛化能力之间存在着密切的关系。更大、更多样化的数据集通常有助于提高模型的泛化能力,降低过拟合的风险。然而,仅仅增加数据的数量并不一定能提高泛化能力,数据的质量和多样性同样重要。在实践中,数据增强等技术可以有效地提高数据集的规模和多样性,从而提高模型的泛化能力。但是从真实世界中获取额外数据扩充数据集成本高昂,因此本发明使用扩散模型解决这一问题。
发明内容
本发明的目的是克服现有技术的不足,提供一种基于扩散模型的图片数据集扩充方法、介质及设备。
本发明的发明构思是:数据集规模和多样性与模型的性能正相关,但是从真实世界中收集额外的数据成本高昂。因此本发明基于扩散模型提出了一种全新的生成模型。可以基于当前有限的数据集,生成同风格和不同风格两种图像。由于生成模型仅需要一个固定的文本描述模板即可根据不同的随机噪声生成目标图像,因此可以不受限制地生成任意张同风格和不同风格的图像,从而对现有数据集进行有效的扩充。
为实现上述发明目的,本发明具体采用的技术方案如下:
第一方面,本发明提供了一种基于扩散模型的图片数据集扩充方法,其包括如下步骤:
S1:针对原始图片数据集中的图像,将每张图像与该图像的风格和类别的文本描述构建为图像文本对,并将原始图片数据集中的图像按照风格划分为子数据集;
S2:针对预训练扩散模型的Unet网络中每个注意力层中的每个线性层L,设置一个单词-图像重映射模块,每个单词-图像重映射模块中包含两个可学习矩阵M1和M2;固定预训练扩散模型的其余模型参数,仅设置单词-图像重映射模块中的两个可学习矩阵可调;
然后针对每一种图像风格,利用对应风格的子数据集训练单词-图像重映射模块;训练过程的每一轮迭代中,将图像文本对输入预训练扩散模型后,扩散模型经过正向过程和反向过程得到还原图像,且扩散模型生成还原图像的过程中,每个线性层L的权重需采用线性层L原始权重与残差的加权和,所述残差为对应的单词-图像重映射模块中两个可学习矩阵M1和M2的点积,再通过最小化图像文本对中的原始图像与模型输出的还原图像之间误差损失,对单词-图像重映射模块中的两个可学习矩阵M1和M2进行更新,而线性层L的原始权重保持不变;每一种图像风格完成训练过程后,保存预训练扩散模型中所有单词-图像重映射模块中最终优化后的两个可学习矩阵M1和M2;
S3:将预训练扩散模型中各单词-图像重映射模块最终优化后的可学习矩阵M1和M2,以残差形式直接更新或融合后更新至预训练扩散模型的线性层权重中,再利用更新后的预训练扩散模型,基于新的文本描述对原始图片数据集进行扩充。
作为上述第一方面的优选,所述S1的具体步骤包括:
S101:对待扩充的原始图片数据集Xsrc={x1,x2,…,xI}根据风格和类别进行分组,其中包含的风格集合为D={d1,d2,…,dK},其中包含的类别集合为C={c1,c2,…,cM};针对任一风格dk和类别cm的组合,风格dk对应的文本描述为类别cm对应的文本描述为该组合的文本描述为/>所有风格和类别的组合所对应文本描述集合为:
S102:针对原始图片数据集中的每个图像xi构建图像文本对,其中若图像xi属于dk风格且类别为cm,则选取对应的文字描述构建成图像文本对<xi,pk,m>。
作为上述第一方面的优选,所述S2的具体步骤包括:
S201:对于当前训练的风格dk∈D,针对预训练扩散模型G中的Unet网络,在Unet网络的每个注意力层中的每个线性层L上对应设置一个单词-图像重映射模块,每个单词-图像重映射模块中包含两个可学习矩阵和/>固定扩散模型G中包含线性层L的权重WL在内的所有模型参数,仅设置单词-图像重映射模块中的两个可学习矩阵可调,且扩散模型G在根据输入数据生成还原图像的过程中,每个线性层L参与计算的权重为该线性层L原始权重WL与残差的加权和/>
S202:针对当前训练的风格dk∈D,从原始图片数据集中属于该风格dk的子数据集中随机选取不同的图像文本对<xi,pk,m>组成一个批处理数据,并输入到扩散模型G中,输入图像xi经过扩散模型正向过程逐渐累加噪声得到随后将/>和对应文本描述pk,m经过扩散模型反向过程得到还原图像x′i;再通过计算xi和x′i之间的均方误差损失,来对文本-图像重映射模块中的两个可学习矩阵/>和/>进行优化,但除文本-图像重映射模块外的其余模型参数全部冻结;不断采样不同的批处理数据对文本-图像重映射模块进行迭代训练,达到终止条件后,保存预训练扩散模型中所有单词-图像重映射模块中最终优化后的两个可学习矩阵/>和/>完成在dk风格下的文本-图片重映射;
S203:针对风格集合D中的其余每一种风格,分别重复执行S201和S202,直到遍历完成风格集合D中的所有风格。
作为上述第一方面的优选,所述S3中,需从生成风格与原始图片数据集一致的扩充图片和生成风格与原始图片数据集不一致的扩充图片两个方向完成对原始图片数据集的扩充。
作为上述第一方面的优选,所述S3中,生成风格与原始图片数据集一致的扩充图片的具体步骤包括:
S311:经过上述S2过程,针对每种风格dk,将预训练扩散模型G中各单词-图像重映射模块最终优化后的可学习矩阵和/>以残差形式直接更新至预训练扩散模型对应的线性层WL权重中,更新后的线性层权重为:
其中,α为控制残差比例的超参数;预训练扩散模型G中所有线性层经过权重更新后,去除所有单词-图像重映射模块,得到能正确反映单词-图向映射关系的扩散模型该模型的输入为描述文本/>输出为图像风格属于dk且图中目标属于cm类别的图片;
S312:在对每个风格dk∈D进行图片扩充时,针对类别集合为C中的M个类别{c1,c2,…,cM},分别将风格dk对应的文本描述与每一种类别cm对应的文本描述/>进行组合,形成文本描述集合/>依次将集合/>中的每个文本描述/>作为扩散模型/>的输入,并通过设置若干不同的随机种子生成若干不同的扩充图片,实现对原始图片数据集中已有风格的图片扩充。
作为上述第一方面的优选,所述S3中,生成风格与原始图片数据集不一致的扩充图片的具体步骤包括:
S321:经过上述S2过程,针对任意两种不同的风格dk1∈D和dk2∈D,将预训练扩散模型G中各单词-图像重映射模块最终优化后的可学习矩阵,以残差形式融合后更新至预训练扩散模型对应的线性层WL权重中,更新后的线性层权重为:
其中,β是融合超参数,负责控制两种风格dk1和dk2融合的尺度;预训练扩散模型G中所有线性层经过权重更新后,去除所有单词-图像重映射模块,得到能正确反映单词-图向映射关系的扩散模型该模型的输入为描述文本/>输出为图像风格属于融合了dk1和dk2的新风格d″k1,k2且图中目标属于cm类别的图片x″;
S322:在对每个新风格进行图片扩充时,针对类别集合为C中的M个类别{c1,c2,…,cM},分别将风格dk1对应的文本描述/>风格dk2对应的文本描述/>与每一种类别cm对应的文本描述/>进行组合,形成文本描述集合依次将集合/>中的每个文本描述/>作为扩散模型/>的输入,并通过设置若干不同的随机种子生成若干不同的扩充图片,实现对原始图片数据集中新风格的图片扩充。
作为上述第一方面的优选,所述超参数α设置为0.5~0.7。
作为上述第一方面的优选,所述超参数β设置为0.5~0.7。
第二方面,本发明提供了一种计算机可读存储介质,所述存储介质上存储有计算机程序,当所述计算机程序被处理器执行时,实现如上述第一方面任一方案所述的基于扩散模型的图片数据集扩充方法。
第三方面,本发明提供了一种计算机电子设备,其包括存储器和处理器;
所述存储器,用于存储计算机程序;
所述处理器,用于当执行所述计算机程序时,实现如上述第一方面任一方案所述的基于扩散模型的图片数据集扩充方法。
本发明与背景技术相比,具有的有益的效果是:
本发明针对使用有限特定风格图片训练,在未见新风格图片分类的任务,提出了一种基于扩散模型的图片数据集扩充方法。该方法从实际应用角度出发,对有限的训练数据进行了充分的利用。通过对比学习和因果指引的方法,使简单的模型对未见风格图片的分类性能有了明显提升。基于本发明,在仅仅依赖有限风格的图片,便可对其他多样风格的图片进行稳定的预测,扩展了模型在真实世界中的应用场景并提升了性能表现。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它附图。
图1是本发明实施例提供的一种基于扩散模型的图片数据集扩充的流程图。
图2是本发明实施例提供的模型架构以及流程图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
为了解决现有技术中存在的问题,本发明实施例提供了一种基于扩散模型的图片数据集扩充方法,该方法对多个不同风格的有限个类别的图片和其对应的文字描述作为输入数据;使用基于扩散模型完成单词-图片重映射;随后通过不同文本即可生成与原数据集同风格和不同风格的新图片。本发明可以基于有限的数据进行训练,并扩充任意数目的同风格和不同风格的图片。
本发明提供了一种基于扩散模型的图片数据集扩充方法,其基于扩散模型来实现。需要说明的是,本发明中的扩散模型(Diffusion Model)属于现有技术,其中使用Unet作为扩散模型的基本算子,通过正向过程(扩散)和反向过程(去噪)得到还原图像,具体的原理不再赘述。其中现有技术已经存在预训练好的扩散模型,本发明后续亦可基于预训练号的扩散模型进行进一步的微调,使其满足本发明的数据集扩充任务。扩散模型内的Unet网络,存在一系列的注意力层(Attention Layer),而每个注意力层中有包含很多线性层,本发明的核心是在每个线性层上外加一个单词-图像重映射模块,来微调单词到图片的重新映射过程。
如图1所示,在本发明的一个较佳实施例中,上述基于扩散模型的图片数据集扩充方法包括如下具体步骤:
S1:对原始图片数据集的图像构建对应的文本描述,构建<图像,文本>对作为输入数据。
在该步骤中,针对原始图片数据集中的图像,将每张图像与该图像的风格和类别的文本描述构建为图像文本对,并将原始图片数据集中的图像按照风格划分为子数据集。
在本发明的实施例中,上述S1的具体步骤包括如下子步骤:
S101:对待扩充的原始图片数据集Xsrc={x1,x2,…,xI}根据风格和类别进行分组,其中包含的风格集合为D={d1,d2,…,dK},其中包含的类别集合为C={c1,c2,…,cM},I为原始图片数据集中的图片总数,K为图片的风格总数,M为图片中的目标对象的类别总数。
本发明的实施例中,若将风格和类别的文字描述分别定义为txtD和txtC,则可将txtD和txtC进行组合形成对原始图片的文本描述定义为p=“txtdtxtc”。具体而言,针对任一风格dk和类别cm的组合,风格dk对应的文本描述定义为类别cm对应的文本描述定义为/>该组合的文本描述定义为/>所有风格和类别的组合所对应文本描述集合定义为:
S102:针对原始图片数据集中的每个图像xi构建图像文本对,其中若图像xi属于dk风格且类别为cm,则选取对应的文字描述构建成图像文本对<xi,pk,m>。
S2:构建单词-图像重映射模块嵌入预训练扩散模型的UNet网络当中,构建单词-图片的重新映射。
在该步骤中,针对预训练扩散模型的Unet网络中每个注意力层中的每个线性层L,设置一个单词-图像重映射模块,每个单词-图像重映射模块中包含两个可学习矩阵M1和M2;固定预训练扩散模型的其余模型参数,仅设置单词-图像重映射模块中的两个可学习矩阵可调。
需要特别说明的是,由于后续需要针对每一种图像风格进行训练,而每一种图像风格初始的预训练扩散模型和单词-图像重映射模块均是相同的。
当构建完预训练扩散模型和单词-图像重映射模块后,即可针对每一种图像风格,利用对应风格的子数据集训练单词-图像重映射模块。与传统的模型训练类似,训练过程是一个迭代的过程,每一轮迭代中,将图像文本对输入预训练扩散模型后,扩散模型经过正向过程和反向过程得到还原图像,且扩散模型生成还原图像的过程中,每个线性层L的权重需采用线性层L原始权重与残差的加权和,所述残差为对应的单词-图像重映射模块中两个可学习矩阵M1和M2的点积,获得还原图像后再通过最小化图像文本对中的原始图像与模型输出的还原图像之间误差损失,对单词-图像重映射模块中的两个可学习矩阵M1和M2进行更新,而线性层L的原始权重保持不变,从而完成一轮迭代,更新后的两个可学习矩阵M1和M2继续参与下一轮迭代时的前向计算过程。每一种图像风格完成训练过程后,保存预训练扩散模型中所有单词-图像重映射模块中最终优化后的两个可学习矩阵M1和M2。
在本发明的实施例中,上述S2的具体步骤可表述为包括如下子步骤:
S201:对于当前训练的风格dk∈D,针对预训练扩散模型G中的Unet网络,在Unet网络的每个注意力层中的每个线性层L上对应设置一个单词-图像重映射模块,每个单词-图像重映射模块中包含两个可学习矩阵和/>固定扩散模型G中包含线性层L的权重WL在内的所有模型参数,仅设置单词-图像重映射模块中的两个可学习矩阵可调,且扩散模型G在根据输入数据生成还原图像的前向计算过程中,每个线性层L参与计算的权重为该线性层L原始权重WL与残差的加权和/>
由此可见,本发明将扩散模型G中UNet的Attention Layer中的线性层L权重定义为WL,针对风格dk构建两个可学习的矩阵和/>两者的点积作为WL的残差用来在下一次计算过程中迭代到WL上,从而基于新的权重获得更好地模型输出。由于残差更新的数据比原始模型微调数量少的多,因此可以快速完成单词-图像重映射。
S202:针对当前训练的风格dk∈D,从原始图片数据集中属于该风格dk的子数据集中随机选取不同的图像文本对<xi,pk,m>组成一个批处理数据,并输入到扩散模型G中,输入图像xi经过扩散模型正向过程逐渐累加噪声得到随后将/>和对应文本描述pk,m经过扩散模型反向过程得到还原图像x′i;再通过计算xi和x′i之间的均方误差损失(MeanSquared Error,MSE),来对文本-图像重映射模块中的两个可学习矩阵/>和/>进行优化,但除文本-图像重映射模块外的其余模型参数全部冻结;不断采样不同的批处理数据对文本-图像重映射模块进行迭代训练,达到终止条件后,保存预训练扩散模型中所有单词-图像重映射模块中最终优化后的两个可学习矩阵/>和/>完成在dk风格下的文本-图片重映射。
需要特别说明的是,上述单词-图像重映射模块是以外挂形式设置在每个线性层L上的,其中的两个可学习矩阵M1和M2在模型前向计算过程中,会以残差形式叠加在线性层L的自身权重上,也就是说模型前向计算时每个线性层L是以其自身权重和残差的加权和替代自身权重进行计算的。但是在模型反向传播更新参数时,仅更新单词-图像重映射模块,而模型中的所有网络层(包括线性层L)则被冻结,不参与反向传播优化过程。
S203:针对风格集合D中的其余每一种风格,分别重复执行S201和S202,直到遍历完成风格集合D中的所有风格。
在本发明的实施例中,上述终止条件可以是模型收敛或者达到最大迭代次数,此处可设置达到1000次迭代时终止。即设置训练部署为固定1000步,即可完成在dk风格下的文本-图片重映射,使得所训练得到的单词-图像重映射模块叠加到扩散模型G后可以准确理解风格dk和所有类别集合C中的全部类别。
S3:文本描述中单词的重新组合和拼接,并使用该文本描述与不同的随机噪声作为输入生成与原数据集同风格和不同风格的两类图片。
在该步骤中,将预训练扩散模型中各单词-图像重映射模块最终优化后的可学习矩阵M1和M2,以残差形式直接更新或融合后更新至预训练扩散模型的线性层权重中,再利用更新后的预训练扩散模型,基于新的文本描述对原始图片数据集进行扩充。
在本发明的实施例中,通过固定扩散模型与单词-图像重映射模块,将单词-图像重映射模块中的信息注入扩散模型形成新的扩散模型,后续可基于这些新的扩散模型生成全新图片完成数据集扩充。新的扩散模型的输入需要将通过文本描述中单词的重新组合和拼接,构建新的文本描述,并使用新的文本描述与不同的随机噪声(通过设置随机种子来实现),通过扩散模型生成风格与数据集一致的图片和风格与数据集不一致的图片,从两个方向完成对原始图片数据集的扩充,可提升了数据集的多样性。下面分别对两种扩充方向的具体实现形式进行展开介绍。
在本发明的实施例中,上述步骤S3中,生成风格与原始图片数据集一致的扩充图片的具体步骤包括:
S311:经过上述S2过程,针对每种风格dk,将预训练扩散模型G中各单词-图像重映射模块最终优化后的可学习矩阵和/>以残差形式直接更新至预训练扩散模型对应的线性层WL权重中,更新后的线性层权重为:
其中,α为控制残差比例的超参数,可设置为0.5~0.7,本实施例优选设置为0.6。预训练扩散模型G中所有线性层经过权重更新后,去除所有单词-图像重映射模块,得到能正确反映单词-图向映射关系的扩散模型该模型的输入为描述文本/>输出为图像风格属于dk且图中目标属于cm类别的图片;
S312:在对每个风格dk∈D进行图片扩充时,针对类别集合为C中的M个类别{c1,c2,…,cM},分别将风格dk对应的文本描述与每一种类别cm对应的文本描述/>进行组合,形成文本描述集合/>依次将集合/>中的每个文本描述/>作为扩散模型/>的输入,并通过设置若干不同的随机种子(可从0开始)生成若干不同的扩充图片,实现对原始图片数据集中已有风格的图片扩充。
在本发明的实施例中,可定义使用文本描述生成图片x′的风格为d′,类别为c′,则d′=dk且c′=cm且/>
由此可见,上述S311~S312可以得到已有风格的新图片,但是由于仅能对/>到dk风格有正确的映射,无法创造出一个新的风格/>为解决这个问题,本发明将D中多个风格融合形成新的风格。而为创造一个同时包含两种已有风格的新风格,本发明需要将两个扩散模型中的Unet中的单词-图像重映射模块进行重新融合,使Unet中每个Attention Layer的线性层L得到新的权重。
具体而言,在本发明的实施例中,上述步骤S3中,生成风格与原始图片数据集不一致的扩充图片的具体步骤包括:
S321:经过上述S2过程,针对任意两种不同的风格dk1∈D和dk2∈D,将预训练扩散模型G中各单词-图像重映射模块最终优化后的可学习矩阵,以残差形式融合后更新至预训练扩散模型对应的线性层WL权重中,更新后的线性层权重为:
其中,β是融合超参数,负责控制两种风格dk1和dk2融合的尺度,可设置为0.5~0.7,本实施例优选设置为0.6。预训练扩散模型G中所有线性层经过权重更新后,去除所有单词-图像重映射模块,得到能正确反映单词-图向映射关系的扩散模型该模型的输入为描述文本/>输出为图像风格属于融合了dk1和dk2的新风格d″k1,k2且图中目标属于cm类别的图片x″;
S322:在对每个新风格进行图片扩充时,针对类别集合为C中的M个类别{c1,c2,…,cM},分别将风格dk1对应的文本描述/>风格dk2对应的文本描述/>与每一种类别cm对应的文本描述/>进行组合,形成文本描述集合依次将集合/>中的每个文本描述/>作为扩散模型/>的输入,并通过设置若干不同的随机种子(可从0开始)生成若干不同的扩充图片,实现对原始图片数据集中新风格的图片扩充。
具体而言,如图2所示,通过S2步骤训练形成的单词-图像重映射模块,如果要用于进行已有风格d1的图片扩充,则可以直接用扩散模型来实现,但如果要创造一个同时包含d1和d2的新风格d″1,2并扩充这个新风格的图片,则需要两个风格下训练得到的单词-图像重映射模块进行重新融合,使Unet中每个Attention Layer的线性层L得到新的权重/>更新公式如下:
经过融合后的扩散模型定义为通过设置/>作为文本描述,输入/>生成具有d″1,2风格的cm类别图像x″。针对任意两个在D中的风格,均进行上述操作,即可生成若干不同风格的图像,从而实验数据集的扩充。
同样的,基于同一发明构思,本发明的另一较佳实施例中还提供了与上述实施例提供的基于扩散模型的图片数据集扩充方法方法对应的一种电子设备,其包括存储器和处理器;
所述存储器,用于存储计算机程序;
所述处理器,用于当执行所述计算机程序时,实现前述任一实施例中描述的基于扩散模型的图片数据集扩充方法方法。
此外,上述的存储器中的逻辑指令可以通过软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。
由此,基于同一发明构思,本发明的另一较佳实施例中还提供了与上述实施例提供的基于扩散模型的图片数据集扩充方法方法对应的一种计算机可读存储介质,该所述存储介质上存储有计算机程序,当所述计算机程序被处理器执行时,能实现前述任一实施例中描述的基于扩散模型的图片数据集扩充方法方法。
具体而言,在上述两个实施例的计算机可读存储介质中,存储的计算机程序被处理器执行,可执行前述S1~S3的步骤。
可以理解的是,上述存储介质可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。同时存储介质还可以是U盘、移动硬盘、磁碟或者光盘等各种可以存储程序代码的介质。
可以理解的是,上述的处理器可以是通用处理器,包括中央处理器(CentralProcessing Unit,CPU)、网络处理器(Network Processor,NP)等;还可以是数字信号处理器(Digital Signal Processing,DSP)、专用集成电路(Application SpecificIntegrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
另外需要说明的是,所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。在本申请所提供的各实施例中,所述的系统和方法中对于步骤或者模块的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个模块或步骤可以结合或者可以集成到一起,一个模块或者步骤亦可进行拆分。
下面将上述实施例中的基于扩散模型的图片数据集扩充方法,应用至具体数据集上进行分类测试。具体的步骤如S1-S3所述,不再赘述,主要展示其具体的参数以及技术效果。
实施例
为了验证本发明的效果,本实施例的验证方法在多个数据上分类的结果。本实施例按照前述S1~S3步骤的实现过程,首先获取从真实世界中采集的数据集,按照风格和类别将其分类,并构建对应的文本描述。单词-图像重映射的训练,并固扩散模型其他模块。通过使用不同的本文表述和设定不同的随机种子,可以生成与原数据集相同风格和不同风格的图片,从而实现数据集的扩充。为量化指标,本实施例对不同数据集进行扩充,使用ResNet50作为基础模型训练分类模型并测试其性能。选取PACS、OfficeHome两个泛化数据集进行多域泛化能力测试;选取ImageNet进行大规模数据集条件下的测试;选取Aircraft、Cars196、DTD、EuroSAT、Flowers、Pets、Food101和SUN397数据集对迁移学习的效果,其结果分别如表1、表2和表3所示。
表1多域泛化评估表
表2大规模泛化评估表
表3迁移学习评估表
/>
以上所述仅为本申请的较佳实施例而已,并非用于限定本申请的保护范围。凡在本申请的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本申请的保护范围内。
Claims (10)
1.一种基于扩散模型的图片数据集扩充方法,其特征在于,包括如下步骤:
S1:针对原始图片数据集中的图像,将每张图像与该图像的风格和类别的文本描述构建为图像文本对,并将原始图片数据集中的图像按照风格划分为子数据集;
S2:针对预训练扩散模型的Unet网络中每个注意力层中的每个线性层,设置一个单词-图像重映射模块,每个单词-图像重映射模块中包含两个可学习矩阵;固定预训练扩散模型的其余模型参数,仅设置单词-图像重映射模块中的两个可学习矩阵可调;
然后针对每一种图像风格,利用对应风格的子数据集训练单词-图像重映射模块;训练过程的每一轮迭代中,将图像文本对输入预训练扩散模型后,扩散模型经过正向过程和反向过程得到还原图像,且扩散模型生成还原图像的过程中,每个线性层的权重需采用线性层原始权重与残差的加权和,所述残差为对应的单词-图像重映射模块中两个可学习矩阵的点积,再通过最小化图像文本对中的原始图像与模型输出的还原图像之间误差损失,对单词-图像重映射模块中的两个可学习矩阵进行更新,而线性层的原始权重保持不变;每一种图像风格完成训练过程后,保存预训练扩散模型中所有单词-图像重映射模块中最终优化后的两个可学习矩阵;
S3:将预训练扩散模型中各单词-图像重映射模块最终优化后的可学习矩阵,以残差形式直接更新或融合后更新至预训练扩散模型的线性层权重中,再利用更新后的预训练扩散模型,基于新的文本描述对原始图片数据集进行扩充。
2.如权利要求1所述的基于扩散模型的图片数据集扩充方法,其特征在于,所述S1的具体步骤包括:
S101:对待扩充的原始图片数据集Xsrc={x1,x2,…,xI}根据风格和类别进行分组,其中包含的风格集合为D={d1,d2,…,dK},其中包含的类别集合为C={c1,c2,…,cM};针对任一风格dk和类别cm的组合,风格dk对应的文本描述为类别cm对应的文本描述为/>该组合的文本描述为/>所有风格和类别的组合所对应文本描述集合为:
S102:针对原始图片数据集中的每个图像xi构建图像文本对,其中若图像xi属于dk风格且类别为cm,则选取对应的文字描述构建成图像文本对<xi,pk,m>。
3.如权利要求2所述的基于扩散模型的图片数据集扩充方法,其特征在于,所述S2的具体步骤包括:
S201:对于当前训练的风格dk∈D,针对预训练扩散模型G中的Unet网络,在Unet网络的每个注意力层中的每个线性层L上对应设置一个单词-图像重映射模块,每个单词-图像重映射模块中包含两个可学习矩阵和/>固定扩散模型G中包含线性层L的权重WL在内的所有模型参数,仅设置单词-图像重映射模块中的两个可学习矩阵可调,且扩散模型G在根据输入数据生成还原图像的过程中,每个线性层L参与计算的权重为该线性层L原始权重WL与残差的加权和/>
S202:针对当前训练的风格dk∈D,从原始图片数据集中属于该风格dk的子数据集中随机选取不同的图像文本对<xi,pk,m>组成一个批处理数据,并输入到扩散模型G中,输入图像xi经过扩散模型正向过程逐渐累加噪声得到随后将/>和对应文本描述pk,m经过扩散模型反向过程得到还原图像xi′;再通过计算xi和x′i之间的均方误差损失,来对文本-图像重映射模块中的两个可学习矩阵/>和/>进行优化,但除文本-图像重映射模块外的其余模型参数全部冻结;不断采样不同的批处理数据对文本-图像重映射模块进行迭代训练,达到终止条件后,保存预训练扩散模型中所有单词-图像重映射模块中最终优化后的两个可学习矩阵/>和/>完成在dk风格下的文本-图片重映射;
S203:针对风格集合D中的其余每一种风格,分别重复执行S201和S202,直到遍历完成风格集合D中的所有风格。
4.如权利要求3所述的基于扩散模型的图片数据集扩充方法,其特征在于,所述S3中,需从生成风格与原始图片数据集一致的扩充图片和生成风格与原始图片数据集不一致的扩充图片两个方向完成对原始图片数据集的扩充。
5.如权利要求4所述的基于扩散模型的图片数据集扩充方法,其特征在于,所述S3中,生成风格与原始图片数据集一致的扩充图片的具体步骤包括:
S311:经过上述S2过程,针对每种风格dk,将预训练扩散模型G中各单词-图像重映射模块最终优化后的可学习矩阵和/>以残差形式直接更新至预训练扩散模型对应的线性层WL权重中,更新后的线性层权重为:
其中,α为控制残差比例的超参数;预训练扩散模型G中所有线性层经过权重更新后,去除所有单词-图像重映射模块,得到能正确反映单词-图向映射关系的扩散模型该模型的输入为描述文本/>输出为图像风格属于dk且图中目标属于cm类别的图片;
S312:在对每个风格dk∈D进行图片扩充时,针对类别集合为C中的M个类别{c1,c2,…,cM},分别将风格dk对应的文本描述与每一种类别cm对应的文本描述/>进行组合,形成文本描述集合/>依次将集合/>中的每个文本描述作为扩散模型/>的输入,并通过设置若干不同的随机种子生成若干不同的扩充图片,实现对原始图片数据集中已有风格的图片扩充。
6.如权利要求4所述的基于扩散模型的图片数据集扩充方法,其特征在于,所述S3中,生成风格与原始图片数据集不一致的扩充图片的具体步骤包括:
S321:经过上述S2过程,针对任意两种不同的风格dk1∈D和dk2∈D,将预训练扩散模型G中各单词-图像重映射模块最终优化后的可学习矩阵,以残差形式融合后更新至预训练扩散模型对应的线性层WL权重中,更新后的线性层权重为:
其中,β是融合超参数,负责控制两种风格dk1和dk2融合的尺度;预训练扩散模型G中所有线性层经过权重更新后,去除所有单词-图像重映射模块,得到能正确反映单词-图向映射关系的扩散模型该模型的输入为描述文本/>输出为图像风格属于融合了dk1和dk2的新风格d′k′1,k2且图中目标属于cm类别的图片x″;
S322:在对每个新风格进行图片扩充时,针对类别集合为C中的M个类别{c1,c2,…,cM},分别将风格dk1对应的文本描述/>风格dk2对应的文本描述/>与每一种类别cm对应的文本描述/>进行组合,形成文本描述集合依次将集合/>中的每个文本描述/>作为扩散模型/>的输入,并通过设置若干不同的随机种子生成若干不同的扩充图片,实现对原始图片数据集中新风格的图片扩充。
7.如权利要求6所述的基于扩散模型的图片数据集扩充方法,其特征在于,所述超参数α设置为0.5~0.7。
8.如权利要求6所述的基于扩散模型的图片数据集扩充方法,其特征在于,所述超参数β设置为0.5~0.7。
9.一种计算机可读存储介质,其特征在于,所述存储介质上存储有计算机程序,当所述计算机程序被处理器执行时,实现如权利要求1~8任一所述的基于扩散模型的图片数据集扩充方法。
10.一种计算机电子设备,其特征在于,包括存储器和处理器;
所述存储器,用于存储计算机程序;
所述处理器,用于当执行所述计算机程序时,实现如权利要求1~8任一所述的基于扩散模型的图片数据集扩充方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310827912.5A CN116883545A (zh) | 2023-07-06 | 2023-07-06 | 基于扩散模型的图片数据集扩充方法、介质及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310827912.5A CN116883545A (zh) | 2023-07-06 | 2023-07-06 | 基于扩散模型的图片数据集扩充方法、介质及设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116883545A true CN116883545A (zh) | 2023-10-13 |
Family
ID=88269157
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310827912.5A Pending CN116883545A (zh) | 2023-07-06 | 2023-07-06 | 基于扩散模型的图片数据集扩充方法、介质及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116883545A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117095258A (zh) * | 2023-10-17 | 2023-11-21 | 苏州元脑智能科技有限公司 | 一种扩散模型训练方法、装置、电子设备及存储介质 |
CN117216886A (zh) * | 2023-11-09 | 2023-12-12 | 中国空气动力研究与发展中心计算空气动力研究所 | 一种基于扩散模型的飞行器气动布局反设计方法 |
CN117593595A (zh) * | 2024-01-18 | 2024-02-23 | 腾讯科技(深圳)有限公司 | 基于人工智能的样本增广方法、装置及电子设备 |
-
2023
- 2023-07-06 CN CN202310827912.5A patent/CN116883545A/zh active Pending
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117095258A (zh) * | 2023-10-17 | 2023-11-21 | 苏州元脑智能科技有限公司 | 一种扩散模型训练方法、装置、电子设备及存储介质 |
CN117095258B (zh) * | 2023-10-17 | 2024-02-20 | 苏州元脑智能科技有限公司 | 一种扩散模型训练方法、装置、电子设备及存储介质 |
CN117216886A (zh) * | 2023-11-09 | 2023-12-12 | 中国空气动力研究与发展中心计算空气动力研究所 | 一种基于扩散模型的飞行器气动布局反设计方法 |
CN117216886B (zh) * | 2023-11-09 | 2024-04-05 | 中国空气动力研究与发展中心计算空气动力研究所 | 一种基于扩散模型的飞行器气动布局反设计方法 |
CN117593595A (zh) * | 2024-01-18 | 2024-02-23 | 腾讯科技(深圳)有限公司 | 基于人工智能的样本增广方法、装置及电子设备 |
CN117593595B (zh) * | 2024-01-18 | 2024-04-23 | 腾讯科技(深圳)有限公司 | 基于人工智能的样本增广方法、装置及电子设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116883545A (zh) | 基于扩散模型的图片数据集扩充方法、介质及设备 | |
US10984319B2 (en) | Neural architecture search | |
TWI751458B (zh) | 神經網路搜索方法及裝置、處理器、電子設備和電腦可讀儲存媒體 | |
CN111950594A (zh) | 基于子图采样的大规模属性图上的无监督图表示学习方法和装置 | |
CN112418292B (zh) | 一种图像质量评价的方法、装置、计算机设备及存储介质 | |
CN102915448B (zh) | 一种基于AdaBoost的三维模型自动分类方法 | |
CN109389166A (zh) | 基于局部结构保存的深度迁移嵌入聚类机器学习方法 | |
CN114610900A (zh) | 知识图谱补全方法及系统 | |
CN109614611B (zh) | 一种融合生成非对抗网络与卷积神经网络的情感分析方法 | |
Huai et al. | Zerobn: Learning compact neural networks for latency-critical edge systems | |
US11651129B2 (en) | Selecting a subset of training data from a data pool for a power prediction model | |
CN115146580A (zh) | 基于特征选择和深度学习的集成电路路径延时预测方法 | |
US20200074277A1 (en) | Fuzzy input for autoencoders | |
Afzal et al. | Discriminative feature abstraction by deep L2 hypersphere embedding for 3D mesh CNNs | |
CN113096133A (zh) | 一种基于注意力机制的语义分割网络的构建方法 | |
CN113222160B (zh) | 一种量子态的转换方法及装置 | |
CN112529057A (zh) | 一种基于图卷积网络的图相似性计算方法及装置 | |
CN116975347A (zh) | 图像生成模型训练方法及相关装置 | |
CN115543762A (zh) | 一种磁盘smart数据扩充方法、系统及电子设备 | |
CN114265954B (zh) | 基于位置与结构信息的图表示学习方法 | |
JP6705506B2 (ja) | 学習プログラム、情報処理装置および学習方法 | |
US20240005129A1 (en) | Neural architecture and hardware accelerator search | |
CN115358178A (zh) | 一种基于融合神经网络的电路良率分析方法 | |
CN114546804A (zh) | 信息推送的效应评估方法、装置、电子设备和存储介质 | |
CN114154572A (zh) | 一种基于异构平台的异构数据集中接入分析方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |