CN114596497A - 目标检测模型的训练方法、目标检测方法、装置及设备 - Google Patents

目标检测模型的训练方法、目标检测方法、装置及设备 Download PDF

Info

Publication number
CN114596497A
CN114596497A CN202210495852.7A CN202210495852A CN114596497A CN 114596497 A CN114596497 A CN 114596497A CN 202210495852 A CN202210495852 A CN 202210495852A CN 114596497 A CN114596497 A CN 114596497A
Authority
CN
China
Prior art keywords
network
student
teacher
training
detection result
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210495852.7A
Other languages
English (en)
Other versions
CN114596497B (zh
Inventor
陈博
高原
白锦峰
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Century TAL Education Technology Co Ltd
Original Assignee
Beijing Century TAL Education Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Century TAL Education Technology Co Ltd filed Critical Beijing Century TAL Education Technology Co Ltd
Priority to CN202210495852.7A priority Critical patent/CN114596497B/zh
Publication of CN114596497A publication Critical patent/CN114596497A/zh
Application granted granted Critical
Publication of CN114596497B publication Critical patent/CN114596497B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

本公开提供一种目标检测模型的训练方法、目标检测方法、装置及设备,其中该训练方法包括:获取携带有标注信息的图像样本;将图像样本输入至骨干网络,得到初步特征;将初步特征输入至学生预测网络,得到第一检测结果;以及,将初步特征输入至预设的教师预测网络,得到第二检测结果;基于第一检测结果、第二检测结果和标注信息对骨干网络、学生预测网络和教师预测网络进行训练,直至达到预设条件时停止训练;其中,教师预测网络用于在训练过程中为学生预测网络提供指定信息,指定信息是教师预测网络生成的与第二检测结果相关的信息;基于停止训练后的骨干网络及学生预测网络得到训练好的目标检测模型。本公开能够有效提升模型训练效率。

Description

目标检测模型的训练方法、目标检测方法、装置及设备
技术领域
本公开涉及人工智能技术领域,尤其涉及目标检测模型的训练方法、目标检测方法、装置及设备。
背景技术
目标检测(Object Detection)是计算机视觉中一项非常重要的研究课题。目标检测的主要目的在于定位与识别出图像中的感兴趣目标。现有主流的目标检测模型的规模相对较大(也即,重量级模型),通常需要借助诸如服务器等数据处理能力较强的云端设备运行,而在很多实际应用场景中需要诸如移动端设备(如手机)等数据处理能力较弱的设备执行目标检测任务,此类设备对其上运行的模型规模有较为严格的限制,因此需要规模相对较小的目标检测模型,也即,需要轻量级模型执行目标检测任务。
基于此,现有技术大多是预先训练一个规模较大的模型(教师模型),之后再采用训练好的教师模型对轻量级模型(学生模型)进行知识蒸馏(Knowledge Distillation),以得到训练好的学生模型,以便于后续将轻量级的学生模型应用于数据处理能力较弱的设备。然而,这种方式属于两阶段训练方式,较为耗时费力,模型训练效率低下。
发明内容
为了解决上述技术问题或者至少部分地解决上述技术问题,本公开提供了一种目标检测模型的训练方法、目标检测方法、装置及设备。
根据本公开的一方面,提供了一种目标检测模型的训练方法,其中,所述目标检测模型包括依次连接的骨干网络和学生预测网络;所述方法包括:获取携带有标注信息的图像样本;将所述图像样本输入至所述骨干网络,得到初步特征;将所述初步特征输入至所述学生预测网络,得到第一检测结果;以及,将所述初步特征输入至预设的教师预测网络,得到第二检测结果;基于所述第一检测结果、所述第二检测结果和所述标注信息对所述骨干网络、所述学生预测网络和所述教师预测网络进行训练,直至达到预设条件时停止训练;其中,所述教师预测网络用于在训练过程中为所述学生预测网络提供指定信息,所述指定信息是所述教师预测网络生成的与所述第二检测结果相关的信息;基于停止训练后的所述骨干网络及所述学生预测网络得到训练好的目标检测模型。
根据本公开的另一方面,提供了一种目标检测方法,包括:获取待检测图像;通过预先训练好的目标检测模型对所述待检测图像进行处理,得到所述待检测图像中所包含的目标对象的检测结果;其中,所述目标检测模型是采用上述目标检测模型的训练方法训练得到的。
根据本公开的另一方面,提供了一种目标检测模型的训练装置,其中,所述目标检测模型包括依次连接的骨干网络和学生预测网络;所述装置包括:样本获取模块,用于获取携带有标注信息的图像样本;初步特征获取模块,用于将所述图像样本输入至所述骨干网络,得到初步特征;检测结果获取模块,用于将所述初步特征输入至所述学生预测网络,得到第一检测结果;以及,将所述初步特征输入至预设的教师预测网络,得到第二检测结果;训练模块,用于基于所述第一检测结果、所述第二检测结果和所述标注信息对所述骨干网络、所述学生预测网络和所述教师预测网络进行训练,直至达到预设条件时停止训练;其中,所述教师预测网络用于在训练过程中为所述学生预测网络提供指定信息,所述指定信息是所述教师预测网络生成的与所述第二检测结果相关的信息;模型获得模块,用于基于停止训练后的所述骨干网络及所述学生预测网络得到训练好的目标检测模型。
根据本公开的另一方面,提供了一种目标检测装置,包括:图像获取模块,用于获取待检测图像;目标检测模块,用于通过预先训练好的目标检测模型对所述待检测图像进行处理,得到所述待检测图像中所包含的目标对象的检测结果;其中,所述目标检测模型是采用上述目标检测模型的训练方法训练得到的。
根据本公开的另一方面,提供了一种电子设备,包括:处理器;以及存储程序的存储器,其中,所述程序包括指令,所述指令在由所述处理器执行时使所述处理器执行根据上述目标检测模型的训练方法或者上述目标检测方法。
根据本公开的另一方面,提供了一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序用于执行上述目标检测模型的训练方法或者目标检测方法。
本公开实施例中提供的上述技术方案,目标检测模型包括依次连接的骨干网络和学生预测网络,在训练时,能够将图像样本(携带有标注信息)输入至骨干网络,得到初步特征;不仅将初步特征输入至学生预测网络,得到第一检测结果;还会将初步特征输入至预设的教师预测网络,得到第二检测结果;基于第一检测结果、第二检测结果和标注信息对骨干网络、学生预测网络和教师预测网络进行训练,直至达到预设条件时停止训练;教师预测网络可以在训练过程中为学生预测网络提供指定信息,最后便可基于停止训练后的骨干网络及学生预测网络得到训练好的目标检测模型。在上述方式中,教师预测网络和学生预测网络共用骨干网络,且无需预先训练教师预测网络,而是可以对教师预测网络和学生预测网络同时训练,在训练过程中教师预测网络给学生预测网络提供指定信息,通过一阶段训练方式即可得到训练好的目标检测模型,相比于现有的两阶段训练方式,本公开实施例提供的上述方式能够有效提升模型训练效率。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。
为了更清楚地说明本公开实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,对于本领域普通技术人员而言,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本公开实施例提供的一种目标检测模型的训练方法的流程示意图;
图2为本公开实施例提供的一种目标检测模型的结构示意图;
图3为本公开实施例提供的一种训练时的网络结构示意图;
图4为本公开实施例提供的另一种训练时的网络结构示意图;
图5为本公开实施例提供的一种网络训练示意图;
图6为本公开实施例提供的一种目标检测方法的流程示意图;
图7为本公开实施例提供的一种目标检测模型的训练装置的结构示意图;
图8为本公开实施例提供的一种目标检测装置的结构示意图;
图9为本公开实施例提供的一种电子设备的结构示意图。
具体实施方式
下面将参照附图更详细地描述本公开的实施例。虽然附图中显示了本公开的某些实施例,然而应当理解的是,本公开可以通过各种形式来实现,而且不应该被解释为限于这里阐述的实施例,相反提供这些实施例是为了更加透彻和完整地理解本公开。应当理解的是,本公开的附图及实施例仅用于示例性作用,并非用于限制本公开的保护范围。
应当理解,本公开的方法实施方式中记载的各个步骤可以按照不同的顺序执行,和/或并行执行。此外,方法实施方式可以包括附加的步骤和/或省略执行示出的步骤。本公开的范围在此方面不受限制。
本公开使用的术语“包括”及其变形是开放性包括,即“包括但不限于”。术语“基于”是“至少部分地基于”。术语“一个实施例”表示“至少一个实施例”;术语“另一实施例”表示“至少一个另外的实施例”;术语“一些实施例”表示“至少一些实施例”。其他术语的相关定义将在下文描述中给出。需要注意,本公开中提及的“第一”、“第二”等概念仅用于对不同的装置、模块或单元进行区分,并非用于限定这些装置、模块或单元所执行的功能的顺序或者相互依存关系。
需要注意,本公开中提及的“一个”、“多个”的修饰是示意性而非限制性的,本领域技术人员应当理解,除非在上下文另有明确指出,否则应该理解为“一个或多个”。
为了能够更清楚地理解本公开的上述目的、特征和优点,下面将对本公开的方案进行进一步描述。需要说明的是,在不冲突的情况下,本公开的实施例及实施例中的特征可以相互组合。
目标检测是计算机视觉中的重要任务,也是其它诸如目标跟踪、姿态估计等任务的基础,可以广泛应用于诸如监控、无人驾驶、物流等诸多场景。但是在很多场景下,受限于设备自身性能,需要采用轻量级的目标检测模型执行目标检测任务。现有大多采用两阶段的训练方式,优先训练一个重量级的教师模型,再用训练好的教师模型训练轻量级的学生模型,且在训练学生模型的过程中会固定教师模型的权重,利用教师模型的输出作为监督信息去指导学生模型学习,从而达到提升学生模型性能的目的。考虑到这种两阶段训练方式费时费力,效率低下,为改善此问题,本公开实施例提供了一种目标检测模型的训练方法、目标检测方法、装置及设备,以下进行详细阐述说明。
图1为本公开实施例提供的一种目标检测模型的训练方法的流程示意图,该方法可以由目标检测模型的训练装置执行,其中该装置可以采用软件和/或硬件实现,一般可集成在电子设备中。上述目标检测模型包括依次连接的骨干网络和学生预测网络,本公开实施例提供的目标检测模型可以针对单一特定类型目标进行检测,诸如仅检测人物或者车辆等,也可以是针对多个不同特定类型目标进行检测,诸如可以同时检测人物、车辆、特定物品、特定动物等,在此不进行限定。如图1所示,该方法主要包括如下步骤S102~步骤S110:
步骤S102,获取携带有标注信息的图像样本。
图像样本的数量可以为多个,图像样本包含有待检测的目标对象,该目标对象可以为需要目标检测模型检测到的特定类型的目标,诸如可以为人物、车辆、特定物品或特定动物等一种或多种,在此不进行限定。标注信息即为指示图像样本中所包含目标的信息,诸如标注信息包括图像样本中所包含目标的位置及类别。
步骤S104,将图像样本输入至骨干网络,得到初步特征。本公开实施例对骨干网络的结构不进行限制,任何可以对图像进行初步特征提取的网络结构均可。在一些实施方式中,为了使目标检测模型整体为轻量级结构,以便于目标检测模型可以直接运行在对模型体量有限制的设备,因此可以选择轻量级的网络结构作为骨干网络,示例性地,该骨干网络可以为轻量模型ShuffleNetv2 0.5x。
步骤S106,将初步特征输入至学生预测网络,得到第一检测结果;以及,将初步特征输入至预设的教师预测网络,得到第二检测结果。
本公开实施例在此对学生预测网络和教师预测网络的结构不进行限制,但需要说明的是,教师预测网络的网络结构比学生预测网络的网络结构复杂,诸如,教师预测网络的网络层数量大于学生预测网络的网络层数量,教师预测网络的参数量大于学生预测网络的参数量。因此教师预测网络的性能通常优于学生预测网络的性能,从而可以在训练过程中为学生预测网络提供一定的指导信息。
在一些实施方式中,第一检测结果包括第一分类分数图和第一回归分数图,在此基础上还可以进一步包括基于第一分类分数图和第一回归分数图所得的目标类别和目标位置,第二检测结果包括第二分类分数图和第二回归分数图,在此基础上还可以进一步包括基于第二分类分数图和第二回归分数图所得的目标类别和目标位置。上述分类分数图也可简称为分类图,回归分数图也可简称为回归图。在实际应用中,预测网络通过对图像样本进行目标检测,可以得到每个像素点对应的分类分数以及回归分数,分数可用于表征该像素点属于某指定目标的概率值。
步骤S108,基于第一检测结果、第二检测结果和标注信息对骨干网络、学生预测网络和教师预测网络进行训练,直至达到预设条件时停止训练;其中,教师预测网络用于在训练过程中为学生预测网络提供指定信息。
教师预测网络给学生预测网络提供指定信息(即为指导信息)的过程也即知识蒸馏过程。该指定信息为教师预测网络生成的与第二检测结果相关的信息,该指定信息可以包括第二检测结果、基于第二检测结果生成的其它信息、生成第二检测结果的过程中产生的中间信息中的一种或多种。
相比于相关技术中需要预先训练教师模型,再采用训练好的教师模型训练学生模型,本公开实施例采用的上述方式无需提前训练教师模型,而是在训练学生模型(目标检测模型)的基础上,在目标检测模型的骨干网络之后额外加了一个教师预测网络。也可以理解为,本公开实施例的教师模型包括骨干网络和教师预测网络,该教师模型与学生模型共用学生模型的骨干网络。
本公开实施例可以基于第一检测结果、第二检测结果和标注信息对骨干网络、学生预测网络和教师预测网络进行同时训练,由于教师预测网络的性能通常优于学生预测网络,教师预测网络相比于学生预测网络所得的检测结果更为准确,而且在训练过程中还可以优先输出符合期望的检测结果,此时可停止调整教师预测网络的参数,之后由停止调参的教师预测网络继续指导学生预测网络进行训练,直至学生预测网络也可以输出符合期望结果。由于可以对学生预测网络和教师预测网络同时训练,这种一阶段训练方式能够有效提升模型训练效率。
步骤S110,基于停止训练后的骨干网络及学生预测网络得到训练好的目标检测模型。
训练好的目标检测模型包含停止训练(停止调整参数)的骨干网络和学生预测网络,可直接应用于目标检测任务,尤其适用于对模型规模有限制的设备。
在上述方式中,教师预测网络和学生预测网络共用骨干网络,且无需预先训练教师预测网络,而是可以对教师预测网络和学生预测网络同时训练,在训练过程中教师预测网络给学生预测网络提供指定信息,通过一阶段训练方式即可得到训练好的目标检测模型,相比于现有的两阶段训练方式,本公开实施例提供的上述方式能够有效提升模型训练效率。
应当说明的是,本公开实施例的目标检测模型仅包括骨干网络和学生预测网络,示例性地,可以参照图2所示的目标检测模型的结构示意图。但是在训练目标检测模型时,额外引入了教师预测网络,如图3所示的一种训练时的网络结构示意图,示意出了在训练过程中目标检测模型与教师预测网络的结构关系,该教师预测网络与学生预测网络共用骨干网络。教师预测网络和学生预测网络可以同时开始训练,也即可以同时开始调整参数,但是通常情况下,教师预测网络停止调整参数的时间优先于学生预测网络调整参数的时间,也即,教师预测网络会被优先训练好。在学生预测网络训练的过程中,教师预测网络可以一直为学生预测网络提供指定信息。在训练结束后,直接采用包括骨干网络和学生预测网络的目标检测模型执行目标检测任务即可。由于学生预测网络更为轻量,因此可以便捷应用于诸如手机等对其上运行的模型规模有限制的设备。
在一些实施示例中,在图3的基础上,如图4所示的另一种训练时的网络结构示意图,学生预测网络包括依次连接的学生金字塔网络和学生预测头,教师预测网络包括依次连接的教师金字塔网络和教师预测头。在具体实施时,学生金字塔网络和教师金字塔网络的网络结构相同,网络参数不同。示例性地,学生金字塔网络和教师金字塔网络均可采用CSPAN结构实现,上述金字塔网络又可称为特征金字塔,能够加强特征提取,对骨干网络输入的特征进行融合,并有效提取出多尺度的深层特征,提取到的特征信息更为丰富全面,有助于提升后续目标检测效果。本公开实施例对金字塔网络的层数不进行限制。
在一些具体的实施示例中,学生预测头包括第一指定数量个卷积层;教师预测头包括第二指定数量个卷积层;且第二指定数量大于第一指定数量。为了提升学生预测头的处理效率,学生预测头中的卷积层为深度可分离卷积层。深度可分离卷积层是对常规的卷积计算进行改进所得的算法,其通过拆分空间维度和通道(深度)维度的相关性,减少了卷积计算所需要的参数个数,可以有效提升卷积核参数的使用效率,也有助于使学生预测头所需的运算量更小。教师预测头中的卷积层可以直接采用常规的卷积层即可。在实际应用中,学生预测头中的卷积层和教师预测头中的卷积层均可为分类与回归共享权重的卷积层。在一些具体的实施示例中,学生预测头包含分类与回归共享权重的两层深度可分离卷积层,两层深度可分离卷积层为并列形式;教师预测头包含2个并列的4层分类与回归共享权重的常规卷积层。学生预测头比教师预测头的结构更为精简轻量,而教师预测头比学生预测头具有更强的表征能力,能够为轻量的学生预测头提供一定的指导信息。
在图4所示的网络结构的基础上,将初步特征输入至学生预测网络,得到第一检测结果的步骤,可以参照如下步骤(1)~步骤(3)实现:
步骤(1),将初步特征分别输入至学生金字塔网络,得到学生金字塔网络输出的第一多尺度特征。示例性地,图像样本的尺寸为320*320*3,将图像样本输入至骨干网络,经骨干网络的多个阶段处理,得到3个特征图F1,F2和F3,该特征图大小分别为40x40x116,20x20x232,10x10x464,该3个特征图均为初始特征。接着将这3个特征图送入学生金字塔网络(也即,特征金字塔),得到不同层级的特征
Figure 907877DEST_PATH_IMAGE001
,即为第一多尺度特征,大小分别为40x40x96,20x20x96,10x10x96,5x5x96。
步骤(2),将第一多尺度特征输入至学生预测头,得到学生预测头输出的第一分类分数图和第一回归分数图。学生预测头通过对第一多尺度特征进行解析,可得到第一分类分数图和第一回归分数图。
步骤(3),基于第一分类分数图和第一回归分数图得到第一检测结果。在一些具体的实施示例中,可以直接将第一分类分数图和第一回归分数图作为第一检测结果,在另一些实施示例中,还可以进一步基于第一分类分数图和第一回归分数图得到图像样本中的目标位置及类别,并将第一分类分数图、第一回归分数图和最后所得的目标位置及类别均作为第一检测结果。
相应地,将初步特征输入至教师预测网络,得到第二检测结果的步骤,可以参照如下步骤1~步骤4实现:
步骤1,将初步特征输入至教师金字塔网络,得到教师金字塔网络输出的第二多尺度特征。诸如,将上述骨干网络输出的初步特征F1,F2和F3输入至教师金字塔网络,得到不同层级的特征
Figure 681929DEST_PATH_IMAGE002
,即为第二多尺度特征。
步骤2,将第一多尺度特征和第二多尺度特征进行拼接,得到多尺度拼接特征。示例性地,将
Figure 189134DEST_PATH_IMAGE003
Figure 260995DEST_PATH_IMAGE004
拼接后,得到多尺度拼接特征P 1P 2P 3P 4
步骤3,将多尺度拼接特征输入至教师预测头,得到教师预测头输出的第二分类分数图和第二回归分数图;为了提升教师预测网络与学生预测网络的联合训练效果,可以将第一多尺度特征和第二多尺度特征进行拼接,将多尺度拼接特征输入给教师预测头,通过这种方式,学生金字塔网络也会对教师预测头的输出结果产生影响,有助于联合训练。
步骤4,基于第二分类分数图和第二回归分数图得到第二检测结果。
与第一检测结果类似,在一些具体的实施示例中,可以直接将第二分类分数图和第二回归分数图作为第二检测结果,在另一些实施示例中,还可以进一步基于第二分类分数图和第二回归分数图得到图像样本中的目标位置及类别,并将第二分类分数图、第二回归分数图和最后所得的目标位置及类别均作为第二检测结果。
教师预测网络在训练过程中会给学生预测网络提供与第二检测结果相关的指定信息(也即指导信息或监督信息),该过程也可称为知识蒸馏过程,以便于让学生预测网络模仿教师预测网络。考虑到相关技术中大多仅是借鉴分类模型的相关方法进行知识蒸馏,也即,仅是将教师模型的分类输出(诸如分类图或者分类层前的值)作为指导信息来指导学生模型学习,通过让学生模型去模仿教师模型的输出来提升学生模型的性能,但这种方式更侧重于整体比较以及结果的输出而忽略局部细节,教师模型对学生模型的指导效果不佳。为了能够使教师预测网络更好地指导学生预测网络,在一些具体的实施示例中,教师预测网络在训练过程中为学生预测网络提供的指定信息包括:基于第二检测结果得到的标签分配信息和/或特征概率分布图,在一些具体的实施示例中,可以首先基于初步特征得到包括第二分类分数图和第二回归分数图的第二检测结果,然后基于第二检测结果得到上述指定信息。以下对标签分配信息和特征概率分布图分别进行详细阐述说明:
指定信息包括基于教师预测网络获取的标签分配信息,该标签分配信息可以为标签分配策略(可简称为标签分配),也可以为基于标签分配策略得到的标签分配结果。标签分配是指训练过程中对分类分数图和回归分数图上的像素点(先验点)进行正、负样本的分配。正样本也即对应目标所在位置,反之即为负样本。在本公开实施例中,主要基于教师预测网络得到标签分配信息,具体而言,基于教师预测网络输出的第二检测结果(第二分类分数图和第二回归分数图)以及标注信息(真实标记)生成代价矩阵,在代价矩阵中每个真实标记对应的前N个结果(诸如前10个结果)作为正样本,其余作为负样本。示例性地,以第二分类分数图为例,一行有M个点,每个点都对应分数,基于代价矩阵选取代价值最小的前10个点作为最有可能匹配目标的点,将这些点作为正样本,其余作为负样本。以上仅为示例,不应当被视为限制。在选取好正负样本之后,即可将样本对应的分类分数/回归分数送入相应的损失函数(诸如分类损失函数、回归损失函数)分别计算损失值,通过损失值的大小来反向调整网络参数。上述基于损失函数的动态分配策略在性能上具有较为明显的优势,而轻量级的学生预测网络表征能力相对较弱,预测结果可能不够准确,损失计算不稳定。容易造成次优的标签分配,因此本公开实施例通过引入重量级的教师预测网络来引导学生预测网络的标签分配,可以实现更优的标签分配结果,提升训练的稳定性与最终的模型性能。在一些具体的实施方式中,可以仅由教师预测网络生成标签分配信息,学生预测网络直接利用教师预测网络的标签分配信息即可。
指定信息包括基于教师预测网络获取的特征概率分布图,在本公开实施例中,通过教师预测网络得到包含有第二分类分数图的第二检测结果时,可以对第二分类分数图进行归一化处理,得到教师预测网络对应的特征概率分布图(也即第二特征概率分布图)。通过第二特征概率分布图指导学生预测网络,可以理解的是,现有的主流蒸馏方法都是对分类层的值或特征图的向量进行蒸馏,将每个位置视为同等重要,这种做法对于目标检测这类密集预测并非最优方式,因此本公开实施例采取将分类分数图进行归一化为概率图的方式进行蒸馏,以突出代表性特征的重要性。同样,学生预测网络得到包含有第一分类分数图的第一检测结果时,可以对第一分类分数图进行归一化处理,得到学生预测网络对应的特征概率分布图(也即第一特征概率分布图)。通过比对第二特征概率分布图和第一特征概率分布图,即可确定蒸馏损失,从而根据蒸馏损失调整网络参数,以逐渐指导学生预测网络模仿教师预测网络的特征概率分布图。在一些具体的实施示例中,可以采用softmax算法对分类分数图进行归一化处理,得到特征概率分布图,之后便可通过KL散度对特征概率分布图进行蒸馏。具体的,归一化公式可参照如下实现:
Figure 170045DEST_PATH_IMAGE005
其中,c=1,2,3…C,具体为通道索引;i为对应通道的空间位置,T为温度超参;W为图像宽度,H为图像高度;y为分类分数。
在实际应用中,可以将学生预测网络和教师预测网络各自输出的分类分数图用上述公式进行归一化处理,将各自得到的结果(第一特征概率分布图和第二特征概率分布图)送入至蒸馏损失中进行计算,该蒸馏损失可以为KL散度损失,也可称为逐通道蒸馏损失。
综上,本公开实施例可以将标签分配信息和特征概率分布图作为指导信息进行知识蒸馏,可以有效提升训练的稳定性以及最终所得模型的性能。
在前述基础上,本公开实施例给出了上述步骤S108的具体实施示例,也即,基于第一检测结果、第二检测结果和标注信息对骨干网络、学生预测网络和教师预测网络进行训练,直至达到预设条件时停止训练的步骤,参照如下(一)和(二)实现:
(一)基于第一检测结果、第二检测结果和标注信息调整骨干网络、学生预测网络和教师预测网络的网络参数。骨干网络和学生预测网络可以视为学生模型,骨干网络与教师预测网络可以视为教师模型,且学生模型与教师模型共用骨干网络。学生模型和教师模型能够分别输出相应的检测结果,并对应有相应的损失函数,从而基于相应的损失函数值进行模型参数调整。在一些具体的实施示例中,可以参照如下步骤A~步骤C实现:
步骤A,基于第一检测结果、第二检测结果、标注信息和预设的学生损失函数,得到学生损失函数值。示例性地,可以参照如下步骤A1和步骤A2实现:
步骤A1,基于第二检测结果和标注信息得到图像样本的标签分配信息。
第二检测结果包括第二分类分数图和第二回归分数图;在此基础上,可以首先基于第二分类分数图、第二回归分数图和标注信息得到代价矩阵;然后根据代价矩阵得到标签分配信息。具体可参照前述相关内容,在此不再赘述。
步骤A2,基于第一检测结果、标签分配信息、标注信息和预设的学生损失函数,得到学生损失函数值。
基于标签分配信息可以获知分类分数图/回归分数图中与像素点对应的正负样本,然后将选取好的样本对应的检测结果(分类分数/回归分数)送入相应的损失函数(分类损失函数/回归损失函数)中进行计算。在一些具体的实施示例中,可以针对正负样本均计算分类损失,仅针对正样本计算回归损失。
在一些实施方式中,学生损失函数包括第一分类损失函数、第一回归损失函数和蒸馏损失函数。学生损失函数值=第一分类损失函数值+第一回归损失函数值+蒸馏损失函数值。
其中,蒸馏损失函数是基于教师预测网络生成的第二特征概率分布图和学生预测网络生成的第一特征概率分布图之间的差异确定;第一特征概率分布图与第一检测结果相关;第二特征概率分布图与第二检测结果相关。在一种具体的实施示例中,第一特征概率分布图是学生预测网络输出的第一分类分数图进行归一化处理后得到的;第二特征概率分布图是教师预测网络输出的第二分类分数图进行归一化处理后得到的;第一分类分数图属于第一检测结果,第二分类分数图属于第二检测结果。具体可参照前述相关内容,在此不再赘述。
步骤B,基于第二检测结果、标注信息和预设的教师损失函数,得到教师损失函数值。
在一些实施方式中,教师损失函数包括第二分类损失函数和第二回归损失函数。教师损失函数值=第二分类损失函数值+第二回归损失函数值。
为便于理解,本公开实施例给出了关于训练过程中所涉及的损失函数的具体示例,以分类损失、回归损失和蒸馏损失为例分别进行说明。
(1)分类损失
在一些具体的实施示例中,上述第一分类损失函数与第二分类损失函数均可以采用质量注意损失函数,质量注意损失函数可以参照如下公式实现:
Figure 974053DEST_PATH_IMAGE006
其中,σ是网络预测分数(也即分类分数),y为真实标记,也即通过标注信息获得;β是大于0的调节因子,可以根据需求而设定,示例性地其值可以为2。
(2)回归损失
在一些具体的实施示例中,上述第一回归损失函数和第二回归损失函数均包括分布注意损失函数和IOU损失函数。
分布注意损失函数可以参照如下公式实现:
Figure 545718DEST_PATH_IMAGE007
其中,
Figure 827794DEST_PATH_IMAGE008
表示区间Si~ Si+1的回归损失,Si与Si+1表示一个区间的左右两个端点位置yi和yi+1相应的回归分数,y为真实标记,也即通过标注信息获得。
IOU损失函数可以参照如下公式实现:
Figure 919247DEST_PATH_IMAGE009
Figure 159736DEST_PATH_IMAGE010
其中,A与B为预测框和真实框,E为A与B的外接矩形。
(3)蒸馏损失
只有学生损失函数会涉及到蒸馏损失,在一些具体的实施示例中,蒸馏损失函数为逐通道蒸馏函数,也可为KL散度损失。逐通道蒸馏函数可以参照如下公式实现:
Figure 454582DEST_PATH_IMAGE011
其中,公式中字母的上角标T均对应教师预测网络,上角标S均对应学生预测网络,诸如yT表示教师预测网络输出的分类分数,yS表示学生预测网络输出的分类分数;公式中各字母含义和前述归一化公式涉及的字母含义一致,在此不再赘述。通过采用前述归一化公式将分类分数进行归一化,然后采用上述蒸馏公式进行蒸馏,可以整体体现出教师预测网络与学生预测网络的特征概率分布之间的损失。
步骤C,基于学生损失函数值和教师损失函数值调整骨干网络的网络参数,基于学生损失函数值调整学生预测网络的网络参数,以及,基于教师损失函数值调整教师预测网络的网络参数。
可以根据损失函数值反向调整网络参数,以使损失函数值尽可能降低,从而使网络的输出结果可以逐步符合期望结果。
(二)在达到第一预设条件时停止调整教师预测网络的网络参数,以及达到第二预设条件时停止调整骨干网络和学生预测网络的网络参数;其中,教师预测网络的停止调整时间早于骨干网络和学生预测网络的停止调整时间。
由于教师预测网络的性能优于学生预测网络,因此基本会预先输出符合期望的检测结果,此时可停止调整教师预测网络的参数。在一些实施示例中,第一预设条件可以是教师损失函数值收敛至预设第一阈值,或者,训练次数达到预设第一次数阈值。停止调整教师预测网络的参数时,也可认为得到了训练好的教师预测网络。在此之后可继续采用训练好的教师预测网络协助目标检测模型(骨干网络和学生预测网络)进行训练,为学生预测网络提供较为准确可靠的标签分配信息以及特征概率分布图,通过对教师预测网络的标签分配信息以及特征概率分布图进行知识蒸馏,可以有效提升学生预测网络的学习性能,保障最终所得的目标检测模型的检测性能。在一些具体实施方式中,在停止调整教师预测网络的参数后,在后续训练目标检测模型时,教师预测网络自身不会再影响模型中骨干网络等网络参数,以防止目标检测模型依赖教师预测网络。
为便于理解,可以参照如图5所示的网络训练示意图,示意出骨干网络、特征金字塔CSPAN,学生预测头和教师预测头,如图5所示,输入图像(诸如大小为320*320*3)首先输入至骨干网络中,骨干网络对其进行四阶段的特征提取,得到4个特征图,考虑到骨干网络的第一层特征图较大,因此可舍去,最后选取其余三个阶段的特征图,具体为F1,F2和F3,大小分别为40x40x116,20x20x232,10x10x464;然后将初始特征F1,F2和F3输入给特征金字塔,然后由特征金字塔对初步特征进行多尺度特征提取,得到不同层级的特征
Figure 805929DEST_PATH_IMAGE003
,大小分别为40x40x96,20x20x96,10x10x96,5x5x96,具体的,特征金字塔有N个层级,便可最后输出N个不同尺度的特征,如图5中特征金字塔有4个层级,因此最后可得到4个多尺度特征,最后这些多尺度特征可以进入学生预测头,然后学生预测头可基于多尺度特征输出检测结果。应当注意的是,图5中的黑色线条表示训练及测试所需走的流程,在测试完毕确定得到可直接投入至目标检测任务的目标检测模型(骨干网络、学生特征金字塔和学生预测头)后,也会按照该黑色线条所指示的流程运行,而灰色线条表示仅在训练阶段所需走的流程,图5中的教师预测头、特征概率分布图、代价矩阵及标签分配仅用于训练阶段,在测试以及后续应用中将不再使用。为了便于查看,在图5中仅统一用一个特征金字塔CSPAN示意,在实际应用中会分别有结构相同但参数不同的教师特征金字塔和学生特征金字塔。在训练过程中,基于教师预测头输出的分类分数图、回归分数图以及图像样本的标注信息可得到代价矩阵,代价矩阵中的GT表示groundtrue,也即真实标签,GT0、GT1、GT2表示图像中的不同目标,诸如GT0对应人,GT1对应人当前进餐的苹果,GT2对应人身后的沙发,BG表示图像中的背景。代价矩阵中的每个方格都对应特征图中的像素点的代价值,在得到代价矩阵后,基于代价矩阵便可进行标签分配,后续即可确定正负样本,从而将正负样本对应的分类分数/回归分数送入相应的损失函数中进行计算。另外,基于教师预测头输出的分类分数图进行归一化处理可得到特征概率分布图,学生预测头输出的分类分数图进行归一化处理也可得到特征概率分布图,图5中仅是以一个特征概率分布图进行简单示意,教师预测头对应的特征概率分布图和学生预测头对应的特征概率分布图也会送入相应的损失函数(蒸馏损失)中进行计算。在图5中示意出的学生损失函数为:Losscls+Lossiou+Lossdfl+αLosscwd;教师损失函数为:AuxLosscls+AuxLossiou+AuxLossdfl。其中,Losscls表示学生分类损失,Lossiou表示学生IOU损失,Lossdfl表示学生分布注意损失,Lossiou和Lossdfl均属于回归损失, Losscwd表示蒸馏损失,α为预设系数。AuxLosscls表示教师分类损失,AuxLossiou表示教师IOU损失,AuxLossdfl表示教师分布注意损失。上述损失的具体实现方式可参照前述相关内容,在此不再赘述。教师预测头与学生预测头共用骨干网络,且在最初同时训练,并根据相应的损失函数值进行参数调整,通常重量级的教师预测头的性能会优于学生预测头的性能,并可提前输出符合期望的检测结果而停止参数调整,停止调整参数后的教师预测头会继续辅助学生预测头进行训练,但是不会再对骨干网络、学生特征金字塔等参数产生影响(停止反传),避免目标检测模型依赖教师预测头,在后续训练时,骨干网络、学生特征金字塔及学生预测头的网络参数根据学生损失函数值进行调整,直至学生预测头可以输出符合预期的检测结果。具体训练方式可参照前述相关内容,在此不再赘述。
综上所述,本公开实施例提供的上述目标检测模型的训练方法,在训练轻量级的目标检测模型时,无需预先训练重量级的教师模型,而是直接在训练过程中引入教师预测网络,该教师预测网络可直接接在目标检测模型的骨干网络之后(也可理解为教师模型与学生模型共用学生模型的骨干网络),然后进行同时训练,且在训练过程中由教师预测网络为学生预测网络提供指定信息。这种方式可以有效避免两阶段训练的计算开销,提升模型训练效率。另外,本公开实施例充分考虑了目标检测任务的密集型特点,不再直接针对教师预测网络输出的特征图或者分类层值进行蒸馏,而是基于教师预测网络的检测结果确定标签分配信息和特征概率分布图,以此来作为指导信息对学生预测网络进行知识蒸馏,可以有效提升训练的稳定性以及最终所得的轻量级目标检测网络的性能。
在通过本公开实施例提供的目标检测模型的训练方法得到训练好的目标检测模型的基础上,本公开实施例提供了一种目标检测方法,参见图6所示的目标检测方法的流程示意图,主要包括如下步骤S602~步骤S604:
步骤S602,获取待检测图像;
步骤S604,通过预先训练好的目标检测模型对待检测图像进行处理,得到待检测图像中所包含的目标对象的检测结果;其中,目标检测模型是采用本公开实施例提供的目标检测模型的训练方法训练得到的。
上述目标检测模型为轻量级模型,可以便捷应用于诸如手机等数据处理能力较弱的设备。
对应于前述目标检测模型的训练方法,本公开实施例还提供了一种目标检测模型的训练装置,图7为本公开实施例提供的一种目标检测模型的训练装置的结构示意图,目标检测模型包括依次连接的骨干网络和学生预测网络;该装置可由软件和/或硬件实现,一般可集成在电子设备中。如图7所示,目标检测模型的训练装置700包括:
样本获取模块702,用于获取携带有标注信息的图像样本;
初步特征获取模块704,用于将图像样本输入至骨干网络,得到初步特征;
检测结果获取模块706,用于将初步特征输入至学生预测网络,得到第一检测结果;以及,将初步特征输入至预设的教师预测网络,得到第二检测结果;
训练模块708,用于基于第一检测结果、第二检测结果和标注信息对骨干网络、学生预测网络和教师预测网络进行训练,直至达到预设条件时停止训练;其中,教师预测网络用于在训练过程中为学生预测网络提供指定信息,指定信息是教师预测网络生成的与第二检测结果相关的信息;
模型获得模块710,用于基于停止训练后的骨干网络及学生预测网络得到训练好的目标检测模型。
在上述装置中,教师预测网络和学生预测网络共用骨干网络,且无需预先训练教师预测网络,而是可以对教师预测网络和学生预测网络同时训练,在训练过程中教师预测网络给学生预测网络提供指定信息,通过一阶段训练方式即可得到训练好的目标检测模型,相比于现有的两阶段训练方式,本公开实施例提供的上述方式能够有效提升模型训练效率。
在一些实施例中,所述指定信息包括所述教师预测网络基于第二检测结果得到的标签分配信息和/或特征概率分布图。
在一些实施例中,训练模块708用于:基于所述第一检测结果、所述第二检测结果和所述标注信息调整所述骨干网络、所述学生预测网络和所述教师预测网络的网络参数;在达到第一预设条件时停止调整所述教师预测网络的网络参数,以及达到第二预设条件时停止调整所述骨干网络和所述学生预测网络的网络参数;其中,所述教师预测网络的停止调整时间早于所述骨干网络和所述学生预测网络的停止调整时间。
在一些实施例中,训练模块708具体用于:基于所述第一检测结果、所述第二检测结果、所述标注信息和预设的学生损失函数,得到学生损失函数值;基于所述第二检测结果、所述标注信息和预设的教师损失函数,得到教师损失函数值;基于所述学生损失函数值和所述教师损失函数值调整所述骨干网络的网络参数,基于所述学生损失函数值调整所述学生预测网络的网络参数,以及,基于所述教师损失函数值调整所述教师预测网络的网络参数。
在一些实施例中,训练模块708具体用于:基于所述第二检测结果和所述标注信息得到所述图像样本的标签分配信息;基于所述第一检测结果、所述标签分配信息、所述标注信息和所述预设的学生损失函数,得到学生损失函数值。
在一些实施例中,所述第一检测结果包括第一分类分数图和第一回归分数图;所述第二检测结果包括第二分类分数图和第二回归分数图;
训练模块708具体用于:基于所述第二分类分数图、所述第二回归分数图和所述标注信息得到代价矩阵;根据所述代价矩阵得到标签分配信息。
在一些实施例中,所述学生损失函数包括第一分类损失函数、第一回归损失函数和蒸馏损失函数;所述教师损失函数包括第二分类损失函数和第二回归损失函数。
在一些实施例中,所述蒸馏损失函数是基于所述教师预测网络生成的第二特征概率分布图和所述学生预测网络生成的第一特征概率分布图之间的差异确定;所述第一特征概率分布图与所述第一检测结果相关;所述第二特征概率分布图与所述第二检测结果相关。
在一些实施例中,所述第一特征概率分布图是所述学生预测网络输出的第一分类分数图进行归一化处理后得到的;所述第二特征概率分布图是所述教师预测网络输出的第二分类分数图进行归一化处理后得到的;所述第一分类分数图属于所述第一检测结果,所述第二分类分数图属于所述第二检测结果。
在一些实施例中,所述第一分类损失函数与所述第二分类损失函数均为质量注意损失函数;所述第一回归损失函数和所述第二回归损失函数均包括分布注意损失函数和IOU损失函数;所述蒸馏损失函数为逐通道蒸馏函数。
在一些实施例中,所述学生预测网络包括依次连接的学生金字塔网络和学生预测头,所述教师预测网络包括依次连接的教师金字塔网络和教师预测头。
在一些实施例中,所述学生金字塔网络和所述教师金字塔网络的网络结构相同,网络参数不同。
在一些实施例中,所述学生预测头包括第一指定数量个卷积层;所述教师预测头包括第二指定数量个卷积层;且所述第二指定数量大于所述第一指定数量。
在一些实施例中,所述学生预测头中的卷积层为深度可分离卷积层。
在一些实施例中,检测结果获取模块706具体用于:将所述初步特征分别输入至所述学生金字塔网络,得到所述学生金字塔网络输出的第一多尺度特征;将所述第一多尺度特征输入至所述学生预测头,得到所述学生预测头输出的第一分类分数图和第一回归分数图;基于所述第一分类分数图和所述第一回归分数图得到第二检测结果。
在一些实施例中,检测结果获取模块706具体用于:将所述初步特征输入至所述教师金字塔网络,得到所述教师金字塔网络输出的第二多尺度特征;将所述第一多尺度特征和所述第二多尺度特征进行拼接,得到多尺度拼接特征;将所述多尺度拼接特征输入至所述教师预测头,得到所述教师预测头输出的第二分类分数图和第二回归分数图;基于所述第二分类分数图和所述第二回归分数图得到第二检测结果。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的装置实施例的具体工作过程,可以参考方法实施例中的对应过程,在此不再赘述。
对应于前述目标检测方法,本公开实施例还提供了一种目标检测装置,图8为本公开实施例提供的一种目标检测装置的结构示意图,该装置可由软件和/或硬件实现,一般可集成在电子设备中。如图8所示,目标检测装置800包括:
图像获取模块802,用于获取待检测图像;
目标检测模块804,用于通过预先训练好的目标检测模型对待检测图像进行处理,得到待检测图像中所包含的目标对象的检测结果;其中,目标检测模型是采用上述任一项的目标检测模型的训练方法训练得到的。
上述目标检测模型为轻量级模型,可以便捷应用于诸如手机等数据处理能力较弱的设备。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的装置实施例的具体工作过程,可以参考方法实施例中的对应过程,在此不再赘述。
本公开示例性实施例还提供一种电子设备,包括:至少一个处理器;以及与至少一个处理器通信连接的存储器。所述存储器存储有能够被所述至少一个处理器执行的计算机程序,所述计算机程序在被所述至少一个处理器执行时用于使所述电子设备执行根据本公开实施例的方法。
本公开示例性实施例还提供一种存储有计算机程序的非瞬时计算机可读存储介质,其中,所述计算机程序在被计算机的处理器执行时用于使所述计算机执行根据本公开实施例的方法。
本公开示例性实施例还提供一种计算机程序产品,包括计算机程序,其中,所述计算机程序在被计算机的处理器执行时用于使所述计算机执行根据本公开实施例的方法。
所述计算机程序产品可以以一种或多种程序设计语言的任意组合来编写用于执行本公开实施例操作的程序代码,所述程序设计语言包括面向对象的程序设计语言,诸如Java、C++等,还包括常规的过程式程序设计语言,诸如“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、作为一个独立的软件包执行、部分在用户计算设备上部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。
此外,本公开的实施例还可以是计算机可读存储介质,其上存储有计算机程序指令,所述计算机程序指令在被处理器运行时使得所述处理器执行本公开实施例所提供的公式识别方法。所述计算机可读存储介质可以采用一个或多个可读介质的任意组合。可读介质可以是可读信号介质或者可读存储介质。可读存储介质例如可以包括但不限于电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。
参考图9,现将描述可以作为本公开的服务器或客户端的电子设备900的结构框图,其是可以应用于本公开的各方面的硬件设备的示例。电子设备旨在表示各种形式的数字电子的计算机设备,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图9所示,电子设备900包括计算单元901,其可以根据存储在只读存储器(ROM)902中的计算机程序或者从存储单元908加载到随机访问存储器(RAM)903中的计算机程序,来执行各种适当的动作和处理。在RAM 903中,还可存储设备900操作所需的各种程序和数据。计算单元901、ROM 902以及RAM 903通过总线904彼此相连。输入/输出(I/O)接口905也连接至总线904。
电子设备900中的多个部件连接至I/O接口905,包括:输入单元906、输出单元907、存储单元908以及通信单元909。输入单元906可以是能向电子设备900输入信息的任何类型的设备,输入单元906可以接收输入的数字或字符信息,以及产生与电子设备的用户设置和/或功能控制有关的键信号输入。输出单元907可以是能呈现信息的任何类型的设备,并且可以包括但不限于显示器、扬声器、视频/音频输出终端、振动器和/或打印机。存储单元908可以包括但不限于磁盘、光盘。通信单元909允许电子设备900通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据,并且可以包括但不限于调制解调器、网卡、红外通信设备、无线通信收发机和/或芯片组,例如蓝牙TM设备、WiFi设备、WiMax设备、蜂窝通信设备和/或类似物。
计算单元901可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元901的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元901执行上文所描述的各个方法和处理。例如,在一些实施例中,目标检测模型的训练方法或者目标检测方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元908。在一些实施例中,计算机程序的部分或者全部可以经由ROM 902和/或通信单元909而被载入和/或安装到电子设备900上。在一些实施例中,计算单元901可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行目标检测模型的训练方法或者目标检测方法。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
如本公开使用的,术语“机器可读介质”和“计算机可读介质”指的是用于将机器指令和/或数据提供给可编程处理器的任何计算机程序产品、设备、和/或装置(例如,磁盘、光盘、存储器、可编程逻辑装置(PLD)),包括,接收作为机器可读信号的机器指令的机器可读介质。术语“机器可读信号”指的是用于将机器指令和/或数据提供给可编程处理器的任何信号。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
需要说明的是,在本文中,诸如“第一”和“第二”等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上所述仅是本公开的具体实施方式,使本领域技术人员能够理解或实现本公开。对这些实施例的多种修改对本领域的技术人员来说将是显而易见的,本文中所定义的一般原理可以在不脱离本公开的精神或范围的情况下,在其它实施例中实现。因此,本公开将不会被限制于本文所述的这些实施例,而是要符合与本文所公开的原理和新颖特点相一致的最宽的范围。

Claims (21)

1.一种目标检测模型的训练方法,其中,所述目标检测模型包括依次连接的骨干网络和学生预测网络;所述方法包括:
获取携带有标注信息的图像样本;
将所述图像样本输入至所述骨干网络,得到初步特征;
将所述初步特征输入至所述学生预测网络,得到第一检测结果;以及,将所述初步特征输入至预设的教师预测网络,得到第二检测结果;
基于所述第一检测结果、所述第二检测结果和所述标注信息对所述骨干网络、所述学生预测网络和所述教师预测网络进行训练,直至达到预设条件时停止训练;其中,所述教师预测网络用于在训练过程中为所述学生预测网络提供指定信息,所述指定信息是所述教师预测网络生成的与所述第二检测结果相关的信息;
基于停止训练后的所述骨干网络及所述学生预测网络得到训练好的目标检测模型。
2.如权利要求1所述的目标检测模型的训练方法,其中,所述指定信息包括基于所述第二检测结果得到的标签分配信息和/或特征概率分布图。
3.如权利要求1所述的目标检测模型的训练方法,其中,基于所述第一检测结果、所述第二检测结果和所述标注信息对所述骨干网络、所述学生预测网络和所述教师预测网络进行训练,直至达到预设条件时停止训练的步骤,包括:
基于所述第一检测结果、所述第二检测结果和所述标注信息调整所述骨干网络、所述学生预测网络和所述教师预测网络的网络参数;
在达到第一预设条件时停止调整所述教师预测网络的网络参数,以及达到第二预设条件时停止调整所述骨干网络和所述学生预测网络的网络参数;其中,所述教师预测网络的停止调整时间早于所述骨干网络和所述学生预测网络的停止调整时间。
4.如权利要求3所述的目标检测模型的训练方法,其中,基于所述第一检测结果、所述第二检测结果和所述标注信息调整所述骨干网络、所述学生预测网络和所述教师预测网络的网络参数的步骤,包括:
基于所述第一检测结果、所述第二检测结果、所述标注信息和预设的学生损失函数,得到学生损失函数值;
基于所述第二检测结果、所述标注信息和预设的教师损失函数,得到教师损失函数值;
基于所述学生损失函数值和所述教师损失函数值调整所述骨干网络的网络参数,基于所述学生损失函数值调整所述学生预测网络的网络参数,以及,基于所述教师损失函数值调整所述教师预测网络的网络参数。
5.如权利要求4所述的目标检测模型的训练方法,基于所述第一检测结果、所述第二检测结果、所述标注信息和预设的学生损失函数,得到学生损失函数值的步骤,包括:
基于所述第二检测结果和所述标注信息得到所述图像样本的标签分配信息;
基于所述第一检测结果、所述标签分配信息、所述标注信息和所述预设的学生损失函数,得到学生损失函数值。
6.如权利要求5所述的目标检测模型的训练方法,其中,所述第一检测结果包括第一分类分数图和第一回归分数图;所述第二检测结果包括第二分类分数图和第二回归分数图;
基于所述第二检测结果和所述标注信息得到所述图像样本的标签分配信息的步骤,包括:
基于所述第二分类分数图、所述第二回归分数图和所述标注信息得到代价矩阵;
根据所述代价矩阵得到标签分配信息。
7.如权利要求4所述的目标检测模型的训练方法,其中,所述学生损失函数包括第一分类损失函数、第一回归损失函数和蒸馏损失函数;所述教师损失函数包括第二分类损失函数和第二回归损失函数。
8.如权利要求7所述的目标检测模型的训练方法,其中,所述蒸馏损失函数是基于所述教师预测网络生成的第二特征概率分布图和所述学生预测网络生成的第一特征概率分布图之间的差异确定;所述第一特征概率分布图与所述第一检测结果相关;所述第二特征概率分布图与所述第二检测结果相关。
9.如权利要求8所述的目标检测模型的训练方法,其中,所述第一特征概率分布图是所述学生预测网络输出的第一分类分数图进行归一化处理后得到的;所述第二特征概率分布图是所述教师预测网络输出的第二分类分数图进行归一化处理后得到的;所述第一分类分数图属于所述第一检测结果,所述第二分类分数图属于所述第二检测结果。
10.如权利要求7所述的目标检测模型的训练方法,其中,所述第一分类损失函数与所述第二分类损失函数均为质量注意损失函数;
所述第一回归损失函数和所述第二回归损失函数均包括分布注意损失函数和IOU损失函数;
所述蒸馏损失函数为逐通道蒸馏函数。
11.如权利要求1所述的目标检测模型的训练方法,其中,所述学生预测网络包括依次连接的学生金字塔网络和学生预测头,所述教师预测网络包括依次连接的教师金字塔网络和教师预测头。
12.如权利要求11所述的目标检测模型的训练方法,其中,所述学生金字塔网络和所述教师金字塔网络的网络结构相同,网络参数不同。
13.如权利要求11所述的目标检测模型的训练方法,其中,所述学生预测头包括第一指定数量个卷积层;所述教师预测头包括第二指定数量个卷积层;且所述第二指定数量大于所述第一指定数量。
14.如权利要求13所述的目标检测模型的训练方法,其中,所述学生预测头中的卷积层为深度可分离卷积层。
15.如权利要求11所述的目标检测模型的训练方法,其中,将所述初步特征输入至所述学生预测网络,得到第一检测结果的步骤,包括:
将所述初步特征分别输入至所述学生金字塔网络,得到所述学生金字塔网络输出的第一多尺度特征;
将所述第一多尺度特征输入至所述学生预测头,得到所述学生预测头输出的第一分类分数图和第一回归分数图;
基于所述第一分类分数图和所述第一回归分数图得到第一检测结果。
16.如权利要求15所述的目标检测模型的训练方法,其中,将所述初步特征输入至所述教师预测网络,得到第二检测结果的步骤,包括:
将所述初步特征输入至所述教师金字塔网络,得到所述教师金字塔网络输出的第二多尺度特征;
将所述第一多尺度特征和所述第二多尺度特征进行拼接,得到多尺度拼接特征;
将所述多尺度拼接特征输入至所述教师预测头,得到所述教师预测头输出的第二分类分数图和第二回归分数图;
基于所述第二分类分数图和所述第二回归分数图得到第二检测结果。
17.一种目标检测方法,包括:
获取待检测图像;
通过预先训练好的目标检测模型对所述待检测图像进行处理,得到所述待检测图像中所包含的目标对象的检测结果;其中,所述目标检测模型是采用权利要求1至16任一项所述的目标检测模型的训练方法训练得到的。
18.一种目标检测模型的训练装置,其中,所述目标检测模型包括依次连接的骨干网络和学生预测网络;所述装置包括:
样本获取模块,用于获取携带有标注信息的图像样本;
初步特征获取模块,用于将所述图像样本输入至所述骨干网络,得到初步特征;
检测结果获取模块,用于将所述初步特征输入至所述学生预测网络,得到第一检测结果;以及,将所述初步特征输入至预设的教师预测网络,得到第二检测结果;
训练模块,用于基于所述第一检测结果、所述第二检测结果和所述标注信息对所述骨干网络、所述学生预测网络和所述教师预测网络进行训练,直至达到预设条件时停止训练;其中,所述教师预测网络用于在训练过程中为所述学生预测网络提供指定信息,所述指定信息是所述教师预测网络生成的与所述第二检测结果相关的信息;
模型获得模块,用于基于停止训练后的所述骨干网络及所述学生预测网络得到训练好的目标检测模型。
19.一种目标检测装置,包括:
图像获取模块,用于获取待检测图像;
目标检测模块,用于通过预先训练好的目标检测模型对所述待检测图像进行处理,得到所述待检测图像中所包含的目标对象的检测结果;其中,所述目标检测模型是采用权利要求1至16任一项所述的目标检测模型的训练方法训练得到的。
20.一种电子设备,包括:
处理器;以及
存储程序的存储器,
其中,所述程序包括指令,所述指令在由所述处理器执行时使所述处理器执行根据权利要求1-16中任一项所述的目标检测模型的训练方法或者权利要求17所述的目标检测方法。
21.一种计算机可读存储介质,所述存储介质存储有计算机程序,所述计算机程序用于执行上述权利要求1-16中任一项所述的目标检测模型的训练方法或者权利要求17所述的目标检测方法。
CN202210495852.7A 2022-05-09 2022-05-09 目标检测模型的训练方法、目标检测方法、装置及设备 Active CN114596497B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210495852.7A CN114596497B (zh) 2022-05-09 2022-05-09 目标检测模型的训练方法、目标检测方法、装置及设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210495852.7A CN114596497B (zh) 2022-05-09 2022-05-09 目标检测模型的训练方法、目标检测方法、装置及设备

Publications (2)

Publication Number Publication Date
CN114596497A true CN114596497A (zh) 2022-06-07
CN114596497B CN114596497B (zh) 2022-08-19

Family

ID=81811456

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210495852.7A Active CN114596497B (zh) 2022-05-09 2022-05-09 目标检测模型的训练方法、目标检测方法、装置及设备

Country Status (1)

Country Link
CN (1) CN114596497B (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114972877A (zh) * 2022-06-09 2022-08-30 北京百度网讯科技有限公司 一种图像分类模型训练方法、装置及电子设备
CN114998570A (zh) * 2022-07-19 2022-09-02 上海闪马智能科技有限公司 一种对象检测框的确定方法、装置、存储介质及电子装置
CN115527083A (zh) * 2022-09-27 2022-12-27 中电金信软件有限公司 图像标注方法、装置和电子设备

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111507378A (zh) * 2020-03-24 2020-08-07 华为技术有限公司 训练图像处理模型的方法和装置
CN112766087A (zh) * 2021-01-04 2021-05-07 武汉大学 一种基于知识蒸馏的光学遥感图像舰船检测方法
CN112950642A (zh) * 2021-02-25 2021-06-11 中国工商银行股份有限公司 点云实例分割模型的训练方法、装置、电子设备和介质
US20210279595A1 (en) * 2020-03-05 2021-09-09 Deepak Sridhar Methods, devices and media providing an integrated teacher-student system
CN113705532A (zh) * 2021-09-10 2021-11-26 中国人民解放军国防科技大学 基于中低分辨率遥感图像的目标检测方法、装置及设备

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210279595A1 (en) * 2020-03-05 2021-09-09 Deepak Sridhar Methods, devices and media providing an integrated teacher-student system
CN111507378A (zh) * 2020-03-24 2020-08-07 华为技术有限公司 训练图像处理模型的方法和装置
CN112766087A (zh) * 2021-01-04 2021-05-07 武汉大学 一种基于知识蒸馏的光学遥感图像舰船检测方法
CN112950642A (zh) * 2021-02-25 2021-06-11 中国工商银行股份有限公司 点云实例分割模型的训练方法、装置、电子设备和介质
CN113705532A (zh) * 2021-09-10 2021-11-26 中国人民解放军国防科技大学 基于中低分辨率遥感图像的目标检测方法、装置及设备

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
AMIN BANITALEBI-DEHKORDI等: "Knowledge Distillation for Low-Power Object Detection: A Simple Technique and Its Extensions for Training Compact Models Using Unlabeled Data", 《2021 IEEE/CVF INTERNATIONAL CONFERENCE ON COMPUTER VISION WORKSHOPS (ICCVW)》 *
杨柏松: "基于深度学习的目标检测系统研究与实现", 《中国优秀博硕士学位论文全文数据库(硕士)信息科技辑》 *
褚晶辉等: "适用于目标检测的上下文感知知识蒸馏网络", 《浙江大学学报》 *

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114972877A (zh) * 2022-06-09 2022-08-30 北京百度网讯科技有限公司 一种图像分类模型训练方法、装置及电子设备
CN114998570A (zh) * 2022-07-19 2022-09-02 上海闪马智能科技有限公司 一种对象检测框的确定方法、装置、存储介质及电子装置
CN114998570B (zh) * 2022-07-19 2023-03-28 上海闪马智能科技有限公司 一种对象检测框的确定方法、装置、存储介质及电子装置
CN115527083A (zh) * 2022-09-27 2022-12-27 中电金信软件有限公司 图像标注方法、装置和电子设备

Also Published As

Publication number Publication date
CN114596497B (zh) 2022-08-19

Similar Documents

Publication Publication Date Title
CN114596497B (zh) 目标检测模型的训练方法、目标检测方法、装置及设备
US11017220B2 (en) Classification model training method, server, and storage medium
US10635979B2 (en) Category learning neural networks
US10936949B2 (en) Training machine learning models using task selection policies to increase learning progress
US20230043174A1 (en) Method for pushing anchor information, computer device, and storage medium
CN112541122A (zh) 推荐模型的训练方法、装置、电子设备及存储介质
EP3872652B1 (en) Method and apparatus for processing video, electronic device, medium and product
CN110264274B (zh) 客群划分方法、模型生成方法、装置、设备及存储介质
CN110413988A (zh) 文本信息匹配度量的方法、装置、服务器及存储介质
US20210165970A1 (en) Method and terminal for generating a text based on self-encoding neural network, and medium
KR102265573B1 (ko) 인공지능 기반 입시 수학 학습 커리큘럼 재구성 방법 및 시스템
CN111382573A (zh) 用于答案质量评估的方法、装置、设备和存储介质
CN108804577B (zh) 一种资讯标签兴趣度的预估方法
CN113326852A (zh) 模型训练方法、装置、设备、存储介质及程序产品
CN111582500A (zh) 一种提高模型训练效果的方法和系统
CN111554276B (zh) 语音识别方法、装置、设备及计算机可读存储介质
CN106663210B (zh) 基于感受的多媒体处理
CN113656582A (zh) 神经网络模型的训练方法、图像检索方法、设备和介质
CN113657483A (zh) 模型训练方法、目标检测方法、装置、设备以及存储介质
US11941867B2 (en) Neural network training using the soft nearest neighbor loss
CN112418302A (zh) 一种任务预测方法及装置
US20220188636A1 (en) Meta pseudo-labels
CN114817478A (zh) 基于文本的问答方法、装置、计算机设备及存储介质
CN113392920B (zh) 生成作弊预测模型的方法、装置、设备、介质及程序产品
CN114037052A (zh) 检测模型的训练方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant