CN115730637A - 多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法 - Google Patents
多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法 Download PDFInfo
- Publication number
- CN115730637A CN115730637A CN202211490477.3A CN202211490477A CN115730637A CN 115730637 A CN115730637 A CN 115730637A CN 202211490477 A CN202211490477 A CN 202211490477A CN 115730637 A CN115730637 A CN 115730637A
- Authority
- CN
- China
- Prior art keywords
- category
- training
- inputting
- prediction model
- vector
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请涉及多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法,模型训练方法包括将车辆运动轨迹训练数据划分为多个类别;对每个类别的训练数据基于解码器模块进行权重迭代剪枝,得到每个类别对应的掩码矩阵;基于每个类别的训练数据、每个类别对应的掩码矩阵对编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,能够输出与多个类别对应的多个轨迹预测结果。本申请的多模态车辆轨迹预测模型训练方法,采用基于数据驱动的方式,充分考虑不同驾驶操作意图,结合车辆的历史运动轨迹状态信息和不同驾驶操作意图的特征,进行多模态车辆轨迹预测模型的训练,生成的车辆轨迹预测模型能够输出多个合理的预测轨迹。
Description
技术领域
本申请涉及自动驾驶技术领域,具体地,涉及一种多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法。
背景技术
精准轨迹预测是引导自动驾驶汽车在道路交通系统中实现自主路径规划、驾驶风险规避的前提,为其安全行驶提供有力保障。如何准确预测周围交互车辆的运动轨迹是交通领域的一个重要挑战。基于模型驱动的轨迹预测方法以车辆或行人为研究对象,通过运动学或统计物理的方法,如卡尔曼滤波等,但这种方法没有考虑其他交通参与者对研究对象的影响,无法充分利用大量真实车辆数据。随着深度学习的发展,以循环神经网络(RNN)为代表的基于数据驱动的方法能够让模型自动化地学习车辆运动特征,从而预测未来轨迹。长短期记忆网络[(LSTM)的长期预测能力有限,其预测误差随预测时长的增加而急剧增加。而对于长时预测能力较强的Transformer模型,其模型参数量与计算量巨大,使得模型过于复杂,成为其目前的主要短板。同时,上述这些模型对驾驶行为缺乏全面正确的理解,即在相同的驾驶情况下,目标车辆也有可能采取不同的驾驶操作意图,因此仅产生单个预测轨迹对预测准确率有很大影响。
发明内容
为了克服现有技术中的至少一个不足,本申请实施例提供一种多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法。
第一方面,提供一种多模态车辆轨迹预测模型训练方法,包括:
将车辆运动轨迹训练数据划分为多个类别,类别包括左偏移、右偏移和直行;
多模态车辆轨迹预测模型包括编码器模块和解码器模块;对每个类别的训练数据基于解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵;
基于每个类别的训练数据、每个类别对应的多个掩码矩阵对编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,训练后的多模态车辆轨迹预测模型能够输出与多个类别对应的多个轨迹预测结果。
在一个实施例中,对每个类别的训练数据基于解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵,包括:
初始化解码器模块的网络参数,采用车辆运动轨迹训练数据对解码器模块进行多轮预训练,得到基网络;
将每个类别的训练数据分别输入到基网络中,进行多轮预训练,得到每个类别对应的子网;
对每个类别对应的子网的多个全连接层的权重参数进行权重迭代剪枝,确定每个类别对应的多个掩码矩阵。
在一个实施例中,基于每个类别的训练数据、每个类别对应的多个掩码矩阵对编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,包括:
训练的过程进行多次,每次训练包括:将每个类别的训练数据分别输入到编码器模块中,得到每个类别对应的高维空间向量;将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,得到每个类别对应的轨迹预测结果;
每次训练结束后,基于车辆运动轨迹测试数据对多模态车辆轨迹预测模型进行测试,当多模态车辆轨迹预测模型满足测试要求时,训练结束,得到训练后的多模态车辆轨迹预测模型。
在一个实施例中,将每个类别的训练数据分别输入到编码器模块中,得到每个类别对应的高维空间向量,包括:
编码器模块包括依次连接的多头注意力模块、残差连接层、归一化层和前馈神经网络;
针对每个类别,将训练数据输入到多头注意力模块,得到Value向量的权重分布;
将Value向量的权重分布和训练数据输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到归一化后的向量;
将归一化后的向量输入到前馈神经网络,得到降维后的向量;
将降维后的向量和归一化后的向量输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到类别对应的高维空间向量。
在一个实施例中,将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,得到每个类别对应的轨迹预测结果,包括:
解码器模块包括掩码多头注意力模块、多头注意力模块、残差连接层、归一化层和前馈神经网络、全连接层网络;
针对每个类别,将训练数据、多个掩码矩阵输入到掩码多头注意力模块,得到Value向量的第一权重分布;
将Value向量的第一权重分布和训练数据输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到第一归一化后的向量;
将第一归一化后的向量、高维空间向量和多个掩码矩阵输入到多头注意力模块,得到Value向量的第二权重分布;
将Value向量的第二权重分布和第一归一化后的向量输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到第二归一化后的向量;
将第二归一化后的向量输入前馈神经网络,得到降维后的向量;
将降维后的向量和第二归一化后的向量输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到第三归一化后的向量;
将第三归一化后的向量输入到全连接层网络,得到类别对应的轨迹预测结果。
在一个实施例中,基于车辆运动轨迹测试数据对多模态车辆轨迹预测模型进行测试,当多模态车辆轨迹预测模型满足测试要求时,训练结束,包括:
车辆运动轨迹测试数据包括多个类别的测试数据;将每个类别的测试数据分别输入到编码器模块中,得到每个类别对应的高维空间向量;
将每个类别的测试数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,得到每个类别对应的轨迹预测结果;
对每个类别对应的高维空间向量进行线性变换和Softmax函数处理,得到每个类别的分布概率;
计算分布概率最大的类别对应的轨迹预测结果与真实轨迹结果之间的误差;
若误差在设定范围内,则多模态车辆轨迹预测模型满足测试要求,训练结束。
第二方面,提供一种多模态车辆轨迹预测模型训练装置,包括:
数据类别划分单元,用于将车辆运动轨迹训练数据划分为多个类别,类别包括左偏移、右偏移和直行;
掩码矩阵获取单元,用于多模态车辆轨迹预测模型包括编码器模块和解码器模块;对每个类别的训练数据基于解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵;
训练单元,用于基于每个类别的训练数据、每个类别对应的多个掩码矩阵对编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,训练后的多模态车辆轨迹预测模型能够输出与多个类别对应的多个轨迹预测结果。
在一个实施例中,掩码矩阵获取单元,还用于:
初始化解码器模块的网络参数,采用车辆运动轨迹训练数据对解码器模块进行多轮预训练,得到基网络;
将每个类别的训练数据分别输入到基网络中,进行多轮预训练,得到每个类别对应的子网;
对每个类别对应的子网的多个全连接层的权重参数进行权重迭代剪枝,确定每个类别对应的多个掩码矩阵。
在一个实施例中,训练单元,还用于:
训练的过程进行多次,每次训练包括:将每个类别的训练数据分别输入到编码器模块中,得到每个类别对应的高维空间向量;将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,得到每个类别对应的轨迹预测结果;
每次训练结束后,基于车辆运动轨迹测试数据对多模态车辆轨迹预测模型进行测试,当多模态车辆轨迹预测模型满足测试要求时,训练结束,得到训练后的多模态车辆轨迹预测模型。
第三方面,提供一种多模态车辆轨迹预测方法,包括:
将待预测轨迹序列输入到多模态车辆轨迹预测模型中,输出多个轨迹预测结果;
多模态车辆轨迹预测模型为根据上述多模态车辆轨迹预测模型训练方法得到的。
相对于现有技术而言,本申请具有以下有益效果:
1、使用基于数据驱动的方式,充分考虑不同驾驶操作意图,结合车辆的历史运动轨迹状态信息和不同驾驶操作意图的特征,通过稀疏权重共享的解码器进行多模态轨迹预测,生成多个合理的预测轨迹。
2、设计权重迭代剪枝单元,对Transformer中的自注意力机制进行参数稀疏化处理,通过多次迭代剪枝为不同驾驶操作意图的数据生成不同的子网掩码,使得其结构与每个类别的数据集高维特征空间相适应,从而提取不同类别的轨迹的隐藏信息,能够极大程度的正确预测未来轨迹。
3、采用并行训练策略,将不同驾驶操作意图的数据集送入对应的子网中训练,实现了网络中部分参数只在特定子网下得到更新,在保证模型效率的前提下达到减少模型计算开销的目的,提升了轨迹预测的准确率,实现了在复杂交通场景下精确可靠的车辆轨迹预测。
附图说明
本申请可以通过参考下文中结合附图所给出的描述而得到更好的理解,附图连同下面的详细说明一起包含在本说明书中并且形成本说明书的一部分。在附图中:
图1示出了根据本申请实施例的多模态车辆轨迹预测模型训练方法的流程框图;
图2示出了根据本申请实施例的多模态车辆轨迹预测模型训练方法的原理图;
图3示出了根据本申请实施例的多模态车辆轨迹预测模型训练装置的结构框图;
图4示出了本申请的多模态车辆轨迹预测模型与现有模型得到的均方根误差对比结果图;
图5示出了车辆正向行驶且向左偏移的预测结果图;
图6示出了车辆反向行驶且向右偏移的预测结果图。
具体实施方式
在下文中将结合附图对本申请的示例性实施例进行描述。为了清楚和简明起见,在说明书中并未描述实际实施例的所有特征。然而,应该了解,在开发任何这种实际实施例的过程中可以做出很多特定于实施例的决定,以便实现开发人员的具体目标,并且这些决定可能会随着实施例的不同而有所改变。
在此,还需要说明的一点是,为了避免因不必要的细节而模糊了本申请,在附图中仅仅示出了与根据本申请的方案密切相关的装置结构,而省略了与本申请关系不大的其他细节。
应理解的是,本申请并不会由于如下参照附图的描述而只限于所描述的实施形式。在本文中,在可行的情况下,实施例可以相互组合、不同实施例之间的特征替换或借用、在一个实施例中省略一个或多个特征。
本申请实施例提供一种多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法,在模型训练过程中,设计了基于无监督聚类的数据类别划分单元,通过车辆纵坐标位移之差,用聚类算法筛选出不同驾驶操作类别的数据集;为减小模型的计算量,提高运行效率,设计掩码矩阵获取单元,为解码器中的自注意力机制计算Q、K、V的全连接层进行参数的稀疏化处理,为多类别数据集生成对应的多个掩码矩阵,得到各类别子网络;基于Transformer的编码器模块和稀疏权重共享的解码器模块,生成的模态车辆轨迹预测模型能够输出多个合理的预测轨迹。
图1示出了根据本申请实施例的多模态车辆轨迹预测模型训练方法的流程框图,多模态车辆轨迹预测模型训练方法中,多模态车辆轨迹预测模型包括编码器模块和解码器模块,包括:
步骤S11,将车辆运动轨迹训练数据划分为多个类别,类别包括左偏移、右偏移和直行;
该步骤中,采用德国高速公路的车辆轨迹数据集(HighD数据集),HighD数据集包括来自6个地点、11.5小时的测量值,采样频率为25Hz,记录了11万辆车、里程45000公里的数据。
获取车辆轨迹数据集中一段时间内运动轨迹的终点和起点之间的纵坐标之差Δy,采用无监督学习中的K-Means聚类方法学习Δy的特征,将车辆轨迹数据集划分为多个类别,将每个类别的车辆轨迹数据划分为车辆运动轨迹训练数据和车辆运动轨迹测试数据。图2示出了根据本申请实施例的多模态车辆轨迹预测模型训练方法的原理图,步骤S11可以由图2中的数据处理模块实现。
步骤S12,对每个类别的训练数据基于解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵;步骤S12可以由图2中的权重迭代剪枝模块实现。
步骤S13,基于每个类别的训练数据、每个类别对应的掩码矩阵对编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,训练后的多模态车辆轨迹预测模型能够输出与多个类别对应的多个轨迹预测结果。步骤S12可以由图2中的编码器模块和解码器模块实现。
本申请实施例的多模态车辆轨迹预测模型训练方法,采用基于数据驱动的方式,充分考虑不同驾驶操作意图,结合车辆的历史运动轨迹状态信息和不同驾驶操作意图的特征,进行多模态车辆轨迹预测模型的训练,生成的模态车辆轨迹预测模型能够输出多个合理的预测轨迹。
传统的Transformer自注意力机制存在大量参数,但有些参数不仅对最终的输出结果贡献不大,还存在大量冗余,本申请的实施例使用权重迭代剪枝方法对Transformer解码器中的自注意力机制计算Query、Key和Value向量的全连接层进行参数的稀疏化处理,从而减少模型大小,提升运行速度,同时在一定程度上防止过拟合。具体的,在一个实施例中,步骤S12中,对每个类别的训练数据基于解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵,包括:
步骤S121,初始化解码器模块的网络参数,采用车辆运动轨迹训练数据对解码器模块进行多轮预训练,得到基网络;这里,假设该基网络是过参数化的,包含多个类别的解,初始化各类别的掩码矩阵Mc=1,其中c=1,2,3,表示类别编号,这里,掩码矩阵中的元素与解码器中的全连接层的多个权重参数相对应。
步骤S122,将每个类别的训练数据分别输入到基网络中,进行多轮预训练,得到每个类别对应的子网;
步骤S123,对每个类别对应的子网的多个全连接层的权重参数进行权重迭代剪枝,确定每个类别对应的多个掩码矩阵。
该步骤中,每个全连接层具有多个权重参数,对每个全连接层的权重参数的绝对值按照升序排序,修剪绝对值较小的前α%的权重参数,其中α为每次剪枝率,若全连接层的某个参数被修剪,则与该权重参数的掩码矩阵中的相应元素置为0。通过上述操作,左偏移、右偏移和直行中每个类别均对应得到多个掩码矩阵。
该实施例中,设计权重迭代剪枝单元,对Transformer中的自注意力机制进行参数稀疏化处理,通过多次迭代剪枝为不同驾驶操作意图的数据生成不同的子网掩码,使得其结构与每个类别的数据集高维特征空间相适应,从而提取不同类别的轨迹的隐藏信息,能够极大程度的正确预测未来轨迹。
在一个实施例中,步骤S13中,基于每个类别的训练数据、每个类别对应的多个掩码矩阵对编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,包括:
训练的过程进行多次,每次训练包括:将每个类别的训练数据分别输入到编码器模块中,得到每个类别对应的高维空间向量;将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,得到每个类别对应的轨迹预测结果;
每次训练结束后,基于车辆运动轨迹测试数据对多模态车辆轨迹预测模型进行测试,当多模态车辆轨迹预测模型满足测试要求时,训练结束,得到训练后的多模态车辆轨迹预测模型。
该实施例中,在训练过程中,针对每个类别的训练数据,先将某训练数据按照一定的数据量大小划分为多个训练子集,每次训练采用不同的训练子集进行训练。
具体地,上述实施例中,对编码器模块进行训练的具体过程中,将每个类别的训练数据分别输入到编码器模块中,这里,编码器模块包括依次连接的多头注意力模块、残差连接层、归一化层和前馈神经网络,得到每个类别对应的高维空间向量,可以包括:
步骤S131,针对每个类别,将训练数据输入到多头注意力模块,得到Value向量的权重分布;
该步骤中,多头注意力模块包括多个注意力单元,每个注意力单元包括3个全连接层,分别用于处理训练数据得到Query、Key和Value向量,具体采用的公式如下:
Q=WqX
K=WkX
V=WvX(1)
其中,Q、K、V分别表示Query、Key和Value向量,Wq,Wk,Wv分别表示3个全连接层的权重参数,X表示训练数据;
然后,采用Softmax函数对Query、Key和Value向量进行处理,得到每个注意力单元输出的Value向量的权重分布,具体采用的公式如下:
其中,dk表示Key向量的维度,attention(Q,K,V)表示每个注意力单元输出的Value向量的权重分布;
然后,将每个注意力单元输出的Value向量的权重分布进行拼接,得到多头注意力模块输出的Value向量的权重分布。
步骤S132,将Value向量的权重分布和训练数据输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到归一化后的向量;
步骤S133,将归一化后的向量输入到前馈神经网络,得到降维后的向量;
该步骤中,前馈神经网络的结构简单,由两个线性层和中间的Relu激活函数组成,对向量进行升维后再降维。
步骤S134,将降维后的向量和归一化后的向量输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到类别对应的高维空间向量。
该实施例中,网络捕获序列非线性的能力主要来自于注意模块,自注意力作为注意力的变体,其减少了对外部信息的依赖,更擅长捕捉数据或特征的内部相关性,从大量输入信息中筛选出少量重要信息,忽略其他信息,即通过计算向量间的相互影响,以解决长期依赖问题。
在一个实施例中,前述实施例中,将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,这里,解码器模块包括掩码多头注意力模块、多头注意力模块、残差连接层、归一化层和前馈神经网络、全连接层网络;得到每个类别对应的轨迹预测结果,具体可以包括:
步骤S141,针对每个类别,将训练数据、多个掩码矩阵输入到掩码多头注意力模块,得到Value向量的第一权重分布;
该步骤中,掩码多头注意力模块同样包括多个注意力单元,每个注意力单元包括3个全连接层,分别用于处理训练数据和该全连接层对应的掩码矩阵得到Query、Key和Value向量,具体采用的公式如下:
Q=Mq⊙Wq)X
K=(Mk⊙Wk)X
V=(Mv⊙Wv)X (3)
其中,Q、K、V分别表示Query、Key和Value向量,Wq,Wk,Wv分别表示3个全连接层的权重参数,Mq表示生成Query向量的全连接层对应的掩码矩阵,Mk表示生成Key向量的全连接层对应的掩码矩阵,Mv表示生成Value向量的全连接层对应的掩码矩阵;X表示训练数据。
然后,采用Softmax函数对Query、Key和Value向量进行处理,得到每个注意力单元输出的Value向量的权重分布,具体采用的公式如下:
其中,dk表示Key向量的维度,attention(Q,K,V)表示每个注意力单元输出的Value向量的权重分布;
然后,将每个注意力单元输出的Value向量的权重分布进行拼接,得到掩码多头注意力模块输出的Value向量的第一权重分布。
步骤S142,将Value向量的第一权重分布和训练数据输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到第一归一化后的向量;
步骤S143,将第一归一化后的向量、高维空间向量和多个掩码矩阵输入到多头注意力模块,得到Value向量的第二权重分布;
该步骤中,多头注意力模块包括多个注意力单元,每个注意力单元包括3个全连接层,分别用于处理第一归一化后的向量、步骤S134得到的高维空间向量和该全连接层对应的掩码矩阵得到Query、Key和Value向量,具体采用的公式如下:
Q=Mq⊙Wq)X归一化
K=Mk⊙Wk)enc
V=(Mv⊙Wv)enc (5)
其中,Q、K、V分别表示Query、Key和Value向量,Wq,Wk,Wv分别表示3个全连接层的权重参数,Mq表示生成Query向量的全连接层对应的掩码矩阵,Mk表示生成Key向量的全连接层对应的掩码矩阵,Mv表示生成Value向量的全连接层对应的掩码矩阵;X归一化表示第一归一化后的向量,enc表示高维空间向量。
然后,采用Softmax函数对Query、Key和Value向量进行处理,得到每个注意力单元输出的Value向量的权重分布,具体采用的公式如下:
其中,dk表示Key向量的维度,attention(Q,K,V)表示每个注意力单元输出的Value向量的权重分布;
然后,将每个注意力单元输出的Value向量的权重分布进行拼接,得到多头注意力模块输出的Value向量的第二权重分布。
步骤S144,将Value向量的第二权重分布和第一归一化后的向量输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到第二归一化后的向量;
步骤S145,将第二归一化后的向量输入前馈神经网络,得到降维后的向量;
步骤S146,将降维后的向量和第二归一化后的向量输入到残差连接层进行残差连接,并将残差连接的结果输入到归一化层进行归一化处理,得到第三归一化后的向量;
步骤S147,将第三归一化后的向量输入到全连接层网络,得到类别对应的轨迹预测结果。
在一个实施例中,基于车辆运动轨迹测试数据对多模态车辆轨迹预测模型进行测试,当多模态车辆轨迹预测模型满足测试要求时,训练结束,包括:
步骤S151,车辆运动轨迹测试数据包括多个类别的测试数据;将每个类别的测试数据分别输入到编码器模块中,得到每个类别对应的高维空间向量;该步骤的具体实现过程与步骤S131-步骤S134一致。
步骤S152,将每个类别的测试数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,得到每个类别对应的轨迹预测结果;该步骤的具体实现过程与步骤S141-步骤S147一致。
步骤S153,对每个类别对应的高维空间向量进行线性变换和Softmax函数处理,得到每个类别的分布概率P(c|x);具体可采用以下公式:
P(c|x)=softmax(wx+b)(7)
其中,x表示高维空间向量,w和b是线性变换中的可学习参数。
步骤S154,计算分布概率最大的类别对应的轨迹预测结果与真实轨迹结果之间的误差;这里可以计算均方根误差(RMSE)。
步骤S155,若误差在设定范围内,则多模态车辆轨迹预测模型满足测试要求,训练结束。
本申请实施例还提供一种多模态车辆轨迹预测方法,包括:
将待预测轨迹序列输入到多模态车辆轨迹预测模型中,输出多个轨迹预测结果;多模态车辆轨迹预测模型为根据上述实施例的多模态车辆轨迹预测模型训练方法得到的。
基于与多模态车辆轨迹预测模型训练方法相同的发明构思,本实施例还提供与之对应的多模态车辆轨迹预测模型训练装置,图3示出了根据本申请实施例的多模态车辆轨迹预测模型训练装置的结构框图,装置包括:
数据类别划分单元31,用于将车辆运动轨迹训练数据划分为多个类别,类别包括左偏移、右偏移和直行;
掩码矩阵获取单元32,用于多模态车辆轨迹预测模型包括编码器模块和解码器模块;对每个类别的训练数据基于解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵;
训练单元33,用于基于每个类别的训练数据、每个类别对应的掩码矩阵对编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,训练后的多模态车辆轨迹预测模型能够输出与多个类别对应的多个轨迹预测结果。
本申请实施例的多模态车辆轨迹预测模型训练装置,采用基于数据驱动的方式,充分考虑不同驾驶操作意图,结合车辆的历史运动轨迹状态信息和不同驾驶操作意图的特征,进行多模态车辆轨迹预测模型的训练,生成的模态车辆轨迹预测模型能够输出多个合理的预测轨迹。
在一个实施例中,掩码矩阵获取单元32,还用于:
初始化解码器模块的网络参数,采用车辆运动轨迹训练数据对解码器模块进行多轮预训练,得到基网络;
将每个类别的训练数据分别输入到基网络中,进行多轮预训练,得到每个类别对应的子网;
对每个类别对应的子网的多个全连接层的权重参数进行权重迭代剪枝,确定每个类别对应的多个掩码矩阵。
该实施例中,设计权重迭代剪枝单元,对Transformer中的自注意力机制进行参数稀疏化处理,通过多次迭代剪枝为不同驾驶操作意图的数据生成不同的子网掩码,使得其结构与每个类别的数据集高维特征空间相适应,从而提取不同类别的轨迹的隐藏信息,能够极大程度的正确预测未来轨迹。
在一个实施例中,训练单元33,还用于:
训练的过程进行多次,每次训练包括:将每个类别的训练数据分别输入到编码器模块中,得到每个类别对应的高维空间向量;将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到解码器模块中,得到每个类别对应的轨迹预测结果;
每次训练结束后,基于车辆运动轨迹测试数据对多模态车辆轨迹预测模型进行测试,当多模态车辆轨迹预测模型满足测试要求时,训练结束,得到训练后的多模态车辆轨迹预测模型。
为验证本申请的有效性,采用公开的德国高速公路的大型自然车辆轨迹数据集HighD进行模型的训练和测试。该数据集包含来自6个地点、11.5小时的测量值,采样频率为25Hz,记录了11万辆车、里程45000公里的数据。处理后的数据集为左偏移、右偏移与保持直行状态的3种驾驶操作类别数据集,以40帧(即8秒)为一个试验样本,前15帧(即3秒)为历史轨迹序列Th=15,后25帧(即5秒)为未来轨迹序列Tf=25。
为了检测多模态车辆轨迹预测模型的效果,使用均方根误差(RMSE)来验证模型效果,RMSE越大,表示误差越大,均方根误差计算公式如下:
其中,Th、Tf为历史轨迹序列长度和未来轨迹序列长度,xt、yt为t时刻目标车辆的真实坐标,xpred t、ypred t为t时刻目标车辆的预测坐标。
本申请为了验证多模态车辆轨迹预测模型的性能,将本申请与现有模型ConvolutionalSocial-LSTM(CS-LSTM)和Social-GAN(S-GAN)1-5秒内RMSE指标进行了可视化,图4示出了本申请的多模态车辆轨迹预测模型与现有模型得到的均方根误差对比结果图,从图4中可以看出,本申请虽然在短期轨迹预测方面(1-2秒)与现有模型存在一定差距,但在长期轨迹预测(3-5秒)方面明显优于现有模型,且从RMSE增长的趋势来看,现有模型的RMSE随预测时域的增加呈指数增长趋势,而本申请的误差增长非常缓慢,保持在较低的水平上,预测时域越长,模型的预测精度越高。这反映了本申请的整体性能是优于现有提出的模型的。
同时,本申请为了获得定性结果,图5、图6展示了一些典型的预测结果,图5示出了车辆正向行驶且向左偏移的预测结果图,图6示出了车辆反向行驶且向右偏移的预测结果图,以反映模型在不同驾驶场景下的性能。图中左上角为坐标系原点,横轴为高速公路行驶方向,纵轴为垂直于高速公路行驶方向,黑色虚线为目标车辆的历史轨迹坐标,通过3种不同朝向的三角形表示模型3个子网(3种驾驶操作类别)的预测结果,并使用五角星标注出当前类别分布概率最大的预测轨迹,与圆形标注的真实未来轨迹形成对比。根据图5和图6,可以看出模型驾驶分类准确且分布概率最大的预测轨迹与真实未来轨迹基本重合。
综上,本申请具有以下有益效果:
1、使用基于数据驱动的方式,充分考虑不同驾驶操作意图,结合车辆的历史运动轨迹状态信息和不同驾驶操作意图的特征,通过稀疏权重共享的解码器进行多模态轨迹预测,生成多个合理的预测轨迹。
2、设计权重迭代剪枝模块,对Transformer中的自注意力机制进行参数稀疏化处理,通过多次迭代剪枝为不同驾驶操作意图的数据生成不同的子网掩码,使得其结构与每个类别的数据集高维特征空间相适应,从而提取不同类别的轨迹的隐藏信息,能够极大程度的正确预测未来轨迹。
3、采用并行训练策略,将不同驾驶操作意图的数据集送入对应的子网中训练,实现了网络中部分参数只在特定子网下得到更新,在保证模型效率的前提下达到减少模型计算开销的目的,提升了轨迹预测的准确率,实现了在复杂交通场景下精确可靠的车辆轨迹预测。
以上所述,仅为本申请的各种实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以所述权利要求的保护范围为准。
Claims (10)
1.一种多模态车辆轨迹预测模型训练方法,其特征在于,包括:
将车辆运动轨迹训练数据划分为多个类别,所述类别包括左偏移、右偏移和直行;
所述多模态车辆轨迹预测模型包括编码器模块和解码器模块;对每个类别的训练数据基于所述解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵;
基于所述每个类别的训练数据、所述每个类别对应的多个掩码矩阵对所述编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,所述训练后的多模态车辆轨迹预测模型能够输出与所述多个类别对应的多个轨迹预测结果。
2.如权利要求1所述的方法,其特征在于,其中,对每个类别的训练数据基于所述解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵,包括:
初始化所述解码器模块的网络参数,采用所述车辆运动轨迹训练数据对所述解码器模块进行多轮预训练,得到基网络;
将每个类别的训练数据分别输入到所述基网络中,进行多轮预训练,得到每个类别对应的子网;
对所述每个类别对应的子网的多个全连接层的权重参数进行权重迭代剪枝,确定每个类别对应的多个掩码矩阵。
3.如权利要求1所述的方法,其特征在于,其中,基于所述每个类别的训练数据、所述每个类别对应的多个掩码矩阵对所述编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,包括:
所述训练的过程进行多次,每次训练包括:将每个类别的训练数据分别输入到所述编码器模块中,得到每个类别对应的高维空间向量;将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到所述解码器模块中,得到所述每个类别对应的轨迹预测结果;
每次训练结束后,基于车辆运动轨迹测试数据对所述多模态车辆轨迹预测模型进行测试,当所述多模态车辆轨迹预测模型满足测试要求时,训练结束,得到训练后的多模态车辆轨迹预测模型。
4.如权利要求3所述的方法,其特征在于,其中,将每个类别的训练数据分别输入到所述编码器模块中,得到每个类别对应的高维空间向量,包括:
所述编码器模块包括依次连接的多头注意力模块、残差连接层、归一化层和前馈神经网络;
针对每个类别,将所述训练数据输入到所述多头注意力模块,得到Value向量的权重分布;
将所述Value向量的权重分布和所述训练数据输入到所述残差连接层进行残差连接,并将所述残差连接的结果输入到所述归一化层进行归一化处理,得到归一化后的向量;
将所述归一化后的向量输入到所述前馈神经网络,得到降维后的向量;
将所述降维后的向量和所述归一化后的向量输入到所述残差连接层进行残差连接,并将所述残差连接的结果输入到所述归一化层进行归一化处理,得到所述类别对应的高维空间向量。
5.如权利要求3所述的方法,其特征在于,其中,将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到所述解码器模块中,得到所述每个类别对应的轨迹预测结果,包括:
所述解码器模块包括掩码多头注意力模块、多头注意力模块、残差连接层、归一化层和前馈神经网络、全连接层网络;
针对每个类别,将所述训练数据、所述多个掩码矩阵输入到所述掩码多头注意力模块,得到Value向量的第一权重分布;
将所述Value向量的第一权重分布和所述训练数据输入到所述残差连接层进行残差连接,并将所述残差连接的结果输入到所述归一化层进行归一化处理,得到第一归一化后的向量;
将所述第一归一化后的向量、所述高维空间向量和所述多个掩码矩阵输入到所述多头注意力模块,得到Value向量的第二权重分布;
将所述Value向量的第二权重分布和所述第一归一化后的向量输入到所述残差连接层进行残差连接,并将所述残差连接的结果输入到所述归一化层进行归一化处理,得到第二归一化后的向量;
将所述第二归一化后的向量输入所述前馈神经网络,得到降维后的向量;
将所述降维后的向量和所述第二归一化后的向量输入到所述残差连接层进行残差连接,并将所述残差连接的结果输入到所述归一化层进行归一化处理,得到第三归一化后的向量;
将所述第三归一化后的向量输入到所述全连接层网络,得到所述类别对应的轨迹预测结果。
6.如权利要求3所述的方法,其特征在于,基于车辆运动轨迹测试数据对所述多模态车辆轨迹预测模型进行测试,当所述多模态车辆轨迹预测模型满足测试要求时,训练结束,包括:
所述车辆运动轨迹测试数据包括多个类别的测试数据;将每个类别的测试数据分别输入到所述编码器模块中,得到每个类别对应的高维空间向量;
将每个类别的测试数据、对应的多个掩码矩阵和对应的高维空间向量输入到所述解码器模块中,得到所述每个类别对应的轨迹预测结果;
对所述每个类别对应的高维空间向量进行线性变换和Softmax函数处理,得到每个类别的分布概率;
计算所述分布概率最大的类别对应的轨迹预测结果与真实轨迹结果之间的误差;
若所述误差在设定范围内,则所述多模态车辆轨迹预测模型满足测试要求,训练结束。
7.一种多模态车辆轨迹预测模型训练装置,其特征在于,包括:
数据类别划分单元,用于将车辆运动轨迹训练数据划分为多个类别,所述类别包括左偏移、右偏移和直行;
掩码矩阵获取单元,用于所述多模态车辆轨迹预测模型包括编码器模块和解码器模块;对每个类别的训练数据基于所述解码器模块进行权重迭代剪枝,得到每个类别对应的多个掩码矩阵;
训练单元,用于基于所述每个类别的训练数据、所述每个类别对应的多个掩码矩阵对所述编码器模块和解码器模块进行训练,得到训练后的多模态车辆轨迹预测模型,所述训练后的多模态车辆轨迹预测模型能够输出与所述多个类别对应的多个轨迹预测结果。
8.如权利要求7所述的装置,其特征在于,所述掩码矩阵获取单元,还用于:
初始化所述解码器模块的网络参数,采用所述车辆运动轨迹训练数据对所述解码器模块进行多轮预训练,得到基网络;
将每个类别的训练数据分别输入到所述基网络中,进行多轮预训练,得到每个类别对应的子网;
对所述每个类别对应的子网的多个全连接层的权重参数进行权重迭代剪枝,确定每个类别对应的多个掩码矩阵。
9.如权利要求7所述的装置,其特征在于,所述训练单元,还用于:
所述训练的过程进行多次,每次训练包括:将每个类别的训练数据分别输入到所述编码器模块中,得到每个类别对应的高维空间向量;将每个类别的训练数据、对应的多个掩码矩阵和对应的高维空间向量输入到所述解码器模块中,得到所述每个类别对应的轨迹预测结果;
每次训练结束后,基于车辆运动轨迹测试数据对所述多模态车辆轨迹预测模型进行测试,当所述多模态车辆轨迹预测模型满足测试要求时,训练结束,得到训练后的多模态车辆轨迹预测模型。
10.一种多模态车辆轨迹预测方法,其特征在于,包括:
将待预测轨迹序列输入到多模态车辆轨迹预测模型中,输出多个轨迹预测结果;
所述多模态车辆轨迹预测模型为根据权利要求1-6中任一权利要求的多模态车辆轨迹预测模型训练方法得到的。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211490477.3A CN115730637A (zh) | 2022-11-25 | 2022-11-25 | 多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211490477.3A CN115730637A (zh) | 2022-11-25 | 2022-11-25 | 多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115730637A true CN115730637A (zh) | 2023-03-03 |
Family
ID=85298287
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211490477.3A Pending CN115730637A (zh) | 2022-11-25 | 2022-11-25 | 多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115730637A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116226787A (zh) * | 2023-05-04 | 2023-06-06 | 中汽信息科技(天津)有限公司 | 商用车出险概率预测方法、设备和介质 |
CN116469041A (zh) * | 2023-06-20 | 2023-07-21 | 成都理工大学工程技术学院 | 一种目标对象的运动轨迹预测方法、系统及设备 |
CN116680656A (zh) * | 2023-07-31 | 2023-09-01 | 合肥海普微电子有限公司 | 基于生成型预训练变换器的自动驾驶运动规划方法及系统 |
CN117333847A (zh) * | 2023-12-01 | 2024-01-02 | 山东科技大学 | 一种基于车辆行为识别的轨迹预测方法及系统 |
-
2022
- 2022-11-25 CN CN202211490477.3A patent/CN115730637A/zh active Pending
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116226787A (zh) * | 2023-05-04 | 2023-06-06 | 中汽信息科技(天津)有限公司 | 商用车出险概率预测方法、设备和介质 |
CN116469041A (zh) * | 2023-06-20 | 2023-07-21 | 成都理工大学工程技术学院 | 一种目标对象的运动轨迹预测方法、系统及设备 |
CN116469041B (zh) * | 2023-06-20 | 2023-09-19 | 成都理工大学工程技术学院 | 一种目标对象的运动轨迹预测方法、系统及设备 |
CN116680656A (zh) * | 2023-07-31 | 2023-09-01 | 合肥海普微电子有限公司 | 基于生成型预训练变换器的自动驾驶运动规划方法及系统 |
CN116680656B (zh) * | 2023-07-31 | 2023-11-07 | 合肥海普微电子有限公司 | 基于生成型预训练变换器的自动驾驶运动规划方法及系统 |
CN117333847A (zh) * | 2023-12-01 | 2024-01-02 | 山东科技大学 | 一种基于车辆行为识别的轨迹预测方法及系统 |
CN117333847B (zh) * | 2023-12-01 | 2024-03-15 | 山东科技大学 | 一种基于车辆行为识别的轨迹预测方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN115730637A (zh) | 多模态车辆轨迹预测模型训练方法、装置及轨迹预测方法 | |
CN110889546B (zh) | 一种基于注意力机制的交通流量模型训练方法 | |
CN110162018B (zh) | 基于知识蒸馏与隐含层共享的增量式设备故障诊断方法 | |
CN109639739B (zh) | 一种基于自动编码器网络的异常流量检测方法 | |
CN110097755B (zh) | 基于深度神经网络的高速公路交通流量状态识别方法 | |
CN113487061A (zh) | 一种基于图卷积-Informer模型的长时序交通流量预测方法 | |
CN112884059B (zh) | 一种融合先验知识的小样本雷达工作模式分类方法 | |
CN110991471B (zh) | 一种高速列车牵引系统故障诊断方法 | |
Jin et al. | Transformer-based map-matching model with limited labeled data using transfer-learning approach | |
CN115859077A (zh) | 一种变工况下多特征融合的电机小样本故障诊断方法 | |
CN116415200A (zh) | 一种基于深度学习的异常车辆轨迹异常检测方法及系统 | |
CN116192500A (zh) | 一种对抗标签噪声的恶意流量检测装置及方法 | |
Zheng et al. | Real‐time driving style classification based on short‐term observations | |
Lee et al. | Probing the purview of neural networks via gradient analysis | |
Sun et al. | HRRP target recognition based on soft-boundary deep SVDD with LSTM | |
CN116946183A (zh) | 一种考虑驾驶能力的商用车驾驶行为预测方法及车用设备 | |
CN107229944B (zh) | 基于认知信息粒子的半监督主动识别方法 | |
Zhang et al. | CAGFuzz: Coverage-guided adversarial generative fuzzing testing of deep learning systems | |
CN116166642A (zh) | 基于引导信息的时空数据填补方法、系统、设备及介质 | |
CN115421029A (zh) | 一种fcm-ga-pnn的模拟电路故障诊断方法 | |
CN114328921A (zh) | 一种基于分布校准的小样本实体关系抽取方法 | |
Kim et al. | Line-Post Insulator Fault Classification Model Using Deep Convolutional GAN-Based Synthetic Images | |
Behnia et al. | Deep generative models for vehicle speed trajectories | |
Li et al. | Research on abnormal monitoring of vehicle traffic network data based on support vector machine | |
CN114328791B (zh) | 一种基于深度学习的地图匹配算法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |