CN116956045A - 一种用于行为预测域自适应的伪标签改进的轨迹预测方法 - Google Patents
一种用于行为预测域自适应的伪标签改进的轨迹预测方法 Download PDFInfo
- Publication number
- CN116956045A CN116956045A CN202310977528.3A CN202310977528A CN116956045A CN 116956045 A CN116956045 A CN 116956045A CN 202310977528 A CN202310977528 A CN 202310977528A CN 116956045 A CN116956045 A CN 116956045A
- Authority
- CN
- China
- Prior art keywords
- traffic parameter
- pseudo tag
- enhancement
- domain
- representing
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Withdrawn
Links
- 238000000034 method Methods 0.000 title claims abstract description 38
- 238000012545 processing Methods 0.000 claims abstract description 45
- 238000012549 training Methods 0.000 claims abstract description 18
- 238000013528 artificial neural network Methods 0.000 claims abstract description 13
- 230000004927 fusion Effects 0.000 claims description 62
- 230000006399 behavior Effects 0.000 claims description 50
- 230000006870 function Effects 0.000 claims description 14
- 230000006978 adaptation Effects 0.000 claims description 12
- 230000000873 masking effect Effects 0.000 claims description 9
- 230000007717 exclusion Effects 0.000 claims description 6
- 238000005457 optimization Methods 0.000 claims description 6
- 230000008569 process Effects 0.000 claims description 5
- 238000004364 calculation method Methods 0.000 claims description 4
- 238000010586 diagram Methods 0.000 claims description 4
- 238000012163 sequencing technique Methods 0.000 claims description 3
- 230000003542 behavioural effect Effects 0.000 claims 9
- 238000005070 sampling Methods 0.000 description 5
- 238000012217 deletion Methods 0.000 description 2
- 230000037430 deletion Effects 0.000 description 2
- 238000007781 pre-processing Methods 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 230000003044 adaptive effect Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 230000001010 compromised effect Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000003672 processing method Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/253—Fusion techniques of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0895—Weakly supervised learning, e.g. semi-supervised or self-supervised learning
-
- G—PHYSICS
- G08—SIGNALLING
- G08G—TRAFFIC CONTROL SYSTEMS
- G08G1/00—Traffic control systems for road vehicles
- G08G1/01—Detecting movement of traffic to be counted or controlled
- G08G1/0104—Measuring and analyzing of parameters relative to traffic conditions
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Chemical & Material Sciences (AREA)
- Analytical Chemistry (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种用于行为预测域自适应的伪标签改进的轨迹预测方法。将源域输入到预测器中进行训练,获得训练后的预测器;然后将目标域的数据输入到由训练后的预测器优化获得的第一神经网络中进行再次训练,获得训练后的第一神经网络;最后从训练后的第一神经网络提取组建预测网络,对实时待测的数据实进行识别预测处理。本发明能够提升伪标签的质量以提升模型网络的泛化性能,能够确保目标域数据的多样性以及防止模型陷入局部最优,提升了伪标签的预测性能,进而提升轨迹预测能力。
Description
技术领域
本发明涉及了自动驾驶领域中的一种行为标签处理方法,尤其是涉及了一种用于行为预测域自适应的伪标签改进的轨迹行为预测方法。
背景技术
行为预测是自动驾驶系统重要的组成部分,其主要利用道路地图信息和交通交通参数者过去的行为状态来预测其未来的行为轨迹,以此为自动驾驶汽车采取安全可靠的行动作保障。依靠大规模手工标注的数据,尤其是交通参数者过去的行为数据,行为预测实现了显著的突破。然而,因为不可避免的域差异,例如地理位置不同、照明条件差异及天气干扰等,导致在源域训练好的模型直接应用在目标域时,模型预测性能出现很大地下降,而重新收集并标注各种各样的目标域又会消耗大量的人力。因此,行为预测域自适应用于降低成本,减小域漂移,进一步提升跨域行为预测性能。
近年来,用于自适应的伪标签方法在很多任务上(例如,分类、识别、检测和分割等)都取得了显著的进步。这些方法大多针对特定的任务,生成相应的伪标签,然后再利用其监督目标域。经过多次迭代,使模型不断收敛。然而,当直接迁移这些伪标签方法到时序的和交互的行为数据时,因为内部数据结构和网络架构的不同,它们的预测性能可能会进一步下降。
因此,区别于以往的任务,如何有效地生成行为预测领域高质量的伪标签(即伪轨迹)依然面临着巨大挑战。
发明内容
为了解决背景技术中存在的问题,解决预测性能下降等方面的问题,本发明提出了一种用于行为预测域自适应的伪标签改进的轨迹行为预测方法,是一种专门应用于行为预测的域自适应的伪标签方法,进一步提升了模型的预测性能和泛化能力,提升伪标签的质量即提升行为预测域自适应性能。
本发明的技术方案:
1)将源域输入到预测器P中进行训练,获得训练后的预测器P;
2)然后将目标域的数据输入到由训练后的预测器P优化获得的第一神经网络中进行再次训练,获得训练后的第一神经网络;
3)最后从训练后的第一神经网络提取组建预测网络,对实时待测的数据实进行识别预测处理。
所述的步骤2)中,第一神经网络是在预测器P的拓扑结构基础上增设了额外的伪标签更新模块和轨迹主导的对比学习模块;
将目标域输入到训练后的预测器P中进行预测处理获得获得初步伪标签和特征并输入到伪标签更新模块中,同时将特征输入到轨迹主导的对比学习模块中并结合根据源域和目标域的数据处理获得对比损失并返回到伪标签更新模块,伪标签更新模块中根据初步伪标签进行处理获得优化伪标签,根据优化伪标签对损失进行计算获得优化总损失,以优化总损失最小为目标对训练后的预测器P、伪标签更新模块和轨迹主导的对比学习模块进行再次训练,直到损失收敛为止。
所述步骤2)中,针对目标域的预测器P的处理具体为:
S1、针对每个样本,在其他交通参数者和目标交通参数者之间的关系以及车道线点和目标交通参数者之间的关系下分别建立一个加权交通参数者图Ga和一个加权车道线点图Gm,利用加权交通参数者图Ga和一个加权车道线点图Gm进行掩码处理获得掩码处理后的样本数据;
S2、在掩码/掩盖删除处理后,根据样本保留下的交通参数者和车道线点的历史行为状态和地图信息按照以下方式进行特征增强处理,获取第i个样本的四个增强特征:
其中,Ai、Mi分别表示第i个样本掩码处理后的历史行为状态和地图信息;和分别对应表示弱增强、强增强,/>分别表示弱增强后的交通参数者增强特征和强增强后的交通参数者增强特征,/>分别表示弱增强后的地图增强特征和强增强后的地图增强特征;φa、φm分别是预测器P中的交通参数者、地图的特征编码器;
再根据两个交通参数者增强特征和两个地图增强特征/>经过采取以下方式处理获取两个融合特征:
其中,分别表示弱增强后的融合增强特征和强增强后的融合增强特征;φf是预测器P中的融合特征编码器;
由上述四个增强特征和两个融合特征共同组成了预测器P输出的特征。
所述S1具体为:
A)样本在其他交通参数者和目标交通参数者之间的关系下,建立一个加权交通参数者图Ga,以其他交通参数者和目标交通参数者的轨迹坐标均作为节点,节点表示最后观测到的轨迹坐标,交通参数者为中心节点,其他交通参数者的节点均连接到目标交通参数者的节点,其他交通参数者的节点均连接到和目标交通参数者的节点之间通过边连接,每条边赋予边权值;
目标交通参数者的节点Θ与其他交通参数者的节点j之间的边权值wΘj设置为:
其中,ρa是缩放系数,x(Θ,0)、y(Θ,0)表示目标交通参数者的初始轨迹坐标,x(j,0)、y(j,0)表示第j个其他交通参数者的初始轨迹坐标;
针对每个其他交通参数者的节点j,分别设置以下节点的掩码概率以进行弱增强和强增强:
其中,wΘk表示目标交通参数者的节点Θ与其他交通参数者的节点k之间的边权值,k表示其他交通参数者的序数,k∈Va/Θ,Va是所有交通参数者的节点的集合,Va/Θ表示Va排除Θ,/表示排除;
所有其他交通参数者j的节点的掩码概率相加为1。
在对其他交通参数者进行弱增强时,按照上述弱增强对应的掩码概率掩码百分之ma的其他交通参数者的历史行为状态,从而获得第i个样本掩码处理后的历史行为状态。
B)样本在车道线点和目标交通参数者之间的关系下,建立一个加权车道线点图Gm,以其他交通参数者的轨迹坐标和车道线点均作为节点,目标交通参数者为中心节点,车道线点均连接到目标交通参数者的节点,车道线点的节点均连接到和目标交通参数者的节点之间通过边连接,每条边赋予边权值;
目标交通参数者的节点Θ和车道线点的节点之间的边权值wΘj是与目标交通参数者的节点Θ和其他交通参数者的节点j之间的边权值wΘj设置相同。
针对每个车道线点的节点,也按照和其他交通参数者和目标交通参数者之间的关系下的相同计算方式设置节点的掩码概率以进行弱增强和强增强;
所有车道线点的节点的掩码概率相加为1。
在对车道线点进行弱增强时,按照上述弱增强对应的掩码概率掩码百分之mm的车道线点的地图信息,从而获得第i个样本掩码处理后的地图信息。
所述的掩码是指掩盖删除,即将掩码的数据进行删除。
所述步骤1)中,针对源域的预测器P的处理具体为:
将第i个样本的历史行为状态和地图信息分别输入到交通参数者特征编码器、地图特征编码器中获得交通参数者特征、地图特征,同时将样本的历史行为状态和地图信息共同输入到融合特征编码器中获得融合特征,具体表示为:
其中,φa、φm和φf是交通参数者、地图、融合的特征编码器;是预测解码器,和/>分别表示交通参数者特征、地图特征和融合特征,Ai、Mi分别表示第i个样本的历史行为状态和地图信息;
然后将融合特征输入到预测解码器中获得第i个样本的目标交通参数者的预测轨迹坐标及其置信分数Ci:
其中,表示第i个样本的目标交通参数者的预测轨迹坐标,Ci表示第i个样本的预测轨迹坐标/>的置信分数。
在伪标签更新模块中,在训练中的每次迭代时按照以下方式处理:
针对当前迭代轮次,计算第k迭代轮次下获得的初步伪标签和前k-1迭代轮次获得的初步伪标签集合/>之间的一致性进而获取第k迭代轮次下获得的初步伪标签的最佳匹配轨迹,最佳匹配轨迹的索引计算如下:
其中,h表示来自集合中的伪标签,/> 表示最佳匹配轨迹的索引,/>表示余弦相似度函数,/>表示最佳匹配轨迹的索引;
然后根据最佳匹配轨迹的索引进行以下判断:
如果满足或者/>则将第1迭代轮次下获得的初步伪标签/>设置为/>其中Tc是预设的置信分数阈值,/>分别表示初步伪标签/>的置信分数和/>的置信分数,此时伪标签/>不更新;此时还未产生伪标签/>则在之后某次迭代过程中满足后面的条件才产生伪标签/>
如果不满足或者/>则进一步进行以下判断:
若满足则设置第k迭代轮次下更新后的伪标签/>并更新前k-1迭代轮次获得的初步伪标签集合/>进行/>
若不满足则设置第k迭代轮次下更新后的伪标签/>并更新前k-1迭代轮次获得的初步伪标签集合/>进行/>
从而获得第k迭代轮次下更新后的伪标签。
所述伪标签更新模块根据更新后的伪标签按照以下公式计算目标域损失/>作为优化伪标签:
其中,Γ(·)是一个截断函数,ρt是超参数,是P中的损失函数,/>表示当前第k迭代轮次下获得的目标交通参数者/>的初步伪标签,/>表示最佳匹配轨迹的索引,exp()表示指数函数,/>表示在第k个训练轮次中所预测的轨迹。
在所述对比学习模块中,设置两个域Δ和且Δ,/>即表示两个域Δ和▽的每个域为源域或者目标域,然后:
针对第一域Δ中的每个样本的融合特征,是指原始样本的在第二域▽的所有样本的融合增强特征中根据轨迹一致性挑选和第一域Δ的样本的融合增强特征不接近融合增强特征并计算特征作为阳性特征/>在第二域▽的所有融合增强特征中选择和第一域Δ的样本的融合增强特征接近的融合增强特征并计算特征作为阴性特征/>
然后进行以下判断:
如果阳性特征和阴性特征/>其中之一不存在,则对比损失设置为零;
如果阳性特征和阴性特征/>均存在,则设置轨迹主导的对比损失为:
其中,sim()表示余弦相似度,ρc表示温度系数,同时实施域内和域间对比学习;表示第一域Δ中第i个样本的融合特征,/>表示第二域▽中第i个样本的融合特征;
进而遍历两个域Δ和▽相同和不相同以及先后顺序的四种情况,分别获得两个域内对比损失和/>以及两个域间对比损失/>和/>
根据两个域内对比损失和/>以及两个域间对比损失/>和/>进行相加获得总的对比损失/>表示为:
所述步骤2)进行再次训练时,由预测器P输出源域预测损失和目标域预测损失/>由对比学习模块输出对比损失/>结合上述损失采用以下公式进行损失的计算,表达为:
其中,表示优化总损失,η表示损失权重,η∈(0,1)是一个超参数。
所述步骤3)中,具体是最后从训练后的第一神经网络提取出预测器P组成优化后的用于生成高质量伪标签的预测网络,将待测场景的数据实时输入到预测网络中进行预测处理获得最终结果。
首先,如图1(a)所示,本发明在预处理阶段,基于加权随机源码设计了应用于地图中车道节点和交通参数者节点的强增强和弱增强,生成目标交通参数者高质量的伪标签。
具体来说,通过自监督学习,方法分别从交通参数者、地图和融合三个角度来重建特征。同时,在源域的增强后的样本和原样本都被真实标签监督。
其次,在训练过程中,如图1(b)所示,针对于目标域中伪标签的改进,通过评估以往轮次与现在轮次的伪标签一致程度,方法设计了一种伪标签更新增强方式。
最后,如图1(c)所示,为了减轻跨域间交通参数者的表征漂移,利用生成的伪标签,方法提出了轨迹主导的对比学习,将相似的轨迹表征拉近,不同的轨迹表征拉远。
整体方法在推断过程中不引入任何额外的计算复杂度。
本发明的有益效果:
本发明能够提升伪标签的质量以提升模型的泛化性能,提升了行为预测域自适应模型的预测性能和泛化能力,提高了轨迹预测的速度和准确性。
附图说明
图1是本发明方法的整体逻辑图;
图2是车道线点增强和交通参数者增强示例图。
具体实施方式
下面结合附图和具体实施对本发明作进一步说明。
如图1所示,本发明的实施例如下:
给定一个已有的行为预测器P、有标注的源域和无标注的目标域/>
其中表示源域,包含了源域的ns个有标注样本,每个有标注样本包含了历史行为状态、车道线点和标签,/>
表示源域第i个有标注样本中的历史行为状态,历史行为状态包括所有交通参数者的轨迹坐标(即位置信息)和速度,所有交通参数者包含了目标交通参数者和位于目标交通参数者周围且处于同一场景中的其他交通参数者,/>其中Na表示源域第i个有标注样本中包含的交通参数者的总数,是一个路口场景下所有的交通参数者的总数/>表示历史行为状态的步长,即轨迹坐标(即位置信息)和速度采样的间隔时间维度,Ca表示每个状态的维度(例如轨迹坐标和速度等)。
表示源域第i个有标注样本中的地图信息,包括所有地图上的车道线点,其中Nm表示源域第i个有标注样本的地图数量,Cm是车道线点的特征维度(例如坐标和类型等),所述车道线点即为车道线上的采样点,是沿车道线等间隔采样获得的点;
为源域第i个有标注样本中的标签,其中,/>分别表示源域的第i个目标交通参数者在第k次采样时的轨迹坐标,K表示采样的总数;
同样地,表示目标域,包含了目标域的nt个无标注样本,每个无标注样本包含了车道线点和标签,/> 和/>表示在目标域第i个无标注样本中的历史行为状态和地图信息,历史行为状态包括所有交通参数者的轨迹坐标(即位置信息)和速度,所有交通参数者包含了目标交通参数者和位于目标交通参数者周围且处于同一场景中的其他交通参数者,地图信息包括了所有地图上的车道线点。
1)将源域输入到预测器P中进行训练,获得训练后的预测器P;
在源域输入到预测器P中进行训练时,由预测器P输出源域预测损失利用源域预测损失/>进行训练优化,具体损失函数表达为:
其中,表示源域训练时的总损失。
2)然后将目标域输入到训练后的预测器P中进行预测处理获得获得初步伪标签和特征并输入到伪标签更新模块中,同时将特征输入到轨迹主导的对比学习模块中并结合根据源域和目标域的数据处理获得三元组损失并返回到伪标签更新模块,伪标签更新模块中根据初步伪标签进行处理获得优化伪标签,根据优化伪标签对损失进行计算获得优化总损失,以优化总损失最小为目标对训练后的预测器P、伪标签更新模块和轨迹主导的对比学习模块进行再次训练,直到损失收敛为止。
在预处理阶段,给定一个来自源域或者目标域/>的样本i,所述的预测器P包括预测解码器和多种特征编码器,预测器P的流程具体为:
S1、针对每个样本,在其他交通参数者和目标交通参数者之间的关系以及车道线点和目标交通参数者之间的关系下分别建立一个加权交通参数者图Ga和一个加权车道线点图Gm,利用加权交通参数者图Ga和一个加权车道线点图Gm进行掩码处理获得掩码处理后的样本数据。
具体实施利用加权随机掩码来获取弱增强和强增强。
A)样本在其他交通参数者和目标交通参数者之间的关系下,建立一个加权交通参数者图Ga,以其他交通参数者和目标交通参数者的轨迹坐标均作为节点,目标交通参数者为中心节点,其他交通参数者的节点均连接到目标交通参数者的节点,其他交通参数者的节点均连接到和目标交通参数者的节点之间通过边连接,每条边赋予边权值;
具体地,Ga=(Va,Ea,Wa),其中Va是一个节点集合(即所有交通参数者),Ea是一个目标交通参数者的节点Θ连接到其他交通参数者的节点的边集合。
目标交通参数者的节点Θ与其他交通参数者的节点j(Θ≠j)之间的边权值wΘj设置为:
其中,ρa是缩放系数,x(Θ,0)、y(Θ,0)表示目标交通参数者的初始轨迹坐标,x(j,0)、y(j,0)表示第j个其他交通参数者的初始轨迹坐标;
针对每个其他交通参数者的节点j,分别设置以下节点的掩码概率以进行弱增强和强增强,以掩码概率对每个其他交通参数者的节点j的轨迹坐标进行弱增强和强增强:
其中,wΘk表示目标交通参数者的节点Θ与其他交通参数者的节点k之间的边权值,k表示其他交通参数者的序数,k∈Va/Θ,Va是所有交通参数者的节点的集合,Va/Θ表示Va排除Θ,/表示排除;
在对其他交通参数者进行弱增强时,按照上述弱增强对应的掩码概率掩码百分之ma的其他交通参数者的历史行为状态,从而获得第i个样本掩码处理后的历史行为状态。
B)样本在车道线点和目标交通参数者之间的关系下,建立一个加权车道线点图Gm,以其他交通参数者的轨迹坐标和车道线点均作为节点,目标交通参数者为中心节点,车道线点均连接到目标交通参数者的节点,车道线点的节点均连接到和目标交通参数者的节点之间通过边连接,每条边赋予边权值;
具体地,Gm=(Vm,Em,Wm),其中Vm是一个节点集合(即所有车道线点∪目标交通参数者的节点Θ),Ea是一个目标交通参数者的节点Θ连接到车道线点的边集合。
目标交通参数者的节点Θ和车道线点的节点之间的边权值wΘj是与目标交通参数者的节点Θ和其他交通参数者的节点j(Θ≠j)之间的边权值wΘj设置相同。
针对每个车道线点的节点,也按照和其他交通参数者和目标交通参数者之间的关系下的相同计算方式设置节点的掩码概率以进行弱增强和强增强,以掩码概率对每个车道线点的节点的轨迹坐标进行弱增强和强增强;
在对车道线点进行弱增强时,按照上述弱增强对应的掩码概率掩码百分之mm的车道线点的地图信息,从而获得第i个样本掩码处理后的地图信息。
S2、在掩码/掩盖删除后,根据样本保留下的交通参数者和车道线点的历史行为状态和地图信息按照以下方式进行特征增强处理,获取第i个样本的四个增强特征:
其中,Ai、Mi分别表示第i个样本掩码处理后的历史行为状态和地图信息;和分别对应表示弱增强、强增强,/>分别表示弱增强后的交通参数者增强特征和强增强后的交通参数者增强特征,/>分别表示弱增强后的地图增强特征和强增强后的地图增强特征;φa、φm分别是预测器P中的交通参数者、地图的特征编码器;
再根据两个交通参数者增强特征和两个地图增强特征/>经过采取以下方式处理获取两个融合特征:
其中,分别表示弱增强后的融合增强特征和强增强后的融合增强特征;φf是预测器P中的融合特征编码器;
最终获得了上述六种特征,由上述四个增强特征和两个融合特征共同组成了预测器P输出的特征。
具体实施中,弱增强是指对距离目标交通参与者较远的部分节点进行掩码操作,强增强是指对距离目标交通参与者较近的部分节点进行掩码操作。
如图2所示,对(a)中较近的代理节点进行掩码,此时为强代理增强;对(a)中较远的车道线节点进行掩码,此时为弱车道线点代理增强;对(b)中较远的代理节点进行掩码,此时为弱代理增强;对(b)中较近的车道线节点进行掩码,此时为强车道线点代理增强。
具体结果如图2所示,图2中,最浅灰色点表示目标交通参数者,较中灰色点表示相关交通参数者,最深灰色点表示车道线点。
在目标域的每个样本经过预测器P的处理中,还要获得未经增强的交通参数者特征、地图特征,并且通过交通参数者特征、地图特征获得融合特征,再通过融合特征获得每个样本的置信分数Ci。
在源域中,预测器P输出源域预测损失被它的损失函数/>所优化。
在目标域中,预测器P输出目标域预测损失被的损失函数/>所优化。
在目标域的训练中,用均方差损失重建所有增强的交通参数者、车道线点及融合特征的损失,即将上述增强特征和融合特征分别与源域训练时获得的、未经增强的交通参数者特征、地图特征和融合特征作对应的损失计算。
源域训练时获得的交通参数者特征、地图特征和融合特征,以及目标域经预测器P处理中获得的未经增强的交通参数者特征、地图特征、融合特征以及置信分数Ci均按照以下方式获得:
将第i个样本的历史行为状态和地图信息分别输入到交通参数者特征编码器、地图特征编码器中获得交通参数者特征、地图特征,同时将样本的历史行为状态和地图信息共同输入到融合特征编码器中获得融合特征,具体表示为:
其中,φa、φm和φf是交通参数者、地图、融合的特征编码器;是预测解码器,和/>分别表示交通参数者特征、地图特征和融合特征,Ai、Mi分别表示第i个样本的历史行为状态和地图信息;
然后将融合特征输入到预测解码器中获得第i个样本的目标交通参数者的预测轨迹坐标及其置信分数Ci:
其中,表示第i个样本的目标交通参数者的预测轨迹坐标,Ci表示第i个样本的预测轨迹坐标/>的置信分数。
在伪标签更新模块中,按照以下方式处理:
具体来说,给定目标域中样本i,表示当前第k迭代轮次下获得的目标交通参数者/>的初步伪标签,i表示第一个样本的,t表示目标区域目标交通参数者域,/>是前k-1迭代轮次获得的目标交通参数者/>的初步伪标签集合。
初始情况下设置初始迭代轮次获得的初步伪标签集合等于第1迭代轮次下获得的初步伪标签/>其中/>来自于预训练模型,使用源域数据训练得到的模型;
针对当前迭代轮次,计算第k迭代轮次下获得的初步伪标签和前k-1迭代轮次获得的初步伪标签集合/>之间的一致性进而获取第k迭代轮次下获得的初步伪标签的最佳匹配轨迹,最佳匹配轨迹的索引计算如下:
其中,h表示来自集合中的伪标签,/>表示最佳匹配轨迹的索引,/>表示余弦相似度函数,/>表示最佳匹配轨迹的索引;
然后根据最佳匹配轨迹的索引进行以下判断:
如果满足或者/>则将第1迭代轮次下获得的初步伪标签/>设置为/>其中Tc是预设的置信分数阈值,/>分别表示初步伪标签/>的置信分数和索引/>的置信分数;伪标签/>不更新;此时还未产生伪标签/>则在之后某次迭代过程中满足后面的条件才产生伪标签/>
如果不满足或者/>则进一步进行以下判断:
若满足则设置第k迭代轮次下更新后的伪标签bk i并更新前k-1迭代轮次获得的初步伪标签集合/>进行/>
若不满足则设置第k迭代轮次下更新后的伪标签bk i并更新前k-1迭代轮次获得的初步伪标签集合/>进行/>
从而获得第k迭代轮次下更新后的伪标签。
最终将第k迭代轮次下更新后的伪标签被储存并在第k个轮次中监督样本i。
伪标签更新模块考虑到一致性水平,根据更新后的伪标签按照以下公式计算目标域损失/>作为优化伪标签:
其中,Γ(·)是一个截断函数,ρt是超参数,是P中的损失函数,/>表示当前第k迭代轮次下获得的目标交通参数者/>的初步伪标签,/>表示最佳匹配轨迹的索引,exp()表示指数函数,/>表示在第k个训练轮次中所预测的轨迹。
具体实施中,为了减轻跨域间交通参数者的表征漂移,利用生成的伪标签,本发明提出了轨迹主导的对比学习,利用对比学习模块将相似的轨迹表征拉近,不同的轨迹表征拉远。
在所述对比学习模块中,设置两个域Δ和且Δ,/>即表示两个域Δ和/>的每个域为源域或者目标域。
针对第一域Δ中的每个样本的融合特征,在第二域▽的所有样本的融合增强特征中根据轨迹一致性挑选和第一域Δ的样本的融合增强特征不接近融合增强特征并计算特征作为阳性特征在第二域▽的所有融合增强特征中选择和第一域Δ的样本的融合增强特征接近的融合增强特征并计算特征作为阴性特征/>
所述阳性特征的索引计算如下:
其中,表示当前批次在域/>中,与目标交通参与者的节点Θi同一类别(例如车辆、行人和自行车等)下的表征集合,/>其中n来自/>且/>i是域Δ中的元素,/>表示元素i的轨迹特征,/> 表示元素j的轨迹特征,/> 表示元素n的轨迹特征。
然后进行以下判断:
如果阳性特征和阴性特征/>其中之一不存在,则对比损失设置为零,不计算对比损失;
如果阳性特征和阴性特征/>均存在,则设置轨迹主导的对比损失为:
其中,sim()表示余弦相似度,ρc表示温度系数,同时实施域内和域间对比学习;表示第一域Δ中第i个样本的融合特征,/>表示第二域/>中第i个样本的融合特征;
进而遍历两个域Δ和相同和不相同以及先后顺序的四种情况,分别获得两个域内对比损失/>和/>以及两个域间对比损失/>和/>
具体来说,当 表示域内对比损失,记为/>和/>当/> 表示域间对比损失,记为/>和/>
设置域间小于域内/>以加速跨域特征收敛,根据两个域内对比损失和/>以及两个域间对比损失/>和/>进行相加获得总的对比损失/>表示为:
步骤2进行再次训练时,由预测器P输出源域预测损失和目标域预测损失由对比学习模块输出对比损失/>结合上述损失采用以下公式进行损失的计算,表达为:
其中,表示优化总损失,η表示损失权重,η∈(0,1)是一个超参数。
在整个方法中,输入一个已有的行为预测器P、有标注的源域和无标注的目标域经上述处理得到一个适用于行为预测的生成高质量伪标签的模型/网络。
3)最后以再次训练后的预测器P组成优化后的用于生成高质量伪标签的预测网络,将目标域输入到预测网络中进行预测处理获得最终结果。
具体实施中,本发明经大量的实验表明,方法能够有效地生成针对行为预测域自适应的高质量伪标签,进一步提升模型泛化,预测结果准确率更高。
Claims (10)
1.一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:方法步骤如下:
1)将源域输入到预测器P中进行训练,获得训练后的预测器P;
2)然后将目标域的数据输入到由训练后的预测器P优化获得的第一神经网络中进行再次训练,获得训练后的第一神经网络;
3)最后从训练后的第一神经网络提取组建预测网络,对实时待测的数据实进行识别预测处理。
2.根据权利要求1所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:所述的步骤2)中,第一神经网络是在预测器P的拓扑结构基础上增设了额外的伪标签更新模块和轨迹主导的对比学习模块;将目标域输入到训练后的预测器P中进行预测处理获得获得初步伪标签和特征并输入到伪标签更新模块中,同时将特征输入到轨迹主导的对比学习模块中处理获得对比损失并返回到伪标签更新模块,伪标签更新模块中根据初步伪标签进行处理获得优化伪标签,根据优化伪标签对损失进行计算获得优化总损失,以优化总损失最小为目标对训练后的预测器P、伪标签更新模块和轨迹主导的对比学习模块进行再次训练。
3.根据权利要求2所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:所述步骤2)中,预测器P的处理具体为:
S1、针对每个样本,在其他交通参数者和目标交通参数者之间的关系以及车道线点和目标交通参数者之间的关系下分别建立一个加权交通参数者图Ga和一个加权车道线点图Gm,利用加权交通参数者图Ga和一个加权车道线点图Gm进行掩码处理获得掩码处理后的样本数据;
S2、在掩码处理后,根据样本保留下的交通参数者和车道线点的历史行为状态和地图信息按照以下方式进行特征增强处理,获取第i个样本的四个增强特征:
其中,Ai、Mi分别表示第i个样本掩码处理后的历史行为状态和地图信息;和/>分别对应表示弱增强、强增强,/>分别表示弱增强后的交通参数者增强特征和强增强后的交通参数者增强特征,/>分别表示弱增强后的地图增强特征和强增强后的地图增强特征;φa、φm分别是预测器P中的交通参数者、地图的特征编码器;
再根据两个交通参数者增强特征和两个地图增强特征/>经过采取以下方式处理获取两个融合特征:
其中,分别表示弱增强后的融合增强特征和强增强后的融合增强特征;φf是预测器P中的融合特征编码器;
由上述四个增强特征和两个融合特征共同组成了预测器P输出的特征。
4.根据权利要求3所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:所述S1具体为:
A)样本在其他交通参数者和目标交通参数者之间的关系下,建立一个加权交通参数者图Ga,以其他交通参数者和目标交通参数者的轨迹坐标均作为节点,交通参数者为中心节点,其他交通参数者的节点均连接到目标交通参数者的节点,其他交通参数者的节点均连接到和目标交通参数者的节点之间通过边连接,每条边赋予边权值;
目标交通参数者的节点Θ与其他交通参数者的节点j之间的边权值wΘj设置为:
其中,ρa是缩放系数,x(Θ,0)、y(Θ,0)表示目标交通参数者的初始轨迹坐标,x(j,0)、y(j,0)表示第j个其他交通参数者的初始轨迹坐标;
针对每个其他交通参数者的节点j,分别设置以下节点的掩码概率以进行弱增强和强增强:
其中,wΘk表示目标交通参数者的节点Θ与其他交通参数者的节点k之间的边权值,k表示其他交通参数者的序数,k∈Va/Θ,Va是所有交通参数者的节点的集合,Va/Θ表示Va排除Θ,/表示排除;
在对其他交通参数者进行弱增强时,按照上述弱增强对应的掩码概率掩码百分之ma的其他交通参数者的历史行为状态,从而获得第i个样本掩码处理后的历史行为状态。
B)样本在车道线点和目标交通参数者之间的关系下,建立一个加权车道线点图Gm,以其他交通参数者的轨迹坐标和车道线点均作为节点,目标交通参数者为中心节点,车道线点均连接到目标交通参数者的节点,车道线点的节点均连接到和目标交通参数者的节点之间通过边连接,每条边赋予边权值;
目标交通参数者的节点Θ和车道线点的节点之间的边权值wΘj是与目标交通参数者的节点Θ和其他交通参数者的节点j之间的边权值wΘj设置相同。
针对每个车道线点的节点,也按照和其他交通参数者和目标交通参数者之间的关系下的相同计算方式设置节点的掩码概率以进行弱增强和强增强;
在对车道线点进行弱增强时,按照上述弱增强对应的掩码概率掩码百分之mm的车道线点的地图信息,从而获得第i个样本掩码处理后的地图信息。
5.根据权利要求1所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:所述步骤1)中,针对源域的预测器P的处理具体为:
将第i个样本的历史行为状态和地图信息分别输入到交通参数者特征编码器、地图特征编码器中获得交通参数者特征、地图特征,同时将样本的历史行为状态和地图信息共同输入到融合特征编码器中获得融合特征,具体表示为:
其中,φa、φm和φf是交通参数者、地图、融合的特征编码器;是预测解码器,和/>分别表示交通参数者特征、地图特征和融合特征,Ai、Mi分别表示第i个样本的历史行为状态和地图信息;
然后将融合特征输入到预测解码器中获得第i个样本的目标交通参数者的预测轨迹坐标及其置信分数Ci:
其中,表示第i个样本的目标交通参数者的预测轨迹坐标,Ci表示第i个样本的预测轨迹坐标/>的置信分数。
6.根据权利要求1所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:在伪标签更新模块中,在训练中的每次迭代时按照以下方式处理:
针对当前迭代轮次,计算第k迭代轮次下获得的初步伪标签和前k-1迭代轮次获得的初步伪标签集合/>之间的一致性进而获取第k迭代轮次下获得的初步伪标签/>的最佳匹配轨迹,最佳匹配轨迹的索引计算如下:
其中,h表示来自集合中的伪标签,/> 表示最佳匹配轨迹的索引,/>表示余弦相似度函数,/>表示最佳匹配轨迹的索引;
然后根据最佳匹配轨迹的索引进行以下判断:
如果满足或者/>则将第1迭代轮次下获得的初步伪标签/>设置为/>其中Tc是预设的置信分数阈值,/>分别表示初步伪标签/>的置信分数和/>的置信分数,伪标签/>不更新;
如果不满足或者/>则进一步进行以下判断:
若满足则设置第k迭代轮次下更新后的伪标签/>并更新前k-1迭代轮次获得的初步伪标签集合/>进行/>
若不满足则设置第k迭代轮次下更新后的伪标签/>并更新前k-1迭代轮次获得的初步伪标签集合/>进行/>
从而获得第k迭代轮次下更新后的伪标签。
7.根据权利要求6所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:所述伪标签更新模块根据更新后的伪标签按照以下公式计算目标域损失作为优化伪标签:
其中,Γ(·)是一个截断函数,ρt是超参数,是P中的损失函数,/>表示当前第k迭代轮次下获得的初步伪标签,/>表示最佳匹配轨迹的索引,exp()表示指数函数,/>表示在第k个训练轮次中所预测的轨迹。
8.根据权利要求1所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:在所述对比学习模块中,设置两个域Δ和且/>即表示两个域Δ和/>的每个域为源域或者目标域,然后:
针对第一域Δ中的每个样本的融合特征,在第二域的所有样本的融合增强特征中根据轨迹一致性挑选和第一域Δ的样本的融合增强特征不接近融合增强特征并计算特征作为阳性特征/>在第二域/>的所有融合增强特征中选择和第一域Δ的样本的融合增强特征接近的融合增强特征并计算特征作为阴性特征/>然后进行以下判断:
如果阳性特征和阴性特征/>其中之一不存在,则对比损失设置为零;
如果阳性特征和阴性特征/>均存在,则设置轨迹主导的对比损失为:
其中,sim()表示余弦相似度,ρc表示温度系数,同时实施域内和域间对比学习;表示第一域Δ中第i个样本的融合特征,/>表示第二域/>中第i个样本的融合特征;
进而遍历两个域Δ和相同和不相同以及先后顺序的四种情况,分别获得两个域内对比损失/>和/>以及两个域间对比损失/>和/>
根据两个域内对比损失和/>以及两个域间对比损失/>和/>进行相加获得总的对比损失/>表示为:
9.根据权利要求1所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:所述步骤2)进行再次训练时,由预测器P输出源域预测损失和目标域预测损失/>由对比学习模块输出对比损失/>结合上述损失采用以下公式进行损失的计算,表达为:
其中,表示优化总损失,η表示损失权重。
10.根据权利要求1所述的一种用于行为预测域自适应的伪标签改进的轨迹预测方法,其特征在于:所述步骤3)中,具体是最后从训练后的第一神经网络提取出预测器P组成预测网络,将待测场景的数据实时输入到预测网络中进行预测处理获得最终结果。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310977528.3A CN116956045A (zh) | 2023-08-04 | 2023-08-04 | 一种用于行为预测域自适应的伪标签改进的轨迹预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310977528.3A CN116956045A (zh) | 2023-08-04 | 2023-08-04 | 一种用于行为预测域自适应的伪标签改进的轨迹预测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116956045A true CN116956045A (zh) | 2023-10-27 |
Family
ID=88449163
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310977528.3A Withdrawn CN116956045A (zh) | 2023-08-04 | 2023-08-04 | 一种用于行为预测域自适应的伪标签改进的轨迹预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116956045A (zh) |
-
2023
- 2023-08-04 CN CN202310977528.3A patent/CN116956045A/zh not_active Withdrawn
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111310583B (zh) | 一种基于改进的长短期记忆网络的车辆异常行为识别方法 | |
CN113326731B (zh) | 一种基于动量网络指导的跨域行人重识别方法 | |
CN108549895A (zh) | 一种基于对抗网络的半监督语义分割方法 | |
CN111967294A (zh) | 一种无监督域自适应的行人重识别方法 | |
CN108399406A (zh) | 基于深度学习的弱监督显著性物体检测的方法及系统 | |
CN111275711A (zh) | 基于轻量级卷积神经网络模型的实时图像语义分割方法 | |
CN112488025B (zh) | 基于多模态特征融合的双时相遥感影像语义变化检测方法 | |
CN114492574A (zh) | 基于高斯均匀混合模型的伪标签损失无监督对抗域适应图片分类方法 | |
CN113313166B (zh) | 基于特征一致性学习的船舶目标自动标注方法 | |
CN113010683B (zh) | 基于改进图注意力网络的实体关系识别方法及系统 | |
CN111860255A (zh) | 驾驶检测模型的训练、使用方法、装置、设备及介质 | |
CN112395957A (zh) | 一种针对视频目标检测的在线学习方法 | |
KR102592935B1 (ko) | 신경망 모델 학습 방법 및 장치, 컴퓨터 프로그램 | |
CN113255837A (zh) | 工业环境下基于改进的CenterNet网络目标检测方法 | |
CN111680702A (zh) | 一种使用检测框实现弱监督图像显著性检测的方法 | |
CN117058024A (zh) | 一种基于Transformer的高效去雾语义分割方法及其应用 | |
CN116433957A (zh) | 一种基于半监督学习的智能驾驶感知方法 | |
CN114693979A (zh) | 一种基于伪标签修正的多目标跟踪无监督域适应方法 | |
CN112163490A (zh) | 一种基于场景图片的目标检测方法 | |
CN114742224A (zh) | 行人重识别方法、装置、计算机设备及存储介质 | |
CN112927266A (zh) | 基于不确定性引导训练的弱监督时域动作定位方法及系统 | |
CN117195031A (zh) | 一种基于神经网络和知识图谱双通道系统的电磁辐射源个体识别方法 | |
CN114626461A (zh) | 基于领域自适应的跨域目标检测方法 | |
CN118230286A (zh) | 一种基于改进YOLOv7的车辆与行人识别方法 | |
CN113569814A (zh) | 一种基于特征一致性的无监督行人重识别方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
WW01 | Invention patent application withdrawn after publication |
Application publication date: 20231027 |
|
WW01 | Invention patent application withdrawn after publication |