CN113396429A - 递归机器学习架构的正则化 - Google Patents
递归机器学习架构的正则化 Download PDFInfo
- Publication number
- CN113396429A CN113396429A CN201980091101.5A CN201980091101A CN113396429A CN 113396429 A CN113396429 A CN 113396429A CN 201980091101 A CN201980091101 A CN 201980091101A CN 113396429 A CN113396429 A CN 113396429A
- Authority
- CN
- China
- Prior art keywords
- distribution
- potential
- observation
- current
- network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Probability & Statistics with Applications (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
建模系统通过确定潜在状态的先验分布和潜在分布来训练递归机器学习模型。模型的参数基于散度损失而被训练,该散度损失惩罚潜在分布与先验分布之间的显著偏差。鉴于当前观察的值和先前观察的潜在状态,当前观察的潜在分布是潜在状态的分布。鉴于与当前观察的值无关的先前观察的潜在状态,当前观察的先验分布是潜在状态的分布,并且表示在考虑输入证据之前关于潜在状态的置信度。
Description
相关申请的交叉引用
本申请要求于2018年12月11日提交的美国临时申请第62/778,277号的权益和优先权,该申请的全部内容通过引用并入本文。
背景技术
本发明一般涉及递归机器学习模型,更具体地涉及递归机器学习模型的正则化。
建模系统通常使用诸如递归神经网络(RNN)或长短期记忆模型(LSTM)之类的递归机器学习模型,以生成顺序预测。递归机器学习模型被配置为基于针对当前预测的潜在状态来生成后续预测,有时结合实际输入的初始序列。当前潜在状态表示关于直到当前预测之前生成的预测的上下文信息,并且基于针对先前预测的潜在状态和当前预测的值而被生成。例如,顺序预测可以是词序列,并且递归机器学习模型可以基于表示关于实际词标记的初始序列的上下文信息的当前潜在状态和直到当前词标记之前生成的预测词标记来生成针对后续词标记的预测。
在结构上,递归机器学习模型包括与已训练参数集合相关联的一个或多个节点层。通过将递归机器学习模型迭代应用于已知观察序列并更新参数以减少观察序列中的损失函数来训练递归机器学习模型的参数。然而,随着模型的复杂性和大小的增加,参数常常难以训练,这可能会导致将模型过度拟合到数据集或丢失可能对生成预测有用的上下文信息。尽管正则化方法已被应用来降低模型复杂性,但训练递归机器学习模型以保留重要的上下文信息并控制对连续输入数据的敏感性仍然是一个具有挑战性的问题。
发明内容
建模系统通过确定潜在状态的先验分布和潜在分布来训练递归机器学习模型。模型的参数基于散度损失而被训练,该散度损失惩罚潜在分布和先验分布之间的显著偏差。鉴于当前观察的值和一个或多个先前观察的潜在状态,当前观察的潜在分布是潜在状态的分布。鉴于与当前观察的值无关的一个或多个先前观察的潜在状态,当前观察的先验分布是潜在状态的分布,并且表示在考虑任何输入证据之前关于潜在状态的置信度。
通过以这种方式训练递归模型,建模系统惩罚针对连续输入的潜在状态之间的显著变化。这可以防止模型的过度拟合并且防止可能对生成预测有用的重要长期上下文信息的丢失。建模系统可以鼓励更简单的潜在状态分布,其在连续潜在状态之间具有更平滑的转移,并且保留了附加的上下文信息。此外,由于更简单的潜在状态分布,利用散度损失训练递归机器学习模型还可以减少训练时间和复杂性,因为后续的潜在状态倾向于遵循先验分布,并且可以控制它在连续输入之间变化的程度。
在一个实施例中,递归机器学习模型的架构被公式化为包括编码器网络和译码器网络的自动编码器。编码器网络可以被布置为与参数集合相关联的一个或多个节点层。编码器网络接收当前预测和一个或多个先前潜在状态作为输入,并通过将参数集合应用于输入来生成当前潜在状态的潜在分布。译码器网络还可以被布置为与参数集合相关联的一个或多个节点层。译码器网络接收从潜在分布所生成的一个或多个值,并通过将参数集合应用于这些值来生成后续预测。
建模系统使用已知的观察序列作为训练数据来训练递归机器学习模型的参数。每个序列可以表示有序观察集合,这些观察相对于空间或时间是顺序依赖的。在训练过程期间,建模系统迭代地将递归机器学习模型应用于观察序列,并训练模型的参数以减少损失函数。损失函数可以被确定为针对序列中的每个观察的损失的组合。特别地,当前观察的损失包括随着后续观察的预测可能性的降低而增加的预测损失,以及测量当前观察的潜在状态的先验分布和潜在分布之间的差异的散度损失。
附图说明
图1是根据实施例的用于建模系统的系统环境的高级框图。
图2图示出了根据实施例的用于递归机器学习模型的示例推理过程。
图3是根据实施例的建模系统的架构的框图。
图4图示出了根据实施例的用于递归机器学习模型的示例训练过程。
图5图示出了根据实施例的包括嵌入层的递归机器学习模型的示例架构。
图6图示出了根据实施例的用于训练递归机器学习模型的方法。
图7A到图7C图示出了根据实施例的示例递归机器学习模型与其他现有技术模型相比的性能结果。
附图仅出于说明的目的描绘了本发明的各种实施例。本领域的技术人员将从以下讨论中容易地认识到,在不脱离本文所描述的本发明的原理的情况下,可以采用本文所图示的结构和方法的替代实施例。
具体实施方式
图1是根据实施例的用于文档分析系统110的系统环境的高级框图。由图1所示的系统环境100包括一个或多个客户端设备116、网络120和建模系统110。在替代配置中,在系统环境100中可以包括不同的和/或附加的组件。
建模系统110是用于训练各种机器学习模型的系统。建模系统110可以向客户端设备116的用户提供经训练的模型,或者可以使用经训练的模型来执行针对各种任务的推理。在一个实施例中,建模系统110训练可以被用来生成顺序预测的递归机器学习模型。顺序预测是有序预测集合,其中序列中的预测可以依赖于关于空间或时间的先前或后续预测的值。例如,顺序预测可以是依赖于包括在先前句子或段落中的词标记的词标记序列。作为另一个示例,顺序预测可以是依赖于前几天的历史股票价格的股票价格的时间串。
递归机器学习模型接收当前预测并生成后续预测。特别地,从针对当前预测的潜在状态生成后续预测,有时与实际输入的初始序列相结合。当前潜在状态表示直到当前预测之前生成的预测的上下文信息。例如,当顺序预测是词序列时,递归机器学习模型可以基于表示关于实际词标记的初始序列的上下文信息的当前潜在状态和直到当前词标记之前生成的预测词标记来生成针对后续词标记的预测。可以基于针对一个或多个先前预测的一个或多个潜在状态和当前预测的值来生成当前潜在状态。
在一个实施例中,递归机器学习模型的架构被公式化为包括编码器网络和译码器网络的自动编码器。编码器网络可以被布置为与经训练参数集合相关联的一个或多个节点层。用于编码器网络的参数可以包括输入参数集合和递归参数集合。输入参数集合沿节点层传播,而递归参数集合沿时间或空间的序列传播。编码器网络接收前一步的编码器网络层和当前预测,并生成针对当前潜在状态的潜在分布。鉴于针对一个或多个先前预测的潜在状态和当前预测,潜在分布是针对潜在状态的分布。译码器网络还可以被布置为与经训练参数集合相关联的一个或多个节点层。译码器网络接收从潜在分布生成的一个或多个值,并通过将参数集合应用于这些值来生成后续预测。
图2图示出了根据实施例的用于递归机器学习模型的示例推理过程。如图2中所示,递归机器学习模型包括与经训练输入参数集合和经训练递归参数集合γ相关联的编码器网络递归机器学习模型还包括与经训练参数集合θ相关联的译码器网络pθ(·)。图2中所示的示例是词标记的预测,并且标记“该(the)”的先前预测和标记“狐狸(fox)”的当前预测被生成。
在当前迭代t处的推理期间,输入参数集合被应用于沿编码器网络层的当前预测并且递归参数集合γ被应用于前一步t-1处的编码器网络层以生成针对当前潜在状态zt的潜在分布因此,潜在状态zt可以包含关于直到当前预测之前生成的预测的上下文信息。在一种实例中,从由编码器网络所输出的一个或多个统计参数中确定潜在分布从潜在分布中生成一个或多个值vt,并且将译码器网络pθ(·)应用于值vt以生成后续预测在图2中所示的示例中,后续预测是词标记“跑(runs)”,它考虑了“该”和“狐狸”的先前预测,以及在当前预测之前或之后出现的许多其他词。在一种实例中,值vt是潜在分布的均值,或者基于来自潜在分布的一个或多个样本而被确定。然而,该值可以是其他统计参数,诸如分布的中位数。重复此过程,直到对序列进行了所有预测。
返回图1,可以使用已知观察序列作为训练数据来训练递归模型的参数。每个序列可以表示有序观察集合,这些有序观察相对于空间或时间是顺序依赖的,递归机器学习模型可以使用其来学习顺序相关性。在一个实例中,建模系统110可以依赖于递归机器学习模型所针对训练的任务来访问不同类型的训练数据。例如,当顺序预测是词标记时,建模系统110可以访问诸如包含词序列的文档和段落之类的训练数据。作为另一个示例,当顺序预测是未来股票价格时,建模系统110可以访问诸如历史股票价格之类的训练数据。
建模系统110可以通过在前向传递步骤和反向传播步骤之间进行迭代以减少损失函数来训练递归模型的参数。在前向传递步骤期间,建模系统110通过将编码器网络的估计参数应用于前一步的编码器网络层和当前观察来生成当前观察的估计潜在分布。建模系统110通过将译码器网络的估计参数应用于从潜在分布生成的值来生成后续观察的预测可能性。针对后续观察重复此过程。在反向传播步骤期间,建模系统110将损失函数确定为针对序列中的每个观察的损失的组合。当前观察的损失可以包括随着后续观察的预测可能性的降低而增加的预测损失。建模系统110通过从损失函数反向传播一个或多个误差项来更新递归机器学习模型的参数。
然而,尤其是随着递归机器学习模型的复杂性和大小增加,递归机器学习模型的参数常常难以训练。特别地,递归机器学习模型容易过度拟合,并可能会导致丢失可能对生成未来预测有用的长期上下文信息。正则化方法可以被用来限制参数的幅值,从而降低模型复杂性。然而,由于难以应用有效的正则化方法,训练递归机器学习模型仍然是一个具有挑战性的问题。
在一个实施例中,建模系统110通过确定潜在状态的潜在分布和先验分布来训练递归机器学习模型。除了预测损失之外,还基于发散损失来训练模型的参数,该发散损失惩罚潜在分布和先验分布之间的显著偏差。鉴于与当前观察值无关的一个或多个先前观察的潜在状态,当前观察的先验分布是潜在状态的分布。与潜在分布不同,先验分布表示在考虑输入观察之前关于潜在状态的置信度。
在训练过程期间,建模系统110将递归机器学习模型迭代地应用于观察序列,并训练模型的参数以减少损失函数。损失函数可以被确定为针对序列中的每个观察的损失的组合。在一个实施例中,当前观察的损失包括随着后续观察的预测可能性的降低而增加的预测损失,以及测量当前观察的潜在状态的先验分布与潜在分布之间的差异的散度损失。训练过程的更详细描述将在下面结合图4和图5进行描述。
通过以这种方式训练递归模型,建模系统110惩罚针对连续输入的潜在状态之间的显著变化。这可以防止模型的过度拟合并且防止可能对生成预测有用的重要长期上下文信息的丢失。建模系统110可以鼓励更简单的潜在状态分布,其在连续潜在状态之间具有更平滑的转移,并且保留了附加的上下文信息。此外,由于更简单的潜在状态分布,利用散度损失训练递归机器学习模型还可以减少训练时间和复杂性,因为后续的潜在状态倾向于遵循先验分布并促进从先验分布中采样,并且可以控制它在连续输入之间变化的程度。
客户端设备116的用户是向建模系统130提供请求以基于各种感兴趣的任务来训练一个或多个递归机器学习模型的各种实体。用户还可以向建模系统130提供针对感兴趣任务而定制的模型的训练数据。客户端设备116接收经训练的模型,并使用模型来执行顺序预测。例如,客户端设备116可以与对生成用于语言合成的顺序词标记预测感兴趣的自然语言处理实体相关联。作为另一个示例,客户端设备116可以与对生成用于未来投资价格的连续预测感兴趣的金融实体相关联。作为又一个示例,客户端设备116可以与对鉴于给定患者的先前就诊历史而生成估计患者的未来医院就诊的顺序预测感兴趣的医院相关联。
建模系统
图3是根据实施例的建模系统110的架构的框图。由图3所示的建模系统110包括数据管理模块320、训练模块330和预测模块335。建模系统110还包括训练语料库360。在替代配置中,不同的和/或附加的组件可以被包括在建模系统110中。
数据管理模块320管理被用来训练递归机器学习模型的参数的训练数据的训练语料库360。训练数据包括相对于空间或时间是顺序依赖的已知观察序列。数据管理模块320尤其还可以将训练数据编码成数字形式以供递归机器学习模型处理。例如,对于词标记序列x1,x2,...,xT,数据管理模块320可以将每个词标记编码为表示从例如训练语料库360中的文档获得的词的词汇表的独热(one-hot)编码向量,其中与词相对应的唯一元素具有非零值。例如,当训练语料库360的词的词汇表是集合{“向前(forward)”、“向后(backward)”、“左(left)”、“右(right)”}时,词“左”可以被编码为向量x=[0 0 0 1],其中与该词相对应的第四个元素具有唯一的非零值。
训练模块330通过迭代地减少损失函数来训练递归机器学习模型的参数。针对训练序列中的每个观察的损失包括预测损失和散度损失,后者惩罚针对观察的先验分布和潜在分布之间的显著偏差。在一个实施例中,在训练过程期间,递归机器学习模型另外还包括一个转移网络,用于生成潜在状态的先验分布。转移网络可以被布置为与参数集合相关联的一个或多个节点层。转移网络接收从一个或多个先前观察的潜在分布生成的一个或多个值,并通过将参数集合应用于从一个或多个先前观察的潜在分布生成的一个或多个值来生成当前观察的先验分布。
图4图示出了根据实施例的用于递归机器学习模型的示例训练过程。训练模块330在前向传递步骤和反向传播步骤之间进行迭代以训练递归机器学习模型的参数。如图4中所示,除了编码器网络和译码器网络pθ(·)之外,递归机器学习模型另外还包括与训练参数集合ψ相关联的转移网络gψ(·)。除了序列中的其他词标记之外,训练序列还包括词标记xt-1“小(little)”和词标记xt“星星(star)”。
在前向传递步骤期间,训练模块330针对序列中的每个观察生成估计的潜在分布和对应的先验分布。训练模块330还针对序列中的每个观察生成后续观察的预测可能性。具体地,对于当前观察xt,训练模块330通过沿编码器网络的各层将输入参数集合应用于当前观察xt,并且将递归参数集合γ应用于先前步骤t-1的编码器网络层,来生成估计的潜在分布训练模块330还通过将转移网络gψ(·)应用于从先前观察的潜在分布生成的一个或多个值vt-1来生成估计的先验分布gψ(zt|zt-1)。训练模块330还从当前潜在分布生成一个或多个值vt。
训练模块330通过将译码器网络pθ(·)应用于值vt来生成后续观察pθ(xt+1|zt)的预测可能性。针对序列中剩余的后续观察重复此过程。在一个实例中,编码器网络被配置为接收单热编码令牌向量作为输入。在这种实例中,译码器网络可以被配置为生成输出向量,其中输出向量中的每个元素对应于观察该元素的对应标记的预测可能性。
在前向传递步骤之后,训练模块330确定针对序列中的每个观察的损失。对于当前观察xt,损失包括随着后续观察pθ(xt+1|zt)的预测可能性降低而增加的预测损失,以及惩罚观察xt的先验分布gψ(zt|zt-1)和潜在分布之间的显著偏差的散度损失。在一个实施例中,当前观察xt的预测损失由下式给出:
它对在当前潜在分布上预测后续观察xt+1的可能性取期望。因此,等式(1)的预测损失可以通过对后续观察的预测可能性pθ(xt+1|zt)取期望来确定,通过将译码器网络pθ(·)应用于前向传递步骤中的值vt来生成该期望。在一个实施例中,当前观察xt的散度损失由下式给出:
Ld=KL(qφ(zt|xt,zt-1)||gψ(zt|zt-1)) (2)
其中KL(·)标示当前观察xt的潜在分布和先验分布的Kullback-Leibler散度。因此,等式(2)的散度损失测量了当前观察xt的潜在分布和先验分布之间的差异。
训练模块330将损失函数确定为针对序列中的每个观察的损失的组合。在一个实例中,观察序列的损失函数由下式确定:
其中t标示序列中的观察的索引,并且λ、γ是控制对每项的贡献的超参数。在反向传播步骤期间,训练模块330通过反向传播一个或多个误差项以减少损失函数来更新编码器网络译码器网络pθ(·)和转移网络gψ(·)的参数。因此,通过增加λ和γ之间的比率,递归机器学习模型的参数被训练为相对于散度损失减少预测损失,并且通过减小λ和γ之间的比率,参数被训练为相对于预测损失减少散度损失。
以这种方式,编码器网络和译码器网络pθ(·)的参数被训练以使得被用来生成后续预测的当前预测的潜在分布不会显著偏离仅基于当前预测值的先前潜在状态。这允许更简单的潜在状态的表示,以及用于递归机器学习模型的更有效的训练过程。
在一个实例中,当前观察xt的潜在分布q(zt|xt,zt-1)和先验分布gψ(zt|zt-1)由概率分布的统计参数定义。在图4中所示的示例中,估计的潜在分布可以是由均值μt和协方差矩阵∑t所定义的高斯分布。估计的先验分布gψ(zt|zt-1)可能是由均值和协方差矩阵所定义的高斯分布。在这种实例中,编码器网络的最后一层可以被配置为输出定义潜在分布的统计参数。转移网络gψ(·)的最后一层也可以被配置为输出定义先验分布gψ(zt|zt-1)的统计参数。可替代地,训练模块330可以通过将转移网络gψ(·)的估计参数集合应用于从先前潜在分布生成的值vt-1并对这些值的输出求平均来确定先验分布的统计参数。
当先验分布被建模为高斯概率分布时,当前观察xt的先验分布的统计参数可以通过下式来确定:
其中Wμ、bμ、WΣ和bΣ是转移网络gψ(·)的参数集合。在另一个实例中,先验分布的统计参数可以通过下式确定:
其中W1、b1、W2、b2、Wμ和bμ是转移网络gψ(·)的参数集合。在另一个实例中,先验分布的统计参数可以通过下式确定:
kt=sigmoid(W1·vt-1+b1)
其中W1、b1、W2、b2、Wμ、bμ、WΣ和bΣ是转移网络gψ(·)的参数集合。在另一个实例中,先验分布的统计参数可以通过下式确定:
其中W1、b1、W2、b2、W3、b3、W4、b4、Wμ、bμ、WΣ和bΣ是转移网络gψ(·)的参数集合。符号标示矩阵乘法,而⊙标示逐元素乘法。softplus函数被定义为softplus(x)=ln(1+ex)。在一个实例中,转移网络gψ(·)的复杂性从等式(4)增加到(7),并且训练模块330可以依赖于数据的复杂性来选择转移网络gψ(·)的合适架构以进行训练。
在一个实施例中,训练模块330用编码器网络来训练递归机器学习模型,该编码器网络包括嵌入层和放置在嵌入层之后的一系列隐藏层。通过将用于嵌入层的输入参数集合应用于输入向量来生成嵌入层。通过将对应的输入参数子集应用于前一个输出来生成每个隐藏层。在一个实例中,用于递归机器学习模型的递归参数集合被配置为使得用于通过将递归参数子集应用于在前一步t-1处的特定隐藏层的值来生成当前步骤t的特定隐藏层。
图5图示出了根据实施例的包括嵌入层的递归机器学习模型的示例架构。如图5中所示,编码器网络的架构包括作为第一层的嵌入层e和放置在嵌入层e之后的多个隐藏层h1,h2,...,hl。在图5中所示的示例中,通过将输入参数的子集应用于词标记的输入向量来生成用于步骤t的嵌入层et。在训练过程期间,这可以是当前观察xt的输入向量,并且在推理过程期间,这可以是当前预测的输入向量。通过将输入参数的子集应用于前一个输出并将递归参数γ的子集应用于在前一步t-1处的对应隐藏层的值来生成每个后续隐藏层ht。
在训练过程完成之后,用于词标记的嵌入向量e被配置为表示词在潜在空间中的嵌入,以使得针对一个词标记的嵌入与针对共享相似含义或出现在相似上下文中的其他词标记的嵌入在距离上更靠近,并且与针对含义不同或出现在不同上下文中的其他词标记的嵌入在距离上更远,这由词嵌入模型(诸如word2vec)来确定。以这种方式,编码器网络的其余层可以处理具有更好上下文信息的词标记,并且可以帮助提高模型的预测准确度。
返回图3,预测模块335接收请求以执行一个或多个任务从而使用经训练的递归机器学习模型来生成顺序预测。类似于图2的推理过程,预测模块335可以重复地应用编码器网络和译码器网络的参数集合以生成一个或多个顺序预测。在一个示例中,顺序预测是有序词集合。在这样的示例中,基于表示先前词预测的上下文的当前预测的潜在状态来生成后续词预测。在另一个示例中,顺序预测是对访问模式的预测,诸如患者到医院的访问模式。在这样的示例中,基于表示患者先前访问模式的上下文的当前预测的潜在状态来生成后续访问预测。
在一个实例中,当顺序预测是词或短语标记并且译码器网络被配置为生成概率的输出向量时,预测模块335可以通过选择输出向量中与最高可能性相关联的标记来确定当前预测。在另一个实例中,预测模块335可以基于由递归机器学习模型所生成的可能性来选择输出向量中满足替代标准的标记。
图6图示出了根据实施例的用于训练递归机器学习模型的方法。建模系统110获得602已知观察序列。已知观察序列可以是有序数据集合,递归机器学习模型可以使用有序数据集合来学习相对于空间或时间的顺序依赖性。对于序列中的每个观察,建模系统110通过将编码器网络应用于当前观察以及编码器网络针对一个或多个先前观察的值来生成604当前观察的当前潜在分布。鉴于当前观察的值和一个或多个先前观察的潜在状态,当前观察的潜在分布表示当前观察的潜在状态的分布。建模系统110还通过将转移网络应用于从先前观察的先前潜在分布所生成的一个或多个先前观察的估计的潜在状态来生成606先验分布。鉴于与当前观察值无关的一个或多个先前观察的潜在状态,当前观察的先验分布表示当前观察的潜在状态的分布。
建模系统110从当前潜在分布生成608当前观察的估计潜在状态。建模系统110通过将译码器网络应用于当前观察的估计潜在状态来生成610用于观察后续观察的预测可能性。建模系统110将当前观察的损失确定612为预测损失和散度损失的组合。预测损失随着后续观察的预测可能性的降低而增加。散度损失测量当前观察的潜在状态的先验分布和潜在分布之间的差异。建模系统110将损失函数确定614为针对序列中的每个观察的损失的组合,并且反向传播一个或多个误差项以更新编码器网络、译码器网络和转移网络的参数。
示例递归模型的性能结果
图7A到图7C图示出了本文所呈现的示例递归机器学习模型与其他现有技术模型相比的性能结果。具体而言,图7A到图7C中所示的结果分别在作为“Penn Treebank”(PTB)数据集和“WikiText-2”(WT2)数据集的子集的训练数据集上训练本文所讨论的递归机器学习模型和其他模型。PTB数据集包含10,000个词汇表,其中训练数据集中有929,590个标记,验证数据集中有73,761个标记,并且测试数据集中有82,431个标记。WT2数据集包含33,278个词汇,其中训练数据集中有2,088,628个标记,验证数据集中有217,646个标记,并且测试数据集中有245,569个标记。
通过将模型应用于作为不与训练数据重叠的相同数据集的子集的测试数据,并比较测试数据中由模型所生成的迭代的预测词标记等于测试数据中的已知词标记的词标记比例,来确定每个模型的性能。在语言处理上下文中测量模型性能的一个度量是困惑度。困惑度指示模型预测数据集中的样本的程度。低困惑度可以指示该模型擅长生成准确的预测。
图7A图示出了PTB数据集的困惑度。图7B图示出了WT2数据集的困惑度。除了其他类型的模型之外,“LSTM”模型是一个基本两层LSTM架构,具有大小为200的嵌入层、大小为400的隐藏层和大小为200的输出层。“LSTM-tie”模型是在架构上与LSTM模型相似,不同之处在于嵌入层的参数与输出层的参数关联。“AWD-LSTM-MOS”模型是语言处理中最先进的mixture-of-softmaxes模型。“LSTM+LatentShift”模型是使用本文所描述的转移网络通过正则化过程修改的LSTM模型。LSTM+LatentShift模型的输出层被加倍以合并由编码器网络输出的统计参数。类似地,“LSTM-tie+LatentShift”模型是使用转移网络通过正则化过程修改的LSTM-tie模型,“AWD-LSTM-MOS+LatentShift”模型是使用转移网络通过正则化过程修改的AWD-LSTM-MOS模型,其中潜在状态的大小与MOS模型中的输出层的大小匹配。
如图7A-图7B中所指示的,使用本文所描述的正则化过程训练的递归机器学习模型始终优于其他最先进的递归模型,相对增益多于10%。特别地,虽然AWD-LSTM-MOS模型是具有许多超参数的模型,但正则化过程能够在不更改默认超参数值的情况下改进该模型。
图7C图示出了在PTB和WT2数据集上训练MOS模型和MOS+LatentShift模型以达到收敛所需的时期数。如图7C中所指示的,MOS+LatentShift模型针对PTB数据集的收敛速度快了近3倍,并且针对WT2数据集的收敛速度快了近2倍。这是一个非常显著的加速,因为训练MOS模型的计算要求很高,并且即使在多个GPU上也可能需要几天。性能结果表明,应用本文所描述的正则化过程可以减少用于训练递归机器学习模型的计算资源和复杂性,同时提高预测准确性。
总结
以上对本发明实施例的描述是为了说明的目的;它并不旨在详尽或将本发明限制为所公开的精确形式。相关领域的技术人员可以领会,根据上述公开内容,许多修改和变化是可能的。
本说明书的某些部分根据对信息的操作的算法和符号表示来描述本发明的实施例。这些算法描述和表示通常被数据处理领域的技术人员用来将他们的工作的实质有效地传达给本领域的其他技术人员。尽管在功能上、计算上或逻辑上描述了这些操作,但是这些操作应被理解为通过计算机程序或等效电路、微代码等来实现。此外,在不失一般性的情况下,有时也证明将这些操作布置称为模块是方便的。所描述的操作及其相关联的模块可以以软件、固件、硬件或其任何组合来体现。
本文所描述的任何步骤、操作或过程可以利用一个或多个硬件或软件模块单独执行或实现,或者与其他设备组合来执行或实现。在一个实施例中,软件模块利用计算机程序产品来实现,该计算机程序产品包括包含计算机程序代码的计算机可读介质,该计算机程序代码可以由计算机处理器执行以用于执行所描述的任何或所有步骤、操作或过程。
本发明的实施例还可以涉及一种用于执行本文操作的装置。该装置可以为所需目的而专门构造,和/或它可以包括由存储在计算机中的计算机程序选择性地激活或重新配置的通用计算设备。这样的计算机程序可以被存储在非瞬态的有形的计算机可读存储介质或者适合于存储电子指令的任何类型的介质中,其可以耦合到计算机系统总线。此外,本说明书中提及的任何计算系统可以包括单个处理器或者可以是采用多处理器设计以增加计算能力的架构。
本发明的实施例还可以涉及通过本文所描述的计算过程产生的产品。这样的产品可以包括从计算过程产生的信息,其中信息被存储在非瞬态的有形的计算机可读存储介质上并且可以包括计算机程序产品或本文所描述的其他数据组合的任何实施例。
最后,说明书中使用的语言主要是为了可读性和指导性目的而选择的,并且并未被选择来刻画或限制本发明的主题。因此,本发明的范围旨在不受该详细描述的限制,而是由基于本文的申请中所发布的任何权利要求限制。因此,本发明实施例的公开内容旨在说明而非限制在以下权利要求中阐述的本发明的范围。
Claims (20)
1.一种训练递归机器学习模型的方法,所述递归机器学习模型具有编码器网络、译码器网络和转移网络,所述方法包括:
获得观察序列;
对于所述序列中的每个观察,重复执行以下步骤:
通过将所述编码器网络应用于当前观察以及所述编码器网络针对一个或多个先前观察的值,来生成针对所述当前观察的当前潜在分布,鉴于所述当前观察的值和针对所述一个或多个先前观察的潜在状态,所述当前潜在分布表示针对所述当前观察的潜在状态的分布;
通过将所述转移网络应用于从针对所述一个或多个先前观察的先前潜在分布所生成的、针对所述一个或多个先前观察的估计潜在状态,来生成先验分布,鉴于与所述当前观察的所述值无关的、针对所述一个或多个先前观察的所述潜在状态,所述先验分布表示针对所述当前观察的所述潜在状态的分布;
从所述当前潜在分布中生成针对所述当前观察的估计潜在状态;
鉴于针对所述当前观察的所述潜在状态,通过将所述译码器网络应用于针对所述当前观察的所述估计潜在状态,来生成用于观察后续观察的预测可能性;以及
确定针对所述当前观察的损失,所述损失包括预测损失和散度损失的组合,所述预测损失随着针对所述后续观察的所述预测可能性的降低而增加,而所述散度损失指示所述当前潜在分布和所述先验分布之间的差异的量度;以及
将所述观察序列的损失函数确定为针对所述序列中的每个观察的损失的组合;以及
反向传播来自所述损失函数的一个或多个误差项,以更新所述编码器网络、所述译码器网络和所述转移网络的参数。
2.根据权利要求1所述的方法,其中针对所述当前观察的所述估计潜在状态通过对来自针对所述当前观察的所述潜在分布的一个或多个值进行采样而被生成。
3.根据权利要求2所述的方法,其中生成所述预测可能性包括:通过将所述译码器网络应用于来自针对所述当前观察的所述潜在分布的一个或多个采样值,来生成观察所述后续观察的一个或多个预测可能性。
4.根据权利要求3所述的方法,其中所述预测损失是所述一个或多个预测可能性的期望值。
5.根据权利要求1所述的方法,其中所述散度损失是所述先验分布和所述当前潜在分布之间的Kullback-Leibler散度。
6.根据权利要求1所述的方法,其中所述当前潜在分布由概率分布的统计参数集合定义,并且其中所述编码器网络被配置为输出所述统计参数集合。
7.根据权利要求1所述的方法,其中所述先验分布由概率分布的统计参数集合定义,并且其中生成所述先验分布包括:
将所述转移网络应用于从所述先前潜在分布中采样的一个或多个值,以生成一个或多个对应的输出值;以及
从所述一个或多个输出值估计针对所述先验分布的所述统计参数集合。
8.一种非瞬态计算机可读介质,包含用于在处理器上执行的指令,所述指令包括:
获得观察序列;
对于所述序列中的每个观察,重复执行以下步骤:
通过将编码器网络应用于当前观察以及所述编码器网络针对一个或多个先前观察的值,来生成针对所述当前观察的当前潜在分布,鉴于所述当前观察的值和针对所述一个或多个先前观察的潜在状态,所述当前潜在分布表示针对所述当前观察的潜在状态的分布;
通过将转移网络应用于从针对所述一个或多个先前观察的先前潜在分布所生成的、针对所述一个或多个先前观察的估计潜在状态,来生成先验分布,鉴于与所述当前观察的所述值无关的针对所述一个或多个先前观察的所述潜在状态,所述先验分布表示针对所述当前观察的所述潜在状态的分布;
从所述当前潜在分布中生成针对所述当前观察的估计潜在状态;
鉴于针对所述当前观察的所述潜在状态,通过将译码器网络应用于针对所述当前观察的所述估计潜在状态,来生成用于观察后续观察的预测可能性;以及
确定针对所述当前观察的损失,所述损失包括预测损失和散度损失的组合,所述预测损失随着针对所述后续观察的所述预测可能性的降低而增加,而所述散度损失指示所述当前潜在分布和所述先验分布之间的差异的量度;以及
将所述观察序列的损失函数确定为针对所述序列中的每个观察的损失的组合;以及
反向传播来自所述损失函数的一个或多个误差项,以更新所述编码器网络、所述译码器网络和所述转移网络的参数。
9.根据权利要求8所述的计算机可读介质,其中针对所述当前观察的所述估计潜在状态通过对来自针对所述当前观察的所述潜在分布的一个或多个值进行采样而被生成。
10.根据权利要求9所述的计算机可读介质,其中生成所述预测可能性包括:通过将所述译码器网络应用于来自针对所述当前观察的所述潜在分布的一个或多个采样值,来生成观察所述后续观察的一个或多个预测可能性,并且其中预测损失是一个或多个预测可能性的期望值。
11.根据权利要求8所述的计算机可读介质,其中所述散度损失是所述先验分布和所述当前潜在分布之间的Kullback-Leibler散度。
12.根据权利要求8所述的计算机可读介质,其中所述当前潜在分布由概率分布的统计参数集合定义,并且其中所述编码器网络被配置为输出所述统计参数集合。
13.根据权利要求8所述的计算机可读介质,其中所述先验分布由概率分布的统计参数集合定义,并且其中生成所述先验分布包括:
将所述转移网络应用于从所述先前潜在分布中采样的一个或多个值,以生成一个或多个对应的输出值;以及
从所述一个或多个输出值估计针对所述先验分布的所述统计参数集合。
14.一种被存储在计算机可读存储介质上的递归机器学习模型,其中所述递归机器学习模型通过过程来制造,所述过程包括:
获得观察序列;
对于所述序列中的每个观察,重复执行以下步骤:
通过将编码器网络应用于当前观察以及所述编码器网络针对一个或多个先前观察的值,来生成针对所述当前观察的潜在分布,鉴于所述当前观察的值和针对所述一个或多个先前观察的潜在状态,所述潜在分布表示针对所述当前观察的潜在状态的分布;
通过将转移网络应用于从针对所述一个或多个先前观察的潜在分布所生成的、针对所述一个或多个先前观察的估计潜在状态,来生成先验分布,鉴于与所述当前观察的所述值无关的针对所述一个或多个先前观察的所述潜在状态,所述先验分布表示针对所述当前观察的所述潜在状态的分布;
从所述当前观察的所述潜在分布中生成针对所述当前观察的估计潜在状态;
鉴于针对所述当前观察的所述潜在状态,通过将译码器网络应用于针对所述当前观察的所述估计潜在状态,来生成用于观察后续观察的预测可能性;以及
确定针对所述当前观察的损失,所述损失包括预测损失和散度损失的组合,所述预测损失随着针对所述后续观察的所述预测可能性的降低而增加,而所述散度损失指示所述潜在分布和所述先验分布之间的差异的量度;
将所述观察序列的损失函数确定为针对所述序列中的每个观察的损失的组合;
反向传播来自所述损失函数的一个或多个误差项,以更新所述编码器网络、所述译码器网络和所述转移网络的参数;以及
将所述编码器网络和所述译码器网络的所述参数存储在所述计算机可读存储介质上。
15.根据权利要求14所述的递归机器学习模型,其中针对所述当前观察的所述估计潜在状态通过对来自针对所述当前观察的所述潜在分布的一个或多个值进行采样而被生成。
16.根据权利要求15所述的递归机器学习模型,其中生成所述预测可能性包括:通过将所述译码器网络应用于来自针对所述当前观察的所述潜在分布的一个或多个采样值,来生成观察所述后续观察的一个或多个预测可能性。
17.根据权利要求16所述的递归机器学习模型,其中所述预测损失是所述一个或多个预测可能性的期望值。
18.根据权利要求14所述的递归机器学习模型,其中所述散度损失是所述先验分布与所述当前潜在分布之间的Kullback-Leibler散度。
19.根据权利要求14所述的递归机器学习模型,其中所述当前潜在分布由概率分布的统计参数集合定义,并且其中所述编码器网络被配置为输出所述统计参数集合。
20.根据权利要求14所述的递归机器学习模型,其中所述先验分布由概率分布的统计参数集合定义,并且其中生成所述先验分布包括:
将所述转移网络应用于从所述先前潜在分布中采样的一个或多个值,以生成一个或多个对应的输出值;以及
从所述一个或多个输出值估计针对所述先验分布的所述统计参数集合。
Applications Claiming Priority (3)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
US201862778277P | 2018-12-11 | 2018-12-11 | |
US62/778,277 | 2018-12-11 | ||
PCT/CA2019/050801 WO2020118408A1 (en) | 2018-12-11 | 2019-06-07 | Regularization of recurrent machine-learned architectures |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113396429A true CN113396429A (zh) | 2021-09-14 |
Family
ID=70971030
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201980091101.5A Pending CN113396429A (zh) | 2018-12-11 | 2019-06-07 | 递归机器学习架构的正则化 |
Country Status (5)
Country | Link |
---|---|
US (1) | US12106220B2 (zh) |
EP (1) | EP3895080A4 (zh) |
CN (1) | CN113396429A (zh) |
CA (1) | CA3117833A1 (zh) |
WO (1) | WO2020118408A1 (zh) |
Families Citing this family (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111899869A (zh) * | 2020-08-03 | 2020-11-06 | 东南大学 | 一种抑郁症患者识别系统及其识别方法 |
CN114037048B (zh) * | 2021-10-15 | 2024-05-28 | 大连理工大学 | 基于变分循环网络模型的信念一致多智能体强化学习方法 |
CN114239744B (zh) * | 2021-12-21 | 2024-07-02 | 南京邮电大学 | 一种基于变分生成对抗网络的个体处理效应评估方法 |
AU2023225811A1 (en) * | 2022-02-28 | 2024-09-12 | The Board Of Trustees Of The Leland Stanford Junior University | Systems and methods to assess neonatal health risk and uses thereof |
Family Cites Families (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US10410113B2 (en) * | 2016-01-14 | 2019-09-10 | Preferred Networks, Inc. | Time series data adaptation and sensor fusion systems, methods, and apparatus |
EP3398114B1 (en) * | 2016-02-05 | 2022-08-24 | Deepmind Technologies Limited | Compressing images using neural networks |
CN110383299B (zh) * | 2017-02-06 | 2023-11-17 | 渊慧科技有限公司 | 记忆增强的生成时间模型 |
WO2018172513A1 (en) * | 2017-03-23 | 2018-09-27 | Deepmind Technologies Limited | Training neural networks using posterior sharpening |
US11586915B2 (en) * | 2017-12-14 | 2023-02-21 | D-Wave Systems Inc. | Systems and methods for collaborative filtering with variational autoencoders |
US20190244680A1 (en) * | 2018-02-07 | 2019-08-08 | D-Wave Systems Inc. | Systems and methods for generative machine learning |
US20210004677A1 (en) * | 2018-02-09 | 2021-01-07 | Deepmind Technologies Limited | Data compression using jointly trained encoder, decoder, and prior neural networks |
US10346524B1 (en) * | 2018-03-29 | 2019-07-09 | Sap Se | Position-dependent word salience estimation |
US10872293B2 (en) * | 2018-05-29 | 2020-12-22 | Deepmind Technologies Limited | Deep reinforcement learning with fast updating recurrent neural networks and slow updating recurrent neural networks |
US20200160176A1 (en) * | 2018-11-16 | 2020-05-21 | Royal Bank Of Canada | System and method for generative model for stochastic point processes |
-
2019
- 2019-06-07 EP EP19895236.8A patent/EP3895080A4/en active Pending
- 2019-06-07 CA CA3117833A patent/CA3117833A1/en active Pending
- 2019-06-07 CN CN201980091101.5A patent/CN113396429A/zh active Pending
- 2019-06-07 US US16/435,213 patent/US12106220B2/en active Active
- 2019-06-07 WO PCT/CA2019/050801 patent/WO2020118408A1/en unknown
Also Published As
Publication number | Publication date |
---|---|
US12106220B2 (en) | 2024-10-01 |
WO2020118408A1 (en) | 2020-06-18 |
EP3895080A1 (en) | 2021-10-20 |
CA3117833A1 (en) | 2020-06-18 |
EP3895080A4 (en) | 2022-04-06 |
US20200184338A1 (en) | 2020-06-11 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Garg et al. | What can transformers learn in-context? a case study of simple function classes | |
Triebe et al. | Ar-net: A simple auto-regressive neural network for time-series | |
US11568207B2 (en) | Learning observation representations by predicting the future in latent space | |
US11593611B2 (en) | Neural network cooperation | |
CN113396429A (zh) | 递归机器学习架构的正则化 | |
US11468324B2 (en) | Method and apparatus with model training and/or sequence recognition | |
CA3074675A1 (en) | System and method for machine learning with long-range dependency | |
US11475220B2 (en) | Predicting joint intent-slot structure | |
US20140032570A1 (en) | Discriminative Learning Via Hierarchical Transformations | |
US11194968B2 (en) | Automatized text analysis | |
EP3916641A1 (en) | Continuous time self attention for improved computational predictions | |
US20220188605A1 (en) | Recurrent neural network architectures based on synaptic connectivity graphs | |
Herzog et al. | Data-driven modeling and prediction of complex spatio-temporal dynamics in excitable media | |
CN111832699A (zh) | 用于神经网络的计算高效富于表达的输出层 | |
US20220405615A1 (en) | Methods and systems for generating an uncertainty score for an output of a gradient boosted decision tree model | |
CN112364659B (zh) | 一种无监督的语义表示自动识别方法及装置 | |
Li et al. | Temporal supervised learning for inferring a dialog policy from example conversations | |
Salman et al. | Nifty method for prediction dynamic features of online social networks from users’ activity based on machine learning | |
US20220414433A1 (en) | Automatically determining neural network architectures based on synaptic connectivity | |
Gupta et al. | Sequential knowledge transfer across problems | |
Oveisi et al. | Software reliability prediction: A machine learning and approximation Bayesian inference approach | |
Nam et al. | Error estimation using neural network technique for solving ordinary differential equations | |
Tomczak | Latent Variable Models | |
Shimamura et al. | A Bayesian approach to multi-task learning with network lasso | |
Huk et al. | David Huk, Lorenzo Pacchiardi, Ritabrata Dutta and Mark Steel's contribution to the Discussion of ‘Martingale posterior distributions’ by Fong, Holmes and Walker |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |