CN115376103A - 一种基于时空图注意力网络的行人轨迹预测方法 - Google Patents

一种基于时空图注意力网络的行人轨迹预测方法 Download PDF

Info

Publication number
CN115376103A
CN115376103A CN202211030137.2A CN202211030137A CN115376103A CN 115376103 A CN115376103 A CN 115376103A CN 202211030137 A CN202211030137 A CN 202211030137A CN 115376103 A CN115376103 A CN 115376103A
Authority
CN
China
Prior art keywords
pedestrian
time
historical
ith
predicted
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211030137.2A
Other languages
English (en)
Inventor
郭洪艳
刘嫣然
孟庆瑜
李嘉霖
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Jilin University
Original Assignee
Jilin University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Jilin University filed Critical Jilin University
Priority to CN202211030137.2A priority Critical patent/CN115376103A/zh
Publication of CN115376103A publication Critical patent/CN115376103A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • G06V20/50Context or environment of the image
    • G06V20/56Context or environment of the image exterior to a vehicle by using sensors mounted on the vehicle
    • G06V20/58Recognition of moving objects or obstacles, e.g. vehicles or pedestrians; Recognition of traffic objects, e.g. traffic signs, traffic lights or roads
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/20Analysis of motion
    • G06T7/246Analysis of motion using feature-based methods, e.g. the tracking of corners or segments
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/70Determining position or orientation of objects or cameras
    • G06T7/73Determining position or orientation of objects or cameras using feature-based methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/44Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/10Image acquisition modality
    • G06T2207/10016Video; Image sequence
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20081Training; Learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20084Artificial neural networks [ANN]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/30Subject of image; Context of image processing
    • G06T2207/30241Trajectory
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/30Subject of image; Context of image processing
    • G06T2207/30248Vehicle exterior or interior
    • G06T2207/30252Vehicle exterior; Vicinity of vehicle

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Databases & Information Systems (AREA)
  • Medical Informatics (AREA)
  • Biomedical Technology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明提供了一种基于时空图注意力网络的行人轨迹预测方法,步骤包括:采集行人轨迹数据构建数据集;数据预处理提取轨迹特征;时间注意力机制计算时间状态特征;空间注意力机制计算时空状态特征;预测未来轨迹,训练预测模型;本方法一方面利用时间注意力机制提取每个行人的时间特征,考虑了每个行人自身过去不同的历史时刻信息对当前预测结果的影响,有效提高预测结果的准确性;另一方面空间注意力机制将上一步时间注意力机制提取的场景中所有行人的时间状态特征作为输入,利用图注意力神经网络对每个行人的相邻行人分配合理的注意力系数从而融合相邻行人的特征信息,模拟行人之间包含社交因素的空间交互作用,保证预测结果的合理性。

Description

一种基于时空图注意力网络的行人轨迹预测方法
技术领域
本发明属于自动驾驶技术领域,涉及一种行人轨迹预测模型建立方法,更加具体来讲,涉及一种基于时空图注意力网络的行人轨迹预测方法。
背景技术
近年来随着人工智能技术的不断发展,自动驾驶领域的研究也在不断地深入。自动驾驶系统主要分为环境感知、决策规划和运动控制三个模块。通过传感器从交通场景中获得道路使用者的位姿信息对其未来运动轨迹进行精准预测,能够提高决策规划系统的合理性和准确性。保证交通场景中道路使用者的自身安全是自动驾驶汽车被普及应用的前提条件,而行人作为弱势道路使用者,通过对目标行人未来运动位置的精准预测,可以减少车辆与行人碰撞事故的发生,提高自动驾驶汽车的行驶安全性。除此之外,根据行人未来的行动轨迹可以帮助自动驾驶系统制定更加合理的行驶策略,解决交通拥堵等问题。因此,研究行人轨迹预测问题对于自动驾驶技术的发展具有重要的现实意义。
行人轨迹预测的任务是根据行人过去一段时间的轨迹,预测其未来时刻的运动位置坐标。行人轨迹预测的挑战性在于行人的运动复杂灵活,难以建立合理的动力学模型,且行人的运动会受到外界环境多样性因素的影响。现有的轨迹预测方法根据建模方法主要分为两类:一类是基于模型的方法,这类方法都依赖于手工函数,无法模拟复杂情景的交互作用并且泛化能力差;另一类是近年来发展迅速的基于深度学习的预测方法。得益于神经网络的广泛应用,其完备的知识体系和丰富的网络模型为提高行人轨迹预测的准确度和合理性提供了必要条件。目前大部分的行人轨迹预测方法只考虑空间上的约束和交互作用,而忽略时间连续性,预测精度不够。
发明内容
针对现有技术存在的问题,为了提升行人轨迹预测精度,本发明提出了一种基于时空图注意力网络的行人轨迹预测方法。
为实现上述目的,本发明是采用如下技术方案实现的:
一种基于时空图注意力网络的行人轨迹预测方法,应用于自动驾驶领域,针对行人横穿马路场景进行行人的行为分析和预测,利用自动驾驶汽车感知系统装备的车载摄像机采集行人信息,其特征在于,具体步骤如下:
步骤一、采集行人轨迹数据构建数据集:
利用自动驾驶汽车感知系统装备的车载摄像机采集车辆行驶过程中前方和两侧的道路视频数据,运用语义分割、图像分类和数据标注与转换技术手段提取行人信息,其中包括道路视频每帧中每个行人在图像坐标系下的坐标值,构建行人轨迹数据集,进一步分为训练数据集和测试数据集;
步骤二、数据预处理提取轨迹特征:
对行人轨迹训练数据集进行预处理,为了在保证预测精度的前提下减小计算量,对输入的数据进行合理采样,提取每个行人的历史观测坐标,定义每个行人在历史时刻t的观测坐标
Figure BDA0003816940970000031
为:
Figure BDA0003816940970000032
其中,i表示场景中第i个行人,t表示历史时刻,Tobs表示观测时域长度,N表示场景中的行人总数量,
Figure BDA0003816940970000033
表示第i个行人在图像坐标系下在历史时刻t沿x轴方向的观测坐标值和沿y轴方向的观测坐标值;
定义每个行人在预测时刻tp的真实坐标
Figure BDA0003816940970000034
为:
Figure BDA0003816940970000035
其中,i表示场景中第i个行人,tp表示预测时刻,Tpre表示预测时域长度,N表示场景中的行人总数量,
Figure BDA0003816940970000036
表示第i个行人在图像坐标系下在预测时刻tp沿x轴方向的真实坐标值和沿y轴方向的真实坐标值;
计算每个行人在历史时刻t与上一个历史时刻t-1的历史相对位置
Figure BDA0003816940970000037
Figure BDA0003816940970000038
Figure BDA0003816940970000039
其中,
Figure BDA00038169409700000310
表示第i个行人在图像坐标系下在历史时刻t沿x轴方向的历史相对值和沿y轴方向的历史相对值;
利用嵌入函数φ(·)对历史相对位置
Figure BDA0003816940970000041
进行升维,得到每个行人在历史时刻t的嵌入向量
Figure BDA0003816940970000042
Figure BDA0003816940970000043
其中,
Figure BDA0003816940970000044
表示第i个行人在历史时刻t的嵌入向量且维数为16,φ(·)表示嵌入函数,嵌入函数由全连接层组成,We表示可学习的全连接网络参数,网络的输入特征维数为2,输出特征维数为16,层数为1;
接着将每个行人在历史时刻t的嵌入向量
Figure BDA0003816940970000045
输入到长短期记忆网络LSTM中,计算得到每个行人在历史时刻t的隐藏状态特征
Figure BDA0003816940970000046
Figure BDA0003816940970000047
其中,
Figure BDA0003816940970000048
表示第i个行人在历史时刻t的隐藏状态特征且维数为32,LSTM(·)由长短期记忆网络单元组成,Wen为可学习得到的长短期记忆网络权重参数,网络的输入特征维数为16,输出特征维数为32,隐藏特征维数为32,层数为1;
步骤三、时间注意力机制计算时间状态特征:
通过时间注意力机制计算包含时间相关性的时间状态特征
Figure BDA0003816940970000049
利用时间注意力机制计算每个行人的其他历史时刻r,r∈{1,...,t}隐藏状态特征对历史时刻t隐藏状态特征的时间注意力系数,提取每个行人的历史轨迹的时间相关性,具体过程如下:
首先,输入每个行人的隐藏状态特征
Figure BDA00038169409700000410
计算第i个行人在其他历史时刻r的隐藏状态特征
Figure BDA00038169409700000411
对历史时刻t的隐藏状态特征
Figure BDA00038169409700000412
的时间注意力系数
Figure BDA0003816940970000051
计算过程如下:
Figure BDA0003816940970000052
Figure BDA0003816940970000053
其中,f(·)表示余弦相似性函数,用来计算相似性值,
Figure BDA0003816940970000054
表示第i个行人在其他历史时刻r的隐藏状态特征,softmax(·)表示归一化指数函数,
Figure BDA0003816940970000055
表示第i个行人在其他历史时刻r的时间注意力系数;
接着,利用第i个行人在其他历史时刻r的时间注意力系数
Figure BDA0003816940970000056
计算第i个行人在历史时刻t的时间状态特征
Figure BDA0003816940970000057
计算过程如下:
Figure BDA0003816940970000058
其中,
Figure BDA0003816940970000059
表示第i个行人在历史时刻t的时间状态特征且维数为32;
步骤四、空间注意力机制计算时空状态特征:
空间注意力机制将每个行人的时间状态特征输入到图注意力网络中,场景中的所有行人对应图结构中的各个节点,行人之间的交互对应图结构中的各个边,基于图注意力网络融合第i个行人在历史时刻t与相邻行人的轨迹交互特征,得到的时空状态特征即包含了时间相关性,也包含了空间交互性,具体过程如下:
首先,定义在图结构中,第i个行人的相邻行人集合为Ni,将所有行人的时间状态特征输入到图注意力网络中,计算在历史时刻t同一场景中第j个行人对第i个行人的空间注意力系数
Figure BDA00038169409700000510
Figure BDA0003816940970000061
其中,j∈{1,...,N}且j∈Ni
Figure BDA0003816940970000062
表示在历史时刻t同一场景中第j个行人对第i个行人的空间注意力系数,Ni表示第i个行人的相邻行人集合,
Figure BDA0003816940970000063
表示第j个行人在历史时刻t的时间状态特征,
Figure BDA0003816940970000064
表示第i个行人的任一相邻的第m个行人在历史时刻t的时间状态特征,m∈{1,...,N}且m∈Ni,LeakyRelu(·)表示非线性激活函数,a表示可学习的模型参数,W表示可学习的节点特征变换权重参数,||表示拼接操作;
其次,在计算得到在历史时刻t第j个行人对第i个行人的空间注意力系数
Figure BDA0003816940970000065
后,利用图注意力网络计算第i个行人在历史时刻t融合相邻行人的空间交互特征的时空状态特征
Figure BDA0003816940970000066
Figure BDA0003816940970000067
其中,
Figure BDA0003816940970000068
表示第i个行人在历史时刻t的时空状态特征且维数为32,σ(·)表示非线性函数;
步骤五、预测未来轨迹,训练预测模型:
将第i个行人在历史时刻Tobs的时空状态特征
Figure BDA0003816940970000069
和隐藏状态特征
Figure BDA00038169409700000610
进行拼接,为了模拟真实场景中行人运动的不确定性,加入服从正态分布的噪声向量z,得到第i个行人在历史时刻Tobs的轨迹解码特征
Figure BDA00038169409700000611
计算过程如下:
Figure BDA00038169409700000612
其中,z表示噪声向量且维数为16,
Figure BDA0003816940970000071
表示第i个行人在历史时刻Tobs的轨迹解码特征且维数为80,||表示拼接操作;
利用由长短期记忆网络组成的解码器Decoder来计算得到未来的预测相对位置,将第i个行人在历史时刻Tobs的轨迹解码特征
Figure BDA0003816940970000072
作为解码器Decoder的输入,计算得到第i个行人在预测时刻Tobs+1的轨迹解码特征
Figure BDA0003816940970000073
计算过程如下:
Figure BDA0003816940970000074
其中,
Figure BDA0003816940970000075
表示第i个行人在预测时刻Tobs+1的轨迹解码特征且维数为80,
Figure BDA0003816940970000076
表示第i个行人在历史时刻Tobs的嵌入向量,Wd表示可学习的网络权重参数;
将计算得到第i个行人在预测时刻Tobs+1的轨迹解码特征
Figure BDA0003816940970000077
通过全连接层进行降维,得到维数为2的第i个行人在预测时刻Tobs+1的预测相对位置
Figure BDA0003816940970000078
Figure BDA0003816940970000079
其中,
Figure BDA00038169409700000710
表示第i个行人在预测时刻Tobs+1的预测相对位置且维数为2,δ(·)表示全连接层网络,Wd表示可学习的网络参数,网络的输入特征维数为80,输出特征维数为2,层数为1;
将第i个行人在预测时刻Tobs+1的预测相对位置
Figure BDA00038169409700000711
与历史时刻Tobs的观测坐标
Figure BDA00038169409700000712
相加即可得到第i个行人在预测时刻Tobs+1的预测坐标
Figure BDA00038169409700000713
计算过程如下:
Figure BDA0003816940970000081
其中,
Figure BDA0003816940970000082
表示第i个行人在预测时刻Tobs+1的预测坐标,
Figure BDA0003816940970000083
第i个行人在图像坐标系下在预测时刻Tobs+1沿x轴方向的预测坐标值和沿y轴方向的预测坐标值;
在得到第i个行人在预测时刻Tobs+1的预测坐标之后,下一预测时刻Tobs+2的预测坐标的利用同样的方法计算得到,依次迭代,即可计算得到各个预测时刻的预测坐标
Figure BDA0003816940970000084
Figure BDA0003816940970000085
在得到每个行人的各个预测时刻的预测坐标后,考虑到生成的合理的行人预测轨迹可能不止一条,结合步骤二中每个行人的预测时刻的真实坐标
Figure BDA0003816940970000086
构造多样损失函数Lvariety,通过采样生成多个轨迹样本,计算其中欧式距离最小的样本作为最佳的预测轨迹,计算方法如下:
Figure BDA0003816940970000087
其中,k是一个初始设定为20的超参数,表示随机抽样生成的样本个数,
Figure BDA0003816940970000088
表示根据第i个行人的预测坐标
Figure BDA0003816940970000089
随机抽样生成的轨迹样本,Lvariety表示多样损失函数;
本发明的模型训练是在pytorch深度学习框架下进行的,使用Adam优化器进行优化,学习率设置为0.01,批大小设置为64,利用训练数据集对方法中所涉及的各种网络的权重参数和模型参数进行训练,计算多样损失函数Lvariety选出最佳的预测轨迹,保存相应的各种权重参数和模型参数,得到训练好的预测模型,然后用测试数据集执行上述步骤二至步骤五来预测行人未来轨迹。
与现有技术相比本发明的有益效果是:
本发明公开了一种基于时空图注意力网络的行人轨迹预测方法,对行人交互作用建模同时考虑时间相关性和空间交互性,一方面针对现有的轨迹预测方法往往忽略行人自身的时间相关性的问题,利用时间注意力机制提取每个行人的时间特征,考虑了每个行人自身过去不同的历史时刻信息对当前预测结果的影响,有效提高预测结果的准确性;
另一方面空间注意力机制将上一步时间注意力机制提取的场景中所有行人的时间状态特征作为输入,利用图注意力神经网络对每个行人的相邻行人分配合理的注意力系数从而融合相邻行人的特征信息,模拟行人之间包含社交因素的空间交互作用,保证预测结果的合理性;
本发明考虑了行人空间交互作用的连续性,实现了时空交互信息的有效融合,可以最大化利用行人轨迹数据中的有效信息,提高行人轨迹预测结果的准确度和合理性。
附图说明
图1为本发明所述的一种基于时空图注意力网络的行人轨迹预测方法的流程示意图;
图2为本方法步骤三中的时间注意力机制的原理示意图。
具体实施方式
下面结合附图对本发明作详细的描述:
本发明提出了一种基于时空图注意力网络的行人轨迹预测方法,如图1所示为本发明的流程示意图,具体的方法步骤如下:
步骤一、采集行人轨迹数据构建数据集:
利用自动驾驶汽车感知系统装备的车载摄像机采集车辆行驶过程中前方和两侧的道路视频数据,运用语义分割、图像分类和数据标注与转换技术手段提取行人信息,其中包括道路视频每帧中每个行人在图像坐标系下的坐标值,构建行人轨迹数据集,进一步分为训练数据集和测试数据集;
步骤二、数据预处理提取轨迹特征:
对行人轨迹训练数据集进行预处理,为了在保证预测精度的前提下减小计算量,对输入的数据进行合理采样,提取每个行人的历史观测坐标,定义每个行人在历史时刻t的观测坐标
Figure BDA0003816940970000101
为:
Figure BDA0003816940970000102
其中,i表示场景中第i个行人,t表示历史时刻,Tobs表示观测时域长度,N表示场景中的行人总数量,
Figure BDA0003816940970000103
表示第i个行人在图像坐标系下在历史时刻t沿x轴方向的观测坐标值和沿y轴方向的观测坐标值;
定义每个行人在预测时刻tp的真实坐标
Figure BDA0003816940970000104
为:
Figure BDA0003816940970000111
其中,i表示场景中第i个行人,tp表示预测时刻,Tpre表示预测时域长度,N表示场景中的行人总数量,
Figure BDA0003816940970000112
表示第i个行人在图像坐标系下在预测时刻tp沿x轴方向的真实坐标值和沿y轴方向的真实坐标值;
经过采样后的数据帧频为2.5fps,即每一帧的时长为0.4s,设定观测时域帧数为8帧,预测时域帧数为12帧,即根据历史观测3.2s的轨迹信息来预测未来4.8s的轨迹信息;
计算每个行人在历史时刻t与上一个历史时刻t-1的历史相对位置
Figure BDA0003816940970000113
Figure BDA0003816940970000114
Figure BDA0003816940970000115
其中,
Figure BDA0003816940970000116
表示第i个行人在图像坐标系下在历史时刻t沿x轴方向的历史相对值和沿y轴方向的历史相对值;
利用嵌入函数φ(·)对历史相对位置
Figure BDA0003816940970000117
进行升维,得到每个行人在历史时刻t的嵌入向量
Figure BDA0003816940970000118
Figure BDA0003816940970000119
其中,
Figure BDA00038169409700001110
表示第i个行人在历史时刻t的嵌入向量且维数为16,φ(·)表示嵌入函数,嵌入函数由全连接层组成,We表示可学习的全连接网络参数,网络的输入特征维数为2,输出特征维数为16,层数为1;
接着将每个行人在历史时刻t的嵌入向量
Figure BDA0003816940970000121
输入到长短期记忆网络LSTM中,计算得到每个行人在历史时刻t的隐藏状态特征
Figure BDA0003816940970000122
Figure BDA0003816940970000123
其中,
Figure BDA0003816940970000124
表示第i个行人在历史时刻t的隐藏状态特征且维数为32,LSTM(·)由长短期记忆网络单元组成,Wen为可学习得到的长短期记忆网络权重参数,网络的输入特征维数为16,输出特征维数为32,隐藏特征维数为32,层数为1;
步骤三、时间注意力机制计算时间状态特征:
通过时间注意力机制计算包含时间相关性的时间状态特征
Figure BDA0003816940970000125
时间注意力机制的原理示意图如图2所示,利用时间注意力机制计算每个行人的其他历史时刻r,r∈{1,...,t}隐藏状态特征对历史时刻t隐藏状态特征的时间注意力系数,提取每个行人的历史轨迹的时间相关性,具体过程如下:
首先,输入每个行人的隐藏状态特征
Figure BDA0003816940970000126
计算第i个行人在其他历史时刻r的隐藏状态特征
Figure BDA0003816940970000127
对历史时刻t的隐藏状态特征
Figure BDA0003816940970000128
的时间注意力系数
Figure BDA0003816940970000129
计算过程如下:
Figure BDA00038169409700001210
Figure BDA00038169409700001211
其中,f(·)表示余弦相似性函数,用来计算相似性值,
Figure BDA00038169409700001212
表示第i个行人在其他历史时刻r的隐藏状态特征,softmax(·)表示归一化指数函数,
Figure BDA00038169409700001213
表示第i个行人在其他历史时刻r的时间注意力系数;
接着,利用第i个行人在其他历史时刻r的时间注意力系数
Figure BDA0003816940970000131
计算第i个行人在历史时刻t的时间状态特征
Figure BDA0003816940970000132
计算过程如下:
Figure BDA0003816940970000133
其中,
Figure BDA0003816940970000134
表示第i个行人在历史时刻t的时间状态特征且维数为32;
步骤四、空间注意力机制计算时空状态特征:
空间注意力机制将每个行人的时间状态特征输入到图注意力网络中,场景中的所有行人对应图结构中的各个节点,行人之间的交互对应图结构中的各个边,基于图注意力网络融合第i个行人在历史时刻t与相邻行人的轨迹交互特征,得到的时空状态特征即包含了时间相关性,也包含了空间交互性,具体过程如下:
首先,定义在图结构中,第i个行人的相邻行人集合为Ni,将所有行人的时间状态特征输入到图注意力网络中,计算在历史时刻t同一场景中第j个行人对第i个行人的空间注意力系数
Figure BDA0003816940970000135
Figure BDA0003816940970000136
其中,j∈{1,...,N}且j∈Ni
Figure BDA0003816940970000137
表示在历史时刻t同一场景中第j个行人对第i个行人的空间注意力系数,Ni表示第i个行人的相邻行人集合,
Figure BDA0003816940970000138
表示第j个行人在历史时刻t的时间状态特征,
Figure BDA0003816940970000139
表示第i个行人的任一相邻的第m个行人在历史时刻t的时间状态特征,m∈{1,...,N}且m∈Ni,LeakyRelu(·)表示非线性激活函数,a表示可学习的模型参数,W表示可学习的节点特征变换权重参数,||表示拼接操作;
其次,在计算得到在历史时刻t第j个行人对第i个行人的空间注意力系数
Figure BDA0003816940970000141
后,利用图注意力网络计算第i个行人在历史时刻t融合相邻行人的空间交互特征的时空状态特征
Figure BDA0003816940970000142
Figure BDA0003816940970000143
其中,
Figure BDA0003816940970000144
表示第i个行人在历史时刻t的时空状态特征且维数为32,σ(·)表示非线性函数;
步骤五、预测未来轨迹,训练预测模型:
将第i个行人在历史时刻Tobs的时空状态特征
Figure BDA0003816940970000145
和隐藏状态特征
Figure BDA0003816940970000146
进行拼接,为了模拟真实场景中行人运动的不确定性,加入服从正态分布的噪声向量z,得到第i个行人在历史时刻Tobs的轨迹解码特征
Figure BDA0003816940970000147
计算过程如下:
Figure BDA0003816940970000148
其中,z表示噪声向量且维数为16,
Figure BDA0003816940970000149
表示第i个行人在历史时刻Tobs的轨迹解码特征且维数为80,||表示拼接操作;
利用由长短期记忆网络组成的解码器Decoder来计算得到未来的预测相对位置,将第i个行人在历史时刻Tobs的轨迹解码特征
Figure BDA00038169409700001410
作为解码器Decoder的输入,计算得到第i个行人在预测时刻Tobs+1的轨迹解码特征
Figure BDA00038169409700001411
计算过程如下:
Figure BDA00038169409700001412
其中,
Figure BDA0003816940970000151
表示第i个行人在预测时刻Tobs+1的轨迹解码特征且维数为80,
Figure BDA0003816940970000152
表示第i个行人在历史时刻Tobs的嵌入向量,Wd表示可学习的网络权重参数;
将计算得到第i个行人在预测时刻Tobs+1的轨迹解码特征
Figure BDA0003816940970000153
通过全连接层进行降维,得到维数为2的第i个行人在预测时刻Tobs+1的预测相对位置
Figure BDA0003816940970000154
Figure BDA0003816940970000155
其中,
Figure BDA0003816940970000156
表示第i个行人在预测时刻Tobs+1的预测相对位置且维数为2,δ(·)表示全连接层网络,Wd表示可学习的网络参数,网络的输入特征维数为80,输出特征维数为2,层数为1;
将第i个行人在预测时刻Tobs+1的预测相对位置
Figure BDA0003816940970000157
与历史时刻Tobs的观测坐标
Figure BDA0003816940970000158
相加即可得到第i个行人在预测时刻Tobs+1的预测坐标
Figure BDA0003816940970000159
计算过程如下:
Figure BDA00038169409700001510
其中,
Figure BDA00038169409700001511
表示第i个行人在预测时刻Tobs+1的预测坐标,
Figure BDA00038169409700001512
第i个行人在图像坐标系下在预测时刻Tobs+1沿x轴方向的预测坐标值和沿y轴方向的预测坐标值;
在得到第i个行人在预测时刻Tobs+1的预测坐标之后,下一预测时刻Tobs+2的预测坐标的利用同样的方法计算得到,依次迭代,即可计算得到各个预测时刻的预测坐标
Figure BDA00038169409700001513
Figure BDA0003816940970000161
在得到每个行人的各个预测时刻的预测坐标后,考虑到生成的合理的行人预测轨迹可能不止一条,结合步骤二中每个行人的预测时刻的真实坐标
Figure BDA0003816940970000162
构造多样损失函数Lvariety,通过采样生成多个轨迹样本,计算其中欧式距离最小的样本作为最佳的预测轨迹,计算方法如下:
Figure BDA0003816940970000163
其中,k是一个初始设定为20的超参数,表示随机抽样生成的样本个数,
Figure BDA0003816940970000164
表示根据第i个行人的预测坐标
Figure BDA0003816940970000165
随机抽样生成的轨迹样本,Lvariety表示多样损失函数;
本发明的模型训练是在pytorch深度学习框架下进行的,使用Adam优化器进行优化,学习率设置为0.01,批大小设置为64,利用训练数据集对方法中所涉及的各种网络的权重参数和模型参数进行训练,计算多样损失函数Lvariety选出最佳的预测轨迹,保存相应的各种权重参数和模型参数,得到训练好的预测模型,然后用测试数据集执行上述步骤二至步骤五来预测行人未来轨迹。

Claims (1)

1.一种基于时空图注意力网络的行人轨迹预测方法,应用于自动驾驶领域,针对行人横穿马路场景进行行人的行为分析和预测,利用自动驾驶汽车感知系统装备的车载摄像机采集行人信息,其特征在于,具体步骤如下:
步骤一、采集行人轨迹数据构建数据集:
利用自动驾驶汽车感知系统装备的车载摄像机采集车辆行驶过程中前方和两侧的道路视频数据,运用语义分割、图像分类和数据标注与转换技术手段提取行人信息,其中包括道路视频每帧中每个行人在图像坐标系下的坐标值,构建行人轨迹数据集,进一步分为训练数据集和测试数据集;
步骤二、数据预处理提取轨迹特征:
对行人轨迹训练数据集进行预处理,为了在保证预测精度的前提下减小计算量,对输入的数据进行合理采样,提取每个行人的历史观测坐标,定义每个行人在历史时刻t的观测坐标
Figure FDA0003816940960000011
为:
Figure FDA0003816940960000012
其中,i表示场景中第i个行人,t表示历史时刻,Tobs表示观测时域长度,N表示场景中的行人总数量,
Figure FDA0003816940960000013
表示第i个行人在图像坐标系下在历史时刻t沿x轴方向的观测坐标值和沿y轴方向的观测坐标值;
定义每个行人在预测时刻tp的真实坐标
Figure FDA0003816940960000014
为:
Figure FDA0003816940960000015
其中,i表示场景中第i个行人,tp表示预测时刻,Tpre表示预测时域长度,N表示场景中的行人总数量,
Figure FDA0003816940960000016
表示第i个行人在图像坐标系下在预测时刻tp沿x轴方向的真实坐标值和沿y轴方向的真实坐标值;
计算每个行人在历史时刻t与上一个历史时刻t-1的历史相对位置
Figure FDA0003816940960000017
Figure FDA0003816940960000021
Figure FDA0003816940960000022
其中,
Figure FDA0003816940960000023
表示第i个行人在图像坐标系下在历史时刻t沿x轴方向的历史相对值和沿y轴方向的历史相对值;
利用嵌入函数φ(·)对历史相对位置
Figure FDA0003816940960000024
进行升维,得到每个行人在历史时刻t的嵌入向量
Figure FDA0003816940960000025
Figure FDA0003816940960000026
其中,
Figure FDA0003816940960000027
表示第i个行人在历史时刻t的嵌入向量且维数为16,φ(·)表示嵌入函数,嵌入函数由全连接层组成,We表示可学习的全连接网络参数,网络的输入特征维数为2,输出特征维数为16,层数为1;
接着将每个行人在历史时刻t的嵌入向量
Figure FDA0003816940960000028
输入到长短期记忆网络LSTM中,计算得到每个行人在历史时刻t的隐藏状态特征
Figure FDA0003816940960000029
Figure FDA00038169409600000210
其中,
Figure FDA00038169409600000211
表示第i个行人在历史时刻t的隐藏状态特征且维数为32,LSTM(·)由长短期记忆网络单元组成,Wen为可学习得到的长短期记忆网络权重参数,网络的输入特征维数为16,输出特征维数为32,隐藏特征维数为32,层数为1;
步骤三、时间注意力机制计算时间状态特征:
通过时间注意力机制计算包含时间相关性的时间状态特征
Figure FDA00038169409600000212
利用时间注意力机制计算每个行人的其他历史时刻r,r∈{1,...,t}隐藏状态特征对历史时刻t隐藏状态特征的时间注意力系数,提取每个行人的历史轨迹的时间相关性,具体过程如下:
首先,输入每个行人的隐藏状态特征
Figure FDA0003816940960000031
计算第i个行人在其他历史时刻r的隐藏状态特征
Figure FDA0003816940960000032
对历史时刻t的隐藏状态特征
Figure FDA0003816940960000033
的时间注意力系数
Figure FDA0003816940960000034
计算过程如下:
Figure FDA0003816940960000035
Figure FDA0003816940960000036
其中,f(·)表示余弦相似性函数,用来计算相似性值,
Figure FDA0003816940960000037
表示第i个行人在其他历史时刻r的隐藏状态特征,softmax(·)表示归一化指数函数,
Figure FDA0003816940960000038
表示第i个行人在其他历史时刻r的时间注意力系数;
接着,利用第i个行人在其他历史时刻r的时间注意力系数
Figure FDA0003816940960000039
计算第i个行人在历史时刻t的时间状态特征
Figure FDA00038169409600000310
计算过程如下:
Figure FDA00038169409600000311
其中,
Figure FDA00038169409600000312
表示第i个行人在历史时刻t的时间状态特征且维数为32;
步骤四、空间注意力机制计算时空状态特征:
空间注意力机制将每个行人的时间状态特征输入到图注意力网络中,场景中的所有行人对应图结构中的各个节点,行人之间的交互对应图结构中的各个边,基于图注意力网络融合第i个行人在历史时刻t与相邻行人的轨迹交互特征,得到的时空状态特征即包含了时间相关性,也包含了空间交互性,具体过程如下:
首先,定义在图结构中,第i个行人的相邻行人集合为Ni,将所有行人的时间状态特征输入到图注意力网络中,计算在历史时刻t同一场景中第j个行人对第i个行人的空间注意力系数
Figure FDA0003816940960000041
Figure FDA0003816940960000042
其中,j∈{1,...,N}且j∈Ni
Figure FDA0003816940960000043
表示在历史时刻t同一场景中第j个行人对第i个行人的空间注意力系数,Ni表示第i个行人的相邻行人集合,
Figure FDA0003816940960000044
表示第j个行人在历史时刻t的时间状态特征,
Figure FDA0003816940960000045
表示第i个行人的任一相邻的第m个行人在历史时刻t的时间状态特征,m∈{1,...,N}且m∈Ni,LeakyRelu(·)表示非线性激活函数,a表示可学习的模型参数,W表示可学习的节点特征变换权重参数,||表示拼接操作;
其次,在计算得到在历史时刻t第j个行人对第i个行人的空间注意力系数
Figure FDA0003816940960000046
后,利用图注意力网络计算第i个行人在历史时刻t融合相邻行人的空间交互特征的时空状态特征
Figure FDA0003816940960000047
Figure FDA0003816940960000048
其中,
Figure FDA0003816940960000049
表示第i个行人在历史时刻t的时空状态特征且维数为32,σ(·)表示非线性函数;
步骤五、预测未来轨迹,训练预测模型:
将第i个行人在历史时刻Tobs的时空状态特征
Figure FDA00038169409600000410
和隐藏状态特征
Figure FDA00038169409600000411
进行拼接,为了模拟真实场景中行人运动的不确定性,加入服从正态分布的噪声向量z,得到第i个行人在历史时刻Tobs的轨迹解码特征
Figure FDA00038169409600000412
计算过程如下:
Figure FDA0003816940960000051
其中,z表示噪声向量且维数为16,
Figure FDA0003816940960000052
表示第i个行人在历史时刻Tobs的轨迹解码特征且维数为80,||表示拼接操作;
利用由长短期记忆网络组成的解码器Decoder来计算得到未来的预测相对位置,将第i个行人在历史时刻Tobs的轨迹解码特征
Figure FDA0003816940960000053
作为解码器Decoder的输入,计算得到第i个行人在预测时刻Tobs+1的轨迹解码特征
Figure FDA0003816940960000054
计算过程如下:
Figure FDA0003816940960000055
其中,
Figure FDA0003816940960000056
表示第i个行人在预测时刻Tobs+1的轨迹解码特征且维数为80,
Figure FDA0003816940960000057
表示第i个行人在历史时刻Tobs的嵌入向量,Wd表示可学习的网络权重参数;
将计算得到第i个行人在预测时刻Tobs+1的轨迹解码特征
Figure FDA0003816940960000058
通过全连接层进行降维,得到维数为2的第i个行人在预测时刻Tobs+1的预测相对位置
Figure FDA0003816940960000059
Figure FDA00038169409600000510
其中,
Figure FDA00038169409600000511
表示第i个行人在预测时刻Tobs+1的预测相对位置且维数为2,δ(·)表示全连接层网络,Wd表示可学习的网络参数,网络的输入特征维数为80,输出特征维数为2,层数为1;
将第i个行人在预测时刻Tobs+1的预测相对位置
Figure FDA00038169409600000512
与历史时刻Tobs的观测坐标
Figure FDA00038169409600000513
相加即可得到第i个行人在预测时刻Tobs+1的预测坐标
Figure FDA0003816940960000061
计算过程如下:
Figure FDA0003816940960000062
其中,
Figure FDA0003816940960000063
表示第i个行人在预测时刻Tobs+1的预测坐标,
Figure FDA0003816940960000064
第i个行人在图像坐标系下在预测时刻Tobs+1沿x轴方向的预测坐标值和沿y轴方向的预测坐标值;
在得到第i个行人在预测时刻Tobs+1的预测坐标之后,下一预测时刻Tobs+2的预测坐标的利用同样的方法计算得到,依次迭代,即可计算得到各个预测时刻的预测坐标
Figure FDA0003816940960000065
Figure FDA0003816940960000066
在得到每个行人的各个预测时刻的预测坐标后,考虑到生成的合理的行人预测轨迹可能不止一条,结合步骤二中每个行人的预测时刻的真实坐标
Figure FDA0003816940960000067
构造多样损失函数Lvariety,通过采样生成多个轨迹样本,计算其中欧式距离最小的样本作为最佳的预测轨迹,计算方法如下:
Figure FDA0003816940960000068
其中,k是一个初始设定为20的超参数,表示随机抽样生成的样本个数,
Figure FDA0003816940960000069
表示根据第i个行人的预测坐标
Figure FDA00038169409600000610
随机抽样生成的轨迹样本,Lvariety表示多样损失函数;
本方法的模型训练是在pytorch深度学习框架下进行的,使用Adam优化器进行优化,学习率设置为0.01,批大小设置为64,利用训练数据集对方法中所涉及的各种网络的权重参数和模型参数进行训练,计算多样损失函数Lvariety选出最佳的预测轨迹,保存相应的各种权重参数和模型参数,得到训练好的预测模型,然后用测试数据集执行上述步骤二至步骤五来预测行人未来轨迹。
CN202211030137.2A 2022-08-26 2022-08-26 一种基于时空图注意力网络的行人轨迹预测方法 Pending CN115376103A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211030137.2A CN115376103A (zh) 2022-08-26 2022-08-26 一种基于时空图注意力网络的行人轨迹预测方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211030137.2A CN115376103A (zh) 2022-08-26 2022-08-26 一种基于时空图注意力网络的行人轨迹预测方法

Publications (1)

Publication Number Publication Date
CN115376103A true CN115376103A (zh) 2022-11-22

Family

ID=84067343

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211030137.2A Pending CN115376103A (zh) 2022-08-26 2022-08-26 一种基于时空图注意力网络的行人轨迹预测方法

Country Status (1)

Country Link
CN (1) CN115376103A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115829171A (zh) * 2023-02-24 2023-03-21 山东科技大学 一种联合时空信息和社交互动特征的行人轨迹预测方法
CN116882148A (zh) * 2023-07-03 2023-10-13 成都信息工程大学 一种基于空间社会力图神经网络的行人轨迹预测方法及系统

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115829171A (zh) * 2023-02-24 2023-03-21 山东科技大学 一种联合时空信息和社交互动特征的行人轨迹预测方法
CN116882148A (zh) * 2023-07-03 2023-10-13 成都信息工程大学 一种基于空间社会力图神经网络的行人轨迹预测方法及系统
CN116882148B (zh) * 2023-07-03 2024-01-30 成都信息工程大学 基于空间社会力图神经网络的行人轨迹预测方法及系统

Similar Documents

Publication Publication Date Title
CN110164128B (zh) 一种城市级智能交通仿真系统
CN112965499B (zh) 基于注意力模型和深度强化学习的无人车行驶决策方法
Zhao et al. A spatial-temporal attention model for human trajectory prediction.
CN115376103A (zh) 一种基于时空图注意力网络的行人轨迹预测方法
CN109131348B (zh) 一种基于生成式对抗网络的智能车驾驶决策方法
CN110991027A (zh) 一种基于虚拟场景训练的机器人模仿学习方法
CN113094357B (zh) 一种基于时空注意力机制的交通缺失数据补全方法
CN114802296A (zh) 一种基于动态交互图卷积的车辆轨迹预测方法
CN114372116B (zh) 一种基于lstm和时空注意力机制的车辆轨迹预测方法
CN112734808B (zh) 一种车辆行驶环境下易受伤害道路使用者的轨迹预测方法
CN111597961B (zh) 面向智能驾驶的移动目标轨迹预测方法、系统、装置
CN115829171B (zh) 一种联合时空信息和社交互动特征的行人轨迹预测方法
CN108791302B (zh) 驾驶员行为建模系统
CN116307152A (zh) 时空交互式动态图注意力网络的交通预测方法
CN115438856A (zh) 基于时空交互特征和终点信息的行人轨迹预测方法
Liu et al. A method for short-term traffic flow forecasting based on GCN-LSTM
CN116595871A (zh) 基于动态时空交互图的车辆轨迹预测建模方法与装置
CN112927507B (zh) 一种基于LSTM-Attention的交通流量预测方法
CN114596726A (zh) 基于可解释时空注意力机制的停车泊位预测方法
CN115331460B (zh) 一种基于深度强化学习的大规模交通信号控制方法及装置
CN116434569A (zh) 基于stnr模型的交通流量预测方法及系统
CN115512214A (zh) 一种基于因果注意力的室内视觉导航方法
CN111443701A (zh) 基于异构深度学习的无人驾驶车辆/机器人行为规划方法
Zhao et al. End-to-end spatiotemporal attention model for autonomous driving
Zhang et al. A virtual end-to-end learning system for robot navigation based on temporal dependencies

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination