CN112163637A - 基于非平衡数据的图像分类模型训练方法、装置 - Google Patents

基于非平衡数据的图像分类模型训练方法、装置 Download PDF

Info

Publication number
CN112163637A
CN112163637A CN202011118747.9A CN202011118747A CN112163637A CN 112163637 A CN112163637 A CN 112163637A CN 202011118747 A CN202011118747 A CN 202011118747A CN 112163637 A CN112163637 A CN 112163637A
Authority
CN
China
Prior art keywords
model
training
image
sub
trained
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202011118747.9A
Other languages
English (en)
Other versions
CN112163637B (zh
Inventor
谢雨洋
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Saiante Technology Service Co Ltd
Original Assignee
Ping An International Smart City Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An International Smart City Technology Co Ltd filed Critical Ping An International Smart City Technology Co Ltd
Priority to CN202011118747.9A priority Critical patent/CN112163637B/zh
Publication of CN112163637A publication Critical patent/CN112163637A/zh
Application granted granted Critical
Publication of CN112163637B publication Critical patent/CN112163637B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Engineering & Computer Science (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了基于非平衡数据的图像分类模型训练方法、装置。方法包括:将训练图像集平均拆分为多个子图像集,将子图像集输入待训练模型得到模型输出信息,根据损失函数及模型输出信息计算得到子图像集中每一训练图像的损失值,损失值根据模型输出信息的置信度系数进行自适应调节得到,根据梯度计算公式及每一训练图像的损失值对待训练模型进行迭代训练,直至所有子训练集均完成对待训练模型的训练。本发明基于人工智能技术,属于机器学习领域,训练图像的差异也会被相应放大并体现在损失值上,可大幅改善训练图像在各样本上的分布不平衡的应用场景中对模型进行训练的质量,可快速提升模型分类的准确度,进而提高模型的训练质量及训练效率。

Description

基于非平衡数据的图像分类模型训练方法、装置
技术领域
本发明涉及人工智能技术领域,属于智慧城市中对图像分类模型进行训练的应用场景,尤其涉及一种基于非平衡数据的图像分类模型训练方法、装置。
背景技术
深度学习在进行视觉分析中应用广泛,为了提高神经网络模型进行图像分析的准确性,通常需要采用海量训练数据对神经网络进行训练。而在某一些应用场景中,训练数据中各样本对应的数据分布不平衡,例如医疗影像中有病变的图像相对于正常图像的数量少很多,因采用分布不平衡的训练数据对神经网络模型进行训练后,在使用过程中模型的分析效果较差,传统技术方法中还可对训练数据中各样的数据进行调整以使各样本对应的数据分布平衡,然而这一调整方法会导致重复采样而让模型过拟合,且训练时间大幅增加,导致对模型进行训练质量较差、训练效率不高。因此,现有技术方法在使用分布不均衡的训练数据对神经网络模型进行训练时,存在训练质量及训练效率不高的问题。
发明内容
本发明实施例提供了一种基于非平衡数据的图像分类模型训练方法、装置、计算机设备及存储介质,旨在解决现有技术方法在使用分布不均衡的训练数据对神经网络模型进行训练时所存在的训练质量及训练效率不高的问题。
第一方面,本发明实施例提供了一种基于非平衡数据的图像分类模型训练方法,其包括:
若接收到用户输入的训练图像集,将所述训练图像集平均拆分为预设数量的子图像集;
将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息,其中,所述子图像集包含非平衡的样本数据;
根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值,其中,所述损失值可根据所述模型输出信息的置信度系数进行自适应调节;
根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整,以对所述待训练模型进行迭代训练得到训练后的模型;
判断是否存在下一子图像集;
若存在下一子图像集,将所述训练后的模型作为所述待训练模型并返回执行所述将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息的步骤;
若不存在下一子图像集,将所述训练后的模型作为目标图像分类模型。
第二方面,本发明实施例提供了一种基于非平衡数据的图像分类模型训练装置,其包括:
训练图像集拆分单元,用于若接收到用户输入的训练图像集,将所述训练图像集平均拆分为预设数量的子图像集;
模型输出信息获取单元,用于将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息,其中,所述子图像集包含非平衡的样本数据;
损失值计算单元,用于根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值,其中,所述损失值可根据所述模型输出信息的置信度系数进行自适应调节;
参数值调整单元,用于根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整,以对所述待训练模型进行迭代训练得到训练后的模型;
判断单元,用于判断是否存在下一子图像集;
返回执行单元,用于若存在下一子图像集,将所述训练后的模型作为所述待训练模型并返回执行所述将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息的步骤;
目标模型获取单元,用于若不存在下一子图像集,将所述训练后的模型作为目标图像分类模型。
第三方面,本发明实施例又提供了一种计算机设备,其包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述第一方面所述的基于非平衡数据的图像分类模型训练方法。
第四方面,本发明实施例还提供了一种计算机可读存储介质,其中所述计算机可读存储介质存储有计算机程序,所述计算机程序当被处理器执行时使所述处理器执行上述第一方面所述的基于非平衡数据的图像分类模型训练方法。
本发明实施例提供了一种基于非平衡数据的图像分类模型训练方法、装置、计算机设备及存储介质。将训练图像集平均拆分为多个子图像集,将子图像集输入待训练模型得到模型输出信息,根据损失函数及模型输出信息计算得到子图像集中每一训练图像的损失值,损失值根据模型输出信息的置信度系数进行自适应调节得到,根据梯度计算公式及每一训练图像的损失值对待训练模型进行迭代训练,直至所有子训练集均完成对待训练模型的训练。通过上述方法,损失值可基于模型输出信息的置信度系数进行自适应调节,训练图像的差异也会被相应放大并体现于所得到的损失值上,可大幅改善训练图像在各样本上的分布不平衡的应用场景中对待训练模型进行训练的质量,通过少量训练图像对模型进行训练即可快速提升模型分类的准确度,进而提高模型的训练质量及训练效率。
附图说明
为了更清楚地说明本发明实施例技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的基于非平衡数据的图像分类模型训练方法的流程示意图;
图2为本发明实施例提供的基于非平衡数据的图像分类模型训练方法的另一流程示意图;
图3为本发明实施例提供的基于非平衡数据的图像分类模型训练方法的子流程示意图;
图4为本发明实施例提供的基于非平衡数据的图像分类模型训练方法的另一子流程示意图;
图5为本发明实施例提供的基于非平衡数据的图像分类模型训练方法的另一子流程示意图;
图6为本发明实施例提供的基于非平衡数据的图像分类模型训练方法的另一子流程示意图;
图7为本发明实施例提供的基于非平衡数据的图像分类模型训练装置的示意性框图;
图8为本发明实施例提供的计算机设备的示意性框图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”和“包含”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在此本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
请参阅图1,图1是本发明实施例提供的基于非平衡数据的图像分类模型训练方法的流程示意图,该基于非平衡数据的图像分类模型训练方法应用于用户终端中,该方法通过安装于用户终端中的应用软件进行执行,用户终端即是用于执行基于非平衡数据的图像分类模型训练方法以完成对图像分类模型进行训练的终端设备,例如台式电脑、笔记本电脑、平板电脑或手机等。如图1所示,该方法包括步骤S110~S170。
S110、若接收到用户输入的训练图像集,将所述训练图像集平均拆分为预设数量的子图像集。
若接收到用户输入的训练图像集,将所述训练图像集平均拆分为预设数量的子图像集。其中,训练图像集中包含多张训练图像,训练图像集中所包含的训练图像在各样本中的分布不平衡,例如训练图像可以是医疗影像;预设数量即为对训练图像集中所包含的训练图像进行拆分的数量信息,根据预设数量即可将训练图像集中的训练图像平均拆分至对应的多个子图像集,每一子图像集中均包含多张训练图像,子图像集即可用于对待训练模型进行训练,子图像集中的每一张训练图像中还包含一个样本分类标签,样本分类标签即为对每一张训练图像的真实样本类型进行记录的标签信息。
例如,所输入的训练图像集中包含600张训练图像,预设数量为12,则将600张训练图像随机分配至12个子数据集,每一子数据集中包含50张训练图像。
在一实施例中,如图2所示,步骤S110之后还包括步骤S111。
S111、判断每一所述子图像集是否满足预设要求,以获取满足所述预设要求的子图像集。
在将子图像集输入待训练模型进行训练之前,还可对每一子图像集是否满足预设要求进行判断以获取满足所述预设要求的子图像集,其中,所述预设要求包括预设比例范围。由于对模型进行训练的方法更适用于各样本数据分布不平衡的应用场景中,为了提高对模型进行训练的效果,可对每一子图像集进行判断,以获取满足预设条件的子图像集对模型进行训练,具体的,可获取每一子图像集中正样本数量与负样本数量的比值,并判断比值是否属于预设比例范围内,以得到子图像集是否满足预设要求的判断结果。
在一实施例中,如图3所示,步骤S111还包括子步骤S1111、S1112、S1113和S1114。
S1111、根据每一训练图像的样本分类标签统计每一所述子图像集中正样本数量及负样本数量;S1112、获取每一所述子图像集中正样本数量与负样本数量的比值;S1113、判断每一所述子图像集的所述比值是否属于预设比值范围内;S1114、若所述子图像集的所述比值属于所述预设比值范围内,判定所述子图像集满足预设要求。若子图像集的比值不属于所述预设比值范围内,判定所述子图像集不满足预设要求
获取每一训练图像的样本分类标签,并基于样本分类标签统计得到每一子图像集中正样本数量及负样本数量。具体的,若训练图像对应两个样本类型,则将样本分类标签中与默认类型相匹配的一个样本类型作为正样本,另一样本类型作为负样本,并统计每一训练图像中正样本数量及负样本数量;若训练图像对应两个以上样本类型,则将与默认类型相匹配的一个样本类型作为正样本,其他样本类型作为负样本,统计每一训练图像中正样本数量及负样本数量。
例如,训练图像为医疗影像,默认类型为病变图像,训练图像对应病变图像及正常图像两个样本类型,则可对一个子图像集中的病变图像的数量进行统计以获取正样本数量,对该子图像集中的正常图像的数量进行统计以获取负样本数量。
根据子图像集的样本数量统计结果即可获取得到正样本数量与负样本数量的比值,并判断每一子图像集的比值是否输入预设比例范围内,例如,可设置预设比例范围为(1:10)-(1:100)。若子图像集的比值属于预设比例范围内,则判断该子图像集满足预设要求,否则判断该子图像集不满足预设要求。通过这一筛选过程,可获取所有满足预设要求的子图像集对待训练模型进行训练。
S120、将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息,其中,所述子图像集包含非平衡的样本数据。
将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息。具体的,待训练模型可以是基于神经网络所构建得到的图像分类模型,则待训练模型可由卷积层、全连接隐层及输出节点组成,全连接隐层中可包含一个或多个全连接层,每一全连接层中均包含多个全连接计算公式,每一全连接计算公式中均包含对应的参数,例如,全连接计算公式可表示为y1=a×x1+b;其中,a和b为该公式中的参数;全连接层与输出节点之间同样包含与全连接计算公式相类似的计算公式,输出节点可对应所需进行分类的样本类别,一个样本类别对应一个输出节点,输出节点的输出节点值即为训练图像与该输出节点的样本类型相匹配的匹配度。
具体的,子图像集中的每一张训练图像输入待训练模型,则通过待训练模型中的卷积层对训练图像进行卷积处理以提取得到图像特征信息,将图像特征信息输入全连接隐层进行计算,并经输出节点获取相应输出节点值,根据输出节点值获取训练图像与正样本的样本类型之间的置信度,获取子图像集中每一训练图像的置信度作为该子图像集的模型输出信息。其中,置信度的取值范围为[0,1],置信度越大,训练图像与正样本之间的相似度越高,也即是若某一训练图像的置信度为0,则表明该训练图像与正样本最不相似;若置信度为1,则表明该训练图像与正样本最相似。
例如,若待训练模型包含一个输出节点,该输出节点与正样本的样本类型相匹配,则该输出节点的输出节点值即为一张训练图像的置信度;若待训练模型包含两个输出节点或多个输出节点,以两个输出节点为例,两个输出节点分别与正样本的样本类型及负样本的样本类型相匹配,则可分别获取两个输出节点的输出节点值,并基于对两个输出节点值进行归一化计算,得到一张训练图像与正样本的输出节点对应的置信度。
S130、根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值。
根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值。可根据子图像集的模型输出信息及损失函数,计算得到子图像集中每一训练图像的损失值,基于每一训练图像的损失值可对待训练模型进行一次训练,则通过子图像集所包含的多张训练图像的损失值可对待训练模型进行迭代训练。
在一实施例中,如图4所示,步骤S130还包括子步骤S131、S132、S133和S134。
S131、根据所述模型输出信息获取所述子图像集的置信度系数。
首先基于子图像集的模型输出信息中所包含的置信值计算得到该子图像集的置信度系数。置信度系数可用于对子图像集中所包含训练图像的置信度进行评价,其中,置信度系数可以是置信度平均值、置信度平均方差或置信度标准差。
S132、判断每一所述训练图像的样本分类标签是否与正样本的样本类型相匹配。
样本分类标签为对训练图像的真实样本类型进行记录的标签信息,可对每一训练图像的样本分类标签是否与正样本的样本类型相匹配,也即是判断样本分类标签是否与正样本的样本类型相匹配。
在一实施例中,如图5所示,步骤S132之前还包括步骤S132a。
S132a、对每一所述训练图像的样本分类标签进行量化得到每一训练图像的标签量化值。
对每一所述训练图像的样本分类标签进行量化得到每一训练图像的标签量化值,可先获取每一训练图像的标签量化值,并基于标签量化值判断训练图像的样本分类标签是否与正样本的样本类型相匹配。
例如,对训练图像的样本分类标签进行量化,若某一训练图像的样本分类标签与默认类型相同,则该训练图像的标签量化值为1;若训练图像的样本分类标签与默认类型不相同,则该训练图像的标签量化值为0。
S133、若所述训练图像的样本分类标签与正样本的样本类型相匹配,根据所述损失函数中的第一计算公式及所述置信度系数计算得到所述训练图像的损失值;S134、若所述训练图像的样本分类标签与正样本的样本类型不相匹配,根据所述损失函数中的第二计算公式及所述置信度系数计算得到所述训练图像的损失值。
具体的,损失函数可采用公式(1)及公式(2)进行表示:
Figure BDA0002731267300000081
其中,y为某一训练图像的标签量化值,若y=1,则表示该训练图像的样本分类标签与正样本的样本类型相匹配,采用上述公式(1)计算训练图像的损失值;若y=0,则表示该训练图像的样本分类标签与正样本的样本类型不相匹配。其中,θ为损失函数中的权重值,θ=1+y0,y0为置信度系数,y0可以是置信度平均值、置信度平均方差或置信度标准差,y'为某一训练图像的置信度,采用上述公式(2)计算训练图像的损失值。通过计算子图像集中每一训练图像的置信度的平均值即可得到置信度平均值y0,y0还可以是子图像集中所包含训练图像的置信度平均方差,此时可通过公式
Figure BDA0002731267300000082
计算得到置信度平均方差;其中,n为子图像集中样本分类标签为正样本的训练图像的数量,m为子图像集中样本分类标签为负样本的训练图像的数量,y′i为第i个正样本的训练图像的置信度,y′r为第r个负样本的训练图像的置信度,
Figure BDA0002731267300000083
为n个正样本的训练图像的置信度平均值,
Figure BDA0002731267300000084
为m个负样本的训练图像的置信度平均值。y0还可以是子图像集中所包含训练图像的置信度标准差,则此时可通过公式
Figure BDA0002731267300000085
计算得到置信度平均方差,公式中各参数值与计算置信度平均方差的公式所包含参数值相同。
由于子图像集中的训练图像在各样本上的分布不平衡,基于上述损失函数计算得到的损失值,训练图像的置信度差异被对应放大并在损失函数的权重值θ上对被放大的差异进行体现,通过这一对损失值进行计算的方法,可基于子图像集中训练图像的置信度系数对训练图像的损失值之间的差值进行自适应调节,也即是当置信度系数越大,权重值θ也呈几何级放大,训练图像的损失值之间的差值也会被相应放大,从而可大幅改善训练图像在各样本上的分布不平衡的应用场景中对待训练模型进行训练的质量,通过少量训练图像对模型进行训练即可快速提升模型分类的准确度,进而提高模型的训练质量及训练效率。
例如,
Figure BDA0002731267300000091
θ=1.6,若某一训练图像的标签量化值y=1,y'=0.8,此时计算得到是损失值L1=0.0073;若某一训练图像的标签量化值y=1,y'=0.4,此时计算得到是损失值L2=0.1757,则损失值相较于初始置信度被明显放大。
S140、根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整,以对所述待训练模型进行迭代训练得到训练后的模型。
根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整以对所述待训练模型进行迭代训练。具体的,梯度计算公式即为基于梯度下降规则所构建的计算公式,基于梯度计算公式及一张训练图像的损失值即可计算得到待训练模型中每一参数的更新值,对参数的参数值进行一次更新调整,通过子图像集中多张训练图像的损失值可对待训练模型中的参数值对应进行多次调整,也即是实现对待训练模型进行迭代训练。
在一实施例中,如图6所示,步骤S140包括子步骤S141和S142。
S141、根据所述梯度计算公式及每一所述训练图像的损失值计算得到所述待训练模型中每一参数的更新值;S142、根据每一所述参数的更新值对每一所述参数的参数值进行更新调整,以对所述待训练模型进行一次训练。
具体的,将待训练模型中一个参数对训练图像进行计算所得到的计算值输入梯度计算公式,并结合该训练图像的损失值,即可计算得到与该参数对应的更新值,这一计算过程也即为梯度下降计算。
具体的,梯度计算公式可表示为:
Figure BDA0002731267300000092
其中,
Figure BDA0002731267300000093
为计算得到的参数x的更新值,ωx为参数x的原始参数值,η为梯度计算公式中预置的学习率,
Figure BDA0002731267300000094
为基于损失值及参数x对应的计算值对该参数x的偏导值(这一计算过程中需使用参数对应的计算值)。
基于所计算得到的每一参数的更新值对待训练模型中相应参数的参数值进行更新,即可完成对待训练模型的一次训练过程。基于子图像集中另一训练图像的损失值对经过一次训练后的待训练模型再次进行参数调整,并重复上述对参数值进行更新的过程,即可实现对待训练模型进行迭代训练。
当子图像集中的每一张训练图像均被用于训练后,即可终止通过当前子图像及对待训练模型进行训练的过程,并进行下一步操作。
S150、判断是否存在下一子图像集。
具体的,在使用一个子图像集对待训练模型进行训练后,可判断是否还存在下一子图像集;更进一步的,还可判断是否存在满足预设要求的下一子图像集。通过一个子图像集对待训练模型进行一次训练后,所得到的分类准确率会大幅提升,在此基础上通过多个子图像集对待训练模型进行重复训练,可基于不同训练阶段的模型计算得到子图像集的置信度系数,通过不同置信度系数逐步调整训练图像的损失值之间的差值,从而进一步提高对模型进行训练的效率。
S160、若存在下一子图像集,将所述训练后的模型作为所述待训练模型并返回执行所述将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息的步骤。
若存在下一子图像集,将所述训练后的模型作为所述待训练模型并返回执行所述将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息的步骤;也即是通过下一子图像集对训练后的模型再次进行迭代训练,并重复执行上述对待训练模型进行训练的步骤,也即是返回执行步骤S120。
S170、若不存在下一子图像集,将所述训练后的模型作为目标图像分类模型。
若不存在下一子图像集,将当前训练得到的模型作为目标图像分类模型进行输出,用户可使用所得到的目标图像分类模型对后续的待处理图像进行分类识别处理。
本申请中的技术方法可应用于智慧政务/智慧城管/智慧社区/智慧安防/智慧物流/智慧医疗/智慧教育/智慧环保/智慧交通等包含对图像分类模型进行训练的应用场景中,从而推动智慧城市的建设。
在本发明实施例所提供的基于非平衡数据的图像分类模型训练方法中,将训练图像集平均拆分为多个子图像集,将子图像集输入待训练模型得到模型输出信息,根据损失函数及模型输出信息计算得到子图像集中每一训练图像的损失值,损失值根据模型输出信息的置信度系数进行自适应调节得到,根据梯度计算公式及每一训练图像的损失值对待训练模型进行迭代训练,直至所有子训练集均完成对待训练模型的训练。通过上述方法,损失值可基于模型输出信息的置信度系数进行自适应调节,训练图像的差异也会被相应放大并体现于所得到的损失值上,可大幅改善训练图像在各样本上的分布不平衡的应用场景中对待训练模型进行训练的质量,通过少量训练图像对模型进行训练即可快速提升模型分类的准确度,进而提高模型的训练质量及训练效率。
本发明实施例还提供一种基于非平衡数据的图像分类模型训练装置,该基于非平衡数据的图像分类模型训练装置用于执行前述基于非平衡数据的图像分类模型训练方法的任一实施例。具体地,请参阅图7,图7是本发明实施例提供的基于非平衡数据的图像分类模型训练装置的示意性框图。该基于非平衡数据的图像分类模型训练装置可以配置于用户终端中。
如图7所示,基于非平衡数据的图像分类模型训练装置100包括训练图像集拆分单元110、模型输出信息获取单元120、损失值计算单元130、参数值调整单元140、判断单元150、返回执行单元160和目标模型获取单元170。
训练图像集拆分单元110,用于若接收到用户输入的训练图像集,将所述训练图像集平均拆分为预设数量的子图像集。
在一实施例中,所述基于非平衡数据的图像分类模型训练装置100还包括子单元:预设要求判断单元。
预设要求判断单元,用于判断每一所述子图像集是否满足预设要求,以获取满足所述预设要求的子图像集。
在一实施例中,所述预设要求判断单元包括子单元:样本统计单元、样本比值获取单元、比值判断单元及判断结果获取单元。
样本统计单元,用于根据每一训练图像的样本分类标签统计每一所述子图像集中正样本数量及负样本数量;样本比值获取单元,用于获取每一所述子图像集中正样本数量与负样本数量的比值;比值判断单元,用于判断每一所述子图像集的所述比值是否属于预设比值范围内;判断结果获取单元,用于若所述子图像集的所述比值属于所述预设比值范围内,判定所述子图像集满足预设要求。
模型输出信息获取单元120,用于将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息,其中,所述子图像集包含非平衡的样本数据。
损失值计算单元130,用于根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值,其中,所述损失值可根据所述模型输出信息的置信度系数进行自适应调节。
在一实施例中,所述损失值计算单元130包括子单元:置信度系数获取单元、样本分类标签判断单元、第一损失值计算单元及第二损失值计算单元。
置信度系数获取单元,用于根据所述模型输出信息获取所述子图像集的置信度系数;样本分类标签判断单元,用于判断每一所述训练图像的样本分类标签是否与正样本的样本类型相匹配;第一损失值计算单元,用于若所述训练图像的样本分类标签与正样本的样本类型相匹配,根据所述损失函数中的第一计算公式及所述置信度系数计算得到所述训练图像的损失值;第二损失值计算单元,用于若所述训练图像的样本分类标签与正样本的样本类型不相匹配,根据所述损失函数中的第二计算公式及所述置信度系数计算得到所述训练图像的损失值。
在一实施例中,所述损失值计算单元130还包括子单元:标签量化值获取单元。
标签量化值获取单元,用于对每一所述训练图像的样本分类标签进行量化得到每一训练图像的标签量化值。
参数值调整单元140,用于根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整,以对所述待训练模型进行迭代训练得到训练后的模型。
在一实施例中,所述参数值调整单元140包括子单元:更新值计算单元及参数值更新单元。
更新值计算单元,用于根据所述梯度计算公式及每一所述训练图像的损失值计算得到所述待训练模型中每一参数的更新值;参数值更新单元,用于根据每一所述参数的更新值对每一所述参数的参数值进行更新调整,以对所述待训练模型进行一次训练
判断单元150,用于判断是否存在下一子图像集。
返回执行单元160,用于若存在下一子图像集,将所述训练后的模型作为所述待训练模型并返回执行所述将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息的步骤。
目标模型获取单元170,用于若不存在下一子图像集,将所述训练后的模型作为目标图像分类模型。
在本发明实施例所提供的基于非平衡数据的图像分类模型训练装置应用上述基于非平衡数据的图像分类模型训练方法,将训练图像集平均拆分为多个子图像集,将子图像集输入待训练模型得到模型输出信息,根据损失函数及模型输出信息计算得到子图像集中每一训练图像的损失值,损失值根据模型输出信息的置信度系数进行自适应调节得到,根据梯度计算公式及每一训练图像的损失值对待训练模型进行迭代训练,直至所有子训练集均完成对待训练模型的训练。通过上述方法,损失值可基于模型输出信息的置信度系数进行自适应调节,训练图像的差异也会被相应放大并体现于所得到的损失值上,可大幅改善训练图像在各样本上的分布不平衡的应用场景中对待训练模型进行训练的质量,通过少量训练图像对模型进行训练即可快速提升模型分类的准确度,进而提高模型的训练质量及训练效率。
上述基于非平衡数据的图像分类模型训练装置可以实现为计算机程序的形式,该计算机程序可以在如图8所示的计算机设备上运行。
请参阅图8,图8是本发明实施例提供的计算机设备的示意性框图。该计算机设备可以是用于执行基于非平衡数据的图像分类模型训练方法以对图像分类模型进行训练的用户终端。
参阅图8,该计算机设备500包括通过系统总线501连接的处理器502、存储器和网络接口505,其中,存储器可以包括非易失性存储介质503和内存储器504。
该非易失性存储介质503可存储操作系统5031和计算机程序5032。该计算机程序5032被执行时,可使得处理器502执行基于非平衡数据的图像分类模型训练方法。
该处理器502用于提供计算和控制能力,支撑整个计算机设备500的运行。
该内存储器504为非易失性存储介质503中的计算机程序5032的运行提供环境,该计算机程序5032被处理器502执行时,可使得处理器502执行基于非平衡数据的图像分类模型训练方法。
该网络接口505用于进行网络通信,如提供数据信息的传输等。本领域技术人员可以理解,图8中示出的结构,仅仅是与本发明方案相关的部分结构的框图,并不构成对本发明方案所应用于其上的计算机设备500的限定,具体的计算机设备500可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
其中,所述处理器502用于运行存储在存储器中的计算机程序5032,以实现上述的基于非平衡数据的图像分类模型训练方法中对应的功能。
本领域技术人员可以理解,图8中示出的计算机设备的实施例并不构成对计算机设备具体构成的限定,在其他实施例中,计算机设备可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。例如,在一些实施例中,计算机设备可以仅包括存储器及处理器,在这样的实施例中,存储器及处理器的结构及功能与图8所示实施例一致,在此不再赘述。
应当理解,在本发明实施例中,处理器502可以是中央处理单元(CentralProcessing Unit,CPU),该处理器502还可以是其他通用处理器、数字信号处理器(DigitalSignal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现成可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等。其中,通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
在本发明的另一实施例中提供计算机可读存储介质。该计算机可读存储介质可以为非易失性的计算机可读存储介质。该计算机可读存储介质存储有计算机程序,其中计算机程序被处理器执行时实现上述的基于非平衡数据的图像分类模型训练方法中所包含的步骤。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,上述描述的设备、装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
在本发明所提供的几个实施例中,应该理解到,所揭露的设备、装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为逻辑功能划分,实际实现时可以有另外的划分方式,也可以将具有相同功能的单元集合成一个单元,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口、装置或单元的间接耦合或通信连接,也可以是电的,机械的或其它的形式连接。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本发明实施例方案的目的。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以是两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分,或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个计算机可读存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的计算机可读存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (10)

1.一种基于非平衡数据的图像分类模型训练方法,应用于用户终端中,其特征在于,所述方法包括:
若接收到用户输入的训练图像集,将所述训练图像集平均拆分为预设数量的子图像集;
将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息,其中,所述子图像集包含非平衡的样本数据;
根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值,其中,所述损失值可根据所述模型输出信息的置信度系数进行自适应调节;
根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整,以对所述待训练模型进行迭代训练得到训练后的模型;
判断是否存在下一子图像集;
若存在下一子图像集,将所述训练后的模型作为所述待训练模型并返回执行所述将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息的步骤;
若不存在下一子图像集,将所述训练后的模型作为目标图像分类模型。
2.根据权利要求1所述的基于非平衡数据的图像分类模型训练方法,其特征在于,所述将所述训练图像集平均拆分为预设数量的子图像集之后,还包括:
判断每一所述子图像集是否满足预设要求,以获取满足所述预设要求的子图像集。
3.根据权利要求2所述的基于非平衡数据的图像分类模型训练方法,其特征在于,所述预设要求包括预设比例范围,所述判断每一所述子图像集是否满足预设要求,包括:
根据每一训练图像的样本分类标签统计每一所述子图像集中正样本数量及负样本数量;
获取每一所述子图像集中正样本数量与负样本数量的比值;
判断每一所述子图像集的所述比值是否属于预设比值范围内;
若所述子图像集的所述比值属于所述预设比值范围内,判定所述子图像集满足预设要求。
4.根据权利要求1所述的基于非平衡数据的图像分类模型训练方法,其特征在于,所述根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值,包括:
根据所述模型输出信息获取所述子图像集的置信度系数;
判断每一所述训练图像的样本分类标签是否与正样本的样本类型相匹配;
若所述训练图像的样本分类标签与正样本的样本类型相匹配,根据所述损失函数中的第一计算公式及所述置信度系数计算得到所述训练图像的损失值;
若所述训练图像的样本分类标签与正样本的样本类型不相匹配,根据所述损失函数中的第二计算公式及所述置信度系数计算得到所述训练图像的损失值。
5.根据权利要求4所述的基于非平衡数据的图像分类模型训练方法,其特征在于,所述判断每一所述训练图像的样本分类标签是否与正样本的样本类型相匹配之前,还包括:
对每一所述训练图像的样本分类标签进行量化得到每一训练图像的标签量化值。
6.根据权利要求1所述的基于非平衡数据的图像分类模型训练方法,其特征在于,所述根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整,包括:
根据所述梯度计算公式及每一所述训练图像的损失值计算得到所述待训练模型中每一参数的更新值;
根据每一所述参数的更新值对每一所述参数的参数值进行更新调整,以对所述待训练模型进行一次训练。
7.根据权利要求4所述的基于非平衡数据的图像分类模型训练方法,其特征在于,所述置信度系数为置信度平均值、置信度平均方差或置信度标准差。
8.一种基于非平衡数据的图像分类模型训练装置,其特征在于,包括:
训练图像集拆分单元,用于若接收到用户输入的训练图像集,将所述训练图像集平均拆分为预设数量的子图像集;
模型输出信息获取单元,用于将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息,其中,所述子图像集包含非平衡的样本数据;
损失值计算单元,用于根据所述模型输出信息及预置的损失函数计算得到所述子图像集中每一训练图像的损失值,其中,所述损失值可根据所述模型输出信息的置信度系数进行自适应调节;
参数值调整单元,用于根据预存的梯度计算公式及每一所述训练图像的损失值对所述待训练模型中的参数值进行调整,以对所述待训练模型进行迭代训练得到训练后的模型;
判断单元,用于判断是否存在下一子图像集;
返回执行单元,用于若存在下一子图像集,将所述训练后的模型作为所述待训练模型并返回执行所述将一个所述子图像集输入待训练模型以获取所述子图像集的模型输出信息的步骤;
目标模型获取单元,用于若不存在下一子图像集,将所述训练后的模型作为目标图像分类模型。
9.一种计算机设备,包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7中任一项所述的基于非平衡数据的图像分类模型训练方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机程序,所述计算机程序当被处理器执行时使所述处理器执行如权利要求1至7任一项所述的基于非平衡数据的图像分类模型训练方法。
CN202011118747.9A 2020-10-19 2020-10-19 基于非平衡数据的图像分类模型训练方法、装置 Active CN112163637B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011118747.9A CN112163637B (zh) 2020-10-19 2020-10-19 基于非平衡数据的图像分类模型训练方法、装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011118747.9A CN112163637B (zh) 2020-10-19 2020-10-19 基于非平衡数据的图像分类模型训练方法、装置

Publications (2)

Publication Number Publication Date
CN112163637A true CN112163637A (zh) 2021-01-01
CN112163637B CN112163637B (zh) 2024-04-19

Family

ID=73867467

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011118747.9A Active CN112163637B (zh) 2020-10-19 2020-10-19 基于非平衡数据的图像分类模型训练方法、装置

Country Status (1)

Country Link
CN (1) CN112163637B (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113052246A (zh) * 2021-03-30 2021-06-29 北京百度网讯科技有限公司 用于训练分类模型及图像分类的方法和相关装置
CN113066069A (zh) * 2021-03-31 2021-07-02 深圳中科飞测科技股份有限公司 调整方法及装置、调整设备和存储介质
WO2023035586A1 (zh) * 2021-09-10 2023-03-16 上海商汤智能科技有限公司 图像检测方法、模型训练方法、装置、设备、介质及程序
CN117132174A (zh) * 2023-10-26 2023-11-28 扬宇光电(深圳)有限公司 一种应用于工业流水线质量检测的模型训练方法与系统

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP6359716B1 (ja) * 2017-03-30 2018-07-18 インテル コーポレイション 分散型コンピューティングにおける低速タスクの診断
CN109784496A (zh) * 2018-12-29 2019-05-21 厦门大学 一种面向不平衡数据集的分类方法
CN109815332A (zh) * 2019-01-07 2019-05-28 平安科技(深圳)有限公司 损失函数优化方法、装置、计算机设备及存储介质
CN111079841A (zh) * 2019-12-17 2020-04-28 深圳奇迹智慧网络有限公司 目标识别的训练方法、装置、计算机设备和存储介质
CN111680740A (zh) * 2020-06-04 2020-09-18 京东方科技集团股份有限公司 神经网络的训练方法、装置及用电负荷的判别方法、装置

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP6359716B1 (ja) * 2017-03-30 2018-07-18 インテル コーポレイション 分散型コンピューティングにおける低速タスクの診断
CN109784496A (zh) * 2018-12-29 2019-05-21 厦门大学 一种面向不平衡数据集的分类方法
CN109815332A (zh) * 2019-01-07 2019-05-28 平安科技(深圳)有限公司 损失函数优化方法、装置、计算机设备及存储介质
WO2020143304A1 (zh) * 2019-01-07 2020-07-16 平安科技(深圳)有限公司 损失函数优化方法、装置、计算机设备及存储介质
CN111079841A (zh) * 2019-12-17 2020-04-28 深圳奇迹智慧网络有限公司 目标识别的训练方法、装置、计算机设备和存储介质
CN111680740A (zh) * 2020-06-04 2020-09-18 京东方科技集团股份有限公司 神经网络的训练方法、装置及用电负荷的判别方法、装置

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
ANDREA DAL POZZOLO 等: "Calibrating Probability with Undersampling for Unbalanced Classification", 《2015 IEEE SYMPOSIUM SERIES ON COMPUTATIONAL INTELLIGENCE》, pages 159 - 166 *
吴艺凡 等: "基于混合采样的非平衡数据分类算法", 《计算机科学与探索》, pages 342 - 349 *
陆悠 等: "一种基于选择性协同学习的网络用户异常行为检测方法", 《计算机学报》, vol. 37, no. 1, pages 28 - 40 *

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113052246A (zh) * 2021-03-30 2021-06-29 北京百度网讯科技有限公司 用于训练分类模型及图像分类的方法和相关装置
CN113052246B (zh) * 2021-03-30 2023-08-04 北京百度网讯科技有限公司 用于训练分类模型及图像分类的方法和相关装置
CN113066069A (zh) * 2021-03-31 2021-07-02 深圳中科飞测科技股份有限公司 调整方法及装置、调整设备和存储介质
WO2023035586A1 (zh) * 2021-09-10 2023-03-16 上海商汤智能科技有限公司 图像检测方法、模型训练方法、装置、设备、介质及程序
CN117132174A (zh) * 2023-10-26 2023-11-28 扬宇光电(深圳)有限公司 一种应用于工业流水线质量检测的模型训练方法与系统
CN117132174B (zh) * 2023-10-26 2024-01-30 扬宇光电(深圳)有限公司 一种应用于工业流水线质量检测的模型训练方法与系统

Also Published As

Publication number Publication date
CN112163637B (zh) 2024-04-19

Similar Documents

Publication Publication Date Title
CN112163637A (zh) 基于非平衡数据的图像分类模型训练方法、装置
CN111523621A (zh) 图像识别方法、装置、计算机设备和存储介质
CN112231584B (zh) 基于小样本迁移学习的数据推送方法、装置及计算机设备
CN110147710B (zh) 人脸特征的处理方法、装置和存储介质
CN109241318B (zh) 图片推荐方法、装置、计算机设备及存储介质
CN111814810A (zh) 图像识别方法、装置、电子设备及存储介质
CN112232476A (zh) 更新测试样本集的方法及装置
CN111724370B (zh) 基于不确定性和概率的多任务图像质量评估方法及系统
CN112348079A (zh) 数据降维处理方法、装置、计算机设备及存储介质
CN112183212A (zh) 一种杂草识别方法、装置、终端设备及可读存储介质
CN112990016B (zh) 表情特征提取方法、装置、计算机设备及存储介质
CN110489659A (zh) 数据匹配方法和装置
CN113095333A (zh) 无监督特征点检测方法及装置
CN112199582A (zh) 一种内容推荐方法、装置、设备及介质
CN112329586A (zh) 基于情绪识别的客户回访方法、装置及计算机设备
CN110222734B (zh) 贝叶斯网络学习方法、智能设备及存储装置
CN115223013A (zh) 基于小数据生成网络的模型训练方法、装置、设备及介质
CN111814804B (zh) 基于ga-bp-mc神经网络的人体三维尺寸信息预测方法及装置
CN111078891A (zh) 一种基于粒子群算法的知识图谱优化方法及装置
CN115661618A (zh) 图像质量评估模型的训练方法、图像质量评估方法及装置
CN115546554A (zh) 敏感图像的识别方法、装置、设备和计算机可读存储介质
CN115619729A (zh) 人脸图像质量评估方法、装置及电子设备
CN107203916B (zh) 一种用户信用模型建立方法及装置
CN112766362A (zh) 数据处理方法、装置和设备
CN110874567B (zh) 颜值判定方法、装置、电子设备及存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
TA01 Transfer of patent application right
TA01 Transfer of patent application right

Effective date of registration: 20210203

Address after: 518000 Room 201, building A, No. 1, Qian Wan Road, Qianhai Shenzhen Hong Kong cooperation zone, Shenzhen, Guangdong (Shenzhen Qianhai business secretary Co., Ltd.)

Applicant after: Shenzhen saiante Technology Service Co.,Ltd.

Address before: 1-34 / F, Qianhai free trade building, 3048 Xinghai Avenue, Mawan, Qianhai Shenzhen Hong Kong cooperation zone, Shenzhen, Guangdong 518000

Applicant before: Ping An International Smart City Technology Co.,Ltd.

SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant