CN111382800B - 一种适用于样本分布不均衡的多标签多分类方法 - Google Patents

一种适用于样本分布不均衡的多标签多分类方法 Download PDF

Info

Publication number
CN111382800B
CN111382800B CN202010166042.8A CN202010166042A CN111382800B CN 111382800 B CN111382800 B CN 111382800B CN 202010166042 A CN202010166042 A CN 202010166042A CN 111382800 B CN111382800 B CN 111382800B
Authority
CN
China
Prior art keywords
label
hidden layer
comparison
labels
layer output
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202010166042.8A
Other languages
English (en)
Other versions
CN111382800A (zh
Inventor
马祥祥
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Eisoo Information Technology Co Ltd
Original Assignee
Shanghai Eisoo Information Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Eisoo Information Technology Co Ltd filed Critical Shanghai Eisoo Information Technology Co Ltd
Priority to CN202010166042.8A priority Critical patent/CN111382800B/zh
Publication of CN111382800A publication Critical patent/CN111382800A/zh
Application granted granted Critical
Publication of CN111382800B publication Critical patent/CN111382800B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/50Information retrieval; Database structures therefor; File system structures therefor of still image data
    • G06F16/55Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Probability & Statistics with Applications (AREA)
  • Databases & Information Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及一种适用于样本分布不均衡的多标签多分类方法,包括以下步骤:S1:构建并训练一个基于神经网络的多标签多分类模型,并设定比较对象;S2:利用训练完成的多标签多分类模型和训练样本,计算各个标签上,所有训练样本对应比较对象的平均值,作为比较平均值;S3:将待检测图片输入多标签多分类模型,得到在各标签上,该图片对应比较对象的值,作为比较值;S4:选取该图片比较值与比较平均值最接近的前N个标签,作为待检测图片的标签,完成对图片的多标签多分类,与现有技术相比,本发明具有实现简单、容易训练且适用性广等优点。

Description

一种适用于样本分布不均衡的多标签多分类方法
技术领域
本发明涉及深度学习的多标签多分类领域,尤其是涉及一种适用于样本分布不均衡的多标签多分类方法。
背景技术
利用深度学习做多标签多分类问题时一般会遇到样本数量不均衡的问题,比以flickr30K样本集做图像标签为例,该样本集共有30000张图片,每张图片有5个标签,所有的标签合并在一起后,共有30W条摘要,下表为包含狗、猫、海豚和大象的标签的样本数量:
类别 海豚 大象
数量 10619 308 40 94
显然相对于包含狗的样本,其它的样本数量少到可以忽略不计。而目前主流的基于深度学习的多分类方法中,最后一层的激活函数为sigmoid函数,该函数的值域为(0,1),其结果往往被用于作为某个标签的概率。模型的期望损失函数为:
Figure BDA0002407497410000011
其中,Ni为包含标签i的样本数量,Nj为不包含标签j的样本数量,Pi为模型预测标签i的平均概率,Pj为模型预测标签j的平均概率。
由于海豚标签的训练样本数量远小于狗标签的训练样本数量,就算海豚标签全部被预测错,只要把狗标签全部预测对,那么模型的损失也是很小的,因此用这样的样本训练出来的结果不可避免地会出现这样的问题:对于出现频率高的标签,模型预测的概率也相应偏高,这样的问题就是样本类别不均衡问题。
目前在单分类领域解决样本不均衡问题的方法为:在训练模型时增加携带低频率标签样本的训练次数,减少携带高频率标签样本的训练次数。虽然这样的做法在单分类任务中有一定的效果,但是在多分类任务中这样的办法基本上失效。在多分类任务中,每个样本都带有多个标签,因此向模型输入带有低频率标签样本的同时,也不可控制地把该样本携带的其它的标签也输入了模型。例如当多输入带有海豚标签的图片给模型时,这张图片中还包含的大象也会被输入模型,因此当海豚标签与狗标签这两种标签均衡时,大象的标签将出现偏多,同样造成样本分布不均衡问题。
发明内容
本发明的目的就是为了克服上述现有技术存在的缺陷而提供一种适用于样本分布不均衡的多标签多分类方法。
本发明的目的可以通过以下技术方案来实现:
一种适用于样本分布不均衡的多标签多分类方法,包括以下步骤:
S1:构建并训练一个基于神经网络的多标签多分类模型,并设定比较对象;
S2:利用训练完成的多标签多分类模型和训练样本,计算各个标签上,所有训练样本对应比较对象的平均值,作为比较平均值;
S3:将待检测图片输入多标签多分类模型,得到在各标签上,该图片对应比较对象的值,作为比较值;
S4:选取该图片比较值与比较平均值最接近的前N个标签,作为待检测图片的标签,完成对图片的多标签多分类。
所述的步骤S4中,选取该图片比较值与比较平均值最接近的前3个标签,作为待检测图片的标签。
所述的比较对象为隐藏层输出logits。
所述的多标签多分类模型首先对输入的图像进行特征提取,得到特征向量V,再通过线性变换得到隐藏层输出logits。
当比较对象为隐藏层输出logits时,所述的步骤S2-步骤S4具体包括:
A1:利用训练好的多标签多分类模型,计算所有训练样本上各标签对应的隐藏层输出logits之和,并求其平均值,记为隐藏层输出平均
Figure BDA0002407497410000021
A2:将待测图片输入多标签多分类模型,计算得到其在各标签上对应的隐藏层输出logits;
A3:计算隐藏层输出logits与隐藏层输出平均
Figure BDA0002407497410000022
的差值作为指标,选取该值最大的前N个标签作为该图片的标签,完成图片的多标签多分类。
所述的步骤A3中,选取隐藏层输出logits与隐藏层输出平均
Figure BDA0002407497410000023
的差值最大的前3个标签作为该图片的标签。
所述的比较对象为标签概率P、第一标签概率对数log(P)或第二标签概率对数ln(P)。
所述的隐藏层输出平均
Figure BDA0002407497410000031
通过将隐藏层输出平均
Figure BDA0002407497410000032
作为一个变量向量进行训练得到,该训练的目标函数表达式为:
Figure BDA0002407497410000033
与现有技术相比,本发明具有以下优点:
1)实现简单:只需要增加对各标签对应比较对象的平均值进行计算,无需对原分类模型进行修改,实现简单;
2)容易训练:训练过程中采用随机梯度下降算法优化,每一步从3万个训练样本中随机采样128个,只需要训练2000步,就能通过纵向比较算法提取出很准确的标签,模型训练2000步实际上只采样了256000次,相当于每个样本被采样了不到10次,而通常利用深度学习分类器,每个样本需要被采样上百次,节省了训练时间的同时也避免了过量训练引起的过拟合问题;
3)兼容标签频率均衡的训练集:当标签均衡时,各个标签的隐藏层输出平均
Figure BDA0002407497410000034
将趋于一致,其得到的结果与直接将隐藏层输出logits作为指标相同,同样能实现多标签分类,因此本发明的方法对于标签频率均衡的训练集也可使用,应用范围广。
附图说明
图1为本发明的流程示意图;
图2为多标签多分类模型计算隐藏层输出logits的过程示意图;
图3为使用本发明方法进行图片预测分类过程的示意图;
图4为实施例中用于测试的图片。
具体实施方式
下面结合附图和具体实施例对本发明进行详细说明。显然,所描述的实施例是本发明的一部分实施例,而不是全部实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都应属于本发明保护的范围。
实施例
如图1所示,本发明一种适用于样本分布不均衡的多标签多分类方法,包括以下步骤:
步骤1:构建并训练一个基于神经网络的多标签多分类模型。如图2所示,该多标签多分类模型首先对输入的样本进行特征提取得到特征向量V,再利用公式:logits=W*V+B(其中B为N维向量),将特征向量V线性变换转换成N维向量,得到隐藏层输出logits,最后对隐藏层输出logits利用sigmoid函数进行激活,转换成样本属于每个标签的概率分布。
步骤2:利用训练好的多标签多分类模型,计算所有训练样本上各标签对应的隐藏层输出logits之和,并求其平均值,记为隐藏层输出平均
Figure BDA0002407497410000041
本实施例的具体实施过程中,根据“平均值所在的点相对于所有样本点的平均方差最小”这一定理,在代码实现中可以将隐藏层输出平均
Figure BDA0002407497410000042
作为一个变量向量进行训练,训练的目标为:
Figure BDA0002407497410000043
步骤3:如图3所示,将待预测的图片输入多标签多分类模型,计算得到其在各标签上对应的隐藏层输出logits。
步骤4:计算
Figure BDA0002407497410000044
的值,并选取该值最大的前N个标签作为该图片的标签,完成图片的多标签多分类。
下表为同样使用flickr30K样本集进行训练后,对图4所示的图片进行分类后得到的结果。利用传统方法选用logits作为指标和使用本发明方法的
Figure BDA0002407497410000045
作为指标,分别选取前3个得分最高的标签作为图4的标签,使用本发明方法得到的结果是“海豚”、“跃出”和“水花”,而使用传统方法得到的结果为“黑狗”、“水花”和“水面”,由该结果可知,由于flickr30K样本集中包含狗的样本过多,包含海豚的样本过少,因此使用传统方法选用logits作为指标进行分类时,出现了样本分布不均衡问题,而使用本发明方法得到的分类结果有效避免了该问题。
Figure BDA0002407497410000046
本发明方法的原理为:采用纵向比较的方法代替传统的横向比较,实现图片标签的提取,即在每一类标签内部进行的比较,例如对于训练集中的训练样本,模型预测存在海豚的概率平均值为0.0001,但对于某张待预测的海豚图片,模型预测存在海豚的概率为0.2,那么虽然这个概率也很小,却比平均概率高出2000倍;而对于训练集中的训练样本,模型预测存在狗的概率平均值为0.6,对于同样一张待预测的海豚图片,模型预测存在狗的概率为0.5,那么及时这个概率大于存在海豚的概率,其仍然比存在狗的概率平均值低,这样便可实现对图片标签的正确提取和分类,而不受样本分布不均衡的影响。
因此本发明通过利用多标签多分类模型在某个标签上对设定比较对象的预测值,与所有训练样本在该标签上对该比较对象的平均值进行比较,来实现图像多标签分类。其中,比较对象可以是各标签的概率P,也可以是隐藏层输出logits、概率对数log(P)或者其他可行的对象,本实施例中选用隐藏层输出logits与其平均值
Figure BDA0002407497410000051
之间的差值作为指标。具体原理如下:
sigmoid函数的形式为:
Figure BDA0002407497410000052
其具有如下性质:单调递增,即某标签的logits值越大,则样本属于对应标签的概率也越大;值域为(0,1),与概率的值域一致;当x小于0时,sigmoid函数无限趋近0,此时误差按指数衰减;当x大于0时,sigmoid函数无限趋近1,此时误差也按指数衰减。
由于logits在小于0时趋近于log(p),表征了概率P的数量级;在大于0时近似于-log(1-p)表征了趋近于1的数量级,因此隐藏层输出
Figure BDA0002407497410000053
在概率极小时表示可能性相差的数量级,在概率极大时表示不可能性相差的数量级。
本发明优选采用隐藏层输出
Figure BDA0002407497410000054
与概率对数
Figure BDA0002407497410000055
作为评价指标,下面是根据热力学与统计物理中的知识给出的解释:
根据玻尔兹曼分布公式可以看到logits是一个与能量成正相关的值,
Figure BDA0002407497410000056
实际上对应了平均的能量,平均能量是有物理意义的。而概率p的平均值并没有什么物理意义,但是概率的对数ln(p)的平均值就是物理量熵,具有物理意义。在恒温的热力学过程中熵的增量与热能的增量也是成正比的。因此采用
Figure BDA0002407497410000057
Figure BDA0002407497410000058
在某种意义上是等价的。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的工作人员在本发明揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (5)

1.一种适用于样本分布不均衡的多标签多分类方法,其特征在于,包括以下步骤:
S1:构建并训练一个基于神经网络的多标签多分类模型,并设定比较对象;
S2:利用训练完成的多标签多分类模型和训练样本,计算各个标签上,所有训练样本对应比较对象的平均值,作为比较平均值;
S3:将待检测图片输入多标签多分类模型,得到在各标签上,该图片对应比较对象的值,作为比较值;
S4:选取该图片比较值与比较平均值最接近的前N个标签,作为待检测图片的标签,完成对图片的多标签多分类;
所述的比较对象为隐藏层输出logits,所述的多标签多分类模型首先对输入的图像进行特征提取,得到特征向量V,再通过线性变换得到隐藏层输出logits;
当比较对象为隐藏层输出logits时,所述的步骤S2-步骤S4具体包括:
A1:利用训练好的多标签多分类模型,计算所有训练样本上各标签对应的隐藏层输出logits之和,并求其平均值,记为隐藏层输出平均
Figure FDA0003832974610000011
A2:将待测图片输入多标签多分类模型,计算得到其在各标签上对应的隐藏层输出logits;
A3:计算隐藏层输出logits与隐藏层输出平均
Figure FDA0003832974610000012
的差值作为指标,选取该值最大的前N个标签作为该图片的标签,完成图片的多标签多分类。
2.根据权利要求1所述的一种适用于样本分布不均衡的多标签多分类方法,其特征在于,所述的比较对象为标签概率P、第一标签概率对数log(P)或第二标签概率对数ln(P)。
3.根据权利要求1所述的一种适用于样本分布不均衡的多标签多分类方法,其特征在于,所述的隐藏层输出平均
Figure FDA0003832974610000013
通过将隐藏层输出平均
Figure FDA0003832974610000014
作为一个变量向量进行训练得到,该训练的目标函数表达式为:
Figure FDA0003832974610000015
4.根据权利要求1所述的一种适用于样本分布不均衡的多标签多分类方法,其特征在于,所述的步骤S4中,选取该图片比较值与比较平均值最接近的前3个标签,作为待检测图片的标签。
5.根据权利要求1所述的一种适用于样本分布不均衡的多标签多分类方法,其特征在于,所述的步骤A3中,选取隐藏层输出logits与隐藏层输出平均
Figure FDA0003832974610000021
的差值最大的前3个标签作为该图片的标签。
CN202010166042.8A 2020-03-11 2020-03-11 一种适用于样本分布不均衡的多标签多分类方法 Active CN111382800B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010166042.8A CN111382800B (zh) 2020-03-11 2020-03-11 一种适用于样本分布不均衡的多标签多分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010166042.8A CN111382800B (zh) 2020-03-11 2020-03-11 一种适用于样本分布不均衡的多标签多分类方法

Publications (2)

Publication Number Publication Date
CN111382800A CN111382800A (zh) 2020-07-07
CN111382800B true CN111382800B (zh) 2022-11-25

Family

ID=71222693

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010166042.8A Active CN111382800B (zh) 2020-03-11 2020-03-11 一种适用于样本分布不均衡的多标签多分类方法

Country Status (1)

Country Link
CN (1) CN111382800B (zh)

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108133240A (zh) * 2018-01-31 2018-06-08 湖北工业大学 一种基于烟花算法的多标签分类方法及系统
CN109934299A (zh) * 2019-03-20 2019-06-25 中国科学技术大学 一种考虑了不均衡查询代价的多标签主动学习方法
CN110210515A (zh) * 2019-04-25 2019-09-06 浙江大学 一种图像数据多标签分类方法
CN110516098A (zh) * 2019-08-26 2019-11-29 苏州大学 基于卷积神经网络及二进制编码特征的图像标注方法

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108133240A (zh) * 2018-01-31 2018-06-08 湖北工业大学 一种基于烟花算法的多标签分类方法及系统
CN109934299A (zh) * 2019-03-20 2019-06-25 中国科学技术大学 一种考虑了不均衡查询代价的多标签主动学习方法
CN110210515A (zh) * 2019-04-25 2019-09-06 浙江大学 一种图像数据多标签分类方法
CN110516098A (zh) * 2019-08-26 2019-11-29 苏州大学 基于卷积神经网络及二进制编码特征的图像标注方法

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
Improving Pairwise Ranking for Multi-label Image Classification;Yuncheng Li et.al;《arXiv:1704.03135v3 [cs.CV]》;20170601;第1-9页 *
基于迁移学习与多标签平滑策略的图像自动标注;汪鹏 等;《计算机应用》;20181110;第38卷(第11期);第3199-3203页 *

Also Published As

Publication number Publication date
CN111382800A (zh) 2020-07-07

Similar Documents

Publication Publication Date Title
CN110309331B (zh) 一种基于自监督的跨模态深度哈希检索方法
Cao et al. Heteroskedastic and imbalanced deep learning with adaptive regularization
CN105608471B (zh) 一种鲁棒直推式标签估计及数据分类方法和系统
WO2019210695A1 (zh) 模型训练和业务推荐
CN107683469A (zh) 一种基于深度学习的产品分类方法及装置
CN110046634B (zh) 聚类结果的解释方法和装置
CN111259140B (zh) 一种基于lstm多实体特征融合的虚假评论检测方法
CN107943856A (zh) 一种基于扩充标记样本的文本分类方法及系统
CN105184298A (zh) 一种快速局部约束低秩编码的图像分类方法
CN105354595A (zh) 一种鲁棒视觉图像分类方法及系统
CN113723492B (zh) 一种改进主动深度学习的高光谱图像半监督分类方法及装置
CN110598753A (zh) 一种基于主动学习的缺陷识别方法
CN113268675B (zh) 一种基于图注意力网络的社交媒体谣言检测方法和系统
CN114298851A (zh) 基于图表征学习的网络用户社交行为分析方法、装置及存储介质
CN108596204B (zh) 一种基于改进型scdae的半监督调制方式分类模型的方法
Padhy et al. Image classification in artificial neural network using fractal dimension
CN111382800B (zh) 一种适用于样本分布不均衡的多标签多分类方法
CN112836007A (zh) 一种基于语境化注意力网络的关系元学习方法
CN110263808B (zh) 一种基于lstm网络和注意力机制的图像情感分类方法
CN111209813A (zh) 基于迁移学习的遥感图像语义分割方法
CN116541704A (zh) 一种多类噪声分离的偏标记学习方法
CN112949590B (zh) 一种跨域行人重识别模型构建方法及构建系统
CN114882279A (zh) 基于直推式半监督深度学习的多标签图像分类方法
CN113837220A (zh) 基于在线持续学习的机器人目标识别方法、系统及设备
CN110647630A (zh) 检测同款商品的方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant