CN112016611A - 生成器网络和策略生成网络的训练方法、装置和电子设备 - Google Patents

生成器网络和策略生成网络的训练方法、装置和电子设备 Download PDF

Info

Publication number
CN112016611A
CN112016611A CN202010867110.3A CN202010867110A CN112016611A CN 112016611 A CN112016611 A CN 112016611A CN 202010867110 A CN202010867110 A CN 202010867110A CN 112016611 A CN112016611 A CN 112016611A
Authority
CN
China
Prior art keywords
network
training
vector
prediction
generator
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010867110.3A
Other languages
English (en)
Inventor
白沁洵
尼尔·拉茨拉夫
徐伟
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing Horizon Robotics Technology Co Ltd
Original Assignee
Nanjing Horizon Robotics Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing Horizon Robotics Technology Co Ltd filed Critical Nanjing Horizon Robotics Technology Co Ltd
Publication of CN112016611A publication Critical patent/CN112016611A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Abstract

公开了一种生成器网络的训练方法、用于增强学习的策略生成网络的训练方法、装置和电子设备。该生成器网络的训练方法通过生成器网络的每个网络单元生成预测网络的一层,并基于预测网络所预测的状态向量的概率分布与真实的状态向量之间的KL散度值更新该生成器网络的参数。且该用于增强学习的策略生成网络的训练方法基于生成器网络生成的多个预测网络所预测的多个状态向量之间的差异来计算用于增强学习的策略生成网络的内在奖励函数值,以训练策略生成网络。这样,提高了生成器网络的性能和策略生成网络的探索效率。

Description

生成器网络和策略生成网络的训练方法、装置和电子设备
技术领域
本公开涉及增强学习技术领域,且更为具体地,涉及一种生成器网络的训练方法、用于增强学习的策略生成网络的训练方法、装置和电子设备。
背景技术
近来,增强学习(RL)在很多应用领域内取得了成功,包括在各种游戏中展现出超越人的性能,此外在机器人控制任务和基于图像的控制任务也表现出优异的性能。
但是,尽管取得了很多成就,当前的增强学习技术遭受采样效率不良的影响,对于实际的执行任务的对象来说,通常在实现合理的性能之前需要进行几百万甚至几十亿的模拟步骤的训练。因此,这种统计效率的缺乏使得增强学习难以应用于实际世界的任务,因为在实际世界的任务中用于执行任务的对象的动作成本要远高于在模拟器中模拟的执行任务的对象的模拟动作的成本。
也就是,在当前的实际的增强学习的任务中,用于生成执行任务的对象的动作的策略生成网络的训练方案需要进行改进,以提高执行任务的对象的动作的有效性。
发明内容
为了解决上述技术问题,提出了本公开。本公开的实施例提供了一种生成器网络的训练方法、用于增强学习的策略生成网络的训练方法、装置和电子设备,其通过预测的状态向量的概率分布与真实的状态向量的概率分布之间的KL散度值更新用于生成预测网络的生成器网络的参数,并基于生成器网络生成的多个预测网络所预测的多个状态向量之间的差异来计算用于增强学习的策略生成网络的内在奖励函数值,由于内在奖励函数值能够体现出策略生成网络相对于动态环境状态的认知的不确定性,通过其训练策略生成网络,可以促进策略生成网络对环境的探索,从而提高了策略生成网络的探索效率。
根据本公开的一方面,提供了一种生成器网络的训练方法,包括:获取用于增强学习任务的训练用当前状态向量、训练用动作向量和与所述训练用当前状态向量和训练用动作向量对应的训练用下一状态向量以及由其确定的真实的后验概率分布;将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络,所述生成器网络包括多个网络单元,每个网络单元用于生成所述预测网络的一层;将所述训练用当前状态向量和所述训练用动作向量输入所述一组预测网络以获得预测性的概率分布;确定所述预测性的概率分布与所述真实的后验概率分布之间的KL散度值;以及,基于所述KL散度值来更新所述生成器网络的参数。
根据本公开的另一方面,提供了一种用于增强学习的策略生成网络的训练方法,包括:获取如上所述的生成器网络的训练方法训练的生成器网络;由所述生成器网络生成N个预测网络;获取当前状态向量和由策略生成网络生成的动作向量;将所述当前状态向量和所述动作向量输入所述N个预测网络以获得N个下一状态向量;基于所述N个下一状态向量之间的差异计算用于增强学习的内在奖励函数值;以及,基于所述内在奖励函数值更新所述策略生成网络的参数。
根据本公开的再一方面,提供了一种生成器网络的训练装置,包括:训练向量获取单元,用于获取用于增强学习任务的训练用当前状态向量、训练用动作向量和与所述训练用当前状态向量和训练用动作向量对应的训练用下一状态向量以及由其确定的真实的后验概率分布;预测网络生成单元,用于将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络,所述生成器网络包括多个网络单元,每个网络单元用于生成所述预测网络的一层;向量预测单元,用于将所述训练向量获取单元所获取的所述训练用当前状态向量和所述训练用动作向量输入所述预测网络生成单元所生成的所述一组预测网络以获得预测性的概率分布;散度值确定单元,用于确定所述向量预测单元所获得的所述预测性的概率分布与训练向量获取单元所获取的所述真实的后验概率分布之间的KL散度值;以及,生成器更新单元,用于基于所述散度值确定单元所确定的所述KL散度值来更新所述生成器网络的参数。
根据本公开的又一方面,提供了一种用于增强学习的策略生成网络的训练装置,包括:网络获取单元,用于获取如上所述的生成器网络的训练装置训练的生成器网络;网络生成单元,用于由所述网络获取单元所获取的所述生成器网络生成N个预测网络;向量获取单元,用于获取当前状态向量和由策略生成网络生成的动作向量;预测向量获得单元,用于将所述向量获取单元所获取的所述当前状态向量和所述动作向量输入所述网络生成单元所生成的所述N个预测网络以获得N个下一状态向量;奖励函数计算单元,用于基于所述预测向量获得单元所获得的所述N个下一状态向量之间的差异计算用于增强学习的内在奖励函数值;以及,网络更新单元,用于基于所述奖励函数计算单元所计算的所述内在奖励函数值更新所述策略生成网络的参数。
根据本公开的再一方面,提供了一种电子设备,包括:处理器;以及,存储器,在所述存储器中存储有计算机程序指令,所述计算机程序指令在被所述处理器运行时使得所述处理器执行如上所述的生成器网络的训练方法或者如上所述的用于增强学习的策略生成网络的训练方法。
根据本公开的又一方面,提供了一种计算机可读介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行如上所述的生成器网络的训练方法或者如上所述的用于增强学习的策略生成网络的训练方法。
本公开的实施例提供的生成器网络的训练方法、装置和电子设备,通过生成器网络的每个网络单元生成预测网络的一层,并基于预测网络所预测的状态向量的概率分布与真实的后验概率分布之间的KL散度值更新该生成器网络的参数,可以通过改变所使用的网络单元的数目来生成任意层数的预测网络,增强了生成器网络的灵活性,并且,基于KL散度值可以不更新预测网络的参数而直接更新生成器网络的参数,使得生成器网络的训练过程简单。
另外,本公开的实施例提供的用于增强学习的策略生成网络的训练方法、装置和电子设备基于生成器网络生成的多个预测网络所预测出的状态向量之间的差异,即基于多个预测网络之间的贝叶斯不确定性来计算用于增强学习的策略生成网络的内在奖励函数值,以训练策略生成网络,由于内在奖励函数值能够体现出策略生成网络相对于动态环境状态的认知的不确定性,通过其训练策略生成网络,可以在增强学习的过程中,促进策略生成网络对环境的探索,从而策略生成网络能够在不同环境中更有效地探索,提高了训练的策略生成网络的探索效率。
附图说明
通过结合附图对本公开实施例进行更详细的描述,本公开的上述以及其他目的、特征和优势将变得更加明显。附图用来提供对本公开实施例的进一步理解,并且构成说明书的一部分,与本公开实施例一起用于解释本公开,并不构成对本公开的限制。在附图中,相同的参考标号通常代表相同部件或步骤。
图1图示了标准的增强学习模型的示意图。
图2图示了根据本公开实施例的生成器网络的训练方法的流程图。
图3图示了根据本公开实施例的生成器网络的架构的示意图。
图4图示了根据本公开实施例的用于增强学习的策略生成网络的训练方法的流程图。
图5图示了根据本公开实施例的生成器网络的训练装置的框图。
图6图示了根据本公开实施例的用于增强学习的策略生成网络的训练装置的框图。
图7图示了根据本公开实施例的电子设备的框图。
具体实施方式
下面,将参考附图详细地描述根据本公开的示例实施例。显然,所描述的实施例仅仅是本公开的一部分实施例,而不是本公开的全部实施例,应理解,本公开不受这里描述的示例实施例的限制。
申请概述
图1图示了标准的增强学习模型的示意图。如图1所示,策略生成网络N生成动作A,环境的当前状态S0基于动作A迁移到环境的下一状态S1,且p用于表示当前状态到下一状态的迁移概率。另外,奖励函数r输入到策略生成网络N,用于由策略生成网络N更新其生成动作A的策略,例如,通常目的是最大化奖励函数的累积数值。
在增强学习中,为了构建能够有效地生成动作的策略,一个方面是有效地利用采集到的数据,包括环境的状态数据和动作数据。这通常又包含三个关键方面,泛化(generalization),探索(exploration)和长期后果认识(long-term consequenceawareness)。
这里,执行任务的对象(agent)通过策略生成网络生成的策略来执行动作,因此,对环境的探索既可以称为执行任务的对象对环境的探索,也可以称为策略对环境的探索,或者也可以称为策略生成网络对环境的探索。并且,由于执行任务的对象对环境的探索实质上也是通过执行策略生成网络生成的策略进行的,因此归根结底是要提高策略生成网络的探索性能。
在增强学习领域中,探索指的是获取更多的关于环境的信息。在如上所述的标准的增强学习模型中,策略生成网络对于每个环境状态的迁移步骤接收奖励函数r,在很多情况下,该奖励函数r是外部奖励函数。但是,在很多任务情况下,存在外部奖励稀疏或者几乎可忽略的环境,在这种情况下,需要来以某种类型的内在奖励驱动策略生成网络的探索,这可以认为是有效的累积关于环境的信息的纯粹的探索问题。
因此,本公开的基本构思是通过贝叶斯不确定性的估计来促进增强学习模型的策略生成网络的有效探索,这里,贝叶斯不确定性的估计值能够以非参数方式特性化策略生成网络相对于动态环境状态的认知的不确定性。具体地,贝叶斯不确定性可以表示为不同的预测网络从当前状态和当前动作预测出的下一状态的不确定性。这样,通过引入贝叶斯不确定性的估计值来用于增强学习中的内在奖励值,可以促进增强学习中策略生成网络基于内在奖励值的探索,提高策略生成网络的探索性能。
并且,在本公开中,为了估计贝叶斯不确定性,需要生成从当前状态和动作预测下一状态的多个预测模型,因此,在本公开的用于增强学习的策略生成网络的训练方法中,使用生成器网络来获得N个预测网络,并通过N个预测网络从当前状态和当前动作预测N个下一状态,以计算N个预测出的下一状态之间的差异,即贝叶斯不确定性的估计值。
具体地,本公开提供的生成器网络的训练方法、装置和电子设备首先获取用于增强学习任务的训练用当前状态向量、训练用动作向量和与所述训练用当前状态向量和训练用动作向量对应的训练用下一状态向量以及由其确定的真实的后验概率分布;然后将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络,所述生成器网络包括多个网络单元,每个网络单元用于生成所述预测网络的一层;再将所述训练用当前状态向量和所述训练用动作向量输入所述一组预测网络以获得预测性的概率分布;之后确定所述预测性的概率分布与所述真实的后验概率分布之间的KL散度值;最后基于所述KL散度值来更新所述生成器网络的参数。
这样,本公开提供的生成器网络的训练方法能够使得生成器网络更加灵活,也就是,通过网络单元生成预测网络的一层,可以生成任意层数的预测网络。并且,在更新生成器网络时,仅需要更新生成器网络的参数,而不需要更新由生成器网络生成的预测网络的参数,可以使得生成器网络的训练过程简单。
另一方面,本公开提供的用于增强学习的策略生成网络的训练方法、装置和电子设备首先获取如上所述的生成器网络的训练方法训练的生成器网络;由所述生成器网络生成N个预测网络,然后获取当前状态向量和由策略生成网络生成的动作向量,再将所述当前状态向量和所述动作向量输入所述N个预测网络以获得N个下一状态向量,之后基于所述N个下一状态向量之间的差异计算用于增强学习的内在奖励函数值,最后基于所述内在奖励函数值更新所述策略生成网络的参数。
这样,本公开提供的用于增强学习的策略生成网络的训练方法、装置和电子设备基于生成器网络生成的多个预测网络所预测出的状态向量之间的差异,即基于多个预测网络之间的贝叶斯不确定性来计算用于增强学习的策略生成网络的内在奖励函数值,以训练策略生成网络,由于内在奖励函数值能够体现出策略生成网络相对于动态环境状态的认知的不确定性,通过其训练策略生成网络,可以在增强学习的过程中,促进策略生成网络对环境的探索,从而策略生成网络能够在不同环境中更有效地探索,提高了训练的策略生成网络的探索效率。
在介绍了本公开的基本原理之后,下面将参考附图来具体介绍本公开的各种非限制性实施例。
示例性方法
图2图示了根据本公开实施例的生成器网络的训练方法的流程图。
如图2所示,根据本公开实施例的生成器网络的训练方法包括如下步骤。
步骤S110,获取用于增强学习任务的训练用当前状态向量、训练用动作向量和与所述训练用当前状态向量和训练用动作向量对应的训练用下一状态向量以及由其确定的真实的后验概率分布。
如上所述,对于增强学习任务来说,环境的当前状态通过执行动作的对象(在增强学习任务中通常称为Agent)的动作迁移到下一状态。例如,在迷宫探索类的任务中,动作指的是探索迷宫的对象的移动方向,而状态可以以迷宫的已经探索的部分占整个部分的百分比来表示。此外,在围棋类任务中,动作就是指的棋子落子的位置,即棋盘的81×81的网格坐标中的位置坐标,而状态可以是统计出的棋盘当前状态下的获胜概率。
另外,对于状态和动作,将其分别转换为状态向量和动作向量。在本公开实施例中,通过动作实际作用在环境上,获得当前状态、动作和下一状态并转换为向量以用于生成器网络的训练。
这里,真实的后验概率分布等于先验概率分布乘以似然函数,在本公开实施例中,因为先验的概率分布假设为均匀的,从而可以忽略,并且基于真实的训练数据,可以将似然函数的对数设计为预测结果的损失函数。因此,可以通过通用的方法,基于真实的训练数据以及设计好的损失函数得到所述真实的后验概率分布。在本公开实施例中,所述真实的后验概率分布可以记为p(f|D),且可以简单地记为p,其中,D表示给定的当前状态向量和动作向量。
步骤S120,将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络,所述生成器网络包括多个网络单元,每个网络单元用于生成所述预测网络的一层。
图3图示了根据本公开实施例的生成器网络的架构的示意图。如图3所示,根据本公开实施例的生成器网络包括多个网络单元,例如如图3所示的G1、G2、…、Gn。随机噪声向量的独立噪声样本,例如如图3所示的具有对角协方差的标准高斯噪声Z的噪声样本Z1、Z2、…、Zn输入各个网络单元,以生成预测网络的各层的参数向量,例如如图3所示的θ1、θ2、…、θn。
例如,在本公开实施例中,所述生成器网络可以包括4个网络单元,从而生成包括4层网络的预测网络。并且,所述生成器网络的每个网络单元都可以是全连接神经网络。
这样,如上所述,如图3所示的生成器网络的架构的优点在于其灵活性和高效率性,因为其仅需要维护多个网络单元的参数,并且可以通过增加或者减少使用的网络单元的数目,来生成任意层数的预测网络。并且,通过输入随机噪声向量,可以生成任意数目的预测网络,以用于将在下文中说明的策略生成网络的训练。
并且,在本公开实施例中,将已知概率分布,比如标准正态分布的一组随机噪声向量输入所述生成器网络后,所生成的预测网络对应的函数将呈现由生成器网络对应的函数变换后的概率分布。
值得注意的是,在本公开实施例中,所述一组随机噪声向量是从预定维度的概率分布,比如d维的标准正态分布独立生成的,而不是联合生成的。
这是因为如果多个随机噪声向量是联合生成的,就需要考虑各个随机噪声向量彼此之间的相关性。而在根据本公开实施例的生成器网络中,希望初始的输入尽可能地简单,而将可能的相关性留给生成器网络去学习,这可以提高生成器网络的性能。
在本公开实施例中,所述生成器网络可以表示为,其参数可以表示为η。并且,预测网络可以表示为fθ,且其参数可以表示为θ。
步骤S130,将所述训练用当前状态向量和所述训练用动作向量输入所述一组预测网络以获得预测性的概率分布。也就是,所述一组预测网络中的每个预测网络从当前状态和动作来预测下一状态,即,从所述训练用当前状态向量和所述训练用动作向量获得预测性的下一状态向量。因此,通过输入所述训练用当前状态向量和所述训练用动作向量,可以由满足变换后的概率分布的一组预测网络对应的函数获得预测出的一组向量的概率分布,即,预测性的概率分布。在本公开实施例中,在预测网络表示为θ的情况下,所述预测性的概率分布可以表示为q(fθ),且可以简单地记为q。
步骤S140,确定所述预测性的概率分布与所述真实的后验概率分布之间的KL散度值。这里,为了使得该预测性的概率分布能够尽可能地接近真实的后验概率分布,通过使用KL散度值来度量这种接近性,因此,在本公开实施例中,确定所述预测性的概率分布与所述真实的后验概率分布之间的KL散度值。
步骤S150,基于所述KL散度值来更新所述生成器网络的参数。具体地,为了使得由生成器网络生成的预测网络所获得的预测性的概率分布尽可能地接近真实的后验概率分布,通过最小化所述KL散度值来更新所述生成器网络的参数。例如,可以使用常用的变分推断的方法来最小化所述预测性的概率分布与所述真实的后验概率分布之间的KL散度值。
这样,根据本公开实施例的生成器网络的训练方法可以在训练生成器网络的过程中,仅更新生成器网络的参数,而不更新由生成器网络生成的预测网络的参数,从而使得生成器网络的训练过程简单。
另外,在一个示例中,在本公开实施例中,可以使用斯特恩变分梯度下降(SteinVariational Gradient Descent)方法,该方法是一种非参数化的变分推断方法,其将要训练的网络表示为一组粒子,而不做出参数化的假定,从而通过迭代的粒子演化来获得函数梯度下降值。
因此,在根据本公开实施例的生成器网络的训练方法中,基于所述KL散度值来更新所述生成器网络的参数包括:使用斯特恩(stein)变分梯度下降方法计算相对于所述KL散度值的函数梯度下降值;以及,基于所述函数梯度下降值更新所述生成器网络的参数。
这样,通过斯特恩变分梯度下降的方法,可以以非参数化的方式来计算相对于所述KL散度值的函数梯度下降值,以更新所述生成器网络的参数,使得计算简单。
具体地,采用斯特恩变分梯度下降的方法,需要将函数梯度通过经由生成器网络的反向传播投影到生成器网络的参数空间中,对于生成器网络G生成的一组预测网络的动态函数
Figure BDA0002650072510000091
斯特恩变分梯度下降通过以下公式来更新预测网络:
Figure BDA0002650072510000092
其中∈是步长大小,且φ*是以再现核希尔伯特空间(RKHS)
Figure BDA0002650072510000097
的单元球的函数,其最大化地减小由一组预测网络
Figure BDA0002650072510000098
表示的预测性的概率分布q与真实的后验概率分布p之间的KL散度值,即DKL(q||p)。具体地,该函数由以下公式表示:
Figure BDA0002650072510000093
该优化问题具有近似形式的解:
Figure BDA0002650072510000094
其中,
Figure BDA0002650072510000095
表示关于f的期望,且f服从概率分布q。并且,函数k(·,·)是与RKHS相关联的正有限核,在本公开实施例中,可以使用高斯核,即计算相邻两次生成的预测网络
Figure BDA0002650072510000096
之间的距离。log p(f)对应于对于空间D内的所有状态迁移的将来的状态预测的递归损失函数的负数,也就是
Figure BDA0002650072510000101
其中,s表示当前状态,a表示动作,且s′表示下一状态。
这样,由于预测网络是由生成器网络的参数确定的,如果将预测网络
Figure BDA0002650072510000102
的参数定义为θi,则有:
θi←θi+∈φ*i)
其中,
Figure BDA0002650072510000103
其中,
Figure BDA0002650072510000104
表示关于θ的期望,且θ服从概率分布G。并且,预测网络的参数θi由生成器网络G生成,因此生成器网络的参数η的更新规则也可以由链式规则获得:
Figure BDA0002650072510000105
其中,φ*i)通过使用采样样本的经验期望来计算:
Figure BDA0002650072510000106
也就是,在根据本公开实施例的生成器网络的训练方法中,使用斯特恩(stein)变分梯度下降方法计算所述KL散度值的函数梯度下降值包括:计算所述生成器网络所生成的每个预测网络的以再现核希尔伯特空间的单元球的预定函数;计算所述生成器网络与所述预定函数之积关于所述生成器网络的参数的梯度;以及,将所述梯度对于所述一组预测网络求和以获得梯度和;并且,基于所述函数梯度下降值更新所述生成器网络的参数包括:基于当前生成器网络的参数、所述梯度和和所述第一系数获得更新的生成器网络的参数。
并且,在根据本公开实施例的生成器网络的训练方法中,计算每个预测网络的以再现核希尔伯特空间的单元球的预定函数包括:计算所述预测网络从当前状态和动作预测出的下一状态与真实的下一状态之间的差异函数值关于所述预测网络的梯度;将所述梯度关于所述状态空间和动作空间内的所有状态和动作求和并乘以核函数以获得核函数积,所述核函数用于计算相邻两次生成的预测网络之间的距离;以及,将所述核函数关于所述预测网络的梯度减去所述核函数积并关于所述一组预测网络求和以获得所述预定函数。
这样,在完全生成器网络的训练之后,就可以通过使用生成器网络生成的预测网络来进行用于增强学习的策略生成网络的训练,也就是,通过生成器网络生成多个预测网络,并多个预测网络所预测出的状态向量之间的差异,即基于多个预测网络之间的贝叶斯不确定性来计算用于增强学习的策略生成网络的内在奖励函数值,以训练策略生成网络,由于内在奖励函数值能够体现出策略生成网络相对于动态环境状态的认知的不确定性,通过其训练策略生成网络,可以在增强学习的过程中,促进策略生成网络对环境的探索,从而策略生成网络能够在不同环境中更有效地探索,提高了训练的策略生成网络的探索效率。
图4图示了根据本公开实施例的用于增强学习的策略生成网络的训练方法的流程图。
如图4所示,根据本公开实施例的用于增强学习的策略生成网络的训练方法包括以下步骤。
S210,获取如上所述的生成器网络的训练方法训练的生成器网络。也就是,获取训练好的生成器网络。
S220,由所述生成器网络生成N个预测网络。在一个示例中,可以使用所述生成器网络生成32个预测网络。这里,由于所述生成器网络输入的随机噪声向量,因此所生成的32个预测网络的参数也不相同,所以生成的N个预测网络可以被称为动态预测网络。
S230,获取当前状态向量和由策略生成网络生成的动作向量。例如,如图1所示,获得当前状态S0的向量和由策略生成网络N生成的动作A的向量。
S240,将所述当前状态向量和所述动作向量输入所述N个预测网络以获得N个下一状态向量。也就是,将当前状态向量和动作向量输入不同参数的N个预测网络,得到的N个下一状态向量也不相同,这可以被称为动态预测网络的贝叶斯不确定性。
S250,基于所述N个下一状态向量之间的差异计算用于增强学习的内在奖励函数值。也就是,将动态预测网络的贝叶斯不确定性用于计算用于增强学习任务的内在奖励函数值。
S260,基于所述内在奖励函数值更新所述策略生成网络的参数。在本公开实施例中,可以使用不受模型约束的增强学习算法来进行策略生成网络的参数的更新,例如,可以采用柔性致动评价(Soft Actor Critic)算法。
这样,基于生成器网络生成的多个预测网络所预测出的状态向量之间的差异,即基于多个预测网络之间的贝叶斯不确定性来计算用于增强学习的策略生成网络的内在奖励函数值,以训练策略生成网络,可以在增强学习的过程中,促进执行任务的对象通过使用策略生成网络所生成的策略,来基于内在奖励对环境进行探索,从而能够在不同环境中更有效地探索,提高了训练的策略生成网络的性能。
在一个示例中,可以通过以下公式计算内在奖励函数值
Figure BDA0002650072510000121
Figure BDA0002650072510000122
其中,st表示当前状态向量,at表示动作向量,
Figure BDA0002650072510000123
Figure BDA0002650072510000124
表示预测网络从当前状态向量和动作向量预测出的下一向量,m是预测网络的数目。
因此,在根据本公开实施例的用于增强学习的策略生成网络的训练方法中,基于所述N个下一状态向量计算用于增强学习的内在奖励函数值包括:计算所述N个下一状态向量的均值向量;计算所述N个下一状态向量中的每个下一状态向量与所述均值向量的L2距离值以获得N个L2距离值;以及,计算所述N个L2距离值的均值以获得所述用于增强学习的奖励函数值。
另外,在本公开实施例中,所述下一状态向量除了由预测网络从当前状态向量和动作向量预测的下一状态向量以外,还可以包括增强学习任务中的真实的执行任务的对象基于所述当前状态向量和所述动作向量获得的真实的下一状态向量。
也就是,所述状态向量和动作向量既可以是从实际环境中采集的,也可以是由生成器网络生成的动态预测网络模拟生成的,因为在本公开实施例中,是将动态预测网络的贝叶斯不确定性用于计算用于增强学习任务的内在奖励函数值,以体现出策略生成网络相对于动态环境状态的认知的不确定性,这样,通过该内在奖励函数值来训练策略生成网络,可以在增强学习的过程中,促进策略生成网络对环境的探索,从而策略生成网络能够在不同环境中更有效地探索,提高了训练的策略生成网络的探索效率。
示例性装置
图5图示了根据本公开实施例的生成器网络的训练装置的框图。
如图5所示,根据本公开实施例的生成器网络的训练装置300包括:训练向量获取单元310,用于获取用于增强学习任务的训练用当前状态向量、训练用动作向量和与所述训练用当前状态向量和训练用动作向量对应的训练用下一状态向量以及由其确定的真实的后验概率分布;预测网络生成单元320,用于将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络,所述生成器网络包括多个网络单元,每个网络单元用于生成所述预测网络的一层;向量预测单元330,用于将所述训练向量获取单元310所获取的所述训练用当前状态向量和所述训练用动作向量输入所述预测网络生成单元320所生成的所述一组预测网络以获得预测性的概率分布;散度值确定单元340,用于确定所述向量预测单元330所获得的所述预测性的概率分布与训练向量获取单元310所获取的所述真实的后验概率分布之间的KL散度值;以及,生成器更新单元350,用于基于所述散度值确定单元340所确定的所述KL散度值来更新所述生成器网络的参数。
在一个示例中,在上述生成器网络的训练装置300中,所述生成器更新单元350包括:梯度计算子单元,用于使用斯特恩(stein)变分梯度下降方法计算所述散度值确定单元340所确定的所述KL散度值的函数梯度下降值;以及,参数更新子单元,用于基于所述梯度计算子单元所计算的所述函数梯度下降值更新所述生成器网络的参数。
在一个示例中,在上述生成器网络的训练装置300中,所述梯度计算子单元用于:计算所述生成器网络所生成的每个预测网络的以再现核希尔伯特空间的单元球的预定函数;计算所述生成器网络与所述预定函数之积关于所述生成器网络的参数的梯度;以及,将所述梯度对于所述一组预测网络求和以获得梯度和;和,所述参数更新子单元用于:基于当前生成器网络的参数、所述梯度和和所述第一系数获得更新的生成器网络的参数。
在一个示例中,在上述生成器网络的训练装置300中,所述梯度计算子单元计算每个预测网络的以再现核希尔伯特空间的单元球的预定函数包括:计算所述预测网络从当前状态和动作预测出的下一状态与真实的下一状态之间的差异函数值关于所述预测网络的梯度;将所述梯度关于所述状态空间和动作空间内的所有状态和动作求和并乘以核函数以获得核函数积,所述核函数用于计算相邻两次生成的预测网络之间的距离;以及,将所述核函数关于所述预测网络的梯度减去所述核函数积并关于所述一组预测网络求和以获得所述预定函数。
在一个示例中,在上述生成器网络的训练装置300中,所述预测网络生成单元320用于:对于每个预测网络,将从具有对角协方差的标准高斯噪声获得的独立噪声样本输入每个网络单元,以生成所述预测网络的一层。
图6图示了根据本公开实施例的用于增强学习的策略生成网络的训练装置的框图。
如图6所示,根据本公开实施例的用于增强学习的策略生成网络的训练装置400包括:网络获取单元410,用于获取如上所述的生成器网络的训练装置300训练的生成器网络;网络生成单元420,用于由所述网络获取单元410所获取的所述生成器网络生成N个预测网络;向量获取单元430,用于获取当前状态向量和由策略生成网络生成的动作向量;预测向量获得单元440,用于将所述向量获取单元430所获取的所述当前状态向量和所述动作向量输入所述网络生成单元420所生成的所述N个预测网络以获得N个下一状态向量;奖励函数计算单元450,用于基于所述预测向量获得单元440所获得的所述N个下一状态向量之间的差异计算用于增强学习的内在奖励函数值;以及,网络更新单元460,用于基于所述奖励函数计算单元450所计算的所述内在奖励函数值更新所述策略生成网络的参数。
在一个示例中,在上述用于增强学习的策略生成网络的训练装置400中,所述奖励函数计算单元450用于:计算所述N个下一状态向量的均值向量;计算所述N个下一状态向量中的每个下一状态向量与所述均值向量的L2距离值以获得N个L2距离值;以及,计算所述N个L2距离值的均值以获得所述用于增强学习的奖励函数值。
在一个示例中,在上述用于增强学习的策略生成网络的训练装置400中,所述下一状态向量包括增强学习任务中的真实的执行任务的对象基于所述当前状态向量和所述动作向量获得的真实的下一状态向量。
这里,本领域技术人员可以理解,上述生成器网络的训练装置300和用于增强学习的策略生成网络的训练装置400中的各个单元和模块的具体功能和操作已经在上面参考图1到图4的生成器网络的训练方法和用于增强学习的策略生成网络的训练方法的描述中得到了详细介绍,并因此,将省略其重复描述。
如上所述,根据本公开实施例的生成器网络的训练装置300和用于增强学习的策略生成网络的训练装置400可以实现在各种终端设备中,例如用于增强学习任务的服务器等。在一个示例中,根据本公开实施例的生成器网络的训练装置300和用于增强学习的策略生成网络的训练装置400可以作为一个软件模块和/或硬件模块而集成到终端设备中。例如,其可以是该终端设备的操作系统中的一个软件模块,或者可以是针对于该终端设备所开发的一个应用程序;当然,生成器网络的训练装置300和用于增强学习的策略生成网络的训练装置400同样可以是该终端设备的众多硬件模块之一。
替换地,在另一示例中,生成器网络的训练装置300和用于增强学习的策略生成网络的训练装置400与该终端设备也可以是分立的设备,并且其可以通过有线和/或无线网络连接到该终端设备,并且按照约定的数据格式来传输交互信息。
示例性电子设备
下面,参考图7来描述根据本公开实施例的电子设备。
图7图示了根据本公开实施例的电子设备的框图。
如图7所示,电子设备10包括一个或多个处理器11和存储器12。
处理器11可以是中央处理单元(CPU)或者具有数据处理能力和/或指令执行能力的其他形式的处理单元,并且可以控制电子设备10中的其他组件以执行期望的功能。
存储器12可以包括一个或多个计算机程序产品,所述计算机程序产品可以包括各种形式的计算机可读存储介质,例如易失性存储器和/或非易失性存储器。所述易失性存储器例如可以包括随机存取存储器(RAM)和/或高速缓冲存储器(cache)等。所述非易失性存储器例如可以包括只读存储器(ROM)、硬盘、闪存等。在所述计算机可读存储介质上可以存储一个或多个计算机程序指令,处理器11可以运行所述程序指令,以实现上文所述的本公开的各个实施例的生成器网络的训练方法和用于增强学习的策略生成网络的训练方法以及/或者其他期望的功能。在所述计算机可读存储介质中还可以存储诸如状态向量、动作向量、预测网络的参数等各种内容。
在一个示例中,电子设备10还可以包括:输入装置13和输出装置14,这些组件通过总线系统和/或其他形式的连接机构(未示出)互连。
该输入装置13可以包括例如键盘、鼠标等等。
该输出装置14可以向外部输出各种信息,包括训练好的生成器网络和策略生成网络的参数等。该输出装置14可以包括例如显示器、扬声器、打印机、以及通信网络及其所连接的远程输出设备等等。
当然,为了简化,图7中仅示出了该电子设备10中与本公开有关的组件中的一些,省略了诸如总线、输入/输出接口等等的组件。除此之外,根据具体应用情况,电子设备10还可以包括任何其他适当的组件。
示例性计算机程序产品和计算机可读存储介质
除了上述方法和设备以外,本公开的实施例还可以是计算机程序产品,其包括计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本说明书上述“示例性方法”部分中描述的根据本公开各种实施例的生成器网络的训练方法和用于增强学习的策略生成网络的训练方法中的步骤。
所述计算机程序产品可以以一种或多种程序设计语言的任意组合来编写用于执行本公开实施例操作的程序代码,所述程序设计语言包括面向对象的程序设计语言,诸如Java、C++等,还包括常规的过程式程序设计语言,诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。
此外,本公开的实施例还可以是计算机可读存储介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本说明书上述“示例性方法”部分中描述的根据本公开各种实施例的生成器网络的训练方法和用于增强学习的策略生成网络的训练方法中的步骤。
所述计算机可读存储介质可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以包括但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
以上结合具体实施例描述了本公开的基本原理,但是,需要指出的是,在本公开中提及的优点、优势、效果等仅是示例而非限制,不能认为这些优点、优势、效果等是本公开的各个实施例必须具备的。另外,上述公开的具体细节仅是为了示例的作用和便于理解的作用,而非限制,上述细节并不限制本公开为必须采用上述具体的细节来实现。
本公开中涉及的器件、装置、设备、系统的方框图仅作为例示性的例子并且不意图要求或暗示必须按照方框图示出的方式进行连接、布置、配置。如本领域技术人员将认识到的,可以按任意方式连接、布置、配置这些器件、装置、设备、系统。诸如“包括”、“包含”、“具有”等等的词语是开放性词汇,指“包括但不限于”,且可与其互换使用。这里所使用的词汇“或”和“和”指词汇“和/或”,且可与其互换使用,除非上下文明确指示不是如此。这里所使用的词汇“诸如”指词组“诸如但不限于”,且可与其互换使用。
还需要指出的是,在本公开的装置、设备和方法中,各部件或各步骤是可以分解和/或重新组合的。这些分解和/或重新组合应视为本公开的等效方案。
提供所公开的方面的以上描述以使本领域的任何技术人员能够做出或者使用本公开。对这些方面的各种修改对于本领域技术人员而言是非常显而易见的,并且在此定义的一般原理可以应用于其他方面而不脱离本公开的范围。因此,本公开不意图被限制到在此示出的方面,而是按照与在此公开的原理和新颖的特征一致的最宽范围。
为了例示和描述的目的已经给出了以上描述。此外,此描述不意图将本公开的实施例限制到在此公开的形式。尽管以上已经讨论了多个示例方面和实施例,但是本领域技术人员将认识到其某些变型、修改、改变、添加和子组合。

Claims (11)

1.一种生成器网络的训练方法,包括:
获取用于增强学习任务的训练用当前状态向量、训练用动作向量和与所述训练用当前状态向量和训练用动作向量对应的训练用下一状态向量以及由其确定的真实的后验概率分布;
将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络,所述生成器网络包括多个网络单元,每个网络单元用于生成所述预测网络的一层;
将所述训练用当前状态向量和所述训练用动作向量输入所述一组预测网络以获得预测性的概率分布;
确定所述预测性的概率分布与所述真实的后验概率分布之间的KL散度值;以及
基于所述KL散度值来更新所述生成器网络的参数。
2.如权利要求1所述的生成器网络的训练方法,其中,基于所述KL散度值来更新所述生成器网络的参数包括:
使用斯特恩变分梯度下降方法计算所述KL散度值的函数梯度下降值;
基于所述函数梯度下降值更新所述生成器网络的参数。
3.如权利要求2所述的生成器网络的训练方法,其中,使用斯特恩变分梯度下降方法计算所述KL散度值的函数梯度下降值包括:
计算所述生成器网络所生成的每个预测网络的以再现核希尔伯特空间的单元球的预定函数;
计算所述生成器网络与所述预定函数之积关于所述生成器网络的参数的梯度;以及,
将所述梯度对于所述一组预测网络求和以获得梯度和;和
基于所述函数梯度下降值更新所述生成器网络的参数包括:
基于当前生成器网络的参数、所述梯度和和所述第一系数获得更新的生成器网络的参数。
4.如权利要求3所述的生成器网络的训练方法,其中,计算每个预测网络的以再现核希尔伯特空间的单元球的预定函数包括:
计算所述预测网络从当前状态和动作预测出的下一状态与真实的下一状态之间的差异函数值关于所述预测网络的梯度;
将所述梯度关于所述状态空间和动作空间内的所有状态和动作求和并乘以核函数以获得核函数积,所述核函数用于计算相邻两次生成的预测网络之间的距离;以及,
将所述核函数关于所述预测网络的梯度减去所述核函数积并关于所述一组预测网络求和以获得所述预定函数。
5.如权利要求1所述的生成器网络的训练方法,其中,将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络包括:
对于每个预测网络,将从具有对角协方差的标准高斯噪声获得的独立噪声样本输入每个网络单元,以生成所述预测网络的一层。
6.一种用于增强学习的策略生成网络的训练方法,包括:
获取如权利要求1到5中任一项所述的生成器网络的训练方法训练的生成器网络;
由所述生成器网络生成N个预测网络;
获取当前状态向量和由策略生成网络生成的动作向量;
将所述当前状态向量和所述动作向量输入所述N个预测网络以获得N个下一状态向量;
基于所述N个下一状态向量之间的差异计算用于增强学习的内在奖励函数值;以及
基于所述内在奖励函数值更新所述策略生成网络的参数。
7.如权利要求6所述的用于增强学习的策略生成网络的训练方法,其中,基于所述N个下一状态向量计算用于增强学习的内在奖励函数值包括:
计算所述N个下一状态向量的均值向量;
计算所述N个下一状态向量中的每个下一状态向量与所述均值向量的L2距离值以获得N个L2距离值;以及
计算所述N个L2距离值的均值以获得所述用于增强学习的奖励函数值。
8.如权利要求6所述的用于增强学习的策略生成网络的训练方法,其中,
所述下一状态向量包括增强学习任务中的真实的执行任务的对象基于所述当前状态向量和所述动作向量获得的真实的下一状态向量。
9.一种生成器网络的训练装置,包括:
训练向量获取单元,用于获取用于增强学习任务的训练用当前状态向量、训练用动作向量和与所述训练用当前状态向量和训练用动作向量对应的训练用下一状态向量以及由其确定的真实的后验概率分布;
预测网络生成单元,用于将已知概率分布的一组随机噪声向量输入生成器网络以获得一组预测网络,所述生成器网络包括多个网络单元,每个网络单元用于生成所述预测网络的一层;
向量预测单元,用于将所述训练向量获取单元所获取的所述训练用当前状态向量和所述训练用动作向量输入所述预测网络生成单元所生成的所述一组预测网络以获得预测性的概率分布;
散度值确定单元,用于确定所述向量预测单元所获得的所述预测性的概率分布与训练向量获取单元所获取的所述真实的后验概率分布之间的KL散度值;以及
生成器更新单元,用于基于所述散度值确定单元所确定的所述KL散度值来更新所述生成器网络的参数。
10.一种用于增强学习的策略生成网络的训练装置,包括:
网络获取单元,用于获取如权利要求9所述的生成器网络的训练装置训练的生成器网络;
网络生成单元,用于由所述网络获取单元所获取的所述生成器网络生成N个预测网络;
向量获取单元,用于获取当前状态向量和由策略生成网络生成的动作向量;
预测向量获得单元,用于将所述向量获取单元所获取的所述当前状态向量和所述动作向量输入所述网络生成单元所生成的所述N个预测网络以获得N个下一状态向量;
奖励函数计算单元,用于基于所述预测向量获得单元所获得的所述N个下一状态向量之间的差异计算用于增强学习的内在奖励函数值;以及
网络更新单元,用于基于所述奖励函数计算单元所计算的所述内在奖励函数值更新所述策略生成网络的参数。
11.一种电子设备,包括:
处理器;以及
存储器,在所述存储器中存储有计算机程序指令,所述计算机程序指令在被所述处理器运行时使得所述处理器执行如权利要求1-6中任一项所述的生成器网络的训练方法或者如权利要求7或8所述的用于增强学习的策略生成网络的训练方法。
CN202010867110.3A 2019-09-23 2020-08-26 生成器网络和策略生成网络的训练方法、装置和电子设备 Pending CN112016611A (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
US201962904382P 2019-09-23 2019-09-23
US62/904,382 2019-09-23

Publications (1)

Publication Number Publication Date
CN112016611A true CN112016611A (zh) 2020-12-01

Family

ID=73503524

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010867110.3A Pending CN112016611A (zh) 2019-09-23 2020-08-26 生成器网络和策略生成网络的训练方法、装置和电子设备

Country Status (1)

Country Link
CN (1) CN112016611A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113506328A (zh) * 2021-07-16 2021-10-15 北京地平线信息技术有限公司 视线估计模型的生成方法和装置、视线估计方法和装置

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113506328A (zh) * 2021-07-16 2021-10-15 北京地平线信息技术有限公司 视线估计模型的生成方法和装置、视线估计方法和装置

Similar Documents

Publication Publication Date Title
Zheng Gradient descent algorithms for quantile regression with smooth approximation
WO2022068623A1 (zh) 一种模型训练方法及相关设备
US10366325B2 (en) Sparse neural control
CN114048331A (zh) 一种基于改进型kgat模型的知识图谱推荐方法及系统
Al-Wesabi et al. Energy aware resource optimization using unified metaheuristic optimization algorithm allocation for cloud computing environment
WO2013086186A2 (en) Particle methods for nonlinear control
CN115066694A (zh) 计算图优化
WO2021025075A1 (ja) 訓練装置、推定装置、訓練方法、推定方法、プログラム及びコンピュータ読み取り可能な非一時的記憶媒体
Zhao et al. Surrogate modeling of nonlinear dynamic systems: a comparative study
CN114261400A (zh) 一种自动驾驶决策方法、装置、设备和存储介质
CN112016678A (zh) 用于增强学习的策略生成网络的训练方法、装置和电子设备
EP3446258B1 (en) Model-free control for reinforcement learning agents
CN115951989A (zh) 一种基于严格优先级的协同流量调度数值模拟方法与系统
CN113407820B (zh) 利用模型进行数据处理的方法及相关系统、存储介质
CN112016611A (zh) 生成器网络和策略生成网络的训练方法、装置和电子设备
CN114648103A (zh) 用于处理深度学习网络的自动多目标硬件优化
WO2020169182A1 (en) Method and apparatus for allocating tasks
JP7150651B2 (ja) ニューラルネットワークのモデル縮約装置
KR102561799B1 (ko) 디바이스에서 딥러닝 모델의 레이턴시를 예측하는 방법 및 시스템
Krishnan et al. Multi-Agent Reinforcement Learning for Microprocessor Design Space Exploration
Mateo et al. A variable selection approach based on the delta test for extreme learning machine models
CN114445692B (zh) 图像识别模型构建方法、装置、计算机设备及存储介质
Wong et al. Hybrid data regression model based on the generalized adaptive resonance theory neural network
Lo et al. Learning based mesh generation for thermal simulation in handheld devices with variable power consumption
WO2020054402A1 (ja) ニューラルネットワーク処理装置、コンピュータプログラム、ニューラルネットワーク製造方法、ニューラルネットワークデータの製造方法、ニューラルネットワーク利用装置、及びニューラルネットワーク小規模化方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination