CN116861262B - 一种感知模型训练方法、装置及电子设备和存储介质 - Google Patents

一种感知模型训练方法、装置及电子设备和存储介质 Download PDF

Info

Publication number
CN116861262B
CN116861262B CN202311128070.0A CN202311128070A CN116861262B CN 116861262 B CN116861262 B CN 116861262B CN 202311128070 A CN202311128070 A CN 202311128070A CN 116861262 B CN116861262 B CN 116861262B
Authority
CN
China
Prior art keywords
interference
network
training
perception model
branch
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202311128070.0A
Other languages
English (en)
Other versions
CN116861262A (zh
Inventor
张腾飞
李茹杨
刘广庆
张恒
邓琪
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Suzhou Inspur Intelligent Technology Co Ltd
Original Assignee
Suzhou Inspur Intelligent Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Suzhou Inspur Intelligent Technology Co Ltd filed Critical Suzhou Inspur Intelligent Technology Co Ltd
Priority to CN202311128070.0A priority Critical patent/CN116861262B/zh
Publication of CN116861262A publication Critical patent/CN116861262A/zh
Application granted granted Critical
Publication of CN116861262B publication Critical patent/CN116861262B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本申请公开了一种感知模型训练方法、装置及电子设备和存储介质,涉及计算机技术领域,该方法包括:获取云端感知模型;其中,云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;将云端感知模型分解为基础分支和抗干扰分支;其中,基础分支包括第一预设数量个基础元网络,抗干扰分支包括第二预设数量个抗干扰元网络;基于正常场景数据训练基础分支,基于干扰场景数据训练抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于正常场景数据和干扰场景数据训练中间云端感知模型得到训练完成的云端感知模型。本申请提高了车端感知模型的鲁棒性。

Description

一种感知模型训练方法、装置及电子设备和存储介质
技术领域
本申请涉及计算机技术领域,更具体地说,涉及一种感知模型训练方法、装置及电子设备和存储介质。
背景技术
感知系统是自动驾驶的关键组成部分,用于感知和理解车辆周围环境,当前主流的自动驾驶感知模型,通常以摄像机、激光雷达等传感器采集的数据作为输入,输出对周围环境的感知结果,如道路目标检测、车道线检测、可行驶区域分割等。基于深度学习的感知模型已取得显著进展,但模型鲁棒性仍面临严重挑战,面对交通场景中的各种恶劣环境因素(如炫光、弱光、雨、雪)、传感器故障(摄像机损坏、激光雷达损坏等)以及恶意攻击等情况,模型准确性难以保证,对自动驾驶安全构成严重威胁。
当前提升感知模型鲁棒性的方法通常只针对一种或某几种干扰因素,没有涵盖尽可能多的干扰因素,所获得的模型只能在一种或某几种干扰因素下具备较好的鲁棒性,当模型部署到车端,仍容易遭受其他干扰因素的影响,如针对恶劣天气因素训练的模型,仍会受到恶意攻击的影响,导致鲁棒性下降,威胁自动驾驶安全。
因此,如何提高感知模型的鲁棒性是本领域技术人员需要解决的技术问题。
发明内容
本申请的目的在于提供一种感知模型训练方法、装置及一种电子设备和一种计算机可读存储介质,提高了感知模型的鲁棒性。
为实现上述目的,本申请提供了一种感知模型训练方法,包括:
获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型。
其中,所述云端感知模型包括依次连接的输入模块、基础网络、云端元网络、元知识融合网络和任务网络,所述云端元网络包括基础元网络组和抗干扰元网络组,所述基础元网络组包括第一预设数量个基础元网络,所述抗干扰元网络组包括第二预设数量个抗干扰元网络;
所述基础分支包括依次连接的所述输入模块、所述基础网络、所述基础元网络组、所述元知识融合网络和所述任务网络,所述抗干扰分支包括依次连接的所述输入模块、所述基础网络、所述抗干扰元网络组、所述元知识融合网络和所述任务网络。
其中,所述输入模块包括多个单模态的输入单元,每个所述输入单元通过对应的基础网络连接多模态融合网络,所述多模态融合网络连接所述云端元网络。
其中,所述基础网络包括依次连接的预处理模块、骨干网络和多尺度特征提取网络。
其中,所述基础元网络包括第一基础子网络、第二基础子网络、第三基础子网络,所述第一基础子网络、所述第二基础子网络、所述第三基础子网络包括多层卷积神经网络或全连接层或注意力层。
其中,所述抗干扰元网络包括第一抗干扰子网络、第二抗干扰子网络、第三抗干扰子网络,所述第一抗干扰子网络的模型结构与所述第一基础子网络的模型结构相同,所述第二抗干扰子网络包括所有所述第二基础子网络、每个所述第二基础子网络对应的特征学习模型和与所有所述特征学习模型连接的特征融合模块,所述第三抗干扰子网络包括多层卷积神经网络或全连接层或注意力层。
其中,所述特征学习模型用于将对应的第二基础子网络输出的特征与对应的学习参数相乘得到对应的学习特征。
其中,所述特征融合模块用于对所有所述学习特征采用均值融合方式进行聚合得到聚合特征。
其中,所述基于正常场景数据训练所述基础分支,包括:
在正常场景数据中采样正常训练样本,将所述正常训练样本输入所述基础分支中,基于所述任务网络的损失对所述基础分支的模型参数进行训练,得到训练完成的基础分支。
其中,所述基于干扰场景数据训练所述抗干扰分支,包括:
基于训练完成的基础分支初始化所述抗干扰分支,并冻结初始化完成的抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;
在干扰场景数据中采样干扰训练样本,将所述干扰训练样本输入所述初始化完成的抗干扰分支中,基于所述任务网络的损失对所述初始化完成的抗干扰分支的模型参数进行训练,得到训练完成的抗干扰分支。
其中,所述基于训练完成的基础分支初始化所述抗干扰分支,包括:
基于训练完成的基础分支中基础元网络中的第一基础子网络的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;
基于训练完成的基础分支中基础元网络中的第二基础子网络构建所述抗干扰分支中每个抗干扰元网络中的第二抗干扰子网络。
其中,所述基于训练完成的基础分支中基础元网络中的第一基础子网络的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数,包括:
对训练完成的基础分支中基础元网络中的所有第一基础子网络的参数进行均值融合,基于融合后的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数。
其中,所述任务网络包括多个子任务网络。
其中,所述基于所述任务网络的损失对所述初始化完成的抗干扰分支的模型参数进行训练,得到训练完成的抗干扰分支,包括:
计算所述任务网络中多个所述子任务网络的损失和,基于所述损失和计算所述初始化完成的抗干扰分支中每个抗干扰元网络的输出特征的每个维度的梯度,对每个所述输出特征的每个维度的梯度进行绝对值求和得到每个所述输出特征的梯度和;
对所有所述输出特征的梯度和由大至小进行排序,确定排序结果中前第五预设数量个输出特征对应的抗干扰元网络作为目标抗干扰元网络;
冻结除所述目标抗干扰元网络之外的其他抗干扰元网络的参数,更新所述目标抗干扰元网络的参数,得到训练完成的抗干扰分支。
其中,所述干扰训练样本包括干扰数据和对应的标注,所述标注包括任务标注和干扰因素类型。
其中,所述合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,包括:
对训练完成的基础分支中的基础网络的参数和训练完成的抗干扰分支中的基础网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;
对训练完成的基础分支中的元知识融合网络的参数和训练完成的抗干扰分支中的元知识融合网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;
对训练完成的基础分支中的任务网络的参数和训练完成的抗干扰分支中的任务网络的参数进行加权合并,得到中间云端感知模型中任务网络的参数;
将训练完成的基础分支中的基础元网络组和训练完成的抗干扰分支中的抗干扰元网络组合并为中间云端感知模型中的云端元网络。
其中,所述基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型,包括:
确定正常训练样本的第一采样率或干扰训练样本的第二采样率;其中,所述第一采样率与所述第二采样率的和为一;
基于所述第一采样率或所述第二采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;
将所述干扰训练样本输入所述中间云端感知模型中,基于所述任务网络的损失对所述中间云端感知模型的模型参数进行训练,得到训练完成的云端感知模型。
为实现上述目的,本申请提供了一种感知模型训练方法,包括:
获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;
基于所述训练完成的云端感知模型构建车端感知模型;其中,所述车端感知模型包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络,所述第三预设数量小于所述第一预设数量,所述第四预设数量小于所述第二预设数量;
基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型。
其中,所述车端感知模型包括依次连接的输入模块、基础网络、车端元网络、元知识融合网络和所述任务网络,车端元网络包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络。
其中,所述基于所述训练完成的云端感知模型构建车端感知模型,包括:
将所述训练完成的云端感知模型中云端元网络中的所有基础元网络划分为所述第三预设数量个类别;
对每个类别中基础元网络的参数进行均值融合得到车端感知模型中对应的基础元网络的参数;
将所述训练完成的云端感知模型中云端元网络中的所有抗干扰元网络划分为所述第四预设数量个类别;
对每个类别中抗干扰元网络的参数进行均值融合得到车端感知模型中对应的抗干扰元网络的参数。
其中,所述基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型,包括:
确定正常训练样本的第三采样率或干扰训练样本的第四采样率;其中,所述第三采样率与所述第四采样率的和为一;
基于所述第三采样率或所述第四采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;
将所述干扰训练样本输入所述车端感知模型中,基于所述任务网络的损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型。
其中,所述基于所述任务网络的损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型,包括:
基于所述任务网络的损失和知识蒸馏损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型;
其中,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的多模态融合网络输出的融合特征和所述车端感知模型中的多模态融合网络输出的融合特征计算得到的,或,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的元知识融合网络输出的融合特征和所述车端感知模型中的元知识融合网络输出的融合特征计算得到的。
其中,所述知识蒸馏损失的计算公式为:
其中,为所述知识蒸馏损失,/>为所述训练完成的云端感知模型中的多模态融合网络或元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,/>所述车端感知模型中的多模态融合网络或元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,1≤i≤W,1≤j≤H,1≤k≤C,W为融合特征的宽度,H为融合特征的高度,C为融合特征的通道数量。
为实现上述目的,本申请提供了一种感知模型训练装置,包括:
获取单元,用于获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
分解单元,用于将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
第一训练单元,用于基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型。
为实现上述目的,本申请提供了一种感知模型训练装置,包括:
获取单元,用于获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
分解单元,用于将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
第一训练单元,用于基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;
构建单元,用于基于所述训练完成的云端感知模型构建车端感知模型;其中,所述车端感知模型包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络,所述第三预设数量小于所述第一预设数量,所述第四预设数量小于所述第二预设数量;
第二训练单元,用于基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型。
为实现上述目的,本申请提供了一种电子设备,包括:
存储器,用于存储计算机程序;
处理器,用于执行所述计算机程序时实现如上述感知模型训练方法的步骤。
为实现上述目的,本申请提供了一种计算机可读存储介质,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如上述感知模型训练方法的步骤。
通过以上方案可知,本申请提供的一种感知模型训练方法,包括:获取单元,用于获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;分解单元,用于将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;第一训练单元,用于基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;构建单元,用于基于所述训练完成的云端感知模型构建车端感知模型;其中,所述车端感知模型包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络,所述第三预设数量小于所述第一预设数量,所述第四预设数量小于所述第二预设数量;第二训练单元,用于基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型。
本申请提供的感知模型训练方法,针对云端感知模型,将其划分为基础分支和抗干扰分支,分别基于正常场景数据和干扰数据进行有效训练,提升云端感知模型对干扰数据的学习效果,提高了感知模型的鲁棒性。本申请还公开了一种感知模型训练装置及一种电子设备和一种计算机可读存储介质,同样能实现上述技术效果。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性的,并不能限制本申请。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。附图是用来提供对本公开的进一步理解,并且构成说明书的一部分,与下面的具体实施方式一起用于解释本公开,但并不构成对本公开的限制。在附图中:
图1为根据一示例性实施例示出的一种感知模型训练方法的流程图;
图2为根据一示例性实施例示出的一种云端感知模型的结构图;
图3为根据一示例性实施例示出的一种基础元网络的结构图;
图4为根据一示例性实施例示出的一种抗干扰元网络的结构图;
图5为根据一示例性实施例示出的一种云端感知模型中基础分支的结构图;
图6为根据一示例性实施例示出的一种云端感知模型中抗干扰分支的结构图;
图7为根据一示例性实施例示出的一种云端感知模型中基础分支、抗干扰分支的元知识融合网络状态示意图;
图8为根据一示例性实施例示出的另一种车端感知模型训练方法的流程图;
图9为根据一示例性实施例示出的一种车端感知模型的结构图;
图10为根据一示例性实施例示出的一种车端感知模型知识蒸馏示意图;
图11为本申请提供的一种应用实施例的流程图;
图12为根据一示例性实施例示出的一种感知模型训练装置的结构图;
图13为根据一示例性实施例示出的另一种感知模型训练装置的结构图;
图14为根据一示例性实施例示出的一种电子设备的结构图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述。显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。另外,在本申请实施例中,“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。
感知大模型(即云端感知模型)是具有大规模参数的模型,其参数量通常为亿级别以上,甚至可达百亿、千亿级别,感知大模型可以为单模态或多模态大模型。相比传统的感知小模型(即车端感知模型),感知大模型参数量更多,模型的表达能力和鲁棒性更强。本发明提出一种利用大模型提高车端感知模型鲁棒性的方法。首先,本发明构建了一种感知大模型结构,包含基础元网络组和抗干扰元网络组,分别负责正常场景数据和干扰场景,在本领域现有的感知大模型方案中,缺少处理干扰场景的专用网络模块,限制了感知大模型对干扰因素的鲁棒性,本发明通过显式的为正常场景和干扰场景分别设置元网络组,既能确保感知大模型对正常场景的准确感知,又能提升感知大模型对干扰因素的鲁棒性。其次,相对正常场景,干扰数据样本量更少,抗干扰元网络容易过拟合,因此本发明构建了一种纺锤结构的抗干扰元网络,该纺锤结构的抗干扰元网络是本领域现有感知大模型方案所不具备的,有着更强的表达能力和抗过拟合能力,可有效提升抗干扰元网络的鲁棒性。进一步地,采用通用的训练策略训练抗干扰元网络会限制感知大模型训练效果,本领域现有感知大模型方案普遍缺少针对干扰场景的训练策略,本发明提出一种抗干扰训练方法,包含分支训练方法和基于梯度的优化训练方法,可显著提高感知大模型的训练效果。最后,本发明利用训练完成的感知大模型对车端感知模型进行知识蒸馏,以获得具备高准确性、高鲁棒性的车端感知模型。
综上所述,本发明提出的方法使车端感知模型具备对恶劣环境、硬件故障以及恶意攻击等多种干扰因素的鲁棒性,进而提升自动驾驶系统的安全性。
本申请实施例公开了一种感知模型训练方法,提高了感知模型的鲁棒性。
参见图1,根据一示例性实施例示出的一种车端感知模型训练方法的流程图,如图1所示,包括:
S101:获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
在本实施例中,云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络。作为一种可行的实施方式,所述云端感知模型包括依次连接的输入模块、基础网络、云端元网络、元知识融合网络和任务网络,所述云端元网络包括基础元网络组和抗干扰元网络组,所述基础元网络组包括第一预设数量个基础元网络,所述抗干扰元网络组包括第二预设数量个抗干扰元网络。在具体实施中,云端感知模型如图2所示,包括依次连接的输入模块、基础网络、云端元网络、元知识融合网络和任务网络。作为一种优选实施方式,所述输入模块包括多个单模态的输入单元,每个所述输入单元通过对应的基础网络连接多模态融合网络,所述多模态融合网络连接所述云端元网络。基础网络包括依次连接的预处理模块、骨干网络和多尺度特征提取网络。骨干网络可采用卷积神经网络如ResNet(Residual Network,残差网络),Transformer 网络如Swin Transformer等,多尺度特征提取网络可采用特征金字塔网络,多模态特征融合网络可采用卷积神经网络层、全连接层或Transformer交叉注意力层组成,多模态特征融合网络输出融合后的特征,随后输入云端元网络。云端元网络包括基础元网络组和抗干扰元网络组,基础元网络组包括第一预设数量个基础元网络,抗干扰元网络组包括第二预设数量个抗干扰元网络。元网络可由卷积神经网络、多层感知机等网络结构组成。元知识融合网络以各元网络的输出作为输出,对各元网络的输出进行融合,获得元知识融合后的特征,并输入到任务网络中。元知识融合网络可采用卷积神经网络层、全连接层或Transformer交叉注意力层组成。任务网络可以包括多个子任务网络,如3D目标检测、车道线分割、目标跟踪等子任务。
对于基础元网络组,本实施例构建Enormal个基础元网络,例如Enormal=100。对于抗干扰元网络组,本实施例构建Edisturb个抗干扰元网络/>,具体地,本实施例为每类干扰因素构建Eper_disturb个抗干扰元网络,Eper_disturb≥1,例如对20类干扰因素,每类构建5个抗干扰元网络,此时Eper_disturb=5,Edisturb=100。通过对每类干扰因素设置多个抗干扰元网络,可显著提升对干扰因素的鲁棒性。
作为一种可行的实施方式,所述基础元网络包括第一基础子网络、第二基础子网络、第三基础子网络,所述第一基础子网络、所述第二基础子网络、所述第三基础子网络包括多层卷积神经网络或全连接层或注意力层。在具体实施中,如图3所示,基础元网络包含三个子网络:第一基础子网络、第二基础子网络和第三基础子网络,其中,每个基础子网络可由多层卷积神经网络或全连接网络等网络结构组成。
作为一种优选的实施方式,所述抗干扰元网络包括第一抗干扰子网络、第二抗干扰子网络、第三抗干扰子网络,所述第一抗干扰子网络的模型结构与所述第一基础子网络的模型结构相同,所述第二抗干扰子网络包括所有所述第二基础子网络、每个所述第二基础子网络对应的特征学习模型和与所有所述特征学习模型连接的特征融合模块,所述第三抗干扰子网络包括多层卷积神经网络和全连接层。
如图4所示,本实施例提出的纺锤结构的抗干扰元网络包含三个子网络:第一抗干扰子网络、第二抗干扰子网络、第三抗干扰子网络,其中,抗干扰元网络的第一抗干扰子网络的模型结构与基础元网络的第一进出子网络相同,抗干扰元网络的参数由基础元网络的第一子网络的元参数φ初始化,该元参数φ通过对所有Enormal个基础元网络的第一子网络参数计算均值获得,用φi表示第i个基础元网络的第一基础子网络的参数,则元参数φ的计算公式为:
第二抗干扰子网络由基础元网络组中每个基础元网络的第二基础子网络构成,每个第二基础子网络对应一个特征学习模型,特征学习模型用于将对应的第二基础子网络输出的特征与对应的学习参数相乘得到对应的学习特征。在具体实施中,对于第i个基础元网络的第二基础子网络的输出特征Fi,设置对应的可学习参数αi,每个抗干扰元网络的第二基础子网络的可学习参数可以相同,也可以不同,在此不进行具体限定。该可学习参数αi与对应的基础元网络的第二基础子网络的输出特征Fi相乘,得到特征Fi ,即Fi i×Fi。特征融合模块用于对所有所述学习特征采用均值融合方式进行聚合得到聚合特征。在具体实施中,将全部的Enormal个特征Fi 输入到第二抗干扰子网络的特征聚合模块中,将这些特征聚合为一个特征,如采用均值聚合方法得到聚合特征/>的计算公式为:
该公式表示对每个特征Fi 对应特征值求和,再求均值。将获得的聚合特征输入到第三抗干扰子网络中继续处理,第三抗干扰子网络的输出特征/>即为该纺锤结构的抗干扰元网络的输出特征,第三抗干扰子网络可由多层卷积神经网络或全连接网络等网络结构组成。
S102:将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
在本步骤中,将云端感知模型分解为两个分支:基础分支和抗干扰分支,基础分支包括第一预设数量个基础元网络,抗干扰分支包括第二预设数量个抗干扰元网络。作为一种可行的实施方式,所述基础分支包括依次连接的所述输入模块、所述基础网络、所述基础元网络组、所述元知识融合网络和所述任务网络。在具体实施中,基础分支如图5所示,包括输入模块、基础网络(预处理模块、骨干网络、多尺度特征提取网络)、基础元网络组、元知识融合网络和任务网络,当输入模块包括多个单模态的输入单元时,基础分支还包括多模态融合网络,多个基础网络连接多模态融合网络,多模态融合网络连接基础元网络组。作为一种可行的实施方式,所述抗干扰分支包括依次连接的所述输入模块、所述基础网络、所述抗干扰元网络组、所述元知识融合网络和所述任务网络。在具体实施中,抗干扰分支如图6所示,包括输入模块、基础网络(预处理模块、骨干网络、多尺度特征提取网络)、抗干扰元网络组、元知识融合网络和任务网络,当输入模块包括多个单模态的输入单元时,抗干扰分支还包括多模态融合网络,多个基础网络连接多模态融合网络,多模态融合网络连接抗干扰元网络组。如图7所示,在基础分支中,元知识融合网络仅来自基础元网络的输入激活,而在抗干扰分支中,元知识融合网络仅来自抗干扰元网络的输入激活。
S103:基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;
正常场景数据为摄像机采集的图像数据和激光雷达采集的点云数据,其中不存在干扰因素。干扰场景数据为存在干扰因素的图像数据或点云数据,干扰因素主要可分为恶劣环境和硬件故障。恶劣环境包含影响摄像机成像质量的各种光照因素,如:炫光、反光等;雨、雪、雾、霾等天气条件会影响摄像机和激光雷达采集的数据质量;当各种传感器发生硬件故障时,会造成采集的数据质量降低甚至完全失效。如摄像机成像时产生的噪声以及一些坏点;激光雷达也可能出现故障,导致对部分视角无法采集到点云数据。
对于正常场景数据或干扰数据,均由数据帧F表示,每帧数据包含至少一种模态的数据,如,表示第i帧数据Fi包含M种模态,模态包括图像、点云等,其中每种模态样本集合又包含至少一条数据,如/>,表示第1个模态/>包含N条数据。对于干扰数据帧,其中包含至少一条干扰数据。数据集中的数据来源包括:在车辆行驶过程中采集的道路场景数据;通过仿真软件获得的道路场景数据;通过模型生成的道路场景数据。
可见,本发明实施例针对各类干扰因素进行数据采集,为提高云端感知模型对干扰因素的鲁棒性提供了良好的数据基础。
进一步的,对数据集中的样本进行标注,包括任务标注,如对目标检测任务,标注每个样本中感兴趣目标的边界框、类别等信息,对场景分割任务,标注每个像素或点云的类别。对于干扰样本,除了针对特定任务的任务标注外,还需要标注该数据帧中包含的干扰因素类型集合,如{p1,p2,p3}表示当前数据帧中包含3种干扰因素类型。样本标注可采用人工标注,或采用云端感知模型的预测结果作为标注,以及人工标注和云端感知模型预测相结合的方式。
分支训练方法分为三个训练阶段:第一训练阶段、第二训练阶段、第三训练阶段。在第一训练阶段中,在正常场景数据上训练基础分支。作为一种可行的实施方式,所述基于正常场景数据训练所述基础分支,包括:在正常场景数据中采样正常训练样本,将所述正常训练样本输入所述基础分支中,基于所述任务网络的损失对所述基础分支的模型参数进行训练,得到训练完成的基础分支。在具体实施中,每次训练从正常场景数据中采样一个批次的训练样本,如批次大小设定为32,即表示采样32个数据帧。优化器可采用GradientDescent (GD,梯度下降)、Adaptive Moment estimation(Adam,自适应矩估计)、AdamWeight Decay Regularization(AdamW,解耦梯度下降与权重衰减正则化)等。使用各子任务网络计算获得的损失,如3D目标检测子任务网络中的目标分类损失和位置回归损失等,反向传播优化基础分支。
在第二训练阶段中,在干扰数据上训练云端感知模型的抗干扰分支。作为一种可行的实施方式,所述基于干扰场景数据训练所述抗干扰分支,包括:基于训练完成的基础分支初始化所述抗干扰分支,并冻结初始化完成的抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;在干扰场景数据中采样干扰训练样本,将所述干扰训练样本输入所述初始化完成的抗干扰分支中,基于所述任务网络的损失对所述初始化完成的抗干扰分支的模型参数进行训练,得到训练完成的抗干扰分支。
作为一种可行的实施方式,所述基于训练完成的基础分支初始化所述抗干扰分支,包括:基于训练完成的基础分支中基础元网络中的第一基础子网络的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;基于训练完成的基础分支中基础元网络中的第二基础子网络构建所述抗干扰分支中每个抗干扰元网络中的第二抗干扰子网络。
在具体实施中,首先,使用训练完成的云端感知模型基础分支的基础元网络的第一基础子网络的元参数初始化云端感知模型抗干扰分支的每个抗干扰元网络的第一抗干扰子网络,作为一种可行的实施方式,对训练完成的基础分支中基础元网络中的所有第一基础子网络的参数进行均值融合,基于融合后的参数初始化抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数。然后,并冻结每个抗干扰元网络的第一抗干扰子网络的参数。其次,使用训练完成的云端感知模型基础分支的基础元网络的第二基础子网络构建云端感知模型抗干扰分支的每个抗干扰元网络的第二抗干扰子网络,对于全部的抗干扰元网络的第二抗干扰子网络,共享云端感知模型基础分支的基础元网络的第二基础子网络,需要注意的是,云端感知模型基础分支的基础元网络的第二基础子网络参数被冻结。每次训练从干扰数据中采样一个批次的训练样本,如批次大小设定为32,即表示采样32个数据帧。使用各子任务网络计算获得的损失,如3D目标检测子任务网络中的目标分类损失和位置回归损失等,反向传播优化云端感知模型抗干扰分支。
作为一种优选的实施方式,所述基于所述任务网络的损失对所述初始化完成的抗干扰分支的模型参数进行训练,得到训练完成的抗干扰分支,包括:计算所述任务网络中多个所述子任务网络的损失和,基于所述损失和计算所述初始化完成的抗干扰分支中每个抗干扰元网络的输出特征的每个维度的梯度,对每个所述输出特征的每个维度的梯度进行绝对值求和得到每个所述输出特征的梯度和;对所有所述输出特征的梯度和由大至小进行排序,确定排序结果中前第五预设数量个输出特征对应的抗干扰元网络作为目标抗干扰元网络;冻结除所述目标抗干扰元网络之外的其他抗干扰元网络的参数,更新所述目标抗干扰元网络的参数,得到训练完成的抗干扰分支。
为了提高云端感知模型对抗元网络的训练效果,采用基于梯度的优化训练方法。用Li表示第i个子任务的损失函数,则所有N个子任务的损失函数的损失和为,计算云端感知模型第j个抗干扰元网络的输出特征/>的梯度为:/>,输出特征/>包含多个维度,计算得到的梯度包含多个维度,对梯度中每个维度的绝对值求和得到每个输出特征的梯度和/>。例如,假设特征为(1,64,100,100),计算得到的该特征的梯度也是(1,64,100,100),则将这个梯度中所有值取绝对值,然后求和。对全部Edisturb个抗干扰元网络的输出特征的梯度和进行降序排序,获取第五预设数量K个最大的梯度和对应的抗干扰元网络/>,也即筛选出K各训练效果较差的抗干扰元网络。在该次训练迭代中,更新该K个抗干扰元网络的第一抗干扰子网络的参数,对于剩余的抗干扰元网络的第一抗干扰子网络,仍冻结其参数。
在第三训练阶段中,首先合并训练完成的云端感知模型基础分支和抗干扰分支,得到完整的云端感知模型,随后从数据集中选取正常场景样本和干扰样本,训练完整的云端感知模型。
作为一种可行的实施方式,所述合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,包括:对训练完成的基础分支中的基础网络的参数和训练完成的抗干扰分支中的基础网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;对训练完成的基础分支中的元知识融合网络的参数和训练完成的抗干扰分支中的元知识融合网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;对训练完成的基础分支中的任务网络的参数和训练完成的抗干扰分支中的任务网络的参数进行加权合并,得到中间云端感知模型中任务网络的参数;将训练完成的基础分支中的基础元网络组和训练完成的抗干扰分支中的抗干扰元网络组合并为中间云端感知模型中的云端元网络。
在将第一训练阶段、第二训练阶段获得的云端感知模型基础分支和云端感知模型抗干扰分支组合为完整的云端感知模型时,每个输入模态的预处理模块、骨干网络、多尺度特征提取网络、多模态融合网络、元知识融合网络、以及任务网络的参数从云端感知模型基础分支和抗干扰分支的参数中通过加权求和获取,例如完整云端感知模型的参数为w,云端感知模型基础分支对应的参数为wnormal,云端感知模型抗干扰分支对应的参数为wdisturb,则w=βwnormal+(1-β)wdisturb,且β∈[0,1];基础元网络组从云端感知模型的基础分支中获取,抗干扰元网络组从云端感知模型的抗干扰分支中获取。
作为一种优选的实施方式,所述基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型,包括:确定正常训练样本的第一采样率或干扰训练样本的第二采样率;其中,所述第一采样率与所述第二采样率的和为一;基于所述第一采样率或所述第二采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;将所述干扰训练样本输入所述中间云端感知模型中,基于所述任务网络的损失对所述中间云端感知模型的模型参数进行训练,得到训练完成的云端感知模型。
在具体实施中,对正常训练样本和干扰训练样本采取对应的采样概率,正常训练样本的采样概率为αnormal,干扰训练样本的采样概率为αdisturb,其中,αnormaldisturb=1。例如,αnormal=0.5、αdisturb0.5。在模型训练过程中,每次训练前,根据前述采样概率首先确定当前选取的是正常训练样本还是干扰训练样本,如果为正常训练样本,则直接选取正常训练样本;如果为干扰训练样本,则从包含至少一种干扰因素的数据帧中采样干扰数据帧。上述采样策略通过调整采样概率,可以平衡正常场景和干扰数据样本量的差异,可以提升模型对各类干扰因素的鲁棒性,同时保持模型在正常场景下的性能。每次训练时,采用上述采样策略从正常场景数据和干扰数据中采样一个批次的训练样本,训练完整的云端感知模型,显著提高云端感知模型的训练效果,得到更鲁棒的云端感知模型。
本申请实施例提供的感知模型训练方法,针对云端感知模型,将其划分为基础分支和抗干扰分支,分别基于正常场景数据和干扰数据进行有效训练,提升云端感知模型对干扰数据的学习效果,提高了感知模型的鲁棒性。
本申请实施例公开了一种感知模型训练方法,相对于上一实施例,本实施例对技术方案作了进一步的说明和优化。具体的:
参见图8,根据一示例性实施例示出的另一种感知模型训练方法的流程图,如图8所示,包括:
S201:获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
S202:将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
S203:基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;
S204:基于所述训练完成的云端感知模型构建车端感知模型;其中,所述车端感知模型包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络,所述第三预设数量小于所述第一预设数量,所述第四预设数量小于所述第二预设数量;
车端感知模型的结构与云端感知模型类似,区别在于车端感知模型包含的基础元网络和抗干扰元网络的数量少于云端感知模型。作为一种可行的实施方式,所述车端感知模型包括依次连接的输入模块、基础网络、车端元网络、元知识融合网络和所述任务网络,车端元网络包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络。在具体实施中,基础网络包括依次连接的预处理模块、骨干网络和多尺度特征提取网络,骨干网络可采用ResNet等;多尺度特征提取网络,如特征金字塔网络;多模态特征融合网络可采用卷积神经网络层、全连接层或Transformer交叉注意力层组成,多模态特征融合网络输出融合后的特征,随后输入元网络,元网络由至少一个基础元网络和至少一个抗干扰元网络构成,如图9所示,分别设置了一个基础元网络和一个抗干扰元网络。知识融合网络以各元网络的输出作为输出,对各元网络的输出进行融合,获得元知识融合后的特征,并输入到各下游子任务网络中,如3D目标检测、车道线分割、目标跟踪等子任务。元知识融合网络可采用卷积神经网络层、全连接层或Transformer交叉注意力层组成。
作为一种可行的实施方式,所述基于所述训练完成的云端感知模型构建车端感知模型,包括:将所述训练完成的云端感知模型中云端元网络中的所有基础元网络划分为所述第三预设数量个类别;对每个类别中基础元网络的参数进行均值融合得到车端感知模型中对应的基础元网络的参数;将所述训练完成的云端感知模型中云端元网络中的所有抗干扰元网络划分为所述第四预设数量个类别;对每个类别中抗干扰元网络的参数进行均值融合得到车端感知模型中对应的抗干扰元网络的参数。
在具体实施中,当车端感知模型只包含一个基础元网络和一个抗干扰元网络时,使用云端感知模型的全部基础元网络的参数均值、全部抗干扰元网络的参数均值来初始化车端感知模型的基础元网络、抗干扰元网络。当车端感知模型包含enormal(enormal>1)个基础元网络或edisturb(edisturb>1)个抗干扰元网络时,则将云端感知模型的Enormal个基础元网络划分为enormal组,每组包含Enormal/enormal个基础元网络,对应一个车端感知模型的基础元网络,将云端感知模型的Edisturb个抗干扰元网络划分为edisturb组,使用每组的Edisturb/edisturb个云端感知模型的抗干扰元网络的参数均值来初始化该组对应的车端感知模型的抗干扰元网络。需要注意的是,当Enormal/enormal或Edisturb/edisturb不为整数时,则向下取整,在划分完前(enormal-1)、(edisturb-1)组后,将剩余的基础元网络分为一组、剩余的抗干扰元网络分为一组。
S205:基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型。
在本步骤中,基于正常场景数据和干扰场景数据训练车端感知模型得到训练完成的车端感知模型。
作为一种优选的实施方式,所述基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型,包括:确定正常训练样本的第三采样率或干扰训练样本的第四采样率;其中,所述第三采样率与所述第四采样率的和为一;基于所述第三采样率或所述第四采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;将所述干扰训练样本输入所述车端感知模型中,基于所述任务网络的损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型。
在具体实施中,每次训练时,采用前述的干扰数据采样策略从正常场景数据集和感知鲁棒性干扰数据库中采样一个批次的训练样本,如批次大小设定为32,即表示采样32个数据帧F。需要注意的是,此时干扰数据采集策略中的采样概率不必与云端感知模型训练阶段保持一致,可以灵活修改。将每个批次的训练样本分别输入云端感知模型和车端感知模型中,获得各自的多模态融合特征和融合的元知识特征,并使用各子任务网络计算获得的损失,反向传播优化车端感知模型。
作为一种优选的实施方式,所述基于所述任务网络的损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型,包括:基于所述任务网络的损失和知识蒸馏损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型;其中,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的多模态融合网络输出的融合特征和所述车端感知模型中的多模态融合网络输出的融合特征计算得到的,或,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的元知识融合网络输出的融合特征和所述车端感知模型中的元知识融合网络输出的融合特征计算得到的。
在具体实施中,计算车端感知模型鲁棒性蒸馏损失函数,使用蒸馏损失和各子任务网络计算获得的损失,反向传播优化车端感知模型。通过知识蒸馏提升车端感知模型的鲁棒性以及在正常场景下的感知准确率。如图10所示,本发明提供两种知识蒸馏模块,分别为多模态融合特征知识蒸馏模块、元知识蒸馏模块。多模态融合特征知识蒸馏模块使用云端感知模型的多模态融合网络的输出特征,对车端感知模型的多模态融合网络的输出特征进行知识蒸馏。
若使用多模态融合特征知识蒸馏模块,使用Γ表示云端感知模型的多模态融合网络输出的融合特征,其形状为(W,H,C),分别表示特征的宽、高、通道数量,使用τ表示车端感知模型的多模态融合网络输出的融合特征,其形状为(W,H,C),在多模态融合特征知识蒸馏模块中,使用知识蒸馏损失函数进行知识蒸馏,知识蒸馏损失的计算公式为:
其中,为所述知识蒸馏损失,/>为所述训练完成的云端感知模型中的多模态融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,/>所述车端感知模型中的多模态融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,1≤i≤W,1≤j≤H,1≤k≤C,W为融合特征的宽度,H为融合特征的高度,C为融合特征的通道数量。
在云端感知模型和车端感知模型的多模态融合网络输出的融合特征间进行知识蒸馏,提高车端感知模型多模态融合特征的质量,进而提高车端感知模型在正常场景下的感知准确性以及对干扰因素的鲁棒性。
对于元知识蒸馏模块,同样采用上述流程在云端感知模型的元知识融合网络输出的融合元知识特征和车端感知模型的元知识融合网络输出的融合元知识特征间进行知识蒸馏,知识蒸馏损失的计算公式为:
其中,为所述知识蒸馏损失,/>为所述训练完成的云端感知模型中的元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,/>所述车端感知模型中的元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,1≤i≤W,1≤j≤H,1≤k≤C,W为融合特征的宽度,H为融合特征的高度,C为融合特征的通道数量。
在云端感知模型和车端感知模型的元知识融合网络输出的融合的元知识特征间进行知识蒸馏,提高车端感知模型元知识特征的质量,进而提高车端感知模型在正常场景下的感知准确性以及对干扰因素的鲁棒性。
由此可见,本发明实施例构建了一种云端感知模型结构,包含基础元网络组和抗干扰元网络组,分别负责正常场景数据和干扰场景,通过显式的为正常场景和干扰场景分别设置元网络组,既能确保云端感知模型对正常场景的准确感知,又能提升云端感知模型对干扰因素的鲁棒性。针对干扰数据样本量少,抗干扰元网络容易过拟合的问题,本发明实施例构建了一种纺锤结构的抗干扰元网络,该纺锤结构的抗干扰元网络有着更强的表达能力和抗过拟合能力,可有效提升抗干扰元网络的鲁棒性。针对通用的训练策略会限制抗干扰元网络训练效果的问题,本发明实施例提出一种抗干扰训练方法,包含分支训练方法和基于梯度的优化训练方法,可显著提高云端感知模型的训练效果。最后,本发明实施例利用训练完成的云端感知模型对车端感知模型进行知识蒸馏,以获得具备高准确性、高鲁棒性的车端感知模型。
下面介绍本申请提供的一种应用实施例,如图11所示,包括以下步骤:
步骤1:采集正常场景数据和干扰数据,构建数据集;
步骤2:从数据集中选取正常场景样本和干扰样本,训练感知大模型的基础分支;
步骤3:利用训练完成的基础分支的基础元网络,构建抗干扰分支中的抗干扰元网络,从数据集中选取正常场景样本和干扰样本,使用抗干扰训练方法,训练感知大模型的抗干扰分支;
步骤4:合并训练完成的感知大模型基础分支和抗干扰分支,得到完整的感知大模型,从数据集中选取正常场景样本和干扰样本,训练完整的感知大模型;
步骤5:从数据集中选取正常场景样本和干扰样本,利用训练完成的完整的感知大模型,对车端感知模型进行知识蒸馏;
步骤6:将训练完成的车端感知模型部署至自动驾驶系统中。
下面对本申请实施例提供的一种感知模型训练装置进行介绍,下文描述的一种感知模型训练装置与上文描述的一种感知模型训练方法可以相互参照。
参见图12,根据一示例性实施例示出的一种感知模型训练装置的结构图,如图12所示,包括:
获取单元100,用于获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
分解单元200,用于将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
第一训练单元300,用于基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型。
本申请实施例提供的感知模型训练装置,针对云端感知模型,将其划分为基础分支和抗干扰分支,分别基于正常场景数据和干扰数据进行有效训练,提升云端感知模型对干扰数据的学习效果,提高了感知模型的鲁棒性。
在上述实施例的基础上,作为一种优选实施方式,所述云端感知模型包括依次连接的输入模块、基础网络、云端元网络、元知识融合网络和任务网络,所述云端元网络包括基础元网络组和抗干扰元网络组,所述基础元网络组包括第一预设数量个基础元网络,所述抗干扰元网络组包括第二预设数量个抗干扰元网络;
所述基础分支包括依次连接的所述输入模块、所述基础网络、所述基础元网络组、所述元知识融合网络和所述任务网络,所述抗干扰分支包括依次连接的所述输入模块、所述基础网络、所述抗干扰元网络组、所述元知识融合网络和所述任务网络。
在上述实施例的基础上,作为一种优选实施方式,所述输入模块包括多个单模态的输入单元,每个所述输入单元通过对应的基础网络连接多模态融合网络,所述多模态融合网络连接所述云端元网络。
在上述实施例的基础上,作为一种优选实施方式,所述基础网络包括依次连接的预处理模块、骨干网络和多尺度特征提取网络。
在上述实施例的基础上,作为一种优选实施方式,所述基础元网络包括第一基础子网络、第二基础子网络、第三基础子网络,所述第一基础子网络、所述第二基础子网络、所述第三基础子网络包括多层卷积神经网络或全连接层或注意力层。
在上述实施例的基础上,作为一种优选实施方式,所述抗干扰元网络包括第一抗干扰子网络、第二抗干扰子网络、第三抗干扰子网络,所述第一抗干扰子网络的模型结构与所述第一基础子网络的模型结构相同,所述第二抗干扰子网络包括所有所述第二基础子网络、每个所述第二基础子网络对应的特征学习模型和与所有所述特征学习模型连接的特征融合模块,所述第三抗干扰子网络包括多层卷积神经网络或全连接层或注意力层。
在上述实施例的基础上,作为一种优选实施方式,所述特征学习模型用于将对应的第二基础子网络输出的特征与对应的学习参数相乘得到对应的学习特征。
在上述实施例的基础上,作为一种优选实施方式,所述特征融合模块用于对所有所述学习特征采用均值融合方式进行聚合得到聚合特征。
在上述实施例的基础上,作为一种优选实施方式,所述第一训练单元300包括:
第一训练子单元,用于在正常场景数据中采样正常训练样本,将所述正常训练样本输入所述基础分支中,基于所述任务网络的损失对所述基础分支的模型参数进行训练,得到训练完成的基础分支。
在上述实施例的基础上,作为一种优选实施方式,所述第一训练单元300包括:
初始化子单元,用于基于训练完成的基础分支初始化所述抗干扰分支,并冻结初始化完成的抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;
第二训练子单元,用于在干扰场景数据中采样干扰训练样本,将所述干扰训练样本输入所述初始化完成的抗干扰分支中,基于所述任务网络的损失对所述初始化完成的抗干扰分支的模型参数进行训练,得到训练完成的抗干扰分支。
在上述实施例的基础上,作为一种优选实施方式,所述初始化子单元具体用于:基于训练完成的基础分支中基础元网络中的第一基础子网络的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;基于训练完成的基础分支中基础元网络中的第二基础子网络构建所述抗干扰分支中每个抗干扰元网络中的第二抗干扰子网络;冻结初始化完成的抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数。
在上述实施例的基础上,作为一种优选实施方式,所述初始化子单元具体用于:对训练完成的基础分支中基础元网络中的所有第一基础子网络的参数进行均值融合,基于融合后的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;基于训练完成的基础分支中基础元网络中的第二基础子网络构建所述抗干扰分支中每个抗干扰元网络中的第二抗干扰子网络;冻结初始化完成的抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数。
在上述实施例的基础上,作为一种优选实施方式,所述任务网络包括多个子任务网络。
在上述实施例的基础上,作为一种优选实施方式,所述第二训练子单元具体用于:在干扰场景数据中采样干扰训练样本,将所述干扰训练样本输入所述初始化完成的抗干扰分支中,计算所述任务网络中多个所述子任务网络的损失和,基于所述损失和计算所述初始化完成的抗干扰分支中每个抗干扰元网络的输出特征的每个维度的梯度,对每个所述输出特征的每个维度的梯度进行绝对值求和得到每个所述输出特征的梯度和;对所有所述输出特征的梯度和由大至小进行排序,确定排序结果中前第五预设数量个输出特征对应的抗干扰元网络作为目标抗干扰元网络;冻结除所述目标抗干扰元网络之外的其他抗干扰元网络的参数,更新所述目标抗干扰元网络的参数,得到训练完成的抗干扰分支。
在上述实施例的基础上,作为一种优选实施方式,所述干扰训练样本包括干扰数据和对应的标注,所述标注包括任务标注和干扰因素类型。
在上述实施例的基础上,作为一种优选实施方式,所述第一训练单元300包括:
合并子单元,用于对训练完成的基础分支中的基础网络的参数和训练完成的抗干扰分支中的基础网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;对训练完成的基础分支中的元知识融合网络的参数和训练完成的抗干扰分支中的元知识融合网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;对训练完成的基础分支中的任务网络的参数和训练完成的抗干扰分支中的任务网络的参数进行加权合并,得到中间云端感知模型中任务网络的参数;将训练完成的基础分支中的基础元网络组和训练完成的抗干扰分支中的抗干扰元网络组合并为中间云端感知模型中的云端元网络。
在上述实施例的基础上,作为一种优选实施方式,所述第一训练单元300包括:
第三训练子单元,用于确定正常训练样本的第一采样率或干扰训练样本的第二采样率;其中,所述第一采样率与所述第二采样率的和为一;基于所述第一采样率或所述第二采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;将所述干扰训练样本输入所述中间云端感知模型中,基于所述任务网络的损失对所述中间云端感知模型的模型参数进行训练,得到训练完成的云端感知模型。
下面对本申请实施例提供的另一种感知模型训练装置进行介绍,下文描述的一种感知模型训练装置与上文描述的另一种感知模型训练方法可以相互参照。
参见图13,根据一示例性实施例示出的一种感知模型训练装置的结构图,如图13所示,包括:
获取单元100,用于获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
分解单元200,用于将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
第一训练单元300,用于基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;
构建单元400,用于基于所述训练完成的云端感知模型构建车端感知模型;其中,所述车端感知模型包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络,所述第三预设数量小于所述第一预设数量,所述第四预设数量小于所述第二预设数量;
第二训练单元500,用于基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型。
本申请实施例提供的感知模型训练装置,针对云端感知模型,将其划分为基础分支和抗干扰分支,分别基于正常场景数据和干扰数据进行有效训练,提升云端感知模型对干扰数据的学习效果,进一步提升了车端感知模型对干扰数据的学习效果,提高了车端感知模型的鲁棒性。
在上述实施例的基础上,作为一种优选实施方式,所述车端感知模型包括依次连接的输入模块、基础网络、车端元网络、元知识融合网络和所述任务网络,车端元网络包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络。
在上述实施例的基础上,作为一种优选实施方式,所述构建单元400具体用于:将所述训练完成的云端感知模型中云端元网络中的所有基础元网络划分为所述第三预设数量个类别;对每个类别中基础元网络的参数进行均值融合得到车端感知模型中对应的基础元网络的参数;将所述训练完成的云端感知模型中云端元网络中的所有抗干扰元网络划分为所述第三预设数量个类别;对每个类别中抗干扰元网络的参数进行均值融合得到车端感知模型中对应的抗干扰元网络的参数。
在上述实施例的基础上,作为一种优选实施方式,所述第二训练单元500包括:
确定子单元,用于确定正常训练样本的第三采样率或干扰训练样本的第四采样率;其中,所述第三采样率与所述第四采样率的和为一;
采样子单元,用于基于所述第三采样率或所述第四采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;
第四训练子单元,用于将所述干扰训练样本输入所述车端感知模型中,基于所述任务网络的损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型。
在上述实施例的基础上,作为一种优选实施方式,所述第四训练子单元具体用于:将所述干扰训练样本输入所述车端感知模型中,基于所述任务网络的损失和知识蒸馏损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型;其中,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的多模态融合网络输出的融合特征和所述车端感知模型中的多模态融合网络输出的融合特征计算得到的,或,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的元知识融合网络输出的融合特征和所述车端感知模型中的元知识融合网络输出的融合特征计算得到的。
在上述实施例的基础上,作为一种优选实施方式,所述知识蒸馏损失的计算公式为:
其中,为所述知识蒸馏损失,/>为所述训练完成的云端感知模型中的多模态融合网络或元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,/>所述车端感知模型中的多模态融合网络或元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,1≤i≤W,1≤j≤H,1≤k≤C,W为融合特征的宽度,H为融合特征的高度,C为融合特征的通道数量。
关于上述实施例中的装置,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
基于上述程序模块的硬件实现,且为了实现本申请实施例的方法,本申请实施例还提供了一种电子设备,图14为根据一示例性实施例示出的一种电子设备的结构图,如图14所示,电子设备包括:
通信接口1,能够与其它设备比如网络设备等进行信息交互;
处理器2,与通信接口1连接,以实现与其它设备进行信息交互,用于运行计算机程序时,执行上述一个或多个技术方案提供的车端感知模型训练方法。而所述计算机程序存储在存储器3上。
当然,实际应用时,电子设备中的各个组件通过总线系统4耦合在一起。可理解,总线系统4用于实现这些组件之间的连接通信。总线系统4除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图14中将各种总线都标为总线系统4。
本申请实施例中的存储器3用于存储各种类型的数据以支持电子设备的操作。这些数据的示例包括:用于在电子设备上操作的任何计算机程序。
可以理解,存储器3可以是易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。其中,非易失性存储器可以是只读存储器(ROM,Read Only Memory)、可编程只读存储器(PROM,Programmable Read-Only Memory)、可擦除可编程只读存储器(EPROM,Erasable Programmable Read-Only Memory)、电可擦除可编程只读存储器(EEPROM,Electrically Erasable Programmable Read-Only Memory)、磁性随机存取存储器(FRAM,ferromagnetic random access memory)、快闪存储器(Flash Memory)、磁表面存储器、光盘、或只读光盘(CD-ROM,Compact Disc Read-Only Memory);磁表面存储器可以是磁盘存储器或磁带存储器。易失性存储器可以是随机存取存储器(RAM,Random AccessMemory),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(SRAM,Static Random Access Memory)、同步静态随机存取存储器(SSRAM,Synchronous Static Random Access Memory)、动态随机存取存储器(DRAM,Dynamic Random Access Memory)、同步动态随机存取存储器(SDRAM,SynchronousDynamic Random Access Memory)、双倍数据速率同步动态随机存取存储器(DDRSDRAM,Double Data Rate Synchronous Dynamic Random Access Memory)、增强型同步动态随机存取存储器(ESDRAM,Enhanced Synchronous Dynamic Random Access Memory)、同步连接动态随机存取存储器(SLDRAM,SyncLink Dynamic Random Access Memory)、直接内存总线随机存取存储器(DRRAM,Direct Rambus Random Access Memory)。本申请实施例描述的存储器3旨在包括但不限于这些和任意其它适合类型的存储器。
上述本申请实施例揭示的方法可以应用于处理器2中,或者由处理器2实现。处理器2可能是一种集成电路芯片,具有信号的处理能力。在实现过程中,上述方法的各步骤可以通过处理器2中的硬件的集成逻辑电路或者软件形式的指令完成。上述的处理器2可以是通用处理器、DSP,或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。处理器2可以实现或者执行本申请实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以是微处理器或者任何常规的处理器等。结合本申请实施例所公开的方法的步骤,可以直接体现为硬件译码处理器执行完成,或者用译码处理器中的硬件及软件模块组合执行完成。软件模块可以位于存储介质中,该存储介质位于存储器3,处理器2读取存储器3中的程序,结合其硬件完成前述方法的步骤。
处理器2执行所述程序时实现本申请实施例的各个方法中的相应流程,为了简洁,在此不再赘述。
在示例性实施例中,本申请实施例还提供了一种存储介质,即计算机存储介质,具体为计算机可读存储介质,例如包括存储计算机程序的存储器3,上述计算机程序可由处理器2执行,以完成前述方法所述步骤。计算机可读存储介质可以是FRAM、ROM、PROM、EPROM、EEPROM、Flash Memory、磁表面存储器、光盘、CD-ROM等存储器。
本领域普通技术人员可以理解:实现上述方法实施例的全部或部分步骤可以通过程序指令相关的硬件来完成,前述的程序可以存储于一计算机可读取存储介质中,该程序在执行时,执行包括上述方法实施例的步骤;而前述的存储介质包括:移动存储设备、ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
或者,本申请上述集成的单元如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实施例的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台电子设备(可以是个人计算机、服务器、网络设备等)执行本申请各个实施例所述方法的全部或部分。而前述的存储介质包括:移动存储设备、ROM、RAM、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以所述权利要求的保护范围为准。

Claims (27)

1.一种感知模型训练方法,其特征在于,应用于自动驾驶,包括:
获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;其中,所述正常场景数据为摄像机采集的图像数据或激光雷达采集的点云数据,所述干扰场景数据为存在干扰因素的图像数据或点云数据。
2.根据权利要求1所述感知模型训练方法,其特征在于,所述云端感知模型包括依次连接的输入模块、基础网络、云端元网络、元知识融合网络和任务网络,所述云端元网络包括基础元网络组和抗干扰元网络组,所述基础元网络组包括第一预设数量个基础元网络,所述抗干扰元网络组包括第二预设数量个抗干扰元网络;
所述基础分支包括依次连接的所述输入模块、所述基础网络、所述基础元网络组、所述元知识融合网络和所述任务网络,所述抗干扰分支包括依次连接的所述输入模块、所述基础网络、所述抗干扰元网络组、所述元知识融合网络和所述任务网络。
3.根据权利要求2所述感知模型训练方法,其特征在于,所述输入模块包括多个单模态的输入单元,每个所述输入单元通过对应的基础网络连接多模态融合网络,所述多模态融合网络连接所述云端元网络。
4.根据权利要求2所述感知模型训练方法,其特征在于,所述基础网络包括依次连接的预处理模块、骨干网络和多尺度特征提取网络。
5.根据权利要求2所述感知模型训练方法,其特征在于,所述基础元网络包括第一基础子网络、第二基础子网络、第三基础子网络,所述第一基础子网络、所述第二基础子网络、所述第三基础子网络包括多层卷积神经网络或全连接层或注意力层。
6.根据权利要求5所述感知模型训练方法,其特征在于,所述抗干扰元网络包括第一抗干扰子网络、第二抗干扰子网络、第三抗干扰子网络,所述第一抗干扰子网络的模型结构与所述第一基础子网络的模型结构相同,所述第二抗干扰子网络包括所有所述第二基础子网络、每个所述第二基础子网络对应的特征学习模型和与所有所述特征学习模型连接的特征融合模块,所述第三抗干扰子网络包括多层卷积神经网络或全连接层或注意力层。
7.根据权利要求6所述感知模型训练方法,其特征在于,所述特征学习模型用于将对应的第二基础子网络输出的特征与对应的学习参数相乘得到对应的学习特征。
8.根据权利要求7所述感知模型训练方法,其特征在于,所述特征融合模块用于对所有所述学习特征采用均值融合方式进行聚合得到聚合特征。
9.根据权利要求6所述感知模型训练方法,其特征在于,所述基于正常场景数据训练所述基础分支,包括:
在正常场景数据中采样正常训练样本,将所述正常训练样本输入所述基础分支中,基于所述任务网络的损失对所述基础分支的模型参数进行训练,得到训练完成的基础分支。
10.根据权利要求9所述感知模型训练方法,其特征在于,所述基于干扰场景数据训练所述抗干扰分支,包括:
基于训练完成的基础分支初始化所述抗干扰分支,并冻结初始化完成的抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;
在干扰场景数据中采样干扰训练样本,将所述干扰训练样本输入所述初始化完成的抗干扰分支中,基于所述任务网络的损失对所述初始化完成的抗干扰分支的模型参数进行训练,得到训练完成的抗干扰分支。
11.根据权利要求10所述感知模型训练方法,其特征在于,所述基于训练完成的基础分支初始化所述抗干扰分支,包括:
基于训练完成的基础分支中基础元网络中的第一基础子网络的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数;
基于训练完成的基础分支中基础元网络中的第二基础子网络构建所述抗干扰分支中每个抗干扰元网络中的第二抗干扰子网络。
12.根据权利要求11所述感知模型训练方法,其特征在于,所述基于训练完成的基础分支中基础元网络中的第一基础子网络的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数,包括:
对训练完成的基础分支中基础元网络中的所有第一基础子网络的参数进行均值融合,基于融合后的参数初始化所述抗干扰分支中每个抗干扰元网络中的第一抗干扰子网络的参数。
13.根据权利要求11所述感知模型训练方法,其特征在于,所述任务网络包括多个子任务网络。
14.根据权利要求13所述感知模型训练方法,其特征在于,所述基于所述任务网络的损失对所述初始化完成的抗干扰分支的模型参数进行训练,得到训练完成的抗干扰分支,包括:
计算所述任务网络中多个所述子任务网络的损失和,基于所述损失和计算所述初始化完成的抗干扰分支中每个抗干扰元网络的输出特征的每个维度的梯度,对每个所述输出特征的每个维度的梯度进行绝对值求和得到每个所述输出特征的梯度和;
对所有所述输出特征的梯度和由大至小进行排序,确定排序结果中前K个输出特征对应的抗干扰元网络作为目标抗干扰元网络;
冻结除所述目标抗干扰元网络之外的其他抗干扰元网络的参数,更新所述目标抗干扰元网络的参数,得到训练完成的抗干扰分支。
15.根据权利要求10所述感知模型训练方法,其特征在于,所述干扰训练样本包括干扰数据和对应的标注,所述标注包括任务标注和干扰因素类型。
16.根据权利要求1所述感知模型训练方法,其特征在于,所述合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,包括:
对训练完成的基础分支中的基础网络的参数和训练完成的抗干扰分支中的基础网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;
对训练完成的基础分支中的元知识融合网络的参数和训练完成的抗干扰分支中的元知识融合网络的参数进行加权合并,得到中间云端感知模型中基础网络的参数;
对训练完成的基础分支中的任务网络的参数和训练完成的抗干扰分支中的任务网络的参数进行加权合并,得到中间云端感知模型中任务网络的参数;
将训练完成的基础分支中的基础元网络组和训练完成的抗干扰分支中的抗干扰元网络组合并为中间云端感知模型中的云端元网络。
17.根据权利要求2所述感知模型训练方法,其特征在于,所述基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型,包括:
确定正常训练样本的第一采样率或干扰训练样本的第二采样率;其中,所述第一采样率与所述第二采样率的和为一;
基于所述第一采样率或所述第二采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;
将所述干扰训练样本输入所述中间云端感知模型中,基于所述任务网络的损失对所述中间云端感知模型的模型参数进行训练,得到训练完成的云端感知模型。
18.一种感知模型训练方法,其特征在于,应用于自动驾驶,包括:
获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;其中,所述正常场景数据为摄像机采集的图像数据或激光雷达采集的点云数据,所述干扰场景数据为存在干扰因素的图像数据或点云数据;
基于所述训练完成的云端感知模型构建车端感知模型;其中,所述车端感知模型包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络,所述第三预设数量小于所述第一预设数量,所述第四预设数量小于所述第二预设数量;
基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型。
19.根据权利要求18所述感知模型训练方法,其特征在于,所述车端感知模型包括依次连接的输入模块、基础网络、车端元网络、元知识融合网络和任务网络,车端元网络包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络。
20.根据权利要求18所述感知模型训练方法,其特征在于,所述基于所述训练完成的云端感知模型构建车端感知模型,包括:
将所述训练完成的云端感知模型中云端元网络中的所有基础元网络划分为所述第三预设数量个类别;
对每个类别中基础元网络的参数进行均值融合得到车端感知模型中对应的基础元网络的参数;
将所述训练完成的云端感知模型中云端元网络中的所有抗干扰元网络划分为所述第四预设数量个类别;
对每个类别中抗干扰元网络的参数进行均值融合得到车端感知模型中对应的抗干扰元网络的参数。
21.根据权利要求19所述感知模型训练方法,其特征在于,所述基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型,包括:
确定正常训练样本的第三采样率或干扰训练样本的第四采样率;其中,所述第三采样率与所述第四采样率的和为一;
基于所述第三采样率或所述第四采样率在正常场景数据中采样正常训练样本、在干扰场景数据中采样干扰训练样本;
将所述干扰训练样本输入所述车端感知模型中,基于所述任务网络的损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型。
22.根据权利要求21所述感知模型训练方法,其特征在于,所述基于所述任务网络的损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型,包括:
基于所述任务网络的损失和知识蒸馏损失对所述车端感知模型的模型参数进行训练,得到训练完成的车端感知模型;
其中,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的多模态融合网络输出的融合特征和所述车端感知模型中的多模态融合网络输出的融合特征计算得到的,或,所述知识蒸馏损失为基于所述训练完成的云端感知模型中的元知识融合网络输出的融合特征和所述车端感知模型中的元知识融合网络输出的融合特征计算得到的。
23.根据权利要求22所述感知模型训练方法,其特征在于,所述知识蒸馏损失的计算公式为:
其中,为所述知识蒸馏损失,/>为所述训练完成的云端感知模型中的多模态融合网络或元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,/>所述车端感知模型中的多模态融合网络或元知识融合网络输出的融合特征第(i,j)位置的第k通道处的特征值,1≤i≤W,1≤j≤H,1≤k≤C,W为融合特征的宽度,H为融合特征的高度,C为融合特征的通道数量。
24.一种感知模型训练装置,其特征在于,应用于自动驾驶,包括:
获取单元,用于获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
分解单元,用于将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
第一训练单元,用于基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;其中,所述正常场景数据为摄像机采集的图像数据或激光雷达采集的点云数据,所述干扰场景数据为存在干扰因素的图像数据或点云数据。
25.一种感知模型训练装置,其特征在于,应用于自动驾驶,包括:
获取单元,用于获取云端感知模型;其中,所述云端感知模型包括第一预设数量个基础元网络和第二预设数量个抗干扰元网络;
分解单元,用于将所述云端感知模型分解为基础分支和抗干扰分支;其中,所述基础分支包括第一预设数量个基础元网络,所述抗干扰分支包括第二预设数量个抗干扰元网络;
第一训练单元,用于基于正常场景数据训练所述基础分支,基于干扰场景数据训练所述抗干扰分支,合并训练完成的基础分支和训练完成的抗干扰分支得到中间云端感知模型,基于所述正常场景数据和所述干扰场景数据训练所述中间云端感知模型得到训练完成的云端感知模型;其中,所述正常场景数据为摄像机采集的图像数据或激光雷达采集的点云数据,所述干扰场景数据为存在干扰因素的图像数据或点云数据;
构建单元,用于基于所述训练完成的云端感知模型构建车端感知模型;其中,所述车端感知模型包括第三预设数量个基础元网络和第四预设数量个抗干扰元网络,所述第三预设数量小于所述第一预设数量,所述第四预设数量小于所述第二预设数量;
第二训练单元,用于基于所述正常场景数据和所述干扰场景数据训练所述车端感知模型得到训练完成的车端感知模型。
26.一种电子设备,其特征在于,包括:
存储器,用于存储计算机程序;
处理器,用于执行所述计算机程序时实现如权利要求1至23任一项所述感知模型训练方法的步骤。
27.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机程序,所述计算机程序被处理器执行时实现如权利要求1至23任一项所述感知模型训练方法的步骤。
CN202311128070.0A 2023-09-04 2023-09-04 一种感知模型训练方法、装置及电子设备和存储介质 Active CN116861262B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311128070.0A CN116861262B (zh) 2023-09-04 2023-09-04 一种感知模型训练方法、装置及电子设备和存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311128070.0A CN116861262B (zh) 2023-09-04 2023-09-04 一种感知模型训练方法、装置及电子设备和存储介质

Publications (2)

Publication Number Publication Date
CN116861262A CN116861262A (zh) 2023-10-10
CN116861262B true CN116861262B (zh) 2024-01-19

Family

ID=88219384

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311128070.0A Active CN116861262B (zh) 2023-09-04 2023-09-04 一种感知模型训练方法、装置及电子设备和存储介质

Country Status (1)

Country Link
CN (1) CN116861262B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117097797B (zh) * 2023-10-19 2024-02-09 浪潮电子信息产业股份有限公司 云边端协同方法、装置、系统、电子设备及可读存储介质

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111785085A (zh) * 2020-06-11 2020-10-16 北京航空航天大学 视觉感知以及感知网络训练方法、装置、设备和存储介质
CN115131679A (zh) * 2022-07-01 2022-09-30 合众新能源汽车有限公司 检测方法、装置及计算机存储介质
CN115393684A (zh) * 2022-10-27 2022-11-25 松立控股集团股份有限公司 一种基于自动驾驶场景多模态融合的抗干扰目标检测方法
CN115879535A (zh) * 2023-02-10 2023-03-31 北京百度网讯科技有限公司 一种自动驾驶感知模型的训练方法、装置、设备和介质
CN115907009A (zh) * 2023-02-10 2023-04-04 北京百度网讯科技有限公司 一种自动驾驶感知模型的迁移方法、装置、设备和介质

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111785085A (zh) * 2020-06-11 2020-10-16 北京航空航天大学 视觉感知以及感知网络训练方法、装置、设备和存储介质
CN115131679A (zh) * 2022-07-01 2022-09-30 合众新能源汽车有限公司 检测方法、装置及计算机存储介质
CN115393684A (zh) * 2022-10-27 2022-11-25 松立控股集团股份有限公司 一种基于自动驾驶场景多模态融合的抗干扰目标检测方法
CN115879535A (zh) * 2023-02-10 2023-03-31 北京百度网讯科技有限公司 一种自动驾驶感知模型的训练方法、装置、设备和介质
CN115907009A (zh) * 2023-02-10 2023-04-04 北京百度网讯科技有限公司 一种自动驾驶感知模型的迁移方法、装置、设备和介质

Also Published As

Publication number Publication date
CN116861262A (zh) 2023-10-10

Similar Documents

Publication Publication Date Title
CN113780296B (zh) 基于多尺度信息融合的遥感图像语义分割方法及系统
CN110929577A (zh) 一种基于YOLOv3的轻量级框架改进的目标识别方法
Dong et al. A hybrid spatial–temporal deep learning architecture for lane detection
CN111047078B (zh) 交通特征预测方法、系统及存储介质
CN116861262B (zh) 一种感知模型训练方法、装置及电子设备和存储介质
CN112784954A (zh) 确定神经网络的方法和装置
CN112446888A (zh) 图像分割模型的处理方法和处理装置
CN115147598A (zh) 目标检测分割方法、装置、智能终端及存储介质
Li et al. Gated auxiliary edge detection task for road extraction with weight-balanced loss
JP2024513596A (ja) 画像処理方法および装置、ならびにコンピュータ可読ストレージ媒体
CN114445461A (zh) 基于非配对数据的可见光红外目标跟踪训练方法及装置
Sun et al. Road crack detection network under noise based on feature pyramid structure with feature enhancement (road crack detection under noise)
CN117217280A (zh) 神经网络模型优化方法、装置及计算设备
Liang et al. Car detection and classification using cascade model
CN116432736A (zh) 神经网络模型优化方法、装置及计算设备
Dong et al. Refinement Co‐supervision network for real‐time semantic segmentation
Kozlov et al. Development of real-time ADAS object detector for deployment on CPU
CN111160282B (zh) 一种基于二值化Yolov3网络的红绿灯检测方法
CN113119996B (zh) 一种轨迹预测方法、装置、电子设备及存储介质
CN112001211B (zh) 对象检测方法、装置、设备及计算机可读存储介质
CN116821699B (zh) 一种感知模型训练方法、装置及电子设备和存储介质
Wang Remote sensing image semantic segmentation network based on ENet
CN114444597B (zh) 基于渐进式融合网络的视觉跟踪方法及装置
CN113362372B (zh) 一种单目标追踪方法及计算机可读介质
CN113031600B (zh) 一种轨迹生成方法、装置、存储介质及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant