CN115482396A - 模型训练方法、图像分类方法、装置、设备和介质 - Google Patents

模型训练方法、图像分类方法、装置、设备和介质 Download PDF

Info

Publication number
CN115482396A
CN115482396A CN202211219252.4A CN202211219252A CN115482396A CN 115482396 A CN115482396 A CN 115482396A CN 202211219252 A CN202211219252 A CN 202211219252A CN 115482396 A CN115482396 A CN 115482396A
Authority
CN
China
Prior art keywords
deep learning
feature
loss value
learning model
converted
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211219252.4A
Other languages
English (en)
Inventor
张婉平
温圣召
田飞
杨馥魁
张刚
冯浩城
韩钧宇
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202211219252.4A priority Critical patent/CN115482396A/zh
Publication of CN115482396A publication Critical patent/CN115482396A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • G06V10/44Local feature extraction by analysis of parts of the pattern, e.g. by detecting edges, contours, loops, corners, strokes or intersections; Connectivity analysis, e.g. of connected components
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computing Systems (AREA)
  • Medical Informatics (AREA)
  • Multimedia (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Databases & Information Systems (AREA)
  • Mathematical Physics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Molecular Biology (AREA)
  • Image Analysis (AREA)

Abstract

本公开提供了一种深度学习模型的训练方法,涉及人工智能技术领域,尤其涉及深度学习、图像处理、计算机视觉等技术领域。具体实现方案为:将样本图像输入第一深度学习模型的第一特征提取网络,得到样本图像的第一特征;对样本图像的第二特征进行转换,得到转换后的第二特征,其中,第一特征的维度与转换后的第二特征的维度之间的差异小于或等于预设维度差异阈值;将转换后的第二特征输入第一深度学习模型的第一分类网络,得到转换后的分类结果;以及根据转换后的分类结果,训练第一深度学习模型。本公开还提供了一种深度学习模型的训练装置、图像分类装置、电子设备和存储介质。

Description

模型训练方法、图像分类方法、装置、设备和介质
技术领域
本公开涉及人工智能技术领域,尤其涉及深度学习、图像处理、计算机视觉等技术领域,可应用于人脸识别等场景下。更具体地,本公开提供了一种深度学习模型的训练方法、图像分类方法、装置、电子设备和存储介质。
背景技术
随着人工智能技术的发展,深度学习模型广泛地应用于各种图像处理场景。例如,可以利用参数量较大的教师模型处理图像,得到处理结果。也可以用参数量较小的学生模型来拟合该处理结果,以进行知识蒸馏,来提高学生模型的性能。
发明内容
本公开提供了一种深度学习模型的训练方法、图像分类方法、装置、设备以及存储介质。
根据本公开的一方面,提供了一种深度学习模型的训练方法,该方法包括:将样本图像输入第一深度学习模型的第一特征提取网络,得到样本图像的第一特征;对样本图像的第二特征进行转换,得到转换后的第二特征,其中,第一特征的维度与转换后的第二特征的维度之间的差异小于或等于预设维度差异阈值;将转换后的第二特征输入第一深度学习模型的第一分类网络,得到转换后的分类结果;以及根据转换后的分类结果,训练第一深度学习模型。
根据本公开的另一方面,提供了一种图像分类方法,该装置包括:将目标图像输入第一深度学习模型,得到目标分类结果,其中,第一深度学习模型是利用本公开提供的方法训练的。
根据本公开的另一方面,提供了一种深度学习模型的训练装置,该装置包括:第一获得模块,用于将样本图像输入第一深度学习模型的第一特征提取网络,得到样本图像的第一特征;转换模块,用于对样本图像的第二特征进行转换,得到转换后的第二特征,其中,第一特征的维度与转换后的第二特征的维度之间的差异小于或等于预设维度差异阈值;第二获得模块,用于将转换后的第二特征输入第一深度学习模型的第一分类网络,得到转换后的分类结果;以及训练模块,用于根据转换后的分类结果,训练第一深度学习模型。
根据本公开的另一方面,提供了一种图像分类装置,该装置包括:第三获得模块,用于将目标图像输入第一深度学习模型,得到目标分类结果,其中,第一深度学习模型是利用本公开提供的装置训练的。
根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与至少一个处理器通信连接的存储器;其中,存储器存储有可被至少一个处理器执行的指令,指令被至少一个处理器执行,以使至少一个处理器能够执行根据本公开提供的方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,该计算机指令用于使计算机执行根据本公开提供的方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,计算机程序在被处理器执行时实现根据本公开提供的方法。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是根据本公开的一个实施例的深度学习模型的训练方法的流程图;
图2是根据本公开的一个实施例的深度学习模型的训练方法的原理图;
图3是根据本公开的另一个实施例的深度学习模型的训练方法的流程图;
图4是根据本公开的一个实施例的图像分类方法的流程图;
图5是根据本公开的一个实施例的深度学习模型的训练装置的框图i
图6是根据本公开的一个实施例的图像分类装置的框图;以及
图7是根据本公开的一个实施例的可以应用深度学习模型的训练方法和/或图像分类方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
可以利用参数量较小的第一深度学习模型来处理图像,得到图像的类别或图像中对象的类别。对象可以是各种物体、动物,也可以是动物或物体的局部(例如动物的面部)。为了提高第一深度学习模型的精度,可以对其进行训练。
例如,基于有监督的训练方法,可以利用样本图像及其标签,训练第一深度学习模型。但这种训练方式训练出的模型的精度较低。
又例如,将第一深度学习模型作为学生模型,将另一经训练的参数量较大的第二深度学习模型作为教师模型。可以调整学生模型的参数,使得学生模型输出的处理结果趋近教师模型输出的处理结果。这种按照蒸馏训练的方式训练出的学生模型可以有较高的精度。但,教师模型的参数量较大,处理时间较长,导致蒸馏训练所需的时间成本较高。
又例如,为了提高蒸馏训练的效率,可以用教师模型提前处理样本图像,得到相关特征和处理结果。然而,为了提高学生模型的精度,可以对样本图像进行数据增强后,再由学生模型进行处理,导致学生模型处理的图像与教师模型处理的图像之间存在区别,进而导致蒸馏训练的效率难以充分提高。
图1是是根据本公开的一个实施例的深度学习模型的训练方法的流程图。
如图1所示,该方法100可以包括操作S110至操作S140。
在操作S110,将样本图像输入第一深度学习模型的第一特征提取网络,得到样本图像的第一特征。
在本公开实施例中,第一特征提取网络可以是各种深度学习网络。例如,第一特征提取网络可以是卷积神经网络(Convolutional Neural Network,CNN)。又例如,第一特征提取网络也可以包括一个或多个Transformer编码块(Transformer Block)。
在本公开实施例中,样本图像可以来自于各种图像数据集。例如,样本图像可以来自于ImageNet图像数据集。
在操作S120,对样本图像的第二特征进行转换,得到转换后的第二特征。
在本公开实施例中,在获得第一特征之前,可以获得第二特征。例如,样本图像与一个第二特征对应。由此,可以提高第一深度学习模型的训练效率。
在本公开实施例中,第二特征的维度可以大于或等于第一特征的维度。例如,与第一特征相比,第二特征中样本图像的信息可以更多、更加丰富。
在本公开实施例中,第一特征的维度可以与转换后的第二特征的维度之间的差异小于或等于预设维度差异阈值。例如,预设维度差异阈值可以是一个较小值。又例如,转换后的第二特征的维度可以与第一特征的维度一致。
在本公开实施例中,可以利用各种网络对第二特征进行转换。例如,可以利用包括多个全连接层的全连接网络来对第二特征进行转换。
在操作S130,将转换后的第二特征输入第一深度学习模型的第一分类网络,得到转换后的分类结果。
例如,转换后的第二特征的维度与第一特征的差异较小,可以由第一分类网络进行处理。
在操作S140,根据转换后的分类结果,训练第一深度学习模型。
在本公开实施例中,根据转换后的分类结果,可以根据各种方式调整第一深度学习模型的参数。例如,基于有监督的训练方式,根据样本图像的标签和转换后的分类结果之间的差异,可以调整第一深度学习模型的参数。又例如,也可以将第一特征输入第一分类网络,得到第一分类结果。基于蒸馏训练的训练方式,根据转换后的分类结果和第一分类结果之间的差异,可以调整第一深度学习模型的参数。
通过本公开实施例,对第二特征进行了转换,得到了与第一特征的维度的差异较小的转换后的第二特征,使得第一分类网络可以根据转换后的第二特征进行分类,得到了转换后的分类结果。该转换后的分类结果同第二特征具有较强的关联,利用该转换后的分类结果训练模型,可以使得第一特征趋近与信息更多、更加丰富的第二特征,进而有助于提高第一深度学习模型输出的分类结果的精度,也有助于提高训练效率。
下面将结合相关实施例对本公开的第一深度学习模型和用于输出第二特征的第二深度学习模型进行说明。
在一些实施例中,第二特征是利用第二深度学习模型的第二特征提取网络处理样本图像得到的。例如,可以将样本图像输入第二特征提取网络,得到样本图像的第二特征。
在本公开实施例中,第二特征提取网络可以是各种深度学习网络。例如,第二特征提取网络可以是卷积神经网络。又例如,第二特征提取网络也可以包括一个或多个Transformer编码块。
在本公开实施例中,第二深度学习模型的参数量大于或等于第一深度学习模型的参数量。第二特征提取网络的参数量可以大于第一特征提取网络。例如,第二特征提取网络中Transformer编码块的数量可以大于第一特征提取网络中Transformer编码块的数量。又例如,第二特征提取网络中Transformer编码块的参数量也可以大于第一特征提取网络中Transformer编码块的参数量。
在本公开实施例中,样本图像可以与一个第二特征对应。该第二特征可以与一个第二分类结果对应。例如,第二分类结果可以是利用第二深度学习模型的第二分类网络处理第二特征得到的。又例如,在获得第一特征之前,可以获得第二特征以及第二分类结果。
在一些实施例中,第一深度学习模型可以作为学生模型,第二深度学习模型可以作为教师模型。例如,如上述的操作S140的一些实施方式中,训练第一深度学习模型还可以包括:利用第二分类结果和第一分类结果之间的差异或者第二分类结果和转换后的分类结果之间的差异,调整第一深度学习模型的参数。
可以理解,上文对本公开的第一深度学习模型和第二深度学习模型进行了说明,下面将结合相关实施例对输入第一深度学习模型的样本图像进行进一步详细说明。
在一些实施例中,在上述的操作S110的一些实施方式中,将样本图像输入第一深度学习模型的第一特征提取网络可以包括:对样本图像进行数据增强处理,得到增强后的样本图像;以及将增强后的样本图像输入第一特征提取网络。
在本公开实施例中,可以利用各种方式进行数据增强处理。例如,可以对样本图像翻转、旋转、裁切等处理,以进行数据增强。可以理解,样本图像与增强后的样本图像之间存在区别。
在本公开实施例中,根据一个样本图像,可以对第一深度学习模型进行多个周期的训练。在每个训练周期,利用一种数据增强处理方式对样本图像进行数据增强处理,得到一个增强后的样本图像。在每个训练周期,可以利用第一深度学习模型处理该增强后的样本图像,得到该训练周期的第一分类结果。由此,可以调整第一深度学习模型的参数,使得该训练周期的第一分类结果和转换后的分类结果之间的差异收敛。
通过本公开实施例,在同一样本图像的情况下,对样本图像进行了数据增强,使得第一深度学习模型的输入和第二深度模型的输入之间存在区别。由此,在利用不同的增强后的样本图像对第一深度学习模型进行多个周期的训练之后,可以使得该周期的第一分类结果与转换后的分类结果之间的差异收敛,进而使得第一深度学习模型可以从增强后的样本图像中提取出能得到准确分类结果的特征。即,可以提取出精度更高更有效的特征,有助于提高第一深度学习模型的精度。
可以理解,上文对第一深度学习模型的输入进行了详细描述,下面将对样本图像的第二特征进行转换的一些方式进行详细描述。
在一些实施例中,在操作S120的一些实施方式中,对样本图像的第二特征进行转换,得到转换后的第二特征可以包括:利用特征转换网络对第二特征进行转换,得到转换后的第二特征。
在本公开实施例中,特征转换网络可以为一个全连接网络。例如,特征转换网络可以包括多个全连接层。又例如,经过特征转换网络处理之后,可以对第二特征降维,使得转换后的第二特征的维度与第一特征的维度一致。由此,第一分类网络可以根据转换后的第二特征进行分类。
可以理解,上文对第一深度学习模型的输入进行了详细描述,下面将结合相关实施例对训练第一深度学习模型的一些实施方式进行详细说明。
在一些实施例中,在上述的操作S140的一些实施方式中,根据转换后的分类结果,训练第一深度学习模型可以包括:将第一特征输入第一分类网络,得到第一分类结果;根据第一分类结果和转换后的分类结果,确定损失值;根据损失值,调整第一深度学习模型的参数,以训练第一深度学习模型。下面将结合图2进行详细说明。
图2是根据本公开的一个实施例的深度学习模型的训练方法的原理图。
上述的第一深度学习模型可以包括第一特征提取网络N211和第一分类网络N212。上述的第二深度学习模型可以包括第二特征提取网络N221。例如,第一深度学习模型可以包括多个高效数据图像Transformer(Data-efficient image Transformer,DeiT)编码块。第二深度学习模型也可以包括多个高效数据图像Transformer编码块。第二深度学习模型的参数量可以大于第一深度学习模型的参数量。
如图2所示,可以对样本图像201进行数据增强处理,得到增强后的样本图像202。将增强后的样本图像202输入第一深度学习模型的第一特征提取网络N211,可以得到第一特征211。将第一特征211输入第一深度学习模型的第一分类网络N212,可以得到第一分类结果212。
如图2所示,在获得第一特征211的之前或同时,可以将样本图像201输入第二深度学习模型的第二特征提取网络N221,得到第二特征221。将第二特征221输入特征转换网络N230,进行特征转换,得到转换后的第二特征231。将转换后的第二特征231输入第一分类网络N212,可以得到转换后的分类结果232。在本公开实施例中,特征转换网络N230可以部署于第一深度学习模型,也可以部署于第二深度学习模型,也可以实现为独立的第三深度学习模型,本公开对此不进行限制。
在本公开实施例中,根据第一分类结果和转换后的分类结果,确定损失值可以包括:根据第一分类结果和转换后的分类结果,确定第一蒸馏子损失值。根据第一蒸馏子损失值,确定损失值。例如,如图2所示,可以根据第一分类结果212和转换后的分类结果232,确定第一蒸馏子损失值241。在一个示例中,可以将第一蒸馏子损失值241作为损失值。
又例如,可以根据以下公式确定第一蒸馏子损失值:
Ldistill_s=cross_entropy(logitss,argmax(logitst,trans)) (公式一)
Ldistill_s可以为第一蒸馏子损失值。cross_entropy(·)可以为交叉熵损失函数。logitss可以为第一分类结果。logitst,trans可以为转换后的分类结果。argmax(·)为一个数学函数,可以从分类结果中获取一个或多个最值信息。
接下来,可以根据第一蒸馏子损失值,基于反向传播或梯度下降算法,调整第一特征提取网络、第一分类网络以及特征转换网络的参数。
通过本公开实施例,在损失值收敛之后,对于一个样本图像,可以使得第一分类结果和转换后的分类结果之间的差异大幅降低。第二深度学习模型的第二分类网络也可以根据第二特征输出一个第二分类结果。与将第二分类结果与第一分类结果蒸馏相比,将转换后的分类结果与第一分类结果蒸馏,可以加快损失值收敛的效率,有助于将第二深度学习模型的性能充分地传递给第一深度学习模型,进而提高第一深度学习模型的性能,有助于提高图像分类的准确率。
此外,通过本公开实施例,利用特征转换网络对第二特征进行转换,可以快速地获取转换后的第二特征,有助于提高模型训练效率。在训练过程中,可以对特征转换网络进行调整,使得特征转换网络可以更加准确全面地对第二特征进行转换,有助于将第二深度学习模型的性能充分地传递给第一深度学习模型。
可以理解,上文对训练第一深度学习模型的一些实施方式进行了详细描述,下面将结合相关实施例对训练第一深度学习模型的另一些实施方式进行详细描述。
图3是根据本公开的另一个实施例的深度学习模型的训练方法的原理图。
如图3所示,上述的第一深度学习模型可以包括第一特征提取网络N311和第一分类网络N312。上述的第二深度学习模型可以包括第二特征提取网络N321和第二分类网络N322。例如,第一深度学习模型可以包括多个高效数据图像Transformer编码块。第二深度学习模型也可以包括多个高效数据图像Transformer编码块。第二深度学习模型的参数量可以大于第一深度学习模型的参数量。
如图3所示,可以对样本图像301进行数据增强处理,得到增强后的样本图像302。将增强后的样本图像302输入第一深度学习模型的第一特征提取网络N311,可以得到第一特征311。将第一特征311输入第一深度学习模型的第一分类网络N312,可以得到第一分类结果312。
如图3所示,在获得第一特征311的之前或同时,可以将样本图像301输入第二深度学习模型的第二特征提取网络N321,得到第二特征321。将第二特征321输入第二分类网络N322,得到第二分类结果。可以理解,图3中第一分类网络和第二分类网络的尺寸仅为示意。在一些实施例中,第二分类网络的参数量可以大于第一分类网络的参数量。
将第二特征321输入特征转换网络N330,进行特征转换,得到转换后的第二特征331。将转换后的第二特征331输入第一分类网络N312,可以得到转换后的分类结果332。在本公开实施例中,特征转换网络N330可以部署于第一深度学习模型,也可以部署于第二深度学习模型,本公开对此不进行限制。
在本公开实施例中,根据第一分类结果和转换后的分类结果,确定损失值可以包括:根据第一分类结果和转换后的分类结果,确定第一蒸馏子损失值。例如,如图3所示,可以根据第一分类结果312和转换后的分类结果332,确定第一蒸馏子损失值341。又例如,可以根据上述的公式一确定第一蒸馏子损失值:
在本公开实施例中,根据第一蒸馏子损失值,确定损失值还可以包括:根据第二分类结果和转换后的分类结果,确定第二蒸馏子损失值。例如,如上述,第二分类结果322是利用第二深度学习模型的第二分类网络N322处理第二特征321得到的。又例如,如图3所示,可以根据第二分类结果322和转换后的分类结果332,确定第二蒸馏子损失值342。又例如,可以通过以下公式确定第二蒸馏子损失值:
Ldistill_t=cross_entropy(logitst,trans,argmax(logitst)) (公式二)
Ldistill_t可以为第二蒸馏子损失值。logitst可以为第二分类结果。logitst,trans可以为转换后的分类结果。
接下来,可以根据第一蒸馏子损失值和第二蒸馏子损失值,进行求和运算或加权求和运算等各种运算,得到损失值。再根据损失值,基于反向传播或梯度下降算法,调整第一特征提取网络、第一分类网络以及特征转换网络的参数。
通过本公开实施例,在损失值收敛之后,对于一个样本图像,可以使得第一分类结果和转换后的分类结果之间的差异大幅降低,也可以使得第二分类结果和转换后的分类结果之间的差异大幅降低。由此,基于这种两次的渐进的蒸馏训练方式,第一分类结果可以更加趋近于第二分类结果,可以加快损失值收敛的效率,有助于将第二深度学习模型的性能更加充分地传递给第一深度学习模型,进而进一步提高第一深度学习模型的性能,有助于进一步提高图像分类的准确率。
此外,在本公开实施例中,根据第一蒸馏子损失值,确定损失值还可以包括:根据第一分类结果和样本图像的标签,确定第一分类子损失值。例如,标签可以指示图像或图像中对象的真实类别。例如,如图3所示,可以根据第一分类结果312和样本图像301的标签,确定第一分类子损失值313。又例如,可以通过以下公式确定第一分类子损失值:
Lcls=cross_entropy(logitss,label) (公式三)
Lcls可以为第一分类子损失值。label可以为样本图像的标签。
此外,在本公开实施例中,根据第一蒸馏子损失值,确定损失值还可以包括:根据转换后的分类结果和样本图像的标签,确定第二分类子损失值。例如,如图3所示,可以根据转换后的分类结果332和样本图像301的标签,确定第二分类子损失值333。又例如,可以通过以下公式确定第二分类子损失值:
Lt_cls=cross_entropy(logitst,trans,label) (公式四)
Lt_cls可以为第二分类子损失值。
在本公开实施例中,根据第一蒸馏子损失值,确定损失值还可以包括:根据第一蒸馏子损失值、第二蒸馏子损失值、第一分类子损失值和第二分类子损失值,确定损失值。例如,根据第一蒸馏子损失值341、第二蒸馏子损失值342、第一分类子损失值313和第二分类子损失值333,可以进行求和运算或加权求和运算等各种运算,得到损失值。又例如,可以通过以下公式确定损失值:
L=Lcls+Lt_cls+Ldistill_t+Ldistill_s (公式五)
L可以为损失值。
可以理解,公式五为确定损失值的一种方式。还可以其他方式确定损失值,例如,利用预设的权重值进行加权求和。
通过本公开实施例,利用结合蒸馏训练和有监督的训练方式,可以加快第一深度学习模型收敛,提高训练效率,也有助于提高第一深度学习模型的精度,进而提高图像分类的准确率和效率。
可以理解,上文对本公开的深度学习模型的训练方法进行了详细说明,下面将结合相关实施例对本公开提供的图像分类方法进行详细说明。
图4是根据本公开的另一个实施例的图像分类方法的流程图。
如图4所示,该方法400可以包括操作S410。
在操作S410,将目标图像输入第一深度学习模型,得到目标分类结果。
在本公开实施例中,第一深度学习模型可以是利用本公开提供的深度学习模型的训练方法训练的。例如,第一深度学习模型可以是利用方法100进行训练的。
在本公开实施例中,目标图像可以是各种图像。例如,目标对象可以是各种物体、动物,也可以是动物或物体的局部(例如动物的面部)。
在本公开实施例中,目标分类结果可以指示目标图像的类别,也可以指示目标图像中对象的类别。
图5是根据本公开的一个实施例的深度学习模型的训练装置的框图。
如图5所示,该装置500可以包括第一获得模块510、转换模块520、第二获得模块530和训练模块540。
第一获得模块510,用于将样本图像输入第一深度学习模型的第一特征提取网络,得到样本图像的第一特征。
转换模块520,用于对样本图像的第二特征进行转换,得到转换后的第二特征,其中,第一特征的维度与转换后的第二特征的维度之间的差异小于或等于预设维度差异阈值。
第二获得模块530,用于将转换后的第二特征输入第一深度学习模型的第一分类网络,得到转换后的分类结果。
训练模块540,用于根据转换后的分类结果,训练第一深度学习模型。
在一些实施例中,第二特征是利用第二深度学习模型的第二特征提取网络处理样本图像得到的,第二深度学习模型的参数量大于或等于第一深度学习模型的参数量。
在一些实施例中,第一获得模块包括:数据增强子模块,用于对样本图像进行数据增强处理,得到增强后的样本图像。输入子模块,用于将增强后的样本图像输入第一特征提取网络。
在一些实施例中,转换模块包括:第一获得子模块,用于将第一特征输入第一分类网络,得到第一分类结果。确定子模块,用于根据第一分类结果和转换后的分类结果,确定损失值。调整子模块,用于根据损失值,调整第一深度学习模型的参数,以训练第一深度学习模型。
在一些实施例中,确定子模块包括:第一确定单元,用于根据第一分类结果和转换后的分类结果,确定第一蒸馏子损失值。第二确定单元,用于根据第一蒸馏子损失值,确定损失值。
在一些实施例中,第二确定单元包括:第一确定子单元,用于根据第二分类结果和转换后的分类结果,确定第二蒸馏子损失值。第二分类结果是利用第二深度学习模型的第二分类网络处理第二特征得到的。第二确定子单元,用于根据第一蒸馏子损失值和第二蒸馏子损失值,确定损失值。
在一些实施例中,第二确定单元还包括:第三确定子单元,用于根据第一分类结果和样本图像的标签,确定第一分类子损失值;第四确定子单元,用于根据转换后的分类结果和标签,确定第二分类子损失值;第五确定子单元,用于根据第一蒸馏子损失值、第二蒸馏子损失值、第一分类子损失值和第二分类子损失值,确定损失值。
在一些实施例中,转换模块包括:转换子模块,用于利用特征转换网络对第二特征进行转换,得到转换后的第二特征。
在一些实施例中,装置500还可以包括:调整模块,用于根据损失值,调整特征转换网络的参数。
在一些实施例中,第一特征的维度与转换后的第二特征的维度一致。
图6是根据本公开的另一个实施例的图像分类装置的框图。
如图6所示,该装置600可以包括第三获得模块610。
第三获得模块610,用于将目标图像输入第一深度学习模型,得到目标分类结果。
例如,第一深度学习模型是利用本公开提供的装置训练的。
本公开的技术方案中,所涉及的用户个人信息的收集、存储、使用、加工、传输、提供和公开等处理,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图7示出了可以用来实施本公开的实施例的示例电子设备700的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图7所示,设备700包括计算单元701,其可以根据存储在只读存储器(ROM)702中的计算机程序或者从存储单元708加载到随机访问存储器(RAM)703中的计算机程序,来执行各种适当的动作和处理。在RAM 703中,还可存储设备700操作所需的各种程序和数据。计算单元701、ROM 702以及RAM 703通过总线704彼此相连。输入/输出(I/O)接口705也连接至总线704。
设备700中的多个部件连接至I/O接口705,包括:输入单元706,例如键盘、鼠标等;输出单元707,例如各种类型的显示器、扬声器等;存储单元708,例如磁盘、光盘等;以及通信单元709,例如网卡、调制解调器、无线通信收发机等。通信单元709允许设备700通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元701可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元701的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元701执行上文所描述的各个方法和处理,例如深度学习模型的训练方法和/或图像分类方法。例如,在一些实施例中,深度学习模型的训练方法和/或图像分类方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元708。在一些实施例中,计算机程序的部分或者全部可以经由ROM 702和/或通信单元709而被载入和/或安装到设备700上。当计算机程序加载到RAM 703并由计算单元701执行时,可以执行上文描述的深度学习模型的训练方法和/或图像分类方法的一个或多个步骤。备选地,在其他实施例中,计算单元701可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行深度学习模型的训练方法和/或图像分类方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、复杂可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)显示器或者LCD(液晶显示器));以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。

Claims (25)

1.一种深度学习模型的训练方法,包括:
将样本图像输入第一深度学习模型的第一特征提取网络,得到所述样本图像的第一特征;
对所述样本图像的第二特征进行转换,得到转换后的第二特征,其中,所述第一特征的维度与所述转换后的第二特征的维度之间的差异小于或等于预设维度差异阈值;
将所述转换后的第二特征输入所述第一深度学习模型的第一分类网络,得到转换后的分类结果;以及
根据所述转换后的分类结果,训练所述第一深度学习模型。
2.根据权利要求1所述的方法,其中,所述第二特征是利用第二深度学习模型的第二特征提取网络处理所述样本图像得到的,所述第二深度学习模型的参数量大于或等于所述第一深度学习模型的参数量。
3.根据权利要求1所述的方法,其中,所述将样本图像输入第一深度学习模型的第一特征提取网络包括:
对所述样本图像进行数据增强处理,得到增强后的样本图像;以及
将所述增强后的样本图像输入所述第一特征提取网络。
4.根据权利要求2所述的方法,其中,所述根据所述转换后的分类结果,训练所述第一深度学习模型包括:
将所述第一特征输入所述第一分类网络,得到第一分类结果;
根据所述第一分类结果和所述转换后的分类结果,确定损失值;以及
根据所述损失值,调整所述第一深度学习模型的参数,以训练所述第一深度学习模型。
5.根据权利要求4所述的方法,其中,所述根据所述第一分类结果和所述转换后的分类结果,确定损失值包括:
根据所述第一分类结果和所述转换后的分类结果,确定第一蒸馏子损失值;以及
根据所述第一蒸馏子损失值,确定所述损失值。
6.根据权利要求5所述的方法,其中,所述根据所述第一蒸馏子损失值,确定所述损失值包括:
根据第二分类结果和所述转换后的分类结果,确定第二蒸馏子损失值,其中,所述第二分类结果是利用所述第二深度学习模型的第二分类网络处理所述第二特征得到的;
根据所述第一蒸馏子损失值和所述第二蒸馏子损失值,确定所述损失值。
7.根据权利要求6所述的方法,其中,所述根据所述第一蒸馏子损失值,确定所述损失值还包括:
根据所述第一分类结果和所述样本图像的标签,确定第一分类子损失值;
根据所述转换后的分类结果和所述标签,确定第二分类子损失值;
根据所述第一蒸馏子损失值、所述第二蒸馏子损失值、所述第一分类子损失值和所述第二分类子损失值,确定所述损失值。
8.根据权利要求4所述的方法,其中,所述对所述样本图像的第二特征进行转换,得到转换后的第二特征包括:
利用特征转换网络对所述第二特征进行转换,得到所述转换后的第二特征。
9.根据权利要求8所述的方法,还包括:
根据所述损失值,调整所述特征转换网络的参数。
10.根据权利要求1所述的方法,其中,所述第一特征的维度与所述转换后的第二特征的维度一致。
11.一种图像分类方法,包括:
将目标图像输入第一深度学习模型,得到目标分类结果,
其中,所述第一深度学习模型是利用权利要求1至10任一项所述的方法训练的。
12.一种深度学习模型的训练装置,包括:
第一获得模块,用于将样本图像输入第一深度学习模型的第一特征提取网络,得到所述样本图像的第一特征;
转换模块,用于对所述样本图像的第二特征进行转换,得到转换后的第二特征,其中,所述第一特征的维度与所述转换后的第二特征的维度之间的差异小于或等于预设维度差异阈值;
第二获得模块,用于将所述转换后的第二特征输入所述第一深度学习模型的第一分类网络,得到转换后的分类结果;以及
训练模块,用于根据所述转换后的分类结果,训练所述第一深度学习模型。
13.根据权利要求12所述的装置,其中,所述第二特征是利用第二深度学习模型的第二特征提取网络处理所述样本图像得到的,所述第二深度学习模型的参数量大于或等于所述第一深度学习模型的参数量。
14.根据权利要求12所述的装置,其中,所述第一获得模块包括:
数据增强子模块,用于对所述样本图像进行数据增强处理,得到增强后的样本图像;以及
输入子模块,用于将所述增强后的样本图像输入所述第一特征提取网络。
15.根据权利要求13所述的装置,其中,所述转换模块包括:
第一获得子模块,用于将所述第一特征输入所述第一分类网络,得到第一分类结果;
确定子模块,用于根据所述第一分类结果和所述转换后的分类结果,确定损失值;
调整子模块,用于根据所述损失值,调整所述第一深度学习模型的参数,以训练所述第一深度学习模型。
16.根据权利要求15所述的装置,其中,所述确定子模块包括:
第一确定单元,用于根据所述第一分类结果和所述转换后的分类结果,确定第一蒸馏子损失值;以及
第二确定单元,用于根据所述第一蒸馏子损失值,确定所述损失值。
17.根据权利要求16所述的装置,其中,所述第二确定单元包括:
第一确定子单元,用于根据第二分类结果和所述转换后的分类结果,确定第二蒸馏子损失值,其中,所述第二分类结果是利用所述第二深度学习模型的第二分类网络处理所述第二特征得到的;
第二确定子单元,用于根据所述第一蒸馏子损失值和所述第二蒸馏子损失值,确定所述损失值。
18.根据权利要求17所述的装置,其中,所述第二确定单元还包括:
第三确定子单元,用于根据所述第一分类结果和所述样本图像的标签,确定第一分类子损失值;
第四确定子单元,用于根据所述转换后的分类结果和所述标签,确定第二分类子损失值;
第五确定子单元,用于根据所述第一蒸馏子损失值、所述第二蒸馏子损失值、所述第一分类子损失值和所述第二分类子损失值,确定所述损失值。
19.根据权利要求15所述的装置,其中,所述转换模块包括:
转换子模块,用于利用特征转换网络对所述第二特征进行转换,得到所述转换后的第二特征。
20.根据权利要求19所述的装置,还包括:
调整模块,用于根据所述损失值,调整所述特征转换网络的参数。
21.根据权利要求12所述的装置,其中,所述第一特征的维度与所述转换后的第二特征的维度一致。
22.一种图像分类装置,包括:
第三获得模块,用于将目标图像输入第一深度学习模型,得到目标分类结果,
其中,所述第一深度学习模型是利用权利要求12至21任一项所述的装置训练的。
23.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1至11中任一项所述的方法。
24.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1至11中任一项所述的方法。
25.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1至11中任一项所述的方法。
CN202211219252.4A 2022-09-30 2022-09-30 模型训练方法、图像分类方法、装置、设备和介质 Pending CN115482396A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211219252.4A CN115482396A (zh) 2022-09-30 2022-09-30 模型训练方法、图像分类方法、装置、设备和介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211219252.4A CN115482396A (zh) 2022-09-30 2022-09-30 模型训练方法、图像分类方法、装置、设备和介质

Publications (1)

Publication Number Publication Date
CN115482396A true CN115482396A (zh) 2022-12-16

Family

ID=84393147

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211219252.4A Pending CN115482396A (zh) 2022-09-30 2022-09-30 模型训练方法、图像分类方法、装置、设备和介质

Country Status (1)

Country Link
CN (1) CN115482396A (zh)

Similar Documents

Publication Publication Date Title
CN113326764B (zh) 训练图像识别模型和图像识别的方法和装置
JP2022135991A (ja) クロスモーダル検索モデルのトレーニング方法、装置、機器、および記憶媒体
CN113901907A (zh) 图文匹配模型训练方法、图文匹配方法及装置
CN114242113B (zh) 语音检测方法、训练方法、装置和电子设备
CN115482395B (zh) 模型训练方法、图像分类方法、装置、电子设备和介质
JP2023547010A (ja) 知識の蒸留に基づくモデルトレーニング方法、装置、電子機器
CN113360700A (zh) 图文检索模型的训练和图文检索方法、装置、设备和介质
CN114494784A (zh) 深度学习模型的训练方法、图像处理方法和对象识别方法
CN114549840A (zh) 语义分割模型的训练方法和语义分割方法、装置
CN115147680B (zh) 目标检测模型的预训练方法、装置以及设备
CN114861758A (zh) 多模态数据处理方法、装置、电子设备及可读存储介质
CN114494747A (zh) 模型的训练方法、图像处理方法、装置、电子设备及介质
CN113989152A (zh) 图像增强方法、装置、设备以及存储介质
CN113361523A (zh) 文本确定方法、装置、电子设备和计算机可读存储介质
CN115035351B (zh) 基于图像的信息提取方法、模型训练方法、装置、设备及存储介质
CN113361522B (zh) 用于确定字符序列的方法、装置和电子设备
CN115496916A (zh) 图像识别模型的训练方法、图像识别方法以及相关装置
CN113239215B (zh) 多媒体资源的分类方法、装置、电子设备及存储介质
CN114724144A (zh) 文本识别方法、模型的训练方法、装置、设备及介质
CN115481285A (zh) 跨模态的视频文本匹配方法、装置、电子设备及存储介质
CN114973333A (zh) 人物交互检测方法、装置、设备以及存储介质
CN114707638A (zh) 模型训练、对象识别方法及装置、设备、介质和产品
CN114882334A (zh) 用于生成预训练模型的方法、模型训练方法及装置
CN115482396A (zh) 模型训练方法、图像分类方法、装置、设备和介质
CN113806541A (zh) 情感分类的方法和情感分类模型的训练方法、装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination