CN114445656A - 多标签模型处理方法、装置、电子设备及存储介质 - Google Patents
多标签模型处理方法、装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN114445656A CN114445656A CN202111657515.5A CN202111657515A CN114445656A CN 114445656 A CN114445656 A CN 114445656A CN 202111657515 A CN202111657515 A CN 202111657515A CN 114445656 A CN114445656 A CN 114445656A
- Authority
- CN
- China
- Prior art keywords
- label
- data
- loss value
- model
- type
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/243—Classification techniques relating to the number of classes
- G06F18/2431—Multiple classes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- Computational Linguistics (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种多标签模型处理方法、装置、电子设备及存储介质,该方法包括:对标签数据进行过采样处理;其中,所述标签数据中包括样本数据和标签,所述标签用于作为索引以指向对应的所述样本数据;将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值;根据所述第一损失值和所述第二损失值确定出目标损失值;采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数。本发明能够使得标签的类别更为均衡,提升了多标签模型的准确率和性能。
Description
技术领域
本发明涉及数据处理技术领域,具体涉及一种多标签模型处理方法、装置、电子设备及存储介质。
背景技术
多标签学习(Multi-Label Learning)是工程应用中常用的技术,例如预测人体属性时,输入一张人体图,同时预测出人体的多个属性,大幅降低了推理耗时和显存占用。然而,多标签模型的训练通常会面临一些挑战,其中,缺失标签、类别不平衡、数据增广问题是比较常见的问题。目前,常用的过采样技术能缓解类别不平衡问题,但会加剧标签缺失问题,导致对多标签模型的准确率提升不足。
发明内容
第一方面,本发明的主要目的是提供一种多标签模型处理方法,包括:
对标签数据进行过采样处理;其中,所述标签数据中包括样本数据和标签,所述标签用于作为索引以指向对应的所述样本数据;
将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;
分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值;
根据所述第一损失值和所述第二损失值确定出目标损失值;
采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数。
可选地,所述将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据包括:
确定所述标签数据中每个所述样本数据对应的标签;
判断每个所述样本数据对应的标签是否出现缺失;
当所述样本数据对应的标签没有缺失时,将所述没有缺失标签的样本数据确定为第一类标签数据;
当所述样本数据对应的标签出现缺失时,将所述出现缺失标签的样本数据确定为第二类标签数据。
可选地,所述分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值包括:
确定所述第一类标签数据中的第一样本数据和第一类标签;
基于所述多标签模型对所述第一样本数据进行预测,得到第一预测结果;
将所述第一预测结果与所述第一类标签输入损失函数,以确定出所述第一类标签对应的第一损失值。
可选地,所述分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值包括:
确定所述第二类标签数据中的第二样本数据和第二类标签;
采用第一增广策略对所述第二样本数据进行增广,得到第一增广数据;
基于所述多标签模型对所述第一增广数据进行预测,得到第二预测结果;
将所述第二预测结果转换为预定形式的伪标签;
将所述第二样本数据再次进行预测,并根据预测结果与所述伪标签确定出第二损失值。
可选地,所述将所述第二样本数据再次进行预测,并根据预测结果与所述伪标签确定出第二损失值包括:
采用第二增广策略对所述第二样本数据进行增广,得到第二增广数据;
基于所述多标签模型对所述第二增广数据进行预测,得到第三预测结果;
将所述第三预测结果和所述伪标签输入损失函数,以确定出所述第二类标签对应的第二损失值。
可选地,所述采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数包括:
采用预设算法对多标签模型进行模型参数更新;其中,所述预设算法为随机梯度下降算法;
判断所述目标损失值是否低于预设阈值;
当所述目标损失值低于预设阈值时,确定终止更新所述模型参数。
可选地,所述采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数,还包括:
当所述目标损失值不低于所述预设阈值时,将执行模型参数更新后的标签数据进行分类;
将分类后的标签数据继续执行所述损失计算的操作,得到新的目标损失值;
对新的目标损失值继续执行判断操作,根据判断结果确定是否继续对所述多标签模型执行模型参数更新、分类及损失计算的操作,或终止更新所述模型参数。
第二方面,本发明实施例提供了一种多标签模型处理装置,包括:
过采样处理模块,用于对标签数据进行过采样处理;其中,所述标签数据中包括样本数据和标签,所述标签用于作为索引以指向对应的所述样本数据;
分类模块,用于将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;
计算模块,用于分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值;
确定模块,用于根据所述第一损失值和所述第二损失值确定出目标损失值;
更新模块,用于采用预设算法对标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数。
第三方面,本发明实施例提供了一种电子设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现如上述的多标签模型处理方法的步骤。
第四方面,本发明实施例提供了一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,所述计算机程序被处理器执行时实现如上述的多标签模型处理方法的步骤。
本发明的上述方案至少包括以下有益效果:
本发明提供的多标签模型处理方法,首先对标签数据进行过采样处理;并将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;然后分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值;根据所述第一损失值和所述第二损失值确定出目标损失值;最后采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数。本发明能够使得标签的类别更为均衡,提升了多标签模型的准确率和性能。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图示出的结构获得其他的附图。
图1为本发明实施例提供的多标签模型处理方法的整体流程示意图;
图2为本发明实施例提供的步骤S20的具体流程示意图;
图3为本发明实施例提供的步骤S30的具体流程示意图;
图4为本发明实施例提供的步骤S30的另一流程示意图;
图5为本发明实施例提供的步骤S30的又一流程示意图;
图6为本发明实施例提供的步骤S50的另一流程示意图;
图7为本发明实施例提供的多标签模型的结构示意图;
图8为本发明实施例提供的多标签模型处理装置的结构框图;
图9为本发明实施例提供的电子设备的结构框图。
本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明的一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明的说明书和权利要求书及上述附图中的术语“第一”、“第二”和“第三”等是用于区别不同对象,而非用于描述特定顺序。此外,术语“包括”以及它们任何变形,意图在于覆盖不排他的包含。例如包含了一系列步骤或单元的过程、方法、系统、产品或设备没有限定于已列出的步骤或单元,而是可选地还包括没有列出的步骤或单元,或可选地还包括对于这些过程、方法、产品或设备固有的其它步骤或单元。
首先结合相关附图来举例介绍下本申请实施例的方案。
如图1所示,本发明的具体实施例提供了一种多标签模型处理方法,包括:
S10、对标签数据进行过采样处理;其中,标签数据中包括样本数据和标签,标签用于作为索引以指向对应的样本数据。
在本实施例中,标签数据为多标签模型中的多标签数据,多标签表示指的是在一个模型中可以同时输出多个标签,例如,一个模型可以同时预测人的年龄、是否拎包,那这个模型就是二标签模型;如果设定年龄可能的取值有小孩、成人、老人,候选的这些取值称为类,因此上述年龄标签有三个类;其中,在收集到多标签数据之后,其中某些类的数量少,另外一些类的数量多,如果出现少数类:多数类<1:4时,就称为类别不平衡,因此,需要对标签数据进行过采样处理,以实现标签数据的类别平衡。
如下表所示,在表中的1-10号样本表示原始收集的样本数据,此时,年龄标签中小孩:成人:老人=1:8:1,拎包标签中否:是=1:8。其中,第5号样本数据中拎包标签为-1表示标签缺失,表示为打标签的过程中没有进行标注,因此需要在标签内进行过采样处理,例如,复制9号样本数据及标签,就可以得到两个小孩的样本数据,复制10号样本数据及标签,就可以得到两个老人的样本数据,但是为了避免影响是否拎包的标签,通常是将是否拎包标签标记为缺失(-1);同时复制1号样本数据及标签,就可以得到拎包的样本数据和标签,同时将年龄标签标记为缺失(-1),类似拎包标签的类别不平衡处理。经过过采样处理之后,则年龄标签中小孩:成人:老人=2:8:2,是否拎包标签中否:是=2:8,如此实现了标签数据的类别平衡。
经过上述过采样处理之后,可以实现标签数据的类别平衡,但是标签中出现了很多-1标签,一部分本来就是缺失的,另一部分是过采样产生的,当标签个数(例如实际的人体属性模型输出的标签个数有23个)很多的时候,这种-1标签会占比会更多,而这种-1标签在训练的过程中损失是置为0的,因此可能会噪声损失不稳定,影响训练效果;因此,在过采样处理后,可以将缺失标签和非缺失标签进行分类筛选,以分别计算对应的损失值。
S20、将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据。
在本实施例中,第一类标签数据可以表示为有效标签数据,第二类标签数据可以表示为缺失标签数据,在对标签数据进行分类时,可以依次对每个标签数据进行判断,并在每个标签数据的某一个标签出现缺失时,则对应划分为第一类标签数据和第二类标签数据,例如在上表的13号样本数据中,小孩对应有1个-1标签,则可以将其划分至第二类标签数据中,10号样本数据中,老人对应有1个1标签,则可以将其划分至第一类标签数据中,因此,通过对每个样本数据对应的标签进行判断,并将对应的样本数据及标签进行划分,如此可以确定出第一类标签数据和第二类标签数据。
如图2所示,上述步骤S20的具体实现方式包括:
S21、确定标签数据中每个样本数据对应的标签;
S22、判断每个样本数据对应的标签是否出现缺失;
S23、当样本数据对应的标签没有缺失时,将没有缺失标签的样本数据确定为第一类标签数据;
S24、当样本数据对应的标签出现缺失时,将出现缺失标签的样本数据确定为第二类标签数据。
其中,每个样本数据具有与其对应的标签,当标签没有缺失时,可以将其划分为第一类标签数据;当标签出现缺失时,可以将其划分为第二类标签数据;也就是说,在确定第一类标签数据和第二类标签数据时,主要是通过标签进行判断,进而对样本数据进行划分。
举例来说,在上表中,样本数据1的年龄标签的类取值为成人1,拎包标签的类取值为是1,则样本数据1可以划分为第一类标签数据;样本数据2的年龄标签的类取值为成人1,拎包标签的类取值为否1,则样本数据2可以划分为第一类标签数据;因此,在第一类标签数据中确定有样本数据1及其对应的年龄标签和拎包标签,以及样本数据2及其对应的年龄标签和拎包标签;其次,样本数据5的年龄标签的类取值为成人1,拎包标签的类取值为否-1、是-1,则样本数据5可以划分为第二类标签数据;因此,在第二类标签数据中确定有样本数据5及其对应的标签。
S30、分别对第一类标签数据和第二类标签数据进行损失计算,得到第一损失值和第二损失值。
在本实施例中,可以采用损失函数进行对第一类标签数据和第二类标签数据进行损失计算,损失函数包括平方损失函数、指数损失函数、感知损失函数及交叉熵损失函数等,损失函数是将随机事件或其有关随机变量的取值映射为非负实数以表示该随机事件的“风险”或“损失”的函数;可选地,可以采用交叉熵损失函数对第一类标签数据和第二类标签数据进行计算,交叉熵损失函数表示实际输出(概率)与期望输出(概率)的距离,也就是交叉熵的值越小,两个概率分布就越接近;因此,第一损失值和第二损失值越低时,则表示第一类标签数据和第二类标签数据对应的缺失标签越少,那么对应的多标签模型的精度则更准确。
如图3所示,上述步骤S30的具体实现方式包括:
S31、确定第一类标签数据中的第一样本数据和第一类标签;
S32、基于多标签模型对第一样本数据进行预测,得到第一预测结果;
S33、将第一预测结果与第一类标签输入损失函数,以确定出第一类标签对应的第一损失值。
其中,多标签模型可以是人体多标签模型,每个分支对应一个标签,例如图7中的年龄、性别、头发颜色、头发长度等标签,并且所有分支均共享一个主干网络,主干网络可以是ResNet、MobileNet等,优先使用ResNet;通过多标签模型对第一样本数据进行预测,在对第一样本数据进行预测时,可以将第一样本数据转换为数组,该数组可以是多维数组,并将第一样本数据输入多标签模型中,使多标签模型通过对应的标签对第一样本数据进行预测,从而得到第一预测结果,在确定得到第一预测结果后,将第一预测结果和第一类标签输入交叉熵损失函数进行计算,从而计算得到第一损失值。
举例来说,在对上表中样本数据9进行预测时,可以将样本数据9转换为多维数组,并将多维数组输入多标签模型进行预测,得到的结果有(小孩,0.98),(老人,0.01),(成人,0.01),由此,可以确定出样本数据9为小孩;可以理解的是,在对第一样本数据进行预测以确定其对应的第一预测结果,并和第一类标签通过交叉熵损失函数进行计算,进而可以确定出第一损失值,同时可以将第一损失值设定为L1。
如图4所示,上述步骤S30的具体实现方式包括:
S31、确定第二类标签数据中的第二样本数据和第二类标签;
S32、采用第一增广策略对第二样本数据进行增广,得到第一增广数据;
S33、基于多标签模型对第一增广数据进行预测,得到第二预测结果;
S34、将第二预测结果转换为预定形式的伪标签;
S35、将第二样本数据再次进行预测,并根据预测结果与伪标签确定出第二损失值。
在本实施例中,第一增广策略为弱数据增广,弱数据增广可以表示在模型中采用水平翻转、小幅度地图像亮度、对比度、饱和度调节等方式进行增广,通过对第二样本数据进行增广,并通过多标签模型对第一增广数据进行预测;同理,在对第一增广数据进行预测时,可以将第一增广数据转换为NumPy数组,并将第一增广数据输入多标签模型中,使多标签模型通过对应的标签对第一增广数据进行预测,从而得到第二预测结果;可以理解的是,通过对第二样本数据进行增广,可以提升模型的性能。
其中,预定形式可以是One-Hot形式,One-Hot形式表示将类别变量转换为机器学习算法易于利用的一种形式的过程,这个向量的表示为一项属性的特征向量,也就是同一时间只有一个激活点(不为0),这个向量只有一个特征是不为0的,其他都是0;例如,第二预测结果为<0.2,0.8>的向量,那么转换为One-Hot形式之后的伪标签应该为<0,1>的向量,也就是将最大的转换为1,其他都转换为0;并且,伪标签可以表示为利用在已标注数据所训练的模型在未标注的数据上进行预测;如此,通过将第二预测结果转换为One-Hot形式的伪标签,进而提升预测结果的一致性,避免破坏类别平衡。
如图5所示,上述步骤S35的具体实现方式包括:
S351、采用第二增广策略对第二样本数据进行增广,得到第二增广数据;
S352、基于多标签模型对第二增广数据进行预测,得到第三预测结果;
S353、将第三预测结果和伪标签输入损失函数,以确定出第二类标签对应的第二损失值。
在本实施例中,可以在得到伪标签之后,再次对第二样本数据进行增广,第二增广策略为强数据增广,强数据增广可以表示在模型中采用水平翻转,大幅度地图像亮度、对比度、饱和度调节,高斯模型,灰色增广,随机擦除等方式进行增广,在得到第二增广数据后,可以通过多标签模型对第二增广数据再次进行预测,从而得到第三预测结果,在确定出第三预测结果后,可以将第三预测结果和伪标签输入交叉熵损失函数进行计算,进而可以确定出第二损失值,同时可以将第二损失值设定为L2。
S40、根据第一损失值和第二损失值确定出目标损失值。
其中,第一损失值为L1,第二损失值为L2,可以将第一损失值和第二损失值进行求和,进而可以确定出目标损失值;可以理解的是,目标损失值表示为多标签模型的损失值,在目标损失值比较大时,会导致多标签模型的训练中会出现类别不平衡等问题,因此,可以对目标损失值进行判断,进而在后续对多标签模型进行参数更新。
S50、采用预设算法对多标签模型进行模型参数更新,并在目标损失值满足预设终止条件时,终止更新模型参数。
如图6所示,上述步骤S50的具体实现方式包括:
S51、采用预设算法对多标签模型进行模型参数更新;其中,预设算法为随机梯度下降算法;
S52、判断目标损失值是否低于预设阈值;
S53、当目标损失值低于预设阈值时,确定终止更新模型参数;
S54、当目标损失值不低于预设阈值时,将执行模型参数更新后的标签数据进行分类;
S55、将分类后的标签数据继续执行损失计算的操作,得到新的目标损失值;
S56、对新的目标损失值继续执行判断操作,根据判断结果确定是否继续对多标签模型执行模型参数更新、分类及损失计算的操作,或终止更新模型参数。
在本实施例中,在确定出目标损失值后,可以先判断目标损失值是否低于预设阈值,然后根据判断结果进行模型参数更新或停止模型参数更新,不低于预设阈值表示大于或等于预设阈值,低于预设阈值表示小于预设阈值;例如,在首次得到目标损失值后,由于缺失标签较多,因此可以先采用预设算法对多标签模型进行模型参数更新,然后再判断目标损失值是否低于预设阈值;可以理解的是,对目标损失值的判断可以在模型参数更新后,也可以在模型参数更新前进行判断,当然,在采用预设算法进行模型参数更新一轮后,则可以在模型参数更新前进行判断;其中,随机梯度下降法表示一轮迭代只用一条随机选取的数据,通过随机梯度下降法可以使得反馈时间更快;因此,在目标损失值低于预设阈值时,则可以停止更新模型参数;在目标损失值不低于预设阈值时,可以再次重复对多标签模型进行更新,并将标签数据进行分类、损失计算及目标损失值判断的操作,直至目标损失值低于预设阈值时,则可以停止更新多标签模型的模型参数,使得更新后的多标签模型类别更为均衡且增加了数据增广信息,显著提升了模型的性能。
本发明提供的多标签模型处理方法,首先对标签数据进行过采样处理;并将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;然后分别对第一类标签数据和第二类标签数据进行损失计算,得到第一损失值和第二损失值;根据第一损失值和第二损失值确定出目标损失值;最后采用预设算法对多标签模型进行模型参数更新,并在目标损失值满足预设终止条件时,终止更新模型参数。本发明能够使得标签的类别更为均衡,提升了多标签模型的准确率和性能。
如图8所示,本发明实施例提供了一种多标签模型处理装置10,包括:
过采样处理模块11,用于对标签数据进行过采样处理;其中,标签数据中包括样本数据和标签,标签用于作为索引以指向对应的样本数据;
分类模块12,用于将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;
计算模块13,用于分别对第一类标签数据和第二类标签数据进行损失计算,得到第一损失值和第二损失值;
确定模块14,用于根据第一损失值和第二损失值确定出目标损失值;
更新模块15,用于采用预设算法对标签模型进行模型参数更新,并在目标损失值满足预设终止条件时,终止更新模型参数。
本发明提供的多标签模型处理装置,首先对标签数据进行过采样处理;并将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;然后分别对第一类标签数据和第二类标签数据进行损失计算,得到第一损失值和第二损失值;根据第一损失值和第二损失值确定出目标损失值;最后采用预设算法对多标签模型进行模型参数更新,并在目标损失值满足预设终止条件时,终止更新模型参数。本发明能够使得标签的类别更为均衡,提升了多标签模型的准确率和性能。
需要说明的是,本发明具体实施例提供的多标签模型处理装置10为与上述多标签模型处理方法对应的装置,上述多标签模型处理方法的所有实施例均适用于该多标签模型处理装置10,上述多标签模型处理装置10实施例中均有相应的模块对应上述多标签模型处理方法中的步骤,能达到相同或相似的有益效果,为避免过多重复,在此不对多标签模型处理装置2中的每一模块进行过多赘述。
如图9所示,本发明的具体实施例还提供了一种电子设备20,包括存储器202、处理器201以及存储在存储器202中并可在处理器201上运行的计算机程序,该处理器201执行计算机程序时实现上述的多标签模型处理方法的步骤。
具体的,处理器201用于调用存储器202存储的计算机程序,执行如下步骤:
对标签数据进行过采样处理;其中,所述标签数据中包括样本数据和标签,所述标签用于作为索引以指向对应的所述样本数据;
将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;
分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值;
根据所述第一损失值和所述第二损失值确定出目标损失值;
采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数。
可选的,处理器201执行的将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据包括:
确定所述标签数据中每个所述样本数据对应的标签;
判断每个样本数据对应的标签是否出现缺失;
当样本数据对应的标签没有缺失时,将没有缺失标签的样本数据确定为第一类标签数据;
当样本数据对应的标签出现缺失时,将出现缺失标签的样本数据确定为第二类标签数据。
可选的,处理器201执行的分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值包括:
确定所述第一类标签数据中的第一样本数据和第一类标签;
基于所述多标签模型对所述第一样本数据进行预测,得到第一预测结果;
将所述第一预测结果与所述第一类标签输入损失函数,以确定出所述第一类标签对应的第一损失值。
可选的,处理器201执行的分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值,还包括:
确定所述第二类标签数据中的第二样本数据和第二类标签;
采用第一增广策略对所述第二样本数据进行增广,得到第一增广数据;
基于所述多标签模型对所述第一增广数据进行预测,得到第二预测结果;
将所述第二预测结果转换为预定形式的伪标签;
将第二样本数据再次进行预测,并根据预测结果与伪标签确定出第二损失值。
可选的,处理器201执行的分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值,还包括:
采用第二增广策略对所述第二样本数据进行增广,得到第二增广数据;
基于所述多标签模型对所述第二增广数据进行预测,得到第三预测结果;
将所述第三预测结果和所述伪标签输入损失函数,以确定出所述第二类标签对应的第二损失值。
可选的,处理器201执行的采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数包括:
采用预设算法对多标签模型进行模型参数更新;其中,所述预设算法为随机梯度下降算法;
判断所述目标损失值是否低于预设阈值;
当所述目标损失值低于预设阈值时,确定终止更新所述模型参数。
可选的,处理器201执行的采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数,还包括:
当所述目标损失值不低于所述预设阈值时,将执行模型参数更新后的标签数据进行分类;
将分类后的标签数据继续执行所述损失计算的操作,得到新的目标损失值;
对新的目标损失值继续执行判断操作,根据判断结果确定是否继续对所述多标签模型执行模型参数更新、分类及损失计算的操作,或终止更新所述模型参数。
即,在本发明的具体实施例中,电子设备20的处理器201执行计算机程序时实现上述多标签模型处理方法的步骤,由此能够使得标签的类别更为均衡,提升了多标签模型的准确率和性能。
需要说明的是,由于电子设备20的处理器201执行计算机程序时实现上述多标签模型处理方法的步骤,因此上述多标签模型处理方法的所有实施例均适用于该电子设备20,且均能达到相同或相似的有益效果。
本发明实施例中提供的计算机可读存储介质,计算机可读存储介质上存储有计算机程序,该计算机程序被处理器执行时实现本发明实施例提供的多标签模型处理方法或应用端多标签模型处理方法的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法的实施例的流程。其中,所述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)或随机存取存储器(Random AccessMemory,简称RAM)等。
在本说明书的描述中,参考术语“一个实施例”、“一些实施例”、“示例”、“具体示例”、或“一些示例”等的描述意指结合该实施例或示例描述的具体特征、结构、材料或者特点包含于本发明的至少一个实施例或示例中。在本说明书中,对上述术语的示意性表述不一定指的是相同的实施例或示例。而且,描述的具体特征、结构、材料或者特点可以在任何的一个或多个实施例或示例中以合适的方式结合。
以上所述仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是在本发明的构思下,利用本发明说明书及附图内容所作的等效结构变换,或直接/间接运用在其他相关的技术领域均包括在本发明的专利保护范围内。
Claims (10)
1.一种多标签模型处理方法,其特征在于,包括:
对标签数据进行过采样处理;其中,所述标签数据中包括样本数据和标签,所述标签用于作为索引以指向对应的所述样本数据;
将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;
分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值;
根据所述第一损失值和所述第二损失值确定出目标损失值;
采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数。
2.根据权利要求1所述的多标签模型处理方法,其特征在于,所述将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据包括:
确定所述标签数据中每个所述样本数据对应的标签;
判断每个所述样本数据对应的标签是否出现缺失;
当所述样本数据对应的标签没有缺失时,将所述没有缺失标签的样本数据确定为第一类标签数据;
当所述样本数据对应的标签出现缺失时,将所述出现缺失标签的样本数据确定为第二类标签数据。
3.根据权利要求2所述的多标签模型处理方法,其特征在于,所述分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值包括:
确定所述第一类标签数据中的第一样本数据和第一类标签;
基于所述多标签模型对所述第一样本数据进行预测,得到第一预测结果;
将所述第一预测结果与所述第一类标签输入损失函数,以确定出所述第一类标签对应的第一损失值。
4.根据权利要求2所述的多标签模型处理方法,其特征在于,所述分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值包括:
确定所述第二类标签数据中的第二样本数据和第二类标签;
采用第一增广策略对所述第二样本数据进行增广,得到第一增广数据;
基于所述多标签模型对所述第一增广数据进行预测,得到第二预测结果;
将所述第二预测结果转换为预定形式的伪标签;
将所述第二样本数据再次进行预测,并根据预测结果与所述伪标签确定出第二损失值。
5.根据权利要求4所述的多标签模型处理方法,其特征在于,所述将所述第二样本数据再次进行预测,并根据预测结果与所述伪标签确定出第二损失值包括:
采用第二增广策略对所述第二样本数据进行增广,得到第二增广数据;
基于所述多标签模型对所述第二增广数据进行预测,得到第三预测结果;
将所述第三预测结果和所述伪标签输入损失函数,以确定出所述第二类标签对应的第二损失值。
6.根据权利要求3所述的多标签模型处理方法,其特征在于,所述采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数包括:
采用预设算法对多标签模型进行模型参数更新;其中,所述预设算法为随机梯度下降算法;
判断所述目标损失值是否低于预设阈值;
当所述目标损失值低于预设阈值时,确定终止更新所述模型参数。
7.根据权利要求6所述的多标签模型处理方法,其特征在于,所述采用预设算法对多标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数,还包括:
当所述目标损失值不低于所述预设阈值时,将执行模型参数更新后的标签数据进行分类;
将分类后的标签数据继续执行所述损失计算的操作,得到新的目标损失值;
对新的目标损失值继续执行判断操作,根据判断结果确定是否继续对所述多标签模型执行模型参数更新、分类及损失计算的操作,或终止更新所述模型参数。
8.一种多标签模型处理装置,其特征在于,包括:
过采样处理模块,用于对标签数据进行过采样处理;其中,所述标签数据中包括样本数据和标签,所述标签用于作为索引以指向对应的所述样本数据;
分类模块,用于将过采样处理后的标签数据分类,以得到第一类标签数据和第二类标签数据;
计算模块,用于分别对所述第一类标签数据和所述第二类标签数据进行损失计算,得到第一损失值和第二损失值;
确定模块,用于根据所述第一损失值和所述第二损失值确定出目标损失值;
更新模块,用于采用预设算法对标签模型进行模型参数更新,并在所述目标损失值满足预设终止条件时,终止更新所述模型参数。
9.一种电子设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1至7任一项所述的多标签模型处理方法的步骤。
10.一种计算机可读存储介质,所述计算机可读存储介质存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至7任一项所述的多标签模型处理方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111657515.5A CN114445656A (zh) | 2021-12-30 | 2021-12-30 | 多标签模型处理方法、装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111657515.5A CN114445656A (zh) | 2021-12-30 | 2021-12-30 | 多标签模型处理方法、装置、电子设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114445656A true CN114445656A (zh) | 2022-05-06 |
Family
ID=81364880
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111657515.5A Pending CN114445656A (zh) | 2021-12-30 | 2021-12-30 | 多标签模型处理方法、装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114445656A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117975171A (zh) * | 2024-03-29 | 2024-05-03 | 南京大数据集团有限公司 | 面向标签不完全和不平衡的多标签学习方法及系统 |
-
2021
- 2021-12-30 CN CN202111657515.5A patent/CN114445656A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117975171A (zh) * | 2024-03-29 | 2024-05-03 | 南京大数据集团有限公司 | 面向标签不完全和不平衡的多标签学习方法及系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
KR102641116B1 (ko) | 데이터 증강에 기초한 인식 모델 트레이닝 방법 및 장치, 이미지 인식 방법 및 장치 | |
US8331655B2 (en) | Learning apparatus for pattern detector, learning method and computer-readable storage medium | |
CN111310814A (zh) | 利用不平衡正负样本对业务预测模型训练的方法及装置 | |
KR20210062687A (ko) | 이미지 분류 모델 훈련 방법, 이미지 처리 방법 및 장치 | |
CN111079780A (zh) | 空间图卷积网络的训练方法、电子设备及存储介质 | |
JP6926934B2 (ja) | 分類タスクの複雑度の評価装置及び方法 | |
CN114492279A (zh) | 一种模拟集成电路的参数优化方法及系统 | |
CN113537630A (zh) | 业务预测模型的训练方法及装置 | |
CN111046949A (zh) | 一种图像分类方法、装置及设备 | |
CN112966754A (zh) | 样本筛选方法、样本筛选装置及终端设备 | |
CN112420125A (zh) | 分子属性预测方法、装置、智能设备和终端 | |
CN113449840A (zh) | 神经网络训练方法及装置、图像分类的方法及装置 | |
CN111901594A (zh) | 面向视觉分析任务的图像编码方法、电子设备及介质 | |
CN114445656A (zh) | 多标签模型处理方法、装置、电子设备及存储介质 | |
US20240095529A1 (en) | Neural Network Optimization Method and Apparatus | |
CN113887655A (zh) | 模型链回归预测方法、装置、设备及计算机存储介质 | |
KR102152081B1 (ko) | 딥러닝 기반의 가치 평가 방법 및 그 장치 | |
CN110059743B (zh) | 确定预测的可靠性度量的方法、设备和存储介质 | |
CN115713669B (zh) | 一种基于类间关系的图像分类方法、装置、存储介质及终端 | |
CN114282684A (zh) | 训练用户相关的分类模型、进行用户分类的方法及装置 | |
CN111523308B (zh) | 中文分词的方法、装置及计算机设备 | |
CN113934813A (zh) | 一种样本数据划分的方法、系统、设备及可读存储介质 | |
CN112861601A (zh) | 生成对抗样本的方法及相关设备 | |
CN116912920B (zh) | 表情识别方法及装置 | |
US20220398833A1 (en) | Information processing device, learning method, and recording medium |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |