CN116542321B - 基于扩散模型的图像生成模型压缩和加速方法及系统 - Google Patents

基于扩散模型的图像生成模型压缩和加速方法及系统 Download PDF

Info

Publication number
CN116542321B
CN116542321B CN202310823847.9A CN202310823847A CN116542321B CN 116542321 B CN116542321 B CN 116542321B CN 202310823847 A CN202310823847 A CN 202310823847A CN 116542321 B CN116542321 B CN 116542321B
Authority
CN
China
Prior art keywords
model
image
student
training
teacher
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202310823847.9A
Other languages
English (en)
Other versions
CN116542321A (zh
Inventor
曹巍瀚
张一帆
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhongke Nanjing Artificial Intelligence Innovation Research Institute
Original Assignee
Zhongke Nanjing Artificial Intelligence Innovation Research Institute
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhongke Nanjing Artificial Intelligence Innovation Research Institute filed Critical Zhongke Nanjing Artificial Intelligence Innovation Research Institute
Priority to CN202310823847.9A priority Critical patent/CN116542321B/zh
Publication of CN116542321A publication Critical patent/CN116542321A/zh
Application granted granted Critical
Publication of CN116542321B publication Critical patent/CN116542321B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/096Transfer learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • G06N3/0455Auto-encoder networks; Encoder-decoder networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0475Generative networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0495Quantised networks; Sparse networks; Compressed networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/094Adversarial learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/0985Hyperparameter optimisation; Meta-learning; Learning-to-learn
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于扩散模型的图像生成模型压缩和加速方法及系统,所述方法包括如下步骤:构建并训练图像生成模型,训练完成后部署于服务器中;接收用户输入的数据并预处理,将预处理后的数据传送至训练后的图像生成模型;采用训练后的图像生成模型生成图像并输出显示;所述图像生成模型采用TS模型进行知识蒸馏训练,TS模型为教师学生模型,即教师(teacher)~学生(student)模型。通过知识蒸馏的方法减少了学生模型的采样步数,从而提高了图像生成模型的图像生成的速度和质量。

Description

基于扩散模型的图像生成模型压缩和加速方法及系统
技术领域
本发明涉及人工智能算法相关技术,尤其是基于扩散模型的图像生成模型压缩和加速方法。
背景技术
图像生成任务是近年来计算机视觉领域中备受关注的研究领域之一。基于扩散模型的生成方法在从文本到图像生成领域展现出强大的生成能力,其生成结果在生成可控性和图像质量方面超过了以往基于对抗生成网络的方法。无条件扩散模型可以生成真实图片,但无法根据输入文本输出满足特定意向的图片。有条件扩散模型可以根据输入文本生成对应图片。
然而,基于扩散模型的生成网络在图像生成过程中对计算量的需求较高,这成为阻碍其进一步发展的一个因素。对于使用T个扩散步骤训练的模型,在图像生成阶段,通常会使用相同的时间步序列进行采样。但这会导致扩散模型的图像生成速度变得很慢。一个很直接的方法是使用跨步采样策略,即每[T/S]+1步进行一次采样,以此将采样步骤从T步降为S步。此时,图像生成过程使用的时间步序列为{τ1,τ2,…,τS},其中τ1<τ2<…<τS,τ1至τS属于[1,T]且S<T。然而,减少采样步骤的同时会导致生成图像质量下降。如何提高生成速度、降低资源消耗,并且提高生成效率,是目前需要解决的问题。
因此,需要进行研究创新,以解决上述问题。
发明内容
发明目的:提供一种基于扩散模型的图像生成模型压缩和加速方法,以解决现有技术存在的上述问题。在进一步的实施例中,提供一种基于上述方法的系统。
技术方案:基于扩散模型的图像生成模型压缩和加速方法,包括如下步骤:
步骤S1、构建并训练图像生成模型,训练完成后部署于服务器中;
步骤S2、接收用户输入的数据并预处理,将预处理后的数据传送至训练后的图像生成模型;
步骤S3、采用训练后的图像生成模型生成图像并输出显示;
所述图像生成模型采用TS模型进行知识蒸馏训练,TS模型为教师学生模型,即教师(teacher)~学生(student)模型。
根据本申请的一个方面,所述步骤S1进一步为:
步骤S11、构建并训练至少一个有条件的扩散模型作为教师模型,该教师模型包括至少一个编码器网络和至少一个解码器网络组成,编码器网络将输入的图像和文本信息转换为隐空间向量,解码器网络将隐空间向量转换为输出图像;
步骤S12、初始化至少一个与教师模型结构相同的学生模型,并使用教师模型的参数作为初始参数;
步骤S13、对学生模型进行知识蒸馏训练,在完成一轮知识蒸馏后,使用该学生模型作为新的训练轮次的教师模型,并重复进行知识蒸馏训练;得到训练完成的学生模型并作为图像生成模型。
根据本申请的一个方面,所述步骤S13中,对学生模型进行知识蒸馏训练的过程进一步为:
步骤S13a、接收训练集的数据,并从训练数据集中随机选择一张图像和对应的文本信息,并使用预训练的文本编码模型将文本信息转换为特征向量;
步骤S13b、从指导强度范围内随机选择一个权重,从时间步集合中随机选择一个时间步,并对图像加t步的随机噪声,得到噪声图像;
步骤S13c、对教师模型进行两次前向传播,分别得到两个输出结果zt1和zt2;
第一次前向传播时,输入为噪声图像zt、时间步t和特征向量c,并计算(1+w)*(x(c,θ))’(zt,t,c)-w*(xθ)’(zt,t,Φ)得到输出结果zt1;
第二次前向传播时,输入为输出结果zt1、时间步t1=t–stride//2和特征向量c,并计算(1+w)*(x(c,θ))’(zt1,t1,c)-w*(xθ)’(zt1,t1,Φ)得到输出结果zt2;//表示相除后向下取整;
步骤S13d、对学生模型进行一次前向传播,输入为噪声图像zt、时间步t和特征向量c,并计算(1+w)*(x(c,η))’(zt,t,c)-w*(xη)’(zt,t,Φ)得到输出结果zst3;
其中Φ、η表示与c不符合的负向监督文本特征向量,w表示生成图像在多样性和质量之间的权衡系数。
步骤S13e、计算教师模型的输出结果zt2与学生模型的输出结果zst3的均方误差,并计算学生模型参数η对应的梯度,并反向传播以完成一次迭代训练;
步骤S13f、判断学生模型训练是否收敛。若收敛,结束程序,否则重复上述步骤。
根据本申请的一个方面,当所述教师模型为至少两个,形成集成学习模块时,数据处理过程如下:
分别训练每一教师模型,并保存各自的模型参数;
接收输入数据,并分别使用各个教师模型进行图像生成,并得到多个输出结果;
对于每个输出结果,计算其与输入数据之间的语义相似度,并根据相似度给予权重;
对于每个输出结果,使用加权平均来确定最终输出,并将其作为目标图像;
将目标图像作为监督信号输出至学生模型。
根据本申请的一个方面,在步骤S13c中,判断学生模型的采样步数N是否低于阈值,若低于,将时间步集合从{0,stride,2stride,…,(N-1)stride}修改为{stride-1,2stride-1,3stride-1,…,N*stride-1}。N为自然数,stride为步长。
根据本申请的一个方面,所述学生模型的编码器和解码器中包括自注意力层,对特征向量c进行自注意力计算,并得到加权后的特征向量c’,以加权后的特征向量c’作为学生模型的输入。
根据本申请的一个方面,所述学生模型中还包括预训练的判别器,能够区分真实图像和生成图像,并给出概率值,在训练时用于执行以下流程:
分别接收教师模型和学生模型的输出结果,并调取教师模型和学生模型的目标图像;
使用判别器对分别对教师模型和学生模型的目标图像和输出结果进行判别,并得到两个概率值;
使用目标图像、输出结果和两个概率值作为输入,并基于构建的对抗损失函数;
使用对抗损失函数作为监督信号,在学生模型进行知识蒸馏时进行监督。
根据本申请的一个方面,还包括用于对训练集进行预处理的元学习模块;
在使用元学习模块时,训练数据集分为至少两个子任务,每个子任务均包含训练集和测试集,每个训练集和测试集分别包含输入文本和对应的目标图像;
训练学生模型时,元学习模块针对每个子任务进行如下步骤:使用子任务的训练集对学生模型进行至少一次梯度更新,并得到一个更新后的模型参数;使用更新后的模型参数和子任务的测试集计算生成图像与目标图像之间的损失函数,并累积该损失函数作为元学习目标函数;
在完成一批子任务后,使用累积的元学习目标函数对学生模型进行梯度更新,并得到新的模型参数。
根据本申请的一个方面,所述步骤S2中接收用户输入的数据并预处理的过程至少包括:
调用已构建的用于估算输入数据复杂度的度量函数,以及用于计算时间步长的映射函数;
在接收用户输入的数据时,使用度量函数计算输入数据的复杂度分数;使用映射函数根据复杂度分数得到时间步长;使用该时间步长作为学生模型生成图像所需的采样步数,并进行图像生成,并输出显示。
根据本申请的一个方面,当采用学生模型为两个或以上时,还包括步骤S4:
步骤S41、接收用户输入的选择信号,并查找对应的学生模型;
步骤S42、调整该学生模型的时间步长,再次生成图像并输出;
步骤S43、重复步骤S41和步骤S42,直至接收到用户下载图像的信号。
根据本申请的另一个方面,一种基于扩散模型的图像生成模型压缩和加速系统,其特征在于,包括:
至少一个处理器;以及
与至少一个所述处理器通信连接的存储器;其中,
所述存储器存储有可被所述处理器执行的指令,所述指令用于被所述处理器执行以实现上述任一项技术方案所述的基于扩散模型的图像生成模型压缩和加速方法。
有益效果:提出了一个蒸馏有条件无分类器引导的隐空间扩散模型的方法,通过知识蒸馏的方法减少了学生模型的采样步数,能够对模型进行压缩和加速,同时提高了图像生成的速度和质量。通过无分类器引导的方法实现了正向监督和负向监督的同时进行,提高了图像生成的语义一致性和多样性。通过修改时间步的采样集合,弥补了训练数据域和测试数据域之间的差异,提高了图像生成的清晰度和真实性。
附图说明
图1是本发明的流程图。
图2是本发明步骤S1的流程图。
图3是本发明步骤S13的流程图。
具体实施方式
为了解决现有技术存在的上述问题,提供了一种蒸馏有条件无分类器引导的隐空间扩散模型的方法及使用过程,使得学生模型在生成图片时具有较少的采样步骤,从而提高生成速度,并保证图像生成质量。
与无条件扩散模型的知识蒸馏相比,对有条件扩散模型进行知识蒸馏的难度较高。因为蒸馏后的学生网络不仅需要清晰度较高且较为真实,还需要满足特定的语义,即学生网络还需要具备较高的语义信息理解能力。例如,当扩散模型的采样步数分别为64步和4步时,根据输入“(一座宝可梦风格的山)”生成的图像内容存在显著差别。当采样步数降低后,不仅生成图像的清晰度降低了,生成图像的语义信息也被大幅度削弱了。为此,提供如下方案。
如图1所示,提供一种基于扩散模型的图像生成模型压缩和加速方法,包括如下步骤:
步骤S1、构建并训练图像生成模型,训练完成后部署于服务器中;
步骤S2、接收用户输入的数据并预处理,将预处理后的数据传送至训练后的图像生成模型;
步骤S3、采用训练后的图像生成模型生成图像并输出显示;
所述图像生成模型采用TS模型进行知识蒸馏训练,TS模型为教师学生模型,即教师(teacher)~学生(student)模型。
针对现有技术的缺点,提出了一种基于扩散模型的图像生成模型压缩和加速方法,通过教师模型和学生模型进行蒸馏,实现压缩和加速。本实施例,在运行时占用得空间与资源大大减少。隐空间扩散模型的加入,使得最终模型具有更加广泛的应用空间。利用用户输入的关键信息,使得图形生成算法更加容易收敛,并且能够较为准确地绘制出用户所希望得到的图形。
根据本申请的一个方面,所述步骤S1进一步为:
步骤S11、构建并训练至少一个有条件的扩散模型作为教师模型,该教师模型包括至少一个编码器网络和至少一个解码器网络组成,编码器网络将输入的图像和文本信息转换为隐空间向量,解码器网络将隐空间向量转换为输出图像;
步骤S12、初始化至少一个与教师模型结构相同的学生模型,并使用教师模型的参数作为初始参数;
步骤S13、对学生模型进行知识蒸馏训练,在完成一轮知识蒸馏后,使用该学生模型作为新的训练轮次的教师模型,并重复进行知识蒸馏训练;得到训练完成的学生模型并作为图像生成模型。
根据本申请的一个方面,所述步骤S13中,对学生模型进行知识蒸馏训练的过程进一步为:
步骤S13a、接收训练集的数据,并从训练数据集中随机选择一张图像和对应的文本信息,并使用预训练的文本编码模型将文本信息转换为特征向量;
步骤S13b、从指导强度范围内随机选择一个权重,从时间步集合中随机选择一个时间步,并对图像加t步的随机噪声,得到噪声图像;
步骤S13c、对教师模型进行两次前向传播,分别得到输出结果zt1和输出结果zt2;
第一次前向传播时,输入为噪声图像zt、时间步t和特征向量c,并计算(1+w)*(x(c,θ))’(zt,t,c)-w*(xθ)’(zt,t,Φ)得到输出结果zt1;
第二次前向传播时,输入为输出结果zt1、时间步t1=t–stride//2和特征向量c,并计算(1+w)*(x(c,θ))’(zt1,t1,c)-w*(xθ)’(zt1,t1,Φ)得到输出结果zt2;
步骤S13d、对学生模型进行一次前向传播,输入为噪声图像zt、时间步t和特征向量c,并计算(1+w)*(x(c,η))’(zt,t,c)-w*(xη)’(zt,t,Φ)得到输出结果zst3;
其中Φ、η表示与c不符合的负向监督文本特征向量,w表示生成图像在多样性和质量之间的权衡系数。
步骤S13e、计算教师模型的输出结果zt2与学生模型的输出结果zst3的均方误差,并计算学生模型参数η对应的梯度,并反向传播以完成一次迭代训练;
步骤S13f、判断学生模型训练是否收敛。
具体地,为使得模型在少采样步数时仍能生成满足输入语义信息的图像,提出上述蒸馏方案。
首先,使用已经预训练的有条件扩散模型((x’(c,θ),(xθ)’)作为教师模型,该教师模型可以通过2N步采样(如2N=64)生成较为清晰的具有充足语义信息的图像。学生模型被构建为((x’(c,η),(xη)’),其中η表示可学习参数。在本实施例中,教师模型和学生模型的网络结构相同。
在训练的初始阶段,使用教师模型的模型参数初始化学生模型并计算学生模型图像生成阶段跨步采样使用的步长stride=T//N,此时学生模型可以通过2N步采样生成质量较高的图像,但仅进行N步采样生成的图像质量不达标、语义信息不足。随后进行训练,在训练的每次迭代中,我们从训练集中采样数据x,从指导强度范围内采样得到w,w用来控制生成图像在多样性和质量之间的权衡,从集合{0,stride,2*stride,…,(N-1)*stride}中采样得到当前训练迭代所需的时间步t,并对采样数据x加t步的随机噪声得到zt。
接着,为了使学生模型经过一轮采样得到的结果可以逼近教师网络两轮采样的结果,具体方法如下:
首先,分别对教师模型进行一次前项传播,模型x(c,θ)的输入为zt、时间步t和监督文本对应的特征向量c,模型xθ的输入为zt、时间步t和负向监督文本对应的特征向量Φ。并计算(1+w)*(x(c,θ))’(zt,t,c)-w*(xθ)’(zt,t,Φ)得到教师模型第一个采样步对应的输出,以此来同时实现正向监督(上式的第一部分)和负向监督(上式的第二部分)。
随后,根据教师模型第一采样步输出计算得到教师模型经一次采样得到的结果zt1。随后重复上述步骤得到教师模型第二次采样后得到的结果zt2,与第一次采样不同的是,第二次采样时教师模型(x(c,θ))’(zt,t,c),(xθ)’(zt,t,Φ)的输入时间步不再为t,而是t1=t–stride//2。接着,使用与教师模型类似的方法,对学生模型((x’(c,η),(xη)’)执行一步采样得到结果zst2。计算zt2与zst2的均方误差并计算学生模型参数η对应的梯度,并反向传播以完成一次迭代训练。重复上述过程直至学生模型训练收敛,此时学生模型通过N轮采样生成图像的语义信息与教师模型2N轮采样较为接近。
在进行完一轮知识蒸馏后,我们使用该学生网络作为新的训练轮次的教师网络,重复上述操作并得到仅需N/2次采样步骤的学生网络,以此类推,直至学生网络的采样步数降至较低水平。
为了解决模型压缩和加速后,导致模型泛化能力降低的问题,提供如下方案。
根据本申请的一个方面,当所述教师模型为至少两个,形成集成学习模块时,数据处理过程如下:
分别训练每一教师模型,并保存各自的模型参数;
接收输入数据,并分别使用各个教师模型进行图像生成,并得到多个输出结果;
对于每个输出结果,计算其与输入数据之间的语义相似度,并根据相似度给予权重;
对于每个输出结果,使用加权平均来确定最终输出,并将其作为目标图像;
将目标图像作为监督信号输出至学生模型。
在某个案例中,具体实施过程如下:
设输入文本为x,特征向量c为f(x),教师模型集合为T={T1,T2,…,Tn},学生模型为S,图像生成损失函数为L,语义相似度函数为sim,加权平均或投票机制为M。
则教师模型的输出结果集合为Y={y1,y2,…,yn},其中yi=Ti(S(C))。
对于每个输出结果yi,计算其与输入文本之间的语义相似度wi=sim(yi,x),并根据相似度给予不同权重。
对于每个输出结果yi,使用加权平均或投票机制来决定最终输出y=M(Y,W),并将其作为目标图像。
知识蒸馏训练的目标是最小化损失函数L(S(C), y)。
在本实施例中,利用多个教师模型的不同特点,提高学生模型的泛化能力和鲁棒性,提高了提高学生模型生成图像的质量和多样性,另外,这个模型可以灵活地选择不同采样步数或不同网络结构或不同训练数据集的教师模型进行集成学习,提高学生模型的适应性。
根据本申请的一个方面,在步骤S13c中,判断学生模型的采样步数N是否低于阈值,若低于,将时间步集合从{0,stride,2stride,…,(N-1)stride}修改为{stride-1,2stride-1,3stride-1,…,N*stride-1}。
上述实施例缓解了采样步数下降导致的生成图像语义信息错误的问题,然而经上述方法蒸馏得到的学生模型生成图像的清晰度仍然不足。具体的表现形式为训练阶段的损失函数数值收敛至了一个较低的水准,然而推理阶段扩散模型根据随机高斯噪声进行图像生成时生成的图像清晰度较低,当学生网络的采样步数较少时该现象尤其明显。分析训练阶段与推理阶段的流程发现训练、推理阶段最大的差别为:训练阶段模型输入为带噪声图像,而推理阶段最初模型的输入为随机的高斯噪声。
由于训练过程的带噪声图片中的噪声强度与采样得到的时间步t相关,而t是从集合{0,stride,2*stride,…,(N-1)*stride}采样得到的结果,因此当学生网络的采样步数N较小时(即跨步采样的步长stride较大时),集合{0,stride,2*stride,…,(N-1)*stride}中的最大值仍然较小。
这就导致在训练阶段所有的带噪声图片中的噪声含量不足,进而导致学生模型的训练数据域与测试数据域差异较大,少量噪声的原图与随机高斯噪声,进而导致学生模型生成图像的清晰度不足、噪声偏多。为缓解上述问题,需要修改上面的蒸馏方法,将每次迭代中时间步集合从{0,stride,2*stride,…,(N-1)*stride}修改为{stride-1,2*stride-1,3*stride-1,…,N*stride-1}。上述修改弥补了学生模型的训练数据域与测试数据域的差异,并极大地提升了学生网络生成图像的清晰度。
在进一步的实施例中,发现当学生网络的采样步数N较小时,上述改进对图像清晰度的提升尤其明显,因此,在学生网络采样步数N低于16时,将上述修改加入蒸馏算法中,可以得到较优性能。学生模型的采样步数为4步,监督文本为:一个蓝黑色的有着两只眼睛的宝可梦风格的东西。实验结果表明,在不做上述改进时,训练得到的学生模型生成图像具有的语义较为正确,但清晰度较低,且图像中含有较多噪声。而进行上述修改后,学生模型生成图像的清晰度得到了较大程度的提升。
对于隐空间扩散模型的图像生成过程,采样(去噪)部分耗时占全部耗时的95%以上,因此通过知识蒸馏的方法减少模型的采样步数可以大幅度提升扩散模型生成图像的速度。实验证明,本申请可以仅需使用4个采样步骤即可生成逼真图像,通过知识蒸馏将图像生成耗时减少至原来的1/8。
根据本申请的一个方面,所述学生模型的编码器和解码器中包括自注意力层,对特征向量c进行自注意力计算,并得到加权后的特征向量c’,以加权的特征向量c’作为学生模型的输入。
在进一步的实施例中,可以采用BERT(Bidirectional Encoder Representationsfrom Transformers)模型、词嵌入模型、预训练语言模型或变换器模型来实现,以BERT为例,具体过程如下:
下载BERT的预训练权重文件和词汇表文件,或者使用自己的数据集进行预训练。
使用BERT的分词器将输入文本分词,并添加特殊符号[CLS]和[SEP]。
使用BERT的词嵌入层将分词后的文本转换为词向量,并添加位置向量和段落向量。
使用BERT的编码器层对词向量进行多层双向自注意力计算,并得到每个词的隐藏状态向量。
使用[CLS]对应的隐藏状态向量作为特征向量,或者对所有词的隐藏状态向量进行平均或最大池化操作,得到特征向量。
使用特征向量作为学生模型的输入,进行图像生成,并得到输出结果。
使用教师模型生成的目标图像作为监督信号,对学生模型进行知识蒸馏训练。
在进一步的实施例中,由于时间步的减少,可能会导致生成内容质量的下降,为了使学生网络能够生成更加符合预期的图像,给出如下的技术方案。
根据本申请的一个方面,所述学生模型中还包括预训练的判别器,能够区分真实图像和生成图像,并给出概率值,在训练时用于执行以下流程:
分别接收教师模型和学生模型的输出结果,并调取教师模型和学生模型的目标图像;
使用判别器对分别对教师模型和学生模型的目标图像和输出结果进行判别,并得到两个概率值;
使用目标图像、输出结果和两个概率值作为输入,并基于构建的对抗损失函数;
使用对抗损失函数作为监督信号,在学生模型进行知识蒸馏时进行监督。
具体而言,实现过程如下:设输入文本为x,特征向量c为f(x),教师模型为T,学生模型为S,判别器网络为D,图像生成损失函数为L1,对抗损失函数为L2。
则教师模型的目标图像为y=T(c),学生模型的输出结果为y’=S(c)。
使用判别器网络对目标图像和输出结果进行判别,并得到两个概率值p=D(y)和p’=D(y’)。
使用目标图像、输出结果和两个概率值作为输入,设计对抗损失函数L2(y,y’,p,p’),使其能够同时考虑与目标图像的相似度和与真实图像的相似度,并提高生成图像的真实性和清晰度。使用对抗损失函数作为监督信号,对学生模型进行知识蒸馏训练,同时对判别器网络进行更新。具体地,可以使用以下公式来更新参数:θSS-α▽θS L2(y,y’,p,p’),θDD-β▽θDL2(y,y’,p,p’),其中θS和θD分别表示学生模型和判别器网络的参数,α和β分别表示学习率。
在本实施例中,可以灵活地设计不同的对抗损失函数来适应不同的任务和数据集,使其能够生成更符合期望的图像。
为了提高模型生成速度,提高对少量数据的快速适应能力,在不同子任务之间进行迁移和泛化,提供如下技术方案。
根据本申请的一个方面,还包括用于对训练集进行预处理的元学习模块;
在使用元学习模块时,训练数据集分为至少两个子任务,每个子任务均包含训练集和测试集,每个训练集和测试集分别包含输入文本和对应的目标图像;
训练学生模型时,元学习模块针对每个子任务进行如下步骤:使用子任务的训练集对学生模型进行至少一次梯度更新,并得到一个更新后的模型参数;使用更新后的模型参数和子任务的测试集计算生成图像与目标图像之间的损失函数,并累积该损失函数作为元学习目标函数;
在完成一批子任务后,使用累积的元学习目标函数对学生模型进行梯度更新,并得到新的模型参数。
具体地,实现过程如下:设输入文本为x,特征向量c为f(x),教师模型为T,学生模型为S,图像生成损失函数为L1,元学习目标函数为L2。
则教师模型的目标图像为y=T(c),学生模型的输出结果为y’=S(c)。
在使用元学习模块时,训练数据集分为至少两个子任务,每个子任务均包含训练集和测试集,每个训练集和测试集分别包含输入文本和对应的目标图像。
训练学生模型时,元学习模块针对每个子任务进行如下步骤:
使用子任务的训练集对学生模型进行至少一次梯度更新,并得到一个更新后的模型参数θ’。具体地,可以使用以下公式来更新参数:θ’=θ-α▽θ L1(y,y’),其中θ表示学生模型的参数,α表示学习率。使用更新后的模型参数θ’和子任务的测试集计算生成图像与目标图像之间的损失函数L1(y,y’)并累积该损失函数作为元学习目标函数L2(θ’)。具体地,可以使用以下公式来累积损失函数:L2(θ’)=L2(θ’)+L1(y,y’)。
在完成一批子任务后,使用累积的元学习目标函数L2(θ’)对学生模型进行梯度更新,并得到新的模型参数θ。具体地,可以使用以下公式来更新参数:θ=θ-β▽θ L2(θ’),其中β表示学习率。
根据本申请的一个方面,所述步骤S2中接收用户输入的数据并预处理的过程至少包括:
调用已构建的用于估算输入数据复杂度的度量函数,以及用于计算时间步长的映射函数;
在接收用户输入的数据时,使用度量函数计算输入数据的复杂度分数;使用映射函数根据复杂度分数得到时间步长;使用该时间步长作为学生模型生成图像所需的采样步数,并进行图像生成,并输出显示。
具体地,文本复杂度的度量函数使用输入文本的长度(以字符数计算)和词汇量(以不同单词数计算)作为特征,使用一个简单的线性回归模型来计算文本复杂度的分数。具体地,该函数可以表示为:f(text)=a*length(text)+b*vocabulary(text)+c。其中,a,b,c是线性回归模型的参数,可以根据一些标注好的文本复杂度的数据来学习或者人工设定。时间步长的映射函数使用一个查找表来根据文本复杂度的分数来映射出一个合适的时间步长。具体地,该函数可以表示为:g(score)=table[score];其中,table是一个预定义好的查找表,可以根据实验数据或者经验知识来设计或者学习。比如得分为0-10时,时间步的长度为8;得分为41-50时,时间步的长度为64。
比如,在某个实施例中,用户输入的是英文表述的“画一幅带着冰球帽子的猫”。则过程如下:使用文本复杂度的度量函数计算输入文本的复杂度分数。假设我们已经知道a=0.1,b=1,c=-5,那么我们可以得到:
f(text)=0.1*length(text)+1*vocabulary(text)-5=0.1*28+1*7-5=2.8+7-5=4.8。
使用时间步长的映射函数根据复杂度分数得到一个时间步长。根据上述查找表,我们可以得到:g(score)=table[score]=table[4.8]=table[0-10]=8。使用该时间步长作为学生模型生成图像所需的采样步数,并进行图像生成,并输出显示。
根据本申请的一个方面,为了提高学生模型对输入文本的歧义性处理能力,提供如下方案。学生模型中引入一个隐变量,使其能够根据输入文本特征向量和随机噪声生成不同风格或视角的图像,并提高生成图像的歧义性处理能力。数据处理步骤如下:对于每个输入文本,使用预训练的文本编码模型(如BERT)将其转换为特征向量。
在学生模型中引入一个隐变量,使其能够根据输入文本特征向量和随机噪声生成不同风格或视角的图像。具体地,可以使用一个编码器网络将特征向量和随机噪声映射为隐变量的均值和方差,然后从正态分布中采样得到隐变量。然后使用一个解码器网络将隐变量和特征向量作为输入,进行图像生成,并得到输出结果。
使用教师模型生成的目标图像作为监督信号,对学生模型进行知识蒸馏训练。具体地,可以设计一个条件变分自编码器损失函数,使其能够同时考虑与目标图像的相似度、与输入文本的一致性和隐变量z的先验分布的相似度,并提高生成图像的歧义性处理能力。
根据本申请的一个方面,当采用学生模型为两个或以上时,在首次生成图像时,采用的时间步长不同,生成不同风格的内容,在后续过程时还包括步骤S4:
步骤S41、接收用户输入的选择信号,并查找对应的学生模型;
步骤S42、调整该学生模型的时间步长,再次生成图像并输出;
步骤S43、重复步骤S41和步骤S42,直至接收到用户下载图像的信号。
在这个实施例中,通过先使用较小时间步长的模型生成图像内容,然后由用户判断内容是否符合预期,如果符合预期,则通过选择对应的图像,然后模型找到对应的学生模型,增加时间步长,然后将对用户选择的图像进行精细化生成,提高图像的分辨率和内容细节。在步骤S42中,对上一次生成的图像进行分辨率和内容细节,如此重复,直到内容和分辨率均符合预期。
在这个实施例中,避免了直接生成高分辨率但内容不符合客户预期的图像。通过快速生成图像,由用户选择图像内容,判断图像内容是否符合预期,然后根据客户的预期对图像进行细节优化,从而实现模型生成图像质量的加速。
在某个具体的实施例中,可以采用如下步骤:
构建至少一个包含多种动物类别的图像数据集,例如ImageNet,以及一个包含类别标签的条件向量数据集,例如one-hot编码。
使用有条件隐空间扩散模型(Conditional Hidden Space Diffusion Model,CHSDM)作为教师模型,在图像数据集和条件向量数据集上进行训练,得到一个能够根据条件向量生成高质量图像的模型。
使用至少两个较小或相同的有条件隐空间扩散模型作为学生模型,在图像数据集和条件向量数据集上进行训练,同时使用上述的知识蒸馏技术(Knowledge Distillation,KD),将教师模型的输出作为额外的监督信号,来提高学生模型的生成能力。
接收用户输入的条件向量,例如[0, 0, 1, 0, 0]表示猫的类别,将其输入学生模型,并根据学生模型的时间步长(Time Step Length, TSL),从隐空间中采样一个随机噪声向量,并通过扩散过程逐步生成图像,并输出。
如果用户对生成的图像不满意,可以选择调整条件向量或时间步长,并重复上述步骤,直至用户下载图像或结束会话。
其中,时间步长是指在扩散过程中,对隐空间向量。时间步长越大优化的迭代次数,表示扩散过程越慢,生成的图像越清晰;时间步长越小,表示扩散过程越快,生成的图像越模糊。
根据本申请的另一个方面,一种基于扩散模型的图像生成模型压缩和加速系统,其特征在于,包括:
至少一个处理器;以及
与至少一个所述处理器通信连接的存储器;其中,
所述存储器存储有可被所述处理器执行的指令,所述指令用于被所述处理器执行以实现上述任一项技术方案所述的基于扩散模型的图像生成模型压缩和加速方法。将上述方法做成软件模块,然后配置到电脑中,即可获得基于扩散模型的图像生成模型压缩和加速系统,相关技术为现有技术,在此不再详述。
以上详细描述了本发明的优选实施方式,但是,本发明并不限于上述实施方式中的具体细节,在本发明的技术构思范围内,可以对本发明的技术方案进行多种等同变换,这些等同变换均属于本发明的保护范围。

Claims (7)

1.基于扩散模型的图像生成模型压缩和加速方法,其特征在于,包括如下步骤:
步骤S1、构建并训练图像生成模型,训练完成后部署于服务器中;
步骤S2、接收用户输入的数据并预处理,将预处理后的数据传送至训练后的图像生成模型;
步骤S3、采用训练后的图像生成模型生成图像并输出显示;
所述图像生成模型采用TS模型进行知识蒸馏训练;
所述步骤S1进一步为:
步骤S11、构建并训练至少一个有条件的扩散模型作为教师模型,该教师模型包括至少一个编码器网络和至少一个解码器网络组成,编码器网络将输入的图像和文本信息转换为隐空间向量,解码器网络将隐空间向量转换为输出图像;
步骤S12、初始化至少一个与教师模型结构相同的学生模型,并使用教师模型的参数作为初始参数;
步骤S13、对学生模型进行知识蒸馏训练,在完成一轮知识蒸馏后,使用该学生模型作为新的训练轮次的教师模型,并重复进行知识蒸馏训练;得到训练完成的学生模型并作为图像生成模型;
所述步骤S13中,对学生模型进行知识蒸馏训练的过程进一步为:
步骤S13a、接收训练集的数据,并从训练数据集中随机选择一张图像和对应的文本信息,并使用预训练的文本编码模型将文本信息转换为特征向量;
步骤S13b、从指导强度范围内随机选择一个权重,从时间步集合中随机选择一个时间步,并对图像加t步的随机噪声,得到噪声图像;
步骤S13c、对教师模型进行两次前向传播,分别得到两个输出结果zt1和zt2;
第一次前向传播时,输入为噪声图像zt、时间步t和特征向量c,并计算(1+w)*(x(c,θ))’(zt,t,c)-w*(xθ)’(zt,t,Φ)得到输出结果zt1;
第二次前向传播时,输入为输出结果zt1、时间步t1=t–stride//2和特征向量c,并计算(1+w)*(x(c,θ))’(zt1,t1,c)-w*(xθ)’(zt1,t1,Φ)得到输出结果zt2;stride表示步长;//表示相除后向下取整;((x’(c,θ),(xθ)’)为教师模型;θ为模型参数;
步骤S13d、对学生模型进行一次前向传播,输入为噪声图像zt、时间步t和特征向量c,并计算(1+w)*(x(c,η))’(zt,t,c)-w*(xη)’(zt,t,Φ)得到输出结果zst3;
其中Φ、η表示与c不符合的负向监督文本特征向量,w表示生成图像在多样性和质量之间的权衡系数;((x’(c,η),(xη)’)为学生模型;
步骤S13e、计算教师模型的输出结果zt2与学生模型的输出结果zst3的均方误差,并计算学生模型参数η对应的梯度,并反向传播以完成一次迭代训练;
步骤S13f、判断学生模型训练是否收敛;
当所述教师模型为至少两个,形成集成学习模块时,数据处理过程如下:
分别训练每一教师模型,并保存各自的模型参数;
接收输入数据,并分别使用各个教师模型进行图像生成,并得到多个输出结果;
对于每个输出结果,计算其与输入数据之间的语义相似度,并根据相似度给予权重;
对于每个输出结果,使用加权平均来确定最终输出,并将其作为目标图像;
将目标图像作为监督信号输出至学生模型;
在步骤S13c中,判断学生模型的采样步数N是否低于阈值,若低于,将时间步集合从{0,stride,2stride,…,(N-1)stride}修改为{stride-1,2stride-1,3stride-1,…,N*stride-1},N为自然数,stride为步长。
2.如权利要求1所述的基于扩散模型的图像生成模型压缩和加速方法,其特征在于,所述学生模型的编码器和解码器中包括自注意力层,对特征向量c进行自注意力计算,并得到加权后的特征向量c’,以加权后的特征向量c’作为学生模型的输入。
3.如权利要求1所述的基于扩散模型的图像生成模型压缩和加速方法,其特征在于,所述学生模型中还包括预训练的判别器,能够区分真实图像和生成图像,并给出概率值,在训练时用于执行以下流程:
分别接收教师模型和学生模型的输出结果,并调取教师模型和学生模型的目标图像;
使用判别器对分别对教师模型和学生模型的目标图像和输出结果进行判别,并得到两个概率值;
使用目标图像、输出结果和两个概率值作为输入,并基于构建的对抗损失函数;
使用对抗损失函数作为监督信号,在学生模型进行知识蒸馏时进行监督。
4.如权利要求1所述的基于扩散模型的图像生成模型压缩和加速方法,其特征在于,还包括用于对训练集进行预处理的元学习模块;
在使用元学习模块时,训练数据集分为至少两个子任务,每个子任务均包含训练集和测试集,每个训练集和测试集分别包含输入文本和对应的目标图像;
训练学生模型时,元学习模块针对每个子任务进行如下步骤:使用子任务的训练集对学生模型进行至少一次梯度更新,并得到一个更新后的模型参数;使用更新后的模型参数和子任务的测试集计算生成图像与目标图像之间的损失函数,并累积该损失函数作为元学习目标函数;
在完成一批子任务后,使用累积的元学习目标函数对学生模型进行梯度更新,并得到新的模型参数。
5.如权利要求1所述的基于扩散模型的图像生成模型压缩和加速方法,其特征在于,所述步骤S2中接收用户输入的数据并预处理的过程至少包括:
调用已构建的用于估算输入数据复杂度的度量函数,以及用于计算时间步长的映射函数;
在接收用户输入的数据时,使用度量函数计算输入数据的复杂度分数;使用映射函数根据复杂度分数得到时间步长;使用该时间步长作为学生模型生成图像所需的采样步数,并进行图像生成,并输出显示。
6.如权利要求1所述的基于扩散模型的图像生成模型压缩和加速方法,其特征在于,当采用学生模型为两个或以上时,还包括步骤S4:
步骤S41、接收用户输入的选择信号,并查找对应的学生模型;
步骤S42、调整该学生模型的时间步长,再次生成图像并输出;
步骤S43、重复步骤S41和步骤S42,直至接收到用户下载图像的信号。
7.一种基于扩散模型的图像生成模型压缩和加速系统,其特征在于,包括:
至少一个处理器;以及
与至少一个所述处理器通信连接的存储器;其中,
所述存储器存储有可被所述处理器执行的指令,所述指令用于被所述处理器执行以实现权利要求1至6任一项所述的基于扩散模型的图像生成模型压缩和加速方法。
CN202310823847.9A 2023-07-06 2023-07-06 基于扩散模型的图像生成模型压缩和加速方法及系统 Active CN116542321B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310823847.9A CN116542321B (zh) 2023-07-06 2023-07-06 基于扩散模型的图像生成模型压缩和加速方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310823847.9A CN116542321B (zh) 2023-07-06 2023-07-06 基于扩散模型的图像生成模型压缩和加速方法及系统

Publications (2)

Publication Number Publication Date
CN116542321A CN116542321A (zh) 2023-08-04
CN116542321B true CN116542321B (zh) 2023-09-01

Family

ID=87458252

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310823847.9A Active CN116542321B (zh) 2023-07-06 2023-07-06 基于扩散模型的图像生成模型压缩和加速方法及系统

Country Status (1)

Country Link
CN (1) CN116542321B (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117993050B (zh) * 2023-12-27 2024-09-17 清华大学 一种基于知识增强扩散模型的建筑设计方法及系统
CN117576518B (zh) * 2024-01-15 2024-04-23 第六镜科技(成都)有限公司 图像蒸馏方法、装置、电子设备和计算机可读存储介质
CN118379219B (zh) * 2024-06-19 2024-09-13 阿里巴巴达摩院(杭州)科技有限公司 模型生成方法及图像生成方法
CN118365510B (zh) * 2024-06-19 2024-09-13 阿里巴巴达摩院(杭州)科技有限公司 图像处理方法、图像处理模型的训练方法及图像生成方法

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112465111A (zh) * 2020-11-17 2021-03-09 大连理工大学 一种基于知识蒸馏和对抗训练的三维体素图像分割方法
CN114399668A (zh) * 2021-12-27 2022-04-26 中山大学 基于手绘草图和图像样例约束的自然图像生成方法及装置
CN114419691A (zh) * 2021-12-13 2022-04-29 深圳数联天下智能科技有限公司 人脸衰老图像的生成方法、模型训练方法、设备和介质
CN115620074A (zh) * 2022-11-11 2023-01-17 浪潮(北京)电子信息产业有限公司 一种图像数据的分类方法、装置以及介质

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111767711B (zh) * 2020-09-02 2020-12-08 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112465111A (zh) * 2020-11-17 2021-03-09 大连理工大学 一种基于知识蒸馏和对抗训练的三维体素图像分割方法
CN114419691A (zh) * 2021-12-13 2022-04-29 深圳数联天下智能科技有限公司 人脸衰老图像的生成方法、模型训练方法、设备和介质
CN114399668A (zh) * 2021-12-27 2022-04-26 中山大学 基于手绘草图和图像样例约束的自然图像生成方法及装置
CN115620074A (zh) * 2022-11-11 2023-01-17 浪潮(北京)电子信息产业有限公司 一种图像数据的分类方法、装置以及介质

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
基于深度特征蒸馏的人脸识别;葛仕明;赵胜伟;刘文瑜;李晨钰;;北京交通大学学报(第06期);全文 *

Also Published As

Publication number Publication date
CN116542321A (zh) 2023-08-04

Similar Documents

Publication Publication Date Title
CN116542321B (zh) 基于扩散模型的图像生成模型压缩和加速方法及系统
CN107392255B (zh) 少数类图片样本的生成方法、装置、计算设备及存储介质
CN111373417B (zh) 与基于度量学习的数据分类相关的设备及其方法
CN107392973B (zh) 像素级手写体汉字自动生成方法、存储设备、处理装置
US20230281445A1 (en) Population based training of neural networks
CN110520871A (zh) 训练机器学习模型
US20200410365A1 (en) Unsupervised neural network training using learned optimizers
WO2018204371A1 (en) System and method for batch-normalized recurrent highway networks
CN114387365B (zh) 一种线稿上色方法及装置
CN111476228A (zh) 针对场景文字识别模型的白盒对抗样本生成方法
CN113158554B (zh) 模型优化方法、装置、计算机设备及存储介质
CN111046178A (zh) 一种文本序列生成方法及其系统
CN111709493A (zh) 对象分类方法、训练方法、装置、设备及存储介质
CN111708871A (zh) 对话状态跟踪方法、装置及对话状态跟踪模型训练方法
JP2024532679A (ja) 自己回帰言語モデルニューラルネットワークを使用して出力系列を評価すること
CN111353541A (zh) 一种多任务模型的训练方法
CN116958712B (zh) 基于先验概率分布的图像生成方法、系统、介质及设备
CN112597777A (zh) 一种多轮对话改写方法和装置
CN111259673A (zh) 一种基于反馈序列多任务学习的法律判决预测方法及系统
CN116186384A (zh) 一种基于物品隐含特征相似度的物品推荐方法及系统
CN115438210A (zh) 文本图像生成方法、装置、终端及计算机可读存储介质
CN113535911B (zh) 奖励模型处理方法、电子设备、介质和计算机程序产品
KR102303626B1 (ko) 단일 이미지에 기반하여 비디오 데이터를 생성하기 위한 방법 및 컴퓨팅 장치
CN118377384B (zh) 一种虚拟场景交互方法和系统
CN117575894B (zh) 图像生成方法、装置、电子设备和计算机可读存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant