CN116596060A - 深度强化学习模型训练方法、装置、电子设备及存储介质 - Google Patents
深度强化学习模型训练方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN116596060A CN116596060A CN202310884815.XA CN202310884815A CN116596060A CN 116596060 A CN116596060 A CN 116596060A CN 202310884815 A CN202310884815 A CN 202310884815A CN 116596060 A CN116596060 A CN 116596060A
- Authority
- CN
- China
- Prior art keywords
- function
- network
- loss
- dominance
- value
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 75
- 230000002787 reinforcement Effects 0.000 title claims abstract description 68
- 238000012549 training Methods 0.000 title claims abstract description 48
- 230000006870 function Effects 0.000 claims abstract description 167
- 230000009471 action Effects 0.000 claims abstract description 85
- 230000000875 corresponding effect Effects 0.000 claims description 33
- 238000004590 computer program Methods 0.000 claims description 20
- 230000003993 interaction Effects 0.000 claims description 9
- 238000004364 calculation method Methods 0.000 claims description 7
- 238000000342 Monte Carlo simulation Methods 0.000 claims description 4
- 230000007613 environmental effect Effects 0.000 claims description 3
- 230000000052 comparative effect Effects 0.000 claims 1
- 230000000694 effects Effects 0.000 abstract description 10
- 238000004422 calculation algorithm Methods 0.000 description 25
- 230000008569 process Effects 0.000 description 20
- 238000005070 sampling Methods 0.000 description 8
- 230000008901 benefit Effects 0.000 description 7
- 238000005457 optimization Methods 0.000 description 5
- 238000012545 processing Methods 0.000 description 5
- 238000010586 diagram Methods 0.000 description 4
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 230000006399 behavior Effects 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 238000003062 neural network model Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 230000001276 controlling effect Effects 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000087 stabilizing effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/092—Reinforcement learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请提供一种深度强化学习模型训练方法、装置、电子设备及存储介质。该方法包括:对深度强化学习模型的策略网络和价值网络进行初始化,将策略网络与环境进行互动,生成一系列的轨迹记录;依据轨迹记录,确定每个状态下执行动作时对应的优势函数;计算优势函数的原始损失,并利用各个状态对应的优势函数值计算对比学习损失,依据原始损失以及对比学习损失计算新的损失函数;利用新的损失函数对价值网络的参数进行更新,并将更新后的价值网络的参数固定,对策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至深度强化学习模型收敛。本申请提高深度强化学习模型的学习效率,减少计算资源消耗,提升了模型的泛化能力和效果。
Description
技术领域
本申请涉及计算机技术领域,尤其涉及一种深度强化学习模型训练方法、装置、电子设备及存储介质。
背景技术
深度强化学习是一种模仿人类思维方式,通过观察并反馈结果以提高决策性能的机器学习技术。在深度强化学习中,优势演员-评论家算法(Advantage Actor-Critic,A2C)被广泛应用,该算法能够学习并处理复杂任务。A2C算法在传统的演员-评论家模型基础上,增加了对优势函数的考虑。优势函数能够揭示执行某个动作相较于遵循当前策略的平均性能的优势程度,通过对这个优势函数的引入,策略的调整更加有效,进而提高任务的完成效率。
然而,现有的强化学习算法,包括A2C算法,通常需要大量的训练数据来完成模型训练,这就导致了训练时间较长,计算资源消耗较大的问题。同时,模型在训练过程中的收敛稳定性较差,可能出现训练不稳定的情况,使得模型在新环境的泛化能力以及模型的效果复现性较差。另一方面,由于A2C需要同时训练两个网络,即策略网络和价值网络,因此这可能增加模型训练的复杂性,并可能进一步影响训练的稳定性。
发明内容
有鉴于此,本申请实施例提供了一种深度强化学习模型训练方法、装置、电子设备及存储介质,以解决现有技术存在的需要大量训练数据、模型训练稳定性差、模型在新环境中的泛化能力以及模型效果复现性差的问题。
本申请实施例的第一方面,提供了一种深度强化学习模型训练方法,包括:对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型;将策略网络与环境进行互动,生成一系列的轨迹记录,其中轨迹记录中包含动作以及动作对应的奖励和状态;依据轨迹记录,确定每个状态下执行动作时对应的优势函数,其中优势函数用于表征执行动作产生预期奖励相比平均奖励的优势程度;计算优势函数的原始损失,并利用各个状态对应的优势函数值计算对比学习损失,依据原始损失以及对比学习损失计算新的损失函数;利用新的损失函数对价值网络的参数进行更新,并将更新后的价值网络的参数固定,对策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至深度强化学习模型收敛。
本申请实施例的第二方面,提供了一种深度强化学习模型训练装置,包括:初始化模块,被配置为对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型;生成模块,被配置为将策略网络与环境进行互动,生成一系列的轨迹记录,其中轨迹记录中包含动作以及动作对应的奖励和状态;确定模块,被配置为依据轨迹记录,确定每个状态下执行动作时对应的优势函数,其中优势函数用于表征执行动作产生预期奖励相比平均奖励的优势程度;计算模块,被配置为计算优势函数的原始损失,并利用各个状态对应的优势函数值计算对比学习损失,依据原始损失以及对比学习损失计算新的损失函数;更新模块,被配置为利用新的损失函数对价值网络的参数进行更新,并将更新后的价值网络的参数固定,对策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至深度强化学习模型收敛。
本申请实施例的第三方面,提供了一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,处理器执行计算机程序时实现上述方法的步骤。
本申请实施例的第四方面,提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机程序,该计算机程序被处理器执行时实现上述方法的步骤。
本申请实施例采用的上述至少一个技术方案能够达到以下有益效果:
通过对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型;将策略网络与环境进行互动,生成一系列的轨迹记录,其中轨迹记录中包含动作以及动作对应的奖励和状态;依据轨迹记录,确定每个状态下执行动作时对应的优势函数,其中优势函数用于表征执行动作产生预期奖励相比平均奖励的优势程度;计算优势函数的原始损失,并利用各个状态对应的优势函数值计算对比学习损失,依据原始损失以及对比学习损失计算新的损失函数;利用新的损失函数对价值网络的参数进行更新,并将更新后的价值网络的参数固定,对策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至深度强化学习模型收敛。本申请提高深度强化学习模型的学习效率,减少计算资源消耗,提高模型在新环境中的泛化能力以及模型的效果复现性。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1是本申请实施例提供的深度强化学习模型训练方法的流程示意图;
图2是本申请实施例提供的深度强化学习模型训练装置的结构示意图;
图3是本申请实施例提供的电子设备的结构示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本申请实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本申请。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本申请的描述。
随着科技的发展,强化学习的应用场景逐渐增多,例如在机器人控制、自动驾驶、游戏、无人机以及金融交易等领域都有广泛的应用。具体来说,强化学习可以用于控制机器人以实现机器人的自主学习和行动;在自动驾驶领域,使车辆能够在复杂的环境中安全驾驶;在游戏领域,使游戏角色能够从自身的行为中学习,从而更好地控制游戏;在无人机控制领域,让无人机能够自主学习和行动,以实现更好的空中飞行姿态控制;在金融交易领域,可以用于提高金融交易的准确性和效率。
目前,常用的深度强化学习方法是优势演员-评论家算法(Advantage Actor-Critic,A2C)。A2C算法是基于优势函数的强化学习算法,是Actor-Critic算法的升级版,A2C算法在Actor-Critic算法的基础上增加了对优势函数的考虑。A2C算法的主要思想是,通过对策略函数和价值函数的同时估计和优化,来提高策略优化的效率。具体来说,A2C算法中有一个策略网络和一个价值网络,策略网络负责预测动作,价值网络负责预测当前状态优势(价值)函数值。策略网络和价值网络都是通过梯度下降算法来进行优化的。
尽管A2C算法在强化学习中的应用广泛且具有许多优点,例如它可以同时估计和优化策略,通过对优势函数的考虑使得策略优化更加有效,能够应用在高维环境中,并且能处理离散和连续动作空间。但是,该深度强化算法在实际应用中仍存在许多问题。具体来说,A2C算法在训练过程中需要同时优化策略网络和价值网络,这容易导致模型训练的稳定性差,影响模型在新环境中的泛化能力以及模型的效果复现性。同时,由于需要训练一个较好的优势估计函数,导致需要进行大量的采样与环境互动,使得训练时间较长,消耗的计算资源较大。
因此,亟需提出一种新的深度强化学习模型训练方法,以解决现有强化学习算法在训练稳定性、泛化能力、效果复现性以及训练效率上存在的问题,从而提高深度强化学习的实用性。
鉴于现有技术中存在的问题,本申请提供一种度强化学习模型训练方法及装置,通过将对比学习引入到A2C算法中,创新性地在优势函数网络的训练过程中加入正则项,从而有助于稳定模型训练过程,高效地利用数据。另外,对比学习损失的引入使得模型能够学习和理解数据的内在结构和模式,而不仅仅是数据本身的表现形式,因此可以提高模型的性能和泛化能力,即使在新的和未见过的环境中也能表现得较好。此外,通过对比学习损失和正则项的引入,可以更高效地利用数据。传统的强化学习方法通常需要大量的数据才能完成训练,而通过引入对比学习损失和正则项,本申请可以在更小的数据集上训练模型,同时还能提高模型的性能。
图1是本申请实施例提供的深度强化学习模型训练方法的流程示意图。图1的深度强化学习模型训练方法可以由服务器执行。如图1所示,该深度强化学习模型训练方法具体可以包括:
S101,对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型;
S102,将策略网络与环境进行互动,生成一系列的轨迹记录,其中轨迹记录中包含动作以及动作对应的奖励和状态;
S103,依据轨迹记录,确定每个状态下执行动作时对应的优势函数,其中优势函数用于表征执行动作产生预期奖励相比平均奖励的优势程度;
S104,计算优势函数的原始损失,并利用各个状态对应的优势函数值计算对比学习损失,依据原始损失以及对比学习损失计算新的损失函数;
S105,利用新的损失函数对价值网络的参数进行更新,并将更新后的价值网络的参数固定,对策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至深度强化学习模型收敛。
首先,对A2C(Advantage Actor-Critic)算法的整体工作流程及实现原理进行说明,A2C算法是一种基于优势函数(advantage function)的深度强化学习算法,A2C算法模型的训练过程主要包括以下步骤:
步骤1,初始化网络:A2C算法中有两个主要的网络,即策略网络(又被称为演员网络,负责生成动作)和价值网络(又被称为评论家网络,负责估计状态值函数)。
步骤2,互动与采样:策略网络与环境进行互动,生成动作并获取相应的奖励和新的状态。这样就得到了一条包含状态-动作-奖励(state-action-reward)的轨迹记录。
步骤3,计算优势函数:对于在每个状态下选择的动作,需要计算其对应的优势函数值。优势函数表示的是采取某个动作相比于当前策略平均表现的优势有多大。优势函数值可以通过将动作的实际奖励与价值网络预测的状态值之间的差计算得出。
步骤4,更新评论家网络:使用上述优势函数和实际得到的奖励来更新价值网络。通常情况下,模型希望价值网络能准确预测出在某个状态下能够得到的期望奖励。
步骤5,更新演员网络:利用计算出的优势函数来更新策略网络。模型希望优势函数较大的动作被策略网络选中的概率更高,因此策略网络的更新方向会偏向于提高这些动作的选择概率。
步骤6,迭代优化:重复上述过程,通过大量的互动、采样、计算优势函数和更新网络,使得演员网络(即策略网络)和评论家网络(即价值网络)逐步优化,直到达到一定的收敛条件,例如达到最大迭代次数、策略网络或者价值网络的预测值稳定等。
上述实施例介绍了A2C算法模型训练的基本工作流程,在A2C算法模型的训练过程中,通过策略网络生成动作,然后通过价值网络和优势函数评价这些动作,最后根据这些评价更新策略,使得总体的奖励最大化。
在一些实施例中,对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型,包括:利用预设参数或随机参数对深度强化学习模型中的策略网络和价值网络进行初始化,其中,策略网络用于决策动作的选择,价值网络用于计算执行相应动作的优势函数值。
具体地,在将对策略网络进行初始化时。策略网络是一个神经网络模型,它的功能是接收环境的状态作为输入,然后输出对应的动作选择。初始化策略网络的过程可以通过设定预设参数或随机参数进行,例如:这些参数可以是网络权重和偏置,还可以通过预设值设定,或者随机生成一个初始值。
进一步地,在将对价值网络进行初始化时。价值网络也是一个神经网络模型,它的功能是接收环境的状态和动作作为输入,然后输出执行该动作在该状态下的优势函数值。初始化价值网络的过程也可以通过设定预设参数或随机参数进行。通过以上操作,便可以得到初始化后的深度强化学习模型。在该模型中,策略网络和价值网络都已经被初始化,他们都具有一组初始参数。接下来,便可以通过训练过程,逐渐优化这些参数,使模型能够更好地执行任务。
在一些实施例中,将策略网络与环境进行互动,生成一系列的轨迹记录,包括:利用策略网络与环境进行互动,并采用时序差分或蒙特卡洛方法进行采样,产生由一系列状态、动作和奖励组合成的轨迹记录。
具体地,利用已经初始化的策略网络与环境进行互动。策略网络将根据当前的环境状态,输出一个动作,然后,将这个动作应用到环境中,环境将返回一个新的状态和一个奖励。在实际应用中,这个过程将被重复进行,直到达到一定的条件,例如达到一定的步数,或者达到一个终止状态。这样,就能得到一条由状态、动作和奖励组成的轨迹记录。
在一个示例中,本申请实施例可以采用时序差分(TD)方法或者蒙特卡洛(MC)方法进行采样。时序差分方法和蒙特卡洛方法是两种常用的强化学习采样方法。时序差分方法是一种基于动态规划思想的采样方法,它可以实时地更新价值函数的估计值。而蒙特卡洛方法则是一种基于经验平均的采样方法,它需要等待一条完整的轨迹结束后才能更新价值函数的估计值。
在一些实施例中,依据轨迹记录,确定每个状态下执行动作时对应的优势函数,包括:对轨迹记录中的每一个状态及对应执行的动作,将执行动作后的预期奖励与相同状态下按照当前策略执行动作的平均奖励进行比较,得到优势函数值,其中,优势函数值用于表征在状态下执行动作相较于按照当前策略执行动作的优势程度。
具体地,对于轨迹记录中的每一个状态及对应执行的动作,将执行动作后的预期奖励与在相同状态下,按照当前策略执行动作的平均奖励进行比较。通过这种比较,就得到了优势函数值。优势函数值用于表征在特定状态下执行特定动作相较于按照当前策略执行动作的优势程度。基于优势函数值可以理解在某个状态下,采取某个动作相比于其他动作的优势是多少。
进一步地,在得到每个状态下执行动作的优势函数值后,便可以进行策略网络的训练。在这个过程中,根据优势函数对每一个步骤的动作进行评价,如果某个动作的优势函数值较大,那么在训练策略网络时,就会更偏向于选择这个动作。通过优势函数的指导,本申请可以更加有效地训练策略网络,使其更好地理解在不同状态下,应该选择哪个动作才能获得更大的奖励。
在一些实施例中,计算优势函数的原始损失,包括:依据价值网络对每个状态下选择的动作的优势函数值的预测,以及利用从环境互动中实际获得的奖励计算出的实际优势函数值,计算优势函数预测值与实际优势函数值之间的平均平方误差,得到优势函数的原始损失。
具体地,本申请实施例利用价值网络预测出每个状态下选择的动作的优势函数值。然后,利用从环境互动中实际获得的奖励,计算出每个状态下选择的动作的实际优势函数值,这个实际优势函数值反映了当前策略在实际环境中的表现。
进一步地,基于预测的优势函数值和实际的优势函数值,可以计算出这两者之间的平均平方误差,将该平均平方误差作为优势函数的原始损失。在实际应用中,这个优势函数的原始损失反映了价值网络对优势函数的预测准确性。如果原始损失较小,说明价值网络能够准确地预测出每个状态下选择的动作的优势函数值,这对于优化策略网络和价值网络,提升深度强化学习模型的性能和效果具有重要意义。
在一些实施例中,利用各个状态对应的优势函数值计算对比学习损失,包括:利用策略网络预测出在各个状态下执行动作的优势函数值,将优势函数值的差距小于阈值的状态作为正样本,将优势函数值的差距大于阈值的状态作为负样本,基于正样本以及负样本,利用三元组损失方法计算得到对比学习损失。
具体地,本申请实施例引入额外的对比学习损失来提升对优势网络训练的稳定性和泛化能力。下面结合具体实施例对对比学习损失的计算过程及原理进行说明,具体可以包括以下内容:
本申请实施例将优势函数值的差距小于预设阈值的状态作为正样本,将优势函数值的差距大于预设阈值的状态作为负样本。例如,如果设定阈值为2,那么对于状态S1,S2和S3,其优势函数值分别为6,7和1,由于S1和S2的优势函数值差距(6-7=1)小于阈值2,因此将状态S1和S2作为正样本;而S1和S3的优势函数值差距(6-1=5)大于阈值2,因此将状态S1和S3作为负样本。
在得到正样本和负样本后,利用三元组损失方法来计算对比学习损失。三元组损失方法会尽量将正样本之间的距离缩小,将负样本之间的距离扩大,从而使得模型能够更好地区分正样本和负样本。例如,在实际应用中,可以使用triplet loss损失函数来计算对比学习损失,即loss_con = tripletloss(s1,s2,s3)。triplet loss是一种常用于计算相似度的损失函数,它会尽量拉近同类的样本(正样本),推远异类的样本(负样本)。
在一些实施例中,依据原始损失以及对比学习损失计算新的损失函数,包括采用以下公式计算新的损失函数:
;
其中,表示新的损失函数,/>表示优势函数的原始损失,/>表示超参数,/>表示对比学习损失。
具体地,本申请实施例在计算优势函数的损失时,基于计算预测优势函数值与实际优势函数值之间的平均平方误差(即原始损失),以及上述实施例计算得到的对比学习损失,利用权重(即超参数)将原始损失与对比学习损失进行求和,得到新的损失函数。这种计算损失的方法不仅利用了优势函数的预测误差来调整模型,还利用了对比学习损失来进一步优化模型,使模型在未见过的状态下也能做出有效的决策,因此能够有效提升模型的泛化能力和稳定性。
在一个示例中,在更新优势函数时,新的损失函数变为:
;
第一项为通常的优势函数损失,是基于估计值和采样均值计算得到的MSE损失(即原始损失),后一项是本申请实施例加入的对比学习损失,其中,表示状态S经过神经网络映射后得到的低维表征向量,要求模型的隐层表征向量在空间内的距离和最终的终局互动得分分数差成正相关关系。该正则项能在训练过程中抑制模型过拟合,提升估计的准确度,使得优势函数在隐空间的表征分布有良好的依据分数的聚类效果,更易策略模型学习。
进一步地,在得到新的损失函数之后,利用新的损失函数更新价值网络的参数,并且在将价值网络的参数固定的情况下,来更新策略网络的参数。在这个过程中,可以使用优化器(比如梯度下降等)来更新优势函数和策略网络的参数,这一过程会迭代多次,直到模型收敛。最后,使用更新过的策略网络继续和环境进行互动,并重复以上实施例的步骤,直到模型训练满足终止条件,例如达到预设的最大迭代次数或者模型收敛等。
需要说明的是,本申请实施例的策略模型、优势函数模型和特征提取器等网络结构可根据实际任务选择不同的网络结构,比如CNN,LSTM或Transformers等。
根据本申请实施例提供的技术方案,本申请实施例基于对比学习和优势函数,为深度强化学习模型中的训练提供了新的思路。通过创新地引入对比学习损失到优势函数网络的训练过程中,不仅考虑了预测值与实际值之间的误差,还考虑了不同状态之间的相似性,这极大地提升了模型的训练稳定性和泛化能力。本申请通过利用三元组损失方法进行对比学习,使得模型在优化过程中,能更好地区分和学习各种状态,从而有效提高策略网络的学习效果。因此,优势函数不仅能准确地预测各种状态下的优势程度,也能根据各种状态的相似性,有效地指导策略网络进行决策。本申请能够更高效地利用数据,显著提升了模型的泛化能力和效果。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
图2是本申请实施例提供的深度强化学习模型训练装置的结构示意图。如图2所示,该深度强化学习模型训练装置包括:
初始化模块201,被配置为对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型;
生成模块202,被配置为将策略网络与环境进行互动,生成一系列的轨迹记录,其中轨迹记录中包含动作以及动作对应的奖励和状态;
确定模块203,被配置为依据轨迹记录,确定每个状态下执行动作时对应的优势函数,其中优势函数用于表征执行动作产生预期奖励相比平均奖励的优势程度;
计算模块204,被配置为计算优势函数的原始损失,并利用各个状态对应的优势函数值计算对比学习损失,依据原始损失以及对比学习损失计算新的损失函数;
更新模块205,被配置为利用新的损失函数对价值网络的参数进行更新,并将更新后的价值网络的参数固定,对策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至深度强化学习模型收敛。
在一些实施例中,图2的初始化模块201利用预设参数或随机参数对深度强化学习模型中的策略网络和价值网络进行初始化,其中,策略网络用于决策动作的选择,价值网络用于计算执行相应动作的优势函数值。
在一些实施例中,图2的生成模块202利用策略网络与环境进行互动,并采用时序差分或蒙特卡洛方法进行采样,产生由一系列状态、动作和奖励组合成的轨迹记录。
在一些实施例中,图2的确定模块203对轨迹记录中的每一个状态及对应执行的动作,将执行动作后的预期奖励与相同状态下按照当前策略执行动作的平均奖励进行比较,得到优势函数值,其中,优势函数值用于表征在状态下执行动作相较于按照当前策略执行动作的优势程度。
在一些实施例中,图2的计算模块204依据价值网络对每个状态下选择的动作的优势函数值的预测,以及利用从环境互动中实际获得的奖励计算出的实际优势函数值,计算优势函数预测值与实际优势函数值之间的平均平方误差,得到优势函数的原始损失。
在一些实施例中,图2的计算模块204利用策略网络预测出在各个状态下执行动作的优势函数值,将优势函数值的差距小于阈值的状态作为正样本,将优势函数值的差距大于阈值的状态作为负样本,基于正样本以及负样本,利用三元组损失方法计算得到对比学习损失。
在一些实施例中,图2的计算模块204采用以下公式计算新的损失函数:
;
其中,表示新的损失函数,/>表示优势函数的原始损失,/>表示超参数,/>表示对比学习损失。
理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。
图3是本申请实施例提供的电子设备3的结构示意图。如图3所示,该实施例的电子设备3包括:处理器301、存储器302以及存储在该存储器302中并且可以在处理器301上运行的计算机程序303。处理器301执行计算机程序303时实现上述各个方法实施例中的步骤。或者,处理器301执行计算机程序303时实现上述各装置实施例中各模块/单元的功能。
示例性地,计算机程序303可以被分割成一个或多个模块/单元,一个或多个模块/单元被存储在存储器302中,并由处理器301执行,以完成本申请。一个或多个模块/单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述计算机程序303在电子设备3中的执行过程。
电子设备3可以是桌上型计算机、笔记本、掌上电脑及云端服务器等电子设备。电子设备3可以包括但不仅限于处理器301和存储器302。本领域技术人员可以理解,图3仅仅是电子设备3的示例,并不构成对电子设备3的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如,电子设备还可以包括输入输出设备、网络接入设备、总线等。
处理器301可以是中央处理单元(Central Processing Unit,CPU),也可以是其它通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application SpecificIntegrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其它可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器302可以是电子设备3的内部存储单元,例如,电子设备3的硬盘或内存。存储器302也可以是电子设备3的外部存储设备,例如,电子设备3上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。进一步地,存储器302还可以既包括电子设备3的内部存储单元也包括外部存储设备。存储器302用于存储计算机程序以及电子设备所需的其它程序和数据。存储器302还可以用于暂时地存储已经输出或者将要输出的数据。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
在本申请所提供的实施例中,应该理解到,所揭露的装置/计算机设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/计算机设备实施例仅仅是示意性的,例如,模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本申请实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,计算机程序可以存储在计算机可读存储介质中,该计算机程序在被处理器执行时,可以实现上述各个方法实施例的步骤。计算机程序可以包括计算机程序代码,计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。计算机可读介质可以包括:能够携带计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、电载波信号、电信信号以及软件分发介质等。需要说明的是,计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如,在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。
以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。
Claims (10)
1.一种深度强化学习模型训练方法,其特征在于,包括:
对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型;
将所述策略网络与环境进行互动,生成一系列的轨迹记录,其中所述轨迹记录中包含动作以及所述动作对应的奖励和状态;
依据所述轨迹记录,确定每个所述状态下执行所述动作时对应的优势函数,其中所述优势函数用于表征执行所述动作产生预期奖励相比平均奖励的优势程度;
计算所述优势函数的原始损失,并利用各个所述状态对应的优势函数值计算对比学习损失,依据所述原始损失以及所述对比学习损失计算新的损失函数;
利用所述新的损失函数对所述价值网络的参数进行更新,并将更新后的价值网络的参数固定,对所述策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至所述深度强化学习模型收敛。
2.根据权利要求1所述的方法,其特征在于,所述对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型,包括:
利用预设参数或随机参数对所述深度强化学习模型中的策略网络和价值网络进行初始化,其中,所述策略网络用于决策动作的选择,所述价值网络用于计算执行相应动作的优势函数值。
3.根据权利要求1所述的方法,其特征在于,所述将所述策略网络与环境进行互动,生成一系列的轨迹记录,包括:
利用所述策略网络与所述环境进行互动,并采用时序差分或蒙特卡洛方法进行采样,产生由一系列所述状态、动作和奖励组合成的所述轨迹记录。
4.根据权利要求1所述的方法,其特征在于,所述依据所述轨迹记录,确定每个所述状态下执行所述动作时对应的优势函数,包括:
对所述轨迹记录中的每一个所述状态及对应执行的动作,将执行所述动作后的预期奖励与相同状态下按照当前策略执行动作的平均奖励进行比较,得到优势函数值,其中,所述优势函数值用于表征在所述状态下执行所述动作相较于按照当前策略执行所述动作的优势程度。
5.根据权利要求4所述的方法,其特征在于,所述计算所述优势函数的原始损失,包括:
依据所述价值网络对每个所述状态下选择的所述动作的优势函数值的预测,以及利用从环境互动中实际获得的奖励计算出的实际优势函数值,计算优势函数预测值与所述实际优势函数值之间的平均平方误差,得到所述优势函数的原始损失。
6.根据权利要求4所述的方法,其特征在于,所述利用各个所述状态对应的优势函数值计算对比学习损失,包括:
利用所述策略网络预测出在各个所述状态下执行所述动作的优势函数值,将所述优势函数值的差距小于阈值的状态作为正样本,将所述优势函数值的差距大于阈值的状态作为负样本,基于所述正样本以及所述负样本,利用三元组损失方法计算得到所述对比学习损失。
7.根据权利要求5或6所述的方法,其特征在于,所述依据所述原始损失以及所述对比学习损失计算新的损失函数,包括采用以下公式计算所述新的损失函数:
;
其中,表示新的损失函数,/>表示优势函数的原始损失,/>表示超参数,/>表示对比学习损失。
8.一种深度强化学习模型训练装置,其特征在于,包括:
初始化模块,被配置为对深度强化学习模型中的策略网络和价值网络进行初始化,得到初始化后的深度强化学习模型;
生成模块,被配置为将所述策略网络与环境进行互动,生成一系列的轨迹记录,其中所述轨迹记录中包含动作以及所述动作对应的奖励和状态;
确定模块,被配置为依据所述轨迹记录,确定每个所述状态下执行所述动作时对应的优势函数,其中所述优势函数用于表征执行所述动作产生预期奖励相比平均奖励的优势程度;
计算模块,被配置为计算所述优势函数的原始损失,并利用各个所述状态对应的优势函数值计算对比学习损失,依据所述原始损失以及所述对比学习损失计算新的损失函数;
更新模块,被配置为利用所述新的损失函数对所述价值网络的参数进行更新,并将更新后的价值网络的参数固定,对所述策略网络的参数进行更新,利用更新后的策略网络与环境进行互动,直至所述深度强化学习模型收敛。
9.一种电子设备,包括存储器,处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7中任一项所述的方法。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310884815.XA CN116596060B (zh) | 2023-07-19 | 2023-07-19 | 深度强化学习模型训练方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310884815.XA CN116596060B (zh) | 2023-07-19 | 2023-07-19 | 深度强化学习模型训练方法、装置、电子设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116596060A true CN116596060A (zh) | 2023-08-15 |
CN116596060B CN116596060B (zh) | 2024-03-15 |
Family
ID=87594138
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310884815.XA Active CN116596060B (zh) | 2023-07-19 | 2023-07-19 | 深度强化学习模型训练方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116596060B (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117409486A (zh) * | 2023-12-15 | 2024-01-16 | 深圳须弥云图空间科技有限公司 | 基于视觉的动作生成方法、装置、电子设备及存储介质 |
CN117408052A (zh) * | 2023-10-18 | 2024-01-16 | 南栖仙策(南京)高新技术有限公司 | 一种蒸镀机镀膜控制优化方法、装置、设备及存储介质 |
CN118153658A (zh) * | 2024-02-28 | 2024-06-07 | 中国科学院自动化研究所 | 离线强化学习训练方法、动作预测方法、装置及介质 |
CN118398100A (zh) * | 2024-06-27 | 2024-07-26 | 日照盛泉新材料科技有限公司 | 一种动态反馈的泵强制循环反应器催化控制方法 |
Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109492583A (zh) * | 2018-11-09 | 2019-03-19 | 安徽大学 | 一种基于深度学习的车辆重识别方法 |
CN113176776A (zh) * | 2021-03-03 | 2021-07-27 | 上海大学 | 基于深度强化学习的无人艇天气自适应避障方法 |
CN114020945A (zh) * | 2021-11-05 | 2022-02-08 | 中山大学 | 一种面向农业采摘的高效识别控制强化学习算法 |
CN114881228A (zh) * | 2021-09-04 | 2022-08-09 | 大连钜智信息科技有限公司 | 一种基于q学习的平均sac深度强化学习方法和系统 |
CN115330556A (zh) * | 2022-08-10 | 2022-11-11 | 北京百度网讯科技有限公司 | 充电站的信息调整模型的训练方法、装置及产品 |
US20220363279A1 (en) * | 2021-04-21 | 2022-11-17 | Foundation Of Soongsil University-Industry Cooperation | Method for combating stop-and-go wave problem using deep reinforcement learning based autonomous vehicles, recording medium and device for performing the method |
CN115545350A (zh) * | 2022-11-28 | 2022-12-30 | 湖南工商大学 | 综合深度神经网络与强化学习的车辆路径问题求解方法 |
CN115618716A (zh) * | 2022-09-14 | 2023-01-17 | 天津大学 | 一种基于离散SAC算法的gazebo潜航器路径规划算法 |
WO2023043601A1 (en) * | 2021-09-16 | 2023-03-23 | Siemens Corporation | System and method for supporting execution of batch production using reinforcement learning |
CN116010054A (zh) * | 2022-12-28 | 2023-04-25 | 哈尔滨工业大学 | 一种基于强化学习的异构边云ai系统任务调度框架 |
CN116187466A (zh) * | 2022-12-08 | 2023-05-30 | 北京航空航天大学 | 一种基于旋转对称性的多智能体强化学习训练方法 |
WO2023102962A1 (zh) * | 2021-12-06 | 2023-06-15 | 深圳先进技术研究院 | 一种训练端到端的自动驾驶策略的方法 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111460650B (zh) * | 2020-03-31 | 2022-11-01 | 北京航空航天大学 | 一种基于深度强化学习的无人机端到端控制方法 |
CN111708355B (zh) * | 2020-06-19 | 2023-04-18 | 中国人民解放军国防科技大学 | 基于强化学习的多无人机动作决策方法和装置 |
-
2023
- 2023-07-19 CN CN202310884815.XA patent/CN116596060B/zh active Active
Patent Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109492583A (zh) * | 2018-11-09 | 2019-03-19 | 安徽大学 | 一种基于深度学习的车辆重识别方法 |
CN113176776A (zh) * | 2021-03-03 | 2021-07-27 | 上海大学 | 基于深度强化学习的无人艇天气自适应避障方法 |
US20220363279A1 (en) * | 2021-04-21 | 2022-11-17 | Foundation Of Soongsil University-Industry Cooperation | Method for combating stop-and-go wave problem using deep reinforcement learning based autonomous vehicles, recording medium and device for performing the method |
CN114881228A (zh) * | 2021-09-04 | 2022-08-09 | 大连钜智信息科技有限公司 | 一种基于q学习的平均sac深度强化学习方法和系统 |
WO2023043601A1 (en) * | 2021-09-16 | 2023-03-23 | Siemens Corporation | System and method for supporting execution of batch production using reinforcement learning |
CN114020945A (zh) * | 2021-11-05 | 2022-02-08 | 中山大学 | 一种面向农业采摘的高效识别控制强化学习算法 |
WO2023102962A1 (zh) * | 2021-12-06 | 2023-06-15 | 深圳先进技术研究院 | 一种训练端到端的自动驾驶策略的方法 |
CN115330556A (zh) * | 2022-08-10 | 2022-11-11 | 北京百度网讯科技有限公司 | 充电站的信息调整模型的训练方法、装置及产品 |
CN115618716A (zh) * | 2022-09-14 | 2023-01-17 | 天津大学 | 一种基于离散SAC算法的gazebo潜航器路径规划算法 |
CN115545350A (zh) * | 2022-11-28 | 2022-12-30 | 湖南工商大学 | 综合深度神经网络与强化学习的车辆路径问题求解方法 |
CN116187466A (zh) * | 2022-12-08 | 2023-05-30 | 北京航空航天大学 | 一种基于旋转对称性的多智能体强化学习训练方法 |
CN116010054A (zh) * | 2022-12-28 | 2023-04-25 | 哈尔滨工业大学 | 一种基于强化学习的异构边云ai系统任务调度框架 |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117408052A (zh) * | 2023-10-18 | 2024-01-16 | 南栖仙策(南京)高新技术有限公司 | 一种蒸镀机镀膜控制优化方法、装置、设备及存储介质 |
CN117409486A (zh) * | 2023-12-15 | 2024-01-16 | 深圳须弥云图空间科技有限公司 | 基于视觉的动作生成方法、装置、电子设备及存储介质 |
CN117409486B (zh) * | 2023-12-15 | 2024-04-12 | 深圳须弥云图空间科技有限公司 | 基于视觉的动作生成方法、装置、电子设备及存储介质 |
CN118153658A (zh) * | 2024-02-28 | 2024-06-07 | 中国科学院自动化研究所 | 离线强化学习训练方法、动作预测方法、装置及介质 |
CN118398100A (zh) * | 2024-06-27 | 2024-07-26 | 日照盛泉新材料科技有限公司 | 一种动态反馈的泵强制循环反应器催化控制方法 |
Also Published As
Publication number | Publication date |
---|---|
CN116596060B (zh) | 2024-03-15 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116596060B (zh) | 深度强化学习模型训练方法、装置、电子设备及存储介质 | |
Hambly et al. | Recent advances in reinforcement learning in finance | |
CN110956148A (zh) | 无人车的自主避障方法及装置、电子设备、可读存储介质 | |
US20220395975A1 (en) | Demonstration-conditioned reinforcement learning for few-shot imitation | |
CN112613608A (zh) | 一种强化学习方法及相关装置 | |
CN112016678A (zh) | 用于增强学习的策略生成网络的训练方法、装置和电子设备 | |
JP2023548915A (ja) | 深層顔認識のためのメタ学習を用いたドメイン一般化マージン | |
CN111694272B (zh) | 基于模糊逻辑系统的非线性多智能体的自适应控制方法及装置 | |
CN114648103A (zh) | 用于处理深度学习网络的自动多目标硬件优化 | |
Sun et al. | An improved hybrid algorithm based on PSO and BP for stock price forecasting | |
CN113721655B (zh) | 一种控制周期自适应的强化学习无人机稳定飞行控制方法 | |
Zhao et al. | Efficient online estimation of empowerment for reinforcement learning | |
Li | Focus of attention in reinforcement learning | |
Ding et al. | Chaos synchronization of two coupled map lattice systems using safe reinforcement learning | |
CN115545188B (zh) | 基于不确定性估计的多任务离线数据共享方法及系统 | |
Esposito et al. | Bellman residuals minimization using online support vector machines | |
KR102558092B1 (ko) | 샘플 효율적인 탐색을 위한 샘플-인지 엔트로피 정규화 기법 | |
CN117850237B (zh) | 基于迭代学习控制的跟踪处理方法、装置、设备和介质 | |
CN115984804B (zh) | 一种基于多任务检测模型的检测方法及车辆 | |
CN117360552B (zh) | 一种车辆控制方法、装置、设备及可读存储介质 | |
CN113433823B (zh) | 一种改进的自适应双幂次趋近律的方法、装置及存储介质 | |
CN118607610A (zh) | 加速训练神经网络模型方法、装置、设备及存储介质 | |
Zhang et al. | Versatile Navigation under Partial Observability via Value-guided Diffusion Policy | |
Chen et al. | Basics of Machine Learning in Architecture | |
CN117425194A (zh) | 无人机分簇方法和设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |