CN116863278B - 模型训练方法、图像分类方法、装置、设备及存储介质 - Google Patents
模型训练方法、图像分类方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN116863278B CN116863278B CN202311087732.4A CN202311087732A CN116863278B CN 116863278 B CN116863278 B CN 116863278B CN 202311087732 A CN202311087732 A CN 202311087732A CN 116863278 B CN116863278 B CN 116863278B
- Authority
- CN
- China
- Prior art keywords
- sample
- classification model
- loss
- class
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 116
- 238000012549 training Methods 0.000 title claims abstract description 77
- 238000013145 classification model Methods 0.000 claims abstract description 172
- 230000001629 suppression Effects 0.000 claims abstract description 65
- 230000006870 function Effects 0.000 claims description 78
- 230000008859 change Effects 0.000 claims description 37
- 238000010606 normalization Methods 0.000 claims description 18
- 238000004590 computer program Methods 0.000 claims description 16
- 238000012545 processing Methods 0.000 claims description 15
- 230000004044 response Effects 0.000 claims description 3
- 230000008569 process Effects 0.000 description 16
- 238000004821 distillation Methods 0.000 description 12
- 238000013135 deep learning Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 238000004891 communication Methods 0.000 description 5
- 238000004364 calculation method Methods 0.000 description 4
- 230000010365 information processing Effects 0.000 description 4
- 230000008901 benefit Effects 0.000 description 3
- 238000009826 distribution Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 206010027175 memory impairment Diseases 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 238000004422 calculation algorithm Methods 0.000 description 2
- 230000018109 developmental process Effects 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 230000004913 activation Effects 0.000 description 1
- 230000032683 aging Effects 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 230000006835 compression Effects 0.000 description 1
- 238000007906 compression Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000005764 inhibitory process Effects 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Databases & Information Systems (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例公开了一种模型训练方法、图像分类方法、装置、设备及存储介质,其中,所述模型训练方法包括:获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数至少包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
Description
技术领域
本申请涉及但不限于人工智能技术领域,尤其涉及一种模型训练方法、图像分类方法、装置、设备及存储介质。
背景技术
深度学习广泛应用于工业视觉,其在复杂场景的表现明显优于传统图像处理算法。但是随着训练数据的不断增加,模型在不同数据集上迁移不可避免地会存在遗忘性问题,即在新数据集上训练深度学习分类模型,训练得到的新深度学习分类模型虽能够精确识别新数据特征,遗忘了在旧数据上学习到的知识的问题。
目前为解决模型遗忘性问题,主要采用蒸馏方法和模型组合方法。蒸馏往往会牺牲模型的准确率,其最主要的原因是:当旧模型的输出和新模型的输出偏差非常大的时候,通过蒸馏方法强行让他们一致,往往会得到负面的结果。模型组合的方式则会增加推理成本,推理时间变长。
发明内容
有鉴于此,本申请实施例至少提供一种模型训练方法、图像分类方法、装置、设备及存储介质。
本申请实施例的技术方案是这样实现的:
第一方面,本申请实施例提供一种模型训练方法,所述方法包括:
获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
第二方面,本申请实施例提供一种图像分类方法,所述方法包括:
获取待分类的图像数据集;通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于上述第一方面所述的模型训练方法进行训练得到的。
第三方面,本申请实施例提供一种模型训练装置,所述装置包括:
样本获取模块,用于获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;
模型训练模块,用于基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
第四方面,本申请实施例提供一种图像分类装置,所述装置包括:
数据获取模块,用于获取待分类的图像数据集;
图像分类模块,用于通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于上述第一方面所述的模型训练方法进行训练得到的。
第五方面,本申请实施例提供一种计算机设备,包括存储器和处理器,所述存储器存储有可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述第一方面或第二方面方法中的部分或全部步骤。
第六方面,本申请实施例提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述第一方面或第二方面方法中的部分或全部步骤。
本申请实施例中,在利用原始数据集训练得到第一分类模型的基础上,获取包括原始数据集中的至少一个原始样本的样本数据集对第二分类模型进行训练,在训练过程中通过计算第一分类模型和第二分类模型针对同一样本输出的类别得分的差异得到差异抑制损失,通过在损失函数中增加差异抑制损失,惩罚新旧模型对于同一样本输出的类别得分变化,从而使得第二分类模型在精确识别新数据特征的同时,保持在旧数据上的识别精度。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,而非限制本公开的技术方案。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,这些附图示出了符合本申请的实施例,并与说明书一起用于说明本申请的技术方案。
图1为本申请实施例提供的模型训练方法的一种流程示意图;
图2为本申请实施例提供的模型训练方法的另一种流程示意图;
图3为本申请实施例提供的模型训练方法的再一种流程示意图;
图4为本申请实施例提供的图像分类方法的可选的流程示意图;
图5为本申请实施例提供的一种模型训练装置的组成结构示意图;
图6为本申请实施例提供的一种图像分类装置的组成结构示意图;
图7为本申请实施例提供的一种计算机设备的硬件实体示意图。
具体实施方式
为了使本申请的目的、技术方案和优点更加清楚,下面结合附图和实施例对本申请的技术方案进一步详细阐述,所描述的实施例不应视为对本申请的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
在以下的描述中,涉及到“一些实施例”,其描述了所有可能实施例的子集,但是可以理解,“一些实施例”可以是所有可能实施例的相同子集或不同子集,并且可以在不冲突的情况下相互结合。
所涉及的术语“第一/第二/第三”仅仅是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一/第二/第三”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本申请实施例能够以除了在这里图示或描述的以外的顺序实施。
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请的目的,不是旨在限制本申请。
在对本申请实施例进行进一步详细说明之前,先对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释。
模型蒸馏,旨在把一个大模型或者多个模型全体学到的知识迁移到另一个轻量级单模型上,方便部署。即用小模型去学习大模型的预测结果,而不是直接学习训练集中的标签(label)。
模型组合(bagging),指通过结合几个模型降低泛化误差的技术。主要想法是分别训练几个不同的模型,然后让所有模型表决测试样例的输出。
模型的遗忘性问题,指的是在新数据集上训练深度学习分类模型,训练得到的新深度学习分类模型虽能够精确识别新数据特征,遗忘了在旧数据上学习到的知识的问题。
logit(指类别得分)为深度学习中一种表示模型输出的方式,通常是指模型输出的未经过softmax(归一化)函数处理的原始数值,也就是各个类别的置信度得分(score),不一定归一化或具有概率(probability)的意义。
在训练过程中,模型通常使用logits值作为损失函数的输入,然后通过softmax函数将其转化为概率分布,最终用于计算损失。在测试或推理阶段,通常使用softmax函数将logits转化为概率分布,以便根据得分最高的类别做出预测。
本申请实施例提供一种模型训练方法,该方法可以由计算机设备的处理器执行。其中,计算机设备指的可以是服务器、笔记本电脑、平板电脑、台式计算机、智能电视、机顶盒、移动设备(例如移动电话、便携式视频播放器、个人数字助理、专用消息设备、便携式游戏设备)等具备模型训练能力的设备。图1为本申请实施例提供的模型训练方法的一种流程示意图,如图1所示,该方法包括如下步骤S110至步骤S120:
步骤S110,获取样本数据集。
这里,所述样本数据集包括原始数据集中的至少一个原始样本。其中原始数据集是训练旧模型时所使用的训练样本。原始数据集包含已有特征,样本数据集则除了已有特征,可以包含一些新特征,使得待训练的第二分类模型适用于具有新特征的待检测图像,也适用于具有已有特征的待检测图像,有利于提高训练模型的泛化能力。结合已有的原始样本和新数据集一起作为样本数据集参与训练,此时样本数据集更大,更易获得更高精度的模型。
步骤S120,基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型。
这里,所述目标损失函数至少包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
在实施中,先随机初始化第二分类模型的网络参数,再将样本数据集输入到第二分类模型中进行处理,获取第二分类模型的归一化层(softmax)之前的那层(通常为全连接层)的网络输出作为类别得分,进一步通过计算第二分类模型与第一分类模型分别对应的类别得分之间的差异确定差异抑制损失。
需要说明的是,第一分类模型和第二分类模型的网络结构可以相同,例如包括特征提取层、全连接层、归一层等。另外,二者的网络结构也可以不相同,而且若第二分类模型的网络结构小于第一分类模型的网络结构,可以起到模型压缩的效果。
在一些实施方式中,分别获取第一分类模型和第二分类模型针对于训练样本集中每一样本的类别得分,再基于两次类别得分之间的变化值计算所述差异抑制损失。这样,在训练第二分类模型的过程中,对第一分类模型和第二分类模型之间输出的类别得分变化部分进行惩罚,而不强迫新旧模型像普通蒸馏一样保持一致,最终得到的第二分类模型会更加灵活。
本申请实施例中,在利用原始数据集训练得到第一分类模型的基础上,获取包括原始数据集中的至少一个原始样本的样本数据集对第二分类模型进行训练,在训练过程中通过计算第一分类模型和第二分类模型针对同一样本输出的类别得分的差异得到差异抑制损失,将差异抑制损失加入目标损失函数中,惩罚新旧模型对于同一样本输出的类别得分变化,从而使得第二分类模型在精确识别新数据特征的同时,保持在旧数据上的识别精度。
在一些实施例中,第二分类模型至少包括全连接层。图2为本申请实施例提供的模型训练方法的另一种流程示意图,如图2所示,上述步骤S120“基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型”可以包括以下步骤S210至步骤S240:
步骤S210,将所述样本数据集中的目标样本输入所述第二分类模型,得到所述全连接层输出的第二类别得分。
这里,所述第二类别得分是第二分类模型在归一层之前的那层全连接层的输出结果即logit。
步骤S220,基于所述第二类别得分,利用所述目标损失函数确定所述第二分类模型的学习损失值。
这里,所述学习损失值包括差异抑制损失,其中差异抑制损失可以基于目标样本的第二类别得分以及第一分类模型针对同一样本输出的第一类别得分之间的变化值得到。
在一些实施方式中,可以通过以下步骤确定第二分类模型的学习损失值:确定所述第一分类模型针对所述目标样本输出的第一类别得分;基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失;基于所述差异抑制损失,确定所述学习损失值。
这里,将样本数据集输入到第一分类模型中,获取前向计算过程中在归一化层之前的那层全连接层的网络输出即为目标样本的第一类别得分。通过计算第二类别得分和第一类别得分之间的变化确定差异抑制损失。其中第二类别得分和第一类别得分之间的变化可以通过特征之间的距离体现,具体地距离可以为KL(Kullback-Leiblerdivergence)距离、欧氏距离、曼哈顿距离等。
步骤S230,基于所述学习损失值对所述第二分类模型的网络参数进行反向传播更新。
步骤S240,响应于满足收敛条件,确定所述第二分类模型为所述图像分类模型。
这里,所述收敛条件包括但不限于迭代次数达到预设次数、训练时间满足预设时长或者损失值低于预设阈值等。其中,所述预设次数为经验值,例如为30万次或5千万次等,直至预设次数后认为完成训练过程,得到参数优化的图像分类模型。
需要说明的是,第二分类模型由卷积层、池化层、激活函数层等一系列模块按照一定规律搭建而成,其组成形式即为网络结构;而卷积层这类模块具有参数,网络中各个结构的参数即为网络参数,通过学习损失值对网络参数进行迭代更新。
在上述实施例中,将样本数据集中目标样本输入到第二分类模型中,通过全连接层获取目标样本的第二类别得分,并通过前向计算得到学习损失值后再使用反向传播算法更新每一层的模型参数,从而实现模型训练过程。这样,通过获取第二类别得分和第一类别得分并计算二者之间的差异输入到目标损失函数中,以惩罚新模型和旧模型之间各自输出的类别得分变化,而不强迫它们像普通蒸馏一样保持一致,模型会更加灵活。
在一些实施例中,所述目标损失函数还包括拟合损失,所述拟合损失用于表征所述第二分类模型的预测类别与样本标签之间的差异;所述第二分类模型还包括所述全连接层之后的归一化层。上述步骤S120“基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型”还包括以下步骤S250和步骤S260:
步骤S250,将所述样本数据集中的目标样本输入所述第二分类模型,得到所述归一化层输出的第二预测类别。
这里,所述第二预测类别是所述归一化层对所述第二类别得分处理得到的。归一化层通常为softmax函数,通过softmax函数将第二类别得分转化为概率分布,以便更加得到最高的类别作出预测。
需要说明的是,步骤S210和步骤S250可以同时进行,即在第二分类模型包括全连接层和归一化层的情况下,将所述样本数据集中的目标样本输入所述第二分类模型,依次通过所述全连接层输出第二类别得分和通过所述归一化层输出第二预测类别。
步骤S260,基于所述第二类别得分和所述第二预测类别,利用所述目标损失函数确定所述第二分类模型的学习损失值。
这里,所述学习损失值包括拟合损失和差异抑制损失两部分,其中通过第二类别得分计算拟合损失,通过第二预测类别计算差异抑制损失。
在一些实施方式中,上述步骤S260可以进一步实施为以下步骤S2601至步骤S2604:
步骤S2601,确定所述第一分类模型针对所述目标样本输出的第一类别得分。
这里,将样本数据集输入到第一分类模型中,获取前向计算过程中在归一化层之前的那层全连接层的网络输出即为目标样本的第一类别得分。
步骤S2602,基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失。
这里,通过计算第二类别得分和第一类别得分之间的变化确定差异抑制损失。其中第二类别得分和第一类别得分之间的变化可以通过特征之间的距离体现,具体地距离可以为KL距离、欧氏距离、曼哈顿距离等,本申请实施例对此不作限定。
步骤S2603,基于所述第二预测类别和所述样本数据集的样本标签,确定所述拟合损失。
这里,可以通过交叉熵函数计算目标样本的第二预测类别和样本标签之间的拟合损失。拟合损失表征第二分类模型输出结果与真实标签之间的差异,训练的目的在于模型的输出结果更接近真实标签,从而能够提升模型的准确率。
步骤S2604,对所述拟合损失和所述差异抑制损失进行加权求和,得到所述学习损失值。
这里,对所述拟合损失和所述差异抑制损失进行加权求和,得到加权求和值,即为样本数据集在一次迭代的前向计算过程中的学习损失值。
上述实施例中,将样本数据集中目标样本输入到第二分类模型中,分别通过全连接层和归一化层获取目标样本的第二类别得分和第二预测类别,并分别计算得到对应的差异抑制损失和拟合损失,进一步得到前向计算的学习损失值以反向更新每一层的模型参数,从而实现模型训练过程。
图3为本申请实施例提供的模型训练方法的再一种流程示意图,如图3所示,上述步骤S2602“基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失”可以包括以下步骤S310至步骤S320:
步骤S310,针对所述目标样本,确定所述第二类别得分和所述第一类别得分之间的变化距离。
这里,第一类别得分为第一分类模型的归一化层前的那一层的网络输出,第二类别得分为第二分类模型的归一化层前的那一层的网络输出,步骤S320,基于所述变化距离和预设的焦点函数确定所述差异抑制损失。
这里,所述焦点函数用于调节所述变化距离的权重。差异抑制损失表征新旧模型输出之间的差异。本申请实施例针对不同样本对应的变化距离可以通过焦点函数来调节权重,使得模型更加灵活。
在上述实施例中,通过计算第一分类模型和第二分类模型针对同一样本分别对应的类别得分之间的差异距离,并设计焦点函数调节不同样本的变化距离的权重,从而确定出差异抑制损失,利用差异抑制损失惩罚新模型和旧模型之间各自输出的类别得分变化,而不强迫它们像普通蒸馏一样保持一致,模型会更加灵活。
在一些实施方式中,所述基于所述变化距离和预设的焦点函数确定所述差异抑制损失进一步实施为:确定所述目标样本通过所述焦点函数计算的权重结果;在所述变化距离大于第一阈值的情况下,对所述目标样本对应的所述变化距离与所述权重结果进行相乘并求和,得到所述差异抑制损失。
这里,所述第一阈值为截断阈值,表征新旧模型输出的类别得分之间距离的容忍度,这样针对第一分类模型输出的第一类别得分与第二分类模型输出的第二类别得分之间的变化距离大于第一阈值的情况,才计算差异抑制损失,能提升模型训练的效率。
在一些实施方式中,所述预设的焦点函数为第一焦点参数、第二焦点参数和二值函数的线性组合,所述方法还包括:针对所述样本数据集中除候选样本集之外的每一样本,通过所述第一焦点参数调节所述变化距离的权重;其中,所述候选样本集为所述第一分类模型识别正确的样本集合;针对所述候选样本集中每一样本,在所述第一焦点参数的基础上,通过所述第二焦点参数和所述二值函数增加所述变化距离的权重。
这里,针对第一分类模型识别正确的样本集合,通过第二焦点参数和二值函数增加变化距离的权重,从而在蒸馏过程中聚焦在旧模型识别正确的那部分样本的类别得分上,减少第二分类模型对这部分样本上学习到的知识的遗忘问题,同时能保证模型的准确性和推理时效。
在一些实施方式中,所述方法还包括:确定所述第一分类模型针对所述目标样本输出的第一预测类别;基于所述样本数据集的样本标签,从所述样本数据集中选择所述第一预测类别和所述样本标签一致的样本作为所述候选样本集。
这里,利用第一分类模型对样本数据集进行预测,得到每一样本的第一预测类别,从而结合样本标签从样本数据集中筛选出第一分类模型识别正确的样本作为候选样本集。
图4为本申请实施例提供的图像分类方法的流程示意图,如图4所示,所述方法包括以下步骤S410和步骤S420:
步骤S410,获取待分类的图像数据集。
这里,所述图像数据集可以包括原始训练过程中的旧数据特征,还可以包括未训练过的新数据特征。
步骤S420,通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果。
这里,所述图像分类模型是基于本申请提出的模型训练方法进行训练得到的,即利用包括差异抑制损失的目标损失函数对第二分类模型的网络参数进行迭代更新得到的,其中所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
上述实施例中,通过计算第一分类模型和第二分类模型针对同一样本输出的类别得分的变化得到差异抑制损失,将差异抑制损失加入目标损失函数中,惩罚新旧模型对于同一样本输出的类别得分变化,从而使得训练好的图像分类模型无论针对新数据特征还是旧数据特征均能保持较高的识别精度。
下面结合一个具体实施例对上述模型训练方法进行说明,然而值得注意的是,该具体实施例仅是为了更好地说明本申请,并不构成对本申请的不当限定。
蒸馏往往会牺牲模型的准确率,其最主要的原因是:当旧模型的输出和新模型的输出偏差非常大的时候,通过蒸馏方法强行让他们一致,往往会得到负面的结果。模型组合方式通过结合几个模型降低泛化误差的技术。主要想法是分别训练几个不同的模型,然后让所有模型表决测试样例的输出,这样会增加推理成本,推理时间变长。
本申请实施例基于蒸馏的方法,提出一种焦点类别得分差异抑制的模型训练方法,以分类任务为例,选取一个分类模型,比如Bert(预训练语言模型)+ MLP(多层感知机)+softmax,并定义目标损失函数如公式(1):
公式(1);
其中,为学习损失值,/>为权重系数,/>为交叉熵损失函数,计算的是新模型输出的预测类别/>与样本标签/>之间的拟合损失,/>为差异抑制损失函数,计算的是针对所述样本数据集中同一样本,第二分类模型输出的第二类别得分/>与第一分类模型输出的第一类别得分/>之间的差异抑制损失,/>为样本特征,/>为样本标签。
差异抑制损失函数的表现形式如公式(2)和(3):
公式(2);
公式(3);
其中,为第二分类模型输出的第二类别得分,/>为第一分类模型输出的第一类别得分;/>和/>分别为第一焦点参数和第二焦点参数,可以调节聚焦样本的类别得分差异的权重;N为样本数;/>为二值函数,在第一分类模型针对样本特征/>输出的预测类别与样本标签/>相等的情况下,/>的值为1,否则为0。
为针对样本特征/>计算的第二类别得分与第一类别得分之间的距离,可以选择KL距离、欧氏距离、曼哈顿距离等;/>为第一阈值,表征新旧模型输出的类别得分之间距离的容忍度。
需要说明的是,第一分类模型和第二分类模型的网络结构可以相同,也可以不相同,可以分别利用原始数据集和样本数据集进行训练,其中原始数据集训练得到的是第一分类模型即旧模型,样本数据集训练的是第二分类模型即新模型。
在实施中,选择带图形处理器(Graphics Processing Unit,GPU)的服务器,按照上述目标损失函数进行迭代训练,得到参数优化的图像分类模型并部署模型服务。经测试,不同模型在afqmc数据集上的评价指标如下表1。需要说明的是,表1中新模型可以对应于前述实施例中的第二分类模型,旧模型可以对应于前述实施例中的第一分类模型。
表1
其中,ACC(Accuracy)表示模型准确率,NFR(Negative flips rate)表示旧模型识别正确但新模型识别错误的概率,可以通过公式(4)计算得到:
公式(4);
其中,N为样本数,是二值函数,/>和/>分别为新模型、旧模型各自针对第i个样本的预测类别,/>为第i个样本对应的真实标签。
可以看出,本申请提出的采用焦点类别得分的差异抑制方法,增加模型准确率的同时,减少了旧模型识别准确但新模型识别错误的概率。
本申请实施例提出的模型训练方法聚焦在哪些样本的类别得分需要蒸馏,重点聚焦在旧模型识别正确的那部分样本的类别得分,其他样本的类别得分是否需要蒸馏可通过焦点参数来调节,不仅要解决模型的遗忘性问题,又能保证模型的准确性和推理时效。同时,本申请实施例提出的差异抑制方法,通过惩罚新模型和旧模型之间分别对应的类别得分之间的差异,而不强迫它们像普通蒸馏一样保持一致,模型会更加灵活。
基于前述的实施例,本申请实施例提供一种模型训练装置,该装置包括所包括的各模块、以及各模块所包括的各子模块及各单元,可以通过计算机设备中的处理器来实现;当然也可通过具体的逻辑电路实现;在实施的过程中,处理器可以为中央处理器(CentralProcessing Unit,CPU)、微处理器(Microprocessor Unit,MPU)、数字信号处理器(DigitalSignal Processor,DSP)或现场可编程门阵列(Field Programmable Gate Array,FPGA)等。
图5为本申请实施例提供的一种模型训练装置的组成结构示意图,如图5所示,模型训练装置500包括:样本获取模块510和模型训练模块520,其中:
所述样本获取模块510,用于获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本;
所述模型训练模块520,用于基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数至少包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中同一样本,所述第二分类模型与第一分类模型分别对应的类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的。
在一些可能的实施例中,所述第二分类模型至少包括全连接层,所述模型训练模块520包括:第一预测子模块,用于将所述样本数据集中的目标样本输入所述第二分类模型,得到所述全连接层输出的第二类别得分;第一损失确定子模块,用于基于所述第二类别得分,利用所述目标损失函数确定所述第二分类模型的学习损失值;参数更新子模块,用于基于所述学习损失值对所述第二分类模型的网络参数进行反向传播更新;响应于满足收敛条件,确定所述第二分类模型为所述图像分类模型。
在一些可能的实施例中,所述第一损失确定子模块包括:第一确定单元,用于确定所述第一分类模型针对所述目标样本输出的第一类别得分;第二确定单元,用于基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失;第三确定单元,用于基于所述差异抑制损失,确定所述学习损失值。
在一些可能的实施例中,所述目标损失函数还包括拟合损失,所述拟合损失用于表征所述第二分类模型的预测类别与样本标签之间的差异;所述第二分类模型还包括所述全连接层之后的归一化层;所述模型训练模块520还包括:第二预测子模块,用于将所述样本数据集中的目标样本输入所述第二分类模型,得到所述归一化层输出的第二预测类别;其中,所述第二预测类别是所述归一化层对所述第二类别得分处理得到的;第二损失确定子模块,用于基于所述第二类别得分和所述第二预测类别,利用所述目标损失函数确定所述第二分类模型的学习损失值。
在一些可能的实施例中,所述第二损失确定子模块包括:第三确定单元,用于确定所述第一分类模型针对所述目标样本输出的第一类别得分;第四确定单元,用于基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失;第五确定单元,用于基于所述第二预测类别和所述样本数据集的样本标签,确定所述拟合损失;加权处理单元,用于对所述拟合损失和所述差异抑制损失进行加权求和,得到所述学习损失值。
在一些可能的实施例中,所述第四确定单元包括:差异确定子单元,用于针对所述目标样本,确定所述第二类别得分和所述第一类别得分之间的变化距离;损失确定子单元,用于基于所述变化距离和预设的焦点函数确定所述差异抑制损失;其中,所述焦点函数用于调节所述变化距离的权重。
在一些可能的实施例中,所述误差确定子单元还用于确定所述目标样本通过所述焦点函数计算的权重结果;在所述变化距离大于第一阈值的情况下,对所述目标样本对应的所述变化距离与所述权重结果进行相乘并求和,得到所述差异抑制损失。
在一些可能的实施例中,所述预设的焦点函数为第一焦点参数、第二焦点参数和二值函数的线性组合,所述第三确定单元还包括:第一调节子单元,用于针对所述样本数据集中除候选样本集之外的每一样本,通过所述第一焦点参数调节所述变化距离的权重;其中,所述候选样本集为所述第一分类模型识别正确的样本集合;第二调节子单元,用于针对所述候选样本集中每一样本,在所述第一焦点参数的基础上,通过所述第二焦点参数和所述二值函数增加所述变化距离的权重。
在一些可能的实施例中,所述第三确定单元还包括:预测子单元,用于确定所述第一分类模型针对所述目标样本输出的第一预测类别;选择子单元,用于基于所述样本数据集的样本标签,从所述样本数据集中选择所述第一预测类别和所述样本标签一致的样本作为所述候选样本集。
以上装置实施例的描述,与上述模型训练方法实施例的描述是类似的,具有同模型训练方法实施例相似的有益效果。在一些实施例中,本公开实施例提供的装置具有的功能或包含的模块可以用于执行上述模型训练方法实施例描述的方法,对于本申请装置实施例中未披露的技术细节,请参照本申请模型训练方法实施例的描述而理解。
基于前述的实施例,本申请实施例提供一种模型训练装置,该装置包括所包括的各模块、以及各模块所包括的各子模块及各单元,可以通过计算机设备中的处理器来实现;当然也可通过具体的逻辑电路实现;在实施的过程中,处理器可以为中央处理器、微处理器、数字信号处理器或现场可编程门阵列等。
图6为本申请实施例提供的一种图像分类装置的组成结构示意图,如图6所示,图像分类装置600包括:数据获取模块610和图像分类模块620,其中:
所述数据获取模块610,用于获取待分类的图像数据集;
所述图像分类模块620,用于通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于本申请实施例提供的模型训练方法进行训练得到的。
以上装置实施例的描述,与上述图像分类方法实施例的描述是类似的,具有同图像分类方法实施例相似的有益效果。在一些实施例中,本公开实施例提供的装置具有的功能或包含的模块可以用于执行上述图像分类方法实施例描述的方法,对于本申请装置实施例中未披露的技术细节,请参照本申请图像分类方法实施例的描述而理解。
若本申请技术方案涉及个人信息,应用本申请技术方案的产品在处理个人信息前,已明确告知个人信息处理规则,并取得个人自主同意。若本申请技术方案涉及敏感个人信息,应用本申请技术方案的产品在处理敏感个人信息前,已取得个人单独同意,并且同时满足“明示同意”的要求。例如,在摄像头等个人信息采集装置处,设置明确显著的标识告知已进入个人信息采集范围,将会对个人信息进行采集,若个人自愿进入采集范围即视为同意对其个人信息进行采集;或者在个人信息处理的装置上,利用明显的标识/信息告知个人信息处理规则的情况下,通过弹窗信息或请个人自行上传其个人信息等方式获得个人授权;其中,个人信息处理规则可包括个人信息处理者、个人信息处理目的、处理方式、处理的个人信息种类等信息。
需要说明的是,本申请实施例中,如果以软件功能模块的形式实现上述的模型训练方法或图像分类方法,并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请实施例的技术方案本质上或者说对相关技术做出贡献的部分可以以软件产品的形式体现出来,该软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机、服务器、或者网络设备等)执行本申请各个实施例所述方法的全部或部分。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read Only Memory,ROM)、磁碟或者光盘等各种可以存储程序代码的介质。这样,本申请实施例不限制于任何特定的硬件、软件或固件,或者硬件、软件、固件三者之间的任意结合。
本申请实施例提供一种计算机设备,包括存储器和处理器,所述存储器存储有可在处理器上运行的计算机程序,所述处理器执行所述程序时实现上述模型训练方法或图像分类方法中的部分或全部步骤。
本申请实施例提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现上述模型训练方法或图像分类方法中的部分或全部步骤。所述计算机可读存储介质可以是瞬时性的,也可以是非瞬时性的。
本申请实施例提供一种计算机程序,包括计算机可读代码,在所述计算机可读代码在计算机设备中运行的情况下,所述计算机设备中的处理器执行用于实现上述模型训练方法或图像分类方法中的部分或全部步骤。
本申请实施例提供一种计算机程序产品,所述计算机程序产品包括存储了计算机程序的非瞬时性计算机可读存储介质,所述计算机程序被计算机读取并执行时,实现上述模型训练方法或图像分类方法中的部分或全部步骤。该计算机程序产品可以具体通过硬件、软件或其结合的方式实现。在一些实施例中,所述计算机程序产品具体体现为计算机存储介质,在另一些实施例中,计算机程序产品具体体现为软件产品,例如软件开发包(Software Development Kit,SDK)等等。
这里需要指出的是:上文对各个实施例的描述倾向于强调各个实施例之间的不同之处,其相同或相似之处可以互相参考。以上设备、存储介质、计算机程序及计算机程序产品实施例的描述,与上述方法实施例的描述是类似的,具有同方法实施例相似的有益效果。对于本申请设备、存储介质、计算机程序及计算机程序产品实施例中未披露的技术细节,请参照本申请方法实施例的描述而理解。
需要说明的是,图7为本申请实施例中计算机设备的一种硬件实体示意图,如图7所示,该计算机设备700的硬件实体包括:处理器701、通信接口702和存储器703,其中:
处理器701通常控制计算机设备700的总体操作。
通信接口702可以使计算机设备通过网络与其他终端或服务器通信。
存储器703配置为存储由处理器701可执行的指令和应用,还可以缓存待处理器701以及计算机设备700中各模块待处理或已经处理的数据(例如,图像数据、音频数据、语音通信数据和视频通信数据),可以通过闪存(FLASH)或随机访问存储器(Random AccessMemory,RAM)实现。处理器701、通信接口702和存储器703之间可以通过总线704进行数据传输。
应理解,说明书通篇中提到的“一个实施例”或“一实施例”意味着与实施例有关的特定特征、结构或特性包括在本申请的至少一个实施例中。因此,在整个说明书各处出现的“在一个实施例中”或“在一实施例中”未必一定指相同的实施例。此外,这些特定的特征、结构或特性可以任意适合的方式结合在一个或多个实施例中。应理解,在本申请的各种实施例中,上述各步骤/过程的序号的大小并不意味着执行顺序的先后,各步骤/过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。上述本申请实施例序号仅仅为了描述,不代表实施例的优劣。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
在本申请所提供的几个实施例中,应该理解到,所揭露的设备和方法,可以通过其它的方式实现。以上所描述的设备实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,如:多个单元或组件可以结合,或可以集成到另一个系统,或一些特征可以忽略,或不执行。另外,所显示或讨论的各组成部分相互之间的耦合、或直接耦合、或通信连接可以是通过一些接口,设备或单元的间接耦合或通信连接,可以是电性的、机械的或其它形式的。
上述作为分离部件说明的单元可以是、或也可以不是物理上分开的,作为单元显示的部件可以是、或也可以不是物理单元;既可以位于一个地方,也可以分布到多个网络单元上;可以根据实际的需要选择其中的部分或全部单元来实现本实施例方案的目的。
另外,在本申请各实施例中的各功能单元可以全部集成在一个处理单元中,也可以是各单元分别单独作为一个单元,也可以两个或两个以上单元集成在一个单元中;上述集成的单元既可以采用硬件的形式实现,也可以采用硬件加软件功能单元的形式实现。
本领域普通技术人员可以理解:实现上述方法实施例的全部或部分步骤可以通过程序指令相关的硬件来完成,前述的程序可以存储于计算机可读取存储介质中,该程序在执行时,执行包括上述方法实施例的步骤;而前述的存储介质包括:移动存储设备、只读存储器(Read Only Memory,ROM)、磁碟或者光盘等各种可以存储程序代码的介质。
或者,本申请上述集成的单元如果以软件功能模块的形式实现并作为独立的产品销售或使用时,也可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对相关技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机、服务器、或者网络设备等)执行本申请各个实施例所述方法的全部或部分。而前述的存储介质包括:移动存储设备、ROM、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本申请的实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。
Claims (14)
1.一种模型训练方法,其特征在于,所述方法包括:
获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本和新数据集;
基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数至少包括差异抑制损失;
所述差异抑制损失用于表征针对所述样本数据集中目标样本,所述第二分类模型输出的第二类别得分与第一分类模型输出的第一类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的;
所述差异抑制损失基于变化距离和预设的焦点函数的乘积确定,其中,所述变化距离基于所述第一类别得分与第二类别得分确定,所述焦点函数为第一焦点参数加第二焦点参数乘以二值函数,所述第一焦点参数和所述第二焦点参数均用于调节所述变化距离的权重,所述二值函数的输出数值基于所述第一分类模型针对所述目标样本输出的预测类别与样本标签确定。
2.根据权利要求1所述的方法,其特征在于,所述第二分类模型至少包括全连接层,所述基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型,包括:
将所述样本数据集中的目标样本输入所述第二分类模型,得到所述全连接层输出的第二类别得分;
基于所述第二类别得分,利用所述目标损失函数确定所述第二分类模型的学习损失值;
基于所述学习损失值对所述第二分类模型的网络参数进行反向传播更新;
响应于满足收敛条件,确定所述第二分类模型为所述图像分类模型。
3.根据权利要求2所述的方法,其特征在于,所述基于所述第二类别得分,利用所述目标损失函数确定所述第二分类模型的学习损失值,包括:
确定所述第一分类模型针对所述目标样本输出的第一类别得分;
基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失;
基于所述差异抑制损失,确定所述学习损失值。
4.根据权利要求2所述的方法,其特征在于,所述目标损失函数还包括拟合损失,所述拟合损失用于表征所述第二分类模型的预测类别与样本标签之间的差异;所述第二分类模型还包括所述全连接层之后的归一化层;
所述基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型,还包括:
将所述样本数据集中的目标样本输入所述第二分类模型,得到所述归一化层输出的第二预测类别;其中,所述第二预测类别是所述归一化层对所述第二类别得分处理得到的;
基于所述第二类别得分和所述第二预测类别,利用所述目标损失函数确定所述第二分类模型的学习损失值。
5.根据权利要求4所述的方法,其特征在于,所述基于所述第二类别得分和所述第二预测类别,利用所述目标损失函数确定所述第二分类模型的学习损失值,包括:
确定所述第一分类模型针对所述目标样本输出的第一类别得分;
基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失;
基于所述第二预测类别和所述样本数据集的样本标签,确定所述拟合损失;
对所述拟合损失和所述差异抑制损失进行加权求和,得到所述学习损失值。
6.根据权利要求5所述的方法,其特征在于,所述基于所述目标样本对应的所述第二类别得分和所述第一类别得分,确定所述差异抑制损失,包括:
针对所述目标样本,确定所述第二类别得分和所述第一类别得分之间的变化距离;
基于所述变化距离和预设的焦点函数的乘积确定所述差异抑制损失。
7.根据权利要求6所述的方法,其特征在于,所述基于所述变化距离和预设的焦点函数的乘积确定所述差异抑制损失,包括:
确定所述目标样本通过所述焦点函数计算的权重结果;
在所述变化距离大于第一阈值的情况下,对所述目标样本对应的所述变化距离与所述权重结果进行相乘并求和,得到所述差异抑制损失。
8.根据权利要求6或7所述的方法,其特征在于,所述方法还包括:
针对所述样本数据集中除候选样本集之外的每一样本,通过所述第一焦点参数调节所述变化距离的权重;其中,所述候选样本集为所述第一分类模型识别正确的样本集合;
针对所述候选样本集中每一样本,在所述第一焦点参数的基础上,通过所述第二焦点参数和所述二值函数增加所述变化距离的权重。
9.根据权利要求8所述的方法,其特征在于,所述方法还包括:
确定所述第一分类模型针对所述目标样本输出的第一预测类别;
基于所述样本数据集的样本标签,从所述样本数据集中选择所述第一预测类别和所述样本标签一致的样本作为所述候选样本集。
10.一种图像分类方法,其特征在于,所述方法包括:
获取待分类的图像数据集;
通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于权利要求1至9任一项所述的模型训练方法进行训练得到的。
11.一种模型训练装置,其特征在于,所述装置包括:
样本获取模块,用于获取样本数据集;其中,所述样本数据集包括原始数据集中的至少一个原始样本和新数据集;
模型训练模块,用于基于所述样本数据集,基于目标损失函数对第二分类模型的网络参数进行迭代更新,得到图像分类模型;其中,所述目标损失函数包括差异抑制损失;所述差异抑制损失用于表征针对所述样本数据集中目标样本,所述第二分类模型输出的第二类别得分与第一分类模型输出的第一类别得分之间的差异;所述第一分类模型是利用所述原始数据集训练得到的;所述差异抑制损失基于变化距离和预设的焦点函数的乘积确定,其中,所述变化距离基于所述第一类别得分与第二类别得分确定,所述焦点函数为第一焦点参数加第二焦点参数乘以二值函数,所述第一焦点参数和所述第二焦点参数均用于调节所述变化距离的权重,所述二值函数的输出数值基于所述第一分类模型针对所述目标样本输出的预测类别与样本标签确定。
12.一种图像分类装置,其特征在于,所述装置包括:
数据获取模块,用于获取待分类的图像数据集;
图像分类模块,用于通过已训练的图像分类模型对所述图像数据集进行分类,得到所述图像数据集中每一图像的分类结果;其中,所述图像分类模型是基于权利要求1至9任一项所述的模型训练方法进行训练得到的。
13.一种计算机设备,包括存储器和处理器,所述存储器存储有可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现权利要求1至9任一项所述方法中的步骤,或执行权利要求10所述方法中的步骤。
14.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现权利要求1至9任一项所述方法中的步骤,或执行权利要求10所述方法中的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311087732.4A CN116863278B (zh) | 2023-08-25 | 2023-08-25 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311087732.4A CN116863278B (zh) | 2023-08-25 | 2023-08-25 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116863278A CN116863278A (zh) | 2023-10-10 |
CN116863278B true CN116863278B (zh) | 2024-01-26 |
Family
ID=88219531
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311087732.4A Active CN116863278B (zh) | 2023-08-25 | 2023-08-25 | 模型训练方法、图像分类方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116863278B (zh) |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117974456B (zh) * | 2024-01-04 | 2024-06-18 | 万里云医疗信息科技(北京)有限公司 | 图像生成模型的训练方法、装置以及存储介质 |
CN118349870A (zh) * | 2024-03-29 | 2024-07-16 | 广东奥普特科技股份有限公司 | 一种模型训练和分类方法及相关设备 |
Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN112561080A (zh) * | 2020-12-18 | 2021-03-26 | Oppo(重庆)智能科技有限公司 | 样本筛选方法、样本筛选装置及终端设备 |
CN112735478A (zh) * | 2021-01-29 | 2021-04-30 | 华南理工大学 | 一种基于加性角惩罚焦点损失的语音情感识别方法 |
CN113222123A (zh) * | 2021-06-15 | 2021-08-06 | 深圳市商汤科技有限公司 | 模型训练方法、装置、设备及计算机存储介质 |
CN113326768A (zh) * | 2021-05-28 | 2021-08-31 | 浙江商汤科技开发有限公司 | 训练方法、图像特征提取方法、图像识别方法及装置 |
CN114091594A (zh) * | 2021-11-15 | 2022-02-25 | 北京市商汤科技开发有限公司 | 模型训练方法及装置、设备、存储介质 |
CN114529750A (zh) * | 2021-12-28 | 2022-05-24 | 深圳云天励飞技术股份有限公司 | 图像分类方法、装置、设备及存储介质 |
CN114972877A (zh) * | 2022-06-09 | 2022-08-30 | 北京百度网讯科技有限公司 | 一种图像分类模型训练方法、装置及电子设备 |
CN115064155A (zh) * | 2022-06-09 | 2022-09-16 | 福州大学 | 一种基于知识蒸馏的端到端语音识别增量学习方法及系统 |
WO2022262757A1 (zh) * | 2021-06-16 | 2022-12-22 | 上海齐感电子信息科技有限公司 | 模型训练方法、图像检测方法及检测装置 |
CN116304811A (zh) * | 2023-02-28 | 2023-06-23 | 王宇轩 | 一种基于焦点损失函数动态样本权重调整方法及系统 |
CN116503670A (zh) * | 2023-06-13 | 2023-07-28 | 商汤人工智能研究中心(深圳)有限公司 | 图像分类及模型训练方法、装置和设备、存储介质 |
-
2023
- 2023-08-25 CN CN202311087732.4A patent/CN116863278B/zh active Active
Patent Citations (12)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN112561080A (zh) * | 2020-12-18 | 2021-03-26 | Oppo(重庆)智能科技有限公司 | 样本筛选方法、样本筛选装置及终端设备 |
CN112735478A (zh) * | 2021-01-29 | 2021-04-30 | 华南理工大学 | 一种基于加性角惩罚焦点损失的语音情感识别方法 |
CN113326768A (zh) * | 2021-05-28 | 2021-08-31 | 浙江商汤科技开发有限公司 | 训练方法、图像特征提取方法、图像识别方法及装置 |
CN113222123A (zh) * | 2021-06-15 | 2021-08-06 | 深圳市商汤科技有限公司 | 模型训练方法、装置、设备及计算机存储介质 |
WO2022262757A1 (zh) * | 2021-06-16 | 2022-12-22 | 上海齐感电子信息科技有限公司 | 模型训练方法、图像检测方法及检测装置 |
CN114091594A (zh) * | 2021-11-15 | 2022-02-25 | 北京市商汤科技开发有限公司 | 模型训练方法及装置、设备、存储介质 |
CN114529750A (zh) * | 2021-12-28 | 2022-05-24 | 深圳云天励飞技术股份有限公司 | 图像分类方法、装置、设备及存储介质 |
CN114972877A (zh) * | 2022-06-09 | 2022-08-30 | 北京百度网讯科技有限公司 | 一种图像分类模型训练方法、装置及电子设备 |
CN115064155A (zh) * | 2022-06-09 | 2022-09-16 | 福州大学 | 一种基于知识蒸馏的端到端语音识别增量学习方法及系统 |
CN116304811A (zh) * | 2023-02-28 | 2023-06-23 | 王宇轩 | 一种基于焦点损失函数动态样本权重调整方法及系统 |
CN116503670A (zh) * | 2023-06-13 | 2023-07-28 | 商汤人工智能研究中心(深圳)有限公司 | 图像分类及模型训练方法、装置和设备、存储介质 |
Non-Patent Citations (2)
Title |
---|
Focal损失在图像情感分析上的应用研究;傅博文;唐向宏;肖涛;;计算机工程与应用(第10期);全文 * |
基于焦点损失的半监督高光谱图像分类;张凯琳;阎庆;夏懿;章军;丁云;;计算机应用(第04期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN116863278A (zh) | 2023-10-10 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN116863278B (zh) | 模型训练方法、图像分类方法、装置、设备及存储介质 | |
CN109815801A (zh) | 基于深度学习的人脸识别方法及装置 | |
CN111126396B (zh) | 图像识别方法、装置、计算机设备以及存储介质 | |
CN111079780B (zh) | 空间图卷积网络的训练方法、电子设备及存储介质 | |
CN112016315B (zh) | 模型训练、文本识别方法及装置、电子设备、存储介质 | |
CN114283350B (zh) | 视觉模型训练和视频处理方法、装置、设备及存储介质 | |
CN108664526A (zh) | 检索的方法和设备 | |
CN110717554A (zh) | 图像识别方法、电子设备及存储介质 | |
CN111626340B (zh) | 一种分类方法、装置、终端及计算机存储介质 | |
CN111523469A (zh) | 一种行人重识别方法、系统、设备及计算机可读存储介质 | |
CN110704668B (zh) | 基于网格的协同注意力vqa方法和装置 | |
CN112270334A (zh) | 一种基于异常点暴露的少样本图像分类方法及系统 | |
CN115205583A (zh) | 图像分类模型训练方法、电子设备和计算机可读存储介质 | |
US20220335566A1 (en) | Method and apparatus for processing point cloud data, device, and storage medium | |
CN112633369B (zh) | 图像匹配方法、装置、电子设备、计算机可读存储介质 | |
CN116503670A (zh) | 图像分类及模型训练方法、装置和设备、存储介质 | |
US20220366242A1 (en) | Information processing apparatus, information processing method, and storage medium | |
CN116486153A (zh) | 图像分类方法、装置、设备及存储介质 | |
CN114972434B (zh) | 一种级联检测和匹配的端到端多目标跟踪系统 | |
CN114912540A (zh) | 迁移学习方法、装置、设备及存储介质 | |
CN115017413A (zh) | 推荐方法、装置、计算设备及计算机存储介质 | |
CN114970732A (zh) | 分类模型的后验校准方法、装置、计算机设备及介质 | |
CN117112880A (zh) | 信息推荐、多目标推荐模型训练方法、装置和计算机设备 | |
CN112749565B (zh) | 基于人工智能的语义识别方法、装置和语义识别设备 | |
CN114329006B (zh) | 图像检索方法、装置、设备、计算机可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |