CN114117259A - 一种基于双重注意力机制的轨迹预测方法及装置 - Google Patents
一种基于双重注意力机制的轨迹预测方法及装置 Download PDFInfo
- Publication number
- CN114117259A CN114117259A CN202111449388.XA CN202111449388A CN114117259A CN 114117259 A CN114117259 A CN 114117259A CN 202111449388 A CN202111449388 A CN 202111449388A CN 114117259 A CN114117259 A CN 114117259A
- Authority
- CN
- China
- Prior art keywords
- attention
- target
- time
- module
- graph
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/95—Retrieval from the web
- G06F16/953—Querying, e.g. by the use of web search engines
- G06F16/9537—Spatial or temporal dependent retrieval, e.g. spatiotemporal queries
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Databases & Information Systems (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Evolutionary Computation (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种基于双重注意力机制的轨迹预测方法及装置。方法包括:获取观测区域内多个目标在上一时间段的轨迹数据;构建图结构数据,图结构数据中节点与目标一一对应;图注意力网络模组被配置为提取目标之间的空间交互信息以及基于所述空间交互信息更新图结构数据的节点特征;时间注意力网络模组从更新后的目标特征中提取时间交互信息以及基于提取的时间交互信息获取目标在下一时间段的预测轨迹。实现了空间交互信息、时间交互信息的融合即获得了时空交互信息,能够更有效地对交通场景下目标的运动模式建模和预测,以获得更好的目标轨迹预测结果,能并行预测多个目标轨迹,加快轨迹预测速度。
Description
技术领域
本发明涉及轨迹预测领域,特别是涉及一种基于双重注意力机制的轨迹预测方法及装置。
背景技术
移动目标(如行人、汽车、无人驾驶智能汽车、移动机器人等)的轨迹数据中不仅含有丰富的时间特征信息,同一时刻相同场景下的目标之间还具有复杂的空间交互信息,不同时刻的不同目标之间的关联信息定义为时空关联信息。对于某一目标来说,与周围邻近目标的空间交互信息、以及时空关联信息可能会影响到该目标的未来轨迹。但现有技术中,往往都是只是基于目标轨迹数据中的时间特征信息进行轨迹预测,预测精度并不高。
发明内容
本发明旨在至少解决现有技术中存在的技术问题,特别创新地提出了一种基于双重注意力机制的轨迹预测方法及装置。
为了实现本发明的上述目的,根据本发明的第一个方面,本发明提供了一种基于双重注意力机制的轨迹预测方法,包括:获取观测区域内多个目标在上一时间段的轨迹数据;对目标的轨迹数据进行预处理,将预处理结果作为目标特征的初始值;构建图结构数据,所述图结构数据中节点与目标一一对应,所述节点特征的初始值为对应目标的目标特征的初始值,连接相邻节点形成图结构数据的边;将所述图结构数据输入图注意力网络模组,所述图注意力网络模组被配置为提取目标之间的空间交互信息以及基于所述空间交互信息更新图结构数据的节点特征,即更新目标特征;将更新后的目标特征输入时间注意力网络模组,所述时间注意力网络模组被配置为从更新后的目标特征中提取时间交互信息以及基于提取的时间交互信息获取目标在下一时间段的预测轨迹。
为了实现本发明的上述目的,根据本发明的第二个方面,本发明提供了一种轨迹预测装置,包括:获取模块,用于获取获取观测区域内多个目标在上一时间段的轨迹数据,将获取的轨迹数据输入轨迹预测模块;图结构数据构建模块,用于构建图结构数据并将所述图结构数据输入图注意力网络模组,其中,所述图结构数据中节点与目标一一对应,所述节点特征的初始值为对应目标的目标特征的初始值,连接相邻节点形成图结构数据的边;图注意力网络模组,被配置为提取目标之间的空间交互信息以及基于所述空间交互信息更新图结构数据的节点特征,即更新目标特征,将更新后的目标特征输入时间注意力网络模组;时间注意力网络模组,被配置为从更新后的目标特征中提取时间交互信息以及基于提取的时间交互信息获取目标在下一时间段的预测轨迹。
综上所述,由于采用了上述技术方案,本发明的有益效果是:通过图注意力网络模组提取目标轨迹数据中的空间交互信息,并利用空间交互信息更新目标特征,再将更新后的具备空间交互信息的目标特征输入时间注意力网络模组用于提取时间交互信息,这时得到的时间交互信息本质为时空关联信息,实现了空间交互信息、时间交互信息的融合,从轨迹数据中提取了更为丰富的信息,这样能够更有效的地对交通场景下目标的运动模式建模和预测,以获得更好的目标轨迹预测结果,并且能够并行的实现对观测区域内多个目标的轨迹预测,加快了轨迹预测速度。
附图说明
图1是本发明一具体实施方式中基于双重注意力机制的轨迹预测方法的流程示意图;
图2是本发明一具体实施方式中轨迹预测装置部分框架示意图;
图3是本发明一具体实施方式中图注意力网络提取相邻目标之间的空间交互信息的过程示意图;
图4是本发明一具体实施方式中门控激活模块的结构示意图;
图5是本发明一具体实施方式中多头图注意力机制的目标特征更新示意图;
图6是本发明一具体实施方式中两层图注意力网络结构示意图;
图7是本发明一具体实施方式中时间注意力卷积网络的结构示意图;
图8是本发明一具体实施方式中因果卷积过程示意图;
图9是本发明一具体实施方式中时间注意力卷积模块工作示意图;
图10是本发明一具体实施方式中增强残差模块的处理过程示意图。
具体实施方式
下面详细描述本发明的实施例,所述实施例的示例在附图中示出,其中自始至终相同或类似的标号表示相同或类似的元件或具有相同或类似功能的元件。下面通过参考附图描述的实施例是示例性的,仅用于解释本发明,而不能理解为对本发明的限制。
在本发明的描述中,需要理解的是,术语“纵向”、“横向”、“上”、“下”、“前”、“后”、“左”、“右”、“竖直”、“水平”、“顶”、“底”“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。
在本发明的描述中,除非另有规定和限定,需要说明的是,术语“安装”、“相连”、“连接”应做广义理解,例如,可以是机械连接或电连接,也可以是两个元件内部的连通,可以是直接相连,也可以通过中间媒介间接相连,对于本领域的普通技术人员而言,可以根据具体情况理解上述术语的具体含义。
本发明公开了一种基于双重注意力机制的轨迹预测方法,在一种优选实施方式中,如图1所示,包括:
步骤S100,获取观测区域内多个目标在上一时间段的轨迹数据。观测区域的大小可人为设定,观测区域优选但不限于为半径为3米或3米以上的区域。上一时间段的大小也可根据需要设置,如大于等于1秒。目标优选但不限于为行人、移动机器人、车辆或无人驾驶汽车等。轨迹数据优选的为目标的坐标位置信息。
步骤S200,对目标的轨迹数据进行预处理,将预处理结果作为目标特征的初始值。优选的,具体过程为:将目标在上一时间段每个采样点的绝对位置坐标和相对位置坐标聚合,将目标的所有采样点的聚合结果进行线性变换获得目标特征的初始值。线性变换优选但不限于为将目标的所有采样点的聚合结果按照采样时间顺序构建为列向量。如:设第i个目标在第t个采样点的绝对位置坐标和相对位置坐标聚合结果弟 表示第i个目标在第t个采样点的绝对位置坐标,表示第i个目标在第t个采样点的相对位置坐标。|表示位置坐标聚合符号,|运算优选但不限于将相对位置坐标放置在绝对位置坐标后。将第i个目标在Tobs内所有采样点的聚合结果组成列向量Hi,Hi表示第i个目标的目标特征的初始值。
步骤S300,构建图结构数据,图结构数据中节点与目标一一对应,节点特征的初始值为对应目标的目标特征的初始值,连接相邻节点形成图结构数据的边。具体的,如图2所示,图结构数据包括上一时间段中Tobs个采样点对应的图结构。需要说明的是,图2中为了便于理解,每个采样点对应的图结构连接了一个多头门控图注意力网络,但实际为同一个多头门控图注意力网络,该多头门控图注意力网络同时并行地对多个采样点的图结构进行处理。
步骤S400,将图结构数据输入图注意力网络模组,图注意力网络模组配置为提取目标之间的空间交互信息以及基于空间交互信息更新图结构数据的节点特征,即更新目标特征。
在本实施方式中,优选的,图注意力网络模组包括一层或多层级联的图注意力网络,以及图输出层;图注意力网络被配置为:基于输入图注意力网络的图结构数据提取相邻目标之间的注意力互相关系数,对注意力互相关系数进行正则化处理获得相邻目标之间的空间交互信息,基于相邻目标之间的空间交互信息更新目标特征;图输出层输出最后一层图注意力网络更新后的每个目标的目标特征,图输出层分别输出每个目标更新后的目标特征,设第i个目标更新后的目标特征为H′i。设下一时间段共包含Tpred个时间点,则目标i在未来第t′个时刻预测的绝对位置坐标为:目标i在未来第t′个时刻预测的相对位置坐标为:
在本实施方式中,进一步优选的,在第一层图注意力网络提取相邻目标之间的注意力互相关系数之前,先将目标的目标特征的初始值与图注意力网络的共享权重矩阵叉乘获得第一特征,基于目标的第一特征提取目标之间的注意力互相关系数。设Hi表示第i个目标的目标特征的初始值,那么WHi表示第i个目标的第一特征,同理,WHj表示第j个目标的第一特征,设总共有N个目标,i=1,2,...,Tobs,i不等于j,j=1,2,...,Tobs。W表示图注意力网络的共享权重矩阵,W是图注意力网络在训练过程中不断学习获得。
在本实施方式中,图3展示了图注意力网络提取相邻目标之间的空间交互信息的过程,a表示图注意力网络,其通常为一层前馈神经网络,设节点i和节点j相邻,Hi、Hj分别表示第i个目标和第j个目标的目标特征的初始值,第i个目标和第j个目标的目标特征的初始值均可以看做一个时间序列,那么第i个目标的第一特征WHi和第j个目标的第一特征WHj也可以看作时间序列,如图3所示。图注意力网络提取第i个目标和第j个目标之间的注意力互相关系数eij=a(WHi,WHj)=aT[WHi||WHj],||表示向量的聚合,WHi||WHj表示将WHi的尾部和WHj的头部相连。T代表矩阵的转置操作,即将WHi||WHj输入图注意力网络a,将图注意力网络a的输出结果进行转置。在获得了相邻目标之间的注意力互相关系数后,引入激活SoftMax层对所有注意力互相关系数进行正则化处理得到相邻目标之间的空间交互信息,具体如下式:
在本实施方式中,如图3所示,通过注意力互相关系数与原节点特征相乘再求和得到更新后的节点特征H′i,即:
步骤S500,将更新后的目标特征输入时间注意力网络模组,时间注意力网络模组被配置为从更新后的目标特征中提取时间交互信息以及基于提取的时间交互信息获取目标在下一时间段的预测轨迹。
在一种优选实施方式中,与传统图注意力网络不同,图注意力网络模组还包括门控激活模块;门控激活模块对每个目标的第一特征进行处理并将处理结果作为目标的第二特征,基于相邻目标之间的空间交互信息和目标的第二特征更新目标特征。轨迹数据原始维度较低,因此需要额外地关注原始数据的信息损失和细微特征丢失的问题,而门控激活模块能够动态调整信息损失,较好地保留细节特征。
在本实施方式中,门控激活模块的结构示意图如图4所示。在节点的第一特征与图注意力互相关系数相乘输出新的特征之前,本发明使用门控激活函数进一步处理输入特征,具体如下:
gi=fg(WHi+bh)⊙(WHi+bh)
其中fg(·)为tanh激活函数,bh为偏置,⊙表示逐元素相乘的哈达玛积。通过门控机制可以保证在图注意力网络训练的过程中梯度不会消失,稳定训练过程。
在一种优选实施方式中,为提高效率、简化模型训练,将相邻目标之间的位置差作为图边特征先验知识加入注意力互相关系数的计算过程中。对于上一段时间中时间点t的图边特征先验知识为:
其中,φr表示线性变换,其将同一场景下行人的相对位置信息编码到高维空间F1表示预设的目标特征维度大小,优选但不限于为3;Wr为可学习的变换参数,在训练中学习获得最佳值。加入图边特征先验知识后的相邻目标之间的空间交互信息表示为:
在一种优选实施方式中,为了能够提取更多的目标轨迹数据之间的空间交互特征,提高轨迹预测的准确性,图注意力网络为多头图注意力网络,多头图注意力网络分别获取相邻目标之间的空间交互信息,每头图注意力网络以各自的空间交互信息更新目标特征,聚合多头图注意力网络更新的目标特征获得目标的最终更新的目标特征。具体计算过程可见图5所示,按照该过程获得最终更新的目标特征可表示为:
其中,||表示向量的聚合,αk,ij表示第k个头的节点i,j之间的图注意力互相关系数,σ表示ELU激活函数,R(·)表示图输出层(可为一层全连接层),Wr′表示图输出层的参数。
在一种优选实施方式中,当图注意力网络模组包括多层级联的图注意力网络时,多层图注意力网络之间设置有至少一条残差连接支路和/或至少一条跳线连接支路,如图6所示采用了两层图注意力网络,来更好地为空间关系建模,为更准确提取场景中行人的空间交互特征,在传统的图注意力网络基础上,添加残差块(residual block)和跳线(skipconnection)方法来保留更多全局和细节特征。图6中,展示了两层图注意力网络中设置了两条残差连接支路,即连接第一层图注意力网络的输入到输出的支路,以及连接第二层图注意力网络的输入到输出的支路。还展示了一条跳线支路,即连接第一层图注意力网络的输入到第二层图注意力网络的输出的支路。
在一种优选实施方式中,时间注意力网络模组通过时间注意力模块和第一因果卷积模块从更新后的目标特征中提取不同尺度的时间交互信息,并基于提取的不同尺度的时间交互信息获取目标在下一时间段的预测轨迹。轨迹预测模型通过时间注意力模块和第一因果卷积模块提取目标i轨迹数据不同尺度的时间特征,并基于提取的不同尺度的时间特征获取目标i在下一时间段的预测轨迹。时间注意力网络模组的能够并行地提取所有目标的时间交互信息并输出下一时间段的预测轨迹。
在本实施方式中,递归神经网络模型具备了对未知时序序列数据的处理能力,通过输入最后时刻的序列数据和隐藏状态,递归地输出时序预测数据。递归神经网络模型的数据处理模式有两个优点,第一,能够处理任意长度序列,而隐藏状态始终具有相同大小;第二个优点是可以对所有的数据使用相同的网络参数。这两个因素使递归神经网络模型成为简洁合理的时间序列处理模型。然而,递归神经网络模型使用的参数共享机制是取决于以下假设:时间序列数据的特征在时间分布上是一致的。实际上,由于多种因素与时间共同影响时间序列变换,以上假设在大部分时间序列处理上并不成立,因此单纯地用递归神经网络模型处理时序序列数据的预测效果并不好,预测准确性较低。此外,由于下一时刻输出必须依赖上一时刻的隐藏状态,因此递归神经网络模型在时间维度上并不能做到并行处理,预测速度较慢。时序卷积神经网络模型有因果卷积模型、空洞卷积模型等。时序卷积神经网络模型主要目的是解决递归神经网络在训练过程中不稳定和不能并行化的问题,但是存在着特征提取单一,轨迹预测精度较低的问题。
在本实施方式中,时间注意力网络模组包括一层或一层以上级联的时间注意力卷积网络,以及译码器;如图7所示,第一支路和第二支路均连接在时间注意力卷积网络的输入端和时间注意力卷积网络的聚合模块的输入端之间,第一支路上设置有时间注意力模块,第二支路上设置有第一因果卷积模块,优选的,为简便计算,聚合模块被配置为对第一支路和第二支路输出的特征数据进行叠加。译码器将最后一层时间注意力卷积网络的聚合模块输出的时间交互信息映射为目标在下一时间段的预测轨迹。优选的,译码器包括一个全连接层,通过一个全连接层将最后一层时间注意力卷积网络的聚合模块的输出的时间交互信息映射为目标在下一时间段的预测轨迹。具体为,译码器将各目标的时间交互信息映射为该目标在下一时间段的预测轨迹。
在本实施方式中,由于卷积神经网络的卷积核在时间维度上的卷积处理,这种卷积模式被称为因果卷积(Casual convolution),从而获得了具有卷积核感受野大小含有时间维度信息的特征。因果卷积的实现为通过时间序列数据左侧加入填充数据,使得因果输出的某一时刻的特征只从当前时刻之前的数据中提取。在顺序上保证了未来时刻的信息不会泄露到前面时刻的信息中,保证了时序正确性。
在本实施方式中,如图8所示,为第l层时间注意力卷积网络中第一因果卷积的数据处理过程。对于第i个目标,设输入第l层的第一因果卷积的向量为则其输出为第(l+1)层的第一因果卷积的输入时间序列随着第一因果卷积模块叠加多层,高层输出的特征的感受野也对应线性增大。在因果卷积处理时序序列中,感受野对应的是当前时刻能够观察到多少个历史时刻的数据。假设第一因果卷积的卷积核大小为k′,第l层的感受野为(k′-1)*l+1。因果卷积的具体实现为普通的卷积模块加上左填充数据操作,如图8所示,设定因果卷积核的卷积核大小为3,每次输入到因果卷积模块前,对输入数据最后一维进行左填充,填充维度为2。通过因果卷积,当前时刻的输出特征只与当前时刻及其之前的输入数据有关,保留了轨迹序列中的时序信息。
在本实施方式中,因果卷积具有将之前时刻的信息融入到当前时刻的输出中的能力,然而,因果卷积模块仅仅对历史时序信息进行普通的卷积和融入,对时间交互信息仅有简单的聚合操作,无法选择关注特定的时间信息。基于上述不足,在时间注意力卷积网络中引入时间注意力模块。
在本实施方式中,时间注意力模块是选择性地重点关注一部分时间点信息,同时忽略或较少关注其余时间点信息。注意力机制对输入数据中的重要信息分配较多的权重进行特征提取,而对不重要的信息分配较少的权重。在处理时间交互信息时,注意力机制是轨迹预测模型重点关注距离当前时间点较近的时间段和/或轨迹变化明显的时间段,对一些较远的时间段或轨迹变化较小的时间段设定较小的权重。
在一种优选方式中,为规避传统自注意力机制在处理时间序列数据时,会同时关注历史时刻和未来时刻,这不符合时序序列的顺序特点。本发明在自注意力机制的基础上,进一步优选的,使用掩码来屏蔽未来时刻特征对当前时刻的影响,即时间注意力模块使用掩码屏蔽未来时刻的时间交互信息对当前时刻的时间交互信息的影响,掩码可以是数值0,这样利用注意力机制来整合历史时刻对当前时刻的影响,并自动关注影响较大的历史时序特征,如图9所示,具体过程包括(每个目标的目标特征处理过程均按照如下步骤进行):
首先,将输入时间注意力模块的目标特征分别通过键线性变换矩阵f、查询线性变化矩阵g、值线性变化矩阵h映射为键矩阵查询矩阵值矩阵 表示输入第l层时间注意力卷积网络的时间序列,实际为大小为Tobs×2的一个矩阵。键线性变换矩阵f、查询线性变化矩阵g、值线性变化矩阵h均是大小为2×dk的矩阵,键线性变换矩阵f、查询线性变化矩阵g、值线性变化矩阵h均为时间注意力模块的模型变量。在轨迹预测模型训练前,通过现有的gaussian分布初始化方法初始化三个矩阵的数值(如可参考网址https://zhuanlan.zhihu.com/p/69026430中公开的gaussian分布初始化方法),在轨迹预测模型的训练中不断更新键线性变换矩阵f、查询线性变化矩阵g、值线性变化矩阵h,具体更新方法为时间注意力机制中的常规设置,在此不再赘述。获得的键矩阵、查询矩阵、值矩阵均是大小为Tobs×dk的矩阵,因此,键矩阵、查询矩阵、值矩阵的维度为dk。
之后,保留第一矩阵中下三角元素的数值,将第一矩阵中非下三角的元素赋值为0,即掩码为0,获得第二矩阵Wl′(l):
之后,通过第一激活函数对第二矩阵进行正则化处理获得注意力权重矩阵Wa(l);第一激活函数优选但不限于为SoftMax激活函数。
最后,获取注意力权重矩阵中所有元素的数值累加值,将数值累加值与值矩阵相乘并将相乘结果作为时间注意力模块输出的注意力输出特征SA(l)。
在一种优选实施方式中,为精确捕捉时序序列的时间关联性,进行了二次时间交互信息提取,如图7所示,在第一支路上还包括级联于时间注意力模块之后的第二因果卷积模块,将时间注意力模块提取的注意力输出特征作为第二因果卷积模块的输入。
在一种优选实施方式中,为提取序列中相对重要的信息并将其直接传送到下一层。输入编码器的轨迹序列并没有直接参与正则化注意力权重矩阵的计算,而是通过转化成值矩阵进行计算。因为值矩阵的映射维度通常小于输入的轨迹序列维度,这就有可能存在信息的丢失问题,受残差模块的启发,我们希望注意力能够直接应用于输入的轨迹序列,主要原因有两个,第一,直接应用于输入轨迹序列可以减少在前向传播过程中信息丢失,解决梯度消失的问题;第二,在神经网络训练的情况下,通过注意力权重矩阵告知网络哪些内容是相对重要的部分,其会加强学习特定的部分,帮助模型更快地学习到数据中的特征。为此,本发明提供的时间注意力卷积网络还包括第一残差支路,第一残差支路连接在时间注意力卷积网络的输入端和时间注意力卷积网络的聚合模块的输入端之间,第一残差支路上设置有增强残差模块;增强残差模块将注意力权重矩阵Wa(l)进行行求和获得权重向量Mt,并求取权重向量Mt与输入该层(设为l层)时间注意力卷积网络的目标特征的哈达玛积SR(l),将哈达玛积SR(l)作为第l层的增强残差其中⊙表示求取哈达玛积符号,具体过程如图10所示,展示了增强残差模块的计算过程。增强残差模块通过使用注意力权重矩阵与输入的序列直接相乘,具有保留输入信息,稳定神经网络的训练过程,加快收敛的作用。
在一种优选实施方式中,如图7所示,时间注意力卷积网络还包括第二残差支路Residual(即跳连模块),第二残差支路Residual连接在时间注意力卷积网络的输入端和时间注意力卷积网络的聚合模块的输入端之间,即第二残差支路Residual直接将输入信息引入聚合模块。增加第二残差支路Residual目的是将时间注意力卷积网络的输出时间交互信息表示为原始的输入数据和输入数据的时间交互信息的叠加,之所以要保留原始输入数据,是因为其本身含有丰富的时间交互信息,这样使得聚合后的数据时间交互信息更丰富,轨迹预测更准确。
在本发明的一种实施例中,时间注意力卷积网络包括第一支路、第二支路、第一残差支路和第二残差支路,具体结构如图7所示。第一因果卷积模块和时间注意力模块的主要作用都是时间交互信息提取,但是具体作用有所不同,二者组合能够多尺度地提取时间交互信息。
在本实施例中,第一因果卷积模块直接处理输入的相对轨迹数据,学习和提取序列内部最明显的时间关联性,从宏观上对时间维度进行处理;时间注意力模块能够学习细微的时间交互信息,且将之前时刻的信息融入到当前时刻之中。两个时间处理模块共同作用,多尺度地学习时间交互信息。增强残差模块通过使用时间注意力卷积模块的权重矩阵,与输入的序列直接相乘,具有保留输入信息,稳定神经网络的训练过程,加快收敛的作用。本发明还使用了残差与跳连模块,目的是将时间注意力卷积网络的输出时间交互信息表示为原始的输入数据和输入数据的时间交互信息的叠加,之所以要保留原始输入数据,是因为其本身含有丰富的时间交互信息。聚合模块就是将第一支路、第二支路、第一残差支路、第二残差支路的输出特征叠加,时间注意力卷积网络输出的特征为O:
其中,O表示第l层时间注意力卷积网络的输出,表示第l层时间注意力卷积网络的输入数据,SA(·)表示时间注意力模块的输出特征,SR(·)表示增强残差模块的输出特征,C(·)表示第一因果卷积模块输出特征。
在一种优选实施方式中,目标运动轨迹具有多样性和不确定性(当目标维行人时更具有多样性和不确定性),即使在同一场景和同一观测轨迹的情况下,由于行人自身意图的不同,其未来轨迹也会多种多样,基于此见解,本发明在提出的模型在较为准确预测未来轨迹的基础上,通过在时空交互特征上添加随机噪声,使模型能够生成具有多样性的预测轨迹。为了使轨迹预测模型输出的轨迹更符合实际情况,译码器包括噪声添加模块和时间输出层,噪声添加模块在最后一层时间注意力卷积网络的输出特征中添加随机噪声,并将添加随机噪声后的特征传输给时间输出层,通过时间输出层映射为所述目标在下一时间段的预测轨迹。添加的随机噪声优选但不限于为(0.1,0)的高斯分布噪声。时间输出层优选但不限于为一个全连接层。
在本实施方式中,进一步优选的,为了使轨迹预测模型有一定的生成空间,本发明还公开了一种轨迹预测模型训练方法,在轨迹预测模型的训练中,译码器的噪声添加模块在最后一层时间注意力卷积网络的输出特征中添加不同的噪声,这样获得多个添加了不同噪声的特征,时间输出层将多个添加了不同噪声的特征分别映射为预测轨迹,则获得多个不同的预测轨迹;计算每条预测轨迹与真实轨迹的差异,将差异最小的预测轨迹反向传播更新时间输出层的线性映射参数。差异优选但不限于为均方差值,具体的,时间输出层可包括五层,每一层都是一个线性映射。
在本实施方式中,直接在时间卷积神经网络的输出特征上,添加高斯分布的噪声,然后通过多层感知机解码器网络输出多样的未来时刻相对位置轨迹数据,然后将相对位置坐标轨迹转换成绝对位置坐标轨迹进行损失函数的计算。本发明使用均方误差作为损失函数,且依照Social-GAN使用多样性损失函数来鼓励预测模型的预测多样性和灵活性。具体的操作为,所提出模型通过多次随机采样获得多个噪声数据,与TACN输出的时间交互信息数据进行聚合,再由解码器生成多个预测轨迹,在训练过程中,我们选择与真实的未来时刻轨迹最相近的输出轨迹来计算损失函数,多样性损失函数的具体含义是模型同时生成多条预测轨迹,通过计算这些预测轨迹与真实值的均方误差,选择其中误差最小的一条轨迹进行反向传播和训练。本发明还公开了一种轨迹预测装置,在一种优选实施方式中,该装置包括:获取模块,用于获取获取观测区域内多个目标在上一时间段的轨迹数据,将获取的轨迹数据输入轨迹预测模块;图结构数据构建模块,用于构建图结构数据并将所述图结构数据输入图注意力网络模组,其中,所述图结构数据中节点与目标一一对应,节点特征的初始值为对应目标的目标特征的初始值,连接相邻节点形成图结构数据的边;图注意力网络模组,被配置为提取目标之间的空间交互信息以及基于所述空间交互信息更新图结构数据的节点特征,即更新目标特征,将更新后的目标特征输入时间注意力网络模组;时间注意力网络模组,被配置为从更新后的目标特征中提取时间交互信息以及基于提取的时间交互信息获取目标在下一时间段的预测轨迹。
在本发明提出的基于双重注意力机制的轨迹预测方法的一个实施例中,目标为行人,对改轨迹预测方法进行验证,本发明提出的轨迹预测模型能够进一步学习多行人轨迹序列中的空间交互信息,并在时间注意力卷积网络中输入轨迹数据的空间交互特征,这样时间注意力模块能够生成时空关联信息,因果卷积能够继续提取序列中的时间特征,有效地提高了模型的完善性和精度。
在本实施例中,使用的行人运动数据集ETH和UCY进行所提出模型的训练和验证。这两个数据集共有五个子数据集,分别是ETH,HOTEL,UNIV,ZARA1,ZARA2。图注意力网络第一层的输入数据维度为64,输出维度为16,多头注意力数目为2,第二层的输入数据维度为32,输出数据维度为32,该层未使用多头注意力。同一场景行人的位置差特征变换函数输出维度为64。
使用两种衡量标准来评估所提出模型的精度,分别是平均位移误差,最终位移误差,两种衡量标准越低,代表预测模型的效果越好。
平均位移误差(Average displacement error,ADE):通过对模型输出的N个目标的预测轨迹的数据与真实数据进行均方误差(Mean square error,MSE)的计算得到,具体如下式:
最终位移误差(Final displacement error,FDE):通过对模型输出的最后一个时刻的预测轨迹数据与真实数据进行均方误差得到,具体如下式:
利用本发明提出的轨迹预测方法Graph-TP-TACN模型进行了轨迹预测精度比较,如表1,Graph-TP-TACN能够准确的轨迹预测。
表1本发明所提出轨迹预测方法Graph-TP-TACN的ADE/FDE
在使用ETH和UCY数据集训练完成后,我们将预测精度最高的Graph-TP-TACN模型(本发明的预测模型)与训练好的Social-LSTM、Social-GAN和STGAT模型在自制数据集上进行轨迹预测与精度比较,如表2。可见在实际校园场景中,所提出模型与在标准数据集上的效果相比,所提出模型的预测平均位移误差仍然较低,由此可见具有良好的准确性,进一步说明了所提出模型能够良好的泛化能力。
此外,我们选取五种主流行人轨迹预测模型进行预测精度对比,分别为:
(1)SR-LSTM模型提取人群中每一时刻意图,通过消息传递机制联合迭代更新和细化所有行人的当前状态。
(2)Sophie模型在GAN模型的基础上引入社会注意力和物理注意力机制,使用LSTM进行预测。
(3)Trajectron模型用LSTM提取时空信息,并使用条件变分自动编码器[58]生成未来轨迹。
(4)Social-STGCNN模型在图卷积网络提取时空特征后,使用时间外推卷积神经网络直接对时空特征进行操作生成轨迹。
(5)STAGT模型使用LSTM提取时间交互信息,使用原始的图注意力网络提取空间交互信息,将时空交互信息聚合解码输出预测轨迹。
表2不同算法在广场、超市和教学楼数据集上的ADE/FDE
轨迹预测任务通常应用在自动驾驶,视频监控等即时性系统之中,因此需要对所提出算法进行实时性评估。本文分别对比了几种基于递归神经网络的轨迹预测模型和几种基于图卷积网络的模型的运行速度,在搭载Inter Core i7-10700K CPU和Nvidia RTX3090 GPU的仿真平台上实时性评估,结果如表3。从运行结果可见本发明所提出算法运行速度快,实时性更好。
表3不同轨迹预测模型的运行速度对比
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不一定指的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任何的一个或多个实施例或示例中以合适的方式结合。
尽管已经示出和描述了本发明的实施例,本领域的普通技术人员可以理解:在不脱离本发明的原理和宗旨的情况下可以对这些实施例进行多种变化、修改、替换和变型,本发明的范围由权利要求及其等同物限定。
Claims (17)
1.一种基于双重注意力机制的轨迹预测方法,其特征在于,包括:
获取观测区域内多个目标在上一时间段的轨迹数据;
对目标的轨迹数据进行预处理,将预处理结果作为目标特征的初始值;
构建图结构数据,所述图结构数据中节点与目标一一对应,所述节点特征的初始值为对应目标的目标特征的初始值,连接相邻节点形成图结构数据的边;
将所述图结构数据输入图注意力网络模组,所述图注意力网络模组被配置为提取目标之间的空间交互信息以及基于所述空间交互信息更新图结构数据的节点特征,即更新目标特征;
将更新后的目标特征输入时间注意力网络模组,所述时间注意力网络模组被配置为从更新后的目标特征中提取时间交互信息以及基于提取的时间交互信息获取目标在下一时间段的预测轨迹。
2.如权利要求1所述的基于双重注意力机制的轨迹预测方法,其特征在于,对目标的轨迹数据进行预处理,将预处理结果作为目标特征的初始值,具体为:将目标在上一时间段每个采样点的绝对位置坐标和相对位置坐标聚合,将目标所有采样点的聚合结果进行线性变换获得目标特征的初始值。
3.如权利要求1或2所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述图注意力网络模组包括一层或多层级联的图注意力网络,以及图输出层;
所述图注意力网络被配置为:基于输入所述图注意力网络的图结构数据提取相邻目标之间的注意力互相关系数,对所述注意力互相关系数进行正则化处理获得相邻目标之间的空间交互信息,基于相邻目标之间的空间交互信息更新目标特征;
所述图输出层输出最后一层图注意力网络更新后的每个目标的目标特征。
4.如权利要求3所述的基于双重注意力机制的轨迹预测方法,其特征在于,在第一层图注意力网络提取相邻目标之间的注意力互相关系数之前,先将目标的目标特征的初始值与图注意力网络的共享权重矩阵叉乘获得第一特征,基于目标的第一特征提取目标之间的注意力互相关系数。
5.如权利要求4所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述图注意力网络模组还包括门控激活模块;
所述门控激活模块对每个目标的第一特征进行处理并将处理结果作为目标的第二特征,基于相邻目标之间的空间交互信息和目标的第二特征更新目标特征。
6.如权利要求3所述的基于双重注意力机制的轨迹预测方法,其特征在于,将相邻目标之间的位置差作为计算注意力互相关系数的图边特征先验知识。
7.如权利要求4、5或6所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述图注意力网络为多头图注意力网络,多头图注意力网络分别获取相邻目标之间的空间交互信息,每头图注意力网络以各自的空间交互信息更新目标特征,聚合多头图注意力网络更新的目标特征将聚合的目标特征作为目标最终更新的目标特征。
8.如权利要求4、5或6所述的基于双重注意力机制的轨迹预测方法,其特征在于,当所述图注意力网络模组包括多层级联的图注意力网络时,多层图注意力网络之间设置有至少一条残差连接支路和/或至少一条跳线连接支路。
9.如权利要求1所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述时间注意力网络模组通过时间注意力模块和第一因果卷积模块从更新后的目标特征中提取不同尺度的时间交互信息,并基于提取的不同尺度的时间交互信息获取所述目标在下一时间段的预测轨迹。
10.如权利要求9所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述时间注意力网络模组包括一层或一层以上级联的时间注意力卷积网络,以及译码器;
所述时间注意力卷积网络包括第一支路、第二支路和聚合模块,所述第一支路和第二支路均连接在所述时间注意力卷积网络的输入端和所述时间注意力卷积网络的聚合模块的输入端之间,所述第一支路上设置有时间注意力模块,所述第二支路上设置有第一因果卷积模块;
所述译码器将最后一层时间注意力卷积网络的聚合模块输出的时间交互信息映射为所述目标在下一时间段的预测轨迹。
11.如权利要求9所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述时间注意力模块使用掩码屏蔽未来时刻的时间交互信息对当前时刻的时间交互信息的影响。
12.如权利要求11所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述时间注意力模块的处理过程包括:
将输入所述时间注意力模块的目标特征分别通过键线性变换矩阵、查询线性变化矩阵、值线性变化矩阵映射为键矩阵、查询矩阵、值矩阵;
通过键矩阵与查询矩阵点乘获得第一矩阵,保留所述第一矩阵中下三角元素的数值,将所述第一矩阵中非下三角的元素赋值为0,获得第二矩阵;通过第一激活函数对第二矩阵进行正则化处理获得注意力权重矩阵;
获取所述注意力权重矩阵中所有元素的数值累加值,将数值累加值与所述值矩阵相乘并将相乘结果作为所述时间注意力模块输出的时间注意力输出特征。
13.如权利要求12所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述时间注意力卷积网络还包括第一残差支路,所述第一残差支路连接在所述时间注意力卷积网络的输入端和所述时间注意力卷积网络的聚合模块的输入端之间,所述第一残差支路上设置有增强残差模块;
所述增强残差模块将所述注意力权重矩阵进行行求和获得权重向量,并求取所述权重向量与输入所述时间注意力卷积网络的目标特征的哈达玛积。
14.如权利要求9、10、11、12或13所述的基于双重注意力机制的轨迹预测方法,其特征在于,在所述第一支路上还包括级联于所述时间注意力模块之后的第二因果卷积模块。
15.如权利要求9、10、11、12或13所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述时间注意力卷积网络还包括第二残差支路,所述第二残差支路连接在所述时间注意力卷积网络的输入端和所述时间注意力卷积网络的聚合模块的输入端之间,所述第二残差支路将输入所述时间注意力卷积网络的目标特征输出至所述时间注意力卷积网络的聚合模块的输入端。
16.如权利要求9、10、11、12或13所述的基于双重注意力机制的轨迹预测方法,其特征在于,所述译码器包括噪声添加模块和时间输出层,所述噪声添加模块在最后一层时间注意力卷积网络的输出特征序列中添加随机噪声,并将添加随机噪声后的特征序列传输给时间输出层,通过时间输出层映射为所述目标在下一时间段的预测轨迹。
17.一种轨迹预测装置,其特征在于,包括:
获取模块,用于获取获取观测区域内多个目标在上一时间段的轨迹数据,将获取的轨迹数据输入轨迹预测模块;
图结构数据构建模块,用于构建图结构数据并将所述图结构数据输入图注意力网络模组,其中,所述图结构数据中节点与目标一一对应,所述节点特征的初始值为对应目标的目标特征的初始值,连接相邻节点形成图结构数据的边;
图注意力网络模组,被配置为提取目标之间的空间交互信息以及基于所述空间交互信息更新图结构数据的节点特征,即更新目标特征,将更新后的目标特征输入时间注意力网络模组;
时间注意力网络模组,被配置为从更新后的目标特征中提取时间交互信息以及基于提取的时间交互信息获取目标在下一时间段的预测轨迹。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111449388.XA CN114117259A (zh) | 2021-11-30 | 2021-11-30 | 一种基于双重注意力机制的轨迹预测方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111449388.XA CN114117259A (zh) | 2021-11-30 | 2021-11-30 | 一种基于双重注意力机制的轨迹预测方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114117259A true CN114117259A (zh) | 2022-03-01 |
Family
ID=80369011
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111449388.XA Pending CN114117259A (zh) | 2021-11-30 | 2021-11-30 | 一种基于双重注意力机制的轨迹预测方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114117259A (zh) |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114842681A (zh) * | 2022-07-04 | 2022-08-02 | 中国电子科技集团公司第二十八研究所 | 一种基于多头注意力机制的机场场面航迹预测方法 |
CN115050184A (zh) * | 2022-06-13 | 2022-09-13 | 九识智行(北京)科技有限公司 | 一种路口车辆轨迹预测方法及装置 |
CN115618986A (zh) * | 2022-09-29 | 2023-01-17 | 北京骑胜科技有限公司 | 协调资源的方法和装置 |
CN116602663A (zh) * | 2023-06-02 | 2023-08-18 | 深圳市震有智联科技有限公司 | 一种基于毫米波雷达的智能监测方法及系统 |
CN117191068A (zh) * | 2023-11-07 | 2023-12-08 | 新石器慧通(北京)科技有限公司 | 模型训练方法和装置、轨迹预测方法和装置 |
-
2021
- 2021-11-30 CN CN202111449388.XA patent/CN114117259A/zh active Pending
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115050184A (zh) * | 2022-06-13 | 2022-09-13 | 九识智行(北京)科技有限公司 | 一种路口车辆轨迹预测方法及装置 |
CN114842681A (zh) * | 2022-07-04 | 2022-08-02 | 中国电子科技集团公司第二十八研究所 | 一种基于多头注意力机制的机场场面航迹预测方法 |
CN115618986A (zh) * | 2022-09-29 | 2023-01-17 | 北京骑胜科技有限公司 | 协调资源的方法和装置 |
CN116602663A (zh) * | 2023-06-02 | 2023-08-18 | 深圳市震有智联科技有限公司 | 一种基于毫米波雷达的智能监测方法及系统 |
CN116602663B (zh) * | 2023-06-02 | 2023-12-15 | 深圳市震有智联科技有限公司 | 一种基于毫米波雷达的智能监测方法及系统 |
CN117191068A (zh) * | 2023-11-07 | 2023-12-08 | 新石器慧通(北京)科技有限公司 | 模型训练方法和装置、轨迹预测方法和装置 |
CN117191068B (zh) * | 2023-11-07 | 2024-01-19 | 新石器慧通(北京)科技有限公司 | 模型训练方法和装置、轨迹预测方法和装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114117259A (zh) | 一种基于双重注意力机制的轨迹预测方法及装置 | |
CN110851782B (zh) | 一种基于轻量级时空深度学习模型的网络流量预测方法 | |
CN109271933B (zh) | 基于视频流进行三维人体姿态估计的方法 | |
CN109635917B (zh) | 一种多智能体合作决策及训练方法 | |
CN112418409B (zh) | 一种利用注意力机制改进的卷积长短期记忆网络时空序列预测方法 | |
Saxena et al. | D-GAN: Deep generative adversarial nets for spatio-temporal prediction | |
CN110737968B (zh) | 基于深层次卷积长短记忆网络的人群轨迹预测方法及系统 | |
CN114818515A (zh) | 一种基于自注意力机制和图卷积网络的多维时序预测方法 | |
CN113313947A (zh) | 短期交通预测图卷积网络的路况评估方法 | |
CN114116944A (zh) | 一种基于时间注意力卷积网络的轨迹预测方法及装置 | |
CN115512545B (zh) | 一种基于时空动态图卷积网络的交通速度预测方法 | |
CN112415521A (zh) | 基于cgru的强时空特性雷达回波临近预报方法 | |
CN114997067A (zh) | 一种基于时空图与空域聚合Transformer网络的轨迹预测方法 | |
CN110163196A (zh) | 显著特征检测方法和装置 | |
CN115829171A (zh) | 一种联合时空信息和社交互动特征的行人轨迹预测方法 | |
CN115659275A (zh) | 非结构化人机交互环境中的实时准确轨迹预测方法及系统 | |
CN116052254A (zh) | 基于扩展卡尔曼滤波神经网络的视觉连续情感识别方法 | |
CN116246338B (zh) | 一种基于图卷积和Transformer复合神经网络的行为识别方法 | |
CN117236492A (zh) | 基于动态多尺度图学习的交通需求预测方法 | |
CN113869170A (zh) | 一种基于图划分卷积神经网络的行人轨迹预测方法 | |
CN115830707A (zh) | 一种基于超图学习的多视角人体行为识别方法 | |
CN113221450B (zh) | 一种针对稀疏不均匀时序数据的航位预测方法及系统 | |
Cheng et al. | Application of a dynamic recurrent neural network in spatio-temporal forecasting | |
CN117647855B (zh) | 一种基于序列长度的短临降水预报方法、装置及设备 | |
CN117671952A (zh) | 基于时空同步动态图注意力网络的交通流预测方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |