CN117474075A - 一种基于扩散模型的多任务策略学习方法 - Google Patents
一种基于扩散模型的多任务策略学习方法 Download PDFInfo
- Publication number
- CN117474075A CN117474075A CN202310680335.1A CN202310680335A CN117474075A CN 117474075 A CN117474075 A CN 117474075A CN 202310680335 A CN202310680335 A CN 202310680335A CN 117474075 A CN117474075 A CN 117474075A
- Authority
- CN
- China
- Prior art keywords
- task
- sequence
- track
- multitasking
- prompt
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000009792 diffusion process Methods 0.000 title claims abstract description 82
- 238000000034 method Methods 0.000 title claims abstract description 74
- 238000012549 training Methods 0.000 claims abstract description 33
- 230000002787 reinforcement Effects 0.000 claims abstract description 15
- 230000009471 action Effects 0.000 claims description 58
- 230000006870 function Effects 0.000 claims description 28
- 230000008569 process Effects 0.000 claims description 21
- 238000005070 sampling Methods 0.000 claims description 17
- 230000007704 transition Effects 0.000 claims description 8
- 238000010606 normalization Methods 0.000 claims description 4
- 238000004590 computer program Methods 0.000 claims description 3
- 238000009826 distribution Methods 0.000 abstract description 15
- 238000000605 extraction Methods 0.000 abstract description 3
- 230000000875 corresponding effect Effects 0.000 description 13
- 239000003795 chemical substances by application Substances 0.000 description 7
- 238000012804 iterative process Methods 0.000 description 4
- 238000004422 calculation algorithm Methods 0.000 description 3
- 238000004364 calculation method Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 238000012546 transfer Methods 0.000 description 3
- 238000013461 design Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000001965 increasing effect Effects 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000013507 mapping Methods 0.000 description 2
- LJROKJGQSPMTKB-UHFFFAOYSA-N 4-[(4-hydroxyphenyl)-pyridin-2-ylmethyl]phenol Chemical compound C1=CC(O)=CC=C1C(C=1N=CC=CC=1)C1=CC=C(O)C=C1 LJROKJGQSPMTKB-UHFFFAOYSA-N 0.000 description 1
- 102100033814 Alanine aminotransferase 2 Human genes 0.000 description 1
- 101710096000 Alanine aminotransferase 2 Proteins 0.000 description 1
- 230000010391 action planning Effects 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 230000001276 controlling effect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 230000007613 environmental effect Effects 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 230000008676 import Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 230000015654 memory Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000002688 persistence Effects 0.000 description 1
- 230000003014 reinforcing effect Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- MYVIATVLJGTBFV-UHFFFAOYSA-M thiamine(1+) chloride Chemical compound [Cl-].CC1=C(CCO)SC=[N+]1CC1=CN=C(C)N=C1N MYVIATVLJGTBFV-UHFFFAOYSA-M 0.000 description 1
- 238000009827 uniform distribution Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/092—Reinforcement learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N7/00—Computing arrangements based on specific mathematical models
- G06N7/01—Probabilistic graphical models, e.g. probabilistic networks
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Algebra (AREA)
- Computational Mathematics (AREA)
- Mathematical Analysis (AREA)
- Mathematical Optimization (AREA)
- Pure & Applied Mathematics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及一种基于扩散模型的多任务策略学习方法,属于强化学习,用于解决多任务策略在学习中具有梯度冲突且泛化能力弱的问题。本案的目的在于提出一种基于扩散模型的多任务策略学习方法,不仅能够建模复杂的多任务轨迹分布,同时具有轻量化的网络结构;此外,使用包括任务的单条轨迹信息的任务提示语,在训练中获得任务提示语和任务之间的关系,不需要额外的预训练或语言模型进行特征提取。在新的任务中只需要提供该任务的单条示教数据,就能够将现有策略泛化到其他任务中。
Description
技术领域
本案涉及强化学习,尤其涉及基于扩散模型的多任务策略学习方法。
背景技术
从离线数据集中学习是避免智能体进行高昂代价的在线交互来获得策略的方法。通过从多个任务的混合数据集进行学习,希望使智能体能够直接获得多任务的策略,同时该策略能够在未见到过的任务中具有一定的泛化能力。
现有方法主要通过设计网络结构和使用任务标记来解决多任务学习的梯度冲突问题和任务泛化问题,主要分为三类:
(1)基于多头网络结构的多任务策略学习。多头网络在解决多任务梯度冲突中具有很强的能力,然而多头网络具有较大的参数量和较高的计算代价。特别的,当任务数量较大时,多头网络的分支数量也会随之增长,从而无法使用。
(2)基于共享网络和任务标记的多任务策略学习。共享网络结构虽然在一定程度上解决了多头网络存在的问题,然而引入的任务编码或者任务描述也需要独立的编码器来进行特征提取,或使用独立的网络结构来进行编码的预训练,从而带来较大的计算代价。
(3)使用大语言模型的多任务策略学习。大语言模型驱动的任务描述虽然能够调用预训练的语言模型而直接获得任务相关的编码,但却需要额外的大模型API或大模型基础设施(多卡GPU)来导入现有的预训练大模型进行推理,具有一定的技术难度。同时,该方法要求每个任务需要带有与之相应的任务描述,这在一定程度上需要人类对任务进行详细描述。
此外,所有现有方法在任务层面的泛化能力都是有限的。具体的,多头网络结构在任务层面不具有泛化性,在遇到新的任务时需要在原有多头结构的基础上增加新的分支,并对该分支进行重新训练。共享网络结构在使用任务编码作为输入时,在任务层面的泛化能力取决于任务编码的泛化性。特别的,任务编号本身不包含任何的任务描述信息,故而在任务层面不具备泛化性。任务的描述信息具有一定的泛化性,但取决于描述信息是否非常详尽,同时是否抽取到了任务的关键要素,以及使用的大语言模型是否能够提取到相应的特征。
发明内容
为了解决现有技术中存在的上述问题,本案的目的在于提出一种基于扩散模型的多任务策略学习方法,不仅能够建模复杂的轨迹分布,同时具有轻量化的网络结构;此外,使用任务的单条轨迹信息作为任务提示语,在训练中建模提示语和任务之间的关系,不需要额外的预训练或语言模型进行特征提取。在新的任务中只需要提供该任务的单条示教数据,就能够将现有策略泛化到其他任务中。为了实现上述技术目的,本案的技术方案如下。
第一方面,本案提出一种基于扩散模型的多任务策略学习方法,所述方法包括下述步骤:
对离线多任务数据集中的动作序列或者轨迹序列,利用训练好的多任务扩散模型,获取智能体能够与环境交互的最优动作序列,或生成新的轨迹序列以用于强化学习;
所述多任务扩散模型为对多任务轨迹使用扩散模型建模,获得任务提示语和任务之间的关系;
任务提示语包括给定任务对应的轨迹提示,轨迹为关于状态、动作和奖励的转移序列。
在上述技术方案的一种实施方式中,对于动作序列,任务提示语还包括奖励回报以及历史状态。
在上述技术方案的一种实施方式中,每个任务设置1-3条轨迹为专家轨迹,将专家轨迹切分成多个片段,多任务扩散模型的逆向扩散过程中,通过采样片段重组成轨迹提示。
在上述技术方案的一种实施方式中,当为动作序列时,多任务扩散模型的逆向扩散过程中的预测噪声为:
式中:∈θ为预测噪声函数,为给定任务T第k步的动作序列,/>表示缺省值为空,R(T)为给定任务T的回报,α为超参数,y′(T)由给定任务T的轨迹提示和历史状态观测序列构成,y′(T)和R(T)共同构成给定任务T的任务提示语;
当为轨迹序列时,多任务扩散模型的逆向扩散过程中的预测噪声为:
式中:∈θ为预测噪声函数,为给定任务T第k步的轨迹序列,ys(T)为在给定任务T的任务提示语,由轨迹提示构成。
在上述技术方案的一种实施方式中,逆向扩散过程采用GPT网络预测噪声。
在上述技术方案的一种实施方式中,GPT的输入数据在输入之前,使用归一化层处理成统一的令牌;输出之后使用一个由全连接层组成的预测头来预测扩散时间步长k处相应的噪声。
在上述技术方案的一种实施方式中,给定任务T的动作序列则第k步去噪得到的动作序列/>满足:
其中:μθ为均值函数,为给定任务T的第k步的动作序列,y′(T)和R(T)构成给定任务T的第k步的任务提示语,y′(T)由给定任务T的轨迹提示和历史状态观测序列构成,R(T)为给定任务T的第k步回报,∑k为第k步方差,β用于减小方差以生成最优动作序列,β∈[0,1);
给定任务T的轨迹序列则第k步去噪得到的轨迹序列/>满足:/>
其中:μθ为均值函数,∑k为第k步方差,为给定任务T的第k步轨迹序列,ys(T)为给定任务T的任务提示语。
在上述技术方案的一种实施方式中,为训练稳定,将不同大小的原始输入转换成相同大小维度,在归一化层进行加强叠加处理统一的令牌或/>
其中:表示动作序列对应的令牌,/>表示轨迹序列对应的令牌;LN表示归一化层加强叠加处理函数;hTi为扩散时间步转换后的量,hP为轨迹提示转换后的量,为动作序列对应的回报转换后的量,/>为历史状态序列转换后的量,/>为动作序列转换后的量,/>为轨迹序列转换后的量,Epos为归一化层要处理的各个量的位置量。
在上述技术方案的一种实施方式中,多任务扩散模型的训练包括下述步骤:
当为动作序列时:
构造训练任务集Ttrain,对训练任务集中的每个任务Ti,从给定的多任务子集中采样M组长度为H的动作序列集及对应的长度为L历史状态观测序列/>计算每个任务Ti下的标准返回R(Ti),从任务子集对应的多任务轨迹提示Zi中采样M个长度为J的轨迹提示/>
获取一组样本|Ttrain|为构造的训练任务集的大小;
随机采样扩散时间步k~U(1,K),获得噪声动作序列
以概率β~Bern(p)将R(Ti)置空;
计算损失函数以更新多任务扩散模型;
当为轨迹序列时:
构造训练任务集Ttrain,对训练任务集中的每个任务Ti,从给定的多任务子集中采样M个长度为H的轨迹序列
从多任务轨迹提示P*采样M个长度为J的轨迹提示Zi;
获取一组样本|Ttrain|为构造的训练任务集的大小;
随机采样扩散时间步k~U(1,K),获得噪声轨迹序列
计算损失函数以更新多任务扩散模型。
第二方面,本案提出一种机器人,机器人采用上述任一种方法获得与环境交互的动作序列或者生成离线强化学习策略。
第三方面,本案提出一种可读存储介质,存储有能够被处理器加载并执行上述任一种方法的计算机程序。
本案的技术效果如下:
(1)利用扩散模型对多模态数据的建模能力来直接的对多任务离线数据集进行建模,轻量化的网络结构使参数量远小于多头网络模型和共享网络模型,极大的降低了模型的计算需求。
(2)不需要任务描述信息作为输入,而是通过获得几条专家轨迹数据作为任务相关属性的代表,在结构上更加容易实现。
(3)使用任务提示语作为逆向扩散过程的引导条件,通过训练条件扩散模型中任务提示语和任务策略之间的关系,使策略能够在训练任务以外的任务中进行泛化。
(3)通过控制扩散模型的生成过程使模型生成多样化的数据,这些数据将用于扩充原有的数据集合,扩充后的数据集能够帮助离线强化学习算法获得更好的策略。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1、一种实施方式中基于扩散模型的多任务学习模型架构图;
图2、一种基于扩散模型的多任务策略学习方法的实施方式流程图。
具体实施方式
相关概率分布:
1、称X服从区间[a,b]上的均匀分布,记作X~U[a,b],b>a。
2、随机变量X~N(0,1)为标准正态分布,随机变量X~N(μ,σ2)为正态分布。
现有的离线强化学习算法一般针对单任务开展研究,智能体从单任务数据集中学习策略。在面对多个任务时,智能体需要独立的学习多个策略,同时学到的策略无法扩展到新的任务中。现有方法解决该问题的思路是通过训练多头结构的网络来得到多任务策略,多头结构通过共享主网络来降低参数量的需求。此类方法可以解决多任务策略在学习中产生的梯度冲突问题,使用多头机制在一定程度上保持了多任务策略网络的独立性。然而,该结构无法解决策略在任务层面的泛化问题。其他方法尝试使用任务编号作为输入,使用多任务共享网络来建模编号和策略之间的关系,然而任务编号往往无法涵盖任务相关信息,同时在任务差别较大时无法进行泛化。
为了解决以上问题,本案提出了一种新的扩散模型结构用于多任务策略学习,利用扩散模型对多模态数据的建模能力来直接的对多任务离线数据集进行建模,设计轻量化的网络结构使参数量远小于多头网络模型和共享网络模型。同时,为了使策略能够在训练任务以外的任务中进行泛化,本案提出使用任务提示语作为网络模型的条件,在训练条件扩散模型中建模提示语和任务策略之间的关系。在测试中使用测试任务的提示语作为输入,能够使策略泛化到测试任务中。本案通过实验验证了提出的条件扩散模型在多任务离线强化学习中的泛化能力,相比于现有方法获得了很大提升。
下面将结合本申请实施例中的附图1-2,对本案实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
一、相关知识
(1)多任务MDP(Markov Decision Process,马尔科夫决策过程)
一个马尔科夫决策过程可以使用元组(S,A,P,R,μ,γ)定义,其中:S为状态矢量空间,A动作空间,P:S×A→S是状态转移函数,R:S×A×S→R是奖励函数,γ∈(0,1]是损失因子,μ是初始状态分布。在每一个时间步t,智能体根据下述策略π:S→ΔA(策略π是从状态S到动作A概率分布的映射)选择一个动作at,然后智能体获得下一个状态st+1并获得一个标量奖励rt。在一个任务的强化学习中,其目标通过最大化相应任务的累积奖励期望来学习策略π*:
在多任务设置中,不同的任务有不同的奖励函数、状态向量空间和状态转移函数。本案中的所有任务都用相同的智能体共用相同的动作空间。对于一类特定任务,多任务强化学习的目标是找到一个最优策略,它可以得到所有任务的最大期望回报。特定任务可以是已有任务,也可以是全新的任务。
(2)多任务离线决策
在离线决策中,给定一个未知行为策略收集状态转换的静态数据集来学习该策略,其中sj为第j个状态,aj第j个动作,s′j为第j个状态转换的新状态,rj为执行第j个动作的奖励。
在多任务离线强化学习设置中,任务集D被划分为多个任务子集,离线设置中强化学习的关键问题是由时序差异(Temporal-difference,TD)引起的分布偏移问题。本案通过使用扩散模型来解决这个问题,将多任务策略学习看作一个条件生成过程而不拟合值函数,利用扩散模型对多任务数据强大的分布建模能力,以避免分布偏移。并且,通过扩散模型基于潜在的马尔科夫过程生成新的状态转移数据(s,a,s′,r),以扩增原始数据集,从而实现显著改进策略。
(3)扩散模型
本案采用扩散模型从多任务数据中学习多任务轨迹分布。使用xk作为使用扩散模型的第k步去噪输出,使用y表示引导条件。
前向扩散过程是指的对数据逐渐增加高斯噪音直至数据变成随机噪音的过程,本案在K步中使用预定义的方差计划(variance schedule)βk向数据x0~q(x)中逐步添加噪声,这个过程可以表示为:
本案中:
式中:βmin=0.1,βmax=10为常量。
一个可训练的逆向扩散过程是学习条件分布q(x|y)。构造逆向扩散过程概率函数:
pθ(xk-1|xk,y):=N(xk-1|μθ(xk,y,k),∑k) (2)
式中:μθ为均值函数,∑k为第k步方差。
通过下述损失进行优化:
其中,∈θ是参数化的深度神经网络,通过训练来预测添加到数据集样本x0中以生成xk的噪声∈,∈~N(0,I)。通过设置αk:=1-βk,可逐步计算去噪结果:
在训练阶段,同时学习一个有条件噪声预测和一个无条件噪声预测,因此设计扰动噪声来生成样本,α为引导标量,是一个超参数,/>表示条件为空。
二、实施方式
为了得到从多个马尔科夫过程中采样轨迹的多模态分布,利用扩散模型将多任务轨迹模建模为一个条件生成问题,以获得任务提示语和任务之间的关系。多任务扩散模型如下:
式中:x0(T)为给定任务T的预期生成序列,y(T)为给定任务T的任务提示语。pθ为条件逆向去噪过程。通过最大化一个变分下界来获得(5)的近似最大值。
给定任务T的序列x(T)={xk(T)|k=0,1,2,…,K),序列x(T)既可以表示动作序列,也可以表示轨迹序列。
(1)动作序列
对于动作序列,多任务扩散模型记作MTDIFF-p,其目标在于设计出能够最大化回报的最优动作序列。
使用表示给定任务T第k步噪声扩散后的动作序列,条件相应表示为yp(T):=[y′(T),R(T)],y′(T):=(Z,st-L+1,…,st),t表示给定任务T中的时间步,H表示动作序列x的长度,yp(T)为给定任务T的任务提示语,R(T)为给定任务T的回报,Z表示轨迹提示语,st-L+1,…,st为历史状态观测序列,L为历史状态观测序列长度。将y′(T)为一个条件在训练和应用中都用的普通条件,同时考虑R(T)作为无分类器的引导(guidance),以获得给定任务的最优动作序列。其中:H、L设定长度与实验环境相关,在一个实施例中,L=i0,H取值为200~300。
最优动作序列是从高斯噪声/>开始采样,从/>细化到/>在每个中间时间步长采用下述扰动噪声/>:
式中:∈θ为预测噪声函数,为给定任务T第k步的动作序列,/>表示缺省值为空,y′(T)由给定任务T的轨迹提示和历史状态观测序列构成,R(T)为给定任务T的回报(或归一化回报),y′(T)和R(T)共同构成给定任务T的任务提示语。α是超参,用于寻求增加和提取高回报的轨迹的最佳部分。在训练过程中,采用DDPM和无分类器引导来训练逆向扩散过程pθ,该逆向扩散过程通过噪声模型∈θ进行参数化,采用下述损失函数进行训练:
将第k步回报R(T)以概率β空置,β服从伯努利分布。在推理过程中,通过采用低温采样技术(low-temperature sampling technique)来生成高相似的序列。去噪过程中的动作序列β∈[0,1),用于减小方差来生成具有更高最优性的动作序列。Rmax(T)为给定任务T的最大回报。
(2)轨迹序列
对于轨迹序列,多任务扩散模型记作MTDIFF-s。本案将状态、动作和奖励的转移序列作为轨迹。将轨迹序列表示如下:
相应地,引导条件ys(T):=[Z],Z表示轨迹提示,t为轨迹起始时间步,H为轨迹长度。轨迹提示为关于状态、动作的转移序列。
MTDIFF-s的目标在于综合不同的轨迹以进行数据增强,因此不需要像R(T)一样的引导。采用下述无引导的损失:
采样值
(3)任务提示语
在技术方案的实施过程中,通过采用不同任务提示语为条件,不仅能将特定任务生成的动作序列和其它任务动作序列分离,而且也能学习多模态轨迹的分布。
在多任务学习中,任务提示语包括轨迹提示,轨迹提示由少量的专家轨迹组成,对于每个任务,可指定该任务中的1-3条轨迹作为专家轨迹,可以是历史轨迹也可以是新设置任务的轨迹。从专家轨迹信息获取轨迹提示,以作为示范来以最直接的方式展示了任务的特征以及如何完成该任务,从而解决在机械臂等任务中往往难以通过语言的方式来准确描述任务信息的问题。此外,从专家轨迹信息获取任务提示语非常类似于人类学习新任务的过程。人类在新任务的学习中往往只需要专家在新任务上进行几次示范,就能够将旧的多任务策略快速泛化到该新任务中。因而在设计中快速利用新任务的示教数据,能够加速在其他任务中进行泛化。
在一种实施方式中,将这些任务提示语构造一个提示语信息集合。
在一种实施方式中,特定任务轨迹提示语Z包含状态和动作,如下所示:
其中,J是用于识别任务的环境步数。每个带“*”的元素是与轨迹相关的轨迹提示,s为状态,a为动作,i为轨迹提示语用到的时间步。
作为进一步改进,在一种具体实现中,将轨迹提示切分成多个片段,通过采样片段重组,以增加轨迹提示的多样性。
将任务提示语作为条件,多任务扩散模型能够通过隐式捕获转移模型和存储在提示中的奖励函数来指定任务,能够更好地通过没有额外的参数调整来泛化到不可见的任务。
综上,多任务扩散模型为对多任务轨迹使用扩散模型建模,前向扩散过程为在K步中使用预定义的方差计划向最优动作序列或生成新的轨迹序列添加噪声的过程,逆向扩散过程中为学习最优动作序列或生成新的轨迹序列在任务提示语下的条件分布。
(3)逆向扩散过程
本案提出一个新的扩散模型如图1所示,在其中采用GPT-2作为变压器预测噪声,一方面由于GPT模型相比于现有的扩散基础模型U-net等具有更少的参数量,另一方面GPT模型通过多个全局注意力模块能够更好的建模条件信息和任务轨迹的相关关系,同时提取到轨迹层面的深层特征。此外,GPT模型拥有统一的输入结构来编码条件信息和轨迹信息,将其转换为输入向量元组,使网络能够处理数量足够多的条件信息。并且,GPT能够实现性能和计算效率之间的良好的平衡,从而使整体扩散模型具有良好的性能。
参见图1,将不同原始输入x嵌入独立MLP(Multi-Layer Perceptron,全连接神经网络)中,将MLP的映射函数记作f,输出记作h,从而可以获得相同大小的输出,可表示如下:
(3.1)对于轨迹提示xprompt和扩散时间步xtimestep,有:
hP=fP(xprompt)
hTi=fTi(xtimestep)
(3.2)对于轨迹序列xtransitions,有:
(3.3)对于动作序列xactions,对应的历史状态序列xhistory,以及对应的动作回报xreturn,有:
其中,hP、hTi分别是MTDIFF-s和MTDIFF-p的共同输入。
在一种实施方式中的进一步改进,通过加强叠加输入来使训练稳定。具体地,通过与扩散时间步长相乘并与返回值相加,然后与各个量的位置量Epos相加,再输入归一化层LN处理,得到GPT的输入令牌或/>即:
基于GPT的输出,使用一个由全连接层组成的预测头(Prediction Head)来预测扩散时间步长k处相应的噪声,该噪声在推理过程中用于逆向去噪过程pθ。
上述实施方式是在GPT模型的基础上构建基于任务提示语的条件生成模型。具体的,GPT模型将使用离线数据集的轨迹信息和任务提示语共同作为输入,得到预测的噪声,经去噪能够得到预期的动作序列。
(4)方法汇总
(4.1)对于动作序列的MTDIFF-p的训练和应用方法
(4.1.1)训练过程
初始化:用于训练的任务集合Ttrain、迭代次数N、多任务数据集D、批大小M,多任务轨迹提示P*,扩散时间总步数K
迭代过程伪代码:
(4.1.2)应用过程
给定任务T,设置要想要的返回Rmax(T)
给定多任务轨迹提示Z,给定初始状态历史h0=(st-L+1,…,st),L是观察的状态历史长度,t表示轨迹T中的时间步;
设置低温采样技术标量β,无分类器引导标量α
迭代过程伪代码:
/>
(4.2)对于轨迹序列的MTDIFF-s的训练和应用方法
(4.2.1)训练过程
初始化:用于训练的轨迹集合Ttrain、迭代次数N、多任务数据集D、批大小M,多任务轨迹提示P*,扩散时间总步数K
迭代过程伪代码:
(4.2.2)应用过程
给定任务T,给定多任务轨迹提示P*,生成M个轨迹序列,生成轨迹集合
迭代过程伪代码:
参见图2,在模型设计的基础上,使用机器人多任务离线数据集来对学习效果进行评价。
在动作规划的设定下,使用任务提示语来产生动作序列,使用较高的目标回报作为条件来引导最优动作序列的生成,生成的最优动作序列可以用于和环境的直接交互学习。
在数据生成的设定下,使用任务提示语来生成多样化的轨迹序列,通过控制扩散模型的生成过程使模型生成多样化的数据,这些数据将用于扩充原有的数据集合,扩充后的数据集能够帮助离线强化学习算法获得更好的策略。使用Meta World等多任务机器人仿真环境对本案提出的方法进行实验,同时在未见过的任务中进行泛化。
通过以上的实施方式的描述,所属领域的技术人员可以清楚地了解到本公开可借助软件加必需的通用硬件的方式来实现,当然也可以通过专用硬件包括专用集成电路、专用CPU、专用存储器、专用元器件等来实现。一般情况下,凡由计算机程序完成的功能都可以很容易地用相应的硬件来实现,而且,用来实现同一功能的具体硬件结构也可以是多种多样的,例如模拟电路、数字电路或专用电路等。但是,对本公开而言更多情况下,软件程序实现是更佳的实施方式。
需要说明的是在本说明书中所谈到的“一个实施例”、“另一个实施例”、“实施例”等,指的是结合该实施例描述的具体特征、结构或者特点包括在本申请概括性描述的至少一个实施例中。在说明书中多个地方出现同种表述不是一定指的是同一个实施例。进一步来说,结合任一实施例描述一个具体特征、结构或者特点时,所要主张的是结合其他实施例来实现这种特征、结构或者特点也落在本案的范围内。
尽管以上结合附图对本案的实施方案进行了描述,但本案并不局限于上述的具体实施方案和应用领域,上述的具体实施方案仅仅是示意性的、指导性的,而不是限制性的。本领域的普通技术人员在本说明书的启示下和在不脱离本案权利要求所保护的范围的情况下,还可以做出很多种的形式,这些均属于本案保护之列。
Claims (10)
1.一种基于扩散模型的多任务策略学习方法,其特征在于,所述方法包括下述步骤:
对离线多任务数据集中的动作序列或者轨迹序列,利用训练好的多任务扩散模型,获取智能体能够与环境交互的最优动作序列,或生成新的轨迹序列以用于强化学习;
所述多任务扩散模型为对多任务轨迹使用扩散模型建模,获得任务提示语和任务之间的关系;
任务提示语包括给定任务对应的轨迹提示,轨迹提示为关于状态、动作的转移序列。
2.根据权利要求1所述的方法,其特征在于,对于动作序列,任务提示语还包括奖励回报以及历史状态观测序列。
3.根据权利要求1所述的方法,其特征在于,每个任务设置1-3条轨迹为专家轨迹,从专家轨迹获取轨迹提示,将轨迹提示切分成多个片段,多任务扩散模型的逆向扩散过程中,通过采样片段重组成轨迹提示。
4.根据权利要求1所述的方法,其特征在于:
当为动作序列时,多任务扩散模型的逆向扩散过程中的预测噪声为:
式中:∈θ为预测噪声函数,为给定任务T第k步的动作序列,/>表示缺省值为空,R(T)为给定任务T的回报,α为超参数,y′(T)由给定任务T的轨迹提示和历史状态观测序列构成,y′(T)和R(T)共同构成给定任务T的任务提示语;
当为轨迹序列时,多任务扩散模型的逆向扩散过程中的预测噪声为:
式中:∈θ为预测噪声函数,为给定任务T第k步的轨迹序列,ys(T)为在给定任务T的任务提示语。
5.根据权利要求4所述的方法,其特征在于,逆向扩散过程采用GPT网络预测噪声。
6.根据权利要求5所述的方法,其特征在于,GPT的输入数据在输入之前,将不同大小的原始输入转换成相同大小维度,再使用归一化层处理成统一的令牌;输出之后使用一个由全连接层组成的预测头来预测扩散时间步长k处相应的噪声。
7.根据权利要求1所述的方法,其特征在于:
给定任务T的动作序列则第k步去噪得到的动作序列/>满足:
其中:μθ为均值函数,为给定任务T的第k步的动作序列,y′(T)和R(T)构成给定任务T的第k步的任务提示语,y′(T)由给定任务T的轨迹提示和历史状态观测序列构成,R(T)为给定任务T的第k步回报,∑k为第k步方差,β为低温采样标量;
给定任务T的轨迹序列则第k步去噪得到的轨迹序列/>满足:
其中:μθ为均值函数,∑k为第k步方差,为给定任务T的第k步轨迹序列,ys(T)为给定任务T的任务提示语。
8.根据权利要求1所述的方法,其特征在于:多任务扩散模型的训练包括下述步骤:
当为动作序列时:
构造训练任务集Ttrain,对训练任务集中的每个任务Ti,从给定的多任务子集中采样M组长度为H的动作序列集及对应的长度为L历史状态观测集/>计算每个任务Ti下的标准返回R(Ti),从任务子集对应的多任务轨迹提示Zi中采样M个长度为J的轨迹提示/>
获取一组样本|Ttrain|为构造的训练任务集的大小;
随机采样扩散时间步k~U(1,K),获得噪声动作序列
以概率β~Bern(p)将R(Ti)置空;
计算损失函数以更新多任务扩散模型;
当为轨迹序列时:
构造训练任务集Ttrain,对训练任务集中的每个任务Ti,从给定的多任务子集中采样M个长度为H的轨迹序列
从多任务轨迹提示P*采样M个长度为J的轨迹提示Zi;
获取一组样本|Ttrain|为构造的训练任务集的大小;
随机采样扩散时间步k~U(1,K),获得噪声轨迹序列
计算损失函数以更新多任务扩散模型。
9.一种机器人,其特征在于:机器人采用权利要求1至8中任一种方法获得与环境交互的动作序列或者生成离线强化学习策略。
10.一种可读存储介质,其特征在于:存储有能够被处理器加载并执行如权利要求1至8中任一种方法的计算机程序。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310680335.1A CN117474075A (zh) | 2023-06-08 | 2023-06-08 | 一种基于扩散模型的多任务策略学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310680335.1A CN117474075A (zh) | 2023-06-08 | 2023-06-08 | 一种基于扩散模型的多任务策略学习方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117474075A true CN117474075A (zh) | 2024-01-30 |
Family
ID=89624460
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310680335.1A Pending CN117474075A (zh) | 2023-06-08 | 2023-06-08 | 一种基于扩散模型的多任务策略学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117474075A (zh) |
-
2023
- 2023-06-08 CN CN202310680335.1A patent/CN117474075A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Wang et al. | Deep reinforcement learning: a survey | |
Rennie et al. | Self-critical sequence training for image captioning | |
Gabora et al. | Two cognitive transitions underlying the capacity for cultural evolution | |
Harmer et al. | Imitation learning with concurrent actions in 3d games | |
US11086938B2 (en) | Interpreting human-robot instructions | |
CN111602144A (zh) | 生成指令序列以控制执行任务的代理的生成神经网络系统 | |
CN112434171A (zh) | 一种基于强化学习的知识图谱推理补全方法及系统 | |
CN111461325B (zh) | 一种用于稀疏奖励环境问题的多目标分层强化学习算法 | |
CN110309170A (zh) | 一种任务型多轮对话中的复杂意图识别方法 | |
Thórisson | Seed-programmed autonomous general learning | |
CN116205298A (zh) | 一种基于深度强化学习的对手行为策略建模方法及系统 | |
Li et al. | SADRL: Merging human experience with machine intelligence via supervised assisted deep reinforcement learning | |
Persiani et al. | A working memory model improves cognitive control in agents and robots | |
CN117454965A (zh) | 基于随机Transformer模型的有模型深度强化学习方法 | |
CN113379027A (zh) | 一种生成对抗交互模仿学习方法、系统、存储介质及应用 | |
CN117474075A (zh) | 一种基于扩散模型的多任务策略学习方法 | |
CN111783983A (zh) | 用于实现导航的可迁移的元学习的无监督dqn强化学习 | |
Zintgraf | Fast adaptation via meta reinforcement learning | |
CN116306947A (zh) | 一种基于蒙特卡洛树探索的多智能体决策方法 | |
Liu | Learning task-oriented dialog with neural network methods | |
Wang et al. | A review of deep reinforcement learning methods and military application research | |
Sun et al. | Research on Sports Dance Video Recommendation Method Based on Style | |
Liu et al. | Soft-Actor-Attention-Critic Based on Unknown Agent Action Prediction for Multi-Agent Collaborative Confrontation | |
Zhai et al. | Building Open-Ended Embodied Agent via Language-Policy Bidirectional Adaptation | |
Villarrubia-Martin et al. | A hybrid online off-policy reinforcement learning agent framework supported by transformers |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |