CN114424212A - 基于距离的学习置信度模型 - Google Patents

基于距离的学习置信度模型 Download PDF

Info

Publication number
CN114424212A
CN114424212A CN202080066367.7A CN202080066367A CN114424212A CN 114424212 A CN114424212 A CN 114424212A CN 202080066367 A CN202080066367 A CN 202080066367A CN 114424212 A CN114424212 A CN 114424212A
Authority
CN
China
Prior art keywords
training
query
training examples
distance
ground truth
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202080066367.7A
Other languages
English (en)
Inventor
塞尔坎·厄梅尔·阿勒克
邢晨
张子钊
托马斯·乔恩·普菲斯特
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Google LLC
Original Assignee
Google LLC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Google LLC filed Critical Google LLC
Publication of CN114424212A publication Critical patent/CN114424212A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2148Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the process organisation or structure, e.g. boosting cascade
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2413Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/09Supervised learning

Abstract

一种用于联合训练分类模型(210)和置信度模型(220)的方法(500)包括接收包括多个训练数据子集(112)的训练数据集(110)和选择训练示例的支持集(114S)和训练示例的查询集(114Q)。该方法还包括基于类距离度量和与为训练示例的查询集中的每个训练示例生成的查询编码(212Q)相关联的地面真实距离来更新分类模型的参数。对于被识别为被错误分类的每个训练示例,该方法还包括对新查询编码进行采样(224)并且基于新查询编码来更新置信度模型的参数。

Description

基于距离的学习置信度模型
技术领域
本公开涉及基于距离的学习置信度模型。
背景技术
机器学习模型接收输入并基于接收到的输入生成输出,例如预测输出。机器学习模型是根据数据进行训练的。然而,量化训练模型的置信度以用于预测(也称为置信度校准)是一项挑战。对于“校准良好”的模型,具有更高置信度的预测应该更有可能准确。然而,被错误地解释为模型置信度的在管线末端获得的预测概率(softmax输出)对模型的决策质量进行了很差的校准——即使分类不准确,置信度值也往往很大。
发明内容
本公开的一个方面提供了一种用于联合训练分类模型和置信度模型的方法。该方法包括在数据处理硬件处接收包括多个训练数据子集的训练数据集。每个训练数据子集与不同的相应类相关联,并且具有属于相应类的多个对应的训练示例。该方法还包括,从训练数据集中的两个或更多个训练数据子集中,通过数据处理硬件选择训练示例的支持集和训练示例的查询集。训练示例的支持集包括从所述两个或更多个训练数据子集的每一个中采样的K个训练示例,并且训练示例的查询集包括从所述两个或更多个训练数据子集的每一个采样的不包含在训练示例的支持集中的训练示例。对于与所述两个或更多个训练数据子集相关联的每个相应类,该方法还包括由数据处理硬件使用分类模型通过平均与属于该相应类的训练示例的支持集中的所述K个训练示例相关联的K个支持编码来确定质心值。对于训练示例的查询集中的每个训练示例,该方法还包括:通过数据处理硬件使用分类模型生成查询编码;通过数据处理硬件确定表示查询编码和为每个相应类确定的质心值之间的相应距离的类距离度量;通过数据处理硬件确定查询编码和与训练示例的查询集中的对应训练示例相关联的地面真实标签之间的地面真实距离;通过数据处理硬件基于类距离度量和地面真实距离更新分类模型的参数。对于被识别为错误分类的训练示例的查询集中的每个训练示例,该方法还包括:通过数据处理硬件使用置信度模型生成由分类模型为对应的错误分类训练示例生成的查询编码的标准偏差值;通过数据处理硬件使用标准偏差值和查询编码,对对应的错误分类训练示例的新查询编码进行采样;通过数据处理硬件基于新查询编码更新置信度模型的参数。
本公开的实施方式可以包括以下可选特征中的一个或多个。在一些实施方式中,地面真实标签包括基于距离的表示空间内的地面真实质心值。在一些示例中,基于类距离度量和地面真实距离更新分类模型的参数训练分类模型以最小化类内距离和最大化类间距离。
在一些实施方式中,训练置信度模型以最大化对于较大的地面真实距离的标准偏差值并且对接近于相应地面真实质心值的新查询编码进行采样。在一些示例中,被识别为被错误分类的训练示例的查询集中的任何训练示例包括训练示例的查询集中包括未能满足距离阈值的地面真实距离的任何训练示例。在一些实施方式中,置信度模型没有在所述训练示例的查询集中包括满足距离阈值的地面真实距离的训练示例上进行训练。在一些示例中,更新置信度模型的参数包括更新置信度模型的参数以鼓励置信度模型输出与较大的类距离度量相关联的查询编码的较大的标准偏差值。在一些实施方式中,训练示例包括图像数据。分类模型可以包括深度神经网络(DNN)。在一些示例中,置信度模型包括深度神经网络(DNN)。
本公开的另一方面提供了一种用于联合训练分类模型和置信度模型的系统。该系统包括数据处理硬件和与数据处理硬件通信的存储器硬件。存储器硬件存储指令,所述指令在由数据处理硬件执行时使数据处理硬件执行操作,所述操作包括接收包括多个训练数据子集的训练数据集。每个训练数据子集与不同的相应类相关联,并且具有属于该相应类的多个对应的训练示例。所述操作还包括,从训练数据集中的两个或更多个训练数据子集中,选择训练示例的支持集和训练示例的查询集。训练示例的支持集包括从所述两个或更多个训练数据子集的每一个中采样的K个训练示例,训练示例的查询集包括从所述两个或更多个训练数据子集的每一个采样的不包含在所述训练示例的支持集中的训练示例。对于与所述两个或更多个训练数据子集相关联的每个相应类,所述操作还包括使用分类模型通过平均与属于所述相应类的训练示例的支持集中的K个训练示例相关联的K个支持编码确定质心值。对于训练示例查询集中的每个训练示例,所述操作还包括使用分类模型生成查询编码,确定表示查询编码和为每个相应类确定的质心值之间的相应距离的类距离度量;确定查询编码和与训练示例的查询集中的对应训练示例相关联的地面真实标签之间的地面真实距离;并且基于类距离度量和地面真实距离更新分类模型的参数。对于识别为错误分类的训练示例的查询集中的每个训练示例,所述操作还包括使用置信度模型生成由分类模型为对应的错误分类训练示例生成的查询编码的标准偏差值;使用标准偏差值和查询编码对相应的错误分类训练示例的新查询编码进行采样;以及基于新查询编码更新置信度模型的参数。
本公开的实施方式可以包括以下可选特征中的一个或多个。在一些实施方式中,地面真实标签包括基于距离的表示空间内的地面真实质心值。在一些示例中,基于类距离测量和地面真实距离更新分类模型的参数训练分类模型以最小化类内距离和最大化类间距离。
在一些实施方式中,置信度模型被训练以最大化对于较大的地面真实距离的标准偏差值并且对接近于地面真实质心值的新查询编码进行采样。在一些示例中,被识别为被错误分类的训练示例的查询集中的任何训练示例包括训练示例的查询集中包括未能满足距离阈值的地面真实距离的任何训练示例。在一些实施方式中,置信度模型没有在训练示例的查询集中包括满足距离阈值的地面真实距离的训练示例上进行训练。在一些示例中,更新置信度模型的参数包括更新置信度模型的参数以鼓励置信度模型输出与较大的类距离度量相关联的查询编码的较大的标准偏差值。在一些实施方式中,训练示例包括图像数据。分类模型可以包括深度神经网络(DNN)。在一些示例中,置信度模型包括深度神经网络(DNN)。
本公开的一个或多个实施方式的细节在附图和以下描述中阐述。从描述和附图以及从权利要求中,其他方面、特征和优点将是显而易见的。
附图说明
图1是提供训练框架的示例系统,该训练框架实现基于距离的从错误中学习(DBLE)来训练分类模型和校准模型。
图2是用于训练图1的分类模型和校准模型的DBLE架构的示例。
图3A和图3B示出了训练示例的基于距离的表示空间的图。
图4是使用DBLE训练分类模型和校准模型的示例算法。
图5是用于与置信度模型并行地训练分类模型的方法的示例操作布置的流程图。
图6是可用于实现本文描述的系统和方法的示例计算设备的示意图。
各图中相同的参考符号表示相同的元件。
具体实施方式
用于训练深度神经网络(DNN)的传统技术通常会导致校准不佳的DNN。由于DNN部署在许多重要的决策场景中,因此校准不佳可能会导致代价可能非常高昂的错误决策。为了防止对DNN做出的错误决策采取行动,DNN最好输出对DNN输出的决策的置信度估计。为此,系统可以避免对DNN输出的具有低置信度的决策采取行动,可以避免采取行动和/或可以咨询人类专家,使得如果这些低置信度的决策被依赖和采取行动则可以避免不利的后果。不幸的是,准确的置信度估计对于DNN来说是一个挑战,尤其是对于校准不佳的DNN。
本文的实施方式针对实施基于距离的从错误中学习(DBLE)以产生良好校准的神经网络的训练框架。在DBLE中,系统并行训练分类模型(也称为“预测模型”)和置信度模型。使用DBLE训练分类模型学习基于距离的表示空间,由此,基于距离的表示空间定义了测试样本到文本样本的地面真实类中心的L2距离,用于校准分类模型在给定的测试样本上的性能。因此,与普通训练(优化最大似然度的传统训练)不同,使用DBLE训练分类模型具有用作校准其决策质量的黄金置信度测量的特征。但是,由于计算测试样本的距离需要地面真实类中心的标签,因此无法在推理时直接获得。因此,使用DBLE训练置信度模型被配置为在推理期间将此距离估计为置信度分数。为了训练置信度模型,DBLE在分类模型的训练期间使用了错误分类的训练样本(从训练错误中学习)。
参考图1,在一些实施方式中,系统100包括计算环境130,计算环境130包括资源102,诸如数据处理硬件104(例如,服务器或CPU和/或存储指令的远程存储器硬件106,所述指令当在数据处理硬件104上执行时使数据处理硬件104执行操作。并行训练分类模型210和置信度模型220的基于距离的从错误中学习(DBLE)架构200可以驻留在资源102上。在所示示例中,DBLE架构200在训练数据集110上训练分类模型210,该训练数据集110包括多个训练数据子集112,112a-n,每个训练数据子集包括与不同的相应类相关联的多个训练示例114。每个训练示例114包括对应的地面真实标签,其指示训练示例114所属的相应类。这里,地面真实标签可以包括基于距离的表示空间中的地面真实的质心值212G。在一些示例中,训练示例对应于图像或图像数据。
显而易见,DBLE架构200被配置为通过分类模型210学习基于距离的表示空间,并利用空间中的距离来产生良好校准的分类。DBLE架构200依赖于测试样本在表示空间中的位置和测试样本到同一类中的训练样本的距离包含用于指导置信度估计的有用信息的相关性。即,DBLE架构被配置为适应用于训练和推理的原型学习,以通过分类来学习基于距离的表示空间,使得测试样本到地面真实类中心的距离能够校准分类模型210的性能。如本文所用,原型学习是指训练和预测两者仅取决于样本到其在表示空间中的对应类中心(也称为“原型”)的距离,从而优化分类模型210的训练以最小化类内距离并最大化类间距离,使得相关样本在表示空间中聚集在一起。由于在推理期间地面真实类中心的地面真实标签是未知的,因此DBLE架构200与分类模型210联合地训练单独的置信度模型220,从而允许估计测试样本与其地面真实类中心的距离。具体而言,实施方式针对仅在训练期间被分类模型210错误分类的训练样本上训练置信度模型。模型210、220可以各自包括深度神经网络(DNN)。
图2提供了用于并行联合训练分类模型210和置信度模型220的示例DBLE架构200,使得被识别为被分类模型210错误分类的训练示例114用于训练置信度模型220以使置信度模型220能够在不知道地面真实中心的推理过程中,在基于距离的表示空间中估计测试样本到其地面真实中心的距离。与基于最小批梯度下降的变体的分类的普通训练技术相比,DBLE架构使用回合(episodic)训练来训练分类模型210,其中DBLE通过从训练数据集110中随机采样训练示例114以选择以下两个训练示例集合来创建回合:(1)训练示例的支持集114S;(2)训练示例的查询集114Q。更具体地,DBLE通过首先从多个数据子集112中随机采样/选择N个训练数据子集112来创建每个回合。此后,DBLE通过从N个训练数据子集112中的每一个中采样K个训练示例114Sa-k来选择训练样本的支持集114S,并且通过从不包括在训练样本的支持集114S中的N个训练数据子集112的每一个中采样训练示例114来选择训练示例的查询集114Q。在一些示例中,N个训练数据子集包括训练数据集110中的两个或更多个训练数据子集112。虽然N个训练数据子集可以包括全部多个训练数据子集,但DBLE不需要使用整个训练数据子集,因为当不同类的数量非常大时,将来自批中的训练示例的支持集的训练示例容纳到处理器存储器可能具有挑战性。
对于与N个训练数据子集112相关联的每个相应类,DBLE使用分类模型210通过对与属于相应类的训练示例的支持集114S中的K个训练示例114Sa-k相关联的K个支持编码212S、212Sa-k进行平均来确定质心值214。也就是说,对于给定的类,分类模型210接收训练示例的支持集114S中的K个训练示例114中的每一个作为输入,并且为支持集中每个训练示例生成对应的支持编码212S作为输出。对于给定类,DBLE对K个支持编码212S进行平均以计算/确定相应给定类的相应质心值214。因此,DBLE对于剩余的N个训练子集212重复,从而计算N个质心值214,使得每个质心值214表示N个类中的相应一个。
分类模型210还为训练示例的查询集114Q中的每个训练示例生成相应的查询编码212Q,hi,并且DBLE确定表示查询编码212Q和为每个相应类确定的质心值214之间的相应距离的类距离度量。DBLE还确定查询编码212Q和与训练示例的查询集114Q中的对应训练示例相关联的地面真实质心值212G之间的地面真实距离,并基于类距离度量和地面真实距离来更新分类模型210的参数。具体而言,DBLE使用接收查询编码212Q和为N个相应类中的每一个确定的质心值214,214a-n的用于分类的原型损失215以确定/计算相应类距离度量,并且还接收地面真实质心值212G以确定/计算查询编码212Q和地面真实质心值212G之间的地面真实距离。因此,分类模型210是由可训练参数θ参数化的函数,并且使用由以下等式表示的与给定训练示例的支持集114S的训练示例的查询集114Q中的每个训练示例的地面真实质心值212G的负对数似然相关联的损失:
Figure BDA0003559126270000081
其中Se是训练示例的支持集114S,Qe是训练示例的查询集114Q,yi是地面真实质心值212G,并且xi是输入到分类模型的训练示例的查询集114Q,θ表示分类模型210的可训练参数。用于分类的原型损失215被配置为使用以下等式基于N个类的每个相应类的相应类距离度量来计算训练示例的查询集114Q中的每个训练示例xi的预测标签分布:
Figure BDA0003559126270000082
其中hi是表示基于距离的表示空间中的对应训练示例xi的对应的查询编码214Q。因此,DBLE通过最小化由等式1利用预测标签分布p(yi|xi,Se;θ)计算的损失
Figure BDA0003559126270000083
来更新分类模型210的可训练参数θ,所述预测标签分布p(yi|xi,Se;θ)使用等式2为训练示例的查询集114Q中的每个训练示例xi计算。因此,在查询编码212Q和为N个类确定的质心值214的表示空间中,分类模型210的训练使类间距离最大化并且使类内距离最小化。结果,属于同一类的训练示例聚集在一起,表示不同类的集群在表示空间中被推开。
在使用分类模型210对训练示例的查询集114Q中的每个训练示例进行分类时,DBLE 200识别被分类模型210错误分类的任何训练示例。DBLE 200可以在由分类模型210预测的分类与训练示例的对应地面真实标签212G不匹配时将训练示例识别为错误分类。在一些示例中,当查询编码212Q和与训练示例的查询集114Q中的对应训练示例相关联的地面真实质心值212G之间的相应地面真实距离未能满足表示空间中的距离阈值时,DBLE 200将训练示例识别为错误分类。否则,DBLE 200可以识别训练示例的查询集114Q中的任何训练示例,所述任何训练示例在查询编码212Q和与训练示例的查询集114Q中的对应训练示例相关联的地面真实质心值212G之间具有满足(例如,小于或等于)由分类模型210正确分类的距离阈值的相应的地面真实距离。
在一些实施方式中,置信度模型220对训练示例的查询集114Q中的训练示例进行训练,所述训练示例被识别为被分类模型210错误分类。通常,正确分类的训练示例构成在分类模型的训练期间遇到的绝大多数训练示例。基于这个概念,使用所有训练示例114Q将导致主导置信度模型220的训练的与训练示例的查询集114Q中正确分类的训练示例相关联的小/短类距离度量,从而使置信度模型220更难以捕捉与构成所有训练示例114Q的少数的错误分类的训练示例相关联的较大的类距离度量。
在图2的示例DBLE架构200中,置信度模型220周围的虚线框以及与置信度模型220相关联的采样操作225和用于校准的原型损失250表示仅使用被识别为被错误分类的训练示例的查询集114Q中每个训练示例与分类模型210并行地训练置信度模型220。因此,对于被识别为被错误分类的训练示例的查询集114Q中的每个训练示例,在数据处理硬件104上执行的DBLE:使用置信度模型220为由分类模型210为对应的错误分类训练示例生成的查询编码212Q生成标准偏差值222,σ;使用标准偏差值222和查询编码212Q对对应的错误分类训练示例的新查询编码224进行采样,并基于新查询编码224更新置信度模型220的参数
Figure BDA0003559126270000091
置信度模型220被训练以输出与较大类距离度量相关联的查询编码212Q的较大的标准偏差值222,σ。为了对新查询编码224,zs,进行采样,置信度模型220使用从由对应查询编码212Q,hs,和对应标准偏差值222,σs,参数化的各向同性高斯分布采样的采样操作225。用于校准的原型损失250被配置为使用以下等式使用对每个错误分类的训练示例xs进行采样的新查询编码224,zs,的预测标签分布来计算原型损失:
Figure BDA0003559126270000101
因此,DBLE更新置信度模型220的可训练参数
Figure BDA0003559126270000102
以鼓励置信度模型220为与较大类距离度量相关联的查询编码212Q输出较大的标准偏差值222,σ。值得注意的是,通过在表示空间中为每个错误分类的训练示例固定查询编码212Q,最大化等式3会迫使新查询编码224尽可能接近相应的地面真实质心值212G。由于错误分类的训练示例包括查询编码212Q更远离地面真实质心值212G,因此鼓励置信度模型220输出较大的对应标准偏差值222会迫使新查询编码224接近地面真实质心值212G。图4提供了表示使用图2中描述的DBLE 200训练分类和校准模型210、220的示例算法400。
图3A和图3B示出了表示在表示空间中来自训练示例的查询集114Q的训练示例的点的示例图300a、300b。曲线图300a、300b中的每一个中的虚垂直线表示决策边界,其中左侧和右侧的训练示例属于不同的相应类。此外,虚圆圈表示对应查询编码212Q,ha-hc,的标准偏差值222,σ,其中ha、hb与对应于错误分类训练示例114Q的错误分类查询编码222Q相关联,并且hc与对应于正确分类的训练示例114Q的正确分类的查询编码222Q相关联。图3A的曲线图300a示出了在更新置信度模型220的可训练参数
Figure BDA0003559126270000103
之前错误分类的查询编码ha、hb和正确分类的查询编码hc两者的短标准偏差值222。在更新置信度模型220的可训练参数
Figure BDA0003559126270000104
之后,图3B的曲线图300b示出了错误分类的查询编码ha、hb的较大标准偏差值222,这是由于校准的原型损失将从错误分类的训练示例中采样的新查询编码za、zb尽可能靠近与正确类相关联的地面真实质心值212G移动。
返回参考图2,在推理期间,在数据处理硬件104上执行的DBLE200通过使用以下等式对所有对应训练示例的表示212S进行平均来为训练集中的每个类c计算类中心214:
Figure BDA0003559126270000111
其中
Figure BDA0003559126270000112
是属于类k的所有训练示例的集合。然后,给定测试样本,xt,对应的查询编码212Q到每个类中心214的相应类距离度量。xt的标签的预测基于类距离度量,使得xt被分配给具有表示空间中最近的中心的类。因此,如果查询编码212Q离它的地面真实类中心214G太远,则它很可能被错误分类。由于在推理时测试样本xt的地面真实质心值212G是未知的,即没有标签可用,因此DBLE采用经过训练的置信度模型220来估计相应的类距离度量,以帮助分类模型210预测标签。即,分类模型210使用对应的查询编码ht 212Q为每个测试样本xt预测标签。然后置信度模型220输出查询编码ht的标准偏差值σt 222,并且采样操作225对新查询编码224进行采样。DBLE然后使用以下等式将预测标签分布平均作为置信度估计:
Figure BDA0003559126270000113
其中U是新查询编码zt 224的总数,并且
Figure BDA0003559126270000114
用作校准分类模型210的预测y′t的置信度分数。因此,DBLE为更远离地面真实类中心(可能被错误分类)的测试示例的表示采样增加了更多随机性,因为置信度模型的估计变化很大。
图5是用于与置信度模型220并行地训练分类模型210的方法500的示例操作布置的流程图。方法500可以基于存储在图1的存储器硬件106上的指令在图1的数据处理硬件104上执行。在操作502处,方法500包括在数据处理硬件104处接收包括多个训练数据子集112的训练数据集110。每个训练数据子集112与不同的相应类相关联并且具有属于相应类的多个对应的训练示例114。
在操作504处,对于训练数据集110中的两个或更多个训练数据子集112,方法500包括由数据处理硬件104选择训练示例的支持集114S和训练示例的查询集114Q。训练示例的支持集114S包括从两个或更多个训练数据子集112中的每一个中采样的K个训练示例114。训练示例的查询集114Q包括从不包括在训练示例的支持集114S中的两个或更多个训练数据子集112中的每一个中采样的训练示例114。
在操作506处,方法500包括由数据处理硬件104使用分类模型210,通过对与属于相应类的训练示例的支持集114S中的K个训练示例114相关联的K个支持编码215进行平均来确定质心值214。在操作508处,对于训练示例的查询集114Q中的每个训练示例,方法500包括通过数据处理硬件104使用分类模型210生成查询编码212Q;通过数据处理硬件104确定表示查询编码212Q和为每个相应类确定的质心值214之间的相应距离的类距离度量;通过数据处理硬件104确定查询编码212Q和与训练示例的查询集114Q中的对应训练示例114相关联的地面真实标签214G之间的地面真实距离;以及通过数据处理硬件104基于类距离度量和地面真实距离更新分类模型210的参数。
在操作510处,对于被识别为被错误分类的训练示例的查询集114Q中的每个训练示例114,方法500包括由数据处理硬件104使用置信度模型220生成由分类模型210为对应的错误分类训练示例生成的查询编码212Q的标准偏差值222;由数据处理硬件104使用标准偏差值222和查询编码212Q对对应的错误分类训练示例的新查询编码224进行采样;以及由数据处理硬件104基于新查询编码224更新置信度模型220的参数。
软件应用(即,软件资源)可以指使计算设备执行任务的计算机软件。在一些示例中,软件应用可以被称为“应用”、“app”或“程序”。示例应用包括但不限于系统诊断应用、系统管理应用、系统维护应用、文字处理应用、电子表格应用、消息传递应用、媒体流应用、社交网络应用和游戏应用。
非暂时性存储器可以是用于在临时或永久的基础上存储程序(例如,指令序列)或数据(例如,程序状态信息)以供计算设备使用的物理设备。非瞬态存储器可以是易失性和/或非易失性可寻址半导体存储器。非易失性存储器的示例包括但不限于闪存和只读存储器(ROM)/可编程只读存储器(PROM)/可擦除可编程只读存储器(EPROM)/电可擦除可编程只读存储器存储器(EEPROM)(例如,通常用于固件,例如引导程序)。易失性存储器的示例包括但不限于随机存取存储器(RAM)、动态随机存取存储器(DRAM)、静态随机存取存储器(SRAM)、相变存储器(PCM)以及磁盘或磁带。
图6是可用于实现本文档中描述的系统和方法的示例计算设备600的示意图。计算设备600旨在表示各种形式的数字计算机,例如膝上型电脑、台式机、工作站、个人数字助理、服务器、刀片式服务器、大型机和其他适当的计算机。这里所示的组件、它们的连接和关系以及它们的功能仅是示例性的,并不意味着限制本文档中描述和/或要求保护的发明的实现。
计算设备600包括处理器610、存储器620、存储设备630、连接到存储器620和高速扩展端口650的高速接口/控制器640、以及连接到低速总线670和存储设备630的低速接口/控制器660。组件610、620、630、640、650和660中的每一个使用各种总线互连,并且可以安装在公共主板上或以其他适当的方式安装。处理器610可以处理用于在计算设备600内执行的指令,包括存储在存储器620或存储设备630上的指令以在诸如耦合到高速接口640的显示器680的外部输入/输出设备上显示用于图形用户界面(GUI)的图形信息680。在其他实现中,可以适当使用多个处理器和/或多个总线以及多个存储器和多种类型的存储器。此外,可以连接多个计算设备600,每个设备提供部分必要操作(例如,作为服务器组、刀片服务器组或多处理器系统)。
存储器620在计算设备600内非暂时性地存储信息。存储器620可以是计算机可读介质、易失性存储器单元或非易失性存储器单元。非瞬态存储器620可以是用于在临时或永久基础上存储程序(例如,指令序列)或数据(例如,程序状态信息)以供计算设备600使用的物理设备。非易失性存储器的示例包括但不限于闪存和只读存储器(ROM)/可编程只读存储器(PROM)/可擦可编程只读存储器(EPROM)/电可擦可编程只读存储器(EEPROM)(例如,通常用于固件,例如引导程序)。易失性存储器的示例包括但不限于随机存取存储器(RAM)、动态随机存取存储器(DRAM)、静态随机存取存储器(SRAM)、相变存储器(PCM)以及磁盘或磁带。
存储设备630能够为计算设备600提供大容量存储。在一些实施方式中,存储设备630是计算机可读介质。在各种不同的实施方式中,存储设备630可以是软盘设备、硬盘设备、光盘设备或磁带设备、闪存或其他类似的固态存储设备、或设备阵列,包括在存储区域网络或其他配置中的设备。在另外的实施方式中,计算机程序产品有形地体现在信息载体中。计算机程序产品包含在执行时执行一种或多种方法的指令,例如上述那些。信息载体是计算机或机器可读介质,例如存储器620、存储设备630或处理器610上的存储器。
高速控制器640管理计算设备600的带宽密集型操作,而低速控制器660管理较低带宽密集型操作。这种职责分配只是示例性的。在一些实施方式中,高速控制器640耦合到存储器620、显示器680(例如,通过图形处理器或加速器),并且耦合到可以接受各种扩展卡(未示出)的高速扩展端口650。在一些实施方式中,低速控制器660耦合到存储设备630和低速扩展端口690。可以包括各种通信端口(例如,USB、蓝牙、以太网、无线以太网)的低速扩展端口690可以例如通过网络适配器耦合到一个或多个输入/输出设备,例如键盘、定点设备、扫描仪或网络设备,例如交换机或路由器。
如图所示,计算设备600可以以多种不同的形式实现。例如,它可以实现为标准服务器600a或在一组这样的服务器600a中多次实现,作为膝上型计算机600b,或作为机架服务器系统600c的一部分。
本文所述的系统和技术的各种实施方式可以在数字电子和/或光学电路、集成电路、专门设计的ASIC(专用集成电路)、计算机硬件、固件、软件和/或它们的组合中实现。这些不同的实现可以包括在一个或多个计算机程序中的实现,这些计算机程序在包括至少一个可编程处理器的可编程系统上是可执行和/或可解释的,该可编程处理器可以是专用或通用的,耦合以从存储系统、至少一个输入设备和至少一个输出设备接收数据和指令以及将指令传输到存储系统、至少一个输入设备和至少一个输出设备。
这些计算机程序(也称为程序、软件、软件应用或代码)包括用于可编程处理器的机器指令,并且可以以高级过程和/或面向对象的编程语言和/或以汇编/机器语言实现。如本文所用,术语“机器可读介质”和“计算机可读介质”是指任何计算机程序产品、非暂时性计算机可读介质、装置和/或设备(例如,磁盘、光盘、存储器、可编程逻辑设备(PLD)),其用于向可编程处理器提供机器指令和/或数据,包括接收机器指令作为机器可读信号的机器可读介质。术语“机器可读信号”是指用于向可编程处理器提供机器指令和/或数据的任何信号。
本说明书中描述的过程和逻辑流程可以由一个或多个可编程处理器(也称为数据处理硬件)执行,执行一个或多个计算机程序以通过对输入数据进行操作并生成输出来执行功能。过程和逻辑流程也可以由专用逻辑电路执行,例如FPGA(现场可编程门阵列)或ASIC(专用集成电路)。适合于执行计算机程序的处理器包括,例如,通用和专用微处理器,以及任何类型的数字计算机的任何一个或多个处理器。通常,处理器将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的基本元件是用于执行指令的处理器和一个或多个用于存储指令和数据的存储设备。通常,计算机还将包括或可操作地耦合以从一个或多个用于存储数据的大容量存储设备(例如,磁、磁光盘或光盘)接收数据或向其传输数据或两者。然而,计算机不需要有这样的设备。适用于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储器设备,例如包括半导体存储器设备,例如EPROM、EEPROM和闪存设备;磁盘,例如内部硬盘或可移动磁盘;磁光盘;和CD ROM和DVD-ROM磁盘。处理器和存储器可以由专用逻辑电路补充或结合在专用逻辑电路中。
为了提供与用户的交互,本公开的一个或多个方面可以在具有显示设备或用于向用户显示信息的触摸屏以及可选的用户可以通过它们向计算机提供输入的键盘和指示设备的计算机上实现,所述显示设备例如是例如,CRT(阴极射线管)、LCD(液晶显示器)监视器,所述指示设备例如是鼠标或轨迹球。也可以使用其他类型的设备来提供与用户的交互;例如,提供给用户的反馈可以是任何形式的感官反馈,例如视觉反馈、听觉反馈或触觉反馈;可以以任何形式接收来自用户的输入,包括声音、语音或触觉输入。此外,计算机可以通过向用户使用的设备发送文档和从其接收文档来与用户交互;例如,通过响应于从web浏览器接收到的请求,将网页发送到用户客户端设备上的web浏览器。
已经描述了许多实施方式。然而,应当理解,在不背离本公开的精神和范围的情况下可以进行各种修改。因此,其他实施方式在所附权利要求的范围内。

Claims (20)

1.一种用于联合训练分类模型(210)和置信度模型(220)的方法(500),所述方法(500)包括:
在数据处理硬件(104)处接收包括多个训练数据子集(112)的训练数据集(110),每个训练数据子集(112)与不同的相应类相关联并且具有属于所述相应类的多个对应的训练示例(114);
从所述训练数据集(110)中的两个或更多个训练数据子集(112):
通过所述数据处理硬件(104)选择训练示例的支持集(114S),所述训练示例的支持集(114S)包括从所述两个或更多个训练数据子集(112)中的每一个中采样的K个训练示例(114);和
通过所述数据处理硬件(104)选择训练示例的查询集(114Q),所述训练示例的查询集(114Q)包括从所述两个或更多个训练数据子集(112)中的每一个中采样的不包含在所述训练示例的支持集(114S)中的训练示例(114);
对于与所述两个或更多个训练数据子集(112)相关联的每个相应类,由所述数据处理硬件(104)使用所述分类模型(210)通过平均与属于所述相应类的所述训练示例的支持集(114S)中的所述K个训练示例(114)相关联的K个支持编码(212S)来确定质心值(214);
对于所述训练示例的查询集(114Q)中的每个训练示例:
通过所述数据处理硬件(104)使用所述分类模型(210)生成查询编码(212Q);
通过所述数据处理硬件确定表示所述查询编码(212Q)和为每个相应类确定的所述质心值(214)之间的相应距离的类距离度量;
通过所述数据处理硬件确定所述查询编码(212Q)和与所述训练示例的查询集(114Q)中的该对应的训练示例相关联的地面真实标签(214G)之间的地面真实距离;和
通过所述数据处理硬件基于所述类距离度量和所述地面真实距离来更新所述分类模型(210)的参数;和
对于被识别为错误分类的所述训练示例的查询集(114Q)中的每个训练示例:
通过所述数据处理硬件(104)使用所述置信度模型(220)生成由所述分类模型(210)为该对应的错误分类的训练示例生成的查询编码(212Q)的标准偏差值(222);
通过所述数据处理硬件(104)使用所述标准偏差值(222)和所述查询编码(212Q)对所述对应的错误分类的训练示例的新查询编码(224)进行采样;和
通过所述数据处理硬件(104)基于所述新查询编码(224)来更新所述置信度模型(220)的参数。
2.根据权利要求1所述的方法(500),其中,所述地面真实标签(214G)包括基于距离的表示空间内的地面真实质心值。
3.根据权利要求1或权利要求2所述的方法(500),其中,基于所述类距离度量和所述地面真实距离来更新所述分类模型(210)的所述参数训练所述分类模型(210)以最小化类内距离并且最大化类间距离。
4.根据权利要求1-3中的任一项所述的方法(500),其中,所述置信度模型被训练以最大化较大的地面真实距离的标准偏差值,并且对接近于基于距离的表示空间内的相应地面真实质心值的新查询编码(224)进行采样。
5.根据权利要求1-4中的任一项所述的方法(500),其中,被识别为被错误分类的所述训练示例的查询集(114Q)中的任何训练示例包括所述训练示例的查询集(114Q)中包含未能满足距离阈值的地面真实距离的任何训练示例。
6.根据权利要求1-5中的任一项所述的方法(500),其中,所述置信度模型(220)不在所述训练示例的查询集(114Q)中包含满足距离阈值的地面真实距离的训练示例上训练。
7.根据权利要求1-6中的任一项所述的方法(500),其中,更新置信度模型的参数包括更新所述置信度模型(220)的所述参数以鼓励所述置信度模型(220)输出与较大的类距离度量相关联的查询编码(212Q)的较大的标准偏差值(222)。
8.根据权利要求1-7中的任一项所述的方法(500),其中,所述训练示例(114)包括图像数据。
9.根据权利要求1-8中的任一项所述的方法(500),其中,所述分类模型(210)包括深度神经网络(DNN)。
10.根据权利要求1-9中的任一项所述的方法(500),其中,所述置信度模型(220)包括深度神经网络(DNN)。
11.一种用于联合训练分类模型(210)和置信度模型(220)的系统(100),所述系统(100)包括:
数据处理硬件(104);和
与所述数据处理硬件(104)通信的存储器硬件(106),所述存储器硬件(106)存储指令,所述指令当在所述数据处理硬件(104)上执行时,使所述数据处理硬件(104)执行操作,所述操作包括:
接收包括多个训练数据子集(112)的训练数据集(110),每个训练数据子集(112)与不同的相应类相关联并且具有属于所述相应类的多个对应的训练示例(114);
从所述训练数据集(110)中的两个或更多个训练数据子集(112):
选择训练示例的支持集(114S),所述训练示例的支持集(114S)包括从所述两个或更多个训练数据子集(112)中的每一个中采样的K个训练示例(114);和
选择训练示例的查询集(114Q),所述训练示例的查询集(114Q)包括从所述两个或更多个训练数据子集(112)中的每一个中采样的不包含在所述训练示例的支持集(114S)中的训练示例(114);
对于与所述两个或更多个训练数据子集(112)相关联的每个相应类,使用所述分类模型(210)通过平均与属于所述相应类的所述训练示例的支持集(114S)中的所述K个训练示例(114)相关联的K个支持编码(212S)来确定质心值(214);
对于所述训练示例的查询集(114Q)中的每个训练示例:
使用所述分类模型(210)生成查询编码(212Q);
确定表示所述查询编码(212Q)和为每个相应类确定的所述质心值(214)之间的相应距离的类距离度量;
确定所述查询编码(212Q)和与所述训练示例的查询集(114Q)中的该对应的训练示例相关联的地面真实标签(214G)之间的地面真实距离;和
基于所述类距离度量和所述地面真实距离来更新所述分类模型(210)的参数;和
对于被识别为错误分类的所述训练示例的查询集(114Q)中的每个训练示例:
使用所述置信度模型(220)生成由所述分类模型(210)为该对应的错误分类的训练示例生成的查询编码(212Q)的标准偏差值(222);
使用所述标准偏差值(222)和所述查询编码(212Q)对所述对应的错误分类的训练示例的新查询编码(224)进行采样;和
基于所述新查询编码(224)来更新所述置信度模型(220)的参数。
12.根据权利要求11所述的系统(100),其中,所述地面真实标签(214G)包括基于距离的表示空间内的地面真实质心值。
13.根据权利要求11权利要求12所述的系统(100),其中,基于所述类距离度量和所述地面真实距离来更新所述分类模型(210)的所述参数训练所述分类模型(210)以最小化类内距离并且最大化类间距离。
14.根据权利要求11-13中的任一项所述的系统(100),其中,所述置信度模型被训练以最大化较大的地面真实距离的标准偏差值,并且对接近于基于距离的表示空间内的相应地面真实质心值的新查询编码(224)进行采样。
15.根据权利要求11-14中的任一项所述的系统(100),其中,被识别为被错误分类的所述训练示例的查询集(114Q)中的任何训练示例包括所述训练示例的查询集(114Q)中包含未能满足距离阈值的地面真实距离的任何训练示例。
16.根据权利要求11-15中的任一项所述的系统(100),其中,所述置信度模型(220)不在所述训练示例的查询集(114Q)中包含满足距离阈值的地面真实距离的训练示例上训练。
17.根据权利要求11-16中的任一项所述的系统(100),其中,更新置信度模型的参数包括更新所述置信度模型(220)的所述参数以鼓励所述置信度模型(220)输出与较大的类距离度量相关联的查询编码(212Q)的较大的标准偏差值(222)。
18.根据权利要求11-17中的任一项所述的系统(100),其中,所述训练示例(114)包括图像数据。
19.根据权利要求11-18中的任一项所述的系统(100),其中,所述分类模型(210)包括深度神经网络(DNN)。
20.根据权利要求11-19中的任一项所述的系统(100),其中,所述置信度模型(220)包括深度神经网络(DNN)。
CN202080066367.7A 2019-09-24 2020-09-24 基于距离的学习置信度模型 Pending CN114424212A (zh)

Applications Claiming Priority (3)

Application Number Priority Date Filing Date Title
US201962904978P 2019-09-24 2019-09-24
US62/904,978 2019-09-24
PCT/US2020/052451 WO2021061951A1 (en) 2019-09-24 2020-09-24 Distance-based learning confidence model

Publications (1)

Publication Number Publication Date
CN114424212A true CN114424212A (zh) 2022-04-29

Family

ID=72811982

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202080066367.7A Pending CN114424212A (zh) 2019-09-24 2020-09-24 基于距离的学习置信度模型

Country Status (6)

Country Link
US (2) US11487970B2 (zh)
EP (1) EP4035090A1 (zh)
JP (2) JP7292506B2 (zh)
KR (1) KR20220049573A (zh)
CN (1) CN114424212A (zh)
WO (1) WO2021061951A1 (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20230386450A1 (en) * 2022-05-25 2023-11-30 Samsung Electronics Co., Ltd. System and method for detecting unhandled applications in contrastive siamese network training
US20240062529A1 (en) * 2022-08-18 2024-02-22 Microsoft Technology Licensing, Llc Determining media documents embedded in other media documents

Family Cites Families (27)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US8068654B2 (en) * 2007-02-02 2011-11-29 Siemens Akteingesellschaft Method and system for detection and registration of 3D objects using incremental parameter learning
US8311319B2 (en) * 2010-12-06 2012-11-13 Seiko Epson Corporation L1-optimized AAM alignment
US8306257B2 (en) * 2011-01-31 2012-11-06 Seiko Epson Corporation Hierarchical tree AAM
US9396412B2 (en) * 2012-06-21 2016-07-19 Siemens Aktiengesellschaft Machine-learnt person re-identification
US20150302317A1 (en) * 2014-04-22 2015-10-22 Microsoft Corporation Non-greedy machine learning for high accuracy
US10332028B2 (en) 2015-08-25 2019-06-25 Qualcomm Incorporated Method for improving performance of a trained machine learning model
US20170068906A1 (en) * 2015-09-09 2017-03-09 Microsoft Technology Licensing, Llc Determining the Destination of a Communication
US20170068904A1 (en) * 2015-09-09 2017-03-09 Microsoft Technology Licensing, Llc Determining the Destination of a Communication
US10423874B2 (en) * 2015-10-02 2019-09-24 Baidu Usa Llc Intelligent image captioning
WO2017203262A2 (en) * 2016-05-25 2017-11-30 Metail Limited Method and system for predicting garment attributes using deep learning
US11194846B2 (en) * 2016-11-28 2021-12-07 Here Global B.V. Method and apparatus for providing automated generation of parking restriction data using machine learning
WO2019046820A1 (en) * 2017-09-01 2019-03-07 Percipient.ai Inc. IDENTIFICATION OF INDIVIDUALS IN A DIGITAL FILE USING MULTIMEDIA ANALYSIS TECHNIQUES
US10657391B2 (en) * 2018-01-05 2020-05-19 Uatc, Llc Systems and methods for image-based free space detection
US10679330B2 (en) 2018-01-15 2020-06-09 Tata Consultancy Services Limited Systems and methods for automated inferencing of changes in spatio-temporal images
US11068737B2 (en) * 2018-03-30 2021-07-20 Regents Of The University Of Minnesota Predicting land covers from satellite images using temporal and spatial contexts
US10825227B2 (en) * 2018-04-03 2020-11-03 Sri International Artificial intelligence for generating structured descriptions of scenes
US10878296B2 (en) * 2018-04-12 2020-12-29 Discovery Communications, Llc Feature extraction and machine learning for automated metadata analysis
US11630995B2 (en) * 2018-06-19 2023-04-18 Siemens Healthcare Gmbh Characterization of amount of training for an input to a machine-learned network
US10832003B2 (en) * 2018-08-26 2020-11-10 CloudMinds Technology, Inc. Method and system for intent classification
US10878297B2 (en) * 2018-08-29 2020-12-29 International Business Machines Corporation System and method for a visual recognition and/or detection of a potentially unbounded set of categories with limited examples per category and restricted query scope
US11087177B2 (en) * 2018-09-27 2021-08-10 Salesforce.Com, Inc. Prediction-correction approach to zero shot learning
US10885384B2 (en) * 2018-11-15 2021-01-05 Intel Corporation Local tone mapping to reduce bit depth of input images to high-level computer vision tasks
US11756291B2 (en) * 2018-12-18 2023-09-12 Slyce Acquisition Inc. Scene and user-input context aided visual search
US20200193552A1 (en) * 2018-12-18 2020-06-18 Slyce Acquisition Inc. Sparse learning for computer vision
US11941493B2 (en) * 2019-02-27 2024-03-26 International Business Machines Corporation Discovering and resolving training conflicts in machine learning systems
US11657094B2 (en) * 2019-06-28 2023-05-23 Meta Platforms Technologies, Llc Memory grounded conversational reasoning and question answering for assistant systems
US11631029B2 (en) * 2019-09-09 2023-04-18 Adobe Inc. Generating combined feature embedding for minority class upsampling in training machine learning models with imbalanced samples

Also Published As

Publication number Publication date
EP4035090A1 (en) 2022-08-03
US20230120894A1 (en) 2023-04-20
US11487970B2 (en) 2022-11-01
JP2022549006A (ja) 2022-11-22
JP7292506B2 (ja) 2023-06-16
US20210279517A1 (en) 2021-09-09
KR20220049573A (ko) 2022-04-21
JP2023116599A (ja) 2023-08-22
WO2021061951A1 (en) 2021-04-01

Similar Documents

Publication Publication Date Title
US20190378044A1 (en) Processing dynamic data within an adaptive oracle-trained learning system using curated training data for incremental re-training of a predictive model
CN114424210A (zh) 存在标签噪声情况下的鲁棒训练
US8140450B2 (en) Active learning method for multi-class classifiers
US9189750B1 (en) Methods and systems for sequential feature selection based on significance testing
US20230120894A1 (en) Distance-based learning confidence model
CN114600117A (zh) 通过样本一致性评估的主动学习
US20230325675A1 (en) Data valuation using reinforcement learning
US10372743B2 (en) Systems and methods for homogeneous entity grouping
CN114467095A (zh) 基于强化学习的局部可解释模型
Yan et al. A framework of online learning with imbalanced streaming data
JP6172317B2 (ja) 混合モデル選択の方法及び装置
WO2020185101A9 (en) Hybrid machine learning system and method
JP5684084B2 (ja) 誤分類検出装置、方法、及びプログラム
CN110059743B (zh) 确定预测的可靠性度量的方法、设备和存储介质
WO2017096219A1 (en) Methods and systems for determination of the number of contributors to a dna mixture
CN111612022A (zh) 用于分析数据的方法、设备和计算机存储介质
JP6233432B2 (ja) 混合モデルの選択方法及び装置
US20220405585A1 (en) Training device, estimation device, training method, and training program
Manivannan Semi-supervised imbalanced classification of wafer bin map defects using a Dual-Head CNN
US20230222324A1 (en) Learning method, learning apparatus and program
CN115204409A (zh) 检测不可推理的数据
Linusson et al. conformal prediction

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination