CN113222105A - 元协作训练范式 - Google Patents
元协作训练范式 Download PDFInfo
- Publication number
- CN113222105A CN113222105A CN202110162379.6A CN202110162379A CN113222105A CN 113222105 A CN113222105 A CN 113222105A CN 202110162379 A CN202110162379 A CN 202110162379A CN 113222105 A CN113222105 A CN 113222105A
- Authority
- CN
- China
- Prior art keywords
- model
- generator
- neural network
- parameter values
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 184
- 238000009826 distribution Methods 0.000 claims abstract description 33
- 238000000034 method Methods 0.000 claims description 94
- 238000003062 neural network model Methods 0.000 claims description 83
- 230000006870 function Effects 0.000 claims description 23
- 230000004044 response Effects 0.000 claims description 17
- 239000000203 mixture Substances 0.000 claims description 12
- 238000005070 sampling Methods 0.000 claims description 10
- 230000036039 immunity Effects 0.000 claims description 9
- 238000004590 computer program Methods 0.000 claims description 8
- 238000007476 Maximum Likelihood Methods 0.000 claims description 6
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000002474 experimental method Methods 0.000 abstract description 11
- 230000003042 antagnostic effect Effects 0.000 abstract description 8
- 238000013459 approach Methods 0.000 abstract description 8
- 230000007246 mechanism Effects 0.000 abstract description 7
- 238000004519 manufacturing process Methods 0.000 abstract description 4
- 230000008569 process Effects 0.000 description 15
- 238000011156 evaluation Methods 0.000 description 12
- 238000005457 optimization Methods 0.000 description 12
- 238000012545 processing Methods 0.000 description 9
- 230000000694 effects Effects 0.000 description 7
- 238000012360 testing method Methods 0.000 description 7
- 238000003860 storage Methods 0.000 description 6
- 238000010998 test method Methods 0.000 description 6
- 238000004891 communication Methods 0.000 description 4
- 230000002093 peripheral effect Effects 0.000 description 4
- 230000003287 optical effect Effects 0.000 description 3
- 238000002679 ablation Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 238000013256 Gubra-Amylin NASH model Methods 0.000 description 1
- 241000764238 Isis Species 0.000 description 1
- 230000003044 adaptive effect Effects 0.000 description 1
- 239000000654 additive Substances 0.000 description 1
- 230000000996 additive effect Effects 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000015572 biosynthetic process Effects 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 150000001875 compounds Chemical class 0.000 description 1
- 230000001143 conditioned effect Effects 0.000 description 1
- 238000007796 conventional method Methods 0.000 description 1
- 238000001816 cooling Methods 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 238000005474 detonation Methods 0.000 description 1
- 238000004880 explosion Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 238000007667 floating Methods 0.000 description 1
- 238000009472 formulation Methods 0.000 description 1
- 230000002068 genetic effect Effects 0.000 description 1
- 238000010348 incorporation Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 238000011158 quantitative evaluation Methods 0.000 description 1
- 238000007670 refining Methods 0.000 description 1
- 230000001105 regulatory effect Effects 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000000717 retained effect Effects 0.000 description 1
- 238000007493 shaping process Methods 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000003786 synthesis reaction Methods 0.000 description 1
- 230000002194 synthesizing effect Effects 0.000 description 1
- 239000010409 thin film Substances 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Machine Translation (AREA)
Abstract
生成对抗模型有很多益处;然而,由于模式瓦解,这些生成器面临质量‑多样性的折衷(即,生成器模型为了提高生成质量而牺牲了生成多样性)。本文提出的是通过减速模式瓦解来提高对抗性内容生成的性能的实施例。在一个或多个实施例中,采用协作训练范式,其中第二模型与生成器协作训练,并且帮助有效地塑造生成器的数据分布以防止模式瓦解。此外,可以使用元学习机制的实施例,其中对生成器的协作更新用作高级元任务,并且其有助于确保在对抗性更新后生成器参数保持对模式瓦解的抵抗力。在实验中,经测试的使用证明了对抗文本生成器的模式瓦解的有效减慢。总体而言,实施例在生成质量和多样性两者方面以明显的优势胜过基准方法。
Description
技术领域
本公开总体上涉及用于计算机学习的系统和方法,该系统和方法可以提供改进的计算机性能、特征和用途。更具体地,本公开涉及用于生成模型的对抗训练的系统和方法。
背景技术
神经网络在许多领域都取得了巨大的成功,诸如计算机视觉、自然语言处理、推荐系统等。神经网络模型的一种类型是生成模型,该生成模型用于生成内容,诸如文本和图像。训练生成模型以从训练集中学习真实的数据分布,并且能够在训练完成时生成新的数据点。近年来,它们已成功应用于广泛的应用,包括图像生成、风格化、半监督分类和自然语言生成。应用的一个领域是文本生成的新兴任务,通常将其建模为顺序的离散数据生成过程。此类任务在许多现实世界应用中扮演着关键角色,诸如机器翻译、文本摘要和对话系统。
顺序文本生成模型的训练在很大程度上依赖于在自动回归模型上应用强制教学(teacher forcing),即,以最大似然估计(MLE)进行优化。然而,用强制教学训练生成模型将遭受曝光偏差(exposure bias),即,模型在推理时间被馈送到其预测数据而不是地面真实数据,并且因此由于积累的误差而导致生成不良样本。为了解决曝光偏差问题,针对文本生成的正在进行的主要研究集中在利用对抗训练技术来推导更好的文本生成模型。通常,这种尝试可以分为以下两个方面:第一种方法将生成对抗网络(GAN)与强化学习(RL)进行组合,表示为基于RL;第二种方法仅玩双人对抗式游戏,而无需使用RL,表示为无RL。
基于RL和无RL的文本生成方法两者都遭受模式瓦解(mode collapse),这对于训练基于GAN的模型是众所周知的挑战。也就是说,随着对抗训练的进行,所生成的分布倾向于与生成用于数据的模式子集形成对比。因此,生成器输出重复的语句,并且因此不再表达性地表示数据生成分布。在最近的研究中,已经对这种效果进行了定量评估,结果表明,当从MLE训练移动到对抗训练阶段时,生成器的输出分布的熵将经历明显的下降。为了使用基于GAN的技术推导更好的文本生成模型,一项关键任务是通过有效地减慢对抗性生成器的模式瓦解来实现更好的质量-多样性折衷,即,让生成器从对抗性更新中获取丰富的梯度信息以使其输出更真实(即提高质量),同时容忍较小的模式瓦解效果(即降低多样性)。然而,有限数量的现有基于RL或无RL的方法明确考虑处理GAN训练的模式瓦解。
因此,需要明确地解决对抗训练的模式瓦解的挑战的方法,从而产生改进的文本生成模型。
发明内容
在第一方面,本公开提供了一种用于训练生成器的计算机实现的方法,其包括:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
从训练数据中采样一组数据点;
使用包括一组生成器参数值的生成器模型来生成一组生成的数据点;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用从所述训练数据中采样的所述一组数据点作为到包括第二神经网络模型组的参数值的第二神经网络模型的输入以及到包括所述一组中间生成器参数值的所述生成器模型的输入,计算所述生成器模型的协作训练损失;
使用所述协作训练损失来确定元梯度;
使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的;
使用鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值;以及
响应于已达到所述停止条件,输出所述生成器模型,所述生成器模型包括生成器参数值的最终更新的集合。
在第二方面,本公开提供了一种系统,其包括:
一个或多个处理器;以及
非暂时性计算机可读介质或媒介,包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
从具有第一分布的训练数据中采样一组数据点;
使用包括一组生成器参数值的生成器模型来生成一组生成的数据点;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用从所述训练数据中采样的所述一组数据点作为到包括第二神经网络模型组的参数值的第二神经网络模型的输入以及到包括所述一组中间生成器参数值的所述生成器模型的输入,计算所述生成器模型的协作训练损失;
使用所述生成器模型的所述协作训练损失来确定元梯度;
使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的;
使用鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值;以及
响应于已达到所述停止条件,输出所述生成器模型,所述生成器模型包括生成器参数值的最终更新的集合。
在第三方面,本公开提供了一种用于训练生成器的计算机实现的方法,其包括:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
使用来自真实数据的训练数据集的一组数据点和来自生成对抗系统的生成器模型来生成一组生成的数据点,所述生成对抗系统包括具有一组生成器模型参数值的所述生成器模型以及具有一组鉴别器参数值的鉴别器模型;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解;
使用所述鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的一组参数值;以及
响应于已达到所述停止条件,输出所述生成器模型。
在第四方面,本公开提供了一种系统,其包括:
一个或多个处理器;以及
非暂时性计算机可读介质或媒介,包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
使用来自真实数据的训练数据集的一组数据点和来自生成对抗系统的生成器模型来生成一组生成的数据点,所述生成对抗系统包括具有一组生成器模型参数值的所述生成器模型以及具有一组鉴别器参数值的鉴别器模型;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解;
使用所述鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的一组参数值;以及
响应于已达到所述停止条件,输出所述生成器模型。
在第四方面,本公开提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据第一方面所述的方法。
在第五方面,本公开提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据第三方面所述的方法。
附图说明
将参考本公开的实施例,它们的示例可示于附图中。这些附图旨在是说明性的而非限制性的。虽然本公开大体上在这些实施例的上下文中描述,但应理解,本公开的范围并不旨在限于这些特定实施例。附图中的项目可能未按比例绘制。
附图(“图”)1描绘了根据本公开的实施例的协作训练过程的高级概述。
图2描绘了根据本公开的实施例的示例性生成系统。
图3描绘了根据本公开的实施例的示例性鉴别器系统。
图4描绘了根据本公开的实施例的GAN系统和Meta-CoTGAN数据流方法的概述。
图5描绘了根据本公开的实施例的用于训练生成器模型的Meta-CoTGAN方法。
图6描绘了根据本公开的实施例的用于使用已经使用Meta-CoTGAN方法训练的生成器模型的方法。
图7描绘了根据本公开的实施例的在关于NLLoracle损失具有长度为20的合成oracle上的评估结果。
图8包含表2,其根据本公开的实施例呈现了在数据集上的评估结果。结果经6次运行取平均(随机种子),并且对于NLLgen(最后一列),越小越好。
图9描绘了根据本公开的实施例的RelGAN和Meta-CoTGAN实施例的NLLgen和BLEU-5结果。
图10包含表3,其根据本公开的实施例呈现了在数据集2上的评估结果。结果经6次运行取平均,并且对于NLLgen(最后一列),越小越好。
图11包含表4,其根据本公开的实施例呈现了在数据集1上的消融研究结果。当协作训练部分和元优化分别关闭时,评估包括Meta-CoTGAN实施例。报告的评分源自6个随机种子。
图12描绘根据本公开的实施例的计算设备/信息处理系统的简化框图。
具体实施方式
在以下描述中,出于解释目的,阐明具体细节以便提供对本公开的理解。然而,将对本领域的技术人员显而易见的是,可在没有这些细节的情况下实践本公开。此外,本领域的技术人员将认识到,下文描述的本公开的实施例可以以各种方式(例如过程、装置、系统、设备或方法)在有形的计算机可读介质上实施。
附图中示出的组件或模块是本公开实施例的示例性说明,并且意图避免使本公开不清楚。还应理解,在本论述的全文中,组件可描述为单独的功能单元(可包括子单元),但是本领域的技术人员将认识到,各种组件或其部分可划分成单独组件,或者可整合在一起(包括例如位于单个的系统或组件内)。应注意,本文论述的功能或操作可实施为组件。组件可以以软件、硬件、或它们的组合实施。
此外,附图内的组件或系统之间的连接并不旨在限于直接连接。相反,在这些组件之间的数据可由中间组件修改、重格式化、或以其它方式改变。另外,可使用另外或更少的连接。还应注意,术语“联接”、“连接”、“通信地联接”、“接合”、“接口”或其派生词中的任一个,应理解为包括直接连接、通过一个或多个中间设备来进行的间接连接、和无线连接。还应注意,任何通信(诸如信号、响应、答复、确认、消息、查询等)可包括一个或多个信息交换。
在本说明书中对“一个或多个实施例”、“优选实施例”、“实施例”、“多个实施例”等的提及表示结合实施例所描述的具体特征、结构、特性或功能包括在本公开的至少一个实施例中,以及可包括在多于一个的实施例中。另外,在本说明书的各个地方出现以上所提到的短语并不一定全都是指相同的实施例或多个相同实施例。
在本说明书的各个地方使用某些术语目的在于说明,并且不应被理解为限制。服务、功能或资源并不限于单个服务、单个功能或单个资源;这些术语的使用可指代相关服务、功能或资源的可分布或聚合的分组。术语“包括”、“包括有”、“包含”和“包含有”应理解为开放性的术语,并且其后任何列出内容都是实例,而不旨在限于所列项目。“层”可包括一个或多个操作。词“最佳”、“优化”、“最优化”等是指对结果或过程的改进,并非要求指定的结果或过程已达到“最佳”或峰值状态。存储器、数据库、信息库、数据存储、表、硬件、高速缓存等在本文中的使用,可用来指代可输入信息或以其它方式记录信息的一个或多个系统组件。
在一个或多个实施例中,停止条件可包括:(1)已执行了设定次数的迭代;(2)已达到一定量的处理时间;(3)收敛(例如,连续迭代之间的差小于第一阈值);(4)发散(例如,性能劣化);(5)已达到可接受的结果。
本领域技术人员应当认识到:(1)可以可选地执行某些步骤;(2)步骤可以不限于本文所述的特定顺序;(3)某些步骤可以以不同的顺序执行;并且(4)某些步骤可以同时进行。
本文所使用的任何标题仅是为了组织目的,并且不应被用于限制说明书或权利要求的范围。本专利文献中提到的每个参考文献/文件以其整体通过引用并入本文。
应注意,本文提供的任何实验和结果均以说明性的方式提供,并且是在特定条件下使用特定实施例进行的;因此,这些实验及其结果均不得用于限制当前专利文件的公开范围。
还应注意,尽管本文描述的实施例可能在文本生成的情景内,但是本公开的各方面不限于此。因此,本公开的各方面可应用或适用于其它情景并用于生成其它内容。
A.概述
可以生成具有足够多样性的高质量文本的训练生成模型是自然语言生成(NLG)社区的重要开放问题。最近,生成性对抗模型已广泛应用于文本生成任务,其中经对抗性训练的生成器减轻了常规最大似然方法所经历的曝光偏差,并且产生了有前途的生成质量。然而,由于对抗训练的模式瓦解的臭名昭著的缺陷,经对抗训练的生成器面临质量-多样性的折衷,即,生成器模型倾向于为了提高生成质量而严重牺牲生成多样性。
本文提出的是新颖方法的实施例,其经由有效地减速对抗训练的模式瓦解来提高对抗内容生成的性能。为此,提出了协作训练范式的实施例,其中与生成器协作地训练语言模型,并且在一个或多个实施例中,语言模型被用来有效地塑造生成器的数据分布以防止模式瓦解。此外,在一个或多个实施例中,代替原则上进行针对生成器的协作更新,制定了元学习机制,其中对生成器的协作更新充当高级元任务,凭直觉确保对抗性更新后的生成器的参数将保持一致以防止模式瓦解。在实验中,证明了实施例可以有效地减慢对抗文本生成器的模式瓦解的速度。总体而言,实施例能够在验证域中的生成质量和多样性两者方面以明显的优势胜过基准方法。
除了具有强制教学的训练语言模型的常规方法之外,当前用于文本生成的方法通常可以被分类为基于RL的方法或无RL的方法。大多数基于RL的方法将文本生成公式化为马尔可夫决策过程(MDP)。通常,生成器会通过策略梯度算法或其变型使用从GAN的鉴别器得出的奖励信号进行更新。此类方法的突出示例包括SeqGAN、RankGAN、LeakGAN和MaskGAN。从鉴别器模型得出的有噪声的奖励信号倾向于使这种基于RL的模型遭受高方差梯度,以更新生成器的参数。除了梯度的高方差之外,基于RL的方法还面临由部分序列评估、学习缓慢和敏感超参数带来的困难。考虑到针对基于RL的方法的此类挑战,可以认为实施例属于但不限于无RL方法的类别。无RL方法的突出示例包括TextGAN、FM-GAN、GSGAN和Rel-GAN。此类方法为生成器提供低方差梯度,并且经常导致更稳定的训练。
大多数对抗文本生成模型首先通过MLE进行预训练,然后在基于RL或无RL的机制下通过对抗训练不断进行优化。当从MLE训练切换到对抗训练阶段时,基于RL和无RL方法两者的生成器模型将遭受模式瓦解问题。本文的一个或多个实施例的核心直觉是利用协作训练的语言模型来减速对抗训练的模式瓦解。尽管利用语言模型以促进对抗性文本生成的类似直觉与其他著作吻合,但是还是存在明显差异。在J.Xu、X.Ren、J.Lin和X.Sun的“DP-GAN:用于生成信息性文本和多样化文本的促进多样性的生成性对抗网络(DP-GAN:Diversity-Promoting Generative Adversarial Network for Generating Informative andDiversified Text)”(获自arXiv预印本arXiv:1802.01345(2018))中,用于对抗训练的鉴别器被建模为语言模型,该模型使真实数据的概率最大化,并且使生成数据的概率最小化。此外,在基于RL的设置下,将从语言模型得出的输出用作奖励信号,以促进生成多样性。由Sidi Lu、Lantao Yu、Siyuan Feng、Yaoming Zhu和Weinan Zhang在第36届机器学习国际会议论文集的PMLR97:4164-4172(2019)(以下称为“Lu等人2019”)中的“CoT:用于离散数据的生成建模的协作训练(CoT:Cooperative training for generative modeling ofdiscrete data)”中,其中在线训练语言模型,以提供目标分布,以便使实际数据分布与生成的分布之间的詹森-香农(Jensen-Shannon)散度最小化。相比之下,一个或多个实施例可以被认为采用相似的策略来训练语言模型,但是针对生成器模型的协作训练在其他差异中大为不同。例如,实施例包括不同的元学习设置,以优化生成器的协作训练损失。
总体而言,该专利文件中提出了至少三个贡献。首先,提出了新颖的协作训练方法的实施例,其中使用语言模型来有效地塑造对抗文本生成器的输出分布。该方法的实施例有效地减慢了对抗文本生成器的模式瓦解,并且因此导致文本生成朝向更好的质量-多样性折衷。其次,为了优化生成器的协作训练损失,本文提出了新颖的元学习机制的实施例。在一个或多个实施例中,协作训练任务用作元任务,并且对抗训练用作基本任务。因此,实施例确保在对抗性更新之后的生成器参数对模式瓦解有抵抗力。第三,在合成和真实世界的数据集上进行的大量实验表明,实施例能够在质量和多样性方面产生更好的文本生成模型。
B.序言
文本生成的任务通常被建模为顺序的离散数据生成过程。让为从基础数据生成分布pdata中提取的N个数据点。每个数据点被表示为离散令牌的序列:x=(y1,...,yT),其中yi表示第i个令牌,并且T表示序列的长度。让Gθ表示由θ参数化的生成器模型。常规的文本生成方法通常使用以下最大似然估计(MLE)来训练语言模型:
其中每个序列x的概率以自回归方式表示:
其中y<i表示先前令牌的序列y1,...,yi-1。
利用GAN进行文本生成的方法尝试在生成器Gθ与鉴别器D之间玩双人游戏。让鉴别器D通过φ来参数化。在对抗设置下,训练生成器Gθ以从pdata产生真实句子给定样本,并且鉴别器Dφ试图区分Gθ的生成的分布pθ和真实的数据分布pdata。因此,上述过程可以公式化为对抗训练机制,如下所示:
在自动回归生成过程中,第i个令牌yi是通过以生成器的先前令牌y<i为条件从生成器的输出分布中采样而生成的。进行这种采样会给生成器利用鉴别器的预测结果带来很大的困难。也就是说,对抗损失的反向传播路线,即
相对于生成器的参数θ变成不可微分,因为由于采样将为零。为了克服上述问题,基于RL的方法主要依赖于REINFORCE算法或其变型来得出梯度,以优化生成器,其中可以利用鉴别器的预测来得出奖励信号。无RL的方法通常通过一些连续近似来松弛不可微分的采样函数,诸如soft-argmax或gumbel-softmax。在一个或多个实施例中,可以使用gumbel-softmax松弛,其将采样的效果建模成将噪声引入到输入中,从而使得输出变得连续且可微分。具体地,噪声是通过Gumbel分布建模的,该分布如下形成:
C.方法实施例
用对抗训练机制(基于RL和无RL的方法两者)训练的语言生成器在从强制教学切换到对抗训练阶段时会遭受模式瓦解。在这一部分中,新颖的元协作训练方法的实施例用于克服此类挑战。总体而言,目标是经由降低其对抗训练的模式瓦解来为语言生成器实现更好的质量-多样性折衷。即,该方法的实施例允许生成器从对抗训练中获得丰富的梯度信息以便提高生成质量,同时在生成多样性方面牺牲很少。总体而言,在一个或多个实施例中,采用语言模型来减速生成器的输出分布的模式瓦解。在一个或多个实施例中,在对抗训练期间,语言模型与生成器Gθ协作训练。语言模型对来自真实数据分布pdata的样本的输出可以用于塑造生成器的输出分布。此外,可以用元优化设置来制定监督。
1.协作训练制定实施例
在本节中给出了协作训练范式的实施例,该协作训练范式参与了对抗生成器Gθ、对抗鉴别器Dφ和语言模型Mψ的交错训练过程,其中ψ表示语言模型的参数。
图1描绘了根据本公开的实施例的协作训练过程的高级概述。用对抗训练训练的生成器Gθ130倾向于遭受模式瓦解(由向内的短黑箭头、例如箭头115形象地描绘)。即,当生成器Gθ130通过对抗损失进行训练时,其生成多样性趋于逐渐减小,以试图提高生成质量。为了克服这种挑战,可以协作训练语言模型Mψ125。在一个或多个实施例中,语言模型125对Gθ的输出分布进行监督,趋向于保留真实数据的期望的生成概率,从而减慢模式瓦解(形象地描绘为向外的短的、浅色的、虚线的轮廓箭头,例如箭头120)。可以从混合样本分布pθ和pdata中训练语言模型。在一个或多个实施例中,从语言模型到语言生成器的监督作用于来自pdata的样本。生成器130可以通过对抗损失和协作训练损失来更新。
在协作训练过程中,可以通过MLE损失来一致地优化语言模型。为了为生成器提供平稳变化的目标分布,在一个或多个实施例中,语言模式用来自混合分布的数据与来自真实数据和生成的数据的平衡样本进行训练,例如尽管可以使用其他混合。在下文的等式(2)中正式定义了用于用MLE更新语言模型的协作训练损失的实施例。可以解释为最小化Mψ之间的直接KL散度以及具有的分布的最佳混合密度模型M*。
用来自真实数据的样本一致地更新语言模型Mψ,并且使用强制教学损失可以使其经历轻微的模式瓦解效果。因此,其输出预测可以提供对生成器Gθ的输出分布的有效监督,以便减慢模式瓦解。此外,与仅使用真实数据分布相比,使用混合分布更新Mψ将提供目标分布,该目标分布趋向于生成器的更新而平稳地变化,这被证明是更有益的。正式地,针对生成器模型的协作训练损失被提出如下:
其中,yi是来自序列x的第i个令牌。因此,KL损失将由语言模型给出的输出分布提炼到生成器。当考虑模式瓦解时,在一个或多个实施例中,关注的是保留来自pdata而不是来自pθ的实际数据的分布。因此,在优化等式(3)时,在一个或多个实施例中,仅采用来自真实数据分布pdata的样本来计算KL损失。由于上述协作训练损失,可以如下得出用于更新生成器的参数的梯度:
因此,可以认为在生成器上应用协作训练的效果等同于以加权方式增加真实数据的密度。
2.元协作优化实施例
在这一部分中,提出了用于对生成器模型参数的对抗训练损失和协作训练损失的优化进行交织的元学习范式的实施例。与致力于实现更快的学习、任务概括或得出自适应模型的常规元学习方法不同,这里的直觉是保留对抗文本生成器模型的生成分布以减速其模式瓦解。
为此,在一个或多个实施例中,将优化对抗损失建模为基本任务,并且将优化协作训练损失建模为元任务。通过这种设置,在一个或多个实施例中,元优化方案确保在用对抗训练损失来优化生成器参数值θ以便提高生成质量之后,所得参数将表现出对模式瓦解的相当大的抵抗力,即,在保持相当大的生成多样性的同时提高生成质量。
正式地,在一个或多个实施例中,可以首先通过优化基本任务损失来对生成器参数θ进行一个梯度更新:
然后,在一个或多个实施例中,从真实数据分布中获得新样本:x~pdata并且根据更新的参数θ′来推断实际样本的元损失元梯度可以由λ>0加权并添加到基本任务梯度以更新参数θ。最后,在元协作训练范式的实施例下的对抗性更新可以公式化如下:
在以下方法1中给出了用于元协作训练的示例性完整方法的实施例。
方法1—元协作训练实施例
图2描绘了根据本公开的实施例的具有相关存储器的示例性生成系统。在并入新的观察xt后,系统通过应用自我关注机制将存储器Mt更新为Mt+1。应当注意,存储器矩阵Mt的每一行是存储器时隙(slot),并且表示查询,表示密钥(key),并且表示值。还应当注意,语言模型也可以是与生成器相同或相似的系统。
图3描绘了根据本公开的实施例的示例性鉴别器系统。在一个或多个实施例中,鉴别器300包括嵌入层、一个或多个卷积层、自我关注层、一个或多个卷积层、线性层和分对数输出。
图4描绘了根据本公开的实施例的GAN系统的概述,并且图5描绘了根据本公开的实施例的用于训练生成器模型的Meta-CoTGAN方法。在一个或多个实施例中,用于训练生成器的计算机实现的方法可以包括以下步骤。可以对来自训练数据405的一组数据点410进行采样(505),并且使用包括一组生成器参数值的生成器模型415,可以生成(510)一组生成的数据点(例如,伪数据点)。使用鉴别器420,该鉴别器接收真实数据点和伪数据点并且试图在两者之间进行区分,可以使用对抗训练损失函数445来计算生成器模型的对抗损失。鉴别器模型的对抗损失和生成器模型的对抗损失可以通过使用最小-最大损失函数来获得。
在一个或多个实施例中,然后可以使用(515)对抗损失和梯度下降来确定生成器模型的一组中间生成器参数值。
在一个或多个实施例中,将从训练数据中采样的一组数据点用作以下各项的输入:(1)第二神经网络模型(例如语言模型425),其包括第二神经网络模型组的参数值;以及(2)使用一组中间生成器参数值的生成器模型415,计算(520)生成器模型的协作训练损失。在一个或多个实施例中,该协作训练损失可以然后用于确定(525)元梯度。
在一个或多个实施例中,使用对抗梯度来更新(530)一组生成器参数值,该对抗梯度是使用生成器模型的对抗损失和元梯度获得的。还可以使用第二神经网络模型的协作训练损失来更新(540)第二神经网络模型的第二神经网络模型组的参数值;并且可以使用鉴别器模型的对抗损失来更新(535)鉴别器模型的一组鉴别器参数值。
在一个或多个实施例中,该处理可以重复直到达到(545)停止条件为止;否则,如果已经达到停止条件,则输出具有其生成器参数值的最终更新集合的生成器模型(550),并且可以将其用于生成。接下来参考图6(如下)讨论经训练的生成器的示例性部署。
在一个或多个实施例中,图5的过程还可以包括初始化步骤。例如,可以初始化至少生成器模型的生成器参数值的集合和鉴别器模型的鉴别器参数值的集合,并且可以使用训练数据、生成器模型和鉴别器模型对生成器模型进行预训练。在一个或多个实施例中,可以使用最小-最大对抗训练来完成预训练。
在一个或多个实施例中,如前所述,第二神经网络模型和生成器模型可以共用相同的神经网络结构。因此,在一个或多个实施例中,来自预训练生成器模型的一组生成器参数值中的至少一些可以用作第二神经网络模型的参数值。还应当注意的是,第二神经网络模型首先用不同的值进行初始化。例如,可以首先使用随机值来初始化所有模型。
在一个或多个实施例中,使用协作训练损失来更新第二神经网络模型的第二神经网络模型组的参数值的步骤可以包括使用最大似然估计(MLE)损失函数。换句话讲,使用协作训练损失来更新第二神经网络模型的第二神经网络模型组的参数值的步骤包括最小化使用从训练数据采样的一组数据点的第二神经网络模型与使用从训练数据采样的数据点和从由生成器模型生成的数据点采样的数据点的混合的第二神经网络模型之间的Kullback-Leibler散度。在一个或多个实施例中,该混合可以是来自训练数据的相等数量或近似相等数量的数据点以及由生成器模型生成的数据点。
图6描绘了根据本公开的实施例的用于使用已经使用Meta-CoTGAN方法训练的生成器模型的方法。给定已经使用Meta-CoTGAN方法实施例训练的生成器模型,可以部署(605)生成器模型以便生成内容。因此,已训练和部署的Meta-CoTGAN生成器模型可以用于(610)生成输出。
D.实验结果
为了方便起见,通常可以将元协作训练生成对抗网络的实施例表示为Meta-CoTGAN。在实验部分中,首先,将一个实施例与另一个但不同的协作训练对应物CoT(Lu等人2019)在合成数据集上进行比较。然后,示出了实施例与两个文本生成数据集(数据集1和数据集2)上的若干基于RL和无RL的方法之间的比较结果。
应当注意,这些实验和结果仅通过说明的方式提供并且使用一个或多个具体实施例在具体条件下执行;因此,这些实验和它们的结果都不应被用来限制本专利文献的公开的范围。
1.实施细节
实施例是在RelGAN(由Weili Nie、Nina Narodytska和Ankit Patel在2019年的国际学习表征会议(ICLR)中的“RelGAN:用于文本生成的关系生成对抗网络(RelGAN:Relational Generative Adversarial Networks For Text Generation)”中提出,该文献通过援引以其全部内容并入本文)、是最先进的方法之一的无RL的对抗文本生成模型之上实现的。应当注意,可以使用其他生成对抗网络。具体地,Rel-GAN采用关系记忆来对输入令牌之间的长距离依赖性建模,并且采用gumbel-softmax松弛来克服生成器训练中的不可微分问题。关系存储器采用1个存储器时隙、带有2头的多头注意力,并且注意密钥大小设置为512。用于协作训练的语言模型采用与生成器相同的网络体系结构,并且在进行预训练后将生成器的参数的权重分配给语言模型。鉴别器采用大小为64的多种表示。在测试实施例中,Adam被用作用于更新所有模型参数的优化算法。
2.评估指标
为了比较,同时根据样本质量和样本多样性来评估各种模型。在当今大多数文本生成工作之后,通过在数据集上进行测试时的BLEU评分指标来评估样本质量,并且在合成数据集上进行测试时通过NLLoracle损失来评估样本质量。NLLoracle损失被定义为从目标LSTM模型得出的由Gθ生成的数据的负对数似然。根据NLLgen损失评估样本多样性,其呈以下形式:
其中,在生成器模型上评估真实数据的密度。因此,具有更好样本多样性的模型将在实际数据空间上具有更广泛的覆盖范围,并且导致更低的NLLgen损失。遭受严重模式瓦解的模型将不再很好地表示真实数据,并且导致更高的NLLgen损失。
3.基准模型
为了评估测试实施例的效率,考虑了MLE以及基于RL的基准,包括SeqGAN、RankGAN和LeakGAN。另外,还与最相关的无RL基准RelGAN进行了比较。在评估过程中,遵循在Rel-GAN中提出的温度设置,并且此处提出了用100和1000的温度评估时所测试方法实施例的结果。
4.合成数据集
第一评估域是合成oracle数据集。该实验采用随机初始化的长短期(LSTM)模型作为目标模型以模拟真实世界的序列,并且从真实数据分布中生成数据。进行的合成实验的序列长度被设置为20。在该域中进行实验的目的是将被测试的实施例与其最接近的协作训练对应物CoT进行比较。虽然可以认为这两种模型采用相似的方式来训练语言模型,但是调查了在这两种方法中提出的在生成器模型上采用相应的协作训练损失的效率。
在图7中证明了NLLoracle损失的学习曲线。总体而言,测试模型实施例收敛到比CoT明显更好的标准。应注意,CoT没有任何预训练阶段,并且其NLLoracle损失会逐渐减少。所测试的方法实施例采用预训练阶段,并且在预训练阶段和对抗训练阶段两者中的损失都减少。应当注意,在收敛时,用于测试方法实施例的NLLoracle损失显著低于CoT。这表明,在样本质量方面,由CoT提出的协作训练机制与测试方法实施例没有可比性。下表1中呈现NLLoracle和NLLgen的评估评分。当比较NLLgen时,所测试的方法实施例实现了比CoT低得多的损失规模。这证明了本文的方法实施例在保留样本多样性方面传达了更高的效率。总体而言,考虑到该模型的性能较差且训练时间较长,因此在以下现实世界数据集实验中未对此进行进一步考虑。
表1:序列长度为20的合成oracle的评估结果。对于CoT,呈现它们对于NLLgen的最佳评分。
方法 | NLL<sub>oracle</sub> | NLL<sub>gen</sub> |
CoT | 8.19 | 7.54 |
Meta-CoTGAN实施例 | 7.69 | 6.86 |
5.数据集1
第二评估域使用真实世界数据集(数据集1),其涉及图像字幕。在朱耀明(YaomingZhu)、陆思迪(Sidi Lu)、郑磊(Lei Zheng)、郭佳贤(Jiaxian Guo)、张卫南(WeinanZhang)、王军(Jun Wang)和余勇(Yong Yu)在2018年6月的SIGIR‘18:第41届国际ACM SIGIR信息检索研究与发展会议(第1097 1100页)中的“Texygen:用于文本生成模型的基准测试平台(Texygen:A Benchmarking Platform for Text Generation Models)”(该文献通过援引以其全部内容并入本文)中提出预处理方法。训练和测试集分别包含大约10,000个句子。句子的最小长度为7,并且最大长度为37。词汇量大约为4,700。
在表2(在图8中)中呈现用于测量样本质量的BLEU-2至BLEU-5的评分以及用于测量样本多样性的NLLgen评分。对于RelGAN和Meta-CoTGAN,将温度(在括号中)设置为100和1000,并且将结果经6次运行取平均(随机种子)。对于NLLgen(最后一列),越小越好。总体而言,所测试的方法实施例显示出优于所有样本质量/多样性指标的显著优势。值得注意的是,所测试的方法实施例导致NLLgen损失明显低于其他基准方法。这表明实施例可以为对抗训练提供对模式瓦解的有效控制,并且最终导致优异的样本多样性。在减少模式瓦解的同时,协同训练也可以产生具有更好的样本质量的模型。
为了进一步验证这一点,在图9中呈现用于样本多样性指标和作为代表性样本质量指标的BLEU-5的学习曲线。图9展示了测试方法实施例以及数据集1上的基准RelGAN的质量-多样性折衷。与RelGAN相比,Meta-CoTGAN实施例逐步获得了更好的BLEU-5评分,其中模式瓦解的进程明显较慢。用于RelGAN的BLEU-5被绘制到对应的NLLgen损失达到其报告标准时的那一点。否则,由于模型已变成严重的模式瓦解(即生成重复的句子),因此BLEU-5评分不再有意义。
已经观察到,用于RelGAN的NLLgen将迅速上升,这是模式瓦解的迹象。然而,对于Meta-CoTGAN实施例,进展相当缓慢。它表明本文的方法实施例可以有效地减速模式瓦解并且控制由于爆炸而产生的NLLgen损失。当研究样本质量指标时,观察到用于RelGAN的BLEU-5评分将比Meta-CoTGAN实施例上升得更快。但是最终,测试的模型实施例实现了比RelGAN显著更高的标准。另外,观察到当用于RelGAN的NLLgen爆炸时(例如,在400个时期之后),重复率相当高,并且因此生成器就变得无用了。然而,测试方法实施例保留了好得多的多样性。另外,从生成的真实句子中观察到,测试的模型实施例可以生成相当长的句子,而大多数GAN模型都达不到目的。
6.数据集2
第三评估域是另一个数据集(数据集2),其大小比数据集1大得多。数据集2包含270,000个句子的训练集和10,000个句子的测试集。句子的最大长度为51,并且词汇量为大约5,250。使用数据集2的结果在表3(在图10中)中呈现。
可以看出,就所有的BLEU指标和NLLgen而言,测试的Meta-CoTGAN实施例始终优于所有基准。在100的温度设置下,测试的方法实施例在BLEU-4/BLEU-5上优于强RelGAN基准0.041/0.039。明显地,当NLLgen损失处于显著低于RelGAN的水平时,获得用于测试方法实施例的最佳BLEU评分。这表明通过进行协作训练,可以同时得出具有更好样本质量和样本多样性的生成器模型。此外,它表明实施例可以在相当具有挑战性和多样化的现实世界数据集中稳健地良好执行。同时,所测试的方法实施例的性能非常稳健,在所有评估指标上在两个温度设置下均始终优于Rel-GAN。通过调查生成的真实样本,可以观察到生成的句子传达了相当多的语义,并且输出包含相当长的句子,这与常规的对抗文本生成器不同,常规的对抗文本生成器很快就会落入生成简短和重复句子的阶段。
E.消融研究
1.协作训练语言模型的影响
在该部分中,展示了使用在线更新的语言模型进行协作训练过程的实施例的影响。为此,直接比较是使用未用协作训练更新的预训练的语言模型。我们将这种基准表示为Meta-CoTGANcot-off。数据集1的结果呈现在表4(在图11中)中。可以观察到,当关闭对语言模型的在线更新时,该模型仍保留了NLLgen方面的可比较的样本多样性,因为在实际数据上仍采用了协作训练损失。然而,在两个温度设定下,样本质量指标表现得没有整组测试方法实施例那么好。这表明与生成器一起共同更新语言模型以使其向生成器提供平稳变化的目标分布是有益的。
2.元优化的影响
还评估了元优化设置的影响。为此,将实施例与采用协同训练损失来优化生成器参数的原理方法进行了比较,该方法以通过加权方式对对抗性损失和协同训练损失进行线性求和的形式来提出,即该基准表示为Meta-CoTGANmeta-off。结果在表4(图11)中示出。总体而言,Meta-CoTGANmeta-off获得NLLgen的可比较的评分。然而,就样本质量指标而言,其性能仍比使用整套解决方案更差。因此,可以得出结论,元优化是用于平衡质量-多样性折衷的重要组成部分。直观地,元优化设置实施例提供了有效的方式来确保在对抗性更新之后生成器参数将从模式瓦解中减速,这对于得出优异的性能很重要。
F.一些结论
本文提出的是用于促进对抗生成模型的训练的元协作训练方法的实施例。实施例利用协作训练的第二模型(例如,语言模型)来经由将实际数据上的第二模型的预测输出分布提炼到对抗生成器模型来有效地减速对抗训练的模式瓦解。使用合成数据集和两个现实世界数据集(具有的序列长度在7到51的范围内)两者来评估所提出方法的实施例。因此,经测试的方法同时在样本质量指标和样本多样性指标上始终优于基准算法。该方法的实施例是通用的,并且可以与面临模式瓦解问题的基于RL或无RL的不同的对抗文本生成算法一起应用。元协作训练的实施例也可以应用于或适应于更多新兴的基于RL的/无GAN的模型。
G.计算系统实施例
在一个或多个实施例中,本专利文献的方面可涉及、可包括一个或多个信息处理系统/计算系统,或者可在一个或多个信息处理系统(或计算系统)上实现。信息处理系统/计算系统可包括可操作来计算、运算、确定、分类、处理、传输、接收、检索、发起、路由、交换、存储、显示、通信、显现、检测、记录、再现、处理或利用任何形式信息、智能或数据的任何手段或手段的组合。例如,计算系统可以是或可包括个人计算机(例如,膝上型计算机)、平板电脑、移动设备(例如,个人数字助理(PDA)、智能手机、平板手机、平板等)、智能手表、服务器(例如,刀片式服务器或机架式服务器)、网络存储设备、摄像机或任何其它合适设备,并且可在大小、形状、性能、功能和价格方面改变。计算系统可包括随机存取存储器(RAM)、一个或多个处理资源(诸如中央处理器(CPU)或硬件或软件控制逻辑)、只读存储器(ROM)和/或其它类型的存储器。计算系统的另外组件可包括一个或多个驱动器(例如,硬盘驱动器、固态驱动器或两者)、用于与外部设备通信的一个或多个网络端口、以及各种输入和输出(I/O)设备(例如键盘、鼠标、手写笔、触摸屏和/或视频显示器)。计算系统还可包括可操作为在各种硬件组件之间传输通信的一个或多个总线。
图12描绘了根据本公开的实施例的信息处理系统(或计算系统)的简化框图。应理解,计算系统可不同地配置并且包括不同组件,包括如图12中所示的更少或更多的部件,但应理解,针对系统1200所示出的功能可操作为支持计算系统的各种实施例。
如图12所示,计算系统1200包括一个或多个中央处理器(CPU)1201,CPU 1201提供计算资源并控制计算机。CPU 1201可用微处理器等实现,并且还可包括一个或多个图处理单元(GPU)1202和/或用于数学计算的浮点协处理器。在一个或多个实施例中,一个或多个GPU 1202可并入显示控制器1209内,诸如一个或多个图卡的一部分。系统1200还可包括系统存储器1219,系统存储器1219可包括随机存取存储器(RAM)、只读存储器(ROM)或两者。
如图12中所示,还可提供多个控制器和外围设备。输入控制器1203表示至各种输入设备1204的接口,例如键盘、鼠标、触摸屏和/或触笔。计算系统1200还可包括存储控制器1207,该存储控制器1207用于与一个或多个存储设备1208对接,存储设备中的每个包括存储介质(诸如磁带或盘)或光学介质(其可用于记录用于操作系统、实用工具和应用程序的指令的程序,它们可包括实施本公开的各方面的程序的实施例)。存储设备1208还可用于存储经处理的数据或是将要根据本公开处理的数据。系统1200还可包括显示控制器1209,该显示控制器1209用于为显示设备1211提供接口,显示设备1211可为阴极射线管(CRT)显示器、薄膜晶体管(TFT)显示器、有机发光二极管、电致发光面板、等离子面板或任何其它类型的显示器。计算系统1200还可包括用于一个或多个外围设备1206的一个或多个外围设备控制器或接口1205。外围设备的示例可包括一个或多个打印机、扫描仪、输入设备、输出设备、传感器等。通信控制器1214可与一个或多个通信设备1215对接,这使系统1200能够通过各种网络(包括互联网、云资源(例如以太云、经以太网的光纤通道(FCoE)/数据中心桥接(DCB)云等)、局域网(LAN)、广域网(WAN)、存储区域网络(SAN))中的任一网络,或通过任何合适电磁载波信号(包括红外信号)来连接至远程设备。如描绘的实施例中所示,计算系统1200包括一个或多个风扇或风扇托盘1218以及一个或多个冷却子系统控制器1217,其监视系统1200(或其组件)的热温度并操作风扇/风扇托盘1218以助于调节温度。
在示出的系统中,所有主要系统组件可连接至总线1216,总线1216可表示多于一个的物理总线。然而,各种系统组件可在物理上彼此接近或可不在物理上彼此接近。例如,输入数据和/或输出数据可远程地从一个物理位置传输到另一物理位置。另外,实现本公开的各方面的程序可经由网络从远程位置(例如,服务器)访问。此类数据和/或程序可通过各种机器可读介质中的任一机器可读介质来传送,机器可读介质包括例如:诸如硬盘、软盘和磁带的磁性介质;诸如光盘(CD)和全息设备的光学介质;磁光介质;以及专门配置成存储或存储并执行程序代码的硬件设备,诸如专用集合成电路(ASIC)、可编程逻辑器件(PLD)、闪存设备、其它非易失性存储器(NVM)设备(诸如基于XPoint的3D设备)、以及ROM和RAM设备。
本公开的方面可利用用于一个或多个处理器或处理单元以使步骤执行的指令在一个或多个非暂态计算机可读介质上编码。应注意,一个或多个非暂态计算机可读介质应包括易失性存储器和/或非易失性存储器。应注意,替代实现方式是可能的,其包括硬件实现方式或软件/硬件实现方式。硬件实施的功能可使用ASIC、可编程的阵列、数字信号处理电路等来实现。因此,任何权利要求中的术语“手段”旨在涵盖软件实现方式和硬件实现方式两者。类似地,如本文使用的术语“计算机可读媒介或介质”包括具有实施在其上的指令程序的软件和/或硬件或它们的组合。利用所构想的这些替代实现方式,应理解,附图以及随附描述提供本领域的技术人员编写程序代码(即,软件)和/或制造电路(即,硬件)以执行所需处理所要求的功能信息。
应注意,本公开的实施例还可涉及具有其上具有用于执行各种计算机实施的操作的计算机代码的非暂态有形计算机可读介质的计算机产品。介质和计算机代码可为出于本公开的目的而专门设计和构造的介质和计算机代码,或者它们可为相关领域中的技术人员已知或可用的。有形计算机可读介质的示例包括例如:诸如硬盘、软盘和磁带的磁性介质;诸如CD和全息设备的光学介质;磁光介质;以及专门配置成存储或存储并执行程序代码的硬件设备,诸如ASIC、可编程逻辑器件(PLD)、闪存设备、其它非易失性存储器(NVM)设备(诸如基于XPoint的3D设备)、以及ROM和RAM设备。计算机代码的示例包括机器代码(例如,编译器产生的代码)以及包含可由计算机使用解释器来执行的更高级代码的文件。本公开的实施例可整体地或部分地实施为可在由处理设备执行的程序模块中的机器可执行指令。程序模块的示例包括库、程序、例程、对象、组件和数据结构。在分布的计算环境中,程序模块可物理上定位在本地、远程或两者的设定中。
本领域的技术人员将认识到,计算系统或编程语言对本公开的实践来说均不重要。本领域的技术人员将还将认识到,多个上述元件可物理地和/或在功能上划分成模块和/或子模块或组合在一起。
本领域技术人员将理解,前文的示例和实施例是示例性的,并且不限制本公开的范围。旨在说明的是,在本领域的技术人员阅读本说明书并研究附图后将对本领域的技术人员显而易见的本公开的所有、置换、增强、等同、组合或改进包括在本公开的真实精神和范围内。还应注意,任何权利要求书的元素可不同地布置,包括具有多个从属、配置和组合。
Claims (25)
1.一种用于训练生成器的计算机实现的方法,其包括:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
从训练数据中采样一组数据点;
使用包括一组生成器参数值的生成器模型来生成一组生成的数据点;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用从所述训练数据中采样的所述一组数据点作为到包括第二神经网络模型组的参数值的第二神经网络模型的输入以及到包括所述一组中间生成器参数值的所述生成器模型的输入,计算所述生成器模型的协作训练损失;
使用所述协作训练损失来确定元梯度;
使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的;
使用鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值;以及
响应于已达到所述停止条件,输出所述生成器模型,所述生成器模型包括生成器参数值的最终更新的集合。
2.如权利要求1所述的计算机实现的方法,其还包括以下初始步骤:
至少初始化所述生成器模型的所述一组生成器参数值以及所述鉴别器模型的所述一组鉴别器参数值;以及
使用所述训练数据、所述生成器模型和所述鉴别器模型对所述生成器模型进行预训练。
3.如权利要求2所述的计算机实现的方法,其中所述第二神经网络模型和所述生成器模型共用相同的神经网络结构,并且所述方法还包括:
使用来自所预训练的生成器模型的所述一组生成器参数值中的至少一些作为所述第二神经网络模型的参数值。
4.如权利要求1所述的计算机实现的方法,其中使用所述第二神经网络模型组的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值的步骤包括:
使用最大似然估计损失函数。
5.如权利要求4所述的计算机实现的方法,其中使用所述第二神经网络模型组的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值的步骤还包括:
最小化使用从所述训练数据采样的所述一组数据点的所述第二神经网络模型与使用从所述训练数据采样的数据点和从由所述生成器模型生成的数据点采样的数据点的混合的所述第二神经网络模型之间的Kullback-Leibler散度。
6.如权利要求5所述的计算机实现的方法,其中所述混合包括来自所述训练数据的相等数量或近似相等数量的数据点以及由所述生成器模型生成的数据点。
7.如权利要求1所述的计算机实现的方法,其中所述鉴别器模型的所述对抗损失和所述生成器模型的所述对抗损失通过使用最小-最大损失函数来获得。
8.一种系统,其包括:
一个或多个处理器;以及
非暂时性计算机可读介质或媒介,包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
从具有第一分布的训练数据中采样一组数据点;
使用包括一组生成器参数值的生成器模型来生成一组生成的数据点;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用从所述训练数据中采样的所述一组数据点作为到包括第二神经网络模型组的参数值的第二神经网络模型的输入以及到包括所述一组中间生成器参数值的所述生成器模型的输入,计算所述生成器模型的协作训练损失;
使用所述生成器模型的所述协作训练损失来确定元梯度;
使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的;
使用鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值;以及
响应于已达到所述停止条件,输出所述生成器模型,所述生成器模型包括生成器参数值的最终更新的集合。
9.如权利要求8所述的系统,其中所述非暂时性计算机可读介质或媒介还包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:
至少初始化所述生成器模型的所述一组生成器参数值以及所述鉴别器模型的所述一组鉴别器参数值;以及
使用所述训练数据、所述生成器模型和所述鉴别器模型对所述生成器模型进行预训练。
10.如权利要求9所述的系统,其中所述第二神经网络模型和所述生成器模型共用相同的神经网络结构,并且所述非暂时性计算机可读介质或媒介还包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:
使用来自所预训练的生成器模型的所述一组生成器参数值中的至少一些作为所述第二神经网络模型的参数值。
11.如权利要求8所述的系统,其中使用所述第二神经网络模型组的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值的步骤包括:
使用最大似然估计损失函数。
12.如权利要求11所述的系统,其中使用所述第二神经网络模型组的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值的步骤还包括:
最小化使用从所述训练数据采样的所述一组数据点的所述第二神经网络模型与使用从所述训练数据采样的数据点和从由所述生成器模型生成的数据点采样的数据点的混合的所述第二神经网络模型之间的Kullback-Leibler散度。
13.如权利要求12所述的系统,其中所述混合包括来自所述训练数据的相等数量或近似相等数量的数据点以及由所述生成器模型生成的数据点。
14.如权利要求8所述的系统,其中所述鉴别器模型的所述对抗损失和所述生成器模型的所述对抗损失通过使用最小-最大损失函数来获得。
15.一种用于训练生成器的计算机实现的方法,其包括:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
使用来自真实数据的训练数据集的一组数据点和来自生成对抗系统的生成器模型来生成一组生成的数据点,所述生成对抗系统包括具有一组生成器模型参数值的所述生成器模型以及具有一组鉴别器参数值的鉴别器模型;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解;
使用所述鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的一组参数值;以及
响应于已达到所述停止条件,输出所述生成器模型。
16.如权利要求15所述的计算机实现的方法,其中使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解的步骤包括:
使用从真实数据的所述训练数据集中采样的所述一组数据点作为到所述第二神经网络模型和到包括所述一组中间生成器参数值的所述生成器模型的输入来计算所述生成器模型的协作训练损失;
使用所述生成器模型的所述协作训练损失来确定元梯度;以及
使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的。
17.如权利要求15所述的计算机实现的方法,其还包括以下初始步骤:
至少初始化所述生成器模型的所述一组生成器参数值以及所述鉴别器模型的所述一组鉴别器参数值;
使用所述训练数据集和所述生成器模型和所述鉴别器模型对所述生成器模型进行预训练;以及
使用来自所预训练的生成器模型的所述一组生成器参数值中的至少一些作为所述第二神经网络模型的参数值。
18.如权利要求15所述的计算机实现的方法,其中所述生成器是自然语言文本生成器,并且所述第二神经网络模型是语言模型。
19.如权利要求15所述的计算机实现的方法,其中使用所述第二神经网络模型组的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值的步骤包括:
最小化使用从真实数据的所述训练数据集采样的所述一组数据点的所述第二神经网络模型与使用从真实数据的所述训练数据集采样的数据点和从由所述生成器模型生成的数据点采样的数据点的混合的所述第二神经网络模型之间的Kullback-Leibler散度。
20.如权利要求19所述的计算机实现的方法,其中所述混合包括来自所述训练数据集的相等数量或近似相等数量的数据点以及由所述生成器模型生成的数据点。
21.一种系统,其包括:
一个或多个处理器;以及
非暂时性计算机可读介质或媒介,包括一个或多个指令集,所述指令集在由所述一个或多个处理器中的至少一者执行时致使执行包括以下各项的步骤:
响应于尚未达到停止条件,执行步骤,所述步骤包括:
使用来自真实数据的训练数据集的一组数据点和来自生成对抗系统的生成器模型来生成一组生成的数据点,所述生成对抗系统包括具有一组生成器模型参数值的所述生成器模型以及具有一组鉴别器参数值的鉴别器模型;
使用对抗训练损失函数来计算所述生成器模型的对抗损失;
使用所述对抗损失和梯度下降来确定用于所述生成器模型的一组中间生成器参数值;
使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解;
使用所述鉴别器模型的对抗损失来更新所述鉴别器模型的一组鉴别器参数值;以及
使用所述第二神经网络模型的协作训练损失来更新所述第二神经网络模型的一组参数值;以及
响应于已达到所述停止条件,输出所述生成器模型。
22.如权利要求21所述的系统,其中使用具有所述一组中间生成器参数值的所述生成器模型和第二神经网络模型来协同训练所述生成器模型,以减速所述生成器模型的模式瓦解的步骤包括:
使用从真实数据的所述训练数据集中采样的所述一组数据点作为到所述第二神经网络模型和到包括所述一组中间生成器参数值的所述生成器模型的输入来计算所述生成器模型的协作训练损失;
使用所述生成器模型的所述协作训练损失来确定元梯度;以及
使用对抗梯度来更新所述一组生成器参数值,所述对抗梯度是使用所述生成器模型的所述对抗损失和所述元梯度获得的。
23.如权利要求21所述的系统,其中使用所述第二神经网络模型组的协作训练损失来更新所述第二神经网络模型的所述第二神经网络模型组的参数值的步骤包括:
最小化使用从真实数据的所述训练数据集采样的所述一组数据点的所述第二神经网络模型与使用从真实数据的所述训练数据集采样的数据点和从由所述生成器模型生成的数据点采样的数据点的混合的所述第二神经网络模型之间的Kullback-Leibler散度。
24.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-7中任一项所述的方法。
25.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求15-20中任一项所述的方法。
Applications Claiming Priority (4)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
US202062970638P | 2020-02-05 | 2020-02-05 | |
US62/970,638 | 2020-02-05 | ||
US17/136,054 | 2020-12-29 | ||
US17/136,054 US20210241099A1 (en) | 2020-02-05 | 2020-12-29 | Meta cooperative training paradigms |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113222105A true CN113222105A (zh) | 2021-08-06 |
CN113222105B CN113222105B (zh) | 2024-07-26 |
Family
ID=77084745
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110162379.6A Active CN113222105B (zh) | 2020-02-05 | 2021-02-05 | 元协作训练范式 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113222105B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114925699A (zh) * | 2022-04-28 | 2022-08-19 | 电子科技大学 | 一种基于风格变换的高迁移性对抗文本生成方法 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20170344808A1 (en) * | 2016-05-28 | 2017-11-30 | Samsung Electronics Co., Ltd. | System and method for a unified architecture multi-task deep learning machine for object recognition |
KR20180065620A (ko) * | 2016-12-08 | 2018-06-18 | 아크위드 주식회사 | 기계학습과 온톨로지 기반 사업모델 작성 방법 및 이를 위한 관리 서비스 시스템 |
US20180336471A1 (en) * | 2017-05-19 | 2018-11-22 | Mehdi Rezagholizadeh | Semi-supervised regression with generative adversarial networks |
US20190251952A1 (en) * | 2018-02-09 | 2019-08-15 | Baidu Usa Llc | Systems and methods for neural voice cloning with a few samples |
US20190295302A1 (en) * | 2018-03-22 | 2019-09-26 | Northeastern University | Segmentation Guided Image Generation With Adversarial Networks |
CN110634108A (zh) * | 2019-08-30 | 2019-12-31 | 北京工业大学 | 一种基于元-循环一致性对抗网络的复合降质网络直播视频增强方法 |
-
2021
- 2021-02-05 CN CN202110162379.6A patent/CN113222105B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20170344808A1 (en) * | 2016-05-28 | 2017-11-30 | Samsung Electronics Co., Ltd. | System and method for a unified architecture multi-task deep learning machine for object recognition |
KR20180065620A (ko) * | 2016-12-08 | 2018-06-18 | 아크위드 주식회사 | 기계학습과 온톨로지 기반 사업모델 작성 방법 및 이를 위한 관리 서비스 시스템 |
US20180336471A1 (en) * | 2017-05-19 | 2018-11-22 | Mehdi Rezagholizadeh | Semi-supervised regression with generative adversarial networks |
US20190251952A1 (en) * | 2018-02-09 | 2019-08-15 | Baidu Usa Llc | Systems and methods for neural voice cloning with a few samples |
US20190295302A1 (en) * | 2018-03-22 | 2019-09-26 | Northeastern University | Segmentation Guided Image Generation With Adversarial Networks |
CN110634108A (zh) * | 2019-08-30 | 2019-12-31 | 北京工业大学 | 一种基于元-循环一致性对抗网络的复合降质网络直播视频增强方法 |
Non-Patent Citations (2)
Title |
---|
J.XU, AT EL.: ""DP‑GAN:Diversity‑ Promoting Generative Adversarial Network for Generating Informative and Diversified Text"", ARXIV * |
WEILI NIE, AT EL.: ""RelGAN: Relational Generative Adversarial Networks For Text Generation"", ICLR * |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114925699A (zh) * | 2022-04-28 | 2022-08-19 | 电子科技大学 | 一种基于风格变换的高迁移性对抗文本生成方法 |
Also Published As
Publication number | Publication date |
---|---|
CN113222105B (zh) | 2024-07-26 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Wang et al. | Self-tuning for data-efficient deep learning | |
Nguyen et al. | The application of machine learning and deep learning in sport: predicting NBA players’ performance and popularity | |
Rere et al. | Metaheuristic algorithms for convolution neural network | |
Kant et al. | Practical text classification with large pre-trained language models | |
Zhang et al. | Towards efficient data free black-box adversarial attack | |
US11983617B2 (en) | Scalable and compressive neural network data storage system | |
Li et al. | Insufficient data can also rock! learning to converse using smaller data with augmentation | |
US20210241099A1 (en) | Meta cooperative training paradigms | |
CN112015904B (zh) | 确定文档语料库的潜在主题的方法、系统和计算机可读介质 | |
Zhu et al. | Distance based multiple kernel ELM: A fast multiple kernel learning approach | |
Rusak et al. | If your data distribution shifts, use self-learning | |
Paul et al. | Lottery tickets on a data diet: Finding initializations with sparse trainable networks | |
US20190228310A1 (en) | Generation of neural network containing middle layer background | |
Zhang et al. | Badlabel: A robust perspective on evaluating and enhancing label-noise learning | |
CN113222105B (zh) | 元协作训练范式 | |
Jiang et al. | Dynamic loss for robust learning | |
Zhao et al. | Balanced and accurate pseudo-labels for semi-supervised image classification | |
CN114190102B (zh) | 用于多目标排序的系统、计算机实施的方法和非暂时性计算机可读介质 | |
Ahn et al. | Fine tuning pre trained models for robustness under noisy labels | |
Zhu et al. | A hybrid model for nonlinear regression with missing data using quasilinear kernel | |
US20210383226A1 (en) | Cross-transformer neural network system for few-shot similarity determination and classification | |
Ouyang et al. | Supervised contrastive learning with corrected labels for noisy label learning | |
Chen et al. | Automatic Noise Generation and Reduction for Text Classification | |
Tao et al. | Cross‐Corpus Speech Emotion Recognition Based on Transfer Learning and Multi‐Loss Dynamic Adjustment | |
Fallahian et al. | Beyond Noise: Incorporating Pre-Trained Contractive Autoencoders for Enhanced GAN-based Tabular Data Creation |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |