CN116912638B - 一种多数据集的联合训练方法及终端 - Google Patents
一种多数据集的联合训练方法及终端 Download PDFInfo
- Publication number
- CN116912638B CN116912638B CN202311175320.6A CN202311175320A CN116912638B CN 116912638 B CN116912638 B CN 116912638B CN 202311175320 A CN202311175320 A CN 202311175320A CN 116912638 B CN116912638 B CN 116912638B
- Authority
- CN
- China
- Prior art keywords
- data
- training
- marking
- model
- original data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 142
- 238000000034 method Methods 0.000 title claims abstract description 47
- 238000013145 classification model Methods 0.000 claims abstract description 82
- 238000003062 neural network model Methods 0.000 claims abstract description 33
- 238000012360 testing method Methods 0.000 claims description 35
- 230000006870 function Effects 0.000 claims description 28
- 238000004590 computer program Methods 0.000 claims description 8
- 238000004806 packaging method and process Methods 0.000 claims description 5
- 230000008569 process Effects 0.000 abstract description 10
- 238000012423 maintenance Methods 0.000 abstract description 2
- 239000013598 vector Substances 0.000 description 16
- 238000002372 labelling Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 4
- 230000004913 activation Effects 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 3
- 238000004364 calculation method Methods 0.000 description 3
- 238000007796 conventional method Methods 0.000 description 2
- 239000003550 marker Substances 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 101100269850 Caenorhabditis elegans mask-1 gene Proteins 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000007635 classification algorithm Methods 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000012854 evaluation process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
Landscapes
- Image Analysis (AREA)
Abstract
本发明提供的一种多数据集的联合训练方法及终端,通过一个预设神经网络模型上同时对多个不同的数据集上进行联合训练,得到最优分类模型,以此方式避免维护多个模型,减少模型推理次数,提高模型训练效率。同时本发明根据所有所述标记数据的类型总数构建掩膜数据,使得多个数据集在进行联合训练时,无需补充标注某一数据集中未标注数据,减少数据标注的工作量;并且避免在不同数据集上生成伪标签,屏蔽训练过程中未标注数据所带来的误差,提高多数据集的联合训练模型的精度。
Description
技术领域
本发明涉及图像处理技术领域,尤其涉及一种多数据集的联合训练方法及终端。
背景技术
在进行多标签分类时,其数据集通常是由一个样例和一个集合的标签所组成的样本,该样本可能同时属于多个类别,例如一张图片中同时含有行人、自行车、小汽车等多个目标,则在数据集A中对应的目标标签为行人,在数据集B中对应的目标标签为自行车,在数据集C中对应的目标标签为小汽车,此时数据集A、B、C中的图像虽然同时包含行人、自行车以及小汽车三个目标,但是在进行标注时,每个数据集仅标注该数据集当前关注的目标。而目前实现多数据集的联合训练方法主要包括以下三种方式:
常规的方法:分别在不同的数据集上训练对应的模型,并将模型串联起来进行部署,同一个目标需要依次在多个模型上进行推理,得到对应的推理结果,最后将全部结果合并得到最终的输出;但是这种方法需要维护多个模型,且同一目标需要进行多次推理,存在大量的重复计算。
使用伪标签的方法:先使用大模型分别在不同的数据集上训练对应的分类模型(例如A模型、B模型、C模型等),然后使用训练好的大模型在其他未标注对应属性的数据上进行分类,生成伪标签;最后将标注的标签和生成的伪标签合并,即将多个数据集合并为一个数据集,再进行最终的分类模型训练,得到最终的联合分类模型;但是这种方法训练得到的大模型精度不是100%准确,在各个数据集全部生成伪标签以后,伪标签的数量远大于标注标签的数量,导致在最终的模型训练时,放大精度误差,影响最终输出的联合分类模型的精度。
半监督的训练方法:先使用半监督的方法进行模型训练,逐步增加未标注的数据,并生成对应的伪标签,将标注的标签和生成的伪标签合并,得到一个新的模型;然后再增加部分未标注数据集,生成伪标签,合并数据集,训练得到新模型;经过多次的迭代后得到最终的分类模型;但是这种方法生成的伪标签精度不可控,在数据集数量太大的情况下,伪标签的质量会严重影响最终输出的联合分类模型的精度。
发明内容
本发明所要解决的技术问题是:提供一种多数据集的联合训练方法及终端,无需维护多个模型,也无需生成伪标签,有效提高联合训练精度。
为了解决上述技术问题,本发明采用的技术方案为:
一种多数据集的联合训练方法,包括:
获取多个不同的数据集;所述数据集包括原始数据以及所述原始数据对应的标记数据;每一所述数据集对应一个标记数据的类型集合;不同数据集的所述类型集合不同;
根据所有所述标记数据的类型总数构建每一所述原始数据对应的掩膜数据;所述掩膜数据标识所述标记数据的类型在所述原始数据对应的标记数据中是否存在;
根据所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据构建训练数据集,并根据所述训练数据集训练预设神经网络模型,得到分类模型。
为了解决上述技术问题,本发明采用的另一种技术方案为:
一种多数据集的联合训练终端,包括存储器、处理器及存储在所述存储器上并在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述一种多数据集的联合训练方法中的各个步骤。
本发明的有益效果在于:通过一个预设神经网络模型上同时对多个不同的数据集上进行联合训练,得到最优分类模型,以此方式避免维护多个模型,减少模型推理次数,提高模型训练效率。同时本发明根据所有所述标记数据的类型总数构建掩膜数据,使得多个数据集在进行联合训练时,无需补充标注某一数据集中未标注数据,减少数据标注的工作量;并且避免在不同数据集上生成伪标签,屏蔽训练过程中未标注数据所带来的误差,提高多数据集的联合训练模型的精度。
附图说明
图1为本发明实施例提供的一种多数据集的联合训练方法的步骤流程图;
图2为本发明实施例提供的一种多数据集的联合训练方法的程序流程图;
图3为本发明实施例提供的一种多数据集的联合训练终端的结构示意图;
标号说明:
301、存储器;302、处理器。
具体实施方式
为详细说明本发明的技术内容、所实现目的及效果,以下结合实施方式并配合附图予以说明。
请参照图1,本发明实施例提供了一种多数据集的联合训练方法,包括:
获取多个不同的数据集;所述数据集包括原始数据以及所述原始数据对应的标记数据;每一所述数据集对应一个标记数据的类型集合;不同数据集的所述类型集合不同;
根据所有所述标记数据的类型总数构建每一所述原始数据对应的掩膜数据;所述掩膜数据标识所述标记数据的类型在所述原始数据对应的标记数据中是否存在;
根据所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据构建训练数据集,并根据所述训练数据集训练预设神经网络模型,得到分类模型。
从上述描述可知,本发明的有益效果在于:通过一个预设神经网络模型上同时对多个不同的数据集上进行联合训练,得到最优分类模型,以此方式避免维护多个模型,减少模型推理次数,提高模型训练效率。同时本发明根据所有所述标记数据的类型总数构建掩膜数据,使得多个数据集在进行联合训练时,无需补充标注某一数据集中未标注数据,减少数据标注的工作量;并且避免在不同数据集上生成伪标签,屏蔽训练过程中未标注数据所带来的误差,提高多数据集的联合训练模型的精度。
进一步的,所述根据所述训练数据集训练预设神经网络模型,得到分类模型,具体为:
将所述原始数据输入预设神经网络模型进行分类,得到模型预测结果;
根据所述模型预测结果、原始数据对应的标记数据以及原始数据对应的掩膜数据计算损失函数值;
根据所述损失函数值更新所述预设神经网络模型的参数,得到迭代中的待选分类模型;
从所有所述待选分类模型中确定分类模型。
由上述描述可知,通过模型预测结果和标注数据计算损失函数值,以更新预设神经网络模型的参数,得到优化后的初始分类模型;同时通过掩膜数据屏蔽原始数据中不存在的标记数据所造成的误差,使其无法影响到预设神经网络模型参数更新的过程,以此提高联合训练模型的精度。
进一步的,所述得到迭代中的待选分类模型之后,还包括:
判断所述训练数据集是否完成预设次数的训练,若否,则返回执行根据损失函数值更新所述预设神经网络模型的参数的步骤,得到多个迭代中的待选分类模型。
由上述描述可知,将训练数据集经过多次训练,从而得到多个迭代后的分类模型,以此提高分类模型的分类精确度,优化分类模型的训练效果。
进一步的,所述根据所述训练数据集训练预设神经网络模型之前,还包括:
将所述训练数据集分为训练集和测试集;
根据所述训练集中的所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据训练所述预设神经网络模型,得到所述分类模型;
根据所述测试集评估所述分类模型的精确度。
由上述描述可知,当在训练集上训练完成后,通过测试集验证每一个分类模型的分类精确度,以此评估各个分类模型的优劣,从而保证联合训练得到的分类模型的训练效果。
进一步的,根据所述测试集评估所述分类模型的精确度,具体为:
通过所述分类模型预测所述测试集中的每一测试数据得到模型预测结果;
获取每一所述测试数据对应的标记数据;
逐一判断所述测试数据对应的模型预测结果是否与所述标记数据相同,且所述标记数据对应的掩膜数据值是否为预设数值,若结果相同且所述掩膜数据值为预设数值,则预测正确计数增加1;
计算所述预测正确计数与所述测试集对应的所有所述标记数据的类型总数之间的比值,作为所述分类模型对应的精确度。
由上述描述可知,通过判断模型预测结果与标记数据是否相同,从而判断分类模型的标记正确个数,再计算标记正确的个数与所有标记数据的类型总数的比值作为分类模型的精确度,保证分类模型的输出维度为所有数据集的所有类型,同时简化模型评估过程,提高训练效率。
进一步的,还包括:
若所述训练数据集已完成预设次数的训练,则根据每一所述待选分类模型对应的精确度确定分类模型。
由上述描述可知,通过选择精确度最高的分类模型作为最优分类模型,以保证训练效果最佳。
进一步的,所述根据所述模型预测结果、原始数据对应的标记数据以及原始数据对应的掩膜数据计算损失函数值,具体为:
计算每一所述原始数据对应的标记数据与所述模型预测结果之间的二元交叉熵;
根据所述掩膜数据判断所述标记数据的类型在所述原始数据对应的标记数据中是否存在,若不存在则将所述标记数据对应的所述二元交叉熵置为0,得到有效二元交叉熵;
计算每一所述原始数据对应的所述有效二元交叉熵之和得到损失函数值。
由上述描述可知,通过掩膜数据屏蔽训练数据集中不存在的标记数据带来的计算损失,在保证联合训练的一个分类模型能够同时输出数据集中所有标记数据的同时,无需生成伪标签,有效避免错误伪标签所带来的精度误差;从而达到在单模型上联合训练多个不同数据集的目的,实现多个分类模型的合并训练,提高了模型训练效率。
进一步的,所述根据所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据构建训练数据集,具体为:
将每一所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据作为一条所述原始数据对应的记录
打包所有所述原始数据对应的所述记录构建训练数据集。
由上述描述可知,以此保证在模型训练过程中,能够快速且准确获取对应数据进行训练。
进一步的,所述根据所有所述标记数据的类型总数构建每一所述原始数据对应的掩膜数据,具体为:
每一所述原始数据对应的掩膜数据的单位数等于所有所述标记数据的类型总数。
由上述描述可知,以此实现一个分类模型能够同时输出数据集中所有标记数据。
请参照图3,本发明另一实施例提供了一种多数据集的联合训练终端,包括存储器、处理器及存储在所述存储器上并在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现上述的一种多数据集的联合训练方法中的各个步骤。
从上述描述可知,本发明的有益效果在于:通过一个预设神经网络模型上同时对多个不同的数据集上进行联合训练,得到最优分类模型,以此方式避免维护多个模型,减少模型推理次数,提高模型训练效率。同时本发明根据所有所述标记数据的类型总数构建掩膜数据,使得多个数据集在进行联合训练时,无需补充标注某一数据集中未标注数据,减少数据标注的工作量;并且避免在不同数据集上生成伪标签,屏蔽训练过程中未标注数据所带来的误差,提高多数据集的联合训练模型的精度。
本发明实施例提供了一种多数据集的联合训练方法及终端,可应用于多标签分类场景下,无需维护多个模型,也无需生成伪标签,有效提高多数据集的联合训练精度,以下通过具体实施例来说明:
请参照图1至图2,本发明的实施例一为:
一种多数据集的联合训练方法,包括:
S1、获取多个不同的数据集;所述数据集包括原始数据以及所述原始数据对应的标记数据;每一所述数据集对应一个标记数据的类型集合;不同数据集的所述类型集合不同。
需要说明的是,所述数据集为图像数据集,所述原始数据为图像数据,所述标记数据为所述图像数据中被标注出的属性。
在本实施例中,获取数据集A、数据集B以及数据集C。数据集A、B、C的原始数据均为人物的全身图像;数据集A标记数据的类型集合为{a1,a2,a3},其类型集合表示人物图像的3个不同属性;数据集B标记数据的类型集合为{b1,b2,b3,b4,b5},其类型集合表示人物图像的5个不同属性;数据集C标记数据的类型集合为{c1,c2,c3、c4},其类型集合表示人物图像的4个不同属性。其中,数据集A、B、C中各有数十万个标注数据,彼此标注的原始数据并不相同,且关注的属性也各不相同。
S2、根据所有所述标记数据的类型总数构建每一所述原始数据对应的掩膜数据;所述掩膜数据标识所述标记数据的类型在所述原始数据对应的标记数据中是否存在。
需要说明的是,所述掩膜数据为用于标识图像数据中所有的属性,并区分已标注的属性和未标注的属性。
所述步骤S2中:根据所有所述标记数据的类型总数构建每一所述原始数据对应的掩膜数据,具体为:
S21、每一所述原始数据对应的掩膜数据的单位数等于所有所述标记数据的类型总数。
在本实施例中,所有标记数据的类型总数为12,即数据集A标记数据的3个不同属性+数据集B标记数据的5个不同属性+数据集C标记数据的4个不同属性=所有标记数据的类型总数12。
在本实施例中,每一原始数据对应的掩膜数据的单位数等于12,即每一原始数据对应的掩膜数据为{a1,a2,a3,b2,b3,b4,b5,c1,c2,c3、c4};并且标记数据与掩膜数据中各个属性类别的排序方式也是相同,例如标记数据的第N个属性值为“是否佩戴安全帽”,那么掩膜数据的第N个属性值若为1,则代表该原始数据标注过“是否佩戴安全帽”,即该原始数据存在标记数据的类型为“是否佩戴安全帽”的标记数据;若为0则表示该原始数据没有标注这一属性,即该原始数据不存在标记数据的类型为“是否佩戴安全帽”的标记数据。
在本实施例中,所述掩膜数据标识所述标记数据的类型在所述原始数据对应的标记数据中是否存在具体为:将所述标记数据的类型存在的所述原始数据对应的标记数据设置为1,所述标记数据的类型不存在的所述原始数据对应的标记数据设置为0。例如,在数据集A中的原始数据对应的掩膜数据为{1,1,1,0,0,0,0,0,0,0,0,0},在数据集B中的原始数据对应的掩膜数据为{0,0,0,1,1,1,1,1,0,0,0,0},在数据集C中的原始数据对应的掩膜数据为{0,0,0,0,0,0,0,0,1,1,1,1}。
S3、根据所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据构建训练数据集,并根据所述训练数据集训练预设神经网络模型,得到分类模型。
所述步骤S3中:根据所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据构建训练数据集,具体为:
S31、将每一所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据作为一条所述原始数据对应的记录。
S32、打包所有所述原始数据对应的所述记录构建训练数据集。
在本实施例中,即一批数据集中的数据经过预处理后,每一条原始数据image1可得到对应的标记数据label1以及掩膜数据mask1;打包该批数据集所有原始数据对应的标记数据以及掩膜数据构建训练数据集。
所述步骤S3中:根据所述训练数据集训练预设神经网络模型,得到分类模型,具体为:
S33、将所述原始数据输入预设神经网络模型进行分类,得到模型预测结果。
在本实施例中,将原始数据集合images1输入预设神经网络模型进行分类,得到模型预测结果为preds1,其中,原始数据集合images1的数据数量为N,预设神经网络模型分类的类别数量为12,则原始数据集合images1、模型预测结果preds1、标记数据集合labels1以及掩膜数据集合masks1的数据形式如下:
原始数据集合images1的数据形式:N*channels*height*width的多维向量,channels为图像的通道数,通常情况下为3;
模型预测结果preds1、标记数据集合labels1以及掩膜数据集合masks1的数据形式均为N*12的多维向量。
S34、根据所述模型预测结果、原始数据对应的标记数据以及原始数据对应的掩膜数据计算损失函数值。
所述步骤S34,具体为:
S341、计算每一所述原始数据对应的标记数据与所述模型预测结果之间的二元交叉熵。
需要说明的是,所有原始数据对应的二元交叉熵以一个原始数据总数×标记数据的类型总数的多维向量进行表示。
在本实施例中,该原始数据集合images1对应的二元交叉熵结果以多维向量的形式进行表示,则二元交叉熵结果具体为loss_1(N*12);需要说明的是,(N*12)表示loss_1为N*12的多维向量。
S342、根据所述掩膜数据判断所述标记数据的类型在所述原始数据对应的标记数据中是否存在,若不存在则将所述标记数据对应的所述二元交叉熵置为0,得到有效二元交叉熵。
在本实施例中,若不存在则将所述标记数据对应的二元交叉熵置为0,具体为:将二元交叉熵结果集合loss_1(N*12)与原始数据集合images1对应的掩膜数据集合masks1相乘,即loss_1(N*12)*masks1,得到有效二元交叉熵集合为loss_2(N*12)。
S343、计算每一所述原始数据对应的所述有效二元交叉熵之和得到损失函数值。
在本实施例中,步骤S343具体为计算有效二元交叉熵集合loss_2(N*12)中所有元素的总和,得到损失函数值。
在本实施例中,以多标签分类算法所使用的BCEWithLogitsLoss为例,损失函数值Loss为:
Loss=torch.sum(torch.binary_cross_entropy_with_logits(preds1,labels1)*masks1)。
其中,torch.sum是用来计算多维向量所有元素之和;
torch.binary_cross_entropy_with_logits是带激活函数的二元交叉熵计算方法,主要用于计算多标签分类的算法的损失。即上述步骤S341、步骤342以及步骤S343对应为:
1)loss_1=torch.binary_cross_entropy_with_logits(preds1,labels1),用于计算模型预测结果与标记数据的二元交叉熵。
2)loss_2=loss_1*masks1,用于将掩膜数据集合masks1中为0的部分对应的二元交叉熵归零,得到有效二元交叉熵。
3)loss_3=torch.sum(loss_2),有效二元交叉熵求和。
S35、根据所述损失函数值更新所述预设神经网络模型的参数,得到迭代中的待选分类模型。
在本实施例中,所述S35具体为通过调用loss.backward()函数,进行梯度反向传播,使得预设神经网络模型的参数向着损失函数值减小的方向变化。
需要说明的是,反向传播是指模型训练在正向传播过程中,输入信息通过逐层处理并传向输出层。如果在输出层得不到期望的输出值,则通过构造输出值与真实值的损失函数作为目标函数,转入反向传播,逐层求出目标函数对各神经元权值的偏导数,构成目标函数对权值向量的梯度,作为修改权值的依据,网络的学习在权值修改过程中完成。输出值与真实值的误差达到所期望值时,网络学习结束。
S36、从所有所述待选分类模型中确定分类模型。
所述步骤S3中:根据所述训练数据集训练预设神经网络模型之前,还包括:
S301、将所述训练数据集分为训练集和测试集。
S302、根据所述训练集中的所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据训练所述预设神经网络模型,得到所述分类模型。
S303、根据所述测试集评估所述分类模型的精确度。
所述步骤S303,具体为:
S3031、通过所述分类模型预测所述测试集中的每一测试数据得到模型预测结果;
S3032、获取每一测试数据对应的标记数据;
S3033、逐一判断所述测试数据对应的模型预测结果是否与所述标记数据相同,且所述标记数据对应的掩膜数据值是否为预设数值,若结果相同且所述掩膜数据值为预设数值,则预测正确计数增加1;
在本实施例中,所述预设数值为1。
S3034、计算所述预测正确计数与所述测试集对应的所有所述标记数据的类型总数之间的比值,作为所述分类模型对应的精确度。
在本实施例中,所述步骤S3034之后还包括:保存所述分类模型对应的参数,即保存所述分类模型对应的权重值。
在本实施例中,测试集中的测试数据集合images2的数据数量为M,则模型预测结果preds2、标记数据集合labels2和掩膜数据集合masks2均为M*12的多维向量。
在本实施例中,在步骤S3032之前,需要对模型预测结果preds2进行预处理,具体为:
通过激活函数将模型预测结果preds2(M*12)向量中元素的输出转化为(0,1)之间,得到第一结果数据preds2_1(M*12);
在通过设定预设阈值将第一结果数据preds2_1(M*12)进行二值化,使其向量中的每个元素转换为0或1,得到第二结果数据preds2_2(M*12)。
在本实施例中,所述步骤S3032-步骤S3034具体为:
通过逻辑非运算和逻辑异或运算逐一计算第二结果数据preds2_2(M*12)向量中的元素与标记数据2labels(M*12)向量中的对应元素是否相等,得到第三结果数据preds2_true(M*12);
再通过逻辑与运算,逐一计算第三结果数据preds2_true(M*12)向量中的元素与掩膜数据masks2(M*12)向量中的对应元素是否相等,得到标记数据中模型预测结果预测正确的数据preds2_true_mask(M*12),即预测正确计数通过多维向量的数据形式进行表示。
最后计算preds2_true_mask(M*12)向量的元素总和与掩膜数据masks2(M*12)向量的元素总和之间的比值,作为分类模型的精确度。
所述步骤S3之后,还包括:
S4、判断所述训练数据集是否完成预设次数的训练,若否,则返回执行根据损失函数值更新所述预设神经网络模型的参数的步骤,得到多个迭代中的待选分类模型。
该方法还包括:
S5、若所述训练数据集已完成预设次数的训练,则根据每一所述待选分类模型对应的精确度确定分类模型。
请参照图2,本发明的实施例二为:
实施例一所述的一种多数据集的联合训练方法应用于实际场景中,包括:
步骤A1、获取数据集A、B。
其中,数据集A关注人物的着装属性(标记数据共有19个类别,302559个样本),数据集B关注人物的帽子和衣物的颜色属性(标记数据共有8个类别,322037个样本)。需要说明的是,所述样本即为原始数据。
将数据集A和B合并后按照9:1的比例随机划分为训练集和测试集;其中,数据集A的训练集样本272303个,测试集样本30256个,数据集B的训练集样本296613个,测试集样本25424个。
步骤A2、构建每一样本对应的掩膜数据。
将数据集A、B的类别总数扩充到27(19+8)类,其中前面19类对应数据集A的19个类别,后面8类对应数据集B的8个类别;将数据集A的每一个样本的后面8类对应的标签设置为0,数据集B的每一个样本的前面19类对应的标签设置为0。
为每一个数据集的每一个样本设置对应的掩膜数据mask,每一个样本对应的掩膜数据mask的长度为27(即数据集A、B所有标记数据的类别总数);数据集A的每一个样本的掩膜数据mask前面19位设置为1,后面8位设置为0;数据集B的每一个样本的掩膜数据mask前面19位设置为0,后面8位设置为1。
步骤A3、打包每一个样本、样本对应的标记数据以及掩膜数据作为每一个样本对应的记录,确保每一批次数据集获取到的数据集合为样本集合images、标记数据集合labels以及掩膜数据集合masks。
步骤A4、构建损失函数为:
Loss=torch.sum(torch.binary_cross_entropy_with_logits(preds1,labels1)*masks1)。
步骤A5、评估分类模型的精确度。
其中,数据集A和数据集B中有效标签的总数=数据集A所对应的测试集的样本数量乘以数据集A的类别数(即30256*19)+数据集B所对应的测试集的样本数量乘以数据集B的类别数(即25424*8)=778256个,记为Mask_Sum。数据集A和数据集B中分类模型标记正确的标签数量=测试集样本的模型预测结果preds2经过激活函数以及二值化处理后与标记数据labels2相等的样本个数,记为Correct_Num。
分类模型的精确度accuracy=Correct_Num/Mask_Sum。
步骤A6、将训练集输入预设神经网络模型进行分类训练,共训练100次,选取精确度accuracy最高的分类模型作为最优分类模型。
在本实施例中,数据集A的测试集共包含30256*19=574864个标记数据,数据集B的测试集共包含25424*8=203392个标记数据,则测试集合计778256个标记数据。通过本实施例的联合训练方法得到的最优分类模型,其模型预测结果的正确个数为765036个,由此得到最优分类模型的精确度为765036/778256=0.983。
而基于本实施例的数据集A与数据集B,通过背景技术中所述《常规方法》得到的最优分类模型进行推理,其数据集A对应的模型预测结果的正确个数为559521个,数据集B对应的模型预测结果的正确个数为202314个,则数据集A与数据集B模型预测结果的正确个数合计761835个,其模型的精确度为0.979;由此可知,本发明在保证分类模型精确度的同时,仅通过单模型即可实现多个数据集的模型训练,得到最优分类模型,在部署时只需要推理一次即可得到最终的分类结果,提高了联合训练效率,还对模型的整体精度有一定的改善效果。
请参照图3,本发明的实施例三为:
一种多数据集的联合训练终端,包括存储器301、处理器302及存储在所述存储器301上并在所述处理器302上运行的计算机程序,所述处理器302执行所述计算机程序时实现实施例一以及实施例二所述的一种多数据集的联合训练方法中的各个步骤。
综上所述,本发明提供的一种多数据集的联合训练方法及终端,通过一个预设神经网络模型上同时对多个不同的数据集上进行联合训练,得到最优分类模型,以此方式避免维护多个模型,减少模型推理次数,提高模型训练效率。同时本发明根据所有所述标记数据的类型总数构建掩膜数据,使得多个数据集在进行联合训练时,无需补充标注某一数据集中未标注数据,减少数据标注的工作量;并且避免在不同数据集上生成伪标签;此外,相较于常规的图像分类算法,本发明仅在构建训练数据集阶段增加了掩膜数据用于标识标注数据的类型是否存在,使得损失函数值计算过程中能够屏蔽训练过程中不存在的标注数据所带来的误差,提高多数据集的联合训练模型的精度,且训练框架简单易实现,工作原理简洁清晰。
以上所述仅为本发明的实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等同变换,或直接或间接运用在相关的技术领域,均同理包括在本发明的专利保护范围内。
Claims (5)
1.一种多数据集的联合训练方法,其特征在于,包括:
获取多个不同的数据集;所述数据集包括原始数据以及所述原始数据对应的标记数据;每一所述数据集对应一个标记数据的类型集合;不同数据集的所述类型集合不同;
根据所有所述标记数据的类型总数构建每一所述原始数据对应的掩膜数据;所述掩膜数据标识所述标记数据的类型在所述原始数据对应的标记数据中是否存在;
根据所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据构建训练数据集,并根据所述训练数据集训练预设神经网络模型,得到分类模型;
所述根据所述训练数据集训练预设神经网络模型,得到分类模型,具体为:
将所述原始数据输入预设神经网络模型进行分类,得到模型预测结果;
根据所述模型预测结果、原始数据对应的标记数据以及原始数据对应的掩膜数据计算损失函数值;
根据所述损失函数值更新所述预设神经网络模型的参数,得到迭代中的待选分类模型;
从所有所述待选分类模型中确定分类模型;
所述根据所述模型预测结果、原始数据对应的标记数据以及原始数据对应的掩膜数据计算损失函数值,具体为:
计算每一所述原始数据对应的标记数据与所述模型预测结果之间的二元交叉熵;
根据所述掩膜数据判断所述标记数据的类型在所述原始数据对应的标记数据中是否存在,若不存在则将所述标记数据对应的所述二元交叉熵置为0,得到有效二元交叉熵;
计算每一所述原始数据对应的所述有效二元交叉熵之和得到损失函数值;
所述根据所述训练数据集训练预设神经网络模型之前,还包括:
将所述训练数据集分为训练集和测试集;
根据所述训练集中的所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据训练所述预设神经网络模型,得到所述分类模型;
根据所述测试集评估所述分类模型的精确度;
所述根据所述测试集评估所述分类模型的精确度,具体为:
通过所述分类模型预测所述测试集中的每一测试数据得到模型预测结果;
获取每一所述测试数据对应的标记数据;
逐一判断所述测试数据对应的模型预测结果是否与所述标记数据相同,且所述标记数据对应的掩膜数据值是否为预设数值,若结果相同且所述掩膜数据值为预设数值,则预测正确计数增加1;
计算所述预测正确计数与所述测试集对应的所有所述标记数据的类型总数之间的比值,作为所述分类模型对应的精确度;
若所述训练数据集已完成预设次数的训练,则根据每一所述待选分类模型对应的精确度确定分类模型。
2.根据权利要求1所述的一种多数据集的联合训练方法,其特征在于,所述得到迭代中的待选分类模型之后,还包括:
判断所述训练数据集是否完成预设次数的训练,若否,则返回执行根据损失函数值更新所述预设神经网络模型的参数的步骤,得到多个迭代中的待选分类模型。
3.根据权利要求1所述的一种多数据集的联合训练方法,其特征在于,所述根据所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据构建训练数据集,具体为:
将每一所述原始数据、原始数据对应的标记数据以及原始数据对应的掩膜数据作为一条所述原始数据对应的记录;
打包所有所述原始数据对应的所述记录构建训练数据集。
4.根据权利要求1所述的一种多数据集的联合训练方法,其特征在于,所述根据所有所述标记数据的类型总数构建每一所述原始数据对应的掩膜数据,具体为:
每一所述原始数据对应的掩膜数据的单位数等于所有所述标记数据的类型总数。
5.一种多数据集的联合训练终端,包括存储器、处理器及存储在所述存储器上并在所述处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1-4任意一项所述的一种多数据集的联合训练方法中的各个步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311175320.6A CN116912638B (zh) | 2023-09-13 | 2023-09-13 | 一种多数据集的联合训练方法及终端 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311175320.6A CN116912638B (zh) | 2023-09-13 | 2023-09-13 | 一种多数据集的联合训练方法及终端 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116912638A CN116912638A (zh) | 2023-10-20 |
CN116912638B true CN116912638B (zh) | 2024-01-12 |
Family
ID=88358784
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311175320.6A Active CN116912638B (zh) | 2023-09-13 | 2023-09-13 | 一种多数据集的联合训练方法及终端 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116912638B (zh) |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
EP4027300A1 (en) * | 2021-01-12 | 2022-07-13 | Fujitsu Limited | Apparatus, program, and method for anomaly detection and classification |
CN115730208A (zh) * | 2021-08-27 | 2023-03-03 | Oppo广东移动通信有限公司 | 训练方法、训练装置、训练设备及计算机可读存储介质 |
CN115861617A (zh) * | 2022-12-12 | 2023-03-28 | 中国工商银行股份有限公司 | 语义分割模型训练方法、装置、计算机设备和存储介质 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11475280B2 (en) * | 2019-11-15 | 2022-10-18 | Disney Enterprises, Inc. | Data object classification using an optimized neural network |
US20220148189A1 (en) * | 2020-11-10 | 2022-05-12 | Nec Laboratories America, Inc. | Multi-domain semantic segmentation with label shifts |
-
2023
- 2023-09-13 CN CN202311175320.6A patent/CN116912638B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
EP4027300A1 (en) * | 2021-01-12 | 2022-07-13 | Fujitsu Limited | Apparatus, program, and method for anomaly detection and classification |
CN115730208A (zh) * | 2021-08-27 | 2023-03-03 | Oppo广东移动通信有限公司 | 训练方法、训练装置、训练设备及计算机可读存储介质 |
CN115861617A (zh) * | 2022-12-12 | 2023-03-28 | 中国工商银行股份有限公司 | 语义分割模型训练方法、装置、计算机设备和存储介质 |
Non-Patent Citations (1)
Title |
---|
基于模糊支持向量的多标签分类方法;郑文博等;广西大学学报(自然科学版);第36卷(第5期);第758-763页 * |
Also Published As
Publication number | Publication date |
---|---|
CN116912638A (zh) | 2023-10-20 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Sun et al. | Evolving deep convolutional neural networks for image classification | |
WO2021164382A1 (zh) | 针对用户分类模型进行特征处理的方法及装置 | |
US11657267B2 (en) | Neural network apparatus, vehicle control system, decomposition device, and program | |
Huang et al. | Unsupervised domain adaptation with background shift mitigating for person re-identification | |
Chen et al. | Hydra: Hypergradient data relevance analysis for interpreting deep neural networks | |
CN112651418B (zh) | 数据分类方法、分类器训练方法及系统 | |
Liu et al. | A lightweight and accurate recognition framework for signs of X-ray weld images | |
Chaudhuri et al. | Functional criticality analysis of structural faults in AI accelerators | |
JP2021051589A5 (zh) | ||
US20220156519A1 (en) | Methods and systems for efficient batch active learning of a deep neural network | |
CN116912638B (zh) | 一种多数据集的联合训练方法及终端 | |
CN112949778A (zh) | 基于局部敏感哈希的智能合约分类方法、系统及电子设备 | |
Shahriyar et al. | An approach for multi label image classification using single label convolutional neural network | |
CN108830302B (zh) | 一种图像分类方法、训练方法、分类预测方法及相关装置 | |
CN116342906A (zh) | 一种跨域小样本图像识别方法及系统 | |
Adiwardana et al. | Using generative models for semi-supervised learning | |
Miki et al. | Weakly supervised graph convolutional neural network for human action localization | |
Khullar et al. | Investigating efficacy of transfer learning for fruit classification | |
CN111127485B (zh) | 一种ct图像中目标区域提取方法、装置及设备 | |
CN114638845A (zh) | 一种基于双阈值的量子图像分割方法、装置及存储介质 | |
Keshinro | Image Detection and Classification: A Machine Learning Approach | |
Mahdavi et al. | Informed Decision-Making through Advancements in Open Set Recognition and Unknown Sample Detection | |
CN111753992A (zh) | 筛选方法和筛选系统 | |
KR20190050230A (ko) | 피쳐 영향 판단 방법 및 그 시스템 | |
KR102073020B1 (ko) | 피쳐 영향 판단 방법 및 그 시스템 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |