CN109886335A - 分类模型训练方法及装置 - Google Patents
分类模型训练方法及装置 Download PDFInfo
- Publication number
- CN109886335A CN109886335A CN201910129385.4A CN201910129385A CN109886335A CN 109886335 A CN109886335 A CN 109886335A CN 201910129385 A CN201910129385 A CN 201910129385A CN 109886335 A CN109886335 A CN 109886335A
- Authority
- CN
- China
- Prior art keywords
- training
- model
- label
- preliminary classification
- value
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Landscapes
- Image Analysis (AREA)
Abstract
本申请公开了一种分类模型训练方法及装置,通过单独计算每个标签的二值交叉熵,提高所获得的标签的精度。详细地,首先通过获取多个第一训练样本,其中,每个第一训练样本包括第一训练图像以及与该第一训练图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类以及下位分类分别对应的标签;然后,根据多个所述第一训练样本进行机器学习训练,获得初始分类模型;在获得初始分类模型后,分别获取初始分类模型中每个标签的二值交叉熵作为该标签的子误差值;最后,根据每个标签的子误差值获得初始分类模型的总误差值,进而根据所述总误差值调整所述初始分类模型,获得目标分类模型。
Description
技术领域
本申请涉及图像处理技术领域,具体而言,涉及一种分类模型训练方法及装置。
背景技术
现有技术中,在图像分类时,有两大类常见的分类方法,其中一种分类方法就是单标签分类方法,在这种方法中,每张图像只击打一个标签,也就是说,每张图像只分为一个类别,因此,这种分类方法并不能完整地表达图像的语义;另一种分类方法是多标签分类,这种方法中,同一图像可以对应击打多个标签,也就是说,每张图像可以划分为多个类别。在现有的多标签分类算法中,如果同一图像对应的多个标签中存在一个标签所对应的含义是另一个或多个标签所对应含义的上位分类,即这些标签之间存在依赖关系。现有的多标签分类算法中,若识别到的图像分类具有存在依赖关系的多个标签,只有输出其中一个标签作为该图像的分类标签。因此,现有多标签分类方法存在输出的标签精度低的问题。
发明内容
为了克服现有技术中的上述不足,本申请的目的在于提供一种分类模型训练方法,所述方法包括:
获取多个第一训练样本,其中,每个第一训练样本包括第一训练图像以及与该第一训练图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类以及下位分类分别对应的标签;
根据多个所述第一训练样本进行机器学习训练,获得初始分类模型;
分别获取初始分类模型中每个标签的二值交叉熵作为对应的子误差值;
根据所述每个标签的子误差值计算获得所述初始分类模型的总误差值;
根据所述总误差值调整所述初始分类模型,获得目标分类模型。
可选地,所述根据所述总误差值调整所述初始分类模型,获得目标分类模型的步骤包括:
检测所述总误差值是否大于预设值;
若所述总误差值大于预设值,则调整所述初始分类模型的参数,直至所述总误差值小于所述预设值;
若所述总误差值小于所述预设值,则将调整参数后的初始分类模型作为目标分类模型。
可选地,所述获取多个第一训练样本的步骤前,所述方法还包括:
获取多个初始样本,每个所述初始样本包括初始图像以及与该初始图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类的标签以及下位分类的标签;
针对每个初始样本,对该初始样本中的所述初始图像进行变换,获得多个所述第一训练样本。
可选地,所述变换包括随机裁剪、随机翻转、随机颜色调整或随机亮度调整中的至少一种。
可选地,所述根据多个所述第一训练样本进行机器学习训练,获得初始分类模型的步骤包括:
将所述多个第一训练样本输入已训练的预分类模型,对该预分类模型进行再次训练;
对所述预分类模型进行调整,获得所述初始分类模型。
可选地,所述预分类模型包括卷积层、池化层和全连接层,所述对所述预分类模型进行调整,获得所述初始分类模型的步骤包括:
对所述预分类模型的全连接层的参数进行调整,获得中间模型;
分别对中间模型的卷积层、池化层和全连接层的参数进行调整,获得初始分类模型。
可选地,所述根据多个所述第一训练样本进行机器学习训练,获得初始分类模型的步骤前,所述方法还包括:
将第二训练样本输入深度学习框架,获得所述预分类模型;
其中,所述第二训练样本包括第二训练图像以及与每个第二训练图像对应的第二预设数量个标签。
可选地,所述方法还包括:
将待识别图像输入所述目标分类模型,获得所述待识别图像对应的第一预设数量个标签。
本申请的另一目的在于提供一种分类模型训练装置,所述装置包括获取模块、训练模块、计算模块和调整模块;
所述获取模块用于获取多个第一训练样本,其中,每个第一训练样本包括第一训练图像以及与该第一训练图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类的标签以及下位分类的标签;
所述训练模块用于根据多个所述第一训练样本进行机器学习训练,获得初始分类模型;
所述计算模块用于分别获取初始分类模型中每个标签的二值交叉熵作为对应的子误差值,以及
根据所述每个标签的子误差值计算所述初始分类模型的总误差值;
所述调整模块用于根据所述总误差值调整所述初始分类模型,获得目标分类模型。
可选地,所述调整模块用于根据所述总误差值调整所述初始分类模型,获得目标分类模型的步骤包括:
检测所述总误差值是否大于预设值;
若所述总误差值大于预设值,则调整所述初始分类模型的参数,直至所述总误差值小于所述预设值;
若所述总误差值小于所述预设值,则将调整参数后的初始分类模型作为目标分类模型。
相对于现有技术而言,本申请具有以下有益效果:
本申请实施例中,通过为每个图像设置多个标签来进行机器学习训练,获得初始分类模型,并分别获取每个标签的二值交叉熵,从而根据每个标签的二值交叉熵来计算初始分类模型的总误差值,进而根据初始分类模型的总误差来调整该初始分类模型的参数。如此,由于各个标签对应的子误差值是单独计算的,各个标签对应的子误差值不会受到其他标签的影响,使得同一图像可以同时具有一上位分类对应的标签以及该上位分类对应的下位分类对应的标签,也就是说,训练出的目标分类模型的分类精度能够得到提高。
附图说明
为了更清楚地说明本申请实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本申请的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1为本申请实施例提供的分类模型训练设备的结构示意框图;
图2为本申请实施例提供的分类模型训练方法的流程示意图一;
图3为本申请实施例提供的分类模型训练方法的流程示意图二;
图4为本申请实施例提供的分类模型训练方法的流程示意图三;
图5为本申请实施例提供的分类模型训练方法的流程示意图四;
图6为本申请实施例提供的分类模型训练方法的流程示意图五;
图7为本申请实施例提供的分类模型训练方法的流程示意图六;
图8为本申请实施例提供的分类模型训练方法的流程示意图七;
图9为本申请实施例提供的分类模型训练装置的结构示意框图。
图标:100-分类模型训练设备;110-分类模型训练装置;111-获取模块;112-训练模块;113-计算模块;114-调整模块;120-存储器;130-处理器。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述。
在对图像进行分类处理时,会先训练出一个模型,现有技术中,所训练出的模型在对图像进行分类时,一般会存在两种分类方法,其中一种分类方法就是单标签分类方法,在这种方法中,每张图像只击打一个标签,也就是说,每张图像只分为一个类别,因此,这种分类方法并不能完整地表达图像的语义。
现有技术中的另一种分类方法是多标签分类,这种方法中,同一图像可以对应击打多个标签,也就是说,每张图像可以划分为多个类别。然而,在多标签分类算法中,如果同一图像对应的多个标签中存在一个标签所对应的含义是另一个或多个标签所对应含义的上位分类,那么,就说明这些标签之间会相互依赖,这些相互依赖的标签中,就只有一个会存在于该图像的输出标签中。例如,图像对应有一级标签护肤,护肤标签下对应有二级标签护肤效果,护肤效果的标签下对应有三级标签美白、补水等,这种情况下,图像就不可能同时被标注为护肤标签、护肤效果标签、美白标签和补水标签,而只能被标注为这四种标签中的一种标签,因此,现有的多标签分类方法会存在输出的标签精度低的问题。
请参见图1,图1是本申请实施例提供的分类模型训练设备100的结构示意框图,所述分类模型训练设备100包括分类模型训练装置110、存储器120、处理器130。
所述存储器120、处理器130各元件之间直接或间接地电性连接,以实现数据的传输或者交互。例如,这些元件相互之间可通过一条或多条通讯总线或信号线实现电性连接。所述分类模型训练装置110包括至少一个可以软件或固件(firmware)的形式存储于所述存储器120中或固化在所述分类模型训练设备100的操作系统(operating system,OS)中的软件功能模块。所述存储器120存储有可执行模块。所述处理器130用于执行所述存储器120中存储的可执行模块,例如所述分类模型训练设备100所包括的软件功能模块及计算机程序等。
其中,所述存储器120可以是,但不限于,随机存取存储器(Random AccessMemory,RAM),只读存储器(Read Only Memory,ROM),可编程只读存储器(ProgrammableRead-Only Memory,PROM),可擦除只读存储器(Erasable Programmable Read-OnlyMemory,EPROM),电可擦除只读存储器(Electric Erasable Programmable Read-OnlyMemory,EEPROM)等。其中,存储器120用于存储程序,处理器130在接收到执行指令后,执行所述程序。所述处理器130以及其他可能的组件对存储器120的访问可在存储控制器的控制下进行。
所述处理器130可能是一种集成电路芯片,具有信号的处理能力。上述的处理器130可以是通用处理器130,包括中央处理器130(Central Processing Unit,CPU)、网络处理器130(Network Processor,NP)等;还可以是数字信号处理器130(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。可以实现或者执行本申请实施例中公开的各方法、步骤及逻辑框图。通用处理器130可以是微处理器130或者该处理器130也可以是任何常规的处理器130等。
请参见图2,图2为分类模型训练方法的流程示意图。所述分类模型训练方法包括步骤S110-步骤S150。
步骤S110,获取多个第一训练样本,其中,每个第一训练样本包括第一训练图像以及与该第一训练图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类以及下位分类分别对应的标签。
本实施例用于获取用于进行后续机器学习训练的第一训练样本。其中,每个第一训练样本都对应第一预设数量个标签,每个标签对应一个类别,也就是说,第一训练样本中的第一训练图像可能分为每个第一预设数量个标签中每个对应的类别。
本实施例中的标签,可以采用One-Hot模式表示,也就是说,将样本的多个标签组成多为向量,对于命中的标签,将其向量的值设置为1,否则,将其向量的值设置为0。以包含240个类别的第一训练样本为例,该样本对应的标签可以形成240维的向量,在该样本的图像中,含有护肤、精华、乳液面霜、护肤效果、补水、美白等六个标签,则在240维向量中,这几个标签对应的索引标记为1,其余标记为0,则240维中有6维的值为1,其余234维的值为0。
步骤S120,根据多个所述第一训练样本进行机器学习训练,获得初始分类模型。
本实施例用于将多个第一训练样本进行机器学习训练,以获得初始分类模型。
步骤S130,分别获取初始分类模型中每个标签的二值交叉熵作为对应的子误差值。
具体地,分别获取每个标签的交叉熵,然后将每个标签的交叉熵输入Sigmoid函数,如此,便可以将每个交叉熵转化为0-1之间的数值,获得标签对应的二值交叉熵。本实施例中,交叉熵的计算公式为:
CEi=-ti*log(si)-(1-ti)*log(1-si)
Sigmoid函数如下:
其中,i是标签的编号,CEi是第i个标签的交叉熵,ti是图像中第i个已有标签的真值(Groundtruth),例如,One-hot模式中,标签的真值可能为0或1。si是算法预测的标签,即si就是初始分类模型的预测值通过Sigmoid激活函数之后,所得到的输出值。x为参数,f(x)的取值范围是0~1之间。这样,就将x的任意正负值映射到0~1的区间中,例如,要将交叉熵映射到0~1之间,就可以将CEi作为x。
本实施例用于分别计算每个标签对应的二值交叉熵,也就是分别计算每个标签对应的子误差值。
步骤S140,根据所述每个标签的子误差值获得所述初始分类模型的总误差值。
本实施例用于根据每个标签的子误差值计算初始分类模型的总误差值,具体地,可以将各个标签的子误差值,也就是各个标签的二值交叉熵相加,从而获得初始分类模型的总误差值。
步骤S150,根据所述总误差值调整所述初始分类模型,获得目标分类模型。
本实施例用于根据总误差值调整初始分类模型,以获得调整后的分类模型。
请参见图3,本实施例中,可选地,步骤S150包括步骤S151-步骤S153。
步骤S151,检测所述总误差值是否大于预设值。
步骤S152,若所述总误差值大于预设值,则调整所述初始分类模型的参数,直至所述总误差值小于所述预设值。
也就是说,本实施例用于在初始分类模型的总误差值大于预设值的情况下,调整该初始分类模型的参数,并在调整后的分类模型的总误差值大于预设值的情况下,反复对该次调整后的分类模型的参数进行调整,以获得总误差值小于预设值的模型。
步骤S153,若所述总误差值小于所述预设值,则将调整参数后的初始分类模型作为目标分类模型。
本实施例用于在初始分类模型或者某次调整后的分类模型的总误差值小于预设值时,将该初始分类模型或者该次调整后的分类模型作为中间模型,以不断调整该分类模型的误差,以获得目标分类模型。
请参见图4,本实施例中,可选地,所述步骤S110前,所述方法还包括步骤S210和步骤S220。
步骤S210,获取多个初始样本,每个所述初始样本包括初始图像以及与该初始图像对应的第一预设数量个标签,所述第一预设数量个标签包括图像内容对应的上位分类的标签以及下位分类的标签。
本实施例中的初始样本为多个,例如,可以为60万个。其中,每个初始样本可以来源于社交网站等的图像,例如,用户上传到图像社区的真实图像。初始样本中的标签数量可以为多个,例如240个,也就是说同一初始样本的图像对应多个标签,例如,同一初始样本可以包括护肤、精华、乳液面霜、护肤效果、补水、美白等多个标签。也就是说,每个初始样本包括一张图像以及与该图像对应的多个标签。其中,每个标签对应一个类别,所述标签可以由人工标注完成。初始样本的多个标签中,不同标签之间可以存在层级关系,也就是说,一个标签所对应的类别可以包括其他一些标签所对应的类别。例如,多个标签中的至少部分标签可以包括三个层级,其中一级标签是护肤类别,二级标签是包括于护肤类别的护肤效果,三级标签是美白、补水等。
步骤S220,通过对初始样本的图片进行变换,获得多个第一训练样本。
具体地,针对每个初始样本,对该初始样本中的所述初始图像进行变换,获得多个所述第一训练样本。
也就是说,本实施例用于对每个初始样本中的初始图像分别进行多类变换,通过每一类变换,便获得与一个与该初始样本对应的第一训练样本,进过多类变换,便获得与该初始样本对应的多个第一训练样本。对每个初始样本中的图片都进行同样类型的变换后,便可获得多个第一训练样本。其中,每类变换包括随机裁剪、随机翻转、随机颜色调整或随机亮度调整中的至少一种。
例如,在一次变换过程中,可以对一个初始样本中的初始图像同时进行随机裁剪、随机翻转、随机颜色调整和随机亮度调整。此时,进行变换的具体步骤请参见图5,本实施例中,可选地,步骤S220包括步骤S221-步骤S225。
步骤S221,按照预设尺寸对所述初始图像进行随机裁剪,获得第一中间图像。
本实施例中用于对图像进行随机尺寸裁剪(Random Resized Crop),也就是说,按照预设尺寸随机对初始图像进行裁剪,例如,预设尺寸可以是224*224。
步骤S222,对所述第一中间图像进行随机翻转,获得第二中间图像。
本实施例用于对图像进行随机左右翻转(Random Flip Left Right),也就是随机地将第一中间图像向左或者向右翻转。其中,左或者右均为图像正放时,图像的左边或者右边。
步骤S223,对第二中间图像进行随机颜色调整,获得第三中间图像。
本实施例用于随机颜色抖动(Random Color Jitter),也就是用于在预设的调整范围内,随机地调整图像(第二图像)的明亮度(Brightness)、对比度(Contrast)或者饱和度(Saturation)。例如,明亮度的调整范围可以是-0.4~0.4倍,对比度的调整范围可以是-0.4~0.4倍,饱和度的调整范围可以是-0.4~0.4倍。
步骤S224,对第三中间图像进行随机亮度调整,获得调整亮度后的第一训练图像。
本实施例用于随机亮度(Random Lighting),即随机调整图像(第三中间图像)亮度(Lighting),并随机添加基于主成分分析(Principal Component Analysis,简称PCA)的图像噪声,例如,可以随机添加0.1倍的基于主成分分析的图像噪声。
步骤S225,根据第一训练图像进行标签标注获得所述第一训练样本。
本实施例中,所述多个初始样本构成数据集,本实施例用于对初始样本中的初始图像进行变换,也就是说,用于扩充已有的数据集,即数据扩充(Data Augmentation)。变换后的图像与变换前的图像的标签一致,也就是说,每个初始样本变化为第一训练样本后,只是图像进行相应的变换,其对应的标签并不发生变化。
本实施例中,在对初始图像进行变换时,步骤S221-步骤S224也可以按照其他顺序进行。
本实施例可以对初始训练样本的图像进行处理,可以增大用于训练的数据集,避免过拟合,进而增加通过第一训练样本训练得到的模型的泛化性。
本实施中,在步骤S225前,还可以对所述第一训练图像进行归一化和正则化,以获得新的第一训练图像。
以RGB图像为例,红绿蓝(RGB)通道的像素值范围为0~255,将每个像素的红绿蓝(RGB)通道中的各个通道除以255,从而将该通道的像素值转化为0~1之间的值。在将图像归一化后,将归一化后的每个通道的像素值分别减去该通道对应的均值,然后再除以该通道对应的标准差。其中,所述均值和所述标准差均是已知的图像数据集的均值和标准差。例如,公开图像数据集ImageNet中,红绿蓝(RGB)三个通道对应的均值的统计值分别为0.485、0.456、0.406,标准差的统计值对应为0.229、0.224、0.225。
在实际对图像进行变换的过程中,对图像进行变化的各个步骤的顺序可以交换。
请参照图6,本实施例中,可选地,步骤S120包括步骤S121-步骤S122。
步骤S121,利用多个第一训练样本对已训练的预分类模型进行再次训练。
具体地,将所述多个第一训练样本输入已训练的预分类模型,对该预分类模型进行再次训练。
步骤S122,对所述预分类模型进行调整,获得所述初始分类模型。
本实施例用于将第一训练样本输入已训练的预分类模型,从而对该预分类模型进行进一步的训练。也就是说,本实施中采用迁移学习获得初始分类模型。采用迁移学习的方式,利用第一训练样本在已训练的预分类模型上再次训练以获得初始分类模型。如此,便可以复用已有的参数,快速且精准地建立初始分类模型,也就是说,本实施例能够减少训练初始分类模型的时间。
请参照图7,本实施例中,可选地,所述预分类模型包括卷积层、池化层和全连接层,步骤S122包括步骤S1221-步骤S1222。
步骤S1221,对所述预分类模型的全连接层的参数进行调整,获得中间模型。
步骤S1222,对中间模型的参数进行调整,获得初始分类模型。
具体地,分别对中间模型的卷积层、池化层和全连接层的参数进行调整,获得初始分类模型。
本实施例中,中间模型为预分类模型的全连接层参数调整后转化得到的,也就是说,中间模型为调整后的预分类模型。先对全连接层的参数进行调整,能够对分类部分(全连接层)的最后一层重定义,将输出的类别改为所需预测的类别,如240个类别就是240维的输出值。在训练的过程中,先冻结(即停止训练)卷积部分(卷积层和池化层),只是训练分类部分;然后再将卷积部分和分类部分全部进行训练。也就是说,分类部分被重新定义,参数随机设置,需要优先训练,避免干扰卷积部分。这样,初始分类模型训练的时间较短,收敛较快,精度较高。
可选地,步骤S120前,所述方法还包括,将第二训练样本输入深度学习框架,获得所述预分类模型。
其中,所述第二训练样本包括第二图像以及与每个第二图像对应的第二预设数量个标签。
本实施例中,所采用的深度学习框架可以是各种各样的主流算法框架,例如VGG系列、ResNet系列、Inception系列和MobileNet系列等。本实施例中,采用MobileNet系列的深度学习框架具有高速的特点。此外,本实施例中,采用深度学习框架还具有以下优点:深度学习框架直接以底层的图像的像素点为基础,通过卷积的方式自动提取特征,不需要特征工程,可以避免算法设计的工作量较多,以及因挑选特征所导致的性能误差。深度学习框架逐层抽象数据的特征,底部数据通过逐层抽象,逐渐转变为高层特征,由像素转换为纹理,由纹理转换为局部,通过深度学习框架获得的模型可以学习图像的语义信息。深度学习框架的特征不断地执行非线性变换,通过卷积、池化、ReLU、Dropout和BN等操作,避免模型过拟合,提升泛化性,即使模型通过训练数据,可以更好地理解未知图像。深度学习框架的学习能力较强,针对于海量的大数据,可以通过增加模型复杂度的方式,完成对于数据的理解,在数据量级较大的情况下,深度学习框架更有优势。
请参照图8,本实施例中,可选地,所述方法还包括步骤S310。
步骤S310,将待识别图像输入所述目标分类模型,获得所述待识别图像对应的第一预设数量个标签。
本实施例中,待识别图像可以是大型基准数据集中的图像,例如ImageNet数据集中的图像。
以分类类别为240的目标分类模型为例,将待识别图像输入目标分类模型后,待识别图像的像素点首先会进行正则化处理,该处理过程可以参见前面对正则化过程的描述。
目标分类模型的最后一层会输出240维的向量,也就是与240个类别分别对应的概率所组成的向量。使用Sigmoid函数,将每一个类别的概率(浮点值)映射到0至1的区间,即概率值为0~1的浮点数,再使用阈值,例如0.5,将向量中的值映射后的值,大于阈值0.5的设置为1,小于阈值0.5的设置为0。其中,类别的名称与向量的位置一一对应,向量值为1的位置,则对应为图像所属的类别。多个1,则图像对应可以属于多个类别。也就是,输入一张图像,经过目标分类模型,输出多个类别。
请参照图9,本申请的另一目的在于提供一种分类模型训练装置110,所述装置包括获取模块111、训练模块112、计算模块113和调整模块114。所述分类模型训练装置110包括一个可以软件或固件的形式存储于所述存储器120中或固化在所述分类模型训练设备100的操作系统(operating system,OS)中的软件功能模块。
所述获取模块111用于获取多个第一训练样本,其中,每个第一训练样本包括第一训练图像以及与该第一训练图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类的标签以及下位分类的标签;
本实施例中的获取模块111用于执行步骤S110,关于所述获取模块111的具体描述可参照对所述步骤S110的描述。
所述训练模块112用于根据多个所述第一训练样本进行机器学习训练,获得初始分类模型。
本实施例中的训练模块112用于执行步骤S120,关于所述训练模块112的具体描述可参照对所述步骤S120的描述。
所述计算模块113用于分别获取初始分类模型中每个标签的二值交叉熵作为对应的子误差值,以及
根据所述每个标签的子误差值计算所述初始分类模型的总误差值。
本实施例中的计算模块113用于执行步骤S130-步骤S140,关于所述计算模块113的具体描述可参照对所述步骤S130-步骤S140的描述。
所述调整模块114用于根据所述总误差值调整所述初始分类模型,获得目标分类模型。
本实施例中的调整模块114用于执行步骤S150,关于所述调整模块114的具体描述可参照对所述步骤S150的描述。
可选地,所述调整模块114用于根据所述总误差值调整所述初始分类模型,获得目标分类模型的步骤包括:
检测所述总误差值是否大于预设值。
若所述总误差值大于预设值,则调整所述初始分类模型的参数,直至所述总误差值小于所述预设值。
若所述总误差值小于所述预设值,则将调整参数后的初始分类模型作为目标分类模型。
综上所述,本申请实施例中,通过为每个图像设置多个标签来进行机器学习训练,获得初始分类模型,并分别获取每个标签的二值交叉熵,从而根据每个标签的二值交叉熵来计算初始分类模型的总误差值,进而根据初始分类模型的总误差值来调整该初始分类模型的参数。由于各个标签对应的子误差值是单独计算的,因此,各个标签对应的子误差值不会受到其他标签的影响,因此,同一图像可以同时标记为一上位分类对应的标签以及该上位分类对应的下位分类对应的标签,也就是说,训练出的目标分类模型的分类精度能够得到提高。
在本申请所提供的实施例中,应该理解到,所揭露的装置和方法,也可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的.
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以所述权利要求的保护范围为准。
Claims (10)
1.一种分类模型训练方法,其特征在于,所述方法包括:
获取多个第一训练样本,其中,每个第一训练样本包括第一训练图像以及与该第一训练图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类以及下位分类分别对应的标签;
根据多个所述第一训练样本进行机器学习训练,获得初始分类模型;
分别获取初始分类模型中每个标签的二值交叉熵作为对应的子误差值;
根据所述每个标签的子误差值计算获得所述初始分类模型的总误差值;
根据所述总误差值调整所述初始分类模型,获得目标分类模型。
2.根据权利要求1所述的分类模型训练方法,其特征在于,所述根据所述总误差值调整所述初始分类模型,获得目标分类模型的步骤包括:
检测所述总误差值是否大于预设值;
若所述总误差值大于预设值,则调整所述初始分类模型的参数,直至所述总误差值小于所述预设值;
若所述总误差值小于所述预设值,则将调整参数后的初始分类模型作为目标分类模型。
3.根据权利要求1所述的分类模型训练方法,其特征在于,所述获取多个第一训练样本的步骤前,所述方法还包括:
获取多个初始样本,每个所述初始样本包括初始图像以及与该初始图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类的标签以及下位分类的标签;
针对每个初始样本,对该初始样本中的所述初始图像进行变换,获得多个所述第一训练样本。
4.根据权利要求3所述的分类模型训练方法,其特征在于,所述变换包括随机裁剪、随机翻转、随机颜色调整或随机亮度调整中的至少一种。
5.根据权利要求1所述的分类模型训练方法,其特征在于,所述根据多个所述第一训练样本进行机器学习训练,获得初始分类模型的步骤包括:
将所述多个第一训练样本输入已训练的预分类模型,对该预分类模型进行再次训练;
对所述预分类模型进行调整,获得所述初始分类模型。
6.根据权利要求5所述的分类模型训练方法,其特征在于,所述预分类模型包括卷积层、池化层和全连接层,所述对所述预分类模型进行调整,获得所述初始分类模型的步骤包括:
对所述预分类模型的全连接层的参数进行调整,获得中间模型;
分别对中间模型的卷积层、池化层和全连接层的参数进行调整,获得初始分类模型。
7.根据权利要求6所述的分类模型训练方法,其特征在于,所述根据多个所述第一训练样本进行机器学习训练,获得初始分类模型的步骤前,所述方法还包括:
将第二训练样本输入深度学习框架,获得所述预分类模型;
其中,所述第二训练样本包括第二训练图像以及与每个第二训练图像对应的第二预设数量个标签。
8.根据权利要求1所述的分类模型训练方法,其特征在于,所述方法还包括:
将待识别图像输入所述目标分类模型,获得所述待识别图像对应的第一预设数量个标签。
9.一种分类模型训练装置,其特征在于,所述装置包括获取模块、训练模块、计算模块和调整模块;
所述获取模块用于获取多个第一训练样本,其中,每个第一训练样本包括第一训练图像以及与该第一训练图像对应的第一预设数量个标签,所述第一预设数量个标签包括与图像内容对应的上位分类的标签以及下位分类的标签;
所述训练模块用于根据多个所述第一训练样本进行机器学习训练,获得初始分类模型;
所述计算模块用于分别获取初始分类模型中每个标签的二值交叉熵作为对应的子误差值,以及
根据所述每个标签的子误差值计算所述初始分类模型的总误差值;
所述调整模块用于根据所述总误差值调整所述初始分类模型,获得目标分类模型。
10.根据权利要求9所述的分类模型训练装置,其特征在于,所述调整模块用于根据所述总误差值调整所述初始分类模型,获得目标分类模型的步骤包括:
检测所述总误差值是否大于预设值;
若所述总误差值大于预设值,则调整所述初始分类模型的参数,直至所述总误差值小于所述预设值;
若所述总误差值小于所述预设值,则将调整参数后的初始分类模型作为目标分类模型。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910129385.4A CN109886335B (zh) | 2019-02-21 | 2019-02-21 | 分类模型训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910129385.4A CN109886335B (zh) | 2019-02-21 | 2019-02-21 | 分类模型训练方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN109886335A true CN109886335A (zh) | 2019-06-14 |
CN109886335B CN109886335B (zh) | 2021-11-26 |
Family
ID=66928678
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201910129385.4A Active CN109886335B (zh) | 2019-02-21 | 2019-02-21 | 分类模型训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN109886335B (zh) |
Cited By (11)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110458245A (zh) * | 2019-08-20 | 2019-11-15 | 图谱未来(南京)人工智能研究院有限公司 | 一种多标签分类模型训练方法、数据处理方法及装置 |
CN110598869A (zh) * | 2019-08-27 | 2019-12-20 | 阿里巴巴集团控股有限公司 | 基于序列模型的分类方法、装置、电子设备 |
CN111046932A (zh) * | 2019-12-03 | 2020-04-21 | 内蒙古拜欧牧业科技有限公司 | 模型训练方法、肉类鉴别方法、装置、终端和存储介质 |
CN111160429A (zh) * | 2019-12-17 | 2020-05-15 | 平安银行股份有限公司 | 图像检测模型的训练方法、图像检测方法、装置及设备 |
CN111275107A (zh) * | 2020-01-20 | 2020-06-12 | 西安奥卡云数据科技有限公司 | 一种基于迁移学习的多标签场景图像分类方法及装置 |
CN111582409A (zh) * | 2020-06-29 | 2020-08-25 | 腾讯科技(深圳)有限公司 | 图像标签分类网络的训练方法、图像标签分类方法及设备 |
CN111639520A (zh) * | 2020-04-14 | 2020-09-08 | 北京迈格威科技有限公司 | 图像处理、模型训练方法、装置和电子设备 |
CN111652320A (zh) * | 2020-06-10 | 2020-09-11 | 创新奇智(上海)科技有限公司 | 一种样本分类方法、装置、电子设备及存储介质 |
CN112084861A (zh) * | 2020-08-06 | 2020-12-15 | 中国科学院空天信息创新研究院 | 模型训练方法、装置、电子设备和存储介质 |
CN112241452A (zh) * | 2020-10-16 | 2021-01-19 | 百度(中国)有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN113222043A (zh) * | 2021-05-25 | 2021-08-06 | 北京有竹居网络技术有限公司 | 一种图像分类方法、装置、设备及存储介质 |
Citations (14)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104166706A (zh) * | 2014-08-08 | 2014-11-26 | 苏州大学 | 基于代价敏感主动学习的多标签分类器构建方法 |
CN104881689A (zh) * | 2015-06-17 | 2015-09-02 | 苏州大学张家港工业技术研究院 | 一种多标签主动学习分类方法及系统 |
CN105138973A (zh) * | 2015-08-11 | 2015-12-09 | 北京天诚盛业科技有限公司 | 人脸认证的方法和装置 |
CN105426908A (zh) * | 2015-11-09 | 2016-03-23 | 国网冀北电力有限公司信息通信分公司 | 一种基于卷积神经网络的变电站属性分类方法 |
CN105740402A (zh) * | 2016-01-28 | 2016-07-06 | 百度在线网络技术(北京)有限公司 | 数字图像的语义标签的获取方法及装置 |
CN105868773A (zh) * | 2016-03-23 | 2016-08-17 | 华南理工大学 | 一种基于层次随机森林的多标签分类方法 |
CN107004363A (zh) * | 2014-12-10 | 2017-08-01 | 三菱电机株式会社 | 图像处理装置及车载显示系统及显示装置及图像处理方法及图像处理程序 |
CN108319980A (zh) * | 2018-02-05 | 2018-07-24 | 哈工大机器人(合肥)国际创新研究院 | 一种基于gru的递归神经网络多标签学习方法 |
CN108664924A (zh) * | 2018-05-10 | 2018-10-16 | 东南大学 | 一种基于卷积神经网络的多标签物体识别方法 |
CN108776808A (zh) * | 2018-05-25 | 2018-11-09 | 北京百度网讯科技有限公司 | 一种用于检测钢包溶蚀缺陷的方法和装置 |
CN109196514A (zh) * | 2016-02-01 | 2019-01-11 | 西-奥特私人有限公司 | 图像分类和标记 |
CN109190482A (zh) * | 2018-08-06 | 2019-01-11 | 北京奇艺世纪科技有限公司 | 多标签视频分类方法及系统、系统训练方法及装置 |
CN109241835A (zh) * | 2018-07-27 | 2019-01-18 | 上海商汤智能科技有限公司 | 图像处理方法及装置、电子设备和存储介质 |
CN109325148A (zh) * | 2018-08-03 | 2019-02-12 | 百度在线网络技术(北京)有限公司 | 生成信息的方法和装置 |
-
2019
- 2019-02-21 CN CN201910129385.4A patent/CN109886335B/zh active Active
Patent Citations (14)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104166706A (zh) * | 2014-08-08 | 2014-11-26 | 苏州大学 | 基于代价敏感主动学习的多标签分类器构建方法 |
CN107004363A (zh) * | 2014-12-10 | 2017-08-01 | 三菱电机株式会社 | 图像处理装置及车载显示系统及显示装置及图像处理方法及图像处理程序 |
CN104881689A (zh) * | 2015-06-17 | 2015-09-02 | 苏州大学张家港工业技术研究院 | 一种多标签主动学习分类方法及系统 |
CN105138973A (zh) * | 2015-08-11 | 2015-12-09 | 北京天诚盛业科技有限公司 | 人脸认证的方法和装置 |
CN105426908A (zh) * | 2015-11-09 | 2016-03-23 | 国网冀北电力有限公司信息通信分公司 | 一种基于卷积神经网络的变电站属性分类方法 |
CN105740402A (zh) * | 2016-01-28 | 2016-07-06 | 百度在线网络技术(北京)有限公司 | 数字图像的语义标签的获取方法及装置 |
CN109196514A (zh) * | 2016-02-01 | 2019-01-11 | 西-奥特私人有限公司 | 图像分类和标记 |
CN105868773A (zh) * | 2016-03-23 | 2016-08-17 | 华南理工大学 | 一种基于层次随机森林的多标签分类方法 |
CN108319980A (zh) * | 2018-02-05 | 2018-07-24 | 哈工大机器人(合肥)国际创新研究院 | 一种基于gru的递归神经网络多标签学习方法 |
CN108664924A (zh) * | 2018-05-10 | 2018-10-16 | 东南大学 | 一种基于卷积神经网络的多标签物体识别方法 |
CN108776808A (zh) * | 2018-05-25 | 2018-11-09 | 北京百度网讯科技有限公司 | 一种用于检测钢包溶蚀缺陷的方法和装置 |
CN109241835A (zh) * | 2018-07-27 | 2019-01-18 | 上海商汤智能科技有限公司 | 图像处理方法及装置、电子设备和存储介质 |
CN109325148A (zh) * | 2018-08-03 | 2019-02-12 | 百度在线网络技术(北京)有限公司 | 生成信息的方法和装置 |
CN109190482A (zh) * | 2018-08-06 | 2019-01-11 | 北京奇艺世纪科技有限公司 | 多标签视频分类方法及系统、系统训练方法及装置 |
Cited By (17)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110458245A (zh) * | 2019-08-20 | 2019-11-15 | 图谱未来(南京)人工智能研究院有限公司 | 一种多标签分类模型训练方法、数据处理方法及装置 |
CN110598869A (zh) * | 2019-08-27 | 2019-12-20 | 阿里巴巴集团控股有限公司 | 基于序列模型的分类方法、装置、电子设备 |
CN110598869B (zh) * | 2019-08-27 | 2024-01-19 | 创新先进技术有限公司 | 基于序列模型的分类方法、装置、电子设备 |
CN111046932A (zh) * | 2019-12-03 | 2020-04-21 | 内蒙古拜欧牧业科技有限公司 | 模型训练方法、肉类鉴别方法、装置、终端和存储介质 |
CN111160429B (zh) * | 2019-12-17 | 2023-09-05 | 平安银行股份有限公司 | 图像检测模型的训练方法、图像检测方法、装置及设备 |
CN111160429A (zh) * | 2019-12-17 | 2020-05-15 | 平安银行股份有限公司 | 图像检测模型的训练方法、图像检测方法、装置及设备 |
CN111275107A (zh) * | 2020-01-20 | 2020-06-12 | 西安奥卡云数据科技有限公司 | 一种基于迁移学习的多标签场景图像分类方法及装置 |
CN111639520A (zh) * | 2020-04-14 | 2020-09-08 | 北京迈格威科技有限公司 | 图像处理、模型训练方法、装置和电子设备 |
CN111639520B (zh) * | 2020-04-14 | 2023-12-08 | 天津极豪科技有限公司 | 图像处理、模型训练方法、装置和电子设备 |
CN111652320A (zh) * | 2020-06-10 | 2020-09-11 | 创新奇智(上海)科技有限公司 | 一种样本分类方法、装置、电子设备及存储介质 |
CN111652320B (zh) * | 2020-06-10 | 2022-08-09 | 创新奇智(上海)科技有限公司 | 一种样本分类方法、装置、电子设备及存储介质 |
CN111582409B (zh) * | 2020-06-29 | 2023-12-26 | 腾讯科技(深圳)有限公司 | 图像标签分类网络的训练方法、图像标签分类方法及设备 |
CN111582409A (zh) * | 2020-06-29 | 2020-08-25 | 腾讯科技(深圳)有限公司 | 图像标签分类网络的训练方法、图像标签分类方法及设备 |
CN112084861A (zh) * | 2020-08-06 | 2020-12-15 | 中国科学院空天信息创新研究院 | 模型训练方法、装置、电子设备和存储介质 |
CN112241452A (zh) * | 2020-10-16 | 2021-01-19 | 百度(中国)有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN112241452B (zh) * | 2020-10-16 | 2024-01-05 | 百度(中国)有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN113222043A (zh) * | 2021-05-25 | 2021-08-06 | 北京有竹居网络技术有限公司 | 一种图像分类方法、装置、设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN109886335B (zh) | 2021-11-26 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109886335A (zh) | 分类模型训练方法及装置 | |
CN111027493B (zh) | 一种基于深度学习多网络软融合的行人检测方法 | |
CN109583425A (zh) | 一种基于深度学习的遥感图像船只集成识别方法 | |
CN110059586B (zh) | 一种基于空洞残差注意力结构的虹膜定位分割系统 | |
CN110390251A (zh) | 一种基于多神经网络模型融合处理的图像文字语义分割方法 | |
CN106650806A (zh) | 一种用于行人检测的协同式深度网络模型方法 | |
CN110399821B (zh) | 基于人脸表情识别的顾客满意度获取方法 | |
CN109711448A (zh) | 基于判别关键域和深度学习的植物图像细粒度分类方法 | |
CN104915972A (zh) | 图像处理装置、图像处理方法以及程序 | |
CN105825168B (zh) | 一种基于s-tld的川金丝猴面部检测和追踪方法 | |
CN114998220B (zh) | 一种基于改进的Tiny-YOLO v4自然环境下舌像检测定位方法 | |
CN109145964B (zh) | 一种实现图像颜色聚类的方法和系统 | |
CN111914797A (zh) | 基于多尺度轻量级卷积神经网络的交通标志识别方法 | |
CN111680739A (zh) | 一种目标检测和语义分割的多任务并行方法及系统 | |
CN111553414A (zh) | 一种基于改进Faster R-CNN的车内遗失物体检测方法 | |
CN111627080A (zh) | 基于卷积神经与条件生成对抗性网络的灰度图像上色方法 | |
CN110827265A (zh) | 基于深度学习的图片异常检测方法 | |
CN114548208A (zh) | 一种基于YOLOv5改进的植物种子实时分类检测方法 | |
CN114821022A (zh) | 融合主观逻辑和不确定性分布建模的可信目标检测方法 | |
CN104794726B (zh) | 一种水下图像并行分割方法及装置 | |
CN116434012A (zh) | 一种基于边缘感知的轻量型棉铃检测方法及系统 | |
CN112446417B (zh) | 基于多层超像素分割的纺锤形果实图像分割方法及系统 | |
CN113643297A (zh) | 一种基于神经网络的计算机辅助牙龄分析方法 | |
CN110119739A (zh) | 一种冰晶图片的自动分类方法 | |
CN114283083B (zh) | 一种基于解耦表示的场景生成模型的美学增强方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |