CN109952583A - 神经网络的半监督训练 - Google Patents
神经网络的半监督训练 Download PDFInfo
- Publication number
- CN109952583A CN109952583A CN201780070359.8A CN201780070359A CN109952583A CN 109952583 A CN109952583 A CN 109952583A CN 201780070359 A CN201780070359 A CN 201780070359A CN 109952583 A CN109952583 A CN 109952583A
- Authority
- CN
- China
- Prior art keywords
- training
- marked
- training program
- unmarked
- input
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F17/00—Digital computing or data processing equipment or methods, specially adapted for specific functions
- G06F17/10—Complex mathematical operations
- G06F17/16—Matrix or vector computation, e.g. matrix-matrix or matrix-vector multiplication, matrix factorization
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F17/00—Digital computing or data processing equipment or methods, specially adapted for specific functions
- G06F17/10—Complex mathematical operations
- G06F17/18—Complex mathematical operations for evaluating statistical data, e.g. average values, frequency distributions, probability functions, regression analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Abstract
用于训练神经网络的方法、系统和装置,包括在计算机存储介质上编码的计算机程序。一种方法包括:获得一批次的已标记训练项目和一批次的未标记训练项目;使用所述神经网络并根据网络参数的当前值来处理所述已标记训练项目和所述未标记训练项目,以生成相应的嵌入;确定多个相似性值,每个相似性值测量用于相应的已标记训练项目的嵌入与用于相应的未标记训练项目的嵌入之间的相似性;确定用于多个往返路径中的每一个的相应的往返路径概率;以及执行神经网络训练过程的迭代,以确定对所述网络参数的当前值的第一值更新,所述第一值更新减小不正确的往返路径的往返路径概率。
Description
技术领域
本说明书涉及一种训练神经网络。
背景技术
神经网络是采用非线性单元的一个或多个层来针对接收到的输入预测输出的机器学习模型。一些神经网络除了包括输出层之外还包括一个或多个隐藏层。每个隐藏层的输出被用作网络中的下一个层(即,下一个隐藏层或输出层)的输入。网络的每个层根据一相应组的参数的当前值来从接收到的输入生成输出。
发明内容
本说明书描述了作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统,所述系统训练神经网络,所述神经网络具有网络参数并且被配置为接收输入数据项目并处理输入数据项目,以根据网络参数来生成输入数据项目的嵌入。特别地,系统以半监督方式(即,使用已标记训练项目和未标记训练项目两者)训练神经网络。
在一些方面中,方法和相对应的系统包括:获得已标记训练项目的标记批次,其中标记批次中的每个已标记训练项目与识别已标记训练项目所属的类别的相应的标记相关联;使用神经网络并根据网络参数的当前值来处理标记批次中的已标记训练项目,以针对已标记训练项目中的每一个生成相应的嵌入;获得未标记训练项目的未标记批次;使用神经网络并根据网络参数的当前值来处理未标记批次中的未标记训练项目,以针对未标记训练项目中的每一个生成相应的嵌入;确定多个相似性值,每个相似性值测量用于相应的已标记训练项目的嵌入与用于相应的未标记训练项目的嵌入之间的相似性;根据相似性值确定用于多个往返路径中的每一个的相应的往返路径概率,每个往返路径从用于相应的起始已标记训练项目的嵌入开始,转到用于相应的未标记训练项目的嵌入,并且返回到用于相应的结束已标记训练项目的嵌入;以及执行神经网络训练过程的迭代,以确定对网络参数的当前值的第一值更新,所述第一值更新减小不正确的往返路径的往返路径概率,其中不正确的往返路径是起始已标记训练项目和结束已标记训练项目具有不同的标记的往返路径。
可选地,该方法然后包括:提供指定经训练的神经网络的数据,以用于生成新输入数据项目的嵌入。
在一些实现方式中为了减小往返路径概率,训练可以采用损失项目,后面被称为沃克(walker)损失项目,其依赖于目标分布与一个或多个往返概率的和之间的差异。例如通过在目标分布中给这些不正确的路径指派比正确的路径低的概率,例如,通过给不正确的路径指派零概率,可以选取目标分布以阻止与正确的往返路径有关的不正确的往返路径,所述正确的往返路径在具有相同标记的已标记训练项目处开始和结束。用于第一值更新的一个适合的损失项目是目标分布与往返概率的和之间的交叉熵损失项目。训练可以包括包含有损失项目的基于梯度上升或下降的过程,其目的旨在使此项最小化。
在一些实现方式中确定路径的往返概率可以包括:对于往返路径确定从用于起始已标记训练项目的嵌入到用于未标记训练项目的嵌入的前向路径概率并且确定从用于未标记训练项目的嵌入到用于结束已标记训练项目的嵌入的后向路径概率。可以基于两个嵌入之间的确定的相似性来确定路径概率,例如,根据这些嵌入之间的点积或者使用这些嵌入之间的距离的量度来计算。
训练可以包括第二访问损失项目以鼓励方法包括或者访问用于未标记训练输入的嵌入。这旨在鼓励方法通过使用更多的未标记训练输入来更好地扩大运用(generalize)。访问损失项目可以增加在用于未标记训练输入的嵌入的访问概率上的一致性,其中用于给定未标记训练输入的嵌入的访问概率是转到用于未标记训练输入的嵌入的前向路径的前向路径概率的和。因此访问损失项目可以旨在使访问概率与一致概率分布之间的差异最小化。一个适合的损失项目是交叉熵损失项目。训练过程因此可以确定对网络参数的当前值的第二值更新。通过组合沃克损失项目和访问损失项目可以同时执行第一值更新和第二值更新。
可选地,第一(沃克)损失项目和第二(访问)损失项目可以在组合损失项目上具有不相等的权重。例如在方法被采用数据集或域中的标记相同但是标记的分布可以不同的域自适应的情况下,访问损失项目可以被赋予比沃克项更低的权重。域自适应的示例是使用来自一个域的诸如手写数字的标记,在另一域中使用诸如街道视图门牌号码的标记。
嵌入的创建本身是有用的,例如以为执行如后面所描述的各式各样的任务中的任一个的机器学习系统提供前端。然而神经网络还可以被配置为处理数据项目的嵌入,以针对数据项目生成分类输出,所述分类输出包括针对多个类别中的每一种的相应的分值。因此可以在执行训练过程时包括另一个分类损失项目,例如其目的是使用于训练项目的标记与其分类输出之间的另外的交叉熵损失最小化。可以与嵌入同时训练分类,或者可以存在使用分类损失的某种预先训练。
能够实现本说明书中描述的主题的特定实施例以便实现以下优点中的一个或多个。通过像本说明书中所描述的那样训练神经网络,经训练的神经网络能够生成嵌入,所述嵌入准确地反映输入数据项目之间的相似性并且能够有效地用于对输入数据项目进行分类,即,通过经由一个或多个附加神经网络层处理所生成的嵌入。特别地,通过像本说明书中所描述的那样训练神经网络,能够将未标记训练数据有效地并入到训练中,从而改进经训练的神经网络的性能。也就是说,能够在不需要附加已标记训练数据的情况下改进经训练的神经网络的性能。因为未标记训练数据一般地比已标记训练数据更容易获得,所以能够在不大大地增加获得或生成附加训练数据的时间和计算成本的情况下改进训练的有效性。附加地,即使能够通过像本说明书中所描述的那样将容易获得的未标记训练数据有效地并入到训练中获得仅相对少量的已标记训练数据,也能够有效地训练神经网络。因此,训练神经网络以生成准确的嵌入变得不太依赖于准确地标记的训练数据的可用性。
在下面的附图和描述中阐述本说明书中描述的主题的一个或多个实施例的细节。主题的其它特征、方面和优点将根据说明书、附图和权利要求书变得显而易见。
附图说明
图1示出示例神经网络训练系统。
图2是用于训练神经网络的示例过程的流程图。
图3示出嵌入之间的示例往返路径。
图4是用于确定往返路径概率的示例过程的流程图。
在各个附图中相似的附图标记和名称指示相似的元件。
具体实施方式
图1示出示例神经网络训练系统100。神经网络训练系统100是作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统的示例,其中能够实现下面描述的系统、组件和技术。
神经网络训练系统100是在已标记训练数据140和未标记训练数据150上训练神经网络110,以根据网络参数的初始值来确定神经网络110的参数(在本说明书中称为网络参数)的训练值的系统。
神经网络110是被配置为接收输入数据项目102并且被配置为处理该输入数据项目以根据网络参数来生成输入数据项目102的嵌入112的神经网络。一般地,数据项目的嵌入是例如矢量的数值的有序合集,该有序合集表示数据项目。换句话说,每个嵌入是多维嵌入空间中的点。一旦被训练,嵌入在多维空间中的位置就能够反映这些嵌入表示的数据项目之间的相似性。
神经网络110能够被配置为接收任何类型的数字数据输入作为输入并且从该输入生成嵌入。例如,输入数据项目能够是图像、文档的部分、文本序列、音频数据等中的任一种。
在一些情况下,神经网络110还能够被配置为通过经由一个或多个附加神经网络层(例如,经由一个或多个全连接层和输出层)来处理数据项目的嵌入,从而生成针对输入数据项目的分类输出。
例如,如果到神经网络100的输入是图像,则针对给定图像由神经网络110生成的分类输出可以是针对一组对象类别中的每一个的分值,其中每个分值表示图像包含属于类别的对象的图像的估计可能性。
作为另一示例,如果到神经网络110的输入是来自互联网资源(例如,web页面)或文档的文本,则针对给定互联网资源、文档、或文档的部分由神经网络110生成的分类输出可以是针对一组主题中的每一个的分值,其中每个分值表示互联网资源、文档、或文档部分是关于主题的估计可能性。
神经网络110能够具有适于由神经网络110处理的网络输入的类型的任何架构。例如,当网络输入是图像时,神经网络110能够是卷积神经网络。
一旦被训练,由神经网络110生成的嵌入就能够被用于各种目的中的任一种,即,除了用于生成分类输出之外或者替代用于生成分类输出。
例如,系统100能够将由经训练的神经网络生成的嵌入作为输入提供给另一系统,作为相对应的网络输入的特征,例如,用于在对网络输入执行机器学习任务时使用。示例任务可以包括基于特征的检索、聚类、近似重复检测、验证、特征匹配、域自适应、基于视频的弱监督学习等。作为另外的示例,系统还可以被用于对象或动作辨识/检测、图像分割、从静止或运动图像中的视觉概念提取、控制任务、识别诸如药物的化学品的特性或性质、以及机器翻译。
由系统100使用以训练神经网络110的已标记训练数据140包括多个批次已标记训练项目。该训练项目被称为“已标记”训练项目是因为已标记训练数据140对于每个已标记训练项目还包括识别已标记训练项目所属的类别(例如,对象类或主题)的标记。
由系统100使用以训练神经网络110的未标记训练数据150包括多个批次未标记训练项目。训练项目被称为“未标记”训练项目是因为系统100不具有对用于未标记训练项目中的任一个的任何标记的访问。
系统100通过执行神经网络训练过程的多次迭代来训练神经网络110,其中对一批次的已标记训练项目和一批次的未标记训练项目两者执行过程的每次迭代。在每次迭代期间,系统100确定对从迭代时起对网络参数的当前值的更新并且将该更新应用于当前值,以生成网络参数的更新值。
在下面参考图2至图4更详细地描述使用已标记和未标记训练项目来训练神经网络。
一旦已经训练神经网络,系统100就提供指定经训练的神经网络的数据以供在处理新网络输入时使用。也就是说,系统100能够(例如,通过向用户设备输出或者通过在对系统100可访问的存储器中存储)输出网络参数的训练值,以供稍后在使用经训练的神经网络来处理输入时使用。可替选地或者除了输出经训练的神经网络数据之外,系统100还能够实例化具有网络参数的训练值的神经网络的实例,例如,通过由系统提供的应用编程接口(API)来接收要处理的输入,使用经训练的神经网络来处理所接收的输入以生成嵌入、分类输出或两者,然后响应于所接收的输入而提供所生成的嵌入、分类输出或两者。
图2是用于在一批次的未标记训练项目和一批次的已标记训练项目上训练神经网络的示例过程200的流程图。为了方便,将过程200描述为由位于一个或多个位置中的一个或多个计算机的系统来执行。例如,适当地编程的神经网络训练系统(例如,图1的神经网络训练系统100)能够执行过程200。
系统能够针对多个不同的标记批次-未标记批次组合多次执行过程200,以根据网络参数的初始值来确定网络参数的训练值。例如,系统能够继续执行过程200指定的迭代次数、指定的时间量、或者直到参数的值变化下降至阈值以下。
系统获得一批次的已标记训练项目(步骤202)。
系统使用神经网络并根据网络参数的当前值来处理批次中的每个已标记训练项目,以针对所标记的训练项目中的每一个生成相应的嵌入(步骤204)。
系统获得未标记训练项目的未标记批次(步骤206)。
系统使用神经网络并根据网络参数的当前值来处理批次中的每个未标记训练项目,以针对未标记训练项目中的每一个生成相应的嵌入(步骤208)。
对于已标记训练项目和未标记训练项目的每个可能的组合,系统确定组合中的已标记训练项目的嵌入与组合中的未标记训练项目的嵌入之间的相似性值(步骤210)。每个相似性值测量一个嵌入与另一嵌入之间的相似性。例如,相似性能够是两个嵌入之间的点积相似性、两个嵌入之间的欧几里德距离、或另一适当的相似性度量。
系统根据相似性值确定多个往返路径中的每一个的相应的往返路径概率(步骤212)。每个往返路径从相应的起始已标记训练项目的嵌入开始,转到用于相应的未标记训练项目的嵌入,并且返回到用于相应的结束已标记训练项目的嵌入。
作为示例,图3图示示例往返路径,所述示例往返路径根据用于已标记训练项目310的一组示例嵌入在已标记训练项目的嵌入处开始和结束。特别地,图3中图示的往返路径从起始已标记训练项目312的嵌入开始,并且包括从起始已标记训练项目312的嵌入到来自未标记训练项目320的一组嵌入的未标记训练项目322的嵌入的前向路径302。然后,往返路径包括从未标记训练项目322的嵌入到结束已标记训练项目314的嵌入的后向路径304。
在下面参考图4描述确定用于给定往返路径的往返路径概率。
系统执行神经网络训练过程的迭代,以确定对网络参数的当前值的第一更新,所述第一更新减小不正确的往返路径的往返路径概率(步骤214)。
也就是说,神经网络训练过程是常规的基于梯度的过程,例如,随机梯度下降或Adam,并且系统执行迭代以使损失函数最小化,该损失函数包括沃克损失项目,所述沃克损失项目是依赖于用于不正确的往返路径的往返路径概率。
不正确的往返路径是起始已标记训练项目和结束已标记训练项目具有不同的标记的往返路径。也就是说,如果用于给定往返路径的起始已标记训练项目和结束已标记训练项目具有不同的标记,则系统确定给定往返路径是不正确的往返路径。在图3的示例中,如果起始已标记训练项目312和结束已标记训练项目314具有不同的标记,则图3中图示的往返路径将是不正确的往返路径。如果起始已标记训练项目312和结束已标记训练项目314具有相同的标记,则图3中图示的往返路径将不是不正确的往返路径。
特别地,在一些实现方式中,第一(沃克)损失项目是以下各项之间的交叉熵损失:(i)在各自包括相应的第一已标记训练输入和相应的第二已标记训练输入的已标记训练输入对上的目标分布以及(ii)对于已标记训练输入对中的每一个,从该对的第一已标记训练输入处开始并且返回到该对的第二已标记训练输入的往返路径的往返概率之和。也就是说,对于包括已标记训练输入i和已标记训练输入j的对,该和满足:
其中该和是遍及所有未标记训练项目,并且其中是从用于已标记训练输入i的嵌入处开始、转到用于未标记训练输入k的嵌入并且返回到用于已标记训练输入j的嵌入的往返路径的往返路径概率。
为了让第一损失项目阻止不正确的往返路径,目标分布向具有相同标记的训练输入对指派比具有不同标记的训练输入对更高的概率。在一些实现方式中,目标分布向包括具有不同标记的第一已标记训练输入和第二已标记训练输入的已标记训练输入对指派零概率。在这些实现方式中的一些中,目标分布向包括具有相同标记的第一已标记训练输入和第二已标记训练输入的已标记训练输入的每个对指派等于一除以该批次中具有由第一已标记训练输入和第二已标记训练输入共享的标记的已标记训练输入的总数目的概率。因此,在这些实现方式中,第一损失项目鼓励正确的往返路径的(即,在具有相同标记的已标记训练项目处开始和结束的路径的)一致概率分布并且惩罚不正确的往返路径。
损失函数还可以包括其它损失项目,即,系统还能够确定对网络参数的当前值的其它更新,作为执行神经网络训练技术的迭代的部分。
特别地,在一些实现方式中,损失函数还包括访问损失项目,所述访问损失项目测量针对未标记训练输入的嵌入的访问概率的一致性。针对给定未标记训练输入的嵌入的访问概率是转到未标记训练输入的嵌入的前向路径的前向路径概率的和。在图3的示例中,针对未标记训练项目322的嵌入的访问概率将是从用于已标记训练项目310的集合的嵌入到未标记训练项目322的嵌入的前向路径的前向路径概率的和。在下面参考图4更详细地描述确定前向路径概率。
通过执行迭代并使访问损失项目最小化,系统确定对网络参数的当前值的第二更新,所述第二更新增加在未标记训练输入的嵌入的访问概率上的一致性。
在这些示例中的一些中,访问损失项目是(i)在未标记训练输入上的一致目标分布与(ii)用于未标记训练输入的访问概率之间的交叉熵损失。一致目标分布向每个未标记训练输入指派相同的概率,即,每个概率等于一除以未标记训练输入的总数目。
如上所述,在一些实现方式中,神经网络还针对每个输入数据项目生成分类输出。在这些实现方式中,损失函数还能够包括分类损失项目,所述分类损失项目是对于每个已标记训练项目的针对训练项目的标记与针对训练项目的分类之间的交叉熵损失。因此,通过执行迭代,系统使此交叉熵损失最小化以确定对网络参数的当前值的增加由神经网络生成的分类输出的准确性的第三值更新。
作为神经网络训练技术的迭代的部分,系统通过确定损失函数的相对于网络参数的梯度(即,通过反向传播),并且然后根据该梯度来确定对这些参数的当前值的更新而确定对网络参数的当前值的更新。例如,当训练过程是随机梯度下降时,系统通过对梯度应用学习速率来确定这些更新。
系统通过将第一值更新和通过执行该迭代确定的任何其它值更新应用到这些参数的当前值来生成网络参数的更新值(步骤216)。特别地,系统将这些值更新添加到网络参数的当前值,以确定这些网络参数的更新值。
图4是用于确定在用于起始已标记训练项目的嵌入处开始、转到用于特定未标记训练项目的嵌入,并且返回到用于结束已标记训练项目的嵌入的往返的往返路径概率的示例过程400的流程图。为了方便,将过程400描述为由位于一个或多个位置中的一个或多个计算机的系统来执行。例如,适当地编程的神经网络训练系统(例如,图1的神经网络训练系统100)能够执行过程400。
系统确定从用于起始已标记训练项目的嵌入到用于特定未标记训练项目的嵌入的前向路径的前向路径概率(步骤402)。一般地,系统基于用于起始已标记训练项目的嵌入与用于特定未标记训练项目的嵌入之间的相似性值来确定前向路径概率。特别地,从已标记训练项目i的嵌入到未标记训练项目k的嵌入的前向路径概率满足:
其中,Mik是已标记训练项目i的嵌入与未标记训练项目k的嵌入之间的相似性,并且该和是遍及所有未标记训练项目的和。
系统确定从用于特定未标记训练项目的嵌入到用于结束已标记训练项目的嵌入的后向路径的后向路径概率(步骤404)。一般地,系统基于用于特定未标记训练项目的嵌入与用于结束已标记训练项目的嵌入之间的相似性值来确定后向路径概率。特别地,从特定未标记训练项目k的嵌入到已标记训练项目j的嵌入的后向路径概率满足:
其中Mkj是特定未标记训练项目k的嵌入与已标记训练项目j的嵌入之间的相似性,并且该和是遍及所有已标记训练项目的和。
系统根据前向路径概率和后向路径概率来确定往返路径概率(步骤406)。特别地,往返路径概率是前向路径概率和后向路径概率的乘积。本说明书连同系统和计算机程序组件一起使用术语“被配置”。对于要配置为执行特定操作或动作的一个或多个计算机的系统,意味着系统已在其上安装了软件、固件、硬件或其组合,所述软件、固件、硬件或其组合在操作中使系统执行操作或动作。对于要配置为执行特定操作或动作的一个或多个计算机程序意味着一个或多个程序包括指令,所述指令当由数据处理装置执行时,使该装置执行操作或动作。
本说明书中描述的主题和功能操作的实施例能够用数字电子电路、用有形地实施的计算机软件或固件、用计算机硬件(包括本说明书中公开的结构及其结构等同物)或者用它们中的一个或多个的组合加以实现。本说明书中描述的主题的实施例能够作为一个或多个计算机程序(即,在有形非暂时性存储介质上编码以供由数据处理装置执行或者控制数据处理装置的操作的计算机程序指令的一个或多个模块)被实现。计算机存储介质能够是机器可读存储设备、机器可读存储基板、随机或串行存取存储器设备,或它们中的一个或多个的组合。可替选地或此外,能够将程序指令编码在人工生成的传播信号上,所述传播信号例如为机器生成的电、光学或电磁信号,该信号被生成来对信息进行编码以便传输到适合的接收器装置以供由数据处理装置执行。
术语“数据处理装置”指代数据处理硬件并且涵盖用于处理数据的所有类型的装置、设备和机器,作为示例包括可编程处理器、计算机、或多个处理器或计算机。装置还能够是或者进一步包括专用逻辑电路,例如FPGA(现场可编程门阵列)或ASIC(专用集成电路)。装置除了包括硬件之外还能够可选地包括为计算机程序创建执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统、或它们中的一个或多个的组合的代码。
计算机程序(其还可以被称为或者描述为程序、软件、软件应用、app、模块、软件模块、脚本或代码)能够用任何形式的编程语言编写,所述编程语言包括编译或解释语言,或声明性或过程语言,并且它能够以任何形式被部署,包括作为独立程序或者作为模块、组件、子例行程序、或适合于在计算环境中使用的其它单元。程序可以但不必需与文件系统中的文件相对应。能够在保持其它程序或数据(例如,在标记语言文档中存储的一个或多个脚本)的文件的部分中、在专用于所述程序的单个文件中、或者在多个协调文件(例如,存储一个或多个模块、子程序或代码的部分的文件)中存储程序。能够将计算机程序部署为在一个计算机上或者在位于一个站点处或者跨多个站点分布并且通过数据通信网络互连的多个计算机上执行。
在本说明书中,术语“数据库”被广泛地用于指代数据的任何合集:数据不需要被以任何特定方式构造,或者根本不构造,并且它能够被存储在一个或多个位置中的存储设备中。因此,例如,索引数据库能够包括数据的多个合集,所述多个合集的每一个均可以被不同地组织和访问。
类似地,在本说明书中术语“引擎”被广泛地用于指代被编程来执行一个或多个具体功能的基于软件的系统、子系统、或过程。一般地,引擎将作为一个或多个软件模块或组件被实现,将被安装在一个或多个位置中的一个或多个计算机上。在一些情况下,一个或多个计算机将专用于特定引擎;在其它情况下,能够在相同的一个或多个计算机上安装并运行多个引擎。
本说明书中描述的过程和逻辑流程能够通过以下操作来执行:一个或多个可编程计算机执行一个或多个计算机程序,以通过对输入数据进行操作并生成输出来执行功能。过程和逻辑流程还能够通过专用逻辑电路(例如,FPGA和ASIC),或者通过专用逻辑电路和一个或多个编程计算机的组合来执行。
适合于执行计算机程序的计算机能够基于通用微处理器或专用微处理器或两者,或任何其它类型的中央处理单元。一般地,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的必要元件是用于执行或者实行指令的中央处理单元以及用于存储指令和数据的一个或多个存储器设备。中央处理单元和存储器能够由专用逻辑电路补充,或者并入在专用逻辑电路中。一般地,计算机还将包括或者在操作上耦合以从用于存储数据的一个或多个大容量存储设备(例如,磁盘、磁光盘或光盘)接收数据,或者将数据转移到用于存储数据的一个或多个大容量存储设备,或者兼而有之。然而,计算机不必具有此类设备。此外,计算机能够被嵌入在另一设备中,所述另一设备例如为移动电话、个人数字助理(PDA)、移动音频或视频播放器、游戏控制台、全球定位系统(GPS)接收器、或便携式存储设备,例如,通用串行总线(USB)闪速驱动器等等。
适合于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储器设备,作为示例包括半导体存储器设备,例如,EPROM、EEPROM和闪速存储器设备;磁盘,例如内部硬盘或可移动磁盘;磁光盘;以及CD ROM和DVD-ROM盘。
为了提供用于与用户的交互,能够在计算机上实现本说明书中描述的主题的实施例,所述计算机具有用于向用户显示信息的显示设备(例如,CRT(阴极射线管)或LCD(液晶显示器)监视器)以及用户能够通过其来向该计算机提供输入的键盘和指向设备,例如鼠标或轨迹球。其它类型的设备还不够用于提供与用户的交互;例如,提供给用户的反馈能够是例如视觉反馈、听觉反馈或触觉反馈的任何形式的感觉反馈;并且能够任何形式接收来自用户的输入,包括声学、语音或触觉输入。此外,计算机能够通过向由用户使用的设备发送文档并且从由用户使用的设备接收文档来与用户交互;例如,通过响应于从web浏览器接收到的请求而向用户的客户端设备上的web浏览器发送web页面。另外,计算机能够通过向个人设备(例如,正在运行消息传送应用的智能电话)发送文本消息或其它形式的消息并且返回来从用户接收响应消息来与用户交互。
用于实现机器学习模型的数据处理装置还能够包括例如用于处理机器学习训练或生产(即,推理)工作负载的公共和计算密集部分的专用硬件加速器单元。
能够使用机器学习框架(例如,TensorFlow框架、Microsoft Cognitive Toolkit框架、Apache Singa框架、或Apache MXNet框架)来实现和部署机器学习模型。
能够在计算系统中实现本说明书中描述的主题的实施例,所述计算系统包括后端组件(例如,作为数据服务器),或者包括中间件组件(例如,应用服务器),或者包括前端组件(例如,具有用户能够与本说明书中描述的主题的实现方式交互的图形用户界面、web浏览器或app的客户端计算机),或者包括一个或多个此类后端、中间件、或前端组件的任何组合。系统的组件能够通过任何形式或介质的数字数据通信(例如,通信网络)来互连。通信网络的示例包括局域网(LAN)和广域网(WAN),例如互联网。
计算系统能够包括客户端和服务器。客户端和服务器一般地彼此远离并且通常通过通信网络来交互。客户端和服务器的关系借助于在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序而产生。在一些实施例中,服务器向用户设备发送数据(例如,HTML页面),例如,用于向与作为客户端的设备交互的用户显示数据并且从与作为客户端的设备交互的用户接收用户输入的目的。能够在服务器处从设备接收在用户设备处生成的数据,例如,用户交互的结果。
虽然本说明书包含许多具体实现方式细节,但是这些不应当被解释为对任何发明的或可能要求保护的东西的范围构成限制,而是相反被解释为可以特定于特定发明的特定实施例的特征的描述。还能够在单个实施例中相结合地实现在本说明书中在单独的实施例的上下文中描述的某些特征。相反地,还能够单独地或者按照任何适合的子组合在多个实施例中实现在单个实施例的上下文中描述的各种特征。此外,尽管特征可以在上面被描述为按照某些组合起作用并且甚至最初如此要求保护,然而来自要求保护的组合的一个或多个特征能够在一些情况下被从该组合中除去,并且所要求保护的组合可以致力于子组合或子组合的变化。
类似地,虽然按照特定次序在附图中描绘并在权利要求书中叙述操作,但是这不应当被理解为要求按照所示的特定次序或者按照顺序次序执行此类操作,或者执行所有图示的操作以实现所希望的效果。在某些情况下,多任务处理和并行处理可以是有利的。此外,上述的实施例中的各种系统模块和组件的分离不应当被理解为在所有实施例中要求这种分离,并且应当理解的是,所描述的程序组件和系统一般地可被一起集成在单个软件产品中或者包装到多个软件产品中。
已经描述了主题的特定实施例。其它实施例在以下权利要求的范围内。例如,权利要求书中叙述的动作可被以不同的次序执行并仍然实现所希望的结果。作为一个示例,附图中描绘的过程不一定要求所示的特定次序或顺序次序以实现所希望的结果。在某些情况下,多任务处理和并行处理可以是有利的。
Claims (15)
1.一种训练神经网络的方法,所述神经网络具有多个网络参数并且被配置为接收输入数据项目并处理所述输入数据项目以根据所述网络参数来生成所述输入数据项目的嵌入,所述方法包括:
获得已标记训练项目的标记批次,其中,所述标记批次中的每个已标记训练项目与识别所述已标记训练项目所属的类别的相应标记相关联;
使用所述神经网络并根据所述网络参数的当前值来处理所述标记批次中的所述已标记训练项目,以针对每个所述已标记训练项目生成相应的嵌入;
获得未标记训练项目的未标记批次;
使用所述神经网络并根据所述网络参数的当前值来处理所述未标记批次中的所述未标记训练项目,以针对每个所述未标记训练项目生成相应的嵌入;
确定多个相似性值,每个相似性值测量用于相应的已标记训练项目的嵌入与用于相应的未标记训练项目的嵌入之间的相似性;
根据所述相似性值确定用于多个往返路径中的每一个的相应的往返路径概率,其中,每个往返路径从用于相应的起始已标记训练项目的嵌入开始,转到用于相应的未标记训练项目的嵌入,并且返回到用于相应的结束已标记训练项目的嵌入;以及
执行神经网络训练过程的迭代,以确定对所述网络参数的所述当前值的第一值更新,所述第一值更新减小用于不正确的往返路径的往返路径概率,其中,不正确的往返路径是所述起始已标记训练项目和所述结束已标记训练项目具有不同的标记的往返路径。
2.根据权利要求1所述的方法,进一步包括:
提供指定经训练的神经网络的数据,以用于生成新输入数据项目的嵌入。
3.根据权利要求1或2中的任何一项所述的方法,其中,确定用于多个往返路径中的每一个的相应的往返路径概率包括:
基于用于所述往返路径的所述起始已标记训练项目的嵌入与用于所述往返路径的所述未标记训练项目的嵌入之间的所述相似性值来确定从用于所述往返路径的所述起始已标记训练项目的嵌入到用于所述往返路径的所述未标记训练项目的嵌入的前向路径的前向路径概率;
基于用于所述往返路径的所述未标记训练项目的嵌入与用于所述往返路径的所述结束已标记训练项目的嵌入之间的所述相似性值来确定从用于所述往返路径的所述未标记训练项目的嵌入到用于所述往返路径的所述结束已标记训练项目的嵌入的后向路径的后向路径概率;以及
根据所述前向路径概率和所述后向路径概率来确定所述往返路径概率。
4.根据权利要求3所述的方法,其中,执行所述神经网络训练过程的迭代包括:
执行所述迭代以确定对所述网络参数的当前值的第二值更新,所述第二值更新增加在用于未标记训练输入的嵌入的访问概率上的一致性,其中,给定未标记训练输入的嵌入的访问概率是转到用于所述未标记训练输入的所述嵌入的前向路径的前向路径概率的和。
5.根据权利要求4所述的方法,其中,执行所述迭代以确定所述第二值更新包括:
执行所述迭代以使(i)在所述未标记训练输入上的一致目标分布与(ii)从用于所述未标记训练输入的所述访问概率取得的项之间的交叉熵损失最小化。
6.根据权利要求1至5中的任何一项所述的方法,其中,执行所述迭代以确定所述第一值更新包括:
执行所述迭代以使以下各项之间的交叉熵损失最小化:(i)在各自包括相应的第一已标记训练输入和相应的第二已标记训练输入的已标记训练输入对上的目标分布,以及(ii)对于每个已标记训练输入对,在所述对的所述第一已标记训练输入处开始并且返回到所述对的所述第二已标记训练输入的往返路径的往返概率的和。
7.根据权利要求6所述的方法,其中,所述目标分布向包括具有不同标记的第一已标记训练输入和第二已标记训练输入的已标记训练输入对指派零概率。
8.根据权利要求6或7中的任何一项所述的方法,其中,所述目标分布向包括具有相同标记的第一已标记训练输入和第二已标记训练输入的已标记训练输入对指派等于一除以具有所述标记的已标记训练输入的总数目的概率。
9.根据权利要求1至7中的任何一项所述的方法,其中,所述神经网络进一步被配置为处理所述数据项目的所述嵌入,以针对所述数据项目生成分类输出,所述分类输出包括针对多个类别中的每一个的相应分值。
10.根据权利要求9所述的方法,其中,执行所述迭代进一步包括:
执行所述神经网络训练过程的迭代,以使对于每个已标记训练项目的用于所述训练项目的标记与用于所述训练项目的所述分类之间的交叉熵损失最小化。
11.根据权利要求1至10中的任何一项所述的方法,其中,所述神经网络训练过程是随机梯度下降。
12.根据权利要求1至11中的任何一项所述的方法,其中,所述相似性值是点积。
13.根据权利要求1至11中的任何一项所述的方法,其中,所述相似性值是欧几里德距离。
14.一种系统,包括一个或多个计算机和存储指令的一个或多个存储设备,所述指令在由所述一个或多个计算机执行时使所述一个或多个计算机执行根据权利要求1至13中的任何一项所述的相应方法的操作。
15.一个或多个存储指令的计算机存储介质,所述指令在由一个或多个计算机执行时使所述一个或多个计算机执行根据权利要求1至13中的任何一项所述的相应方法的操作。
Applications Claiming Priority (3)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
US201662422550P | 2016-11-15 | 2016-11-15 | |
US62/422,550 | 2016-11-15 | ||
PCT/US2017/061839 WO2018093926A1 (en) | 2016-11-15 | 2017-11-15 | Semi-supervised training of neural networks |
Publications (1)
Publication Number | Publication Date |
---|---|
CN109952583A true CN109952583A (zh) | 2019-06-28 |
Family
ID=60570251
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201780070359.8A Pending CN109952583A (zh) | 2016-11-15 | 2017-11-15 | 神经网络的半监督训练 |
Country Status (3)
Country | Link |
---|---|
US (1) | US11443170B2 (zh) |
CN (1) | CN109952583A (zh) |
WO (1) | WO2018093926A1 (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112215487A (zh) * | 2020-10-10 | 2021-01-12 | 吉林大学 | 一种基于神经网络模型的车辆行驶风险预测方法 |
Families Citing this family (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11521044B2 (en) * | 2018-05-17 | 2022-12-06 | International Business Machines Corporation | Action detection by exploiting motion in receptive fields |
KR102107021B1 (ko) * | 2018-08-10 | 2020-05-07 | 주식회사 딥핑소스 | 데이터를 식별 처리하는 방법, 시스템 및 비일시성의 컴퓨터 판독 가능 기록 매체 |
US11410031B2 (en) * | 2018-11-29 | 2022-08-09 | International Business Machines Corporation | Dynamic updating of a word embedding model |
US20200202210A1 (en) * | 2018-12-24 | 2020-06-25 | Nokia Solutions And Networks Oy | Systems and methods for training a neural network |
US11651209B1 (en) | 2019-10-02 | 2023-05-16 | Google Llc | Accelerated embedding layer computations |
CN110909803B (zh) * | 2019-11-26 | 2023-04-18 | 腾讯科技(深圳)有限公司 | 图像识别模型训练方法、装置和计算机可读存储介质 |
WO2021120028A1 (en) * | 2019-12-18 | 2021-06-24 | Intel Corporation | Methods and apparatus for modifying machine learning model |
US10783257B1 (en) * | 2019-12-20 | 2020-09-22 | Capital One Services, Llc | Use of word embeddings to locate sensitive text in computer programming scripts |
GB202008030D0 (en) * | 2020-05-28 | 2020-07-15 | Samsung Electronics Co Ltd | Learning the prediction distribution for semi-supervised learning with normalising flows |
CN111737487A (zh) * | 2020-06-10 | 2020-10-02 | 深圳数联天下智能科技有限公司 | 一种辅助本体构建的方法、电子设备及存储介质 |
CN112381227B (zh) * | 2020-11-30 | 2023-03-24 | 北京市商汤科技开发有限公司 | 神经网络生成方法、装置、电子设备及存储介质 |
CN113989596B (zh) * | 2021-12-23 | 2022-03-22 | 深圳佑驾创新科技有限公司 | 图像分类模型的训练方法及计算机可读存储介质 |
CN115174251B (zh) * | 2022-07-19 | 2023-09-05 | 深信服科技股份有限公司 | 一种安全告警的误报识别方法、装置以及存储介质 |
Family Cites Families (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US8234228B2 (en) | 2008-02-07 | 2012-07-31 | Nec Laboratories America, Inc. | Method for training a learning machine having a deep multi-layered network with labeled and unlabeled training data |
US8527432B1 (en) * | 2008-08-08 | 2013-09-03 | The Research Foundation Of State University Of New York | Semi-supervised learning based on semiparametric regularization |
US9710463B2 (en) * | 2012-12-06 | 2017-07-18 | Raytheon Bbn Technologies Corp. | Active error detection and resolution for linguistic translation |
BR112017003893A8 (pt) * | 2014-09-12 | 2017-12-26 | Microsoft Corp | Rede dnn aluno aprendiz via distribuição de saída |
AU2016284455A1 (en) * | 2015-06-22 | 2017-11-23 | Myriad Women's Health, Inc. | Methods of predicting pathogenicity of genetic sequence variants |
EP3332375A4 (en) * | 2015-08-06 | 2019-01-16 | Hrl Laboratories, Llc | SYSTEM AND METHOD FOR IDENTIFYING THE INTERESTS OF A USER THROUGH A SOCIAL MEDIA |
US9767565B2 (en) * | 2015-08-26 | 2017-09-19 | Digitalglobe, Inc. | Synthesizing training data for broad area geospatial object detection |
US9275347B1 (en) * | 2015-10-09 | 2016-03-01 | AlpacaDB, Inc. | Online content classifier which updates a classification score based on a count of labeled data classified by machine deep learning |
US10552735B1 (en) * | 2015-10-14 | 2020-02-04 | Trading Technologies International, Inc. | Applied artificial intelligence technology for processing trade data to detect patterns indicative of potential trade spoofing |
US10686829B2 (en) * | 2016-09-05 | 2020-06-16 | Palo Alto Networks (Israel Analytics) Ltd. | Identifying changes in use of user credentials |
US11544529B2 (en) * | 2016-09-07 | 2023-01-03 | Koninklijke Philips N.V. | Semi-supervised classification with stacked autoencoder |
-
2017
- 2017-11-15 US US16/461,287 patent/US11443170B2/en active Active
- 2017-11-15 WO PCT/US2017/061839 patent/WO2018093926A1/en active Application Filing
- 2017-11-15 CN CN201780070359.8A patent/CN109952583A/zh active Pending
Non-Patent Citations (1)
Title |
---|
MEHDI SAJJADI等: "Mutual exclusivity loss for semi-supervised deep learning", 《2016 IEEE INTERNATIONAL CONFERENCE ON IMAGE PROCESSING》, pages 2 - 5 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112215487A (zh) * | 2020-10-10 | 2021-01-12 | 吉林大学 | 一种基于神经网络模型的车辆行驶风险预测方法 |
CN112215487B (zh) * | 2020-10-10 | 2023-05-23 | 吉林大学 | 一种基于神经网络模型的车辆行驶风险预测方法 |
Also Published As
Publication number | Publication date |
---|---|
US20200057936A1 (en) | 2020-02-20 |
US11443170B2 (en) | 2022-09-13 |
WO2018093926A1 (en) | 2018-05-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109952583A (zh) | 神经网络的半监督训练 | |
KR102122373B1 (ko) | 사용자 포트레이트를 획득하는 방법 및 장치 | |
CN109923560A (zh) | 使用变分信息瓶颈来训练神经网络 | |
CN109564575A (zh) | 使用机器学习模型来对图像进行分类 | |
US20180225391A1 (en) | System and method for automatic data modelling | |
US9129190B1 (en) | Identifying objects in images | |
CN109690576A (zh) | 在多个机器学习任务上训练机器学习模型 | |
US11080560B2 (en) | Low-shot learning from imaginary 3D model | |
CN110366734A (zh) | 优化神经网络架构 | |
CN110520871A (zh) | 训练机器学习模型 | |
CN110383308A (zh) | 预测管道泄漏的新型自动人工智能系统 | |
US20210374453A1 (en) | Segmenting objects by refining shape priors | |
US20240029086A1 (en) | Discovery of new business openings using web content analysis | |
US11797839B2 (en) | Training neural networks using priority queues | |
CN109313720A (zh) | 具有稀疏访问的外部存储器的增强神经网络 | |
CN103534697B (zh) | 用于提供统计对话管理器训练的方法和系统 | |
US11416760B2 (en) | Machine learning based user interface controller | |
WO2023020005A1 (zh) | 神经网络模型的训练方法、图像检索方法、设备和介质 | |
US10748041B1 (en) | Image processing with recurrent attention | |
US11568212B2 (en) | Techniques for understanding how trained neural networks operate | |
CN110476173A (zh) | 利用强化学习的分层设备放置 | |
CN113939791A (zh) | 图像标注方法、装置、设备及介质 | |
CN111902812A (zh) | 电子装置及其控制方法 | |
CN107209763A (zh) | 指定和应用数据的规则 | |
US20210216874A1 (en) | Radioactive data generation |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |