CN112085524B - 一种基于q学习模型的结果推送方法和系统 - Google Patents

一种基于q学习模型的结果推送方法和系统 Download PDF

Info

Publication number
CN112085524B
CN112085524B CN202010896316.9A CN202010896316A CN112085524B CN 112085524 B CN112085524 B CN 112085524B CN 202010896316 A CN202010896316 A CN 202010896316A CN 112085524 B CN112085524 B CN 112085524B
Authority
CN
China
Prior art keywords
value
gradient
learning model
network parameter
result
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010896316.9A
Other languages
English (en)
Other versions
CN112085524A (zh
Inventor
徐君
贾浩男
张骁
蒋昊
文继荣
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Huawei Technologies Co Ltd
Renmin University of China
Original Assignee
Huawei Technologies Co Ltd
Renmin University of China
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Huawei Technologies Co Ltd, Renmin University of China filed Critical Huawei Technologies Co Ltd
Priority to CN202010896316.9A priority Critical patent/CN112085524B/zh
Publication of CN112085524A publication Critical patent/CN112085524A/zh
Application granted granted Critical
Publication of CN112085524B publication Critical patent/CN112085524B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q30/00Commerce
    • G06Q30/02Marketing; Price estimation or determination; Fundraising
    • G06Q30/0241Advertisements
    • G06Q30/0251Targeted advertisements
    • G06Q30/0255Targeted advertisements based on user history
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q30/00Commerce
    • G06Q30/02Marketing; Price estimation or determination; Fundraising
    • G06Q30/0241Advertisements
    • G06Q30/0251Targeted advertisements
    • G06Q30/0255Targeted advertisements based on user history
    • G06Q30/0256User search
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06QINFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
    • G06Q30/00Commerce
    • G06Q30/02Marketing; Price estimation or determination; Fundraising
    • G06Q30/0241Advertisements
    • G06Q30/0251Targeted advertisements
    • G06Q30/0269Targeted advertisements based on user profile or attribute
    • G06Q30/0271Personalized advertisement
    • HELECTRICITY
    • H04ELECTRIC COMMUNICATION TECHNIQUE
    • H04LTRANSMISSION OF DIGITAL INFORMATION, e.g. TELEGRAPHIC COMMUNICATION
    • H04L67/00Network arrangements or protocols for supporting network services or applications
    • H04L67/50Network services
    • H04L67/55Push-based network services

Landscapes

  • Engineering & Computer Science (AREA)
  • Business, Economics & Management (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Strategic Management (AREA)
  • Finance (AREA)
  • Development Economics (AREA)
  • Accounting & Taxation (AREA)
  • Software Systems (AREA)
  • Game Theory and Decision Science (AREA)
  • Entrepreneurship & Innovation (AREA)
  • General Business, Economics & Management (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Mathematical Physics (AREA)
  • Data Mining & Analysis (AREA)
  • Marketing (AREA)
  • Economics (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Signal Processing (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Computer Networks & Wireless Communication (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)

Abstract

本发明涉及一种基于Q学习模型的结果推送方法和系统,包括以下步骤:将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;从经验池D中提取若干数据组,计算网络参数
Figure DDA0002658524720000011
下的全梯度均值,此时的网络参数为锚点网络参数;随机提取上一步骤中的数据组,并计算其在当前网络参数下和锚点网络参数下的目标Q值和梯度值,将梯度值和全梯度均值带入方差缩减公式实现梯度更新;重复上述步骤直至训练结束,获得最终的Q学习模型,将待测状态输入最终的Q学习模型获得最佳推送结果。其通过将方差缩减技术引入到随机梯度下降的Q学习模型中,提高了强化学习的训练过程的稳定性。

Description

一种基于Q学习模型的结果推送方法和系统
技术领域
本发明是关于一种基于Q学习模型的结果推送方法及系统,属于互联网技术领域。
背景技术
在信息检索中,采用结果推送方法或者按照结果与检索信息的相关度进行排序可以大大降低检索者的工作量,提高信息获取效率。目前已经有很多将强化学习模型,例如深度Q学习模型,应用到检索结果推送中,通过使用检索者的历史检索记录对强化学习模型进行训练,可以是推送出的结果更加符合检索者的要求,进一步提高检索效率。但现有的利用深度Q学习模型生成的结果推送的方法还存在以下问题:
一方面,由于深度Q学习模型(DQN)在基于值函数的深度强化学习方面起着绝对的引领作用,导致对DQN算法的改进多注重于改进DQN算法的网络结构以提升其效率;另一方面,由于强化学习算法有着“试错”的训练特点,导致其在训练过程通常很不稳定,而其不稳定性主要是由奖励值、Q值等的方差过高而引起的。
发明内容
针对上述现有技术的不足,本发明的目的是提供了一种基于Q学习模型的结果推送方法及系统,其通过将方差缩减技术引入到随机梯度下降的Q学习模型中,降低了奖励值或Q值的方差,提高了强化学习的训练过程的稳定性。
为实现上述目的,本发明提供了一种基于Q学习模型的结果推送方法,包括以下步骤:S1确定当前状态st,将当前状态st带入初始Q学习模型获得Q值,根据Q值获得原始推送结果at;S2将原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1;S3将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;S4从经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure GDA0003831718570000015
下的全梯度均值,此时的网络参数为锚点网络参数;S5随机提取一步骤S4中的数据组,并计算其在当前网络参数下和锚点网络参数下的目标Q值和梯度值,将梯度值和全梯度均值带入方差缩减公式实现梯度更新;S6重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入最终的Q学习模型获得最佳推送结果。
进一步,步骤S5中的方差缩减公式:
Figure GDA0003831718570000011
其中,
Figure GDA0003831718570000012
是下一个网络参数;
Figure GDA0003831718570000013
是当前网络参数;α是学习率;
Figure GDA0003831718570000014
是梯度值;g是全梯度均值。
进一步,梯度值的计算公式为:
当前网络参数下的梯度值:
Figure GDA0003831718570000021
锚点网络参数下的梯度值:
Figure GDA0003831718570000022
其中,s,a分别为步骤S5中随机提取的一数据组中的状态和状态对应的推送结果,qm是当前网络参数下的目标Q值,q0是锚点网络参数下的目标Q值,
Figure GDA0003831718570000023
是锚点网络参数,Q()为Q网络。
进一步,目标Q值的计算公式为:
当前网络参数下的目标Q值:
Figure GDA0003831718570000024
锚点网络参数下的目标Q值:
Figure GDA0003831718570000025
其中,s′,a′分别为步骤S5中随机提取的一数据组中的下一个状态和下一个状态对应的推送结果,r是奖励值,γ是折扣系数。
进一步,全梯度均值的计算公式为:
Figure GDA0003831718570000026
其中,N为数据组的数量,l()为损失函数。
本发明还公开了另一种基于Q学习模型的结果推送方法,包括以下步骤:S1确定当前状态s1,将当前状态st带入初始Q学习模型获得Q值,根据Q值获得原始推送结果at;S2将原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1;S3将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;S4从经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure GDA0003831718570000027
下的全梯度均值,对全梯度均值进行梯度优化:
Figure GDA0003831718570000028
其中,
Figure GDA0003831718570000029
是下一个网络参数;
Figure GDA00038317185700000210
是当前网络参数;
Figure GDA00038317185700000211
是当前网络参数下的全梯度均值;S5随机提取一步骤S4中的数据组,并计算其在当前网络参数下和上一个网络参数下的目标Q值和梯度值,将梯度值和全梯度均值带入方差缩减公式实现梯度更新;S6重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入最终的Q学习模型获得最佳推送结果。
进一步,步骤S5中的方差缩减公式:
Figure GDA00038317185700000212
其中,l()为损失函数,
Figure GDA00038317185700000213
是上一个网络参数;
Figure GDA00038317185700000214
是当前网络参数;
Figure GDA00038317185700000215
是上一个网络参数下的全梯度均值;
Figure GDA00038317185700000216
是当前网络参数下的全梯度均值。
进一步,梯度值的计算公式为:
当前网络参数下的梯度值:
Figure GDA0003831718570000031
上一个网络参数下的梯度值:
Figure GDA0003831718570000032
其中,s,a分别为步骤S5中随机提取的一数据组中的状态和状态对应的推送结果,qm是当前网络参数下的目标Q值,q0是锚点网络参数下的目标Q值,
Figure GDA0003831718570000033
是锚点网络参数,Q()为Q网络。
进一步,目标Q值的计算公式为:
当前网络参数下的目标Q值:
Figure GDA0003831718570000034
上一个网络参数下的目标Q值:
Figure GDA0003831718570000035
其中,s′,a′分别为步骤S5中随机提取的一数据组中的下一个状态和下一个状态对应的推送结果,r是奖励值,γ是折扣系数。
本发明还公开了一种基于Q学习模型的结果推送系统,包括:原始推送结果生成模块,用于确定当前状态st,将当前状态st带入初始Q学习模型获得Q值,根据Q值获得原始推送结果at;奖励值生成模块,用于将原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1;存储模块,用于将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;全梯度均值计算模块,用于从经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure GDA0003831718570000036
下的全梯度均值,此时的网络参数为锚点网络参数;梯度更新模块,用于随机提取一步骤S4中的数据组,并计算其在当前网络参数下和锚点网络参数下的目标Q值和梯度值,将梯度值和全梯度均值带入方差缩减公式实现梯度更新;输出模块,用于重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入最终的Q学习模型获得最佳推送结果。
本发明由于采取以上技术方案,其具有以下优点:
1、通过将方差缩减技术引入到随机梯度下降的Q学习模型中,降低了奖励值或Q值的方差,提高了强化学习的训练过程的精度、稳定性。
2、采用随机递归梯度算法(Stochastic recursive gradient algorithm,SARAH)解决了随机方差缩减梯度下降技术(Stochastic Variance Reduced Gradient Descent,SVRG)在训练时网络的参数不固定的且可能会逐渐偏移采样时的参数,从而造成信息差越来越大的问题,使模型计算更加准确。
附图说明
图1是本发明一实施例中基于深度学习模型的地震数据不连续性检测方法的示意图;
图2是本发明一实施例中梯度优化算法的示意图,图2(a)是传统的梯度优化算法的示意图,图2(b)是随机梯度下降的梯度优化算法的示意图;
图3是本发明一实施例中基于方差缩减的的深度Q学习模型训练框架的逻辑示意图。
具体实施方式
为了使本领域技术人员更好的理解本发明的技术方向,通过具体实施例对本发明进行详细的描绘。然而应当理解,具体实施方式的提供仅为了更好地理解本发明,它们不应该理解成对本发明的限制。在本发明的描述中,需要理解的是,所用到的术语仅仅是用于描述的目的,而不能理解为指示或暗示相对重要性。
实施例一
本实施例公开了一种基于Q学习模型的结果推送方法,如图1所示,包括以下步骤:
S1首先,设定初始Q学习模型,确定当前状态st,其中,初始化状态s0通过用户当前浏览记录活动;随后的状况通过用户上一次交互后的浏览历史获得;将当前状态st带入初始Q学习模型获得Q值,根据Q值获得原始推送结果at;其中,推送结果包括推送内容和推送内容的位置。
S2将原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1
S3将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;
S4从经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure GDA0003831718570000041
下的全梯度均值,此时的网络参数为锚点网络参数;
全梯度均值的计算公式为:
Figure GDA0003831718570000042
其中,N为数据组的数量,l()为损失函数。
S5随机提取一步骤S4中的数据组,并计算其在当前网络参数下和锚点网络参数下的目标Q值和梯度值,将梯度值和全梯度均值带入方差缩减公式实现梯度更新;
其中,目标Q值的计算公式为:
当前网络参数下的目标Q值:
Figure GDA0003831718570000043
锚点网络参数下的目标Q值:
Figure GDA0003831718570000044
其中,s′,a′分别为步骤S5中随机提取的一数据组中的下一个状态和下一个状态对应的推送结果,r是奖励值,γ是折扣系数。
若引入目标网络Q`(s,a;θ),目标Q值的计算公式为:
当前网络参数下的目标Q值:
qm←r+γmaxa′Q`(s′,a′;θ-)
锚点网络参数下的目标Q值:
q0←r+γmaxa′Q`(s′,a′;θ-)
其中,参数θ-代表上一次训练网络Q(s,a;θ)向目标网络Q`(s,a;θ)的参数值,而目标网络Q`是与训练网络Q结构相同但网络参数不同的网络。
梯度值的计算公式为:
当前网络参数下的梯度值:
Figure GDA0003831718570000051
锚点网络参数下的梯度值:
Figure GDA0003831718570000052
其中,s,a分别为步骤S5中随机提取的一数据组中的状态和状态对应的推送结果,qm是当前网络参数下的目标Q值,q0是锚点网络参数下的目标Q值,
Figure GDA0003831718570000053
是锚点网络参数,Q()为Q网络。
方差缩减公式为:
Figure GDA0003831718570000054
其中,
Figure GDA0003831718570000055
是下一个网络参数;
Figure GDA0003831718570000056
是当前网络参数;α是学习率;
Figure GDA0003831718570000057
是梯度值;g是全梯度均值。
S6重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入最终的Q学习模型获得最佳推送结果。
本实施例主要采用基于随机方差缩减梯度下降技术(Stochastic VarianceReduced Gradient Descent,SVRG)的Q学习模型实现。如图2所示,在传统的梯度优化算法中,以梯度下降(GD)为主体的算法能保证待优化参数达到一个全局最优点,但由于其每一步都涉及到全梯度的计算,这在数据量过大的问题背景下通常会造成大计算量消耗,从而使训练过程变得迟缓。随机梯度下降(SGD)算法为避免每一步训练的大计算量消耗,其放弃了全梯度的计算,通过每一步采样一个(或一小批)数据来训练模型,虽然同样能保证优化目标的收敛,但由于其随机采样的特点,在优化层面上仍然有着因梯度方差过高引起收敛速度慢的局限性。
为解决上述问题,通过在随机梯度下降的过程中引入方差缩减技术进行优化。方差缩减的数学定义为:
Zα=α(X-Y)+E[Y]
其中,X代表需要被缩减方差的随机变量,Y代表另一个与X有正相关关系的随机变量,E[Y]代表随机变量Y的数学期望,Zα代表被方差缩减优化后的随机变量。
随机方差缩减梯度下降技术将原始的参数更新步骤改为了形如上Zα的方差缩减形式,通过定期采样批量训练数据充当方差缩减定义中的Y,其梯度更新公式为:
Figure GDA0003831718570000061
其中θt为训练至第t步时的待优化参数,θold代表计算全梯度时的参数值,
Figure GDA0003831718570000062
代表批量数据损失函数的全梯度值的期望,
Figure GDA0003831718570000063
代表单个数据样本损失函数的梯度值,η代表学习率。
本发明将损失函数l(s,a;θ)对网络各层参数的梯度
Figure GDA0003831718570000064
作为待缩减方差的随机变量X。如图3所示,基于方差缩减的深度Q-learning训练框架,其中当前网络Q代表学习模型,环境代表与网络Q交互的对象,网络Q接受环境的当前状态s作为输入,并且根据当前的网络参数θm评估在状态s下执行各个动作的Q值,根据Q值选出最优动作a输出至环境,环境接收该动作并转入下一状态s′。该框架以当前网络Q作为输入,以方差优化后的网络作为输出,具体而言,输入该网络的参数θ0,输出经过方差缩减训练过的优化网络参数
Figure GDA0003831718570000065
在训练过程中,环境与当前网络不断交互产生转移数据组(s,a,r,s′),容量有限的经验池D负责存储这些产生的数据并定期送入网络进行训练。由SVRG算法的特性可知,首先需要在经验池中采样一批数据,同时需要根据采样批数据时的网络
Figure GDA0003831718570000066
计算出这批数据的全梯度均值g,用于充当SVRG优化过程中的期望E[Y]。批数据中的单个样本在采样批数据时的网络
Figure GDA0003831718570000067
下的梯度值则充当了优化过程中的辅助变量Y。
实施例二
基于相同的发明构思,本实施例公开了另一种基于Q学习模型的结果推送方法,包括以下步骤:
S1首先,设定初始Q学习模型,确定当前状态st,其中,初始化状态s0通过用户当前浏览记录活动;随后的状况通过用户上一次交互后的浏览历史获得;将当前状态st带入初始Q学习模型获得Q值,根据Q值获得原始推送结果at;其中,推送结果包括推送内容和推送内容的位置。
S2将原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1
S3将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;S4从经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure GDA0003831718570000071
下的全梯度均值,对全梯度均值进行梯度优化:
Figure GDA0003831718570000072
其中,
Figure GDA0003831718570000073
是下一个网络参数;
Figure GDA0003831718570000074
是当前网络参数;
Figure GDA0003831718570000075
是当前网络参数下的全梯度均值;
S5随机提取一步骤S4中的数据组,并计算其在当前网络参数下和上一个网络参数下的目标Q值和梯度值,将梯度值和全梯度均值带入方差缩减公式实现梯度更新;
其中,目标Q值的计算公式为:
当前网络参数下的目标Q值:
Figure GDA0003831718570000076
上一个网络参数下的目标Q值:
Figure GDA0003831718570000077
其中,s′,a′分别为步骤S5中随机提取的一数据组中的下一个状态和下一个状态对应的推送结果,r是奖励值,γ是折扣系数。
步骤S5中的方差缩减公式:
Figure GDA0003831718570000078
其中,l()为损失函数,
Figure GDA0003831718570000079
是上一个网络参数;
Figure GDA00038317185700000710
是当前网络参数;
Figure GDA00038317185700000711
是上一个网络参数下的全梯度均值;
Figure GDA00038317185700000712
是当前网络参数下的全梯度均值。
梯度值的计算公式为:
当前网络参数下的梯度值:
Figure GDA00038317185700000713
上一个网络参数下的梯度值:
Figure GDA00038317185700000714
其中,s,a分别为步骤S5中随机提取的一数据组中的状态和状态对应的推送结果,qm是当前网络参数下的目标Q值,q0是锚点网络参数下的目标Q值,
Figure GDA00038317185700000715
是锚点网络参数,Q()为Q网络。
步骤S5中的方差缩减公式:
Figure GDA00038317185700000716
其中,l()为损失函数,
Figure GDA00038317185700000717
是上一个网络参数;
Figure GDA00038317185700000718
是当前网络参数;
Figure GDA00038317185700000719
是上一个网络参数下的全梯度均值。
S6重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入最终的Q学习模型获得最佳推送结果。
本实施例主要采用基于随机递归梯度算法(Stochastic recursive gradientalgorithm,SARAH)的Q学习模型实现。SVRG算法中使用一个固定的批数据全梯度均值g作为修正量E[Y],并且使用固定的网络(批数据采样时网络)
Figure GDA00038317185700000720
去计算单个样本的梯度值来充当Y,而在训练时网络的参数是不固定的且可能会逐渐偏移采样时的参数θ0,从而造成信息差越来越大的问题。
为了解决这一问题,SARAH提出使用循环更新或适应性更新的方法来处理梯度和全梯度的估计值,放弃使用固定的批数据全梯度均值g和固定的采样参数θold,而在训练过程中对全梯度均值g进行逐步更新,并且使用上一步的参数θt-1来代替θold,综上可以得出,在SARAH算法中,带有方差缩减效用梯度更新步骤如下:
Figure GDA0003831718570000081
θt+1=θt-ηgt
相对图3中SVRG算法,本实施例中将SVRG操作单元替换为上述的SARAH更新单元,并且在更新参数的同时保持对全梯度均值g的更新,此外本实施例采用固定的采样时网络替换为上一步训练时的网络、即
Figure GDA0003831718570000082
实施例三
基于相同的发明构思,本实施例公开了一种基于Q学习模型的结果推送系统,包括:
原始推送结果生成模块,用于确定当前状态st,将当前状态st带入初始Q学习模型获得Q值,根据Q值获得原始推送结果at
奖励值生成模块,用于将原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1
存储模块,用于将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;
全梯度均值计算模块,用于从经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure GDA0003831718570000083
下的全梯度均值,此时的网络参数为锚点网络参数;
梯度更新模块,用于随机提取一步骤S4中的数据组,并计算其在当前网络参数下和锚点网络参数下的目标Q值和梯度值,将梯度值和全梯度均值带入方差缩减公式实现梯度更新;
输出模块,用于重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入最终的Q学习模型获得最佳推送结果。
最后应当说明的是:以上实施例仅用以说明本发明的技术方案而非对其限制,尽管参照上述实施例对本发明进行了详细的说明,所属领域的普通技术人员应当理解:依然可以对本发明的具体实施方式进行修改或者等同替换,而未脱离本发明精神和范围的任何修改或者等同替换,其均应涵盖在本发明的权利要求保护范围之内。上述内容仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以权利要求的保护范围。

Claims (9)

1.一种基于Q学习模型的结果推送方法,其特征在于,包括以下步骤:
S1确定当前状态st,将当前状态st带入初始Q学习模型获得Q值,根据所述Q值获得原始推送结果at
S2将所述原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1
S3将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;
S4从所述经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure FDA00038317185600000110
下的全梯度均值,此时的网络参数为锚点网络参数;
S5随机提取一步骤S4中的数据组,并计算其在当前网络参数下和锚点网络参数下的目标Q值和梯度值,将所述梯度值和全梯度均值带入方差缩减公式实现梯度更新;
S6重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入所述最终的Q学习模型获得最佳推送结果;
所述步骤S5中的方差缩减公式:
Figure FDA0003831718560000011
其中,
Figure FDA0003831718560000012
是下一个网络参数;
Figure FDA0003831718560000013
是当前网络参数;α是学习率;
Figure FDA0003831718560000014
是梯度值;g是全梯度均值。
2.如权利要求1所述的基于Q学习模型的结果推送方法,其特征在于,所述梯度值的计算公式为:
当前网络参数下的梯度值:
Figure FDA0003831718560000015
锚点网络参数下的梯度值:
Figure FDA0003831718560000016
其中,s,a分别为步骤S5中随机提取的一数据组中的状态和所述状态对应的推送结果,qm是当前网络参数下的目标Q值,q0是锚点网络参数下的目标Q值,
Figure FDA0003831718560000017
是锚点网络参数,Q()为Q网络。
3.如权利要求2所述的基于Q学习模型的结果推送方法,其特征在于,所述目标Q值的计算公式为:
当前网络参数下的目标Q值:
Figure FDA0003831718560000018
锚点网络参数下的目标Q值:
Figure FDA0003831718560000019
其中,s′,a′分别为步骤S5中随机提取的一数据组中的下一个状态和所述下一个状态对应的推送结果,r是奖励值,γ是折扣系数。
4.如权利要求3所述的基于Q学习模型的结果推送方法,其特征在于,所述全梯度均值的计算公式为:
Figure FDA0003831718560000021
其中,N为数据组的数量,l()为损失函数。
5.一种基于Q学习模型的结果推送方法,其特征在于,包括以下步骤:
S1确定当前状态st,将当前状态st带入初始Q学习模型获得Q值,根据所述Q值获得原始推送结果at
S2将所述原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1
S3将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;
S4从所述经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure FDA0003831718560000022
下的全梯度均值,对所述全梯度均值进行梯度优化:
Figure FDA0003831718560000023
其中,
Figure FDA0003831718560000024
是下一个网络参数;
Figure FDA0003831718560000025
是当前网络参数;
Figure FDA0003831718560000026
是当前网络参数下的全梯度均值;
S5随机提取一步骤S4中的数据组,并计算其在当前网络参数下和上一个网络参数下的目标Q值和梯度值,将所述梯度值和全梯度均值带入方差缩减公式实现梯度更新;
S6重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入所述最终的Q学习模型获得最佳推送结果;
所述步骤S5中的方差缩减公式:
Figure FDA0003831718560000027
其中,
Figure FDA0003831718560000028
是下一个网络参数;
Figure FDA0003831718560000029
是当前网络参数;α是学习率;
Figure FDA00038317185600000210
是梯度值;g是全梯度均值。
6.如权利要求5所述的基于Q学习模型的结果推送方法,其特征在于,所述步骤S5中的方差缩减公式:
Figure FDA00038317185600000211
其中,l()为损失函数,
Figure FDA00038317185600000212
是上一个网络参数;
Figure FDA00038317185600000213
是当前网络参数;
Figure FDA00038317185600000214
是上一个网络参数下的全梯度均值;
Figure FDA00038317185600000215
是当前网络参数下的全梯度均值。
7.如权利要求6所述的基于Q学习模型的结果推送方法,其特征在于,所述梯度值的计算公式为:
当前网络参数下的梯度值:
Figure FDA00038317185600000216
上一个网络参数下的梯度值:
Figure FDA00038317185600000217
其中,s,a分别为步骤S5中随机提取的一数据组中的状态和所述状态对应的推送结果,qm是当前网络参数下的目标Q值,q0是锚点网络参数下的目标Q值,
Figure FDA0003831718560000031
是锚点网络参数,Q()为Q网络。
8.如权利要求7所述的基于Q学习模型的结果推送方法,其特征在于,所述目标Q值的计算公式为:
当前网络参数下的目标Q值:
Figure FDA0003831718560000032
上一个网络参数下的目标Q值:
Figure FDA0003831718560000033
其中,s′,a′分别为步骤S5中随机提取的一数据组中的下一个状态和所述下一个状态对应的推送结果,r是奖励值,γ是折扣系数。
9.一种基于Q学习模型的结果推送系统,其特征在于,包括:
原始推送结果生成模块,用于确定当前状态st,将当前状态st带入初始Q学习模型获得Q值,根据所述Q值获得原始推送结果at
奖励值生成模块,用于将所述原始推送结果推送给用户,并通过记录用户浏览,获得奖励值rt+1
存储模块,用于将状态st、推送结果at,下一状态st+1和奖励值rt+1组成一个数据组,并将其存储至经验池D中;
全梯度均值计算模块,用于从所述经验池D中提取若干数据组,并根据提取的数据组计算网络参数
Figure FDA0003831718560000038
下的全梯度均值,此时的网络参数为锚点网络参数;
梯度更新模块,用于随机提取一步骤S4中的数据组,并计算其在当前网络参数下和锚点网络参数下的目标Q值和梯度值,将所述梯度值和全梯度均值带入方差缩减公式实现梯度更新;
输出模块,用于重复步骤S4-S5直至训练结束,获得最终的Q学习模型,将待测状态输入所述最终的Q学习模型获得最佳推送结果;
所述步骤S5中的方差缩减公式:
Figure FDA0003831718560000034
其中,
Figure FDA0003831718560000035
是下一个网络参数;
Figure FDA0003831718560000036
是当前网络参数;α是学习率;
Figure FDA0003831718560000037
是梯度值;g是全梯度均值。
CN202010896316.9A 2020-08-31 2020-08-31 一种基于q学习模型的结果推送方法和系统 Active CN112085524B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010896316.9A CN112085524B (zh) 2020-08-31 2020-08-31 一种基于q学习模型的结果推送方法和系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010896316.9A CN112085524B (zh) 2020-08-31 2020-08-31 一种基于q学习模型的结果推送方法和系统

Publications (2)

Publication Number Publication Date
CN112085524A CN112085524A (zh) 2020-12-15
CN112085524B true CN112085524B (zh) 2022-11-15

Family

ID=73731256

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010896316.9A Active CN112085524B (zh) 2020-08-31 2020-08-31 一种基于q学习模型的结果推送方法和系统

Country Status (1)

Country Link
CN (1) CN112085524B (zh)

Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109471963A (zh) * 2018-09-13 2019-03-15 广州丰石科技有限公司 一种基于深度强化学习的推荐算法
CN110084378A (zh) * 2019-05-07 2019-08-02 南京大学 一种基于本地学习策略的分布式机器学习方法
KR20190132193A (ko) * 2018-05-18 2019-11-27 한양대학교 에리카산학협력단 스마트 그리드에서 동적 가격 책정 수요반응 방법 및 시스템

Patent Citations (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
KR20190132193A (ko) * 2018-05-18 2019-11-27 한양대학교 에리카산학협력단 스마트 그리드에서 동적 가격 책정 수요반응 방법 및 시스템
CN109471963A (zh) * 2018-09-13 2019-03-15 广州丰石科技有限公司 一种基于深度强化学习的推荐算法
CN110084378A (zh) * 2019-05-07 2019-08-02 南京大学 一种基于本地学习策略的分布式机器学习方法

Also Published As

Publication number Publication date
CN112085524A (zh) 2020-12-15

Similar Documents

Publication Publication Date Title
CN110674604B (zh) 基于多维时序帧卷积lstm的变压器dga数据预测方法
CN108875916B (zh) 一种基于gru神经网络的广告点击率预测方法
CN111260030B (zh) 基于a-tcn电力负荷预测方法、装置、计算机设备及存储介质
WO2021109644A1 (zh) 一种基于元学习的混合动力车辆工况预测方法
CN110942194A (zh) 一种基于tcn的风电预测误差区间评估方法
CN112381673B (zh) 一种基于数字孪生的园区用电信息分析方法及装置
CN113449919B (zh) 一种基于特征和趋势感知的用电量预测方法及系统
CN112015719A (zh) 基于正则化和自适应遗传算法的水文预测模型的构建方法
CN115271219A (zh) 一种基于因果关系分析的短期负荷预测方法及预测系统
CN114548591A (zh) 一种基于混合深度学习模型和Stacking的时序数据预测方法及系统
CN114742209A (zh) 一种短时交通流预测方法及系统
CN113807596B (zh) 一种信息化工程造价的管理方法及系统
CN114971090A (zh) 一种电供暖负荷预测方法、系统、设备和介质
CN112085524B (zh) 一种基于q学习模型的结果推送方法和系统
CN103607219B (zh) 一种电力线通信系统的噪声预测方法
CN112951209A (zh) 一种语音识别方法、装置、设备及计算机可读存储介质
CN109740221B (zh) 一种基于搜索树的智能工业设计算法
CN115829123A (zh) 基于灰色模型与神经网络的天然气需求预测方法及装置
CN116151581A (zh) 一种柔性车间调度方法、系统及电子设备
CN113705878B (zh) 水平井出水量的确定方法、装置、计算机设备及存储介质
CN115035304A (zh) 一种基于课程学习的图像描述生成方法及系统
CN112348275A (zh) 一种基于在线增量学习的区域生态环境变化预测方法
CN111859807A (zh) 汽轮机初压寻优方法、装置、设备及存储介质
CN111369046A (zh) 一种基于灰色神经网络的风光互补功率预测方法
CN110580548A (zh) 一种基于类集成学习的多步交通速度预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant