CN116956098A - 一种基于感知分布式对比学习框架的长尾轨迹预测方法 - Google Patents
一种基于感知分布式对比学习框架的长尾轨迹预测方法 Download PDFInfo
- Publication number
- CN116956098A CN116956098A CN202311222987.7A CN202311222987A CN116956098A CN 116956098 A CN116956098 A CN 116956098A CN 202311222987 A CN202311222987 A CN 202311222987A CN 116956098 A CN116956098 A CN 116956098A
- Authority
- CN
- China
- Prior art keywords
- track
- gate
- cluster
- network
- tail
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 41
- 230000008447 perception Effects 0.000 title claims abstract description 9
- 239000013598 vector Substances 0.000 claims description 39
- 239000011159 matrix material Substances 0.000 claims description 29
- 230000006870 function Effects 0.000 claims description 18
- 238000012549 training Methods 0.000 claims description 6
- 238000001914 filtration Methods 0.000 claims description 3
- 238000009499 grossing Methods 0.000 claims description 3
- 238000013507 mapping Methods 0.000 claims description 3
- 238000010606 normalization Methods 0.000 claims description 3
- 238000012545 processing Methods 0.000 description 7
- 230000009286 beneficial effect Effects 0.000 description 6
- 238000013461 design Methods 0.000 description 4
- 230000008569 process Effects 0.000 description 4
- 238000013459 approach Methods 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000007774 longterm Effects 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
- G06F18/232—Non-hierarchical techniques
- G06F18/2321—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
- G06F18/23213—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/213—Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
- G06N3/0442—Recurrent networks, e.g. Hopfield networks characterised by memory or gating, e.g. long short-term memory [LSTM] or gated recurrent units [GRU]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/048—Activation functions
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Probability & Statistics with Applications (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种基于感知分布式对比学习框架的长尾轨迹预测方法,涉及自动驾驶轨迹预测领域,本发明通过离线轨迹聚类模块和原型对比学习,识别和分离出数据集中的不同轨迹模式,将同一模式群中的轨迹的特征拉近,而将不同模式群中的特征推远,通过分布感知的预测器,为具有不同模式的轨迹输入提供分离的解码器参数,即根据输入轨迹的模式,使用不同的参数进行解码,从而更准确地预测出未来的轨迹。本发明解决了现有方法忽视数据集中的长尾现象,以及尾部样本的特征被错误编码导致预测结果不准确的问题。
Description
技术领域
本发明涉及自动驾驶轨迹预测领域,特别是涉及一种基于感知分布式对比学习框架的长尾轨迹预测方法。
背景技术
在自动驾驶场景中,轨迹预测的重要性不言而喻。这种预测的目标是根据过去观察到的轨迹,为道路上的参与者预测出一系列可能的未来位置。这项任务对于保证安全和高效的自动驾驶具有关键的作用。在轨迹预测的研究领域中,最近几年已经出现了许多新的方法和技术。这些方法在预测的模态方面有所不同,包括单模态和多模态两种类型。单模态预测主要关注预测出一种最可能的未来轨迹,而多模态预测则试图预测出所有可能的未来轨迹。无论是单模态还是多模态预测,都有一系列的方法和技术可供选择。这些新的预测方法无疑为自动驾驶的发展提供了更多可能性,也为未来的研究开启了新的道路。
尽管现有的预测方法在精度上已经取得了相当高的成绩,但在训练和评估阶段,这些方法大多将数据集中的样本平等对待,这就忽视了一个现象,即数据集中的长尾现象。当尾部样本的特征被错误编码时,模型在进行预测时就可能产生误导,导致预测结果的准确性大大降低。
发明内容
针对现有技术中的上述不足,本发明提供的一种基于感知分布式对比学习框架的长尾轨迹预测方法解决了现有方法忽视数据集中的长尾现象,以及尾部样本的特征被错误编码导致预测结果不准确的问题。
为了达到上述发明目的,本发明采用的技术方案为:一种基于感知分布式对比学习框架的长尾轨迹预测方法,包括以下步骤:
S1:利用轨迹特征提取器提取历史轨迹数据的特征;
S2:利用离散聚类模块对提取的特征进行聚类,将历史轨迹数据划分为多个轨迹模式;
S3:基于划分的多个轨迹模式,在基线预测网络的历史编码特征上,根据离散聚类模块生成的伪标签进行原型对比学习,使轨迹编码器的特征空间根据不同轨迹模式分别进行聚类;
S4:基于轨迹编码器的聚类结果,利用分布感知的预测器,为不同的轨迹输入生成不同的解码器权重,将尾部轨迹模式和头部轨迹模式以不同的方式进行预测,实现长尾轨迹的预测。
上述方案的有益效果是:本发明提供了一种对比学习框架,能够识别并分离出长尾数据中的不同运动模式,并在特征空间中形成独立的模式群,这使得可以更准确地预测出未来的轨迹,而不是单一地依赖于大多数数据,同时提出了一种分布感知的预测器,根据特征空间的分布情况进行预测,从而提高对罕见轨迹的预测准确性,解决了现有方法忽视数据集中的长尾现象,以及尾部样本的特征被错误编码导致预测结果不准确的问题。
进一步地,S1中轨迹特征提取器为:在1D卷积网络的基础上增加LSTM网络。
上述进一步方案的有益效果是:使用一个附加了LSTM的1D卷积网络作为轨迹特征提取器,进行轨迹的编码和重构。
进一步地,S2中利用离散聚类模块对提取的特征进行聚类,包括以下分步骤:
S2-1:将历史轨迹数据按照时间划分为多个时间段,且每个时间段包含多个轨迹点;
S2-2:将每个时间段中的轨迹点按照空间位置进行聚类,得到多个轨迹簇;
S2-3:计算每个轨迹簇的质心,并作为当前轨迹簇的代表点;
S2-4:将所有的代表点作为输入,使用K-means算法进行聚类,得到多个聚类簇;
S2-5:将每个聚类簇的代表点作为聚类簇的原型,完成对历史轨迹数据轨迹模式的划分。
上述进一步方案的有益效果是:如果采用在线聚类的方式,将需要消耗大量的计算资源和时间。因此选择使用离线聚类模块,以提高处理效率,同时保持预测性能,同时选用K-means算法,具有计算效率高的优点。
进一步地,S3中原型对比学习采用多级聚类方法计算损失,公式为:
其中,为原型对比学习损失函数,/>为实例间的对比项,/>为实例-原型的对比项;
其中,为轨迹实例,/>为批次大小,/>为在批次中与轨迹实例/>为正样本的数量,/>为正样本,/>为轨迹实例在编码器后的特征嵌入,/>为正样本在编码器后的特征嵌入,/>为实例间对比项的对比温度,/>为当前批次数据中的任意样本在编码器后的特征嵌入,/>为K-means聚类层次的数量,/>为第/>层簇的数量,/>为当前批次数据中的任意样本,/>为实例所属的簇的原型,/>为任意簇的原型,/>为实例所属的簇的密度,/>为任意簇的密度,/>为簇中的实例数量,/>为簇中的实例,/>为/>的动量更新特征,/>为簇的原型,为平滑因子,/>为范数,/>为以/>为底的指数函数。
上述进一步方案的有益效果是:计算原型对比学习损失时,采用具有多层次的多级聚类方法,这种设计使的预测模型可以更好的区分不同的轨迹模式,从而提高了预测的准确性,同时也使得模型具有了处理复杂轨迹数据的能力。
进一步地,S4中分布感知的预测器中将LSTM网络作为轨迹解码器,所述LSTM网络的公式为:
其中,为LSTM网络的输入门,/>为LSTM网络的更新门,/>为LSTM网络的遗忘门,为LSTM网络的输出门,/>为时间步/>的隐藏状态向量,/>为时间步/>的输入向量,/>为输入门隐藏状态权重矩阵,/>为输入门输入权重矩阵,/>为输入门偏置向量,/>为更新门隐藏状态权重矩阵,/>为更新门输入权重矩阵,/>为更新门偏置向量,/>为遗忘门隐藏状态权重矩阵,/>为遗忘门输入权重矩阵,/>为遗忘门偏置向量,/>为输出门隐藏状态权重矩阵,/>为输出门输入权重矩阵,/>为输出门偏置向量,/>为时间步/>的单元状态,为时间步/>的隐藏状态向量,/>为sigmoid操作符,/>为元素级乘积运算,/>为时间步的单元状态,/>为双曲正切函数;
则带有超网络的LSTM网络公式为:
其中,为原始LSTM网络中的任意一个门,/>为层归一化,/>为隐藏状态权重调整向量,/>为任意一个门的隐藏状态权重矩阵,/>为输入权重调整向量,/>为任意一个门的输入权重矩阵,/>为任意一个门的偏置向量,/>为超网络输出向量;
对于具有的轨迹实例/>,使用/>表示为:
其中,为超网络的映射函数,/>为对于具有/>的轨迹实例/>的超网络输出向量;
利用带有超网络的LSTM网络,为不同的轨迹输入生成不同的解码器权重,将尾部轨迹模式和头部轨迹模式以不同的方式进行预测。
上述进一步方案的有益效果是:使用LSTM网络作为轨迹解码器,LSTM网络是一种特殊的递归神经网络,能够在处理时间序列数据时捕捉长期依赖关系,LSTM的关键是其内部的记忆单元,能够在过程中存储信息,用于后续的预测。在轨迹预测中,LSTM可以用来解码或生成预测的轨迹,可有效地处理时间序列数据。而使用超网络解码器可以对尾部簇进行不同的预测。
进一步地,S4生成不同的解码器权重,公式为:
其中,为最终的损失函数,/>为基线预测网络损失函数,/>为原型对比学习损失项的系数,/>为常量超参数,/>为用于过滤头部样本的阈值,/>为预热阶段训练后的网络预测损失。
上述进一步方案的有益效果是:对于网络已经能够拟合较好的简单样本,原型对比学习损失在网络优化过程中无法带来更多的优势,因此通过改变原型对比学习损失项的系数,可以关闭在简单样本上的原型对比学习损失。
附图说明
图1为一种基于感知分布式对比学习框架的长尾轨迹预测方法流程图。
图2为FEND框架的时序图。
具体实施方式
下面结合附图和具体实施例对本发明做进一步说明。
如图1所示,一种基于感知分布式对比学习框架的长尾轨迹预测方法,包括以下步骤:
S1:利用轨迹特征提取器提取历史轨迹数据的特征;
S2:利用离散聚类模块对提取的特征进行聚类,将历史轨迹数据划分为多个轨迹模式;
S3:基于划分的多个轨迹模式,在基线预测网络的历史编码特征上,根据离散聚类模块生成的伪标签进行原型对比学习,使轨迹编码器的特征空间根据不同轨迹模式分别进行聚类;
S4:基于轨迹编码器的聚类结果,利用分布感知的预测器,为不同的轨迹输入生成不同的解码器权重,将尾部轨迹模式和头部轨迹模式以不同的方式进行预测,实现长尾轨迹的预测。
本发明提出了一种长尾轨迹预测框架(FEND),FEND框架主要包含两个部分。第一部分是一个未来增强的对比学习方法。这个方法的主要目标是通过提升特征嵌入的质量,来优化轨迹编码器的性能。通过这种方式,能够更好地理解和预测那些罕见但重要的轨迹模式。第二部分是一个灵活的分布感知预测器。该预测器的设计目标是减轻头部样本对尾部样本的影响。这意味着我们的模型可以根据输入轨迹的模式,使用不同的参数进行解码,从而更准确地预测出未来的轨迹。通过这两个主要的组件,FEND框架不仅能够更好地处理长尾现象,同时也能提高轨迹预测的整体性能,为解决自动驾驶中的轨迹预测问题提供了新的解决方案。
S1中轨迹特征提取器为:在1D卷积网络的基础上增加LSTM网络。
S2中利用离散聚类模块对提取的特征进行聚类,包括以下分步骤:
S2-1:将历史轨迹数据按照时间划分为多个时间段,且每个时间段包含多个轨迹点;
S2-2:将每个时间段中的轨迹点按照空间位置进行聚类,得到多个轨迹簇;
S2-3:计算每个轨迹簇的质心,并作为当前轨迹簇的代表点;
S2-4:将所有的代表点作为输入,使用K-means算法进行聚类,得到多个聚类簇;
S2-5:将每个聚类簇的代表点作为聚类簇的原型,完成对历史轨迹数据轨迹模式的划分。
轨迹聚类步骤的执行结果是我们已经得到了群集标签。这些标签在后续步骤中被作为伪标签使用,以计算原型和密度。需要指出的是,原始的原型对比学习(PCL)是一种自我监督的方法,其包含了期望最大化(EM)步骤,因此需要在每个训练周期之前进行聚类。然而,本方案通过使用伪标签来减少聚类步骤,因此相比于原始的PCL,本方案的方法可以节省更多的计算资源。
给定伪群标签后,利用PCL将属于同一群的实例的特征拉在一起,并将不同群的实例的特征推开。这与普通对比学习处理正样本和负样本的方式类似,但在处理长尾问题上,这种方法可以更好地区分不同的轨迹模式,从而提高预测的准确性。
本方案在编码器-解码器轨迹预测网络的瓶颈处执行了原型对比学习,对编码器后面添加了一个全连接(FC)层,并将PCL损失添加到了这个FC层的输出特征中。需要注意的是,这个FC层之前的特征会被传递给轨迹解码器。
S3中原型对比学习采用多级聚类方法计算损失,公式为:
其中,为原型对比学习损失函数,/>为实例间的对比项,/>为实例-原型的对比项;
其中,为轨迹实例,/>为批次大小,/>为在批次中与轨迹实例/>为正样本的数量,/>为正样本,/>为轨迹实例在编码器后的特征嵌入,/>为正样本在编码器后的特征嵌入,/>为实例间对比项的对比温度,/>为当前批次数据中的任意样本在编码器后的特征嵌入,/>为K-means聚类层次的数量,/>为第/>层簇的数量,/>为当前批次数据中的任意样本,/>为实例所属的簇的原型,/>为任意簇的原型,/>为实例所属的簇的密度,/>为任意簇的密度,/>为簇中的实例数量,/>为簇中的实例,/>为/>的动量更新特征,/>为簇的原型,为平滑因子,/>为范数,/>为以/>为底的指数函数。
在计算PCL损失时,采用了具有M层次的多级聚类。这种设计使得本方案的模型可以更好地区分不同的轨迹模式,从而提高了预测的准确性。同时,由于在计算PCL损失时采用了多级聚类,这也使得模型具有了处理复杂轨迹数据的能力。
S4中分布感知的预测器中将LSTM网络作为轨迹解码器,所述LSTM网络的公式为:
其中,为LSTM网络的输入门,/>为LSTM网络的更新门,/>为LSTM网络的遗忘门,为LSTM网络的输出门,/>为时间步/>的隐藏状态向量,/>为时间步/>的输入向量,/>为输入门隐藏状态权重矩阵,/>为输入门输入权重矩阵,/>为输入门偏置向量,/>为更新门隐藏状态权重矩阵,/>为更新门输入权重矩阵,/>为更新门偏置向量,/>为遗忘门隐藏状态权重矩阵,/>为遗忘门输入权重矩阵,/>为遗忘门偏置向量,/>为输出门隐藏状态权重矩阵,/>为输出门输入权重矩阵,/>为输出门偏置向量,/>为时间步/>的单元状态,为时间步/>的隐藏状态向量,/>为sigmoid操作符,/>为元素级乘积运算,/>为时间步的单元状态,/>为双曲正切函数;
则带有小型超网络的LSTM网络公式为:
其中,为原始LSTM网络中的任意一个门,/>为层归一化,/>为隐藏状态权重调整向量,/>为任意一个门的隐藏状态权重矩阵,/>为输入权重调整向量,/>为任意一个门的输入权重矩阵,/>为任意一个门的偏置向量,/>为超网络输出向量;
对于具有的轨迹实例/>,使用/>表示为:
其中,为小型超网络的映射函数,/>为对于具有/>的轨迹实例/>的超网络输出向量;
利用带有小型超网络的LSTM网络,为不同的轨迹输入生成不同的解码器权重,将尾部轨迹模式和头部轨迹模式以不同的方式进行预测。
对于头部簇和尾部簇应该使用不同的解码器进行处理,以此来减弱他们之间的相互影响。然而,由于尾部样本的数量较少,如果为其独立训练解码器,可能会导致严重的过拟合问题。因此,本发明提出一种解决方案,即在整个数据集中分享通用的知识,同时保持每个解码器的建模灵活性。这种需求自然可以通过使用超网络来实现,超网络是一种使用较小的网络(即超网络)生成主网络权重的方法。
超网络能够包含整个样本集的知识,这可以有效防止过拟合。同时,由于头部和尾部簇都有自己独立的解码器参数,因此解码器能够意识到聚类特征空间的分布,因此,超网络解码器可以对尾部簇进行不同的预测。
S4生成不同的解码器权重,公式为:
其中,为最终的损失函数,/>为基线预测网络损失函数,/>为原型对比学习损失项的系数,/>为常量超参数,/>为用于过滤头部样本的阈值,/>为预热阶段训练后的网络预测损失。
在本发明的一个实施例中,如图2所示,基于设计的FEND框架,输入数据包括历史轨迹数据和当前帧的场景信息;通过编码器将历史轨迹数据和场景信息编码成特征向量;通过对比学习的方式,将同一轨迹的不同部分的特征向量进行对比,使得相似的部分在特征空间中更加接近,不相似的部分在特征空间中更加远离;通过解码器将对比后的特征向量解码成轨迹预测结果;计算预测结果与真实轨迹之间的误差,并通过反向传播更新模型参数;重复以上步骤,直到模型收敛或达到预设的训练轮数。
本发明提出了一个未来增强的对比特征学习框架,专门用于处理长尾轨迹预测的问题,可以更好地区分出数据集中尾部模式和头部模式。这是通过代表不同模式的群聚原型来实现的,这样做的目的是为了增强模型对尾部数据的建模能力。通过这种方式,能够更好地理解和预测罕见但重要的轨迹模式。其次,本发明提出了一个名为分布感知的预测器。这个预测器的设计目标是为具有不同模式的轨迹输入提供分离的解码器参数,说明本发明的模型可以根据输入轨迹的模式,使用不同的参数进行解码,从而更准确地预测出未来的轨迹。
本领域的普通技术人员将会意识到,这里所述的实施例是为了帮助读者理解本发明的原理,应被理解为本发明的保护范围并不局限于这样的特别陈述和实施例。本领域的普通技术人员可以根据本发明公开的这些技术启示做出各种不脱离本发明实质的其它各种具体变形和组合,这些变形和组合仍然在发明的保护范围内。
Claims (6)
1.一种基于感知分布式对比学习框架的长尾轨迹预测方法,其特征在于,包括以下步骤:
S1:利用轨迹特征提取器提取历史轨迹数据的特征;
S2:利用离散聚类模块对提取的特征进行聚类,将历史轨迹数据划分为多个轨迹模式;
S3:基于划分的多个轨迹模式,在基线预测网络的历史编码特征上,根据离散聚类模块生成的伪标签进行原型对比学习,使轨迹编码器的特征空间根据不同轨迹模式分别进行聚类;
S4:基于轨迹编码器的聚类结果,利用分布感知的预测器,为不同的轨迹输入生成不同的解码器权重,将尾部轨迹模式和头部轨迹模式以不同的方式进行预测,实现长尾轨迹的预测。
2.根据权利要求1所述的基于感知分布式对比学习框架的长尾轨迹预测方法,其特征在于,所述S1中轨迹特征提取器为:在1D卷积网络的基础上增加LSTM网络。
3.根据权利要求1所述的基于感知分布式对比学习框架的长尾轨迹预测方法,其特征在于,所述S2中利用离散聚类模块对提取的特征进行聚类,包括以下分步骤:
S2-1:将历史轨迹数据按照时间划分为多个时间段,且每个时间段包含多个轨迹点;
S2-2:将每个时间段中的轨迹点按照空间位置进行聚类,得到多个轨迹簇;
S2-3:计算每个轨迹簇的质心,并作为当前轨迹簇的代表点;
S2-4:将所有的代表点作为输入,使用K-means算法进行聚类,得到多个聚类簇;
S2-5:将每个聚类簇的代表点作为聚类簇的原型,完成对历史轨迹数据轨迹模式的划分。
4.根据权利要求1所述的基于感知分布式对比学习框架的长尾轨迹预测方法,其特征在于,所述S3中原型对比学习采用多级聚类方法计算损失,公式为:
其中,为原型对比学习损失函数,/>为实例间的对比项,/>为实例-原型的对比项;
其中,为轨迹实例,/>为批次大小,/>为在批次中与轨迹实例/>为正样本的数量,/>为正样本,/>为轨迹实例在编码器后的特征嵌入,/>为正样本在编码器后的特征嵌入,/>为实例间对比项的对比温度,/>为当前批次数据中的任意样本在编码器后的特征嵌入,/>为K-means聚类层次的数量,/>为第/>层簇的数量,/>为当前批次数据中的任意样本,/>为实例所属的簇的原型,/>为任意簇的原型,/>为实例所属的簇的密度,/>为任意簇的密度,/>为簇中的实例数量,/>为簇中的实例,/>为/>的动量更新特征,/>为簇的原型,/>为平滑因子,/>为范数,/>为以/>为底的指数函数。
5.根据权利要求4所述的基于感知分布式对比学习框架的长尾轨迹预测方法,其特征在于,所述S4中分布感知的预测器中将LSTM网络作为轨迹解码器,所述LSTM网络的公式为:
其中,为LSTM网络的输入门,/>为LSTM网络的更新门,/>为LSTM网络的遗忘门,/>为LSTM网络的输出门,/>为时间步/>的隐藏状态向量,/>为时间步/>的输入向量,/>为输入门隐藏状态权重矩阵,/>为输入门输入权重矩阵, />为输入门偏置向量,/>为更新门隐藏状态权重矩阵,/>为更新门输入权重矩阵,/>为更新门偏置向量,/>为遗忘门隐藏状态权重矩阵,/>为遗忘门输入权重矩阵,/>为遗忘门偏置向量,/>为输出门隐藏状态权重矩阵,/>为输出门输入权重矩阵,/>为输出门偏置向量,/>为时间步/>的单元状态,/>为时间步/>的隐藏状态向量,/>为sigmoid操作符,/>为元素级乘积运算,/>为时间步的单元状态,/>为双曲正切函数;
则带有超网络的LSTM网络公式为:
其中,为原始LSTM网络中的任意一个门,/>为层归一化,/>为隐藏状态权重调整向量,/>为任意一个门的隐藏状态权重矩阵,/>为输入权重调整向量,/>为任意一个门的输入权重矩阵,/>为任意一个门的偏置向量,/>为超网络输出向量;
对于具有的轨迹实例/>,使用/>表示为:
其中,为超网络的映射函数,/>为对于具有/>的轨迹实例/>的超网络输出向量;
利用带有超网络的LSTM网络,为不同的轨迹输入生成不同的解码器权重,将尾部轨迹模式和头部轨迹模式以不同的方式进行预测。
6.根据权利要求5所述的基于感知分布式对比学习框架的长尾轨迹预测方法,其特征在于,所述S4生成不同的解码器权重,公式为:
其中,为最终的损失函数,/>为基线预测网络损失函数,/>为原型对比学习损失项的系数,/>为常量超参数,/>为用于过滤头部样本的阈值,/>为预热阶段训练后的网络预测损失。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311222987.7A CN116956098A (zh) | 2023-09-21 | 2023-09-21 | 一种基于感知分布式对比学习框架的长尾轨迹预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311222987.7A CN116956098A (zh) | 2023-09-21 | 2023-09-21 | 一种基于感知分布式对比学习框架的长尾轨迹预测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116956098A true CN116956098A (zh) | 2023-10-27 |
Family
ID=88453316
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311222987.7A Pending CN116956098A (zh) | 2023-09-21 | 2023-09-21 | 一种基于感知分布式对比学习框架的长尾轨迹预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116956098A (zh) |
Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109213524A (zh) * | 2017-06-29 | 2019-01-15 | 英特尔公司 | 用于难预测分支的预测器 |
CN111488984A (zh) * | 2020-04-03 | 2020-08-04 | 中国科学院计算技术研究所 | 一种用于训练轨迹预测模型的方法和轨迹预测方法 |
CN113158861A (zh) * | 2021-04-12 | 2021-07-23 | 杭州电子科技大学 | 一种基于原型对比学习的运动分析方法 |
WO2022115766A1 (en) * | 2020-11-30 | 2022-06-02 | Saudi Arabian Oil Company | Deep learning-based localization of uavs with respect to nearby pipes |
CN115049991A (zh) * | 2022-06-15 | 2022-09-13 | 上海钧正网络科技有限公司 | 共享设备位姿整齐度判断方法、装置、终端及介质 |
US20220396289A1 (en) * | 2021-06-15 | 2022-12-15 | Nvidia Corporation | Neural network path planning |
CN115829171A (zh) * | 2023-02-24 | 2023-03-21 | 山东科技大学 | 一种联合时空信息和社交互动特征的行人轨迹预测方法 |
US20230131815A1 (en) * | 2020-05-29 | 2023-04-27 | Imra Europe S.A.S. | Computer-implemented method for predicting multiple future trajectories of moving objects |
CN116092055A (zh) * | 2023-01-30 | 2023-05-09 | 北京百度网讯科技有限公司 | 训练方法、获取方法、装置、设备及自动驾驶车辆 |
CN116186358A (zh) * | 2023-02-07 | 2023-05-30 | 和智信(山东)大数据科技有限公司 | 一种深度轨迹聚类方法、系统及存储介质 |
CN116258252A (zh) * | 2023-01-10 | 2023-06-13 | 浙江工业大学 | 基于注意力机制的飞机轨迹预测方法、储存介质及设备 |
CN116349211A (zh) * | 2020-09-14 | 2023-06-27 | 华为云计算技术有限公司 | 基于自注意力的深度学习的分布式轨迹异常检测 |
-
2023
- 2023-09-21 CN CN202311222987.7A patent/CN116956098A/zh active Pending
Patent Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109213524A (zh) * | 2017-06-29 | 2019-01-15 | 英特尔公司 | 用于难预测分支的预测器 |
CN111488984A (zh) * | 2020-04-03 | 2020-08-04 | 中国科学院计算技术研究所 | 一种用于训练轨迹预测模型的方法和轨迹预测方法 |
US20230131815A1 (en) * | 2020-05-29 | 2023-04-27 | Imra Europe S.A.S. | Computer-implemented method for predicting multiple future trajectories of moving objects |
CN116349211A (zh) * | 2020-09-14 | 2023-06-27 | 华为云计算技术有限公司 | 基于自注意力的深度学习的分布式轨迹异常检测 |
WO2022115766A1 (en) * | 2020-11-30 | 2022-06-02 | Saudi Arabian Oil Company | Deep learning-based localization of uavs with respect to nearby pipes |
CN113158861A (zh) * | 2021-04-12 | 2021-07-23 | 杭州电子科技大学 | 一种基于原型对比学习的运动分析方法 |
US20220396289A1 (en) * | 2021-06-15 | 2022-12-15 | Nvidia Corporation | Neural network path planning |
CN115049991A (zh) * | 2022-06-15 | 2022-09-13 | 上海钧正网络科技有限公司 | 共享设备位姿整齐度判断方法、装置、终端及介质 |
CN116258252A (zh) * | 2023-01-10 | 2023-06-13 | 浙江工业大学 | 基于注意力机制的飞机轨迹预测方法、储存介质及设备 |
CN116092055A (zh) * | 2023-01-30 | 2023-05-09 | 北京百度网讯科技有限公司 | 训练方法、获取方法、装置、设备及自动驾驶车辆 |
CN116186358A (zh) * | 2023-02-07 | 2023-05-30 | 和智信(山东)大数据科技有限公司 | 一种深度轨迹聚类方法、系统及存储介质 |
CN115829171A (zh) * | 2023-02-24 | 2023-03-21 | 山东科技大学 | 一种联合时空信息和社交互动特征的行人轨迹预测方法 |
Non-Patent Citations (2)
Title |
---|
YUNING WANG等: "FEND: A Future Enhanced Distribution-Aware Contrastive Learning Framework for Long-tail Trajectory Prediction", 《ARXIV:2303.16574V1》, pages 2 - 4 * |
张天予等: "视频中的未来动作预测研究综述", 《计算机学报》, vol. 46, no. 06, pages 1315 - 1338 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Wu et al. | Progressive tandem learning for pattern recognition with deep spiking neural networks | |
US10380236B1 (en) | Machine learning system for annotating unstructured text | |
Wiseman et al. | Sequence-to-sequence learning as beam-search optimization | |
CN110929092B (zh) | 一种基于动态注意力机制的多事件视频描述方法 | |
CN113642225A (zh) | 一种基于attention机制的CNN-LSTM短期风电功率预测方法 | |
Zhang et al. | Unsupervised representation learning from pre-trained diffusion probabilistic models | |
CN110033089B (zh) | 基于分布式估计算法的手写体数字图像识别深度神经网络参数优化方法及系统 | |
CN107506865A (zh) | 一种基于lssvm优化的负荷预测方法及系统 | |
CN116187555A (zh) | 基于自适应动态图的交通流预测模型构建方法及预测方法 | |
CN111709754A (zh) | 一种用户行为特征提取方法、装置、设备及系统 | |
Shen et al. | Reservoir transformers | |
CN110570035A (zh) | 同时建模时空依赖性和每日流量相关性的人流量预测系统 | |
CN113033861A (zh) | 一种基于时间序列模型的水质预测方法及系统 | |
CN115578680B (zh) | 一种视频理解方法 | |
Bolelli et al. | A hierarchical quasi-recurrent approach to video captioning | |
Wang et al. | Quantformer: Learning extremely low-precision vision transformers | |
Liu et al. | Spiking-diffusion: Vector quantized discrete diffusion model with spiking neural networks | |
Xu et al. | Generative data free model quantization with knowledge matching for classification | |
CN116956098A (zh) | 一种基于感知分布式对比学习框架的长尾轨迹预测方法 | |
Mondal et al. | SSDMM-VAE: variational multi-modal disentangled representation learning | |
CN114154582A (zh) | 基于环境动态分解模型的深度强化学习方法 | |
CN111539989B (zh) | 基于优化方差下降的计算机视觉单目标跟踪方法 | |
CN117980915A (zh) | 用于端到端自监督预训练的对比学习和掩蔽建模 | |
CN113657405A (zh) | 平滑模型训练方法、装置及电子设备 | |
Potapov et al. | Differences between Kolmogorov complexity and Solomonoff probability: consequences for AGI |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |