CN109923560A - 使用变分信息瓶颈来训练神经网络 - Google Patents

使用变分信息瓶颈来训练神经网络 Download PDF

Info

Publication number
CN109923560A
CN109923560A CN201780066234.8A CN201780066234A CN109923560A CN 109923560 A CN109923560 A CN 109923560A CN 201780066234 A CN201780066234 A CN 201780066234A CN 109923560 A CN109923560 A CN 109923560A
Authority
CN
China
Prior art keywords
training
network
neural network
input
output
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201780066234.8A
Other languages
English (en)
Inventor
亚历山大·埃米尔·阿勒米
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Google LLC
Original Assignee
Google LLC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Google LLC filed Critical Google LLC
Publication of CN109923560A publication Critical patent/CN109923560A/zh
Pending legal-status Critical Current

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • G06N3/0455Auto-encoder networks; Encoder-decoder networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Probability & Statistics with Applications (AREA)
  • Debugging And Monitoring (AREA)
  • Machine Translation (AREA)

Abstract

一种用于训练神经网络的方法、系统和装置,包括在计算机存储介质上编码的计算机程序。所述方法中的一种包括:接收训练数据;在所述训练数据上训练神经网络,其中,所述神经网络被配置成:接收网络输入,将所述网络输入转换成所述网络输入的潜在表示,并且处理所述潜在表示以从所述网络输入生成网络输出,并且其中,在所述训练数据上训练所述神经网络包括在变分信息瓶颈目标上训练所述神经网络,所述变分信息瓶颈目标对于每个训练输入鼓励针对所述训练输入生成的所述潜在表示与所述训练输入具有低的互信息,同时针对所述训练输入生成的所述网络输出与针对所述训练输入的所述目标输出具有高的互信息。

Description

使用变分信息瓶颈来训练神经网络
技术领域
本说明书涉及训练神经网络。
背景技术
神经网络是采用非线性单元的一个或多个层来针对接收到的输入预测输出的机器学习模型。一些神经网络除了包括输出层之外还包括一个或多个隐藏层。每个隐藏层的输出被用作网络中的下一个层(即,下一个隐藏层或输出层)的输入。网络的每个层依照相应组的参数的当前值从接收到的输入生成输出。
一些神经网络是递归神经网络。递归神经网络是接收输入序列并且从该输入序列生成输出序列的神经网络。特别地,递归神经网络可在当前时间步计算输出时使用网络从前一个时间步起的内部状态中的一些或全部。递归神经网络的示例是包括一个或多个LSTM记忆块的长短期(LSTM)神经网络。每个LSTM记忆块可包括一个或多个细胞(cell),所述一个或多个细胞各自包括输入门、遗忘门和输出门,这些门允许细胞存储该细胞的先前状态,例如,以用于在生成当前激活时使用或者被提供给LSTM神经网络的其它组件。
发明内容
本说明书一般地描述作为一个或多个位置中的一个或多个计算机实现的系统,所述系统在训练数据上训练神经网络,所述训练数据包括一组训练输入以及针对每个训练输入的相应的目标输出。
神经网络是被配置成接收网络输入、将网络输入转换成网络输入的潜在表示并且处理潜在表示以从网络输入生成网络输出的神经网络。
特别地,系统在变分信息瓶颈目标上训练神经网络,所述变分信息瓶颈目标对于每个训练输入鼓励针对训练输入生成的潜在表示与训练输入具有低的互信息,同时针对训练输入生成的网络输出与针对训练输入的目标输出具有高的互信息。
因此在一个方面中,方法包括:接收训练数据,所述训练数据包括多个训练输入以及针对每个训练输入的相应的目标输出;在训练数据上训练神经网络,其中,神经网络被配置成:接收网络输入,将网络输入转换成网络输入的潜在表示,并且处理潜在表示以从网络输入生成网络输出,并且其中,在训练数据上训练神经网络包括在变分信息瓶颈目标上训练神经网络,所述变分信息瓶颈目标对于每个训练输入鼓励针对训练输入生成的潜在表示与训练输入具有低的互信息,同时针对训练输入生成的网络输出与针对训练输入的目标输出具有高的互信息。
可选地,系统然后可提供指定经训练的神经网络的数据以用于在处理新网络输入时使用。
可实现本说明书中描述的主题的特定实施例以便实现以下优点中的一个或多个。通过在上述目标上训练神经网络,经训练的神经网络可在推广到新输入时超过在不同目标上训练的神经网络的性能,并且,如在下面更详细地讨论的,对对抗性攻击更加鲁棒。附加地,由经训练的神经网络生成的预测分布将比在不同目标上训练网络情况下被更好地校准。特别地,对神经网络的训练比在常规目标上训练相同网络被更好地规则化。因此,通过以本说明书中描述的方式训练神经网络,结果得到的训练后的神经网络将是高性能的,同时还抵抗对抗性攻击而不会在训练数据上过拟合。
在下面的附图和描述中阐述本说明书中描述的主题的一个或多个实施例的细节。主题的其它特征、方面和优点将根据说明书、附图和权利要求书变得显而易见。
附图说明
图1示出示例神经网络训练系统。
图2是用于训练神经网络的示例过程的流程图。
图3是用于确定对神经网络的参数的当前值的更新的示例过程的流程图。
在各个附图中相似的附图标记和名称指示相似的元件。
具体实施方式
图1示出示例性神经网络训练系统100。神经网络训练系统100是作为计算机程序实现在一个或多个位置中的一个或多个计算机上的系统的示例,其中可实现在下面描述的系统、组件和技术。
神经网络训练系统100是在训练数据140上训练神经网络110以从参数的初始值确定神经网络110的参数的训练值的系统。
神经网络110可被配置成接收任何种类的数字数据输入并且针对该输入生成网络输出。网络输出可以是定义针对输入的一组可能的输出上的分数分布的任何输出。
例如,如果到神经网络的输入是图像或从图像中已提取的特征,则由神经网络针对给定图像生成的输出可以是针对一组对象类别中的每一个的分数,其中每个分数表示图像包含属于该类别的对象的图像的估计可能性。
作为另一示例,如果到神经网络的输入是因特网资源(例如,web页面)、文档或文档的部分或从因特网资源、文档或文档的部分中提取的特征,则由神经网络针对给定因特网资源、文档或文档的部分生成的输出可以是针对一组主题中的每一个的分数,其中每个分数表示因特网资源、文档或文档部分关于该主题的估计可能性。
作为另一示例,如果到神经网络的输入是针对特定广告的印象上下文的特征,则由神经网络生成的输出可以是表示该特定广告将被点击的估计可能性的分数。
作为另一示例,如果到神经网络的输入是针对用户的个性化推荐的特征,例如,表征针对推荐的上下文的特征,例如,表征由用户采取的先前动作的特征,则由神经网络生成的输出可以是针对一组内容项目中的每一个的分数,其中每个分数表示用户将有利地对被推荐该内容项目做出响应的估计可能性。
作为另一示例,如果到神经网络的输入是一种语言的文本的序列,则由神经网络生成的输出可以是针对另一语言的一组文本块中的每一个的分数,其中每个分数表示另一语言的文本块是输入文本变成该另一语言的正确翻译的估计可能性。
作为另一示例,如果到神经网络的输入是表示口头话语的序列,则由神经网络生成的输出可以是针对一组文本块中的每一个的分数,每个分数表示文本块是针对话语的正确转录的估计可能性。
神经网络110被配置成接收网络输入102,将网络输入102映射到网络输入102的潜在表示122,然后从网络输入102的潜在表示122生成网络输出132。
特别地,神经网络包括一个或多个初始神经网络层120和一个或多个附加神经网络层130,所述一个或多个初始神经网络层120接收网络输入102并处理网络输入102以生成定义潜在表示122的输出,所述一个或多个附加神经网络层130处理潜在表示122以生成网络输出132。
一般地,潜在表示122是数值的有序合集,例如矢量、矩阵或多维矩阵,所述数值是如由神经网络110所确定的网络输入102的表示。
一般地,神经网络110是随机神经网络。随机神经网络是假定网络参数的固定值有时将针对相同的网络输入生成不同的网络输出的网络。在这些实施方式中,初始神经网络层120生成可能的潜在表示上的分布的参数作为中间输出,并且神经网络110根据通过由初始神经网络层120生成的中间输出参数化的分布对潜在表示122进行采样。例如,中间输出可以是可能的潜在表示上的多元分布的均值和协方差。
在训练期间,为了让采样相对于中间输出是确定性的并且因此允许通过网络有效地反向传播梯度,神经网络110还根据独立于神经网络110的参数的噪声分布对噪声进行采样。神经网络110然后使用经采样的噪声和中间输出来生成潜在表示122,即,通过确定地组合中间输出和经采样的噪声。例如,噪声分布可以是多元高斯分布。当中间输出是可能的潜在表示上的多元分布的均值和协方差时,神经网络110可通过针对每个维度确定噪声和协方差的乘积并且然后将该乘积加到均值以生成潜在表示来确定潜在表示122。
由系统100使用来训练神经网络110的训练数据140包括多个训练输入,并且对于每个训练输入,包括应该由神经网络110通过处理该训练输入来生成的目标输出。例如,在分类上下文中,目标输出可以是网络输入应该被分类为的正确类别或类的单热编码。
系统100通过优化变分瓶颈目标150来在训练数据140上训练神经网络110。特别地,变分瓶颈目标150是对于训练数据140中的每个训练输入鼓励如下的目标:(i)针对训练输入生成的潜在表示与训练输入具有低的互信息,同时(ii)针对训练输入生成的网络输出与针对训练输入的目标输出具有高的互信息。
在下面参考图2和图3更详细地描述在此目标上训练神经网络。
图2是用于使用变分信息瓶颈来训练神经网络的示例过程200的流程图。为了方便,过程200将被描述为由位于一个或多个位置中的一个或多个计算机的系统执行。例如,适当地编程的强化学习系统(例如,图1的神经网络训练系统100)可执行过程200。
系统获得用于训练神经网络的训练数据(步骤202)。训练数据包括多个训练输入,并且对于每个训练输入,包括应该由神经网络针对该训练输入生成的相应的目标输出。
系统在训练数据上训练神经网络以优化变分信息瓶颈目标(步骤204)。
一般地,变分信息瓶颈目标具有如下形式:
I(Z,Y)-βI(Z,X),
其中I(Z,Y)是针对网络输入的潜在表示与目标输出之间的互信息,I(Z,X)是潜在表示与网络输入之间的互信息并且β是固定正常数值。因此,通过在此目标上训练神经网络,系统鼓励网络当在仍然使潜在表示可预测目标输出的同时生成潜在表示时“遗忘”网络输入(到由β统治的程度)。
在这样做时,经训练的神经网络变得能够更好地推广到在训练期间未看到的示例,从而在网络正在被训练来执行的任务上产生改进的性能。特别地,经训练后的神经网络可针对新接收到的输入(即,与在训练神经网络时使用的输入中的任一个不同的输入)生成更准确的输出。
经训练的神经网络还变得更能抵抗对抗性攻击。对抗性攻击是可破坏采用神经网络的计算机系统(例如,向用户提供基于由一个或多个神经网络生成的网络输出而生成的数据的系统)的可靠性的计算机安全问题。通过将神经网络训练为更能抵抗如本说明书中所描述的对抗性攻击,计算机系统的计算机安全性被改进。更具体地,当恶意用户给神经网络提供受合法输入扰动最小的输入以便试图使网络生成不正确的输出时发生对抗性攻击,这会降低使用网络的输出的系统的可靠性。也就是说,一般地,良好训练的神经网络应该针对给定测试输入和受该测试输入扰动最小的另一输入生成相同的输出。然而,另外在给定任务上执行良好的许多神经网络将替代地针对扰动输入生成与针对测试输入更加不同的输出,网络可以以其它方式正确地处理这个。然而,使用所描述的目标来训练的神经网络将更能抵抗此类攻击,并且将很可能替代地针对测试输入和最小扰动输入生成相同的(正确的)输出。
在许多情况下系统直接地优化变分信息瓶颈目标是不可行的,即,因为针对大量数据通过变分信息瓶颈目标直接地计算和反向传播至少部分地由于将需要的互信息量度的直接计算而在计算上是不可行的。
因此,为了在目标上训练神经网络,对于每个训练输入系统执行机器学习训练过程的相应迭代,例如,具有反向传播的梯度下降,以相对于变分信息瓶颈目标的下界的网络参数确定梯度,并且然后确定对网络参数的当前值的对应更新。在下面参考图3更详细地描述确定这种更新。
一旦已训练了神经网络,在一些实施方式中系统就输出训练后的神经网络数据(步骤206)。也就是说,系统可输出(例如,通过向用户设备输出或者通过存储在系统可访问的存储器中)网络参数的训练值以供在使用经训练后的神经网络来处理输入时稍后使用。
可替选地或除了输出经训练的神经网络数据之外,系统可例如通过由系统提供的应用编程接口(API)接收要处理的输入,使用经训练的神经网络来处理所接收到的输入以生成网络输出,然后提供响应于所接收到的输入所生成的网络输出。
图3是用于确定对网络参数的当前值的更新的示例过程300的流程图。为了方便,过程300将被描述为由位于一个或多个位置中的一个或多个计算机的系统执行。例如,适当地编程的强化学习系统(例如,图1的神经网络训练系统100)可执行过程300。
系统可在神经网络的训练期间针对一批训练输入执行过程300以针对该批次中的每个输入确定对网络参数的当前值的相应更新。系统然后可应用(即,添加)针对该批次中的输入所确定的更新以生成网络参数的更新值。
系统接收训练输入和针对该训练输入的目标输出(步骤302)。
系统使用神经网络并且依照网络参数的当前值来处理训练输入以确定针对该训练输入的网络输出(步骤304)。如上所述,网络输出一般地定义针对训练输入的可能的输出上的分数分布。作为处理训练输入的一部分,神经网络将训练输入映射到中间输出,所述中间输出定义可能的潜在表示上的分布,使用中间输出来对潜在表示进行采样,然后从潜在表示生成网络输出。如上所述,为了让采样相对于中间输出是确定性的,为了对潜在表示进行采样,神经网络根据独立于神经网络的参数的噪声分布对噪声进行采样并且确定地组合中间输出和经采样的噪音。
一般地,系统根据预先确定的噪声分布(例如,高斯分布)对噪声进行采样。在一些实施方式中,系统针对每个训练输入来对噪声进行采样。在其它实施方式中,系统对批次中的每个训练输入使用相同的噪声,即,每批次仅对噪声采样一次。
系统确定相对于变分信息瓶颈目标的下界的网络参数的梯度(步骤306)。特别地,系统将下界表示为对于给定训练输入xn满足下式的最小化的目标函数:
其中,N是训练数据集中的训练输入的总数,q(yn|f(xn,∈))是通过针对训练输入xn的网络输出指派给针对训练输入xn的目标输出的分数,∈是从噪声分布采样的噪声,f(xn,∈)是使用噪声∈和针对训练输入xn的中间输出来采样的潜在表示,KL是Kullback-Leibler发散,p(Z|xn)是通过中间输出所定义的可能的潜在表示上的概率分布,并且r(Z)是潜在表示的边际分布的变分近似。
系统可使用任何适当的分布作为潜在表示的边际分布的变分近似。例如,当潜在表示是K维的时,边际分布的变分近似可以是固定K维球面高斯。
系统可使用常规技术(例如,通过经由神经网络反向传播梯度)来确定下界相对于网络参数的梯度。
系统依照用于训练神经网络的训练技术来根据梯度确定对网络参数的更新(步骤308)。例如,当技术是随机梯度下降时,系统可对梯度应用学习速率以确定更新。
系统可针对多批训练数据执行过程300以迭代地将网络参数的值从初始值更新为训练值。
本说明书连同系统和计算机程序组件一起使用术语“被配置”。对于要配置成执行特定操作或动作的一个或多个计算机的系统意味着该系统已在其上安装了软件、固件、硬件或其组合,所述软件、固件、硬件或其组合在操作中使该系统执行操作或动作。对于要配置成执行特定操作或动作的一个或多个计算机程序意味着一个或多个程序包括指令,所述指令当由数据处理装置执行时,使该装置执行操作或动作。
本说明书中描述的主题和功能操作的实施例可用数字电子电路、用有形地具体实现的计算机软件或固件、用计算机硬件(包括本说明书中公开的结构及其结构等同物)或者用它们中的一个或多个的组合加以实现。本说明书中描述的主题的实施例可作为一个或多个计算机程序(即,在有形非暂时性存储介质上编码以供由数据处理装置执行或者控制数据处理装置的操作的计算机程序指令的一个或多个模块)被实现。计算机存储介质可以是机器可读存储设备、机器可读存储基板、随机或串行存取存储器设备,或它们中的一个或多个的组合。可替选地或此外,可将程序指令编码在人工生成的传播信号上,所述传播信号例如为机器生成的电、光学或电磁信号,该信号被生成来对信息进行编码以便传输到适合的接收器装置以供由数据处理装置执行。
术语“数据处理装置”指代数据处理硬件并且包含用于处理数据的所有种类的装置、设备和机器,作为示例包括可编程处理器、计算机或多个处理器或计算机。装置还可以是或者进一步包括专用逻辑电路,例如FPGA(现场可编程门阵列)或ASIC(专用集成电路)。装置除了包括硬件之外还可可选地包括为计算机程序创建执行环境的代码,例如,构成处理器固件、协议栈、数据库管理系统、操作系统或它们中的一个或多个的组合的代码。
计算机程序(其还可以被称为或者描述为程序、软件、软件应用、app、模块、软件模块、脚本或代码)可用任何形式的编程语言编写,所述编程语言包括编译或解释语言,或声明性或过程语言;并且它可被以任何形式部署,包括作为独立程序或者作为模块、组件、子例行程序或适合于在计算环境中使用的其它单元。程序可以但不必对应于文件系统中的文件。可在保持其它程序或数据(例如,存储在标记语言文档中的一个或多个脚本)的文件的一部分中、在专用于所述程序的单个文件中或者在多个协调文件(例如,存储一个或多个模块、子程序或代码的部分的文件)中存储程序。可将计算机程序部署成在一个计算机上或者在位于一个站点处或者跨越多个站点分布并通过数据通信网络互连的多个计算机上执行。
在本说明书中,术语“数据库”广泛地用于指代数据的任何合集:数据不需要被以任何特定方式构造,或者根本不构造,并且它可被存储在一个或多个位置中的存储设备上。因此,例如,索引数据库可包括数据的多个合集,其中的每一个均可以被不同地组织和访问。
类似地,在本说明书中术语“引擎”广泛地用于指代被编程来执行一个或多个具体功能的基于软件的系统、子系统或过程。一般地,引擎将作为安装在一个或多个位置中的一个或多个计算机上的一个或多个软件模块或组件被实现。在一些情况下,一个或多个计算机将专用于特定引擎;在其它情况下,可在相同的一个或多个计算机上安装和运行多个引擎。
本说明书中描述的过程和逻辑流程可通过一个或多个可编程计算机执行一个或多个计算机程序以通过对输入数据进行操作并生成输出来执行功能而被执行。过程和逻辑流程也可由专用逻辑电路(例如,FPGA或ASIC)或者由专用逻辑电路和一个或多个编程计算机的组合来执行。
适合于执行计算机程序的计算机可基于通用微处理器或专用微处理器或两者,或任何其它种类的中央处理单元。一般地,中央处理单元将从只读存储器或随机存取存储器或两者接收指令和数据。计算机的必要元件是用于执行或者实行指令的中央处理单元以及用于存储指令和数据的一个或多个存储器设备。中央处理单元和存储器可由专用逻辑电路补充,或者并入在专用逻辑电路中。一般地,计算机还将包括或者在操作上耦合以从用于存储数据的一个或多个大容量存储设备(例如,磁盘、磁光盘或光盘)接收数据或者将数据转移到用于存储数据的一个或多个大容量存储设备,或者兼而有之。然而,计算机不必具有此类设备。此外,计算机可被嵌入在另一设备中,所述另一设备例如为移动电话、个人数字助理(PDA)、移动音频或视频播放器、游戏控制台、全球定位系统(GPS)接收器或便携式存储设备,例如,通用串行总线(USB)闪存驱动器等等。
适合于存储计算机程序指令和数据的计算机可读介质包括所有形式的非易失性存储器、介质和存储器设备,作为示例包括半导体存储器设备,例如,EPROM、EEPROM和闪速存储器设备;磁盘,例如内部硬盘或可移动磁盘;磁光盘;以及CD ROM和DVD-ROM盘。
为了提供与用户的交互,可在计算机上实现本说明书中描述的主题的实施例,所述计算机具有用于向用户显示信息的显示设备(例如,CRT(阴极射线管)或LCD(液晶显示器)监视器)以及用户可用来向该计算机提供输入的键盘和指点设备,例如鼠标或轨迹球。其它种类的设备也可用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的感觉反馈,例如视觉反馈、听觉反馈或触觉反馈;并且可以任何形式接收来自用户的输入,包括声学、语音或触觉输入。此外,计算机可通过向由用户使用的设备发送文档并且从由用户使用的设备接收文档来与用户交互;例如,通过响应于从web浏览器接收到的请求而向用户的设备上的web浏览器发送web页面。另外,计算机可通过向个人设备(例如,正在运行消息传送应用的智能电话)发送文本消息或其它形式的消息并且作为回报从用户接收响应消息来与用户交互。
用于实现机器学习模型的数据处理装置还可包括例如用于处理机器学习训练或生产的公共且计算密集部分(即,推理、工作负载)的专用硬件加速器单元。
可使用机器学习框架(例如,TensorFlow框架、Microsoft Cognitive Toolkit框架、Apache Singa框架或Apache MXNet框架)来实现和部署机器学习模型。
可在计算系统中实现本说明书中描述的主题的实施例,所述计算系统包括后端组件(例如,作为数据服务器),或者包括中间件组件(例如,应用服务器),或者包括前端组件(例如,具有用户可用来与本说明书中描述的主题的实施方式交互的图形用户界面、web浏览器或app的客户端计算机),或者包括一个或多个此类后端、中间件或前端组件的任何组合。系统的组件可通过任何形式或介质的数字数据通信(例如,通信网络)来互连。通信网络的示例包括局域网(LAN)和广域网(WAN),例如因特网。
计算系统可包括客户端和服务器。客户端和服务器一般地彼此远离并且通常通过通信网络来交互。客户端和服务器的关系借助于在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序而产生。在一些实施例中,服务器向用户设备发送数据(例如,HTML页面),例如,用于向与作为客户端的设备交互的用户显示数据并且从与作为客户端的设备交互的用户接收用户输入的目的。可在服务器处从设备接收在用户设备处生成的数据,例如,用户交互的结果。
虽然本说明书包含许多具体实施方式细节,但是这些不应该被解释为对任何发明的范围或者对可能要求保护的东西的范围构成限制,而是相反被解释为可以特定于特定发明的特定实施例的特征的描述。还可在单个实施例中相结合地实现在本说明书中在单独的实施例的上下文中描述的某些特征。相反地,还可单独地或者按照任何适合的子组合在多个实施例中实现在单个实施例的上下文中描述的各种特征。此外,尽管特征可以在上面被描述为按照某些组合起作用并且甚至最初如此要求保护,然而来自要求保护的组合的一个或多个特征可在一些情况下被从该组合中除去,并且所要求保护的组合可以针对子组合或子组合的变化。
类似地,虽然按照特定次序在附图中描绘并在权利要求书中叙述操作,但是这不应该被理解为要求按照所示的特定次序或者按照顺序次序执行此类操作,或者执行所有图示的操作以实现所希望的效果。在某些情况下,多任务处理和并行处理可以是有利的。此外,上述的实施例中的各种系统模块和组件的分离不应该被理解为在所有实施例中要求这种分离,并且应该理解的是,所描述的程序组件和系统一般地可被一起集成在单个软件产品中或者包装到多个软件产品中。
已经描述了主题的特定实施例。其它实施例在以下权利要求的范围内。例如,权利要求中叙述的动作可以被以不同的次序执行并仍然实现所希望的结果。作为一个示例,附图中描绘的过程不一定要求所示的特定次序或顺序次序以实现所希望的结果。在一些情况下,多任务处理和并行处理可以是有利的。

Claims (11)

1.一种方法,包括:
接收训练数据,所述训练数据包括多个训练输入以及针对每个训练输入的相应的目标输出;
在所述训练数据上训练神经网络,其中,所述神经网络被配置成:
接收网络输入,
将所述网络输入转换成所述网络输入的潜在表示,并且
处理所述潜在表示以从所述网络输入生成网络输出,
其中,在所述训练数据上训练所述神经网络包括在变分信息瓶颈目标上训练所述神经网络,所述变分信息瓶颈目标对于每个训练输入鼓励针对该训练输入生成的所述潜在表示与该训练输入具有低的互信息,同时针对该训练输入生成的所述网络输出与针对该训练输入的所述目标输出具有高的互信息;以及
提供指定经训练的神经网络的数据以供在处理新网络输入时使用。
2.根据权利要求1所述的方法,其中,所述神经网络是随机神经网络,并且其中,所述神经网络被配置成:
处理所述网络输入以生成中间输出,所述中间输出定义在可能的潜在表示上的分布;以及
根据通过所述中间输出所定义的所述分布对所述网络输入的潜在表示进行采样。
3.根据权利要求2所述的方法,其中,对所述潜在表示进行采样包括:
根据独立于所述中间输出的预先确定的噪声分布对噪声进行采样;以及
从经采样的噪声和所述中间输出生成所述潜在表示。
4.根据权利要求1至3中的任一项所述的方法,其中,训练所述神经网络包括:
对于每个训练输入,在所述变分信息瓶颈目标的下界上执行随机梯度下降的迭代以确定对所述神经网络的参数的当前值的更新。
5.根据权利要求4所述的方法,其中,所述下界取决于在针对所述训练输入的所述潜在表示的情况下针对所述训练输入的所述网络输出的可能性的变分近似。
6.根据权利要求4或5中的任一项所述的方法,其中,所述下界取决于针对所述训练输入的所述潜在表示的边际分布的变分近似。
7.根据权利要求4至6中的任一项所述的方法,其中,所述下界被表示为对于给定训练输入xn满足下式的最小化的目标函数:
其中,N是训练数据集中的训练输入的总数,q(yn|f(xn,∈))是通过针对所述训练输入xn的所述网络输出指派给针对所述训练输入xn的所述目标输出的分数,∈是从噪声分布采样的噪声,f(xn,∈)是从经采样的噪声和针对所述训练输入xn的所述中间输出生成的所述潜在表示,KL是Kullback-Leibler发散,p(Z|xn)是在通过所述中间输出所定义的可能的潜在表示上的概率分布,并且r(Z)是所述潜在表示的边际分布的变分近似。
8.根据权利要求1至7中的任一项所述的方法,其中,经训练的神经网络抵抗对抗性扰动。
9.根据权利要求8所述的方法,其中,经训练的神经网络针对测试输入和所述测试输入的最小扰动生成相同的网络输出。
10.一种系统,包括一个或多个计算机和存储指令的一个或多个存储设备,所述指令在由所述一个或多个计算机执行时使所述一个或多个计算机执行根据权利要求1至9中的任一项所述的相应方法的操作。
11.存储指令的一个或多个计算机存储介质,所述指令在由一个或多个计算机执行时使所述一个或多个计算机执行根据权利要求1至9中的任一项所述的相应方法的操作。
CN201780066234.8A 2016-11-04 2017-11-03 使用变分信息瓶颈来训练神经网络 Pending CN109923560A (zh)

Applications Claiming Priority (3)

Application Number Priority Date Filing Date Title
US201662418100P 2016-11-04 2016-11-04
US62/418,100 2016-11-04
PCT/US2017/060003 WO2018085697A1 (en) 2016-11-04 2017-11-03 Training neural networks using a variational information bottleneck

Publications (1)

Publication Number Publication Date
CN109923560A true CN109923560A (zh) 2019-06-21

Family

ID=60382627

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201780066234.8A Pending CN109923560A (zh) 2016-11-04 2017-11-03 使用变分信息瓶颈来训练神经网络

Country Status (4)

Country Link
US (2) US10872296B2 (zh)
EP (1) EP3520037B1 (zh)
CN (1) CN109923560A (zh)
WO (1) WO2018085697A1 (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110348573A (zh) * 2019-07-16 2019-10-18 腾讯科技(深圳)有限公司 训练图神经网络的方法、图神经网络设备、装置、介质
CN112667071A (zh) * 2020-12-18 2021-04-16 宜通世纪物联网研究院(广州)有限公司 基于随机变分信息的手势识别方法、装置、设备及介质
CN112717415A (zh) * 2021-01-22 2021-04-30 上海交通大学 一种基于信息瓶颈理论的强化学习对战游戏ai训练方法
CN113434683A (zh) * 2021-06-30 2021-09-24 平安科技(深圳)有限公司 文本分类方法、装置、介质及电子设备
CN113516153A (zh) * 2020-04-10 2021-10-19 三星电子株式会社 学习多个随机变量之间的随机推断模型的方法和装置

Families Citing this family (13)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP7002404B2 (ja) * 2018-05-15 2022-01-20 株式会社日立製作所 データから潜在因子を発見するニューラルネットワーク
US10635979B2 (en) * 2018-07-20 2020-04-28 Google Llc Category learning neural networks
EP3611854B1 (en) * 2018-08-13 2021-09-22 Nokia Technologies Oy Method and apparatus for defending against adversarial attacks
JP6830464B2 (ja) * 2018-09-26 2021-02-17 株式会社Kokusai Electric 基板処理装置、半導体装置の製造方法および記録媒体。
US11625487B2 (en) 2019-01-24 2023-04-11 International Business Machines Corporation Framework for certifying a lower bound on a robustness level of convolutional neural networks
US11227215B2 (en) 2019-03-08 2022-01-18 International Business Machines Corporation Quantifying vulnerabilities of deep learning computing systems to adversarial perturbations
MX2022000163A (es) * 2019-06-24 2022-05-20 Insurance Services Office Inc Sistemas y metodos de aprendizaje de maquina para localizacion mejorada de falsificacion de imagenes.
CN111724767B (zh) * 2019-12-09 2023-06-02 江汉大学 基于狄利克雷变分自编码器的口语理解方法及相关设备
CN113222103A (zh) * 2020-02-05 2021-08-06 北京三星通信技术研究有限公司 神经网络的更新方法、分类方法和电子设备
US20210241112A1 (en) * 2020-02-05 2021-08-05 Samsung Electronics Co., Ltd. Neural network update method, classification method and electronic device
US11868428B2 (en) 2020-07-21 2024-01-09 Samsung Electronics Co., Ltd. Apparatus and method with compressed neural network computation
CN113191392B (zh) * 2021-04-07 2023-01-24 山东师范大学 一种乳腺癌图像信息瓶颈多任务分类和分割方法及系统
CN113488060B (zh) * 2021-06-25 2022-07-19 武汉理工大学 一种基于变分信息瓶颈的声纹识别方法及系统

Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN101639937A (zh) * 2009-09-03 2010-02-03 复旦大学 一种基于人工神经网络的超分辨率方法
CN103531205A (zh) * 2013-10-09 2014-01-22 常州工学院 基于深层神经网络特征映射的非对称语音转换方法
US20140229158A1 (en) * 2013-02-10 2014-08-14 Microsoft Corporation Feature-Augmented Neural Networks and Applications of Same
CN104933245A (zh) * 2015-06-15 2015-09-23 南华大学 基于神经网络和遗传算法的船用反应堆屏蔽设计优化方法
CN105144203A (zh) * 2013-03-15 2015-12-09 谷歌公司 信号处理系统
CN105378764A (zh) * 2013-07-12 2016-03-02 微软技术许可有限责任公司 计算机-人交互式学习中的交互式概念编辑
CN105637540A (zh) * 2013-10-08 2016-06-01 谷歌公司 用于强化学习的方法和设备
EP3054403A2 (en) * 2015-02-06 2016-08-10 Google, Inc. Recurrent neural networks for data item generation
CN105940395A (zh) * 2014-01-31 2016-09-14 谷歌公司 生成文档的矢量表示

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN101639937A (zh) * 2009-09-03 2010-02-03 复旦大学 一种基于人工神经网络的超分辨率方法
US20140229158A1 (en) * 2013-02-10 2014-08-14 Microsoft Corporation Feature-Augmented Neural Networks and Applications of Same
CN105144203A (zh) * 2013-03-15 2015-12-09 谷歌公司 信号处理系统
CN105378764A (zh) * 2013-07-12 2016-03-02 微软技术许可有限责任公司 计算机-人交互式学习中的交互式概念编辑
CN105637540A (zh) * 2013-10-08 2016-06-01 谷歌公司 用于强化学习的方法和设备
CN103531205A (zh) * 2013-10-09 2014-01-22 常州工学院 基于深层神经网络特征映射的非对称语音转换方法
CN105940395A (zh) * 2014-01-31 2016-09-14 谷歌公司 生成文档的矢量表示
EP3054403A2 (en) * 2015-02-06 2016-08-10 Google, Inc. Recurrent neural networks for data item generation
CN104933245A (zh) * 2015-06-15 2015-09-23 南华大学 基于神经网络和遗传算法的船用反应堆屏蔽设计优化方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
DIEDERIK P. KINGMA: "Semi-supervised Learning with Deep Generative Models", 《ADVANCES IN NEURAL INFORMATION PROCESSING SYSTEM》 *
NAFTALI TISHBY: "Deep Learning and the Information Bottleneck Principle", 《IEEE》 *
李海滨: "结构有限元分析神经网络计算研究", 《中国优秀博硕士学位论文全文数据库 (博士) 信息科技辑》 *

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110348573A (zh) * 2019-07-16 2019-10-18 腾讯科技(深圳)有限公司 训练图神经网络的方法、图神经网络设备、装置、介质
CN113516153A (zh) * 2020-04-10 2021-10-19 三星电子株式会社 学习多个随机变量之间的随机推断模型的方法和装置
CN113516153B (zh) * 2020-04-10 2024-05-24 三星电子株式会社 学习多个随机变量之间的随机推断模型的方法和装置
CN112667071A (zh) * 2020-12-18 2021-04-16 宜通世纪物联网研究院(广州)有限公司 基于随机变分信息的手势识别方法、装置、设备及介质
CN112717415A (zh) * 2021-01-22 2021-04-30 上海交通大学 一种基于信息瓶颈理论的强化学习对战游戏ai训练方法
CN112717415B (zh) * 2021-01-22 2022-08-16 上海交通大学 一种基于信息瓶颈理论的强化学习对战游戏ai训练方法
CN113434683A (zh) * 2021-06-30 2021-09-24 平安科技(深圳)有限公司 文本分类方法、装置、介质及电子设备
CN113434683B (zh) * 2021-06-30 2023-08-29 平安科技(深圳)有限公司 文本分类方法、装置、介质及电子设备

Also Published As

Publication number Publication date
EP3520037A1 (en) 2019-08-07
EP3520037B1 (en) 2024-01-03
WO2018085697A1 (en) 2018-05-11
US20210103823A1 (en) 2021-04-08
US11681924B2 (en) 2023-06-20
US20190258937A1 (en) 2019-08-22
US10872296B2 (en) 2020-12-22

Similar Documents

Publication Publication Date Title
CN109923560A (zh) 使用变分信息瓶颈来训练神经网络
US20210150355A1 (en) Training machine learning models using task selection policies to increase learning progress
JP7157154B2 (ja) 性能予測ニューラルネットワークを使用したニューラルアーキテクチャ探索
EP3329411B1 (en) Classifying user behavior as anomalous
Halvaiee et al. A novel model for credit card fraud detection using Artificial Immune Systems
US9576248B2 (en) Record linkage sharing using labeled comparison vectors and a machine learning domain classification trainer
US11636314B2 (en) Training neural networks using a clustering loss
US11443170B2 (en) Semi-supervised training of neural networks
US11694109B2 (en) Data processing apparatus for accessing shared memory in processing structured data for modifying a parameter vector data structure
CN109564575A (zh) 使用机器学习模型来对图像进行分类
US20240127058A1 (en) Training neural networks using priority queues
CN109196527A (zh) 广度和深度机器学习模型
US11756059B2 (en) Discovery of new business openings using web content analysis
Andrés et al. Linkages between Formal Institutions, ICT Adoption, and inclusive human development in sub-Saharan Africa
US11249751B2 (en) Methods and systems for automatically updating software functionality based on natural language input
CN108475346A (zh) 神经随机访问机器
Tang et al. A framework for constrained optimization problems based on a modified particle swarm optimization
US9176948B2 (en) Client/server-based statistical phrase distribution display and associated text entry technique
US20230169364A1 (en) Systems and methods for classifying a webpage or a webpage element
CN113191527A (zh) 一种基于预测模型进行人口预测的预测方法及装置
Zhao et al. Structural reliability assessment based on low-discrepancy adaptive importance sampling and artificial neural network
US20240028973A1 (en) Cross-hierarchical machine learning prediction
US20210064961A1 (en) Antisymmetric neural networks
Korolev et al. Applying Time Series for Background User Identification Based on Their Text Data Analysis
Tang Design and development of a machine learning-based framework for phishing website detection

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination