CN117910533A - 用于扩散神经网络的噪声调度 - Google Patents
用于扩散神经网络的噪声调度 Download PDFInfo
- Publication number
- CN117910533A CN117910533A CN202410118622.8A CN202410118622A CN117910533A CN 117910533 A CN117910533 A CN 117910533A CN 202410118622 A CN202410118622 A CN 202410118622A CN 117910533 A CN117910533 A CN 117910533A
- Authority
- CN
- China
- Prior art keywords
- output
- network
- diffuse
- network output
- new
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 92
- 238000012549 training Methods 0.000 claims abstract description 62
- 238000000034 method Methods 0.000 claims abstract description 39
- 238000003860 storage Methods 0.000 claims abstract description 11
- 238000009792 diffusion process Methods 0.000 claims description 48
- 238000012545 processing Methods 0.000 claims description 22
- 238000005070 sampling Methods 0.000 claims description 12
- 238000009826 distribution Methods 0.000 claims description 11
- 230000001419 dependent effect Effects 0.000 claims description 3
- 238000004590 computer program Methods 0.000 abstract description 15
- 238000013527 convolutional neural network Methods 0.000 abstract description 4
- 230000008569 process Effects 0.000 description 18
- 239000013598 vector Substances 0.000 description 9
- 230000001143 conditioned effect Effects 0.000 description 7
- 238000010801 machine learning Methods 0.000 description 6
- 230000009471 action Effects 0.000 description 5
- 238000004891 communication Methods 0.000 description 5
- 230000006870 function Effects 0.000 description 5
- 230000008901 benefit Effects 0.000 description 3
- 238000001514 detection method Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000003993 interaction Effects 0.000 description 3
- 241001522296 Erithacus rubecula Species 0.000 description 2
- 241001465754 Metazoa Species 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000007667 floating Methods 0.000 description 2
- 230000000116 mitigating effect Effects 0.000 description 2
- 238000010606 normalization Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 238000013515 script Methods 0.000 description 2
- 238000000926 separation method Methods 0.000 description 2
- 230000026676 system process Effects 0.000 description 2
- 230000000007 visual effect Effects 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 206010028980 Neoplasm Diseases 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000003750 conditioning effect Effects 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000012886 linear function Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000001537 neural effect Effects 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 230000005236 sound signal Effects 0.000 description 1
- 239000000758 substrate Substances 0.000 description 1
- 230000002123 temporal effect Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
- 238000009827 uniform distribution Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/09—Supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本公开涉及用于扩散神经网络的噪声调度。用于使用扩散神经网络生成网络输出以及用于利用修改的噪声调度策略训练扩散神经网络的方法、系统和装置,其包括在计算机存储介质上编码的计算机程序。
Description
技术领域
本说明书涉及使用神经网络以条件输入(conditioning input)为条件生成输出。
背景技术
神经网络是机器学习模型,它采用一层或多层非线性单元来针对接收到的输入预测输出。一些神经网络除了输出层之外还包括一个或多个隐藏层。每个隐藏层的输出用作对于网络中一个或多个其他层——即,一个或多个其他隐藏层、输出层或两者——的输入。网络的每一层根据相应参数集的当前值从接收到的输入生成输出。
发明内容
本说明书描述了一种在一个或多个位置的一个或多个计算机上实现为计算机程序的系统,该系统训练扩散神经网络以生成训练网络输出。
一般而言,条件输入表征网络输出的一个或多个期望属性,即表征由系统生成的最终网络输出应具有的一个或多个属性。
更具体地说,该系统使用扩散神经网络生成网络输出。
该系统可以修改对扩散神经网络的训练、对训练后的扩散神经网络的输入、或两者,以提高由扩散神经网络生成的网络输出的质量。
可以实现本说明书中描述的主题的特定实施例以便实现以下优点中的一个或多个。
本说明书描述了用于改进扩散神经网络在生成输出——例如生成图像、音频或视频——方面的性能的技术。特别地,本说明书描述了作为扩散神经网络的训练的一部分的、对应用于训练网络输出的噪声进行修改的技术。也就是说,本说明书描述了用于修改用于扩散神经网络的训练的噪声调度(noise scheduling)策略,以便改进由训练后的扩散神经网络生成的网络输出的质量的技术。
特别是,不同的噪声调度(different noise schedules)可以极大地影响扩散神经网络的性能,并且最优噪声调度可以取决于任务(例如,取决于需要由扩散神经网络生成的网络输出的大小)。例如,当增加图像大小时,由于像素冗余增加,最优噪声调度会移向噪声较大的调度。因此,不同的任务可以从不同的噪声调度中受益。
为了缓解这些问题并改进性能,本说明书描述了用于按标度因子缩放输入数据的技术。本说明书还描述了在训练期间使用线性噪声调度的技术。通过应用这些修改之一或两者,系统可以显著改进扩散神经网络的训练,从而改进训练后的神经网络的操作。可选地,对于扩散神经网络的输入也可以在被扩散神经网络处理之前被归一化,减轻缩放因子对由模型处理的输入的方差的影响并进一步改进训练质量。
在附图和下面的描述中阐述了本说明书的主题的一个或多个实施例的细节。本主题的其他特征、方面和优点将从说明书、附图和权利要求中变得显而易见。
附图说明
图1是示例数据生成系统的图。
图2是用于训练扩散神经网络的示例过程的流程图。
图3示出了一组带噪声的图像的示例。
图4是用于使用经过训练的扩散神经网络生成最终网络输出的示例过程的流程图。
图5示出了所描述的技术在图像生成任务上的性能的示例。
各个附图中相同的附图标号和标记指示相同的元件。
具体实施方式
本说明书描述了一种在一个或多个位置的一个或多个计算机上实现为计算机程序的系统,该系统训练扩散神经网络以用于生成网络输出。
可选地,任何给定网络输出的生成可以以条件输入为条件。一般而言,条件输入表征网络输出的一个或多个期望属性,即表征由系统生成的最终网络输出应具有的一个或多个属性。
该系统可以被配置为以无条件的方式或者以多种条件输入中的任何一种为条件生成多种网络输出中的任何一种。
例如,系统可以被配置为生成音频数据,例如音频波形或音频频谱图,例如梅尔频谱图或频率处于不同标度的音频的频谱图。
在此示例中,条件输入可以是音频应当表示的文本或文本的特征,即,使得系统充当文本到语音的机器学习模型,其将文本或文本的特征转换为用于该文本被说出的话语的音频数据。
作为另一个示例,条件输入可以识别音频的期望说话者,即,使得系统生成表示期望说话者的语音的音频数据。
作为另一个示例,条件输入可以表征歌曲或其他音乐片段的属性,例如歌词、流派(genre)等,使得系统生成具有由条件输入表征的属性的音乐片段。
作为另一个示例,条件输入可以指定将音频数据分类为来自一组可能类中的类,使得系统生成属于该类的音频数据。例如,这些类可以表示乐器或其他音频发射设备的类型,即,使得系统生成由对应类、动物类型所发出的音频,即,使得系统生成表示由对应动物生成的噪声等的音频。
作为另一特定示例,网络输出可以是图像,使得系统可以通过生成图像的像素的强度值来执行有条件的图像生成。
在该特定示例中,条件输入可以是文本序列,并且网络输出可以是描述文本的图像,即,条件输入可以是输出图像的标题。
作为又一特定示例,条件输入可以是指定一个或多个边界框以及可选地指定应当在每个边界框中描绘的相应类型的对象的对象检测输入。
作为又一特定示例,条件输入可以从多个对象类中指定输出图像中描绘的对象应该属于的对象类。
作为又一特定示例,条件输入可以指定第一分辨率的图像,并且网络输出可以包括第二更高分辨率的图像。
作为又一特定示例,条件输入可以指定图像,并且网络输出可以包括图像的去噪版本。
作为又一特定示例,条件输入可以指定包括用于检测的目标实体——例如肿瘤——的图像,并且网络输出可以包括没有目标实体的图像,例如以通过比较图像来促进目标实体的检测。
作为又一特定示例,条件输入可以是分割,其将输出图像的多个像素中的每一个像素分配给一组类别中的一个类别,例如,将所述类别中的相应一个类别分配给每个像素。
更一般地,任务可以是输出以条件输入为条件的连续数据的任何任务。例如,输出可以是不同传感器的输出,例如激光雷达点云、雷达点云、心电图读数等,并且条件输入可以表示应由传感器测量的数据的类型。当期望离散输出时,可以例如通过阈值化获得该输出。
在上述示例中的任何示例中,使用扩散神经网络生成的网络输出可以是输出空间中的网络输出,即,使得网络输出中的值是适当类型的网络输出的值,例如,图像像素的值、音频信号的幅度值等,或者潜在空间中的网络输出,即,使得网络输出中的值是输出空间中的网络输出的潜在表示中的值。
当在潜在空间中生成网络输出时,系统可以通过使用解码器神经网络——例如,已经在自动编码器框架中预训练的解码器神经网络——处理潜在空间中的网络输出来生成像素空间中的最终网络输出。在训练期间,系统可以使用编码器神经网络——例如,已经在自动编码器框架中与解码器联合预训练的编码器神经网络——来对目标网络输出进行编码,以生成扩散神经网络的目标输出。
图1是示例数据生成系统100的图。数据生成系统100是在一个或多个位置的一个或多个计算机上实现为计算机程序的系统的示例,其中可以实现下面描述的系统、组件和技术。
系统100获得条件输入102并使用条件输入102来生成网络输出112,该网络输出112具有由条件输入102表征的一个或多个期望属性。
具体地,为了生成网络输出112,系统100使用扩散神经网络110通过执行反向扩散过程来跨多个更新迭代生成网络输出112。
扩散神经网络110可以是任何适当的扩散神经网络,其已经例如由系统100或另一训练系统训练,以在任何给定的更新迭代处处理包括当前网络输出(截至更新迭代)的用于更新迭代的扩散输入以生成用于更新迭代的扩散输出。
例如,扩散神经网络110可以是具有多个卷积层块的卷积神经网络,例如U-Net。在这些情况下,扩散神经网络110可以包括散布在卷积层块当中的一个或多个注意力层块。如下文将描述的,一些或所有注意力块可以以条件输入102的表示为条件。
作为另一个示例,扩散神经网络110可以是循环接口网络(RIN)。循环接口网络是一种包括一系列神经网络块的神经网络,每个神经网络块对从对于神经网络的输入导出的一组接口向量进行更新。具体地,每个块使用一组潜在向量来更新该组接口向量,其中,该组中的潜在向量的数量独立于该组接口向量中的接口向量的数量。特别地,该组中的潜在向量的数量通常小于该组中的接口向量的数量。在下文中更详细地描述了循环接口网络:在arXiv:2212.11972处可获得的Scalable Adaptive Computation for IterativeGeneration(用于迭代生成的可缩放自适应计算);以及,在12/22/2023提交的申请号PCT/US2023/085784,RECURRENT INTERFACE NETWORKS(循环接口网络),它们的全部内容为在此通过引用整体并入本文。
作为另一个示例,扩散神经网络110可以是Transformer神经网络,其通过一组自注意力层处理扩散输入以生成去噪输出。
神经网络110可以以多种方式中的任一种以条件输入102为条件。
作为一个示例,系统100可以使用编码器神经网络来生成表示条件输入102的一个或多个嵌入,并且扩散神经网络110可以包括一个或多个交叉注意力层,每个交叉注意力层交叉关注到一个或多个嵌入。
如本说明书中所使用的,嵌入是数值的有序集合,例如浮点值或其他类型的值的向量。
例如,当条件输入是文本时,系统可以使用文本编码器神经网络,例如,Transformer神经网络,来生成表示条件输入的固定或可变数量的文本嵌入。
当条件输入是图像时,系统可以使用图像编码器神经网络,例如,卷积神经网络或视觉Transformer神经网络,来生成表示图像的一组嵌入。
当条件输入是音频时,系统可以使用例如音频编码器神经网络,例如已经作为神经音频编解码器的一部分与解码器神经网络联合训练的音频编码器神经网络,以生成对音频进行编码的一个或多个嵌入。
当条件输入是标量值时,系统可以使用例如嵌入矩阵来将标量值或标量值的独热表示映射到嵌入。
在一些实施方式中,扩散输出是对当前网络输出的噪声分量的估计,即该噪声分量为需要与最终网络输出即由系统100生成的网络输出112相组合(例如被添加或被减去)以生成当前网络输出的噪声。
在一些其他实施方式中,扩散输出是给定当前网络输出的情况下的最终网络输出的估计,即,通过去除当前网络输出的噪声分量而产生的网络输出的估计。
例如,扩散神经网络110可以已经使用去噪分数匹配目标在一组训练网络输出上进行训练以生成扩散输出。
下面参考图2和3更详细地描述训练扩散神经网络110。
在每次更新迭代时,系统100使用由扩散神经网络110生成的扩散输出来更新截至更新迭代时的当前网络输出。
在最后一次更新迭代之后,系统100输出当前网络输出作为最终网络输出112。
例如,系统100可以提供网络输出112以在用户计算机上向用户呈现或回放或者存储网络输出112以供以后使用。
在一些实施方式中,扩散神经网络110是系统100用来生成最终网络输出的扩散神经网络序列(例如扩散神经网络的层级或级联)中的一个。例如,该序列中的每个扩散神经网络可以接收由该序列中的前一个扩散神经网络生成的网络输出作为输入,并生成相对于该序列中的前一个扩散神经网络具有增加的分辨率——例如,增加的空间分辨率、增加的时间分辨率、或两者——的网络输出。在这些实施方式中,该序列中的所有神经网络可以接收条件输入102,或者仅该序列中的扩散神经网络的真子集可以接收条件输入102,例如,在仅该序列中的一个或多个最早位置处的扩散神经网络。
一般而言,系统100可以修改以下各项中的一个或多个:对扩散神经网络110的训练、对扩散神经网络110的输入、或者在训练之后如何使用扩散神经网络110来生成网络输出以提高由扩散神经网络110生成的网络输出的质量。
作为一个示例,系统100可以修改在扩散神经网络110的训练期间使用的噪声调度,以便改进由扩散神经网络110在训练之后生成的网络输出的质量。
这在下面参考图2更详细地描述。
作为另一个示例,在训练之后并且在每次更新迭代时,系统100可以修改网络输出的当前版本,以便改进在最后一次更新迭代之后生成的最终网络输出的质量。
这在下面参考图4更详细地描述。
图2是用于训练扩散神经网络使得扩散神经网络可以以可变数量的上下文网络输出为条件的示例过程200的流程图。为了方便起见,过程200将被描述为由位于一个或多个位置的一个或多个计算机的系统执行。例如,根据本说明书适当地编程的数据生成系统——例如图1中描绘的数据生成系统100——可以执行过程200。
系统可以重复地执行过程200的迭代以便训练扩散神经网络。
系统获得一组一个或多个训练网络输出(步骤202)。可选地,训练网络输出中的一些或全部可以与对应的条件输入相关联。例如,系统可以从更大的一组训练网络输出——即,从用于训练扩散神经网络的一组训练数据——对训练网络输出进行采样。
然后系统对所述训练网络输出中的每个训练网络输出执行步骤204-210。
系统通过从在时间步长分布的下限和上限之间的时间步长上的时间步长分布进行采样来采样时间步长(步骤204)。例如,时间步长分布可以是在零和一(包括零和一)之间的间隔上的连续均匀分布。也就是说,时间步长具有在零和一(包括零和一)之间的值。
系统生成新的噪声分量(步骤206)。噪声分量通常具有与训练网络输出相同的维度,但具有噪声值。例如,系统可以通过从例如正态分布的指定噪声分布采样新的噪声分量中的每个值来生成新的噪声分量。
系统通过下述方式来生成新的噪声网络输出:根据不等于1的缩放因子和取决于采样时间步长的噪声调度来组合训练网络输出和新的噪声分量(步骤208)。
也就是说,噪声调度(noise schedule)是一个函数,该函数将对于该函数的输入映射到定义用于训练网络输出和新的噪声分量的相应权重的输出。然后,系统例如通过计算训练网络输出和新的噪声分量的加权总和而根据相应的权重组合训练网络输出,以生成新的噪声网络输出。
例如,系统可以将训练网络输出和新的噪声分量组合如下:
其中,∈是噪声分量,b是缩放因子,γ(t)是噪声水平,即采样时间步长t的噪声调度的输出,并且x0是训练网络输入。
图3中的示例可以证明噪声水平和噪声调度的重要性。
图3示出了一组带噪声的图像a)至e)的示例300。每个带噪声的图像具有不同的分辨率,并且已经通过将(i)相同的真实值图像(下采样到对应的分辨率)与(i)根据相同噪声水平γ=.7从相同噪声分布采样的噪声分量相组合而生成。从图3可以看出,随着图像大小的增加,相同噪声水平(即相同γ)下的去噪任务变得更简单。这是因为下述事实:数据中的信息的冗余(例如,附近像素当中的相关性)通常随着图像大小的增加而增加。此外,噪声被独立地添加到每个像素,使得当图像大小增加时更容易恢复原始信号。因此,较小分辨率下的最优调度在较高分辨率下可能不是最优的。例如,较高的分辨率(较大的输出)可能在训练期间需要较高的噪声水平,以便有效地训练神经网络。因此,如果不相应地调整噪声调度,可能会导致某些噪声水平的训练不足,并可能损害所训练的扩散神经网络的性能。
利用缩放因子和噪声调度可以解决该现象并改进训练后神经网络的性能。
特别是,通过减小缩放因子(减小到小于1的数字),噪声网络输出中的噪声水平会增加。因此,考虑到较高分辨率的任务需要较高的噪声水平,当输出分辨率更高时,缩放因子可以被设置为较小的值。例如,缩放因子可被设置为.1、.2、.3、.4、.5、.6、.7、.8或.9之一。
系统可以与缩放因子组合地使用多种噪声调度中的任何一种。
作为一个示例,噪声调度可以是采样时间步长的一维函数。
例如,噪声调度可以是采样时间步长的线性函数,并且更具体地,γ(t)=1-t。使用这个噪声调度表可以确保训练覆盖训练期间的所有噪声水平,从而改进训练后扩散神经网络的性能。
作为另一个示例,噪声调度可以是余弦调度或S形调度。
系统使用扩散神经网络处理新的扩散输入,该新的扩散输入包括(i)新的噪声网络输出、以及(ii)指定采样时间步长的数据,以生成新的扩散输出,该新的扩散输出定义用于采样时间步长的新的噪声分量的估计(步骤210)。
如上所述,在一些实施方式中,扩散输出是对新的噪声网络输出的噪声分量的估计。
在一些其他实施方式中,扩散输出是给定新的噪声网络输出的情况下的对训练网络输出的估计,即,对通过去除新的噪声网络输出的噪声分量而会产生的网络输出的估计。
当训练网络输出与条件输入相关联时,新的扩散输入还包括条件输入的表示。
新的扩散输入还可以包括其他数据,例如噪声水平、采样时间步长或两者的表示。
可选地,作为处理新的扩散输入的一部分,系统可以在由扩散神经网络的输入层处理新的扩散输入之前对新的噪声网络输出进行归一化。例如,系统可以通过新的噪声网络输出的方差来归一化新的噪声网络输出。也就是说,系统可以计算新的噪声网络输出内的值的方差,然后将每个值除以方差。执行此归一化可以减轻缩放因子b对新噪声网络输出的方差的影响。也就是说,在不执行归一化的情况下,由于缩放因子的应用,即使当前网络输出和噪声分量具有相同的方差,新的噪声网络输出也可以具有与噪声分量和当前网络输出不同的方差,这可能会降低训练过程的有效性。对新的噪声网络输出进行归一化确保了其在被扩散神经网络处理之前具有单位方差,从而减轻缩放因子对方差的影响。
然后系统在目标上训练扩散神经网络(步骤212)。
对于每个训练网络输出,目标测量以下两项之间的误差:(i)通过处理用于训练网络输出的对应新扩散输入生成的采样时间步的新的噪声分量的估计、以及(ii)用于训练网络输出的采样时间步长的新的噪声分量。作为特定示例,该目标可以是误差的平均值或总和,或者可以包括作为误差的平均值或总和的第一项以及一个或多个其他项,例如正则化项、辅助损失项等。
例如,当新的扩散输出表示训练网络输出的预测时,误差的一个示例可以是:
||f(xt)-x0||2
其中,f(xt)是新的扩散输出。
作为另一个示例,当新的扩散输出表示噪声分量的预测时,误差的一个示例可以是:
||f(xt)-∈||2
其中,f(xt)是新的扩散输出。
为了在目标上训练扩散神经网络,系统可以例如通过反向传播来计算目标相对于扩散神经网络的参数的梯度,然后通过向梯度应用优化器,例如Adam优化器、AdamW优化器、Adafactor优化器、学习优化器等,来更新参数。
图4是用于使用所训练的扩散神经网络来生成最终网络输出的示例过程400的流程图。为了方便起见,过程400将被描述为由位于一个或多个位置的一个或多个计算机的系统执行。例如,根据本说明书适当编程的数据生成系统,例如,图1中描绘的数据生成系统100,可以执行过程400。
系统获得条件输入(步骤402)。
系统初始化网络输出(步骤404)。
通常,初始化的网络输出是与最终网络输出相同的维度,但是具有噪声值。也就是说,初始化的网络输出具有与最终网络输出相同数量的元素。
例如,系统可以通过从对应的噪声分布——例如,正态分布或不同的噪声分布——采样网络输出中的每个元素的值来初始化网络输出,即,可以生成网络输出的第一实例。也就是说,网络输出包括多个元素,并且初始网络输出包括相同数量的元素,其中,从对应的噪声分布采样每个元素的值。
然后,系统通过在多次更新迭代中的每次更新迭代更新网络输出来生成最终网络输出。换句话说,最终网络输出是多次更新迭代中的最后一次迭代之后的网络输出。
在一些情况下,迭代次数是固定的。在其他情况下,系统或另一系统可以基于生成最终网络输出的时延要求来调整迭代次数,即,可以选择迭代次数,使得最终网络输出将被生成以满足时延要求。在其他情况下,系统或另一系统可以基于用于生成最终网络输出的计算资源消耗要求来调整迭代次数,即,可以选择迭代次数,使得最终网络输出将被生成以满足该要求。例如,该要求可以是作为生成最终网络输出的一部分要执行的最大浮动运算数量(FLOPS)。
通常,系统跨更新迭代执行反向扩散过程,以在每次迭代时更新当前网络输出。每个更新迭代对应于时间间隔中的不同时间步长——例如,零与一之间的间隔——或对应于另一适当的时间间隔。例如,每个不同的时间步长可以对应于时间间隔的均匀离散化的不同点或对应于时间间隔的不同的非均匀离散化的不同点。
特别地,在每次更新迭代时,系统执行步骤406-412以更新截至该更新迭代的当前网络输出。
系统归一化截至该更新迭代的当前网络输出(步骤406)。对于第一次更新迭代,当前网络输出是初始化的网络输出。对于每个后续更新迭代,当前网络输出是来自先前更新迭代的更新网络输出。
例如,系统可以使用如上所述的当前网络输出的方差来归一化当前网络输出。
系统使用扩散神经网络处理用于更新迭代的、包括当前网络输出和条件输入的表示的第一扩散输入,以生成用于更新迭代的第一扩散输出(步骤408)。
例如,在第一次更新迭代之前,系统可以使用嵌入神经网络来处理条件输入以生成该条件输入的一个或多个嵌入。
然后,用于任何给定更新迭代的第一扩散输入可以包括条件输入的一个或多个嵌入。
第一扩散输入还可以包括以下各项中的一个或多个:标识更新迭代的数据、表征在网络输出生成期间用作上下文的一个或多个上下文网络输出的数据、所生成的网络输出的一个或多个属性的标量值等。
可选地,即,当使用无分类器的指导时,系统还可以处理用于更新迭代的一个或多个附加扩散输入,以针对每个附加扩散输入生成用于更新迭代的相应附加扩散输出(步骤410)。
每个附加扩散输入还包括截至更新迭代的当前网络输出,但包括不同的条件输入。
例如,附加扩散输入之一可以是无条件的扩散输入,其包括已经被指定为指示应当无条件地生成网络输出的条件输入的表示。
作为另一示例,附加扩散输入之一可以是负扩散输入,其包括指示所生成的网络输出不应当具有的属性的负条件输入的表示。
也就是说,系统还可以接收指示所生成的网络输出不应当具有的属性的负条件输入,并且可以在负扩散输入中包括负条件输入的表示,例如,从负条件输入生成的一个或多个嵌入。
系统根据第一扩散输出以及在生成时的附加扩散输出来确定用于更新迭代的最终扩散输出(步骤412)。
当没有生成附加扩散输出时,系统可以将最终扩散输出设置为等于第一扩散输出。
当生成一个或多个附加扩散输出时,系统可以根据用于更新迭代的引导权重w来组合第一扩散输出和最终扩散输出。
例如,系统可以将最终扩散输出设置为等于(1+w)*第一扩散输出–w*附加扩散输出,或者当存在多个附加扩散输出时,附加扩散输出的总和。
系统然后使用最终扩散输出更新当前网络输出(步骤414)。
例如,系统可以根据最终扩散输出计算最终网络输出的初始估计,然后使用最终网络输出的初始估计来更新当前网络输出。
例如,当扩散输出是最终网络输出的估计时,系统可以使用最终扩散输出作为噪声分量的初始估计。
当扩散输出是噪声分量的估计时,系统可以使用最终扩散输出来计算最终网络输出的初始估计,例如,如下:
其中,是最终扩散输出,t是对应于更新迭代的时间步长,并且γ(t)是作为用于推断的噪声调度的输出的噪声水平。注意,与用于训练的噪声调度相比,不同的噪声调度可以用于推断。例如,噪声调度可以是用于训练的1-t调度和用于推断的余弦调度。
对于最后更新迭代,系统可以使用初始估计作为更新的网络输出。
对于除了最后更新迭代之外的每个更新迭代,系统可以将适当的扩散采样器应用于初始估计以生成更新的网络输出。
图5示出了在三个图像生成任务上的所描述的技术的性能的示例500,一个需要生成64x64个图像,一个需要生成128x128个图像,并且一个需要生成256x256个图像。
特别地,图5示出了FID方面的结果,其中,对于输入标度因子的各种值、对于各种图像分辨率以及对于1-t调度和余弦调度,较低的分数较好。
从图5可以看出,随着图像分辨率增加,最优输入缩放因子减小,即,对于除了最小分辨率之外的所有分辨率,最优输入缩放因子小于1。
此外,图5还示出,在给定对应分辨率的最优输入缩放因子的情况下,1-t调度通常比余弦调度表现得更好。
本说明书结合系统和计算机程序组件使用术语“被配置”。对于要被配置为执行特定操作或动作的一个或多个计算机的系统,意味着系统已经在其上安装了软件、固件、硬件或它们的组合,其在操作中使系统执行所述操作或动作。对于要被配置为执行特定操作或动作的一个或多个计算机程序,意味着所述一个或多个程序包括指令,所述指令在由数据处理装置执行时使得所述装置执行操作或动作。
在本说明书中描述的主题和功能操作的实施例可以在数字电子电路系统中、在有形地体现的计算机软件或固件中、在计算机硬件中——包括在本说明书中公开的结构及其结构等同物——或在它们中的一个或多个的组合中实现。本说明书中描述的主题的实施例可以被实现为一个或多个计算机程序,例如,编码在有形非暂时性存储介质上的计算机程序指令的一个或多个模块,用于由数据处理装置执行或控制数据处理装置的操作。计算机存储介质可以是机器可读存储设备、机器可读存储基板、随机或串行访问存储器设备或它们中的一个或多个的组合。可替代地或附加地,程序指令可以被编码在人工生成的传播信号——例如,机器生成的电信号、光信号或电磁信号——上,该传播信号被生成以编码用于传输到合适的接收器装置以供数据处理装置执行的信息。
术语“数据处理装置”是指数据处理硬件,并且涵盖用于处理数据的所有种类的装置、设备和机器,包括例如可编程处理器、计算机或多个处理器或计算机。该装置还可以是或进一步包括专用逻辑电路系统,例如FPGA(现场可编程门阵列)或ASIC(专用集成电路)。除了硬件之外,该装置还可以可选地包括创建用于计算机程序的执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统或它们中的一个或多个的组合的代码。
也可以被称为或描述为程序、软件、软件应用、应用、模块、软件模块、脚本或代码的计算机程序可以用任何形式的编程语言编写,包括编译或解释语言、或声明或过程语言;并且它可以以任何形式部署,包括作为独立程序或作为模块、组件、子例程、或适合于在计算环境中使用的其他单元。程序可以但不必对应于文件系统中的文件。程序可以存储在保存其他程序或数据的文件的一部分中,例如存储在标记语言文档中的一个或多个脚本,存储在专用于所讨论的程序的单个文件中,或者存储在多个协调文件中,例如存储一个或多个模块、子程序或代码部分的文件。可以部署计算机程序以在位于一个站点或跨多个站点分布并通过数据通信网络互连的一个计算机或多个计算机上执行。
在本说明书中,术语“数据库”广泛地用于指代任何数据集合:数据不需要以任何特定方式结构化或根本不需要结构化,并且它可以存储在一个或多个位置中的存储设备上。因此,例如,索引数据库可以包括多个数据集合,每个数据集合可以被不同地组织和访问。
类似地,在本说明书中,术语“引擎”广泛地用于指代被编程为执行一个或多个特定功能的基于软件的系统、子系统或过程。通常,引擎将被实现为安装在一个或多个位置中的一个或多个计算机上的一个或多个软件模块或组件。在一些情况下,一个或多个计算机将专用于特定引擎;在其他情况下,可以在相同的一个或多个计算机上安装和运行多个引擎。
本说明书中描述的过程和逻辑流程可以由执行一个或多个计算机程序的一个或多个可编程计算机执行,以通过对输入数据进行操作并生成输出来执行功能。过程和逻辑流程还可以由例如FPGA或ASIC的专用逻辑电路系统或由专用逻辑电路系统和一个或多个编程计算机的组合来执行。
适于执行计算机程序的计算机可以基于通用或专用微处理器或两者,或任何其他类型的中央处理单元。通常,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的基本元件是用于执行或实施指令的中央处理单元以及用于存储指令和数据的一个或多个存储器设备。中央处理单元和存储器可以由专用逻辑电路系统补充或并入专用逻辑电路系统中。通常,计算机还将包括或可操作地耦合以从用于存储数据的一个或多个大容量存储设备——例如磁盘、磁光盘或光盘——接收数据或向其传送数据或两者都有。然而,计算机不需要具有这样的设备。此外,计算机可以被嵌入在另一设备中,例如移动电话、个人数字助理(PDA)、移动音频或视频播放器、游戏控制台、全球定位系统(GPS)接收器或便携式存储设备,例如通用串行总线(USB)快闪驱动器,仅举几例。
适合于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储器设备,包括例如半导体存储器设备,例如EPROM、EEPROM和闪存设备;磁盘,例如内部硬盘或可移动盘;磁光盘;以及CD ROM和DVD-ROM盘。
为了提供与用户的交互,本说明书中描述的主题的实施例可以在计算机上实现,该计算机具有用于向用户显示信息的显示设备,例如CRT(阴极射线管)或LCD(液晶显示器)监视器,以及用户可以通过其向计算机提供输入的键盘和指点设备,例如鼠标或轨迹球。其他种类的设备也可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的感觉反馈,例如视觉反馈、听觉反馈或触觉反馈;并且可以以任何形式接收来自用户的输入,包括声学、语音或触觉输入。另外,计算机可以通过向用户使用的设备发送文档和从用户使用的设备接收文档来与用户交互;例如,通过响应于从用户的设备上的web浏览器接收到的请求而向web浏览器发送网页。此外,计算机可以通过向个人设备——例如正在运行消息传送应用的智能电话——发送文本消息或其他形式的消息并且作为回报从用户接收响应消息来与用户交互。
用于实现机器学习模型的数据处理装置还可以包括例如专用硬件加速器单元,用于处理机器学习训练或生产的公共和计算密集型部分,例如推理、工作负载。
可以使用机器学习框架——例如TensorFlow框架或Jax框架——来实现和部署机器学习模型。
本说明书中描述的主题的实施例可以在计算系统中实现,该计算系统包括后端组件,例如,作为数据服务器,或者包括中间件组件,例如,应用服务器,或者包括前端组件,例如,具有用户可以通过其与本说明书中描述的主题的实施方式进行交互的图形用户界面、web浏览器或应用的客户端计算机,或者一个或多个这样的后端、中间件或前端组件的任何组合。系统的组件可以通过例如通信网络的任何形式或介质的数字数据通信互连。通信网络的示例包括局域网(LAN)和广域网(WAN),例如因特网。
计算系统可以包括客户端和服务器。客户端和服务器通常彼此远离,并且通常通过通信网络进行交互。客户端和服务器的关系通过在相应计算机上运行并且彼此具有客户端-服务器关系的计算机程序而产生。在一些实施例中,服务器将例如HTML页面的数据传输到用户设备,例如,用于向充当客户端的设备交互的用户显示数据和从其接收用户输入。可以在服务器处从设备接收在用户设备处生成的数据,例如用户交互的结果。
虽然本说明书包含许多特定的实施方式细节,但是这些不应当被解释为对任何发明的范围或对可以要求保护的范围的限制,而是作为可以特定于特定发明的特定实施例的特征的描述。在本说明书中在分开的实施例的上下文中描述的某些特征也可以在单个实施例中组合实现。相反,在单个实施例的上下文中描述的各种特征也可以分开地或以任何合适的子组合在多个实施例中实现。此外,尽管以上可以将特征描述为以某些组合起作用并且甚至最初如此要求保护,但是在一些情况下,来自要求保护的组合的一个或多个特征可以从组合中去除,并且要求保护的组合可以涉及子组合或子组合的变型。
类似地,虽然以特定次序在附图中描绘并且在权利要求中叙述了操作,但是这不应当被理解为要求以所示的特定次序或以顺序次序执行这样的操作,或者执行所有所示的操作以实现期望的结果。在某些情况下,多任务和并行处理可能是有利的。此外,上述实施例中的各种系统模块和组件的分离不应被理解为在所有实施例中都需要这样的分离,并且应当理解,所描述的程序组件和系统通常可以一起集成在单个软件产品中或被封装到多个软件产品中。
已经描述了主题的特定实施例。其他实施例在所附权利要求书的范围内。例如,权利要求中记载的动作可以以不同的次序执行并且仍然实现期望的结果。作为一个示例,附图中描绘的过程不一定需要所示的特定次序或顺序次序来实现期望的结果。在一些情况下,多任务和并行处理可能是有利的。
Claims (12)
1.一种训练扩散神经网络的方法,所述方法包括:
获得一组一个或多个训练网络输出;
对于每个训练网络输出:
通过从在时间步长分布的下限和上限之间的时间步长上的所述时间步长分布进行采样,来采样时间步长;
生成新的噪声分量;
通过根据不等于1的缩放因子和取决于所采样的时间步长的噪声调度来组合所述训练网络输出和所述新的噪声分量,来生成新的噪声网络输出;
使用所述扩散神经网络处理包括(i)所述新的噪声网络输出和(ii)指定所采样的时间步长的数据的新的扩散输入,以生成定义所采样的时间步长的所述新的噪声分量的估计的新的扩散输出;以及
在目标上训练所述扩散神经网络,所述目标针对每个训练网络输出测量如下两项之间的误差:通过对包括从所述训练网络输出生成的所述新的噪声网络输出的所述新的扩散输入进行处理而生成的所采样的时间步长的所述新的噪声分量的所述估计、以及所采样的时间步长的所述新的噪声分量。
2.根据权利要求1所述的方法,其中,使用所述扩散神经网络处理包括(i)所述新的噪声网络输出和(ii)指定所采样的时间步长的数据的新的扩散输入,以生成定义所采样的时间步长的所述新的噪声分量的估计的新的扩散输出,包括:
通过所述新的噪声网络输出的方差来归一化所述新的噪声网络输出。
3.根据权利要求1所述的方法,其中,通过根据不等于1的缩放因子和取决于所采样的时间步长的噪声调度来组合所述训练网络输出和所述新的噪声分量来生成新的噪声网络输出包括生成满足以下项的新的噪声网络输出xt:
其中,∈是所述噪声分量,b是所述缩放因子,γ(t)是所采样的时间步长t的所述噪声调度的所述输出,并且x0是所述训练网络输入。
4.根据权利要求3所述的方法,其中,γ(t)=1-t。
5.根据权利要求1所述的方法,其中,所述噪声调度是余弦调度或S形调度。
6.根据权利要求1所述的方法,其中,所述一个或多个训练网络输出是图像。
7.根据权利要求1所述的方法,其中,每个训练网络输出与条件输入相关联,并且其中,所述新的扩散输入包括与所述训练网络输出相关联的所述条件输入的表示。
8.根据权利要求7所述的方法,其中,所述条件输入是文本提示。
9.根据权利要求1-8中任一项所述的方法,进一步包括:
在所述训练之后,使用所训练的扩散神经网络来生成新的网络输出,包括在多个迭代中的每个迭代处:
生成用于该迭代的最终扩散输出,包括使用所述扩散神经网络处理包括截至该迭代的当前网络输出的第一扩散输入以生成第一扩散输出,所述处理包括归一化当前网络输出;以及
使用该迭代的所述最终扩散输出来更新当前网络输出。
10.根据权利要求9所述的方法,其中,归一化当前网络输出包括基于当前网络输出的方差来归一化当前网络输出。
11.一种系统,包括:
一个或多个计算机;以及
通信地耦合到所述一个或多个计算机的一个或多个存储设备,其中,所述一个或多个存储设备存储指令,所述指令在由所述一个或多个计算机执行时使所述一个或多个计算机执行根据权利要求1-10中任一项所述的方法的操作。
12.存储指令的一个或多个非暂时性计算机存储介质,所述指令在由一个或多个计算机执行时使所述一个或多个计算机执行根据权利要求1-10中任一项所述的方法的操作。
Applications Claiming Priority (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
US202363441417P | 2023-01-26 | 2023-01-26 | |
US63/441,417 | 2023-01-26 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117910533A true CN117910533A (zh) | 2024-04-19 |
Family
ID=89767010
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410118622.8A Pending CN117910533A (zh) | 2023-01-26 | 2024-01-26 | 用于扩散神经网络的噪声调度 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117910533A (zh) |
-
2024
- 2024-01-26 CN CN202410118622.8A patent/CN117910533A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109741736B (zh) | 使用生成对抗网络进行鲁棒语音识别的系统和方法 | |
CN112289342B (zh) | 使用神经网络生成音频 | |
KR102213013B1 (ko) | 신경망을 이용한 주파수 기반 오디오 분석 | |
CN111386536B (zh) | 语义一致的图像样式转换的方法和系统 | |
US20210224578A1 (en) | Classifying input examples using a comparison set | |
US20200104640A1 (en) | Committed information rate variational autoencoders | |
US11922281B2 (en) | Training machine learning models using teacher annealing | |
CN111699497B (zh) | 使用离散潜变量的序列模型的快速解码 | |
US20240127058A1 (en) | Training neural networks using priority queues | |
CN113039555B (zh) | 在视频剪辑中进行动作分类的方法、系统及存储介质 | |
CN116468070A (zh) | 使用规范化的目标输出训练神经网络 | |
CN110663049A (zh) | 神经网络优化器搜索 | |
US11755879B2 (en) | Low-pass recurrent neural network systems with memory | |
US20220129740A1 (en) | Convolutional neural networks with soft kernel selection | |
CN113826125A (zh) | 使用无监督数据增强来训练机器学习模型 | |
US11062229B1 (en) | Training latent variable machine learning models using multi-sample objectives | |
WO2023144386A1 (en) | Generating data items using off-the-shelf guided generative diffusion processes | |
US20230206030A1 (en) | Hyperparameter neural network ensembles | |
CN117910533A (zh) | 用于扩散神经网络的噪声调度 | |
CN111868752B (zh) | 神经网络层权重的连续参数化 | |
WO2020182930A1 (en) | Compressed sensing using neural networks | |
US20230325658A1 (en) | Conditional output generation through data density gradient estimation | |
WO2024138177A1 (en) | Recurrent interface networks | |
CN113298248B (zh) | 一种针对神经网络模型的处理方法、装置以及电子设备 | |
US20240038212A1 (en) | Normalizing flows with neural splines for high-quality speech synthesis |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |