CN111382800B - 一种适用于样本分布不均衡的多标签多分类方法 - Google Patents
一种适用于样本分布不均衡的多标签多分类方法 Download PDFInfo
- Publication number
- CN111382800B CN111382800B CN202010166042.8A CN202010166042A CN111382800B CN 111382800 B CN111382800 B CN 111382800B CN 202010166042 A CN202010166042 A CN 202010166042A CN 111382800 B CN111382800 B CN 111382800B
- Authority
- CN
- China
- Prior art keywords
- label
- hidden layer
- comparison
- labels
- layer output
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/50—Information retrieval; Database structures therefor; File system structures therefor of still image data
- G06F16/55—Clustering; Classification
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Databases & Information Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种适用于样本分布不均衡的多标签多分类方法,包括以下步骤:S1:构建并训练一个基于神经网络的多标签多分类模型,并设定比较对象;S2:利用训练完成的多标签多分类模型和训练样本,计算各个标签上,所有训练样本对应比较对象的平均值,作为比较平均值;S3:将待检测图片输入多标签多分类模型,得到在各标签上,该图片对应比较对象的值,作为比较值;S4:选取该图片比较值与比较平均值最接近的前N个标签,作为待检测图片的标签,完成对图片的多标签多分类,与现有技术相比,本发明具有实现简单、容易训练且适用性广等优点。
Description
技术领域
本发明涉及深度学习的多标签多分类领域,尤其是涉及一种适用于样本分布不均衡的多标签多分类方法。
背景技术
利用深度学习做多标签多分类问题时一般会遇到样本数量不均衡的问题,比以flickr30K样本集做图像标签为例,该样本集共有30000张图片,每张图片有5个标签,所有的标签合并在一起后,共有30W条摘要,下表为包含狗、猫、海豚和大象的标签的样本数量:
类别 | 狗 | 猫 | 海豚 | 大象 |
数量 | 10619 | 308 | 40 | 94 |
显然相对于包含狗的样本,其它的样本数量少到可以忽略不计。而目前主流的基于深度学习的多分类方法中,最后一层的激活函数为sigmoid函数,该函数的值域为(0,1),其结果往往被用于作为某个标签的概率。模型的期望损失函数为:
其中,Ni为包含标签i的样本数量,Nj为不包含标签j的样本数量,Pi为模型预测标签i的平均概率,Pj为模型预测标签j的平均概率。
由于海豚标签的训练样本数量远小于狗标签的训练样本数量,就算海豚标签全部被预测错,只要把狗标签全部预测对,那么模型的损失也是很小的,因此用这样的样本训练出来的结果不可避免地会出现这样的问题:对于出现频率高的标签,模型预测的概率也相应偏高,这样的问题就是样本类别不均衡问题。
目前在单分类领域解决样本不均衡问题的方法为:在训练模型时增加携带低频率标签样本的训练次数,减少携带高频率标签样本的训练次数。虽然这样的做法在单分类任务中有一定的效果,但是在多分类任务中这样的办法基本上失效。在多分类任务中,每个样本都带有多个标签,因此向模型输入带有低频率标签样本的同时,也不可控制地把该样本携带的其它的标签也输入了模型。例如当多输入带有海豚标签的图片给模型时,这张图片中还包含的大象也会被输入模型,因此当海豚标签与狗标签这两种标签均衡时,大象的标签将出现偏多,同样造成样本分布不均衡问题。
发明内容
本发明的目的就是为了克服上述现有技术存在的缺陷而提供一种适用于样本分布不均衡的多标签多分类方法。
本发明的目的可以通过以下技术方案来实现:
一种适用于样本分布不均衡的多标签多分类方法,包括以下步骤:
S1:构建并训练一个基于神经网络的多标签多分类模型,并设定比较对象;
S2:利用训练完成的多标签多分类模型和训练样本,计算各个标签上,所有训练样本对应比较对象的平均值,作为比较平均值;
S3:将待检测图片输入多标签多分类模型,得到在各标签上,该图片对应比较对象的值,作为比较值;
S4:选取该图片比较值与比较平均值最接近的前N个标签,作为待检测图片的标签,完成对图片的多标签多分类。
所述的步骤S4中,选取该图片比较值与比较平均值最接近的前3个标签,作为待检测图片的标签。
所述的比较对象为隐藏层输出logits。
所述的多标签多分类模型首先对输入的图像进行特征提取,得到特征向量V,再通过线性变换得到隐藏层输出logits。
当比较对象为隐藏层输出logits时,所述的步骤S2-步骤S4具体包括:
A2:将待测图片输入多标签多分类模型,计算得到其在各标签上对应的隐藏层输出logits;
所述的比较对象为标签概率P、第一标签概率对数log(P)或第二标签概率对数ln(P)。
与现有技术相比,本发明具有以下优点:
1)实现简单:只需要增加对各标签对应比较对象的平均值进行计算,无需对原分类模型进行修改,实现简单;
2)容易训练:训练过程中采用随机梯度下降算法优化,每一步从3万个训练样本中随机采样128个,只需要训练2000步,就能通过纵向比较算法提取出很准确的标签,模型训练2000步实际上只采样了256000次,相当于每个样本被采样了不到10次,而通常利用深度学习分类器,每个样本需要被采样上百次,节省了训练时间的同时也避免了过量训练引起的过拟合问题;
附图说明
图1为本发明的流程示意图;
图2为多标签多分类模型计算隐藏层输出logits的过程示意图;
图3为使用本发明方法进行图片预测分类过程的示意图;
图4为实施例中用于测试的图片。
具体实施方式
下面结合附图和具体实施例对本发明进行详细说明。显然,所描述的实施例是本发明的一部分实施例,而不是全部实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都应属于本发明保护的范围。
实施例
如图1所示,本发明一种适用于样本分布不均衡的多标签多分类方法,包括以下步骤:
步骤1:构建并训练一个基于神经网络的多标签多分类模型。如图2所示,该多标签多分类模型首先对输入的样本进行特征提取得到特征向量V,再利用公式:logits=W*V+B(其中B为N维向量),将特征向量V线性变换转换成N维向量,得到隐藏层输出logits,最后对隐藏层输出logits利用sigmoid函数进行激活,转换成样本属于每个标签的概率分布。
步骤3:如图3所示,将待预测的图片输入多标签多分类模型,计算得到其在各标签上对应的隐藏层输出logits。
下表为同样使用flickr30K样本集进行训练后,对图4所示的图片进行分类后得到的结果。利用传统方法选用logits作为指标和使用本发明方法的作为指标,分别选取前3个得分最高的标签作为图4的标签,使用本发明方法得到的结果是“海豚”、“跃出”和“水花”,而使用传统方法得到的结果为“黑狗”、“水花”和“水面”,由该结果可知,由于flickr30K样本集中包含狗的样本过多,包含海豚的样本过少,因此使用传统方法选用logits作为指标进行分类时,出现了样本分布不均衡问题,而使用本发明方法得到的分类结果有效避免了该问题。
本发明方法的原理为:采用纵向比较的方法代替传统的横向比较,实现图片标签的提取,即在每一类标签内部进行的比较,例如对于训练集中的训练样本,模型预测存在海豚的概率平均值为0.0001,但对于某张待预测的海豚图片,模型预测存在海豚的概率为0.2,那么虽然这个概率也很小,却比平均概率高出2000倍;而对于训练集中的训练样本,模型预测存在狗的概率平均值为0.6,对于同样一张待预测的海豚图片,模型预测存在狗的概率为0.5,那么及时这个概率大于存在海豚的概率,其仍然比存在狗的概率平均值低,这样便可实现对图片标签的正确提取和分类,而不受样本分布不均衡的影响。
因此本发明通过利用多标签多分类模型在某个标签上对设定比较对象的预测值,与所有训练样本在该标签上对该比较对象的平均值进行比较,来实现图像多标签分类。其中,比较对象可以是各标签的概率P,也可以是隐藏层输出logits、概率对数log(P)或者其他可行的对象,本实施例中选用隐藏层输出logits与其平均值之间的差值作为指标。具体原理如下:
sigmoid函数的形式为:
其具有如下性质:单调递增,即某标签的logits值越大,则样本属于对应标签的概率也越大;值域为(0,1),与概率的值域一致;当x小于0时,sigmoid函数无限趋近0,此时误差按指数衰减;当x大于0时,sigmoid函数无限趋近1,此时误差也按指数衰减。
由于logits在小于0时趋近于log(p),表征了概率P的数量级;在大于0时近似于-log(1-p)表征了趋近于1的数量级,因此隐藏层输出在概率极小时表示可能性相差的数量级,在概率极大时表示不可能性相差的数量级。
根据玻尔兹曼分布公式可以看到logits是一个与能量成正相关的值,实际上对应了平均的能量,平均能量是有物理意义的。而概率p的平均值并没有什么物理意义,但是概率的对数ln(p)的平均值就是物理量熵,具有物理意义。在恒温的热力学过程中熵的增量与热能的增量也是成正比的。因此采用与在某种意义上是等价的。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的工作人员在本发明揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。
Claims (5)
1.一种适用于样本分布不均衡的多标签多分类方法,其特征在于,包括以下步骤:
S1:构建并训练一个基于神经网络的多标签多分类模型,并设定比较对象;
S2:利用训练完成的多标签多分类模型和训练样本,计算各个标签上,所有训练样本对应比较对象的平均值,作为比较平均值;
S3:将待检测图片输入多标签多分类模型,得到在各标签上,该图片对应比较对象的值,作为比较值;
S4:选取该图片比较值与比较平均值最接近的前N个标签,作为待检测图片的标签,完成对图片的多标签多分类;
所述的比较对象为隐藏层输出logits,所述的多标签多分类模型首先对输入的图像进行特征提取,得到特征向量V,再通过线性变换得到隐藏层输出logits;
当比较对象为隐藏层输出logits时,所述的步骤S2-步骤S4具体包括:
A2:将待测图片输入多标签多分类模型,计算得到其在各标签上对应的隐藏层输出logits;
2.根据权利要求1所述的一种适用于样本分布不均衡的多标签多分类方法,其特征在于,所述的比较对象为标签概率P、第一标签概率对数log(P)或第二标签概率对数ln(P)。
4.根据权利要求1所述的一种适用于样本分布不均衡的多标签多分类方法,其特征在于,所述的步骤S4中,选取该图片比较值与比较平均值最接近的前3个标签,作为待检测图片的标签。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010166042.8A CN111382800B (zh) | 2020-03-11 | 2020-03-11 | 一种适用于样本分布不均衡的多标签多分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010166042.8A CN111382800B (zh) | 2020-03-11 | 2020-03-11 | 一种适用于样本分布不均衡的多标签多分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111382800A CN111382800A (zh) | 2020-07-07 |
CN111382800B true CN111382800B (zh) | 2022-11-25 |
Family
ID=71222693
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010166042.8A Active CN111382800B (zh) | 2020-03-11 | 2020-03-11 | 一种适用于样本分布不均衡的多标签多分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111382800B (zh) |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108133240A (zh) * | 2018-01-31 | 2018-06-08 | 湖北工业大学 | 一种基于烟花算法的多标签分类方法及系统 |
CN109934299A (zh) * | 2019-03-20 | 2019-06-25 | 中国科学技术大学 | 一种考虑了不均衡查询代价的多标签主动学习方法 |
CN110210515A (zh) * | 2019-04-25 | 2019-09-06 | 浙江大学 | 一种图像数据多标签分类方法 |
CN110516098A (zh) * | 2019-08-26 | 2019-11-29 | 苏州大学 | 基于卷积神经网络及二进制编码特征的图像标注方法 |
-
2020
- 2020-03-11 CN CN202010166042.8A patent/CN111382800B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108133240A (zh) * | 2018-01-31 | 2018-06-08 | 湖北工业大学 | 一种基于烟花算法的多标签分类方法及系统 |
CN109934299A (zh) * | 2019-03-20 | 2019-06-25 | 中国科学技术大学 | 一种考虑了不均衡查询代价的多标签主动学习方法 |
CN110210515A (zh) * | 2019-04-25 | 2019-09-06 | 浙江大学 | 一种图像数据多标签分类方法 |
CN110516098A (zh) * | 2019-08-26 | 2019-11-29 | 苏州大学 | 基于卷积神经网络及二进制编码特征的图像标注方法 |
Non-Patent Citations (2)
Title |
---|
Improving Pairwise Ranking for Multi-label Image Classification;Yuncheng Li et.al;《arXiv:1704.03135v3 [cs.CV]》;20170601;第1-9页 * |
基于迁移学习与多标签平滑策略的图像自动标注;汪鹏 等;《计算机应用》;20181110;第38卷(第11期);第3199-3203页 * |
Also Published As
Publication number | Publication date |
---|---|
CN111382800A (zh) | 2020-07-07 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110309331B (zh) | 一种基于自监督的跨模态深度哈希检索方法 | |
Cao et al. | Heteroskedastic and imbalanced deep learning with adaptive regularization | |
CN105608471B (zh) | 一种鲁棒直推式标签估计及数据分类方法和系统 | |
WO2019210695A1 (zh) | 模型训练和业务推荐 | |
CN107683469A (zh) | 一种基于深度学习的产品分类方法及装置 | |
CN110046634B (zh) | 聚类结果的解释方法和装置 | |
CN111259140B (zh) | 一种基于lstm多实体特征融合的虚假评论检测方法 | |
CN107943856A (zh) | 一种基于扩充标记样本的文本分类方法及系统 | |
CN105184298A (zh) | 一种快速局部约束低秩编码的图像分类方法 | |
CN105354595A (zh) | 一种鲁棒视觉图像分类方法及系统 | |
CN113723492B (zh) | 一种改进主动深度学习的高光谱图像半监督分类方法及装置 | |
CN110598753A (zh) | 一种基于主动学习的缺陷识别方法 | |
CN113268675B (zh) | 一种基于图注意力网络的社交媒体谣言检测方法和系统 | |
CN114298851A (zh) | 基于图表征学习的网络用户社交行为分析方法、装置及存储介质 | |
CN108596204B (zh) | 一种基于改进型scdae的半监督调制方式分类模型的方法 | |
Padhy et al. | Image classification in artificial neural network using fractal dimension | |
CN111382800B (zh) | 一种适用于样本分布不均衡的多标签多分类方法 | |
CN112836007A (zh) | 一种基于语境化注意力网络的关系元学习方法 | |
CN110263808B (zh) | 一种基于lstm网络和注意力机制的图像情感分类方法 | |
CN111209813A (zh) | 基于迁移学习的遥感图像语义分割方法 | |
CN116541704A (zh) | 一种多类噪声分离的偏标记学习方法 | |
CN112949590B (zh) | 一种跨域行人重识别模型构建方法及构建系统 | |
CN114882279A (zh) | 基于直推式半监督深度学习的多标签图像分类方法 | |
CN113837220A (zh) | 基于在线持续学习的机器人目标识别方法、系统及设备 | |
CN110647630A (zh) | 检测同款商品的方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |