CN113330462A - 使用软最近邻损失的神经网络训练 - Google Patents

使用软最近邻损失的神经网络训练 Download PDF

Info

Publication number
CN113330462A
CN113330462A CN202080010180.5A CN202080010180A CN113330462A CN 113330462 A CN113330462 A CN 113330462A CN 202080010180 A CN202080010180 A CN 202080010180A CN 113330462 A CN113330462 A CN 113330462A
Authority
CN
China
Prior art keywords
data
network
data elements
data element
neural network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202080010180.5A
Other languages
English (en)
Inventor
杰弗里·E·欣顿
尼古拉斯·迈尔斯·维塞纳·福罗斯特
尼古拉斯·盖伊·罗伯特·帕佩尔诺特
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Google LLC
Original Assignee
Google LLC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Google LLC filed Critical Google LLC
Publication of CN113330462A publication Critical patent/CN113330462A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0475Generative networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0499Feedforward networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning

Abstract

用于训练分类神经网络的方法、系统和装置,包括在计算机存储介质上编码的计算机程序。在一个方面,一种方法包括:对多个网络输入中的每一个网络输入:使用分类神经网络处理网络输入以生成定义网络输入的预测类别的分类输出;确定软最近邻损失,其中,软最近邻损失鼓励不同类别的网络输入的中间表示变得更加纠缠,其中,不同类别的网络输入的中间表示的纠缠表征了不同类别的网络输入的中间表示对相对于同一类别的网络输入的中间表示对的相似程度;以及使用软最近邻损失相对于分类神经网络参数的梯度来调整分类神经网络参数的当前值。

Description

使用软最近邻损失的神经网络训练
技术领域
本说明书涉及使用机器学习模型处理数据。
背景技术
机器学习模型接收输入,并基于所接收的输入生成输出,例如,预测输出。一些机器学习模型是参数模型,并且基于所接收的输入和模型的参数的值生成输出。
一些机器学习模型是采用多层模型对所接收的输入生成输出的深度模型。例如,深度神经网络是包括输出层以及一个或多个隐藏层的深度机器学习模型,每个隐藏层将非线性变换应用于所接收的输入以生成输出
发明内容
本说明书描述了被实现为在一个或多个位置的一个或多个计算机上的计算机程序的训练系统,该训练系统被配置为使用软最近邻损失(soft nearest neighbor loss)训练分类神经网络和生成神经网络。
根据第一方面,提供一种由一个或多个数据处理装置执行的用于训练分类神经网络的方法,所述方法包括:对于多个网络输入中的每一个网络输入:根据分类神经网络参数的当前值,使用所述分类神经网络处理所述网络输入,以生成定义所述网络输入的预测类别的分类输出。对于包括来自多个网络输入的第一网络输入和第二网络输入的多对网络输入中的每一对网络输入,基于以下两者之间的相应相似性度量来确定软最近邻损失:(i)所述第一网络输入的中间表示,所述第一网络输入的中间表示由所述分类神经网络的一个或多个隐藏层通过处理所述第一网络输入以生成用于所述第一网络输入的分类输出而生成,以及(ii)所述第二网络输入的中间表示,所述第二网络输入的中间表示由所述分类神经网络的一个或多个隐藏层通过处理所述第二网络输入以生成用于所述第二网络输入的分类输出而生成。所述软最近邻损失鼓励不同类别的网络输入的中间表示变得更加纠缠,其中,不同类别的网络输入的中间表示的纠缠表征了不同类别的网络输入的中间表示对相对于同一类别的网络输入的中间表示对的相似程度。使用软最近邻损失相对于所述分类神经网络参数的梯度来调整所述分类神经网络参数的当前值。
在一些实现方式中,确定软最近邻损失包括:对于多个网络输入中的每个给定网络输入:确定给定网络输入的类内变化,所述给定网络输入的类内变化表征所述给定网络输入的中间表示与属于与所述给定网络输入同一类别的所述多个网络输入中的其他网络输入的中间表示的相似程度。确定所述给定网络输入的总变化,所述给定网络输入的总变化表征所述给定网络输入的中间表示与属于任何类别的所述多个网络输入中的其他网络输入的中间表示的相似程度。确定所述给定网络输入的所述类内变化与所述总变化的比率;以及基于用于每个给定网络输入的类内变化与总变化的各个比率确定所述软最近邻损失。
在一些实现方式中,确定用于所述给定网络输入的类内变化包括确定:
Figure BDA0003172800360000021
其中,j索引所述多个网络输入的网络输入,b是所述多个网络输入中的网络输入的总数,i是所述给定网络输入的索引,yi表示所述给定网络输入的类别,yj表示对应于索引j的网络输入的类别,xi表示所述给定网络输入的中间表示,xj表示对应于索引j的网络输入的中间表示,S(·,·)是相似性度量,T是温度参数。
在一些实现方式中,确定用于所述给定网络输入的总变化包括确定:
Figure BDA0003172800360000031
其中,j索引所述多个网络输入的网络输入,b是所述多个网络输入中的网络输入的总数,i是所述给定网络输入的索引,xi表示所述给定网络输入的中间表示,xj表示对应于索引j的网络输入的中间表示,S(·,·)是相似性度量,T是温度参数。
在一些实现方式中,基于用于每个给定网络输入的类内变化和总变化的各个比率来确定软最近邻损失包括将软最近邻损失确定为:
Figure BDA0003172800360000032
其中,b是所述多个网络输入中的网络输入的总数,i索引所述给定网络输入,以及Ri表示用于对应于索引i的所述给定网络输入的类内变化和总变化的比率。
在一些实现方式中,使用软最近邻损失相对于所述分类神经网络参数的梯度来调整所述分类神经网络参数的当前值包括:使用软最近邻损失相对于所述温度参数的梯度来调整所述温度参数的当前值。
在一些实现方式中,定义所述网络输入的预测类别的分类输出包括用于多个可能类别中的每一个的相应似然性得分,其中,用于给定类别的似然性得分指示所述网络输入属于给定类别的似然性。
在一些实现方式中,该方法进一步包括基于定义每个网络输入的预测类别的相应分类输出来确定分类损失。使用所述分类损失相对于所述分类神经网络参数的梯度来调整所述分类神经网络参数的当前值。
在一些实现方式中,所述分类损失包括交叉熵损失。
在一些实现方式中,所述多对网络输入包括由来自所述多个网络输入的第一网络输入和第二不同网络输入组成的每个可能网络输入对。
在一些实现方式中,所述网络输入是图像。
在一些实现方式中,图像的类别定义在所述图像中描绘的对象的分类。
根据第二方面,提供一种由一个或多个数据处理装置执行的用于训练生成神经网络以基于真实数据元素的训练数据集合来生成合成数据元素的方法,所述方法包括:根据生成神经网络参数的当前值,使用所述生成神经网络来生成合成数据元素集合。从真实数据元素的训练数据集合中获得真实数据元素集合。对于包括第一数据元素和第二数据元素的多个数据元素对中的每一个数据元素对,基于所述第一数据元素与所述第二数据元素之间的各个相似性度量来确定软最近邻损失,所述第一数据元素和第二数据元素来自包括合成数据元素集合和真实数据元素集合的组合数据元素集合。所述软最近邻损失鼓励不同类别的数据元素变得更加纠缠。数据元素的类别定义所述数据元素是真实数据元素还是合成数据元素。不同类别的数据元素的纠缠表征了不同类别的数据元素对相对于同一类别的数据元素对的相似程度。使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的当前值。
在一些实现方式中,确定软最近邻损失包括:对于来自包含合成数据元素集合和真实数据元素集合的组合数据元素集合中的每个给定数据元素:确定用于给定数据元素的类内变化,所述给定数据元素的类内变化表征所述给定数据元素与来自与所述给定数据元素属于同一类别的组合数据元素集合的其他数据元素的相似程度。确定用于所述给定数据元素的总变化,所述给定数据元素的总变化表征所述给定数据元素与来自属于任何类别的组合数据元素集合中的其他数据元素的相似程度。确定所述给定数据元素的所述类内变化与所述总变化的比率。基于用于所述每个给定数据元素的类内变化与总变化的相应比率来确定所述软最近邻损失。
在一些实现方式中,确定用于所述给定数据元素的类内变化包括确定:
Figure BDA0003172800360000051
其中,j索引来自组合数据元素集合的数据元素,b是组合数据元素集合中的数据元素的总数,i是所述给定数据元素的索引,yi表示所述给定数据元素的类别,yj表示对应于所述索引j的数据元素的类别,xi表示所述给定数据元素,xj表示对应于索引j的数据元素,S(·,·)是相似性度量,以及T是温度参数。
在一些实现方式中,确定用于给定数据元素的总变化包括确定:
Figure BDA0003172800360000052
其中,j索引来自组合数据元素集合的数据元素,b是组合数据元素集合中的数据元素的总数,i是所述给定数据元素的索引,xi表示所述给定数据元素,xj表示对应于索引j的数据元素,S(·,·)是相似性度量,S(·,·)是相似性度量以及T是温度参数。
在一些实现方式中,基于用于每个给定数据元素的类内变化和总变化的各个比率来确定软最近邻损失包括将所述软最近邻损失确定为:
Figure BDA0003172800360000053
其中,b是所述组合数据元素集合中的数据元素的总数,i索引所述给定数据元素,以及Ri表示用于对应于索引i的给定数据元素的类内变化与总变化的比率。
在一些实现方式中,使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的当前值包括:使用所述软最近邻损失相对于所述温度参数的梯度来调整所述温度参数的当前值。
在一些实现方式中,所述数据元素是图像。
根据第三方面,提供一种由一个或多个数据处理装置执行的用于训练生成神经网络以基于真实数据元素的训练数据集合来生成合成数据元素的方法。所述方法包括根据生成神经网络参数的当前值,使用所述生成神经网络来生成合成数据元素集合。从真实数据元素的训练数据集合中获得真实数据元素集合。对于包括所述合成数据元素集合和所述真实数据元素集合的组合数据元素集合中的每个数据元素,使用判别器神经网络以生成所述数据元素的嵌入。对于包括第一数据元素和第二数据元素的多个数据元素对中的每一个数据元素对,基于所述第一数据元素的嵌入和所述第二数据元素的嵌入之间的各个相似性度量来确定软最近邻损失,所述第一数据元素和第二数据元素来自包括合成数据元素集合和真实数据元素集合的组合数据元素集合。所述软最近邻损失鼓励不同类别的数据元素的嵌入变得更加纠缠。数据元素的类别定义所述数据元素是真实数据元素还是合成数据元素。不同类别的数据元素的嵌入的纠缠表征了不同类别的数据元素对的嵌入与同一类别的数据元素对的嵌入的相似程度。使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的当前值。
在一些实现方式中,确定软最近邻损失包括对于来自包含合成数据元素集合和真实数据元素集合的组合数据元素集合中的每个给定数据元素:确定用于给定数据元素的类内变化,所述给定数据元素的类内变化表征所述给定数据元素的嵌入与来自与所述给定数据元素属于同一类别的组合数据元素集合的其他数据元素的嵌入的相似程度。确定用于所述给定数据元素的总变化,所述给定数据元素的总变化表征所述给定数据元素的嵌入与来自属于任何类别的组合数据元素集合中的其他数据元素的嵌入的相似程度。确定所述给定数据元素的类内变化与所述总变化的比率。基于用于所述每个给定数据元素的类内变化与总变化的相应比率来确定所述软最近邻损失。
在一些实现方式中,确定用于所述给定数据元素的类内变化包括确定:
Figure BDA0003172800360000071
其中,j索引来自组合数据元素集合的数据元素,b是组合数据元素集合中的数据元素的总数,i是所述给定数据元素的索引,yi表示所述给定数据元素的类别,yj表示对应于索引j的数据元素的类别,E(xi)表示所述给定数据元素的嵌入,E(xj)表示对应于索引j的数据元素的嵌入,S(·,·)是相似性度量,以及T是温度参数。
在一些实现方式中,确定给定数据元素的总变化包括确定:
Figure BDA0003172800360000072
其中,j索引来自组合数据元素集合的数据元素,b是组合数据元素集合中的数据元素的总数,i是所述给定数据元素的索引,E(xi)表示所述给定数据元素的嵌入,E(xj)表示对应于索引j的数据元素的嵌入,S(·,·)是相似性度量,以及T是温度参数。
在一些实现方式中,基于用于每个给定数据元素的类内变化与总变化的各个比率来确定软最近邻损失包括将所述软最近邻损失确定为:
Figure BDA0003172800360000073
其中,b是所述组合数据元素集合中的数据元素的总数,i索引所述给定数据元素,以及Ri表示用于对应于索引i的给定数据元素的类内变化与总变化的比率。
在一些实现方式中,使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的当前值包括:使用所述软最近邻损失相对于所述温度参数的梯度来调整所述温度参数的当前值。
在一些实现方式中,所述数据元素是图像。
在一些实现方式中,该方法进一步包括:使用所述软最近邻损失相对于所述判别器神经网络参数的梯度来调整所述判别器神经网络参数的当前值。
在一些实现方式中,调整所述判别器神经网络参数的当前值鼓励所述判别器神经网络生成纠缠较少的不同类别的数据元素的嵌入。
根据第四方面,提供一种由一个或多个数据处理装置执行的用于对数据进行分类的方法,所述方法包括:向分类神经网络提供输入数据,所述分类神经网络已经通过执行第一方面的方法进行了训练;使用所述分类神经网络对所述输入数据进行分类;以及接收来自所述分类神经网络的分类输出,所述输出指示所述输入数据的类别。
根据第五方面,提供一种由一个或多个数据处理装置执行的用于生成合成数据的方法,所述方法包括:向生成神经网络提供输入数据,所述生成神经网络已经通过执行第二或第三方面的方法进行了训练;使用所述生成神经网络,基于所述输入数据生成合成数据;以及从所述生成神经网络接收所述合成数据。
根据第六方面,提供一种系统,包括一个或多个计算机和一个或多个存储设备,所述存储设备存储指令,所述指令当由所述一个或多个计算机执行时,使所述一个或多个计算机执行所述方法的相应操作。
根据第六方面,提供一种存储指令的计算机程序产品,所述指令当由一个或多个计算机执行时,使所述一个或多个计算机执行所述方法的相应操作。
可以实现在本说明书中描述的主题的具体实施例以便实现一个或多个以下优点。
本说明书描述了一种判别训练系统,该系统使用软最近邻损失训练分类神经网络,该软最近邻损失鼓励来自不同类别的网络输入的中间表示(即,由分类神经网络的隐藏层生成)变得更加纠缠。软最近邻损失可以使网络输入的中间表示表征用来捕获提高分类精度信息的类别无关特征,从而正则化分类神经网络并提高其泛化能力,即在训练期间未使用的网络输入上实现可接受的预测精度。如果特征的值对于不同类别的网络输入可能相似,则该特征可以被称为“类别无关”。通过充当正则化器,软最近邻损失还能够通过使用更少的训练数据、更少的训练迭代或两者来训练分类网络,从而减少计算资源(例如,内存和计算功率)的消耗。
本说明书还描述了一种生成训练系统,该系统可以使用软最近邻损失训练生成神经网络,该软最近邻损失鼓励由生成神经网络生成的“合成(synthetic)”数据元素与来自训练数据集合的“真实(genuine)”数据元素变得更加纠缠。使用软最近邻损失训练生成神经网络使得生成神经网络能够生成“现实的(realistic)”合成数据元素,即,与训练数据集合中的真实数据元素具有相似特征的合成数据元素。软最近邻损失能够通过使用较少的训练数据、较少的训练迭代或两者来训练生成神经网络。因此,与不使用软最近邻损失的一些传统训练系统相比,生成训练系统在训练生成神经网络时可能消耗更少的计算资源(例如,内存和计算功率)。
在附图和以下描述中阐述了本说明书的主题的一个或多个实施例的细节。本主题的其他特征、方面和优点根据说明书、附图和权利要求将变得显而易见。
附图说明
图1包括四个面板,每个面板图示数据点集合和相关联的软最近邻损失值,该值度量不同类别的数据点的纠缠。
图2示出了示例判别训练系统。
图3示出了示例生成训练系统。
图4示出了指示在由分类神经网络中的隐藏层生成的不同类别的网络输入的中间表示的纠缠的图。
图5是用于计算分别与相应类别相关联的数据点集合的软最近邻损失的示例过程的流程图。
图6是使用软最近邻损失训练分类神经网络的示例过程的流程图。
图7是使用软最近邻损失训练生成神经网络的示例过程的流程图。
各个附图中相同的附图标记和名称表示相同的元件。
具体实现方式
如本说明书中所使用的,对于分别与相应类别相关联的数据点集合,不同类别的数据点的“纠缠(entanglement)”表征了不同类别的数据点对与同一类别的数据点对的相似程度。
数据点是指数值的有序集合,例如数值的向量或矩阵,其可以表示例如图像、音频数据段或文本的一部分。在一个示例中,数据点可以是由神经网络的一个或多个隐藏层通过处理网络输入生成的网络输入的中间表示;在另一个示例中,数据点可以是由神经网络的输出层通过处理网络输入生成的网络输入的嵌入。
数据点的类别是指数据点的标签(例如,在表示图像的数据点的情况下,数据点的类别可以指定图像中描绘的对象的分类)。数据点对之间的相似性可以使用数值相似性度量(例如,欧几里德相似性度量或余弦相似性度量)来测量。
图1包括四个面板(即面板102、104、106和108),分别图示一个数据点集合,其中,每个数据点的类别通过其形状和颜色来区分。每个面板还与度量不同类别的数据点的纠缠的软最近邻损失值相关联。面板102图示具有最高软最近邻损失(即最高纠缠)的数据点,而面板108图示具有最低软最近邻损失(即最低纠缠)的数据点。可以意识到,面板108中所示的具有低纠缠的数据点被分组为类别同质聚类,而面板102中所示的具有高纠缠的数据点未被分组。
本说明书描述了一种判别训练系统,该系统使用软最近邻损失来训练分类神经网络,该分类神经网络被配置为处理网络输入以生成预测网络输入的类别的相应分类输出。使用软最近邻损失训练分类神经网络鼓励来自不同类别的网络输入的中间表示(即,由分类神经网络的隐藏层生成)变得更加纠缠。软最近邻损失可以通过鼓励网络输入的中间表示来表征捕获提高分类精度信息的类别无关特征,从而使分类神经网络正则化。软最近邻损失可以被添加为到分类损失的附加项,该分类损失鼓励分类神经网络生成与指定网络输入的类别的目标输出相匹配的分类输出。
本说明书进一步描述了一种生成训练系统,该生成训练系统使用软最近邻损失来训练生成神经网络,该生成神经网络被配置为生成具有与来自训练数据集合中的“真实”数据元素相似特性的“合成”数据元素。使用软最近邻损失训练生成神经网络可能会鼓励合成数据元素(或其嵌入)和真实数据元素(或其嵌入)变得更加纠缠。增加合成数据元素与真实数据元素之间的纠缠可以增加合成数据元素的真实性,例如,通过增加它们的相似之处以真实化数据元素。
在下文中,更详细地描述这些特征和其他特征。
图2示出了示例判别训练系统200。判别训练系统200是被实现为在一个或多个位置中的一台或多台计算机上的计算机程序的系统的示例,其中,实现了以下系统、组件和技术。
判别训练系统200训练分类神经网络202。分类神经网络202被配置为处理网络输入204以生成定义网络输入204的预测类别的相应分类输出206。网络输入可以是任何种类的数字数据输入,例如图像数据、视频数据、音频数据或文本数据。分类输出可以包括用于多种可能类别中的每一个类别的相应得分,其中,类别的得分指示网络输入来自该类别的似然性;网络输入的预测类别可以被标识为具有最高得分的类别。
在一个示例中,网络输入可以是图像或由图像导出的特征,并且网络输入的类别可以指定图像的分类。例如,图像的类别可以指定图像是否描绘了特定类型的对象,例如车辆、行人、道路标志等。作为另一个示例,医学图像的类别可以指定在医学图像中描绘的患者的医疗状况。作为另一个示例,图像的类别可以指定图像是否描绘了不适当的(例如,令人反感的)内容。作为另一个示例,图像的类别可以指定正由图像中描绘的人执行的动作的类型(例如,坐、站、跑等)。
在另一个示例中,网络输入可以是文本序列并且网络输入的类别可以指定由网络输入表达的意图,例如执行某个动作的意图。
在另一示例中,网络输入可以是一系列音频数据样本,并且网络输入的类别可以指定对应于音频数据样本的音素或字素。
在另一个示例中,网络输入可以是互联网资源(例如,网页)或文档(或其一部分,或其提取的特征),并且网络输入的类别可以指定网络输入的主题。
分类神经网络可以具有任何合适的神经网络架构,例如前馈架构或递归架构,并且可以包括任何合适种类的神经网络层或块,例如全连接层、卷积层或残差块。通常,分类神经网络包括一个或多个隐藏层208,即在分类神经网络的架构中,在输入层之后并在输出层之前的层。输出层是指生成分类神经网络的分类输出的神经网络层,即生成指示网络输入来自每个类别的似然性的类别得分的层。
判别训练系统200在包括多个训练示例的训练数据集合上的多次训练迭代上训练分类神经网络202。每个训练示例指定:(i)网络输入,以及(ii)网络输入的目标(即实际)类别。在每次训练迭代中,判别训练系统200可以从训练数据获得(例如,采样)当前训练示例的“批”(即,集合),并且处理由每个训练示例指定的网络输入204以生成:(i)网络输入204的中间表示210,以及(ii)网络输入204的分类输出206。网络输入204的中间表示210是指由分类神经网络的一个或多个隐藏层208通过处理网络输入而生成的输出。中间表示210可以被表示为数值的有序集合,例如数值的向量或矩阵。在一个示例中,网络输入214的中间表示210可以是分类神经网络的指配隐藏层208的输出。在另一示例中,网络输入214的中间表示210可以是分类神经网络的多个指配隐藏层的输出的组合(例如,级联)。
在处理来自训练数据的当前批网络输入204之后,判别训练系统200使用以下之一或两者来更新分类神经网络的当前参数值:(i)软最近邻损失212,和(ii)分类损失214。特别地,判别训练系统200使用软最近邻损失212、分类损失214或两者相对于分类神经网络的当前参数值的梯度来更新分类神经网络的当前参数值。例如,判别训练系统200可以使用由下式给出的复合损失函数
Figure BDA0003172800360000131
的梯度来更新分类神经网络的当前参数值:
Figure BDA0003172800360000132
其中,
Figure BDA0003172800360000133
是分类损失,
Figure BDA0003172800360000134
是软最近邻损失,并且α>0是控制分类损失和软最近邻损失相对重要性的超参数。在其他示例中,复合损失函数可以包括多个软最近邻损失,每个软最近邻损失对应于由分类神经网络的相应隐藏层生成的中间输出。即,判别训练系统200可以生成每个网络输入的多个中间表示,并且复合损失函数可以包括对应于这些中间表示中的每一个中间表示的相应软最近邻损失。
判别训练系统200可以使用反向传播技术来计算软最近邻损失212和分类损失214的梯度,并且使用任何适当的梯度下降优化过程,例如RMSprop或Adam来更新分类神经网络202的当前参数值。
使用软最近邻损失212来更新分类神经网络202的当前参数值鼓励不同类别的网络输入204的中间表示210变得更纠缠。即,软最近邻损失212相对于相同类别的网络输入的中间表示对鼓励在不同类别的网络输入204的中间表示对之间增加相似性。将参考图5更详细地描述用于计算软最近邻损失212的示例技术。
使用分类损失214更新分类神经网络202的当前参数值鼓励网络输入204的分类输出206匹配由训练示例指定的目标类别。分类损失214可以是例如交叉熵损失。
使用软最近邻损失212可以提高所训练的分类神经网络从训练数据泛化到先前未见过的网络输入的能力,即可以提高分类神经网络对未用于训练分类神经网络的网络输入的精度。具体地,软最近邻损失可以通过鼓励网络输入的中间表示表征捕获提高分类精度信息的类别无关特征来正则化分类神经网络。在没有软最近邻损失的情况下,网络输入的中间表示可以形成类别同质聚类(即,大部分来自同一类别的中间表示组);如果网络输入的中间表示未在这些类别同质聚类之一中被表示,则分类输出可能不准确。软最近邻损失不鼓励形成中间表示的类别同质聚类,从而可以提高分类神经网络的泛化和鲁棒性。
除了对分类神经网络的训练进行正则化之外,使用软最近邻损失也可以促进确定由针对“测试”网络输入(即不是被包括在训练数据中的训练网络输入)的分类神经网络生成的类别预测的置信度。为了确定针对测试网络输入生成的类别预测的置信度,可以标识与测试网络输入的中间表示最相似(即,最接近)的预定义的K个训练网络输入的中间表示。然后,可以基于与测试网络输入的预测类别共享同一类别的K个最近中间表示的分数,确定用于测试网络输入的类别预测的置信度。通常,与测试网络输入的预测类别共享同一类别的K个最近中间表示的较高分数指示对测试网络输入的类别预测的较高置信度,反之亦然。
如上所述,使用软最近邻损失可以通过不鼓励形成中间表示的类别同质聚类来增加类别预测的置信度和类别预测的精度之间的相关性。当分类神经网络可能被提供异常值测试数据时这尤其重要,例如,在医疗诊断环境中(患者可能受到未知状况的困扰),或由于对抗性攻击。在对抗性攻击中,网络输入被提供给分类神经网络以试图使分类神经网络生成不准确的类别预测。通过促进对分类神经网络生成的类别预测的置信度的评估,软最近邻损失可以例如,通过使分类神经网络(以及通过扩展,计算机系统)不太容易受到恶意行为者的对抗性攻击,来提高使用分类神经网络的计算机系统的安全性。
使用软最近邻损失可以更好地处理与用于训练分类神经网络的训练数据不同的异常数据。可以通过观察隐藏层中的数据来识别不是来自训练分布的数据,因为该数据具有少于来自预测类别的正常数量的邻居。在一个示例中,这可以允许检测对抗性攻击。在另一个示例中,这可以被用来在患者状况未知时辅助医疗诊断。
图3示出了示例生成训练系统300。生成训练系统300是在一个或多个位置的一台或多台计算机上实现为计算机程序的系统的示例,其中,实现了以下所述的系统、组件和技术。
生成训练系统300使用软最近邻损失308训练生成神经网络302以生成具有与“真实”数据元素306的训练数据集合相似特征的“合成”数据元素304。数据元素可以是例如图像、文本片段或音频片段。
通常,生成神经网络302可以具有使其能够生成数据元素的任何适当的神经网络架构,并且生成神经网络302可以以多种方式中的任何一种生成数据元素。在一个示例中,为了生成合成数据元素304,生成神经网络302可以处理从潜变量空间上的概率分布采样的潜变量。在该示例中,潜变量的空间可以是例如实数集合,并且潜在空间上的概率分布可以是例如正态(0,1)概率分布。在另一个示例中,为了生成合成数据元素304,生成神经网络302可以在可能的数据元素的空间上生成概率分布,并且根据概率分布对合成数据元素进行采样。
生成训练系统300在多次训练迭代中训练生成神经网络。在每次训练迭代时,生成训练系统300使用生成神经网络302来生成当前批(集合)的合成数据元素304,并且获得(例如,采样)当前批(集合)的真实数据元素306。生成训练系统300提供当前合成数据元素304和真实数据元素306作为到判别神经网络310(其也可以被称为“判别器”神经网络)的相应输入,该判别神经网络310被配置为处理输入数据元素以生成输入数据元素的嵌入312。数据元素的嵌入是指将数据元素表示为数值的有序聚合,例如数值的向量或矩阵。判别神经网络310可以具有使得其能够生成嵌入的任何适当的神经网络架构,例如全连接或卷积神经网络架构。
在生成当前合成和真实数据元素的嵌入312之后,生成训练系统300使用软最近邻损失308来更新生成神经网络的当前参数值。特别地,生成训练系统300相对于生成神经网络的当前参数值使用软最近邻损失308的梯度来更新生成神经网络的当前参数值。生成训练系统300可以使用反向传播技术来计算软最近邻损失308的梯度,并且使用任何适当的梯度下降优化过程(例如,RMSprop或Adam)来更新生成神经网络302的当前参数值。
通常,对于多个数据元素对(例如,包括两个真实数据元素、两个合成数据元素或一个真实数据元素和一个合成数据元素的数据元素对)中的每一个数据元素对,软最近邻损失308基于该对中的数据元素的嵌入312之间的相似性的度量。使用软最近邻损失308更新生成神经网络302的当前参数值鼓励合成数据元素304和真实数据元素306的嵌入312变得更加纠缠。也就是说,如果数据元素的类别被理解为定义数据元素是合成的还是真实的,则软最近邻损失308相对于相同类别的数据元素的嵌入对鼓励不同类别的数据元素的嵌入对312之间的相似性增加。参考图5更详细地描述了用于计算软最近邻损失308的示例技术。
生成训练系统300与生成神经网络302协同地训练判别神经网络310,例如,通过在训练判别神经网络310和生成神经网络302之间交替。具体地,在多次训练迭代中的每一次,生成训练系统300使用判别神经网络以生成合成数据元素304和真实数据元素306的嵌入312。然后,生成训练系统300使用软最近邻损失308来更新判别神经网络的当前参数值,以鼓励合成数据元素304和真实数据元素306的嵌入312变得更少纠缠。
在训练过程中,生成神经网络302在生成具有与真实数据样本的嵌入更加纠缠的嵌入的合成数据样本方面不断变得更好。同时,判别神经网络不断适应以使合成数据元素的嵌入与真实数据元素的嵌入更少纠缠。生成神经网络302和判别神经网络310的对抗训练导致由生成神经网络302生成的合成数据元素304的特性越来越类似于真实数据元素306的特性。
判别神经网络310使得生成训练系统300能够评估学习嵌入空间中的合成和真实数据元素的纠缠,这可以促进生成神经网络302的更有效训练。然而,在简化的实现方式中,生成训练系统300可以避免使用判别神经网络310。在这些实现方式中,生成训练系统300可以通过使用软最近邻损失308来训练生成神经网络302,以鼓励合成数据元素变得与真实数据元素更加纠缠,即,在数据元素空间中而不是在嵌入空间中。即,在这些实现方式中,对于多个数据元素对(例如,包括两个真实数据元素、两个合成数据元素或一个真实数据元素和一个合成数据元素的数据元素对)中的每一个数据元素对,软最近邻损失是基于该对中的数据元素之间的相似性的度量。如果将数据元素的类别理解为定义数据元素是合成的还是真实的,则软最近邻损失相对于相同类别的数据元素对鼓励不同类别的数据元素对之间的相似性增加。参考图5更详细地描述了用于计算软最近邻损失的示例技术。
图4示出了指示由分类神经网络中的一系列层(特别是在CIFAR-10训练数据集合上训练的ResNet的最后一个块中的层)生成的不同类别的网络输入的中间表示的纠缠(由软最近邻损失进行测量)的图。可以意识到,在使用软最近邻损失的训练过程中,除了输出层(在图4中标记为“最后一层”)之外,每一层的中间表示的纠缠通常会增加。软最近邻损失不被应用于输出层,从而允许输出层保持判别性。
图5是用于计算分别与相应类别相关联的数据点集合的软最近邻损失的示例过程500的流程图。为方便起见,过程500将被描述为由位于一个或多个位置的一个或多个计算机的系统执行。例如,根据本说明书适当编程的训练系统,例如图2的判别训练系统200或图3的生成训练系统300可以执行过程500。
参考分别与相应类别相关联的数据点集合描述过程500。在一个示例中,每个数据点可以是由分类神经网络的隐藏层通过在训练分类神经网络期间处理来自当前一批网络输入的网络输入而生成的中间输出,并且每个数据点的类别可以是相应网络输入的类别,如参考图2所述。在另一个示例中,每个数据点可以是在训练生成神经网络期间,来自当前一批数据元素的数据元素或数据元素的嵌入,并且每个数据点的类别可以指示该数据元素是真实的或合成的(即,由生成神经网络生成),如参考图3所述。
系统为每个数据点确定相应的类内变化(502)。给定数据点的类内变化表征该给定数据点与属于与该给定数据点同一类别的其他数据点的相似程度。在一个示例中,系统可以将给定数据点的类内变化确定为:
Figure BDA0003172800360000191
其中,j索引数据点,b是数据点的总数(例如对应于当前批的数据点),i是给定数据点的索引,yi表示给定数据点的类别(例如,指示数据点i对应于真实数据元素还是合成数据元素),yj表示对应于索引j的数据点的类别(例如,指示数据点j对应于真实数据元素还是合成数据元素),pi表示给定数据点,pj表示对应于索引j的数据点,S(·,·)是相似性度量(例如,S(pi,pj)=|pi-pj|2),以及T是温度参数,其控制赋予数据点对之间的相似性的相对重要性。
系统确定每个数据点的相应总变化(504)。给定数据点的总变化表征该数据点与属于任何类别的其他数据点的相似程度。在一个示例中,系统可以将给定数据点的总变化确定为:
Figure BDA0003172800360000192
其中,j索引数据点,b是数据点的总数,i是给定数据点的索引,pi表示给定数据点,pj表示对应于索引j的数据点,S(·,·)是相似性度量(例如S(pi,pj)=|pi-pj|2),以及T是温度参数,其控制赋予数据点对之间的相似性的相对重要性。
系统基于每个数据点的类内变化与总变化的相应比率来确定软最近邻损失(506)。例如,系统可以将软最近邻损失
Figure BDA0003172800360000201
确定为:
Figure BDA0003172800360000202
其中,b是数据点的总数,i对数据点进行索引,Ri表示用于对应于索引i的数据点的类内变化与总变化的比率,以及
Figure BDA0003172800360000203
被表示为温度参数T的函数(如上所述)。
从等式(2)到(4)可以看出,软最近邻损失使用了一批b个数据点中的所有数据点。这导致中间表示变得比可能基于单个正数据点和单个负数据点的其他损失函数产生的中间表示更纠缠。
在一些实现方式中,系统可以将软最近邻损失
Figure BDA0003172800360000204
确定为所有温度下的最小值,即:
Figure BDA0003172800360000205
其中,参考等式(4)描述
Figure BDA0003172800360000206
在训练判别神经网络或生成神经网络期间(如上所述),通过将T初始化为预定值并且在T上利用梯度下降进行优化来近似(参考等式(5)所述的)软最近邻损失
Figure BDA0003172800360000207
以最小化软最近邻损失,即通过使用软最近邻损失相对于T的梯度来调整T的当前值。也就是说,在训练判别神经网络或生成神经网络期间,T的值与判别神经网络或生成神经网络的参数值联合调整以优化软最近邻损失。将软最近邻损失确定为所有温度下的最小值消除了手动地设置温度超参数值的任何需求,例如通过试错测试。
图6是用于使用软最近邻损失训练分类神经网络的示例过程600的流程图。为方便起见,过程600将被描述为由位于一个或多个位置的一台或多台计算机的系统执行。例如,根据本说明书适当编程的训练系统(例如图2的判别训练系统200)可以执行过程600。
系统使用分类神经网络(602)处理多个网络输入中的每一个网络输入。分类神经网络被配置为处理网络输入以生成定义网络输入的预测类别的分类输出。
系统确定软最近邻损失(604)。对于包括第一网络输入和第二网络输入的多对网络输入中的每对网络输入,该系统基于以下两者之间的相应相似性度量来确定软最近邻损失:(i)第一网络输入的中间表示,以及(ii)第二网络输入的中间表示。
系统使用软最近邻损失相对于分类神经网络参数的梯度来调整分类神经网络参数的当前值(606)。软最近邻损失鼓励不同类别的网络输入的中间表示变得更加纠缠。
在调整分类神经网络参数的当前值之后,系统可以返回到步骤602以执行另一次训练迭代。当满足训练终止标准时,例如,当已经执行预定数量的训练迭代时,或者当分类神经网络对验证数据集合的精度达到预定义的阈值时,系统可以确定训练完成。
图7是用于使用软最近邻损失来训练生成神经网络的示例过程700的流程图。为方便起见,过程700将被描述为由位于一个或多个位置的一台或多台计算机的系统执行。例如,根据本说明书适当编程的训练系统,例如图3的生成训练系统300可以执行过程700。
该系统使用生成神经网络来生成当前的合成数据元素集合(702)。
该系统例如通过从训练数据集合中采样预定义数量的真实数据元素,从真实数据元素的训练数据集合中获得当前真实数据元素集合(704)。
该系统使用判别神经网络生成包括当前合成数据元素集合和当前真实数据元素集合的组合数据元素集合中的每个数据元素的相应嵌入(706)。
系统确定软最近邻损失(708)。对于包括来自组合数据元素集合的第一数据元素和第二数据元素的多个数据元素对中的每一个数据元素对,该系统基于第一数据的嵌入与第二数据元素的嵌入之间的相应相似性度量来确定软最近邻损失。
系统使用软最近邻损失相对于生成神经网络参数的梯度来调整生成神经网络参数的当前值(710)。软最近邻损失鼓励合成数据元素的嵌入与真实数据元素的嵌入变得更加纠缠。
在调整了生成神经网络参数的当前值之后,系统可以返回到步骤702以执行另一次训练迭代。当满足训练终止标准时,例如当已经执行了预定义数量的训练迭代时,系统可以确定训练完成。
本说明书连同系统和计算机程序组件一起使用术语“被配置”。对于要被配置成执行特定操作或动作的一个或多个计算机的系统意指系统已在其上安装了在操作中使该系统执行这些操作或动作的软件、固件、硬件或软件、固件、硬件的组合。对于要被配置成执行特定操作或动作的一个或多个计算机程序意指该一个或多个程序包括指令,所述指令当由数据处理装置执行时,使该装置执行操作或动作。
本说明书中描述的主题和功能操作的实施例可用数字电子电路、用有形地具体实现的计算机软件或固件、用包括本说明书中公开的结构及其结构等同物的计算机硬件或者用它们中的一个或多个的组合来实现。本说明书中描述的主题的实施例可作为一个或多个计算机程序被实现,所述一个或多个计算机程序即在有形非暂时性存储介质上编码以供数据处理装置执行或者控制数据处理装置的操作的计算机程序指令的一个或多个模块。计算机存储介质可以是机器可读存储设备、机器可读存储基板、随机或串行访问存储设备或它们中的一个或多个的组合。替换地或此外,可将程序指令编码在人工生成的传播信号上,所述传播信号例如是机器生成的电、光或电磁信号,该传播信号被生成来对信息进行编码以用于传输到适合的接收器装置以供数据处理装置执行。
术语“数据处理装置”指代数据处理硬件并且包含用于处理数据的所有种类的装置、设备和机器,作为示例包括可编程处理器、计算机或多个处理器或计算机。装置还可以是或者进一步包括专用逻辑电路,例如,FPGA(现场可编程门阵列)或ASIC(专用集成电路)。装置除了包括硬件之外还可以可选地包括为计算机程序创建执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统或它们中的一个或多个的组合的代码。
也可以被称为或者描述为程序、软件、软件应用、app、模块、软件模块、脚本或代码的计算机程序可用包括编译或解释语言或声明或过程语言的任何形式的编程语言编写;并且它可以以任何形式部署,包括作为独立程序或者作为模块、组件、子例行程序或适合于在计算环境中使用的其它单元。程序可以但是不必对应于文件系统中的文件。程序可以被存储在保持其它程序或数据的文件的一部分中,例如存储在标记语言文档中的一个或多个脚本;在专用于所述程序的单个文件中或者在多个协调文件中,例如存储代码的一个或多个模块、子程序或部分的文件。可将计算机程序部署成在一个计算机上或者在位于一个站点处或者分布在多个站点上并通过数据通信网络互连的多个计算机上执行。
在本说明书中,术语“引擎”广泛地用于指代被编程来执行一个或多个具体功能的基于软件的系统、子系统或过程。通常,引擎将作为安装在一个或多个位置中的一个或多个计算机上的一个或多个软件模块或组件被实现。在一些情况下,一个或多个计算机将专用于特定引擎;在其它情况下,可在同一计算机或多个计算机上安装并运行多个引擎。
本说明书中描述的过程和逻辑流程可由执行一个或多个计算机程序的一个或多个可编程计算机执行以通过对输入数据进行操作并生成输出来执行功能。过程和逻辑流程还可由例如是FPGA或ASIC的专用逻辑电路执行,或者通过专用逻辑电路和一个或多个编程计算机的组合来执行。
适合于执行计算机程序的计算机可基于通用微处理器或专用微处理器或两者,或任何其它种类的中央处理器。通常,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的必要元件是用于执行或者实行指令的中央处理单元以及用于存储指令和数据的一个或多个存储设备。中央处理单元和存储器可由专用逻辑电路补充或者并入在专用逻辑电路中。通常,计算机还将包括用于存储数据的一个或多个大容量存储设备,例如磁盘、磁光盘或光盘,或者操作上被耦合以从所述一个或多个大容量存储设备接收数据或者将数据传送到所述一个或多个大容量存储设备,或者两者以用于存储数据。然而,计算机不必具有这样的设备。此外,计算机可被嵌入在另一设备中,所述另一设备例如是移动电话、个人数字助理(PDA)、移动音频或视频播放器、游戏控制器、全球定位系统(GPS)接收器或便携式存储设备,例如通用串行总线(USB)闪存驱动器等。
适合于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储设备,作为示例包括半导体存储设备,例如EPROM、EEPROM和闪速存储器设备;磁盘,例如内部硬盘或可移动盘;磁光盘;以及CD ROM和DVD-ROM盘。
为了提供与用户的交互,可在计算机上实现本说明书中描述的主题的实施例,所述计算机具有用于向用户显示信息的显示设备以及用户可用来向该计算机提供输入的键盘和定点设备,所述显示设备例如是CRT(阴极射线管)或LCD(液晶显示器)监视器,所述定点设备例如是鼠标或轨迹球。其它种类的设备也可用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的感觉反馈,例如视觉反馈、听觉反馈或触觉反馈;并且可以以任何形式接收来自用户的输入,包括声、语音或触觉输入。此外,计算机可通过向由用户使用的设备发送文档并从由用户使用的设备接收文档来与用户交互;例如,通过响应于从web浏览器接收到请求而向用户的设备上的web浏览器发送网页。另外,计算机可通过向个人设备发送文本消息或其它形式的消息并且继而从用户接收响应消息来与用户交互,所述个人设备例如是正在运行消息传送应用的智能电话。
用于实现机器学习模型的数据处理装置还可以包括例如用于处理机器学习训练或生产(即推理、工作负载)的公共和计算密集部分的专用硬件加速器单元。
可以使用机器学习框架来实现和部署机器学习模型。所述机器学习框架例如是TensorFlow框架、Microsoft Cognitive Toolkit框架、Apache Singa框架或Apache MXNet框架。
本说明书中描述的主题的实施例可被实现在计算系统中,所述计算系统包括后端组件,例如作为数据服务器;或者包括中间件组件,例如应用服务器;或者包括前端组件,例如具有用户可用来与本说明书中描述的主题的实现方式交互的图形用户界面、web浏览器或app的客户端计算机;或者包括一个或多个这样的后端、中间件或前端组件的任何组合。系统的组件可通过例如通信网络的任何形式或介质的数字数据通信来互连。通信网络的示例包括局域网(LAN)和广域网(WAN),例如互联网。
计算系统可包括客户端和服务器。客户端和服务器一般地彼此远离并通常通过通信网络来交互。客户端和服务器的关系借助于在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序而产生。在一些实施例中,服务器向用户设备传输例如HTML页面的数据例如以用于向与作为客户端的设备交互的用户显示数据并从该用户接收用户输入的目的。可在服务器处从设备接收在用户设备处生成的数据,例如,用户交互的结果。
虽然本说明书包含许多具体实现方式细节,但是这些不应该被解释为对任何发明的或可能要求保护的范围的限制,而是相反地被解释为对可能特定于特定发明的特定实施例的特征的描述。也可在单个实施例中组合地实现在本说明书中在单独的实施例的上下文中描述的某些特征。相反地,也可单独地或者按照任何适合的子组合在多个实施例中实现在单个实施例的上下文中描述的各种特征。此外,尽管特征可能在上面被描述按照某些组合起作用并且甚至最初被如此要求保护,但是来自要求保护的组合的一个或多个特征可在一些情况下被从该组合中除去,并且所要求保护的组合可以针对子组合或子组合的变化。
类似地,虽然按照特定次序在附图中描绘并在权利要求书中记载操作,但是这不应该被理解为要求按照所示的特定次序或者以先后次序执行这样的操作,或者要求执行所有图示的操作以实现所预期的结果。在一些情况下,多任务处理和并行处理可以是有利的。此外,上述实施例中的各种系统模块和组件的分离不应该被理解为在所有实施例中要求这样的分离,并且应该理解的是,所描述的程序组件和系统一般地可被一起集成在单个软件产品中或者包装到多个软件产品中。
已描述了主题的特定实施例。其它实施例在所附权利要求的范围内。例如,权利要求中记载的动作可被以不同的次序执行并仍然实现所预期的结果。作为一个示例,附图中描绘的过程不一定要求所示的特定次序或顺序次序以实现所预期的结果。在一些情况下,多任务处理和并行处理可以是有利的。

Claims (32)

1.一种由一个或多个数据处理装置执行的用于训练分类神经网络的方法,所述方法包括:
对于多个网络输入中的每一个网络输入:
根据分类神经网络参数的当前值,使用所述分类神经网络处理所述网络输入,以生成定义所述网络输入的预测类别的分类输出;
对于包括来自所述多个网络输入的第一网络输入和第二网络输入的多对网络输入中的每一对网络输入,基于以下两者之间的相应相似性度量来确定软最近邻损失:
(i)所述第一网络输入的中间表示,所述第一网络输入的中间表示由所述分类神经网络的一个或多个隐藏层通过处理所述第一网络输入以生成用于所述第一网络输入的分类输出而生成,以及
(ii)所述第二网络输入的中间表示,所述第二网络输入的中间表示由所述分类神经网络的一个或多个隐藏层通过处理所述第二网络输入以生成用于所述第二网络输入的分类输出而生成;
其中,所述软最近邻损失鼓励不同类别的网络输入的中间表示变得更加纠缠,其中,不同类别的网络输入的中间表示的所述纠缠表征了不同类别的网络输入的中间表示对相对于同一类别的网络输入的中间表示对的相似程度;以及
使用所述软最近邻损失相对于所述分类神经网络参数的梯度来调整所述分类神经网络参数的所述当前值。
2.如权利要求1所述的方法,其中,确定所述软最近邻损失包括:
对于所述多个网络输入中的每个给定网络输入:
确定所述给定网络输入的类内变化,所述给定网络输入的类内变化表征所述给定网络输入的所述中间表示与属于与所述给定网络输入同一类别的所述多个网络输入中的其他网络输入的中间表示的相似程度;
确定所述给定网络输入的总变化,所述给定网络输入的总变化表征所述给定网络输入的所述中间表示与属于任何类别的所述多个网络输入中的其他网络输入的中间表示的相似程度;以及
确定所述给定网络输入的所述类内变化与所述总变化的比率;以及
基于用于每个给定网络输入的所述类内变化与所述总变化的相应比率来确定所述软最近邻损失。
3.如权利要求2所述的方法,其中,确定用于所述给定网络输入的所述类内变化包括确定:
Figure FDA0003172800350000021
其中,j索引所述多个网络输入中的所述网络输入,b是所述多个网络输入中的网络输入的总数,i是所述给定网络输入的索引,yi表示所述给定网络输入的类别,yj表示对应于索引j的所述网络输入的类别,xi表示所述给定网络输入的所述中间表示,xj表示对应于索引j的所述网络输入的所述中间表示,S(·,·)是相似性度量,以及T是温度参数。
4.如权利要求2-3中的任一项所述的方法,其中,确定用于所述给定网络输入的所述总变化包括确定:
Figure FDA0003172800350000022
其中,j索引所述多个网络输入中的所述网络输入,b是所述多个网络输入中的网络输入的总数,i是所述给定网络输入的索引,xi表示所述给定网络输入的所述中间表示,xj表示对应于索引j的所述网络输入的所述中间表示,S(·,·)是相似性度量,以及T是温度参数。
5.如权利要求2-4中的任一项所述的方法,其中,基于用于每个给定网络输入的所述类内变化与所述总变化的所述相应比率来确定所述软最近邻损失包括将所述软最近邻损失确定为:
Figure FDA0003172800350000031
其中,b是所述多个网络输入中的网络输入的总数,i索引所述给定网络输入,以及Ri表示用于对应于索引i的所述给定网络输入的所述类内变化与所述总变化的所述比率。
6.如权利要求3-5中的任一项所述的方法,其中,使用所述软最近邻损失相对于所述分类神经网络参数的梯度来调整所述分类神经网络参数的所述当前值包括:
使用所述软最近邻损失相对于所述温度参数的梯度来调整所述温度参数的当前值。
7.如任一项前述权利要求所述的方法,其中,定义所述网络输入的所述预测类别的所述分类输出包括用于多个可能类别中的每一个的相应似然性得分,其中,用于给定类别的所述似然性得分指示所述网络输入属于所述给定类别的似然性。
8.如任一项前述权利要求所述的方法,进一步包括:
基于定义每个网络输入的所述预测类别的相应分类输出来确定分类损失;以及
使用所述分类损失相对于所述分类神经网络参数的梯度来调整所述分类神经网络参数的所述当前值。
9.如权利要求8所述的方法,其中,所述分类损失包括交叉熵损失。
10.如任一项前述权利要求所述的方法,其中,所述多对网络输入包括每个可能的网络输入对,所述每个可能的网络输入对包括来自所述多个网络输入的第一网络输入和第二不同网络输入。
11.如任一项前述权利要求所述的方法,其中,所述网络输入是图像。
12.如权利要求11所述的方法,其中,图像的类别定义在所述图像中描绘的对象的分类。
13.一种由一个或多个数据处理装置执行的用于训练生成神经网络以基于真实数据元素的训练数据集合来生成合成数据元素的方法,所述方法包括:
根据生成神经网络参数的当前值,使用所述生成神经网络来生成合成数据元素集合;
从所述真实数据元素的训练数据集合中获得真实数据元素集合;
对于包括第一数据元素和第二数据元素的多个数据元素对中的每一个数据元素对,基于所述第一数据元素与所述第二数据元素之间的相应相似性度量来确定软最近邻损失,所述第一数据元素和第二数据元素来自包括所述合成数据元素集合和所述真实数据元素集合的组合数据元素集合,其中:
所述软最近邻损失鼓励不同类别的数据元素变得更加纠缠,
数据元素的所述类别定义所述数据元素是真实数据元素还是合成数据元素,以及
不同类别的数据元素的所述纠缠表征了不同类别的数据元素对相对于同一类别的数据元素对的相似程度;以及
使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的所述当前值。
14.如权利要求13所述的方法,其中,确定所述软最近邻损失包括:
对于来自包括所述合成数据元素集合和所述真实数据元素集合的所述组合数据元素集合中的每个给定数据元素:
确定用于所述给定数据元素的类内变化,所述给定数据元素的类内变化表征所述给定数据元素与来自属于与所述给定数据元素同一类别的所述组合数据元素集合中的其他数据元素的相似程度;
确定用于所述给定数据元素的总变化,所述给定数据元素的总变化表征所述给定数据元素与来自属于任何类别的所述组合数据元素集合中的其他数据元素的相似程度;以及
确定所述给定数据元素的所述类内变化与所述总变化的比率;以及
基于用于每个给定数据元素的所述类内变化与所述总变化的相应比率来确定所述软最近邻损失。
15.如权利要求14所述的方法,其中,确定用于所述给定数据元素的所述类内变化包括确定:
Figure FDA0003172800350000051
其中,j索引来自所述组合数据元素集合的所述数据元素,b是所述组合数据元素集合中的数据元素的总数,i是所述给定数据元素的索引,yi表示所述给定数据元素的所述类别,yj表示对应于索引j的所述数据元素的所述类别,xi表示所述给定数据元素,xj表示对应于索引j的所述数据元素,S(·,·)是相似性度量,以及T是温度参数。
16.如权利要求14-15中的任一项所述的方法,其中,确定用于所述给定数据元素的所述总变化包括确定:
Figure FDA0003172800350000052
其中,j索引来自所述组合数据元素集合的所述数据元素,b是所述组合数据元素集合的数据元素的总数,i是所述给定数据元素的索引,xi表示所述给定数据元素,xj表示对应于索引j的所述数据元素,S(·,·)是相似性度量,S(·,·)是相似性度量,以及T是温度参数。
17.如权利要求14-16中的任一项所述的方法,其中,基于用于每个给定数据元素的所述类内变化与所述总变化的相应比率来确定所述软最近邻损失包括将所述软最近邻损失确定为:
Figure FDA0003172800350000061
其中,b是所述组合数据元素集合中的数据元素的总数,i索引所述给定数据元素,以及Ri表示用于对应于索引i的所述给定数据元素的所述类内变化与所述总变化的所述比率。
18.如权利要求15-17中的任一项所述的方法,其中,使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的所述当前值包括:
使用所述软最近邻损失相对于所述温度参数的梯度来调整所述温度参数的当前值。
19.如权利要求13所述的方法,其中,所述数据元素是图像。
20.一种由一个或多个数据处理装置执行的用于训练生成神经网络以基于真实数据元素的训练数据集合来生成合成数据元素的方法,所述方法包括:
根据生成神经网络参数的当前值,使用所述生成神经网络来生成合成数据元素集合;
从所述真实数据元素的训练数据集合中获得真实数据元素集合;
对于包括所述合成数据元素集合和所述真实数据元素集合的组合数据元素集合中的每个数据元素,使用判别器神经网络来生成所述数据元素的嵌入;
对于包括第一数据元素和第二数据元素的多个数据元素对中的每一个数据元素对,基于所述第一数据元素的嵌入和所述第二数据元素的嵌入之间的相应相似性度量来确定软最近邻损失,所述第一数据元素和第二数据元素来自包括所述合成数据元素集合和所述真实数据元素集合的所述组合数据元素集合,其中:
所述软最近邻损失鼓励不同类别的数据元素的嵌入变得更加纠缠,
数据元素的所述类别定义所述数据元素是真实数据元素还是合成数据元素,以及
不同类别的数据元素的嵌入的所述纠缠表征了不同类别的数据元素对的嵌入相对于同一类别的数据元素对的嵌入的相似程度;以及
使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的所述当前值。
21.如权利要求20所述的方法,其中,确定所述软最近邻损失包括:
对于来自包括所述合成数据元素集合和所述真实数据元素集合的所述组合数据元素集合中的每个给定数据元素:
确定用于所述给定数据元素的类内变化,所述给定数据元素的类内变化表征所述给定数据元素的所述嵌入与来自属于与所述给定数据元素同一类别的所述组合数据元素集合的其他数据元素的嵌入的相似程度;
确定用于所述给定数据元素的总变化,所述给定数据元素的总变化表征所述给定数据元素的所述嵌入与来自属于任何类别的所述组合数据元素集合的其他数据元素的嵌入的相似程度;以及
确定用于所述给定数据元素的所述类内变化与所述总变化的比率;以及
基于用于每个给定数据元素的所述类内变化与所述总变化的相应比率来确定所述软最近邻损失。
22.如权利要求21所述的方法,其中,确定用于所述给定数据元素的所述类内变化包括确定:
Figure FDA0003172800350000071
其中,j索引来自所述组合数据元素集合的所述数据元素,b是所述组合数据元素集合中的数据元素的总数,i是所述给定数据元素的索引,yi表示所述给定数据元素的所述类别,yj表示对应于索引j的所述数据元素的所述类别,E(xi)表示所述给定数据元素的所述嵌入,E(xj)表示对应于索引j的所述数据元素的所述嵌入,S(·,·)是相似性度量,以及T是温度参数。
23.如权利要求21-22中的任一项所述的方法,其中,确定所述给定数据元素的所述总变化包括确定:
Figure FDA0003172800350000081
其中,j索引来自所述组合数据元素集合的所述数据元素,b是所述组合数据元素集合中的数据元素的总数,i是所述给定数据元素的索引,E(xi)表示所述给定数据元素的所述嵌入,E(xj)表示对应于索引j的所述数据元素的所述嵌入,S(·,·)是相似性度量,以及T是温度参数。
24.如权利要求21-23中的任一项所述的方法,其中,基于用于每个给定数据元素的所述类内变化与所述总变化的相应比率来确定所述软最近邻损失包括将所述软最近邻损失确定为:
Figure FDA0003172800350000082
其中,b是所述组合数据元素集合中的数据元素的总数,i索引所述给定数据元素,以及Ri表示用于对应于索引i的所述给定数据元素的所述类内变化与所述总变化的所述比率。
25.如权利要求22-24中的任一项所述的方法,其中,使用所述软最近邻损失相对于所述生成神经网络参数的梯度来调整所述生成神经网络参数的所述当前值包括:
使用所述软最近邻损失相对于所述温度参数的梯度来调整所述温度参数的当前值。
26.如权利要求20至25中的任一项所述的方法,其中,所述数据元素是图像。
27.如权利要求20至26中的任一项所述的方法,进一步包括:
使用所述软最近邻损失相对于所述判别器神经网络参数的梯度来调整所述判别器神经网络参数的当前值。
28.如权利要求27所述的方法,其中,调整所述判别器神经网络参数的所述当前值鼓励所述判别器神经网络生成纠缠较少的不同类别的数据元素的嵌入。
29.一种由一个或多个数据处理装置执行的用于对数据进行分类的方法,所述方法包括:
向分类神经网络提供输入数据,所述分类神经网络已经通过执行如权利要求1-12中的任一项所述的方法进行了训练;
使用所述分类神经网络对所述输入数据进行分类;以及
接收来自所述分类神经网络的分类输出,所述输出指示所述输入数据的类别。
30.一种由一个或多个数据处理装置执行的用于生成合成数据的方法,所述方法包括:
向生成神经网络提供输入数据,所述生成神经网络已经通过执行如权利要求13-28中的任一项所述的方法进行了训练;
使用所述生成神经网络,基于所述输入数据生成合成数据;以及
从所述生成神经网络接收所述合成数据。
31.一种包括一个或多个计算机和一个或多个存储设备的系统,所述存储设备存储指令,所述指令当由所述一个或多个计算机执行时,使所述一个或多个计算机执行如权利要求1-30中的任一项所述的方法的相应操作。
32.一种存储指令的计算机程序产品,所述指令当由一个或多个计算机执行时,使所述一个或多个计算机执行如权利要求1-30中的任一项所述的方法的相应操作。
CN202080010180.5A 2019-01-23 2020-01-22 使用软最近邻损失的神经网络训练 Pending CN113330462A (zh)

Applications Claiming Priority (3)

Application Number Priority Date Filing Date Title
US201962796001P 2019-01-23 2019-01-23
US62/796,001 2019-01-23
PCT/US2020/014571 WO2020154373A1 (en) 2019-01-23 2020-01-22 Neural network training using the soft nearest neighbor loss

Publications (1)

Publication Number Publication Date
CN113330462A true CN113330462A (zh) 2021-08-31

Family

ID=69724091

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202080010180.5A Pending CN113330462A (zh) 2019-01-23 2020-01-22 使用软最近邻损失的神经网络训练

Country Status (4)

Country Link
US (1) US11941867B2 (zh)
EP (1) EP3732632A1 (zh)
CN (1) CN113330462A (zh)
WO (1) WO2020154373A1 (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11210775B1 (en) 2020-09-18 2021-12-28 International Business Machines Corporation Gradient-embedded video anomaly detection
CN113114633A (zh) * 2021-03-24 2021-07-13 华南理工大学 一种入侵检测系统对抗攻击防御方法、系统、装置及介质
KR20230138294A (ko) * 2022-03-23 2023-10-05 주식회사 Lg 경영개발원 검사 성능을 유지하기 위한 메모리 기반 비전 검사 장치 및 그 방법

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2020104590A2 (en) * 2018-11-21 2020-05-28 Deepmind Technologies Limited Aligning sequences by generating encoded representations of data items
US10453197B1 (en) * 2019-02-18 2019-10-22 Inception Institute of Artificial Intelligence, Ltd. Object counting and instance segmentation using neural network architectures with image-level supervision

Also Published As

Publication number Publication date
US11941867B2 (en) 2024-03-26
EP3732632A1 (en) 2020-11-04
US20220101624A1 (en) 2022-03-31
WO2020154373A1 (en) 2020-07-30

Similar Documents

Publication Publication Date Title
US11651259B2 (en) Neural architecture search for convolutional neural networks
US10817805B2 (en) Learning data augmentation policies
Mao et al. Predicting remaining useful life of rolling bearings based on deep feature representation and long short-term memory neural network
EP3711000B1 (en) Regularized neural network architecture search
CN110929114A (zh) 利用动态记忆网络来跟踪数字对话状态并生成响应
CN116261731A (zh) 基于多跳注意力图神经网络的关系学习方法与系统
CN113330462A (zh) 使用软最近邻损失的神经网络训练
US20220121934A1 (en) Identifying neural networks that generate disentangled representations
Azzouz et al. Steady state IBEA assisted by MLP neural networks for expensive multi-objective optimization problems
Napoli et al. An agent-driven semantical identifier using radial basis neural networks and reinforcement learning
JP6172317B2 (ja) 混合モデル選択の方法及び装置
CN114298851A (zh) 基于图表征学习的网络用户社交行为分析方法、装置及存储介质
Chivukula et al. Game theoretical adversarial deep learning with variational adversaries
WO2023174064A1 (zh) 自动搜索方法、自动搜索的性能预测模型训练方法及装置
US20240020531A1 (en) System and Method for Transforming a Trained Artificial Intelligence Model Into a Trustworthy Artificial Intelligence Model
Chugh Mono-surrogate vs multi-surrogate in multi-objective Bayesian optimisation
CN114936890A (zh) 一种基于逆倾向加权方法的反事实公平的推荐方法
Tencer et al. TITS-FM: Transductive incremental Takagi-Sugeno fuzzy models
CN114358364A (zh) 一种基于注意力机制的短视频点击率大数据预估方法
CN112861601A (zh) 生成对抗样本的方法及相关设备
Wang et al. DualMatch: Robust Semi-supervised Learning with Dual-Level Interaction
Li et al. Substep active deep learning framework for image classification
Rongali et al. Parameter optimization of support vector machine by improved ant colony optimization
Alshmrany LFD-CNN: Levy flight distribution based convolutional neural network for an adaptive learning style prediction in E-learning environment
CN117010480A (zh) 模型训练方法、装置、设备、存储介质及程序产品

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination