CN113326826A - 网络模型的训练方法、装置、电子设备及存储介质 - Google Patents

网络模型的训练方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN113326826A
CN113326826A CN202110883889.2A CN202110883889A CN113326826A CN 113326826 A CN113326826 A CN 113326826A CN 202110883889 A CN202110883889 A CN 202110883889A CN 113326826 A CN113326826 A CN 113326826A
Authority
CN
China
Prior art keywords
network model
data
training
pseudo
loss function
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110883889.2A
Other languages
English (en)
Inventor
王力超
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Neolix Technologies Co Ltd
Original Assignee
Neolix Technologies Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Neolix Technologies Co Ltd filed Critical Neolix Technologies Co Ltd
Priority to CN202110883889.2A priority Critical patent/CN113326826A/zh
Publication of CN113326826A publication Critical patent/CN113326826A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V20/00Scenes; Scene-specific elements
    • G06V20/50Context or environment of the image
    • G06V20/56Context or environment of the image exterior to a vehicle by using sensors mounted on the vehicle
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

本公开涉及人工智能技术领域,提供了一种网络模型的训练方法、装置、电子设备及存储介质。该方法应用于无人车,即无人驾驶设备或自动驾驶设备,包括:执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件。本公开能够基于伪标签数据和有标签数据对改进的网络模型进行训练,因此,提升了网络模型的泛化性能,提高了网络模型的预测准确率。

Description

网络模型的训练方法、装置、电子设备及存储介质
技术领域
本公开涉及人工智能技术领域,尤其涉及一种网络模型的训练方法、装置、电子设备及计算机可读存储介质。
背景技术
深度学习(Deep Learning,DL)是机器学习(Machine Learning,ML)研究中的一个新的领域,其动机在于建立和模拟人脑进行分析学习的神经网络,并模仿人脑的机制来解释数据,例如,图像、声音和文本。深度学习是无监督学习的一种。
近年来,随着深度学习技术的不断进步和计算机算力的不断提升,数据分类技术在诸如语音分析、图像识别、自然语言处理等各个领域取得了巨大的进展。以图像识别技术领域为例,通常可以使用大规模的有标签数据的训练样本作为训练集,应用相应的神经网络来训练分类器,使其可以学习图像的全局或局部特征,并将该全局或局部特征与已学习的特征进行比对,从而确定每个图像中对象的类别。
然而,现有的数据分类技术依赖于人工标注的标签数据,这不仅耗费了大量的时间和人力成本,而且受限于标签数据的准确性和数据规模等因素,导致网络模型的泛化性能差且预测准确率低。
发明内容
有鉴于此,本公开实施例提供了一种网络模型的训练方法、装置、电子设备及计算机可读存储介质,以解决现有技术中数据分类技术依赖于人工标注的标签数据,耗费了大量的时间和人力成本且受限于标签数据的准确性和数据规模,导致网络模型的泛化性能差且预测准确率低的问题。
本公开实施例的第一方面,提供了一种网络模型的训练方法,包括:执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件。
本公开实施例的第二方面,提供了一种网络模型的训练装置,包括:第一训练模块,被配置为执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;处理模块,被配置为执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;第二训练模块,被配置为执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;交替模块,被配置为控制处理模块和第二训练模块交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件。
本公开实施例的第三方面,提供了一种电子设备,包括存储器、处理器以及存储在存储器中并且可以在处理器上运行的计算机程序,该处理器执行计算机程序时实现上述方法的步骤。
本公开实施例的第四方面,提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机程序,该计算机程序被处理器执行时实现上述方法的步骤。
本公开实施例与现有技术相比存在的有益效果是:通过执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件,能够利用所生成的更准确的伪标签数据和有标签数据对改进的网络模型进行训练,因此,提升了网络模型的性能,并提高了网络模型的预测准确率。
附图说明
为了更清楚地说明本公开实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本公开的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1是本公开实施例的应用场景的场景示意图;
图2是本公开实施例提供的一种网络模型的训练方法的流程图;
图3是本公开实施例提供的另一种网络模型的训练方法的流程图;
图4是本公开实施例提供的再一种网络模型的训练方法的流程图;
图5是本公开实施例提供的一种网络模型的训练装置的框图;
图6是本公开实施例提供的一种电子设备的示意图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本公开实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本公开。在其它情况中,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本公开的描述。
下面将结合附图详细说明根据本公开实施例的一种网络模型的训练方法和装置。
图1是本公开实施例的应用场景的场景示意图。该应用场景可以包括终端设备1、无人驾驶设备2、服务器3和网络4。
终端设备1可以是硬件,也可以是软件。当终端设备1为硬件时,其可以是具有显示屏且支持与服务器3通信的各种电子设备,包括但不限于智能手机、平板电脑、膝上型便携计算机和台式计算机等;当终端设备1为软件时,其可以安装在如上的电子设备中。终端设备1可以实现为多个软件或软件模块,也可以实现为单个软件或软件模块,本公开实施例对此不作限制。进一步地,终端设备1上可以安装有各种应用,例如,数据处理应用、即时通信工具、社交平台软件、远程控制软件、搜索类应用、购物类应用等。
无人驾驶设备2可以是支持无人驾驶、自动驾驶和远程驾驶中的任一功能的车辆,包括但不限于无人驾驶汽车、无人驾驶飞机、无人驾驶船舶、自动配送设备、机器人等。
服务器3可以是提供各种服务的服务器,例如,对与其建立通信连接的终端设备发送的请求进行接收的后台服务器,该后台服务器可以对终端设备发送的请求进行接收和分析等处理,并生成处理结果。服务器3可以是一台服务器,也可以是由若干台服务器组成的服务器集群,或者还可以是一个云计算服务中心,本公开实施例对此不作限制。
需要说明的是,服务器3可以是硬件,也可以是软件。当服务器3为硬件时,其可以是为终端设备1提供各种服务的各种电子设备。当服务器3为软件时,其可以实现为为终端设备1提供各种服务的多个软件或软件模块,也可以实现为为终端设备1提供各种服务的单个软件或软件模块,本公开实施例对此不作限制。
终端设备1或无人驾驶设备2与服务器3之间可以通过网络4进行信息交互。网络4可以是采用同轴电缆、双绞线和光纤连接的有线网络,也可以是无需布线就能实现各种通信设备互联的无线网络,例如,蓝牙(Bluetooth)、近场通信(Near Field Communication,NFC)、红外(Infrared)等,本公开实施例对此不作限制。
本公开实施例提供的网络模型的训练方法可以由终端设备1或无人驾驶设备2执行,也可以由服务器3执行,或者还可以由终端设备1、无人驾驶设备2和服务器3共同执行,本公开实施例对此不作限制。举例来说,可以将基于本公开实施例训练得到的网络模型配置在图1所示的终端设备1上,也可以配置在无人驾驶设备2上,或者还可以配置在服务器3上,当用户通过终端设备1上传数据时,可以通过上述网络模型对上传数据进行分类,从而将用户上传的数据自动划分为相应的类别。
需要说明的是,终端设备1、无人驾驶设备2和服务器3的具体类型、数量和组合可以根据应用场景的实际需求进行调整,本公开实施例对此不作限制。
图2是本公开实施例提供的一种网络模型的训练方法的流程图。图2的网络模型的训练方法可以由图1的终端设备1、无人驾驶设备2和/或服务器3执行。如图2所示,该网络模型的训练方法包括:
S201,执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;
S202,执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;
S203,执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;
S204,交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件。
具体地,在利用点云数据集中的有标签数据对改进的网络模型进行训练时,可以将有标签数据输入至改进的网络模型,以通过对改进的网络模型进行训练来得到训练好的改进的网络模型;在生成无标签数据的伪标签数据时,可以利用训练好的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理以生成伪标签数据,例如,可以通过计算有标签数据与伪标签数据之间的相似度,将具有最高相似度的有标签数据的标签确定为对应伪标签数据的标签,以生成无标签数据的伪标签数据;进一步地,将所生成的伪标签数据和有标签数据输入至改进的网络模型,以通过对改进的网络模型进行训练来得到训练好的改进的网络模型;交替执行处理步骤和第二训练步骤,直至改进的网络模型饱和,即,循环所达到的最大迭代次数。
这里,点云数据是指在一个三维坐标系统中的一组向量的集合。点云数据除了具有几何位置以外,有的还有颜色信息,颜色信息通常是通过相机获取彩色影像,再将对应位置的像素的颜色信息(RGB)赋予点云中对应的点。大多数点云数据是由三维(3D)扫描设备产生的,例如,激光雷达(2D/3D)、立体摄像头(stereo camera)、越渡时间相机(time-of-flight camera)等。这些设备用自动化的方式测量在物体表面的大量的点的信息,再用某种数据文件输出点云数据。在本公开实施例中,点云数据也可以理解为训练数据,是数据挖掘过程中用于训练网络模型的数据。各个点云数据可以构成点云数据集,该点云数据集可以包括有标签数据和无标签数据。有标签数据是指训练数据中带有标签的数据,该标签可以由人工预先标注生成;无标签数据是指训练数据中不带有标签的数据。
机器学习是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习等技术。机器学习可以分为监督学习、无监督学习和半监督学习,这里,监督学习的所有数据都是有标签数据,无监督学习的所有数据都是无标签数据,半监督学习的部分数据是有标签数据,但大部分数据是无标签数据。
伪标签数据也称软伪标签数据,表示伪标签数据中的标签数据是根据已有标签预测出来的,也就是说,伪标签数据中的标签数据并非真实的标签数据,而是基于已有标签近似得到的标签数据。由于伪标签数据的标签并非是真实标签,为了提高改进的网络模型的准确率,可选地,在生成训练好的网络模型之前,可以根据有标签数据中的标签数据对改进的网络模型进行训练,并调整改进的网络模型的参数。
进一步地,由于伪标签数据的标签并不是完全准确的,为了避免因标签不准确而对改进的网络模型的训练性能造成影响,可选地,在利用伪标签数据对改进的网络模型进行训练时,可以通过置信学习等方法确定伪标签数据中各标签的标签置信度,并根据标签置信度对伪标签数据进行筛选,以得到标签置信度较高的数据。这里,标签置信度也称为标签可靠度,或者标签置信水平、置信系数、置信值等,可以用于衡量标签的真实值有一定概率落在测量结果周围的程度。
网络模型可以具有半监督学习模型的网络结构,网络模型可以是点体素集成网络(PointVoxel-RCNN,PV-RCNN),用于从点云中精确检测三维目标。需要说明的是,本公开的网络模型不限于如上的PV-RCNN,例如,还可以为VAT、LPDSSL、TNAR、pseudo-label、DCT、mean teacher模型中的任一种。
迭代是指程序中对一组指令(或一定步骤)的重复。迭代可以被用作通用的术语(与“重复”同义),也可以用来描述一种特定形式的具有可变状态的重复。由于数值迭代是逐步逼近最优点而获得近似解的,其无限地接近于最优点却又不是理论上的最优点,因此,需要考虑在什么样的条件下才终止迭代,获得一个足够精度的近似极小点,这一条件就是迭代计算的终止准则。对于最优化问题,常用的迭代过程终止准则可以包括但不限于点距准则、函数下降量准则、梯度准则等。
最大迭代次数可以是用户根据经验数据预先设置的阈值,也可以是用户根据所传输的视频图像的清晰度对已设置的间隔进行调整后得到的阈值,本公开实施例对此不作限制。在本公开实施例中,最大迭代次数可以根据实际需要设置,例如,可以为2次、5次、8次、10次、15次、20次、30次等。
需要说明的是,由于迭代计算是一个推算过程,逐渐产生最接近真实结果的解,因此,其得到的解有可能正确,也有可能错误,但误差会相当小。在选项中设置的误差值和迭代计算次数直接影响所得到的解的正确性,次数越大越接近真实的解,但也会耗费更长的时间。
根据本公开实施例提供的技术方案,通过执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件,能够利用所生成的更准确的伪标签数据和有标签数据对改进的网络模型进行训练,因此,提升了网络模型的性能,并提高了网络模型的预测准确率。
在一些实施例中,改进的网络模型通过如下步骤得到:仿照有标签数据中的正标签数据的交叉熵损失函数,生成负标签数据的交叉熵损失函数;将正标签数据的交叉熵损失函数和负标签数据的交叉熵损失函数进行融合处理,得到融合后的交叉熵损失函数;利用融合后的交叉熵损失函数对原始的网络模型进行训练,得到改进的网络模型。
具体地,正标签数据也称正样本,是指属于某一类别的样本;负标签数据也称负样本,是指不属于某一类别的样本。例如,在进行“猫”的图像识别时,“猫”的样本就属于正样本,不是“猫”的样本就属于负样本。又例如,针对分类的问题,正样本是正确分类出的类别所对应的样本,比如,对一张图片进行分类以确定其是否属于汽车,则在训练时,汽车的图片为正样本,负样本原则上可以选取诸如马路、树木、路灯等任何非汽车的图片。
损失函数(loss function)是将随机事件或其有关随机变量的取值映射为非负实数以表示该随机事件的“风险”或“损失”的函数。在应用中,损失函数通常作为学习准则与优化问题相联系,即通过最小化损失函数求解和评估模型。损失函数用来评价模型的预测值和真实值不一样的程度,损失函数越好,通常模型的性能越好。
交叉熵损失函数(Cross Entropy Loss Function)是一个平滑函数,经常用于分类问题中,特别是在神经网络做分类问题时,经常使用交叉熵作为损失函数,此外,由于交叉熵涉及到计算每个类别的概率,因此,交叉熵几乎每次都和sigmoid(或softmax)数一起出现。交叉熵损失函数用来描述模型预测值和真实值的差距大小,交叉熵损失函数越大代表越不相近。
根据本公开实施例提供的技术方案,通过基于正标签数据和负标签数据来对点云数据集中的无标签数据进行打上伪标签的处理,能够提升标注结果的准确性和可靠性,并进一步地提升了标注速度,节约了大量资源。
在一些实施例中,正标签数据的交叉熵损失函数如下:
Figure 744352DEST_PATH_IMAGE001
,其中,
Figure 30977DEST_PATH_IMAGE002
表示正标签数据的交叉熵损失函数;m表示有标签数据的数量,m为正整数;c表示有标签数据的类别;n表示类别的数量,n为正整数;
Figure 92474DEST_PATH_IMAGE003
表示伪标签是否参与训练的置信值,
Figure 186945DEST_PATH_IMAGE003
的值为0或1;
Figure 599471DEST_PATH_IMAGE004
表示第c个类别的伪标签,
Figure 627470DEST_PATH_IMAGE005
的值为0或1;
Figure 871370DEST_PATH_IMAGE006
表示第c个类别的预测值,
Figure 61043DEST_PATH_IMAGE006
的取值范围为[0,1]。
在一些实施例中,负标签数据的交叉熵损失函数如下:
Figure 773915DEST_PATH_IMAGE007
,其中,
Figure 340025DEST_PATH_IMAGE008
表示负标签数据的交叉熵损失函数。
在一些实施例中,融合后的交叉熵损失函数如下:
Figure 704011DEST_PATH_IMAGE009
,其中,
Figure 64585DEST_PATH_IMAGE010
表示融合后的交叉熵损失函数。
在一些实施例中,该网络模型的训练方法还包括:将训练得到的网络模型部署到车辆中,其中,车辆包括自动驾驶车辆或无人驾驶车辆。
具体地,车辆是能够实现无人驾驶的各种设备,例如,无人驾驶飞机、无人驾驶船舶、自动配送设备、机器人等;也可以是具有自动巡航控制功能的车辆,例如,轿车、房车、卡车、越野车、运动型实用汽车(Sport Utility Vehicle,SUV);或者还可以是电动车、自行车等,本公开实施例对此不作限制。优选地,在本公开实施例中,车辆可以是自动驾驶车辆或无人驾驶车辆。
上述所有可选技术方案,可以采用任意结合形成本申请的可选实施例,在此不再一一赘述。
下面,通过具体示例对本公开实施例的网络模型的训练方法进行描述。
假设一个样本(即,一张点云图像)中有一辆汽车、一个行人和一辆自行车,则可以确定样本的总数为3;进一步地,假设预测结果为“汽车”的伪标签和预测值分别为1和0.6,预测结果为“行人”的伪标签和预测值分别为0和0.1,预测结果为“自行车”的伪标签和预测值分别为0和0.3,如下表所示。
类别(c) 汽车 行人 自行车
伪标签(
Figure 451704DEST_PATH_IMAGE011
1 0 0
预测值(
Figure 634555DEST_PATH_IMAGE012
0.6 0.1 0.3
置信值(
Figure 525150DEST_PATH_IMAGE013
1 0 0
进一步地,假设伪标签的置信度阈值为0.5,由于预测结果为“汽车”的预测值为0.6,大于伪标签的置信度阈值0.5,则汽车参与训练的置信值为1;由于预测结果为“行人”的预测值为0.1,小于伪标签的置信度阈值0.5,则行人参与训练的置信值为0;由于预测结果为“自行车”的预测值为0.3,小于伪标签的置信度阈值0.5,则自行车参与训练的置信值为0。
基于以上数据并根据交叉熵损失函数的计算公式,可以得到:
Figure 384522DEST_PATH_IMAGE014
上述所有可选技术方案,可以采用任意结合形成本申请的可选实施例,在此不再一一赘述。
图3是本公开实施例提供的另一种网络模型的训练方法的流程图。图3的网络模型的训练方法可以由图1的终端设备1、无人驾驶设备2和/或服务器3执行。如图3所示,该网络模型的训练方法包括:
S301,执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;
S302,执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;
S303,执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;
S304,交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件;
S305,将训练得到的网络模型部署到车辆中,其中,车辆包括自动驾驶车辆或无人驾驶车辆。
具体地,以服务器为例,服务器执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练,执行处理步骤以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据,并执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;进一步地,服务器交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件,并将训练得到的网络模型部署到包括自动驾驶车辆或无人驾驶车辆的车辆中。
根据本公开实施例提供的技术方案,能够利用所生成的更准确的伪标签数据和有标签数据对改进的网络模型进行训练,因此,提升了网络模型的性能,并提高了网络模型的预测准确率。
图4是本公开实施例提供的再一种网络模型的训练方法的流程图。图4的网络模型的训练方法可以由图1的终端设备1、无人驾驶设备2和/或服务器3执行。如图4所示,该网络模型的训练方法包括:
S401,利用点云数据集中的有标签数据对改进的网络模型进行训练;
S402,利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;
S403,利用伪标签数据和有标签数据对改进的网络模型进行训练;
S404,确定是否满足改进的网络模型的训练结束条件,如果是,则执行S405;否则,执行S402;
S405,将训练得到的网络模型部署到车辆中,其中,车辆包括自动驾驶车辆或无人驾驶车辆。
具体地,以服务器为例,服务器利用点云数据集中的有标签数据对改进的网络模型进行训练,并利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,以生成无标签数据的伪标签数据;进一步地,服务器利用伪标签数据和有标签数据对改进的网络模型进行训练,并确定是否满足改进的网络模型的训练结束条件,在确定满足改进的网络模型的训练结束条件的情况下,服务器将训练得到的网络模型部署到车辆中;在确定不满足改进的网络模型的训练结束条件的情况下,服务器返回执行利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理以生成无标签数据的伪标签数据的步骤。
根据本公开实施例提供的技术方案,能够利用所生成的更准确的伪标签数据和有标签数据对改进的网络模型进行训练,因此,提升了网络模型的性能,并提高了网络模型的预测准确率。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本公开实施例的实施过程构成任何限定。
下述为本公开系统实施例,可以用于执行本公开方法实施例。对于本公开系统实施例中未披露的细节,请参照本公开方法实施例。
图5是本公开实施例提供的一种网络模型的训练装置的示意图。如图5所示,该网络模型的训练装置包括:
第一训练模块501,被配置为执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;
处理模块502,被配置为执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;
第二训练模块503,被配置为执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;
交替模块504,被配置为控制处理模块和第二训练模块交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件。
根据本公开实施例提供的技术方案,通过执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;执行处理步骤,以利用训练得到的改进的网络模型对点云数据集中的无标签数据进行打上伪标签的处理,生成无标签数据的伪标签数据;执行第二训练步骤,以利用伪标签数据和有标签数据对改进的网络模型进行训练;交替执行处理步骤和第二训练步骤,直至满足改进的网络模型的训练结束条件,能够利用所生成的更准确的伪标签数据和有标签数据对改进的网络模型进行训练,因此,提升了网络模型的性能,并提高了网络模型的预测准确率。
在一些实施例中,改进的网络模型通过如下步骤得到:仿照有标签数据中的正标签数据的交叉熵损失函数,生成负标签数据的交叉熵损失函数;将正标签数据的交叉熵损失函数和负标签数据的交叉熵损失函数进行融合处理,得到融合后的交叉熵损失函数;利用融合后的交叉熵损失函数对原始的网络模型进行训练,得到改进的网络模型。
在一些实施例中,正标签数据的交叉熵损失函数如下:
Figure 993358DEST_PATH_IMAGE001
,其中,
Figure 979899DEST_PATH_IMAGE002
表示正标签数据的交叉熵损失函数;m表示有标签数据的数量,m为正整数;c表示有标签数据的类别;n表示类别的数量,n为正整数;
Figure 990581DEST_PATH_IMAGE003
表示伪标签是否参与训练的置信值,
Figure 755274DEST_PATH_IMAGE003
的值为0或1;
Figure 192684DEST_PATH_IMAGE004
表示第c个类别的伪标签,
Figure 107551DEST_PATH_IMAGE005
的值为0或1;
Figure 831793DEST_PATH_IMAGE006
表示第c个类别的预测值,
Figure 439492DEST_PATH_IMAGE006
的取值范围为[0,1]。
在一些实施例中,负标签数据的交叉熵损失函数如下:
Figure 367128DEST_PATH_IMAGE007
,其中,
Figure 85685DEST_PATH_IMAGE008
表示负标签数据的交叉熵损失函数。
在一些实施例中,融合后的交叉熵损失函数如下:
Figure 664434DEST_PATH_IMAGE009
,其中,
Figure 443034DEST_PATH_IMAGE010
表示融合后的交叉熵损失函数。
在一些实施例中,该网络模型的训练装置包括:部署模块505,被配置为将训练得到的网络模型部署到车辆中,其中,车辆包括自动驾驶车辆或无人驾驶车辆。
在一些实施例中,网络模型为点体素集成网络。
图6是本公开实施例提供的一种电子设备6的示意图。如图6所示,该实施例的电子设备6包括:处理器601、存储器602以及存储在该存储器602中并且可以在处理器601上运行的计算机程序603。处理器601执行计算机程序603时实现上述各个方法实施例中的步骤。或者,处理器601执行计算机程序603时实现上述各装置实施例中各模块/单元的功能。
示例性地,计算机程序603可以被分割成一个或多个模块/单元,一个或多个模块/单元被存储在存储器602中,并由处理器601执行,以完成本公开。一个或多个模块/单元可以是能够完成特定功能的一系列计算机程序指令段,该指令段用于描述计算机程序603在电子设备6中的执行过程。
电子设备6可以是桌上型计算机、笔记本、掌上电脑及云端服务器等电子设备。电子设备6可以包括但不仅限于处理器601和存储器602。本领域技术人员可以理解,图6仅仅是电子设备6的示例,并不构成对电子设备6的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件,例如,电子设备还可以包括输入输出设备、网络接入设备、总线等。
处理器601可以是中央处理单元(Central Processing Unit,CPU),也可以是其它通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其它可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
存储器602可以是电子设备6的内部存储单元,例如,电子设备6的硬盘或内存。存储器602也可以是电子设备6的外部存储设备,例如,电子设备6上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。进一步地,存储器602还可以既包括电子设备6的内部存储单元也包括外部存储设备。存储器602用于存储计算机程序以及电子设备所需的其它程序和数据。存储器602还可以用于暂时地存储已经输出或者将要输出的数据。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本申请的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本公开的范围。
在本公开所提供的实施例中,应该理解到,所揭露的装置/电子设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/电子设备实施例仅仅是示意性的,例如,模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通讯连接可以是通过一些接口,装置或单元的间接耦合或通讯连接,可以是电性,机械或其它的形式。
作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本公开各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本公开实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,计算机程序可以存储在计算机可读存储介质中,该计算机程序在被处理器执行时,可以实现上述各个方法实施例的步骤。计算机程序可以包括计算机程序代码,计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。计算机可读介质可以包括:能够携带计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、电载波信号、电信信号以及软件分发介质等。需要说明的是,计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如,在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。
以上实施例仅用以说明本公开的技术方案,而非对其限制;尽管参照前述实施例对本公开进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本公开各实施例技术方案的精神和范围,均应包含在本公开的保护范围之内。

Claims (10)

1.一种网络模型的训练方法,其特征在于,包括:
执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;
执行处理步骤,以利用训练得到的改进的网络模型对所述点云数据集中的无标签数据进行打上伪标签的处理,生成所述无标签数据的伪标签数据;
执行第二训练步骤,以利用所述伪标签数据和所述有标签数据对所述改进的网络模型进行训练;
交替执行所述处理步骤和所述第二训练步骤,直至满足所述改进的网络模型的训练结束条件。
2.根据权利要求1所述的方法,其特征在于,所述改进的网络模型通过如下步骤得到:
仿照所述有标签数据中的正标签数据的交叉熵损失函数,生成负标签数据的交叉熵损失函数;
将所述正标签数据的交叉熵损失函数和所述负标签数据的交叉熵损失函数进行融合处理,得到融合后的交叉熵损失函数;
利用所述融合后的交叉熵损失函数对原始的网络模型进行训练,得到所述改进的网络模型。
3.根据权利要求2所述的方法,其特征在于,所述正标签数据的交叉熵损失函数如下:
Figure 199145DEST_PATH_IMAGE001
其中,
Figure 277959DEST_PATH_IMAGE002
表示所述正标签数据的交叉熵损失函数;m表示所述有标签数据的数量,m为正整数;c表示所述有标签数据的类别;n表示所述类别的数量,n为正整数;
Figure 588986DEST_PATH_IMAGE003
表示所述伪标签是否参与训练的置信值,
Figure 360633DEST_PATH_IMAGE003
的值为0或1;
Figure 373588DEST_PATH_IMAGE004
表示第c个类别的伪标签,
Figure 256094DEST_PATH_IMAGE005
的值为0或1;
Figure 418697DEST_PATH_IMAGE006
表示第c个类别的预测值,
Figure 361245DEST_PATH_IMAGE006
的取值范围为[0,1]。
4.根据权利要求3所述的方法,其特征在于,所述负标签数据的交叉熵损失函数如下:
Figure 861497DEST_PATH_IMAGE007
其中,
Figure 282114DEST_PATH_IMAGE008
表示所述负标签数据的交叉熵损失函数。
5.根据权利要求4所述的方法,其特征在于,所述融合后的交叉熵损失函数如下:
Figure 567733DEST_PATH_IMAGE009
其中,
Figure 681182DEST_PATH_IMAGE010
表示所述融合后的交叉熵损失函数。
6.根据权利要求1至5中任一项所述的方法,其特征在于,所述方法还包括:
将训练得到的网络模型部署到车辆中,其中,所述车辆包括自动驾驶车辆或无人驾驶车辆。
7.根据权利要求1至5中任一项所述的方法,其特征在于,所述网络模型为点体素集成网络。
8.一种网络模型的训练装置,其特征在于,包括:
第一训练模块,被配置为执行第一训练步骤,以利用点云数据集中的有标签数据对改进的网络模型进行训练;
处理模块,被配置为执行处理步骤,以利用训练得到的改进的网络模型对所述点云数据集中的无标签数据进行打上伪标签的处理,生成所述无标签数据的伪标签数据;
第二训练模块,被配置为执行第二训练步骤,以利用所述伪标签数据和所述有标签数据对所述改进的网络模型进行训练;
交替模块,被配置为控制所述处理模块和所述第二训练模块交替执行所述处理步骤和所述第二训练步骤,直至满足所述改进的网络模型的训练结束条件。
9.一种电子设备,包括存储器、处理器以及存储在所述存储器中并且可以在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7中任一项所述方法的步骤。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7中任一项所述方法的步骤。
CN202110883889.2A 2021-08-03 2021-08-03 网络模型的训练方法、装置、电子设备及存储介质 Pending CN113326826A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110883889.2A CN113326826A (zh) 2021-08-03 2021-08-03 网络模型的训练方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110883889.2A CN113326826A (zh) 2021-08-03 2021-08-03 网络模型的训练方法、装置、电子设备及存储介质

Publications (1)

Publication Number Publication Date
CN113326826A true CN113326826A (zh) 2021-08-31

Family

ID=77426859

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110883889.2A Pending CN113326826A (zh) 2021-08-03 2021-08-03 网络模型的训练方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN113326826A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113592045A (zh) * 2021-09-30 2021-11-02 杭州一知智能科技有限公司 从印刷体到手写体的模型自适应文本识别方法和系统
CN113743618A (zh) * 2021-09-03 2021-12-03 北京航空航天大学 时间序列数据处理方法、装置、可读介质及电子设备
CN114792417A (zh) * 2022-02-24 2022-07-26 广州文远知行科技有限公司 模型训练方法、图像识别方法、装置、设备及存储介质

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190147320A1 (en) * 2017-11-15 2019-05-16 Uber Technologies, Inc. "Matching Adversarial Networks"
CN111598004A (zh) * 2020-05-18 2020-08-28 北京星闪世图科技有限公司 一种渐进增强自学习的无监督跨领域行人再识别方法
CN112115780A (zh) * 2020-08-11 2020-12-22 西安交通大学 一种基于深度多模型协同的半监督行人重识别方法
CN112215487A (zh) * 2020-10-10 2021-01-12 吉林大学 一种基于神经网络模型的车辆行驶风险预测方法
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN112381098A (zh) * 2020-11-19 2021-02-19 上海交通大学 基于目标分割领域自学习的半监督学习方法和系统
CN112651975A (zh) * 2020-12-29 2021-04-13 奥比中光科技集团股份有限公司 一种轻量化网络模型的训练方法、装置及设备

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190147320A1 (en) * 2017-11-15 2019-05-16 Uber Technologies, Inc. "Matching Adversarial Networks"
CN111598004A (zh) * 2020-05-18 2020-08-28 北京星闪世图科技有限公司 一种渐进增强自学习的无监督跨领域行人再识别方法
CN112115780A (zh) * 2020-08-11 2020-12-22 西安交通大学 一种基于深度多模型协同的半监督行人重识别方法
CN112215487A (zh) * 2020-10-10 2021-01-12 吉林大学 一种基于神经网络模型的车辆行驶风险预测方法
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN112381098A (zh) * 2020-11-19 2021-02-19 上海交通大学 基于目标分割领域自学习的半监督学习方法和系统
CN112651975A (zh) * 2020-12-29 2021-04-13 奥比中光科技集团股份有限公司 一种轻量化网络模型的训练方法、装置及设备

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
SHAOSHUAI SHI ET AL.: ""PV-RCNN: Point-Voxel Feature Set Abstraction for 3D Object Detection"", 《2020 IEEE/CVF CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION (CVPR)》 *
陈龙: ""基于标签关联的图像分类方法研究"", 《中国优秀博硕士学位论文全文数据库(硕士) 信息科技辑》 *

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113743618A (zh) * 2021-09-03 2021-12-03 北京航空航天大学 时间序列数据处理方法、装置、可读介质及电子设备
CN113592045A (zh) * 2021-09-30 2021-11-02 杭州一知智能科技有限公司 从印刷体到手写体的模型自适应文本识别方法和系统
CN113592045B (zh) * 2021-09-30 2022-02-08 杭州一知智能科技有限公司 从印刷体到手写体的模型自适应文本识别方法和系统
CN114792417A (zh) * 2022-02-24 2022-07-26 广州文远知行科技有限公司 模型训练方法、图像识别方法、装置、设备及存储介质
CN114792417B (zh) * 2022-02-24 2023-06-16 广州文远知行科技有限公司 模型训练方法、图像识别方法、装置、设备及存储介质

Similar Documents

Publication Publication Date Title
CN111797893B (zh) 一种神经网络的训练方法、图像分类系统及相关设备
JP7075366B2 (ja) 運転場面データを分類するための方法、装置、機器及び媒体
CN112183577A (zh) 一种半监督学习模型的训练方法、图像处理方法及设备
CN113326826A (zh) 网络模型的训练方法、装置、电子设备及存储介质
EP4066171A1 (en) Vehicle intent prediction neural network
CN112950642A (zh) 点云实例分割模型的训练方法、装置、电子设备和介质
US11270425B2 (en) Coordinate estimation on n-spheres with spherical regression
CN115115872A (zh) 图像识别方法、装置、设备及存储介质
CN111126459A (zh) 一种车辆细粒度识别的方法及装置
CN113807399A (zh) 一种神经网络训练方法、检测方法以及装置
CN111738403A (zh) 一种神经网络的优化方法及相关设备
CN114611672A (zh) 模型训练方法、人脸识别方法及装置
CN115018039A (zh) 一种神经网络蒸馏方法、目标检测方法以及装置
CN114092920B (zh) 一种模型训练的方法、图像分类的方法、装置及存储介质
CN115953643A (zh) 基于知识蒸馏的模型训练方法、装置及电子设备
CN115063585A (zh) 一种无监督语义分割模型的训练方法及相关装置
CN114550116A (zh) 一种对象识别方法和装置
CN112329830B (zh) 一种基于卷积神经网络和迁移学习的无源定位轨迹数据识别方法及系统
Bai et al. Cyber mobility mirror for enabling cooperative driving automation: A co-simulation platform
CN114170484B (zh) 图片属性预测方法、装置、电子设备和存储介质
CN116994021A (zh) 图像检测方法、装置、计算机可读介质及电子设备
CN115576990A (zh) 视觉真值数据与感知数据的评测方法、装置、设备及介质
CN112417260A (zh) 本地化推荐方法、装置及存储介质
CN112434591B (zh) 车道线确定方法、装置
CN115565152B (zh) 融合车载激光点云与全景影像的交通标志牌提取方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20210831