CN117994611B - 一种图像分类模型的训练方法、装置及电子设备 - Google Patents
一种图像分类模型的训练方法、装置及电子设备 Download PDFInfo
- Publication number
- CN117994611B CN117994611B CN202410400212.2A CN202410400212A CN117994611B CN 117994611 B CN117994611 B CN 117994611B CN 202410400212 A CN202410400212 A CN 202410400212A CN 117994611 B CN117994611 B CN 117994611B
- Authority
- CN
- China
- Prior art keywords
- image
- sample
- mixed
- sample images
- loss value
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 151
- 238000000034 method Methods 0.000 title claims abstract description 79
- 238000013145 classification model Methods 0.000 title claims abstract description 65
- 238000002156 mixing Methods 0.000 claims abstract description 61
- 230000006870 function Effects 0.000 claims description 24
- 238000004891 communication Methods 0.000 claims description 18
- 238000000605 extraction Methods 0.000 claims description 16
- 238000009826 distribution Methods 0.000 claims description 14
- 238000005457 optimization Methods 0.000 claims description 14
- 238000005070 sampling Methods 0.000 claims description 12
- 238000004590 computer program Methods 0.000 claims description 10
- 230000008447 perception Effects 0.000 claims description 6
- 230000008569 process Effects 0.000 description 15
- 238000003062 neural network model Methods 0.000 description 10
- 230000009471 action Effects 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000012545 processing Methods 0.000 description 3
- 208000037170 Delayed Emergence from Anesthesia Diseases 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 230000002093 peripheral effect Effects 0.000 description 2
- 238000011176 pooling Methods 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 230000000007 visual effect Effects 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 230000002159 abnormal effect Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 239000003086 colorant Substances 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000000802 evaporation-induced self-assembly Methods 0.000 description 1
- 238000010191 image analysis Methods 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Evolutionary Computation (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例提供了一种图像分类模型的训练方法、装置及电子设备,涉及计算机视觉技术领域,本申请实施例包括:针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照指定比例混合,得到混合图像的混合标签。再将样本图像和混合图像分别输入图像分类网络,之后基于图像分类网络输出的样本图像所属的类别和训练标签,确定样本损失值,并基于图像分类网络输出的混合图像所属的类别和混合标签,确定混合损失值。再基于样本损失值和混合损失值,调整图像分类网络的网络参数,直至图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。能够提高图像分类的准确度。
Description
技术领域
本申请涉及计算机视觉技术领域,特别是涉及一种图像分类模型的训练方法、装置及电子设备。
背景技术
图像分类技术的应用范围十分广泛,例如,可应用于人脸识别、自动驾驶、智能家居和医学影像分析等领域。使用神经网络模型能够预测图像属于每种预设类别的概率,实现了对图像的分类,该方式能提高图像分类的速度和准确度,因此神经网络的快速发展进一步推动了图像分类技术在各个领域中落地。
在实际落地场景中,需要使用高可信的神经网络模型。即实际应用中期望神经网络模型预测的更高的概率对应的分类结果更可能是正确的;且更低的概率对应的分类结果更可能是不准确的,意味着神经网络对此次预测结果不太确认。
为了保证神经网络模型对图像进行分类的准确率,并降低误报率,目前在神经网络的训练过程中,通常将样本图像输入神经网络模型,得到神经网络模型输出的样本图像属于每种预设类别的概率,之后基于神经网络模型的输出结果与样本图像的训练标签计算损失值,然后利用损失值调整神经网络模型的网络参数。但该方式对于提高神经网络模型的预测准确度的效果有限,即训练后的得到神经网络模型对图像分类的准确度不够高,存在较多的误报。
发明内容
本申请实施例的目的在于提供一种图像分类模型的训练方法、装置及电子设备,以提高图像分类的准确度。具体技术方案如下:
本申请实施例的第一方面,提供了一种图像分类模型的训练方法,所述方法包括:
获取多张样本图像以及每张样本图像的训练标签,每张样本图像的训练标签表示该样本图像实际所属的类别;
针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签;
将所述样本图像和所述混合图像分别输入图像分类网络,得到所述图像分类网络输出的所述样本图像所属的类别和所述混合图像所属的类别;
基于所述图像分类网络输出的所述样本图像所属的类别和所述样本图像的训练标签,确定样本损失值;
基于所述图像分类网络输出的所述混合图像所属的类别和所述混合图像的混合标签,确定混合损失值;
基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,并返回所述获取多张样本图像以及每张样本图像的训练标签的步骤,直至所述图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。
可选的,每张原始样本的尺寸均相同;所述对该两张样本图像按照指定比例混合,得到混合图像,包括:
在0到1范围内采样,得到混合权重;
计算该两张样本图像中的第一样本图像的像素值与所述混合权重的第一乘积;
计算该两张样本图像中的第二样本图像的像素值与指定权重的第二乘积;其中,所述指定权重为1与所述混合权重的差值;
计算所述第一乘积与所述第二乘积的和值,作为所述混合图像。
可选的,所述对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签,包括:
计算所述第一样本图像的训练标签与所述混合权重的第三乘积;
计算所述第二样本图像的训练标签与所述指定权重的第四乘积;
计算所述第三乘积与所述第四乘积的和值,作为所述混合图像的混合标签。
可选的,多次采样获得的混合权重满足贝塔分布。
可选的,所述获取多张样本图像以及每张样本图像的训练标签,包括:
确定上一次从样本图像集包括的多组样本图像中选择的一组样本图像;
若上一次选择的一组样本图像不为最后一组样本图像,则获取上一次选择的样本图像组的下一组样本图像以及所述下一组样本图像的训练标签;
若上一次选择的一组样本图像为最后一组样本图像,则获取第一组样本图像以及所述第一组样本图像的训练标签。
可选的,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,包括:
获取每张样本图像的正确分类次数以及目标预测概率;其中,所述正确分类次数为:所述图像分类网络输出的该样本图像属于每种预设类别的概率中,最大概率对应的类别与该样本图像的训练标签表示的目标类别相同的次数;所述目标预测概率为:所述图像分类网络输出的该样本图像属于目标类别的概率;
针对每两张样本图像,根据该两张样本图像的正确分类次数之间的差值,以及该两张样本图像的目标预测概率之间的差值,确定该两张样本图像之间的一致性偏差;
计算每两张样本图像之间的一致性偏差的和值,作为一致性损失值;
根据所述样本损失值、所述混合损失值和所述一致性损失值,确定总损失值;
利用所述总损失值,调整所述图像分类网络的网络参数。
可选的,每两张样本图像之间的一致性偏差为:
;
其中,表示第i张样本图像和第j张样本图像之间的一致性偏差,表
示第i张样本图像,表示第j张样本图像,表示第i张样本图像的正确分类次数,表示第
j张样本图像的正确分类次数,max和sign分别表示函数运算,表示第i张样本图像的目标
预测概率,表示第j张样本图像的目标预测概率。
可选的,所述总损失值为:
;
其中,为所述总损失值,为所述样本损失值,和为预设的超参
数,为所述混合损失值,为所述一致性损失值。
可选的,所述利用所述总损失值,调整所述图像分类网络的网络参数,包括:
基于锐度感知最小化优化算法,将所述总损失值的最小值对应的所述图像分类网络的网络参数,作为候选网络参数;
若当前迭代次数未达到指定次数,则将所述图像分类网络的网络参数修改为本次计算的候选网络参数;
若当前迭代次数达到指定次数,则计算最近的预设次数针对最后一组样本图像确定的候选网络参数的平均值,将所述图像分类网络的网络参数修改为所述平均值。
可选的,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。
本申请实施例的第二方面,提供了一种图像分类模型的训练装置,所述装置包括:
获取模块,用于获取多张样本图像以及每张样本图像的训练标签,每张样本图像的训练标签表示该样本图像实际所属的类别;
增强模块,用于针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签;
分类模块,用于将所述样本图像和所述混合图像分别输入图像分类网络,得到所述图像分类网络输出的所述样本图像所属的类别和所述混合图像所属的类别;
确定模块,用于基于所述图像分类网络输出的所述样本图像所属的类别和所述样本图像的训练标签,确定样本损失值;
所述确定模块,还用于基于所述图像分类网络输出的所述混合图像所属的类别和所述混合图像的混合标签,确定混合损失值;
调整模块,用于基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,并调用所述获取模块执行所述获取多张样本图像以及每张样本图像的训练标签的步骤,直至所述图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。
可选的,每张原始样本的尺寸均相同;所述增强模块,具体用于:
在0到1范围内采样,得到混合权重;
计算该两张样本图像中的第一样本图像的像素值与所述混合权重的第一乘积;
计算该两张样本图像中的第二样本图像的像素值与指定权重的第二乘积;其中,所述指定权重为1与所述混合权重的差值;
计算所述第一乘积与所述第二乘积的和值,作为所述混合图像。
可选的,所述增强模块,具体用于:
计算所述第一样本图像的训练标签与所述混合权重的第三乘积;
计算所述第二样本图像的训练标签与所述指定权重的第四乘积;
计算所述第三乘积与所述第四乘积的和值,作为所述混合图像的混合标签。
可选的,多次采样获得的混合权重满足贝塔分布。
可选的,所述获取模块,具体用于:
确定上一次从样本图像集包括的多组样本图像中选择的一组样本图像;
若上一次选择的一组样本图像不为最后一组样本图像,则获取上一次选择的样本图像组的下一组样本图像以及所述下一组样本图像的训练标签;
若上一次选择的一组样本图像为最后一组样本图像,则获取第一组样本图像以及所述第一组样本图像的训练标签。
可选的,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述调整模块,具体用于:
获取每张样本图像的正确分类次数以及目标预测概率;其中,所述正确分类次数为:所述图像分类网络输出的该样本图像属于每种预设类别的概率中,最大概率对应的类别与该样本图像的训练标签表示的目标类别相同的次数;所述目标预测概率为:所述图像分类网络输出的该样本图像属于目标类别的概率;
针对每两张样本图像,根据该两张样本图像的正确分类次数之间的差值,以及该两张样本图像的目标预测概率之间的差值,确定该两张样本图像之间的一致性偏差;
计算每两张样本图像之间的一致性偏差的和值,作为一致性损失值;
根据所述样本损失值、所述混合损失值和所述一致性损失值,确定总损失值;
利用所述总损失值,调整所述图像分类网络的网络参数。
可选的,每两张样本图像之间的一致性偏差为:
;
其中,表示第i张样本图像和第j张样本图像之间的一致性偏差,表
示第i张样本图像,表示第j张样本图像,表示第i张样本图像的正确分类次数,表示第
j张样本图像的正确分类次数,max和sign分别表示函数运算,表示第i张样本图像的目标
预测概率,表示第j张样本图像的目标预测概率。
可选的,所述总损失值为:
;
其中,为所述总损失值,为所述样本损失值,和为预设的超参
数,为所述混合损失值,为所述一致性损失值。
可选的,所述调整模块,具体用于:
基于锐度感知最小化优化算法,将所述总损失值的最小值对应的所述图像分类网络的网络参数,作为候选网络参数;
若当前迭代次数未达到指定次数,则将所述图像分类网络的网络参数修改为本次计算的候选网络参数;
若当前迭代次数达到指定次数,则计算最近的预设次数针对最后一组样本图像确定的候选网络参数的平均值,将所述图像分类网络的网络参数修改为所述平均值。
可选的,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。
本申请实施例的第三方面,提供了一种电子设备,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现第一方面任一项所述的方法。
本申请实施例的第四方面,提供了一种计算机可读存储介质,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现第一方面任一项所述的方法。
本申请实施例的第五方面,提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述第一方面任一项所述的图像分类模型的训练方法。
本申请实施例有益效果:
本申请实施例提供的图像分类模型的训练方法、装置及电子设备,可以分别对样本图像以及样本图像的训练标签进行混合,得到混合图像以及混合图像的混合标签,之后利用样本图像和混合图像共同训练图像分类网络。由于本申请实施例中训练图像分类网络时使用的图像更多,且混合图像包含的内容更丰富,因此提高了图像分类网络的泛化性,即提高了图像分类网络的准确率,并降低了误报率。
当然,实施本申请的任一产品或方法并不一定需要同时达到以上所述的所有优点。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,还可以根据这些附图获得其他的实施例。
图1为本申请实施例提供的一种图像分类模型的训练方法的流程图;
图2为本申请实施例提供的一种Resnet18的网络结构示意图;
图3为本申请实施例提供的另一种图像分类模型的训练方法的流程图;
图4为本申请实施例提供的一种图像分类模型的训练装置的结构示意图;
图5为本申请实施例提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员基于本申请所获得的所有其他实施例,都属于本申请保护的范围。
为了提高图像分类的准确度,降低误报率,本申请实施例提供了一种图像分类模型的训练方法,该方法应用于电子设备,例如电子设备可以为:服务器、台式计算机、笔记本电脑、平板电脑或者手机等具备图像处理能力的设备。如图1所示,本申请实施例提供的图像分类模型的训练方法包括如下步骤:
S101、获取多张样本图像以及每张样本图像的训练标签。
每张样本图像均包含至少一个预设类别的对象。预设类别可以根据实际应用场景设置。例如,在自动驾驶场景下,预设类别包括:人、路障和车等。
每张样本图像的训练标签表示该样本图像实际所属的类别,即,该样本图像包括的对象实际所属的类别。其中,样本图像的训练标签可以通过人工标注获得。
可选的,可以从公开的图像集中获取多张样本图像以及每张样本图像的训练标签。或者,可以拍摄得到多张样本图像,并对每张样本图像进行人工标注,得到每张样本图像的训练标签。或者还可以通过其他方式获取多张样本图像以及每张样本图像的训练标签,本申请实施例对此不作具体限定。
S102、针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照指定比例混合,得到混合图像的混合标签。
具体的混合方式可参考下文描述。
S103、将样本图像和混合图像分别输入图像分类网络,得到图像分类网络输出的样本图像所属的类别和混合图像所属的类别。
即,将每张样本图像分别输入图像分类网络,得到图像分类网络输出的每张样本图像所属的类别。以及将每张混合图像分别输入图像分类网络,得到图像分类网络输出的每张混合图像所属的类别。
图像分类网络可以为:卷积神经网络(Convolutional Neural Network,CNN)、视觉几何小组(Visual Geometry Group,VGG)、残差网络(Residual Network,Resnet)或者转换器(transformer)等具备图像分类能力的深度学习网络。
以图像分类网络基于Resnet18构建为例,Resnet18的网络结构如图2所示,图像分类网络包括:一个卷积层、8个残差块、一个平均池化(average pooling,Avg pool)层、一个全连接(Full Connection,FC)层和一个归一化指数函数(softmax)层。其中,每个残差块包括两个串联的卷积层。图2中每个卷积层内“3×3conv”表示卷积层的卷积核大小为3×3,“3×3conv”之后的数字表示卷积层输出的特征维度,“/”后的数字表示卷积步长。图2中输入(input)表示输入图像。
S104、基于图像分类网络输出的样本图像所属的类别和样本图像的训练标签,确定样本损失值。
可以将图像分类网络输出的样本图像所属的类别以及样本图像的训练标签代入预设的损失函数,从而计算出样本损失值。例如,预设的损失函数可以为:交叉熵损失函数、L1损失函数或者L2损失函数等分类损失函数。其中,L1损失函数,也称为平均绝对误差(Mean Absolute Error,MAE)损失函数;L2损失函数,也称为均方误差(Mean SquaredError,MSE)损失函数。
S105、基于图像分类网络输出的混合图像所属的类别和混合图像的混合标签,确定混合损失值。
可以将图像分类网络输出的混合图像所属的类别以及混合图像的混合标签代入预设的损失函数,从而计算出混合损失值。例如,预设的损失函数可以为:交叉熵损失函数、L1损失函数或者L2损失函数等分类损失函数。
需说明的是,S104和S105可以并行执行,也可以先后执行,本申请实施例对S104和S105的执行顺序不作具体限定。
S106、基于样本损失值和混合损失值,调整图像分类网络的网络参数,并返回S101,直至图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。
本申请实施例提供的图像分类模型的训练方法,可以分别对样本图像以及样本图像的训练标签进行混合,得到混合图像以及混合图像的混合标签,之后利用样本图像和混合图像共同训练图像分类网络。由于本申请实施例中训练图像分类网络时使用的图像更多,且混合图像包含的内容更丰富,因此提高了图像分类网络的泛化性,即提高了图像分类网络的准确率,并降低了误报率。
以下对图1所示的图像分类模型的训练方法包括的步骤进行详细说明:
上述S102是对样本图像进行数据增强的过程。对两张样本图像进行混合,可以理解为对两张样本图像中相同位置的像素点按照比例叠加。
参见图3,S102中对该两张样本图像按照指定比例混合,得到混合图像的方式,包括以下步骤:
S1021、在0到1范围内采样,得到混合权重。
其中,采样范围可以包括0或不包括0,也可以包括1或不包括1。可选的,可以在0到1范围内随机采样,或者等间隔采样,获得混合权重。例如,对S101获取的每两张样本进行混合,假设共需要混合10次,则每次获取的混合权重分别为:0.1、0.2、0.3、0.4、0.5、0.6、0.7、0.8、0.9、1。
或者,可以按照贝塔分布曲线,在0到1范围内随机采样,获得混合权重。使得多次采样获得的混合权重满足贝塔分布,其中,贝塔分布,也称β分布,指一组定义在(0,1)区间的连续概率分布。
即,,本申请实施例设置贝塔分布的两个参数均为β,
从而使得混合权重满足的概率分布曲线是对称的,且对称轴对应0.5。
通过设置多次采样得到的混合权重满足贝塔分布,可以使得混合权重的设定更灵活,且对0到1范围内的覆盖更广,避免随机采样带来的不稳定性。而且,本申请实施例中混合权重满足的贝塔分布是对称的,降低每次混合样本图像i和j时,总是依赖样本图像i或j,保障了混合的平均性,提高了混合图像的多样性。
S1022、计算该两张样本图像中的第一样本图像的像素值与混合权重的第一乘积。
可以分别计算第一样本图像中每个像素点的像素值与混合权重的第一乘积。
S1023、计算该两张样本图像中的第二样本图像的像素值与指定权重的第二乘积。其中,指定权重为1与混合权重的差值。第一样本图像和第二样本图像中的“第一”和“第二”仅用于区分这两张样本图像。
可以分别计算第二样本图像中每个像素点的像素值与指定权重的第二乘积。
S1024、计算第一乘积与第二乘积的和值,作为混合图像。
即,混合图像为:
(1);
其中,为第n张混合图像,为混合权重,为第i张样本图像,为第j张样本
图像。i和j的取值范围均为[1,M],M为S101获取的样本图像数量。
本申请实施例中,每张样本图像的尺寸均相同,因此可以针对每两张样本图像中每个位置的像素点,通过公式(1)计算该像素点的像素值的混合像素值,从而得到混合图像。
通过上述方法,本申请实施例可以对每两张样本图像中相同位置的像素点的像素值,按指定比例混合,得到混合图像。使得混合图像包含两张样本图像叠加后的色彩和纹理,因此混合图像更加复杂且多样,丰富了图像分类网络的训练样本集,提高了图像分类网络的泛化性。
参见图3,上述S102对该两张样本图像的训练标签按照指定比例混合,得到混合图像的混合标签的方式,包括如下步骤:
S1025、计算第一样本图像的训练标签与混合权重的第三乘积。
训练标签表示样本图像属于多种预设类别中的其中一种预设类别。
样本图像的训练标签为该样本图像所属的类别的独热码(one-hot code)。其中,各类别的one-hot code长度相同;每种类别的one-hot code中只有一位为1,其余位置均为0,且每种预设类别的one-hot code包括的1的位置各不相同。
例如,各预设类别包括:人、车和路障。包含人的样本图像的训练标签为(1,0,0),包含车的样本图像的训练标签为(0,1,0),包含路障的样本图像的训练标签为(0,0,1)。
基于此,可以将第一样本图像的训练标签中的每一位分别与混合权重相乘,得到第三乘积。例如,训练标签为(1,0,0),混合权重为0.5,第三乘积为(0.3,0,0)。
S1026、计算第二样本图像的训练标签与指定权重的第四乘积。其中,指定权重=1-混合权重。
计算第四乘积的方式与计算第三乘积的方式相同,可参考S1025中的相关描述,此处不再赘述。
S1027、计算第三乘积与第四乘积的和值,作为混合图像的混合标签。
即,混合图像的混合标签为:
(2);
其中,为第n张混合图像,为混合权重,为第i张样本图像的训练标签,为
第j张样本图像的训练标签。i和j的取值范围均为[1,M],M为S101获取的样本图像数量。需
要说明的是,基于S1025中的描述,此处样本图像的训练标签和混合图像的混合标签都是
one-hot形式。
由于第三乘积和第四乘积包含的数据位数相同,因此可以针对第三乘积中的每一位数据,通过公式(2),计算第三乘积中该位数据与第四乘积中该位数据的和值,从而得到混合图像的混合标签。
例如,第三乘积为(0.3,0,0),第四乘积为(0,0.7,0),则混合标签为(0.3,0.7,0)。
由于对两张样本图像按比例进行混合,能够使得样本图像中的对象的色彩和纹理发生一定的改变,因此混合图像的混合标签也需要通过混合得到。即本申请实施例通过对样本图像的训练标签按比例进行混合,能够提高混合标签的准确度。
而且,常规的图像分类模型的训练方法中,训练图像分类模型的训练标签中必然有一个数值为1,使得以此训练得到的图像分类模型对每张图像进行分类时,都会认定该图像属于其中一个预设类别的概率接近1。这种情况下,若向图像分类模型输入一个不属于任何预设类别的图像,图像分类模型仍会输出该图像属于一种预设类别的概率接近1,导致图像分类模型的分类结果不准确。
而本申请实施例中,混合图像的混合标签中不必然包含1,使得以此训练得到的图像分类模型不会认定每张图像都必然属于其中一个预设类别。因此向本申请实施例训练得到的图像分类模型输入一个不属于任何预设类别的图像时,图像分类模型也不会输出该图像属于一种预设类别的概率接近1。因此本申请实施例提高了图像分类模型的准确度,降低了误报率。
本申请实施例中,图1所示的是基于每个小批次(mini batch)的训练过程,该过程称为一次迭代,整体的模型训练过程包括多个纪元(epoch),且每个epoch包括多次迭代。
S101所示的每次迭代过程中,获取多张样本图像以及每张样本图像的训练标签的方式,包括如下步骤:
步骤一、确定上一次从样本图像集包括的多组样本图像中选择的一组样本图像。
其中,样本图像集称为batch,将样本图像集进行切分后得到的每组样本图像称为mini-batch。
步骤二、若上一次选择的一组样本图像不为最后一组样本图像,则获取上一次选择的样本图像组的下一组样本图像以及下一组样本图像的训练标签。
步骤三、若上一次选择的一组样本图像为最后一组样本图像,则获取第一组样本图像以及第一组样本图像的训练标签。
结合步骤一~步骤三举例,假设样本图像集包括500张样本图像,将第1~100张样本图像作为样本图像组1,将第101~200张样本图像作为样本图像组2,将第201~300张样本图像作为样本图像组3,将第301~400张样本图像作为样本图像组4,将第401~500张样本图像作为样本图像组5。每个epoch中按照样本图像组顺序,在每次迭代时获取一组样本图像及其训练标签。即,在每个epoch中,第一次迭代获取样本图像组1及其训练标签,第二次迭代获取样本图像组2及其训练标签,以此类推,第五次迭代获取样本图像组5及其训练标签。
通过上述方法,本申请实施例能够利用样本图像集通过多个epoch对图像分类网络进行训练,使得图像分类网络在训练过程中,逐渐接近并收敛得到最优解,改善图像分类网络的分类性能和泛化能力。
图3所示的其他步骤的具体实现方式,可参考上文或下文描述,此处不再赘述。
本申请实施例中,上述S106基于样本损失值和混合损失值,调整图像分类网络的网络参数的方式,包括如下步骤:
步骤①、获取每张样本图像的正确分类次数以及目标预测概率。
其中,图像分类网络输出的样本图像所属的类别包括样本图像属于每种预设类别的概率。相应的,每张样本图像的正确分类次数为:图像分类网络输出的该样本图像属于每种预设类别的概率中,最大概率对应的类别与该样本图像的训练标签表示的目标类别相同的次数。第i张样本图像的正确分类次数记为ci。
由于每个epoch中,样本图像集中的每张样本图像都会被图像分类网络处理一次,即训练过程中,每张样本图像都会经过图像分类网络多次处理,因此本申请实施例可以统计每张样本图像的正确分类次数。
例如,样本图像1的训练标签为(1,0,0),表示样本图像1所属的目标类别为行人图像。假设图像分类网络针对样本图像1的输出结果为(0.7,0.2,0.1),即图像分类网络认为样本图像1属于行人图像的概率为0.7,属于车辆图像的概率为0.2,属于路障图像的概率为0.1,可见最大概率0.7对应的类别为行人图像,与样本图像1的训练标签表示的目标类别相同,因此本次图像分类网络对图标图像1进行分类的验证结果为正确分类。
本申请实施例中,目标预测概率为:图像分类网络输出的该样本图像属于目标类别的概率。第i张样本图像的目标预测概率记为si。
例如,样本图像1的训练标签为(1,0,0),表示样本图像1所属的目标类别为行人图像,图像分类网络针对样本图像1的输出结果为(0.7,0.2,0.1),则目标分类网络预测样本图像1为行人图像,且目标预测概率为0.7。
步骤②、针对每两张样本图像,根据该两张样本图像的正确分类次数之间的差值,以及该两张样本图像的目标预测概率之间的差值,确定该两张样本图像之间的一致性偏差。
每两张样本图像之间的一致性偏差为:
(3);
其中,表示第i张样本图像和第j张样本图像之间的一致性偏差,表
示第i张样本图像,表示第j张样本图像,表示第i张样本图像的正确分类次数,表示第
j张样本图像的正确分类次数,max和sign分别表示函数运算,表示第i张样本图像的目标
预测概率,表示第j张样本图像的目标预测概率。
后续通过降低一致性损失值,能够使得样本图像的正确分类次数之间的差值与目标预测概率之间的差值相近,从而使得样本图像的正确分类次数与目标预测概率相近,使得目标预测概率更能体现图像分类网络对样本图像的正确分类次数,从而提高了图像分类模型的分类结果的可信度。
步骤③、计算每两张样本图像之间的一致性偏差的和值,作为一致性损失值。
步骤④、根据样本损失值、混合损失值和一致性损失值,确定总损失值。
总损失值为:
(4);
其中,为总损失值,为样本损失值,和为预设的超参数,
为混合损失值,为一致性损失值。
总损失值通过多种损失值组合得到,使得后续通过降低总损失值,对图像分类网络进行优化,实现了对图像分类网络进行多角度优化,使得以此训练得到的图像分类模型对未混合图像以及混合图像的识别准确度更高,且分类结果更可信。
步骤⑤、利用总损失值,调整图像分类网络的网络参数。
通过每次迭代过程,能够以降低总损失为目标,对图像分类网络的网络参数进行一次优化,从而使得图像分类网络在训练过程中,总损失值越来越小,即图像分类网络的输出结果越来越接近图像的标签,因此训练得到的图像分类模型的准确度更高,误报率更低。
在本申请实施例中,上述步骤⑤中利用总损失值,调整图像分类网络的网络参数的方式,包括如下步骤:
步骤1、基于锐度感知最小化优化(Sharpness Awareness Minimization,SAM)算法,即通过公式(5),将总损失值的最小值对应的图像分类网络的网络参数,作为候选网络参数。
(5);
其中,min和max分别表示函数运算,表示图像分类网络的网络参数,表示扰动,表示取二范数,为预设的扰动范围。
步骤2、若当前迭代次数未达到指定次数,则将图像分类网络的网络参数修改为本次计算的候选网络参数。
步骤3、若当前迭代次数达到指定次数,则计算最近的预设次数针对最后一组样本图像确定的候选网络参数的平均值,将图像分类网络的网络参数修改为平均值。
可以采用随机加权平均(Stochastic Weight Averaging,SWA)优化的方式,确定多次迭代得到的候选网络参数的平均值,并将图像分类网络的网络参数修改为平均值。
其中,指定次数可以为一次或多次。例如,指定次数为每个epoch迭代次数的整数倍,使得每个epoch执行一次SWA优化。或者,指定次数为迭代的每一次,使得每次迭代都执行一次SWA优化。或者,还可以设置指定次数为其他次数,本申请实施例对此不作具体限定。
通过SAM算法获得的候选网络参数,能够获得平坦区域的解,保证梯度更新对细小的扰动不敏感,从而获得稳定的全局最优解,保障图像分类网络的分类准确性。而且,在迭代次数达到指定次数时,本申请实施例还能对网络参数进行平均,从而减少异常的网络参数对模型训练的影响,保证图像分类网络的分类准确度。
在本申请实施例中,图像分类网络包括特征提取层和余弦分类器。
其中,特征提取层用于对输入的图像进行特征提取得到图像特征。余弦分类器用于通过公式(6),基于图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率:
(6);
其中,为输入的图像i属于类别k的概率,为预设的超参数,cos表示余弦相似
度函数,为对输入的图像i进行特征提取得到的图像特征,是预设类别k的权重,表示取二范数。
其中,属于图像分类网络的网络参数,可在训练过程中被调整。
常规的图像分类模型一般在特征提取层后加上一层softmax算子,利用softmax算子将图像特征映射到一个概率空间,从而得到输入模型的图像属于每种预设类型的概率。该方式输出的概率中,一般包含一个接近1的概率,但输入图像分类模型的图像可能不属于任何预设类型,使得该方式获得的分类结果不准确。
而本申请实施例中,使用余弦分类器计算输入的图像的图像特征与每个预设类别的权重之间的相似度,并以此确定该图像属于该类别的概率。即图像属于类别的概率,是通过图像与该类别之间的相似度得到的,其中并不必然包含一个接近1的概率,因此保障了图像分类模型的分类结果的准确度。
以下通过实验结果,对使用本申请实施例提供的图像分类模型的训练方法得到的图像分类模型,与常规的图像分类模型的性能进行对比说明。
其中,使用CIFAR100作为模型的样本图像集,CIFAR100包括多张图像,且每张图像均被标注有其所属的类别。图像分类模型的性能指标包括:准确度(Accuracy)和风险覆盖率曲线下的区域(Area Under the Risk-Coverage Curve,AURC)。准确度越高,图像分类模型的分类结果的可信度更高,准确度越低,图像分类模型的分类结果的可信度更低。AURC越低,说明图像分类模型的正确的预测概率和错误的预测概率更能够使用一个分类阈值(threshold)区分开,即图像分类模型的分类结果的可信度更高,AURC越高,图像分类模型的分类结果的可信度更低。
参见表一,表一中基础(Baseline)表示常规的图像分类模型。
Ours w.o L_{mix}表示在本申请提供的模型训练方法的基础上,去除混合图像的相关步骤后,训练得到的图像分类模型。
Ours w.o L_{crl}表示在本申请提供的模型训练方法的基础上,去除一致性损失值的相关步骤后,训练得到的图像分类模型。
Ours w.o cosine classifier表示在本申请提供的模型训练方法的基础上,使用softmax算子代替余弦分类器后,训练得到的图像分类模型。
Ours w.o SAM表示在本申请提供的模型训练方法的基础上,去除SAM优化算法后,训练得到的图像分类模型。
Ours w.o SWA表示在本申请提供的模型训练方法的基础上,去除SWA优化算法后,训练得到的图像分类模型。
Ours表示利用本申请提供的模型训练方法训练得到的图像分类模型。
表一
从表一中可以看出,本申请实施例提供的模型训练方法,能够将图像分类模型的准确率从75.89提升至80.43,有效地提升了图像分类模型的准确率。同时,本申请实施例提供的模型训练方法,还能够将图像分类模型的AURC从69.44降低至45.81,大幅度提升了图像分类模型的置信度。且本申请实施例提供的模型训练方法包括的各个部分对上述提升均有贡献。
而且,本申请实施例提供的模型训练方法与网络架构无关,即本申请实施例中,S102涉及的数据增强、步骤①~步骤④涉及的优化损失和步骤1~3涉及的优化过程,均与网络架构解耦,使得本申请实施例提供的模型训练方法可应用于各种图像分类网络架构,即可对各种图像分类网络进行训练,并提升其分类性能。使得本申请实施例可以应用于各种图像分类场景,例如,自动驾驶、智能安防监控和医学图片分析等。因此本申请实施例提供的图像分类模型的训练方法实施简单、效率高且灵活性高。
基于相同的发明构思,对应于上述方法实施例,本申请实施例还提供了一种图像分类模型的训练装置,如图4所示,该装置包括:获取模块401、增强模块402、分类模块403、确定模块404和调整模块405;
获取模块401,用于获取多张样本图像以及每张样本图像的训练标签,每张样本图像的训练标签表示该样本图像实际所属的类别;
增强模块402,用于针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照指定比例混合,得到混合图像的混合标签;
分类模块403,用于将样本图像和混合图像分别输入图像分类网络,得到图像分类网络输出的样本图像所属的类别和混合图像所属的类别;
确定模块404,用于基于图像分类网络输出的样本图像所属的类别和样本图像的训练标签,确定样本损失值;
确定模块404,还用于基于图像分类网络输出的混合图像所属的类别和混合图像的混合标签,确定混合损失值;
调整模块405,用于基于样本损失值和混合损失值,调整图像分类网络的网络参数,并调用获取模块401执行获取多张样本图像以及每张样本图像的训练标签的步骤,直至图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。
可选的,每张原始样本的尺寸均相同;增强模块402,具体用于:
在0到1范围内采样,得到混合权重;
计算该两张样本图像中的第一样本图像的像素值与混合权重的第一乘积;
计算该两张样本图像中的第二样本图像的像素值与指定权重的第二乘积;其中,指定权重为1与混合权重的差值;
计算第一乘积与第二乘积的和值,作为混合图像。
可选的,增强模块402,具体用于:
计算第一样本图像的训练标签与混合权重的第三乘积;
计算第二样本图像的训练标签与指定权重的第四乘积;
计算第三乘积与第四乘积的和值,作为混合图像的混合标签。
可选的,多次采样获得的混合权重满足贝塔分布。
可选的,获取模块401,具体用于:
确定上一次从样本图像集包括的多组样本图像中选择的一组样本图像;
若上一次选择的一组样本图像不为最后一组样本图像,则获取上一次选择的样本图像组的下一组样本图像以及下一组样本图像的训练标签;
若上一次选择的一组样本图像为最后一组样本图像,则获取第一组样本图像以及第一组样本图像的训练标签。
可选的,图像分类网络输出的样本图像所属的类别包括样本图像属于每种预设类别的概率;调整模块405,具体用于:
获取每张样本图像的正确分类次数以及目标预测概率;其中,正确分类次数为:图像分类网络输出的该样本图像属于每种预设类别的概率中,最大概率对应的类别与该样本图像的训练标签表示的目标类别相同的次数;目标预测概率为:图像分类网络输出的该样本图像属于目标类别的概率;
针对每两张样本图像,根据该两张样本图像的正确分类次数之间的差值,以及该两张样本图像的目标预测概率之间的差值,确定该两张样本图像之间的一致性偏差;
计算每两张样本图像之间的一致性偏差的和值,作为一致性损失值;
根据样本损失值、混合损失值和一致性损失值,确定总损失值;
利用总损失值,调整图像分类网络的网络参数。
可选的,每两张样本图像之间的一致性偏差为:
;
其中,表示第i张样本图像和第j张样本图像之间的一致性偏差,表
示第i张样本图像,表示第j张样本图像,表示第i张样本图像的正确分类次数,表示第
j张样本图像的正确分类次数,max和sign分别表示函数运算,表示第i张样本图像的目标
预测概率,表示第j张样本图像的目标预测概率。
可选的,总损失值为:
;
其中,为总损失值,为样本损失值,和为预设的超参数,
为混合损失值,为一致性损失值。
可选的,调整模块405,具体用于:
基于锐度感知最小化优化算法,将总损失值的最小值对应的图像分类网络的网络参数,作为候选网络参数;
若当前迭代次数未达到指定次数,则将图像分类网络的网络参数修改为本次计算的候选网络参数;
若当前迭代次数达到指定次数,则计算最近的预设次数针对最后一组样本图像确定的候选网络参数的平均值,将图像分类网络的网络参数修改为平均值。
可选的,图像分类网络包括特征提取层和余弦分类器,特征提取层用于对输入的图像进行特征提取得到图像特征,余弦分类器用于基于图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。
本申请实施例还提供了一种电子设备,如图5所示,包括处理器501、通信接口502、存储器503和通信总线504,其中,处理器501,通信接口502,存储器503通过通信总线504完成相互间的通信,
存储器503,用于存放计算机程序;
处理器501,用于执行存储器503上所存放的程序时,实现上述方法实施例中的方法步骤。
上述电子设备提到的通信总线可以是外设部件互连标准(Peripheral ComponentInterconnect,PCI)总线或扩展工业标准结构(Extended Industry StandardArchitecture,EISA)总线等。该通信总线可以分为地址总线、数据总线、控制总线等。为便于表示,图中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
通信接口用于上述电子设备与其他设备之间的通信。
存储器可以包括随机存取存储器(Random Access Memory,RAM),也可以包括非易失性存储器(Non-Volatile Memory,NVM),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(Network Processor,NP)等;还可以是数字信号处理器(Digital SignalProcessor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
在本申请提供的又一实施例中,还提供了一种计算机可读存储介质,该计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现上述任一图像分类模型的训练方法的步骤。
在本申请提供的又一实施例中,还提供了一种包含指令的计算机程序产品,当其在计算机上运行时,使得计算机执行上述实施例中任一图像分类模型的训练方法。
在上述实施例中,可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本申请实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘Solid State Disk (SSD))等。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上所述仅为本申请的较佳实施例,并非用于限定本申请的保护范围。凡在本申请的精神和原则之内所作的任何修改、等同替换、改进等,均包含在本申请的保护范围内。
Claims (22)
1.一种图像分类模型的训练方法,其特征在于,所述方法包括:
获取多张样本图像以及每张样本图像的训练标签,每张样本图像的训练标签表示该样本图像实际所属的类别;
针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签;
将所述样本图像和所述混合图像分别输入图像分类网络,得到所述图像分类网络输出的所述样本图像所属的类别和所述混合图像所属的类别;
基于所述图像分类网络输出的所述样本图像所属的类别和所述样本图像的训练标签,确定样本损失值;
基于所述图像分类网络输出的所述混合图像所属的类别和所述混合图像的混合标签,确定混合损失值;
基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,并返回所述获取多张样本图像以及每张样本图像的训练标签的步骤,直至所述图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。
2.根据权利要求1所述的方法,其特征在于,每张原始样本的尺寸均相同;所述对该两张样本图像按照指定比例混合,得到混合图像,包括:
在0到1范围内采样,得到混合权重;
计算该两张样本图像中的第一样本图像的像素值与所述混合权重的第一乘积;
计算该两张样本图像中的第二样本图像的像素值与指定权重的第二乘积;其中,所述指定权重为1与所述混合权重的差值;
计算所述第一乘积与所述第二乘积的和值,作为所述混合图像。
3.根据权利要求2所述的方法,其特征在于,所述对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签,包括:
计算所述第一样本图像的训练标签与所述混合权重的第三乘积;
计算所述第二样本图像的训练标签与所述指定权重的第四乘积;
计算所述第三乘积与所述第四乘积的和值,作为所述混合图像的混合标签。
4.根据权利要求2所述的方法,其特征在于,多次采样获得的混合权重满足贝塔分布。
5.根据权利要求1所述的方法,其特征在于,所述获取多张样本图像以及每张样本图像的训练标签,包括:
确定上一次从样本图像集包括的多组样本图像中选择的一组样本图像;
若上一次选择的一组样本图像不为最后一组样本图像,则获取上一次选择的样本图像组的下一组样本图像以及所述下一组样本图像的训练标签;
若上一次选择的一组样本图像为最后一组样本图像,则获取第一组样本图像以及所述第一组样本图像的训练标签。
6.根据权利要求5所述的方法,其特征在于,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,包括:
获取每张样本图像的正确分类次数以及目标预测概率;其中,所述正确分类次数为:所述图像分类网络输出的该样本图像属于每种预设类别的概率中,最大概率对应的类别与该样本图像的训练标签表示的目标类别相同的次数;所述目标预测概率为:所述图像分类网络输出的该样本图像属于目标类别的概率;
针对每两张样本图像,根据该两张样本图像的正确分类次数之间的差值,以及该两张样本图像的目标预测概率之间的差值,确定该两张样本图像之间的一致性偏差;
计算每两张样本图像之间的一致性偏差的和值,作为一致性损失值;
根据所述样本损失值、所述混合损失值和所述一致性损失值,确定总损失值;
利用所述总损失值,调整所述图像分类网络的网络参数。
7.根据权利要求6所述的方法,其特征在于,每两张样本图像之间的一致性偏差为:
;
其中,表示第i张样本图像和第j张样本图像之间的一致性偏差,表示第i张样本图像,表示第j张样本图像,表示第i张样本图像的正确分类次数,表示第j张样本图像的正确分类次数,max和sign分别表示函数运算,表示第i张样本图像的目标预测概率,表示第j张样本图像的目标预测概率。
8.根据权利要求6所述的方法,其特征在于,所述总损失值为:
;
其中,为所述总损失值,为所述样本损失值,和为预设的超参数,为所述混合损失值,为所述一致性损失值。
9.根据权利要求6所述的方法,其特征在于,所述利用所述总损失值,调整所述图像分类网络的网络参数,包括:
基于锐度感知最小化优化算法,将所述总损失值的最小值对应的所述图像分类网络的网络参数,作为候选网络参数;
若当前迭代次数未达到指定次数,则将所述图像分类网络的网络参数修改为本次计算的候选网络参数;
若当前迭代次数达到指定次数,则计算最近的预设次数针对最后一组样本图像确定的候选网络参数的平均值,将所述图像分类网络的网络参数修改为所述平均值。
10.根据权利要求1-9任一项所述的方法,其特征在于,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。
11.一种图像分类模型的训练装置,其特征在于,所述装置包括:
获取模块,用于获取多张样本图像以及每张样本图像的训练标签,每张样本图像的训练标签表示该样本图像实际所属的类别;
增强模块,用于针对每两张样本图像,对该两张样本图像按照指定比例混合,得到混合图像,并对该两张样本图像的训练标签按照所述指定比例混合,得到所述混合图像的混合标签;
分类模块,用于将所述样本图像和所述混合图像分别输入图像分类网络,得到所述图像分类网络输出的所述样本图像所属的类别和所述混合图像所属的类别;
确定模块,用于基于所述图像分类网络输出的所述样本图像所属的类别和所述样本图像的训练标签,确定样本损失值;
所述确定模块,还用于基于所述图像分类网络输出的所述混合图像所属的类别和所述混合图像的混合标签,确定混合损失值;
调整模块,用于基于所述样本损失值和所述混合损失值,调整所述图像分类网络的网络参数,并调用所述获取模块执行所述获取多张样本图像以及每张样本图像的训练标签的步骤,直至所述图像分类网络收敛时,将当前的图像分类网络作为图像分类模型。
12.根据权利要求11所述的装置,其特征在于,每张原始样本的尺寸均相同;所述增强模块,具体用于:
在0到1范围内采样,得到混合权重;
计算该两张样本图像中的第一样本图像的像素值与所述混合权重的第一乘积;
计算该两张样本图像中的第二样本图像的像素值与指定权重的第二乘积;其中,所述指定权重为1与所述混合权重的差值;
计算所述第一乘积与所述第二乘积的和值,作为所述混合图像。
13.根据权利要求12所述的装置,其特征在于,所述增强模块,具体用于:
计算所述第一样本图像的训练标签与所述混合权重的第三乘积;
计算所述第二样本图像的训练标签与所述指定权重的第四乘积;
计算所述第三乘积与所述第四乘积的和值,作为所述混合图像的混合标签。
14.根据权利要求12所述的装置,其特征在于,多次采样获得的混合权重满足贝塔分布。
15.根据权利要求11所述的装置,其特征在于,所述获取模块,具体用于:
确定上一次从样本图像集包括的多组样本图像中选择的一组样本图像;
若上一次选择的一组样本图像不为最后一组样本图像,则获取上一次选择的样本图像组的下一组样本图像以及所述下一组样本图像的训练标签;
若上一次选择的一组样本图像为最后一组样本图像,则获取第一组样本图像以及所述第一组样本图像的训练标签。
16.根据权利要求15所述的装置,其特征在于,所述图像分类网络输出的所述样本图像所属的类别包括所述样本图像属于每种预设类别的概率;所述调整模块,具体用于:
获取每张样本图像的正确分类次数以及目标预测概率;其中,所述正确分类次数为:所述图像分类网络输出的该样本图像属于每种预设类别的概率中,最大概率对应的类别与该样本图像的训练标签表示的目标类别相同的次数;所述目标预测概率为:所述图像分类网络输出的该样本图像属于目标类别的概率;
针对每两张样本图像,根据该两张样本图像的正确分类次数之间的差值,以及该两张样本图像的目标预测概率之间的差值,确定该两张样本图像之间的一致性偏差;
计算每两张样本图像之间的一致性偏差的和值,作为一致性损失值;
根据所述样本损失值、所述混合损失值和所述一致性损失值,确定总损失值;
利用所述总损失值,调整所述图像分类网络的网络参数。
17.根据权利要求16所述的装置,其特征在于,每两张样本图像之间的一致性偏差为:
;
其中,表示第i张样本图像和第j张样本图像之间的一致性偏差,表示第i张样本图像,表示第j张样本图像,表示第i张样本图像的正确分类次数,表示第j张样本图像的正确分类次数,max和sign分别表示函数运算,表示第i张样本图像的目标预测概率,表示第j张样本图像的目标预测概率。
18.根据权利要求16所述的装置,其特征在于,所述总损失值为:
;
其中,为所述总损失值,为所述样本损失值,和为预设的超参数,为所述混合损失值,为所述一致性损失值。
19.根据权利要求16所述的装置,其特征在于,所述调整模块,具体用于:
基于锐度感知最小化优化算法,将所述总损失值的最小值对应的所述图像分类网络的网络参数,作为候选网络参数;
若当前迭代次数未达到指定次数,则将所述图像分类网络的网络参数修改为本次计算的候选网络参数;
若当前迭代次数达到指定次数,则计算最近的预设次数针对最后一组样本图像确定的候选网络参数的平均值,将所述图像分类网络的网络参数修改为所述平均值。
20.根据权利要求11-19任一项所述的装置,其特征在于,所述图像分类网络包括特征提取层和余弦分类器,所述特征提取层用于对输入的图像进行特征提取得到图像特征,所述余弦分类器用于基于所述图像特征与每种预设类别的权重之间的余弦相似度,确定输入的图像属于每种预设类别的概率。
21.一种电子设备,其特征在于,包括处理器、通信接口、存储器和通信总线,其中,处理器,通信接口,存储器通过通信总线完成相互间的通信;
存储器,用于存放计算机程序;
处理器,用于执行存储器上所存放的程序时,实现权利要求1-10任一项所述的方法。
22.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质内存储有计算机程序,所述计算机程序被处理器执行时实现权利要求1-10任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410400212.2A CN117994611B (zh) | 2024-04-03 | 2024-04-03 | 一种图像分类模型的训练方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410400212.2A CN117994611B (zh) | 2024-04-03 | 2024-04-03 | 一种图像分类模型的训练方法、装置及电子设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117994611A CN117994611A (zh) | 2024-05-07 |
CN117994611B true CN117994611B (zh) | 2024-07-02 |
Family
ID=90900979
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410400212.2A Active CN117994611B (zh) | 2024-04-03 | 2024-04-03 | 一种图像分类模型的训练方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117994611B (zh) |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115937616A (zh) * | 2023-02-21 | 2023-04-07 | 深圳新视智科技术有限公司 | 图像分类模型的训练方法、系统及移动终端 |
CN117237757A (zh) * | 2023-09-19 | 2023-12-15 | 英特灵达信息技术(深圳)有限公司 | 一种人脸识别模型训练方法、装置、电子设备及介质 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111814810A (zh) * | 2020-08-11 | 2020-10-23 | Oppo广东移动通信有限公司 | 图像识别方法、装置、电子设备及存储介质 |
-
2024
- 2024-04-03 CN CN202410400212.2A patent/CN117994611B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115937616A (zh) * | 2023-02-21 | 2023-04-07 | 深圳新视智科技术有限公司 | 图像分类模型的训练方法、系统及移动终端 |
CN117237757A (zh) * | 2023-09-19 | 2023-12-15 | 英特灵达信息技术(深圳)有限公司 | 一种人脸识别模型训练方法、装置、电子设备及介质 |
Also Published As
Publication number | Publication date |
---|---|
CN117994611A (zh) | 2024-05-07 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112990432B (zh) | 目标识别模型训练方法、装置及电子设备 | |
CN108921206B (zh) | 一种图像分类方法、装置、电子设备及存储介质 | |
CN108898086B (zh) | 视频图像处理方法及装置、计算机可读介质和电子设备 | |
CN108171203B (zh) | 用于识别车辆的方法和装置 | |
JP2021532434A (ja) | 顔特徴抽出モデル訓練方法、顔特徴抽出方法、装置、機器および記憶媒体 | |
CN111368636B (zh) | 目标分类方法、装置、计算机设备和存储介质 | |
CN110909663B (zh) | 一种人体关键点识别方法、装置及电子设备 | |
CN112906823B (zh) | 目标对象识别模型训练方法、识别方法及识别装置 | |
CN114549840B (zh) | 语义分割模型的训练方法和语义分割方法、装置 | |
CN112132206A (zh) | 图像识别方法及相关模型的训练方法及相关装置、设备 | |
CN111325067B (zh) | 违规视频的识别方法、装置及电子设备 | |
CN112001403A (zh) | 一种图像轮廓检测方法及系统 | |
CN111178364A (zh) | 一种图像识别方法和装置 | |
EP4343616A1 (en) | Image classification method, model training method, device, storage medium, and computer program | |
CN110135428B (zh) | 图像分割处理方法和装置 | |
CN117994611B (zh) | 一种图像分类模型的训练方法、装置及电子设备 | |
CN117315310A (zh) | 一种图像识别方法、图像识别模型训练方法及装置 | |
CN112800813B (zh) | 一种目标识别方法及装置 | |
CN111860623A (zh) | 基于改进ssd神经网络的统计树木数量的方法及系统 | |
CN111340140A (zh) | 图像数据集的获取方法、装置、电子设备及存储介质 | |
CN116258873A (zh) | 一种位置信息确定方法、对象识别模型的训练方法及装置 | |
CN108183736B (zh) | 基于机器学习的发射机码字选择方法、装置和发射机 | |
CN113706428B (zh) | 一种图像生成方法及装置 | |
CN115546554A (zh) | 敏感图像的识别方法、装置、设备和计算机可读存储介质 | |
CN112001211A (zh) | 对象检测方法、装置、设备及计算机可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant |