CN114861873A - 多阶段计算高效的神经网络推断 - Google Patents

多阶段计算高效的神经网络推断 Download PDF

Info

Publication number
CN114861873A
CN114861873A CN202210391735.6A CN202210391735A CN114861873A CN 114861873 A CN114861873 A CN 114861873A CN 202210391735 A CN202210391735 A CN 202210391735A CN 114861873 A CN114861873 A CN 114861873A
Authority
CN
China
Prior art keywords
network
neural network
input
output
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210391735.6A
Other languages
English (en)
Inventor
安基特·辛格·拉瓦特
曼齐尔·查希尔
阿迪蒂亚·克里希纳·梅农
桑吉夫·库马尔
阿姆尔·艾哈迈德
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Google LLC
Original Assignee
Google LLC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Google LLC filed Critical Google LLC
Publication of CN114861873A publication Critical patent/CN114861873A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2148Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the process organisation or structure, e.g. boosting cascade
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/211Selection of the most significant subset of features
    • G06F18/2113Selection of the most significant subset of features by ranking or filtering the set of features, e.g. using a measure of variance or of feature cross-correlation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/217Validation; Performance evaluation; Active pattern learning techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2413Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/044Recurrent networks, e.g. Hopfield networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/0895Weakly supervised learning, e.g. semi-supervised or self-supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/04Inference or reasoning models
    • G06N5/041Abduction
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Biomedical Technology (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本公开涉及多阶段计算高效的神经网络推断。用于使用第一神经网络和第二神经网络的多阶段计算高效推断的方法、系统和装置,所述方法、系统和装置包括在计算机存储介质上编码的计算机程序。

Description

多阶段计算高效的神经网络推断
技术领域
本说明书涉及使用神经网络对输入进行分类。
背景技术
神经网络是采用一层或多层非线性单元来预测接收到的输入的输出的机器学习模型。一些神经网络除了输出层之外还包括一个或多个隐藏层。每个隐藏层的输出用作网络中下一层——即下一个隐藏层或输出层——的输入。网络的每一层根据相应一组参数的当前值从接收到的输入生成输出。
发明内容
本说明书描述了在一个或多个位置的一个或多个计算机上实现为计算机程序的系统,该系统使用第一神经网络对于新网络输入执行分类任务,并且对于一些输入使用第二神经网络。
可以实现本说明书中描述的主题的特定实施例以实现以下优点中的一个或多个。
本说明书描述了一种两阶段推断机制,该机制首先使用计算量轻的学生神经网络(也称为“第一神经网络”)来进行预测。如果学生神经网络对预测的准确性有信心,则系统就会发出该预测。只有当学生神经网络对预测的准确性没有信心时,系统才会使用计算量大的教师神经网络(也称为“第二神经网络”)进行预测,然后发出教师的预测。
因为学生神经网络通常对所有“简单”实例有信心,并且因为现实世界中的数据,即在生产设置中的推断时间,是严重长尾的,所以向系统的大部分推断时间查询仅使用学生进行处理。“长尾”的数据是指大部分数据属于类别的真子集(“简单”类别),其中只有少量数据属于少数剩余类别(“硬”类别)。因此,现实世界中的绝大多数输入将是学生神经网络对其有信心的示例,因此不需要使用教师神经网络。
当学生没有信心时(这仅在推断时发生在少数困难示例中)系统回退到更大的(“巨型”)教师模型。因此,可以为所有推断时间输入生成准确的预测,同时仅对一小部分输入使用大型教师模型。这允许推断保持计算高效(特别是对于仅使用轻量学生模型处理的“简单”示例),但具有显著提高的预测/分类准确性。
此外,这种两阶段推断在现代设置(如边缘计算和5G小云)中特别有用,其中轻量学生模型在设备上运行,以低时延进行大部分预测,并且仅在很少的情况下有硬实例需要委托给在云中运行的共享巨型教师模型。也就是说,所描述的技术允许完全在边缘计算设备处处理绝大多数推断时间输入,其中只有少量推断时间输入被路由到云进行处理。相对于将所有输入都路由到云进行处理的传统方法,这大大减少了需要通过数据通信网络发送的数据量,同时相对于在边缘计算设备处本地地进行所有预测的方法,这显著提高了预测/分类的准确性,其中仅有最少量的网络流量增加。
本说明书的主题的一个或多个实施例的细节在附图和以下描述中阐述。本主题的其它特征、方面和优点将从描述、附图和权利要求中变得显而易见。
附图说明
图1示出了示例训练系统和示例推断系统。
图2是为新网络输入生成分类输出的流程图。
图3是用于训练第一和第二神经网络的示例过程的流程图。
各种附图中相同的附图标记和名称表示相同的元件。
具体实施方式
图1示出了示例训练系统100和示例推断系统150。
训练系统100和推断系统150是在一个或多个位置的一个或多个计算机上实现为计算机程序的系统的示例,其中可以实现下文描述的系统、组件和技术。系统100和系统150可以在同一组一个或多个计算机上或在不同位置的一个或多个计算机的不同组上实现。
训练系统100用训练数据102训练第一神经网络110和第二神经网络120以执行分类任务。
一旦第一神经网络110和第二神经网络120已经被训练,推断系统150就使用经训练的第一神经网络110和第二神经网络120来执行推断,即接收新网络输入152和处理新网络输入152以生成用于分类任务的分类输出154。
神经网络110和120可以被配置成执行多种分类任务中的任一种。如在本说明书中所使用的,分类任务是以下任何任务,该任务需要神经网络110或120生成包括一组多个类别中的每一个类别的相应分数的输出,并且然后使用相应的分数选择一个或多个类别作为对网络输入的“分类”。
分类任务的一个示例是图像分类,其中到神经网络110的输入是图像,即图像像素的强度值,类别是对象类别,而任务是将图像分类为该图像描绘了对象类别中的一个或多个对象。也就是说,给定输入图像的分类输出是对输入图像中描绘的一个或多个对象类别的预测。
分类任务的另一示例是文本分类,其中到神经网络110的输入是文本,而任务是将文本分类为属于一个多个类别。这种任务的一个示例是情绪分析任务,其中每个类别对应于任务的不同可能情绪。这种任务的另一个示例是阅读理解任务,其中输入文本包括上下文段落和问题,并且每个类别对应于上下文段落中可能是问题答案的不同片段。可以被构建为分类任务的文本处理任务的其它示例包括含义任务、释义任务、文本相似性任务、情绪任务、句子完成任务、语法任务等。
分类任务的其它示例包括语音处理任务,其中到神经网络的输入是表示语音的音频数据。语音处理任务的示例包括语言识别(其中类别是语音的不同可能语言)、热词识别(其中类别指示音频数据中是否说出一个或多个特定“热词”)等等。
作为另一个示例,该任务可以是健康预测任务,其中输入是从患者的电子健康记录数据导出的序列,并且类别是与患者未来健康相关的相应预测,例如,应为患者开具的预测治疗,患者发生不良健康事件的可能性,或患者的预测诊断。
因此,如上所述,神经网络110和120被配置成处理网络输入以生成以下输出,该输出包括一组的多个类别中的每一个类别的相应分数。
第一神经网络110和第二神经网络120通常都是被配置成执行相同分类任务的神经网络。
在一些情况下,第一神经网络110和第二神经网络120具有相同的架构,因此具有相同数量的参数。例如,两个神经网络都可以是卷积神经网络、基于自注意力的神经网络(Transformers)或循环神经网络。
然而,在一些其它情况下,该两个神经网络具有不同的架构,其中第二神经网络120具有比第一神经网络110更多数量的参数。在这些情况下,在推断时,更大、计算效率低的第二神经网络120用于提高更小、计算高效的第一神经网络110的性能。
例如,该两个神经网络都可以是卷积神经网络、基于自注意力的神经网络(Transformers)或循环神经网络,但是第一神经网络110由于具有更少的层而具有更少的参数,在具有较小尺寸的内部表示上运行(例如,在卷积层的情况下,输出过滤器较少,或者在Transformer中的自注意力子层的查询、键和值的较小尺寸),或两者兼而有之。
例如,在推断时,第一神经网络110可以部署在边缘计算设备上或部署在具有有限计算预算的另一个计算环境中,而第二神经网络120部署在远离第一神经网络110且包括一个或多个计算机的云计算系统上,例如,在有大量计算资源可用的云中。
因此,如下文将描述的,可以成功地利用额外的计算资源来提高第一神经网络110在“长尾”或“硬”网络输入上的性能,同时仅使用计算高效的第一神经网络110用于“简单”的输入。
作为另一个示例,在推断时,第一神经网络110和第二神经网络120可以部署在相同的一组一个或多个计算机上,但是第一神经网络110可以部署在与第二神经网络不同的硬件上。例如,第一神经网络110可以部署在计算机的一个或多个ASIC上(并且在某些情况下,部署在计算机的一个或多个ASIC的优化器上),例如视觉处理单元(VPU)、张量处理单元(TPU),图形处理单元(GPU),而第二神经网络120可以在计算机的其它硬件上执行,例如,使用一个或多个中央处理单元(CPU)。结果,由于第一神经网络110部署在ASIC上,因此“简单”输入可以以减少的时延、减少的功耗或两者进行处理,而第二神经网络120可以用于确保对于“硬”输入来说保持预测质量高。
更具体地,在推断时,系统150使用第一和第二神经网络110和120对新网络输入104执行推断。也就是说,系统150为每个网络输入104使用这两个神经网络来生成相应的分类输出106。
特别地,当系统150获得新网络输入152时,系统150使用第一神经网络110处理新网络输入104以生成第一网络输出,该第一网络输出包括第一组类别中的每一个类别(即对于分类任务所需的所有类别)的相应第一分数。可选地,第一网络输出还包括“弃权”类的分数,如下文将更详细描述的。
系统150根据第一网络输出确定第一网络输出是否可能不准确。在下面参考图2更详细地描述了确定由第一神经网络110生成的给定网络输出是否可能不准确。
响应于确定第一网络输出不太可能不准确,即,输出可能准确,系统150使用第一网络输出生成分类输出154,例如,通过提供第一网络输出作为分类输出154或通过根据第一网络输出选择一个或多个最高得分类别并提供识别所选类别的数据以及可选地所选类别的第一分数作为分类输出154。即,响应于确定第一网络输出不太可能不准确,系统150基于第一网络输出来对新网络输入152进行分类,而不将新网络输入152作为输入提供给第二神经网络120,即,不使用第二神经网络120。
响应于确定第一网络输出可能不准确,系统150提供新网络输入152作为第二神经网络120的输入。第二神经网络120被配置成处理新网络输入104以生成第二网络输出,该第二网络输出包括第二组类别中的每一个类别的相应第二分数。通常,第二组类别包括第一组类别中的所有类别,并且可选地包括一个或多个附加类别。
系统150然后基于第二网络输出对新网络输入152进行分类。即,系统150使用第二网络输出来生成分类输出154,例如,通过提供第二网络输出作为分类输出154或者通过根据第二网络输出选择一个或多个最高得分类别并且提供识别所选类别的数据以及可选地,所选类别的第二分数作为分类输出154。
因此,当系统确定第一网络输出可能不准确时,系统150仅使用第二神经网络120(仅将新网络输入作为输入提供给第二神经网络)。换言之,系统150仅将第二神经网络120用于被认为是“难”的网络输入,并且仅将第一神经网络110用于被认为是“容易”的网络输入,因为第一神经网络网络110有信心其输出是准确的。
因此,在第一神经网络110部署在边缘计算设备上或部署在具有有限计算预算的另一个计算环境中而第二神经网络120部署在远离第一神经网络110且包括一个或多个计算机的云计算系统上的示例中,系统150可以在边缘计算设备上本地使用第一神经网络110处理新网络输入104,并且仅在响应于确定第一网络输出可能不准确的情况下,将新网络输入104发送到云计算系统以供第二神经网络120处理。
因此,如下文将描述的,可以成功地利用额外的计算资源来提高第一神经网络110在“长尾”或“硬”网络输入上的性能,同时仅使用计算高效的第一神经网络110用于“简单”的输入,从而确保绝大多数输入在设备160上以最小时延本地处理。
作为另一示例,当第一神经网络110和第二神经网络120部署在相同的一组一个或多个计算机上,但第一神经网络110可以部署在与第二神经网络120不同的硬件上时,系统150系统150可以在硬件上(例如专用于第一神经网络110的一个或多个ASIC)本地使用第一神经网络110处理新网络输入104,并且仅在响应于确定第一网络输出可能不准确的情况下,才将新网络输入104发送到专用于第二神经网络120的硬件上以进行处理。结果,由于第一神经网络110部署在ASIC上,因此“简单”的输入可以以减少的时延、减少的功耗或两者来处理,而第二神经网络120可以用于确保对于“硬”输入来说保持预测质量高。
如上所述,在使用神经网络110和120执行推断之前,训练系统100用训练数据102训练这两个神经网络。
通常,训练系统100使用蒸馏来训练神经网络。在蒸馏中,系统100首先训练第二神经网络120,然后使用经训练的第二神经网络120的输出作为训练第一神经网络110的一部分。
通常,系统可以使用任何合适的蒸馏技术来用训练数据102训练神经网络110和120。
在蒸馏技术中,一个神经网络(在这种情况下是第二神经网络120)的输出用于生成用于训练另一个神经网络(在这种情况下是第一神经网络110)的目标。
也就是说,系统100首先训练第二神经网络120,然后使用经训练的第二神经网络120来训练第一神经网络110。
蒸馏技术的示例在下文参考图3进行更详细的描述。
图2是用于为新网络输入生成分类输出的示例过程200的流程图。为方便起见,过程200将被描述为由在一个或多个位置的一个或多个计算机的系统执行。例如,根据本说明书合适地编程的推断系统(例如图1的推断系统150)可以执行过程200。
系统获得新网络输入(步骤202)。
系统使用第一神经网络处理新网络输入以生成第一网络输出,该第一网络输出包括第一组类别中的每一个类别的相应第一分数(步骤204)。
系统根据第一网络输出确定该第一网络输出是否可能不准确(步骤206)。
系统可以以多种方式中的任一种方式确定第一网络输出是否可能不准确。
在一些实现方式中,系统基于第一组类别中的哪个类别具有最高的第一分数来确定第一网络输出是否可能不准确。特别地,系统可以确定具有最高第一分数的类别是否在第一多个类别的第一预定真子集中。如果是,则当具有最高第一分数的类别在该多个类别的预定真子集中时,系统确定第一网络输出可能不准确。
系统或另一系统(例如,图1的训练系统100)可以选择第一真子集,使得在第一真子集中的类别比不在第一个真子集中的类别更可能与第一神经网络的不准确分类相关联,即第一神经网络更有可能将网络输入错误地分类为属于第一真子集中的类别而不是分类为不在第一真子集中的类别。通常,系统可以选择第一真子集,使得第一真子集包括训练数据频率分布的“尾部”中的类别。例如,系统可以选择阈值数量的具有最小出现频率的类别作为第一真子集,或者可以按出现频率的升序选择类别,即从最小的出现频率开始,直到所选择的类的频率总和超过阈值。
在一些其它实现方式中,系统使用基于间隔(margin-based)的方法来确定第一网络输出是否可能不准确,即,基于第一集中的不同类别的分数之间的差异。作为特定示例,系统可以确定(i)第一组类别中的任何类别的最高第一分数与(ii)第一组中的任何类别的第二高第一分数之间的差是否满足阈值,即,小于阈值。如果差值小于阈值,则系统确定第一网络输出可能不准确。
作为另一示例,第一神经网络可以被配置成除了针对第一组类别的第一分数之外生成针对“弃权”类的第一分数。在该示例中,系统可以确定弃权类的第一分数是否高于第一组类别中的任何类别的第一分数,然后仅当弃权类的第一分数高于第一组类别中的任何类别的第一分数时,才确定第一网络输出可能不准确。
在该示例中,如下文将描述的,系统可以使用鼓励第一神经网络将高分数指配给“硬”示例(即,第一神经网络对于正确分类没有信心的示例)的弃权类的目标来训练神经网络。
响应于确定第一网络输出可能不准确,系统将新网络输入作为输入提供给第二神经网络(步骤208)。第二神经网络被配置成处理新网络输入以生成第二网络输出,该第二网络输出包括用于第二多个类别中的每一个类别的相应的第二分数。第二组类别包括第一组类别中的所有类别以及可选的一个或多个附加类别。
然后系统基于第二网络输出对新网络输入进行分类(步骤210)。例如,系统可以将网络输入分类为属于具有最高第二分数的一个或多个类别。
响应于确定第一网络输出不太可能不准确,系统基于第一网络输出对新网络输入进行分类,而不将新网络输入作为输入提供给第二神经网络(步骤212)。
如上所述,在一些实现方式中,第二神经网络部署在远离边缘计算设备的一个或多个第二计算机上。例如,第二神经网络可以部署在云计算系统中,而第一个神经网络可以部署在边缘设备上。在这些实现方式中,网络输入在边缘设备处被接收,并且步骤202、204、206以及当被执行时,步骤212在边缘设备上本地执行。在这些实现方式中,为了将新网络输入作为输入提供给第二神经网络,系统通过数据通信网络将来自边缘计算设备的新网络输入提供给一个或多个第二计算机。然后边缘设备通过数据通信网络从(一个或多个)第二计算机获得第二网络输出。
图3是用于训练第一神经网络的示例过程300的流程图。为方便起见,过程300将被描述为由位于一个或多个位置的一个或多个计算机的系统执行。例如,训练系统,例如图1的训练系统100,根据本说明书适当地编程,可以执行过程300。
在使用常规监督学习或半监督学习技术在训练数据上训练第二神经网络之后,系统可以对不同批量的训练样本重复执行过程300以训练第一神经网络,即以确定第一神经网络的网络参数的训练值。
系统获得一批量的一个或多个训练输入(步骤302)。
在一些实施方式中,一些或所有训练输入与训练输入的标签相关联,该标签识别训练输入的真实值(ground truth)类别。
特别地,系统可以从一组训练数据中采样一批量的一个或多个训练输入,以用于训练第一神经网络。训练数据可以是与用于训练第二神经网络的训练数据相同的或是不同的一组训练数据,例如,包含比用于训练第二神经网络的组更多的训练输入的更大的一组训练数据,例如,除了带有标签的那些训练输入之外,还包括未加标签的训练输入。
系统可以使用任何适当的采样技术从一组训练数据中对批量进行采样。例如,系统可以从一组训练数据中均匀随机采样固定数量的训练样例。作为另一个示例,系统可以使用过采样或欠采样对训练示例进行采样。
对于批量中的每个训练输入,系统使用第二神经网络为批量生成伪标签(步骤304)。即,系统使用训练的第二神经网络处理训练输入,并使用第二网络输出,即第二神经网络生成的分数分布,作为训练输入的伪标签。
系统使用伪标签训练神经网络以最小化损失函数(步骤306)。
特别地,系统例如通过反向传播确定损失函数相对于第一神经网络的网络参数的梯度。
系统使用梯度更新网络参数的当前值。特别是,系统通过使用适当的优化器——例如,Adam、rmsProp、Adafactor、SGD——将梯度映射到更新来更新当前值,然后应用更新,例如,将更新添加到网络参数的当前值或从网络参数的当前值中减去。
在一些实施方式中,损失函数包括针对具有真实值标签的任一个给定训练网络输入鼓励由第一神经网络针对该给定训练输入生成的第一网络输出的项,以匹配从第二神经网络为给定训练输入生成的第二网络输出生成的伪标签,所述真实值标签将该给定训练网络输入指配到第二多个类别的预定真子集中的任一个类别。
在这些实现中的一些中,该项针对具有真实值标签的任一个给定训练网络输入鼓励由第一神经网络针对该给定训练输入生成的第一网络输出,以匹配独立于由第二神经网络为该给定训练输入生成的任一个第二网络输出的分布,所述真实值标签将该给定训练网络输入指配到不在多个类别的预定真子集中的任一个类别。
例如,该分布可以是从给定训练网络输入的真实值标签生成的标签平滑分布。
作为另一示例,该分布可以是第二多个类别上的均匀分布。
作为另一示例,当第一神经网络还为弃权类生成分数时,该分布可以是将非零分数仅指配给弃权类的独热分布。
通常,系统基于用于训练第一神经网络并且具有标签的训练数据中的每个类别的出现频率来选择预定真子集。特别地,系统可以选择第一真子集,使得第一真子集包括训练数据频率分布的“头部”中的类别。例如,系统可以选择阈值数量的具有最高出现频率的类别作为真子集,或者可以按出现频率的降序选择类别,即从最高出现频率开始,直到总和所选择类的频率超过阈值。
在一些其他实施方式中,损失函数包括针对训练数据中的训练输入的预定真子集中的任一个给定训练网络输入鼓励由第一神经网络针对该给定训练输入生成的第一网络输出的项,以匹配从第二神经网络为给定训练输入生成的第二网络输出生成的伪标签。
在这些实施方式中,已经基于由第二神经网络为训练数据中的训练输入生成的第二网络输出中的分数之间的间隔来选择训练数据中训练输入的预定真子集。例如,系统可以确定(i)第二组类别中任一个类别的最高第一分数与(ii)第二组中任一个类别的第二高的第一分数之间的差是否大于阈值。如果差大于阈值,则系统确定训练输入可能“容易”,并将训练输入添加到真子集中。
如果训练输入不在子集中,则该项鼓励由第一神经网络为给定训练输入生成的第一网络输出以匹配独立于由第二神经网络为给定训练输入生成的任一个第二网络输出的分布。例如,该分布可以是从给定训练网络输入的真实值标签生成的标签平滑分布。
在上述实施方式中,损失函数中的项可以是交叉熵项,其测量由第一神经网络生成的第一网络输出与如上所述生成的目标输出——即,上述分布中的一个——之间的交叉熵损失。
可选地,损失函数可以包括一个或多个附加项,例如,正则化项或测量第一网络输出和真实值标签之间的误差的一个或多个项。
本说明书使用与系统和计算机程序组件相关的术语“配置”。对于配置为执行特定操作或动作的一个或多个计算机的系统而言,意味着系统已在其上安装了软件、固件、硬件或它们的组合,这些软件、固件、硬件或它们的组合在操作中导致系统执行操作或动作。一个或多个计算机程序被配置为执行特定操作或动作意味着该一个或多个程序包括指令,当由数据处理装置执行所述指令时,使该装置执行操作或动作。
本说明书中描述的主题和功能操作的实施例可以在包括在本说明书中公开的结构和它们的结构等效物的数字电子电路中、在有形体现的计算机软件或固件中、在计算机硬件中实现,或者在它们中的一个或多个的组合中实现。本说明书中描述的主题的实施例可以实现为一个或多个计算机程序,即,一个或多个编码在有形的非暂时性存储介质上的计算机程序指令模块,用于由数据处理装置执行或控制数据处理装置的操作。计算机存储介质可以是机器可读存储设备、机器可读存储基板、随机或串行存取存储器设备,或者它们中的一个或多个的组合。替代地或附加地,程序指令可以在人工生成的传播信号上编码,例如机器生成的电、光或电磁信号,所述传播信号被生成以编码信息以传输到合适的接收器装置以供数据处理装置执行。
术语“数据处理装置”指的是数据处理硬件并且涵盖用于处理数据的所有种类的装置、设备和机器,包括例如可编程处理器、计算机或多个处理器或计算机。该装置还可以是或进一步包括专用逻辑电路,例如FPGA(现场可编程门阵列)或ASIC(专用集成电路)。除了硬件之外,该装置可以可选地包括为计算机程序创建执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统或它们中的一个或多个的组合的代码。
计算机程序,也可以称为或描述为程序、软件、软件应用、app、模块、软件模块、脚本或代码,可以用任何形式的编程语言编写,所述编程语言包括编译或解释性语言,或声明性或程序性语言;并且它可以以任何形式部署,包括作为独立程序或作为模块、组件、子程序或其他适用于计算环境的单元。程序可以但不必对应于文件系统中的文件。程序可以存储在包含其他程序或数据的文件的一部分中,例如,存储在标记语言文档中的一个或多个脚本、专用于所讨论程序的单个文件或多个协调文件中,例如,存储一个或多个模块、子程序或部分代码的文件。可以部署计算机程序以在一个计算机或位于一个站点或分布在多个站点并通过数据通信网络互连的多个计算机上执行。
在本说明书中,术语“数据库”广泛用于指代任何数据合集:数据不需要以任何特定方式结构化,或者根本不需要结构化,它可以存储在一个或多个地点中的存储设备上。因此,例如,索引数据库可以包括多个数据合集,每个数据合集可以以不同方式组织和访问。
类似地,在本说明书中,术语“引擎”广泛用于指代基于软件的系统、子系统或被编程为执行一个或多个特定功能的过程。通常,引擎将被实现为一个或多个软件模块或组件,安装在一个或多个位置的一个或多个计算机上。在某些情况下,一个或多个计算机将专用于特定引擎;在其他情况下,可以在同一个或多个计算机上安装和运行多个引擎。
本说明书中描述的过程和逻辑流程可以由执行一个或多个计算机程序的一个或多个可编程计算机来执行以通过对输入数据进行操作并生成输出来执行功能。过程和逻辑流程也可以由专用逻辑电路(例如FPGA或ASIC)或由专用逻辑电路和一个或多个编程计算机的组合来执行。
适合执行计算机程序的计算机可以基于通用或专用微处理器或两者,或任何其他类型的中央处理单元。通常,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的基本元件是用于执行或执行指令的中央处理单元以及用于存储指令和数据的一个或多个存储器设备。中央处理单元和存储器可以由专用逻辑电路补充或结合在专用逻辑电路中。通常,计算机还将包括或可操作地耦合以从一个或多个用于存储数据的大容量存储设备(例如,磁、磁光盘或光盘)接收数据或向其传递数据或两者。然而,计算机不需要有这样的设备。此外,计算机可以嵌入到另一个设备中,例如移动电话、个人数字助理(PDA)、移动音频或视频播放器、游戏控制台、全球定位系统(GPS)接收器或便携式存储设备,例如,通用串行总线(USB)闪存驱动器,仅举几例。
适用于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储设备,例如包括半导体存储设备,例如EPROM、EEPROM和闪存设备;磁盘,例如内部硬盘或可移动磁盘;磁光盘;和CD ROM和DVD-ROM磁盘。
为了提供与用户的交互,本说明书中描述的主题的实施例可以在计算机上实现,所述计算机具有用于将信息显示给用户的显示设备和用户可以通过它们向计算机提供输入的键盘和定点设备,所述显示设备例如是CRT(阴极射线管)或LCD(液晶显示器)监视器,所述定点设备例如是鼠标或轨迹球。也可以使用其他类型的设备来提供与用户的交互;例如,提供给用户的反馈可以是任何形式的感官反馈,例如视觉反馈、听觉反馈或触觉反馈;可以以任何形式接收来自用户的输入,包括声音、语音或触觉输入。此外,计算机可以通过向用户使用的设备发送文档和从其接收文档来与用户交互;例如,通过响应从web浏览器收到的请求而将网页发送到用户设备上的web浏览器。此外,计算机可以通过将文本消息或其他形式的消息发送到个人设备(例如,运行消息传递应用的智能手机)并继而接收来自用户的响应消息来与用户交互。
用于实现机器学习模型的数据处理装置还可以包括例如专用硬件加速器单元,以用于处理机器学习训练或生产的常见和计算密集型部分,即推断、工作负载。
机器学习模型可以使用机器学习框架来实现和部署,所述机器学习框架例如是TensorFlow框架、Microsoft Cognitive Toolkit框架、Apache Singa框架或Apache MXNet框架。
本说明书中描述的主题的实施例可以在计算系统中实现,该计算系统包括:后端组件,例如,作为数据服务器;或者包括中间件组件,例如,应用服务器;或者包括前端组件,例如,具有图形用户界面、web浏览器或用户可以通过其与本说明书中描述的主题的实施方式进行交互的app的客户端计算机;或者一个或多个这样的后端、中间件或前端组件的任一个组合。系统的组件可以通过任何形式或介质的数字数据通信互连,例如通信网络。通信网络的示例包括局域网(LAN)和广域网(WAN),例如因特网。
计算系统可以包括客户端和服务器。客户端和服务器通常彼此远离并且通常通过通信网络进行交互。客户端和服务器的关系是通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序而产生的。在一些实施例中,服务器将数据(例如HTML页面)传输到用户设备,例如,用于向与作为客户端的设备交互的用户显示数据和从用户接收用户输入的目的。在用户设备处生成的数据,例如用户交互的结果,可以在服务器处从设备接收。
虽然本说明书包含许多具体的实施方式细节,但这些不应被解释为对任何发明的范围或可要求保护的范围的限制,而是对可能特定于特定发明的特定实施例的特征的描述。在本说明书中在单独实施例的上下文中描述的某些特征也可以在单个实施例中组合实现。相反,在单个实施例的上下文中描述的各种特征也可以在多个实施例中单独或以任何合适的子组合来实现。此外,尽管特征可以在上面描述为在某些组合中起作用,甚至最初被要求保护,但在某些情况下,来自要求保护的组合的一个或多个特征可以从组合中删除,并且要求保护的组合可以针对子组合或子组合的变化。
类似地,虽然操作在附图中被描绘并且在权利要求中以特定顺序叙述,但这不应被理解为要求这样的操作以所示的特定顺序或按顺序执行,或者所有所示的操作都被执行,以达到理想的效果。在某些情况下,多任务和并行处理可能是有利的。此外,上述实施例中各个系统模块和组件的分离不应被理解为在所有实施例中都需要这样的分离,而应该理解的是,所描述的程序组件和系统通常可以集成在单个软件产品中,或者打包成多个软件产品。
已经描述了本主题的特定实施例。其他实施例在以下权利要求的范围内。例如,权利要求中记载的动作可以以不同的顺序执行,但仍能达到期望的结果。作为一个示例,附图中描绘的过程不一定需要所示的特定顺序或顺序来实现期望的结果。在某些情况下,多任务和并行处理可能是有利的。

Claims (22)

1.一种由一个或多个计算机执行的方法,所述方法包括:
获得新网络输入;
使用第一神经网络处理所述新网络输入以生成第一网络输出,所述第一网络输出包括针对第一多个类别中的每个类别的相应第一分数;
根据所述第一网络输出确定所述第一网络输出是否可能是不准确的;
响应于确定所述第一网络输出可能是不准确的:
将所述新网络输入作为输入提供给第二神经网络,所述第二神经网络被配置为处理所述新网络输入以生成第二网络输出,所述第二网络输出包括第二多个类别中的每一个类别的相应第二分数,所述第二多个类别包括所述第一多个类别中的所有类别;和
基于所述第二网络输出来对所述新网络输入进行分类。
2.根据权利要求1所述的方法,其中,基于所述第二网络输出来对所述新网络输入进行分类包括:
将所述网络输入分类为属于具有最高第二分数的一个或多个类别。
3.根据权利要求1所述的方法,还包括:
响应于确定所述第一网络输出不可能是不准确的:
基于所述第一网络输出来对所述新网络输入进行分类,而不将所述新网络输入作为输入提供给所述第二神经网络。
4.根据权利要求1所述的方法,其中,所述一个或多个计算机是边缘计算设备,并且其中,所述第一神经网络被部署在所述边缘计算设备上。
5.根据权利要求4所述的方法,其中,所述第一神经网络被部署在所述边缘计算设备上的专用集成电路(ASIC)或现场可编程门阵列(FPGA)上。
6.根据权利要求4所述的方法,其中,所述第二神经网络被部署在远离所述边缘计算设备的一个或多个第二计算机上,其中,将所述新网络输入作为输入提供给所述第二神经网络包括通过数据通信网络将所述新网络输入从所述边缘计算设备提供给所述一个或多个第二计算机,并且其中,所述方法还包括通过所述数据通信网络获得所述第二网络输出。
7.根据权利要求1所述的方法,其中,根据所述第一网络输出确定所述第一网络输出是否可能是不准确的包括:
确定具有最高第一分数的类别是否在所述第一多个类别的预定真子集中;和
在具有最高第一分数的所述类别在所述多个类别的所述预定真子集中时,确定所述第一网络输出可能是不准确的。
8.根据权利要求7所述的方法,其中,所述多个类别的所述预定真子集已经基于用于训练所述第一神经网络的训练数据中的每个类别的出现频率来选择。
9.根据权利要求1所述的方法,其中,根据所述第一网络输出确定所述第一网络输出是否可能是不准确的包括:
确定(i)所述第一多个类别中的任一个类别的最高第一分数与(ii)所述第一多个类别中的任一个类别的第二高的第一分数之间的差是否满足阈值;和
在所述差满足所述阈值时,确定所述第一网络输出可能是不准确的。
10.根据权利要求1所述的方法,其中,所述第一网络输出还包括弃权类的第一分数,并且其中,根据所述第一网络输出确定所述第一网络输出是否可能是不准确的包括:
确定所述弃权类的第一分数是否高于所述第一多个类别中的任一个类别的第一分数;和
在所述弃权类的第一分数高于所述第一多个类别中的任一个类别的所述第一分数时,确定所述第一网络输出可能是不准确的。
11.根据权利要求1所述的方法,其中,所述第一神经网络已经使用由所述第二神经网络通过处理训练网络输入生成的伪标签来训练。
12.根据权利要求11所述的方法,其中,所述第一神经网络已经被训练以最小化包括第一项的损失函数,所述第一项针对具有真实值标签的任一个给定训练网络输入鼓励由所述第一神经网络针对所述给定训练输入生成的第一网络输出,以与从由所述第二神经网络针对所述给定训练输入生成的第二网络输出生成的伪标签相匹配,所述真实值标签将所述给定训练网络输入指配给所述第二多个类别的预定真子集中的任一个类别。
13.根据权利要求12所述的方法,其中,所述第一项针对具有真实值标签的任一个给定训练网络输入鼓励由所述第一神经网络针对所述给定训练输入生成的第一网络输出,以与独立于由所述第二神经网络针对所述给定训练输入生成的任一个第二网络输出的分布相匹配,所述真实值标签将所述给定训练网络输入指配给不在所述多个类别的所述预定真子集中的任一个类别。
14.根据权利要求13所述的方法,其中,所述分布是针对所述给定训练网络输入从所述真实值标签生成的标签平滑分布。
15.根据权利要求13所述的方法,其中,所述分布是在所述第二多个类别上的均匀分布。
16.根据权利要求13所述的方法,其中,所述分布是仅将非零分数指配给弃权类的独热分布。
17.根据权利要求12所述的方法,其中,所述第二多个类别的所述预定真子集已经基于用于训练所述第一神经网络的训练数据中的每个类别的出现频率来选择。
18.根据权利要求11所述的方法,其中,所述第一神经网络已被训练以最小化包括第一项的损失函数,所述第一项针对在所述训练数据中的训练输入的预定真子集的任一个给定训练网络输入鼓励由所述第一神经网络针对所述给定训练输入生成的第一网络输出,以与从所述第二神经网络针对所述给定训练输入生成的第二网络输出生成的伪标签相匹配。
19.根据权利要求18所述的方法,其中,所述训练数据中的训练输入的所述预定真子集已经基于由所述第二神经网络针对所述训练数据中的所述训练输入生成的第二网络输出中的分数之间的间隔来选择。
20.根据权利要求1-19中的任一项所述的方法,其中,所述第一神经网络具有比所述第二神经网络更少的参数。
21.一种系统,包括:
一个或多个计算机;和
一个或多个存储指令的存储设备,所述指令在由所述一个或多个计算机执行时,使所述一个或多个计算机执行根据权利要求1-20中的任一项所述的方法的相应操作。
22.一个或多个存储指令的计算机可读存储介质,所述指令在由一个或多个计算机执行时,使所述一个或多个计算机执行根据权利要求1-20中的任一项所述的方法的相应操作。
CN202210391735.6A 2021-04-14 2022-04-14 多阶段计算高效的神经网络推断 Pending CN114861873A (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
US202163175042P 2021-04-14 2021-04-14
US63/175,042 2021-04-14

Publications (1)

Publication Number Publication Date
CN114861873A true CN114861873A (zh) 2022-08-05

Family

ID=82630538

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210391735.6A Pending CN114861873A (zh) 2021-04-14 2022-04-14 多阶段计算高效的神经网络推断

Country Status (2)

Country Link
US (1) US20220335274A1 (zh)
CN (1) CN114861873A (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11922314B1 (en) * 2018-11-30 2024-03-05 Ansys, Inc. Systems and methods for building dynamic reduced order physical models

Also Published As

Publication number Publication date
US20220335274A1 (en) 2022-10-20

Similar Documents

Publication Publication Date Title
US11934956B2 (en) Regularizing machine learning models
CN111602148B (zh) 正则化神经网络架构搜索
US11995528B2 (en) Learning observation representations by predicting the future in latent space
US11681924B2 (en) Training neural networks using a variational information bottleneck
US20210150355A1 (en) Training machine learning models using task selection policies to increase learning progress
US20220121906A1 (en) Task-aware neural network architecture search
US20190286984A1 (en) Neural architecture search by proxy
US20230049747A1 (en) Training machine learning models using teacher annealing
US20220215209A1 (en) Training machine learning models using unsupervised data augmentation
US20220230065A1 (en) Semi-supervised training of machine learning models using label guessing
US20220188636A1 (en) Meta pseudo-labels
CN114861873A (zh) 多阶段计算高效的神经网络推断
WO2023158881A1 (en) Computationally efficient distillation using generative neural networks
US20230017505A1 (en) Accounting for long-tail training data through logit adjustment
US20220108174A1 (en) Training neural networks using auxiliary task update decomposition
US20210383195A1 (en) Compatible neural networks
US20230206030A1 (en) Hyperparameter neural network ensembles
US20220019856A1 (en) Predicting neural network performance using neural network gaussian process
WO2021159099A9 (en) Searching for normalization-activation layer architectures
US20240169211A1 (en) Training neural networks through reinforcement learning using standardized absolute deviations
US20230145129A1 (en) Generating neural network outputs by enriching latent embeddings using self-attention and cross-attention operations
US20240152809A1 (en) Efficient machine learning model architecture selection
WO2022167660A1 (en) Generating differentiable order statistics using sorting networks
EP4315180A1 (en) Efficient hardware accelerator configuration exploration
CN114595743A (zh) 基于模仿学习的小样本目标分类方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination