CN112348113B - 离线元强化学习模型的训练方法、装置、设备及存储介质 - Google Patents

离线元强化学习模型的训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN112348113B
CN112348113B CN202011354318.1A CN202011354318A CN112348113B CN 112348113 B CN112348113 B CN 112348113B CN 202011354318 A CN202011354318 A CN 202011354318A CN 112348113 B CN112348113 B CN 112348113B
Authority
CN
China
Prior art keywords
task
network
vector
training sample
loss function
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202011354318.1A
Other languages
English (en)
Other versions
CN112348113A (zh
Inventor
李蓝青
杨瑞
罗迪君
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tencent Technology Shenzhen Co Ltd
Original Assignee
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tencent Technology Shenzhen Co Ltd filed Critical Tencent Technology Shenzhen Co Ltd
Priority to CN202011354318.1A priority Critical patent/CN112348113B/zh
Publication of CN112348113A publication Critical patent/CN112348113A/zh
Application granted granted Critical
Publication of CN112348113B publication Critical patent/CN112348113B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

本申请公开了一种离线元强化学习模型的训练方法、装置、设备及存储介质,涉及人工智能技术领域。所述方法包括:从离线数据池中采样获取多个训练样本集;通过任务推断网络生成训练样本集的任务表示向量;基于不同训练样本集的任务表示向量之间的距离度量,确定任务推断网络的损失函数;基于训练样本的状态向量、动作向量和任务表示向量,确定策略网络的损失函数和评判网络的损失函数;基于任务推断网络的损失函数、策略网络的损失函数和评判网络的损失函数,分别对任务推断网络、策略网络和评判网络的参数进行调整。相比于传统的在线强化学习,本申请采用离线学习方式对元强化学习模型进行训练,提高了安全性及数据利用效率,且有助于降低成本。

Description

离线元强化学习模型的训练方法、装置、设备及存储介质
技术领域
本申请实施例涉及人工智能技术领域,特别涉及一种离线元强化学习模型的训练方法、装置、设备及存储介质。
背景技术
目前,强化学习在自动驾驶、机器人控制、农业种植等场景中得到了应用。
在相关技术中,通过在线学习的方式,对强化学习模型进行训练。以农业种植场景为例,获取农作物当前时刻的状态信息(如温室气候、作物发育情况等信息),通过策略网络基于上述状态信息,输出相应的动作策略(如温度、湿度、光照、水分、CO2等控制策略),给农作物施加上述动作策略并采集下一时刻的状态信息,并反馈奖励值,该奖励值反映了农作物在施加上述动作策略前后的性能差异,譬如产量、资源消耗量等。强化学习的优化目标是最大化累积奖励值,从而得到满足预期的策略网络。
然而,上述在线学习的方式需要不断地试错优化,但在很多场景下是不可行的,会带来安全和成本的问题。
发明内容
本申请实施例提供了一种离线元强化学习模型的训练方法、装置、设备及存储介质。所述技术方案如下:
根据本申请实施例的一个方面,提供了一种离线元强化学习模型的训练方法,所述离线元强化学习模型包括任务推断网络、策略网络和评判网络;所述方法包括:
从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本;
通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
基于不同训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的损失函数;其中,所述任务推断网络的损失函数与不同任务间的距离度量负相关;
基于所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的损失函数和所述评判网络的损失函数;
基于所述任务推断网络的损失函数、所述策略网络的损失函数和所述评判网络的损失函数,分别对所述任务推断网络、所述策略网络和所述评判网络的参数进行调整。
根据本申请实施例的一个方面,提供了一种离线元强化学习模型的训练装置,所述离线元强化学习模型包括任务推断网络、策略网络和评判网络;所述装置包括:
离线采样模块,用于从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本;
任务表征模块,用于通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
损失计算模块,用于基于不同训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的损失函数;其中,所述任务推断网络的损失函数与不同任务间的距离度量负相关;
所述损失计算模块,还用于基于所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的损失函数和所述评判网络的损失函数;
参数调整模块,用于基于所述任务推断网络的损失函数、所述策略网络的损失函数和所述评判网络的损失函数,分别对所述任务推断网络、所述策略网络和所述评判网络的参数进行调整。
根据本申请实施例的一个方面,提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现上述方法。
根据本申请实施例的一个方面,提供了一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现上述方法。
根据本申请实施例的一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述方法。
本申请实施例提供的技术方案至少包括如下有益效果:
通过采用离线学习方式对元强化学习模型进行训练,模型训练所需的样本数据来自于离线数据池,从而无需将不成熟的模型应用到在线交互过程中,相比于传统的在线强化学习,提高了安全性及数据利用效率,且有助于降低成本。以自动驾驶场景为例,如果将不成熟的模型应用到在线交互过程中,对真实场景中运行的车辆进行控制容易造成安全问题,对仿真场景下运行的车辆进行控制则需搭建仿真场景,成本过高,且可靠性难以保证。而本申请提供的离线学习方式对元强化学习模型进行训练,则可以克服上述问题。另外,本申请实施例可以重复利用完全离线的数据进行模型训练,无需在线采集数据,更加高效。
另外,还通过采用元强化学习的方式,训练能够完成多个不同任务的模型,从而能够低成本地推广到新的任务,模型的泛化能力显著提升。
另外,任务推断网络的损失函数与不同任务间的距离度量负相关。通过这种方式,可以使得任务推断网络输出的任务表示向量,对于相同任务输出的任务表示向量尽可能地接近,而对于不同任务输出的任务表示向量尽可能地分散,从而实现对不同任务的有效区分,这对于学习不同任务的策略更加有优势。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个实施例提供的离线元强化学习模型的模型架构图;
图2是本申请一个实施例提供的离线元强化学习模型的训练方法的流程图;
图3是本申请另一个实施例提供的离线元强化学习模型的训练方法的流程图;
图4是本申请技术方案和对比方案得出的任务向量表示的示意图;
图5是本申请一个实施例提供的引入注意力机制后的模型架构图;
图6是本申请一个实施例提供的引入注意力机制后生成动作向量的示意图;
图7是本申请一个实施例提供的离线元强化学习模型的训练装置的框图;
图8是本申请一个实施例提供的计算机设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
本申请实施例提供的技术方案,涉及人工智能的机器学习等技术,主要涉及强化学习技术。在对本申请技术方案进行介绍说明之前,先对本申请涉及的一些名词及术语进行解释说明。
1、强化学习(reinforcement learning)
属于机器学习的范畴,通常用于解决序列决策问题,主要包括环境和智能体两个组成成分,智能体根据环境的状态选择动作执行,环境根据智能体的动作转移到新的状态并反馈一个数值的奖励,智能体根据环境反馈的奖励不断优化策略。
2、离线强化学习(off-line reinforcement learning)
离线强化学习是一类完全从离线的数据中进行学习的强化学习方法,不与环境交互采样,通常这类方法使用动作约束(behavior regularization)来控制在线测试时数据分布与离线数据分布的差异。
3、经验回放
强化学习中离线策略算法的使用的一个技巧,保持一个经验池储存智能体与环境交互的数据,训练策略时,从经验池中采样数据来训练策略网络。经验回放的方式可以重复利用已获得的数据样本,使得离线策略算法的数据利用效率高于在线策略算法。
4、元强化学习(meta reinforcement learning)
元强化学习的目标是学习到具有较强泛化性能的强化学习模型。通过在服从一定分布的多个任务上进行训练,模型能够利用较少的数据量和时间成本,适应或推广到在训练期间从未遇到过的新任务和新环境。
5、度量学习(metric learning)
度量学习的对象通常是样本特征向量的距离,度量学习的目的是通过训练和学习,减小或限制同类别样本之间的距离,同时增大不同类别样本之间的距离,从而得到多类别样本的有效表征。
6、注意力机制(attention mechanism)
深度学习中针对时序数据的经典处理方式,利用特殊神经网络结构组成的运算单元,将输入中各元素(self-attention,自注意力)以及输入输出间各元素(generalattention,整体注意力)的相关性进行量化。类似人类视觉的注意力机制,针对不同的输出,将不同的权重或“注意力”赋予特定的输入元素,从而提升模型效果。
强化学习通常可以表示为马尔科夫决策过程(Markov Decision Process,MDP),MDP包含了五元组(S,A,R,P,γ),其中,S代表状态空间,A代表动作空间,R代表奖励函数,P代表状态转移概率矩阵,γ代表折扣因子。智能体每个时刻观测到状态st,根据状态执行动作at,环境接收到动作后转移到下一个状态st+1并反馈奖励rt,强化学习优化的目标是最大化累积奖励值
Figure BDA0002802149390000061
智能体根据策略π(at|st)选择动作,动作值函数Q(st,at)代表在状态st执行动作at后的期望累积奖励,
Figure BDA0002802149390000062
元强化学习的优化目标为:
Figure BDA0002802149390000063
其中,τi代表第i个任务,p(τ)代表所有任务的分布,
Figure BDA0002802149390000064
表示使用策略θ在任务τi下获得的期望累积奖励。
离线强化学习中引入了动作约束来限制所学习策略πθ和离线数据的策略πb的差异,由此控制两者采样的数据分布(数据分布与策略相关)。假设D是一种衡量分布差异的函数,比如KL散度(Kullback-Leibler divergence),B是固定的离线数据池,我们定义带约束的动作值函数QD如下:
QD(s,a)=Q(s,a)-αD(πθ(·|s),πb(·|s))。
在强化学习模型中,包括actor(策略网络)和critic(评判网络)。其中,策略网络用于基于概率选择行为,评判网络用于基于策略网络的行为评判该行为的得分(即奖励)。策略网络根据评判网络给出的评分,修改选行为的概率。
我们分别定义critic(评判网络)、actor(策略网络)的损失函数为:
Figure BDA0002802149390000065
Figure BDA0002802149390000066
式中,Lcritic表示评判网络的损失函数,Lactor表示策略网络的损失函数,s表示状态,a表示在状态s下给定的动作,r表示奖励,s′表示执行动作a之后转移至的状态,a′表示在状态s′下给定的动作,a″表示在状态s下基于概率选择的动作,πθ表示所学习策略,γ代表折扣因子,α为模型可调超参数,Q为带约束的动作值函数,D是一种衡量分布差异的函数。
图1示出了本申请一个实施例提供的离线元强化学习模型的模型架构图。如图1所示,在本申请实施例中,离线元强化学习模型可以包括任务推断网络10、策略网络20和评判网络30。
离线数据池40用于提供多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本。
任务推断网络10用于生成训练样本集的任务表示向量,该任务表示向量用于表征训练样本集所属的任务。
策略网络20用于基于训练样本的状态向量和任务表示向量,生成该训练样本的动作向量。
评判网络30用于基于训练样本的状态向量、动作向量和任务表示向量,生成相应的评分(即奖励)。
任务推断网络10、策略网络20和评判网络30的损失函数,分别以Ldml、Lcritic和Lactor表示。在模型训练过程中,通过计算上述3个损失函数,据此分别对任务推断网络10、策略网络20和评判网络30的参数进行调整,例如以最小化其损失函数值为目标,不断优化各网络的参数,以达到优化训练整个模型的目的。
本申请实施例提供的方法,各步骤的执行主体可以是计算机设备。计算机设备是指具备数据处理、计算和存储能力的电子设备,该计算机设备可以是服务器,也可以是诸如手机、平板电脑、PC(Personal Computer)等终端,本申请实施例对此不作限定。
请参考图2,其示出了本申请一个实施例提供的离线元强化学习模型的训练方法的流程图。该方法可以包括如下几个步骤(201~205):
步骤201,从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本。
在离线数据池中,包括多个不同任务的训练样本,通过采用不同任务的训练样本对模型进行训练,从而实现多任务的元强化学习,提升最终训练出的模型在不同任务间的泛化能力。上述多个不同任务可以是在同一应用场景下服从一定分布的多个任务。以农业种植场景为例,可以从气候、作物品种、温室设施等任意一种因素或多种因素的组合,划分出多个不同任务。例如,以作物品种对不同任务进行划分为例,可以包括西红柿种植任务、黄瓜种植任务、苹果种植任务等多个不同任务。当然,还可以结合多种因素对任务进行划分,例如在夏季气候下种植西红柿、在冬季气候下种植西红柿、在秋季气候下种植黄瓜等多个不同任务。
通过对样本对象的状态信息以及施加给样本对象的动作信息进行记录,可以得到用于构建训练样本的离线数据池。仍然以农业种植场景为例,对于西红柿种植任务来说,样本对象即为西红柿,通过定期对西红柿的状态信息(如温室气候、作物发育情况等信息)和施加给西红柿的动作信息(如温度、湿度、光照、水分、CO2等控制策略)进行采集,并记录在离线数据池中。离线数据池中可以记录多个不同任务的相关数据。通过从离线数据池中采样数据,可以构建用于模型训练的训练样本,训练样本可以包括某一时刻的状态信息,以及针对该状态信息施加的动作信息。
另外,为了实现多任务的元强化学习,从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本。并且,任意两个训练样本集可以对应同一个相同任务,也可以对应两个不同任务。
步骤202,通过任务推断网络生成训练样本集的任务表示向量,该训练样本集的任务表示向量用于表征训练样本集所属的任务。
任务推断网络用于对训练样本集对应的任务进行推断,生成训练样本集的任务表示向量。对于不同的任务,任务推断网络采用不同的任务表示向量来进行表示,从而实现对不同任务的区分。
任务推断网络可以是一个多层神经网络,基于训练样本的与任务相关的信息(如气候、作物品种、温室等信息),生成相应的任务表示向量。经过任务推断网络的映射表征,相同任务对应于同一个任务表示向量,不同任务对应于不同的任务表示向量。
步骤203,基于不同训练样本集的任务表示向量之间的距离度量,确定任务推断网络的损失函数;其中,任务推断网络的损失函数与不同任务间的距离度量负相关。
不同训练样本集的任务表示向量之间的距离度量,用于表征这两个任务表示向量之间的差异程度。在本申请实施例中,对计算两个任务表示向量之间的距离度量的方式不作限定,任何适用于计算两个向量间距离的算法都可适用,如欧式距离、余弦距离、曼哈顿距离、切比雪夫距离等算法。
任务推断网络的损失函数用于对该任务推断网络的表现性能进行评价。在示例性实施例中,假设任务推断网络的损失函数值越小,则表明任务推断网络的表现性能越优;那么,在模型训练过程中,通过不断调整任务推断网络的参数,以最小化其损失函数值,以达到使得任务推断网络不断优化的目的。
在本申请实施例中,任务推断网络的损失函数与不同任务间的距离度量负相关。例如,在一种可能的实现方式中,任务推断网络的损失函数与不同任务间的距离度量的反比相关。通过这种方式,可以使得任务推断网络输出的任务表示向量,对于相同任务输出的任务表示向量尽可能地接近(也即距离度量尽可能地小),而对于不同任务输出的任务表示向量尽可能地分散(也即距离度量尽可能地大),从而实现对不同任务的有效区分,这对于学习不同任务的策略更加有优势。
步骤204,基于训练样本的状态向量、动作向量和任务表示向量,确定策略网络的损失函数和评判网络的损失函数。
策略网络的损失函数用于对该策略网络的表现性能进行评价。在示例性实施例中,假设策略网络的损失函数值越小,则表明策略网络的表现性能越优;那么,在模型训练过程中,通过不断调整策略网络的参数,以最小化其损失函数值,以达到使得策略网络不断优化的目的。
评判网络的损失函数用于对该评判网络的表现性能进行评价。在示例性实施例中,假设评判网络的损失函数值越小,则表明评判网络的表现性能越优;那么,在模型训练过程中,通过不断调整评判网络的参数,以最小化其损失函数值,以达到使得评判网络不断优化的目的。
在本申请实施例中,在计算策略网络和评判网络的损失函数时,除了考虑训练样本的状态向量和动作向量之外,还考虑了训练样本的任务表示向量,从而使得模型能够学习到针对不同任务的策略生成方式和评判方式,提升模型在不同任务上的表现性能。
步骤205,基于任务推断网络的损失函数、策略网络的损失函数和评判网络的损失函数,分别对任务推断网络、策略网络和评判网络的参数进行调整。
在得到上述3个损失函数之后,基于任务推断网络的损失函数对任务推断网络的参数进行调整,基于策略网络的损失函数对策略网络的参数进行调整,基于评判网络的损失函数对评判网络的参数进行调整,例如以最小化其损失函数值为目标,不断优化各网络的参数,以达到优化训练整个模型的目的。
综上所述,本申请实施例提供的技术方案,通过采用离线学习方式对元强化学习模型进行训练,模型训练所需的样本数据来自于离线数据池,从而无需将不成熟的模型应用到在线交互过程中,相比于传统的在线强化学习,提高了安全性及数据利用效率,且有助于降低成本。以自动驾驶场景为例,如果将不成熟的模型应用到在线交互过程中,对真实场景中运行的车辆进行控制容易造成安全问题,对仿真场景下运行的车辆进行控制则需搭建仿真场景,成本过高,且可靠性难以保证。而本申请提供的离线学习方式对元强化学习模型进行训练,则可以克服上述问题。另外,本申请实施例可以重复利用完全离线的数据进行模型训练,无需在线采集数据,更加高效。
另外,还通过采用元强化学习的方式,训练能够完成多个不同任务的模型,从而能够低成本地推广到新的任务,模型的泛化能力显著提升。
另外,任务推断网络的损失函数与不同任务间的距离度量负相关。通过这种方式,可以使得任务推断网络输出的任务表示向量,对于相同任务输出的任务表示向量尽可能地接近,而对于不同任务输出的任务表示向量尽可能地分散,从而实现对不同任务的有效区分,这对于学习不同任务的策略更加有优势。
强化学习适用于序列决策问题,譬如农业中的温室种植场景,作物生长从发芽、生长、开花、结果,遵循一定的时间规律,不同时间段对应的控制策略也有所不同。通常基于深度学习的强化学习算法采用前馈神经网络进行值函数、策略函数的拟合,本申请实施例在其中加入注意力机制,一方面对于时序数据,能更好地捕捉时间维度上各变量的相关性;另一方面也可以在隐空间的度量学习中,关注与任务信息相关性最大的样本,降低混淆样本带来的负面作用。在本申请实施例中,有至少但不限于两处可以引入注意力机制的主要环节,分别是:(1)在进行任务推断得到任务表示向量时,由于稀疏奖励等因素导致的混淆样本的存在,使得每一个批次(batch)中并非所有训练样本都对任务推断有帮助。注意力机制通过将不同训练样本赋予不同权重,帮助任务推断网络着重关注信息量更为丰富的训练样本,从而提升推断效果。且此处注意力机制并不依赖于训练样本的时序性,与本方案中采用的具有顺序不变性的任务推断网络完美兼容。(2)在训练基于任务表示向量的策略网络和评判网络时,由于网络输入为状态向量、动作向量及任务表示向量是3种意义及维度截然不同的变量,普通的函数映射难以捕捉三者间的关联性,而引入注意力机制可以有效地解决这一问题。下面,结合图3实施例,对注意力机制在模型训练过程中的应用进行介绍说明。
请参考图3,其示出了本申请另一个实施例提供的离线元强化学习模型的训练方法的流程图。该方法可以包括如下几个步骤(301~307):
步骤301,从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本。
步骤302,基于注意力机制,生成训练样本集中各个训练样本分别对应的权重;其中,训练样本对应的权重用于表示训练样本对任务表示的重要程度。
步骤303,通过任务推断网络基于训练样本集中各个训练样本分别对应的权重,生成训练样本集的任务表示向量。
任务推断网络会基于属于同一任务的一批样本(如一个训练样本集),生成相应的任务表示向量。考虑到同一训练样本集中的不同训练样本,对于任务表示的重要程度可能会有所不同,即有的训练样本对任务表示提供的信息量更多更有效,有的训练样本对任务表示提供的信息量更少,因此通过引入注意力机制,给不同训练样本赋予不同权重,使得任务推断网络在任务表示时对各个样本进行有针对性的侧重和取舍,有助于提升最终得到的任务表示向量的准确性。
步骤304,基于不同训练样本集的任务表示向量之间的距离度量,确定任务推断网络的损失函数;其中,任务推断网络的损失函数与不同任务间的距离度量负相关。
可选地,对于第一训练样本集和第二训练样本集,确定第一训练样本集的任务表示向量和第二训练样本集的任务表示向量之间的距离度量;基于该距离度量确定任务推断网络的损失函数。其中,在第一训练样本集和第二训练样本集对应相同任务的情况下,距离度量与任务推断网络的损失函数呈正相关关系;在第一训练样本集和第二训练样本集对应不同任务的情况下,距离度量与任务推断网络的损失函数呈负相关关系。
这里先介绍一种对比方案,任务推断网络的损失函数与不同任务间的距离度量线性相关。例如,任务推断网络的损失函数
Figure BDA0002802149390000111
的计算公式如下:
Figure BDA0002802149390000112
其中,xi和xj是两个训练样本集的数据,yi和yj是两个训练样本集的标签,qi和qj是两个训练样本集的任务向量表示,m是调整任务向量表示间距离度量的参数。在该对比方案中,通过最小化
Figure BDA0002802149390000113
就能够学习到相同任务的任务向量表示接近、不同任务的任务向量表示分散的表示。但是,经过实验发现,该对比方案还是不够好,很多情况下无法有效地将不同任务分隔开来。
本申请提供的技术方案,任务推断网络的损失函数与不同任务间的距离度量负相关。例如,任务推断网络的损失函数Ldml的计算公式如下:
Figure BDA0002802149390000121
其中,xi和xj是两个训练样本集的数据,yi和yj是两个训练样本集的标签,qi和qj是两个训练样本集的任务向量表示,β是比例系数,∈是预设常数,用于防止分母为0,n的取值可以是大于等于1。例如,n的取值可以是1、2、3等。
通过实验对本申请技术方案和对比方案进行比较,得到的比较结果如图4所示。在图4中,展示了20个不同任务的数据在隐空间上的二维投影(即任务向量表示)。本申请技术方案得到的实验结果是图4中(a)和(b)部分所示。图4中(a)部分对应上述n的取值为2,任务推断网络的损失函数与距离度量的平方反比相关。图4中(b)部分对应上述n的取值为1,任务推断网络的损失函数与距离度量的反比相关。对比方案得到的实验结果是图4中(c)和(d)部分所示。图4中(c)部分,任务推断网络的损失函数与距离度量线性相关。图4中(d)部分,任务推断网络的损失函数与距离度量的平方线性相关。在图4中,不同的点代表不同任务的样本,从图中可以很明显地看出,本申请技术方案相比于对比方案,不同任务的任务向量表示更为分散,这对于学习不同任务的策略更加有优势。
步骤305,基于自注意力机制,为训练样本的状态向量、动作向量和任务表示向量分别赋予对应的权重。
步骤306,基于训练样本的状态向量、动作向量和任务表示向量以及分别对应的权重,确定策略网络的损失函数和评判网络的损失函数。
在训练基于任务表示向量的策略网络和评判网络时,由于网络输入为状态向量、动作向量及任务表示向量是3种意义及维度截然不同的变量,普通的函数映射难以捕捉三者间的关联性。在本申请实施例中,通过引入自注意力机制,为上述3种不同向量分别赋予对应的权重,从而能够灵活调节各种向量对损失函数的影响程度,提升损失函数计算的准确性。
步骤307,基于任务推断网络的损失函数、策略网络的损失函数和评判网络的损失函数,分别对任务推断网络、策略网络和评判网络的参数进行调整。
参考图5,其示出了引入注意力机制后的模型架构图。离线元强化学习模型可以包括任务推断网络10、策略网络20和评判网络30。
离线数据池40用于提供多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本。
基于注意力机制,生成训练样本集中各个训练样本分别对应的权重。任务推断网络10用于基于训练样本集中各个训练样本分别对应的权重,生成训练样本集的任务表示向量。
策略网络20用于基于训练样本的状态向量和任务表示向量,生成该训练样本的动作向量。
评判网络30用于基于训练样本的状态向量、动作向量和任务表示向量,生成相应的评分(即奖励)。
任务推断网络10、策略网络20和评判网络30的损失函数,分别以Ldml、Lcritic和Lactor表示。在模型训练过程中,基于不同训练样本集的任务表示向量之间的距离度量,确定任务推断网络10的损失函数。基于自注意力机制,为训练样本的状态向量、动作向量和任务表示向量分别赋予对应的权重,基于训练样本的状态向量、动作向量和任务表示向量以及分别对应的权重,确定策略网络20的损失函数和评判网络30的损失函数。根据上述3个损失函数,分别对任务推断网络10、策略网络20和评判网络30的参数进行调整,例如以最小化其损失函数值为目标,不断优化各网络的参数,以达到优化训练整个模型的目的。
在其他可能的实施例中,考虑到训练样本的状态向量通常是包括多个维度的状态向量。以农业种植场景为例,农作物的状态向量可以包括多种温室气候条件(如温度、湿度、光照等)以及多种发育情况指标(如重量、长度等)等多个维度的状态向量。如图6所示,策略网络在基于状态向量生成动作向量的过程中,可以引入多视角注意力机制。首先,提取多个维度的状态向量分别对应的特征信息,得到多个维度的特征信息;然后,基于多视角注意力机制,为多个维度的特征信息分别赋予对应的权重;最后,基于多个维度的特征信息以及分别对应的权重,生成训练样本的动作向量。通过上述方式,针对不同维度的状态向量分别赋予不同的权重,从而能够灵活调节各维度状态向量对动作向量生成的影响程度,提升最终生成的动作向量的准确性。
需要说明的是,在本实施例中,仅示例性给出了几个例子,以说明注意力机制在本申请提供的离线元强化学习方案中的应用,在实际应用中,可以结合实际需求在合适的步骤中选择使用注意力机制,本申请实施例对此不作限定。
综上所述,本申请实施例提供的技术方案,通过引入注意力机制,一方面对于时序数据,能更好地捕捉时间维度上各变量的相关性;另一方面也可以在隐空间的度量学习中,关注与任务信息相关性最大的样本,降低混淆样本带来的负面作用。
在示例性实施例中,还可以执行如下步骤对训练好的离线元强化学习模型进行测试:
1、获取测试样本;
2、通过任务推断网络生成测试样本的任务表示向量;
3、通过策略网络基于测试样本的状态向量和任务表示向量,生成测试样本的动作向量,测试样本的动作向量用于指导在测试样本的状态向量所表征的状态下,对测试样本执行的动作策略。
例如,在策略网络用于在多任务农业种植场景下为农作物提供种植策略的情况下,通过策略网络基于农作物的状态向量和任务表示向量,生成农作物的动作向量,如该农作物自动控制相关的动作向量。该农作物的动作向量用于指导在农作物的状态向量所表征的状态下,对农作物执行的种植策略。其中,农作物的状态向量可以是其状态信息(如温室气候、作物发育情况等信息)的向量表示,农作物的动作向量可以是其种植策略(如温度、湿度、光照、水分、CO2等控制策略)的向量表示。
又例如,在策略网络用于在多任务自动驾驶场景下为自动驾驶车辆提供车辆控制策略的情况下,通过策略网络基于自动驾驶车辆的状态向量和任务表示向量,生成自动驾驶车辆的动作向量,如该自动驾驶车辆自动控制相关的动作向量。该自动驾驶车辆的动作向量用于指导在自动驾驶车辆的状态向量所表征的状态下,对自动驾驶车辆执行的车辆控制策略。其中,自动驾驶车辆的状态向量可以是其状态信息(如道路状况、车辆状况等信息)的向量表示,自动驾驶车辆的动作向量可以是其车辆控制策略(如车速、方向、灯光等控制策略)的向量表示。
再例如,在策略网络用于在多任务机器人控制场景下为机器人提供操作控制策略的情况下,通过策略网络基于机器人的状态向量和任务表示向量,生成机器人的动作向量,如该机器人自动控制相关的动作向量。该机器人的动作向量用于指导在机器人的状态向量所表征的状态下,对机器人执行的操作控制策略。其中,机器人的状态向量可以是其状态信息(如所处环境状况、机器人自身属性等信息)的向量表示,机器人的动作向量可以是其操作控制策略(如移动、抓取、放置等控制策略)的向量表示。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
请参考图7,其示出了本申请一个实施例提供的离线元强化学习模型的训练装置的框图。该装置具有实现上述方法示例的功能,所述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置可以是计算机设备,也可以设置在计算机设备中。该装置700可以包括:离线采样模块710、任务表征模块720、损失计算模块730和参数调整模块740。
离线采样模块710,用于从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本。
任务表征模块720,用于通过所述任务推断网络生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务。
损失计算模块730,用于基于不同训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的损失函数;其中,所述任务推断网络的损失函数与不同任务间的距离度量负相关。
所述损失计算模块730,还用于基于所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的损失函数和所述评判网络的损失函数。
参数调整模块740,用于基于所述任务推断网络的损失函数、所述策略网络的损失函数和所述评判网络的损失函数,分别对所述任务推断网络、所述策略网络和所述评判网络的参数进行调整。
在示例性实施例中,所述损失计算模块730,用于:
对于第一训练样本集和第二训练样本集,确定所述第一训练样本集的任务表示向量和所述第二训练样本集的任务表示向量之间的距离度量;
基于所述距离度量确定所述任务推断网络的损失函数;其中,在所述第一训练样本集和所述第二训练样本集对应相同任务的情况下,所述距离度量与所述任务推断网络的损失函数呈正相关关系;在所述第一训练样本集和所述第二训练样本集对应不同任务的情况下,所述距离度量与所述任务推断网络的损失函数呈负相关关系。
在示例性实施例中,所述任务表征模块720,用于:
基于注意力机制,生成所述训练样本集中各个训练样本分别对应的权重;其中,所述训练样本对应的权重用于表示所述训练样本对任务表示的重要程度;
通过所述任务推断网络基于所述训练样本集中各个训练样本分别对应的权重,生成所述训练样本集的任务表示向量。
在示例性实施例中,所述损失计算模块730,用于:
基于自注意力机制,为所述训练样本的状态向量、动作向量和任务表示向量分别赋予对应的权重;
基于所述训练样本的状态向量、动作向量和任务表示向量以及分别对应的权重,确定所述策略网络的损失函数和所述评判网络的损失函数。
在示例性实施例中,所述训练样本的状态向量包括多个维度的状态向量;所述策略网络,用于:
提取所述多个维度的状态向量分别对应的特征信息,得到多个维度的特征信息;
基于多视角注意力机制,为所述多个维度的特征信息分别赋予对应的权重;
基于所述多个维度的特征信息以及分别对应的权重,生成所述训练样本的动作向量。
在示例性实施例中,所述装置700还包括模型测试模块,用于:
获取测试样本;
通过所述任务推断网络生成所述测试样本的任务表示向量;
通过所述策略网络基于所述测试样本的状态向量和任务表示向量,生成所述测试样本的动作向量,所述测试样本的动作向量用于指导在所述测试样本的状态向量所表征的状态下,对所述测试样本执行的动作策略。
可选地,所述模型测试模块,用于:
在所述策略网络用于在多任务农业种植场景下为农作物提供种植策略的情况下,通过所述策略网络基于农作物的状态向量和任务表示向量,生成所述农作物的动作向量,所述农作物的动作向量用于指导在所述农作物的状态向量所表征的状态下,对所述农作物执行的种植策略;
或者,在所述策略网络用于在多任务自动驾驶场景下为自动驾驶车辆提供车辆控制策略的情况下,通过所述策略网络基于所述自动驾驶车辆的状态向量和任务表示向量,生成所述自动驾驶车辆的动作向量,所述自动驾驶车辆的动作向量用于指导在所述自动驾驶车辆的状态向量所表征的状态下,对所述自动驾驶车辆执行的车辆控制策略;
或者,在所述策略网络用于在多任务机器人控制场景下为机器人提供操作控制策略的情况下,通过所述策略网络基于所述机器人的状态向量和任务表示向量,生成所述机器人的动作向量,所述机器人的动作向量用于指导在所述机器人的状态向量所表征的状态下,对所述机器人执行的操作控制策略。
综上所述,本申请实施例提供的技术方案,通过采用离线学习方式对元强化学习模型进行训练,模型训练所需的样本数据来自于离线数据池,从而无需将不成熟的模型应用到在线交互过程中,相比于传统的在线强化学习,提高了安全性及数据利用效率,且有助于降低成本。以自动驾驶场景为例,如果将不成熟的模型应用到在线交互过程中,对真实场景中运行的车辆进行控制容易造成安全问题,对仿真场景下运行的车辆进行控制则需搭建仿真场景,成本过高,且可靠性难以保证。而本申请提供的离线学习方式对元强化学习模型进行训练,则可以克服上述问题。另外,本申请实施例可以重复利用完全离线的数据进行模型训练,无需在线采集数据,更加高效。
另外,还通过采用元强化学习的方式,训练能够完成多个不同任务的模型,从而能够低成本地推广到新的任务,模型的泛化能力显著提升。
另外,任务推断网络的损失函数与不同任务间的距离度量负相关。通过这种方式,可以使得任务推断网络输出的任务表示向量,对于相同任务输出的任务表示向量尽可能地接近,而对于不同任务输出的任务表示向量尽可能地分散,从而实现对不同任务的有效区分,这对于学习不同任务的策略更加有优势。
需要说明的是,上述实施例提供的装置,在实现其功能时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的装置与方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
请参考图8,其示出了本申请一个实施例提供的计算机设备的结构示意图。该计算机设备可以是任何具备数据计算、处理和存储功能的电子设备,如PC(Personal Computer,个人计算机)或服务器。该计算机设备用于实施上述实施例中提供的离线元强化学习模型的训练方法。具体来讲:
该计算机设备800包括处理单元(如CPU(Central Processing Unit,中央处理器)、GPU(Graphics Processing Unit,图形处理器)和FPGA(Field Programmable GateArray,现场可编程逻辑门阵列)等)801、包括RAM(Random-Access Memory,随机存储器)802和ROM(Read-Only Memory,只读存储器)803的系统存储器804,以及连接系统存储器804和中央处理单元801的系统总线805。该计算机设备800还包括帮助服务器内的各个器件之间传输信息的基本输入/输出系统(Input Output System,I/O系统)806,和用于存储操作系统813、应用程序814和其他程序模块815的大容量存储设备807。
该基本输入/输出系统806包括有用于显示信息的显示器808和用于用户输入信息的诸如鼠标、键盘之类的输入设备809。其中,该显示器808和输入设备809都通过连接到系统总线805的输入输出控制器810连接到中央处理单元801。该基本输入/输出系统806还可以包括输入输出控制器810以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器810还提供输出到显示屏、打印机或其他类型的输出设备。
该大容量存储设备807通过连接到系统总线805的大容量存储控制器(未示出)连接到中央处理单元801。该大容量存储设备807及其相关联的计算机可读介质为计算机设备800提供非易失性存储。也就是说,该大容量存储设备807可以包括诸如硬盘或者CD-ROM(Compact Disc Read-Only Memory,只读光盘)驱动器之类的计算机可读介质(未示出)。
不失一般性,该计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、EPROM(Erasable Programmable Read-Only Memory,可擦写可编程只读存储器)、EEPROM(Electrically Erasable Programmable Read-Only Memory,电可擦写可编程只读存储器)、闪存或其他固态存储其技术,CD-ROM、DVD(Digital Video Disc,高密度数字视频光盘)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知该计算机存储介质不局限于上述几种。上述的系统存储器804和大容量存储设备807可以统称为存储器。
根据本申请实施例,该计算机设备800还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即计算机设备800可以通过连接在该系统总线805上的网络接口单元811连接到网络812,或者说,也可以使用网络接口单元811来连接到其他类型的网络或远程计算机系统(未示出)。
所述存储器还包括至少一条指令、至少一段程序、代码集或指令集,该至少一条指令、至少一段程序、代码集或指令集存储于存储器中,且经配置以由一个或者一个以上处理器执行,以实现上述离线元强化学习模型的训练方法。
在示例性实施例中,还提供了一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或所述指令集在被计算机设备的处理器执行时实现上述实施例提供的离线元强化学习模型的训练方法。
可选地,该计算机可读存储介质可以包括:ROM(Read-Only Memory,只读存储器)、RAM(Random-Access Memory,随机存储器)、SSD(Solid State Drives,固态硬盘)或光盘等。其中,随机存取记忆体可以包括ReRAM(Resistance Random Access Memory,电阻式随机存取记忆体)和DRAM(Dynamic Random Access Memory,动态随机存取存储器)。
在示例性实施例中,还提供了一种计算机程序产品或计算机程序,所述计算机程序产品或计算机程序包括计算机指令,所述计算机指令存储在计算机可读存储介质中。计算机设备的处理器从所述计算机可读存储介质中读取所述计算机指令,所述处理器执行所述计算机指令,使得所述计算机设备执行上述离线元强化学习模型的训练方法。
应当理解的是,在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。另外,本文中描述的步骤编号,仅示例性示出了步骤间的一种可能的执行先后顺序,在一些其它实施例中,上述步骤也可以不按照编号顺序来执行,如两个不同编号的步骤同时执行,或者两个不同编号的步骤按照与图示相反的顺序执行,本申请实施例对此不作限定。
以上所述仅为本申请的示例性实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (7)

1.一种离线元强化学习模型的训练方法,其特征在于,所述离线元强化学习模型包括任务推断网络、策略网络和评判网络;所述方法包括:
从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本;
基于注意力机制,生成所述训练样本集中各个训练样本分别对应的权重;其中,所述训练样本对应的权重用于表示所述训练样本对任务表示的重要程度;
通过所述任务推断网络基于所述训练样本集中各个训练样本分别对应的权重,生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
基于不同训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的损失函数;其中,所述任务推断网络的损失函数与不同任务间的距离度量负相关;
通过所述策略网络提取多个维度的状态向量分别对应的特征信息,得到多个维度的特征信息,所述训练样本的状态向量包括所述多个维度的状态向量;基于多视角注意力机制,为所述多个维度的特征信息分别赋予对应的权重;基于所述多个维度的特征信息以及分别对应的权重,生成所述训练样本的动作向量;
基于自注意力机制,为所述训练样本的状态向量、动作向量和任务表示向量分别赋予对应的权重;
基于所述训练样本的状态向量、动作向量和任务表示向量以及分别对应的权重,确定所述策略网络的损失函数和所述评判网络的损失函数;
基于所述任务推断网络的损失函数、所述策略网络的损失函数和所述评判网络的损失函数,分别对所述任务推断网络、所述策略网络和所述评判网络的参数进行调整;
所述方法还包括:
获取测试样本;
通过所述任务推断网络生成所述测试样本的任务表示向量;
在所述策略网络用于在多任务农业种植场景下为农作物提供种植策略的情况下,通过所述策略网络基于农作物的状态向量和任务表示向量,生成所述农作物的动作向量,所述农作物的动作向量用于指导在所述农作物的状态向量所表征的状态下,对所述农作物执行的种植策略;或者,在所述策略网络用于在多任务自动驾驶场景下为自动驾驶车辆提供车辆控制策略的情况下,通过所述策略网络基于所述自动驾驶车辆的状态向量和任务表示向量,生成所述自动驾驶车辆的动作向量,所述自动驾驶车辆的动作向量用于指导在所述自动驾驶车辆的状态向量所表征的状态下,对所述自动驾驶车辆执行的车辆控制策略;或者,在所述策略网络用于在多任务机器人控制场景下为机器人提供操作控制策略的情况下,通过所述策略网络基于所述机器人的状态向量和任务表示向量,生成所述机器人的动作向量,所述机器人的动作向量用于指导在所述机器人的状态向量所表征的状态下,对所述机器人执行的操作控制策略;
其中,所述测试样本的动作向量用于指导在所述测试样本的状态向量所表征的状态下,对所述测试样本执行的动作策略。
2.根据权利要求1所述的方法,其特征在于,所述基于不同训练样本集的任务表示向量之间的距离度量,确定所述任务推断网络的损失函数,包括:
对于第一训练样本集和第二训练样本集,确定所述第一训练样本集的任务表示向量和所述第二训练样本集的任务表示向量之间的距离度量;
基于所述距离度量确定所述任务推断网络的损失函数;其中,在所述第一训练样本集和所述第二训练样本集对应相同任务的情况下,所述距离度量与所述任务推断网络的损失函数呈正相关关系;在所述第一训练样本集和所述第二训练样本集对应不同任务的情况下,所述距离度量与所述任务推断网络的损失函数呈负相关关系。
3.一种离线元强化学习模型的训练装置,其特征在于,所述离线元强化学习模型包括任务推断网络、策略网络和评判网络;所述装置包括:
离线采样模块,用于从离线数据池中采样获取多个训练样本集,每个训练样本集中包括属于同一任务的多个训练样本;
任务表征模块,用于基于注意力机制,生成所述训练样本集中各个训练样本分别对应的权重;其中,所述训练样本对应的权重用于表示所述训练样本对任务表示的重要程度;通过所述任务推断网络基于所述训练样本集中各个训练样本分别对应的权重,生成所述训练样本集的任务表示向量,所述训练样本集的任务表示向量用于表征所述训练样本集所属的任务;
所述策略网络,用于提取多个维度的状态向量分别对应的特征信息,得到多个维度的特征信息,所述训练样本的状态向量包括所述多个维度的状态向量;基于多视角注意力机制,为所述多个维度的特征信息分别赋予对应的权重;基于所述多个维度的特征信息以及分别对应的权重,生成所述训练样本的动作向量;
损失计算模块,用于基于自注意力机制,为所述训练样本的状态向量、动作向量和任务表示向量分别赋予对应的权重;
基于所述训练样本的状态向量、动作向量和任务表示向量以及分别对应的权重,确定所述策略网络的损失函数和所述评判网络的损失函数;
所述损失计算模块,还用于基于所述训练样本的状态向量、动作向量和任务表示向量,确定所述策略网络的损失函数和所述评判网络的损失函数;
参数调整模块,用于基于所述任务推断网络的损失函数、所述策略网络的损失函数和所述评判网络的损失函数,分别对所述任务推断网络、所述策略网络和所述评判网络的参数进行调整;
所述装置还包括模型测试模块,用于:
获取测试样本;
通过所述任务推断网络生成所述测试样本的任务表示向量;
在所述策略网络用于在多任务农业种植场景下为农作物提供种植策略的情况下,通过所述策略网络基于农作物的状态向量和任务表示向量,生成所述农作物的动作向量,所述农作物的动作向量用于指导在所述农作物的状态向量所表征的状态下,对所述农作物执行的种植策略;或者,在所述策略网络用于在多任务自动驾驶场景下为自动驾驶车辆提供车辆控制策略的情况下,通过所述策略网络基于所述自动驾驶车辆的状态向量和任务表示向量,生成所述自动驾驶车辆的动作向量,所述自动驾驶车辆的动作向量用于指导在所述自动驾驶车辆的状态向量所表征的状态下,对所述自动驾驶车辆执行的车辆控制策略;或者,在所述策略网络用于在多任务机器人控制场景下为机器人提供操作控制策略的情况下,通过所述策略网络基于所述机器人的状态向量和任务表示向量,生成所述机器人的动作向量,所述机器人的动作向量用于指导在所述机器人的状态向量所表征的状态下,对所述机器人执行的操作控制策略;
其中,所述测试样本的动作向量用于指导在所述测试样本的状态向量所表征的状态下,对所述测试样本执行的动作策略。
4.根据权利要求3所述的装置,其特征在于,所述损失计算模块,用于:
对于第一训练样本集和第二训练样本集,确定所述第一训练样本集的任务表示向量和所述第二训练样本集的任务表示向量之间的距离度量;
基于所述距离度量确定所述任务推断网络的损失函数;其中,在所述第一训练样本集和所述第二训练样本集对应相同任务的情况下,所述距离度量与所述任务推断网络的损失函数呈正相关关系;在所述第一训练样本集和所述第二训练样本集对应不同任务的情况下,所述距离度量与所述任务推断网络的损失函数呈负相关关系。
5.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如权利要求1或2所述的方法。
6.一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现如权利要求1或2所述的方法。
7.一种计算机程序产品,其特征在于,所述计算机程序产品包括计算机程序,所述计算机程序存储在计算机可读存储介质中,处理器从所述计算机可读存储介质读取并执行计算机指令,以实现如上述权利要求1或2所述的方法。
CN202011354318.1A 2020-11-27 2020-11-27 离线元强化学习模型的训练方法、装置、设备及存储介质 Active CN112348113B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011354318.1A CN112348113B (zh) 2020-11-27 2020-11-27 离线元强化学习模型的训练方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011354318.1A CN112348113B (zh) 2020-11-27 2020-11-27 离线元强化学习模型的训练方法、装置、设备及存储介质

Publications (2)

Publication Number Publication Date
CN112348113A CN112348113A (zh) 2021-02-09
CN112348113B true CN112348113B (zh) 2022-11-18

Family

ID=74365000

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011354318.1A Active CN112348113B (zh) 2020-11-27 2020-11-27 离线元强化学习模型的训练方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN112348113B (zh)

Families Citing this family (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113044064B (zh) * 2021-04-01 2022-07-29 南京大学 基于元强化学习的车辆自适应的自动驾驶决策方法及系统
CN113759709A (zh) * 2021-06-02 2021-12-07 京东城市(北京)数字科技有限公司 策略模型的训练方法、装置、电子设备和存储介质
CN113435935B (zh) * 2021-07-02 2022-06-28 支付宝(杭州)信息技术有限公司 权益推送的方法及装置
CN113759724B (zh) * 2021-09-17 2023-08-15 中国人民解放军国防科技大学 基于数据驱动的机器人控制方法、装置和计算机设备
CN114004233B (zh) * 2021-12-30 2022-05-06 之江实验室 一种基于半训练和句子选择的远程监督命名实体识别方法
CN116983656B (zh) * 2023-09-28 2023-12-26 腾讯科技(深圳)有限公司 决策模型的训练方法、装置、设备及存储介质
CN117056866B (zh) * 2023-10-12 2024-01-30 贵州新思维科技有限责任公司 一种多源特征数据融合的隧道智能调光方法及系统

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109901572A (zh) * 2018-12-13 2019-06-18 华为技术有限公司 自动驾驶方法、训练方法及相关装置
CN111260026A (zh) * 2020-01-10 2020-06-09 电子科技大学 一种基于元强化学习的导航迁移方法

Family Cites Families (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11709495B2 (en) * 2019-03-29 2023-07-25 SafeAI, Inc. Systems and methods for transfer of material using autonomous machines with reinforcement learning and visual servo control
CN110322017A (zh) * 2019-08-13 2019-10-11 吉林大学 基于深度强化学习的自动驾驶智能车轨迹跟踪控制策略
CN110587606B (zh) * 2019-09-18 2020-11-20 中国人民解放军国防科技大学 一种面向开放场景的多机器人自主协同搜救方法

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109901572A (zh) * 2018-12-13 2019-06-18 华为技术有限公司 自动驾驶方法、训练方法及相关装置
WO2020119363A1 (zh) * 2018-12-13 2020-06-18 华为技术有限公司 自动驾驶方法、训练方法及相关装置
CN111260026A (zh) * 2020-01-10 2020-06-09 电子科技大学 一种基于元强化学习的导航迁移方法

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
《基于学习的乒乓球机器人回球决策》;金礼森;《中国优秀硕士学位论文全文数据库 信息科技辑》;20190815(第2019年第08期);全文 *

Also Published As

Publication number Publication date
CN112348113A (zh) 2021-02-09

Similar Documents

Publication Publication Date Title
CN112348113B (zh) 离线元强化学习模型的训练方法、装置、设备及存储介质
CN110119844B (zh) 引入情绪调控机制的机器人运动决策方法、系统、装置
CN110019151B (zh) 数据库性能调整方法、装置、设备、系统及存储介质
US20190061147A1 (en) Methods and Apparatus for Pruning Experience Memories for Deep Neural Network-Based Q-Learning
US20200104717A1 (en) Systems and methods for neural network pruning with accuracy preservation
Mnih et al. Human-level control through deep reinforcement learning
CN110134697B (zh) 一种面向键值对存储引擎的参数自动调优方法、装置、系统
CN112052948B (zh) 一种网络模型压缩方法、装置、存储介质和电子设备
Vezzani et al. Learning latent state representation for speeding up exploration
CN112434791A (zh) 多智能体强对抗仿真方法、装置及电子设备
Yan et al. Locating and navigation mechanism based on place-cell and grid-cell models
US20210150371A1 (en) Automatic multi-objective hardware optimization for processing of deep learning networks
CN114282741A (zh) 任务决策方法、装置、设备及存储介质
CN116510302A (zh) 虚拟对象异常行为的分析方法、装置及电子设备
KR102597184B1 (ko) 가지치기 기반 심층 신경망 경량화에 특화된 지식 증류 방법 및 시스템
CN114840024A (zh) 基于情景记忆的无人机控制决策方法
CN114549516A (zh) 一种应用于多种类高密度极小虫体行为学的智能分析系统
CN114511078A (zh) 基于多策略麻雀搜索算法的bp神经网络预测方法及装置
Kanakis Designing Efficient Deep Neural Networks: Topological Optimization, Quantization and Multi-Task Learning
Spears et al. Scale-invariant temporal history (sith): optimal slicing of the past in an uncertain world
García-Ramírez et al. Model Compression for Deep Reinforcement Learning Through Mutual Information
CN113537318B (zh) 一种仿人脑记忆机理的机器人行为决策方法及设备
US20240028902A1 (en) Learning apparatus and method
CN115115057A (zh) 可持续学习模型的训练方法、装置、设备及存储介质
CN112905013B (zh) 智能体控制方法、装置、计算机设备和存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
REG Reference to a national code

Ref country code: HK

Ref legal event code: DE

Ref document number: 40037972

Country of ref document: HK

GR01 Patent grant
GR01 Patent grant