CN117273225B - 一种基于时空特征的行人路径预测方法 - Google Patents
一种基于时空特征的行人路径预测方法 Download PDFInfo
- Publication number
- CN117273225B CN117273225B CN202311253071.8A CN202311253071A CN117273225B CN 117273225 B CN117273225 B CN 117273225B CN 202311253071 A CN202311253071 A CN 202311253071A CN 117273225 B CN117273225 B CN 117273225B
- Authority
- CN
- China
- Prior art keywords
- model
- track
- data
- gail
- pedestrian
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 151
- 230000008569 process Effects 0.000 claims description 47
- 230000009471 action Effects 0.000 claims description 36
- 238000012549 training Methods 0.000 claims description 33
- 230000000875 corresponding effect Effects 0.000 claims description 24
- 239000003795 chemical substances by application Substances 0.000 claims description 19
- 230000006870 function Effects 0.000 claims description 14
- 230000004927 fusion Effects 0.000 claims description 14
- 238000004458 analytical method Methods 0.000 claims description 11
- 238000012545 processing Methods 0.000 claims description 11
- 238000009826 distribution Methods 0.000 claims description 10
- 238000004364 calculation method Methods 0.000 claims description 9
- 238000005070 sampling Methods 0.000 claims description 9
- 238000009795 derivation Methods 0.000 claims description 8
- 230000008859 change Effects 0.000 claims description 6
- 230000006399 behavior Effects 0.000 claims description 5
- 230000001364 causal effect Effects 0.000 claims description 4
- 238000013145 classification model Methods 0.000 claims description 3
- 238000005259 measurement Methods 0.000 claims description 3
- 238000004422 calculation algorithm Methods 0.000 abstract description 36
- 238000002474 experimental method Methods 0.000 abstract description 14
- 238000005457 optimization Methods 0.000 abstract description 14
- 230000008901 benefit Effects 0.000 abstract description 6
- 238000010276 construction Methods 0.000 abstract description 5
- 102100040653 Tryptophan 2,3-dioxygenase Human genes 0.000 abstract 1
- 101710136122 Tryptophan 2,3-dioxygenase Proteins 0.000 abstract 1
- 230000003042 antagnostic effect Effects 0.000 abstract 1
- 238000012360 testing method Methods 0.000 description 17
- 238000010586 diagram Methods 0.000 description 16
- 230000002787 reinforcement Effects 0.000 description 12
- 230000000694 effects Effects 0.000 description 9
- 230000003993 interaction Effects 0.000 description 8
- 230000033001 locomotion Effects 0.000 description 7
- 230000007246 mechanism Effects 0.000 description 7
- 238000011160 research Methods 0.000 description 6
- 238000011161 development Methods 0.000 description 4
- 230000018109 developmental process Effects 0.000 description 4
- 238000006073 displacement reaction Methods 0.000 description 4
- 230000007613 environmental effect Effects 0.000 description 4
- 230000006872 improvement Effects 0.000 description 4
- 238000012800 visualization Methods 0.000 description 4
- 238000013256 Gubra-Amylin NASH model Methods 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 3
- 238000010801 machine learning Methods 0.000 description 3
- 238000007500 overflow downdraw method Methods 0.000 description 3
- 238000011176 pooling Methods 0.000 description 3
- 238000004088 simulation Methods 0.000 description 3
- 230000000007 visual effect Effects 0.000 description 3
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 2
- 230000006978 adaptation Effects 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 2
- 230000008485 antagonism Effects 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 239000000284 extract Substances 0.000 description 2
- 238000000605 extraction Methods 0.000 description 2
- 238000007499 fusion processing Methods 0.000 description 2
- 230000000670 limiting effect Effects 0.000 description 2
- 238000013507 mapping Methods 0.000 description 2
- 230000036961 partial effect Effects 0.000 description 2
- 101100004280 Caenorhabditis elegans best-2 gene Proteins 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000013213 extrapolation Methods 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 230000005021 gait Effects 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 238000003874 inverse correlation nuclear magnetic resonance spectroscopy Methods 0.000 description 1
- 238000011835 investigation Methods 0.000 description 1
- 230000002829 reductive effect Effects 0.000 description 1
- 230000002441 reversible effect Effects 0.000 description 1
- 238000012552 review Methods 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 230000002123 temporal effect Effects 0.000 description 1
- 239000009891 weiqi Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q10/00—Administration; Management
- G06Q10/04—Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
- G06N3/0442—Recurrent networks, e.g. Hopfield networks characterised by memory or gating, e.g. long short-term memory [LSTM] or gated recurrent units [GRU]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0475—Generative networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/092—Reinforcement learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/094—Adversarial learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Computational Linguistics (AREA)
- Evolutionary Computation (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Data Mining & Analysis (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Business, Economics & Management (AREA)
- Human Resources & Organizations (AREA)
- Economics (AREA)
- Strategic Management (AREA)
- Tourism & Hospitality (AREA)
- Game Theory and Decision Science (AREA)
- Development Economics (AREA)
- Quality & Reliability (AREA)
- Operations Research (AREA)
- Marketing (AREA)
- Entrepreneurship & Innovation (AREA)
- General Business, Economics & Management (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种基于时空特征的行人轨迹预测方法。首先,通过构建基于传统的置信域策略优化算法(Trust Region Policy Optimization,TRPO)和改进的近端策略优化算法(PP O‑penalty)的实验,验证了基于PPO‑penalty的GAIL(Generative Adversarial Imitation Learnin g,生成对抗式模仿学习)模型具有较大的优势,因此选择GAIL(PPO‑penalty)结构实现行人轨迹预测。然后,为了提高信息的利用率,防止信息的丢失,引入了ConstantPadding(常数填充)的方法,并将该方法命名为ConstantPadding‑GAIL。最后,创新性的提出使用Mogrifier LST M抽取行人历史信息中存在的时序特征,并将其融合到当前状态的实验构建。本发明使用基于PPO‑penalty的GAIL模型在已有的行人历史真实轨迹数据集基础上,学习行人在社交场合中如何运动,从而预测行人轨迹。
Description
技术领域
本发明属于人工智能领域,具体地说,涉及一种基于时空特征的行人路径预测方法。
背景技术
当前机器学习按照学习模式的不同,基本实现方法不同,大体上能够分为三大类,即:有监督学习Supervised Learning(基于拥有样本标签的数据),无监督学习Unsupervised Le arning(训练数据中不存在事先人工标记的标签)和强化学习(一种基于马尔可夫决策模型的代理-环境互动探索模型)。其中,强化学习(ReinforcementLearning)作为机器学习的一种实现,已经在越来越多决策问题领域发挥至关重要的作用,并且在一些策略游戏中也已经展现了出色的成绩,如DeepMind公司推出的AlphaGo、AphaZero等围棋机器人系列;同时在智慧城市交通信号控制等研究中也表现出了巨大潜力,基于传统的强化学习和深度神经网络的发展,深度强化学习(Deep ReinforcementLearning)也随之诞生,基于目前机器学习前沿的深度学习算法所具备的优异的特征感知能力,将其与传统强化学习优秀的多步连续问题中的决策能力相合并,开拓了新的天地;接下来,为了解决在奖励函数很难完整定义等场景问题,研究人员又提出了逆向强化学习(Inverse Reinforcement Learning),通过Agent代理与环境交互过程的数据中,尽可能学习到一个近似最优的专家奖励函数,从而指导策略更新,学到一个近似专家的策略,最终给出预测结果;而在2016年,GAIL(Generative Adversarial Imitation Learning)的作者证明了该算法在实施上与Inverse RL的等效性,同时在具体实现过程中通过减少内层冗余的强化学习过程,极大地节省了计算资源,是一个优秀的模拟学习的算法。
现有技术中也提出了部分行人路径预测方法,例如:
现有技术一,如公开号CN113888638A,公开了一种基于注意力机制的图神经网络行人路径预测方法,在交互技术中同时关注空间相关性和时间相关性,并通过注意力机制将有效信息最大化,其提出的技术方案包括:采集行人轨迹信息,提取轨迹运动特征,构建行人轨迹原始节点图,其中所述行人轨迹原始节点图包含行人轨迹的空间信息和时间信息;对所述行人轨迹原始节点图进行融合、舍弃和放大,过滤出对形成行人轨迹影响重大的信息,生成行人轨迹最终节点图,其中所述融合、舍弃和放大由图通道注意力机制完成;利用时空图卷积神经网络提取行人轨迹最终节点图的时空特征,根据所述时空特征构建行人轨迹的原始时空特征图,并利用图通道注意力机制筛选出重要的时空特征组成新的时空特征图;将所述新的时空特征图输入预测器,预测器输出预测的预设时间内的行人轨迹,其中所述预测器采用时间外推神经网络,预测的行人轨迹包含多种不同结果;给所述预测的行人轨迹分配权重,以权重最大的轨迹作为最终的预测结果,其中所述分配权重由时间通道注意力机制完成。
现有技术二,如公开号为CN113658228B,公开了一种基于卷积神经网络的行人路径预测系统和方法,其提出的技术方案包括:重新构建了基于卷积神经网络的行人路径预测的解决方案,优化了行人路径预测系统的整体系统架构,提高了数据处理的速率和现实场景下的预测能力,综合考虑了现实世界场景下的各种上下文背景和环境因素。
对于模型性能评估的指标主要有两个:平均位移误差(ADE)和最终位移误差(FDE)。具体来说,ADE评价的是模型沿轨迹的平均预测性能,而FDE只考虑最后的预测精度。两个指标的数值越小,网络的表现性能越优。两个指标的定义如下:
ADE:用来衡量模型预测轨迹点的坐标和真实值坐标的差距(越小预测精度越高)。
其中,t表示时间帧,pt表示时间帧为t时行人的坐标位置。
在上式中,n表示第n个Agent或者同一个Agent的第n轮预测,假设场景存在5个Agent,此时N为5,n从0到4,T表示总预测时间帧数,因此上式表示每个Agent每一步采取行动之后,当前预测坐标与真实坐标的二范数值即距离。
FDE:用来衡量模型预测到最终点的坐标和真实最终点的差距(越小离真实目标越近)。
其中,符号定义同ADE过程。此时只取每个Agent的最终坐标值进行衡量;因此只在t=T的时候衡量其距离的平均值。
但是,行人运动轨迹和模式受规则常识、相互作用、步态特征等影响,研究中仍存在一些问题导致以上两个指标难以提升。问题比较集中在:
1)交互缺乏可解释性。网络模型在进行训练时,使用的数据都是能客观测量的数据,对行人运动意图把握不准,缺少依赖于人的主观判断来训练算法的数据。例如,在自动驾驶场景中,自动驾驶辅助系统关心的是此时的行人是否会过马路。有的模型利用头部姿势,结合行人行为预测进行了一些尝试,但获取数据的方式单一,行人的主观意图研究的少。所以,目前的模型对计算到的交互缺乏可解释性,仍然依赖于数据驱动。
2)动态图缺少时序特征。基于图结构的网络架构,在时序构建动态图的过程中,对不同时刻的目标的相关信息缺少跟踪与更新。换言之,模型在各个时间点能清楚地获得目标(例如障碍物)的位置,但目前算法没有对目标在时序上进行关联,网络无法理解两个时刻目标的对应关系,降低了交互的性能,导致图网络结构不稳定。
3)预测算法环境适应能力不强。现有的社交感知方法假设所有被观察到的行人行为相似,并且他们的运动可以用相同的模型和特征来预测,对高层社会属性的捕捉和推理不强。大多数模型都是针对特定的场景、任务或运动而设计。这些方法在空间结构具体、运动模式固定时表现良好,例如,当环境中运动模式显著、空间结构和行人目标已知时,而在未定义的、不断变化的情况下性能较差。
发明内容
要解决的技术问题
针对上述现有技术存在的问题,本发明提供一种基于时空特征的行人路径预测方法。
本发明使用GAIL模型在已有的行人历史真实轨迹数据集基础上,学习行人在社交场合中如何运动,从而预测行人轨迹。根据调查和检索,目前还结合生成对抗模仿学习GAIL并且考虑社交特性进行预测的工作,本发明填补了行人路径预测技术方面的空缺。
首先,本方法构建了实验分别采用传统的置信域策略优化算法和近端策略优化算法,通过对GAIL(PPO-penalty)以及GAIL(TRPO)的模型结果进行的对比分析,可以看出基于PPO-penalty的GAIL模型结果的优势十分明显,因此也奠定了后续继续研究对GAIL(PPO-penalty)结构的拓展和改进。传统的GAIL(TRPO)算法已经在性能上劣于SGAN模型,但是通过对比,本方法使用的GAIL(PPO-penalty)在ADE和FDE指标上均优于SGAN。此处仅基线方法就已经能够取得很好的效果;在此基础上,为了解决时序特征问题,对当前的基础模型进一步优化改进。
其次,基于GAIL(PPO-penalty)算法的结果,此处,本方法引入了ConstantPadding的方法,测试发现ConstantPadding的方法在数据集有限的情况下,对于提高数据质量和利用率具有较为明显的作用,只引入了ConstantPadding的拓展方法,还没有添加MogrifierLSTM时序特性的情况下,模型能够表现最佳结果。基于结果对比,可以看到,GAIL(ConstantPadding)方法可以将模型对于训练数据的利用率提高,并且较好的提高两个指标ADE和FDE的准确度。最后,基于ConstantPadding-GAIL的结果,考虑使用Mogrifier LSTM抽取行人历史信息中存在的时序特征,并将其融合到当前状态的实验构建。
技术方案
为解决上述问题,本发明采用如下的技术方案。
本申请请求保护一种基于时空特征的行人路径预测方法,其特征在于,包括以下步骤:
步骤S1:构建基于GAN网络的GAIL模型,包括:
步骤S11:构建GAIL模型,如下式(I):
在上式(I)中,E表示策略的期望,或策略为专家的期望,π是训练得到的策略模型,Eπ即为对策略π求取期望运算;EπE表示对专家数据蕴含的专家策略求取期望运算;log表示求以10为底的对数;D为判别器,s表示当前时刻的状态,H表示λ参数控制的策略调整器,同时根据所学习的策略,输出对应的动作值记作aπ,aπE表示采用专家策略生成的动作值;action所对应的动作空间是A<vx,vy>;
步骤S12:GAIL模型对照到GAN网络,可以得到公式(II):
其中,s和a分别表示状态空间和动作空间,S和A分别表示s和a的取值范围;其余参数同公式(I);其训练过程是一个最小化min和最大化max的过程,同时进行生成器和策略网络的博弈,使得策略网络在生成器的打分中不断优化模型参数,得到最小值,同时,判别器也在不断训练判别能力,企图最大化以上的目标函数值;
步骤S2:在所述GAIL模型中融合Mogrifier LSTM提取的历史信息,基于前n个可变步长的预测过程,给出下一时刻的结果预测,其中,n为自然数;
步骤S3:构建与Mogrifier LSTM模型适配的缓冲区;
步骤S4:构建基于Mogrifier LSTM的MogrifierGAIL模型;
步骤S5:将判别器模型和生成器进行优化,直至达到一种稳态即纳什均衡点;
步骤S6:输入行人轨迹数据至S5得到的训练好的模型;将所观测到的轨迹作为输入,所观测的长度可以是1个五元组【x1,y1,x,y,t=1】至8个五元组组成的序列;
步骤S7:利用训练好的模型进行输出,得到该行人的预测轨迹数据。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
步骤S13:对于策略π进行优化,将问题归类为最小化JS散度的问题,具体包括:
步骤S131:首先将度量两种策略的占用度量公式(occupancy measure公式)转化为特殊的风险期望值;
步骤S132:为了最小化策略度量,可以对风险期望值进行最小化探索,将对于正则化项的推导进一步转化成为对于风险期望函数的推导;
步骤S133:最终得到对于策略更新过程中的自然梯度更新。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
步骤S2中在GAIL模型中融合Mogrifier LSTM提取的历史信息的操作如下:
步骤S21:对行人的下一步位置或者行人所采取的行动进行预测,提出加入Mogrifier L STM进行前几步长的时序特征隐藏层输出值融入模型中;
步骤S22:将前几步长的时序特征与当前的观测状态同时进行考量,做出基于前几个时间步而言在合理接受的阈值范围内的行为动作;
步骤S23:得到较为可靠的预测精确度。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
其中,前n个时刻的历史轨迹将会被作为模型的记忆信息输入Mogrifier LSTM网络中,基于上述历史轨迹的绝对坐标信息,提取出对应的OUT输出状态;
将上述输出状态在后续应用到MLP多层感知机中进行特征融合,使用MogrifierLSTM提取时序特征。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
步骤S3中构建与长短期记忆模型(Mogrifier LSTM)适配的缓冲区的具体操作如下:
步骤S31:假设我们目前设定的观测值为n,在前期,从开始到当前观测时刻t的长度不到n个可变步长时,均采用ConstantPadding的方法填充0;
步骤S32:在n个可变步长后,不断更新当前状态,其中最右边的五元组:(xt,yt,x,y,t)为当前状态,包括当前时刻的二维坐标(xt,yt)目标坐标(x,y)以及当前时间信息(t),会随着时间的变化,更新这个五元组;其中前n-1个序列为观测值;
步骤S33:经过上述数据结构的变化,模型最终在每次输入的时候,均携带前期的总共n个可变步长的观测值信息。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
步骤S3中,在GAIL模型中通过构建专家数据缓冲区,为模型不断地提供行人历史真实轨迹数据分布的学习与计算;
在缓冲区中,将当前时刻的状态作为最后一个时刻,引入变量obs_len观测步长;
在上述过程中,采用了ConstantPadding方法,给不足特定步长的部分,填补为0。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
步骤S4中构建基于Mogrifier LSTM的MogrifierGAIL模型的具体操作包括:
步骤S41:将采样的轨迹或者行人历史真实轨迹,输入到Input中,其中轨迹为n个步长的特征五元组所组成的连续轨迹;
步骤S42:输入的数据会被分割成两部分,一部分采用最低维分割出来当前Agent的位置信息,存储在(d1,d2,x)的数据中,x将会存储自定义的hidden_size大小信息,以便根据当前状态做出下一步的决策;另外一部分为原始数据的备份,和另外的时序数据将会被分在另一个分支中;
步骤S43:首先进行维度的变化,将rollout_len和processor_num合并成为批数据,同时对Mogrifier LSTM采用batch_first=True的设定,批量对数据进行处理。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
步骤S4中在Mogrifier LSTM处理之后,还包括:
步骤S44:Mogrifier LSTM处理之后,数据将会被取出其Out[:,-1:]的数据,之后被进行维度拉伸,回到原始的(d1,d2,x)的维度大小,x将会存储自定义的hidden_size大小信息,即代表时序特征的信息;
步骤S45:将原始的当前位置特征与时序信息进行融合,合并之后采用嵌入的方法提取融合特征;
步骤S46:输入标准的MLP层用来得到对应out_size的输出,对应不同的out_size大小,信息会分为两种处理取到,action信息将会被用来计算对应的loss值,并且更新generator生成器的参数,奖励信息会被用优化器来更新判别器的参数。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
步骤S4,还包括:
步骤S47:行人历史真实轨迹和采样轨迹均经过共用的Actor_Critic的MLP类结构,具体在实际的训练过程中,利用所用的普通多层感知机;
步骤S48:采用了多处理器并行运行的方式,对应不同的处理器,会初始化对应设定数量的环境个数,随即进行特征分析与融合;
步骤S49:进行采样与轨迹更新,从产生动作中进行判别器打分作为奖励,继续优化训练的策略。
进一步的,本申请请求保护一种基于时空特征的行人路径预测方法,还包括:
所述步骤S5具体包括:
步骤S51:判别器D是一个二分类模型,将在真实轨迹A与生成器所生成的轨迹B所组成的带标签的数据池中随机采样,并且将采样的轨迹作为判别器的输入,判别器将给出即判断输入属于真实轨迹或生成轨迹的分类结果,根据判别器的分类结果,可以与标签真值进行对比并且计算误差;优化器目标为将误差值进行最小化;
步骤S52:对于生成器,其从一个初始化的分布中进行采样,经过生成器的网络的处理并输出,输出将会作为生成的轨迹与真实轨迹组合成为数据池;在每次输入轨迹进入判别器时,都会进行轨迹池的更新;
步骤S53:将生成器和判别器不断优化时,最终达到一种稳态即纳什均衡点,从而获得了能够达到最优的轨迹的生成器模型参数。
有益效果
相比于现有技术,本发明的有益效果为:
首先,本方法构建了实验分别采用传统的置信域策略优化算法和近端策略优化算法,通过对GAIL(PPO-penalty)以及GAIL(TRPO)的模型结果进行的对比分析,可以看出基于PPO-penalty的GAIL模型结果的优势十分明显,因此也奠定了后续继续研究对GAIL(PPO-penalty)结构的拓展和改进。传统的GAIL(TRPO)算法已经在性能上劣于SGAN模型,但是通过对比,本方法使用的GAIL(PPO-penalty)在ADE和FDE指标上均优于SGAN。此处仅基线方法就已经能够取得很好的效果,在此基础上,为了解决时序特征问题,对当前的基础模型进一步优化改进。
其次,基于GAIL(PPO-penalty)算法的结果,此处,本方法引入了ConstantPadding的方法,测试发现ConstantPadding的方法在数据集有限的情况下,对于提高数据质量和利用率具有较为明显的作用,只引入了ConstantPadding的拓展方法,还没有添加MogrifierLSTM时序特性的情况下,模型能够表现的最佳结果。基于结果对比,可以看到,GAIL(ConstantPadding)方法可以将模型对于训练数据的利用率提高,并且较好的提高两个指标ADE和FDE的准确度。
最后,基于ConstantPadding-GAIL的结果,考虑使用Mogrifier LSTM抽取行人历史信息中存在的时序特征,并将其融合到当前状态的实验构建。在基于Mogrifier LSTM的GAIL模型中,本方法前期实验结果非常不理想,随即进行了模型结构的具体分析,发现在前期的实验中,模型结构是直接将原始的Agent当前坐标(5元组)进行了Linear映射之后才与Mogrifier LSTM时序特征进行了融合,在这个基础上,又进行了一次Embedding和MLP的操作,使得当前最重要的t时刻特征被不同的权重矩阵进行了拉伸,影响了后续模型对于特征图中特征的重要性判断出现了失误,因此无法进行准确的提取和预测。基于以上分析,本方法对模型特征融合过程进行了改进,后来采用将当前状态分割预留出来,首先对于整体再进行时序特征提取,之后进行特征拼接的操作顺序,保证了重要特征不被掩盖。其中还尝试了自注意力机制,但是发现在当前场景中效果一般,因此没有增加这一部分的结构以免使得模型过分冗余。通过实验结果可以看出,本方法最终提出的第二种模型结构具有最好的表现。
附图说明
图1为本发明一种基于时空特征的行人路径预测方法的流程图;
图2为本发明中基于Mogrifier LSTM的MogrifierGAIL模型结构图;
图3为本发明中部分GAIL(TRPO)结果可视化,分别由5组“真实-预测”轨迹组成图;
图4为本发明中部分GAIL(PPO-penalty)结果可视化,分别由5组“真实-预测”轨迹组成图;
图5为本发明中MogrifierGAIL训练过程奖励值可视化图;
图6为本发明中MogrifierGAIL模型效果可视化,分别由5组“真实-预测”轨迹组成图;
图7为现有技术中Social-GAN模型基于8步预测的8、16、32步预测结果可视化图;
图8为现有技术中Social-GAN论文Pooling部分的结构与复现办法图;
图9为本发明中不同结构的另一特征融合结果示意图。
具体实施方式
下面结合具体实施例对本发明进一步进行描述。
本专利所使用的术语“包括”和“包含”应被理解为包含性的和开放式的,而不具有排他性。具体而言,当在说明书和权利要求书中使用术语“包括”和“包含”及其同义词时,是表示包括指定的特征、步骤或组成部分。这些术语不能被理解为排除其他特征、步骤或组成部分的存在。
本发明所述基于时空特征的行人路径预测方法,整体流程图如图1所示,包括以下步骤:
步骤S1:构建基于GAN网络的GAIL模型,包括:
步骤S11:构建GAIL模型:
本发明的GAIL模型采用了生成对抗网络(Generative Adversarial Network,GAN),训练一个生成器,基于一定的数据分布产生对应的行为,以便欺骗另一个同时训练的判别器,尽最大可能使得判别器无法判别真实轨迹与生成轨迹;而判别器的作用即为能够区分数据中哪些是真实轨迹,哪些是生成器生成的假轨迹。
在这样的一个反复的生成数据和判别(trade off)过程中,不断地优化两个部分的模型参数;目标是生成器(Generator)生成的数据分布与真实(Real)数据分布最接近。
具体地说,本发明采用的算法公式如下:
在上式(I)中,定义E表示策略的期望,或策略为专家的期望,π是训练得到的策略模型,Eπ即为对策略π求取期望运算;EπE表示对专家数据蕴含的专家策略求取期望运算;log表示求以10为底的对数;
D为判别器,s表示当前时刻的状态,H表示λ参数控制的策略调整器,同时根据所学习的策略,输出一个/一组对应的动作值记作aπ,aπE表示采用专家策略生成的动作值;action所对应的动作空间是A<vx,vy>,同理,若是采用专家策略生成的动作值记作
其中,上述所述基于时空特征的行人路径预测方法,动作空间大小一致。假设D表示判断其为假数据的概率,因此D的目标即为使得概率越高越好(因为表示的是生成器学习的策略,即为假策略),同时使得1-D越小越好,因为这意味着判别器能够判断出来专家策略并不是生成的假数据轨迹,为了保证整体公式的方向一致性,在上述公式中取目标为D使得越大越好;这样前后两项的目标趋于一致;在具体实现的过程中,GAIL中训练的“生成器”策略,其更新的参数记为aπ,判别器所使用的更新策略权重参数记为λ,对于判别器,在算法中选择使用自适应矩估计(Adam adaptive moment estimation)的梯度更新方法来优化参数,选择使用TRPO算法(Trust Region Policy Optimization)进行策略参数的更新;同时在上式的尾部,添加了一个基于因果熵梯度的特殊的正则化项,使得训练过程中能够更好地收敛。
步骤S12:GAIL模型,对照到GAN网络,可以得到公式(II):
其中,s表示状态空间,a表示动作空间,S和A分别表示s和a的取值范围;其余参数同公式(I)。其训练过程是一个最小化min和最大化max的过程,进行生成器和策略网络的博弈,使得策略网络在生成器的打分中不断优化模型参数,得到最小值;同时,判别器也在不断训练判别能力,企图最大化以上的目标函数值。其与生成对抗网络模型的目标略微不同之处,即在博弈过程中,生成对抗网络是对于同一个目标函数分别基于两个部分进行有侧重的优化;而在GAIL模型中,非常重要的一步是正则化项的加入,对于策略的优化起到非常重要的作用。基于一种模仿学习算法的度量启发,可以把问题归类为最小化JS散度的问题。首先将度量两种策略的occupancy measure公式转化为特殊的风险期望值,f-divergence和风险期望具有一致性,且JS散度是属于f-divergence的一种特殊情况。因此,为了最小化策略度量,即同理可以对风险期望值进行最小化探索,将对于正则化项的推导进一步转化成为对于风险期望函数的推导。最终得到下列伪代码中对于策略更新过程中的自然梯度更新步骤,其中n表示所设置的训练timestep总数,D表示判别器,E表示期望,s表示状态空间内的状态元组,a表示基于策略所产生的动作序列。
GAIL模型伪代码
上述所述基于时空特征的行人路径预测方法,
进一步的,步骤S1中,还可以包括如下步骤:
步骤S13:对于策略π进行优化,进行问题归类为最小化JS散度的问题。
其中,JS散度(Jensen-Shannon divergence,缩写JSD)是基于KL散度(相对熵)的一种统计学度量,能够衡量两个概率分布之间的差异程度,可以参考:B.Fuglede andF.Topsoe,“Jensen-Shannon divergence and Hilbert space embedding,”inInternational Symposium onInformation Theory,2004.ISIT 2004.Proceedings.,2004,p.31。
步骤S131:首先将度量两种策略的占用度量(occupancy measure)公式转化为特殊的风险期望值;
具体来说,可以采用下面的公式把损失函数Ф转换最小预期风险R Ф (ρπ,ρπE):
其中,Ф为损失函数,ρπ和ρπE为占用度量(occupancy measure),γ为因果熵(causal entropy),π为轨迹,s表示状态空间内的状态元组,a表示基于策略所产生的动作序列。
步骤S132:为了最小化策略度量,即同理可以对风险期望值进行最小化探索,将对于正则化项的推导进一步转化成为对于风险期望函数的推导;
步骤S133:最终得到伪代码中对于策略更新过程中的自然梯度更新步骤。
本发明还包括:
步骤S2:在所述GAIL模型中融合Mogrifier LSTM提取的历史信息,基于前n个可变步长的预测过程,给出下一时刻的结果预测,其中,n为自然数。
Mogrifier LSTM(形变长短期记忆模型)是2020年提出的LSTM的一种改进版本,可以参见2020年发表的论文https://arxiv.org/pdf/1909.01792.pdf,它通过各个门很好地控制了时间步前后的信息。
具体的,步骤S2中在GAIL模型中融合Mogrifier LSTM提取的历史信息的操作如下:
步骤S21:对行人的下一步位置或者行人所采取的行动进行预测,提出加入Mogrifier LSTM进行前几步长的时序特征隐藏层输出值融入模型中;
步骤S22:将前几步长的时序特征与当前的观测状态同时进行考量,做出基于前几个时间步而言在合理接受的阈值范围内的行为动作;
步骤S23:得到较为可靠的提高预测精确度。
其中,步骤S2中GAIL模型将会基于前8个可变步长的预测过程,给出下一时刻的结果预测,在上述过程中,前8个时刻的历史轨迹将会被作为模型的记忆信息输入MogrifierLSTM网络中,基于上述历史轨迹的绝对坐标信息,提取出对应的OUT输出状态,将上述输出状态在后续应用到MLP(Multilayer Perceptron)中进行特征融合,最终起到使用Mogrifier LSTM提取时序特征的作用。前几步长的历史轨迹和未来的预测轨迹在图3,图4,图6中由不同线条表示。
本发明还包括:
步骤S3:构建与Mogrifier LSTM适配的缓冲区。
上述所述基于时空特征的行人路径预测方法,步骤S3中构建与长短期记忆模型适配的缓冲区的操作如下:
步骤S31:假设我们目前设定的观测值为n(优选的,n为8),在前期不到n个可变步长时,均采用常数填充(ConstantPadding)的方法填充0;
步骤S32:在n个可变步长后,不断更新当前状态,其中最右边的五元组:(xt,yt,x,y,t)为当前状态,包括当先时刻的二维坐标(xt,yt)目标坐标(x,y)以及当前时间信息(t),会随着时间的变化,更新这个五元组;其中前n-1个序列为观测值;
步骤S33:经过上述数据结构的变化,模型最终在每次输入的时候,均携带前期的总共n个可变步长的观测值信息。
例如:当n=8时,总特征窗口大小为8*5=40,如果当前t<=7,则观测特征为:(5*t个实际观测值+5*(8-t)个数值0填充)。
上述所述基于时空特征的行人路径预测方法,步骤S3中在GAIL模型中通过构建专家数据缓冲区,为模型不断地提供行人历史真实轨迹数据分布的学习与计算;在缓冲区中,将当前时刻的状态作为最后一个时刻,引入变量obs_len观测步长,在上述过程中,采用了常数填充方法,给不足特定步长的部分,填补为0。
本发明还包括:
步骤S4:构建基于Mogrifier LSTM的MogrifierGAIL模型。
图2为本发明中基于Mogrifier LSTM的MogrifierGAIL模型流程图。
上述所述基于时空特征的行人路径预测方法,
步骤S4中构建基于Mogrifier LSTM的MogrifierGAIL模型的具体操作如下:
步骤S41:将采样的轨迹或者行人历史真实轨迹,输入到模型的Input中,其中轨迹为n个步长的特征五元组所组成的连续轨迹;
步骤S42:输入的数据会被分割成为两部分,一部分采用最低维分割出来当前Agent的位置信息,存储在(d1,d2,x)的数据中,此处x=5,以便根据当前状态做出下一步的决策;另外一部分为原始数据的备份,和另外的时序数据将会被分在另一个分支中;
步骤S43,首先进行维度的变化,将rollout_len(初始化轨迹个数参数)和processor_num(处理器数量参数)合并成为批数据,同时对Mogrifier LSTM采用batch_first=True的设定,批量对数据进行处理。
上述所述基于时空特征的行人路径预测方法,步骤S4还包括:
步骤S44:在Mogrifier LSTM处理之后的数据将会被取出其Out[:,-1:]的数据,之后被进行维度拉伸,回到原始的(d1,d2,x)的维度大小,x将会存储自定义的hidden_size大小信息,即代表时序特征的信息;
步骤S45:将原始的当前位置特征与时序信息进行融合,合并之后采用嵌入的方法提取融合特征;
步骤S46:输入标准的MLP层用来得到对应out_size的输出,对应不同的out_size大小,信息会分为两种处理取到,动作(action)信息将会被用来计算对应的损失(loss)值,并且更新生成器(generator)的参数,奖励(reward)信息会被用优化器来更新判别器(discriminator)的参数。
上述所述基于时空特征的行人路径预测方法,步骤S4还包括:
步骤S47:行人历史真实轨迹和采样轨迹均经过共用的Actor_Critic的MLP类结构,具体在实际的训练过程中,利用所用的普通多层感知机;同时,结合了Mogrifier LSTM网络模型结构,根据提供的包含时序信息的Sequential Demo Buffer的数据;
步骤S48:采用了多处理器并行运行的方式,对应不同的处理器,会初始化对应设定数量的环境个数,随即进行特征分析与融合;
步骤S49:进行采样与轨迹更新,从产生动作中进行判别器打分作为奖励,继续优化训练的策略。
本发明还包括:
步骤S5:将判别器模型和生成器进行优化,直至达到一种稳态即纳什均衡点。
具体包括:
步骤S51:判别器D是一个二分类模型,将在真实轨迹A与生成器所生成的轨迹B所组成的带标签的数据池中随机采样,并且将采样的轨迹作为判别器的输入,判别器将给出即判断输入属于真实轨迹或生成轨迹的分类结果,根据判别器的分类结果,可以与标签真值进行对比并且计算误差;优化器目标为将误差值进行最小化;
步骤S52:同理,对于生成器,其从一个初始化的分布中进行采样,经过生成器的网络的处理并输出,输出将会作为生成的轨迹与真实轨迹组合成为数据池;在每次输入轨迹进入判别器时,都会进行轨迹池的更新;因为每一步都会对生成器进行优化,优化过程同理,如果判别器的损失误差越大则对生成器越有利。
步骤S53:将生成器和判别器不断优化时,最终达到一种稳态即纳什均衡点。即在判别器最优的情况下,获得了能够达到最优的轨迹的生成器模型参数。
具体地说,上述步骤S53中,在优化过程中采用Adam优化器进行梯度下降计算。
本发明还包括:
步骤S6:输入到行人轨迹数据至S5得到的训练好的模型;将所观测到的轨迹作为输入,所观测的长度可以是1个五元组【x1,y1,x,y,t=1】至8个五元组组成的序列。
步骤S7:利用训练好的模型进行输出,得到该行人的预测轨迹图,或者数据。
输出将是下一时刻最有可能所采取的动作,即(x增量,y增量),因此可以自由决定预测n个时间步长的动作值,只需在每次动作之后更新当前状态和观测值即可,这样便可以得到一个具有特定时间步长的预测值。
关于本发明的实验环境与数据特征设定
在硬件方面:首先本方法中所有实验均在Ubuntu 18.04操作系统完成,型号为Intel(R)Core(TM)i5-4590 CPU@3.30GHz。Python使用3.7版本,Pytorch使用1.9.1+cu102。
在强化学习环境设定方面:本方法可视化训练过程所基于的实验环境是名为“mycrowd-v2”的自定义Gym实验环境,对“路径预测场景”进行了基于精确度和探索方向的环境构建:对于观测空间,将其定义为一段连续空间状态值:obs_space=space.Box(obe_low,obs_high),其中obs_low与obs_high分别对应观测上限和下限,T表示当前时刻时间节点,xi表示第i个时刻下的x坐标位置,yi表示第i个时刻下的y坐标位置(也可以理解成为第t个时刻,i和t含义相同),xg表示目标点goal的x坐标位置,yg表示目标点y坐标的位置。各个参数满足下列条件:
在实验中,设定每帧(Frame)为固定值0.25,作为时间递增的模拟,并且在后期可视化的时候进行展示。
环境与代理进行交互的过程符合以下流程:首先基于当前的整体观测状态,输入obs中,模型将以obs变量作为输入,产生Action[vx,vy]的结果,基于当前结果,环境进行行人状态的更新,调用Step()函数,对全局观测信息进行更新,在MogrifierGAIL中,由于存在一个obs_len长度的观测窗口用来刻画时序信息,因此在更新的过程中需要对其进行适配处理(缓冲区)。
在本方法的数据结构中,当前信息将存放在观测序列的最后一段,其他信息将根据时间距当前帧的跨度值逆序排列,从而保证在获取当前状态的时候只需对观测序列的逆向进行切片即可,如Trajectory[-5:]便可以获取倒数的最近时刻观测信息(5个特征值)。在强化学习实验过程的奖励设定方面,由于本方法中的GAIL基础模型算法不以所提供的奖励进行反馈,而是基于判别器的打分值为信号。因此本方法中的环境反馈的奖励信息也仅作为对模型训练过程表现的一个参考指标,而非实际指导模型进行策略更新。在环境奖励值的设置中,其分别基于状态更新前后相对于目标距离的变化进行奖励或惩罚,假设dis_before表示更新状态之前对于目标值的距离,dis_after表示更新状态之后距离目标值的距离。则设定,若dis_after<dis_before,则奖励值为2*(dis_befor-dis_after),否则奖励值为-2*(dis_after-dis_befor),同时设定Done信号为False;另一种情况是Agent代理已经抵达目标附近,因此设定奖励为10,Done信号为True表示一轮探索已经完成。
2)实验所用衡量指标
ADE:用来衡量模型预测轨迹点的坐标和真实值坐标的差距(越小预测精度越高)。
其中,t表示时间帧,pt表示时间帧为t时行人的坐标位置。
在上式中,n表示第n个Agent或者同一个Agent的第n轮预测,假设场景存在5个Agent,此时N为5,n从0到4,T表示总预测时间帧数,因此上式表示每个Agent每一步采取行动之后,当前预测坐标与真实坐标的二范数值即距离。
FDE:用来衡量模型预测到最终点的坐标和真实最终点的差距(越小离真实目标越近)。
其中,符号定义同ADE过程。此时只取每个Agent的最终坐标值进行衡量;因此只在t=T的时候衡量其距离的平均值。
以上指标均被同行业其他科研人员广泛使用,具有普适性和较强的说服力。
3)基于TRPO的GAIL模型(传统算法)
表1 GAIL(TRPO)算法在10个测试场景中的精度指标
对GAIL在测试集上测试结果(部分):
如图3所述,分别由5组“真实-预测”轨迹组成;经过反复实验,以及对超参数的调节,最终基于TRPO的GAIL模型最佳的结果为表2。
表2 GAIL(TRPO)算法在10个测试场景中的精度指标
通过传统的GAIL模型,我们可以发现,agent在刚开始的前几步预测过程中,能够保持相对平稳的预测轨迹和较为准确的方向。但是在经过交互点的时候,开始出现偏离轨道实际方向并且折返的现象。因此我们考虑使用Mogrifier LSTM融合时序特性增强前后的状态整体一致性,再进行分析和比较;为了比较基于PPO-penalty的GAIL模型是否具有优势,我们同样在一致的实验设置下,实现了下列算法。
4)基于PPO-penalty的GAIL模型(本发明)
经过上述实验,本发明的方法获得了基于TRPO算法的GAIL模型的最佳结果数据,接下来,在同样的专家数据Demo中进行同样的采样长度的设定,基于同样的训练条件,并且设置了相同的Gym环境测试指标,采用一致的ADE与FDE衡量标准,得到如下结果,如表3和图4所示。图4中,Pred:预测轨迹坐标,Real:真实轨迹坐标。
表3GAIL(PPO-penalty)算法在10个测试场景中的精度指标
经过分析实验结果,可以发现,基于PPO-penalty的GAIL模型在整体的预测精度ADE和FDE方面均要优于第一部分实验GAIL(TRPO)。但是其在中间部分的预测过程中,仍会出现难以捕捉时序状态的信息的情况。于是本方法在基于PPO-penalty的GAIL模型结构上,尝试加入Mogrifier LSTM。在Mogrifier LSTM的模型结构方面,本方法进行了数十次实验,最终找到了最合适的特征融合方法,即上文模型结构中所说明的。
5)在PPO-penalty基础上,加入Mogrifier LSTM与Sequential Demo Buffer的GAIL模型
通过详细的实验设定与对比模拟,在本方法中构建了一种高效的特征融合方式,模型训练过程中的步长消耗和奖励曲线如图5所示。
在本组模型训练过程中,不仅模型精确度表现优异,而且在训练过程和训练时长方面,收敛速度也更快。如图6所示,在Epoch=11时间节点便已经取得了相对精确的模型参数解,相对于传统的GAIL模型,在1Epoch=1024Steps的设定下,大约需要14个Epoch的训练时间,能够提高2-3个时间单位。
表4MogrifierGAIL(Epoch=11)算法在10个测试场景中的精度指标
由表4可以看出,即使只经过了11个Epoch,MogrifierGAIL便已经达到了收敛状态,并且能够得到非常理想的ADE和FDE指标值。
6)与Social-GAN的对照实验对比
为了与其他已发表的工作进行合理的对比,以验证本方法算法的有效性,经过大量的文献查阅,本方法选择对Social-GAN行人路径预测模型进行复现,Social-GAN是LiFeifei等人在2018年的CVPR会议中提出的一种基于时序的预测模型,其主要特点是基于social-Pooling的方法,对场景中的行人进行建模,从而进行未来轨迹的预测。基于文章的内容,本方法分析其核心Encoder模型结构如图7所示。
基于以上模型,本方法将基于同样的10组行人历史真实轨迹进行训练,每组行人历史真实轨迹中包含1个场景,场景中含有5个行人,都是基于crossing交互问题,行人之间会有方向上的交互行走,经过训练后,将在另外10组测试集中进行测试和验证,测试结果分析其FDE和ADE指标,与本方法所提出的方法进行对比。
根据图7的可视化结果,可以看出,在Social-GAN模型下,基于8步预测8步的情况下还是相对准确的,从图示来看比较接近行人历史真实轨迹,但是在预测8步的基础上,再增加8步、16步、32步便逐渐出现严重的偏离行人历史真实轨迹的情况。
图8为现有技术中Social-GAN论文Pooling部分的结构与复现办法图。
对SGAN进行测试集测试结果如下表5:
表5Social-GAN算法在10个测试场景中的精度指标
7)综合结果分析
首先,本方法构建了实验分别采用传统的置信域策略优化算法和近端策略优化算法,通过对GAIL(PPO-penalty)以及GAIL(TRPO)的模型结果进行的对比分析,可以看出基于PPO-penalty的GAIL模型结果的优势十分明显,因此也奠定了后续继续研究对GAIL(PPO-penalty)结构的拓展和改进,对比分析如表6所示:
表6 GAIL(TRPO)、GAIL(PPO-penalty)、SGAN结果对比(取自10组测试结果的平均值)
可以看出,传统的GAIL(TRPO)算法已经在性能上劣于SGAN模型,但是通过对比,本方法使用的GAIL(PPO-penalty)在ADE和FDE指标上均优于SGAN。此处仅基线方法就已经能够取得很好的效果,在此基础上,为了解决时序特征问题,对当前的基础模型进一步优化改进,结果如下述。
其次,基于GAIL(PPO-penalty)算法的结果,本方法引入了ConstantPadding的方法,测试发现ConstantPadding的方法在数据集有限的情况下,对于提高数据质量和利用率具有较为明显的作用,下表展示的是只引入了ConstantPadding的拓展方法,还没有添加Mogrifier LSTM时序特性的情况下,模型能够表现的最佳结果如表7中所示。
表7 ConstantPadding-GAIL(epoch10)在10个测试场景中的精度指标
基于上述结果,本方法将GAIL(ConstantPadding)与基础模型GAIL(PPO-penalty)进行对比,结果呈现如表8所列。
表8 GAIL(ConstantPadding)与GAIL(PPO-penalty)所得到的平均指标
基于结果对比表可以看到,GAIL(ConstantPadding)方法可以将模型对于训练数据的利用率提高,并且较好地提高两个指标ADE和FDE的准确度。
最后,基于ConstantPadding-GAIL的结果,考虑使用Mogrifier LSTM抽取行人历史信息中存在的时序特征,并将其融合到当前状态的实验构建。在基于Mogrifier LSTM的GAIL模型中,本方法前期实验结果非常不理想。随即进行了模型结构的具体分析,发现在前期的实验中,模型结构是直接将原始的Agent当前坐标(5元组)进行了Linear映射之后才与Mogrifier LSTM时序特征进行了融合。在这个基础上,又进行了一次Embedding和MLP的操作,使得当前最重要的t时刻特征被不同的权重矩阵进行了拉伸,影响了后续模型对于特征图中特征的重要性判断出现了失误,因此无法进行准确的提取和预测。基于以上分析,本方法对模型特征融合过程进行了改进,后来采用将当前状态分割预留出来,首先对于整体进行时序特征提取,之后进行特征拼接的操作顺序,保证了重要特征不被掩盖。其中还尝试了自注意力机制,但是发现在当前场景中效果一般,因此没有增加这一部分的结构以免使得模型过分冗余。下面是不同类型Mogrifier LSTM结构对应的平均结果,可以看出,第二种模型结构(即本方法最终提出的)具有最好的表现。
表9 Mogrifier LSTM-GAIL特征融合结构对比(10组测试的平均值,第②组最佳)
基于表9平均结果ADE与FDE指标可以看出,在采用第二种结构的特征融合方法时,在ADE和FDE指标上的预测精度均有很好的表现,优于其他三种融合方式。
此外,本方法对训练过程中的模型受到的环境奖励反馈进行可视化,如图9所示:
在图9中,第一中融合方法最终收敛的奖励值约为30,第二种结构可以达到最好效果40+,第三、第四种分别约为25到30。由此可见,图9中呈现的预测精度与模型在训练过程中,对探索时的结果反馈具有一致性。因此更加验证了模型的有效性。
表10 MogrifierGAIL与常填充后的pure-GAIL结果对比
可以看出基于Mogrifier LSTM的GAIL(PPO-penalty)算法(即本研究所提出的MogrifierGAIL)更加优于只用了ConstantPadding手段之后的GAIL(PPO-penalty)。分别在平均位移误差ADE、最终位移误差FDE两个指标上提升了14%和21%。结合上文,最终本方法的模型比前人所以提出的Social-GAN模型在ADE指标上提高了29.8%。
以上所述仅为本发明的较佳实施例,并不用以限制本发明,对于本领域技术人员而言,显然本发明不限于上述示范性实施例的细节,而且在不背离本发明的精神或基本特征的情况下,能够以其他的具体形式实现本发明。因此,无论从哪一点来看,均应将实施例看作是示范性的,而且是非限制性的,本发明的范围由所附权利要求而不是上述说明限定,因此旨在将落在权利要求的等同要件的含义和范围内的所有变化囊括在本发明内。不应将权利要求中的任何附图标记视为限制所涉及的权利要求。
此外,应当理解,虽然本说明书按照实施方式加以描述,但并非每个实施方式仅包含一个独立的技术方案,说明书的这种叙述方式仅仅是为清楚起见,本领域技术人员应当将说明书作为一个整体,各实施例中的技术方案也可以经适当组合,形成本领域技术人员可以理解的其他实施方式。
Claims (5)
1.一种基于时空特征的行人路径预测方法,其特征在于,包括以下步骤:
步骤S1:构建基于GAN网络的GAIL模型,包括:
步骤S11:构建GAIL模型,如下式(I):
在上式(I)中,E表示策略的期望或策略为专家的期望,π是训练得到的策略模型,Eπ即为对策略π求取期望运算;EπE表示对专家数据蕴含的专家策略求取期望运算;log表示求以10为底的对数;D为判别器,s表示当前时刻的状态,H表示λ参数控制的策略调整器,同时根据所学习的策略,输出对应的动作值记作aπ,aπE表示采用专家策略生成的动作值;action所对应的动作空间是A<υx,υy>;
步骤S12:GAIL模型对照到GAN网络,可以得到公式(II):
其中,s和a分别表示状态空间和动作空间,S和A分别表示s和a的取值范围;其余参数同公式(I);其训练过程是一个最小化和最大化的过程,同时进行生成器和策略网络的博弈,使得策略网络在生成器的打分中不断优化模型参数,得到最小值,同时,对判别器不断进行判别能力训练,以便最大化以上的目标函数值;
步骤S2:在所述GAIL模型中融合Mogrifier LSTM提取的历史信息,基于前n个可变步长的预测过程,给出下一时刻的结果预测,其中,n为自然数;
步骤S2中,在GAIL模型中融合Mogrifier LSTM提取的历史信息的操作如下:
步骤S21:对行人的下一步位置或者行人所采取的行动进行预测,提出加入MogrifierLSTM进行前几步长的时序特征隐藏层输出值融入模型中;
步骤S22:将前几步长的时序特征与当前的观测状态同时进行考量,做出基于前几个时间步而言在阈值范围内的行为动作,得到较为可靠的预测精确度;
步骤S3:构建与Mogrifier LSTM模型适配的缓冲区;具体操作如下:
步骤S31:设定观测值为n,在前期,从开始到当前观测时刻t的长度不到n个可变步长时,均采用常数填充的方法填充0;
步骤S32:在n个可变步长后,不断更新当前状态,其中最右边的五元组:(xt,yt,x,y,t)为当前状态,包括当前时刻的二维坐标(xt,yt),目标坐标(x,y),以及当前时间信息(t),会随着时间的变化,更新这个五元组;其中前n-1个序列为观测值;
步骤S33:经过上述数据结构的变化,模型最终在每次输入的时候,均携带前期的总共n个可变步长的观测值信息;
步骤S4:构建基于Mogrifier LSTM的MogrifierGAIL模型;具体操作包括:
步骤S41:将采样的轨迹或者行人历史真实轨迹,输入到Input中,其中轨迹为n个步长的特征五元组所组成的连续轨迹;
步骤S42:输入的数据会被分割成两部分,一部分采用最低维分割出来当前Agent的位置信息,存储在(d1,d2,x)的数据中,x将会存储自定义的hidden_size大小信息,以便根据当前状态做出下一步的决策;另外一部分为原始数据的备份,和另外的时序数据将会被分在另一个分支中;
步骤S43:首先进行维度的变化,将初始化轨迹个数参数和processor_num合并成为批数据,同时对Mogrifier LSTM采用batch_first=True的设定,批量对数据进行处理;
步骤S44:Mogrifier LSTM处理之后,数据将会被取出其Out[:,-1:]的数据,之后被进行维度拉伸,回到原始的(d1,d2,x)的维度大小,x将会存储自定义的hidden_size大小信息,即代表时序特征的信息;
步骤S45:将原始的当前位置特征与时序信息进行融合,合并之后采用嵌入的方法提取融合特征;
步骤S46:输入标准的MLP层用来得到对应out_size的输出,对应不同的out_size大小,信息会分为两种处理取到,动作信息将会被用来计算对应的损失值,并且更新生成器的参数,奖励信息会被用优化器来更新判别器的参数;
步骤S47:行人历史真实轨迹和采样轨迹均经过共用的Actor_Critic的MLP类结构,具体在实际的训练过程中,利用所用的普通多层感知机;
步骤S48:采用了多处理器并行运行的方式,对应不同的处理器,会初始化对应设定数量的环境个数,随即进行特征分析与融合;
步骤S49:进行采样与轨迹更新,从产生动作中进行判别器打分作为奖励,继续优化训练的策略;
步骤S5:将判别器模型和生成器进行优化,直至达到一种稳态即纳什均衡点;
步骤S6:输入行人轨迹数据至S5得到的训练好的模型;将所观测到的轨迹作为输入,所观测的长度可以是1个五元组【x1,y1,x,y,t=1】至8个五元组组成的序列;
步骤S7:利用训练好的模型进行输出,得到该行人的预测轨迹数据。
2.根据权利要求1所述基于时空特征的行人路径预测方法,其特征在于,还包括:
步骤S13:对于策略π进行优化,将问题归类为最小化JS散度的问题,具体包括:
步骤S131:首先将度量两种策略的占用度量公式转化为特殊的风险期望值,即最小预期风险R Ф (pπ,PπE);
其中,Ф为损失函数,ρπ和ρπE为占用度量,γ为因果熵,π为轨迹,s表示状态空间内的状态元组,a表示基于策略所产生的动作序列;
步骤S132:为了最小化策略度量,可以对风险期望值进行最小化探索,将对于正则化项的推导进一步转化成为对于风险期望函数的推导;
步骤S133:最终得到对于策略更新过程中的自然梯度更新。
3.根据权利要求1所述基于时空特征的行人路径预测方法,其特征在于:
其中,前n个时刻的历史轨迹将会被作为模型的记忆信息输入Mogrifier LSTM网络中,基于上述历史轨迹的绝对坐标信息,提取出对应的输出状态;
将上述输出状态在后续应用到多层感知器中进行特征融合,使用Mogrifier LSTM提取时序特征。
4.根据权利要求1所述基于时空特征的行人路径预测方法,其特征在于:
步骤S3中,在GAIL模型中通过构建专家数据缓冲区,为模型不断地提供行人历史真实轨迹数据分布的学习与计算;
在缓冲区中,将当前时刻的状态作为最后一个时刻,引入变量obs_len观测步长;
在上述过程中,采用了常数填充的方法,给不足特定步长的部分,填补为0。
5.根据权利要求1所述基于时空特征的行人路径预测方法,其特征在于:
所述步骤S5具体包括:
步骤S51:判别器D是一个二分类模型,将在真实轨迹A与生成器所生成的轨迹B所组成的带标签的数据池中随机采样,并且将采样的轨迹作为判别器的输入,判别器将给出即判断输入属于真实轨迹或生成轨迹的分类结果,根据判别器的分类结果,可以与标签真值进行对比并且计算误差;优化器目标为将误差值进行最小化;
步骤S52:对于生成器,其从一个初始化的分布中进行采样,经过生成器的网络的处理并输出,输出将会作为生成的轨迹与真实轨迹组合成为数据池;在每次输入轨迹进入判别器时,都会进行轨迹池的更新;
步骤S53:将生成器和判别器不断优化时,最终达到一种稳态即纳什均衡点,从而获得了能够达到最优的轨迹的生成器模型参数。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311253071.8A CN117273225B (zh) | 2023-09-26 | 2023-09-26 | 一种基于时空特征的行人路径预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311253071.8A CN117273225B (zh) | 2023-09-26 | 2023-09-26 | 一种基于时空特征的行人路径预测方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117273225A CN117273225A (zh) | 2023-12-22 |
CN117273225B true CN117273225B (zh) | 2024-05-03 |
Family
ID=89210083
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311253071.8A Active CN117273225B (zh) | 2023-09-26 | 2023-09-26 | 一种基于时空特征的行人路径预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117273225B (zh) |
Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112766561A (zh) * | 2021-01-15 | 2021-05-07 | 东南大学 | 一种基于注意力机制的生成式对抗轨迹预测方法 |
CN113538506A (zh) * | 2021-07-23 | 2021-10-22 | 陕西师范大学 | 基于全局动态场景信息深度建模的行人轨迹预测方法 |
CN114445465A (zh) * | 2022-02-28 | 2022-05-06 | 常州大学 | 一种基于融合逆强化学习的轨迹预测方法 |
CN114611663A (zh) * | 2022-02-21 | 2022-06-10 | 北京航空航天大学 | 一种基于在线更新策略的定制化行人轨迹预测方法 |
KR20220102395A (ko) * | 2021-01-13 | 2022-07-20 | 부경대학교 산학협력단 | 자율주행 차량 군집 운행을 위한 비신호 교차로에서의 강화학습기반 통행 개선을 위한 장치 및 방법 |
WO2022241808A1 (zh) * | 2021-05-19 | 2022-11-24 | 广州中国科学院先进技术研究所 | 一种多机器人轨迹规划方法 |
CN115829171A (zh) * | 2023-02-24 | 2023-03-21 | 山东科技大学 | 一种联合时空信息和社交互动特征的行人轨迹预测方法 |
CN115826601A (zh) * | 2022-11-17 | 2023-03-21 | 中国人民解放军海军航空大学 | 基于逆向强化学习的无人机路径规划方法 |
WO2023155231A1 (zh) * | 2022-02-21 | 2023-08-24 | 东南大学 | 一种高度类人的自动驾驶营运车辆安全驾驶决策方法 |
JP2023132902A (ja) * | 2022-03-11 | 2023-09-22 | 国立研究開発法人産業技術総合研究所 | サンプル効率の良い強化学習 |
-
2023
- 2023-09-26 CN CN202311253071.8A patent/CN117273225B/zh active Active
Patent Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
KR20220102395A (ko) * | 2021-01-13 | 2022-07-20 | 부경대학교 산학협력단 | 자율주행 차량 군집 운행을 위한 비신호 교차로에서의 강화학습기반 통행 개선을 위한 장치 및 방법 |
CN112766561A (zh) * | 2021-01-15 | 2021-05-07 | 东南大学 | 一种基于注意力机制的生成式对抗轨迹预测方法 |
WO2022241808A1 (zh) * | 2021-05-19 | 2022-11-24 | 广州中国科学院先进技术研究所 | 一种多机器人轨迹规划方法 |
CN113538506A (zh) * | 2021-07-23 | 2021-10-22 | 陕西师范大学 | 基于全局动态场景信息深度建模的行人轨迹预测方法 |
CN114611663A (zh) * | 2022-02-21 | 2022-06-10 | 北京航空航天大学 | 一种基于在线更新策略的定制化行人轨迹预测方法 |
WO2023155231A1 (zh) * | 2022-02-21 | 2023-08-24 | 东南大学 | 一种高度类人的自动驾驶营运车辆安全驾驶决策方法 |
CN114445465A (zh) * | 2022-02-28 | 2022-05-06 | 常州大学 | 一种基于融合逆强化学习的轨迹预测方法 |
JP2023132902A (ja) * | 2022-03-11 | 2023-09-22 | 国立研究開発法人産業技術総合研究所 | サンプル効率の良い強化学習 |
CN115826601A (zh) * | 2022-11-17 | 2023-03-21 | 中国人民解放军海军航空大学 | 基于逆向强化学习的无人机路径规划方法 |
CN115829171A (zh) * | 2023-02-24 | 2023-03-21 | 山东科技大学 | 一种联合时空信息和社交互动特征的行人轨迹预测方法 |
Also Published As
Publication number | Publication date |
---|---|
CN117273225A (zh) | 2023-12-22 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Lobo et al. | Evolving spiking neural networks for online learning over drifting data streams | |
CN114519469B (zh) | 一种基于Transformer框架的多变量长序列时间序列预测模型的构建方法 | |
Hu et al. | A framework for probabilistic generic traffic scene prediction | |
CN117574308B (zh) | 基于人工智能的计量芯片异常检测方法及系统 | |
CN112347923A (zh) | 一种基于对抗生成网络的路侧端行人轨迹预测算法 | |
CN109993770B (zh) | 一种自适应时空学习与状态识别的目标跟踪方法 | |
CN114155270A (zh) | 行人轨迹预测方法、装置、设备及存储介质 | |
Bharilya et al. | Machine learning for autonomous vehicle's trajectory prediction: A comprehensive survey, challenges, and future research directions | |
CN111695737A (zh) | 一种基于lstm神经网络的群目标行进趋势预测方法 | |
Bougie et al. | Combining deep reinforcement learning with prior knowledge and reasoning | |
CN111191722B (zh) | 通过计算机训练预测模型的方法及装置 | |
CN115544239A (zh) | 一种基于深度学习模型的布局偏好预测方法 | |
Zhang et al. | Residual memory inference network for regression tracking with weighted gradient harmonized loss | |
CN117747064A (zh) | 一种基于ai的抑郁症临床决策方法及系统 | |
CN111626198A (zh) | 自动驾驶场景下基于Body Pix的行人运动检测方法 | |
CN117273225B (zh) | 一种基于时空特征的行人路径预测方法 | |
Li et al. | Active temporal action detection in untrimmed videos via deep reinforcement learning | |
KR20220014744A (ko) | 강화 학습을 기반으로 한 데이터 전처리 시스템 및 방법 | |
CN113836818B (zh) | 一种基于bp神经网络预测模型的洋流运动预测算法 | |
CN113360772B (zh) | 一种可解释性推荐模型训练方法与装置 | |
CN114612810B (zh) | 一种动态自适应异常姿态识别方法及装置 | |
Liu et al. | Long short-term memory networks based on particle filter for object tracking | |
Bilro et al. | Pedestrian Trajectory Prediction Using LSTM and Sparse Motion Fields | |
CN112735600B (zh) | 基于大数据监测和深度学习级联预测的提前预警方法 | |
KR102617344B1 (ko) | 비지도 학습 기반의 깊이 예측 방법 및 이를 이용하는 시스템 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |