CN112329885B - 模型训练方法、装置以及计算机可读存储介质 - Google Patents

模型训练方法、装置以及计算机可读存储介质 Download PDF

Info

Publication number
CN112329885B
CN112329885B CN202011338954.5A CN202011338954A CN112329885B CN 112329885 B CN112329885 B CN 112329885B CN 202011338954 A CN202011338954 A CN 202011338954A CN 112329885 B CN112329885 B CN 112329885B
Authority
CN
China
Prior art keywords
sample
training
classification model
model
initial
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202011338954.5A
Other languages
English (en)
Other versions
CN112329885A (zh
Inventor
冯于树
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Jiangsu Yuncongxihe Artificial Intelligence Co ltd
Original Assignee
Jiangsu Yuncongxihe Artificial Intelligence Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Jiangsu Yuncongxihe Artificial Intelligence Co ltd filed Critical Jiangsu Yuncongxihe Artificial Intelligence Co ltd
Priority to CN202011338954.5A priority Critical patent/CN112329885B/zh
Publication of CN112329885A publication Critical patent/CN112329885A/zh
Application granted granted Critical
Publication of CN112329885B publication Critical patent/CN112329885B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2411Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on the proximity to a decision surface, e.g. support vector machines
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/02Knowledge representation; Symbolic representation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computational Linguistics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及数据分类技术领域,具体提供了一种模型训练方法、装置以及计算机可读存储介质,旨在解决代价敏感学习算法与数据增强方法无法进行有效的结合,导致数据分类模型的精准度和性能无法一同得到提升的技术问题。为此目的,根据本发明实施例的方法,可以采用代价敏感学习算法并且根据初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;对初始训练样本组中的训练样本进行数据增强处理,以生成增强样本;采用知识蒸馏算法,使初始数据分类模型指导第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型。通过上述步骤,可以将代价敏感学习算法和数据增强方法进行有效的结合,同时提升了模型分类的精准度和性能。

Description

模型训练方法、装置以及计算机可读存储介质
技术领域
本发明涉及数据分类技术领域,具体涉及一种模型训练方法、装置以及计算机可读存储介质。
背景技术
随着信息技术的高速发展,深度学习技术在图像分类任务上的性能已经远远超越了传统的图像识别方法。深度卷积神经网络(Convolutional Neural Network,CNN)是特别设计用于识别图像的多层感知器,CNN的权重共享网络结构与生物神经网络类似,通过对图像进行多次的卷积核池化操作,逐渐提取到图像的高层表达,再使用神经网络对特征进行分类,以此来实现对图像分类的功能。此外,通过对数据进行标注,CNN在图像分类领域表现出极大的优势。
然而,在实际的图像分类过程中可能会出现数据不平衡的情况,标注为某一类别的数据量远远小于标注为其他类别的数据量,神经网络模型往往会忽略该类别从而使得模型分类的精准度下降。为解决该问题,代价敏感学习算法是其中一种有效的方法;另一方面,在实际的图像分类过程中还可能会因为数据量较少而导致模型分类的性能差,现有技术中往往采用数据增强方法来提高神经网络模型的性能,但是代价敏感学习算法与数据增强方法无法进行有效的结合,导致神经网络模型分类的精准度和性能得到无法一同得到提升。
发明内容
为了克服上述缺陷,提出了本发明,以提供解决或至少部分地解决代价敏感学习算法与数据增强方法无法进行有效的结合,导致数据分类模型的精准度和性能无法一同得到提升的技术问题的模型训练方法、装置以及计算机可读存储介质。
第一方面,提供一种模型训练方法,所述模型训练方法包括:
采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;
对所述初始训练样本组中的训练样本进行数据增强处理,以生成增强样本;
采用知识蒸馏算法,使所述初始数据分类模型指导第二分类模型使用所述增强样本进行模型训练,得到最终的数据分类模型;
其中,
所述第一分类模型与所述第二分类模型的模型结构相同;所述初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。
在上述模型训练方法的一个技术方案中,“采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型”的步骤具体包括:
采用代价敏感学习算法按照下列公式所示的代价敏感学习函数对所述第一分类模型进行模型训练:
Figure BDA0002798019490000021
其中,所述L1表示所述代价敏感学习函数,所述N表示所述初始训练样本组中训练样本的个数;所述li表示所述初始训练样本组中第i个训练样本的训练误差,i=1,2,3,...,N;
Figure BDA0002798019490000022
所述m表示所述初始训练样本组中样本类别的总数;所述Wj表示第j个样本类别的权重且
Figure BDA0002798019490000023
所述nj表示第j个样本类别的训练样本的个数;所述pij表示第i个训练样本被分类为第j个样本类别的预测概率;所述qij表示第i个训练样本被标记为第j个样本类别的标签值。
在上述模型训练方法的一个技术方案中,“采用知识蒸馏算法并且利用所述初始数据分类模型与所述增强样本对第二分类模型进行模型训练,得到最终的数据分类模型”的步骤具体包括:
将所述增强样本同时输入至所述初始数据分类模型以及所述第二分类模型;
采用知识蒸馏算法并且按照下列公式所示的知识蒸馏函数对所述第二分类模型进行模型训练:
Figure BDA0002798019490000031
其中,所述L2表示所述知识蒸馏函数,所述la表示所述第二分类模型在对所述增强样本进行训练时确定的损失函数,所述lb表示利用所述初始数据分类模型对所述第二分类模型使用所述增强样本进行训练指导学习时确定的知识蒸馏损失函数。
在上述模型训练方法的一个技术方案中,每个所述增强样本分别由所述初始训练样本组中的任意两个训练样本各自对应的一部分样本数据组成;
所述第二分类模型的损失函数la如下式所示:
Figure BDA0002798019490000032
其中,所述r表示浮点数且r∈[0,1];所述cuj表示与增强样本相关的一个训练样本被标记为第j个样本类别的标签值,所述cvj表示与当前增强样本相关的另一个训练样本被标记为第j个样本类别的标签值,所述sj表示增强样本被分类为第j个样本类别的预测概率;
并且/或者,
所述知识蒸馏损失函数lb如下式所示:
Figure BDA0002798019490000033
其中,所述T表示超参数,T为[2,5]之间的整数;所述fj表示利用所述初始数据分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率,所述hj表示利用所述第二分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率;
Figure BDA0002798019490000034
所述zj表示所述初始数据分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量;
Figure BDA0002798019490000041
所述kj表示所述第二分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量。
在上述模型训练方法的一个技术方案中,“对所述初始训练样本组进行数据增强处理,以生成增强样本”的步骤具体包括:
采用混合样本数据增强算法对所述初始训练样本组进行数据增强处理。
第二方面,提供一种模型训练装置,所述模型训练装置包括:
代价敏感学习模块,其被配置成采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;
数据增强模块,其被配置成对所述初始训练样本组进行数据增强处理,以生成增强样本;
知识蒸馏模块,其被配置成采用知识蒸馏算法,使所述初始数据分类模型指导第二分类模型使用所述增强样本进行模型训练,得到最终的数据分类模型;
其中,
所述第一分类模型与所述第二分类模型的模型结构相同;所述初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。
在上述模型训练装置的一个技术方案中,所述代价敏感学习模块还被配置成执行以下操作:
采用代价敏感学习算法按照下列公式所示的代价敏感学习函数对所述第一分类模型进行模型训练:
Figure BDA0002798019490000042
其中,所述L1表示所述代价敏感学习函数,所述N表示所述初始训练样本组中训练样本的个数;所述li表示所述初始训练样本组中第i个训练样本的训练误差,i=1,2,3,...,N;
Figure BDA0002798019490000051
所述m表示所述初始训练样本组中样本类别的总数;所述Wj表示第j个样本类别的权重且
Figure BDA0002798019490000052
所述nj表示第j个样本类别的训练样本的个数;所述pij表示第i个训练样本被分类为第j个样本类别的预测概率;所述qij表示第i个训练样本被标记为第j个样本类别的标签值。
在上述模型训练装置的一个技术方案中,所述知识蒸馏模块还被配置成执行以下操作:
将所述增强样本同时输入至所述初始数据分类模型以及所述第二分类模型;
采用知识蒸馏算法并且按照下列公式所示的知识蒸馏函数对所述第二分类模型进行模型训练:
Figure BDA0002798019490000053
其中,所述L2表示所述知识蒸馏函数,所述la表示所述第二分类模型在对所述增强样本进行训练时确定的损失函数,所述lb表示利用所述初始数据分类模型对所述第二分类模型使用所述增强样本进行训练指导学习时确定的知识蒸馏损失函数。
在上述模型训练装置的一个技术方案中,所述知识蒸馏模块还被配置成执行以下操作:
每个所述增强样本分别由所述初始训练样本组中的任意两个训练样本各自对应的一部分样本数据组成;
所述第二分类模型的损失函数la如下式所示:
Figure BDA0002798019490000054
其中,所述r表示浮点数且r∈[0,1];所述cuj表示与增强样本相关的一个训练样本被标记为第j个样本类别的标签值,所述cvj表示与当前增强样本相关的另一个训练样本被标记为第j个样本类别的标签值,所述sj表示增强样本被分类为第j个样本类别的预测概率;
并且/或者,
所述知识蒸馏损失函数lb如下式所示:
Figure BDA0002798019490000061
其中,所述T表示超参数,T为[2,5]之间的整数;所述fj表示利用所述初始数据分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率,所述hj表示利用所述第二分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率;
Figure BDA0002798019490000062
所述zj表示所述初始数据分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量;
Figure BDA0002798019490000063
所述kj表示所述第二分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量。
在上述模型训练装置的一个技术方案中,所述数据增强模块还被配置成执行以下操作:
采用混合样本数据增强算法对所述初始训练样本组进行数据增强处理。
第三方面,提供一种模型训练装置,该模型训练装置包括处理器和存储装置,所述存储装置适于存储多条程序代码,所述程序代码适于由所述处理器加载并运行以执行上述任一项技术方案所述的模型训练方法。
第四方面,提供一种计算机可读存储介质,其中存储有多条程序代码,所述程序代码适于由处理器加载并运行以执行上述任一项技术方案所述的模型训练方法。
本发明上述一个或多个技术方案,至少具有如下一种或多种有益效果:
在实施本发明的技术方案中,首先,通过采用代价敏感学习算法训练得到初始数据分类模型,使得初始数据分类模型能够对类别不平衡的数据进行分类,提高了模型分类的精准度;其次,对初始训练样本组中的训练样本进行数据增强处理生成增强样本,使得模型训练过程中有足够数量的样本,提高了模型的性能;最后,采用知识蒸馏算法利用初始数据分类模型指导第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型,使得最终的数据分类模型不仅分类的精准度高并且模型性能也得到了提升,通过这样的设置,采用知识蒸馏算法将代价敏感学习算法和数据增强方法进行有效的结合,同时提升了模型分类的精准度和性能。
附图说明
下面参照附图来描述本发明的具体实施方式,附图中:
图1是根据本发明的一个实施例的模型训练方法的主要步骤流程示意图;
图2是根据本发明的一个实施例的模型训练装置的主要结构框图;
附图标记列表:
11:代价敏感学习模块;12:数据增强模块;13:知识蒸馏模块。
具体实施方式
下面参照附图来描述本发明的一些实施方式。本领域技术人员应当理解的是,这些实施方式仅仅用于解释本发明的技术原理,并非旨在限制本发明的保护范围。
在本发明的描述中,“模块”、“处理器”可以包括硬件、软件或者两者的组合。一个模块可以包括硬件电路,各种合适的感应器,通信端口,存储器,也可以包括软件部分,比如程序代码,也可以是软件和硬件的组合。处理器可以是中央处理器、微处理器、图像处理器、数字信号处理器或者其他任何合适的处理器。处理器具有数据和/或信号处理功能。处理器可以以软件方式实现、硬件方式实现或者二者结合方式实现。非暂时性的计算机可读存储介质包括任何合适的可存储程序代码的介质,比如磁碟、硬盘、光碟、闪存、只读存储器、随机存取存储器等等。术语“A和/或B”表示所有可能的A与B的组合,比如只是A、只是B或者A和B。术语“至少一个A或B”或者“A和B中的至少一个”含义与“A和/或B”类似,可以包括只是A、只是B或者A和B。单数形式的术语“一个”、“这个”也可以包含复数形式。
这里先解释本发明涉及到的一些术语。
代价敏感学习算法(Cost-sensitive learning algorithm)是机器学习技术领域中一种常规的机器学习算法,该算法能够考虑到不同的错误分类造成的结果不同,为了权衡不同结果产生的不同损失,将错误分类赋予非均等代价。
知识蒸馏算法(Knowledge distillation algorithm)是机器学习技术领域中一种常规的机器学习算法,该算法通过构建教师模型—学生模型框架,由教师模型指导学生模型的训练,将模型结构复杂、参数量大的教师模型所学到的关于特征表示的“知识”蒸馏出来,将这些“知识”迁移到模型结构简单、参数量少,学习能力弱的学生模型中。
目前传统的数据分类方法主要是利用卷积神经网络对数据进行多次的卷积核池化操作,逐渐提取到数据的高层表达,再使用神经网络对特征进行分类,从而对图像、语音等数据进行分类。然而,在实际的数据分类过程中可能会出现数据不平衡的情况,即标注为某一类别的数据量远远小于标注为其他类别的数据量,神经网络模型往往会忽略该类别从而使得模型分类的精准度下降。为解决该问题,代价敏感学习算法是其中一种有效的方法,代价敏感学习算法的做法在于根据各个类别的数量,分别给予不同类别的数据不同大小的权重,类别数量少的数据在模型训练计算损失时有更大的权重,从而提高神经网络模型分类的精准度。
另一方面,在实际的数据分类过程中还可能会因为数据量较少而导致模型分类的性能差,现有技术中往往采用数据增强方法来提高神经网络分类模型的性能,但是在实际应用中,数据增强方法可能无法与代价敏感学习算法直接结合使用。例如cutmix算法的做法是随机将一张图像的随机区域放置到另一张图像的相应区域,形成新的图像,并将新的图像输入模型中,计算损失时根据两张图像的比例进行加权求和。当一张图像在cutmix算法拼接后的图像中所占的比重很小时,cutmix算法将给予这张图像一个较小的权重,另一方面,如果这张图像来自图像数量较小的类别,代价敏感学习算法又会给其一个较大的权重,使得针对图像数据的增强方法和针对训练策略的算法无法进行有效的结合,从而影响神经网络模型分类的精准度和性能。
在本发明实施例中,可以采用代价敏感学习算法并且根据初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;对初始训练样本组中的训练样本进行数据增强处理,以生成增强样本;采用知识蒸馏算法并且利用初始数据分类模型与增强样本,对第二分类模型进行模型训练,得到最终的数据分类模型;其中,第一分类模型与第二分类模型的模型结构相同;初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。首先,通过采用代价敏感学习算法训练得到初始数据分类模型,使得初始数据分类模型能够对类别不平衡的数据进行分类,提高了模型分类的精准度;其次,对初始训练样本组中的训练样本进行数据增强处理生成增强样本,使得模型训练过程中有足够数量的样本,提高了模型的性能;最后,采用知识蒸馏算法使初始数据分类模型指导第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型,使得最终的数据分类模型不仅分类的精准度高并且模型性能也得到了提升,通过这样的设置,采用知识蒸馏算法将代价敏感学习算法和数据增强方法进行有效的结合,提升了模型分类的精准度和性能。
在本发明的一个应用场景中,需要判断非机动车辆是否非法进入高速公路,因此,需要训练一个能够对机动车辆和非机动车辆进行分类的神经网络模型。首先,将高速公路上的监控器拍摄到的图片作为初始训练样本组采用代价敏感学习算法对第一分类模型进行模型训练,得到能够分类机动车辆和非机动车辆的初始数据分类模型,然后,采用cutmix算法对拍摄到的图片进行任意两张图片之间的拼接以生成增强样本,最后,采用知识蒸馏算法使初始数据分类模型指导与第一分类模型结构相同的第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型,以使得最终的数据分类模型能够准确地识别出机动车辆和非机动车辆。
参阅附图1,图1是根据本发明的一个实施例的模型训练方法的主要步骤流程示意图。如图1所示,本发明实施例中的模型训练方法主要包括以下步骤:
步骤S101:采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型,其中,初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。
在本实施例中,第一分类模型包括但不限于:基于XGBoost(eXtreme GradientBoosting)算法的分类模型、基于支持向量机(Support Vector Machines,SVM)的分类模型、基于神经网络的分类模型,本领域技术人员可以根据实际需求灵活设置。训练样本包括但不限于:图像样本、语音样本,本领域技术人员可以根据实际需求灵活设置。
在本实施例中,上述的远小于指的是一部分类别训练样本的数量与其他类别训练样本的数量的差值大于预设的阈值,举一个例子:初始训练样本组包括A和B两个类别,A类别有1个训练样本,B类别有99个训练样本,而预设的阈值是80,由于99-1>80,则判定A类别训练样本的数量远小于B类别训练样本的数量。
一个实施方式中,“采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型”的步骤具体包括:采用代价敏感学习算法按照下列公式(1)所示的代价敏感学习函数对第一分类模型进行模型训练:
Figure BDA0002798019490000101
公式(1)中各参数含义是:
L1表示代价敏感学习函数,N表示初始训练样本组中训练样本的个数;li表示初始训练样本组中第i个训练样本的训练误差,i=1,2,3,...,N;
Figure BDA0002798019490000102
m表示初始训练样本组中样本类别的总数;Wj表示第j个样本类别的权重且
Figure BDA0002798019490000103
nj表示第j个样本类别的训练样本的个数;pij表示第i个训练样本被分类为第j个样本类别的预测概率;qij表示第i个训练样本被标记为第j个样本类别的标签值。在本实施方式中,可以利用代价敏感学习函数进行梯度回传,通过完成指定次数的迭代训练来得到初始数据分类模型,或者通过迭代训练使得L1达到预设值来得到初始数据分类模型。
在本实施方式中,通过根据样本类别的训练样本的个数,分别给予不同样本类别不同大小的权重,训练样本个数少的样本类别在模型训练计算损失时有更大的权重,从而使得训练后得到的初始数据分类模型针对数据不平衡的数据集进行分类时具备很好的精准度。
步骤S102:对初始训练样本组进行数据增强处理,以生成增强样本。
一个实施方式中,“对初始训练样本组进行数据增强处理,以生成增强样本”的步骤具体包括:采用混合样本数据增强算法对初始训练样本组进行数据增强处理。通过对初始训练样本组中的训练样本进行数据增强处理,使得模型训练过程中有足够数量的样本,提高了模型的性能。
在本实施方式中,混合样本数据增强算法(Mixed Sample Data Augmentation,MSDA)是数据处理技术领域中一种常规的数据增强算法,该算法能够按一定的比例随机混合样本集中的样本及其标签从而生成更多的样本及标签。混合样本数据增强算法包括但不限于:cutmix算法、mixup算法、attention mix算法,本领域技术人员可以根据实际需求灵活设置。
步骤S103:采用知识蒸馏算法,使初始数据分类模型指导第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型,其中,第一分类模型与第二分类模型的模型结构相同。
在本实施例中,第一分类模型与第二分类模型的模型结构相同,可以随机初始化一个与第一分类模型结构相同的模型作为第二分类模型。
一个实施方式中,“采用知识蒸馏算法并且利用初始数据分类模型与增强样本对第二分类模型进行模型训练,得到最终的数据分类模型”的步骤具体包括:将增强样本同时输入至初始数据分类模型以及第二分类模型;采用知识蒸馏算法并且按照下列公式(2)所示的知识蒸馏函数对第二分类模型进行模型训练:
Figure BDA0002798019490000111
公式(2)中各参数含义是:
L2表示知识蒸馏函数,la表示第二分类模型在对增强样本进行训练时确定的损失函数,lb表示利用初始数据分类模型对第二分类模型使用增强样本进行训练指导学习时确定的知识蒸馏损失函数。在本实施方式中,可以利用知识蒸馏函数进行梯度回传,通过完成指定次数的迭代训练来得到初始数据分类模型,或者通过迭代训练使得L2达到预设值来得到初始数据分类模型。
在本实施方式中,利用初始数据分类模型指导第二分类模型进行模型训练,使得模型获得对类别不平衡的数据进行分类的能力,提高了模型分类的精准度,并且,在训练过程中使用增强样本作为训练样本,提高了模型的性能,采用知识蒸馏算法将代价敏感学习算法和数据增强方法进行有效的结合,使得最终的数据分类模型不仅分类的精准度高并且模型性能也得到了提升。
一个实施方式中,每个增强样本分别由初始训练样本组中的任意两个训练样本各自对应的一部分样本数据组成;第二分类模型的损失函数la如下列公式(3)所示:
Figure BDA0002798019490000121
公式(3)中各参数含义是:
r表示浮点数且r∈[0,1];cuj表示与增强样本相关的一个训练样本被标记为第j个样本类别的标签值,cvj表示与当前增强样本相关的另一个训练样本被标记为第j个样本类别的标签值,sj表示增强样本被分类为第j个样本类别的预测概率;并且/或者,知识蒸馏损失函数lb如下列公式(4)所示:
Figure BDA0002798019490000122
公式(4)中各参数含义是:
T表示超参数,T为[2,5]之间的任意一个整数;fj表示利用初始数据分类模型获取到的增强样本被分类为第j个样本类别的预测概率,hj表示利用第二分类模型获取到的增强样本被分类为第j个样本类别的预测概率;
Figure BDA0002798019490000123
zj表示初始数据分类模型的特征提取模块输出的增强样本对应的第j个样本类别的样本特征向量;
Figure BDA0002798019490000131
kj表示第二分类模型的特征提取模块输出的增强样本对应的第j个样本类别的样本特征向量。在本实施方式中,采用知识蒸馏算法将代价敏感学习算法和数据增强方法进行有效的结合,使得最终的数据分类模型不仅分类的精准度高并且模型性能也得到了提升。
在本实施方式中,特征提取模块能够提取训练样本的样本特征以便于模型按照上述公式(2)所示的知识蒸馏函数进行模型优化。
在本实施方式中,组成每个增强样本的两个训练样本各自对应的一部分样本数据的比例可以相同,也可以不同,本领域技术人员可以根据实际需求灵活设置。在一个可能的实施方式中,初始训练样本组中的训练样本是图像样本,可以任意获取初始训练样本组中的两个训练样本x1和x2,选取训练样本x1的左侧区域并将该区域表示为
Figure BDA0002798019490000132
选取训练样本x2的右侧区域并将该区域表示为
Figure BDA0002798019490000133
将训练样本x1的左侧区域和训练样本x2的右侧区域进行拼接生成增强样本x,则x的组成可以表示为:
Figure BDA0002798019490000134
其中,B表示图像样本的宽度,r表示浮点数且r∈[0,1]。
在本发明实施例中,首先,通过采用代价敏感学习算法训练得到初始数据分类模型,使得初始数据分类模型能够对类别不平衡的数据进行分类,提高了模型分类的精准度;其次,对初始训练样本组中的训练样本进行数据增强处理生成增强样本,使得模型训练过程中有足够数量的样本,提高了模型的性能;最后,采用知识蒸馏算法利用初始数据分类模型指导第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型,使得最终的数据分类模型不仅分类的精准度高并且模型性能也得到了提升,通过这样的设置,采用知识蒸馏算法将代价敏感学习算法和数据增强方法进行有效的结合,同时提升了模型分类的精准度和性能。
需要指出的是,尽管上述实施例中将各个步骤按照特定的先后顺序进行了描述,但是本领域技术人员可以理解,为了实现本发明的效果,不同的步骤之间并非必须按照这样的顺序执行,其可以同时(并行)执行或以其他顺序执行,这些变化都在本发明的保护范围之内。
进一步,本发明还提供了一种模型训练装置。
参阅附图2,图2是根据本发明的一个实施例的模型训练装置的主要结构框图。如图2所示,本发明实施例中的模型训练装置主要包括代价敏感学习模块11、数据增强模块12和知识蒸馏模块13。在一些实施例中,代价敏感学习模块11、数据增强模块12和知识蒸馏模块13中的一个或多个可以合并在一起成为一个模块。在一些实施例中,代价敏感学习模块11可以被配置成采用代价敏感学习算法并且根据初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型,其中,初始训练样本组包括多个训练样本以及每个训练样本各自对应的样本类别标签,并且一部分样本类别标签对应的训练样本的数量远大于另一部分样本类型标签对应的训练样本的数量。数据增强模块12可以被配置成对初始训练样本组中的训练样本进行数据增强处理,以生成增强样本。知识蒸馏模块13可以被配置成采用知识蒸馏算法,使初始数据分类模型指导第二分类模型使用增强样本进行模型训练,得到最终的数据分类模型,其中,第一分类模型与第二分类模型的模型结构相同。一个实施方式中,具体实现功能的描述可以参见步骤S101-S103所述。
在一个实施方式中,代价敏感学习模块11还被配置成执行以下操作:按照公式(1)所示的代价敏感学习函数对第一分类模型进行模型训练。一个实施方式中,具体实现功能的描述可以参见步骤S101所述。
在一个实施方式中,数据增强模块12还被配置成执行以下操作:采用混合样本数据增强算法对初始训练样本组中的训练样本进行数据增强处理。一个实施方式中,具体实现功能的描述可以参见步骤S102所述。
在一个实施方式中,知识蒸馏模块13还被配置成执行以下操作:将增强样本同时输入至初始数据分类模型以及第二分类模型;采用知识蒸馏算法并且按照公式(2)所示的知识蒸馏函数对第二分类模型进行模型训练。一个实施方式中,具体实现功能的描述可以参见步骤S103所述。
在一个实施方式中,知识蒸馏模块13还被配置成执行以下操作:每个增强样本分别由初始训练样本组中的任意两个训练样本各自对应的一部分样本数据组成;第二分类模型的损失函数la如公式(3)所示;并且/或者,知识蒸馏损失函数lb如公式(4)所示。一个实施方式中,具体实现功能的描述可以参见步骤S103所述。
上述模型训练装置以用于执行图1所示的模型训练方法实施例,两者的技术原理、所解决的技术问题及产生的技术效果相似,本技术领域技术人员可以清楚地了解到,为了描述的方便和简洁,模型训练装置的具体工作过程及有关说明,可以参考模型训练方法的实施例所描述的内容,此处不再赘述。
本领域技术人员能够理解的是,本发明实现上述一实施例的方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,所述的计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,所述计算机程序包括计算机程序代码,所述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。所述计算机可读介质可以包括:能够携带所述计算机程序代码的任何实体或装置、介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器、随机存取存储器、电载波信号、电信信号以及软件分发介质等。需要说明的是,所述计算机可读介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减,例如在某些司法管辖区,根据立法和专利实践,计算机可读介质不包括电载波信号和电信信号。
进一步,本发明还提供了一种模型训练装置。在根据本发明的一个模型训练装置实施例中,模型训练装置包括处理器和存储装置,存储装置可以被配置成存储执行上述方法实施例的模型训练方法的程序,处理器可以被配置成用于执行存储装置中的程序,该程序包括但不限于执行上述方法实施例的模型训练方法的程序。为了便于说明,仅示出了与本发明实施例相关的部分,具体技术细节未揭示的,请参照本发明实施例方法部分。该控制装置可以是包括各种电子设备形成的控制装置设备。
进一步,本发明还提供了一种计算机可读存储介质。在根据本发明的一个计算机可读存储介质实施例中,计算机可读存储介质可以被配置成存储执行上述方法实施例的模型训练方法的程序,该程序可以由处理器加载并运行以实现上述模型训练方法。为了便于说明,仅示出了与本发明实施例相关的部分,具体技术细节未揭示的,请参照本发明实施例方法部分。该计算机可读存储介质可以是包括各种电子设备形成的存储装置设备,可选的,本发明实施例中存储是非暂时性的计算机可读存储介质。
进一步,应该理解的是,由于各个模块的设定仅仅是为了说明本发明的系统的功能单元,这些模块对应的物理器件可以是处理器本身,或者处理器中软件的一部分,硬件的一部分,或者软件和硬件结合的一部分。因此,图中的各个模块的数量仅仅是示意性的。
本领域技术人员能够理解的是,可以对系统中的各个模块进行适应性地拆分或合并。对具体模块的这种拆分或合并并不会导致技术方案偏离本发明的原理,因此,拆分或合并之后的技术方案都将落入本发明的保护范围内。
至此,已经结合附图所示的一个实施方式描述了本发明的技术方案,但是,本领域技术人员容易理解的是,本发明的保护范围显然不局限于这些具体实施方式。在不偏离本发明的原理的前提下,本领域技术人员可以对相关技术特征作出等同的更改或替换,这些更改或替换之后的技术方案都将落入本发明的保护范围之内。

Claims (12)

1.一种模型训练方法,应用于图像分类,其特征在于,所述模型训练方法包括:
采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;其中,所述初始训练样本组中的训练样本是图像样本;
对所述图像样本进行数据增强处理,以生成增强样本;
采用知识蒸馏算法,使所述初始数据分类模型指导第二分类模型使用所述增强样本进行模型训练,得到最终的数据分类模型,以使用所述最终的数据分类模型对图像进行图像分类;
其中,所述第一分类模型与所述第二分类模型的模型结构相同;所述初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。
2.根据权利要求1所述的模型训练方法,其特征在于,“采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型”的步骤具体包括:
采用代价敏感学习算法按照下列公式所示的代价敏感学习函数对所述第一分类模型进行模型训练:
Figure FDA0003071443780000011
其中,所述L1表示所述代价敏感学习函数,所述N表示所述初始训练样本组中图像样本的个数;所述li表示所述初始训练样本组中第i个图像样本的训练误差,i=1,2,3,...,N;
Figure FDA0003071443780000012
所述m表示所述初始训练样本组中图像样本的样本类别的总数;所述Wj表示第j个样本类别的权重且
Figure FDA0003071443780000013
j=1,2,3,...,m;所述nj表示第j个样本类别的图像样本的个数;所述pij表示第i个图像样本被分类为第j个样本类别的预测概率;所述qij表示第i个图像样本被标记为第j个样本类别的标签值。
3.根据权利要求1所述的模型训练方法,其特征在于,“采用知识蒸馏算法并且利用所述初始数据分类模型与所述增强样本对第二分类模型进行模型训练,得到最终的数据分类模型,以使用所述最终的数据分类模型对图像进行图像分类”的步骤具体包括:
将所述增强样本同时输入至所述初始数据分类模型以及所述第二分类模型;
采用知识蒸馏算法并且按照下列公式所示的知识蒸馏函数对所述第二分类模型进行模型训练:
Figure FDA0003071443780000021
其中,所述L2表示所述知识蒸馏函数,所述la表示所述第二分类模型在对所述增强样本进行训练时确定的损失函数,所述lb表示利用所述初始数据分类模型对所述第二分类模型使用所述增强样本进行训练指导学习时确定的知识蒸馏损失函数。
4.根据权利要求3所述的模型训练方法,其特征在于,每个所述增强样本分别由所述初始训练样本组中的任意两个图像样本各自对应的一部分样本数据组成;
所述第二分类模型的损失函数la如下式所示:
Figure FDA0003071443780000022
其中,所述r表示浮点数且r∈[0,1];所述cuj表示与增强样本相关的一个图像样本被标记为第j个样本类别的标签值,所述cvj表示与当前增强样本相关的另一个图像样本被标记为第j个样本类别的标签值,所述sj表示增强样本被分类为第j个样本类别的预测概率;
并且/或者,
所述知识蒸馏损失函数lb如下式所示:
Figure FDA0003071443780000023
其中,所述T表示超参数,T为[2,5]之间的整数;所述fj表示利用所述初始数据分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率,所述hj表示利用所述第二分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率;
Figure FDA0003071443780000031
所述zj表示所述初始数据分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量;
Figure FDA0003071443780000032
所述kj表示所述第二分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量。
5.根据权利要求1至4中任一项所述的模型训练方法,其特征在于,“对所述图像样本进行数据增强处理”的步骤具体包括:
采用混合样本数据增强算法对所述图像样本进行数据增强处理。
6.一种模型训练装置,应用于图像分类,其特征在于,所述训练装置包括:
代价敏感学习模块,其被配置成采用代价敏感学习算法,用初始训练样本组对第一分类模型进行模型训练,得到初始数据分类模型;其中,所述初始训练样本组中的训练样本是图像样本;
数据增强模块,其被配置成对所述图像样本进行数据增强处理,以生成增强样本;
知识蒸馏模块,其被配置成采用知识蒸馏算法,使所述初始数据分类模型指导第二分类模型使用所述增强样本进行模型训练,得到最终的数据分类模型,以使用所述最终的数据分类模型对图像进行图像分类;
其中,
所述第一分类模型与所述第二分类模型的模型结构相同;所述初始训练样本组中一部分类别训练样本的数量远小于其他类别训练样本的数量。
7.根据权利要求6所述的训练装置,其特征在于,所述代价敏感学习模块还被配置成执行以下操作:
采用代价敏感学习算法按照下列公式所示的代价敏感学习函数对所述第一分类模型进行模型训练:
Figure FDA0003071443780000041
其中,所述L1表示所述代价敏感学习函数,所述N表示所述初始训练样本组中图像样本的个数;所述li表示所述初始训练样本组中第i个图像样本的训练误差,i=1,2,3,...,N;
Figure FDA0003071443780000042
所述m表示所述初始训练样本组中图像样本样本类别的总数;所述Wj表示第j个样本类别的权重且
Figure FDA0003071443780000043
j=1,2,3,...,m;所述nj表示第j个样本类别的图像样本的个数;所述pij表示第i个图像样本被分类为第j个样本类别的预测概率;所述qij表示第i个图像样本被标记为第j个样本类别的标签值。
8.根据权利要求6所述的训练装置,其特征在于,所述知识蒸馏模块还被配置成执行以下操作:
将所述增强样本同时输入至所述初始数据分类模型以及所述第二分类模型;
采用知识蒸馏算法并且按照下列公式所示的知识蒸馏函数对所述第二分类模型进行模型训练:
Figure FDA0003071443780000044
其中,所述L2表示所述知识蒸馏函数,所述la表示所述第二分类模型在对所述增强样本进行训练时确定的损失函数,所述lb表示利用所述初始数据分类模型对所述第二分类模型使用所述增强样本进行训练指导学习时确定的知识蒸馏损失函数。
9.根据权利要求8所述的训练装置,其特征在于,所述知识蒸馏模块还被配置成执行以下操作:
每个所述增强样本分别由所述初始训练样本组中的任意两个图像样本各自对应的一部分样本数据组成;
所述第二分类模型的损失函数la如下式所示:
Figure FDA0003071443780000051
其中,所述r表示浮点数且r∈[0,1];所述cuj表示与增强样本相关的一个图像样本被标记为第j个样本类别的标签值,所述cvj表示与当前增强样本相关的另一个图像样本被标记为第j个样本类别的标签值,所述sj表示增强样本被分类为第j个样本类别的预测概率;
并且/或者,
所述知识蒸馏损失函数lb如下式所示:
Figure FDA0003071443780000052
其中,所述T表示超参数,T为[2,5]之间的整数;所述fj表示利用所述初始数据分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率,所述hj表示利用所述第二分类模型获取到的所述增强样本被分类为第j个样本类别的预测概率;
Figure FDA0003071443780000053
所述zj表示所述初始数据分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量;
Figure FDA0003071443780000054
所述kj表示所述第二分类模型的特征提取模块输出的所述增强样本对应的第j个样本类别的样本特征向量。
10.根据权利要求6至9中任一项所述的训练装置,其特征在于,所述数据增强模块还被配置成执行以下操作:
采用混合样本数据增强算法对所述图像样本进行数据增强处理。
11.一种模型训练装置,包括处理器和存储装置,所述存储装置适于存储多条程序代码,其特征在于,所述程序代码适于由所述处理器加载并运行以执行权利要求1至5中任一项所述的模型训练方法。
12.一种计算机可读存储介质,其中存储有多条程序代码,其特征在于,所述程序代码适于由处理器加载并运行以执行权利要求1至5中任一项所述的模型训练方法。
CN202011338954.5A 2020-11-25 2020-11-25 模型训练方法、装置以及计算机可读存储介质 Active CN112329885B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011338954.5A CN112329885B (zh) 2020-11-25 2020-11-25 模型训练方法、装置以及计算机可读存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011338954.5A CN112329885B (zh) 2020-11-25 2020-11-25 模型训练方法、装置以及计算机可读存储介质

Publications (2)

Publication Number Publication Date
CN112329885A CN112329885A (zh) 2021-02-05
CN112329885B true CN112329885B (zh) 2021-07-09

Family

ID=74309694

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011338954.5A Active CN112329885B (zh) 2020-11-25 2020-11-25 模型训练方法、装置以及计算机可读存储介质

Country Status (1)

Country Link
CN (1) CN112329885B (zh)

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113326768B (zh) * 2021-05-28 2023-12-22 浙江商汤科技开发有限公司 训练方法、图像特征提取方法、图像识别方法及装置
CN115544029A (zh) * 2021-06-29 2022-12-30 华为技术有限公司 一种数据处理方法及相关装置
CN113642605A (zh) * 2021-07-09 2021-11-12 北京百度网讯科技有限公司 模型蒸馏方法、装置、电子设备及存储介质
CN117616428A (zh) * 2021-11-30 2024-02-27 英特尔公司 用于在资源受约束的图像识别应用中执行并行双批自蒸馏的方法和装置
CN114202673B (zh) * 2021-12-13 2024-10-18 深圳壹账通智能科技有限公司 证件分类模型的训练方法、证件分类方法、装置和介质
CN114595785B (zh) * 2022-03-29 2022-11-04 小米汽车科技有限公司 模型训练方法、装置、电子设备及存储介质

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109034219A (zh) * 2018-07-12 2018-12-18 上海商汤智能科技有限公司 图像的多标签类别预测方法及装置、电子设备和存储介质
CN110223281A (zh) * 2019-06-06 2019-09-10 东北大学 一种数据集中含有不确定数据时的肺结节图像分类方法
WO2020111574A1 (en) * 2018-11-30 2020-06-04 Samsung Electronics Co., Ltd. System and method for incremental learning
CN111242297A (zh) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 基于知识蒸馏的模型训练方法、图像处理方法及装置
CN111444760A (zh) * 2020-02-19 2020-07-24 天津大学 一种基于剪枝与知识蒸馏的交通标志检测与识别方法
CN111738303A (zh) * 2020-05-28 2020-10-02 华南理工大学 一种基于层次学习的长尾分布图像识别方法
CN111967534A (zh) * 2020-09-03 2020-11-20 福州大学 基于生成对抗网络知识蒸馏的增量学习方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
EP3736741A1 (en) * 2019-05-06 2020-11-11 Dassault Systèmes Experience learning in virtual world

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109034219A (zh) * 2018-07-12 2018-12-18 上海商汤智能科技有限公司 图像的多标签类别预测方法及装置、电子设备和存储介质
WO2020111574A1 (en) * 2018-11-30 2020-06-04 Samsung Electronics Co., Ltd. System and method for incremental learning
CN110223281A (zh) * 2019-06-06 2019-09-10 东北大学 一种数据集中含有不确定数据时的肺结节图像分类方法
CN111242297A (zh) * 2019-12-19 2020-06-05 北京迈格威科技有限公司 基于知识蒸馏的模型训练方法、图像处理方法及装置
CN111444760A (zh) * 2020-02-19 2020-07-24 天津大学 一种基于剪枝与知识蒸馏的交通标志检测与识别方法
CN111738303A (zh) * 2020-05-28 2020-10-02 华南理工大学 一种基于层次学习的长尾分布图像识别方法
CN111967534A (zh) * 2020-09-03 2020-11-20 福州大学 基于生成对抗网络知识蒸馏的增量学习方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
Autoregressive Knowledge Distillation through Imitation Learning;Alexander Lin et al;《Conference on Empirical Methods in Natural Language Processing》;20200930;全文 *
Knowledge distill via neuron selectivity transfer;Huang Z et al;《arXiv》;20171231;全文 *
基于特征重建的知识蒸馏方法;郭俊伦等;《现代计算机》;20201031(第29期);全文 *

Also Published As

Publication number Publication date
CN112329885A (zh) 2021-02-05

Similar Documents

Publication Publication Date Title
CN112329885B (zh) 模型训练方法、装置以及计算机可读存储介质
CN111275107A (zh) 一种基于迁移学习的多标签场景图像分类方法及装置
Singh et al. Shunt connection: An intelligent skipping of contiguous blocks for optimizing MobileNet-V2
CN111507370A (zh) 获得自动标注图像中检查标签的样本图像的方法和装置
CN104933428B (zh) 一种基于张量描述的人脸识别方法及装置
CN109118504B (zh) 一种基于神经网络的图像边缘检测方法、装置及其设备
CN111489297A (zh) 用于检测危险要素的学习用图像数据集的生成方法和装置
CN111008626A (zh) 基于r-cnn检测客体的方法和装置
CN112884235B (zh) 出行推荐方法、出行推荐模型的训练方法、装置
CN112364828B (zh) 人脸识别方法及金融系统
CN116740364B (zh) 一种基于参考机制的图像语义分割方法
CN116883726B (zh) 基于多分支与改进的Dense2Net的高光谱图像分类方法及系统
CN112651324A (zh) 视频帧语义信息的提取方法、装置及计算机设备
CN112966754A (zh) 样本筛选方法、样本筛选装置及终端设备
CN115690752A (zh) 一种驾驶员行为检测方法及装置
CN114022727B (zh) 一种基于图像知识回顾的深度卷积神经网络自蒸馏方法
CN116861262B (zh) 一种感知模型训练方法、装置及电子设备和存储介质
CN117884379A (zh) 一种矿石分选方法及系统
CN112508684A (zh) 一种基于联合卷积神经网络的催收风险评级方法及系统
CN116563850A (zh) 多类别目标检测方法及其模型训练方法、装置
CN118411531A (zh) 一种神经网络的训练方法、图像处理的方法以及装置
CN116777814A (zh) 图像处理方法、装置、计算机设备、存储介质及程序产品
CN111709479B (zh) 一种图像分类方法和装置
CN113723431A (zh) 图像识别方法、装置以及计算机可读存储介质
Haas et al. Neural network compression through shunt connections and knowledge distillation for semantic segmentation problems

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant