CN111144451B - 一种图像分类模型的训练方法、装置及设备 - Google Patents
一种图像分类模型的训练方法、装置及设备 Download PDFInfo
- Publication number
- CN111144451B CN111144451B CN201911264127.3A CN201911264127A CN111144451B CN 111144451 B CN111144451 B CN 111144451B CN 201911264127 A CN201911264127 A CN 201911264127A CN 111144451 B CN111144451 B CN 111144451B
- Authority
- CN
- China
- Prior art keywords
- sample set
- image sample
- image
- training
- classification model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/251—Fusion techniques of input or preprocessed data
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Image Analysis (AREA)
Abstract
本申请公开一种图像分类模型的训练方法、装置及设备,所述方法包括:基于有标签图像样本集和无标签图像样本集,构建融合图像样本集;对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集;通过基于所述融合图像样本集与所述增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行半监督迭代训练,得到经过训练的图像分类模型。本申请能够在有监督学习训练的基础上,基于一致性损失量继续对图像分类模型模型进行半监督迭代训练,最终得到训练精度较高的图像分类模型。
Description
技术领域
本申请涉及机器学习领域,具体涉及一种图像分类模型的训练方法、装置及设备。
背景技术
深度学习已经广泛应用于各个领域,在训练大型深层神经网络方面之所以能取得成功,很大程度上要归功于大量有标签数据集的存在。对于某些领域,如病理图像的分类领域,收集有标签图像样本的成本较高,而且非常耗时,通常需要从多个专家的结论中得出有标签图像样本,相比之下,无标签图像样本更容易获取。
半监督学习是有监督学习与无监督学习相结合的一种学习方法。半监督学习在使用有标签数据的同时,大量使用无标签数据,在很大程度上缓解了对有标签数据的大量需求。因此,半监督学习目前正越来越受到人们的重视。
目前,对于图像分类模型的训练一般采用半监督学习,具体的,首先利用有标签数据集对图像分类模型进行有监督学习训练,得到预训练模型。然后,在有监督学习训练的基础上,对无标签数据进行处理,得到无标签数据的预测值,进而利用带有预测值的无标签数据继续对预训练模型进行训练,最终完成对模型的训练。但是,如果预训练模型对无标签数据的预测值是可信的,但实际是错误的,则错误的数据会被用于图像分类模型的训练,导致影响图像分类模型的训练精度。
发明内容
有鉴于此,本申请提供了一种图像分类模型的训练方法,能够在有监督学习训练的基础上,基于一致性损失量继续对图像分类模型模型进行半监督迭代训练,最终得到训练精度较高的图像分类模型。
第一方面,为实现上述发明目的,本申请提供了一种图像分类模型的训练方法,所述方法包括:
基于有标签图像样本集和无标签图像样本集,构建融合图像样本集;
对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集;
通过基于所述融合图像样本集与所述增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行半监督迭代训练,得到经过训练的图像分类模型。
一种可选的实施方式中,所述通过基于所述融合图像样本集与所述增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行半监督迭代训练,得到经过训练的图像分类模型,包括:
基于所述融合图像样本集与所述增量图像样本集,计算一致性损失量;
基于所述一致性损失量,确定是否继续对所述图像分类模型进行半监督训练;如果是,则继续执行所述基于有标签图像样本集和无标签图像样本集,构建融合图像样本集的步骤,以对所述图像分类模型进行迭代训练;如果否,则输出经过训练的图像分类模型。
一种可选的实施方式中,所述基于所述一致性损失量,确定是否继续对经过有监督学习训练的图像分类模型进行半监督训练之前,还包括:
计算有监督学习损失量;
相应的,所述基于所述一致性损失量,确定是否继续对经过有监督学习训练的图像分类模型进行半监督训练,包括:
基于所述有监督学习损失量和所述一致性损失量,确定是否继续对经过有监督学习训练的图像分类模型进行半监督训练。
一种可选的实施方式中,所述基于所述融合图像样本集与所述增量图像样本集,计算一致性损失量,包括:
确定所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集;
基于所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集,计算一致性损失量。
一种可选的实施方式中,所述计算有监督学习损失量,包括:
对所述有标签图像样本集中的各个图像样本进行增量处理,得到有标签增量图像样本集;
确定所述有标签增量图像样本集对应的领域随机分布图像样本集;其中,所述领域随机分布图像样本集中的各个图像样本的标签为基于所述有标签增量图像样本集中对应的图像样本的标签确定;
基于所述有标签增量图像样本集对应的领域随机分布图像样本集,计算有监督学习损失量。
第二方面,本申请提供了一种图像分类模型的训练装置,所述装置包括:
构建模块,用于基于有标签图像样本集和无标签图像样本集,构建融合图像样本集;
增量模块,用于对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集;
训练模块,用于通过基于所述融合图像样本集与所述增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行半监督迭代训练,得到经过训练的图像分类模型。
一种可选的实施方式中,所述训练模块,包括:
第一计算子模块,用于基于所述融合图像样本集与所述增量图像样本集,计算一致性损失量;
第一确定子模块,用于基于所述一致性损失量,确定是否继续对所述图像分类模型进行半监督训练;
触发子模块,用于在所述第一确定子模块的结果为是时,触发所述构建模块,以对所述图像分类模型进行迭代训练;
输出子模块,用于在所述第一确定子模块的结果为否时,输出经过训练的图像分类模型。
一种可选的实施方式中,所述装置还包括:
计算模块,用于计算有监督学习损失量;
相应的,所述第一确定子模块,具体用于:
基于所述有监督学习损失量和所述一致性损失量,确定是否继续对经过有监督学习训练的图像分类模型进行半监督训练。
第三方面,本申请还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当所述指令在终端设备上运行时,使得所述终端设备实现如上述权利要求任一项所述的方法。
第四方面,本申请还提供了一种设备,包括:存储器,处理器,及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时,实现上述权利要求任一项所述的方法。
本申请提供的图像分类模型的训练方法中,预先采用有监督学习对图像分类模型进行训练,在此基础上,利用包含有标签图像样本和无标签图像样本的融合图像样本集,进一步对图像分类模型进行训练,具体的,采用计算一致性损失量的方式对图像分类模型进行迭代训练,最终得到经过训练的图像分类模型。由于对于无标记图像样本而言,一致性损失量的计算只是计算自身偏差量,不会将错误数据引入模型训练中,从而提高了图像分类模型的训练精度。
进一步的,本申请还可以采用领域风险最小化的方法,计算有监督学习损失量和一致性损失量,基于计算得到的有监督学习损失量与一致性损失量之和,对图像分类模型进行半监督迭代训练,进一步提高了图像分类模型的训练精度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本申请实施例提供的一种图像分类模型的训练方法流程图;
图2为本申请实施例提供的另一种图像分类模型的训练方法流程图;
图3为本申请实施例提供的一种图像分类模型的训练装置的结构示意图;
图4为本申请实施例提供的一种图像分类模型的训练设备的结构图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
在图像分类模型的训练领域,由于具有标签的图像样本的获取成本较高且非常耗时,而无标签的图像样本相对更容易获取,因此,基于数量相对不足的有标签图像样本和足够数量的无标签图像样本,如何在保证模型训练精度的前提下实现对图像分类模型的训练,是目前急需解决的问题。
半监督学习模型训练是有监督学习和无监督学习相结合的一种模型训练方法,由于半监督学习模型训练可以基于数量不足的有标签图像样本和足够数量的无标签图像样本,实现对图像分类模型的训练,因此,半监督学习越来越多的被用于图像分类模型的训练中。
现有的一种半监督学习模型训练方法中,首先利用有标签图像样本对图像分类模型进行有监督学习训练,得到预训练模型,然后利用预训练模型对无标签图像样本进行处理,得到各个无标签图像样本的预测值,进而利用具有预测值的无标签图像样本对预训练模型进行进一步训练,最终得到经过训练的图像分类模型。
但是,如果预训练模型对无标签图像样本的预测值是可信的,但实际是错误的,则错误的数据会被用于图像分类模型的训练中,导致影响图像分类模型的训练精度。因此,上述半监督学习模型训练方法存在问题。
基于此,本申请提供了一种图像分类模型的训练方法,预先采用有监督学习对图像分类模型进行训练,在此基础上,利用包含有标签图像样本和无标签图像样本的融合图像样本集,进一步对图像分类模型进行训练,具体的,采用计算一致性损失量的方式对图像分类模型进行迭代训练,最终得到经过训练的图像分类模型。由于对于无标记图像样本而言,一致性损失量的计算只是计算自身偏差量,不会将上述方法中的错误数据引入模型训练中,从而提高了图像分类模型的训练精度。
以下本申请提供了一种图像分类模型的训练方法,参考图1,为本申请实施例提供的一种图像分类模型的训练方法流程图,所述方法包括:
S101:基于有标签图像样本集和无标签图像样本集,构建融合图像样本集。
本申请实施例中,在获取到有标签图像样本集之后,首先基于有标签图像样本集对图像分类模型进行有监督学习训练,其中,图像分类模型通常为深度卷积神经网络模型。
具体的,在进行有监督学习训练之前,首先对有标签图像样本集中各个图像样本进行预处理,例如包括标准归一化处理、白化处理等,另外,为了丰富样本的多样化,还可以采用数据增量方法,如翻转、随机旋转、随机切割等数据增量方法对有标签图像样本集中的各个图像样本进行处理,得到增量图像样本,具体的,增量图像样本沿用原图像样本的标签,用于有监督学习模型训练中,以增加有标签图像样本的数量,提高有监督学习训练的精度。
本申请实施例中,在利用有标签图像样本集对图像分类模型进行有监督学习训练后,得到已完成初步训练的预训练模型,该预训练模型作为后续半监督学习模型训练的基础。
对于无标签图像样本集,为了保证用于图像分类模型训练的标签的有效性,本申请实施例可以从无标签图像样本中筛选出更适合用于图像分类模型训练的图像样本。具体的,利用已完成初步训练的预训练模型对无标签图像样本集中的各个图像样本进行预测,得到各个图像样本的预测结果,其中,预测结果包括图像样本所属的类别和概率值。由于概率值越高,则说明对应的样本图像属于对应类别的可信度越高,因此,本申请实施例可以将概率值大于预设可信度阈值(例如为80%)的图像样本筛选出来,用于后续的半监督学习模型训练中。
实际应用中,利用上述从无标签图像样本集中筛选出的无标签图像样本,与有标签图像样本集中的有标签图像样本,构成融合图像样本集,用于半监督学习模型训练中。具体的,可以从筛选出的无标签图像样本和有标签图像样本分别随机选择预设数量的图像样本,构成融合图像样本集。由于本申请实施例提供的半监督学习模型训练采用的一致性损失量计算方法不需要利用图像样本的标签,因此,融合图像样本集中的图像样本不需要携带标签。
本申请实施例采用一致性损失量计算方法进行半监督学习模型训练的基本思想是,假设如果以某种很小的方式修改图像样本,也可以确信模型的预测应该在数据点与其扰动之间保持一致,即对图像样本增加扰动(如增加噪音)等后,形成新的图像样本,则原图像样本与新的图像样本的模型预测结果应该一致。本申请实施例采用面向任务域的思想计算一致性损失量,无论有标签图像样本还是无标签图像样本,它们共同组成了该任务的域的分布,因此基于由无标签图像样本和有标签图像样本构成的融合图像样本集,计算一致性损失量,具体的计算方式在后面进行介绍。
S102:对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集。
本申请实施例中,在得到融合图像样本集后,采用图像增量处理方法,对融合图像样本集中的各个图像样本进行增量处理,得到融合图像样本集的增量图像样本集。其中,图像增量处理方法包括水平翻转,随机切割,随机旋转等,可以根据实际图像分类需求选择图像增量处理方法。值得注意的是,对各个图像样本进行增量处理不改变图像样本的数据特征,同时增量图像样本沿用原图像样本的标签。另外,融合图像样本集中的各个增量图像样本与融合图像样本集中的原图像样本建立对应关系,以便用于后续一致性损失量的计算。
S103:通过基于所述融合图像样本集与所述增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行半监督迭代训练,得到经过训练的图像分类模型。
本申请实施例中,在得到融合图像样本集和增量图像样本集后,将融合图像样本集和增量图像样本集中的各个图像样本导入已完成初步训练的预训练模型,即经过有监督学习训练的图像分类模型,用于对其进行半监督学习训练。
本申请实施例中,通过基于融合图像样本集和增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行进一步训练,最终得到经过训练的图像分类模型。
实际应用中,基于融合图像样本集和增量图像样本集计算一致性损失量的方式如下:
假设u为融合图像样本集,为u的增量图像样本集,fθ(u)为图像分类模型的模型函数,θ为图像分类模a(u)型的训练参数,Pui为u中第i个图像样本ui的模型预测结果,为a(u)中第i个图像样本/>的模型预测结果。
其中,Pui=fθ(ui),用于表示图像样本ui经过图像分类模型的处理后得到的预测结果,用于表示图像样本/>经过图像分类模型的处理后得到的预测结果。由于图像样本ui与图像样本/>之间具有对应关系,即图像样本/>为图像样本ui的增量图像样本,因此二者的预设结果实际上应该是一致的,一致性损失量的计算实际上是计算二者之间的差距,可以理解的是,二者的差距越小,即一致性损失量越小,则说明图像分类模型的训练精度越高。
具体的,在半监督学习迭代训练的每轮训练中可以利用以下公式(1)计算一致性损失量Lu:
其中,B用于表示融合图像样本集u中图像样本的个数,C用于表示融合图像样本集u中图像样本的类别个数。
本申请实施例中,在利用上述计算方式得到一致性损失量之后,基于该一致性损失量确定是否继续对图像分类模型进行半监督训练。实际应用中,可以基于采用分类交叉熵得到的有监督学习损失量和一致性损失量,共同确定是否继续对图像分类模型进行半监督训练。其中,对于采用分类交叉熵得到有监督学习损失量的方法,本申请实施例不做介绍。
一种可选的实施方式中,如果有监督学习损失量和一致性损失量之和大于预设阈值,则再次执行S101-S103,以对图像分类模型进行迭代训练;如果有监督学习损失量和一致性损失量之和不大于预设阈值,则说明图像分类模型的训练精度已经达到标准,此时输出经过训练的图像分类模型即可。本申请实施例可以经过多轮迭代训练,最终使得图像分类模型的训练精度达到标准。
本申请实施例提供的一种图像分类模型的训练方法中,预先采用有监督学习对图像分类模型进行训练,在此基础上,利用包含有标签图像样本和无标签图像样本的融合图像样本集,进一步对图像分类模型进行训练,具体的,采用计算一致性损失量的方式对图像分类模型进行迭代训练,最终得到经过训练的图像分类模型。由于对于无标记图像样本而言,一致性损失量的计算只是计算自身偏差量,不会将错误数据引入模型训练中,从而提高了图像分类模型的训练精度。
为了提高图像分类模型的泛化性,最终使得基于图像分类模型进行图像分类的结果更准确,本申请实施例采用领域风险最小化方法,对一致性损失量进行计算。参考图2,为本申请实施例提供的另一种图像分类模型的训练方法流程图,该方法包括:
S201:对有标签图像样本集中的各个图像样本进行增量处理,得到有标签增量图像样本集。
其中,对图像样本进行增量处理的方法可以包括翻转、随机旋转、随机切割等数据增量方法。
S202:确定所述有标签增量图像样本集对应的领域随机分布图像样本集;其中,所述领域随机分布图像样本集中的各个图像样本的标签为基于所述有标签增量图像样本集中对应的图像样本的标签确定;
本申请实施例中,有标签增量图像样本集对应的领域随机分布图像样本集中的图像样本,是基于有标签增量图像样本集中任意两个图像样本确定的;另外,有标签增量图像样本集对应的领域随机分布图像样本集中的图像样本的标签也是基于有标签增量图像样本集中对应的图像样本的标签确定。具体的确定方式后续进行介绍。
S203:基于所述有标签增量图像样本集对应的领域随机分布图像样本集,计算有监督学习损失量。
一种可选的实施方式中,假设x为有标签图像样本集,a(x)为x的有标签增量图像样本集,从有标签增量图像样本集a(x)随机抽取两个图像样本xi和xj,根据如下公式(2)生成领域随机分布图像样本集中的图像样本xv,构成领域随机分布图像样本集;根据如下公式(3)生成领域随机分布图像样本集中的图像样本xv的标签yv,其中,yi和yj分别为图像样本xi和xj的标签,基于公式(3)将系数较大到的图像样本的标签作为图像样本xv的标签yv:
xv=λxi+(1-λ)xj (2)
yv=λyi+(1-λ)yj (3)
其中,λ为Beta函数,λ=Beta(a,a),且a为训练参数,a∈(0,∞)。
基于上述方式能够为有标签增量图像样本集确定对应的领域随机分布图像样本集,其中,包括确定域随机分布图像样本集中的各个图像样本和对应的标签。
实际应用中,计算有标签增量图像样本集对应的领域随机分布图像样本集中每个图像样本的预测值与标签之间的差距,该差距可以利用有监督学习损失量Lsv表征。可以理解的是,该差距越小,有监督学习损失量Lsv越小,则说明分类模型精度越高。
其中,领域随机分布图像样本集中每个图像样本的预测值,是指将每个图像样本输入至当前的图像分类模型中,经过其处理后,得到预测值,通过将该预测值与标签相比较,能够得出当前的图像分类模型的训练精度。
具体的,可以从有标签增量图像样本集对应的领域随机分布图像样本集中随机选取预设个数的图像样本,利用如下公式(4)计算有监督学习损失量Lsv:
其中,xm和xn分别为有标签增量图像样本集a(x)对应的领域随机分布图像样本集中的任意两个图像样本;ym和yn分别为xm和xn的标签;pxm和pxn分别为xm和xn的预测值;B用于表示从领域随机分布图像样本集中随机选取的图像样本的预设个数;c(x,y)用于表示分类交叉熵函数;n为闭区间[1,B]中随机正整数。
S204:基于有标签图像样本集和无标签图像样本集,构建融合图像样本集;
S205:对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集;
S204和S205可参照上述实施例中的S101和S102的描述进行理解,在此不再赘述。
S206:确定所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集;
本申请实施例中,融合图像样本集对应的领域随机分布图像样本集中的图像样本,是基于融合图像样本集中任意两个图像样本确定的,其标签也是基于融合图像样本集中对应的图像样本的标签确定。同样的方式确定增量图像样本集分别对应的领域随机分布图像样本集。具体的确定方式后续进行介绍。
S207:基于所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集,计算一致性损失量。
一种可选的实施方式中,假设u为融合图像样本集,a(u)为增量图像样本集,从融合图像样本集u中随机选取两个图像样本ui和uj,用于根据公式(5)确定融合图像样本集u对应的领域随机分布图像样本集中的图像样本pi和pj分别为图像样本ui和uj的预测值,利用公式(6)确定图像样本/>的预测值/>
uvi=λui+(1-λ)uj (5)
同样的方式,利用公式(7)和(8)确定增量图像样本集a(u)中各个图像样本及对应的预测值:
其中,a(ui)和a(uj)为从增量图像样本集a(u)中个随机选取的图像样本;用于表示基于a(ui)和a(uj)确定的增量图像样本集分别对应的领域随机分布图像样本集中到的图像样本;/>和/>分别为图像样本a(ui)和a(uj)的预测值,/>为领域随机分布图像样本集中到的图像样本/>的预测值。
实际应用中,可以利用如下公式(9)计算一致性损失量:
其中,Luv用于表示一致性损失量;B为领域随机分布图像样本集中的图像样本个数,C为标签类别个数,j为闭区间[1,B]中的随机正整数。
S208:基于有监督学习损失量与一致性损失量之和,确定是否继续对所述图像分类模型进行半监督训练;如果是,则执行S201;否则输出经过训练的图像分类模型。
本申请实施例中,在得到有监督学习损失量Lsv和一致性损失量Luv之后,确定二者之和L=Lsv+Luv是否小于预设阈值,如果是,则继续对图像分类模型进行迭代训练,否则,说明当前的图像分类模型的训练精度已经达到要求,可以输出经过训练的图像分类模型。
本申请实施例提供的图像分类模型的训练方法中,采用领域风险最小化的方法,计算有监督学习损失量和一致性损失量,基于计算得到的有监督学习损失量与一致性损失量之和,对图像分类模型进行半监督迭代训练,进一步提高了图像分类模型的训练精度。
与上述方法实施例相对应的,本申请还提供了一种图像分类模型的训练装置,参考图3,为本申请实施例提供的一种图像分类模型的训练装置的结构示意图,所述装置包括:
构建模块301,用于基于有标签图像样本集和无标签图像样本集,构建融合图像样本集;
增量模块302,用于对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集;
训练模块303,用于通过基于所述融合图像样本集与所述增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行半监督迭代训练,得到经过训练的图像分类模型。
一种可选的实施方式中,所述训练模块,包括:
第一计算子模块,用于基于所述融合图像样本集与所述增量图像样本集,计算一致性损失量;
第一确定子模块,用于基于所述一致性损失量,确定是否继续对所述图像分类模型进行半监督训练;
触发子模块,用于在所述第一确定子模块的结果为是时,触发所述构建模块,以对所述图像分类模型进行迭代训练;
输出子模块,用于在所述第一确定子模块的结果为否时,输出经过训练的图像分类模型。
另一种可选的实施方式中,所述装置还包括:
计算模块,用于计算有监督学习损失量;
相应的,所述第一确定子模块,具体用于:
基于所述有监督学习损失量和所述一致性损失量,确定是否继续对经过有监督学习训练的图像分类模型进行半监督训练。
本申请实施例提供的一种图像分类模型的训练方法中,预先采用有监督学习对图像分类模型进行训练,在此基础上,利用包含有标签图像样本和无标签图像样本的融合图像样本集,进一步对图像分类模型进行训练,具体的,采用计算一致性损失量的方式对图像分类模型进行迭代训练,最终得到经过训练的图像分类模型。由于对于无标记图像样本而言,一致性损失量的计算只是计算自身偏差量,不会将错误数据引入模型训练中,从而提高了图像分类模型的训练精度。
进一步的,本申请实施例还可以采用领域风险最小化的方法,计算有监督学习损失量和一致性损失量,基于计算得到的有监督学习损失量与一致性损失量之和,对图像分类模型进行半监督迭代训练,进一步提高了图像分类模型的训练精度。
另外,本申请实施例还提供了一种图像分类模型的训练设备,参见图4所示,可以包括:
处理器401、存储器402、输入装置403和输出装置404。图像分类模型的训练设备中的处理器401的数量可以一个或多个,图4中以一个处理器为例。在本发明的一些实施例中,处理器401、存储器402、输入装置403和输出装置404可通过总线或其它方式连接,其中,图4中以通过总线连接为例。
存储器402可用于存储软件程序以及模块,处理器401通过运行存储在存储器402的软件程序以及模块,从而执行图像分类模型的训练设备的各种功能应用以及数据处理。存储器402可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序等。此外,存储器402可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。输入装置403可用于接收输入的数字或字符信息,以及产生与图像分类模型的训练设备的用户设置以及功能控制有关的信号输入。
具体在本实施例中,处理器401会按照如下的指令,将一个或一个以上的应用程序的进程对应的可执行文件加载到存储器402中,并由处理器401来运行存储在存储器402中的应用程序,从而实现上述图像分类模型的训练设备的各种功能。
另外,本申请还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当所述指令在终端设备上运行时,使得所述终端设备实现图像分类模型的训练功能。
可以理解的是,对于装置实施例而言,由于其基本对应于方法实施例,所以相关之处参见方法实施例的部分说明即可。以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性劳动的情况下,即可以理解并实施。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
以上对本申请实施例所提供的一种图像分类模型的训练方法、装置及设备进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。
Claims (4)
1.一种图像分类模型的训练方法,其特征在于,所述方法包括:
基于有标签图像样本集和无标签图像样本集,构建融合图像样本集;
对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集;
计算有监督学习损失量,包括:
对所述有标签图像样本集中的各个图像样本进行增量处理,得到有标签增量图像样本集;
确定所述有标签增量图像样本集对应的领域随机分布图像样本集;其中,所述领域随机分布图像样本集中的各个图像样本的标签为基于所述有标签增量图像样本集中对应的图像样本的标签确定;
基于所述有标签增量图像样本集对应的领域随机分布图像样本集,计算有监督学习损失量;
基于所述融合图像样本集与所述增量图像样本集,计算一致性损失量,包括:
确定所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集,包括:
假设u为融合图像样本集,a(u)为增量图像样本集,从融合图像样本集u中随机选取两个图像样本ui和uj,用于根据公式(5)确定融合图像样本集u对应的领域随机分布图像样本集中的图像样本pi和pj分别为图像样本ui和uj的预测值,利用公式(6)确定图像样本的预测值/>
式中,λ为Beta函数,λ=Beta(a,a),且a为训练参数,a∈(0,∞);
利用公式(7)和(8)确定增量图像样本集a(u)中各个图像样本及对应的预测值:
其中,a(ui)和a(uj)为从增量图像样本集a(u)中个随机选取的图像样本;用于表示基于a(ui)和a(uj)确定的增量图像样本集分别对应的领域随机分布图像样本集中到的图像样本;/>和/>分别为图像样本a(ui)和a(uj)的预测值,/>为领域随机分布图像样本集中到的图像样本的预测值;
基于所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集,利用如下公式(9)计算一致性损失量:
其中,Luv用于表示一致性损失量;B为领域随机分布图像样本集中的图像样本个数,C为标签类别个数,j为闭区间[1,B]中的随机正整数;
基于所述有监督学习损失量和所述一致性损失量,确定是否继续对经过有监督学习训练的图像分类模型进行半监督训练;如果是,则继续执行所述基于有标签图像样本集和无标签图像样本集,构建融合图像样本集的步骤,以对所述图像分类模型进行迭代训练;如果否,则输出经过训练的图像分类模型。
2.一种图像分类模型的训练装置,其特征在于,所述装置包括:
构建模块,用于基于有标签图像样本集和无标签图像样本集,构建融合图像样本集;
增量模块,用于对所述融合图像样本集中的各个图像样本进行增量处理,得到增量图像样本集;
计算模块,用于计算有监督学习损失量,包括:
对所述有标签图像样本集中的各个图像样本进行增量处理,得到有标签增量图像样本集;
确定所述有标签增量图像样本集对应的领域随机分布图像样本集;其中,所述领域随机分布图像样本集中的各个图像样本的标签为基于所述有标签增量图像样本集中对应的图像样本的标签确定;
基于所述有标签增量图像样本集对应的领域随机分布图像样本集,计算有监督学习损失量;
训练模块,用于通过基于所述融合图像样本集与所述增量图像样本集计算一致性损失量的方式,对经过有监督学习训练的图像分类模型进行半监督迭代训练,得到经过训练的图像分类模型,包括:
第一计算子模块,用于基于所述融合图像样本集与所述增量图像样本集,计算一致性损失量,包括:
确定所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集,包括:
假设u为融合图像样本集,a(u)为增量图像样本集,从融合图像样本集u中随机选取两个图像样本ui和uj,用于根据公式(5)确定融合图像样本集u对应的领域随机分布图像样本集中的图像样本pi和pj分别为图像样本ui和uj的预测值,利用公式(6)确定图像样本的预测值/>
式中,λ为Beta函数,λ=Beta(a,a),且a为训练参数,a∈(0,∞);
利用公式(7)和(8)确定增量图像样本集a(u)中各个图像样本及对应的预测值:
其中,a(ui)和a(uj)为从增量图像样本集a(u)中个随机选取的图像样本;用于表示基于a(ui)和a(uj)确定的增量图像样本集分别对应的领域随机分布图像样本集中到的图像样本;/>和/>分别为图像样本a(ui)和a(uj)的预测值,/>为领域随机分布图像样本集中到的图像样本的预测值;
基于所述融合图像样本集和所述增量图像样本集分别对应的领域随机分布图像样本集,利用如下公式(9)计算一致性损失量:
其中,Luv用于表示一致性损失量;B为领域随机分布图像样本集中的图像样本个数,C为标签类别个数,j为闭区间[1,B]中的随机正整数;
第一确定子模块,用于基于所述有监督学习损失量和所述一致性损失量,确定是否继续对所述图像分类模型进行半监督训练;
触发子模块,用于在所述第一确定子模块的结果为是时,触发所述构建模块,以对所述图像分类模型进行迭代训练;
输出子模块,用于在所述第一确定子模块的结果为否时,输出经过训练的图像分类模型。
3.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有指令,当所述指令在终端设备上运行时,使得所述终端设备实现如权利要求1所述的方法。
4.一种电子设备,其特征在于,包括:存储器,处理器,及存储在所述存储器上并可在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时,实现如权利要求1所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911264127.3A CN111144451B (zh) | 2019-12-10 | 2019-12-10 | 一种图像分类模型的训练方法、装置及设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911264127.3A CN111144451B (zh) | 2019-12-10 | 2019-12-10 | 一种图像分类模型的训练方法、装置及设备 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN111144451A CN111144451A (zh) | 2020-05-12 |
CN111144451B true CN111144451B (zh) | 2023-08-25 |
Family
ID=70517988
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911264127.3A Active CN111144451B (zh) | 2019-12-10 | 2019-12-10 | 一种图像分类模型的训练方法、装置及设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN111144451B (zh) |
Families Citing this family (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112101217B (zh) * | 2020-09-15 | 2024-04-26 | 镇江启迪数字天下科技有限公司 | 基于半监督学习的行人再识别方法 |
CN112668586B (zh) | 2020-12-18 | 2024-05-14 | 北京百度网讯科技有限公司 | 模型训练、图片处理方法及设备、存储介质、程序产品 |
CN113806535B (zh) * | 2021-09-07 | 2024-09-06 | 清华大学 | 利用无标签文本数据样本提升分类模型表现的方法和装置 |
CN115471717B (zh) * | 2022-09-20 | 2023-06-20 | 北京百度网讯科技有限公司 | 模型的半监督训练、分类方法装置、设备、介质及产品 |
CN115908993A (zh) * | 2022-10-24 | 2023-04-04 | 北京数美时代科技有限公司 | 一种基于图像融合的数据增强方法、系统和存储介质 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108764281A (zh) * | 2018-04-18 | 2018-11-06 | 华南理工大学 | 一种基于半监督自步学习跨任务深度网络的图像分类方法 |
CN109697469A (zh) * | 2018-12-26 | 2019-04-30 | 西北工业大学 | 一种基于一致性约束的自学习小样本遥感图像分类方法 |
CN109784392A (zh) * | 2019-01-07 | 2019-05-21 | 华南理工大学 | 一种基于综合置信的高光谱图像半监督分类方法 |
CN109815331A (zh) * | 2019-01-07 | 2019-05-28 | 平安科技(深圳)有限公司 | 文本情感分类模型的构建方法、装置和计算机设备 |
CN110059672A (zh) * | 2019-04-30 | 2019-07-26 | 福州大学 | 一种利用增量学习对显微镜细胞图像检测模型进行增类学习的方法 |
-
2019
- 2019-12-10 CN CN201911264127.3A patent/CN111144451B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108764281A (zh) * | 2018-04-18 | 2018-11-06 | 华南理工大学 | 一种基于半监督自步学习跨任务深度网络的图像分类方法 |
CN109697469A (zh) * | 2018-12-26 | 2019-04-30 | 西北工业大学 | 一种基于一致性约束的自学习小样本遥感图像分类方法 |
CN109784392A (zh) * | 2019-01-07 | 2019-05-21 | 华南理工大学 | 一种基于综合置信的高光谱图像半监督分类方法 |
CN109815331A (zh) * | 2019-01-07 | 2019-05-28 | 平安科技(深圳)有限公司 | 文本情感分类模型的构建方法、装置和计算机设备 |
CN110059672A (zh) * | 2019-04-30 | 2019-07-26 | 福州大学 | 一种利用增量学习对显微镜细胞图像检测模型进行增类学习的方法 |
Non-Patent Citations (1)
Title |
---|
Interpolation Consistency Training for Semi-Supervised Learning;Vikas Verma et al;《arxiv.org》;20190309;第1-10页 * |
Also Published As
Publication number | Publication date |
---|---|
CN111144451A (zh) | 2020-05-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111144451B (zh) | 一种图像分类模型的训练方法、装置及设备 | |
Artigue et al. | The principal problem with principal components regression | |
Feng et al. | A hierarchical multi-label classification method based on neural networks for gene function prediction | |
CN112990530B (zh) | 区域人口数量预测方法、装置、电子设备和存储介质 | |
Seki et al. | Machine learning-based prediction of in-hospital mortality using admission laboratory data: A retrospective, single-site study using electronic health record data | |
Wei et al. | Real-time process monitoring using kernel distances | |
CN114445143A (zh) | 一种业务数据的预测方法、装置、设备及介质 | |
Pham et al. | Unsupervised training of Bayesian networks for data clustering | |
Alizadeh et al. | Simulating monthly streamflow using a hybrid feature selection approach integrated with an intelligence model | |
CN117289200A (zh) | 一种基于深度混合标准化的电能表异常检测方法及装置 | |
CN116451081A (zh) | 数据漂移的检测方法、装置、终端及存储介质 | |
Acharya et al. | An improved gradient boosting tree algorithm for financial risk management | |
Eom et al. | Marketable value estimation of patents using ensemble learning methodology: Focusing on US patents for the electricity sector | |
CN112433952B (zh) | 深度神经网络模型公平性测试方法、系统、设备及介质 | |
CN114186646A (zh) | 区块链异常交易识别方法及装置、存储介质及电子设备 | |
Stankovic et al. | Univariate individual household energy forecasting by tuned long short-term memory network | |
Qin et al. | A hybrid deep learning model for short‐term load forecasting of distribution networks integrating the channel attention mechanism | |
CN116777056A (zh) | 预测模型的训练、确定物资需求量的方法和装置 | |
Liu et al. | Combinatorial machine learning approaches for high-rise building cost prediction and their interpretability analysis | |
CN112990826B (zh) | 一种短时物流需求预测方法、装置、设备及可读存储介质 | |
CN115423159A (zh) | 光伏发电预测方法、装置及终端设备 | |
CN111027680B (zh) | 基于变分自编码器的监控量不确定性预测方法及系统 | |
Jiang | Monitoring model group of seepage behavior of earth-rock dam based on the mutual information and support vector machine algorithms | |
Banik et al. | Improved Regression Analysis with Ensemble Pipeline Approach for Applications across Multiple Domains | |
CN112149833B (zh) | 基于机器学习的预测方法、装置、设备和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |