CN113785314A - 使用标签猜测对机器学习模型进行半监督训练 - Google Patents

使用标签猜测对机器学习模型进行半监督训练 Download PDF

Info

Publication number
CN113785314A
CN113785314A CN202080033626.6A CN202080033626A CN113785314A CN 113785314 A CN113785314 A CN 113785314A CN 202080033626 A CN202080033626 A CN 202080033626A CN 113785314 A CN113785314 A CN 113785314A
Authority
CN
China
Prior art keywords
input
output
processed
model
unlabeled
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202080033626.6A
Other languages
English (en)
Inventor
戴维·贝特洛
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Google LLC
Original Assignee
Google LLC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Google LLC filed Critical Google LLC
Publication of CN113785314A publication Critical patent/CN113785314A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Machine Translation (AREA)

Abstract

用于训练机器学习模型的方法、系统和装置,包括在计算机存储介质上编码的计算机程序。方法中的一种包括接收未标记批次;从未标记批次和标记批次生成处理后的未标记批次和处理后的标记批次,该生成包括:对于多个未标记训练输入中的每个未标记训练输入:从未标记训练输入生成多个增强未标记训练输入;使用机器学习模型处理扩增的未标记训练输入中的每一个以为每个扩增的未标记训练输入生成相应的模型输出;从扩增的未标记训练输入的模型输出生成猜测模型输出;以及将猜测的模型输出与扩增的未标记训练输入中的每一个相关联;并且在处理后的标记批次和处理后的未标记批次上训练机器学习模型。

Description

使用标签猜测对机器学习模型进行半监督训练
相关申请的交叉引用
本申请要求2019年5月6日提交的美国临时专利申请No.62/843,806的优先权,其全部内容通过引用并入本文。
技术领域
本说明书涉及训练机器学习模型。
背景技术
机器学习模型接收输入并基于接收到的输入和模型的参数值生成输出,例如,预测的输出。
神经网络是机器学习模型,其采用一层或多层非线性单元来预测接收到的输入的输出。一些神经网络除了输出层以外还包括一个或多个隐藏层。每个隐藏层的输出被用作网络中下一层,即,下一隐藏层或输出层的输入。网络的每一层根据相应的参数集的当前值从接收到的输入生成输出。
发明内容
本说明书描述一种作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统,该系统通过半监督学习来训练机器学习模型以执行机器学习任务,即,通过在包括未标记训练输入和标记训练输入的训练数据上训练机器学习模型。标记训练输入是对其来说可获得地面真值输出,即,应该由机器学习模型通过对标记训练输入执行特定机器学习任务来生成的输出的输入。未标记训练输入是对其来说不可获得地面真值输出的训练输入。
系统通过部分地为训练数据中的未标记训练输入生成猜测模型输出来训练机器学习模型。为了生成猜测模型输出,系统从未标记训练输入生成多个扩增的未标记训练输入。系统然后使用机器学习模型来处理多个扩增的未标记训练输入以为每个扩增的未标记训练输入生成相应的模型输出。系统然后从每个扩增的未标记训练输入的相应的模型输出生成猜测模型输出。
能够实现本说明书中描述的主题的特定实施例以便实现以下优点中的一个或多个。
所描述的系统能够利用有限的标记数据来将训练机器学习模型成在机器学习任务上很好地执行。特别地,通过利用“标签猜测”,即,为未标记训练输入生成的猜测模型输出,所描述的系统能够将机器学习模型训练成与常规技术相比在标记数据与未标记数据的较低比率的情况下具有高性能。给定相同量的标记数据和未标记数据,系统能够将机器学习模型训练成与使用常规技术相比具有更好的准确性。作为特定示例,所描述的技术能够被用于将机器学习模型训练成在各种图像分类任务上实现现有技术水平性能。
另外,系统能够将机器学习模型训练成对输入可变性鲁棒,例如,以有效地处理输入中的可变性。例如,已经根据所描述的技术被训练的训练后的机器学习模型将能够有效地对输入图像进行分类,即使当这些图像具有遮挡或模糊、具有不同程度的歪斜失真、不同程度的旋转等时也如此。
本说明书中描述的主题的一个或多个实施例的细节在附图和以下描述中阐述。本主题的其他特征、方面和优点将从描述、附图和权利要求中变得明显。
附图说明
图1示出示例机器学习模型训练系统。
图2是用于训练机器学习模型的示例过程的流程图。
图3A是用于在未标记训练输入的批次和标记训练输入的批次上训练机器学习模型的示例过程的流程图。
图3B是示出初始处理的未标记批次的生成的示意图。
图4示出所描述的技术相对于其他半监督学习技术的性能。
各图中相同的附图标记和标记指示相同的元件。
具体实施方式
图1示出示例机器学习模型训练系统100。机器学习模型训练系统100是作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统的示例,在该计算机中能够实现下述系统、组件和技术。
机器学习模型训练系统100是在包括标记训练数据140和未标记训练数据150的训练数据上训练机器学习模型110以根据模型参数的初始值来确定机器学习模型110的参数,在本说明书中被称为模型参数,的训练值的系统。
机器学习模型110是被配置成接收模型输入102并且处理该模型输入以将模型输入102映射到模型输出112以依照模型参数来执行特定机器学习任务的机器学习模型。
机器学习模型110能够被配置成执行各种机器学习任务中的任一种,即,接收任何种类的数字数据输入作为输入并且从输入生成模型输出。通常,模型输出是可能的类的集合上的概率分布。
例如,如果任务是图像分类,则模型110的输入是图像,并且对于给定图像的模型输出可以是对象类别的集合中的每一个的概率,其中每个概率表示图像包含属于该类别的对象的图像的估计的可能性。
例如,如果任务是视频分类,则模型110的输入是一个或多个视频帧,并且模型输出是对象类的集合上的概率分布或主题的集合上的概率分布。
作为另一示例,如果机器学习任务是文档分类,则对机器学习模型110的输入是来自互联网资源(例如,web页面)或文档的文本,并且对于给定互联网资源、文档或文档的一部分的模型输出可以是主题的集合中的每一个的分数,其中每个分数表示互联网资源、文档或文档部分是关于该主题的估计的可能性。
作为另一示例,如果任务是自然语言理解任务,则对机器学习模型110的输入是文本的序列,并且对于文本的给定序列的模型输出能够是适于自然语言理解任务的概率分布,例如,语言可接受性类别、语言情感类别、语言释义类别、句子相似性类别、文本蕴涵类别、问答类别等上的分布。
作为另一示例,如果任务是健康预测任务,则对机器学习模型110的输入是患者的电子健康记录数据,并且对于给定序列的模型输出能够是患者健康相关类别,例如,对于患者可能的诊断、与患者相关联的将来可能的健康事件等上的概率分布。
作为另一示例,如果任务是语音处理任务,则对机器学习模型110的输入能够是表示讲出话语,例如原始音频或声学特征的音频信号,即音频数据,并且模型输出能够是语音分类类别的集合上的概率分布,例如,可能的语言上的概率分布、自然语言文本,例如,可能的热词的集合上的概率分布等。
机器学习模型110能够具有适于由机器学习模型110处理的模型输入的类型的任何架构。例如,当模型输入是图像或音频数据时,机器学习模型110能够是卷积神经网络。当模型输入是文本序列或其他特征,例如,电子健康记录特征或音频特征的序列时,机器学习模型110能够是基于自注意力的神经网络,例如Transformer,或循环神经网络,例如,长短期记忆(LSTM)神经网络。当模型输入包括多种模态,例如图像和文本两者的输入时,模型110能够包括不同类型的神经网络,例如,卷积层和自注意力或循环层两者。
由系统100使用来训练机器学习模型110的标记训练数据140包括标记训练输入的多个批次。训练输入被称为“标记”训练输入是因为标记训练数据140对于每个标记训练输入也包括地面真值输出,即,应该由机器学习模型通过对标记训练输入执行特定机器学习任务来生成的输出。换句话说,地面真值输出是当对对应的标记训练输入执行机器学习任务时的实际输出。
由系统100使用来训练机器学习模型110的未标记训练数据150包括未标记训练输入的多个批次。训练输入被称为“未标记”训练输入是因为对于未标记训练输入的地面真值输出是得不到的,即,系统100不能够访问对于未标记训练输入中的任一个的任何地面真值输出或者由于某个其他原因对于用于训练模型110的未标记训练输入不能使用地面真值输出。
通常,系统100通过执行迭代训练过程来训练机器学习模型110。
在训练过程的每次迭代时,系统100在未标记训练数据的批次和标记训练数据的批次上训练模型110。为了在任何给定训练迭代时在这两个批次上训练模型110,系统100生成标记数据的处理后的批次和未标记数据的处理后的批次,然后在处理后的标记批次和处理后的未标记批次上训练机器学习模型以调整模型参数的当前值,即,从训练迭代时起模型参数的值。
在下面参考图2、图3A和图3B更详细地描述在模型110的训练期间执行训练迭代。
一旦模型110已经被训练了,系统100就能够提供指定训练后的模型以供在处理新网络输入时使用的数据。也就是说,系统100能够,例如,通过向用户设备输出或通过在系统100可访问的存储器中存储,来输出模型参数的训练值以供稍后在使用训练后的模型来处理输入时使用。
替代地或除了输出训练后的模型数据之外,系统100还能够实例化具有模型参数的训练后的值的机器学习模型的实例,例如,通过由系统提供的应用编程接口(API)来接收要处理的输入,使用训练后的模型来处理接收到的输入以生成模型输出,然后响应于接收到的输入来提供生成的模型输出、分类输出或两者。
图2是用于在未标记训练输入的批次和标记训练输入的批次上训练机器学习模型的示例过程200的流程图。为了方便,过程200将被描述为由位于一个或多个位置中的一个或多个计算机的系统执行。例如,适当地编程的机器学习模型训练系统,例如,图1的机器学习模型训练系统100,能够执行过程200。
系统能够对于多个不同的标记批次-未标记批次组合执行过程200多次以根据模型参数的初始值确定模型参数的训练值,即,能够在迭代训练过程的不同训练迭代时重复地执行过程200来训练机器学习模型。例如,系统能够继续执行过程200指定的迭代次数、指定的时间量,或者直到参数的值变化下降至阈值以下为止。
系统获得标记批次,即,标记训练输入的批次,并且对于每个标记训练输入,获得应该由机器学习模型通过对标记训练输入执行特定机器学习任务来生成的地面真值输出(步骤202)。
系统获得未标记批次,即,未标记训练输入的批次(步骤204)。
系统从未标记批次和标记批次生成处理后的未标记批次和处理后的标记批次(步骤206)。
在处理后的标记批次和处理后的未标记批次被生成之后,处理后的标记批次中的每个输入和处理后的未标记批次中的每个输入与相应的目标模型输出相关联。
通常,并且如将在下面更详细地描述的,处理后的标记批次中的每个输入对应于标记训练输入中的相应一个,并且对于输入的目标输出(i)是对于对应的标记训练输入的地面真值输出或者(ii)是从对于对应的标记训练输入的地面真值输出被导出的。
通常,并且如将在下面更详细地描述的,处理后的未标记批次中的每个输入对应于未标记训练输入中的相应一个并且对于输入的目标输出(i)是对于对应的未标记训练输入的猜测模型输出或者(ii)是从对于对应的未标记训练输入的猜测模型输出被导出的。猜测模型输出是基于机器学习模型的模型输出而生成的猜测模型输出,即,并且不是从作为输入提供给系统的任何地面真值信息生成的猜测模型输出。
在下面参考图3更详细地描述生成处理后的标记批次和处理后的未标记批次。
系统在处理后的标记批次和处理后的未标记批次上训练机器学习模型以调整模型参数的当前值(步骤208)。
特别地,系统通过计算包括标记损失项和未标记损失项的自监督学习损失函数的梯度来确定对模型参数的当前值的更新。例如,损失函数能够是标记损失项和未标记损失项的和或加权和。
标记损失项测量以下各项之间的误差:对于处理后的标记批次中的每个输入,(i)由机器学习模型依照参数的当前值对于输入生成的模型输出以及(ii)对于处理后的标记批次中的输入的目标输出。
例如,标记损失项可以是以下各项之间的交叉熵损失:(i)由机器学习模型依照参数的当前值为输入生成的模型输出以及(ii)对于处理后的标记批次中的输入的目标输出。
未标记损失项测量以下各项之间的误差:对于处理后的未标记批次中的每个输入,(i)由机器学习模型依照参数的当前值为输入生成的模型输出以及(ii)对于处理后的未标记批次中的输入的目标输出。
未标记损失项可以是与标记损失项或不同类型的机器学习损失相同的损失,例如交叉熵损失。例如,当标记损失项是交叉熵损失时,未标记损失可能是以下各项之间的平方L2损失:(i)由机器学习模型依照参数的当前值为输入生成的模型输出以及(ii)对于处理后的未标记批次中的输入的目标输出。使用平方L2损失可以是有益的,因为与交叉熵损失不同,它是有界的并且对错误预测不太敏感。
更具体地,系统通过相对于模型参数并在处理后的标记批次和处理后的未标记批次上计算损失函数的梯度、然后使用该梯度来更新模型参数的当前值来更新模型参数的当前值。特别地,系统能够对梯度应用更新规则,例如学习速率、亚当(Adam)优化器更新规则或rmsProp更新规则,以生成更新并且然后将更新应用,即减去或添加到梯度当前值,以确定模型参数的更新值。
图3A是用于生成处理后的标记批次和处理后的未标记批次的示例过程300的流程图。为了方便,过程300将被描述为由位于一个或多个位置中的一个或多个计算机的系统执行。例如,适当地编程的机器学习模型训练系统,例如,图1的机器学习模型训练系统100,能够执行过程300。
系统从标记批次生成初始处理后的标记批次(步骤302)。
在一些实施方式中,初始处理后的标记批次与标记批次相同,即,系统不修改标记批次中的标记训练输入。
在一些其他实施方式中,为了生成初始处理后的标记批次,系统从每个标记训练输入生成相应扩增的标记训练输入并且使该扩增的标记训练输入与对于标记训练输入的地面真值输出相关联。
系统使用以生成扩增的标记训练输入的数据扩增技术能够是适于模型被配置成处理的输入的类型的任何常规扩增技术。在下面参考图3B描述数据扩增技术的示例。
因此,在这些实施方式中,初始处理后的标记批次包括各自与对应的地面真值输出相关联的扩增的标记训练输入的集合。
系统从未标记批次生成初始处理后的未标记批次(步骤304)。
对于未标记批次中的每个未标记训练输入,初始处理后的未标记批次包括各自与同一猜测模型输出相关联的K个扩增的未标记训练输入。为了确保多样性,K被设置为大于1的固定正整数,例如2、4或5。因此,代替未标记批次中的每个未标记训练输入,未标记批次包括从未标记训练输入生成的多个扩增的未标记训练输入。
图3B是示出初始处理后的未标记批次的生成的示意图。
特别地,为了为给定未标记训练输入320生成扩增的未标记训练输入,系统从未标记训练输入生成K个扩增的未标记训练输入330。
特别地,系统对未标记训练输入应用数据扩增技术K次以为未标记训练输入320生成K个扩增的未标记训练输入330。
数据扩增技术能够是适于模型被配置成处理的输入的类型并且是随机的,即,使得将相同技术多次应用于相同输入将通常导致多个不同的扩增输出的任何常规扩增技术。
作为一个示例,当输入是图像时,扩增技术能够是对每个输入图像应用随机水平翻转、随机垂直翻转、随机裁切或随机旋转中的一种或多种的技术。
作为另一示例,当输入是图像时,扩增技术能够是将从分布,例如,高斯分布,采样的扰动添加到每个输入的技术。
作为另一示例,当输入是音频数据时,扩增技术能够是将从分布,例如,高斯分布,采样的扰动添加到每个输入的技术。
作为另一示例,当输入包括文本数据时,扩增技术能够是向文本数据中的每个单词指配概率、依照概率选择固定数量的单词、然后利用不同的单词,例如,从可能的单词的词汇表中采样的单词,替换每个所选择的单词的技术。
系统然后依照参数的当前值使用机器学习模型来处理K个扩增的未标记训练输入330中的每一个,以为每个扩增的未标记训练输入生成相应的模型输出340,即,以生成K个模型输出340.
系统然后从为K个扩增的未标记训练输入的模型输出生成单个猜测模型输出并且使该猜测模型输出与K个扩增的未标记训练输入330中的每一个相关联,即,使得每个扩增的未标记训练输入330与相同的猜测模型输出相关联。
更具体地,系统计算对于K个扩增的未标记训练输入的模型输出的平均值350。
在一些实施方式中,系统使用平均值作为猜测模型输出。
然而,在一些其他实施方式中,系统对模型输出的平均值应用锐化函数360以减少平均值中的不确定性,然后使用锐化函数的输出作为猜测模型输出。
特别地,当模型输出包括L个概率时,锐化函数对于模型输出中的第i个概率的输出满足:
Figure BDA0003337729220000111
其中pi是模型输出的平均值中的第i个概率的值并且T是介于0与1之间的超参数,不包括例如.25、.5或.75。通过将T设置在0与1之间,系统锐化平均值概率分布以降低平均值概率分布的熵。
因此,如能够从上述描述看到的,对于给定未标记训练输入的猜测模型输出是基于由模型为训练输入的扩增版本生成的模型输出而被生成的,即,并且不是从任何外部地面真值数据被生成的。
在一些实施方式中,系统使用初始处理后的标记批次和初始处理后的未标记批次作为用于训练的给定迭代的最终处理后的批次,即,作为在其上计算上面参考图2描述的梯度的批次。
然而,在一些其他实施方式中,系统进一步处理初始处理后的标记批次、初始处理后的未标记批次或两者以生成最终处理后的批次。例如,系统能够执行该进一步处理以规则化训练过程并且一旦被训练就改进模型的泛化。
特别地,在一些实施方式中,系统通过为每个特定扩增的标记输入和关联的地面真值输出生成与处理后的地面真值输出相关联的处理后的标记输入来生成处理后的最终批次(步骤306)。
特别地,对于每个给定特定扩增的标记输入,系统从包括至少扩增的标记输入和关联的地面真值输出的集合中选择输入-输出对。
也就是说,在一些情况下,该集合仅包括扩增的标记输入和关联的地面真值输出。
然而,在其他情况下,该集合包括(i)扩增的标记输入和关联的地面真值输出以及(ii)扩增的未标记输入和关联的猜测输出。在一些情况下,在该集合中包括(i)和(ii)两者能够为模型的训练提供改进的规则化。
例如,系统能够通过在无替换的情况下从集合中随机地采样来选择对,即,使得同一对不被选择用于多于一个特定扩增的标记输入。
系统然后执行扩增的标记输入和所选择的对中的输入的凸组合以生成处理后的输入。为了执行该凸组合,系统能够从预定分布中对权重λ进行采样,然后计算扩增的标记输入和所选择的对中的输入之间的加权和,其中扩增的标记输入被指配了权重λ而所选择的对中的输入被指配了权重(1-λ)。
系统还执行与扩增的标记输入相关联的地面真值输出和所选择的对中的输出的凸组合以生成处理后的模型输出。为了执行该凸组合,系统能够计算地面真值输出与所选择的对中的输出之间的加权和,其中地面真值被指配了权重λ而所选择的对中的输出被指配了权重(1-λ)。
系统然后使处理后的输入与处理后的输出相关联。
代替执行步骤306除此之外,在一些实施方式中,系统为每个特定扩增的未标记输入和关联的猜测输出生成与处理后的猜测输出相关联的处理后的未标记输入(步骤308)。
特别地,对于每个给定特定扩增的未标记输入,系统从包括上述(i)和(ii)的集合中选择输入-输出对。例如,系统能够通过在无替换的情况下从集合中随机地采样来选择对,即,使得同一对不被选择用于不止一个特定扩增的标记输入。当还执行步骤306时,在对用于扩增的未标记输入的对进行采样之前,系统从集合中移除在生成最终处理后的标记批次时采样的任何对。
系统然后执行扩增的未标记输入和所选择的对中的输入的凸组合以生成处理后的输入。为了执行该凸组合,系统能够从预定分布中对权重λ进行采样,然后如上面参考步骤306所描述的那样计算加权和。
系统还执行与扩增的未标记输入相关联的猜测输出和所选择的对中的输出的凸组合以生成处理后的模型输出,即,通过如上所述使用λ来计算加权和。
系统然后使处理后的输入与处理后的输出相关联。
在一些实施方式中,代替使用从分布中直接采样的值作为λ值,系统能够使用修改后的值作为在凸组合中使用的λ值。特别地,当分布在介于0与1(不包括0和1或包括0和1)之间的值的范围内时,系统能够从分布中对值进行采样,然后将λ值设置为(i)样本值和(ii)1中的最大值减去采样值。这确保处理后的特定扩增的输入,即标记或未标记输入,比从集合中采样的输入更接近原始扩增的输入。
图4示出所描述的技术相对于其他半监督学习技术的性能。
特别地,图4示出所描述的技术(“MixMatch”)与若干竞争基线的在两个数据集上并使用不同数量的标记训练输入的比较。特别地,图表410示出所描述的技术和基线集在CIFAR-10数据集上的性能,然而图表420示出所描述的技术和基线集在SVHN数据集上的性能。
如从能够从图4看到的,在不同大小的标记数据下,使用所描述的技术来训练的训练后的模型的错误率始终较低,即比基线更好,即,给定不同大小的标记数据所描述的技术始终胜过基线。
因此,即使与其他半监督学习技术,即,使用标记数据和未标记数据两者的其他技术相比,所描述的技术也产生更有效的模型训练。
本说明书连同系统和计算机程序组件一起使用术语“被配置”与系统和计算机程序组件相结合。对于要被配置成执行特定操作或动作的一个或多个计算机的系统意味着系统已经在其上安装了在操作中使该系统执行这些操作或动作的软件、固件、硬件或软件、固件、硬件的组合。对于要被配置成执行特定操作或动作的一个或多个计算机程序意味着该一个或多个程序包括指令,该指令当由数据处理装置执行时,使该装置执行操作或动作。
本说明书中描述的主题和功能操作的实施例能够以数字电子电路、以有形地体现的计算机软件或固件、以包括本说明书中公开的结构及其结构等价物的计算机硬件或者以它们中的一个或多个的组合而被实现。本说明书中描述的主题的实施例能够被实现为一个或多个计算机程序,即,在有形非暂时性存储介质上编码以用于由数据处理装置执行或者控制数据处理装置的操作的计算机程序指令的一个或多个模块。计算机存储介质能够是机器可读存储设备、机器可读存储基板、随机或串行访问存储设备或它们中的一个或多个的组合。可替代地或另外,程序指令能够被编码在人工生成的传播信号上,该传播信号例如是机器生成的电、光或电磁信号,其被生成为对信息进行编码以向适合的接收器装置传输以供由数据处理装置执行。
术语“数据处理装置”指数据处理硬件并且涵盖用于处理数据的所有种类的装置、设备和机器,作为示例,包括可编程处理器、计算机或多个处理器或计算机。装置还能够是或者进一步包括专用逻辑电路,例如,FPGA(现场可编程门阵列)或ASIC(专用集成电路)。装置除了包括硬件之外还能够可选地包括为计算机程序创建执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统或它们中的一个或多个的组合的代码。
也可以被称为或者被描述为程序、软件、软件应用、app、模块、软件模块、脚本或代码的计算机程序能够以包括编译或解释语言或声明或过程语言的任何形式的编程语言被编写;并且它能够以任何形式被部署,包括作为独立程序或者作为模块、组件、子例行程序或适合于在计算环境中使用的其它单元。程序可以但是不需要对应于文件系统中的文件。程序能够被存储在保持其它程序或数据的文件的一部分中,例如,存储在标记语言文档中的一个或多个脚本;在专用于所涉及的程序的单个文件中或者在多个协调文件中,例如存储代码的一个或多个模块、子程序或部分的文件。计算机程序能够被部署为在一个计算机上或者在位于一个站点处或者分布在多个站点上并通过数据通信网络互连的多个计算机上被执行。
在本说明书中,术语“数据库”被广泛用于指任何数据的合集:数据不需要以任何特定方式被结构化,或者根本不需要被结构化,并且其能够被存储在一个或多个位置中的存储设备中。因此,例如,索引数据库能够包括多个数据合集,每个数据合集可以被不同地组织和访问。
类似地,在本说明书中术语“引擎”被广泛地用于指被编程为执行一个或多个具体功能的基于软件的系统、子系统或过程。通常,引擎将被实现为安装在一个或多个位置中的一个或多个计算机上的一个或多个软件模块或组件。在一些情况下,一个或多个计算机将被专用于特定引擎;在其它情况下,多个引擎能够在同一计算机或多个计算机上被安装并运行。
本说明书中描述的过程和逻辑流程能够由执行一个或多个计算机程序的一个或多个可编程计算机执行以通过对输入数据进行操作并生成输出来执行功能。该过程和逻辑流程还能够由例如FPGA或ASIC的专用逻辑电路执行,或者通过专用逻辑电路和一个或多个编程计算机的组合而被执行。
适合于执行计算机程序的计算机能够基于通用微处理器或专用微处理器或两者,或任何其它种类的中央处理器。通常,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的必要元件是用于执行或者实行指令的中央处理单元以及用于存储指令和数据的一个或多个存储设备。中央处理单元和存储器能够由专用逻辑电路补充或者被并入在专用逻辑电路中。通常,计算机还将包括用于存储数据的一个或多个大容量存储设备,例如磁盘、磁光盘或光盘,或者可操作地被耦合为从该一个或多个大容量存储设备接收数据或者将数据传送到该一个或多个大容量存储设备,或者两者。然而,计算机不需要具有这样的设备。此外,计算机能够被嵌入在另一设备中,例如,移动电话、个人数字助理(PDA)、移动音频或视频播放器、游戏控制器、全球定位系统(GPS)接收器或便携式存储设备,例如通用串行总线(USB)闪存驱动器等。
适合于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储设备,作为示例,包括半导体存储设备,例如,EPROM、EEPROM和闪速存储器设备;磁盘,例如,内部硬盘或可移动盘;磁光盘;以及CD ROM和DVD-ROM盘。
为了提供与用户的交互,本说明书中描述的主题的实施例能够在计算机上被实现,该计算机具有用于向用户显示信息的显示设备,例如,CRT(阴极射线管)或LCD(液晶显示器)监视器,以及用户能够通过其向计算机提供输入的键盘和定点指向设备,例如,鼠标或轨迹球。其它种类的设备能够被用于提供与用户的交互;例如,提供给用户的反馈能够是任何形式的感觉反馈,例如,视觉反馈、听觉反馈或触觉反馈;并且来自用户的输入能够以任何形式被接收,包括声学、语音或触觉输入。另外,计算机能够通过向由用户使用的设备发送文档并从由用户使用的设备接收文档来与用户交互;例如,通过响应于从web浏览器接收到请求而向用户的设备上的web浏览器发送网页。另外,计算机能够通过向个人设备,例如,正在运行消息传送应用的智能电话,发送文本消息或其它形式的消息并且反过来从用户接收响应消息来与用户交互。
用于实现机器学习模型的数据处理装置还能够包括,例如,用于处理机器学习训练或生产,即,推理、工作负载,的公共和计算密集部分的专用硬件加速器单元。
机器学习模型能够使用机器学习框架被实现和部署,例如,TensorFlow框架、Microsoft Cognitive Toolkit框架、Apache Singa框架或Apache MXNet框架。
本说明书中描述的主题的实施例能够被实现在计算系统中,该计算系统包括后端组件,例如,作为数据服务器,或者包括中间件组件,例如,应用服务器,或者包括前端组件,例如,具有用户通过其能够与本说明书中描述的主题的实施方式交互的图形用户界面、web浏览器或app的客户端计算机,或者包括一个或多个这种后端、中间件或前端组件的任何组合。系统的组件能够通过例如通信网络的任何形式或介质的数字数据通信而被互连。通信网络的示例包括局域网(LAN)和广域网(WAN),例如互联网。
计算系统能够包括客户端和服务器。客户端和服务器一般地通常彼此远离并通常通过通信网络来交互。客户端和服务器的关系借助于在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序而产生。在一些实施例中,服务器向用户设备发送例如HTML页面的数据,例如,出于向与作为客户端的设备交互的用户显示数据并从该用户接收用户输入的目的。在用户设备处生成的数据,例如,用户交互的结果,能够在服务器处从设备被接收。
虽然本说明书包含许多具体实施方式细节,但是这些不应该被解释为对任何发明的或可能要求保护的范围的限制,而是相反地被解释为对可以特定于特定发明的特定实施例的特征的描述。在本说明书中的单独的实施例的上下文中描述的某些特征也能够在单个实施例中组合地被实现。相反地,在单个实施例的上下文中描述的各种特征也能够单独地或者按照任何适合的子组合在多个实施例中被实现。此外,尽管特征可能在上文被描述按照某些组合起作用并且甚至最初被如此要求保护,但是来自要求保护的组合的一个或多个特征能够在一些情况下被从该组合中去除,并且要求保护的组合可以被引导为子组合或子组合的变形。
类似地,虽然操作按照特定顺序在附图中被描绘并在权利要求书中被记载,但是这不应该被理解为需要这种操作以所示的特定顺序或者以先后顺序被执行,或者需要所有图示的操作被执行以实现预期的结果。在某些情况下,多任务处理和并行处理可以是有益的。此外,上述实施例中的各种系统模块和组件的分离不应该被理解为在所有实施例中需要这种分离,并且应该理解的是,描述的程序组件和系统一般地能够被一起集成在单个软件产品中或者封装到多个软件产品中。
已经描述了主题的特定实施例。其它实施例在所附权利要求的范围内。例如,权利要求中记载的动作能够以不同的顺序被执行并仍然实现预期的结果。作为一个示例,附图中描绘的过程不一定需要示出的特定顺序或依次的顺序以实现预期的结果。在一些情况下,多任务处理和并行处理可以是有益的。

Claims (12)

1.一种训练具有多个参数的机器学习模型以执行机器学习任务的方法,其中,所述机器学习模型被配置成接收输入并且根据所述参数处理所述输入以生成模型输出,所述方法包括:
接收包括多个未标记训练输入的未标记批次;
接收包括多个标记训练输入的标记批次,并且对于每个标记训练输入,接收应由所述机器学习模型通过对所述标记训练输入执行特定机器学习任务来生成的地面真值输出;
从所述未标记批次和所述标记批次生成处理后的未标记批次和处理后的标记批次,所述生成包括:
对于所述多个未标记训练输入中的每个未标记训练输入:
从所述未标记训练输入生成多个扩增的未标记训练输入;
使用所述机器学习模型根据所述参数的当前值处理所述扩增的未标记训练输入中的每一个,以为每个扩增的未标记训练输入生成相应的模型输出;
从所述扩增的未标记训练输入的所述模型输出生成猜测模型输出;以及
将所述猜测的模型输出与所述扩增的未标记训练输入中的每一个相关联;以及
在所述处理后的标记批次和所述处理后的未标记批次上训练所述机器学习模型以调整所述参数的所述当前值。
2.根据权利要求1所述的方法,其中,所述机器学习模型的输入是图像,并且所述模型输出是对象类的集合上的概率分布。
3.根据权利要求1所述的方法,其中,所述机器学习模型的输入是一个或多个视频帧,并且所述模型输出是对象类的集合上的概率分布或主题的集合上的概率分布。
4.根据权利要求1所述的方法,其中,所述机器学习模型的输入是文本,并且所述模型输出是主题的集合上的概率分布。
5.根据权利要求1所述的方法,其中,所述机器学习模型的输入是音频信号,并且所述模型输出是自然语言文本的集合上的概率分布。
6.根据任一前述权利要求所述的方法,其中,从所述扩增的未标记训练输入的模型输出生成猜测模型输出包括:
计算所述扩增的未标记训练输入的所述模型输出的平均值。
7.根据权利要求6所述的方法,其中,从所述扩增的未标记训练输入的所述模型输出生成猜测模型输出进一步包括:
对所述模型输出的平均值应用锐化函数以减少所述平均值的不确定性。
8.根据任一前述权利要求所述的方法,其中,从所述未标记批次和所述标记批次生成处理后的未标记批次和处理后的标记批次进一步包括:
对于所述多个标记训练输入中的每个标记训练输入:
从所述标记训练输入生成扩增的标记训练输入;和
将所述扩增的标记训练输入与所述标记训练输入的地面真值输出相关联。
9.根据权利要求8所述的方法,其中,从所述未标记批次和所述标记批次生成处理后的未标记批次和处理后的标记批次进一步包括:
为每个特定的扩增的标记输入和相关联的地面真值输出生成与处理后的地面真值输出相关联的处理后的标记输入,包括:
从以下的集合中选择输入-输出对:(i)扩增的标记输入和相关联的地面真值输出和(ii)扩增的未标记输入和相关联的猜测输出;
执行所述扩增的标记输入与所述输入选择对中的输入的凸组合以生成处理后的输入;
执行与所述扩增的标记输入相关联的所述地面真值输出与所选择的对中的所述输出的凸组合以生成处理后的输出;以及
将所述处理后的输入与所述处理后的输出相关联。
10.根据权利要求8或9中的任一项所述的方法,其中,从所述未标记批次和所述标记批次生成处理后的未标记批次和处理后的标记批次进一步包括:
为每个特定的扩增的标记输入和相关联的猜测输出生成与处理后的猜测输出相关联的处理后的未标记输入,包括:
从以下的集合中选择输入-输出对:(i)扩增的标记输入和相关联的地面真值输出和(ii)扩增的未标记输入和相关联的猜测输出;
执行所述扩增的未标记输入和所选择的对中的输入的凸组合以生成处理后的输入;
执行与所述扩增的未标记输入相关联的所述猜测输出和所选择的中的所述输出的凸组合以生成处理后的输出;以及
将所述处理后的输入与所述处理后的输出相关联。
11.一种包括一个或多个计算机和存储指令的一个或多个存储设备的系统,所述指令当由所述一个或多个计算机执行时,使所述一个或多个计算机执行根据权利要求1-10中的任一项所述的相应方法的操作。
12.一种或多种存储指令的计算机存储介质,所述指令当由一个或多个计算机执行时,使所述一个或多个计算机执行根据权利要求1-10中的任一项所述的相应方法的操作。
CN202080033626.6A 2019-05-06 2020-05-06 使用标签猜测对机器学习模型进行半监督训练 Pending CN113785314A (zh)

Applications Claiming Priority (3)

Application Number Priority Date Filing Date Title
US201962843806P 2019-05-06 2019-05-06
US62/843,806 2019-05-06
PCT/US2020/031691 WO2020227418A1 (en) 2019-05-06 2020-05-06 Semi-supervised training of machine learning models using label guessing

Publications (1)

Publication Number Publication Date
CN113785314A true CN113785314A (zh) 2021-12-10

Family

ID=70919062

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202080033626.6A Pending CN113785314A (zh) 2019-05-06 2020-05-06 使用标签猜测对机器学习模型进行半监督训练

Country Status (4)

Country Link
US (1) US20220230065A1 (zh)
EP (1) EP3948691A1 (zh)
CN (1) CN113785314A (zh)
WO (1) WO2020227418A1 (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117574258A (zh) * 2024-01-15 2024-02-20 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 一种基于文本噪声标签和协同训练策略的文本分类方法

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112767922B (zh) * 2021-01-21 2022-10-28 中国科学技术大学 一种对比预测编码自监督结构联合训练的语音识别方法
CN113011531B (zh) * 2021-04-29 2024-05-07 平安科技(深圳)有限公司 分类模型训练方法、装置、终端设备及存储介质
CN114943879B (zh) * 2022-07-22 2022-10-04 中国科学院空天信息创新研究院 基于域适应半监督学习的sar目标识别方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117574258A (zh) * 2024-01-15 2024-02-20 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 一种基于文本噪声标签和协同训练策略的文本分类方法
CN117574258B (zh) * 2024-01-15 2024-04-26 合肥综合性国家科学中心人工智能研究院(安徽省人工智能实验室) 一种基于文本噪声标签和协同训练策略的文本分类方法

Also Published As

Publication number Publication date
EP3948691A1 (en) 2022-02-09
US20220230065A1 (en) 2022-07-21
WO2020227418A1 (en) 2020-11-12

Similar Documents

Publication Publication Date Title
US20210150355A1 (en) Training machine learning models using task selection policies to increase learning progress
US11681924B2 (en) Training neural networks using a variational information bottleneck
US11443170B2 (en) Semi-supervised training of neural networks
US20180189950A1 (en) Generating structured output predictions using neural networks
US11922281B2 (en) Training machine learning models using teacher annealing
WO2019083553A1 (en) NEURONAL NETWORKS IN CAPSULE
US20240127058A1 (en) Training neural networks using priority queues
US11010664B2 (en) Augmenting neural networks with hierarchical external memory
CN113785314A (zh) 使用标签猜测对机器学习模型进行半监督训练
US20220215209A1 (en) Training machine learning models using unsupervised data augmentation
US20210049298A1 (en) Privacy preserving machine learning model training
EP3371747A1 (en) Augmenting neural networks with external memory
US10824946B2 (en) Training neural networks using posterior sharpening
US20220383120A1 (en) Self-supervised contrastive learning using random feature corruption
CN114861873A (zh) 多阶段计算高效的神经网络推断
CN114730380A (zh) 神经网络的深度并行训练
US20230206030A1 (en) Hyperparameter neural network ensembles
US20240152809A1 (en) Efficient machine learning model architecture selection

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination