CN114881169A - 使用随机特征损坏的自监督对比学习 - Google Patents
使用随机特征损坏的自监督对比学习 Download PDFInfo
- Publication number
- CN114881169A CN114881169A CN202210597656.0A CN202210597656A CN114881169A CN 114881169 A CN114881169 A CN 114881169A CN 202210597656 A CN202210597656 A CN 202210597656A CN 114881169 A CN114881169 A CN 114881169A
- Authority
- CN
- China
- Prior art keywords
- feature
- neural network
- training
- unlabeled training
- network parameters
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
- G06N3/0455—Auto-encoder networks; Encoder-decoder networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Software Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computing Systems (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Multimedia (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本公开涉及使用随机特征损坏的自监督对比学习,尤其是用于训练具有多个网络参数的神经网络的方法、系统以及包括在计算机存储介质上编码的计算机程序的装置。其中一种方法包括:从未标记训练数据集合获得未标记训练输入;处理未标记训练输入以生成第一嵌入;生成未标记训练输入的损坏版本,包括:确定特征维度的真子集并且对于在特征维度的真子集中的每个特征维度,使用从如未标记训练数据集合中指定的该特征维度的边缘分布采样的一个或多个特征值来对该特征维度中的相应特征应用损坏;处理未标记训练输入的损坏版本以生成第二嵌入;以及确定对多个网络参数的当前值的更新。
Description
相关申请的交叉引用
本申请要求于2021年5月28日提交的美国申请No.63/194,899的申请日的权益。在先申请的公开内容被认为是本申请的公开内容的一部分并且通过引用并入在本申请的公开内容中。
技术领域
本说明书涉及训练神经网络。
背景技术
神经网络是采用一个或多个层的非线性单元来针对接收到的输入预测输出的机器学习模型。除了输出层之外,一些神经网络还包括一个或多个隐藏层。每个隐藏层的输出被用作网络中的下一层——即下一隐藏层或输出层——的输入。网络的每个层依照相应参数集的当前值从接收到的输入生成输出。
发明内容
本说明书描述一种作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统,该系统实现并训练能够对一个或多个接收到的输入执行机器学习任务的神经网络。特别地,该神经网络使用如下两阶段过程来训练:预训练阶段和微调阶段。神经网络的预训练阶段利用自监督对比学习方案。
能够在特定实施例中实现本说明书中描述的主题以便实现以下优点中的一个或多个。
如本说明书中描述的系统预训练神经网络以通过处理不需要被标记的网络输入对来生成可能稍后在特定下游任务中有用的任务不可知表示。特别地,网络输入对包括未标记训练输入,例如图像、视频或文本序列,以及由系统通过随机化未标记训练输入的随机特征集合的特征值自动地生成的未标记训练输入的损坏副本。与通常高度地特定于来自诸如计算机视觉或自然语言处理的窄范围技术领域的数据的现有自监督学习技术不同,系统所采用的边缘采样损坏技术普遍适用于跨越各种技术领域的不同格式或不同类型的数据或者两者。
此外,经预训练后的神经网络然后能够用于使用比用于预训练网络少几个数量级的数据来有效地适应特定机器学习任务。例如,虽然预训练网络可能利用数十亿个未标记训练输入,但是使网络适应特定任务可能需要仅仅几千个标记训练输入。由于针对特定任务训练网络需要比现有方法少,有时少几个数量级,的标记训练输入,所以系统因此能够在微调期间更高效地利用计算资源,例如存储器、挂钟时间或两者。系统还能够以与数据标记相关联的较低人类劳动成本训练神经网络,同时仍然确保匹配或甚至超过目前技术水平的经训练后的神经网络在一系列任务上的竞争性能,同时附加地可推广且可容易适应新任务。
本说明书中描述的主题的一个或多个实施例的细节在下面的附图和描述中阐述。主题的其他特征、方面和优点将从说明书、附图和权利要求书中变得显而易见。
附图说明
图1A示出预训练阶段期间的示例神经网络系统。
图1B示出微调阶段期间的示例神经网络系统。
图2是用于使用自监督对比学习方案来预训练神经网络的示例过程的流程图。
图3是用于在机器学习任务上微调神经网络的示例过程的流程图。
图4A-B分别是预训练和微调神经网络的示例图示。
在各个附图中相似的附图标记和名称指示相似的元件。
具体实施方式
本说明书描述一种作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统,该系统实现并训练能够对一个或多个接收到的输入执行机器学习任务的神经网络。取决于任务,神经网络能够被配置成接收任何种类的数字数据输入并且依照神经网络的当前参数值来处理所接收到的输入以基于该输入生成一个或多个输出。
在一些情况下,神经网络的输入包括表列数据。表列数据是指按行和列或者以单元格矩阵的形式布置的数字数据或信息。表列数据是指信息的布置,而不是指在列、行或单元格中的给定位置处找到的数据的特定类型。表列数据也不是指可以通过表列数据表示的实际数据。例如,每个给定位置可以具有表示像素值的数值(在表列数据表示图像数据的情况下)或者可以替换地具有表示字母、单词、短语或句子的数值(在表列数据表示文本数据的情况下)。
在一些情况下,神经网络的输出包括任何种类的分类输出。分类可以是例如类型、类、组、类别或量度。
例如,神经网络能够被配置成在制造工厂的场境下执行自动模式识别任务,其中神经网络接收包括已制造产品的描述故障的多个特征——例如位置、大小等——的输入数据,并且处理该输入数据以生成分类输出,该分类输出指定故障的类型,例如划痕、污点、脏污、凹凸等。在此示例中,输入可以以具有对应于描述已制造产品的故障的多个特征的行或列的表列数据格式布置,例如,可以具有其中每列具有描述故障的相应特征的多个列。
作为另一示例,神经网络能够被配置成处理描述植物的叶子样本的物理特性——例如形状、纹理、边缘等——的输入数据,以生成指定植物的种类的分类输出。
能够在加州大学欧文分校机器学习储存库(UCI储存库)和开放媒体库(OpenML)中找到用于此类任务和其他类似分类任务的标记数据集的示例。
在另外的示例中,任务可以是计算机视觉任务,其中输入是图像或点云并且输出是图像或点云的计算机视觉输出。例如,神经网络能够被配置成执行图像处理任务,例如,以接收包括图像数据的输入,所述图像数据包含多个像素。图像数据可以例如包括一个或多个图像或已从一个或多个图像中提取的特征。神经网络能够被配置成处理图像数据以生成用于图像处理任务的输出。
例如,如果任务是图像分类,则由神经网络针对给定图像生成的输出可以是对象类别集合中的每一个对象类别的分数,其中每个分数表示图像包含属于该类别的对象的图像的估计可能性。
作为另一示例,如果任务是对象检测,则由神经网络针对给定图像生成的输出可以是每一个与相应分数相关联的一个或多个有界框,其中每个有界框表示图像中的估计位置并且相应分数表示在图像中的位置处——即在有界框内——描绘对象的估计可能性。
作为另一示例,如果任务是语义分割,则由神经网络针对给定图像生成的输出可以是用于图像中的多个像素中的每一个的标签,其中每个像素被标记为属于对象类别集合中的一个对象类别。替换地,对于多个像素中的每一个,输出可以是包括对象类别集合中的每一个对象类别的相应分数的分数集合,该相应分数表示该像素属于来自该对象类别的对象的可能性。
作为另一示例,如果神经网络的输入是互联网资源(例如,web页面)、文档或文档的部分或从互联网资源、文档或文档的部分中提取的特征,则由神经网络针对给定互联网资源、文档或文档的部分生成的输出可以是主题集合中的每一个主题的分数,其中每个分数表示互联网资源、文档或文档部分是关于该主题的估计可能性。
作为另一示例,如果神经网络的输入是特定广告的闪现(impression)上下文的特征,则由神经网络生成的输出可以是表示该特定广告将被点击的估计可能性的分数。
作为另一示例,如果神经网络的输入是针对用户的个性化推荐的特征,例如表征推荐的场境的特征,例如表征由用户采取的先前动作的特征,则由神经网络生成的输出可以是内容项集合中的每一个的分数,其中每个分数表示用户将积极地对被推荐该内容项做出响应的估计可能性。
作为另一示例,任务可以是对某种自然语言的文本序列进行操作的自然语言处理或理解任务,例如,蕴涵任务、释义任务、文本相似性任务、情感任务、句子完成任务、语法性任务等。
例如,如果神经网络的输入是一种语言的文本序列,则由神经网络生成的输出可以是另一种语言的文本片段集合中的每一个文本片段的分数,其中每个分数表示另一种语言的文本片段是输入文本到该另一种语言的适当翻译的估计可能性。
作为另一示例,如果神经网络的输入是表示口语话语的序列,则由神经网络生成的输出可以是文本片段集合中的每一个文本片段的分数,每个分数表示该文本片段是该话语的正确转录的估计可能性。
作为另一示例,任务可以是健康预测任务,其中输入是从患者的电子健康记录数据得出的序列并且输出是与患者的将来健康相关的预测,例如,应该为患者采用的预测治疗、患者将发生不利健康事件的可能性、或针对患者的预测诊断。
图1A示出预训练阶段期间的示例神经网络系统100。神经网络系统100是作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统的示例,其中能够实现下述系统、组件和技术。
神经网络系统100包括损坏引擎120、神经网络130和训练引擎140。神经网络130被配置成接收输入并且基于所接收到的输入和神经网络130的网络参数150的值来生成输出。
一般而言,神经网络130能够具有使得它能够执行上面提及的机器学习任务的任何适当的神经网络架构。在图1A的示例中,神经网络130包括编码器子网络132和嵌入生成子网络134。神经网络的子网络是指神经网络中的一组一个或多个神经网络层。当输入包括文本数据时,编码器子网络132可以是被配置成处理输入以生成编码器网络输出的全连接子网络,即,所述全连接子网络包括一个或多个全连接神经网络层,并且在一些实现方式中,所述全连接子网络包括一个或多个非线性激活层,例如ReLU激活层。当输入包括图像数据时,编码器子网络132能够附加地或替换地包括一个或多个卷积神经网络层。嵌入生成子网络134能够被类似地配置为全连接子网络,然后能够处理由编码器子网络132生成的编码器网络输出以生成用于输入的嵌入,其通常是具有固定维数的数值表示。
作为另一示例,神经网络130可以是包括一个或多个注意力层的注意力神经网络。如本文所使用的,注意力层是包括注意力机制——例如多头自注意力机制——的神经网络层。注意力神经网络的配置的示例和注意力神经网络的其他组件的细节——例如将输入嵌入到神经网络的嵌入层或注意力网络的层内的前馈层——在Vaswani,et al,AttentionIs All You Need,arXiv:1706.03762(Vaswani等人,注意力就是你所需要的,arXiv:1706.03762)以及Raffel,et al,Exploring the Limits of Transfer Learning with aUnified Text-to-Text Transformer,arXiv:1910.10683(Raffel,et al,使用统一的文本到文本转换器探索迁移学习的极限,arXiv:1910.10683)中被更详细地描述,其全部内容特此通过引用整体地并入全文。
在一些情况下,神经网络130的架构在预训练阶段和微调阶段两者期间保持相同,然而在其他情况下,神经网络130能够在两个阶段期间具有不同架构。在后者情况下,神经网络130能够在预训练阶段和微调阶段两者期间具有公共主干子网络(例如,图1A的编码器子网络132),并且能够具有在每个阶段使用的不同辅助子网络(例如,在预训练阶段期间使用的嵌入生成子网络134或图1B的在微调阶段期间使用的输出子网络136)。
在图1A的示例中,神经网络130包括仅用于在预训练阶段期间协助编码器子网络132的训练的嵌入生成子网络134。换句话说,一旦预训练已完成,即在微调阶段或部署期间,嵌入生成子网络134将不再作为神经网络130的一部分被包括。
系统100中的训练引擎140在未标记训练数据110上训练神经网络130以使用迭代训练过程来从网络参数的初始值确定网络参数150的学习值。在训练过程的每次迭代时,训练引擎140确定对网络参数150(包括编码器子网络132的参数和嵌入生成子网络134的参数)的当前值的参数值更新,然后将更新应用于网络参数150的当前值。
特别地,为了通过利用可跨越广泛范围的机器学习任务相对更容易地大量获得的未标记训练数据110来有效地确定神经网络130的参数150的训练值,即与标记(例如,人类注释的)训练数据相比,与损坏引擎120协同工作的训练引擎140通过使用自监督对比学习技术来训练神经网络130。
来自未标记训练数据110的未标记训练输入112是指对于其关于应该由神经网络130生成的已知真实值输出——例如训练输入的真实值分类——的信息未被系统100使用的训练输入。未标记训练输入112包括表示任何种类的数字数据的多个特征。在一些示例中,每个特征能够表示描述分类任务的主题的属性或特征集合中的一个属性或特征。在其他示例中,每个特征能够针对任何适当的任务表示对应像素的对应通道的不同强度值、文本序列中的不同文本词元、音频数据中的不同振幅值、点云中的不同点等。
在预训练阶段期间,对于每个未标记训练输入112,损坏引擎120处理未标记训练输入以通过损坏,即修改,包含在原始未标记训练输入中的特征的子集来生成未标记训练输入的损坏版本(“损坏训练输入”)114。特别地,损坏引擎120被配置成使用边缘采样损坏技术来生成损坏训练输入114。
许多对比学习和相关联损坏技术已经在视觉领域(例如,诸如随机裁切、颜色失真和模糊的基于图像的损坏技术)和自然语言领域(例如,诸如词元屏蔽、删除和填充的基于文本的损坏技术)中成功。尽管为计算中最常见的数据类型之一,表列数据是似乎缺少的又一种类型的数据。
具体地,在表列数据格式中,未标记训练输入112可以在多个特征维度中的每一个中,例如在多个行或列或两者中的每一个中,具有相应特征。每个相应特征可以具有表示该特征的特征值,其通常是数值。每个相应特征可以是数字特征或者可以替换地是离散特征。换句话说,未标记训练输入112可以包括作为数字特征的一些特征和作为离散特征的一些特征。数值特征是具有可以为某个范围内的任何值的数值的特征,然而离散特征包括二元特征和仅能够取少数可能的值之一的其他特征,例如分类特征。
通过应用所公开的有效地适用于表列数据的边缘采样损坏技术,损坏引擎120能够通过首先选择要损坏哪些特征维度并且对于每个选择的特征维度基于训练输入中的特征值的经验边缘分布来对特征维度中的特征值应用损坏而生成损坏训练输入114。
对于每个未标记训练输入112,神经网络130处理未标记训练输入112的原始未损坏版本以生成第一嵌入142。此外,神经网络130处理已由损坏引擎120从未标记训练输入112生成的损坏训练输入114以生成第二嵌入144。也就是说,第一嵌入和第二嵌入144由相同神经网络(具有相同架构和相同参数值)针对相同训练输入的两个不同版本——原始版本和损坏版本——生成。
训练引擎140然后能够通过将对比损失函数的梯度146反向传播通过嵌入生成子网络134和编码器子网络132的参数来确定参数值更新,所述对比损失函数来测量第一嵌入142和第二嵌入144之间的差异。例如,对比损失函数可以是噪声对比估计(NCE)损失函数,例如InfoNCE损失函数。
图1B示出微调阶段期间的示例神经网络系统100。
在预训练之后,系统100的训练引擎120然后利用包括多个标记训练输入118的标记训练数据116来使经预训练后的神经网络130适应下游任务,该下游任务可以是上面提及的机器学习任务中的任一个。
在一些情况下,随后对已预训练的所有神经网络130进行微调,然而在其他情况下,随后对神经网络130的仅一部分进行微调。在图1B的示例中,除了具有编码器子网络132之外,神经网络130还包括输出子网络136代替嵌入生成子网络134,所述输出子网络136能够被配置成处理由编码器子网络132生成的编码器网络输出以生成用于下游任务的输出。不再需要并因此不进一步微调嵌入生成子网络134。
来自标记训练数据116的标记训练输入118是指如下训练输入:对于所述训练输入关于应该由神经网络130生成的已知真实值输出——例如训练输入的真实值分类——的信息由训练输入定义或以其他方式指定并且因此可被系统100利用。
通常,用于微调阶段的数据能够比用于预训练阶段的数据小几个数量级。在一些实现方式中,未标记训练数据110包括数百万个未标记训练输入,然而标记训练数据132包括仅仅几千个标记训练输入。另外,不再需要自监督对比学习技术以及损坏处理步骤。相反,可以在微调阶段期间使用更常规的监督学习技术。
使经预训练后的神经网络130适应下游任务涉及调整一些或所有网络参数150的学习值。在图1B的示例中,在微调阶段期间,编码器子网络132和输出子网络136的参数而不是不再作为神经网络130的一部分被包括的嵌入生成子网络134的参数被调整。训练引擎140能够通过将适合于下游任务的目标函数的梯度反向传播通过输出子网络136和编码器子网络132的参数148来确定参数值更新。例如,目标函数可以是通过处理训练输入,即相对于与训练输入相关联的真实值分类,来测量由神经网络130生成的分类输出的质量的交叉熵损失函数。
一旦两阶段过程已完成,系统100就能够向另一系统——例如服务器——提供指定经训练后的神经网络的数据,例如,指定神经网络的架构(其可以与在微调而不是预训练阶段期间使用的架构相同)和神经网络的网络参数150的训练值的数据,以用于在处理新输入时使用。作为提供指定经训练后的神经网络的数据的代替或补充,系统100能够使用经训练后的神经网络来处理新输入并且生成相应输出。
图2是用于使用自监督对比学习方案来预训练神经网络的示例过程200的流程图。为了方便,过程200将被描述为由位于一个或多个位置中的一个或多个计算机的系统执行。例如,依照本说明书适当地编程的神经网络系统,例如图1的神经网络系统100,能够执行过程200。
系统从未标记训练数据集合获得未标记训练输入(步骤202)。能够通过从用于预训练神经网络的未标记训练数据中随机采样来获得未标记训练数据集合。未标记训练数据集合能够包括固定数目的未标记训练输入,例如,64、128或256。系统通常对包括在未标记训练数据集合中的每个未标记训练输入执行步骤202-208的一迭代。
未标记训练输入能够具有表列数据格式。未标记训练输入能够在多个特征维度的每一个中具有相应特征。每个相应特征可以具有表示该特征的特征值,其通常是数值。例如,未标记训练输入能够包括描述矩阵的数据,所述矩阵具有作为矩阵的行或列中的矩阵元素布置的未标记训练输入的特征,其中每行或列对应于特定特征维度。在其他类似示例中,未标记训练输入能够包括描述向量、表、数组等的数据。
图4A是预训练神经网络的示例图示。如图示,未标记训练输入402是6维向量,即具有六个特征维度的向量。未标记训练输入402在六个特征维度的每一个中具有相应特征。
系统使用神经网络并且依照多个网络参数的当前值来处理未标记训练输入以生成未标记训练输入的第一嵌入(步骤204)。嵌入可以是具有固定维数的数值表示。
在图4A的示例中,神经网络包括编码器子网络和嵌入生成子网络。在此示例中,系统能够首先依照编码器网络参数(由f表示)的当前值来处理未标记训练输入402以生成编码器网络输出(嵌入406A),然后依照嵌入生成网络参数(由g表示)的当前值来处理编码器网络输出以生成未标记训练输入的第一嵌入408A。
系统生成未标记训练输入的损坏版本(步骤206)。
生成未标记训练输入的损坏版本能够包括确定特征维度的真子集并且对于在特征维度的真子集中的每个特征维度使用从如未标记训练数据集合中指定的特征维度的边缘分布采样的一个或多个特征值来对特征维度中的相应特征应用损坏的操作。应用损坏能够包括用一个或多个采样的特征值替换真子集中的每个特征维度中的特征。
在一些实现方式中,系统能够通过按均匀随机性从多个特征维度对特征维度的真子集进行采样来确定特征维度的真子集。在一些实现方式中,系统能够依照指定要选择的特征维度的总数的预定损坏率来确定特征维度的真子集。例如,预定损坏率c可以是相对于包括在未标记训练输入中的特征维度的总数M而定义的百分比值(例如,20%、30%、50%等)。在此示例中,系统能够对总数c×M个特征维度进行采样,然后对每个采样的特征维度中的相应特征应用损坏。
在图4A的示例中,系统对未标记训练输入402的六个特征维度中的一半进行采样,然后用从特征维度的经验边缘分布采样的特征值替换每个采样的特征维度中的原始特征值。
特别地,能够将特征维度的边缘分布定义为跨越未标记训练数据集合在特征维度中的特征已采用的所有值之上的均匀分布。换句话说,为了对于真子集中的每个特征维度确定一个或多个替换特征值,系统能够从跨越未标记训练数据集合出现在特征维度中至少阈值次数的所有特征值之上的均匀分布采样。例如,阈值是一,但是在其他示例中可以提高阈值。
在数学上,令未标记训练数据集为其中M是特征维度的数量,是之上的均匀分布,其中xj表示x的第j个特征维度,对于每个未标记训练输入系统能够从大小q的多个特征维度{1,...,M}均匀地对特征维度的真子集进行采样并且生成未标记训练输入的损坏版本如下:如果则否则其中
系统使用神经网络并且依照多个网络参数的当前值来处理未标记训练输入的损坏版本以生成未标记训练输入的损坏版本的第二嵌入(步骤208)。换句话说,系统使用已用于生成未标记训练输入的第一嵌入的相同神经网络(具有相同神经网络架构和相同参数值)来通过处理未标记训练输入的损坏版本而生成相同未标记训练输入的第二嵌入。
如图示,系统首先依照编码器网络参数(由f表示)的当前值来处理未标记训练输入的损坏版本404以生成编码器网络输出(嵌入406B),然后依照嵌入生成网络参数(由g表示)的当前值来处理编码器网络输出以生成未标记训练输入的损坏版本的第二嵌入408B。
系统例如通过反向传播来计算对比学习损失函数的相对于多个网络参数的梯度(步骤210)。对于未标记训练数据集合中的每个未标记训练输入,对比学习损失函数评价未标记训练输入的第一嵌入与未标记训练输入的损坏版本的第二嵌入之间的差异。另外,对于未标记训练数据集合中的每个未标记训练输入,对比学习损失函数评价未标记训练输入的第一嵌入与已由神经网络针对集合中的每个其他未标记训练输入生成的对应第一嵌入之间的差异。
对比学习损失函数训练神经网络以通过使相同输入的不同版本的相应嵌入之间(即,正训练对的嵌入之间)的相似性最大化并且使不同输入的不同版本的相应嵌入之间(即,负训练对的嵌入之间)的相似性最小化来生成对相同输入的不同版本稳健的表示。例如,对比学习损失函数可以是噪声对比估计(NCE)损失函数,例如InfoNCE损失函数。
系统继续基于梯度并且通过使用适当的梯度下降优化技术,例如随机梯度下降、RMSprop或Adam技术,来更新当前参数值。
系统能够重复地执行过程200直到满足预训练终止准则,例如,在过程200已被执行预定次数之后、在对比学习函数的梯度已收敛于指定值之后、或在满足某个早期停止准则之后。
在确定满足预训练终止准则之后,系统能够继续使神经网络适应特定机器学习任务。在一些情况下,随后对所有经预训练后的神经网络进行微调,然而在其他情况下,随后对经预训练后的神经网络的仅一部分进行微调。在后者情况下,系统能够通过相对于标记训练数据与输出子网络协同重新训练编码器子网络来微调编码器子网络,包括调整编码器网络参数的学习值。标记训练数据包括专用于特定机器学习任务并且各自与对应的真实值输出相关联的训练输入。
图3是用于在机器学习任务上微调神经网络的示例过程300的流程图。为了方便,过程300将被描述为由位于一个或多个位置中的一个或多个计算机的系统执行。例如,依照本说明书适当地编程的神经网络系统,例如图1的神经网络系统100,能够执行过程300。
系统使用编码器子网络并且依照多个编码器网络参数的学习值来处理一组一个或多个标记训练输入中的标记训练输入以生成标记训练输入的嵌入(步骤302)。例如,标记训练输入集合是从更大的标记训练数据集采样的。
系统使用输出子网络并且依照多个输出网络参数的当前值来处理嵌入以针对标记训练输入集合中的每个标记训练输入生成训练输出(步骤304)。
图4B是微调神经网络的示例图示。如图示,系统能够首先依照编码器网络参数(由f表示)的当前值来处理标记训练输入412以生成编码器网络输出(嵌入416),然后依照输出网络参数(由h表示)的当前值来处理编码器网络输出以生成训练输出416。
系统计算监督学习损失函数(步骤306)。对于标记训练输入集合中的每个标记训练输入,监督学习损失函数评价训练输出与和标记训练输入相关联的真实值输出之间的差异。系统还例如通过反向传播来计算监督学习损失函数相对于多个编码器网络参数和相对于多个输出网络参数的梯度。
在图4B的示例中,特定机器学习任务是分类任务,并且监督学习损失函数可以是评价训练输出416与和标记训练输入412相关联的真实值输出414之间的差异的分类损失函数,例如交叉熵损失函数。
系统然后继续基于梯度并且通过使用适当的梯度下降优化技术,例如随机梯度下降、RMSprop或Adam技术,来更新编码器网络参数和输出网络参数的当前值(步骤308)。
以这种方式,在预训练过程期间学习的参数值被调整,使得它们适应特定机器学习任务。
本说明书连同系统和计算机程序组件一起使用术语“配置”。对于一个或多个计算机的系统被配置成执行特定操作或动作意味着系统已在其上安装了软件、固件、硬件或它们的组合,这些软件、固件、硬件或它们的组合在操作中使系统执行操作或动作。对于一个或多个计算机程序被配置成执行特定操作或动作意味着一个或多个程序包括指令,当由数据处理装置执行时,这些指令使该装置执行操作或动作。
本说明书中描述的主题和功能操作的实施例能够用数字电子电路系统、用有形地体现的计算机软件或固件、用计算机硬件(包括本说明书中公开的结构及其结构等同物)或者用它们中的一种或多种的组合加以实现。能够将本说明书中描述的主题的实施例实现为一个或多个计算机程序,即,在有形非暂时性存储介质上编码以用于由数据处理装置执行或者控制数据处理装置的操作的计算机程序指令的一个或多个模块。计算机存储介质可以是机器可读存储设备、机器可读存储基板、随机或串行存取存储器设备或它们中的一种或多种的组合。替换地或另外,程序指令能够被编码在人工生成的传播信号上,该传播信号例如机器生成的电、光或电磁信号,该信号被生成来对信息进行编码以便传输到合适的接收器装置以供由数据处理装置执行。
术语“数据处理装置”是指数据处理硬件并且包含用于处理数据的所有种类的装置、设备和机器,作为示例包括可编程处理器、计算机或多个处理器或计算机。该装置还可以是或进一步包括专用逻辑电路系统,例如FPGA(现场可编程门阵列)或ASIC(专用集成电路)。除了硬件之外,该装置还能够可选地包括为计算机程序创建执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统或它们中的一种或多种的组合的代码。
还可以被称为或者描述为程序、软件、软件应用、app、模块、软件模块、脚本或代码的计算机程序能够用任何形式的编程语言编写,所述编程语言包括编译或解释语言或声明或过程语言;并且它能够被以任何形式部署,包括作为独立程序或作为模块、组件、子例程或适合于在计算环境中使用的其他单元。程序可以但不必对应于文件系统中的文件。能够在保持其他程序或数据(例如,存储在标记语言文档中的一个或多个脚本)的文件的一部分中、在专用于所述程序的单个文件中或者在多个协调文件(例如,存储一个或多个模块、子程序或代码的部分的文件)中存储程序。能够将计算机程序部署成在一个计算机上或在位于一个站点处或者跨越多个站点分布并且通过数据通信网络互连的多个计算机上执行。
在本说明书中,术语“数据库”广泛地用于是指数据的任何合集:数据不需要被以任何特定方式构造,或者根本不构造,并且它能够被存储在一个或多个位置中的存储设备上。因此,例如,索引数据库能够包括数据的多个合集,其中的每个可以被不同地组织和访问。
类似地,在本说明书中术语“引擎”广泛用于是指被编程以执行一个或多个特定功能的基于软件的系统、子系统或过程。通常,引擎将被实现为安装在一个或多个位置中的一个或多个计算机上的一个或多个软件模块或组件。在一些情况下,一个或多个计算机将专用于特定引擎;在其他情况下,能够在同一个或多个计算机上安装和运行多个引擎。
本说明书中描述的过程和逻辑流程可以是通过一个或多个可编程计算机执行一个或多个计算机程序以通过对输入数据操作并生成输出来执行功能而执行的。过程和逻辑流程还能够由专用逻辑电路例如FPGA或ASIC执行,或者由专用逻辑电路系统和一个或多个编程计算机的组合执行。
适合于执行计算机程序的计算机能够基于通用或专用微处理器或两者,或任何其他种类的中央处理单元。通常,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的必要元件是用于执行或执行指令的中央处理单元以及用于存储指令和数据的一个或多个存储器设备。中央处理单元和存储器能够由专用逻辑电路系统补充,或者并入在专用逻辑电路系统中。通常,计算机还将包括或者在操作上耦合以从用于存储数据的一个或多个大容量存储设备(例如,磁盘、磁光盘或光盘)接收数据或者向其转移数据,或者兼而有之。然而,计算机不需要有此类设备。此外,计算机能够被嵌入到另一设备中,该另一设备例如移动电话、个人数字助理(PDA)、移动音频或视频播放器、游戏机、全球定位系统(GPS)接收器或便携式存储设备,例如,通用串行总线(USB)闪存驱动器,仅举几例。
适合于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储器设备,作为示例包括半导体存储器设备,例如EPROM、EEPROM和闪速存储器设备;磁盘,例如内部硬盘或可移动盘;磁光盘;以及CD ROM和DVD-ROM盘。
为了提供与用户的交互,能够在计算机上实现本说明书中描述的主题的实施例,该计算机具有用于向用户显示信息的显示设备——例如CRT(阴极射线管)或LCD(液晶显示器)监视器——以及用户能够通过其向计算机提供输入的键盘和指点设备,例如鼠标或轨迹球。其他种类的设备也能够用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的感觉反馈,例如视觉反馈、听觉反馈或触觉反馈;并且能够以任何形式接收来自用户的输入,包括声、语音或触觉输入。另外,计算机能够通过向由用户使用的设备发送文档并且从由用户使用的设备接收文档来与用户交互;例如,通过响应于从web浏览器接收到的请求而向用户的设备上的web浏览器发送网页。另外,计算机能够通过向个人设备例如正在运行消息传送应用的智能电话发送文本消息或其他形式的消息并且作为回报接收来自用户的响应消息来与用户交互。
用于实现机器学习模型的数据处理装置还能够包括例如用于处理机器学习训练或生产,即推理,工作负载的常见和计算密集部分的专用硬件加速器单元。
机器学习模型能够使用机器学习框架,例如TensorFlow框架、MicrosoftCognitive Toolkit框架、Apache Singa框架或Apache MXNet框架,来实现和部署。
能够在计算系统中实现本说明书中描述的主题的实施例,该计算系统包括后端组件(例如,作为数据服务器),或者包括中间件组件(例如,应用服务器),或者包括前端组件(例如,具有用户能够通过其与本说明书中描述的主题的实现方式交互的图形用户接口、web浏览器或app的客户端计算机),或者包括一个或多个此类后端、中间件或前端组件的任何组合。系统的组件能够通过任何形式或介质的数字数据通信例如通信网络来互连。通信网络的示例包括局域网(LAN)和广域网(WAN),例如互联网。
计算系统能够包括客户端和服务器。客户端和服务器通常彼此远离并且通常通过通信网络来交互。客户端和服务器的关系借助于在相应计算机上运行并且彼此具有客户端-服务器关系的计算机程序而发生。在一些实施例中,服务器向用户设备传送数据(例如,HTML页面),例如,用于向与作为客户端的设备交互的用户显示数据并且从与作为客户端的设备交互的用户接收用户输入的目的。能够在服务器处从设备接收在用户设备处生成的数据,例如用户交互的结果。
虽然本说明书包含许多特定实现方式细节,但是这些不应该被解释为对任何发明的范围或对可能要求保护的东西的范围构成限制,而是相反被解释为可能特定于特定发明的特定实施例的特征的描述。还能够在单个实施例中相结合地实现在本说明书中在单独实施例的上下文中描述的某些特征。相反地,还能够单独地在多个实施例中或在任何合适的子组合中实现在单个实施例的上下文中描述的各种特征。此外,尽管特征可以在上面被描述为按某些组合起作用并且甚至最初被如此要求保护,但是来自要求保护的组合的一个或多个特征能够在一些情况下被从组合中除去,并且所要求保护的组合可以针对子组合或子组合的变化。
类似地,虽然以特定次序在附图中描绘并在权利要求中叙述操作,但是这不应该被理解为要求以所示特定次序或以顺序次序执行此类操作,或者要求执行所有图示的操作以实现所希望的结果。在某些情况下,多任务处理和并行处理可以是有利的。此外,上述实施例中的各种系统模块和组件的分离不应该被理解为在所有实施例中要求这种分离,并且应该理解,所描述的程序组件和系统通常能够被集成在单个软件产品中或者包装到多个软件产品中。
已经描述了主题的特定实施例。其他实施例在以下权利要求的范围内。例如,权利要求中叙述的动作能够被以不同次序执行并且仍然实现所希望的结果。作为一个示例,附图中描绘的过程不一定要求所示的特定次序或顺序次序以实现所希望的结果。在某些情况下,多任务处理和并行处理可以是有利的。
所要求保护的是:
Claims (18)
1.一种训练具有多个网络参数的神经网络的计算机实现的方法,所述方法包括:
从未标记训练数据集合中获得未标记训练输入,所述未标记训练输入在多个特征维度中的每一个中具有相应特征;
使用所述神经网络依照所述多个网络参数的当前值来处理所述未标记训练输入以生成所述未标记训练输入的第一嵌入;
生成所述未标记训练输入的损坏版本,包括:
从所述多个特征维度中确定特征维度的真子集,并且
对于在所述特征维度的真子集中的每一个特征维度,使用从所述未标记训练数据集合中所指定的该特征维度的边缘分布中采样的一个或多个特征值来对该特征维度中的所述相应特征应用损坏;
使用所述神经网络依照所述多个网络参数的所述当前值来处理所述未标记训练输入的所述损坏版本以生成所述未标记训练输入的所述损坏版本的第二嵌入;以及
基于计算对所述第一嵌入与所述第二嵌入之间的差异进行求解的对比学习损失函数的相对于所述多个网络参数的梯度,确定对所述多个网络参数的所述当前值的更新。
2.根据权利要求1所述的方法,其中,所述对比学习损失函数包括噪声对比估计(NCE)损失函数。
3.根据权利要求2所述的方法,其中,所述NCE损失函数包括InfoNCE损失函数。
4.根据权利要求1所述的方法,其中,确定所述特征维度的真子集包括按均匀随机性从所述多个特征维度中采样所述特征维度的真子集。
5.根据权利要求4所述的方法,其中,所述特征维度的真子集是依照指定要选择的特征维度的总数的预定损坏率来按均匀随机性采样的。
6.根据权利要求1所述的方法,其中,所述一个或多个特征值是从跨所述未标记训练数据集合中的所述未标记训练输入出现在所述特征维度中至少阈值次数的所述特征值上的均匀分布中采样的。
7.根据权利要求6所述的方法,其中,所述阈值是一。
8.根据权利要求1所述的方法,其中,使用所述一个或多个特征值来对所述相应特征应用所述损坏包括用所述一个或多个特征值替换所述相应特征。
9.根据权利要求1所述的方法,其中,至少一个特征维度中的特征是数值特征。
10.根据权利要求1所述的方法,其中,至少一个特征维度中的特征是分类特征。
11.根据权利要求1所述的方法,其中,第一特征维度中的特征是数值特征并且第二特征维度中的特征是分类特征。
12.根据权利要求1-11中的任一项所述的方法,其中,所述神经网络包括具有多个编码器网络参数的编码器子神经网络和具有多个嵌入生成网络参数的嵌入生成子神经网络。
13.根据权利要求12所述的方法,进一步包括:在所述未标记训练数据集合上训练所述神经网络之后,使所述编码器子神经网络适应特定机器学习任务,包括使用包括标记训练输入的标记数据来调整所述多个编码器网络参数的学习值。
14.根据权利要求13所述的方法,其中,使所述编码器子神经网络适应所述特定机器学习任务还包括:
使用所述编码器子神经网络依照所述多个编码器网络参数的所述学习值来处理标记训练输入以生成所述标记训练输入的嵌入;
使用输出子神经网络依照多个输出网络参数的当前值来处理所述嵌入以生成训练输出;
计算对所述训练输出与和所述标记训练输入相关联的真实值输出之间的差异进行求解的监督学习损失函数;以及
基于计算所述监督学习损失函数相对于所述多个编码器网络参数和相对于所述多个输出网络参数的梯度,确定对所述多个编码器网络参数的所述学习值的调整。
15.根据权利要求14所述的方法,其中,所述特定机器学习任务包括分类任务,并且其中,所述监督学习损失函数包括交叉熵损失函数。
16.根据权利要求13-15中任一项所述的方法,进一步包括提供所述多个编码器网络参数的所述学习值以用于在执行所述特定机器学习任务时使用。
17.一种包括一个或多个计算机和存储指令的一个或多个存储设备的系统,所述指令在由所述一个或多个计算机执行时,使所述一个或多个计算机执行根据权利要求1-16中任一项所述的相应方法的操作。
18.一种编码有指令的非易失性计算机可读存储介质,所述指令在由一个或多个计算机执行时,使所述一个或多个计算机执行根据权利要求1-16中任一项所述的相应方法的操作。
Applications Claiming Priority (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
US202163194899P | 2021-05-28 | 2021-05-28 | |
US63/194,899 | 2021-05-28 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114881169A true CN114881169A (zh) | 2022-08-09 |
Family
ID=82679261
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210597656.0A Pending CN114881169A (zh) | 2021-05-28 | 2022-05-30 | 使用随机特征损坏的自监督对比学习 |
Country Status (2)
Country | Link |
---|---|
US (1) | US20220383120A1 (zh) |
CN (1) | CN114881169A (zh) |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116229063B (zh) * | 2023-01-08 | 2024-01-26 | 复旦大学 | 基于类别色彩化技术的语义分割网络模型及其训练方法 |
CN116089838B (zh) * | 2023-03-01 | 2023-09-26 | 中南大学 | 窃电用户智能识别模型训练方法和识别方法 |
-
2022
- 2022-05-27 US US17/827,448 patent/US20220383120A1/en active Pending
- 2022-05-30 CN CN202210597656.0A patent/CN114881169A/zh active Pending
Also Published As
Publication number | Publication date |
---|---|
US20220383120A1 (en) | 2022-12-01 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11568207B2 (en) | Learning observation representations by predicting the future in latent space | |
CN112084327B (zh) | 在保留语义的同时对稀疏标注的文本文档的分类 | |
CN105631479B (zh) | 基于非平衡学习的深度卷积网络图像标注方法及装置 | |
US11847541B2 (en) | Training neural networks using data augmentation policies | |
CN110188195B (zh) | 一种基于深度学习的文本意图识别方法、装置及设备 | |
CN111079532A (zh) | 一种基于文本自编码器的视频内容描述方法 | |
US12118064B2 (en) | Training machine learning models using unsupervised data augmentation | |
CN114881169A (zh) | 使用随机特征损坏的自监督对比学习 | |
CN109948149A (zh) | 一种文本分类方法及装置 | |
US11250838B2 (en) | Cross-modal sequence distillation | |
CN111639186B (zh) | 动态嵌入投影门控的多类别多标签文本分类模型及装置 | |
CN110968725B (zh) | 图像内容描述信息生成方法、电子设备及存储介质 | |
CN111475622A (zh) | 一种文本分类方法、装置、终端及存储介质 | |
CN113486175B (zh) | 文本分类方法、文本分类装置、计算机设备及存储介质 | |
US20230205994A1 (en) | Performing machine learning tasks using instruction-tuned neural networks | |
US20220230065A1 (en) | Semi-supervised training of machine learning models using label guessing | |
US20240152749A1 (en) | Continual learning neural network system training for classification type tasks | |
CN115398446A (zh) | 使用符号编程的机器学习算法搜索 | |
WO2021159099A1 (en) | Searching for normalization-activation layer architectures | |
CN112561530A (zh) | 一种基于多模型融合的交易流水处理方法及系统 | |
CN117033961A (zh) | 一种上下文语境感知的多模态图文分类方法 | |
CN111008329A (zh) | 基于内容分类的页面内容推荐方法及装置 | |
WO2023158881A1 (en) | Computationally efficient distillation using generative neural networks | |
CN117436457B (zh) | 反讽识别方法、装置、计算设备及存储介质 | |
CN117932073B (zh) | 一种基于提示工程的弱监督文本分类方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |