CN115358413A - 一种点云多任务模型的训练方法、装置及电子设备 - Google Patents

一种点云多任务模型的训练方法、装置及电子设备 Download PDF

Info

Publication number
CN115358413A
CN115358413A CN202211115837.1A CN202211115837A CN115358413A CN 115358413 A CN115358413 A CN 115358413A CN 202211115837 A CN202211115837 A CN 202211115837A CN 115358413 A CN115358413 A CN 115358413A
Authority
CN
China
Prior art keywords
task
gradient
point cloud
branch
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211115837.1A
Other languages
English (en)
Inventor
李骏
张新钰
王力
黄毅
谢涛
杨淋淇
吴新刚
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Original Assignee
Tsinghua University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University filed Critical Tsinghua University
Priority to CN202211115837.1A priority Critical patent/CN115358413A/zh
Publication of CN115358413A publication Critical patent/CN115358413A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请提供了一种点云多任务模型的训练方法、装置及电子设备,涉及智能驾驶技术领域,该方法包括:利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;对每个任务分支的主干网络参数的梯度进行更新以消除梯度冲突,得到每个任务分支的主干网络参数的最终梯度及当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数;基于更新后的主干网络参数和所述多个多任务训练样本组合,继续进行主干网络的参数更新过程,直至达到预设的迭代结束条件。本申请训练出的不同任务分支间的共享参数,能够减少各任务之间的干扰。

Description

一种点云多任务模型的训练方法、装置及电子设备
技术领域
本申请涉及智能驾驶技术领域,尤其是涉及一种点云多任务模型的训练方法、装置及电子设备。
背景技术
在目前的多任务学习方法,多为对每个任务单独设计一个深度卷积网络结构,输入图片,输出对应标签或关键点位置信息。这种方法具有以下问题:每个任务设计一个独立的深度卷积网络,网络间没有共享的参数,总参数量和计算量大,模型推理耗时长。解决这类学习问题的一种方法是联合训练一个针对所有任务的网络,目的是发现跨任务的共享结构,其效率和性能优于单独解决任务。
将多任务学习设计为能够从多任务监督信号中学习共享表达的网络,与每个单独的任务都有自己的单独的网络相比,多任务网络具有以下优势:
首先,由于它们固有的层共享,因此产生的内存占用大大减少。其次,因为它们明确避免重复计算共享层中的特征,所以拥有更快的推理速度。最重要的是,如果关联的任务共享互补信息,或相互充当正则化器,那么它们就有提高性能的潜力。
然而,一次性学习多个任务会产生优化问题,有时会导致整体性能和数据效率低于单独学习任务。
发明内容
有鉴于此,本申请提供了一种点云多任务模型的训练方法、装置及电子设备,以解决上述技术问题。
第一方面,本申请实施例提供了一种点云多任务模型的训练方法,所述点云多任务模型包括一个主干网络和多个任务处理模型,所述主干网络和每个任务处理模型连接构成多个任务分支;所述方法包括:
获取多个多任务训练样本组合,每个多任务训练样本组合包括多个标注不同任务结果的点云数据样本;
利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;
对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则对其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;
计算每个任务分支的主干网络参数的最终梯度的和,作为当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数;
基于更新后的主干网络参数和所述多个多任务训练样本组合,继续进行主干网络的参数更新过程,直至达到预设的迭代结束条件,将得到的主干网络参数作为训练好的点云多任务模型的模型参数。
进一步,所述点云多任务模型包括N个任务分支,获取多个多任务训练样本组合;包括:
获取N个任务训练数据集合,每个任务训练数据集合包括多个标注一个任务结果的点云数据样本;
分别从各任务训练数据集合中抽取一个点云数据样本,将N个点云数据样本进行组合,得到多任务训练样本组合。
进一步,利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;包括:
将每个多任务训练样本组合中的点云数据样本输入对应的任务分支,得到预测结果;
根据预测结果和点云数据样本的标注结果计算损失函数Li(θ),i为任务分支的编号,1≤i≤N;θ为主干网络参数;
根据损失函数Li(θ),计算第i个任务分支的主干网络参数θ的梯度gi
Figure BDA0003845508640000031
其中,
Figure BDA0003845508640000032
为对Li(θ)中的参数θ的梯度运算。
进一步,对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则将其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;包括:
从i=1开始,执行下述步骤,直至i=N-1:
对于第i个任务分支的主干网络参数θ的梯度gi,计算其与第j个任务分支的主干网络参数θ的梯度gj的夹角为φij,其中,i+1≤j≤N;
判断cosφij<0是否成立,若成立,则梯度gi和梯度gj存在梯度冲突,否则,不存在梯度冲突;
当梯度gi和梯度gj存在梯度冲突,利用下式得到更新后的梯度
Figure BDA0003845508640000033
Figure BDA0003845508640000034
利用
Figure BDA0003845508640000035
更新gj
由此得到第i个任务分支的主干网络参数θ的最终梯度。
第二方面,本申请实施例提供了一种点云多任务模型的训练装置,所述点云多任务模型包括一个主干网络和多个任务处理模型,所述主干网络和每个任务处理模型连接构成任务分支;所述装置包括:
获取单元,用于获取多个多任务训练样本组合,每个多任务训练样本组合包括多个标注不同任务结果的点云数据样本;
计算单元,用于利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;
梯度冲突消除单元,用于对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则对其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;
主干网络参数更新单元,用于计算每个任务分支的主干网络参数的最终梯度的和,作为当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数;
迭代单元,用于基于更新后的主干网络参数和所述多个多任务训练样本组合,继续进行主干网络的参数更新过程,直至达到预设的迭代结束条件,将得到的主干网络参数作为训练好的点云多任务模型的模型参数。
第三方面,本申请实施例提供了一种电子设备,包括:存储器、处理器和存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现本申请实施例的点云多任务模型的训练方法。
第四方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机指令,所述计算机指令被处理器执行时实现本申请实施例的点云多任务模型的训练方法。
本申请训练出的不同任务分支间的共享参数,能够减少各任务之间的干扰。
附图说明
为了更清楚地说明本申请具体实施方式或现有技术中的技术方案,下面将对具体实施方式或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本申请的一些实施方式,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例提供的点云多任务模型的训练方法的流程图;
图2为本申请实施例提供的梯度消除的示意图;
图3为本申请实施例提供的点云多任务模型的训练装置的功能结构图;
图4为本申请实施例提供的电子设备的功能结构图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本申请实施例的组件可以以各种不同的配置来布置和设计。
因此,以下对在附图中提供的本申请的实施例的详细描述并非旨在限制要求保护的本申请的范围,而是仅仅表示本申请的选定实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
首先对本申请实施例的设计思想进行简单介绍。
点云学习因其在计算机视觉、自动驾驶、机器人等领域的广泛应用而受到越来越多的关注。深度学习作为人工智能的主流技术,已经成功地应用于解决各种二维视觉问题。然而,由于用深度神经网络处理点云所面临的独特挑战,点云的深度学习仍处于起步阶段。
为了方便分析,相关学者将使用纯激光雷达点云的3D检测分为基于点素和基于体素两个分支。基于点素的方式采用原始的点云数据坐标作为特征载体,直接利用激光雷达点云进行处理。基于体素的方式将点云数据转化成规则数据,利用卷积实现任务,换而言之,该方式将体素中心作为CNN感知特征载体,但相对原始点云对图像的坐标索引来说,体素中心与原始图像的索引存在偏差。
无论是哪种方法进行三维目标检测,本质上都利用深度神经网络处理点云信息。传统上,神经网络对于所给的任务是单独处理的,即为每个任务训练一个单独的神经网络。然而,许多现实世界的问题本质上是多模态的。例如,一辆自动驾驶汽车应该能够检测场景中的所有物体,定位它们,了解它们是什么,估计它们的距离和轨迹等,以便在它的周围安全导航。
上述观察结果促使研究人员开发了多任务学习模型,即给定一个输入图像可以推断出所有所需的任务输出。从自然语言处理和语音识别到计算机视觉,多任务学习已经成功地应用于深度学习几乎所有领域。多任务学习的形式有很多种,例如联合学习、自主学习、借助辅助任务学习等等。这些都只是被用来指代这种学习形式的一些名称,一般来说,一旦出现优化了不止一个损失函数的情况,都可以定义为在有效地进行多任务学习(相较于单任务学习)。即使有时候只是优化一个损失函数,有可能存在一个辅助任务能有助于改进主要任务,这种现象简明扼要地总结了多任务学习的目标,即多任务学习通过利用相关任务的训练信号中包含的领域特定信息来提高泛化能力。
本申请中的点云多任务模型包括一个主干网络backbone和多个并联的header,其中backbone用于提取点云特征,每个header对应一个处理任务,Backbone输出的点云特征是各header共用的。
在点云多任务模型的训练中,是单独对每个任务的backbone和header进行训练,训练完成一个任务的backbone和header后,当训练下一个任务的backbone和header,backbone的参数会去适应新的任务而发生变化,由于主干网络backbone的参数是各任务共享的,因此主干网络backbone的参数会在各任务之间产生冲突。
为解决上述技术问题,本申请提出了一种点云多任务模型训练的梯度更新策略,能够调整不同任务间的共享参数的梯度,以尽量减少各任务之间的干扰。具体思路如下:
步骤1、在给定任务批次B中选定一个任务Ti∈B,再以随机顺序从B中选定一个不同于任务Ti的任务Tj∈B\Ti,任务梯度表示如下:
Figure BDA0003845508640000071
Figure BDA0003845508640000072
步骤2、定义两个任务梯度gi和gj之间的夹角为φij,以余弦相似度来衡量两个任务是否存在梯度冲突,若cosφij<0则存在梯度冲突,反之则无梯度冲突。
步骤3、对于gi和gj存在梯度冲突的情况,需要利用梯度更新规则进行梯度更新,即将gi在gj向量的法平面上的投影来替代原来的gi,更新公式如下:
Figure BDA0003845508640000081
步骤4、对当前批次中随机选取的所有其他任务
Figure BDA0003845508640000082
重复步骤2和步骤3的过程得到任务Ti的最终梯度
Figure BDA0003845508640000083
步骤5、对当前批次中的所有任务执行步骤2,步骤3和步骤4以获得它们各自的梯度,将所有梯度求和实现共享参数θ的优化,表达如下:
Figure BDA0003845508640000084
接下来对上述方法的效果进行详细理论证明:
步骤6、考虑两个任务的损失函数L1:Rn→R和L2:Rn→R,定义一个两任务学习过程,总任务损失函数为L(θ)=L1(θ)+L2(θ),θ∈Rn为模型共享参数。假设L1和L2是凸可微的且L>0并满足利普希茨连续条件,那么采用步长
Figure BDA0003845508640000085
的梯度更新规则要么收敛于“优化地形”中cosφ12=-1的一个位置,要么收敛于最优值L(θ*)。
步骤7、用‖·‖2来表示L2范数,并令
Figure BDA0003845508640000086
根据步骤1和步骤2,设
Figure BDA0003845508640000087
φ12为g1和g2之间的夹角。在每次更新时会有cosφ12<0和cosφ12≥0两种情况。
步骤8、如果cosφ12≥0,就用步长
Figure BDA0003845508640000088
的标准梯度下降更新,目标函数值L(θ)会严格下降(因为它也是凸的),直到θ=θ*,
Figure BDA0003845508640000089
时,即到达最优解。
步骤9、对于cosφ12<0的情况,假设
Figure BDA00038455086400000810
是利普希茨连续且L为常数,这意味着
Figure BDA00038455086400000811
是一个半负定矩阵。根据这个推论,可以围绕L(θ)对L进行二次展开,得到如下不等式:
Figure BDA0003845508640000091
现在我们可以引入梯度更新规则,即
Figure BDA0003845508640000092
带进不等式(5)中可得到
Figure BDA0003845508640000093
将不等式(7)根据恒等式g=g1+g2展开,有
Figure BDA0003845508640000094
进一步整理可得,
Figure BDA0003845508640000095
带入恒等式
Figure BDA0003845508640000096
Figure BDA0003845508640000097
步骤10、由于cosφ12<0,所以最后一项非负,因为步长设定为
Figure BDA0003845508640000098
可以知道
Figure BDA0003845508640000099
且Lt2≤t。将上述结论带入不等式(10)中可得,
Figure BDA0003845508640000101
如果cosφ12>-1,那么
Figure BDA0003845508640000102
总是正的(除非g=0),不等式(11)表明目标函数随着每次cosφ12>-1的迭代严格递减。因此,反复进行梯度更新过程可以达到最优值L(θ)=L(θ*)或者cosφ12=-1,分别对应最优解和次优解的情况,需要注意的是,此结论只在步长t设置得非常小时成立即
Figure BDA0003845508640000103
在介绍了本申请实施例的应用场景和设计思想之后,下面对本申请实施例提供的技术方案进行说明。
如图1所示,本申请实施提供一种点云多任务模型的训练方法,该方法包括如下步骤:
步骤101:获取多个多任务训练样本组合,每个多任务训练样本组合包括多个标注不同任务结果的点云数据样本;
具体的,首先获取N个任务训练数据集合,每个任务训练数据集合包括多个标注一个任务结果的点云数据样本;然后分别从各任务训练数据集合中抽取一个点云数据样本,将N个点云数据样本进行组合,得到一个多任务训练样本组合;最后将所有的多任务训练样本组合构成多任务训练数据集。
其中,任务包括:三维目标检测、三维点云分割、行人轨迹预测和室外场景理解等。
步骤102:利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;
在本实施例中,该步骤具体包括:
将每个多任务训练样本组合中的点云数据样本输入对应的任务分支,得到预测结果;
根据预测结果和点云数据样本的标注结果计算损失函数Li(θ),i为任务分支的编号,1≤i≤N;θ为主干网络参数;
根据损失函数Li(θ),计算第i个任务分支的主干网络参数θ的梯度gi
Figure BDA0003845508640000111
其中,
Figure BDA0003845508640000112
为对Li(θ)中的参数θ的梯度运算。
步骤103:对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则对其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;
在本实施例中,该步骤具体包括:
从i=1开始,执行下述步骤,直至i=N-1:
对于第i个任务分支的主干网络参数θ的梯度gi,计算其与第j个任务分支的主干网络参数θ的梯度gj的夹角为φij,其中,i+1≤j≤N;
判断cosφij<0是否成立,若成立,则梯度gi和梯度gj存在梯度冲突,否则,不存在梯度冲突;
当梯度gi和梯度gj存在梯度冲突,利用下式得到更新后的梯度
Figure BDA0003845508640000113
Figure BDA0003845508640000121
利用
Figure BDA0003845508640000122
更新gj
由此得到第i个任务分支的主干网络参数θ的最终梯度
Figure BDA0003845508640000123
步骤104:计算每个任务分支的主干网络参数的最终梯度的和,作为当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数;
当前多任务模型的主干网络参数θ的梯度的Δθ为:
Figure BDA0003845508640000124
步骤105:基于更新后的主干网络参数和所述多个多任务训练样本组合,继续进行更新过程,直至达到预设的迭代结束条件,将得到的主干网络参数作为训练好的点云多任务模型的模型参数。
其中,迭代结束条件为:主干网络参数的梯度的迭代次数达到预设次数,或者,主干网络参数的梯度小于预设的阈值。
基于上述实施例,本申请实施例提供了一种点云多任务模型的训练装置,参阅图3所示,本申请实施例提供的一种点云多任务模型的训练装置200至少包括:
获取单元201,用于获取多个多任务训练样本组合,每个多任务训练样本组合包括多个标注不同任务结果的点云数据样本;
计算单元202,用于利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;
梯度冲突消除单元203,用于对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则对其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;
主干网络参数更新单元204,用于计算每个任务分支的主干网络参数的最终梯度的和,作为当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数;
迭代单元205,用于计算每个任务分支的主干网络参数的最终梯度的和,作为当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数。
需要说明的是,本申请实施例提供的一种点云多任务模型的训练200解决技术问题的原理与本申请实施例提供的一种点云多任务模型的训练方法相似,因此,本申请实施例提供的一种点云多任务模型的训练装置200的实施可以参见本申请实施例提供的一种点云多任务模型的训练方法的实施,重复之处不再赘述。
基于上述实施例,本申请实施例还提供了一种电子设备,参阅图4所示,本申请实施例提供的电子设备300至少包括:处理器301、存储器302和存储在存储器302上并可在处理器301上运行的计算机程序,处理器301执行计算机程序时实现本申请实施例提供的点云多任务模型的训练方法。
本申请实施例提供的电子设备300还可以包括连接不同组件(包括处理器301和存储器302)的总线303。其中,总线303表示几类总线结构中的一种或多种,包括存储器总线、外围总线、局域总线等。
存储器302可以包括易失性存储器形式的可读介质,例如随机存储器(RandomAccess Memory,RAM)3021和/或高速缓存存储器3022,还可以进一步包括只读存储器(ReadOnly Memory,ROM)3023。
存储器302还可以包括具有一组(至少一个)程序模块3024的程序工具3025,程序模块3024包括但不限于:操作子系统、一个或者多个应用程序、其它程序模块以及程序数据,这些示例中的每一个或某种组合中可能包括网络环境的实现。
电子设备300也可以与一个或多个外部设备304(例如键盘、遥控器等)通信,还可以与一个或者多个使得用户能与电子设备300交互的设备通信(例如手机、电脑等),和/或,与使得电子设备300与一个或多个其它电子设备300进行通信的任何设备(例如路由器、调制解调器等)通信。这种通信可以通过输入/输出(Input/Output,I/O)接口305进行。并且,电子设备300还可以通过网络适配器306与一个或者多个网络(例如局域网(Local AreaNetwork,LAN),广域网(Wide Area Network,WAN)和/或公共网络,例如因特网)通信。如图4所示,网络适配器306通过总线303与电子设备300的其它模块通信。应当理解,尽管图4中未示出,可以结合电子设备300使用其它硬件和/或软件模块,包括但不限于:微代码、设备驱动器、冗余处理器、外部磁盘驱动阵列、磁盘阵列(Redundant Arrays of IndependentDisks,RAID)子系统、磁带驱动器以及数据备份存储子系统等。
需要说明的是,图4所示的电子设备300仅仅是一个示例,不应对本申请实施例的功能和使用范围带来任何限制。
本申请实施例还提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机指令,该计算机指令被处理器执行时实现本申请实施例提供的点云多任务模型的训练方法。具体地,该可执行程序可以内置或者安装在电子设备300中,这样,电子设备300就可以通过执行内置或者安装的可执行程序实现本申请实施例提供的点云多任务模型的训练方法。
本申请实施例提供的点云多任务模型的训练方法还可以实现为一种程序产品,该程序产品包括程序代码,当该程序产品可以在电子设备300上运行时,该程序代码用于使电子设备300执行本申请实施例提供的点云多任务模型的训练方法。
本申请实施例提供的程序产品可以采用一个或多个可读介质的任意组合,其中,可读介质可以是可读信号介质或者可读存储介质,而可读存储介质可以是但不限于是电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合,具体地,可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、RAM、ROM、可擦式可编程只读存储器(Erasable Programmable Read Only Memory,EPROM)、光纤、便携式紧凑盘只读存储器(Compact Disc Read-Only Memory,CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
本申请实施例提供的程序产品可以采用CD-ROM并包括程序代码,还可以在计算设备上运行。然而,本申请实施例提供的程序产品不限于此,在本申请实施例中,可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
应当注意,尽管在上文详细描述中提及了装置的若干单元或子单元,但是这种划分仅仅是示例性的并非强制性的。实际上,根据本申请的实施方式,上文描述的两个或更多单元的特征和功能可以在一个单元中具体化。反之,上文描述的一个单元的特征和功能可以进一步划分为由多个单元来具体化。
此外,尽管在附图中以特定顺序描述了本申请方法的操作,但是,这并非要求或者暗示必须按照该特定顺序来执行这些操作,或是必须执行全部所示的操作才能实现期望的结果。附加地或备选地,可以省略某些步骤,将多个步骤合并为一个步骤执行,和/或将一个步骤分解为多个步骤执行。
最后所应说明的是,以上实施例仅用以说明本申请的技术方案而非限制。尽管参照实施例对本申请进行了详细说明,本领域的普通技术人员应当理解,对本申请的技术方案进行修改或者等同替换,都不脱离本申请技术方案的精神和范围,其均应涵盖在本申请的权利要求范围当中。

Claims (7)

1.一种点云多任务模型的训练方法,所述点云多任务模型包括一个主干网络和多个任务处理模型,所述主干网络和每个任务处理模型连接构成多个任务分支;其特征在于,包括:
获取多个多任务训练样本组合,每个多任务训练样本组合包括多个标注不同任务结果的点云数据样本;
利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;
对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则对其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;
计算每个任务分支的主干网络参数的最终梯度的和,作为当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数;
基于更新后的主干网络参数和所述多个多任务训练样本组合,继续进行主干网络的参数更新过程,直至达到预设的迭代结束条件,将得到的主干网络参数作为训练好的点云多任务模型的模型参数。
2.根据权利要求1所述的点云多任务模型的训练方法,其特征在于,所述点云多任务模型包括N个任务分支,获取多个多任务训练样本组合;包括:
获取N个任务训练数据集合,每个任务训练数据集合包括多个标注一个任务结果的点云数据样本;
分别从各任务训练数据集合中抽取一个点云数据样本,将N个点云数据样本进行组合,得到多任务训练样本组合。
3.根据权利要求2所述的点云多任务模型的训练方法,其特征在于,利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;包括:
将每个多任务训练样本组合中的点云数据样本输入对应的任务分支,得到预测结果;
根据预测结果和点云数据样本的标注结果计算损失函数Li(θ),i为任务分支的编号,1≤i≤N;θ为主干网络参数;
根据损失函数Li(θ),计算第i个任务分支的主干网络参数θ的梯度gi
Figure FDA0003845508630000021
其中,
Figure FDA0003845508630000022
为对Li(θ)中的参数θ的梯度运算。
4.根据权利要求3所述的点云多任务模型的训练方法,其特征在于,对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则将其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;包括:
从i=1开始,执行下述步骤,直至i=N-1:
对于第i个任务分支的主干网络参数θ的梯度gi,计算其与第j个任务分支的主干网络参数θ的梯度gj的夹角为φij,其中,i+1≤j≤N;
判断cosφij<0是否成立,若成立,则梯度gi和梯度gj存在梯度冲突,否则,不存在梯度冲突;
当梯度gi和梯度gj存在梯度冲突,利用下式得到更新后的梯度
Figure FDA0003845508630000023
Figure FDA0003845508630000031
利用
Figure FDA0003845508630000032
更新gj
由此得到第i个任务分支的主干网络参数θ的最终梯度。
5.一种点云多任务模型的训练装置,所述点云多任务模型包括一个主干网络和多个任务处理模型,所述主干网络和每个任务处理模型连接构成任务分支;其特征在于,包括:
获取单元,用于获取多个多任务训练样本组合,每个多任务训练样本组合包括多个标注不同任务结果的点云数据样本;
计算单元,用于利用点云多任务模型对每个多任务训练样本组合进行处理,得到每个任务分支的损失函数,分别计算每个任务分支的主干网络参数的梯度;
梯度冲突消除单元,用于对于每个任务分支上的主干网络参数的梯度,判断其与其它任务分支的主干网络参数的梯度是否存在冲突,若存在则对其它任务分支的主干网络参数的梯度进行更新,得到每个任务分支的主干网络参数的最终梯度;
主干网络参数更新单元,用于计算每个任务分支的主干网络参数的最终梯度的和,作为当前多任务模型的主干网络参数的梯度;利用当前多任务模型的主干网络参数的梯度更新主干网络参数;
迭代单元,用于基于更新后的主干网络参数和所述多个多任务训练样本组合,继续进行主干网络的参数更新过程,直至达到预设的迭代结束条件,将得到的主干网络参数作为训练好的点云多任务模型的模型参数。
6.一种电子设备,其特征在于,包括:存储器、处理器和存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如权利要求1-4任一项所述的点云多任务模型的训练方法。
7.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机指令,所述计算机指令被处理器执行时实现如权利要求1-4任一项所述的点云多任务模型的训练方法。
CN202211115837.1A 2022-09-14 2022-09-14 一种点云多任务模型的训练方法、装置及电子设备 Pending CN115358413A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211115837.1A CN115358413A (zh) 2022-09-14 2022-09-14 一种点云多任务模型的训练方法、装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211115837.1A CN115358413A (zh) 2022-09-14 2022-09-14 一种点云多任务模型的训练方法、装置及电子设备

Publications (1)

Publication Number Publication Date
CN115358413A true CN115358413A (zh) 2022-11-18

Family

ID=84007405

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211115837.1A Pending CN115358413A (zh) 2022-09-14 2022-09-14 一种点云多任务模型的训练方法、装置及电子设备

Country Status (1)

Country Link
CN (1) CN115358413A (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115984827A (zh) * 2023-03-06 2023-04-18 安徽蔚来智驾科技有限公司 点云感知方法、计算机设备及计算机可读存储介质
CN115994936A (zh) * 2023-03-23 2023-04-21 季华实验室 点云融合模型获取方法、装置、电子设备及存储介质
CN116070119A (zh) * 2023-03-31 2023-05-05 北京数慧时空信息技术有限公司 基于小样本的多任务组合模型的训练方法
CN116385825A (zh) * 2023-03-22 2023-07-04 小米汽车科技有限公司 模型联合训练方法、装置及车辆
CN116740669A (zh) * 2023-08-16 2023-09-12 之江实验室 多目图像检测方法、装置、计算机设备和存储介质

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112561077A (zh) * 2020-12-14 2021-03-26 北京百度网讯科技有限公司 多任务模型的训练方法、装置及电子设备
CN113420787A (zh) * 2021-05-31 2021-09-21 哈尔滨工业大学(深圳) 一种缓解多任务学习中任务冲突方法、装置及存储介质
US20210374542A1 (en) * 2020-12-14 2021-12-02 Beijing Baidu Netcom Science And Technology Co., Ltd. Method and apparatus for updating parameter of multi-task model, and storage medium
CN114237838A (zh) * 2021-11-23 2022-03-25 华南理工大学 基于自适应损失函数加权的多任务模型训练方法
CN114820463A (zh) * 2022-04-06 2022-07-29 合众新能源汽车有限公司 点云检测和分割方法、装置,以及,电子设备

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112561077A (zh) * 2020-12-14 2021-03-26 北京百度网讯科技有限公司 多任务模型的训练方法、装置及电子设备
US20210374542A1 (en) * 2020-12-14 2021-12-02 Beijing Baidu Netcom Science And Technology Co., Ltd. Method and apparatus for updating parameter of multi-task model, and storage medium
CN113420787A (zh) * 2021-05-31 2021-09-21 哈尔滨工业大学(深圳) 一种缓解多任务学习中任务冲突方法、装置及存储介质
CN114237838A (zh) * 2021-11-23 2022-03-25 华南理工大学 基于自适应损失函数加权的多任务模型训练方法
CN114820463A (zh) * 2022-04-06 2022-07-29 合众新能源汽车有限公司 点云检测和分割方法、装置,以及,电子设备

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
TIANHE YU等: "Gradient Surgery for Multi-Task Learning" *

Cited By (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115984827A (zh) * 2023-03-06 2023-04-18 安徽蔚来智驾科技有限公司 点云感知方法、计算机设备及计算机可读存储介质
CN115984827B (zh) * 2023-03-06 2024-02-02 安徽蔚来智驾科技有限公司 点云感知方法、计算机设备及计算机可读存储介质
CN116385825A (zh) * 2023-03-22 2023-07-04 小米汽车科技有限公司 模型联合训练方法、装置及车辆
CN116385825B (zh) * 2023-03-22 2024-04-30 小米汽车科技有限公司 模型联合训练方法、装置及车辆
CN115994936A (zh) * 2023-03-23 2023-04-21 季华实验室 点云融合模型获取方法、装置、电子设备及存储介质
CN115994936B (zh) * 2023-03-23 2023-06-30 季华实验室 点云融合模型获取方法、装置、电子设备及存储介质
CN116070119A (zh) * 2023-03-31 2023-05-05 北京数慧时空信息技术有限公司 基于小样本的多任务组合模型的训练方法
CN116070119B (zh) * 2023-03-31 2023-10-27 北京数慧时空信息技术有限公司 基于小样本的多任务组合模型的训练方法
CN116740669A (zh) * 2023-08-16 2023-09-12 之江实验室 多目图像检测方法、装置、计算机设备和存储介质
CN116740669B (zh) * 2023-08-16 2023-11-14 之江实验室 多目图像检测方法、装置、计算机设备和存储介质

Similar Documents

Publication Publication Date Title
CN115358413A (zh) 一种点云多任务模型的训练方法、装置及电子设备
Pierson et al. Deep learning in robotics: a review of recent research
Carlucho et al. AUV position tracking control using end-to-end deep reinforcement learning
CN111507378A (zh) 训练图像处理模型的方法和装置
US20220335304A1 (en) System and Method for Automated Design Space Determination for Deep Neural Networks
US20230419113A1 (en) Attention-based deep reinforcement learning for autonomous agents
WO2020062911A1 (en) Actor ensemble for continuous control
CN113204988B (zh) 小样本视点估计
CN113449859A (zh) 一种数据处理方法及其装置
CN114792359B (zh) 渲染网络训练和虚拟对象渲染方法、装置、设备及介质
CN109902192B (zh) 基于无监督深度回归的遥感图像检索方法、系统、设备及介质
CN115860102B (zh) 一种自动驾驶感知模型的预训练方法、装置、设备和介质
CN111274994A (zh) 漫画人脸检测方法、装置、电子设备及计算机可读介质
CN113011568A (zh) 一种模型的训练方法、数据处理方法及设备
Wodziński et al. Sequential classification of palm gestures based on A* algorithm and MLP neural network for quadrocopter control
Ou et al. GPU-based global path planning using genetic algorithm with near corner initialization
CN117372983B (zh) 一种低算力的自动驾驶实时多任务感知方法及装置
Lin et al. Robot grasping based on object shape approximation and LightGBM
CN113239799A (zh) 训练方法、识别方法、装置、电子设备和可读存储介质
Xia et al. Hybrid feature adaptive fusion network for multivariate time series classification with application in AUV fault detection
CN111260074A (zh) 一种超参数确定的方法、相关装置、设备及存储介质
CN113837993B (zh) 轻量级虹膜图像分割方法、装置、电子设备及存储介质
CN115952856A (zh) 一种基于双向分割的神经网络流水线并行训练方法及系统
CN114707643A (zh) 一种模型切分方法及其相关设备
CN116710974A (zh) 在合成数据系统和应用程序中使用域对抗学习的域适应

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20221118

RJ01 Rejection of invention patent application after publication