CN117095258B - 一种扩散模型训练方法、装置、电子设备及存储介质 - Google Patents
一种扩散模型训练方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN117095258B CN117095258B CN202311345132.3A CN202311345132A CN117095258B CN 117095258 B CN117095258 B CN 117095258B CN 202311345132 A CN202311345132 A CN 202311345132A CN 117095258 B CN117095258 B CN 117095258B
- Authority
- CN
- China
- Prior art keywords
- model
- target
- data
- diffusion
- determining
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000009792 diffusion process Methods 0.000 title claims abstract description 207
- 238000012549 training Methods 0.000 title claims abstract description 129
- 238000000034 method Methods 0.000 title claims abstract description 119
- 238000003860 storage Methods 0.000 title claims abstract description 22
- 230000006870 function Effects 0.000 claims description 61
- 238000004140 cleaning Methods 0.000 claims description 45
- 238000009826 distribution Methods 0.000 claims description 34
- 238000010606 normalization Methods 0.000 claims description 26
- 238000005070 sampling Methods 0.000 claims description 21
- 238000004891 communication Methods 0.000 claims description 12
- 238000004590 computer program Methods 0.000 claims description 5
- 230000004044 response Effects 0.000 claims description 4
- 230000002159 abnormal effect Effects 0.000 claims description 3
- 230000008569 process Effects 0.000 description 48
- 238000010586 diagram Methods 0.000 description 10
- 238000012545 processing Methods 0.000 description 9
- 239000013598 vector Substances 0.000 description 9
- 238000004821 distillation Methods 0.000 description 8
- 230000009467 reduction Effects 0.000 description 7
- 238000004364 calculation method Methods 0.000 description 6
- 238000010428 oil painting Methods 0.000 description 6
- 230000002441 reversible effect Effects 0.000 description 6
- 230000000694 effects Effects 0.000 description 5
- 238000013528 artificial neural network Methods 0.000 description 4
- 238000010422 painting Methods 0.000 description 4
- 230000008901 benefit Effects 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000001514 detection method Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000011946 reduction process Methods 0.000 description 3
- 238000012937 correction Methods 0.000 description 2
- 230000006378 damage Effects 0.000 description 2
- 230000007547 defect Effects 0.000 description 2
- 239000000976 ink Substances 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 239000004973 liquid crystal related substance Substances 0.000 description 2
- 238000013507 mapping Methods 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 238000013139 quantization Methods 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 1
- 230000008485 antagonism Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000002512 chemotherapy Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000000593 degrading effect Effects 0.000 description 1
- 238000012217 deletion Methods 0.000 description 1
- 230000037430 deletion Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000007599 discharging Methods 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 230000005484 gravity Effects 0.000 description 1
- 238000009499 grossing Methods 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 238000012886 linear function Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 230000005236 sound signal Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000010897 surface acoustic wave method Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
- G06N3/0442—Recurrent networks, e.g. Hopfield networks characterised by memory or gating, e.g. long short-term memory [LSTM] or gated recurrent units [GRU]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/766—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using regression, e.g. by projecting features on hyperplanes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02P—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN THE PRODUCTION OR PROCESSING OF GOODS
- Y02P90/00—Enabling technologies with a potential contribution to greenhouse gas [GHG] emissions mitigation
- Y02P90/30—Computing systems specially adapted for manufacturing
Abstract
本发明实施例提供了一种扩散模型训练方法、装置、电子设备及存储介质,涉及扩散模型训练领域,通过生成通用扩散模型;获取包含风格类型信息的专业任务数据集;基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型,实现了分阶段对通用扩散模型和针对不同风格类型的目标模型进行训练,从而实现了在扩散模型训练中,使模型能够在特定任务图像上进行训练,从而提升了模型生成特定风格图像的准确性。
Description
技术领域
本发明涉及扩散模型训练技术领域,特别是涉及一种扩散模型训练方法、一种扩散模型训练装置、一种电子设备以及一种计算机可读存储介质。
背景技术
扩散模型是目前一种主流的图像生成模型,能够生成逼真细腻的图像。该类模型一般用于文生图场景,即给定一段文本,模型生成一张或者多张能够描述该文本的图像,而在扩散模型训练过程中,存在模型训练不充分及在特定任务图像生成质量不佳的问题,只是扩散模型训练不能满足专用任务。
发明内容
本发明实施例是提供一种扩散模型训练方法、装置、电子设备以及计算机可读存储介质,以解决扩散模型训练中模型训练在特定任务图像生成质量不佳的问题。
本发明实施例公开了一种扩散模型训练方法,包括:
生成通用扩散模型;
获取包含风格类型信息的专业任务数据集;
基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型。
可选地,所述通用扩散模型具有对应的第一时间步,所述生成通用扩散模型的步骤包括:
获取全量数据集;
从所述全量数据集中确定出清洗数据集;
确定针对所述第一时间步的目标函数;所述目标函数为由伯努利分布和二项式分布组成的共轭先验分布密度函数;所述共轭先验分布密度函数包括第一目标参数和第二目标参数;
通过控制所述第一目标参数和所述第二目标参数确定针对所述通用扩散模型的损失函数;
将所述清洗数据集确定为训练样本,基于所述损失函数生成通用扩散模型。
可选地,所述共轭先验分布密度函数为:
其中,α为所述第一目标参数,β为所述第二目标参数,为伽马函数。
可选地,所述通用扩散模型设置有对应的门控器,所述基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型的步骤包括:
在所述通用扩散模型中添加针对所述风格类型信息的专家层独立模型;
采用所述门控器基于所述风格类型信息,确定针对所述专家层独立模型的专家层独立模型权重;
采用所述门控器向所述专家层独立模型输入所述专业任务数据集,生成预测结果;
采用所述门控器基于所述专家层独立模型权重和所述预测结果,生成目标扩散模型。
可选地,所述专家层独立模型通过如下公式表达:
其中,G(x)i为第i个层独立模型的权重,Ei(x)为i个专家的输出,x为前一层网络的输出。
可选地,还包括:
将所述目标扩散模型拆分为目标扩散子模型;所述目标扩散子模型与所述风格类型信息具有关联关系;
采用所述通用扩散模型和所述目标扩散子模型构建模型库;所述通用扩散模型为针对目标风格类型信息为真实风格类型的扩散模型,所述目标扩散子模型为针对目标风格类型信息为真实风格类型的扩散模型。
可选地,还包括:
响应于接收到由用户发送的文本信息和针对所述文本信息的目标风格类型信息,且所述目标风格类型信息为真实风格类型时,调用针对所述真实风格类型的所述通用扩散模型,生成针对所述文本信息的目标图片。
可选地,还包括:
响应于接收到由用户发送的文本信息和针对所述文本信息的目标风格类型信息,且所述目标风格类型信息为非真实风格类型时,调用针对所述非真实风格类型的所述目标扩散子模型,生成针对所述文本信息的目标图片。
可选地,所述非真实风格类型至少包括油画风格类型、动漫风格类型和插画风格类型。
可选地,所述目标扩散模型具有对应的第二时间步,所述将所述目标扩散模型拆分为目标扩散子模型的步骤包括:
确定初始学生模型;
确定针对所述初始学生模型的第一可学习参数,所述第一可学习参数与所述第二时间步的输出具有关联关系;
基于离散第二时间步,将所述初始学生模型转化为目标扩散子模型。
可选地,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
确定合并指导权重;
基于所述合并指导权重生成权重条件模型,并将所述权重条件模型合并至所述初始学生模型。
可选地,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
在所述权重条件模型中添加傅里叶函数。
可选地,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
初始化所述初始学生模型,以控制所述初始学生模型的模型参数与所述目标扩散模型的模型参数相同。
可选地,所述从所述全量数据集中确定出清洗数据集的步骤包括:
确定所述全量数据集中的第一目标数据和其他数据;
采用所述第一目标数据的特征值与其他数据的特征值确定所述第一目标数据和所述其他数据的相似度,并在所述相似度过高时,删除所述第一目标数据生成第一数据子集,并采用所述第一数据子集确定为清洗数据集。
可选地,还包括:
确定所述第一数据子集中的第二目标数据;所述第二目标数据为存在缺失值的数据;
删除所述第二目标数据生成第二数据子集,并采用所述第二数据子集确定为清洗数据集。
可选地,还包括:
确定所述第二数据子集中的第三目标数据;所述第三目标数据为存在异常值的数据;
删除所述第三目标数据生成第三数据子集,并采用所述第三数据子集确定为清洗数据集。
可选地,还包括:
确定所述第三数据子集中的第四目标数据;所述第四目标数据为存在噪声数据;
删除所述第四目标数据生成第四数据子集,并采用所述第四数据子集确定为清洗数据集。
可选地,还包括:
对所述第四数据子集执行归一化操作,并将经执行所述归一化操作的第四数据子集确定为清洗数据集;所述归一化操作为均值归一化,或,方差归一化,或,阈值归一化。
可选地,还包括:
对经执行所述归一化操作的第四数据子集执行抽样操作,生成清洗数据集
本发明实施例还公开了一种扩散模型训练装置,包括:
通用扩散模型生成模块,用于生成通用扩散模型;
专业任务数据集获取模块,用于获取包含风格类型信息的专业任务数据集;
目标扩散模型生成模块,用于基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型。
本发明实施例还公开了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,所述处理器、所述通信接口以及所述存储器通过所述通信总线完成相互间的通信;
所述存储器,用于存放计算机程序;
所述处理器,用于执行存储器上所存放的程序时,实现如本发明实施例所述的方法。
本发明实施例还公开了一种计算机可读存储介质,其上存储有指令,当由一个或多个处理器执行时,使得所述处理器执行如本发明实施例所述的方法。
本发明实施例包括以下优点:
本发明实施例,通过生成通用扩散模型;获取包含风格类型信息的专业任务数据集;基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型,实现了分阶段对通用扩散模型和针对不同风格类型的目标模型进行训练,从而实现了在扩散模型训练中,使模型能够在特定任务图像上进行训练,从而提升了模型生成特定风格图像的准确性。
附图说明
图1是本发明实施例中提供的一种扩散模型训练方法的步骤流程图;
图2是本发明实施例中提供的一种针对模型库数据交互示意图;
图3是本发明实施例中提供的一种扩散模型训练方法流程示意图;
图4是本发明实施例中提供的一种针对学生模型的蒸馏流程示意图;
图5是本发明实施例中提供的一种扩散模型训练装置的结构框图;
图6是本发明实施例中提供的一种电子设备的硬件结构框图;
图7是本发明实施例中提供的一种计算机可读介质的示意图。
具体实施方式
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本发明作进一步详细的说明。
为使本领域技术人员更好地理解本发明实施例,以下对本发明实施例所涉及的部分名词进行说明。
L2损失:也称为均方误差(Mean Squared Error,MSE)损失,是深度学习中常用的损失函数之一。它是通过计算预测值与真实值之间的平方差来表示预测值与真实值之间的距离。
Parti:一种基于自回归思想的文生图模型。
VQ-GAN:是一种生成式对抗网络(GAN)架构,它使用向量量化(VectorQuantization,VQ)实现了一种离散的、神经网络可训练的深度学习模型。
VQ-GAN的基本思路是通过将图像像素映射到一组离散的编码向量来实现图像压缩和处理,即通过将实数值的向量空间划分为一组离散的向量空间,来表示原始像素输入。这些编码向量形成了一种更抽象的表示形式,具有更强的表示能力。然后,使用GAN框架来学习如何在给定的编码向量中生成具有高保真度的图像。
Transformer:是一种用于处理序列数据的神经网络结构,由Google提出,并在自然语言处理(NLP)领域中被广泛使用。与传统的循环神经网络(RNN)不同,Transformer使用注意力机制(Self-Attention)来捕捉序列中不同位置之间的关系,具有高度的并行性和可扩展性,能够有效处理长序列数据。
Transformer是由编码器(Encoder)和解码器(Decoder)两部分组成的,通常用于序列到序列(Sequence-to-Sequence,Seq2Seq)模型中,如机器翻译、对话系统等。编码器将输入序列转换为一系列高维的特征表示,解码器则使用这些特征表示生成目标序列。在编码器和解码器之间,使用了更高抽象的注意力层,使得模型能够将关注点放在输入序列中不同的部分上。
stable diffusion:一种基于因扩散模型的开源文生图工具。
扩散模型是目前一种主流的图像生成模型,能够生成逼真细腻的图像。该类模型一般用于文生图场景,即给定一段文本,模型生成一张或者多张能够描述该文本的图像。其训练使用的数据被称为图文对,即一张图片与其对应的文本描述构成一条数据。扩散模型包含两个过程,前向过程与反向过程。前向过程将一张图像逐步添加微小的噪声,直到这张图片变成随机噪声。这个过程是图像破坏的阶段。反向过程是以一个随机噪声作为开始,逐步进行图像降噪,最终将随机噪声恢复成其对应的原始图片的过程。这个过程是一个图像降噪的过程。
在文生图领域,与扩散模型最接近的技术是自回归方法。典型的方法如Parti,其包括两个组成部分,VQ-GAN和Transformer编码器/解码器。VQ-GAN用于实现对于图像转变为向量,再将向量转变为图像。Transformer的编码器/解码器用于学习将文本描述的编码转换为其对应的图像编码。得到此图像编码后,使用VQ-GAN将其转换为图像。自回归方法的主要问题在于,将文本描述转向量变为对应的图像向量这个过程,由于信息的损失,导致生成图像的质量相比于扩散模型方法要差一些。
隐扩散模型(Latent diffusion model)以stable diffusion 为代表的开源文生图模型使用隐扩散模型作为其技术原理。其原理与扩散模型相似,区别在于其扩散过程不是在图像上开展的,而是在下采样后的特征图上开展的。这样做的好处是计算速度更快,所需的算力资源更少,缺点是下采样造成了信息损失,使得图像生成质量同样低于普通的扩散模型。
本发明实施例提出一种扩散模型训练方法。涵盖多阶段训练与模型推理流程。可以解决或部分解决扩散模型训练中模型训练存在性能不充分及在特定任务图像生成质量不佳的技术问题,解决目前扩散模型单纯在大规模数据集训练不能满足专用任务的缺点。
模型训练性能不充分:模型在预训练阶段使用大规模数据集上训练后,获得了图像生成的通识性能力。由于预训练数据分布的不合理,可能导致模型对于某些方面生成能力过强,某些方面生成能力过弱。导致模型训练实际上处于不充分的状态。
特定任务生成质量不佳:模型在通用数据集训练完成后,模型具备在特定任务上的生成能力,如对特定风格,特定目标的生成比较擅长。但是直接训练完的模型往往在特定任务上的生成能力不佳。
参照图1,示出了本发明实施例中提供的一种扩散模型训练方法的步骤流程图,具体可以包括如下步骤:
步骤101,生成通用扩散模型;
步骤102,获取包含风格类型信息的专业任务数据集;
步骤103,基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型。
在具体实现中,本发明实施例可以通过多阶方式生成通用扩散模型,示例性地,在一阶段训练流程中,本发明实施例可以使用全量数据集进行短时间的训练,这部分训练的特点是在一个体量大但是质量差的图文对数据集上进行初步的预训练,以获得一个能够具备多样性生成能力的初始权重。全量数据集是指收集的海量图文对数据未对其进行有针对性的清洗。
在第二阶段训练流程中,本发明实施例可以使用清洗数据集,对扩散模型进行长时间的充分的训练。这部分训练的特点是,模型在第一阶段的初始权重下,在一个数据量相对少,但是质量较高的数据集上高效的学习图文对对应关系、高质量图像的降噪过程,引入高质量文本参与模型训练。在此数据集上进行尽可能长的训练有助于获得一个生成质量较好的扩散模型,即,经过第二阶段的模型的训练,可以生成一个生成质量优异的通用扩散模型。
本发明实施例还可以获取包含风格类型信息的专业任务数据集,例如,包含油画或漫画或插画风格的图像集。
在获取到专业任务数据集后,可以基于专业任务数据集训练通用扩散模型,生成目标扩散模型。
示例性地,本发明实施例可以使用第二阶段训练得到的通用扩散模型基于包含多种不同风格类型的专业任务数据集在各专业领域进行调参训练,以具备各种专业所需的特殊能力。在第三阶段,各专业任务仅需要准备数千张图文对,即可完成针对通用扩散模型的调参训练,使得模型具备解决此专业任务的能力。典型的专业任务可以不限于包括:动漫图像生成、西方油画图像生成、中国水墨画图像生成、特定艺术家风格画作生成、插画生成等等。经过此阶段的训练,通用扩散(教师)模型可以从通用模型扩展为多个专用(学生)模型。
本发明实施例,通过生成通用扩散模型;获取包含风格类型信息的专业任务数据集;基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型,实现了分阶段对通用扩散模型和针对不同风格类型的目标模型进行训练,从而实现了在扩散模型训练中,使模型能够在特定任务图像上进行训练,从而提升了模型生成特定风格图像的准确性。
在上述实施例的基础上,提出了上述实施例的变型实施例,在此需要说明的是,为了使描述简要,在变型实施例中仅描述与上述实施例的不同之处。
在本发明的一个可选地实施例中,所述通用扩散模型具有对应的第一时间步,所述生成通用扩散模型的步骤包括:
获取全量数据集;
从所述全量数据集中确定出清洗数据集;
确定针对所述第一时间步的目标函数;所述目标函数为由伯努利分布和二项式分布组成的共轭先验分布密度函数;所述共轭先验分布密度函数包括第一目标参数和第二目标参数;
通过控制所述第一目标参数和所述第二目标参数确定针对所述通用扩散模型的损失函数;
将所述清洗数据集确定为训练样本,基于所述损失函数生成通用扩散模型。
本发明实施例可以先获取全量数据集,然后从所述全量数据集中确定出清洗数据集。
示例性地,在一阶段训练流程中,本发明实施例可以使用全量数据集进行短时间的训练,这部分训练的特点是在一个体量大但是质量差的图文对数据集上进行初步的预训练,以获得一个能够具备多样性生成能力的初始权重。全量数据集是指收集的海量图文对数据未对其进行有针对性的清洗。
在第二阶段训练流程中,本发明实施例可以使用清洗数据集,对扩散模型进行长时间的充分的训练。这部分训练的特点是,模型在第一阶段的初始权重下,在一个数据量相对少,但是质量较高的数据集上高效的学习图文对对应关系、高质量图像的降噪过程,引入高质量文本参与模型训练。在此数据集上进行尽可能长的训练有助于获得一个生成质量较好的扩散模型,即,经过第二阶段的模型的训练,可以生成一个生成质量优异的通用扩散模型。
在实际应用中,扩散模型是目前一种主流的图像生成模型,能够生成逼真细腻的图像。该类模型一般用于文生图场景,即给定一段文本,模型生成一张或者多张能够描述该文本的图像。其训练使用的数据被称为图文对,即一张图片与其对应的文本描述构成一条数据。扩散模型包含两个过程,前向过程与反向过程。前向过程将一张图像逐步添加微小的噪声,直到这张图片变成随机噪声。这个过程是图像破坏的阶段。反向过程是以一个随机噪声作为开始,逐步进行图像降噪,最终将随机噪声恢复成其对应的原始图片的过程。这个过程是一个图像降噪的过程。
具体而言,将前向过程划分为T个时间步,如T可以为1000,每两个相邻的时间步之间的图像可以根据添加一个高斯噪声得到。由此,任意两个相邻的时间步的图像的关系是可以确定的。当T足够大时,整个前向过程是一个马尔科夫链的计算过程。再根据重参数的技巧,任意时刻的图像可以由输入图像根据一个高斯分布获得。与之类似,前向过程的任意两个相邻时间步的降噪也可以看作是一个高斯分布的计算,这个过程可以由一个网络模型进行学习。在扩散模型中,一般通过一个网络模型预测反向阶段的两个相邻时间步的噪声。由于对于文生图任务,数据的组成是图像及其对应的文本描述,随意对于任意的时刻t所对应的图像xt,模型以xt,t以及文本描述c作为输入,输出相邻的xt-1应该被降噪的噪声。扩散模型选用的主流网络为U-Net网络模型,模型的损失函数被简化为对于噪声的L2损失。
相关技术对于扩散模型的训练方法包括:
对于原始数据x0,总共包含T步的扩散过程的每一步都是对上一步得到的数据Xt−1按如下公式增加高斯噪音:
由上述公式可知,基于上一时间步Xt−1得到下一时间步 Xt,表示从高斯噪音是基于高斯分布中采样获得。
高斯噪音为每一步所采用的方差可以为
反向过程中,由t时刻的样本Xt及原始图像X0计算得到Xt-1的过程同样为从一个高斯分布中采样的过程可以用公式表达。
高斯分布对应的方差与均值如下:
其中,
方差是一个定量,而均值是一个依赖X0和Xt的函数
扩散模型训练过程中,可以先通过正向过程获得Xt,输入模型据此求得X0,为了简化训练中的损失函数,模型损失函数修改为对于噪声进行L2损失函数。
其中,为正向过程中高斯分布中随机采样的噪声,/>为模型学习的噪声,损失函数为对二者进行L2损失约束。
相关技术对于扩散模型训练损失函数的计算,采用的方式为对所有的时间步进行均匀采样,计算损失函数,即从0时刻到T时刻对于损失函数的贡献是均等的。相关研究发现,在时间步接近0的阶段中,由于噪声较小,模型学习的是简单的降噪任务;在时间步接近T的阶段中,由于噪声较大,模型能学习到的是一些较为粗糙的信息重建,如边缘信息的重建;在时间步距离0和T都相对较远的中间时间段,此时噪声中等,学习到的是对于具体内容的重建。所以我们提出一种在训练阶段对于时刻t的重要性采样策略,以使得中间时刻更多的参与训练,接近0的时刻更少的参与训练。
在具体实现中,第一时间步为针对通用扩散模型的时间步,为了更加灵活的对第一时间步进行采样,针对第二阶段的训练,本发明实施例可以引入贝塔分布控制对于不同时间阶段的重要性采样策略。贝塔分布(Beta Distribution) 是一个作为伯努利分布和二项式分布的共轭先验分布的密度函数,是指一组定义在(0,1) 区间的连续概率分布;有两个参数,α和β,要求二者大于0;本发明实施例可以采用该函数作为目标函数。
在贝塔分布的概率密度可以通过如下公式表达:
其中,可以为伽马函数,α为第一目标参数,β为第二目标参数。
将清洗数据集确定为训练样本,通过控制α和β的参数选择,可以灵活的控制损失函数中使用哪些采样时间t进行计算,并以该损失函数对第一时间步进行采样,以生成目标模型。
示例性地,对于通用扩散模型的生成训练过程,可以设计不同的时间步采样策略调整训练的性能,这一过程可通过修改贝塔分布的α和β的参数选择实现。
如希望扩散阶段的初始时刻参与训练更多,可选α=2和β=8,此时,初始时间段参与训练的比例较大,中间时间段参与训练比例较小,尾段时间段基本不参与训练。
如希望扩散阶段的中间时刻参与训练更多,可选α=5和β=5,此时,中间时间段参与训练比例较大,初始时间段和尾段时间段参与训练较少。
如希望扩散阶段的尾段时刻参与训练更多,可选在α=8和β=2,此时,初始时间段基本不参与训练,中间时间段参与训练比例较小,尾段时间段主要参与训练。
当然,上述例子仅为示例,本领域技术人员可以根据实际情况采去任意值作为第一目标参数和第二目标参数以确定损失函数,对此,本发明实施例不作限定。
在本发明的一个可选地实施例中,所述通用扩散模型设置有对应的门控器,所述基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型的步骤包括:
在所述通用扩散模型中添加针对所述风格类型信息的专家层独立模型;
采用所述门控器基于所述风格类型信息,确定针对所述专家层独立模型的专家层独立模型权重;
采用所述门控器向所述专家层独立模型输入所述专业任务数据集,生成预测结果;
采用所述门控器基于所述专家层独立模型权重和所述预测结果,生成目标扩散模型。
专家层独立模型又称为混合专家(Mixture-of-Experts),是一种机器学习模型,用于解决复杂的问题和任务。它通过将多个专家模型组合在一起,以获得更强大的整体性能。
混合专家模型包含两个主要组件:专家和门控器。
专家(Experts): 每个专家是一个独立的子模型,它们被设计为在特定的输入领域或子任务中表现出色。每个专家都有自己的参数和学习能力,可以根据特定的数据子集提供有针对性的预测。例如,在语音识别任务中,一个专家可能擅长识别清晰的语音,而另一个专家可能擅长处理背景噪声。
门控器(Gatekeeper): 门控器是一个选择器,根据输入样本的特性决定哪个专家应该负责给出最终的预测结果。门控器可以是一个简单的线性函数,也可以是一个更复杂的神经网络。门控器的作用是根据输入的上下文信息,调整每个专家模型的权重,使得最适合当前输入样本的专家获得更多的贡献或权重。
混合专家模型的工作过程如下:
输入样本通过门控器,门控器根据输入的特征确定各个专家模型的权重。
输入样本同时送入多个专家模型中,并根据各自的参数得到预测结果。
门控器根据专家的预测结果和权重,进行适当的整合或选择,得到最终的输出结果。
文生图任务的一个重要方面是需要生成多种风格的图像。如自然风格/动漫风格/现实主义/印象主义等。为实现高质量的多种风格图像生成的,本发明实施例的采取的扩散模型训练策略可以为,为通用扩散模型添加针对风格类型信息的专家层独立模型MOE,即,结构为是一组卷积模块,每层MOE对应一种单独的图像风格,用于捕获各种风格的特征,模型中包含一个门控器网络来决定激活哪个或者哪些专家层,具体地,可以采用门控器基于风格类型信息,确定针对专家层独立模型的专家层独立模型权重,同时,可以采用门控器向专家层独立模型输入所述专业任务数据集,生成预测结果,并采用门控器基于专家层独立模型权重和预测结果,生成目标扩散模型。
生成过程可以通过如下公式表达:
G(x)i为第i个专家的权重,Ei(x)为第i个专家的输出,x为前一层网络的输出。
门控器网络的操作可以是,对于输入x,先对其施加一个全连接层W,再对其施加一个归一化指数函数Softmax,将模型输出的实数域映射到[0, 1]区间。
在具体实现中,清洗数据集是指对于海量图文对数据集进行有针对性的清洗,以获得一个数量少但是质量高的数据集。在进行大规模模型的训练时,数据的质量和准确性对于模型的性能至关重要。因此,在进行大规模训练数据的训练之前,可以对数据进行清理和预处理,以保证模型的准确性和稳定性。
在本发明的一个可选地实施例中,可以确定所述全量数据集中的第一目标数据和其他数据;
采用所述第一目标数据的特征值与其他数据的特征值确定所述第一目标数据和所述其他数据的相似度,并在所述相似度过高时,删除所述第一目标数据生成第一数据子集,并采用所述第一数据子集确定为清洗数据集。
和/或,
确定所述第一数据子集中的第二目标数据;所述第二目标数据为存在缺失值的数据;
删除所述第二目标数据生成第二数据子集,并采用所述第二数据子集确定为清洗数据集。
和/或,
确定所述第二数据子集中的第三目标数据;所述第三目标数据为存在异常值的数据;
删除所述第三目标数据生成第三数据子集,并采用所述第三数据子集确定为清洗数据集。
和/或,
确定所述第三数据子集中的第四目标数据;所述第四目标数据为存在噪声数据;
删除所述第四目标数据生成第四数据子集,并采用所述第四数据子集确定为清洗数据集。
和/或,
对所述第四数据子集执行归一化操作,并将经执行所述归一化操作的第四数据子集确定为清洗数据集;所述归一化操作为均值归一化,或,方差归一化,或,阈值归一化。
和/或,
对经执行所述归一化操作的第四数据子集执行抽样操作,生成清洗数据集。
示例性地,可以通过如下方式清晰全量数据集中,以确定出清洗数据集。
一、数据去重
在进行大规模训练数据清理时,可以先进行数据去重操作。由于大规模数据集通常包含大量重复的样本,这些重复的样本会对模型的训练造成干扰,降低模型的性能。因此,通过对数据集进行去重操作,可以减少冗余数据,提高训练效果。
数据去重可以通过比较样本之间的特征值或者使用哈希算法来实现。对于特征值比较的方法,可以通过计算样本之间的相似度,将相似度高于一定阈值的样本视为重复样本,并删除其中一个。而哈希算法则可以将样本映射到一个唯一的哈希值,通过判断哈希值是否重复来去除重复数据。
二、数据清洗
除了去重操作,数据清洗也是大规模训练数据清理的重要步骤之一。在数据采集过程中,可能会存在一些错误、噪声或者异常值,这些数据可能会对模型的训练造成干扰,降低模型的准确性。因此,需要对数据进行清洗,去除这些错误数据。
数据清洗可以通过以下几种方法来实现:
1. 缺失值处理:对于存在缺失值的样本,可以选择删除或者填充缺失值。删除缺失值可能会导致数据量的减少,但可以避免对模型的干扰。而填充缺失值可以通过均值、中位数等方法进行。
2. 异常值处理:对于存在异常值的样本,可以选择删除或者修正异常值。删除异常值可能会导致数据量的减少,但可以避免对模型的干扰。修正异常值可以通过替换为均值、中位数等方法进行。
3. 噪声数据处理:对于存在噪声数据的样本,可以选择删除或者平滑噪声数据。删除噪声数据可能会导致数据量的减少,但可以避免对模型的干扰。平滑噪声数据可以通过滤波等方法进行。
三、数据标准化
在进行大规模模型的训练时,数据的尺度和分布可能会对模型的训练造成影响。因此,需要对数据进行标准化,使得数据具有相同的尺度和分布,以提高模型的训练效果。
数据标准化可以通过以下几种方法来实现:
1. 均值归一化:将数据减去均值,使得数据的均值为0。
2. 方差归一化:将数据除以标准差,使得数据的方差为1。
3. 最大最小值归一化:将数据缩放到0到1的范围内。
四、数据抽样
在进行大规模模型的训练时,由于数据量庞大,可能会导致训练时间过长或者资源消耗过大。因此,可以通过数据抽样的方法,从大规模数据集中随机选择一部分样本进行训练,以减少训练时间和资源消耗。
针对数据的抽样操作可以通过以下几种方法来实现:
1. 随机抽样:从数据集中随机选择一部分样本进行训练。
2. 分层抽样:根据样本的类别进行分层抽样,保证每个类别的样本在抽样后的数据集中的比例与原始数据集中的比例一致。
3. 过采样和欠采样:对于数据不平衡的情况,可以通过过采样和欠采样的方法来调整样本的比例,使得各个类别的样本数量平衡。
在具体实现中,由于扩散模型通常模型参数量较大,使用原始模型进行推理对于存储和算力的需求较大。为了提升模型在推理阶段的部署能力,在本发明的一个可选地实施例中,还可以将所述目标扩散模型拆分为目标扩散子模型;所述目标扩散子模型与所述风格类型信息具有关联关系。
将目标扩散模型拆分为目标扩散子模型,可以是将通用扩散模型中的多个专家层独立模型蒸馏为多个一一对应的目标扩散子模型。
示例性地,可以使用模型蒸馏的方式将目标扩散模型看作教师模型,通过对目标扩散模型进行蒸馏得到的小模型看作学生模型。期望学生模型与教师模型的性能损失尽可能小,同时获得的学生模型的模型参数量较小,推理速度更快,所需要的计算资源更少。
可选地,所述目标扩散模型具有对应的第二时间步,所述将所述目标扩散模型拆分为目标扩散子模型的步骤包括:
确定初始学生模型;
确定针对所述初始学生模型的第一可学习参数,所述第一可学习参数与所述第二时间步的输出具有关联关系;
基于离散第二时间步,将所述初始学生模型转化为目标扩散子模型。
可选地,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
确定合并指导权重;
基于所述合并指导权重生成权重条件模型,并将所述权重条件模型合并至所述初始学生模型。
可选地,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
在所述权重条件模型中添加傅里叶函数。
可选地,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
初始化所述初始学生模型,以控制所述初始学生模型的模型参数与所述目标扩散模型的模型参数相同。
可选地,所述从所述全量数据集中确定出清洗数据集的步骤包括:
确定所述全量数据集中的第一目标数据和其他数据;
采用所述第一目标数据的特征值与其他数据的特征值确定所述第一目标数据和所述其他数据的相似度,并在所述相似度过高时,删除所述第一目标数据生成第一数据子集,并采用所述第一数据子集确定为清洗数据集。
示例性地,第二时间步为针对目标扩散模型的时间步,将目标扩散模型作为教师模型,之后本文分可以分两步完成。
第一步,引入一个连续时间学生模型,该模型具有可学习参数η_1,以匹配教师模型在任意第二时间步 t∈[0,1] 处的输出。给定一个优化范围 [w_min, w_max],对学生模型进行优化。
其中,为了合并指导权重 w,本文可以引入了一个 w-条件模型,其中 w 作为学生模型的输入。为了更好地捕捉特征,本文还对w应用傅里叶嵌入。此外,由于初始化在模型性能中起着关键作用,因此本文初始化学生模型的参数与教师模型相同。
第二步,将离散第二时间步(discrete time-step)考虑在内,并逐步将第一步中的蒸馏模型转化为步数较短的学生模型,其可学习参数为η_2,每次采样步数减半。设N为采样步数,给定 w ~ U[w_min, w_max] 和 t∈{1,…, N},然后根据 Salimans&Ho 等人提出的方法训练学生模型。在将教师模型中的 2N 步蒸馏为学生模型中的 N 步之后,之后使用N 步学生模型作为新的教师模型,这个过程不断重复,直到将教师模型蒸馏为 N/2 步学生模型。
其中,N步可确定性和随机采样:一旦模型训练完成,给定一个指定的w∈ [w_min,w_max],然后使用DDIM(denoising diffusion implicit model,去噪扩散隐式模型)更新规则执行采样。
可选地,还包括:
采用所述通用扩散模型和所述目标扩散子模型构建模型库,所述通用扩散模型为针对目标风格类型信息为真实风格类型的扩散模型,所述目标扩散子模型为针对目标风格类型信息为真实风格类型的扩散模型。
可选地,所述模型库用于,在接收到由用户发送的文本信息和针对所述文本信息的目标风格类型信息,且所述目标风格类型信息为真实风格类型时,调用针对所述真实风格类型的所述通用扩散模型,生成针对所述文本信息的目标图片。
可选地,所述模型库用于,在接收到由用户发送的文本信息和针对所述文本信息的目标风格类型信息,且所述目标风格类型信息为非真实风格类型时,调用针对所述非真实风格类型的所述目标扩散子模型,生成针对所述文本信息的目标图片。
参考图2,图2是本发明实施例中提供的一种针对模型库数据交互示意图;
示例性地,因为目标扩散子模型可以与风格类型信息具有关联关系,采用通用扩散模型和针对不同风格类型信息的目标扩散子模型构建模型库,当用户在用户界面输入文本“一只奔跑的小狗”,并选定风格类型,则模型库可以基于用户选择的风格类型调用对应的目标扩散子模型,以基于文本“一只奔跑的小狗”,生成图片,例如,当用户选定风格类型为“真实风格”则可以采用通用扩散模型生成模拟真实风格的在奔跑中的小狗的图片A,若用户选定油画风格,则可以调用油画风格模型生成油画风格的在奔跑中的小狗的图片B等等。
可选地,所述非真实风格类型至少包括油画风格类型、动漫风格类型和插画风格类型。
为使本领域技术人员更好地理解本发明实施例,以下采用一完整示例对本发明实施例进行说明。
参考图3,图3是本发明实施例中提供的一种扩散模型训练装置的结构框图。
在一阶段训练流程中,可以使用全量数据集进行短时间的训练,这部分训练的特点是在一个体量大但是质量差的图文对数据集上进行初步的预训练,以获得一个能够具备多样性生成能力的初始权重。全量数据集是指收集的海量图文对数据未对其进行有针对性的清洗。
在第二阶段训练流程中,可以使用清洗数据集,对扩散模型进行长时间的充分的训练。这部分训练的特点是,模型在第一阶段的初始权重下,在一个数据量相对少,但是质量较高的数据集上高效的学习图文对对应关系、高质量图像的降噪过程,引入高质量文本参与模型训练。
在此数据集上进行尽可能长的训练有助于获得一个生成质量较好的扩散模型。经过第二阶段的模型的训练,我们得到了一个生成质量优异的通用扩散模型,在具备一定的多样性生成能力的同时,由于数据集质量较好,模型抑制了生成敏感内容。该阶段训练得到模型可用于通用场景的图像生成,生成的图像风格以真实场景的图像为主,具备一定的多样风格图像生成能力,能够生成多样的目标,生成复杂关系的图像内容。清洗数据集是指对于海量图文对数据集进行有针对性的清洗,以获得一个数量少但是质量高的数据集。
在第三阶段训练流程中,可以使用第二阶段训练得到的通用模型在各专业领域进行调参训练,以具备各种专业所需的特殊能力。此阶段,各专业任务仅需要准备数千张图文对即可完成对于第二阶段模型的调参训练,使得模型具备解决此专业任务的能力。典型的专业任务包括:动漫图像生成、西方油画图像生成、中国水墨画图像生成、特定艺术家风格画作生成、插画生成等等。经过此阶段的训练,模型可以从通用模型扩展为多个专用模型。
参考图4,图4是本发明实施例中提供的一种针对学生模型的蒸馏流程示意图;
扩散模型通常模型参数量较大,使用原始模型进行推理对于存储和算力的需求较大。为了实现模型推理阶段的部署,可以使用模型蒸馏的方式将原始扩散模型看作教师模型,蒸馏得到的小模型看作学生模型。期望学生模型与教师模型的性能损失尽可能小,同时获得的学生模型的模型参数量较小,推理速度更快,所需要的计算资源更小。
真实的推理场景中涉及多种图像风格、任务的生成,需要多个模型的参与。经过模型蒸馏后,可形成多个专业化的小模型。一个小模型对应一种生成能力,根据任务的下达,确定需要使用的专业小模型进行有针对性的生成。
通过第一阶段在全量数据上的小时长训练、第二阶段在清洗数据上的大时长训练、第三阶段在专业数据集上的小时长非全量参数更新训练。使用蒸馏学习获得扩散模型对应的小模型,以节省存储及计算资源。根据用户选择的风格,加载对应的模型,生成最匹配用户描述的图像,解决了目前扩散模型单纯在大规模数据集训练不能满足专用任务的缺点。
需要说明的是,对于方法实施例,为了简单描述,故将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本发明实施例并不受所描述的动作顺序的限制,因为依据本发明实施例,某些步骤可以采用其他顺序或者同时进行。其次,本领域技术人员也应该知悉,说明书中所描述的实施例均属于优选实施例,所涉及的动作并不一定是本发明实施例所必须的。
参照图5,示出了本发明实施例中提供的一种扩散模型训练装置的结构框图,具体可以包括如下模块:
通用扩散模型生成模块501,用于生成通用扩散模型;
专业任务数据集获取模块502,用于获取包含风格类型信息的专业任务数据集;
目标扩散模型生成模块503,用于基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型。
对于装置实施例而言,由于其与方法实施例基本相似,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
另外,本发明实施例还提供了一种电子设备,包括:处理器,存储器,存储在存储器上并可在处理器上运行的计算机程序,该计算机程序被处理器执行时实现上述扩散模型训练方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。
本发明实施例还提供了一种计算机可读存储介质,计算机可读存储介质上存储有计算机程序,计算机程序被处理器执行时实现上述扩散模型训练方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。其中,所述的计算机可读存储介质,如只读存储器(Read-Only Memory,简称ROM)、随机存取存储器(Random Access Memory,简称RAM)、磁碟或者光盘等。
图6为实现本发明各个实施例的一种电子设备的硬件结构示意图。
该电子设备600包括但不限于:射频单元601、网络模块602、音频输出单元603、输入单元604、传感器605、显示单元606、用户输入单元607、接口单元608、存储器609、处理器610、以及电源611等部件。本领域技术人员可以理解,图6中示出的电子设备结构并不构成对电子设备的限定,电子设备可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。在本发明实施例中,电子设备包括但不限于手机、平板电脑、笔记本电脑、掌上电脑、车载终端、可穿戴设备、以及计步器等。
应理解的是,本发明实施例中,射频单元601可用于收发信息或通话过程中,信号的接收和发送,具体的,将来自基站的下行数据接收后,给处理器610处理;另外,将上行的数据发送给基站。通常,射频单元601包括但不限于天线、至少一个放大器、收发信机、耦合器、低噪声放大器、双工器等。此外,射频单元601还可以通过无线通信系统与网络和其他设备通信。
电子设备通过网络模块602为用户提供了无线的宽带互联网访问,如帮助用户收发电子邮件、浏览网页和访问流式媒体等。
音频输出单元603可以将射频单元601或网络模块602接收的或者在存储器609中存储的音频数据转换成音频信号并且输出为声音。而且,音频输出单元603还可以提供与电子设备600执行的特定功能相关的音频输出(例如,呼叫信号接收声音、消息接收声音等等)。音频输出单元603包括扬声器、蜂鸣器以及受话器等。
输入单元604用于接收音频或视频信号。输入单元604可以包括图形处理器(Graphics Processing Unit,GPU)6041和麦克风6042,图形处理器6041对在视频捕获模式或图像捕获模式中由图像捕获装置(如摄像头)获得的静态图片或视频的图像数据进行处理。处理后的图像帧可以显示在显示单元606上。经图形处理器6041处理后的图像帧可以存储在存储器609(或其它存储介质)中或者经由射频单元601或网络模块602进行发送。麦克风6042可以接收声音,并且能够将这样的声音处理为音频数据。处理后的音频数据可以在电话通话模式的情况下转换为可经由射频单元601发送到移动通信基站的格式输出。
电子设备600还包括至少一种传感器605,比如光传感器、运动传感器以及其他传感器。具体地,光传感器包括环境光传感器及接近传感器,其中,环境光传感器可根据环境光线的明暗来调节显示面板6061的亮度,接近传感器可在电子设备600移动到耳边时,关闭显示面板6061和/或背光。作为运动传感器的一种,加速计传感器可检测各个方向上(一般为三轴)加速度的大小,静止时可检测出重力的大小及方向,可用于识别电子设备姿态(比如横竖屏切换、相关游戏、磁力计姿态校准)、振动识别相关功能(比如计步器、敲击)等;传感器605还可以包括指纹传感器、压力传感器、虹膜传感器、分子传感器、陀螺仪、气压计、湿度计、温度计、红外线传感器等,在此不再赘述。
显示单元606用于显示由用户输入的信息或提供给用户的信息。显示单元606可包括显示面板6061,可以采用液晶显示器(Liquid Crystal Display,LCD)、有机发光二极管(Organic Light-Emitting Diode, OLED)等形式来配置显示面板6061。
用户输入单元607可用于接收输入的数字或字符信息,以及产生与电子设备的用户设置以及功能控制有关的键信号输入。具体地,用户输入单元607包括触控面板6071以及其他输入设备6072。触控面板6071,也称为触摸屏,可收集用户在其上或附近的触摸操作(比如用户使用手指、触笔等任何适合的物体或附件在触控面板6071上或在触控面板6071附近的操作)。触控面板6071可包括触摸检测装置和触摸控制器两个部分。其中,触摸检测装置检测用户的触摸方位,并检测触摸操作带来的信号,将信号传送给触摸控制器;触摸控制器从触摸检测装置上接收触摸信息,并将它转换成触点坐标,再送给处理器610,接收处理器610发来的命令并加以执行。此外,可以采用电阻式、电容式、红外线以及表面声波等多种类型实现触控面板6071。除了触控面板6071,用户输入单元607还可以包括其他输入设备6072。具体地,其他输入设备6072可以包括但不限于物理键盘、功能键(比如音量控制按键、开关按键等)、轨迹球、鼠标、操作杆,在此不再赘述。
进一步的,触控面板6071可覆盖在显示面板6061上,当触控面板6071检测到在其上或附近的触摸操作后,传送给处理器610以确定触摸事件的类型,随后处理器610根据触摸事件的类型在显示面板6061上提供相应的视觉输出。虽然在图6中,触控面板6071与显示面板6061是作为两个独立的部件来实现电子设备的输入和输出功能,但是在某些实施例中,可以将触控面板6071与显示面板6061集成而实现电子设备的输入和输出功能,具体此处不做限定。
接口单元608为外部装置与电子设备600连接的接口。例如,外部装置可以包括有线或无线头戴式耳机端口、外部电源(或电池充电器)端口、有线或无线数据端口、存储卡端口、用于连接具有识别模块的装置的端口、音频输入/输出(I/O)端口、视频I/O端口、耳机端口等等。接口单元608可以用于接收来自外部装置的输入(例如,数据信息、电力等等)并且将接收到的输入传输到电子设备600内的一个或多个元件或者可以用于在电子设备600和外部装置之间传输数据。
存储器609可用于存储软件程序以及各种数据。存储器609可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据手机的使用所创建的数据(比如音频数据、电话本等)等。此外,存储器609可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
处理器610是电子设备的控制中心,利用各种接口和线路连接整个电子设备的各个部分,通过运行或执行存储在存储器609内的软件程序和/或模块,以及调用存储在存储器609内的数据,执行电子设备的各种功能和处理数据,从而对电子设备进行整体监控。处理器610可包括一个或多个处理单元;优选的,处理器610可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器610中。
电子设备600还可以包括给各个部件供电的电源611(比如电池),优选的,电源611可以通过电源管理系统与处理器610逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。
另外,电子设备600包括一些未示出的功能模块,在此不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本发明各个实施例所述的方法。
如图7所示,在本发明提供的又一实施例中,还提供了一种计算机可读存储介质701,该计算机可读存储介质中存储有指令,当其在计算机上运行时,使得计算机执行上述实施例中所述的扩散模型训练方法。
上面结合附图对本发明的实施例进行了描述,但是本发明并不局限于上述的具体实施方式,上述的具体实施方式仅仅是示意性的,而不是限制性的,本领域的普通技术人员在本发明的启示下,在不脱离本发明宗旨和权利要求所保护的范围情况下,还可做出很多形式,均属于本发明的保护之内。
本领域普通技术人员可以意识到,结合本发明实施例中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统、装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的实施例中,应该理解到,所揭露的装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。
所述功能如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。
Claims (20)
1.一种扩散模型训练方法,其特征在于,包括:
生成通用扩散模型;所述通用扩散模型设置有对应的门控器;
获取包含风格类型信息的专业任务数据集;
基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型;
所述基于所述专业任务数据集训练所述通用扩散模型,生成目标扩散模型的步骤包括:
在所述通用扩散模型中添加针对所述风格类型信息的专家层独立模型;
采用所述门控器基于所述风格类型信息,确定针对所述专家层独立模型的专家层独立模型权重;
采用所述门控器向所述专家层独立模型输入所述专业任务数据集,生成预测结果;
采用所述门控器基于所述专家层独立模型权重和所述预测结果,生成目标扩散模型;
所述通用扩散模型具有对应的第一时间步,所述生成通用扩散模型的步骤包括:
获取全量数据集;
从所述全量数据集中确定出清洗数据集;
确定针对所述第一时间步的目标函数;所述目标函数为由伯努利分布和二项式分布组成的共轭先验分布密度函数;所述共轭先验分布密度函数包括第一目标参数和第二目标参数;
通过控制所述第一目标参数和所述第二目标参数确定针对所述通用扩散模型的损失函数;
将所述清洗数据集确定为训练样本,基于所述损失函数生成通用扩散模型。
2.根据权利要求1所述的方法,其特征在于,所述共轭先验分布密度函数为:
3.根据权利要求2所述的方法,其特征在于,所述专家层独立模型通过如下公式表达:
其中,G(x)i为第i个层独立模型的权重,Ei(x)为i个专家的输出,x为前一层网络的输出。
4.根据权利要求3所述的方法,其特征在于,还包括:
将所述目标扩散模型拆分为目标扩散子模型;所述目标扩散子模型与所述风格类型信息具有关联关系;
采用所述通用扩散模型和所述目标扩散子模型构建模型库;所述通用扩散模型为针对目标风格类型信息为真实风格类型的扩散模型,所述目标扩散子模型为针对目标风格类型信息为真实风格类型的扩散模型。
5.根据权利要求4所述的方法,其特征在于,还包括:
响应于接收到由用户发送的文本信息和针对所述文本信息的目标风格类型信息,且所述目标风格类型信息为真实风格类型时,调用针对所述真实风格类型的所述通用扩散模型,生成针对所述文本信息的目标图片。
6.根据权利要求4所述的方法,其特征在于,还包括:
响应于接收到由用户发送的文本信息和针对所述文本信息的目标风格类型信息,且所述目标风格类型信息为非真实风格类型时,调用针对所述非真实风格类型的所述目标扩散子模型,生成针对所述文本信息的目标图片。
7.根据权利要求6所述的方法,其特征在于,所述非真实风格类型至少包括油画风格类型、动漫风格类型和插画风格类型。
8.根据权利要求4所述的方法,其特征在于,所述目标扩散模型具有对应的第二时间步,所述将所述目标扩散模型拆分为目标扩散子模型的步骤包括:
确定初始学生模型;
确定针对所述初始学生模型的第一可学习参数,所述第一可学习参数与所述第二时间步的输出具有关联关系;
基于离散第二时间步,将所述初始学生模型转化为目标扩散子模型。
9.根据权利要求8所述的方法,其特征在于,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
确定合并指导权重;
基于所述合并指导权重生成权重条件模型,并将所述权重条件模型合并至所述初始学生模型。
10.根据权利要求9所述的方法,其特征在于,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
在所述权重条件模型中添加傅里叶函数。
11.根据权利要求10所述的方法,其特征在于,在所述确定针对所述初始学生模型的第一可学习参数的步骤之前,还包括:
初始化所述初始学生模型,以控制所述初始学生模型的模型参数与所述目标扩散模型的模型参数相同。
12.根据权利要求1所述的方法,其特征在于,所述从所述全量数据集中确定出清洗数据集的步骤包括:
确定所述全量数据集中的第一目标数据和其他数据;
采用所述第一目标数据的特征值与其他数据的特征值确定所述第一目标数据和所述其他数据的相似度,并在所述相似度过高时,删除所述第一目标数据生成第一数据子集,并采用所述第一数据子集确定为清洗数据集。
13.根据权利要求12所述的方法,其特征在于,还包括:
确定所述第一数据子集中的第二目标数据;所述第二目标数据为存在缺失值的数据;
删除所述第二目标数据生成第二数据子集,并采用所述第二数据子集确定为清洗数据集。
14.根据权利要求13所述的方法,其特征在于,还包括:
确定所述第二数据子集中的第三目标数据;所述第三目标数据为存在异常值的数据;
删除所述第三目标数据生成第三数据子集,并采用所述第三数据子集确定为清洗数据集。
15.根据权利要求14所述的方法,其特征在于,还包括:
确定所述第三数据子集中的第四目标数据;所述第四目标数据为存在噪声数据;
删除所述第四目标数据生成第四数据子集,并采用所述第四数据子集确定为清洗数据集。
16.根据权利要求15所述的方法,其特征在于,还包括:
对所述第四数据子集执行归一化操作,并将经执行所述归一化操作的第四数据子集确定为清洗数据集;所述归一化操作为均值归一化,或,方差归一化,或,阈值归一化。
17.根据权利要求16所述的方法,其特征在于,还包括:
对经执行所述归一化操作的第四数据子集执行抽样操作,生成清洗数据集。
18.一种扩散模型训练装置,其特征在于,包括:
通用扩散模型生成模块,用于获取全量数据集;从所述全量数据集中确定出清洗数据集;确定针对第一时间步的目标函数;所述目标函数为由伯努利分布和二项式分布组成的共轭先验分布密度函数;所述共轭先验分布密度函数包括第一目标参数和第二目标参数;通过控制所述第一目标参数和所述第二目标参数确定针对所述通用扩散模型的损失函数;将所述清洗数据集确定为训练样本,基于所述损失函数生成通用扩散模型;所述通用扩散模型设置有对应的门控器;所述通用扩散模型具有对应的第一时间步;
专业任务数据集获取模块,用于获取包含风格类型信息的专业任务数据集;
目标扩散模型生成模块,用于在所述通用扩散模型中添加针对所述风格类型信息的专家层独立模型;采用所述门控器基于所述风格类型信息,确定针对所述专家层独立模型的专家层独立模型权重;采用所述门控器向所述专家层独立模型输入所述专业任务数据集,生成预测结果;采用所述门控器基于所述专家层独立模型权重和所述预测结果,生成目标扩散模型。
19.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,所述处理器、所述通信接口以及所述存储器通过所述通信总线完成相互间的通信;
所述存储器,用于存放计算机程序;
所述处理器,用于执行存储器上所存放的程序时,实现如权利要求1-17任一项所述的方法。
20.一种计算机可读存储介质,其上存储有指令,当由一个或多个处理器执行时,使得所述处理器执行如权利要求1-17任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311345132.3A CN117095258B (zh) | 2023-10-17 | 2023-10-17 | 一种扩散模型训练方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311345132.3A CN117095258B (zh) | 2023-10-17 | 2023-10-17 | 一种扩散模型训练方法、装置、电子设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117095258A CN117095258A (zh) | 2023-11-21 |
CN117095258B true CN117095258B (zh) | 2024-02-20 |
Family
ID=88777608
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311345132.3A Active CN117095258B (zh) | 2023-10-17 | 2023-10-17 | 一种扩散模型训练方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117095258B (zh) |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116740204A (zh) * | 2023-03-09 | 2023-09-12 | 网易(杭州)网络有限公司 | 风格化图像生成模型的生成方法、装置、设备及存储介质 |
CN116883545A (zh) * | 2023-07-06 | 2023-10-13 | 浙江大学 | 基于扩散模型的图片数据集扩充方法、介质及设备 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11131737B2 (en) * | 2019-06-04 | 2021-09-28 | The Regents Of The University Of California | Joint estimation diffusion imaging (JEDI) |
-
2023
- 2023-10-17 CN CN202311345132.3A patent/CN117095258B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116740204A (zh) * | 2023-03-09 | 2023-09-12 | 网易(杭州)网络有限公司 | 风格化图像生成模型的生成方法、装置、设备及存储介质 |
CN116883545A (zh) * | 2023-07-06 | 2023-10-13 | 浙江大学 | 基于扩散模型的图片数据集扩充方法、介质及设备 |
Also Published As
Publication number | Publication date |
---|---|
CN117095258A (zh) | 2023-11-21 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110288979B (zh) | 一种语音识别方法及装置 | |
US20220261960A1 (en) | Super-resolution reconstruction method and related apparatus | |
CN110162799B (zh) | 模型训练方法、机器翻译方法以及相关装置和设备 | |
CN110009052B (zh) | 一种图像识别的方法、图像识别模型训练的方法及装置 | |
CN110544488B (zh) | 一种多人语音的分离方法和装置 | |
WO2021135577A9 (zh) | 音频信号处理方法、装置、电子设备及存储介质 | |
US20220172737A1 (en) | Speech signal processing method and speech separation method | |
CN111816159B (zh) | 一种语种识别方法以及相关装置 | |
CN113284142B (zh) | 图像检测方法、装置、计算机可读存储介质及计算机设备 | |
CN111680123B (zh) | 对话模型的训练方法、装置、计算机设备及存储介质 | |
CN110516113B (zh) | 一种视频分类的方法、视频分类模型训练的方法及装置 | |
WO2022253061A1 (zh) | 一种语音处理方法及相关设备 | |
CN111225237B (zh) | 一种视频的音画匹配方法、相关装置以及存储介质 | |
CN112184548A (zh) | 图像超分辨率方法、装置、设备及存储介质 | |
CN114418069A (zh) | 一种编码器的训练方法、装置及存储介质 | |
CN112084959B (zh) | 一种人群图像处理方法及装置 | |
CN114065900A (zh) | 数据处理方法和数据处理装置 | |
CN110544287B (zh) | 一种配图处理方法及电子设备 | |
CN115131475A (zh) | 过渡帧生成方法、装置、设备及存储介质 | |
CN116543076B (zh) | 图像处理方法、装置、电子设备及存储介质 | |
CN112488157A (zh) | 一种对话状态追踪方法、装置、电子设备及存储介质 | |
CN117095258B (zh) | 一种扩散模型训练方法、装置、电子设备及存储介质 | |
CN112748899A (zh) | 一种数据处理方法和相关设备 | |
CN115659959B (zh) | 图像的文本纠错方法、装置、电子设备及存储介质 | |
CN112101204A (zh) | 生成式对抗网络的训练方法、图像处理方法、装置和设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |