CN116681128A - 一种带噪多标签数据的神经网络模型训练方法和装置 - Google Patents
一种带噪多标签数据的神经网络模型训练方法和装置 Download PDFInfo
- Publication number
- CN116681128A CN116681128A CN202310509397.6A CN202310509397A CN116681128A CN 116681128 A CN116681128 A CN 116681128A CN 202310509397 A CN202310509397 A CN 202310509397A CN 116681128 A CN116681128 A CN 116681128A
- Authority
- CN
- China
- Prior art keywords
- label
- transfer matrix
- learning
- class
- noise transfer
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 60
- 238000012549 training Methods 0.000 title claims abstract description 42
- 238000003062 neural network model Methods 0.000 title claims abstract description 18
- 238000012546 transfer Methods 0.000 claims abstract description 89
- 239000011159 matrix material Substances 0.000 claims abstract description 88
- 230000001419 dependent effect Effects 0.000 claims abstract description 54
- 238000013528 artificial neural network Methods 0.000 claims abstract description 22
- 238000005457 optimization Methods 0.000 claims abstract description 14
- 238000004590 computer program Methods 0.000 claims description 6
- 239000000203 mixture Substances 0.000 claims description 6
- 238000002372 labelling Methods 0.000 description 8
- 230000006870 function Effects 0.000 description 6
- 230000000694 effects Effects 0.000 description 4
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 241000251468 Actinopterygii Species 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000000605 extraction Methods 0.000 description 2
- 239000004973 liquid crystal related substance Substances 0.000 description 2
- 238000013507 mapping Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000004821 distillation Methods 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 230000002427 irreversible effect Effects 0.000 description 1
- 230000003446 memory effect Effects 0.000 description 1
- 230000001575 pathological effect Effects 0.000 description 1
- 230000008092 positive effect Effects 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000010200 validation analysis Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
- XLYOFNOQVPJJNP-UHFFFAOYSA-N water Substances O XLYOFNOQVPJJNP-UHFFFAOYSA-N 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/0985—Hyperparameter optimisation; Meta-learning; Learning-to-learn
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T90/00—Enabling technologies or technologies with a potential or indirect contribution to GHG emissions mitigation
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种带噪多标签数据的神经网络模型训练方法和装置。该方法的步骤包括:通过样本选择算法为每一类别选择出干净样本集合作为元数据集,并进行类别依赖的标签噪声转移矩阵估计;利用类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化;基于统计一致性的标签噪声学习损失,将学习问题转化为双层优化问题,用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。本发明创新性地利用元学习算法以数据驱动的方式,将实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数统一到一个框架下学习。
Description
技术领域
本发明属于互联网领域,具体而言,涉及一种带噪多标签数据的神经网络模型训练方法和装置。
背景技术
近年来,深度学习在图像识别,目标检测,视觉跟踪和文本匹配等领域均取得了显著的成果。这主要归功于深度神经网络强大的非线性映射能力,即能够保证数据的特征空间的表达高度可分。然而,随着深度学习的发展,训练高精度模型所需的数据量随着要求精度的增加而产生了爆发式的增长,例如ImageNet图像分类数据集就包含了14,197,122张图片,CLIP跨模态预训练模型使用了4亿个文本-图像对进行训练。在数据集的规模越来越大的同时,给这些大规模的数据集进行精准的标注所需要的人工成本以及经济成本呈指数型增长,这尤其体现在医学图像处理、金融风控等领域。例如在在医学图像处理领域,一张病灶照片的准确标注需要经过一个甚至多个专家的诊断,大大增加了标注大数据集的难度;在金融风控领域的国际盗卡场景中,通常需要等6个月才能获得反馈的案件标签。
在此背景下,一些简单标注方法应运而生,例如将大数据集分成若干个小数据集后分发给不同的标注者的众包方法,以及利用搜索引擎、网络链接或视频标签等的自动标注方法。然而,这些方法在以较低的经济花费获得大量标注数据的同时,也为构建的数据集引入了不可逆的标签噪声问题,例如在众包方法中,各个标注者对数据的识别能力以及本身对各个类别的偏向不同,这会导致一些数据的标签产生不同的错误。
针对上述带噪标记样本下的学习问题,主要有以下几种方法。
一类方法是具有统计一致性的方法:通过设计损失函数,使得利用噪声数据学习的分类器将渐近收敛到在干净域上定义的最佳分类器。噪声转移矩阵表示干净标签转换为噪声标签的概率,因此它被用来构建一系列的统计一致性算法。具体来说,它已被用于修改损失函数以建立具有风险一致性的标签噪声学习算法,以及被用于限制假设空间以构建具有分类器一致性的标签噪声学习算法。
第二类方法是不具有统计一致性的方法:采用启发式设计来减少标签噪声的副作用,例如提取可靠示例,校正标签,和添加隐式或显式正则化。目前有效的提取可靠示例方法主要包括但不限于以下方法:蒸馏法、样本筛法、高斯混合模型损失分布建模、基于置信度的样本集合,基于小损失的方法,以及一些早期停止技术。
第三类方法是利用少量干净数据的方法:通过尽可能地利用少量干净标签的分布信息来抵抗标签噪声的影响。大多数关于标签噪声学习的工作都假设所有训练数据的标签都可能错误。但是,通常情况下有一些可信示例可用以创建验证和测试集。通过假设训练的一个子集是可信的,利用少量干净数据的方法改变了所有训练数据都可能被破坏的假设,并且证明拥有一定数量的可信训练数据可以显着提高稳健性。这类方法大都采用了元学习形式对假设空间进行限制。
尽管相关学者们已经提出了许多方法用于标签噪声深度学习,但是所提出的场景大都是简单的多分类噪声学习场景。而由于标签形式和学习方式的不同,实际应用中往往包含具有各不相同的场景特点,而这些方法难以简单迁移。具体到带噪多标签学习场景中,这种场景中每个样本中含有多个目标类别,其每个类别的标签都有可能含有噪声。在多标签学习中,很少有方法关注标签噪声的后果。因此,提出一种带噪多标签数据的神经网络模型训练方法十分有必要。
发明内容
本发明提供了一种带噪多标签数据的神经网络模型训练方法,以解决在带噪多标签学习场景中训练强决策能力的分类网络的技术问题。
本发明的技术方案为:
一种带噪多标签数据的神经网络模型训练方法,包括以下步骤:
通过样本选择算法为每一类别选择出“干净”的样本集合(后文称为干净样本集合)作为元数据集,并进行类别依赖的标签噪声转移矩阵估计;
利用类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化;
基于统计一致性的标签噪声学习损失,采用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。
进一步地,所述样本选择算法是一个基于深度神经网络的记忆效应的算法。本发明在噪声训练样本集上训练具有标准多标签分类损失的分类器若干个轮次,然后执行样本选择算法以获得每个类标签的选定的干净样本集合作为元数据集。基于所得到的干净样本集合和已有带噪数据,本发明利用标签相关性的不匹配性进行类别依赖的标签噪声转移矩阵估计。
进一步地,通过利用所求得的类别依赖的标签噪声转移矩阵对于实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化,为学习优化该网络提供了很好的参数初始点。
进一步地,本发明将带噪单标签场景中的统计一致性的算法应用于多标签任务分解得到的每个二分类问题,并将选择的干净样本集合作为元数据集,通过元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数,能够同时缓解标签噪声和数据不平衡的影响。
一种带噪多标签数据的神经网络模型训练装置,其包括:
样本选择及类别依赖的标签噪声转移矩阵估计模块,用于通过样本选择算法为每一类别选择出干净样本集合作为元数据集,并进行类别依赖的标签噪声转移矩阵估计;
实例特征依赖的标签噪声转移矩阵初始化模块,用于利用类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化;
元学习训练模块,用于基于统计一致性的标签噪声学习损失,采用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。
与现有技术相比,本发明的积极效果为:
1)利用每个类别的样本采样得到的干净样本集合作为带噪多标签学习的元数据集;
2)利用所求得的类别依赖的标签噪声转移矩阵对于实例特征的标签噪声转移矩阵网络中的部分参数进行初始化,为学习优化实例特征依赖的标签噪声转移矩阵提供了很好的参数初始点;
3)创新性地利用元学习算法以数据驱动的方式学习优化实例特征的标签噪声转移矩阵。将实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数统一到一个框架下学习,能够同时缓解标签噪声和数据不平衡的影响。
附图说明
此处所说明的附图用来提供对本发明的进一步理解,构成本发明的一部分,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。在附图中:
图1为本发明的方法流程图。
图2为本发明的实例特征依赖的标签噪声转移矩阵网络结构图。
图3为本发明的元学习算法内层训练流程图。
图4为本发明的元学习算法外层训练流程图。
图5对本发明在服饰属性分类场景中的应用流程图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述。
需要说明的是,本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本发明的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
根据本发明实施例的一方面,提供了一种模型的训练方法的方法实施例。本发明的方法流程如图1所示。本发明的方法框架包含样本选择及类别依赖的标签噪声转移矩阵估计、实例特征依赖的标签噪声转移矩阵初始化和用于带噪多标签场景的元学习训练模块组成。
一.样本选择及类别依赖的标签噪声转移矩阵估计
根据已有的带噪多标签数据,为了构建元学习算法所需的元数据,本发明利用样本选择算法来获得带尽可能干净的标签的数据作为元数据集。具体地,本发明在带噪多标签数据Dt上预热训练神经网络模型f若干个(例如5-20个等)轮次,其采用的损失函数为如下的标准多标签二分类学习损失:
其中f为多标签预测神经网络,结构为ResNet;fj(X)表示模型对输入数据特征X进行非线性映射得到的对第j类的预测结果,其中X表示输入特征向量;为样本的噪声标签向量,/>为第j类的噪声标签;q为总的类别数;l为二分类交叉熵损失;Y为样本的干净标签向量。
经过预热训练之后,神经网络模型获得了一定的判别能力,本发明通过使用高斯混合模型(GMM)对每一类j的每个样本损失的分布进行建模,提取损失较小的样本子集,获得每个类标签j的选定的干净样本集合
然后,为了便于后续学习实例特征依赖的标签噪声转移矩阵网络,本发明利用多标签之间的标签相关性来估计类别依赖的标签噪声转移矩阵,即其中Yj为第j类的干净标签,v和k为0或1,代表负类或正类。具体来说,一些在实践中不应该存在的标签相关性包含在带噪多标签学习中。例如,真实多标签数据中“鱼”和“水”总是同时出现,而“鸟”和“天空”总是同时出现。但是,由于标签错误,“鱼”和“天空”之间存在轻微的相关性,这是不切实际的。因此,可以利用元数据和带噪数据中标签相关性的不匹配性进行估计。
该步骤的创新性体现在:利用样本选择算法得到每一类标签的选定的干净样本集合作为带噪多标签场景元学习算法的元数据集,并利用其估计类别依赖的标签噪声转移矩阵。
二.实例特征依赖的标签噪声转移矩阵初始化
为了缓解在带噪多标签学习中实例特征依赖的标签噪声转移矩阵网络学习难的问题,本发明利用已估计得到的类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络的部分参数进行初始化。
实例特征依赖的标签噪声转移矩阵网络的结构如图2所示。类别j的标签噪声转移矩阵网络的网络结构分为骨干网络g(.)和线性层Lj(.),骨干网络的输入为实例特征x,输出为噪声模式表征g(x),线性层的输入为噪声模式表征g(x)和输出为类别j的实例特征依赖的标签噪声转移矩阵/>其中/>表示分类器f对于特征x除第j类外的其他类别的预测输出集合。
设线性层Lj(.)的输入变量为z,则其可以表示为Lj(z)=az+b,其中a和b为该线性层的可学习参数。为了便于学习,在初始化时,本发明对于参数a采用均值为0,方差为0.01的正态分布进行初始化,对于参数b,本发明利用估计得到的第j类的类别依赖的标签噪声转移矩阵,即对其进行初始化。这样的初始化能够保证实例特征依赖的标签噪声转移矩阵拥有损失比较小的初始解,有利于进行后续的优化。
该步骤的创新性体现在:利用类别依赖的标签噪声转移矩阵初始化实例特征依赖的标签噪声转移矩阵的部分参数,缓解实例特征依赖的标签噪声转移矩阵网络难学习的问题。
三.用于带噪多标签场景的元学习训练框架
为了缓解带噪多标签场景下的数据不平衡参数和标签噪声转移矩阵参数的耦合导致学习难的问题,本发明提出一种用于带噪多标签场景的元学习训练框架。
首先,给定输入X,本发明将任务分解为q个条件独立的二分类问题,即预测P(Yj∣X)之间相互独立。
其次,本发明将单标签情况下的具有统计一致性的标签噪声学习损失应用于每个二分类问题,加和得到最终的损失进行学习。不失一般性,这里采用重加权算法(Reweight)作为具有统计一致性的单标签噪声学习损失用于学习神经网络分类器:
其中,Lj为第j类别的损失函数,其可以替换为其他具有统计一致性的二分类噪声标签学习损失,n为样本总数,为第i个样本的第j类的噪声标签;
其中,为将转移矩阵Tj(x)中的元素,带有下标ik表示从将样本x的第j类标签从取值为i翻转为取值为k的概率,i、k=0或1;
其中,wj为数据不平衡参数,用于缓解数据不平衡问题。
在本框架中,需要同时学习实例特征依赖的噪声转移矩阵Tj(.)、数据不平衡参数wj和神经网络分类器fj(.)参数,如果直接最小化损失L,存在多组可行解,无法保证学习的效果。
为了解决这一问题,本发明将学习问题转化为下列双层优化问题(如图3和图4),并利用元学习进行训练:
其中,w*为学习到的最优数据不平衡参数,T*为学习到的最优转移矩阵,f*为学习到的最优分类器,w为可学习的数据不平衡参数,T为可学习的最优转移矩阵,f为可学习的分类器,为之前利用高斯混合模型选择得到的j类数据集/>的合集,/>为原始的带噪多标签训练集。在该双层优化问题中,外层优化利用选择的干净样本集合作为元数据集学习实例特征依赖的噪声转移矩阵Tj(.)和数据不平衡参数wj,内层优化根据已学习到的噪声转移矩阵和数据不平衡参数在原始的带噪多标签训练集上利用具有统计一致性的损失学习得到神经网络分类器fj(.)的参数。
该方法模块的创新性体现在:基于统计一致性的标签噪声学习损失,用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。
采用本发明的上述技术方案,实现了不准确监督信息数据下的学习,解决了如下的两个问题。第一,带噪多标签场景下的实例特征依赖的标签噪声转移矩阵估计问题:在没有任何假设的情况下,直接优化实例特征依赖的标签噪声转移矩阵是一个病态问题,为了解决这一问题,本发明利用了样本选择得到的干净样本集合提供额外信息进行学习;此外为了进一步促进学习,本发明创新地利用了类别依赖的标签噪声转移矩阵初始化实例特征依赖的标签噪声转移矩阵。第二,带噪多标签场景下的数据不平衡参数和标签噪声转移矩阵参数的耦合问题:由于统计一致性算法要求在学习过程中能够比较好地拟合噪声后验概率,但是多标签场景下的类别不平衡和正负样本不平衡问题往往会严重影响噪声后验概率的学习,如果仅学习标签噪声转移矩阵,会由于数据不平衡的耦合影响导致学习的不准确性,为了更好地缓解这个问题,本发明采用元学习框架同时学习学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。
需要说明的是,对于前述的各方法实施例,为了简单描述,故将其都表述为一系列的动作组合,但是本领域技术人员应该知悉,本发明并不受所描述的动作顺序的限制,因为依据本发明,某些步骤可以采用其他顺序或者同时进行。其次,本领域技术人员也应该知悉,说明书中所描述的实施例均属于优选实施例,所涉及的动作和模块并不一定是本发明所必须的。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到根据上述实施例的方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,或者网络设备等)执行本发明各个实施例所述的方法。例如:
本发明的一个实施例提供一种带噪多标签数据的神经网络模型训练装置,其包括:
样本选择及类别依赖的标签噪声转移矩阵估计模块,用于通过样本选择算法为每一类别选择出干净样本集合作为元数据集,并进行类别依赖的标签噪声转移矩阵估计;
实例特征依赖的标签噪声转移矩阵初始化模块,用于利用类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化;
元学习训练模块,用于基于统计一致性的标签噪声学习损失,采用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。
其中各模块的具体实施过程参见前文对本发明方法的描述。
本发明的另一实施例提供一种计算机设备(手机,计算机,服务器,或者网络设备等),其包括存储器和处理器,所述存储器存储计算机程序,所述计算机程序被配置为由所述处理器执行,所述计算机程序包括用于执行本发明方法中各步骤的指令。
本发明的另一实施例提供一种计算机可读存储介质(ROM/RAM、磁碟、光盘等),所述计算机可读存储介质存储计算机程序,所述计算机程序被计算机执行时,实现本发明方法的各个步骤。
下面结合图5对本发明在具体场景中的利用进行描述。
当今服装的形态属性各异,如何有效地对服装的多个属性进行识别成为一些互联网电商业务的关注点。近年来使用深度神经网络(DNN)在干净监督数据上训练模型,取得了良好效果。但是,这种训练方式在很多实际业务落地中存在困难。一方面,数据标注过程所需的人力和时间成本极高;另一方面,许多服装属性的十分复杂难辨,即使是人工标签往往也具有很大的不准确性。本发明涉及的一种带噪多标签数据的神经网络模型训练方法能够很好地解决这个问题。如图5,将服装图像数据经过特征提取网络得到的低维特征和带噪多属性标签输入本发明涉及的一种带噪多标签数据的神经网络模型训练方法,可以训练得到一个准确的神经网络模型,进而准确的属性分类和检索。
本发明的以上技术方案中,未详细描述的部分可以采用现有技术实现。
本发明方案中,样本选择算法、统计一致性损失、神经网络架构和元学习优化方式并不仅限于本发明方案中所描述的基于高斯混合模型的小损失选择算法、重加权算法、ResNet网络和SGD优化方式,而可以根据具体业务场景设计和选择具体的样本选择算法、统计一致性损失、网络架构和优化方式。本发明方案中定义的交叉熵损失函数可替换为其他通用分类损失函数。
显然,以上所描述的实施例仅仅是本发明一部分的实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都应当属于本发明保护的范围。
Claims (10)
1.一种带噪多标签数据的神经网络模型训练方法,其特征在于,包括以下步骤:
通过样本选择算法为每一类别选择出干净样本集合作为元数据集,并进行类别依赖的标签噪声转移矩阵估计;
利用类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化;
基于统计一致性的标签噪声学习损失,采用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。
2.根据权利要求1所述的方法,其特征在于,所述通过样本选择算法为每一类别选择出干净样本集合作为元数据集,包括:在噪声训练样本集上训练具有标准多标签分类损失的分类器若干个轮次,然后执行样本选择算法以获得每个类标签的选定的干净样本集合作为元数据集。
3.根据权利要求2所述的方法,其特征在于,采用以下步骤获得所述干净样本集合:
在带噪多标签数据Dt上预热训练神经网络模型f若干个轮次,其采用如下的标准多标签二分类学习损失:
其中f为多标签预测神经网络,结构为ResNet;用fj(X)表示模型对输入的X进行非线性操作得到的对第j类的预测结果,其中X表示输入特征向量;为样本的噪声标签向量,/>为第j类的噪声标签;q为总的类别数;l为二分类交叉熵损失;
通过使用高斯混合模型对每一类j的每个样本损失的分布进行建模,提取损失较小的样本子集,获得每个类标签j的干净样本集合
4.根据权利要求1所述的方法,其特征在于,所述进行类别依赖的标签噪声转移矩阵估计,包括:基于所述干净样本集合和已有带噪数据,利用标签相关性的不匹配性进行类别依赖的标签噪声转移矩阵估计。
5.根据权利要求1所述的方法,其特征在于,利用类别依赖的标签噪声转移矩阵初始化实例特征依赖的标签噪声转移矩阵的部分参数,缓解实例特征依赖的标签噪声转移矩阵网络难学习的问题。
6.根据权利要求1所述的方法,其特征在于,所述利用类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化,包括:
设转移矩阵网络的最后线性层Lj(.)的输入变量为z,并表示为Lj(z)=az+b,其中a和b为该线性层的可学习参数;
在初始化时,对于参数a采用均值为0,方差为0.01的正太分布进行初始化;对于参数b,利用估计得到的第j类的类别依赖的标签噪声转移矩阵对其进行初始化。
7.根据权利要求1所述的方法,其特征在于,所述用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数,是将学习问题转化为下列双层优化问题,并利用元学习进行训练:
其中,w*为学习到的最优数据不平衡参数,T*为学习到的最优转移矩阵,f*为学习到的最优分类器,w为可学习的数据不平衡参数,T为可学习的最优转移矩阵,f为可学习的分类器,为之前利用高斯混合模型选择得到的j类数据集/>的合集,/>为原始的带噪多标签训练集;在该双层优化问题中,外层优化利用选择的“干净”的样本集合作为元数据集学习实例特征依赖的噪声转移矩阵Tj(.)和数据不平衡参数wj,内层优化根据已学习到的噪声转移矩阵和数据不平衡参数在原始的带噪多标签训练集上利用具有统计一致性的损失学习得到神经网络分类器fj(.)的参数。
8.一种带噪多标签数据的神经网络模型训练装置,其特征在于,包括:
样本选择及类别依赖的标签噪声转移矩阵估计模块,用于通过样本选择算法为每一类别选择出干净样本集合作为元数据集,并进行类别依赖的标签噪声转移矩阵估计;
实例特征依赖的标签噪声转移矩阵初始化模块,用于利用类别依赖的标签噪声转移矩阵对实例特征依赖的标签噪声转移矩阵网络中的部分参数进行初始化;
元学习训练模块,用于基于统计一致性的标签噪声学习损失,采用元学习算法同时学习实例特征依赖的标签噪声转移矩阵网络参数、数据不平衡参数和多标签分类神经网络参数。
9.一种计算机设备,其特征在于,包括存储器和处理器,所述存储器存储计算机程序,所述计算机程序被配置为由所述处理器执行,所述计算机程序包括用于执行权利要求1~7中任一项所述方法的指令。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储计算机程序,所述计算机程序被计算机执行时,实现权利要求1~7中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310509397.6A CN116681128A (zh) | 2023-05-08 | 2023-05-08 | 一种带噪多标签数据的神经网络模型训练方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310509397.6A CN116681128A (zh) | 2023-05-08 | 2023-05-08 | 一种带噪多标签数据的神经网络模型训练方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116681128A true CN116681128A (zh) | 2023-09-01 |
Family
ID=87786288
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310509397.6A Pending CN116681128A (zh) | 2023-05-08 | 2023-05-08 | 一种带噪多标签数据的神经网络模型训练方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116681128A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117237720A (zh) * | 2023-09-18 | 2023-12-15 | 大连理工大学 | 基于强化学习的标签噪声矫正图像分类方法 |
-
2023
- 2023-05-08 CN CN202310509397.6A patent/CN116681128A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117237720A (zh) * | 2023-09-18 | 2023-12-15 | 大连理工大学 | 基于强化学习的标签噪声矫正图像分类方法 |
CN117237720B (zh) * | 2023-09-18 | 2024-04-12 | 大连理工大学 | 基于强化学习的标签噪声矫正图像分类方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111724083B (zh) | 金融风险识别模型的训练方法、装置、计算机设备及介质 | |
US11494616B2 (en) | Decoupling category-wise independence and relevance with self-attention for multi-label image classification | |
CN104298682B (zh) | 一种基于人脸表情图像的信息推荐效果的评价方法及手机 | |
CN110866140A (zh) | 图像特征提取模型训练方法、图像搜索方法及计算机设备 | |
CN112380435A (zh) | 基于异构图神经网络的文献推荐方法及推荐系统 | |
CN110598869B (zh) | 基于序列模型的分类方法、装置、电子设备 | |
CN113065409A (zh) | 一种基于摄像分头布差异对齐约束的无监督行人重识别方法 | |
CN111582506A (zh) | 基于全局和局部标记关系的偏多标记学习方法 | |
WO2022035942A1 (en) | Systems and methods for machine learning-based document classification | |
CN114255371A (zh) | 一种基于组件监督网络的小样本图像分类方法 | |
CN115687610A (zh) | 文本意图分类模型训练方法、识别方法、装置、电子设备及存储介质 | |
CN116681128A (zh) | 一种带噪多标签数据的神经网络模型训练方法和装置 | |
CN115577283A (zh) | 一种实体分类方法、装置、电子设备及存储介质 | |
Xu et al. | Weakly supervised facial expression recognition via transferred DAL-CNN and active incremental learning | |
Tahir et al. | Explainable deep learning ensemble for food image analysis on edge devices | |
CN111291705A (zh) | 一种跨多目标域行人重识别方法 | |
Okokpujie et al. | Predictive modeling of trait-aging invariant face recognition system using machine learning | |
CN113657473A (zh) | 一种基于迁移学习的Web服务分类方法 | |
Liu et al. | Iterative deep neighborhood: a deep learning model which involves both input data points and their neighbors | |
CN112270334A (zh) | 一种基于异常点暴露的少样本图像分类方法及系统 | |
Dhanalakshmi et al. | Tomato leaf disease identification by modified inception based sequential convolution neural networks | |
CN115392474A (zh) | 一种基于迭代优化的局部感知图表示学习方法 | |
Yang et al. | iCausalOSR: invertible Causal Disentanglement for Open-set Recognition | |
CN114842301A (zh) | 一种图像注释模型的半监督训练方法 | |
CN112613341A (zh) | 训练方法及装置、指纹识别方法及装置、电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |