CN112149359A - 信任域引导裁剪的策略优化方法、系统、存储介质及应用 - Google Patents

信任域引导裁剪的策略优化方法、系统、存储介质及应用 Download PDF

Info

Publication number
CN112149359A
CN112149359A CN202011074176.3A CN202011074176A CN112149359A CN 112149359 A CN112149359 A CN 112149359A CN 202011074176 A CN202011074176 A CN 202011074176A CN 112149359 A CN112149359 A CN 112149359A
Authority
CN
China
Prior art keywords
trust domain
matrix
graph
batch
attention
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011074176.3A
Other languages
English (en)
Inventor
刘世旋
吴克宇
冯旸赫
刘忠
黄金才
程光权
陈超
成清
胡星辰
杜航
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
National University of Defense Technology
Original Assignee
National University of Defense Technology
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by National University of Defense Technology filed Critical National University of Defense Technology
Priority to CN202011074176.3A priority Critical patent/CN112149359A/zh
Publication of CN112149359A publication Critical patent/CN112149359A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F30/00Computer-aided design [CAD]
    • G06F30/20Design optimisation, verification or simulation
    • G06F30/27Design optimisation, verification or simulation using machine learning, e.g. artificial intelligence, neural networks, support vector machines [SVM] or training a model
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F2111/00Details relating to CAD techniques
    • G06F2111/04Constraint-based CAD
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F2111/00Details relating to CAD techniques
    • G06F2111/08Probabilistic or stochastic CAD

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Computer Hardware Design (AREA)
  • Geometry (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明属于数据信任域策略优化技术领域,公开了一种信任域引导裁剪的策略优化方法、系统、存储介质及应用,所述信任域引导裁剪的策略优化方法包括:缓冲集合到经验回放器;从经验回放器中采样;计算批概率比矩阵、计算批KL距离矩阵;运用批KL距离矩阵根据裁剪概率比矩阵。本发明提出信任域引导裁剪下的策略优化方法来搜索有向图空间,而不会遭受使用REINFORCE或PPO的局部收敛问题。这种方法也适用于解决其他高维动作空间控制问题。本发明结合了图神经网络和缩放点积图注意力机制的最新进展,形成了可缩放的点积图图注意力网络。该网络可以从原始数据中更好地抽取特征关系,提高DAG的精确生成。

Description

信任域引导裁剪的策略优化方法、系统、存储介质及应用
技术领域
本发明属于数据信任域策略优化技术领域,尤其涉及一种信任域引导裁剪的策略优化方法、系统、存储介质及应用。
背景技术
目前,数据生成过程的因果结构对于许多研究问题至关重要。确定自然现象背后的因果机制在许多科学领域仍摆在首要位置,例如在了解病毒作用机制后,可以更有效地开发新药或者控制病毒传播。最近研究还表明,因果推理对于经典的机器学习任务(例如在具有因果关系的函数估计情况下,了解因果及其影响对半监督学习和迁移学习具有决定性意义。发现因果关系的最有效方法是通过受控随机实验,但是受控随机实验在某些科学领域几乎不可能进行模拟(比如自然灾害影响或者政治影响研究)。因此因果发现方法的最新研究集中于从被动可观察的数据推断因果关系。基于分数的方法将结构学习的问题表述为:在组合性非环约束的基础上,优化根据邻接矩阵和观测数据计算得出的得分函数。然而这个问题的搜索空间随着图节点数量增加是呈指数增加的,这对基于分数的优化方法提出了严峻的挑战。对此,提出了一种强化学习的方法,以贝叶斯信息准则(BIC)作为得分函数,并使用编解码器神经网络体系结构来搜索得分最高的DAG。其思想是:在使用观察数据生成图形之后,特定的强化学习智能体可以通过奖励信号流优化其策略,该奖励信号流包括生成图的得分函数和无环性约束。但是他们并未从强化学习方法上进行有效的探究,其采用的REINFORCE算法易于局部收敛且数据利于效率较低。信任域策略优化(TRPO)可以有效解决这些问题并提供可靠稳定的性能,但是其计算复杂度异常高。尽管近端策略优化(PPO)实施起来更简单并且具有TRPO的一些优点,但是本发明发现,这种由简易的裁剪导致的与信任区域约束的微小偏差,将随着动作数量增多而呈指数地放大,这导致其用于因果推理的过程中出现异常的探索行为,因为智能体无法优化某些子动作或使其中的一些动作陷入局部最优状态。
尽管在转纳和归纳测试中,图注意力网络(GAT)的性能都优于其他最新的图网络。但由于其简单的加性机制,它无法在本发明因果推断的任务中提供强大的特征抽取能力。传统的因果发现方法包括基于得分的方法、基于约束的方法和混合方法。广泛应用的基于得分的方法依赖于预定义的分数函数来对因果问题建模,应用启发式方法在DAG空间中搜索得分最高的DAG。但是由于该问题仍然是NP-难的,因此对于中大规模的实际问题,通常需要在采用近似搜索时附加结构假设。一些混合方法可以在某些假定约束的情况下,减小基于得分方法的搜索空间,但是这些方法缺乏关于得分函数和启发式策略选择的范式。Zheng等人通过运用邻接矩阵的连续函数规范了一种等效的无环约束。通过选择适当的损失函数,将问题的组合性质变成带有加权邻接矩阵的连续优化问题。尽管这种优化问题由于其非线性性,仅具有定点解而不是全局最优,这种局部解在应用中,可与耗时更久的、通过组合搜索获得的全局最优解高度等价。但是许多有效的得分函数,例如Huang等人的广义得分函数、Peters等人的“基于独立性的得分函数”可能无法和这种方法进整合。因为这些函数要么太复杂,要么无法闭式表示。
神经网络(NN)的最新发展推动了基于NN的因果推断方法的兴起。Goudet等人提出因果生成神经网络,通过最小化最大平均差异来学习一个观测变量联合分布的生成模型。Kalainathan等人提出了一种对抗生成式的结构不可知模型,该方法运用对抗生成,从连续观察数据中恢复完整的因果模型。Yu等人提出了DAG-GNN,一种新型图神经网络(GNN)架构参数化下的可变自动编码器来生成DAG。为了扩展神经网络使其可以处理任意图结构,Gori等人和Scarselli等人提出了图神经网络。图神经网络是指可以直接处理一般图类的递归神经网络的总称。此外,为了将卷积推广到图域,研究人员提出了频谱方法(例如ChebNet,GCN,一阶模型)和非频谱方法(例如MoNet,GraphSAGE)。此外,由于注意力机制已成为许多序列-到-序列问题的基础,
Figure BDA0002716112360000031
等人将自我注意力引入到了图传播过程,从而在转纳和归纳任务上均达到顶尖水平。
近年来,强化学习(RL)也刺激了许多应用的发展。除了在游戏领域中最著名的应用(例如AlphaGo和StarCraftII)外,RL还在机器人技术、目标定位、自然语言处理,决策服务和运输调度上展现出了良好的适用性、鲁棒性和通用性。RL也被Zoph等和Krishna等运用到神经体系结构的搜索,得到了人类水平的结果。
推断多个变量内的因果结构是许多经验科学领域中极其重要的步骤,为了解决传统的基于分数或基于约束的方法所面临的无向边或违反假设的问题,提出了用于因果推理的强化学习范例,并根据精心设计的奖励使用REINFORCE算法作为搜索策略来搜索有向非环图(DAG)。但是这种搜索方法易于局部收敛,并且在处理真实数据集时性能不稳定。同时,对于DAG生成之类的高维动作空间控制问题,信任域策略优化所需计算量非常大,而近端策略优化则容易出现异常行为。近端策略优化这种在单个动作上产生的微小偏差,会随着动作数量的增多而呈指数增长。
通过上述分析,现有技术存在的问题及缺陷为:目前使用REINFORCE算法作为搜索策略搜索有向非环图在处理真实数据集时,生成的图结果较差且训练性能不稳定。另一方面图注意力网络在没有相邻信息时对数据抽取的能力较差。
解决以上问题及缺陷的难度为:对于DAG生成之类的高维动作空间控制问题,信任域策略优化所需计算量非常大,而近端策略优化则容易出现异常行为;近端策略优化这种在单个动作上产生的微小偏差,会随着动作数量的增多而呈指数增长。
解决以上问题及缺陷的意义为:通过信任域引导裁剪,可以进一步保证因果推断过程中生成有向无环图的准确性。且提出的基于缩放点积图注意力网络可以在未给予相邻信息时对数据保持较好的抽取能力。
发明内容
针对现有技术存在的问题,本发明提供了一种信任域引导裁剪的策略优化方法、系统、存储介质及应用。
本发明是这样实现的,一种信任域引导裁剪的策略优化方法,所述信任域引导裁剪的策略优化方法包括:
缓冲
Figure BDA0002716112360000041
到R中;
从R中采样
Figure BDA0002716112360000042
Figure BDA0002716112360000043
计算批概率比矩阵
Figure BDA0002716112360000044
计算批KL距离矩阵{DKL(b,πθ|At,St)}N
运用批KL距离矩阵根据
Figure BDA0002716112360000045
裁剪批概率比矩阵:
Figure BDA0002716112360000046
Figure BDA0002716112360000047
以αθ,αω的学习率,优化θ和ω以最小化Lθ,Lω...;
其中,经验回放器R;批大小N;移动平均更新率αm;Actor和Critic学习率αθ,αω;熵值权重λe;裁剪比例∈;信任域区间δ;t=1,2,......
Figure BDA0002716112360000048
进一步,所述信任域引导裁剪的策略优化方法的更新策略的方向为:
Figure BDA0002716112360000049
其中,b(·|St)表示旧策略,πθ(At|St)中的每个元素表示独立子动作的概率,其动作空间为{0,1},给定At和St下,新旧策略之间的KL距离
Figure BDA00027161123600000410
计算为:
Figure BDA0002716112360000051
使用PPO计算联合概率比
Figure BDA0002716112360000052
时,要计算n*(n-1)个比例的累乘,对于其中的每个比率,两个约束之间的小偏差将随着比例的增多而成倍增大;
将基于似然比的触发条件替换为基于信任域的触发条件,策略不在信任域内时,似然比例将会被裁剪:
Figure BDA0002716112360000053
进一步,所述信任域引导裁剪的策略优化方法的因果模型定义fi为线性、Ni为高斯噪音的线性高斯模型;fi为线性、Ni为非高斯噪音的LiNGAM模型;fi为二次函数、Ni为非高斯噪音的非线性模型;fi为高斯过程采样的函数、Ni为正态分布噪音的非线性模型;根据因果模型,使用固定的随机DAG生成所有变量,并将数据集采样为X∈Rn×M
进一步,所述信任域引导裁剪的策略优化方法的状态、策略和动作包括:S∈Rn×m代表状态,其中n代表节点数量,m代表每个节点中采样得到的特征数量,在t时刻,
Figure BDA0002716112360000054
在训练过程中,通过从整个观察到的数据集X中随机抽取样本来构造St
策略由θ参数化,定义为
Figure BDA0002716112360000055
理解为邻接概率矩阵,其元素-子策略
Figure BDA0002716112360000056
代表存在从节点i到j的边的概率,在图生成的最后一步中,策略中的所有对角线元素都被设为零,即
Figure BDA0002716112360000057
A∈{0,1}n×n表示动作,为图
Figure BDA00027161123600000513
对应的二进制邻接矩阵,时间t时,
Figure BDA0002716112360000058
其子动作
Figure BDA0002716112360000059
表示存在在t时刻从节点i到节点j的边,每个子动作
Figure BDA00027161123600000510
是根据值为
Figure BDA00027161123600000511
的伯努利分布采样生成的;每个子动作是独立采样的,在状态St下运用策略采样得到At的概率为:
Figure BDA00027161123600000512
进一步,所述信任域引导裁剪的策略优化方法的奖励R∈R是包含得分函数和无环约束项的奖励信号,采用BIC作为得分,给定整个观察数据X,邻接矩阵At的BIC得分为:
Figure BDA0002716112360000061
其中,
Figure BDA0002716112360000062
表示由μ参数化的最大似然估计器,dμ代表μ的维数,如果使用线性模型对每个因果关系进行建模,则BIC得分计算为:
Figure BDA0002716112360000063
Figure BDA0002716112360000064
表示估算值,ne代表边数,公式中第二项用来惩罚冗余边数,第一项等效于GraN-DAG采用的对数似然目标,节点数据的附加噪声方差相等,得到另一种形式的BIC得分:
Figure BDA0002716112360000065
将有环性惩罚定义为:
Figure BDA0002716112360000066
如果At是有向图,则CL(At)=0;增加了值相对较大的指标惩罚:
Figure BDA0002716112360000067
通过将权重λ1和λ2赋予两项有环性惩罚,将三项惩罚值结合,总的奖励信号定义为:
R=-[BIC(At)+λ1CL(At)+λ2Ind(At)]。
进一步,所述信任域引导裁剪的策略优化方法的强化学习过程,根据对因果推断强化学习的建模,将优化目标定义为:
Figure BDA0002716112360000068
通过最大化有向图空间上的奖赏值,在以下条件下获得得分最高的DAG:
Figure BDA0002716112360000069
其中,上界BICu可以通过随机的DAG计算得到,λ1发置了一个相对较小的值。
进一步,所述信任域引导裁剪的策略优化方法的图生成神经网络架构编码器和解码器一起形成由θ参数化的actorθ模块,其输入是观测数据St,输出是图邻接矩阵At,编码器应理解变量之间的内在关系,并输出最能描述因果关系的编码enci,而解码器应使用该编码来解释变量之间的相互关;
(1)缩放点积图注意力编码器,首先运用图注意力网络GAT作为编码器,将GAT中的加性注意力替换为缩放点积注意力,形成SDGAT,SDGAT由ns个独立注意层堆叠,对于具有m特征的一组n-节点变量,
Figure BDA0002716112360000071
单个注意层将输出一组基数为m′的新n-节点变量;
SDGAT中采用了两级多头注意力的层次化结构,第一级采用相同于GAT的设计,在第一级层次结构中,h被分成ngat段,每段由
Figure BDA0002716112360000072
并由p索引,每个h(p)并行地插入具有nsd个头的第二级层次结构,每个头都遵循缩放点积注意力的基本结构,在第(p,q)子层中,查询Qpq,键Kpq和值Vpq由h(p)通过三套独立的线性投影层得到:
Figure BDA0002716112360000073
其中
Figure BDA0002716112360000074
而dk是Qpq和Kpq的隐藏层维度,通过相应的查询和键,第(p,q)子层中的注意力矩阵αpq∈Rn×n被并行计算为:
Figure BDA0002716112360000075
其中,
Figure BDA0002716112360000076
表示节点j的特征对节点i的重要性,
Figure BDA0002716112360000077
是对应的缩放比例因子;通过二元邻接矩阵掩盖相应
Figure BDA0002716112360000078
将结构信息注入到注意力机制中;
将nsd个独立的特征输出连接后进行线性投影,得到第一级注意力的输出,再对ngat个这样的输出再次进行串联,得到以下单层的输出表示形式:
Figure BDA0002716112360000079
其中,
Figure BDA00027161123600000710
对于网络的最后一层,对ngat个头的连接操作的意义不大,采用取平均操作:
Figure BDA00027161123600000711
所有隐藏层维度和单层输出维度保持不变。输入的特征在ns个注意力层之后,得出编码enc∈Rn×m′,以进行图的生成和Critic对状态值函数的估计;
(2)线性图生成解码器,给定编码输出enc,线性图生成解码器按逐元素操作的方式生成图邻接矩阵,通过遍历enc中的所有enci-encj对,子策略
Figure BDA0002716112360000081
计算为:
Figure BDA0002716112360000082
其中,
Figure BDA0002716112360000083
是可训练的参数,dh是解码器的隐藏层维度,所有对角线元素都被掩码为0,每个子动作
Figure BDA0002716112360000084
根据概率
Figure BDA0002716112360000085
按伯努利分布进行采样,生成整个邻接矩阵At
所述信任域引导裁剪的策略优化方法的应用于因果推断的强化学习算法包括:
(1)带有移动平均基线的REINFORCE算法,批大小N;移动平均更新率αm;Actor和Critic学习率αθ,αω;熵值权重λe;t=1,2,......生成批在线经验{St,At,Rt}N
Figure BDA0002716112360000086
Figure BDA0002716112360000087
Figure BDA0002716112360000088
Figure BDA0002716112360000089
以αθ,αω的学习率,优化θ和ω以最小化Lθ,Lω...
{·}N表示一批元素,
Figure BDA00027161123600000810
表示对批元素取平均运算,根据REINFORCE算法,朝以下方向优化:
Figure BDA00027161123600000811
其中,
Figure BDA00027161123600000812
表示用Rt+Rm-Vω(St)估计的优势函数,其中Rm表示移动平均值,而Vω(St)是ω参数化下的Critic由编码{enc}估计的值函数;由一批Rt以αm速率更新的Rm,通过减少参数基线的方差来稳定训练过程;Critic是一个带有ReLU单元的双层前馈网络,经过训练将其对状态值函数的预测值与真实奖励加上移动平均值的和之间的均方误差Lw降至最低;
熵正则化项也被添加到代理损失Lθ中:
Figure BDA0002716112360000091
通过最小化两个代理损失Lθ和Lω,使用Adam优化器分别以学习率αθ和αω训练Actor和Critic模块;
(2)优先级采样辅助的REINFORCE算法,将优先采样引入REINFORCE算法中,基于排名的优先级定义pi
Figure BDA0002716112360000092
其中经验i在经验回放区中按
Figure BDA0002716112360000093
排序;每条经验的采样概率i定义为:
Figure BDA0002716112360000094
经验回放器R;批大小N;移动平均更新率αm;Actor和Critic学习率αθ,αω;熵值权重λe
Figure BDA0002716112360000095
缓存
Figure BDA0002716112360000096
到R中
从R中根据采样
Figure BDA0002716112360000097
Figure BDA0002716112360000098
Figure BDA0002716112360000099
Figure BDA00027161123600000910
Figure BDA00027161123600000911
以αθ,αω的学习率,优化θ和ω以最小化Lθ,Lω...。
本发明的另一目的在于提供一种计算机可读存储介质,存储有计算机程序,所述计算机程序被处理器执行时,使得所述处理器执行如下步骤:
缓冲集合到经验回放器;
从经验回放器中采样;
计算批概率比矩阵、计算批KL距离矩阵;
运用批KL距离矩阵根据裁剪概率比矩阵。
本发明的另一目的在于提供一种所述信任域引导裁剪的策略优化方法在机器学习中的应用。
本发明的另一目的在于提供一种运行所述信任域引导裁剪的策略优化方法的信任域引导裁剪的策略优化系统,所述信任域引导裁剪的策略优化系统包括:
缓冲集合处理模块,用于缓冲集合到经验回放器;
采样模块,用于从经验回放器中采样;
矩阵计算模块,用于计算批概率比矩阵、计算批KL距离矩阵;
概率比矩阵裁剪模块,用于运用批KL距离矩阵根据裁剪概率比矩阵。
结合上述的所有技术方案,本发明所具备的优点及积极效果为:本发明为了解决目前使用REINFORCE算法作为搜索策略搜索有向非环图在处理真实数据集时性能不稳定;同时,对于DAG生成之类的高维动作空间控制问题,信任域策略优化所需计算量非常大,而近端策略优化则容易出现异常行为;近端策略优化这种在单个动作上产生的微小偏差,会随着动作数量的增多而呈指数增长的问题,提出了一种用于图注意力下因果推断的信任域引导裁剪的策略优化方法(TRCPO),与REINFORCE算法和优先采样引导的REINFORCE相比,该方法可确保更好的搜索效率和策略优化稳定性。还设计了一种称为缩放点积注意力图注意力网络(SDGAT),通过用缩放点积注意力替代原始图注意力网络中的加性注意力机制,可以在没有邻域信息的情况下掌握更多特征信息。
本发明提出信任域引导裁剪下的策略优化方法来搜索有向图空间,而不会遭受使用REINFORCE或PPO的局部收敛问题。这种方法也适用于解决其他高维动作空间控制问题。本发明结合了图神经网络和缩放点积图注意力机制的最新进展,形成了可缩放的点积图图注意力网络。该网络可以从原始数据中更好地抽取特征关系,从而提高DAG的精确生成。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例中所需要使用的附图做简单的介绍,显而易见地,下面所描述的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的信任域引导裁剪的策略优化方法流程图。
图2是本发明实施例提供的信任域引导裁剪的策略优化系统的结构示意图;
图2中:1、缓冲集合处理模块;2、采样模块;3、矩阵计算模块;4、概率比矩阵裁剪模块。
图3是本发明实施例提供的因果推断强化学习范式示意图。
图4是本发明实施例提供的REINFORCE训练过程示例示意图。
图5是本发明实施例提供的缩放点积图注意力网络示意图。
图6是本发明实施例提供的对两个非线性合成模型进行实验的结构示意图。
图7是本发明实施例提供的本发明的方法生成的有向无环图是所有因果发现方法得到的最好结果示意图。
具体实施方式
为了使本发明的目的、技术方案及优点更加清楚明白,以下结合实施例,对本发明进行进一步详细说明。应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
针对现有技术存在的问题,本发明提供了一种信任域引导裁剪的策略优化方法、系统、存储介质及应用,下面结合附图对本发明作详细的描述。
如图1所示,本发明提供的信任域引导裁剪的策略优化方法包括以下步骤:
S101:缓冲集合到经验回放器;
S102:从经验回放器中采样;
S103:计算批概率比矩阵、计算批KL距离矩阵;
S104:运用批KL距离矩阵根据裁剪概率比矩阵。
本发明提供的信任域引导裁剪的策略优化方法业内的普通技术人员还可以采用其他的步骤实施,图1的本发明提供的信任域引导裁剪的策略优化方法仅仅是一个具体实施例而已。
如图2所示,本发明提供的信任域引导裁剪的策略优化系统包括:
缓冲集合处理模块1,用于缓冲集合到经验回放器;
采样模块2,用于从经验回放器中采样;
矩阵计算模块3,用于计算批概率比矩阵、计算批KL距离矩阵;
概率比矩阵裁剪模块4,用于运用批KL距离矩阵根据裁剪概率比矩阵。
下面结合附图对本发明的技术方案作进一步的描述。
1、因果模型定义
考虑一组观测到的有限随机变量
Figure BDA0002716112360000121
其由序号集合V=1,...,n标注.每一个标量
Figure BDA0002716112360000122
都和图
Figure BDA0002716112360000124
中的一个节点i相关,图由n个节点组成的节点集合V和节点之间形成的边集合
Figure BDA0002716112360000123
组成。对于任意边v∈V,(v,v)∈ε。当(i,j)∈ε或者(j,i)∈ε一个节点i可以被认为是另一节点j的父节点或子节点。本发明将数据生成过程视为基于有向无环图的加性噪声模型,其中观测值xi是由图
Figure BDA0002716112360000125
中父节点集PAi的值,通过函数计算得到并带有加性噪声Ni
xi=fi(PAi)+Ni,i=1,...,n (1)
假设所有噪声变量Ni都具有严格的正密度,并且是共同独立的。因此为满足因果极小度,每个函数fi中的任何系数都不恒定。在常规的马尔可夫和忠实性假设下,以上模型只能达到马尔可夫等价类。为了评估本发明方法的适用性,本发明使用以下模型对合成数据集建模以进行实验:fi为线性、Ni为高斯噪音的线性高斯模型;fi为线性、Ni为非高斯噪音的LiNGAM模型;fi为二次函数、Ni为非高斯噪音的非线性模型;fi为高斯过程采样的函数、Ni为正态分布噪音的非线性模型。
根据上述模型,使用固定的随机DAG生成所有变量,并将数据集采样为X∈Rn×M
2、因果推断强化学习
图3给出了因果推理的RL模型,其采用表演者-评论家(Actor-Critic)模型。因果发现的目的是从观测数据X在时间t时刻下采样得到的状态St,推断数据间隐藏的因果DAG。基于编码器-解码器模块的Actor根据St生成图
Figure BDA00027161123600001310
的邻接矩阵。环境的奖励模块在得到输出后为Critic计算奖励Rt,以估算值函数V(St),将其与Rt一起组成优化演员的优势信号。该问题被建模为一个决策过程,由元组<S,A,R>表示。对于每个轨迹仅具有一个决策步骤。为了进一步解释学习过程,图4中采用REINFORCE的训练过程作为示例。
2.1状态、策略和动作
S∈Rn×m代表状态,其中n代表节点数量,m代表每个节点中采样得到的特征数量。在t时刻,
Figure BDA0002716112360000131
在训练过程中,通过从整个观察到的数据集x中随机抽取样本来构造St。值得注意的是,由于数据集x不会随时间变化,因此采样得到St的概率P(St)在所有时间t都相同,符合隐藏马尔可夫性。
策略由θ参数化,定义为
Figure BDA0002716112360000132
可理解为邻接概率矩阵。其元素-子策略
Figure BDA0002716112360000133
代表存在从节点i到j的边的概率。为了避免自环,在图生成的最后一步中,策略中的所有对角线元素都被设为零,即
Figure BDA0002716112360000134
然而,即使不考虑对角线元素,动作空间2n×(n-1)仍然是相当高维的,可能会带来局部收敛和梯度溢出的问题。
A∈{0,1}n×n表示动作,可以理解为图
Figure BDA00027161123600001311
对应的二进制邻接矩阵。时间t时,
Figure BDA0002716112360000135
其子动作
Figure BDA0002716112360000136
表示存在在t时刻从节点i到节点j的边。每个子动作
Figure BDA0002716112360000137
是根据值为
Figure BDA0002716112360000138
的伯努利分布采样生成的。由于每个子动作是独立采样的,因此在状态St下运用策略采样得到At的概率为:
Figure BDA0002716112360000139
2.2奖励信号
R∈R是包含得分函数和无环约束项的奖励信号。传统的基于得分的方法采用因果关系的由μ参数化的参数模型。本发明遵循Zhu等人的工作,采用BIC(因为其一致性和可分解性)作为得分。给定整个观察数据X,邻接矩阵At的BIC得分为:
Figure BDA0002716112360000141
其中,
Figure BDA0002716112360000142
表示由μ参数化的最大似然估计器,dμ代表μ的维数。如果使用线性模型对每个因果关系进行建模,则BIC得分可以计算为:
Figure BDA0002716112360000143
Figure BDA0002716112360000144
表示估算值,ne代表边数,因此公式中第二项用来惩罚冗余边数。此外,第一项等效于GraN-DAG采用的对数似然目标。假设节点数据的附加噪声方差相等,本发明可以得到另一种形式的BIC得分:
Figure BDA0002716112360000145
除了线性回归用于线性数据回归,二次回归和高斯过程回归也被用于非线性数据回归,以建模X的因果关系。考虑Zheng等人对无环性性约束的定义,本发明将有环性惩罚定义为:
Figure BDA0002716112360000146
如果At是有向图,则CL(At)=0。但是对于某些有环图,CL(At)的值非常小,因此非DAG的最小值很难计算。因此,本发明增加了值相对较大的指标惩罚,以利于DAG的生成:
Figure BDA0002716112360000147
通过将权重λ1和λ2赋予两项有环性惩罚,将三项惩罚值结合,总的奖励信号定义为:
R=-[BIC(At)+λ1CL(At)+λ2Ind(At)] (8)
2.3强化学习流程
根据对因果推断强化学习的建模,本发明将优化目标定义为:
Figure BDA0002716112360000148
通过最大化有向图空间上的奖赏值,可以在以下条件下获得得分最高的DAG:
Figure BDA0002716112360000149
其中,上界BICu可以通过随机的DAG计算得到。考虑Peters等人对基于独立性的分数的分析,下界BICl可以简单地被设置为零(考虑)。即使
Figure BDA0002716112360000151
难以计算,通过设置λ2=BICu-BICl,对于任意λ1≥0,条件(10)都能成立。本发明为λ1设置了一个相对较小的值,以助于推动DAG的准确生成。总体强化学习流程由算法2.3给出。
Figure BDA0002716112360000152
与NOTEARS中采用的拉格朗日方法相似,λ1和λ2从较小的值开始并逐渐递增。λ1逐次增加Δ1,λ2逐次乘以Δ2。其上界分别为Λ1和BICu,以确保约束(10)成立。同时,由于得分函数是无界的,而CL(At)和Ind(At)与得分函数的范围无关,因此本发明使用
Figure BDA0002716112360000153
将定义的得分函数限制在[0,BIC0]。在本发明实验中,BIC1是由通过完整的有向图计算得到,而BICu最初由空图计算得到。在DAG生成期间,按周期用记录的最低分数BICmin进行参数的调整。
3、图生成神经网络架构
为了从观察到的数据中推断出最能描述(1)中描述的数据生成过程的因果图,本发明需要设计一个合适的神经网络架构来进行图生成。编码器-解码器体系结构已成为许多序列-到-序列学习(例如机器翻译(Cho等人)和文本摘要(Nallapati等人))的首要选择,因此本发明中也被采用该架构。如图3所示,编码器和解码器一起形成由θ参数化的actorθ模块,其输入是观测数据St,输出是图邻接矩阵At。设计选择的编码器应理解变量之间的内在关系,并输出最能描述因果关系的编码enci,而解码器应使用该编码来解释变量之间的相互关系。
3.1缩放点积图注意力编码器
为了更好把握图间变量的因果关系,本发明首先运用图注意力网络(GAT)作为编码器。但是本发明在实验中,生成的图形和训练速度非常糟糕。由于本实验起初没有图结构的先验知识,该模型会忽略所有邻域状态,计算所有节点对之间的互信息。最常用的两种注意力机制是加性注意力和缩放点积注意力,本发明注意到,在没有邻接信息的情况下,GAT中采用的加性注意力不能有效地理解变量之间的相互关系。
本发明将GAT中的加性注意力替换为缩放点积注意力,从而形成SDGAT。与GAT相似,SDGAT也由ns个独立注意层堆叠。对于具有m特征的一组n-节点变量,
Figure BDA0002716112360000161
单个注意层将输出一组基数为m′的新n-节点变量。
如图5所示,SDGAT中采用了两级多头注意力的层次化结构,第一级采用相同于GAT的设计。在第一级层次结构中,h被分成ngat段,每段由
Figure BDA0002716112360000162
并由p索引。每个h(p)并行地插入具有nsd个头的第二级层次结构,每个头都遵循缩放点积注意力的基本结构。在第(p,q)子层中,查询Qpq,键Kpq和值Vpq由h(p)通过三套独立的线性投影层得到:
Figure BDA0002716112360000163
其中
Figure BDA0002716112360000164
而dk是Qpq和Kpq的隐藏层维度。通过相应的查询和键,第(p,q)子层中的注意力矩阵αpq∈Rn×n可以被并行计算为:
Figure BDA0002716112360000165
其中,
Figure BDA0002716112360000166
表示节点j的特征对节点i的重要性,
Figure BDA0002716112360000167
是对应的缩放比例因子。在本发明的实验中,该模型允许每个节点与每个其他节点建立连接,因为在生成图之前没有任何结构信息。如果用于其他给定图结构的任务,本发明可以通过二元邻接矩阵掩盖相应
Figure BDA0002716112360000168
将结构信息注入到注意力机制中。
将nsd个独立的特征输出连接后进行线性投影,得到第一级注意力的输出,再对ngat个这样的输出再次进行串联,从而得到以下单层的输出表示形式:
Figure BDA0002716112360000171
其中,
Figure BDA0002716112360000172
对于网络的最后一层,对ngat个头的连接操作的意义不大,因此采用取平均操作:
Figure BDA0002716112360000173
在本发明的实验中,所有隐藏层维度和单层输出维度保持不变。输入的特征在ns个注意力层之后,得出编码enc∈Rn×m′,以进行图的生成和Critic对状态值函数的估计。
3.2线性图生成解码器
给定编码输出enc,线性图生成解码器按逐元素操作的方式生成图邻接矩阵。通过遍历enc中的所有enci-encj对,在2.2中涉及的子策略
Figure BDA0002716112360000174
计算为:
Figure BDA0002716112360000175
其中,
Figure BDA0002716112360000176
是可训练的参数,dh是解码器的隐藏层维度。所有对角线元素都被掩码为0。每个子动作
Figure BDA0002716112360000177
根据概率
Figure BDA0002716112360000178
按伯努利分布进行采样,进而生成整个邻接矩阵At
4、应用于因果推断的强化学习算法
4.1带有移动平均基线的REINFORCE算法
Figure BDA0002716112360000179
如算法所示,{·}N表示一批元素,
Figure BDA00027161123600001710
表示对批元素取平均运算。根据REINFORCE算法,(9)可以朝以下方向优化:
Figure BDA0002716112360000181
其中,
Figure BDA0002716112360000182
表示用Rt+Rm-Vω(St)估计的优势函数,其中Rm表示移动平均值,而Vω(St)是ω参数化下的Critic由编码{enc}估计的值函数。由一批Rt以αm速率更新的Rm,通过减少参数基线的方差来稳定训练过程。本发明的Critic是一个带有ReLU单元的双层前馈网络,经过训练可将其对状态值函数的预测值与真实奖励加上移动平均值的和之间的均方误差Lω降至最低。
为了鼓励智能体的探索行为,熵正则化项也被添加到代理损失Lθ中:
Figure BDA0002716112360000183
最后本发明通过最小化两个代理损失Lθ和Lω,使用Adam优化器分别以学习率αθ和αω来训练Actor和Critic模块。
4.2优先级采样辅助的REINFORCE算法
本发明将优先采样引入REINFORCE算法中,以指导选择更有意义的经验批次。该方法较于原始REIN-FORCE算法展现出更强健的策略改进性能。在本发明中采用经验回放区是要记住罕见的但潜在的有用更新经验,并有助于确保更新独立同分布假设。
本发明重描具有较高优先级的经验,优先级由其优势函数
Figure BDA0002716112360000184
的大小来衡量。在Schaul等人的工作中,提出了两种优先级排序方式:均衡优先直接根据
Figure BDA0002716112360000185
的绝对值大小定义优先级pi;而基于排名的优先级定义pi
Figure BDA0002716112360000186
其中经验i在经验回放区中按
Figure BDA0002716112360000187
排序。由于基于排名的优先级不太容易受到
Figure BDA0002716112360000188
异常值的影响,且在应用中展现出更好的鲁棒性,因此本发明采用基于排名的优先级。
为了确保每条经验被采样的概率是单调的,同时对于最低优先级经验的采样概率要保证非零,因此每条经验的采样概率i定义为:
Figure BDA0002716112360000189
Figure BDA0002716112360000191
一般情况下,正如Q学习和ACER中,优先抽样在优化的过程中会带来偏差,因为它会以不受控制的方式更改估计的分布,从而可能收敛到不同解。在这些情况下,需要为每条经验计算对应的重要性采样率来调整优化方向与大小。但是在本发明的实验中,由于轨迹仅为单步,因此这种偏差在不存在序贯决策的情况下很小。而且在优化的过程中,本发明重新根据采样到状态和奖励计算优势函数时,从而消除偏差。此外本发明限制了经验回放区的大小,进一步限制了这种偏差的产生。
4.3信任域引导裁剪的策略优化
TRPO方法通过最大化受约束条件的代理目标来更新策略。按照本发明的框架,本发明更新策略的方向为:
Figure BDA0002716112360000192
其中,b(·|St)表示旧策略。如2.1中,πθ(At|St)中的每个元素表示独立子动作的概率,其动作空间为{0,1},因此给定At和St下,新旧策略之间的Kullback-Leibler(KL)距离
Figure BDA0002716112360000193
可计算为:
Figure BDA0002716112360000194
复杂的二阶形式使得优化过程中的计算量繁杂庞大,尤其是神经网络架构复杂时。作为一阶优化方法的PPO通过采用裁剪机制来避免处理前文所述的硬约束,从而大大降低了计算的复杂度。但是这种启发式的似然比例约束与信任区域约束之间存在明显差距。当本发明使用PPO计算联合概率比
Figure BDA0002716112360000201
时,要计算n*(n-1)个比例的累乘。对于其中的每个比率,两个约束之间的小偏差将随着比例的增多而成倍增大。另一方面,自适应KL惩罚版本的PPO,因为优化目标涉及KL散度的惩罚项,因此反向传播期间会消耗大量计算量。
鉴于这些问题,本发明提出更好的削减策略,在限制这两个约束间差距的同时,可以达到一阶优化的效率。受PPO将似然比用作裁剪触发条件的启发,本发明将基于似然比的触发条件替换为基于信任域的触发条件。一旦策略不在信任域内时,似然比例将会被裁剪:
Figure BDA0002716112360000202
Figure BDA0002716112360000203
下面结合实验对本发明的技术效果作详细的描述。
如图6所示,对两个非线性合成模型进行了实验,都以上面讨论的类似方式对上三角矩阵进行采样。对于二阶模型,给定第i个节点的上三角矩阵中隐含了父亲节点集合,将生成所有将一阶和二阶特征组合在一起的函数,系数为0或从[-1,-0.5]∪[0.5,1]的均匀分布中采样得到的实数。使用非高斯噪声设置,将生成一个包含5000个样本的10节点数据集,本发明使用二次回归对因果关系进行建模,通过对一阶和二阶项的所有系数进行阈值化来完成修剪过程。对于高斯过程模型,每个描述因果关系的函数都是从具有带宽为1的RBF内核的高斯过程采样的,附加噪声通常是正态分布的,方差是均匀采样的。因此,此处采用带有RBF内核的高斯过程回归模型对因果关系进行建模。本发明对观察到的数据进行归一化,并对内核带宽应用中间启发式算法,以避免由固定内核带宽引起的过拟合。本发明观察到,与REINFORCE相比,信任域辅助裁剪可以提高所有性能指标,并且在二阶模型中几乎可以生成真实的图形。
真实数据集方面,本发明考虑称为CYTO的一个真实的数据集,其由853个单细胞记录组成,其中包含11种磷蛋白和磷脂的丰度。作为因果关系问题的基线测试问题,它包含具有11个节点和17个边的图形的观察数据和干预数据。本发明使用观测数据,运用本发明提出的方法进行因果发现的探索。其数据归一化过程与前文所述一致,裁剪过程采用CAM裁剪。如图7所示,本发明提出的方法生成的有向无环图是所有因果发现方法得到的最好结果,其结构汉明距离仅为10。信任域辅助裁剪得到的结果中,发现了9个边,且8个边皆为正确边,且训练收敛性能优于其他所有方法。
应当注意,本发明的实施方式可以通过硬件、软件或者软件和硬件的结合来实现。硬件部分可以利用专用逻辑来实现;软件部分可以存储在存储器中,由适当的指令执行系统,例如微处理器或者专用设计硬件来执行。本领域的普通技术人员可以理解上述的设备和方法可以使用计算机可执行指令和/或包含在处理器控制代码中来实现,例如在诸如磁盘、CD或DVD-ROM的载体介质、诸如只读存储器(固件)的可编程的存储器或者诸如光学或电子信号载体的数据载体上提供了这样的代码。本发明的设备及其模块可以由诸如超大规模集成电路或门阵列、诸如逻辑芯片、晶体管等的半导体、或者诸如现场可编程门阵列、可编程逻辑设备等的可编程硬件设备的硬件电路实现,也可以用由各种类型的处理器执行的软件实现,也可以由上述硬件电路和软件的结合例如固件来实现。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,凡在本发明的精神和原则之内所作的任何修改、等同替换和改进等,都应涵盖在本发明的保护范围之内。

Claims (10)

1.一种信任域引导裁剪的策略优化方法,其特征在于,所述信任域引导裁剪的策略优化方法包括:
缓冲
Figure FDA0002716112350000011
到R中;
从R中采样
Figure FDA0002716112350000012
Figure FDA0002716112350000013
计算批概率比矩阵
Figure FDA0002716112350000014
计算批KL距离矩阵{DKL(b,πθ|At,St)}N
运用批KL距离矩阵根据
Figure FDA0002716112350000015
裁剪批概率比矩阵:
Figure FDA0002716112350000016
Figure FDA0002716112350000017
以αθ,αω的学习率,优化θ和ω以最小化Lθ,Lω...;
其中,经验回放器R;批大小N;移动平均更新率αm;Actor和Critic学习率αθ,αω;熵值权重λe;裁剪比例∈;信任域区间δ;
Figure FDA0002716112350000018
Figure FDA0002716112350000019
2.如权利要求1所述的信任域引导裁剪的策略优化方法,其特征在于,所述信任域引导裁剪的策略优化方法的更新策略的方向为:
Figure FDA00027161123500000110
其中,b(·|St)表示旧策略,πθ(At|St)中的每个元素表示独立子动作的概率,其动作空间为{0,1},给定At和St下,新旧策略之间的KL距离
Figure FDA00027161123500000111
计算为:
Figure FDA00027161123500000112
使用PPO计算联合概率比
Figure FDA00027161123500000113
时,要计算n*(n-1)个比例的累乘,对于其中的每个比率,两个约束之间的小偏差将随着比例的增多而成倍增大;
将基于似然比的触发条件替换为基于信任域的触发条件,策略不在信任域内时,似然比例将会被裁剪:
Figure FDA0002716112350000021
3.如权利要求1所述的信任域引导裁剪的策略优化方法,其特征在于,所述信任域引导裁剪的策略优化方法的因果模型定义fi为线性、Ni为高斯噪音的线性高斯模型;fi为线性、Ni为非高斯噪音的LiNGAM模型;fi为二次函数、Ni为非高斯噪音的非线性模型;fi为高斯过程采样的函数、Ni为正态分布噪音的非线性模型;根据因果模型,使用固定的随机DAG生成所有变量,并将数据集采样为
Figure FDA0002716112350000022
4.如权利要求1所述的信任域引导裁剪的策略优化方法,其特征在于,所述信任域引导裁剪的策略优化方法的状态、策略和动作包括:
Figure FDA0002716112350000023
代表状态,其中n代表节点数量,m代表每个节点中采样得到的特征数量,在t时刻,
Figure FDA0002716112350000024
在训练过程中,通过从整个观察到的数据集X中随机抽取样本来构造St
策略由θ参数化,定义为
Figure FDA0002716112350000025
理解为邻接概率矩阵,其元素-子策略
Figure FDA0002716112350000026
代表存在从节点i到j的边的概率,在图生成的最后一步中,策略中的所有对角线元素都被设为零,即
Figure FDA0002716112350000027
A∈{0,1}n×n表示动作,为图g对应的二进制邻接矩阵,时间t时,
Figure FDA0002716112350000028
其子动作
Figure FDA0002716112350000029
表示存在在t时刻从节点i到节点j的边,每个子动作
Figure FDA00027161123500000210
是根据值为
Figure FDA00027161123500000211
的伯努利分布采样生成的;每个子动作是独立采样的,在状态St下运用策略采样得到At的概率为:
Figure FDA00027161123500000212
5.如权利要求1所述的信任域引导裁剪的策略优化方法,其特征在于,所述信任域引导裁剪的策略优化方法的奖励
Figure FDA00027161123500000213
是包含得分函数和无环约束项的奖励信号,采用BIC作为得分,给定整个观察数据X,邻接矩阵At的BIC得分为:
Figure FDA0002716112350000031
其中,
Figure FDA0002716112350000032
表示由μ参数化的最大似然估计器,dμ代表μ的维数,如果使用线性模型对每个因果关系进行建模,则BIC得分计算为:
Figure FDA0002716112350000033
Figure FDA0002716112350000034
表示估算值,ne代表边数,公式中第二项用来惩罚冗余边数,第一项等效于GraN-DAG采用的对数似然目标,节点数据的附加噪声方差相等,得到另一种形式的BIC得分:
Figure FDA0002716112350000035
将有环性惩罚定义为:
Figure FDA0002716112350000036
如果At是有向图,则CL(At)=0;增加了值相对较大的指标惩罚:
Figure FDA0002716112350000037
通过将权重λ1和λ2赋予两项有环性惩罚,将三项惩罚值结合,总的奖励信号定义为:
R=-[BIC(At)+λ1CL(At)+λ2Ind(At)]。
6.如权利要求1所述的信任域引导裁剪的策略优化方法,其特征在于,所述信任域引导裁剪的策略优化方法的强化学习过程,根据对因果推断强化学习的建模,将优化目标定义为:
Figure FDA0002716112350000038
通过最大化有向图空间上的奖赏值,在以下条件下获得得分最高的DAG:
Figure FDA0002716112350000039
其中,上界BICu可以通过随机的DAG计算得到,λ1发置了一个相对较小的值。
7.如权利要求1所述的信任域引导裁剪的策略优化方法,其特征在于,所述信任域引导裁剪的策略优化方法的图生成神经网络架构编码器和解码器一起形成由θ参数化的actorθ模块,其输入是观测数据St,输出是图邻接矩阵At,编码器应理解变量之间的内在关系,并输出最能描述因果关系的编码enci,而解码器应使用该编码来解释变量之间的相互关;
(1)缩放点积图注意力编码器,首先运用图注意力网络GAT作为编码器,将GAT中的加性注意力替换为缩放点积注意力,形成SDGAT,SDGAT由ns个独立注意层堆叠,对于具有m特征的一组n-节点变量,
Figure FDA0002716112350000041
单个注意层将输出一组基数为m′的新n-节点变量;
SDGAT中采用了两级多头注意力的层次化结构,第一级采用相同于GAT的设计,在第一级层次结构中,h被分成ngat段,每段由
Figure FDA0002716112350000042
并由p索引,每个h(p)并行地插入具有nsd个头的第二级层次结构,每个头都遵循缩放点积注意力的基本结构,在第(p,q)子层中,查询Qpq,键Kpq和值Vpq由h(p)通过三套独立的线性投影层得到:
Figure FDA0002716112350000043
其中
Figure FDA0002716112350000044
而dk是Qpq和Kpq的隐藏层维度,通过相应的查询和键,第(p,q)子层中的注意力矩阵αpq∈Rn×n被并行计算为:
Figure FDA0002716112350000045
其中,
Figure FDA0002716112350000046
表示节点j的特征对节点i的重要性,
Figure FDA0002716112350000047
是对应的缩放比例因子;通过二元邻接矩阵掩盖相应
Figure FDA0002716112350000048
将结构信息注入到注意力机制中;
将nsd个独立的特征输出连接后进行线性投影,得到第一级注意力的输出,再对ngat个这样的输出再次进行串联,得到以下单层的输出表示形式:
Figure FDA0002716112350000049
其中,
Figure FDA00027161123500000410
对于网络的最后一层,对ngat个头的连接操作的意义不大,采用取平均操作:
Figure FDA00027161123500000411
所有隐藏层维度和单层输出维度保持不变。输入的特征在ns个注意力层之后,得出编码enc∈Rn×m′,以进行图的生成和Critic对状态值函数的估计;
(2)线性图生成解码器,给定编码输出enc,线性图生成解码器按逐元素操作的方式生成图邻接矩阵,通过遍历enc中的所有enci-encj对,子策略
Figure FDA0002716112350000051
计算为:
Figure FDA0002716112350000052
其中,
Figure FDA0002716112350000053
是可训练的参数,dh是解码器的隐藏层维度,所有对角线元素都被掩码为0,每个子动作
Figure FDA0002716112350000054
限据概率
Figure FDA0002716112350000055
按伯努利分布进行采样,生成整个邻接矩阵At
所述信任域引导裁剪的策略优化方法的应用于因果推断的强化学习算法包括:
(1)带有移动平均基线的REINFORCE算法,批大小N;移动平均更新率αm;Actor和Critic学习率αθ,αω;熵值权重λe;t=1,2,......生成批在线经验{St,At,Rt}N
Figure FDA0002716112350000056
Figure FDA0002716112350000057
Figure FDA0002716112350000058
Figure FDA0002716112350000059
以αθ,αω的学习率,优化θ和ω以最小化Lθ,Lω...
{·}N表示一批元素,
Figure FDA00027161123500000510
表示对批元素取平均运算,根据REINFORCE算法,朝以下方向优化:
Figure FDA00027161123500000511
其中,
Figure FDA00027161123500000512
表示用Rt+Rm-Vω(St)估计的优势函数,其中Rm表示移动平均值,而Vω(St)是ω参数化下的Critic由编码{enc}估计的值函数;由一批Rt以αm速率更新的Rm,通过减少参数基线的方差来稳定训练过程;Critic是一个带有ReLU单元的双层前馈网络,经过训练将其对状态值函数的预测值与真实奖励加上移动平均值的和之间的均方误差Lw降至最低;
熵正则化项也被添加到代理损失Lθ中:
Figure FDA0002716112350000061
通过最小化两个代理损失Lθ和Lω,使用Adam优化器分别以学习率αθ和αω训练Actor和Critic模块;
(2)优先级采样辅助的REINFORCE算法,将优先采样引入REINFORCE算法中,基于排名的优先级定义pi
Figure FDA0002716112350000062
其中经验i在经验回放区中按
Figure FDA0002716112350000063
排序;每条经验的采样概率i定义为:
Figure FDA0002716112350000064
经验回放器R;批大小N;移动平均更新率αm;Actor和Critic学习率αθ,αω;熵值权重λe;t=1,2,......
Figure FDA0002716112350000065
缓存
Figure FDA0002716112350000066
到R中
从R中根据采样
Figure FDA0002716112350000067
Figure FDA0002716112350000068
Figure FDA0002716112350000069
Figure FDA00027161123500000610
Figure FDA00027161123500000611
以αθ,αω的学习率,优化θ和ω以最小化Lθ,Lω...。
8.一种计算机可读存储介质,存储有计算机程序,所述计算机程序被处理器执行时,使得所述处理器执行如下步骤:
缓冲集合到经验回放器;
从经验回放器中采样;
计算批概率比矩阵、计算批KL距离矩阵;
运用批KL距离矩阵根据裁剪概率比矩阵。
9.一种如权利要求1~7任意一项所述信任域引导裁剪的策略优化方法在机器学习中的应用。
10.一种运行权利要求1~7任意一项所述信任域引导裁剪的策略优化方法的信任域引导裁剪的策略优化系统,其特征在于,所述信任域引导裁剪的策略优化系统包括:
缓冲集合处理模块,用于缓冲集合到经验回放器;
采样模块,用于从经验回放器中采样;
矩阵计算模块,用于计算批概率比矩阵、计算批KL距离矩阵;
概率比矩阵裁剪模块,用于运用批KL距离矩阵根据裁剪概率比矩阵。
CN202011074176.3A 2020-10-09 2020-10-09 信任域引导裁剪的策略优化方法、系统、存储介质及应用 Pending CN112149359A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011074176.3A CN112149359A (zh) 2020-10-09 2020-10-09 信任域引导裁剪的策略优化方法、系统、存储介质及应用

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011074176.3A CN112149359A (zh) 2020-10-09 2020-10-09 信任域引导裁剪的策略优化方法、系统、存储介质及应用

Publications (1)

Publication Number Publication Date
CN112149359A true CN112149359A (zh) 2020-12-29

Family

ID=73952724

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011074176.3A Pending CN112149359A (zh) 2020-10-09 2020-10-09 信任域引导裁剪的策略优化方法、系统、存储介质及应用

Country Status (1)

Country Link
CN (1) CN112149359A (zh)

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112990437A (zh) * 2021-03-24 2021-06-18 厦门吉比特网络技术股份有限公司 一种基于因果多输出的强化学习神经网络及其构建方法
CN113064907A (zh) * 2021-04-26 2021-07-02 陕西悟空云信息技术有限公司 一种基于深度强化学习的内容更新方法
CN113296502A (zh) * 2021-05-08 2021-08-24 华东师范大学 动态环境下基于层级关系图学习的多机器人协同导航方法
CN113470758A (zh) * 2021-07-06 2021-10-01 北京科技大学 基于因果发现和多结构信息编码的化学反应收率预测方法
CN114493885A (zh) * 2022-03-30 2022-05-13 支付宝(杭州)信息技术有限公司 策略组合的优化方法及装置
CN114666204A (zh) * 2022-04-22 2022-06-24 广东工业大学 一种基于因果强化学习的故障根因定位方法及系统
CN114661783A (zh) * 2022-03-02 2022-06-24 华南师范大学 一种基于用电行为的生活状态检测方法
CN117648673A (zh) * 2024-01-29 2024-03-05 深圳海云安网络安全技术有限公司 一种基于大模型的安全编码规范多标融合方法及系统

Cited By (14)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112990437B (zh) * 2021-03-24 2024-05-14 厦门吉比特网络技术股份有限公司 一种基于因果多输出的强化学习神经网络及其构建方法
CN112990437A (zh) * 2021-03-24 2021-06-18 厦门吉比特网络技术股份有限公司 一种基于因果多输出的强化学习神经网络及其构建方法
CN113064907B (zh) * 2021-04-26 2023-02-21 陕西悟空云信息技术有限公司 一种基于深度强化学习的内容更新方法
CN113064907A (zh) * 2021-04-26 2021-07-02 陕西悟空云信息技术有限公司 一种基于深度强化学习的内容更新方法
CN113296502A (zh) * 2021-05-08 2021-08-24 华东师范大学 动态环境下基于层级关系图学习的多机器人协同导航方法
CN113470758A (zh) * 2021-07-06 2021-10-01 北京科技大学 基于因果发现和多结构信息编码的化学反应收率预测方法
CN113470758B (zh) * 2021-07-06 2023-10-13 北京科技大学 基于因果发现和多结构信息编码的化学反应收率预测方法
CN114661783A (zh) * 2022-03-02 2022-06-24 华南师范大学 一种基于用电行为的生活状态检测方法
CN114661783B (zh) * 2022-03-02 2024-07-19 华南师范大学 一种基于用电行为的生活状态检测方法
CN114493885A (zh) * 2022-03-30 2022-05-13 支付宝(杭州)信息技术有限公司 策略组合的优化方法及装置
CN114666204A (zh) * 2022-04-22 2022-06-24 广东工业大学 一种基于因果强化学习的故障根因定位方法及系统
CN114666204B (zh) * 2022-04-22 2024-04-16 广东工业大学 一种基于因果强化学习的故障根因定位方法及系统
CN117648673A (zh) * 2024-01-29 2024-03-05 深圳海云安网络安全技术有限公司 一种基于大模型的安全编码规范多标融合方法及系统
CN117648673B (zh) * 2024-01-29 2024-05-03 深圳海云安网络安全技术有限公司 一种基于大模型的安全编码规范多标融合方法及系统

Similar Documents

Publication Publication Date Title
CN112149359A (zh) 信任域引导裁剪的策略优化方法、系统、存储介质及应用
US11436496B2 (en) Systems and methods for regularizing neural networks
CA3090759A1 (en) Systems and methods for training generative machine learning models
Last et al. A compact and accurate model for classification
Hassan et al. A hybrid of multiobjective Evolutionary Algorithm and HMM-Fuzzy model for time series prediction
Chen et al. Experiments with repeating weighted boosting search for optimization signal processing applications
Jun-Zhong et al. A Bayesian network learning algorithm based on independence test and ant colony optimization
Peng et al. Towards sparsification of graph neural networks
Qin et al. Temporal link prediction: A unified framework, taxonomy, and review
CN113326884B (zh) 大规模异构图节点表示的高效学习方法及装置
Chen et al. Unsupervised sampling promoting for stochastic human trajectory prediction
Wang et al. Dynamic graph Conv-LSTM model with dynamic positional encoding for the large-scale traveling salesman problem
Phan et al. A New Fuzzy Logic‐Based Similarity Measure Applied to Large Gap Imputation for Uncorrelated Multivariate Time Series
Cui et al. Path-based multi-hop reasoning over knowledge graph for answering questions via adversarial reinforcement learning
Guan et al. Uac: Offline reinforcement learning with uncertain action constraint
Zhang et al. Missing-edge aware knowledge graph inductive inference through dual graph learning and traversing
Lei et al. A novel time-delay neural grey model and its applications
Deng et al. A progressive predictor-based quantum architecture search with active learning
Wang et al. Discovering Lin-Kernighan-Helsgaun heuristic for routing optimization using self-supervised reinforcement learning
CN117933402A (zh) 基于gnn的电网知识图谱多跳推理方法与系统
Zhou et al. Online recommendation based on incremental-input self-organizing map
Ninniri et al. Classifier-free graph diffusion for molecular property targeting
Kaur et al. Non-parametric learning of lifted restricted boltzmann machines
Block et al. Butterfly Effects of SGD Noise: Error Amplification in Behavior Cloning and Autoregression
Meng et al. Learning non-stationary dynamic Bayesian network structure from data stream

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20201229