CN112819159A - 一种深度强化学习训练方法及计算机可读存储介质 - Google Patents

一种深度强化学习训练方法及计算机可读存储介质 Download PDF

Info

Publication number
CN112819159A
CN112819159A CN202110208061.7A CN202110208061A CN112819159A CN 112819159 A CN112819159 A CN 112819159A CN 202110208061 A CN202110208061 A CN 202110208061A CN 112819159 A CN112819159 A CN 112819159A
Authority
CN
China
Prior art keywords
situation
reinforcement learning
deep reinforcement
samples
neural network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110208061.7A
Other languages
English (en)
Inventor
张甜甜
袁博
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen International Graduate School of Tsinghua University
Original Assignee
Shenzhen International Graduate School of Tsinghua University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen International Graduate School of Tsinghua University filed Critical Shenzhen International Graduate School of Tsinghua University
Priority to CN202110208061.7A priority Critical patent/CN112819159A/zh
Publication of CN112819159A publication Critical patent/CN112819159A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • G06F18/232Non-hierarchical techniques
    • G06F18/2321Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
    • G06F18/23213Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Biophysics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Probability & Statistics with Applications (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Evolutionary Biology (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Feedback Control In General (AREA)

Abstract

本发明提供一种深度强化学习训练方法及计算机可读存储介质,方法包括:指定情境数量,初始化深度强化学习多头神经网络模型的权重参数;智能体随机决策,收集样本存于经验回放缓冲区;依据情境数量,采用在线聚类算法实现自适应情境划分,得到截止当前时刻的情境划分和各情境中心;从经验回放缓冲区随机采样样本,并将各样本分配至距离最近的情境中;依据样本对应情境训练共享特征提取器及相应输出头的权重参数,结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数;下一时间步,智能体依据值函数继续决策,收集样本存于经验回放缓冲区,重复上述步骤,直至完成预先指定的训练次数或达到收敛。提升了模型训练的稳定性和可塑性。

Description

一种深度强化学习训练方法及计算机可读存储介质
技术领域
本发明涉及人工智能技术领域,尤其涉及一种深度强化学习训练方法及计算机可读存储介质。
背景技术
在强化学习领域,深度神经网络强大的学习能力使得智能体直接从高维连续环境中学习有效的控制策略成为可能。理论上,为了实现稳定的训练性能,神经网络一般要求训练数据满足独立同分布(i.i.d.)的特点,这在一般的强化学习范式中几乎是不可能成立的。强化学习边探索边学习的训练模式使得训练数据具有高度时间相关和非平稳的固有属性,由于神经网络在训练过程前后采用的训练数据分布不同,后期训练得到的权重很可能干扰甚至完全覆盖前期已经学习到的好的策略,从而导致模型性能受到干扰甚至是突然崩溃,使得模型训练过程非常不稳定,甚至很难收敛到优策略。对应于实际具体应用,如人工智能围棋系统等各类游戏对战、机器人调优工业设备参数等工业自动化应用、自动驾驶领域车辆运动规划等凡是利用强化学习来自动化寻求最佳序贯决策的真实应用场景,则表现为强化学习智能体在特定环境中学习完成特定任务的策略过程非常不稳定,随着学习的进行,智能体可能会突然忘记已经学习到的稍好的策略以致于面对相应的环境场景做出错误的决策,从而必须重新从头开始再次学习,后期再次遗忘并再次重新学习,如此反复,使得智能体学习优策略的效率大大降低,甚至最终无法学习到完成相应任务的优策略。
以上问题被称为灾难性干扰和遗忘(Catastrophic Interference andForgetting)。现有基于值的深度强化学习训练框架一般采用经验回放和固定目标网络两种策略来缓解灾难性干扰和遗忘问题,其中,经验回放对计算内存有很高的要求,尤其是当处理复杂图像或视频输入问题时,为了能更好地产生近似独立同分布的训练数据,需要设置百万甚至更高级别的经验存储缓冲区大小,这对一般计算机而言是非常困难的;此外,固定目标网络也只能使输出目标相对平稳,单独使用时对灾难性干扰和遗忘问题改善效果非常有限。
现有技术中缺乏解决强化学习领域神经网络模型在训练过程中所遇到的灾难性干扰和遗忘问题的方案。
以上背景技术内容的公开仅用于辅助理解本发明的构思及技术方案,其并不必然属于本专利申请的现有技术,在没有明确的证据表明上述内容在本专利申请的申请日已经公开的情况下,上述背景技术不应当用于评价本申请的新颖性和创造性。
发明内容
本发明为解决现有深度强化学习神经网络模型在训练过程中普遍遭遇的灾难性干扰和遗忘问题,提供一种深度强化学习训练方法及计算机可读存储介质。
为了解决上述问题,本发明采用的技术方案如下所述:
一种深度强化学习训练方法,包括如下步骤:S1:指定情境数量,初始化深度强化学习多头神经网络模型的权重参数;智能体随机决策,收集样本存于经验回放缓冲区;S2:依据所述情境数量,采用在线聚类算法实现自适应情境划分,对当前时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心;S3:从所述经验回放缓冲区随机采样小批量样本,并依据各所述样本对应的状态与各所述情境中心的欧氏距离依次将各所述样本分配至距离最近的所述情境中;S4:依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数;S5:下一时间步,智能体依据所述值函数继续决策,收集样本存于所述经验回放缓冲区,重复进行所述自适应情境划分和所述深度强化学习多头神经网络模型的权重参数更新迭代,直至所述深度强化学习多头神经网络模型完成预先指定的训练次数或达到收敛。
优选地,指定所述情境数量k,其中,k>1;选用一个共享特征提取器和一组线性输出头组成的神经网络结构参数化值函数,每个线性输出头对应于一个特定情境;初始化所述深度强化学习多头神经网络模型的权重参数
Figure BDA0002950046260000021
其中,
Figure BDA0002950046260000022
为共享特征提取器参数,
Figure BDA0002950046260000023
为对应当前训练样本所属情境的输出头参数,
Figure BDA0002950046260000024
为其他输出头参数;将单个强化学习环境划分为k个情境,并针对每个所述情境中包含的状态采用所述深度强化学习多头神经网络分别对每个所述情境进行值函数估计。
优选地,收集样本存于所述经验回放缓冲区,在t时刻收集到的样本表示为{st,at,rt,st+1}其中,st为t时刻智能体所处的环境状态,at为t时刻智能体所采取的动作,rt为采取动作at后环境反馈的奖励,st+1为采取动作后智能体达到的t+1时刻环境状态,经验回放缓冲区表示为
Figure BDA0002950046260000031
大小为N。
优选地,对所述深度强化学习多头神经网络模型训练过程中经历的所有状态进行划分得到有限个簇,每个所述簇称为一个情境ω,Ω={ω1,ω2,...,ωk}为划分得到的所述情境的有限集合。
优选地,利用Sequential K-Means算法对当前时刻t智能体所处的环境状态st进行在线聚类实现自适应情境划分,得到情境划分中心为
Figure BDA0002950046260000032
其中,ci为第i个情境ωi的聚类中心。
优选地,对每个时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心,具体操作包括:当前时刻t下,各所述情境中心为
Figure BDA0002950046260000033
每个所述情境内包含的状态个数为
Figure BDA0002950046260000034
并分别计算当前时刻智能体所处的环境状态st与各个所述情境中心ci之间的距离d(st,ci),将当前时刻智能体所处的环境状态st分配至距离最近的所述情境ωj中;然后相应地更新所述情境ωj包含的状态个数nj=nj+1及所述情境划分中心
Figure BDA0002950046260000035
其中,i∈{1,2,...,k},j=arg minid(st,ci)。
优选地,从所述经验回放缓冲区
Figure BDA0002950046260000036
中随机采样m个样本,并依据各所述样本对应的状态与各所述情境中心的欧氏距离得到相应的所属情境
Figure BDA0002950046260000037
所述样本总体表示为:
Figure BDA0002950046260000038
优选地,依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数包括:所述深度强化学习多头神经网络模型值函数估计的原始损失函数为
Figure BDA0002950046260000039
将知识蒸馏目标网络与原始值函数估计目标网络设为一致,在每次迭代中,蒸馏目标为
Figure BDA0002950046260000041
预测输出为
Figure BDA0002950046260000042
优选地,所述深度强化学习多头神经网络模型是DQN算法,所述深度强化学习多头神经网络模型值函数估计的原始损失函数
Figure BDA0002950046260000043
为:
Figure BDA0002950046260000044
当前时刻智能体所处的环境状态st对应的输出头的蒸馏损失为:
Figure BDA0002950046260000045
其他输出头对应的蒸馏损失为:
Figure BDA0002950046260000046
联合优化损失函数如下:
Figure BDA0002950046260000047
其中,λ∈[0,1]为控制深度强化学习多头神经网络模型可塑性和稳定性平衡系数。
本发明还提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上任一所述方法的步骤。
本发明的有益效果为:提供一种深度强化学习训练方法及计算机可读存储介质,通过基于情境划分和知识蒸馏的深度强化学习训练框架,联合在线聚类算法实现自适应情境划分,并采用蒸馏损失和多头神经网络架构对各情境下深度强化学习模型值函数进行更有针对性地估计,大幅提升深度强化学习模型训练的稳定性和可塑性。
进一步地,本发明的方法可方便地整合到各类现有的基于值的深度强化学习模型中,大大减少模型训练过程中因训练数据分布漂移而产生的灾难性干扰和遗忘,并显著降低现有模型对计算内存的需求,大幅提升模型的稳定性和可塑性,提升其在各类强化学习任务上的性能,具有很强的通用应用价值。
附图说明
图1是本发明实施例中强化学习模型训练过程中数据分布漂移示例图。
图2是本发明实施例中一种深度强化学习训练方法的示意图。
图3是本发明实施例中一种基于情境划分和知识蒸馏的深度强化学习训练框架示意图。
图4(a)-图4(c)分别是本发明实施例中对现有深度强化学习代表性算法DQN及利用本发明的方法分别在经验回放缓冲区存储容量为1、100和50000下训练OpenAI Gym经典控制游戏CartPole-v1所获得的累计奖励曲线展示。
具体实施方式
为了使本发明实施例所要解决的技术问题、技术方案及有益效果更加清楚明白,以下结合附图及实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
需要说明的是,当元件被称为“固定于”或“设置于”另一个元件,它可以直接在另一个元件上或者间接在另一个元件上。当一个元件被称为是“连接于”另一个元件,它可以是直接连接到另一个元件或间接连接至另一个元件上。另外,连接既可以是用于固定作用也可以是用于电路连通作用。
需要理解的是,术语“长度”、“宽度”、“上”、“下”、“前”、“后”、“左”、“右”、“竖直”、“水平”、“顶”、“底”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明实施例和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。
此外,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多该特征。在本发明实施例的描述中,“多个”的含义是两个或两个以上,除非另有明确具体的限定。
如图1所示,为强化学习模型训练过程中数据分布漂移示例图。实线记录了模型训练过程中遇到的状态的分布情况,虚线展示了相应时刻的模型训练性能。该图显示了模型训练过程中数据分布及模型训练性能的动态变化,揭示了干扰和遗忘产生的内部机理。在T3时刻之前,随着模型训练,数据分布逐步从P1转移至P2再至P3,当神经网络逐步拟合至P3,模型权重被很大程度地更新,使得已学习的P1和P2分布上的信息受到干扰甚至被完全覆盖,因此,当智能体再次遇到P1分布的状态时会突然无法做出正确决策,从而导致模型性能突然下降,此时模型必须在P1分布上重新学习。图1显示了训练数据分布漂移导致的灾难性干扰和遗忘伴随着模型性能的急剧波动。
本发明的目的是为了解决现有基于值的深度强化学习模型训练过程中普遍存在的因训练数据分布漂移引起神经网络出现灾难性干扰和遗忘,从而导致模型训练过程中性能非常不稳定,甚至无法学习到优策略的问题。
如图2和图3所示,为了解决上述技术问题,本发明提供一种深度强化学习训练方法,包括如下步骤:
S1:指定情境数量,初始化深度强化学习多头神经网络模型的权重参数;智能体随机决策,收集样本存于经验回放缓冲区;
S2:依据所述情境数量,采用在线聚类算法实现自适应情境划分,对当前时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心;
S3:从所述经验回放缓冲区随机采样小批量样本,并依据各所述样本对应的状态与各所述情境中心的欧氏距离依次将各所述样本分配至距离最近的所述情境中;
S4:依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数;
S5:下一时间步,智能体依据所述值函数继续决策,收集样本存于所述经验回放缓冲区,重复进行所述自适应情境划分和所述深度强化学习多头神经网络模型的权重参数更新迭代,直至所述深度强化学习多头神经网络模型完成预先指定的训练次数或达到收敛。
本发明的方法通过基于情境划分和知识蒸馏的深度强化学习训练框架,联合在线聚类算法实现自适应情境划分,并采用蒸馏损失和多头神经网络架构对各情境下强化学习模型值函数进行更有针对性地估计,大幅提升强化学习模型训练的稳定性和可塑性。
进一步地,与现有技术相比,本发明具有如下有益效果:
大大减少模型训练过程中因训练数据分布漂移而产生的灾难性干扰和遗忘,大幅提升模型的稳定性和可塑性;
显著降低现有模型对计算内存的需求;
可方便地整合到各类现有的基于值的深度强化学习模型中,提升其在各类强化学习任务上的性能,具有很强的通用应用价值。
在一种具体的实施例中,指定情境数量k,其中,k>1;选用一个共享特征提取器和一组线性输出头组成的神经网络结构参数化值函数,每个线性输出头对应于一个特定情境;初始化深度强化学习多头神经网络模型的权重参数
Figure BDA0002950046260000071
其中,
Figure BDA0002950046260000072
为共享特征提取器参数,
Figure BDA0002950046260000073
为对应当前训练样本所属情境的输出头参数,
Figure BDA0002950046260000074
为其他输出头参数;将单个强化学习环境划分为k个情境,并针对每个情境中包含的状态采用深度强化学习多头神经网络分别对每个情境进行值函数估计。
本发明将针对一个特定任务估计一个值函数的问题转化为针对任务中包含的多个情境分别估计一个单独的值函数问题,从而解耦不同情境包含的状态间的干扰;同时共享特征提取器也能最大限度地促进不同情境在特征提取层的正向泛化,加速训练进程。
可以理解的是,理论上,k值越大则代表情境划分粒度越细,在训练次数足够大的情况下,所获得的模型性能越好,但k值太大会造成神经网络输出头太多从而导致模型非常复杂,增加了训练难度,因此k值也不宜过大。依据经验,对任务情境进行粗略划分,如k=3~5,即可实现明显地模型训练性能改进。
本发明中只需要指定一个超参数来控制情境划分的细粒度,从而确保本发明的方法在实践中的可用性。
收集样本存于经验回放缓冲区,在t时刻收集到的样本表示为{st,at,rt,st+1};其中,st为t时刻智能体所处的环境状态,at为t时刻智能体所采取的动作,rt为采取动作at后环境反馈的奖励,st+1为采取动作后智能体达到的t+1时刻环境状态,经验回放缓冲区表示为
Figure BDA0002950046260000075
大小为N。
对深度强化学习多头神经网络模型训练过程中经历的所有状态进行划分得到有限个簇,每个簇称为一个情境ω,Ω={ω1,ω2,...,ωk}为划分得到的情境的有限集合。本发明的优点是将环境状态空间依距离相似性进行划分,以便在训练阶段对各个状态子空间(情境)分别进行值函数估计,实现不同分布状态间的解耦,从而减少模型训练过程中灾难性干扰和遗忘的发生。
在本发明中,将单个强化学习环境划分为若干个情境,并针对每个情境中包含的状态采用共享特征提取器和一组特定于单个情境的线性输出头组成的神经网络结构进行值函数估计;多个情境间共享特征提取层,提升特征提取层训练效率,加速训练进程。
情境划分:在本发明的一种实施例中,利用Sequential K-Means算法对当前时刻t智能体所处的环境状态st进行在线聚类实现自适应情境划分,得到情境划分中心为
Figure BDA0002950046260000081
其中,ci为第i个情境ωi的聚类中心。本发明的优点是实现自适应情境划分过程,并依据训练进程,实时更新情境划分边界。
进一步地,对每个时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心,具体操作包括:
当前时刻t下,各情境中心为
Figure BDA0002950046260000082
每个情境内包含的状态个数为
Figure BDA0002950046260000083
并分别计算当前时刻智能体所处的环境状态st与各个情境中心ci之间的距离d(st,ci),将当前时刻智能体所处的环境状态st分配至距离最近的所述情境ωj中;然后相应地更新所述情境ωj包含的状态个数nj=nj+1及所述情境划分中心
Figure BDA0002950046260000084
其中,i∈{1,2,...,k},j=argminid(st,ci)。
本发明中采用多头神经网络结构对每个情境中包含的状态对应的值函数分别进行估计,解耦不同情境状态间对神经网络训练的干扰,提升网络训练稳定性和可塑性。
状态分配:从经验回放缓冲区
Figure BDA0002950046260000085
中随机采样m个样本,并依据各样本对应的状态与各情境中心的欧氏距离得到相应的所属情境
Figure BDA0002950046260000086
样本总体表示为:
Figure BDA0002950046260000091
联合优化:依据样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数包括:
所述深度强化学习多头神经网络模型值函数估计的原始损失函数为
Figure BDA0002950046260000092
将知识蒸馏目标网络与原始值函数估计目标网络设为一致,在每次迭代中,蒸馏目标为
Figure BDA0002950046260000093
预测输出为
Figure BDA0002950046260000094
本发明采用知识蒸馏正则化损失对神经网络参数进行优化,最大限度保留网络已学习的知识。知识蒸馏包含两项内容,分别是训练当前输入状态对应输出头的蒸馏损失和同步更新其他输出头的蒸馏损失,其中,前项蒸馏损失表示为
Figure BDA0002950046260000095
后项蒸馏损失表示为
Figure BDA0002950046260000096
保证了模型在学习新知识的同时最大限度保留神经网络已学习到的知识,进一步减少灾难性干扰和遗忘。
深度强化学习多头神经网络模型可以为任一基于值函数的深度强化学习模型,以深度强化学习代表性算法DQN为例,深度强化学习多头神经网络模型值函数估计的原始损失函数
Figure BDA0002950046260000097
为:
Figure BDA0002950046260000098
当前时刻智能体所处的环境状态st对应的输出头的蒸馏损失为:
Figure BDA0002950046260000099
其他输出头对应的蒸馏损失为:
Figure BDA00029500462600000910
联合优化损失函数如下:
Figure BDA00029500462600000911
其中,λ∈[0,1]为控制深度强化学习多头神经网络模型可塑性和稳定性平衡系数。
智能体依据所得值函数进行下一步决策,存于所述经验回放缓冲区,重复进行自适应情境划分和深度强化学习多头神经网络模型的权重参数更新迭代,直至深度强化学习多头神经网络模型完成预先指定的训练次数T或达到收敛,最终得到训练好的模型参数θ和情境划分中心
Figure BDA0002950046260000101
模型部署:指导智能体决策以完成相应任务。对于当前状态s,首先依据其与各情境中心距离判断其所属情境:
Figure BDA0002950046260000102
计算第j个输出头对应的Q值
Figure BDA0002950046260000103
智能体依据ε-greedy策略进行决策:
Figure BDA0002950046260000104
如图4所示,是在不同大小的经验回放缓冲区容量设置下,分别采用深度强化学习代表性算法DQN及利用本发明提出的训练框架训练的DQN在OpenAI Gym经典控制游戏CartPole-v1上训练所获得的累计奖励曲线展示。从图中可以看出,在不同大小的缓冲区容量下,原始的DQN方法在训练过程中都出现了非常明显的灾难性遗忘和性能波动,尤其是当缓冲区容量非常小(为1)时,DQN模型根本无法学习到最优策略而实现最大的累计奖励。对比之下,融合了本发明提出的训练框架的DQN方法在不同缓冲区容量设置下的训练性能都稳定得多,且即使是缓冲区容量设置为1时也还是可以学习到完成任务的最优策略,实现最大的累计奖励。
如表1所示,是对图3中所示两种方法训练曲线分别就训练过程中所达到的最大累计奖励和最大的累计奖励下降比例两种指标进行统计所得的结果。
表1统计所得的结果
Figure BDA0002950046260000105
从表中所示结果可以再次印证对图3分析所得出的结论。不论是在哪种缓冲区容量设置下,融合本发明所提出的训练框架的DQN方法都获得了最大的累计奖励(即,融合本发明提出的训练框架的DQN模型具有很强的可塑性),并且在训练过程中,累计奖励波动的最大值都比原始的DQN小得多(即,融合本发明提出的训练框架的DQN模型具有很好的稳定性)。
本申请实施例还提供一种控制装置,包括处理器和用于存储计算机程序的存储介质;其中,处理器用于执行所述计算机程序时至少执行如上所述的方法。
本申请实施例还提供一种存储介质,用于存储计算机程序,该计算机程序被执行时至少执行如上所述的方法。
本申请实施例还提供一种处理器,所述处理器执行计算机程序,至少执行如上所述的方法。
所述存储介质可以由任何类型的易失性或非易失性存储设备、或者它们的组合来实现。其中,非易失性存储器可以是只读存储器(ROM,Read Only Memory)、可编程只读存储器(PROM,Programmable Read-Only Memory)、可擦除可编程只读存储器(EPROM,ErasableProgrammable Read-Only Memory)、电可擦除可编程只读存储器(EEPROM,ElectricallyErasable Programmable Read-Only Memory)、磁性随机存取存储器(FRAM,FerromagneticRandom Access Memory)、快闪存储器(Flash Memory)、磁表面存储器、光盘、或只读光盘(CD-ROM,Compact Disc Read-Only Memory);磁表面存储器可以是磁盘存储器或磁带存储器。易失性存储器可以是随机存取存储器(RAM,Random Access Memory),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(SRAM,Static Random Access Memory)、同步静态随机存取存储器(SSRAM,SynchronousStatic Random Access Memory)、动态随机存取存储器(DRAM,Dynamic Random AccessMemory)、同步动态随机存取存储器(SDRAM,Synchronous Dynamic Random AccessMemory)、双倍数据速率同步动态随机存取存储器(DDRSDRAM,Double Data RateSynchronous Dynamic Random Access Memory)、增强型同步动态随机存取存储器(ESDRAM,Enhanced Synchronous Dynamic Random Access Memory)、同步连接动态随机存取存储器(SLDRAM,Sync Link Dynamic Random Access Memory)、直接内存总线随机存取存储器(DRRAM,Direct Rambus Random Access Memory)。本发明实施例描述的存储介质旨在包括但不限于这些和任意其它适合类型的存储器。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统和方法,可以通过其它的方式实现。以上所描述的设备实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,如:多个单元或组件可以结合,或可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的各组成部分相互之间的耦合、或直接耦合、或通信连接可以是通过一些接口,设备或单元的间接耦合或通信连接,可以是电性的、机械的或其它形式的。
上述作为分离部件说明的单元可以是、或也可以不是物理上分开的,作为单元显示的部件可以是、或也可以不是物理单元,即可以位于一个地方,也可以分布到多个网络单元上;可以根据实际的需要选择其中的部分或全部单元来实现本实施例方案的目的。
另外,在本发明各实施例中的各功能单元可以全部集成在一个处理单元中,也可以是各单元分别单独作为一个单元,也可以两个或两个以上单元集成在一个单元中;上述集成的单元既可以采用硬件的形式实现,也可以采用硬件加软件功能单元的形式实现。
本领域普通技术人员可以理解:实现上述方法实施例的全部或部分步骤可以通过程序指令相关的硬件来完成,前述的程序可以存储于一计算机可读取存储介质中,该程序在执行时,执行包括上述方法实施例的步骤;而前述的存储介质包括:移动存储设备、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
或者,本发明上述集成的单元如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明实施例的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机、服务器、或者网络设备等)执行本发明各个实施例所述方法的全部或部分。而前述的存储介质包括:移动存储设备、ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
本申请所提供的几个方法实施例中所揭露的方法,在不冲突的情况下可以任意组合,得到新的方法实施例。
本申请所提供的几个产品实施例中所揭露的特征,在不冲突的情况下可以任意组合,得到新的产品实施例。
本申请所提供的几个方法或设备实施例中所揭露的特征,在不冲突的情况下可以任意组合,得到新的方法实施例或设备实施例。
以上内容是结合具体的优选实施方式对本发明所做的进一步详细说明,不能认定本发明的具体实施只局限于这些说明。对于本发明所属技术领域的技术人员来说,在不脱离本发明构思的前提下,还可以做出若干等同替代或明显变型,而且性能或用途相同,都应当视为属于本发明的保护范围。

Claims (10)

1.一种深度强化学习训练方法,其特征在于,包括如下步骤:
S1:指定情境数量,初始化深度强化学习多头神经网络模型的权重参数;智能体随机决策,收集样本存于经验回放缓冲区;
S2:依据所述情境数量,采用在线聚类算法实现自适应情境划分,对当前时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心;
S3:从所述经验回放缓冲区随机采样小批量样本,并依据各所述样本对应的状态与各所述情境中心的欧氏距离依次将各所述样本分配至距离最近的所述情境中;
S4:依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数;
S5:下一时间步,智能体依据所述值函数继续决策,收集样本存于所述经验回放缓冲区,重复进行所述自适应情境划分和所述深度强化学习多头神经网络模型的权重参数更新迭代,直至所述深度强化学习多头神经网络模型完成预先指定的训练次数或达到收敛。
2.如权利要求1所述的深度强化学习训练方法,其特征在于,指定所述情境数量k,其中,k>1;选用一个共享特征提取器和一组线性输出头组成的神经网络结构参数化值函数,每个线性输出头对应于一个特定情境;初始化所述深度强化学习多头神经网络模型的权重参数
Figure FDA0002950046250000011
其中,
Figure FDA0002950046250000012
为共享特征提取器参数,
Figure FDA0002950046250000013
为对应当前训练样本所属情境的输出头参数,
Figure FDA0002950046250000014
为其他输出头参数;将单个强化学习环境划分为k个情境,并针对每个所述情境中包含的状态采用所述深度强化学习多头神经网络分别对每个所述情境进行值函数估计。
3.如权利要求2所述的深度强化学习训练方法,其特征在于,收集样本存于所述经验回放缓冲区,在t时刻收集到的样本表示为{st,at,rt,st+1};
其中,st为t时刻智能体所处的环境状态,at为t时刻智能体所采取的动作,rt为采取动作at后环境反馈的奖励,st+1为采取动作后智能体达到的t+1时刻环境状态,经验回放缓冲区表示为
Figure FDA0002950046250000015
大小为N。
4.如权利要求3所述的深度强化学习训练方法,其特征在于,对所述深度强化学习多头神经网络模型训练过程中经历的所有状态进行划分得到有限个簇,每个所述簇称为一个情境ω,Ω={ω1,ω2,...,ωk}为划分得到的所述情境的有限集合。
5.如权利要求4所述的深度强化学习训练方法,其特征在于,利用Sequential K-Means算法对当前时刻t智能体所处的环境状态st进行在线聚类实现自适应情境划分,得到情境划分中心为
Figure FDA0002950046250000021
其中,ci为第i个情境ωi的聚类中心。
6.如权利要求5所述的深度强化学习训练方法,其特征在于,对每个时间步状态进行在线聚类,自适应进行情境推断,得到截止当前时刻的情境划分和各情境中心,具体操作包括:
当前时刻t下,各所述情境中心为
Figure FDA0002950046250000022
每个所述情境内包含的状态个数为
Figure FDA0002950046250000023
并分别计算当前时刻智能体所处的环境状态st与各个所述情境中心ci之间的距离d(st,ci),将当前时刻智能体所处的环境状态st分配至距离最近的所述情境ωj中;然后相应地更新所述情境ωj包含的状态个数nj=nj+1及所述情境划分中心
Figure FDA0002950046250000024
其中,i∈{1,2,...,k},j=arg minid(st,ci)。
7.如权利要求6所述的深度强化学习训练方法,其特征在于,从所述经验回放缓冲区
Figure FDA0002950046250000027
中随机采样m个样本,并依据各所述样本对应的状态与各所述情境中心的欧氏距离得到相应的所属情境
Figure FDA0002950046250000025
所述样本总体表示为:
Figure FDA0002950046250000026
8.如权利要求7所述的深度强化学习训练方法,其特征在于,依据所述样本对应情境训练共享特征提取器及相应输出头的权重参数,并结合知识蒸馏损失对其他输出头权重参数进行同步更新,估计值函数包括:
所述深度强化学习多头神经网络模型值函数估计的原始损失函数为
Figure FDA0002950046250000031
将知识蒸馏目标网络与原始值函数估计目标网络设为一致,在每次迭代中,蒸馏目标为
Figure FDA0002950046250000032
预测输出为
Figure FDA0002950046250000033
9.如权利要求8所述的深度强化学习训练方法,其特征在于,所述深度强化学习多头神经网络模型是DQN算法,所述深度强化学习多头神经网络模型值函数估计的原始损失函数
Figure FDA0002950046250000034
为:
Figure FDA0002950046250000035
当前时刻智能体所处的环境状态st对应的输出头的蒸馏损失为:
Figure FDA0002950046250000036
其他输出头对应的蒸馏损失为:
Figure FDA0002950046250000037
联合优化损失函数如下:
Figure FDA0002950046250000038
其中,λ∈[0,1]为控制深度强化学习多头神经网络模型可塑性和稳定性平衡系数。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1-9任一所述方法的步骤。
CN202110208061.7A 2021-02-24 2021-02-24 一种深度强化学习训练方法及计算机可读存储介质 Pending CN112819159A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110208061.7A CN112819159A (zh) 2021-02-24 2021-02-24 一种深度强化学习训练方法及计算机可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110208061.7A CN112819159A (zh) 2021-02-24 2021-02-24 一种深度强化学习训练方法及计算机可读存储介质

Publications (1)

Publication Number Publication Date
CN112819159A true CN112819159A (zh) 2021-05-18

Family

ID=75865483

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110208061.7A Pending CN112819159A (zh) 2021-02-24 2021-02-24 一种深度强化学习训练方法及计算机可读存储介质

Country Status (1)

Country Link
CN (1) CN112819159A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113222035A (zh) * 2021-05-20 2021-08-06 浙江大学 基于强化学习和知识蒸馏的多类别不平衡故障分类方法
CN113449867A (zh) * 2021-07-02 2021-09-28 电子科技大学 一种基于知识蒸馏的深度强化学习多智能体协作方法
CN113836788A (zh) * 2021-08-24 2021-12-24 浙江大学 基于局部数据增强的流程工业强化学习控制的加速方法
CN115816466A (zh) * 2023-02-02 2023-03-21 中国科学技术大学 一种提升视觉观测机器人控制稳定性的方法
CN113269315B (zh) * 2021-06-29 2024-04-02 安徽寒武纪信息科技有限公司 利用深度强化学习执行任务的设备、方法及可读存储介质

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113222035A (zh) * 2021-05-20 2021-08-06 浙江大学 基于强化学习和知识蒸馏的多类别不平衡故障分类方法
CN113222035B (zh) * 2021-05-20 2021-12-31 浙江大学 基于强化学习和知识蒸馏的多类别不平衡故障分类方法
CN113269315B (zh) * 2021-06-29 2024-04-02 安徽寒武纪信息科技有限公司 利用深度强化学习执行任务的设备、方法及可读存储介质
CN113449867A (zh) * 2021-07-02 2021-09-28 电子科技大学 一种基于知识蒸馏的深度强化学习多智能体协作方法
CN113836788A (zh) * 2021-08-24 2021-12-24 浙江大学 基于局部数据增强的流程工业强化学习控制的加速方法
CN113836788B (zh) * 2021-08-24 2023-10-27 浙江大学 基于局部数据增强的流程工业强化学习控制的加速方法
CN115816466A (zh) * 2023-02-02 2023-03-21 中国科学技术大学 一种提升视觉观测机器人控制稳定性的方法

Similar Documents

Publication Publication Date Title
CN112819159A (zh) 一种深度强化学习训练方法及计算机可读存储介质
Hernandez-Leal et al. A survey and critique of multiagent deep reinforcement learning
Lin et al. Episodic memory deep q-networks
CN108776483B (zh) 基于蚁群算法和多智能体q学习的agv路径规划方法和系统
Wang et al. Towards cooperation in sequential prisoner's dilemmas: a deep multiagent reinforcement learning approach
CN111104595A (zh) 一种基于文本信息的深度强化学习交互式推荐方法及系统
CN111352419B (zh) 基于时序差分更新经验回放缓存的路径规划方法及系统
CN112734014A (zh) 基于置信上界思想的经验回放采样强化学习方法及系统
CN113487039A (zh) 基于深度强化学习的智能体自适应决策生成方法及系统
CN116596060A (zh) 深度强化学习模型训练方法、装置、电子设备及存储介质
CN117707795B (zh) 基于图的模型划分的边端协同推理方法及系统
CN112131089B (zh) 软件缺陷预测的方法、分类器、计算机设备及存储介质
CN116227579A (zh) 一种对离散环境基于值的强化学习训练的优化方法
CN113469369B (zh) 一种面向多任务强化学习的缓解灾难性遗忘的方法
CN115903901A (zh) 内部状态未知的无人集群系统输出同步优化控制方法
Smith et al. Co-Learning Empirical Games and World Models
CN114912518A (zh) 基于用户群体典型特征的强化学习分组方法、装置及介质
CN112906435B (zh) 视频帧优选方法及装置
Li et al. CoAxNN: Optimizing on-device deep learning with conditional approximate neural networks
CN113721655A (zh) 一种控制周期自适应的强化学习无人机稳定飞行控制方法
Venturini Distributed deep reinforcement learning for drone swarm control
CN111612572A (zh) 一种基于推荐系统的自适应局部低秩矩阵近似建模方法
CN114612750B (zh) 自适应学习率协同优化的目标识别方法、装置及电子设备
Gupta Obedience-based multi-agent cooperation for sequential social Dilemmas
Chan et al. Dynamic fusion for ensemble of deep Q-network

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination