CN114926701A - 一种模型训练方法、目标检测方法、以及相关设备 - Google Patents

一种模型训练方法、目标检测方法、以及相关设备 Download PDF

Info

Publication number
CN114926701A
CN114926701A CN202110133274.8A CN202110133274A CN114926701A CN 114926701 A CN114926701 A CN 114926701A CN 202110133274 A CN202110133274 A CN 202110133274A CN 114926701 A CN114926701 A CN 114926701A
Authority
CN
China
Prior art keywords
enhancement
model
training
updated
training sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110133274.8A
Other languages
English (en)
Inventor
黄泽昊
刘翱铭
王乃岩
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Tusimple Technology Co Ltd
Original Assignee
Beijing Tusimple Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Tusimple Technology Co Ltd filed Critical Beijing Tusimple Technology Co Ltd
Priority to CN202110133274.8A priority Critical patent/CN114926701A/zh
Publication of CN114926701A publication Critical patent/CN114926701A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/29Graphical models, e.g. Bayesian networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例公开了一种模型训练方法、目标检测方法、以及相关设备,本申请实施例可以获取训练样本集,所述训练样本集包括训练样本和验证样本;获取所述训练样本对应的增强策略;根据所述增强策略对所述训练样本进行增强操作,得到增强后训练样本;通过初始模型对所述增强后训练样本进行预测;基于对所述增强后训练样本的预测结果更新所述初始模型的参数,得到更新后模型;通过所述更新后模型对所述验证样本进行预测;基于对所述验证样本的预测结果更新所述增强策略,得到更新后增强策略;根据所述更新后增强策略对所述更新后模型进行迭代训练,得到训练后模型。本申请实施例提高了对模型训练准确性和精度。

Description

一种模型训练方法、目标检测方法、以及相关设备
技术领域
本申请涉及数据处理技术领域,具体涉及一种模型训练方法、目标检测方法、以及相关设备。
背景技术
数据增强是防止深度神经网络过拟合的关键技术,因此,合适的数据增强策略对于深度神经网络模型的表现有着非常重要的影响,一般是通过人工选择数据增强策略来进行数据增强,但是人工选择数据增强策略的效率非常低,进而出现了自动数据增强,自动数据增强是指针对不同的任务或数据集,通过特定的搜索优化方法来自动找到最合适的数据增强策略。
现有的自动数据增强技术可以大致分为两类,一类是不可微分的自动数据增强技术,这类技术主要通过增强学习、启发式搜索算法或贝叶斯优化来实现数据增强策略的搜索优化(例如AutoAugment)。具体地,可以设计一个数据增强的搜索空间,每一种策略由多个子策略组成,子策略包含两个图像处理操作,如平移、旋转或剪切等,对于每一个操作都有一组概率和幅度来表征这个操作的使用性质。搜索算法由控制器和优化算法组成,控制器产生每一步的数据增强策略决策,然后将决策作为下一步操作的嵌入向量。每一步采样一个数据增强策略,并用这个策略在从训练集中采样出来的一个小型代理数据集上训练一个模型,所训练模型在代理验证集上的表现作为奖励信号来优化控制器。该自动数据增强技术能够实现较好数据增强技术搜索,但是由于采样与优化效率的限制,搜索的时间成本往往较高,以及搜索到的数据增强可靠性低,从而降低了对模型训练的准确性和精度。
另一类是可微分的自动数据增强技术,采用Gumbel-Softmax等近似技术来将数据增强搜索转化为可微分优化问题,例如代表技术为自动数据增强搜索方式(Differentiable Automatic Data Augmentation,DADA)。具体地,首先将数据增广策略搜索形式化为类别分布(Categorical Distribution)采样问题,对每个子策略里每个操作的概率作为伯努利分布(Bernoulli Distribution)采样问题。然后通过gumbel-softmax技术来将上述分布参数的优化松弛为一个可微分优化问题,并通过使用RELAX估计器估计上述分布的梯度,同时使用二阶梯度估计技术来提升搜索速度。DADA使用基于一步梯度更新的元学习来交替优化模型权重和数据增广策略参数梯度。可微分的自动数据增强技术虽然能够提升搜索的效率,但是在实现可微分搜索的过程中,需要使用很多的近似技术,例如使用gumbel-softmax来实现可微分松弛,或者使用REINFORCE等估计器来估计数据增强参数的梯度,这些近似技术一方面造成了技术上的复杂性,限制了自动数据增强技术的广泛使用,另一方面可能引入不准确的近似,导致搜索到的数据增强测率性能不强,降低了对模型训练的准确性和精度。
发明内容
本申请实施例提供一种模型训练方法、目标检测方法、以及相关设备,其中,相关设备可以包括模型训练装置、目标检测装置、计算机设备及计算机可读存储介质,本申请实施例可以提高对模型训练准确性和精度。
为解决上述技术问题,本申请实施例提供以下技术方案:
本申请实施例提供了一种模型训练方法,包括:获取训练样本集,所述训练样本集包括训练样本和验证样本;获取训练样本对应的增强策略;根据增强策略对训练样本进行增强操作,得到增强后训练样本;通过初始模型对增强后训练样本进行预测;基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型;通过更新后模型对验证样本进行预测;基于对验证样本的预测结果更新所述增强策略,得到更新后增强策略;根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。
根据本申请的一个方面,还提供了一种目标检测方法,包括:获取待检测的图像;通过训练后目标检测模型提取图像的特征信息,训练后目标检测模型为基于更新后增强策略进行迭代训练得到,更新后增强策略为根据更新后模型基于验证样本的预测结果对增强策略进行更新得到,更新后模型为根据增强后训练样本的预测结果对初始模型的参数进行更新得到,增强后训练样本为通过增强策略对训练样本进行增强操作得到;通过训练后目标检测模型基于特征信息对所述图像进行目标检测。
根据本申请的一个方面,还提供了一种模型训练装置,包括:第一获取模块,用于获取训练样本集,所述训练样本集包括训练样本和验证样本;第二获取模块,用于获取训练样本对应的增强策略;增强模块,用于根据增强策略对训练样本进行增强操作,得到增强后训练样本;第一预测模块,用于通过初始模型对增强后训练样本进行预测;第一更新模块,用于基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型;第二预测模块,用于通过更新后模型对所述验证样本进行预测;第二更新模块,用于基于对验证样本的预测结果更新增强策略,得到更新后增强策略;训练模块,用于根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。
根据本申请的一个方面,还提供了一种目标检测装置,包括:图像获取模块,用于获取待检测的图像;提取模块,用于通过训练后目标检测模型提取图像的特征信息,训练后目标检测模型为基于更新后增强策略进行迭代训练得到,更新后增强策略为根据更新后模型基于验证样本的预测结果对增强策略进行更新得到,更新后模型为根据增强后训练样本的预测结果对初始模型的参数进行更新得到,增强后训练样本为通过增强策略对训练样本进行增强操作得到;检测模块,用于通过训练后目标检测模型基于特征信息对图像进行目标检测。
根据本申请的一个方面,还提供了一种计算机设备,包括处理器和存储器,所述存储器中存储有计算机程序,所述处理器调用所述存储器中的计算机程序时执行本申请实施例提供的任一种模型训练方法,或执行本申请实施例提供的任一种目标检测方法。
根据本申请的一个方面,还提供了一种计算机可读存储介质,所述计算机可读存储介质用于存储计算机程序,所述计算机程序被处理器加载,以执行本申请实施例提供的任一种模型训练方法,或执行本申请实施例提供的任一种目标检测方法。
本申请实施例可以基于训练样本集中训练样本对应的增强策略对训练样本进行增强操作,得到增强后训练样本,以及通过初始模型基于增强后训练样本进行预测,并基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型;然后可以通过更新后模型对训练样本集中验证样本进行预测,并基于对验证样本的预测结果自动更新增强策略,得到更新后增强策略;此时可以根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。该方案通过增强策略快速增强训练样本,基于增强后训练样本更新初始模型的参数,并通过更新后模型基于验证样本自动更新增强策略,以对模型进行迭代训练,提高了对训练样本增强的效率,以及提高了对模型训练的准确性和精度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例提供的模型训练方法应用的场景示意图;
图2是本申请实施例提供的模型训练方法的流程示意图;
图3是本申请实施例提供的训练样本增强的示意图;
图4是本申请实施例提供的模型训练方法的另一流程示意图;
图5是本申请实施例提供的训练样本增强的另一示意图;
图6是本申请实施例提供的模型训练方法的另一流程示意图;
图7是本申请实施例提供的模型训练方法的另一流程示意图;
图8是本申请实施例提供的目标检测方法的流程示意图;
图9是本申请实施例提供的模型训练装置的示意图;
图10是本申请实施例提供的目标检测装置的示意图;
图11是本申请实施例提供的计算机设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
在以下的描述中,所涉及的术语“第一\第二”仅仅是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一\第二”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本申请实施例能够以除了在这里图示或描述的以外的顺序实施。
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
本申请实施例提供一种模型训练方法、目标检测方法、以及相关设备,其中,相关设备可以包括模型训练装置、目标检测装置、计算机设备及计算机可读存储介质等。
本申请实施例所提供的模型训练方法,可以由终端或服务器独自实现,也可以由终端和服务器协同实现,例如终端独自承担下文所述的模型训练方法,或者,终端可以向服务器发送针对模型训练的训练请求,服务器可以根据接收的针对训练请求执行模型训练方法,并向终端发送训练结果。以及,本申请实施例所提供的目标检测方法,可以由终端或服务器独自实现,也可以由终端和服务器协同实现。
本申请实施例提供的用于模型训练的计算机设备可以是各种类型的终端设备或服务器,其中,服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content Delivery Network,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器,但并不局限于此。终端可以是智能手机、平板电脑、笔记本电脑、台式计算机、摄像机、可穿戴设备、或车载终端等,该车载终端可以位于无人驾驶车辆,但并不局限于此。终端与服务器之间可以通过有线或无线通信方式进行直接或间接地连接,本申请实施例在此不做限制。
请参阅图1,图1为本申请实施例所提供的模型训练方法和目标检测方法应用的场景示意图,该模型训练方法可以应用于模型训练装置,该模型训练装置具体可以集成在终端或服务器等计算机设备中,该目标检测方法可以应用于目标检测装置,该目标检测装置具体可以集成在终端或服务器等计算机设备中。以下将以模型训练装置集成在服务器中、目标检测装置集成在终端中为例进行详细说明。例如,服务器可以基于训练样本集中训练样本对应的增强策略对训练样本进行增强操作,得到增强后训练样本,以及通过初始模型基于增强后训练样本进行预测,并基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型;然后可以通过更新后模型对训练样本集中验证样本进行预测,并基于对验证样本的预测结果自动更新增强策略,得到更新后增强策略;此时可以根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。服务器通过增强策略快速增强训练样本,基于增强后训练样本更新初始模型的参数,并通过更新后模型基于验证样本自动更新增强策略,以对模型进行迭代训练,提高了对训练样本增强的效率,以及提高了对模型训练的准确性和精度。在得到训练后模型后,服务器可以将训练后模型发送给终端,该训练后模型的具体类型可以根据实际需要进行灵活设置,例如训练后模型可以是训练后目标检测模型,此时终端可以获取待检测的图像,通过训练后目标检测模型提取图像的特征信息,通过训练后目标检测模型基于特征信息对图像进行目标检测,通过训练后模型对目标进行检测,提高了对目标检测的准确性。
需要说明的是,图1所示的模型训练方法应用的场景示意图仅仅是一个示例,本申请实施例描述的模型训练方法应用以及场景是为了更加清楚的说明本申请实施例的技术方案,并不构成对于本申请实施例提供的技术方案的限定,本领域普通技术人员可知,随着模型训练方法应用的演变和新业务场景的出现,本申请实施例提供的技术方案对于类似的技术问题,同样适用。
以下分别进行详细说明。需说明的是,以下实施例的描述顺序不作为对实施例优选顺序的限定。
在本实施例中,将从模型训练装置的角度进行描述,该模型训练装置具体可以集成在终端或服务器等计算机设备中。
请参阅图2,图2是本申请一实施例提供的模型训练方法的流程示意图。该模型训练方法可以包括:
S101、获取训练样本集,训练样本集包括训练样本和验证样本。
其中,可以从本地存储数据库中获取训练样本集,或从服务器上获取训练样本集,训练样本集中可以包括多个样本,该样本可以是图像、视频、音频、文本等本领域常见的数据,本申请对此不作限制。另外,可以从训练样本集提取部分样本作为训练样本,以及从训练样本集提取另一部分样本作为验证样本,训练样本可以用于给初始模型进行预测,验证样本可以用于更新后模型验证,训练样本可以包括一个或多个样本,验证样本可以包括一个或多个样本。
例如,可以从训练样本集中筛选出多个样本作为代理训练集以及筛选出多个样本作为代理验证集,可以从代理训练集中选择一组或多组样本作为训练样本,每组训练样本中至少包括一个样本,以及从代理验证集选择一组或者多组样本作为验证样本,每组验证样本中至少包括一个样本。
训练样本集中的样本可以包括真实值,该真实值可以包括目标对象所在的位置或者正确的数据标签等,例如,当训练样本集中的样本为图像时,该图像可以包括目标对象,该真实值可以是目标对象所在的位置,以便基于该图像对初始模型进行训练得到的训练后模型可以准确检测图像中目标对象所在的位置,该目标对象可以包括车辆、人、建筑物、动物或植物等。
S102、获取训练样本对应的增强策略。
在一实施方式中,增强策略可以包括平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节中的任一种或多种组合的增强操作。
在一实施方式中,获取训练样本对应的增强策略可以包括:获取增强权重,根据增强权重确定训练样本对应的增强策略。
为了提高增强策略确定的准确性,可以通过训练样本的增强权重确定训练样本对应的增强策略,首先可以获取训练样本对应的增强权重,该增强权重可以包括增强概率(也可以称为总体分布概率ptp)和执行概率(也可以称为操作分布概率po)等,总体分布概率ptp可以表示训练样本整体上多大概率施加数据增强(即训练样本要进行增强的概率),操作分布概率po可以表示选择每个候选增强操作的概率。
在一实施方式中,获取增强权重可以包括:获取当前的增强概率,以及获取不同增强操作对应的执行概率;根据增强概率和执行概率,确定增强策略对应的增强权重。
例如,可以预先为训练样本设置初始的增强概率,以及预先为不同增强操作设置初始的执行概率,开始训练时,可以获取训练样本初始的增强概率,以及获取不同增强操作初始的执行概率,在训练过程中,可以调整训练样本的增强概率,以及调整不同增强操作的执行概率,此时可以获取得到的调整后的训练样本的增强概率,以及调整后的不同增强操作的执行概率。
在得到训练样本对应的增强概率和增强操作对应的执行概率后,可以根据增强概率和执行概率,确定训练样本对应的增强权重。在一实施方式中,根据增强概率和执行概率,确定增强策略对应的增强权重可以包括:将增强概率和增强策略所对应的多个增强操作的执行概率进行连乘操作,得到增强策略对应的增强权重。
例如,训练样本对应的增强权重=增强概率ptp*执行概率po
在一实施方式中,根据增强权重确定训练样本对应的增强策略可以包括:根据增强权重采样多个增强策略;根据增强策略对训练样本进行增强操作,得到增强后训练样本包括:将采样得到的多个增强策略作用到不同的训练样本上,得到增强后的训练样本。
例如,可以基于增强策略中包含的增强操作对不同的训练样本进行增强,得到增强后的训练样本。其中,增强策略还可以包括各增强操作的执行优先级等,例如,增强策略A可以包括增强操作1、增强操作2和增强操作3,以及增强操作1、增强操作2和增强操作3之间串联执行的执行优先级顺序:依次对训练样本执行增强操作3、增强操作2和增强操作4得到增强后训练样本。
在一实施方式中,多个增强策略的生成方式可以包括:从多个增强操作中筛选出满足条件的多个候选增强操作;将多个候选增强操作划分为多个增强操作组;基于每个增强操作组以及每个增强操作组中各增强操作的执行顺序生成多个增强策略。
具体地,该多个增强操作可以构成增强操作库,假设一个增强操作库中有K个增强操作,则可以从K(K大于N)个增强操作中筛选出满足条件的多个候选增强操作:O1,O2,......,ONo,例如,可以从K个增强操作中随机筛选出N个候选增强操作,或者可以从K个增强操作中筛选出执行概率较大的前N个候选增强操作,等等。然后,可以将多个候选增强操作划分为多个增强操作组(例如L组增强操作组):
Figure BDA0002926147060000081
每个增强操作组中可以包括至少一个增强操作,具体划分方式可以是随机划分、平均划分或按照增强操作的执行概率进行划分等,在此处不做限定。此时,可以基于每个增强操作组以及每个增强操作组中各增强操作的执行顺序生成多个增强策略。
在一实施方式中,增强概率的初始值为设定值;若增强操作库中包括K个增强操作,则每个增强操作所对应的执行概率的初始值为1/K;增强概率和执行概率在每轮迭代后被更新。
一般地,增强概率可以与当前训练迭代的次数有关,初始的增强概率为设定值(例如0.5或0.4等),初始的执行概率值可以根据增强操作库中的增强操作数目设定为均分值,之后每轮迭代后,增强概率值和执行概率值进行更新。例如,设定操作库中有100个增强操作,则每个增强操作的初始执行概率均为0.01。
S103、根据增强策略对训练样本进行增强操作,得到增强后训练样本。
在得到增强策略后,可以基于增强策略对训练样本进行增强操作,例如,如图3所示,对于训练样本D,可以基于训练样本D的增强概率ptp确定需要进行增强操作,基于增强概率ptp和执行概率po确定增强策略中包含的增强操作:O1,O2,......,ONo,并对训练样本进行增强操作,得到增强后训练样本Da
例如,如图4所示,可以从K个增强操作中随机选取No个增强操作,基于No个增强操作生成L组增强策略
Figure BDA0002926147060000082
将L组增强策略
Figure BDA0002926147060000083
Figure BDA0002926147060000084
分别施加在L组训练样本上,生成L组增强后训练样本
Figure BDA0002926147060000085
Figure BDA0002926147060000086
图4中,
Figure BDA0002926147060000087
为训练样本,Vt为验证样本,ptp为增强概率,po为执行概率,Neural Networkθ为需要训练的模型(例如神经网络),
Figure BDA0002926147060000088
为基于训练样本
Figure BDA0002926147060000089
得到的梯度值,
Figure BDA00029261470600000810
分别为基于L组增强后训练样本
Figure BDA00029261470600000811
Figure BDA00029261470600000812
得到的梯度值,
Figure BDA00029261470600000813
为基于验证样本Vt得到的梯度值,可以将训练样本
Figure BDA00029261470600000814
L组增强后训练样本
Figure BDA00029261470600000815
以及验证样本Vt等输入需要训练的模型(例如初始模型),并基于模型的输出的梯度值
Figure BDA00029261470600000816
对模型进行迭代训练,以下将进行详细说明。
根据增强策略对训练样本进行增强操作,得到增强后训练样本可以包括:当训练样本对应的增强概率大于预设概率阈值时,根据增强策略中各个增强操作对应的执行概率,确定增强策略中各个增强操作的执行优先级顺序;根据执行优先级顺序对训练样本进行平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节中的任一种或多种组合的增强操作,得到增强后训练样本。
可选地,在得到训练样本的增强概率后,可以判断训练样本对应的增强概率是否大于预设概率阈值,其中,预设概率阈值可以根据实际需要进行灵活设置。例如,如图5所示,当训练样本对应的增强概率大于预设概率阈值时,说明需要对训练样本进行增强操作,此时,可以对训练样本施加增强操作,例如,可以根据增强策略中各个增强操作对应的执行概率,确定增强策略中各个增强操作的执行优先级顺序。例如,增强操作的执行概率越大,增强操作的执行优先级越低,反之,增强操作的执行概率越小,增强操作的执行优先级越高。然后,可以根据执行优先级顺序对训练样本进行平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节中的任一种或多种组合的增强操作,得到增强后训练样本,以便后续可以通过增强后训练样本队模型进行训练。例如,当增强策略为依次执行旋转、裁剪、放大和亮度调节时,首先可以对训练样本进行旋转操作,得到旋转后训练样本,然后可以对旋转后训练样本进行裁剪操作,得到裁剪后训练样本,其次,可以对裁剪后训练样本进行放大操作,得到放大后训练样本,最后,可以对放大后训练样本进行亮度调节操作,得到增强后训练样本。当训练样本对应的增强概率小于或等于预设概率阈值时,说明不需要对训练样本进行增强操作,后续可以直接利用训练样本训练模型。
需要说明的是,在根据增强策略对训练样本进行增强操作过程中,可以是基于增强策略包含的各增强操作分别并行对训练样本进行增强操作,得到多个增强后训练样本。例如,当增强策略包括旋转、裁剪、缩小和噪声叠加时,可以分别对训练样本进行旋转、裁剪、缩小和噪声叠加操作,得到旋转后训练样本、裁剪后训练样本、缩小后训练样本和噪声叠加后训练样本等多个增强后训练样本。
S104、通过初始模型对增强后训练样本进行预测。
其中,初始模型的类型及结构等可以根据实际需要进行灵活设置,具体在此处不做限定。例如,初始模型可以是神经网络模型。
S105、基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型。
在一实施方式中,通过初始模型基于增强后训练样本进行预测,基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型可以包括:通过初始模型基于增强后训练样本进行预测,得到预测值;基于预测值与增强后训练样本对应的真实值获取损失值;基于损失值获取第一梯度值;根据第一梯度值更新初始模型的参数,得到更新后模型。
具体地,可以通过初始模型基于增强后训练样本进行预测,得到预测值,例如,当需要基于样本图像(即训练样本)对初始模型进行训练,以基于训练后模型准确检测图像中目标对象所在的位置时,可以通过初始模型基于增强后样本图像对目标对象所在的位置进行预测,得到预测位置。以及,可以获取增强后训练样本对应的真实值,该真实值可以是目标对象在增强后训练样本中的真实位置。然后可以获取初始模型基于增强后训练样本预测得到的预测值与增强后训练样本对应的真实值之间的损失值(也可以称为期望值),基于损失值确定第一梯度值,即通过初始模型基于增强后训练样本进行前向传播与反向传播操作以获取第一梯度值,当有多个增强后训练样本时,可以对应得到多个第一梯度值
Figure BDA0002926147060000101
此时可以根据第一梯度值更新初始模型的参数至合适数值,得到更新后模型。
S106、通过更新后模型对验证样本进行预测。
S107、基于对验证样本的预测结果更新增强策略,得到更新后增强策略。
在一实施方式中,通过更新后模型对验证样本进行预测,基于对验证样本的预测结果更新增强策略,得到更新后增强策略可以包括:通过更新后模型对验证样本进行预测,以基于预测结果获取第二梯度值;根据第二梯度值更新增强权重,得到更新后增强权重;根据更新后增强权重确定更新后增强策略。
具体地,可以通过更新后模型对验证样本进行预测,得到验证样本对应的预测值,例如,当验证样本为图像时,为了使得训练后模型能够准确检测图像中目标对象所在的位置,可以通过更新后模型基于验证样本对目标对象所在的位置进行预测,得到预测位置。
以及,可以获取验证样本对应的真实值,该真实值可以是目标对象在验证样本中的真实位置。然后可以获取更新后模型基于验证样本预测得到的预测值与验证样本对应的真实值之间的损失值,基于该损失值确定第二梯度值,即通过更新后模型基于验证样本进行前向传播与反向传播操作以获取第二梯度值
Figure BDA0002926147060000102
当有多个验证样本时,可以对应得到多个第二梯度值。此时,可以根据第二梯度值更新增强概率和执行概率等增强权重,得到更新后增强权重(例如更新后增强概率gptp和更新后执行概率
Figure BDA0002926147060000111
),具体更新方式可以根据实际需要进行灵活设置。最后,可以根据不同增强权重与增强策略之间的映射关系,确定与更新后增强权重匹配的增强策略,得到更新后增强策略。
为提高对模型参数更新的准确性和可靠性,可以对第一梯度值以及第二梯度值进行归一化处理,在一实施方式中,根据第一梯度值更新初始模型的参数,得到更新后模型可以包括:对第一梯度值进行归一化,得到归一化第一梯度值;根据归一化后第一梯度值更新初始模型的参数,得到更新后模型。
在一实施方式中,根据第二梯度值更新增强权重,得到更新后增强权重可以包括:对第二梯度值进行归一化,得到归一化第二梯度值;根据归一化第二梯度值更新增强权重,得到更新后增强权重。
例如,可以用增强策略对应的增强权重
Figure BDA0002926147060000112
作为权重对第一梯度值和第二梯度值进行归一化,使得计算得到的梯度在同一预设尺度上,增强权重可以包括增强概率ptp和执行概率po。具体归一化方式可以如下公式所示:
Figure BDA0002926147060000113
Figure BDA0002926147060000114
其中,
Figure BDA0002926147060000115
可以表示归一化后第一梯度,gptp可以表示归一化后第二梯度,
Figure BDA0002926147060000116
可以表示基于验证样本Vt得到的梯度值,
Figure BDA0002926147060000117
可以表示基于训练样本
Figure BDA0002926147060000118
得到的梯度值,
Figure BDA0002926147060000119
可以表示基于增强后训练样本
Figure BDA00029261470600001110
得到的梯度值,ptp可以表示训练样本的增强概率,
Figure BDA00029261470600001111
可以表示增强策略对应的增强权重,η可以表示系数,其具体取值可以根据实际需要进行灵活设置。
其中Z与Zg为归一化系数,其计算方式可以如下公式所示:
Figure BDA00029261470600001112
Figure BDA00029261470600001113
基于以上归一化后第一梯度和归一化后第二梯度,可以对数据增强参数进行搜索优化,例如,可以根据归一化后第一梯度值更新模型的参数,得到更新后模型,以便通过更新后模型对验证样本进行预测,以及根据归一化后第二梯度值更新增强权重,得到更新后增强权重,以便根据更新后增强权重确定更新后增强策略,根据更新后增强策略对更新后模型进行迭代训练。
S108、根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。
例如,如图6所示,在按照上述方式选取训练样本及确定训练样本对应的增强策略后,可以对训练样本施加增强策略,生成增强后训练样本,然后可以基于增强后训练样本计算第一梯度来更新模型参数,得到更新后模型,以及通过更新后模型基于选取的验证样本计算第二梯度值,并基于第二梯度值更新增强策略,得到更新后增强策略,基于更新后增强策略对更新后模型进行迭代训练,得到训练后模型。
在一实施方式中,训练样本集为全量训练集中的部分样本集,根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型可以包括:根据更新后增强策略对更新后模型进行迭代训练,得到最终的增强策略;采用最终的增强策略对全量训练集进行增强操作;基于增强后的全量训练集对更新后模型进行迭代训练,得到训练后模型。
其中,全量训练集包括多个训练样本,训练样本集是全量训练集的一部分,可以按照上述迭代训练方式基于更新后增强策略对更新后模型进行迭代训练,从而可以在训练样本集迭代训练的基础上得到最终的增强策略,将最终的增强策略映射到全量训练集来进行增强操作,根据增强后的全量训练集对更新后模型进行迭代训练,得到训练后模型,这样可以提升模型训练的精准性。
在一实施方式中,采用最终的增强策略对全量训练集进行增强操作可以包括:将全量训练集中的每一样本均采用最终的增强策略进行增强操作;或者,根据全量训练集和代理训练集的样本数比例,将每轮迭代后得到的更新后增强策略按比例映射到对应的全量训练集中。
其中,最终的增强策略可以是训练样本集上训练终止时得到的策略,然后全量训练集都采用该策略,即将全量训练集中的每一样本均采用最终的增强策略进行增强操作。或者,可以获取全量训练集的样本数和代理训练集的样本数,根据全量训练集的样本数和代理训练集的样本数,确定全量训练集和代理训练集的样本数比例(例如,1:s),将第一批训练样本集所采用的增强策略映射到前s个全量样本中,第二批训练样本集所采用的增强策略映射到第二组全量样本(s~2s)中,依次类推,实现将每轮迭代后得到的更新后增强策略按比例映射到对应的全量训练集中,提高了模型训练的灵活性和可靠性。
在一实施方式中,训练样本包括第一训练样本和第二训练样本,根据增强策略对训练样本进行增强操作,得到增强后训练样本可以包括:根据增强策略对第一训练样本进行增强操作,得到增强后训练样本。根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型可以包括:根据更新后增强策略对第二训练样本进行增强操作,得到目标增强后训练样本;将目标增强后训练样本作为增强后训练样本,以及将更新后模型作为初始模型,并返回执行通过初始模型基于增强后训练样本进行预测的操作,直至迭代训练次数达到预设次数或者更新后模型预测得到的损失值小于预设阈值,得到训练后模型。
具体地,训练样本可以包括多组训练样本,例如,第一训练样本、第二训练样本、......、以及第n训练样本等。在上述根据增强策略对训练样本进行增强操作的过程中,可以是根据增强策略对第一训练样本进行增强操作,得到增强后训练样本。此时,在根据更新后增强策略对更新后模型进行迭代训练的过程中,可以根据更新后增强策略对第二训练样本进行增强操作,得到目标增强后训练样本,将目标增强后训练样本作为增强后训练样本,并将更新后模型作为初始模型,返回执行通过初始模型基于增强后训练样本进行预测的操作,并基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型,通过更新后模型对验证样本进行预测,并基于对验证样本的预测结果更新增强策略,得到更新后增强策略,根据更新后增强策略对更新后模型进行迭代训练的操作,直至迭代训练次数达到预设次数或者更新后模型预测得到的损失值小于预设阈值(即更新后模型收敛),得到训练后模型。其中,预设次数或预设阈值可以根据实际需要进行灵活设置,具体取值在此处不做限定。
例如,本实施例中训练模型的目的是使得模型的参数收敛:
Figure BDA0002926147060000131
Figure BDA0002926147060000132
其中,
Figure BDA0002926147060000133
可以表示模型的参数,
Figure BDA0002926147060000134
可以表示验证样本的损失,
Figure BDA0002926147060000135
可以表示训练样本的损失,Ep可以表示期望值。
其中,期望值Ep的计算方式可以如下所示:
Figure BDA0002926147060000136
其中,第l个增强后训练样本的概率
Figure BDA0002926147060000137
的计算方式可以如下所示:
Figure BDA0002926147060000141
其中,训练样本的概率
Figure BDA0002926147060000142
的计算方式可以如下所示:
Figure BDA0002926147060000143
其中,模型参数的迭代更新方式可以如下所示:
θt+1=θt-η·gt
其中,梯度值gt的计算方式可以如下所示:
Figure BDA0002926147060000144
其中,
Figure BDA0002926147060000145
可以表示训练样本,
Figure BDA0002926147060000146
可以表示增强后训练样本,θt可以表示模型的参数,
Figure BDA0002926147060000147
可以表示基于增强后训练样本
Figure BDA0002926147060000148
得到的梯度值,
Figure BDA0002926147060000149
可以表示基于训练样本
Figure BDA00029261470600001410
得到的梯度值。在搜索过程中,可以通过对训练过程中的损失函数取期望,从而将概率参数转化为损失函数中的权重,进而实现可微分化近似。
在一实施方式中,训练后模型包括训练后目标检测模型或训练后图像分类模型,模型训练方法还可以包括:获取待检测的图像;通过训练后目标检测模型或训练后图像分类模型提取图像的特征信息,并基于特征信息对图像进行目标检测或者图像分类。
例如,当训练样本为样本图像时,可以基于样本图像按照上述训练方式对初始目标检测模型进行训练,并基于训练后目标检测模型准确检测图像中目标对象所在的位置。其中,目标检测模型的类型和结构等可以根据实际需要进行灵活设置,待检测的图像可以是通过激光雷达采集到的深度图像,或者是,待检测的图像可以是通过手机、摄像头或相机等采集得到的图像。
又例如,当训练后模型为训练后图像分类模型时,训练样本可以为样本图像,可以基于样本图像按照上述训练方式对初始图像分类模型进行训练,并基于训练后图像分类模型准确对图像进行分类。其中,图像分类模型的类型和结构等可以根据实际需要进行灵活设置。首先可以通过激光雷达、手机、摄像头或相机等采集待检测的图像,然后可以通过训练后图像分类模型提取图像的特征信息,例如可以通过训练后图像分类模型的卷积算子提取待检测的图像对应的特征信息(例如特征图)然后,可以基于该特征信息对待检测的图像进行分类,该以确定该图像所属的类别,提高了对图像分类的准确性和便捷性。
本申请实施例可以基于训练样本集中训练样本对应的增强策略对训练样本进行增强操作,得到增强后训练样本,以及通过初始模型基于增强后训练样本进行预测,并基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型;然后可以通过更新后模型对训练样本集中验证样本进行预测,并基于对验证样本的预测结果自动更新增强策略,得到更新后增强策略;此时可以根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。该方案通过增强策略快速增强训练样本,基于增强后训练样本更新初始模型的参数,并通过更新后模型基于验证样本自动更新增强策略,以对模型进行迭代训练,提高了对训练样本增强的效率,以及提高了对模型训练的准确性和精度。
根据上述实施例所描述的方法,以下将说明本申请实施例在一个实际的应用场景中的示例性应用,例如,无人驾驶车辆的应用场景。本实施例以模型训练装置集成在无人驾驶车辆为例、以初始模型为初始目标检测模型、训练后模型为训练后目标检测模型为例进行详细说明,其中,无人驾驶车辆上可以设置有用于采集图像的激光雷达或摄像头等,以及设置有目标检测模型等。
请参阅图7,图7是本申请一实施例提供的模型训练方法的流程示意图。该模型训练方法可以包括:
S201、获取样本图像集,样本图像集包括训练样本图像和验证样本图像。
无人驾驶车辆可以从本地存储数据库中获取样本图像集,或从服务器上下载样本图像集,样本图像集中可以包括多张图像,可以从样本图像集中筛选出多张图像作为代理训练集以及筛选出多张图像作为代理验证集,可以从代理训练集中选择一组或多组样本图像作为训练样本图像,每组训练样本图像中至少包括一张样本图像,以及从代理验证集选择一组或者多组样本图像作为验证样本图像,每组验证样本图像中至少包括一张样本图像。
其中,样本图像集中的样本图像可以包括目标对象以及目标表对象对应的真实值,该真实值可以包括目标对象在样本图像中的位置,还没有补考目标对象的类别等,该目标对象可以包括车辆、人、建筑物、动物或植物等。
S202、获取训练样本图像对应的增强权重,根据增强权重确定训练样本图像对应的增强策略。
其中,增强权重可以包括增强概率ptp和执行概率po等,例如,可以预先为训练样本图像设置初始的增强概率,以及预先为不同增强操作设置初始的执行概率,开始训练时,可以获取训练样本图像初始的增强概率,以及获取不同增强操作初始的执行概率,在训练过程中,可以调整训练样本图像的增强概率,以及调整不同增强操作的执行概率,此时可以获取得到的调整后的训练样本图像的增强概率,以及调整后的不同增强操作的执行概率。
其中,增强操作可以包括平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节等,增强策略可以包括平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节等增强操作中的任一种或多种组合,以及各增强操作的执行优先级等。例如,增强策略A可以包括增强操作1、增强操作2和增强操作3,以及增强操作1、增强操作2和增强操作3之间串联执行的执行优先级顺序:依次对训练样本图像执行增强操作3、增强操作2和增强操作4得到增强后训练样本图像。
在得到训练样本图像对应的增强概率和增强操作对应的执行概率后,无人驾驶车辆可以根据增强概率和执行概率,确定训练样本图像对应的增强权重,例如,训练样本对应的增强权重=增强概率ptp*执行概率po;或者,训练样本对应的增强权重=第一系数*增强概率ptp+第二系数*执行概率po
此时,无人驾驶车辆可以根据增强权重确定训练样本图像对应的增强策略,例如,可以获取预先建立的各个增强策略与增强权重之间的映射关系,根据该映射关系确定与训练样本图像对应的增强权重匹配的增强策略,即可得到训练样本图像对应的增强策略。
S203、根据增强策略对训练样本图像进行增强操作,得到增强后训练样本图像。
在得到增强策略后,无人驾驶车辆可以基于增强策略对训练样本图像进行增强操作,例如,可以在基于训练样本图像的增强概率ptp确定需要进行增强操作后,可以基于增强概率ptp和执行概率po确定增强策略中包含的增强操作:O1,O2,......,ONo,并对训练样本图像进行增强操作,得到增强后训练样本图像。
具体地,在得到训练样本图像的增强概率后,无人驾驶车辆可以判断训练样本图像对应的增强概率是否大于预设概率阈值,当训练样本图像对应的增强概率大于预设概率阈值时,说明需要对训练样本图像进行增强操作,此时,可以根据增强策略中各个增强操作对应的执行优先级顺序,对训练样本图像进行平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节中的任一种或多种组合的增强操作,得到增强后训练样本图像。
例如,当增强策略为依次执行旋转、裁剪、放大和亮度调节时,首先可以对训练样本图像进行旋转操作,得到旋转后训练样本图像,然后可以对旋转后训练样本图像进行裁剪操作,得到裁剪后训练样本图像,其次,可以对裁剪后训练样本图像进行放大操作,得到放大后训练样本图像,最后,可以对放大后训练样本图像进行亮度调节操作,得到增强后训练样本图像。当训练样本图像对应的增强概率小于或等于预设概率阈值时,说明不需要对训练样本图像进行增强操作,后续可以直接利用训练样本图像训练模型。
S204、通过初始目标检测模型基于增强后训练样本图像进行预测,并基于对增强后训练样本图像的预测结果更新初始目标检测模型的参数,得到更新后目标检测模型。
例如,无人驾驶车辆可以通过初始目标检测模型基于增强后训练样本图像进行预测,得到预测值(即预测位置),基于预测值与增强后训练样本图像对应的真实值获取损失值,基于损失值获取第一梯度值,根据第一梯度值更新初始目标检测模型的参数,得到更新后目标检测模型。
S205、通过更新后目标检测模型对验证样本图像进行预测,并基于对验证样本图像的预测结果更新增强策略,得到更新后增强策略。
例如,无人驾驶车辆可以通过更新后目标检测模型对验证样本图像进行预测,以基于预测结果获取第二梯度值,例如,可以获取更新后目标检测模型基于验证样本图像预测得到的预测值与验证样本图像对应的真实值之间的损失值,根据该损失值确定第二梯度值,根据第二梯度值更新增强权重,得到更新后增强权重,根据更新后增强权重确定更新后增强策略。
S206、根据更新后增强策略对更新后目标检测模型进行迭代训练,得到训练后目标检测模型。
无人驾驶车辆可以根据更新后增强策略对第二训练样本图像进行增强操作,得到目标增强后训练样本图像,将目标增强后训练样本图像作为增强后训练样本图像,并将更新后目标检测模型作为初始目标检测模型,返回执行通过初始目标检测模型基于增强后训练样本图像进行预测的操作,直至迭代训练次数达到预设次数或者更新后目标检测模型预测收敛,得到训练后模型。
其中,训练样本图像可以包括多组训练样本图像,例如,第一训练样本图像、第二训练样本图像、......、以及第n训练样本图像等。在上述根据增强策略对训练样本图像进行增强操作的过程中,可以是根据增强策略对第一训练样本图像进行增强操作,得到增强后训练样本图像。在根据更新后增强策略对训练样本图像进行增强操作时,可以是根据更新后增强策略对第二训练样本图像进行增强操作,得到目标增强后训练样本图像。
S207、获取待检测的图像,通过训练后目标检测模型对图像进行目标检测。
其中,待检测的图像可以是通过无人驾驶车辆上预设的激光雷达采集到的深度图像,或者是,待检测的图像可以是通过无人驾驶车辆上预设的摄像头或相机等采集得到的图像。
在得到待检测的图像后,可以通过训练后目标检测模型提取图像的特征信息,例如可以通过训练后目标检测模型的卷积算子提取待检测的图像对应的特征信息(例如特征图),具体地,可以通过卷积算子从待检测的图像中采样预设邻域范围内的区域,得到滑窗区域,获取滑窗区域内采样基准点与邻域点之间的相对关系(例如相对位置关系),基于相对关系获取滑窗区域内各采样点的权重向量,提取滑窗区域内各采样点的特征向量,根据滑窗区域内各采样点的权重向量和特征向量,获取待检测的图像对应的特征信息。
然后,可以基于该特征信息对待检测的图像进行检测,例如,可以根据待检测的图像中各位置点的属性信息对特征信息(例如特征图)进行检测,生成各特征图所对应的预设属性区间内的多个检测框、以及各检测框的分类置信度,该分类置信度可以是检测框与对应的同类别真实框的重合度。之后,可以采用加权非极大值抑制多个检测框进行合并,得到待检测深度图像对应的目标框。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见上文针对模型训练方法的详细描述,此处不再赘述。
本申请实施例可以获取样本图像集包括训练样本图像和验证样本图像的样本图像集,以及获取训练样本图像对应的增强权重,根据增强权重确定训练样本图像对应的增强策略。然后可以根据增强策略对训练样本图像进行增强操作,得到增强后训练样本图像,提高了对训练样本图像增强的效率,以及通过初始目标检测模型基于增强后训练样本图像进行预测,并基于对增强后训练样本图像的预测结果更新初始目标检测模型的参数,得到更新后目标检测模型。其次可以通过更新后目标检测模型对验证样本图像进行预测,并基于对验证样本图像的预测结果更新增强策略,得到更新后增强策略,根据更新后增强策略对更新后目标检测模型进行迭代训练,得到训练后目标检测模型,提高了对目标检测模型训练的准确性和精度;此时可以获取待检测的图像,通过训练后目标检测模型对图像进行目标检测,提高了对目标检测的准确性。
在本实施例中,将从目标检测装置的角度进行描述,该目标检测装置具体可以集成在终端或服务器等计算机设备中。
请参阅图8,图8是本申请一实施例提供的目标检测方法的流程示意图。该目标检测方法可以包括:
S301、获取待检测的图像。
S302、通过训练后目标检测模型提取图像的特征信息,训练后目标检测模型为基于更新后增强策略进行迭代训练得到,更新后增强策略为根据更新后模型基于验证样本的预测结果对增强策略进行更新得到,更新后模型为根据增强后训练样本的预测结果对初始模型的参数进行更新得到,增强后训练样本为通过增强策略对训练样本进行增强操作得到。该目标检测模型的详细训练过程已基于图1-图7的描述中详细公开,这里不再展开赘述。
S303、通过训练后目标检测模型基于特征信息对图像进行目标检测。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见上文针对模型训练方法的详细描述,此处不再赘述。
为便于更好的实施本申请实施例提供的模型训练方法,本申请实施例还提供一种基于上述模型训练方法的装置。其中名词的含义与上述模型训练方法中相同,具体实现细节可以参考方法实施例中的说明。
请参阅图9,图9为本申请实施例提供的模型训练装置400的结构示意图,其中该模型训练装置400可以包括第一获取模块401、第二获取模块402、增强模块403、第一预测模块404、第一更新模块405、第二预测模块406、第二更新模块407以及训练模块408等。
其中,第一获取模块401,用于获取训练样本集,训练样本集包括训练样本和验证样本。
第二获取模块402,用于获取训练样本对应的增强策略。
增强模块403,用于根据增强策略对训练样本进行增强操作,得到增强后训练样本。
第一预测模块404,用于通过初始模型对增强后训练样本进行预测。
第一更新模块405,用于基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型。
第二预测模块406,用于通过更新后模型对验证样本进行预测;
第二更新模块407,用于基于对验证样本的预测结果更新增强策略,得到更新后增强策略。
训练模块408,用于根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。
在一实施方式中,训练样本集为全量训练集中的部分样本集,训练模块408包括:
第一训练子模块,用于根据更新后增强策略对更新后模型进行迭代训练,得到最终的增强策略;
增强子模块,用于采用最终的增强策略对全量训练集进行增强操作;
第二训练子模块,用于基于增强后的全量训练集对更新后模型进行迭代训练,得到训练后模型。
在一实施方式中,增强子模块具体可以用于:将全量训练集中的每一样本均采用最终的增强策略进行增强操作;或者,根据全量训练集和代理训练集的样本数比例,将每轮迭代后得到的更新后增强策略按比例映射到对应的全量训练集中。
在一实施方式中,第二获取模块402包括:
获取子模块,用于获取增强权重;
确定子模块,用于根据增强权重确定训练样本对应的增强策略。
在一实施方式中,确定子模块具体可以用于:根据增强权重采样多个增强策略;
增强模块403具体可以用于:将采样得到的多个增强策略作用到不同的训练样本上,得到增强后的训练样本。
在一实施方式中,获取子模块具体可以用于:获取当前的增强概率,以及获取不同增强操作对应的执行概率;根据增强概率和执行概率,确定增强策略对应的增强权重。
在一实施方式中,增强概率的初始值为设定值;若增强操作库中包括K个操作,则每个增强操作所对应的执行概率的初始值为1/K;增强概率和执行概率在每轮迭代后被更新。
在一实施方式中,获取子模块具体可以用于:将增强概率和增强策略所对应的多个增强操作的执行概率进行连乘操作,得到增强策略对应的增强权重。
在一实施方式中,模型训练装置400可以包括:
筛选模块,用于从多个增强操作中筛选出满足条件的多个候选增强操作;
划分模块,用于将多个候选增强操作划分为多个增强操作组;
生成模块,用于基于每个增强操作组以及每个增强操作组中各增强操作的执行顺序生成多个增强策略。
在一实施方式中,增强策略包括平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节中的任一种或多种组合的增强操作。
在一实施方式中,第一预测模块404具体可以用于:通过初始模型基于增强后训练样本进行预测,得到预测值。
第一更新模块405具体可以用于:基于预测值与增强后训练样本对应的真实值获取损失值;基于损失值获取第一梯度值;根据第一梯度值更新初始模型的参数,得到更新后模型。
在一实施方式中,第一更新模块405具体可以用于:对第一梯度值进行归一化,得到归一化第一梯度值;根据归一化后第一梯度值更新初始模型的参数,得到更新后模型。
在一实施方式中,第二预测模块406具体可以用于:通过更新后模型对验证样本进行预测,以基于预测结果获取第二梯度值;
第二更新模块407具体可以用于:根据第二梯度值更新增强权重,得到更新后增强权重;根据更新后增强权重确定更新后增强策略。
在一实施方式中,第二更新模块407具体可以用于:对第二梯度值进行归一化,得到归一化第二梯度值;根据归一化第二梯度值更新增强权重,得到更新后增强权重。
在一实施方式中,训练样本包括第一训练样本和第二训练样本,增强模块403具体可以用于:根据增强策略对第一训练样本进行增强操作,得到增强后训练样本;
训练模块408具体可以用于:根据更新后增强策略对第二训练样本进行增强操作,得到目标增强后训练样本;将目标增强后训练样本作为增强后训练样本,以及将更新后模型作为初始模型,并返回执行通过初始模型基于增强后训练样本进行预测的操作,直至迭代训练次数达到预设次数或者更新后模型预测得到的损失值小于预设阈值,得到训练后模型。
在一实施方式中,模型训练装置400还可以包括:
检测模块,用于获取待检测的图像,通过训练后目标检测模型或训练后图像分类模型提取图像的特征信息,并基于特征信息对图像进行目标检测或者图像分类。
本申请实施例可以由第一获取模块401获取训练样本集,训练样本集包括训练样本和验证样本,以及由第二获取模块402获取训练样本对应的增强策略,由增强模块403基于训练样本集中训练样本对应的增强策略对训练样本进行增强操作,得到增强后训练样本,以及由第一预测模块404通过初始模型基于增强后训练样本进行预测,并由第一更新模块405基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型;然后可以由第二预测模块406通过更新后模型对训练样本集中验证样本进行预测,并由第二更新模块407基于对验证样本的预测结果自动更新增强策略,得到更新后增强策略;此时可以由训练模块408根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。该方案通过增强策略快速增强训练样本,基于增强后训练样本更新初始模型的参数,并通过更新后模型基于验证样本自动更新增强策略,以对模型进行迭代训练,提高了对训练样本增强的效率,以及提高了对模型训练的准确性和精度。
为便于更好的实施本申请实施例提供的目标检测方法,本申请实施例还提供一种基于上述目标检测方法的装置。其中名词的含义与上述目标检测方法中相同,具体实现细节可以参考方法实施例中的说明。
请参阅图10,图10为本申请实施例提供的目标检测装置500的结构示意图,其中该目标检测装置500可以包括图像获取模块501、提取模块502以及检测模块503等。
其中,图像获取模块501,用于获取待检测的图像。
提取模块502,用于通过训练后目标检测模型提取图像的特征信息,训练后目标检测模型为基于更新后增强策略进行迭代训练得到,更新后增强策略为根据更新后模型基于验证样本的预测结果对增强策略进行更新得到,更新后模型为根据增强后训练样本的预测结果对初始模型的参数进行更新得到,增强后训练样本为通过增强策略对训练样本进行增强操作得到。
检测模块503,用于通过训练后目标检测模型基于特征信息对图像进行目标检测。
本申请实施例还提供一种计算机设备,该计算机设备可以是终端或服务器等,如图11所示,其示出了本申请实施例所涉及的计算机设备的结构示意图,具体来讲:
该计算机设备可以包括一个或者一个以上处理核心的处理器601、一个或一个以上计算机可读存储介质的存储器602、电源603和输入单元604等部件。本领域技术人员可以理解,图11中示出的计算机设备结构并不构成对计算机设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。其中:
处理器601是该计算机设备的控制中心,利用各种接口和线路连接整个计算机设备的各个部分,通过运行或执行存储在存储器602内的软件程序和/或模块,以及调用存储在存储器602内的数据,执行计算机设备的各种功能和处理数据,从而对计算机设备进行整体监控。可选的,处理器601可包括一个或多个处理核心;优选的,处理器601可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器601中。
存储器602可用于存储软件程序以及模块,处理器601通过运行存储在存储器602的软件程序以及模块,从而执行各种功能应用以及数据处理。存储器602可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据计算机设备的使用所创建的数据等。此外,存储器602可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。相应地,存储器602还可以包括存储器控制器,以提供处理器601对存储器602的访问。
计算机设备还包括给各个部件供电的电源603,优选的,电源603可以通过电源管理系统与处理器601逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。电源603还可以包括一个或一个以上的直流或交流电源、再充电系统、电源故障检测电路、电源转换器或者逆变器、电源状态指示器等任意组件。
该计算机设备还可包括输入单元604,该输入单元604可用于接收输入的数字或字符信息,以及产生与用户设置以及功能控制有关的键盘、鼠标、操作杆、光学或者轨迹球信号输入。
尽管未示出,计算机设备还可以包括显示单元等,在此不再赘述。具体在本实施例中,计算机设备中的处理器601会按照如下的指令,将一个或一个以上的应用程序的进程对应的可执行文件加载到存储器602中,并由处理器601来运行存储在存储器602中的应用程序,从而实现各种功能,如下:
获取训练样本集,训练样本集包括训练样本和验证样本,以及获取训练样本对应的增强策略;根据增强策略对训练样本进行增强操作,得到增强后训练样本;通过初始模型基于增强后训练样本进行预测,基于对增强后训练样本的预测结果更新初始模型的参数,得到更新后模型;通过更新后模型对验证样本进行预测,基于对验证样本的预测结果更新增强策略,得到更新后增强策略;根据更新后增强策略对更新后模型进行迭代训练,得到训练后模型。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述的部分,可以参见上文针对模型训练方法的详细描述,此处不再赘述。
根据本申请的一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中各种可选实现方式中提供的方法。
本领域普通技术人员可以理解,上述实施例的各种方法中的全部或部分步骤可以通过计算机指令来完成,或通过计算机指令控制相关的硬件来完成,该计算机指令可以存储于一计算机可读存储介质中,并由处理器进行加载和执行。为此,本申请实施例提供一种计算机可读存储介质,其中存储有计算机程序,计算机程序可以包括计算机指令,该计算机程序能够被处理器进行加载,以执行本申请实施例所提供的任一种模型训练方法。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
其中,该计算机可读存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、磁盘或光盘等。
由于该计算机可读存储介质中所存储的指令,可以执行本申请实施例所提供的任一种模型训练方法,以及执行本申请实施例所提供的任一种模型训练方法,因此,可以实现本申请实施例所提供的任一种模型训练方法所能实现的有益效果,详见前面的实施例,在此不再赘述。
以上对本申请实施例所提供的一种模型训练方法、模型训练方法、模型训练装置、模型训练装置、计算机设备及计算机可读存储介质进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。

Claims (21)

1.一种模型训练方法,其特征在于,包括:
获取训练样本集,所述训练样本集包括训练样本和验证样本;
获取所述训练样本对应的增强策略;
根据所述增强策略对所述训练样本进行增强操作,得到增强后训练样本;
通过初始模型对所述增强后训练样本进行预测;
基于对所述增强后训练样本的预测结果更新所述初始模型的参数,得到更新后模型;
通过所述更新后模型对所述验证样本进行预测;
基于对所述验证样本的预测结果更新所述增强策略,得到更新后增强策略;以及
根据所述更新后增强策略对所述更新后模型进行迭代训练,得到训练后模型。
2.根据权利要求1所述的方法,其特征在于,所述训练样本集为全量训练集中的部分样本集,所述根据所述更新后增强策略对所述更新后模型进行迭代训练,得到训练后模型,包括:
根据所述更新后增强策略对所述更新后模型进行迭代训练,得到最终的增强策略;
采用所述最终的增强策略对所述全量训练集进行增强操作;
基于增强后的全量训练集对所述更新后模型进行迭代训练,得到所述训练后模型。
3.根据权利要求2所述的方法,其特征在于,所述采用所述最终的增强策略对全量训练集进行增强操作,包括:
将所述全量训练集中的每一样本均采用所述最终的增强策略进行增强操作;或者
根据所述全量训练集和代理训练集的样本数比例,将每轮迭代后得到的更新后增强策略按比例映射到对应的全量训练集中。
4.根据权利要求1所述的方法,其特征在于,所述获取所述训练样本对应的增强策略,包括:
获取增强权重,根据所述增强权重确定所述训练样本对应的增强策略。
5.根据权利要求4所述的方法,其特征在于,所述根据所述增强权重确定所述训练样本对应的增强策略,包括:
根据增强权重采样多个增强策略;
所述根据所述增强策略对所述训练样本进行增强操作,得到增强后训练样本包括:
将采样得到的多个增强策略作用到不同的训练样本上,得到增强后的训练样本。
6.根据权利要求4所述的方法,其特征在于,所述获取增强权重,包括:
获取当前的增强概率,以及获取不同增强操作对应的执行概率;
根据所述增强概率和所述执行概率,确定所述增强策略对应的增强权重。
7.根据权利要求6所述的方法,其特征在于,所述增强概率的初始值为设定值;若增强操作库中包括K个增强操作,则每个增强操作所对应的执行概率的初始值为1/K;所述增强概率和执行概率在每轮迭代后被更新。
8.根据权利要求6所述的方法,其特征在于,所述根据所述增强概率和所述执行概率,确定所述增强策略对应的增强权重,包括:
将所述增强概率和所述增强策略所对应的多个增强操作的执行概率进行连乘操作,得到所述增强策略对应的增强权重。
9.根据权利要求5所述的方法,其特征在于,所述方法还包括:
从多个增强操作中筛选出满足条件的多个候选增强操作;
将所述多个候选增强操作划分为多个增强操作组;
基于每个增强操作组以及每个增强操作组中各增强操作的执行顺序生成多个增强策略。
10.根据权利要求1所述的方法,其特征在于,所述增强策略包括平移、旋转、裁剪、翻转、镜像、缩小、放大、噪声叠加、颜色变换、以及亮度调节中的任一种或多种组合的增强操作。
11.根据权利要求1所述的方法,其特征在于,所述通过初始模型基于所述增强后训练样本进行预测,基于对所述增强后训练样本的预测结果更新所述初始模型的参数,得到更新后模型,包括:
通过初始模型基于所述增强后训练样本进行预测,得到预测值;
基于所述预测值与所述增强后训练样本对应的真实值获取损失值;
基于所述损失值获取第一梯度值;
根据所述第一梯度值更新所述初始模型的参数,得到更新后模型。
12.根据权利要求11所述的方法,其特征在于,所述根据所述第一梯度值更新所述初始模型的参数,得到更新后模型,包括:
对所述第一梯度值进行归一化,得到归一化第一梯度值;
根据所述归一化后第一梯度值更新所述初始模型的参数,得到更新后模型。
13.根据权利要求1所述的方法,其特征在于,所述通过所述更新后模型对所述验证样本进行预测,基于对所述验证样本的预测结果更新所述增强策略,得到更新后增强策略,包括:
通过所述更新后模型对所述验证样本进行预测,以基于预测结果获取第二梯度值;
根据所述第二梯度值更新所述增强权重,得到更新后增强权重;
根据所述更新后增强权重确定更新后增强策略。
14.根据权利要求13所述的模型训练方法,其特征在于,所述根据所述第二梯度值更新所述增强权重,得到更新后增强权重,包括:
对所述第二梯度值进行归一化,得到归一化第二梯度值;
根据归一化第二梯度值更新所述增强权重,得到更新后增强权重。
15.根据权利要求1所述的方法,其特征在于,所述训练样本包括第一训练样本和第二训练样本,所述根据所述增强策略对所述训练样本进行增强操作,得到增强后训练样本包括:根据所述增强策略对所述第一训练样本进行增强操作,得到增强后训练样本;
所述根据所述更新后增强策略对所述更新后模型进行迭代训练,得到训练后模型包括:
根据所述更新后增强策略对所述第二训练样本进行增强操作,得到目标增强后训练样本;
将所述目标增强后训练样本作为增强后训练样本,以及将所述更新后模型作为初始模型,并返回执行通过初始模型基于所述增强后训练样本进行预测的操作,直至迭代训练次数达到预设次数或者更新后模型预测得到的损失值小于预设阈值,得到训练后模型。
16.根据权利要求1至15任一项所述的方法,其特征在于,所述训练后模型包括训练后目标检测模型或训练后图像分类模型,所述方法还包括:
获取待检测的图像;
通过所述训练后目标检测模型或训练后图像分类模型提取所述图像的特征信息,并基于所述特征信息对所述图像进行目标检测或者图像分类。
17.一种目标检测方法,其特征在于,包括:
获取待检测的图像;
通过训练后目标检测模型提取所述图像的特征信息,所述训练后目标检测模型为基于更新后增强策略进行迭代训练得到,所述更新后增强策略为根据更新后模型基于验证样本的预测结果对增强策略进行更新得到,所述更新后模型为根据增强后训练样本的预测结果对初始模型的参数进行更新得到,所述增强后训练样本为通过增强策略对训练样本进行增强操作得到;以及
通过所述训练后目标检测模型基于所述特征信息对所述图像进行目标检测。
18.一种模型训练装置,其特征在于,包括:
第一获取模块,用于获取训练样本集,所述训练样本集包括训练样本和验证样本;
第二获取模块,用于获取所述训练样本对应的增强策略;
增强模块,用于根据所述增强策略对所述训练样本进行增强操作,得到增强后训练样本;
第一预测模块,用于通过初始模型对所述增强后训练样本进行预测;
第一更新模块,用于基于对所述增强后训练样本的预测结果更新所述初始模型的参数,得到更新后模型;
第二预测模块,用于通过所述更新后模型对所述验证样本进行预测;
第二更新模块,用于基于对所述验证样本的预测结果更新所述增强策略,得到更新后增强策略;以及
训练模块,用于根据所述更新后增强策略对所述更新后模型进行迭代训练,得到训练后模型。
19.一种目标检测装置,其特征在于,包括:
图像获取模块,用于获取待检测的图像;
提取模块,用于通过训练后目标检测模型提取所述图像的特征信息,所述训练后目标检测模型为基于更新后增强策略进行迭代训练得到,所述更新后增强策略为根据更新后模型基于验证样本的预测结果对增强策略进行更新得到,所述更新后模型为根据增强后训练样本的预测结果对初始模型的参数进行更新得到,所述增强后训练样本为通过增强策略对训练样本进行增强操作得到;
检测模块,用于通过所述训练后目标检测模型基于所述特征信息对所述图像进行目标检测。
20.一种计算机设备,其特征在于,包括处理器和存储器,所述存储器中存储有计算机程序,所述处理器调用所述存储器中的计算机程序时执行如权利要求1至16任一项所述的模型训练方法,或执行如权利要求17所述的目标检测方法。
21.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质用于存储计算机程序,所述计算机程序被处理器加载以执行权利要求1至16任一项所述的模型训练方法,或执行权利要求17所述的目标检测方法。
CN202110133274.8A 2021-02-01 2021-02-01 一种模型训练方法、目标检测方法、以及相关设备 Pending CN114926701A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110133274.8A CN114926701A (zh) 2021-02-01 2021-02-01 一种模型训练方法、目标检测方法、以及相关设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110133274.8A CN114926701A (zh) 2021-02-01 2021-02-01 一种模型训练方法、目标检测方法、以及相关设备

Publications (1)

Publication Number Publication Date
CN114926701A true CN114926701A (zh) 2022-08-19

Family

ID=82804075

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110133274.8A Pending CN114926701A (zh) 2021-02-01 2021-02-01 一种模型训练方法、目标检测方法、以及相关设备

Country Status (1)

Country Link
CN (1) CN114926701A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116416492A (zh) * 2023-03-20 2023-07-11 湖南大学 一种基于特征自适应的自动数据增广方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116416492A (zh) * 2023-03-20 2023-07-11 湖南大学 一种基于特征自适应的自动数据增广方法
CN116416492B (zh) * 2023-03-20 2023-12-01 湖南大学 一种基于特征自适应的自动数据增广方法

Similar Documents

Publication Publication Date Title
CN109948029B (zh) 基于神经网络自适应的深度哈希图像搜索方法
US11240121B2 (en) Methods and systems for controlling data backup
CN110766142A (zh) 模型生成方法和装置
CN111340221B (zh) 神经网络结构的采样方法和装置
CN109961041B (zh) 一种视频识别方法、装置及存储介质
CN111368973B (zh) 用于训练超网络的方法和装置
CN111079780A (zh) 空间图卷积网络的训练方法、电子设备及存储介质
CN112200296B (zh) 网络模型量化方法、装置、存储介质及电子设备
CN113344016A (zh) 深度迁移学习方法、装置、电子设备及存储介质
CN116684330A (zh) 基于人工智能的流量预测方法、装置、设备及存储介质
CN114926701A (zh) 一种模型训练方法、目标检测方法、以及相关设备
CN113420792A (zh) 图像模型的训练方法、电子设备、路侧设备及云控平台
CN111957053A (zh) 游戏玩家匹配方法、装置、存储介质与电子设备
CN111161238A (zh) 图像质量评价方法及装置、电子设备、存储介质
CN116432780A (zh) 一种模型增量学习方法、装置、设备及存储介质
CN109815474B (zh) 一种词序列向量确定方法、装置、服务器及存储介质
CN114120180B (zh) 一种时序提名的生成方法、装置、设备及介质
CN115240704A (zh) 音频识别方法、装置、电子设备和计算机程序产品
CN114489574B (zh) 一种基于svm的流处理框架的自动调优方法
CN116027829A (zh) 机房温度控制方法、装置、设备及存储介质
CN113033397A (zh) 目标跟踪方法、装置、设备、介质及程序产品
CN111931994A (zh) 一种短期负荷及光伏功率预测方法及其系统、设备、介质
CN115630772B (zh) 综合能源检测配电方法、系统、设备及存储介质
WO2024012179A1 (zh) 模型训练方法、目标检测方法及装置
CN112437460B (zh) Ip地址黑灰名单分析方法、服务器、终端及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination