CN112734641B - 目标检测模型的训练方法、装置、计算机设备及介质 - Google Patents

目标检测模型的训练方法、装置、计算机设备及介质 Download PDF

Info

Publication number
CN112734641B
CN112734641B CN202011625437.6A CN202011625437A CN112734641B CN 112734641 B CN112734641 B CN 112734641B CN 202011625437 A CN202011625437 A CN 202011625437A CN 112734641 B CN112734641 B CN 112734641B
Authority
CN
China
Prior art keywords
sample image
sub
image
images
detection model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202011625437.6A
Other languages
English (en)
Other versions
CN112734641A (zh
Inventor
陈建强
陈德健
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Bigo Technology Pte Ltd
Original Assignee
Bigo Technology Pte Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Bigo Technology Pte Ltd filed Critical Bigo Technology Pte Ltd
Priority to CN202011625437.6A priority Critical patent/CN112734641B/zh
Publication of CN112734641A publication Critical patent/CN112734641A/zh
Application granted granted Critical
Publication of CN112734641B publication Critical patent/CN112734641B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T3/00Geometric image transformations in the plane of the image
    • G06T3/40Scaling of whole images or parts thereof, e.g. expanding or contracting
    • G06T3/4038Image mosaicing, e.g. composing plane images from plane sub-images
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/10Segmentation; Edge detection
    • G06T7/11Region-based segmentation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/60Analysis of geometric attributes
    • G06T7/62Analysis of geometric attributes of area, perimeter, diameter or volume
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20081Training; Learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20084Artificial neural networks [ANN]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20112Image segmentation details
    • G06T2207/20132Image cropping
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V2201/00Indexing scheme relating to image or video recognition or understanding
    • G06V2201/07Target detection
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Evolutionary Biology (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Computational Linguistics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Geometry (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例公开了一种目标检测模型的训练方法、装置、计算机设备及介质,属于计算机技术领域。该方法包括:基于原始样本图像生成第一样本图像,原始样本图像中包含待检测对象,第一样本图像包含至少两张子图像,子图像是对原始样本图像中包含的待检测对象进行裁剪得到;基于第一样本图像预训练目标检测模型,预训练的目的为调整目标检测模型中特征提取网络的网络参数;基于原始样本图像生成第二样本图像,第二样本图像包含至少两张原始样本图像;基于第二样本图像对目标检测模型进行微调。增加样本图像中数据的多样性,避免通过样本图像训练得到的目标检测模型对某些属性较为依赖,从而增加目标检测模型的鲁棒性和准确性。

Description

目标检测模型的训练方法、装置、计算机设备及介质
技术领域
本申请实施例涉及计算机技术领域,特别涉及一种目标检测模型的训练方法、装置、计算机设备及介质。
背景技术
目标检测任务指示定位出图像中目标对象的位置,并确定出目标对象所属的类别,即目标检测任务包含定位和分类两个子任务,随着深度学习与计算机硬件的迅速发展,目标检测与深度学习的集合,使得目标检测任务也取得了较大的发展。
相关技术中,在构造目标检测任务的训练数据集时,一般直接采用目标检测数据集中的原始样本图像作为训练样本,而目标检测数据集中的原始样本图像一般包含单个用于进行目标检测的目标对象,训练样本的数据多样性较差,基于该训练样本训练目标检测模型时,会导致目标检测模型对某些属性的依赖性较高,从而影响目标检测模型的模型鲁棒性。
发明内容
本申请实施例提供了一种目标检测模型的训练方法、装置、计算机设备及介质。所述技术方案如下:
一方面,本申请实施例提供了一种目标检测模型的训练方法,所述方法包括:
基于原始样本图像生成第一样本图像,所述原始样本图像中包含待检测对象,所述第一样本图像包含至少两张子图像,所述子图像是对所述原始样本图像中包含的所述待检测对象进行裁剪得到;
基于所述第一样本图像预训练目标检测模型,预训练的目的为调整所述目标检测模型中特征提取网络的网络参数;
基于所述原始样本图像生成第二样本图像,所述第二样本图像包含至少两张所述原始样本图像;
基于所述第二样本图像对所述目标检测模型进行微调。
另一方面,本申请实施例提供了一种目标检测模型的训练装置,所述装置包括:
第一生成模块,用于基于原始样本图像生成第一样本图像,所述原始样本图像中包含待检测对象,所述第一样本图像包含至少两张子图像,所述子图像是对所述原始样本图像中包含的所述待检测对象进行裁剪得到;
预训练模块,用于基于所述第一样本图像预训练目标检测模型,预训练的目的为调整所述目标检测模型中特征提取网络的网络参数;
第二生成模块,用于基于所述原始样本图像生成第二样本图像,所述第二样本图像包含至少两张所述原始样本图像;
训练模块,用于基于所述第二样本图像对所述目标检测模型进行微调。
另一方面,本申请实施例提供了一种计算机设备,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令,所述至少一条指令由所述处理器加载并执行以实现如上述方面所述的目标检测模型的训练方法。
另一方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有至少一条指令,所述至少一条指令由处理器加载并执行以实现如上述方面所述的目标检测模型的训练方法。
另一方面,本申请实施例提供了一种计算机程序产品,该计算机程序产品包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述方面所述的目标检测模型的训练方法。
本申请实施例提供的技术方案带来的有益效果至少包括:
在训练目标检测模型过程中,通过在预训练以及微调目标检测模型的过程中,均采用包含多个待检测对象的样本图像,即通过原始样本图像裁剪或拼接得到预训练和微调阶段所需的样本图像,从而增加样本图像中数据的多样性,避免通过样本图像训练得到的目标检测模型对某些属性较为依赖,从而增加目标检测模型的鲁棒性和准确性;此外,通过对原始样本图像中的待检测对象进行裁剪和拼接后生成的第一样本图像,作为预训练阶段的样本图像,由于基于待检测对象进行裁剪,可以排除待检测对象之外的其他因素的干扰,从而加速预训练阶段模型的收敛。
附图说明
图1示出了本申请一个示例性实施例示出的目标检测模型的训练方法的流程图;
图2是本申请一个示例性实施例示出的第一样本图像的生成过程示意图;
图3是本申请一个示例性实施例示出的第二样本图像的生成过程示意图;
图4示出了本申请一个示例性实施例示出的第一样本图像的生成方法流程图;
图5示出了本申请实施例所示出的三类子图像的示意图;
图6示出了本申请另一个示例性实施例示出的目标检测模型的训练方法的流程图;
图7示出了本申请一个示例性实施例示出的预训练阶段的分类模型示意图;
图8示出了本申请一个示例性实施例示出的目标检测模型的训练过程示意图;
图9示出了本申请一个示例性实施例示出的完整目标检测模型的训练方法的流程图;
图10示出了本申请一个示例性实施例提供的目标检测模型的训练装置的结构框图;
图11示出了本申请一个示例性实施例提供的计算机设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。
请参考图1,其示出了本申请一个示例性实施例示出的目标检测模型的训练方法的流程图,本申请实施例以该方法应用于计算机设备为例进行说明,该方法包括:
步骤101,基于原始样本图像生成第一样本图像,原始样本图像中包含待检测对象,第一样本图像包含至少两张子图像,子图像是对原始样本图像中包含的待检测对象进行裁剪得到。
其中,原始样本图像中包含至少一个待检测对象,其可以从目标检测数据集中获取得到,目标检测数据集可以是VOC数据集、MSCOCO数据集等,本申请实施例对原始样本图像的来源不构成限定。
不同于相关技术中直接采用原始样本图像训练目标检测模型,原始样本图像中一般仅包含单个待检测对象,训练样本的数据多样性和数据结构较为单一,本申请实施例中,为了提高训练样本的数据多样性,在预训练目标检测模型时,通过对原始样本图像进行裁剪和拼接,生成类似于马赛克样式的第一样本图像,也就是说,第一样本图像中包含至少两张子图像,且子图像是对原始样本图像中包含的待检测对象进行裁剪得到。
在一种可能的实施方式中,原始样本图像中包含有待检测对象,且原始样本图像中标注有待检测对象对应的边框(bounding box),可以按照边框对原始样本图像进行裁剪,得到所需要的目标区域,即仅包含单个待检测对象的子图像,再对至少两个子图像进行拼接,从而得到用于预训练的第一样本图像。
可选的,为了获得更好的上下文信息,在对原始样本图像进行裁剪时,可以在边框的上下左右各多裁剪一定长度,比如,多裁剪30个像素的区域。
如图2所示,其是本申请一个示例性实施例示出的第一样本图像的生成过程示意图。从目标检测数据集中获取到原始样本图像201,按照标注框202对原始样本图像进行裁剪,得到仅包含单个待检测对象的子图像203,对其中的4个子图像203进行拼接和像素处理,得到第一样本图像204,其中,斜线的阴影部分为像素填充区域。
步骤102,基于第一样本图像预训练目标检测模型,预训练的目的为调整目标检测模型中特征提取网络的网络参数。
目标检测模型的训练过程包括预训练阶段和微调阶段,其中,预训练阶段主要是为了调整目标检测模型中特征提取网络的网络参数,使得预训练得到的特征提取网络可以较好的提取物体的特征表示,从而有利于目标检测模型在微调阶段的模型收敛。
由于第一样本图像由目标检测数据集中的原始样本图像生成基于生成,通过第一样本图像预训练目标检测模型,相比于采用目标分类数据集中的图像作为训练样本,可以降低算法的计算负担,可以在有限的计算资源的情况下预训练好所需的特征提取网络来加速微调阶段的收敛。
在一种可能的实施方式中,由于目标分类模型和目标检测模型在特征提取阶段类似,因此,预训练阶段采用的网络模型一般为目标分类模型,其中,目标分类模型一般由卷积层、池化层和全连接层构成,卷积层即为特征提取网络。
以目标检测模型为yolo网络模型为例,yolo网络模型由多层卷积构成,预训练阶段的分类网络模型包括20个卷积层+1个池化层+1个全连接层,即预训练阶段主要训练yolo网络模型的前20个卷积层,调整好前20个卷积层的网络参数,以便在训练阶段,通过前20个卷积层的网络参数来初始化yolo网络模型前20个卷积层的网络参数。
步骤103,基于原始样本图像生成第二样本图像,第二样本图像包含至少两张原始样本图像。
与预训练阶段中构建的第一样本图像类似,为了增加训练阶段样本图像的数据多样性,在一种可能的实施方式中,通过对多张原始样本图像进行拼接,生成类似于马赛克样式的第二样本图像。
由于目标检测任务包括对象分类和对象定位两个任务,即训练目标检测模型过程中,不仅需要目标检测模型准确识别到待检测对象所属的类别,还需要目标检测模型准确定位到待检测对象在第二样本图像中的位置,因此,在生成目标检测模型微调阶段的第二样本图像过程中,无需对原始样本图像进行裁剪,可以直接根据多个原始样本图像拼接为第二样本图像。
如图3所示,其是本申请一个示例性实施例示出的第二样本图像的生成过程示意图。从目标检测数据集中获取到原始样本图像301,从中选取4个原始样本图像进行拼接和像素填充处理,得到第二样本图像302,其中,斜线的阴影部分为像素填充区域。
步骤104,基于第二样本图像对目标检测模型进行微调。
在一种可能的实施方式中,在获取到微调阶段对应的第二样本图像后,可以直接根据第二样本图像训练目标检测模型,即根据得到的分类损失和回归框损失调整整个目标检测模型的参数。
以目标检测模型为yolo网络模型为例,通过预训练阶段调整前20个卷积层的网络参数,在微调阶段,通过调整好的20个卷积层的网络参数初始化yolo网络模型前20个卷积层的网络参数,并通过第二样本图像训练完整的yolo网络模型。
综上所述,本申请实施例中,在训练目标检测模型过程中,通过在预训练以及微调目标检测模型的过程中,均采用包含多个待检测对象的样本图像,即通过原始样本图像裁剪或拼接得到预训练和微调阶段所需的样本图像,从而增加样本图像中数据的多样性,避免通过样本图像训练得到的目标检测模型对某些属性较为依赖,从而增加目标检测模型的鲁棒性和准确性;此外,通过对原始样本图像中的待检测对象进行裁剪和拼接后生成的第一样本图像,作为预训练阶段的样本图像,由于基于待检测对象进行裁剪,可以排除待检测对象之外的其他因素的干扰,从而加速预训练阶段模型的收敛。
由于预训练阶段的模型任务是得到分类结果,为了降低拼接不同子图像对预训练阶段的干扰,比如,在拼接多个子图像生成第一样本图像时,导致第一样本图像中的留白(像素填充)过多,从而影响预训练阶段的模型收敛速率,因此,在一种可能的实施方式中,通过统计不同子图像的长宽比,以便合理采用具有不同长宽比的子图像拼接形成第一样本图像,尽量减少第一样本图像中的像素填充区域。
请参考图4,其示出了本申请一个示例性实施例示出的第一样本图像的生成方法流程图,本申请实施例以该方法应用于计算机设备为例进行说明,该方法包括:
步骤401,基于待检测对象在原始样本图像中的位置对原始样本图像进行裁剪,得到子图像。
从目标检测数据集中获取原始样本图像时,也会获取到原始样本图像中包含的待检测对象对应的边框,边框即指示待检测对象在原始样本图像中的位置,因此,在一种可能的实施方式中,在对原始样本图像进行裁剪时,可以基于待检测对象在原始样本图像中的位置(边框)对原始样本图像进行裁剪,从而得到包含单个待检测对象的子图像。
步骤401,获取各个子图像对应的长宽比,长宽比指子图像的长边和宽边之间的比值。
由于不同原始样本图像大小的差异,以及原始样本图像中包含的待检测对象大小的差异,使得裁剪得到的子图像存在形状和大小上的差异,而一般训练样本具有固定尺寸,如果多个子样本拼接后形成的样本图像与训练样本的图形尺寸差异较大,就需要对差异部分进行像素填充,因此,为了减少像素填充的部分,在一种可能的实施方式中,通过获取各个子图像对应的长宽比,来确定拼接方案。其中,长宽比为子图像对应的图像长边和图像宽边之间的比值。
步骤402,基于长宽比对至少两张子图像进行拼接处理,得到第一样本图像。
在一种可能的实施方式中,在获取到各个子图像对应的长宽比后,可以选择长宽比存在互补关系的多张子图像进行拼接,得到第一样本图像,以尽量减少第一样本图像中像素填充的区域。
其中,基于长宽比选择多个子图像进行拼接处理的过程可以包括以下步骤:
一、按照长宽比将子图像划分为第一类子图像、第二类子图像和第三类子图像,第一类子图像对应的长宽比介于第一长宽比阈值和第二长宽比阈值之间,第二类子图像对应的长宽比小于第一长宽比阈值,第三类子图像对应的长宽比大于第一长宽比阈值。
在一种可能的实施方式中,基于不同长宽比对应的子图像的形状,将子图像划分为第一类子图像(方形)、第二类子图像(竖形)以及第三类子图像(横形)。
其中,第一类图像(方形)对应的长宽比介于第一长宽比阈值和第二长宽比阈值之间,第一长宽比阈值和第二长宽比阈值可以基于方形的长宽比特征确定,由于方形图像对应的长边和宽边的长度较为接近或相等,因此,方形对应的长宽比一般位于1附近,因此,可以设置第一长宽比阈值为0.8,第二长宽比阈值为1.2,即第一类子图像的长宽比介于0.8~1.2之间。
由于第二类子图像的形状为竖形,而竖形图像对应的长边较短,宽边较长,因此,竖形图像对应的长宽比应该小于1,对应的,可以设置第二类子图像对应的长宽比小于第一长宽比阈值,比如,第二类子图像对应的长宽比小于0.8。
由于第三类子图像对应横形,与竖形相反,即横形图像对应的长边较长,宽边较短,因此,横形图像对应的长宽比应该大于1,对应的,可以设置第三类子图像对应的长宽比大于第二长宽比阈值,比如,第三类子图像对应的长宽比大于1.2。
如图5所示,其示出了本申请实施例所示出的三类子图像的示意图。第一类子图像(方形502),其对应的长宽比(长边和宽边的比值)介于0.8~1.2之间,第二类子图像(竖形501),其对应的长宽比小于0.8,第三类子图像(横形503),其对应的长宽比大于1.2。
二、从第一类子图像、第二类子图像和第三类子图像中选取至少两张子图像。
基于拼接原理(尽量减少像素填充区域),在一种可能的实施方式中,基于需要从不同类子图像中选取第一样本所需要的子图像,比如,若只需要两张第一类子图像可以拼接出第一样本图像,则仅需要从第一类子图像中随机选取两张进行拼接即可;若需要两张第一类子图像和一张第二类子图像,则需要从第一类子图像中选取两张,从第二类子图像中选取一张,本申请实施例对选取的具体方式不构成限定。
三、基于至少两张子图像生成第一样本图像。
在一种可能的实施方式中,当基于长宽比选择出所需要的多张子图像后,可以对多张子图像进行拼接处理,若对多张子图像进行拼接处理后的图像与第一样本图像的尺寸还存在差异,可以在缺失部分填充像素。
以第一样本图像中包含四张子图像为例,通过拼接实验得到,由四张子图像拼接第一样本图像时,分别从第一类子图像中选取一大一小的两张子图像,从第二类子图像和第三类子图像中分别选取一张子图像,可以尽量保证拼出方形的第一样本图像,且像素填充较少,因此,在一种可能的实施方式中,对四张子图像进行拼接处理,得到第一原始样本图像,四张子图像包括两张第一类子图像、一张第二类子图像和一张第三类子图像。
可选的,由于第一样本图像设置有预设图像尺寸,若对四张图像进行拼接后,得到的第一原始样本图像对应的图像尺寸与预设图像尺寸存在差异,还需要按照预设图像尺寸述第一原始样本图像进行像素填充处理,得到第一样本图像。
本实施例中,通过统计裁剪后的各个子图像的长宽比,可以基于长宽比将不同子图像划分为不同类别,对应在根据不同子图像拼接第一样本图像的过程中,可以基于长宽比选择合适的子图像进行拼接,使得生成的第一样本图像中像素填充区域较少,避免像素填充区域较多对预训练阶段的影响。
上文实施例主要介绍了目标检测模型的完整训练流程,包括预训练阶段和微调阶段,本实施例中详细介绍每个阶段的训练过程,比如,预训练阶段损失函数的构建以及微调阶段损失函数的构建过程。
如图6所示,其示出了本申请另一个示例性实施例示出的目标检测模型的训练方法的流程图,本申请实施例以该方法应用于计算机设备为例进行说明,该方法包括:
步骤601,基于原始样本图像生成第一样本图像。
步骤601的实施方式可以参考上文实施例,本实施例在此不做赘述。
步骤602,将第一样本图像输入目标分类模型,得到目标分类模型输出的第一预测分类信息,目标分类模型中包含特征提取网络。
由于目标分类模型的任务是预测第一样本图像中各个待检测对象所属的类别,因此,在一种可能的实施方式中,将第一样本图像输入目标分类模型中,可以得到目标分类模型输出的第一样本图像中各个待检测对象所属类别的预测概率,也就是说,第一预测分类信息为第一样本图像中各个待检测对象所属类别的预测概率。
其中,目标分类模型主要由特征提取网络+池化层+全连接层构成,在预训练阶段,主要通过预训练调整特征提取网络的参数,特征提取网络主要由多个卷积层构成。
可选的,目标分类模型可以基于移动面部网络(MobileNetV2)模型的基础上进行改进得到,在一种可能的实施方式中,可以将mobilenet-V2最后一层全连接层替换为包含C个神经元的全连接层,其中,C为目标检测数据集中类别的数量,比如,目标检测数据集中包含20中类别的对象,对应的C就是20,也就是目标分类网络中的全连接层包含20个神经元。
如图7所示,其示出了本申请一个示例性实施例示出的预训练阶段的分类模型示意图。将第一样本图像701输入特征提取网络702,进行特征提取,得到图像特征,该图像特征经过池化层703和全连接层704处理后,由全连接层704输出预测分类信息705。
目标检测数据集中的类别可以是生物类别,比如,狗、猫、鸟、花等;也可以是非生物类别,比如,手机、电脑、桌子、椅子等。
步骤603,基于第一预测分类信息和第一样本图像对应的第一标注分类信息,计算得到目标分类模型对应的第一分类损失。
神经网络训练过程,即网络模型学习预测和标准之间差异(损失)的过程,以便基于损失调整网络模型中的各个参数,因此,在训练目标分类模型过程中,需要比较预测分类信息和标注分类信息之间的差异,对应的,需要获取目标分类模型输出的第一预测分类信息,以及第一样本图像对应的第一标注分类信息,并基于第一预测分类信息和第一标注分类信息,计算得到为目标分类模型对应的第一分类损失。
与第一预测分类信息类似,第一标注分类信息为第一样本图像中各个待检测对象所属类别的标注概率。其中,第一标注信息可以是人工预先标注得到,或直接从目标检测数据集中获取得到。
由于本申请实施例中的第一样本图像中包含多个子图像,每个子图像中包含有待检测对象,因此,在计算目标分类模型对应的分类损失时,需要综合考虑到不同子图像对应的分类损失,在一个示例性的例子中,计算第一分类损失的过程可以包括以下步骤:
一、基于同一子图像中包含的待检测对象对应的第一预测分类信息和第一标注分类信息,计算得到子图像对应的交叉熵损失。
由于第一样本图像中包含多个子图像,因此,需要分别计算不同子图像对应的交叉熵损失,以便综合分析得到目标检测模型对应的整体分类损失,在一种可能的实施方式中,基于同一子图像中包含的待检测对象对应的第一预测分类信息和第二标注分类信息,计算得到该子图像对应的交叉熵损失,依次类推,可以得到第一样本图像中各个子图像对应的交叉熵损失。
在一个示例性的例子中,以目标检测集中包含20种类别为例,子图像A中包含单个待检测对象,子图像A对应的第一预测分类信息可以为:A={P1,P2,P3…P20},分别表示子图像中待检测对象分别属于不同类别的概率。
其中,计算交叉熵损失的公式可以表示为:
其中,LCE表交叉熵损失,M表示类别的数量,N表示一次输入目标检测模型中的第一样本图像的数量;yic指示变量(即标注概率),如果该类别和待检测对象i的类别相同就是1,否则是0;pic表示待检测对象i属于类别c的预测概率。对应的,将每个子图像对应的第一标注分类信息和第一预测分了信息的代入公式(1)中,可以计算得到每个子图像对应的交叉熵损失。
二、将第一样本图像包含的各个子图像对应的交叉熵损失之和,确定为目标分类模型对应的第一分类损失。
由于不同子图像占第一样本图像的百分比并不相同,比如,第一样本图像中有的子图像的图像尺寸较大,其在第一样本图像中的占比较大,有的子图像对应的图像尺寸较小,其在第一样本图像中的占比较小,不同子图像在第一样本图像中的占比不同,因此,在一种可能的实施方式中,设置不同子图像对应的交叉熵损失对应不同交叉熵损失权重,对应在计算整体分类损失时,通过对不同子图像对应的交叉熵损失乘以其对应的交叉熵损失权重,再求和,从而得到目标分类模型对应的第一分类损失。
其中,交叉熵损失权重由子图像占第一样本图像的百分比确定,比如,交叉熵损失权重可以是子图像的像素占第一样本图像像素的百分比。
在一个示例性的例子中,第一分类损失的计算公式可以为:
其中,LCE(x,y)表示第一分类损失,n表示第一样本图像中包含的子样本图像的数量,若第一样本图像中包含4个子样本,则n等于4,ai表示每个子样本对应的交叉熵损失权重,其值为每个子图像像素占第一样本图像像素的百分比,CE(xi,yi)表示每个子样本图像对应的交叉熵损失。
步骤604,基于第一分类损失训练目标分类模型,得到目标检测模型中的特征提取网络。
基于上述得到的目标分类模型对应的第一分类损失,对目标分类模型进行反向传播算法,更新目标分类模型中各个网络的参数,直至目标分类模型收敛,预训练结束,可以得到具有较好特征提取功能的特征提取网络,并将该特征提取网络对应的网络参数作为目标检测模型中的特征提取网络的初始参数,从而进行后续对目标检测网络的微调阶段。
步骤605,基于原始样本图像生成第二样本图像,第二样本图像包含至少两张原始样本图像。
由于第二样本图像也是经过拼接得到的,在拼接时,也可以统计不同原始样本图像的长宽比进行第二样本图像的拼接过程,由于第二样本图像也设置有预设图像尺寸,因此,当对多个原始样本图像进行拼接后,若拼接后的第二原始样本图像的图像尺寸与该预设图像尺寸存在差异,也可以按照该预设图像尺寸对第二原始样本图像进行像素填充,以便生成第二样本图像。
步骤606,将第二样本图像输入目标检测模型,得到目标检测模型输出的预测对象信息。
在目标检测模型的模型微调阶段,由于目标检测模型的任务包括分类任务和定位任务,分类任务即预测第二样本图像中包含的各个待检测对象对应的类别,定位任务即预测待检测对象在第二样本图像中的位置,因此,在一种可能的实施方式中,将第二样本图像输入目标检测模型中,可以得到目标检测模型输出的预测对象信息,预测对象信息包括第二预测分类信息和预测位置信息,第二预测分类信息为第二样本图像各个待检测对象所属类别的预测概率,预测位置信息指示各个待检测对象在第二样本图像中的预测位置区域。
按照目标检测模型中包括的各个网络层对应的功能划分,目标检测模型中可以包括特征提取网络、特征融合网络以及目标检测网络,其中,特征提取网络用于提取输入样本的图像特征,特征融合网络用于对不同分辨率的特征进行加工、融合和增强,而目标检测网络即检测头,用于输出目标检测任务需要的预测结果,比如,预测类别或预测位置信息等;在一个示例性的例子中,目标检测模型对第二样本图像的处理过程可以包括以下步骤:
一、将第二样本图像输入特征提取网络,得到特征提取网络输出的n个第一样本特征图,不同第一样本特征图对应不同分辨率。
其中,在初始化目标检测模型时,特征提取网络的网络参数采用预训练阶段得到的特征提取网络的参数。
在一种可能的实施方式中,将第二样本图像输入目标检测模型后,首先经过特征提取网络对第二样本图像进行图像特征提取,可以得到特征提取网络输出的n个第一样本特征图,即从第二样本图像中提取出不同层级上的图像特征,由于特征提取网络是对第二样本图像进行下采样得到,对应可以得到不同分辨率的n个第一样本特征图。
如图8所示,其示出了本申请一个示例性实施例示出的目标检测模型的训练过程示意图。将第二样本图像801输入特征提取网络802中,得到特征提取网络802输出的n个第一样本特征图803,其中,不同第一样本特征图803具备不同的分辨率,比如,第一样本特征图P1对应的分辨率为1/2,第一样本特征图P2对应的分辨率为1/4,第一样本特征图P3对应的分辨率为1/8,第一样本特征图P4对应的分辨率为1/16,第一样本特征图P5对应的分辨率为1/32。
二、将n个第一样本特征图输入特征融合网络,得到特征融合网络输出的n个第二样本特征图,不同第二样本特征图对应不同分辨率,特征融合网络用于按照预设权重对n个第一样本特征图进行混合。
特征融合网络用于对特征提取网络提取到的图像特征进行加工和增强,使得经过特征融合网络处理过的图像特征是所需要的图像特征,在一种可能的实施方式中,将n个第一样本特征图输入特征融合网络,由特征融合网络按照预设权重对n个第一样本特征图进行混合,从而得到特征融合网络输出的n个第二样本特征图,其中,不同第二样本特征图对应不同分辨率。
相关技术中yolo网络模型采用的特征融合网络为金字塔注意网络(PyramidAttention Network,PAN),虽然采用PAN网络可以提高yolo网络模型的准确性和精度,但是由于PAN网络会导致训练过程中的运算参数较多,从而一样想yolo模型的训练效率和运行效率,因此,本申请实施例中,基于原有的yolo网络模型进行改进,采用双向特征金字塔网络(BidirectionalFeature Pyramid Networks,BiFPN)作为目标检测模型(yolo网络模型)中的特征融合网络,以提高yolo模型的运行效率,同时保证了yolo网络模型的准确性和鲁棒性。
如图8所示,将n个第一样本特征图803输入特征融合网络804,该特征融合网络804采用BiFPN网络,得到特征融合网络804输出的第二样本特征图805,不同第二样本特征图对应不同分辨率,比如,第二样本特征图P6对应的分辨率为1/8,第二样本特征图P7对应的分辨率为1/16,第二样本特征图P8对应的分辨率为1/32。
三、将n个第二样本特征图输入目标检测网络中,得到目标检测网络输出的预测对象信息。
其中,目标检测网络即检测头,通过对第二样本特征图进行卷积操作,得到所需要的目标检测信息,即预测对象信息,该预测对象信息包括预测类别信息和预测位置信息。
如图8所示,将经过特征融合后生成的第二样本特征图805输入目标检测网络806,得到目标检测网络806输出的预测对象信息807,该预测对象信息包括预测类别信息和预测位置信息。
步骤607,基于预测对象信息和第二样本图像对应的标注对象信息,计算得到目标检测模型对应的目标损失。
其中,标注对象信息包括第二标注分类信息和标注位置信息,第二标注分类信息为第二样本图像中各个待检测对象所属类别的标注概率,标注位置信息指示各个待检测对象在第二样本图像中的标注位置区域。
由于本实施例中目标检测任务包括分类任务和定位任务,对应的,目标检测模型对应的目标损失也包含两部分,一部分是分类损失,由第二标注分类信息和第二预测分类信息计算得到,一部分是定位损失,由标注位置区域和预测位置区域计算得到,其中,标注位置区域即标注边框,因此,定位损失也可以称为边框回归损失。
在一个示例性的例子中,计算目标检测模型对应的目标损失的过程可以包括以下步骤:
一、基于第二预测分类信息和第二标注分类信息,计算得到目标检测模型对应的第二分类损失。
分类损失即预测类别概率和标注类别概率之间的损失,对应的,可以基于第二预测分类信息和第二标注分类信息,计算得到目标检测模型对应的第二分类损失。
其中,计算第二分类损失的过程可以参考上文实施例中计算第一分类损失的过程,本申请实施例在此不做赘述。
二、基于预测位置信息和标注位置信息,计算得到目标检测模型对应的定位损失。
其中,预测位置信息即第二样本图像中各个待检测对象对应的预测框位置,比如,预测框四个顶点的坐标,标注位置信息即第二样本图像中各个待检测对象对应的标注框位置,比如,标注框四个顶点的坐标。
可选的,标注位置信息可以由人工标注获取得到。
在一个示例性的例子中,计算定位损失(边框回归损失)的公式可以表示为:
其中,LCIOU表示定位损失,IOU表示预测框和标注框的交并比,是用来衡量长宽比一致性的参数,ωgt表示标注框的宽边,hgt表示标注框的长边,ω表示预测框的宽边,h表示预测框的长边;ρ2表示预测框与标注框中心点的平方,c2表示刚好包住预测框和标注框的最小矩形的对角线的平方,b表示预测框,bgt表示标注框。
在一种可能的实施方式中,将预测位置信息中的预测框坐标和标注位置信息中的标注框坐标带入公式(3),可以计算得到目标检测模型对应的定位损失。
三、将第二分类损失和定位损失之和,确定为目标检测模型对应的目标损失,其中,第二分类损失和定位损失对应不同损失权重。
在一种可能的实施方式中,预设有第二分类损失和定位损失分别对应的损失权重,当获取到第二分类损失和定位损失后,可以根据第二分类损失、定位损失和其对应的损失权重,计算得到目标检测模型对应的目标损失。
在一个示例性的例子中,目标检测模型对应的目标损失的计算公式可以表示为:
L=α1CE_Loss+α2CIOU_Loss (4)
其中,L表示目标检测模型对应的目标损失,α1表示第二分类损失对应的损失权重,α2表示定位损失对应的损失权重,CE_Loss表示第二分类损失,CIOU_Loss表示定位损失。将上文实施例中计算得到的第二分类损失和定位损失带入公式(4)中,即可以求得目标检测模型对应的目标损失。
步骤608,根据目标损失训练目标检测模型。
在一种可能的实施方式中,根据公式(4)计算得到的目标检测模型对应的目标损失,从而利用该目标损失对目标检测模型执行反向传播算法,更新目标检测模型中各个网络对应的网络参数。
可选的,在多个训练周期内,按照上文实施例所示的方法重复对目标检测模型进行训练,直至目标检测模型对应的损失函数完全收敛时,完成目标检测模型的训练。
可选的,当目标检测模型达到收敛状态时,将马赛克输入增强关掉,即更换第二样本图像,直接采用原始样本图像再微调一会目标检测网络。
本实施例中,描述了对目标检测模型进行预训练和微调的训练过程,以及预训练和微调阶段损失函数的构建过程,实现通过预训练得到具有较好特征提取功能的特征提取网络,并在微调阶段直接采用预训练得到的特征提取网络的网络参数,从而使得微调阶段的目标检测模型可以快速收敛,提高了目标检测模型的训练速率,此外,本申请实施例通过采用BiFPN网络作为yolo网络模型中的特征融合网络,可以在保证yolo网络模型准确性的基础上,提高yolo网络模型的运行速率。
请参考图9,其示出了本申请一个示例性实施例示出的完整目标检测模型的训练方法的流程图,该方法包括:
步骤901,裁剪出预训练过程需要的子样本图像。
步骤902,按规则基于子样本图像构造预训练的马赛克图片。
步骤903,预训练目标分类模型。
步骤904,将目标分类模型中特征提取网络的网络参数作为目标检测模型中特征提取网络的初始化网络参数。
步骤905,从目标检测数据集中获取原始样本图像。
步骤906,按规则基于原始样本图像构建训练目标检测模型的马赛克图片或者原始样本图像。
步骤907,训练目标检测模型。
请参考图10,其示出了本申请一个示例性实施例提供的目标检测模型的训练装置的结构框图。该装置可以通过软件、硬件或者两者的结合实现成为计算机设备的全部或一部分,该装置包括:
第一生成模块1001,用于基于原始样本图像生成第一样本图像,所述原始样本图像中包含待检测对象,所述第一样本图像包含至少两张子图像,所述子图像是对所述原始样本图像中包含的所述待检测对象进行裁剪得到;
预训练模块1002,用于基于所述第一样本图像预训练目标检测模型,预训练的目的为调整所述目标检测模型中特征提取网络的网络参数;
第二生成模块1003,用于基于所述原始样本图像生成第二样本图像,所述第二样本图像包含至少两张所述原始样本图像;
训练模块1004,用于基于所述第二样本图像对所述目标检测模型进行微调。
可选的,所述预训练模块1002,包括:
第一处理单元,用于将所述第一样本图像输入目标分类模型,得到所述目标分类模型输出的第一预测分类信息,所述第一预测分类信息为所述第一样本图像中各个所述待检测对象所属类别的预测概率,所述目标分类模型中包含所述特征提取网络;
第一计算单元,用于基于所述第一预测分类信息和所述第一样本图像对应的第一标注分类信息,计算得到所述目标分类模型对应的第一分类损失,所述第一标注分类信息为所述第一样本图像中各个所述待检测对象所属类别的标注概率;
第一训练单元,用于基于所述第一分类损失训练所述目标分类模型,得到所述目标检测模型中的所述特征提取网络。
可选的,所述第一计算单元,还用于:
基于同一所述子图像中包含的所述待检测对象对应的所述第一预测分类信息和所述第一标注分类信息,计算得到所述子图像对应的交叉熵损失;
将所述第一样本图像包含的各个所述子图像对应的所述交叉熵损失之和,确定为所述目标分类模型对应的所述第一分类损失,其中,不同子图像对应的所述交叉熵损失对应不同交叉熵损失权重,所述交叉熵损失权重由所述子图像占所述第一样本图像的百分比确定。
可选的,所述训练模块1004,包括:
第二处理单元,用于将所述第二样本图像输入所述目标检测模型,得到所述目标检测模型输出的预测对象信息,所述预测对象信息包括第二预测分类信息和预测位置信息,所述第二预测分类信息为所述第二样本图像各个所述待检测对象所属类别的预测概率,所述预测位置信息指示各个所述待检测对象在所述第二样本图像中的预测位置区域;
第二计算单元,用于基于所述预测对象信息和所述第二样本图像对应的标注对象信息,计算得到所述目标检测模型对应的目标损失,所述标注对象信息包括第二标注分类信息和标注位置信息,所述第二标注分类信息为所述第二样本图像中各个所述待检测对象所属类别的标注概率,所述标注位置信息指示各个所述待检测对象在所述第二样本图像中的标注位置区域;
第二训练单元,用于根据所述目标损失训练所述目标检测模型。
可选的,所述第二计算单元,还用于:
基于所述第二预测分类信息和所述第二标注分类信息,计算得到所述目标检测模型对应的第二分类损失;
基于所述预测位置信息和所述标注位置信息,计算得到所述目标检测模型对应的定位损失;
将所述第二分类损失和所述定位损失之和,确定为所述目标检测模型对应的所述目标损失,其中,所述第二分类损失和所述定位损失对应不同损失权重。
可选的,所述目标检测模型还包括特征融合网络和目标检测网络;
所述第二处理单元,还用于:
将所述第二样本图像输入所述特征提取网络,得到所述特征提取网络输出的n个第一样本特征图,不同第一样本特征图对应不同分辨率;
将n个所述第一样本特征图输入所述特征融合网络,得到所述特征融合网络输出的n个第二样本特征图,不同第二样本特征图对应不同分辨率,所述特征融合网络用于按照预设权重对n个所述第一样本特征图进行混合;
将n个所述第二样本特征图输入所述目标检测网络中,得到所述目标检测网络输出的所述预测对象信息。
可选的,所述目标检测模型采用yolo网络模型,所述yolo网络模型中的所述特征融合网络采用BiFPN。
可选的,所述第一生成模块1001,包括:
裁剪单元,用于基于所述待检测对象在所述原始样本图像中的位置对所述原始样本图像进行裁剪,得到所述子图像;
获取单元,用于获取各个所述子图像对应的长宽比,所述长宽比指所述子图像的长边和宽边之间的比值;
拼接处理单元,用于基于所述长宽比对至少两张所述子图像进行拼接处理,得到所述第一样本图像。
可选的,所述拼接处理单元,还用于:
按照所述长宽比将所述子图像划分为第一类子图像、第二类子图像和第三类子图像,所述第一类子图像对应的长宽比介于第一长宽比阈值和第二长宽比阈值之间,所述第二类子图像对应的长宽比小于所述第一长宽比阈值,所述第三类子图像对应的长宽比大于所述第一长宽比阈值;
从所述第一类子图像、所述第二类子图像和所述第三类子图像中选取至少两张所述子图像;
基于至少两张所述子图像生成所述第一样本图像。
可选的,所述第一样本图像中包含四张所述子图像,且所述第一样本图像对应预设图像尺寸;
所述拼接处理单元,还用于:
对四张所述子图像进行拼接处理,得到第一原始样本图像,四张所述子图像包括两张所述第一类子图像、一张所述第二类子图像和一张所述第三类子图像;
按照所述预设图像尺寸对所述第一原始样本图像进行像素填充处理,得到所述第一样本图像。
本申请实施例中,在训练目标检测模型过程中,通过在预训练以及微调目标检测模型的过程中,均采用包含多个待检测对象的样本图像,即通过原始样本图像裁剪或拼接得到预训练和微调阶段所需的样本图像,从而增加样本图像中数据的多样性,避免通过样本图像训练得到的目标检测模型对某些属性较为依赖,从而增加目标检测模型的鲁棒性和准确性;此外,通过对原始样本图像中的待检测对象进行裁剪和拼接后生成的第一样本图像,作为预训练阶段的样本图像,由于基于待检测对象进行裁剪,可以排除待检测对象之外的其他因素的干扰,从而加速预训练阶段模型的收敛。
请参考图11,其示出了本申请一个示例性实施例提供的计算机设备的结构示意图。所述计算机设备1100包括中央处理单元(Central Processing Unit,CPU)1101、包括随机存取存储器(Random Access Memory,RAM)1102和只读存储器(Read-Only Memory,ROM)1103的系统存储器1104,以及连接系统存储器1104和中央处理单元1101的系统总线1105。所述计算机设备1100还包括帮助计算机设备内的各个器件之间传输信息的基本输入/输出系统(Input/Output系统,I/O系统)1106,和用于存储操作系统1113、应用程序1114和其他程序模块1115的大容量存储设备1107。
所述基本输入/输出系统1106包括有用于显示信息的显示器1108和用于用户输入信息的诸如鼠标、键盘之类的输入设备1109。其中所述显示器1108和输入设备1109都通过连接到系统总线1105的输入输出控制器1110连接到中央处理单元1101。所述基本输入/输出系统1106还可以包括输入输出控制器1110以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器1110还提供输出到显示屏、打印机或其他类型的输出设备。
所述大容量存储设备1107通过连接到系统总线1105的大容量存储控制器(未示出)连接到中央处理单元1101。所述大容量存储设备1107及其相关联的计算机可读存储介质为计算机设备1100提供非易失性存储。也就是说,所述大容量存储设备1107可以包括诸如硬盘或者只读光盘(Compact Disc Read-Only Memory,CD-ROM)驱动器之类的计算机可读存储介质(未示出)。
不失一般性,所述计算机可读存储介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读存储指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、可擦除可编程只读寄存器(Erasable Programmable Read OnlyMemory,EPROM)、电子抹除式可复写只读存储器(Electrically-Erasable ProgrammableRead-Only Memory,EEPROM)、闪存或其他固态存储设备,CD-ROM、数字多功能光盘(DigitalVersatile Disc,DVD)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知所述计算机存储介质不局限于上述几种。上述的系统存储器1104和大容量存储设备1107可以统称为存储器。
存储器存储有一个或多个程序,一个或多个程序被配置成由一个或多个中央处理单元1101执行,一个或多个程序包含用于实现上述方法实施例的指令,中央处理单元1101执行该一个或多个程序实现上述各个方法实施例提供的目标检测模型的训练方法。
根据本申请的各种实施例,所述计算机设备1100还可以通过诸如因特网等网络连接到网络上的远程服务器运行。也即计算机设备1100可以通过连接在所述系统总线1105上的网络接口单元1111连接到网络1112,或者说,也可以使用网络接口单元1111来连接到其他类型的网络或远程服务器系统(未示出)。
所述存储器还包括一个或者一个以上的程序,所述一个或者一个以上程序存储于存储器中,所述一个或者一个以上程序包含用于进行本申请实施例提供的目标检测模型的训练方法中由计算机设备所执行的步骤。
本申请实施例还提供了一种计算机可读存储介质,该计算机可读存储介质存储有至少一条指令,所述至少一条指令由所述处理器加载并执行以实现如上各个实施例所述的目标检测模型的训练方法。
本申请实施例还提供了一种计算机程序产品,该计算机程序产品存储有至少一条指令,所述至少一条指令由所述处理器加载并执行以实现如上各个实施例所述的目标检测模型的训练方法。
本申请实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述方面的各种可选实现方式中提供的目标检测模型的训练方法。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本申请实施例所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读存储介质中或者作为计算机可读存储介质上的一个或多个指令或代码进行传输。计算机可读存储介质包括计算机存储介质和通信介质,其中通信介质包括便于从一个地方向另一个地方传送计算机程序的任何介质。存储介质可以是通用或专用计算机能够存取的任何可用介质。
以上所述仅为本申请的可选实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (13)

1.一种目标检测模型的训练方法,其特征在于,所述方法包括:
基于原始样本图像生成第一样本图像,所述原始样本图像中包含待检测对象,所述第一样本图像包含至少两张子图像,所述子图像是对所述原始样本图像中包含的所述待检测对象进行裁剪得到;
基于所述第一样本图像预训练目标检测模型,预训练的目的为调整所述目标检测模型中特征提取网络的网络参数,且预训练阶段的模型任务为分类任务;
基于所述原始样本图像生成第二样本图像,所述第二样本图像包含至少两张所述原始样本图像;
基于所述第二样本图像对所述目标检测模型进行微调,其中,微调阶段的模型任务为分类任务及定位任务。
2.根据权利要求1所述的方法,其特征在于,所述基于所述第一样本图像预训练目标检测模型,包括:
将所述第一样本图像输入目标分类模型,得到所述目标分类模型输出的第一预测分类信息,所述第一预测分类信息为所述第一样本图像中各个所述待检测对象所属类别的预测概率,所述目标分类模型中包含所述特征提取网络;
基于所述第一预测分类信息和所述第一样本图像对应的第一标注分类信息,计算得到所述目标分类模型对应的第一分类损失,所述第一标注分类信息为所述第一样本图像中各个所述待检测对象所属类别的标注概率;
基于所述第一分类损失训练所述目标分类模型,得到所述目标检测模型中的所述特征提取网络。
3.根据权利要求2所述的方法,其特征在于,所述基于所述第一预测分类信息和所述第一样本图像对应的第一标注分类信息,计算得到所述目标分类模型对应的第一分类损失,包括:
基于同一所述子图像中包含的所述待检测对象对应的所述第一预测分类信息和所述第一标注分类信息,计算得到所述子图像对应的交叉熵损失;
将所述第一样本图像包含的各个所述子图像对应的所述交叉熵损失之和,确定为所述目标分类模型对应的所述第一分类损失,其中,不同子图像对应的所述交叉熵损失对应不同交叉熵损失权重,所述交叉熵损失权重由所述子图像占所述第一样本图像的百分比确定。
4.根据权利要求1至3任一所述的方法,其特征在于,所述基于所述第二样本图像对所述目标检测模型进行微调,包括:
将所述第二样本图像输入所述目标检测模型,得到所述目标检测模型输出的预测对象信息,所述预测对象信息包括第二预测分类信息和预测位置信息,所述第二预测分类信息为所述第二样本图像各个所述待检测对象所属类别的预测概率,所述预测位置信息指示各个所述待检测对象在所述第二样本图像中的预测位置区域;
基于所述预测对象信息和所述第二样本图像对应的标注对象信息,计算得到所述目标检测模型对应的目标损失,所述标注对象信息包括第二标注分类信息和标注位置信息,所述第二标注分类信息为所述第二样本图像中各个所述待检测对象所属类别的标注概率,所述标注位置信息指示各个所述待检测对象在所述第二样本图像中的标注位置区域;
根据所述目标损失训练所述目标检测模型。
5.根据权利要求4所述的方法,其特征在于,所述基于所述预测对象信息和所述第二样本图像对应的标注对象信息,计算得到所述目标检测模型对应的目标损失,包括:
基于所述第二预测分类信息和所述第二标注分类信息,计算得到所述目标检测模型对应的第二分类损失;
基于所述预测位置信息和所述标注位置信息,计算得到所述目标检测模型对应的定位损失;
将所述第二分类损失和所述定位损失之和,确定为所述目标检测模型对应的所述目标损失,其中,所述第二分类损失和所述定位损失对应不同损失权重。
6.根据权利要求4所述的方法,其特征在于,所述目标检测模型还包括特征融合网络和目标检测网络;
所述将所述第二样本图像输入所述目标检测模型,得到所述目标检测模型输出的预测对象信息,包括:
将所述第二样本图像输入所述特征提取网络,得到所述特征提取网络输出的n个第一样本特征图,不同第一样本特征图对应不同分辨率;
将n个所述第一样本特征图输入所述特征融合网络,得到所述特征融合网络输出的n个第二样本特征图,不同第二样本特征图对应不同分辨率,所述特征融合网络用于按照预设权重对n个所述第一样本特征图进行混合;
将n个所述第二样本特征图输入所述目标检测网络中,得到所述目标检测网络输出的所述预测对象信息。
7.根据权利要求6所述的方法,其特征在于,
所述目标检测模型采用yolo网络模型,所述yolo网络模型中的所述特征融合网络采用双向特征金字塔网络BiFPN。
8.根据权利要求1至3任一所述的方法,其特征在于,所述基于原始样本图像生成第一样本图像,包括:
基于所述待检测对象在所述原始样本图像中的位置对所述原始样本图像进行裁剪,得到所述子图像;
获取各个所述子图像对应的长宽比,所述长宽比指所述子图像的长边和宽边之间的比值;
基于所述长宽比对至少两张所述子图像进行拼接处理,得到所述第一样本图像。
9.根据权利要求8所述的方法,其特征在于,所述基于所述长宽比对至少两张所述子图像进行拼接处理,得到所述第一样本图像,包括:
按照所述长宽比将所述子图像划分为第一类子图像、第二类子图像和第三类子图像,所述第一类子图像对应的长宽比介于第一长宽比阈值和第二长宽比阈值之间,所述第二类子图像对应的长宽比小于所述第一长宽比阈值,所述第三类子图像对应的长宽比大于所述第一长宽比阈值;
从所述第一类子图像、所述第二类子图像和所述第三类子图像中选取至少两张所述子图像;
基于至少两张所述子图像生成所述第一样本图像。
10.根据权利要求9所述的方法,其特征在于,所述第一样本图像中包含四张所述子图像,且所述第一样本图像对应预设图像尺寸;
所述基于至少两张所述子图像生成所述第一样本图像,包括:
对四张所述子图像进行拼接处理,得到第一原始样本图像,四张所述子图像包括两张所述第一类子图像、一张所述第二类子图像和一张所述第三类子图像;
按照所述预设图像尺寸对所述第一原始样本图像进行像素填充处理,得到所述第一样本图像。
11.一种目标检测模型的训练装置,其特征在于,所述装置包括:
第一生成模块,用于基于原始样本图像生成第一样本图像,所述原始样本图像中包含待检测对象,所述第一样本图像包含至少两张子图像,所述子图像是对所述原始样本图像中包含的所述待检测对象进行裁剪得到;
预训练模块,用于基于所述第一样本图像预训练目标检测模型,预训练的目的为调整所述目标检测模型中特征提取网络的网络参数,且预训练阶段的模型任务为分类任务;
第二生成模块,用于基于所述原始样本图像生成第二样本图像,所述第二样本图像包含至少两张所述原始样本图像;
训练模块,用于基于所述第二样本图像对所述目标检测模型进行微调,其中,微调阶段的模型任务为分类任务及定位任务。
12.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有至少一条指令,所述至少一条指令由所述处理器加载并执行以实现如权利要求1至10任一所述的目标检测模型的训练方法。
13.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有至少一条指令,所述至少一条指令由处理器加载并执行以实现如权利要求1至10任一所述的目标检测模型的训练方法。
CN202011625437.6A 2020-12-31 2020-12-31 目标检测模型的训练方法、装置、计算机设备及介质 Active CN112734641B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011625437.6A CN112734641B (zh) 2020-12-31 2020-12-31 目标检测模型的训练方法、装置、计算机设备及介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011625437.6A CN112734641B (zh) 2020-12-31 2020-12-31 目标检测模型的训练方法、装置、计算机设备及介质

Publications (2)

Publication Number Publication Date
CN112734641A CN112734641A (zh) 2021-04-30
CN112734641B true CN112734641B (zh) 2024-05-31

Family

ID=75609826

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011625437.6A Active CN112734641B (zh) 2020-12-31 2020-12-31 目标检测模型的训练方法、装置、计算机设备及介质

Country Status (1)

Country Link
CN (1) CN112734641B (zh)

Families Citing this family (16)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113066078A (zh) * 2021-04-15 2021-07-02 上海找钢网信息科技股份有限公司 管状物计数、模型的训练方法、设备及存储介质
CN113065533B (zh) * 2021-06-01 2021-11-02 北京达佳互联信息技术有限公司 一种特征提取模型生成方法、装置、电子设备和存储介质
CN113361588B (zh) * 2021-06-03 2024-06-25 北京文安智能技术股份有限公司 基于图像数据增强的图像训练集生成方法和模型训练方法
CN113269267B (zh) * 2021-06-15 2024-04-26 苏州挚途科技有限公司 目标检测模型的训练方法、目标检测方法和装置
KR20220169373A (ko) * 2021-06-17 2022-12-27 센스타임 인터내셔널 피티이. 리미티드. 타겟 검출 방법들, 장치들, 전자 디바이스들 및 컴퓨터 판독가능한 저장 매체
CN113505800A (zh) * 2021-06-30 2021-10-15 深圳市慧鲤科技有限公司 图像处理方法及其模型的训练方法和装置、设备、介质
CN113361487B (zh) * 2021-07-09 2024-09-06 无锡时代天使医疗器械科技有限公司 异物检测方法、装置、设备及计算机可读存储介质
CN113657269A (zh) * 2021-08-13 2021-11-16 北京百度网讯科技有限公司 人脸识别模型的训练方法、装置及计算机程序产品
CN113947775A (zh) * 2021-09-30 2022-01-18 北京三快在线科技有限公司 识别证照图像完整性的方法、装置、设备及存储介质
CN114140637B (zh) * 2021-10-21 2023-09-12 阿里巴巴达摩院(杭州)科技有限公司 图像分类方法、存储介质和电子设备
CN114004840A (zh) * 2021-10-29 2022-02-01 北京百度网讯科技有限公司 图像处理方法、训练方法、检测方法、装置、设备及介质
CN114255389A (zh) * 2021-11-15 2022-03-29 浙江时空道宇科技有限公司 一种目标对象检测方法、装置、设备和存储介质
CN114387266A (zh) * 2022-01-19 2022-04-22 北京大学第一医院 结核杆菌检测模型的训练方法、装置、设备及存储介质
CN114587416A (zh) * 2022-03-10 2022-06-07 山东大学齐鲁医院 基于深度学习多目标检测的胃肠道粘膜下肿瘤诊断系统
CN114417046B (zh) * 2022-03-31 2022-07-12 腾讯科技(深圳)有限公司 特征提取模型的训练方法、图像检索方法、装置及设备
CN114842454B (zh) * 2022-06-27 2022-09-13 小米汽车科技有限公司 障碍物检测方法、装置、设备、存储介质、芯片及车辆

Citations (16)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN101438249A (zh) * 2006-05-07 2009-05-20 应用材料股份有限公司 用于错误诊断的多种错误特征
CN105677496A (zh) * 2016-01-12 2016-06-15 电子科技大学 基于两层神经网络的测试性指标分配方法
CN107368787A (zh) * 2017-06-16 2017-11-21 长安大学 一种面向深度智驾应用的交通标志识别算法
CN108182456A (zh) * 2018-01-23 2018-06-19 哈工大机器人(合肥)国际创新研究院 一种基于深度学习的目标检测模型及其训练方法
CN108304873A (zh) * 2018-01-30 2018-07-20 深圳市国脉畅行科技股份有限公司 基于高分辨率光学卫星遥感影像的目标检测方法及其系统
CN109284704A (zh) * 2018-09-07 2019-01-29 中国电子科技集团公司第三十八研究所 基于cnn的复杂背景sar车辆目标检测方法
CN110096964A (zh) * 2019-04-08 2019-08-06 厦门美图之家科技有限公司 一种生成图像识别模型的方法
CN110187334A (zh) * 2019-05-28 2019-08-30 深圳大学 一种目标监控方法、装置及计算机可读存储介质
CN110263697A (zh) * 2019-06-17 2019-09-20 哈尔滨工业大学(深圳) 基于无监督学习的行人重识别方法、装置及介质
CN110363138A (zh) * 2019-07-12 2019-10-22 腾讯科技(深圳)有限公司 模型训练方法、图像处理方法、装置、终端及存储介质
WO2020087974A1 (zh) * 2018-10-30 2020-05-07 北京字节跳动网络技术有限公司 生成模型的方法和装置
CN111488930A (zh) * 2020-04-09 2020-08-04 北京市商汤科技开发有限公司 分类网络的训练方法、目标检测方法、装置和电子设备
WO2020164282A1 (zh) * 2019-02-14 2020-08-20 平安科技(深圳)有限公司 基于yolo的图像目标识别方法、装置、电子设备和存储介质
CN111815592A (zh) * 2020-06-29 2020-10-23 郑州大学 一种肺结节检测模型的训练方法
CN111832443A (zh) * 2020-06-28 2020-10-27 华中科技大学 一种施工违规行为检测模型的构建方法及其应用
CN112070074A (zh) * 2020-11-12 2020-12-11 中电科新型智慧城市研究院有限公司 物体检测方法、装置、终端设备和存储介质

Patent Citations (16)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN101438249A (zh) * 2006-05-07 2009-05-20 应用材料股份有限公司 用于错误诊断的多种错误特征
CN105677496A (zh) * 2016-01-12 2016-06-15 电子科技大学 基于两层神经网络的测试性指标分配方法
CN107368787A (zh) * 2017-06-16 2017-11-21 长安大学 一种面向深度智驾应用的交通标志识别算法
CN108182456A (zh) * 2018-01-23 2018-06-19 哈工大机器人(合肥)国际创新研究院 一种基于深度学习的目标检测模型及其训练方法
CN108304873A (zh) * 2018-01-30 2018-07-20 深圳市国脉畅行科技股份有限公司 基于高分辨率光学卫星遥感影像的目标检测方法及其系统
CN109284704A (zh) * 2018-09-07 2019-01-29 中国电子科技集团公司第三十八研究所 基于cnn的复杂背景sar车辆目标检测方法
WO2020087974A1 (zh) * 2018-10-30 2020-05-07 北京字节跳动网络技术有限公司 生成模型的方法和装置
WO2020164282A1 (zh) * 2019-02-14 2020-08-20 平安科技(深圳)有限公司 基于yolo的图像目标识别方法、装置、电子设备和存储介质
CN110096964A (zh) * 2019-04-08 2019-08-06 厦门美图之家科技有限公司 一种生成图像识别模型的方法
CN110187334A (zh) * 2019-05-28 2019-08-30 深圳大学 一种目标监控方法、装置及计算机可读存储介质
CN110263697A (zh) * 2019-06-17 2019-09-20 哈尔滨工业大学(深圳) 基于无监督学习的行人重识别方法、装置及介质
CN110363138A (zh) * 2019-07-12 2019-10-22 腾讯科技(深圳)有限公司 模型训练方法、图像处理方法、装置、终端及存储介质
CN111488930A (zh) * 2020-04-09 2020-08-04 北京市商汤科技开发有限公司 分类网络的训练方法、目标检测方法、装置和电子设备
CN111832443A (zh) * 2020-06-28 2020-10-27 华中科技大学 一种施工违规行为检测模型的构建方法及其应用
CN111815592A (zh) * 2020-06-29 2020-10-23 郑州大学 一种肺结节检测模型的训练方法
CN112070074A (zh) * 2020-11-12 2020-12-11 中电科新型智慧城市研究院有限公司 物体检测方法、装置、终端设备和存储介质

Non-Patent Citations (5)

* Cited by examiner, † Cited by third party
Title
Walk and Learn: Facial Attribute Representation Learning from Egocentric Video and Contextual Data;Jing Wang;《2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)》;20161212;全文 *
基于度量学习的小样本零器件表面缺陷检测;于重重;《仪器仪表学报》;20200731;全文 *
基于深度学习的目标检测框架进展研究;寇大磊;权冀川;张仲伟;;计算机工程与应用;20190326(第11期);全文 *
基于深度学习的航空对地小目标检测;梁华;宋玉龙;钱锋;宋策;;液晶与显示;20180915(第09期);全文 *
基于深度学习的行人检测及其应用研究;郑丽琴;《中国优秀硕士论文全文数据库》;20190715;全文 *

Also Published As

Publication number Publication date
CN112734641A (zh) 2021-04-30

Similar Documents

Publication Publication Date Title
CN112734641B (zh) 目标检测模型的训练方法、装置、计算机设备及介质
CN114155543B (zh) 神经网络训练方法、文档图像理解方法、装置和设备
CN114202672A (zh) 一种基于注意力机制的小目标检测方法
CN109934792B (zh) 电子装置及其控制方法
JP2021508123A (ja) リモートセンシング画像認識方法、装置、記憶媒体及び電子機器
EP2458872B1 (en) Adaptive method and system for encoding digital images for the internet
WO2021147817A1 (zh) 文本定位方法和系统以及文本定位模型训练方法和系统
CN112215171B (zh) 目标检测方法、装置、设备及计算机可读存储介质
WO2021090771A1 (en) Method, apparatus and system for training a neural network, and storage medium storing instructions
EP4404148A1 (en) Image processing method and apparatus, and computer-readable storage medium
CN110222726A (zh) 图像处理方法、装置及电子设备
CN116645592B (zh) 一种基于图像处理的裂缝检测方法和存储介质
AU2021354030B2 (en) Processing images using self-attention based neural networks
CN113516666A (zh) 图像裁剪方法、装置、计算机设备及存储介质
CN112242002B (zh) 基于深度学习的物体识别和全景漫游方法
CN112561801A (zh) 基于se-fpn的目标检测模型训练方法、目标检测方法及装置
CN113033516A (zh) 对象识别统计方法及装置、电子设备、存储介质
CN111292377A (zh) 目标检测方法、装置、计算机设备和存储介质
CN116071300A (zh) 一种基于上下文特征融合的细胞核分割方法及相关设备
CN113762327A (zh) 机器学习方法、机器学习系统以及非暂态电脑可读取媒体
CN110969641A (zh) 图像处理方法和装置
CN114998672B (zh) 基于元学习的小样本目标检测方法与装置
CN114511862B (zh) 表格识别方法、装置及电子设备
CN116246064A (zh) 一种多尺度空间特征增强方法及装置
CN115933949A (zh) 一种坐标转换方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant