CN116468112B - 目标检测模型的训练方法、装置、电子设备和存储介质 - Google Patents
目标检测模型的训练方法、装置、电子设备和存储介质 Download PDFInfo
- Publication number
- CN116468112B CN116468112B CN202310357394.5A CN202310357394A CN116468112B CN 116468112 B CN116468112 B CN 116468112B CN 202310357394 A CN202310357394 A CN 202310357394A CN 116468112 B CN116468112 B CN 116468112B
- Authority
- CN
- China
- Prior art keywords
- pseudo tag
- pseudo
- target detection
- loss
- obtaining
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 93
- 238000012549 training Methods 0.000 title claims abstract description 60
- 238000000034 method Methods 0.000 title claims abstract description 49
- 238000003860 storage Methods 0.000 title claims abstract description 16
- 238000009499 grossing Methods 0.000 claims abstract description 32
- 238000012545 processing Methods 0.000 claims abstract description 17
- 238000004364 calculation method Methods 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 12
- 238000012423 maintenance Methods 0.000 claims description 3
- 230000000694 effects Effects 0.000 abstract description 9
- 238000013473 artificial intelligence Methods 0.000 abstract description 3
- 238000013135 deep learning Methods 0.000 abstract description 2
- 239000011159 matrix material Substances 0.000 description 13
- 238000010586 diagram Methods 0.000 description 9
- 238000004891 communication Methods 0.000 description 8
- 238000005457 optimization Methods 0.000 description 7
- 230000006870 function Effects 0.000 description 6
- 238000009825 accumulation Methods 0.000 description 3
- 238000002372 labelling Methods 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 230000008569 process Effects 0.000 description 3
- 230000033228 biological regulation Effects 0.000 description 2
- 238000010276 construction Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000003491 array Methods 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000010191 image analysis Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000001629 suppression Effects 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/07—Target detection
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Multimedia (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种目标检测模型的训练方法、装置、电子设备和存储介质,涉及人工智能技术领域,尤其涉及深度学习、图像处理、计算机视觉等领域。具体实现方案为:基于样本图像以及教师模型,得到当前伪标签;基于所述当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签;基于所述样本图像以及学生模型,得到目标检测结果;基于所述记忆平滑伪标签以及所述目标检测结果,得到第一损失;基于所述第一损失对所述学生模型进行训练,得到目标检测模型。本公开可以提升目标检测模型的训练效果,进而提升目标检测的准确度。
Description
技术领域
本公开涉及人工智能技术领域,尤其涉及深度学习、图像处理、计算机视觉等领域。
背景技术
数据标注缺失长期以来是目标检测中的一个难点。在现实场景中,图像数据相对容易获取,但是对海量的图像数据进行人工标注却需要大量的时间和精力,尤其是在自动驾驶、智慧医疗、缺陷检测和航拍图像分析等领域,需要数据标注者具有较高的专业知识背景,更加难以获取数据标注。因此,使用有标注数据和无标注数据相结合的半监督学习技术得到了广泛关注。
发明内容
本公开提供了一种目标检测模型的训练方法、装置、电子设备和存储介质。
根据本公开的一方面,提供了一种目标检测模型的训练方法,包括:
基于样本图像以及教师模型,得到当前伪标签;
基于所述当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签;
基于所述样本图像以及学生模型,得到目标检测结果;
基于所述记忆平滑伪标签以及所述目标检测结果,得到第一损失;
基于所述第一损失对所述学生模型进行训练,得到目标检测模型。
根据本公开的另一方面,提供了一种目标检测模型的训练装置,包括:
教师模型处理模块,用于基于样本图像以及教师模型,得到当前伪标签;
标签平滑模块,用于基于所述当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签;
学生模型处理模块,用于基于所述样本图像以及学生模型,得到目标检测结果;
损失计算模块,用于基于所述记忆平滑伪标签以及所述目标检测结果,得到第一损失;
训练模块,用于基于所述第一损失对所述学生模型进行训练,得到目标检测模型。
根据本公开的另一方面,提供了一种电子设备,包括:
至少一个处理器;以及
与该至少一个处理器通信连接的存储器;其中,
该存储器存储有可被该至少一个处理器执行的指令,该指令被该至少一个处理器执行,以使该至少一个处理器能够执行本公开实施例中任一的方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,该计算机指令用于使该计算机执行根据本公开实施例中任一的方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,该计算机程序在被处理器执行时实现根据本公开实施例中任一的方法。
根据本公开实施例的技术方案,在利用教师模型和样本图像得到当前伪标签之后,利用当前伪标签和至少一个伪标签,得到记忆平滑伪标签。将该记忆平滑伪标签应用于学生模型训练过程中的损失计算中,如此,可以减少伪标签的偏置,避免伪标签的偏置累积影响后续学生模型的优化,从而提升模型训练效果,进而提升目标检测的准确度。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是本公开一实施例提供的目标检测模型的训练方法的流程示意图;
图2是一个应用示例中目标检测模型的训练方法的流程示意图;
图3是本公开一实施例提供的目标检测模型的训练装置的示意性框图;
图4是本公开另一实施例提供的目标检测模型的训练装置的示意性框图;
图5是用来实现本公开实施例的目标检测模型的训练方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
图1示出了本公开一实施例提供的目标检测模型的训练方法的流程示意图。该方法可以应用于上述目标检测模型的训练装置,该装置可以部署于电子设备中。电子设备例如是单机或多机的终端、服务器或其他处理设备。其中,终端可以为移动设备、个人数字助理(Personal Digital Assistant,PDA)、手持设备、计算设备、车载设备、可穿戴设备等用户设备(User Equipment,UE)。在一些可能的实现方式中,该方法还可以通过处理器调用存储器中存储的计算机可读指令的方式来实现。如图1所示,该方法可以包括:
S110、基于样本图像以及教师模型,得到当前伪标签;
S120、基于当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签;
S130、基于样本图像以及学生模型,得到目标检测结果;
S140、基于记忆平滑伪标签以及目标检测结果,得到第一损失;
S150、基于第一损失对学生模型进行训练,得到目标检测模型。
示例性地,在本公开实施例中,教师模型和学生模型均可用于对样本图像进行目标检测。其中,教师模型输出的目标检测结果可以作为伪标签,以替代人工标签。学生模型输出的目标检测结果可以用于结合伪标签计算损失,从而基于损失进行模型的参数优化,得到目标检测模型。
在一种示例中,学生模型以及教师模型可以具有不同的模型结构,例如,教师模型为结构较为复杂、精度较高的模型,且教师模型可以采用带人工标签的样本图像训练得到。从而教师模型可以输出较为准确的目标检测结果,提高伪标签的准确度。
在另一种示例中,学生模型以及教师模型可以具有相同的模型结构。其中,学生模型的参数优化可以通过利用损失进行梯度的反向传播实现。教师模型的参数优化可以通过对学生模型的参数进行指数移动平均实现。
可选地,在本公开实施例中的模型,例如学生模型、教师模型和目标检测模型,可以是基于PP-YOLOE等无锚框检测框架的模型。
示例性地,历史伪标签库中可以包括前M次训练中产生的伪标签,M为正整数。例如,对学生模型进行多次迭代(多次训练),针对每次迭代,保存前M次迭代中的伪标签得到历史伪标签库,基于历史伪标签库中的每个伪标签和当前伪标签,得到记忆平滑伪标签。可以理解,结合前M次迭代中的信息可以对当前伪标签进行平滑,得到的记忆平滑伪标签相比于当前伪标签,会减少偏置,从而避免偏置的累积,提升伪标签的准确度。
示例性地,在上述步骤S110中,可以将样本图像的弱增强图像输入教师模型,得到教师模型的输出结果,作为当前伪标签。相应地,在上述步骤S130中,可以将样本图像的强增强图像输入学生模型,得到学生模型的输出结果,即目标检测结果。通过将复杂程度较低的弱增强图像输入教师模型,有利于保证伪标签的准确度。通过将复杂程度较高的强增强图像输入学生模型,有利于提升学生模型的学习效果和性能。
根据上述方法,在利用教师模型和样本图像得到当前伪标签之后,利用当前伪标签和至少一个伪标签,得到记忆平滑伪标签。将该记忆平滑伪标签应用于学生模型训练过程中的损失计算中,如此,可以减少伪标签的偏置,避免伪标签的偏置累积影响后续学生模型的优化,从而提升模型训练效果,进而提升目标检测的准确度,有利于实现半监督学习技术在无锚框检测器上的应用。
在一种示例性的实施方式中,上述方法还可以包括:将当前伪标签添加到历史伪标签库中。
可选地,还可以在历史伪标签库中将时间最早的伪标签删除。例如,历史伪标签库中包括M次训练中产生的伪标签;在完成当前迭代时,将当前迭代产生的伪标签添加到历史伪标签库中,并删除前第M次训练中产生的伪标签,保留当前迭代的伪标签以及前M-1次训练中产生的伪标签,用于下一次迭代中。
根据该实施方式,可以实现对历史伪标签的维护,保持历史伪标签库中的伪标签为与当前伪标签的产生时间较为接近,从而避免偏置矫正过度,保证记忆平滑伪标签的准确度。
在一种示例性的实施方式中,S120、基于当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签,可以包括:基于至少一个伪标签中的每个伪标签与当前伪标签之间的时间间隔,确定每个伪标签的权重;基于每个伪标签的权重,对至少一个伪标签以及当前伪标签计算移动平均值,得到多个记忆平滑伪标签。
示例性地,时间间隔大的历史伪标签,权重可以相对较小;时间间隔小的历史伪标签,权重可以相对较大。例如,前1次迭代中产生的历史伪标签与当前伪标签的时间间隔最小,可以设置相对较大的权重。前第M次迭代中产生的历史伪标签与当前伪标签的时间间隔最大,可以设置较小的权重。
示例性地,移动平均值是指对一组测定值,按顺序取一定数量的数据并算出其全部算数平均值。基于此,通过对一组伪标签计算得到的移动平均值,可以包括多个伪标签,该多个伪标签可以作为多个记忆平滑伪标签。
可以理解,基于该权重配置方式进行移动平均值的计算,可以提升计算结果的平滑效果,即提升记忆平滑伪标签的平滑效果,进一步抑制偏置,提升伪标签的准确度。
在一种示例性的实施方式中,S140、基于记忆平滑伪标签以及目标检测结果,得到第一损失,包括:基于多个记忆平滑伪标签两两之间的相关性信息,得到伪标签图;基于目标检测结果与当前伪标签之间的相关性信息,得到预测嵌入图;基于伪标签图与预测嵌入图,得到第一损失。
示例性地,可以通过矩阵相乘的方式,得到上述相关性信息。例如,基于K个记忆平滑伪标签,得到K*1矩阵以及1*K矩阵,对K*1矩阵以及1*K矩阵进行矩阵相乘,则得到的矩阵中各元素可以表征K个记忆平滑伪标签两两之间的相关性信息,其中,K为不小于2的整数。可以理解,矩阵也可以视为图像,因此,可以将该矩阵作为伪标签图。相应地,对学生模型输出的目标检测结果(例如特征图)与教师模型输出的当前伪标签进行矩阵相乘,也可以得到用于表征两者之间的相关性信息的矩阵,将该矩阵作为预测嵌入图。
示例性地,可以通过对预测嵌入图和伪标签图进行对比损失的计算,从而得到第一损失。
根据上述实施方式,可以充分提取不同标签之间的信息相关性,利用伪标签图与预测嵌入图提升损失计算的准确度,从而提升模型优化效果,相应地提升目标检测的准确度。
在一种示例性的实施方式中,S110、基于样本图像以及教师模型,得到当前伪标签,可以包括:将样本图像对应的增强图像输入教师模型,得到教师模型输出的特征图;基于特征图中的每个特征点的分类分数,在特征图中选取N个特征点,N为大于1的整数,且N是基于特征图的尺寸以及样本图像的数量确定的;基于N个特征点中每个特征点对应的检测框,得到当前伪标签。
实际应用中,可以预先确定N值。示例性地,可以先基于特征图的尺寸以及当前训练批次中样本图像的数量计算模型最后一层的特征点数X。再基于特征点数X和预设比例确定N。例如,X=H×W×Y,其中,H和W表示最后一层中的特征图的高度和宽度,Y表示样本图像的数量,H、W和Y均为不小于2的整数。进一步地,预设比例为1%,则N=1%*X。
示例性地,可以选取分类分数最大的N个特征点,将N个特征点中每个特征点对应的检测框作为当前伪标签。
在上述方式中,采用密集学习的思路确定当前伪标签。相比于采用NMS(极大值抑制)方式,可以降低伪标签的构建复杂度,同时可以选取大量的特征点对应的检测框,提升了伪标签中的信息丰富程度,避免错过有用信息。基于此,可以提升模型训练效率以及效果。
在一种示例性的实施方式中,S150、基于第一损失对学生模型进行训练,得到目标检测模型,可以包括:基于当前伪标签以及目标检测结果,确定第二损失;基于第一损失以及第二损失,得到第三损失;基于第三损失对学生模型进行训练,得到目标检测模型。
示例性地,可以基于当前伪标签所包含的特征点信息,在学生模型输出的目标检测结果中确定对应的特征点信息,再基于伪标签中的特征点信息与目标检测结果中的特征点信息计算一致性损失,得到第二损失。
示例性地,上述第二损失可以包括分类损失和回归损失。可选地,分类损失和回归损失均可以通过质量焦点损失(Quality Focal Loss)计算。
根据上述实施方式,可以结合当前伪标签和目标检测结果的一致性损失,以及记忆平滑伪标签对应的损失,进行总损失计算,并基于总损失(第三损失)对学生模型进行训练,从而可以提升训练效果,相应地提升训练得到的目标检测模型的目标检测准确度。
为了便于理解上述实施例,下面提供一个具体的应用示例。图2示出了该应用示例中目标检测模型的训练方法的流程示意图。该方法包括以下三部分内容:
1、基于PP-YOLOE半监督目标检测框架
该半监督目标检测框架使用师生互助学习方法。如图2所示,对教师模型201输入经过弱数据增强的无标注图像202,并产生预测结果203,基于预测结果203生成伪标签204。对学生模型205输入经过强数据增强的无标注图像206,产生学生模型205的预测结果。通过计算学生模型205的预测结果和伪标签204的一致性损失211,来对学生模型205进行优化和参数更新,在此期间教师模型201的梯度不进行反向传播,教师模型201的参数由学生模型205使用指数移动平均方法进行更新。
2、密集学习伪标签
该部分将密集学习方法迁移到PP-YOLOE中。分别统计学生模型205和教师模型201的最后一层特征层的特征点数n,n=H×W×N。其中,H和W表示特征层的长和宽,N表示训练批次中的样本数量。并选取1%×n作为保留的特征点数量;根据n个特征点所预测的各类别中的最大值进行排序,选择分类分数最大的前1%×n个特征点作为计算无监督损失值的特征点。根据前1%×n个特征点的索引值选择学生模型205的预测结果中对应特征点所对应的检测框和分类分数,并计算一致性损失211,一致性损失211包括分类损失和回归损失。由于PP-YOLOE中所使用的损失函数只能对离散值进行处理,因此,在无监督分支将损失函数替换成质量焦点损失(Quality Focal Loss)来对学生模型205预测的分类分数和教师模型201预测的分类分数进行损失计算。
3、基于图的对比学习辅助训练
如图2所示,将本训练迭代的前n个迭代的伪标签存储为种子库207,之后对本迭代伪标签和种子库207中的伪标签计算移动平均,获得记忆平滑伪标签来减少伪标签的偏置。对记忆平滑伪标签自身计算矩阵乘法,得到伪标签图208。针对学生模型的预测,通过计算学生模型的输出和教师模型的输出的矩阵乘法来计算两者的相似度,并进行归一化处理,得到本迭代的学生模型的预测嵌入图209。最终使用预测嵌入图209与伪标签图208计算对比辅助损失210,利用对比辅助损失210优化无监督分类分支。
可以看到,本公开实施例的方法,将基于图的学习模型和对比学习方法结合应用到半监督目标检测中,使用前几次迭代的伪标签生成平滑伪标签,降低伪标签偏置,利于对比损失辅助优化分类分支训练,提高了物体检测精度。
根据本公开的实施例,本公开还提供了一种目标检测模型的训练装置,图3示出了本公开一实施例提供的目标检测模型的训练装置的示意性框图,如图3所示,该装置包括:
教师模型处理模块310,用于基于样本图像以及教师模型,得到当前伪标签;
标签平滑模块320,用于基于当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签;
学生模型处理模块330,用于基于样本图像以及学生模型,得到目标检测结果;
损失计算模块340,用于基于记忆平滑伪标签以及目标检测结果,得到第一损失;
训练模块350,用于基于第一损失对学生模型进行训练,得到目标检测模型。
图4是根据本公开另一实施例的目标检测模型的训练装置的结构示意图,该装置可以包括上述实施例的目标检测模型的训练装置的一个或多个特征,在一种可能的实施方式中,该装置还包括:
伪标签库维护模块410,用于将当前伪标签添加到历史伪标签库中。
示例性地,在本公开实施例中,标签平滑模块320用于:
基于至少一个伪标签中的每个伪标签与当前伪标签之间的时间间隔,确定每个伪标签的权重;
基于每个伪标签的权重,对至少一个伪标签以及当前伪标签计算移动平均值,得到多个记忆平滑伪标签。
示例性地,在本公开实施例中,损失计算模块340用于:
基于多个记忆平滑伪标签两两之间的相关性信息,得到伪标签图;
基于目标检测结果与当前伪标签之间的相关性信息,得到预测嵌入图;
基于伪标签图与预测嵌入图,得到第一损失。
示例性地,在本公开实施例中,教师模型处理模块310用于:
将样本图像对应的增强图像输入教师模型,得到教师模型输出的特征图;
基于特征图中的每个特征点的分类分数,在特征图中选取N个特征点;其中,N为大于1的整数,且N是基于特征图的尺寸以及样本图像的数量确定的;
基于N个特征点中每个特征点对应的检测框,得到当前伪标签。
示例性地,在本公开实施例中,训练模块350用于:
基于当前伪标签以及目标检测结果,确定第二损失;
基于第一损失以及第二损失,得到第三损失;
基于第三损失对学生模型进行训练,得到目标检测模型。
本公开实施例的装置的各模块、子模块的具体功能和示例的描述,可以参见上述方法实施例中对应步骤的相关描述,在此不再赘述。
本公开的技术方案中,所涉及的用户个人信息的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图5示出了可以用来实施本公开的实施例的示例电子设备500的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字助理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图5所示,电子设备500包括计算单元501,其可以根据存储在只读存储器(ROM)502中的计算机程序或者从存储单元508加载到随机访问存储器(RAM)503中的计算机程序,来执行各种适当的动作和处理。在RAM 503中,还可存储设备500操作所需的各种程序和数据。计算单元501、ROM 502以及RAM 503通过总线504彼此相连。输入/输出(I/O)接口505也连接至总线504。
电子设备500中的多个部件连接至I/O接口505,包括:输入单元506,例如键盘、鼠标等;输出单元507,例如各种类型的显示器、扬声器等;存储单元508,例如磁盘、光盘等;以及通信单元509,例如网卡、调制解调器、无线通信收发机等。通信单元509允许设备500通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元501可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元501的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元501执行上文所描述的各个方法和处理,例如目标检测模型的训练方法。例如,在一些实施例中,目标检测模型的训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元508。在一些实施例中,计算机程序的部分或者全部可以经由ROM 502和/或通信单元509而被载入和/或安装到设备500上。当计算机程序加载到RAM 503并由计算单元501执行时,可以执行上文描述的目标检测模型的训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元501可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行目标检测模型的训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、现场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入、或者触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (13)
1.一种目标检测模型的训练方法,包括:
基于样本图像以及教师模型,得到当前伪标签;
基于所述当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签;
基于所述样本图像以及学生模型,得到目标检测结果;
基于所述记忆平滑伪标签以及所述目标检测结果,得到第一损失;
基于所述第一损失对所述学生模型进行训练,得到目标检测模型;
其中,所述基于所述当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签,包括:
基于所述至少一个伪标签中的每个伪标签与所述当前伪标签之间的时间间隔,确定所述每个伪标签的权重;
基于所述每个伪标签的权重,对所述至少一个伪标签以及所述当前伪标签计算移动平均值,得到多个记忆平滑伪标签。
2.根据权利要求1所述的方法,还包括:
将所述当前伪标签添加到所述历史伪标签库中。
3.根据权利要求1所述的方法,其中,所述基于所述记忆平滑伪标签以及所述目标检测结果,得到第一损失,包括:
基于所述多个记忆平滑伪标签两两之间的相关性信息,得到伪标签图;
基于所述目标检测结果与所述当前伪标签之间的相关性信息,得到预测嵌入图;
基于所述伪标签图与所述预测嵌入图,得到所述第一损失。
4.根据权利要求1-3中任一项所述的方法,其中,所述基于样本图像以及教师模型,得到当前伪标签,包括:
将样本图像对应的增强图像输入所述教师模型,得到所述教师模型输出的特征图;
基于所述特征图中的每个特征点的分类分数,在所述特征图中选取N个特征点;其中,N为大于1的整数,且N是基于所述特征图的尺寸以及样本图像的数量确定的;
基于所述N个特征点中每个特征点对应的检测框,得到所述当前伪标签。
5.根据权利要求1-3中任一项所述的方法,其中,所述基于所述第一损失对所述学生模型进行训练,得到目标检测模型,包括:
基于所述当前伪标签以及所述目标检测结果,确定第二损失;
基于所述第一损失以及所述第二损失,得到第三损失;
基于所述第三损失对所述学生模型进行训练,得到所述目标检测模型。
6.一种目标检测模型的训练装置,包括:
教师模型处理模块,用于基于样本图像以及教师模型,得到当前伪标签;
标签平滑模块,用于基于所述当前伪标签以及历史伪标签库中的至少一个伪标签,得到记忆平滑伪标签;
学生模型处理模块,用于基于所述样本图像以及学生模型,得到目标检测结果;
损失计算模块,用于基于所述记忆平滑伪标签以及所述目标检测结果,得到第一损失;
训练模块,用于基于所述第一损失对所述学生模型进行训练,得到目标检测模型;
其中,所述标签平滑模块用于:
基于所述至少一个伪标签中的每个伪标签与所述当前伪标签之间的时间间隔,确定所述每个伪标签的权重;
基于所述每个伪标签的权重,对所述至少一个伪标签以及所述当前伪标签计算移动平均值,得到多个记忆平滑伪标签。
7.根据权利要求6所述的装置,还包括:
伪标签库维护模块,用于将所述当前伪标签添加到所述历史伪标签库中。
8.根据权利要求6所述的装置,其中,所述损失计算模块用于:
基于所述多个记忆平滑伪标签两两之间的相关性信息,得到伪标签图;
基于所述目标检测结果与所述当前伪标签之间的相关性信息,得到预测嵌入图;
基于所述伪标签图与所述预测嵌入图,得到所述第一损失。
9.根据权利要求6-8中任一项所述的装置,其中,所述教师模型处理模块用于:
将样本图像对应的增强图像输入所述教师模型,得到所述教师模型输出的特征图;
基于所述特征图中的每个特征点的分类分数,在所述特征图中选取N个特征点;其中,N为大于1的整数,且N是基于所述特征图的尺寸以及样本图像的数量确定的;
基于所述N个特征点中每个特征点对应的检测框,得到所述当前伪标签。
10.根据权利要求6-8中任一项所述的装置,其中,所述训练模块用于:
基于所述当前伪标签以及所述目标检测结果,确定第二损失;
基于所述第一损失以及所述第二损失,得到第三损失;
基于所述第三损失对所述学生模型进行训练,得到所述目标检测模型。
11.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-5中任一项所述的方法。
12.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-5中任一项所述的方法。
13.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-5中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310357394.5A CN116468112B (zh) | 2023-04-06 | 2023-04-06 | 目标检测模型的训练方法、装置、电子设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310357394.5A CN116468112B (zh) | 2023-04-06 | 2023-04-06 | 目标检测模型的训练方法、装置、电子设备和存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116468112A CN116468112A (zh) | 2023-07-21 |
CN116468112B true CN116468112B (zh) | 2024-03-12 |
Family
ID=87178251
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310357394.5A Active CN116468112B (zh) | 2023-04-06 | 2023-04-06 | 目标检测模型的训练方法、装置、电子设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116468112B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117115107B (zh) * | 2023-08-24 | 2024-06-07 | 哪吒港航智慧科技(上海)有限公司 | 基于长尾分布概率的外观缺陷检测模型的训练方法及装置 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112801212A (zh) * | 2021-03-02 | 2021-05-14 | 东南大学 | 一种基于小样本半监督学习的白细胞分类计数方法 |
CN113609965A (zh) * | 2021-08-03 | 2021-11-05 | 同盾科技有限公司 | 文字识别模型的训练方法及装置、存储介质、电子设备 |
CN114881129A (zh) * | 2022-04-25 | 2022-08-09 | 北京百度网讯科技有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN114943689A (zh) * | 2022-04-27 | 2022-08-26 | 河钢数字技术股份有限公司 | 基于半监督学习的钢铁冷轧退火炉元器件检测方法 |
CN115019060A (zh) * | 2022-07-12 | 2022-09-06 | 北京百度网讯科技有限公司 | 目标识别方法、目标识别模型的训练方法及装置 |
CN115240035A (zh) * | 2022-07-29 | 2022-10-25 | 北京百度网讯科技有限公司 | 半监督目标检测模型训练方法、装置、设备以及存储介质 |
CN115273148A (zh) * | 2022-08-03 | 2022-11-01 | 北京百度网讯科技有限公司 | 行人重识别模型训练方法、装置、电子设备及存储介质 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20040260550A1 (en) * | 2003-06-20 | 2004-12-23 | Burges Chris J.C. | Audio processing system and method for classifying speakers in audio data |
US20130297600A1 (en) * | 2012-05-04 | 2013-11-07 | Thierry Charles Hubert | Method and system for chronological tag correlation and animation |
-
2023
- 2023-04-06 CN CN202310357394.5A patent/CN116468112B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112801212A (zh) * | 2021-03-02 | 2021-05-14 | 东南大学 | 一种基于小样本半监督学习的白细胞分类计数方法 |
CN113609965A (zh) * | 2021-08-03 | 2021-11-05 | 同盾科技有限公司 | 文字识别模型的训练方法及装置、存储介质、电子设备 |
CN114881129A (zh) * | 2022-04-25 | 2022-08-09 | 北京百度网讯科技有限公司 | 一种模型训练方法、装置、电子设备及存储介质 |
CN114943689A (zh) * | 2022-04-27 | 2022-08-26 | 河钢数字技术股份有限公司 | 基于半监督学习的钢铁冷轧退火炉元器件检测方法 |
CN115019060A (zh) * | 2022-07-12 | 2022-09-06 | 北京百度网讯科技有限公司 | 目标识别方法、目标识别模型的训练方法及装置 |
CN115240035A (zh) * | 2022-07-29 | 2022-10-25 | 北京百度网讯科技有限公司 | 半监督目标检测模型训练方法、装置、设备以及存储介质 |
CN115273148A (zh) * | 2022-08-03 | 2022-11-01 | 北京百度网讯科技有限公司 | 行人重识别模型训练方法、装置、电子设备及存储介质 |
Non-Patent Citations (1)
Title |
---|
Long Lan,Xiao Teng.Multi-scale Knowledge Distillation for Unsupervised Person Re-Identification.《arXiv:2204.09931v1》.2022,全文. * |
Also Published As
Publication number | Publication date |
---|---|
CN116468112A (zh) | 2023-07-21 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113326764B (zh) | 训练图像识别模型和图像识别的方法和装置 | |
CN113657465B (zh) | 预训练模型的生成方法、装置、电子设备和存储介质 | |
CN114020950B (zh) | 图像检索模型的训练方法、装置、设备以及存储介质 | |
CN112862005B (zh) | 视频的分类方法、装置、电子设备和存储介质 | |
CN113657483A (zh) | 模型训练方法、目标检测方法、装置、设备以及存储介质 | |
CN114693934B (zh) | 语义分割模型的训练方法、视频语义分割方法及装置 | |
CN115147680B (zh) | 目标检测模型的预训练方法、装置以及设备 | |
CN116468112B (zh) | 目标检测模型的训练方法、装置、电子设备和存储介质 | |
CN113538235A (zh) | 图像处理模型的训练方法、装置、电子设备及存储介质 | |
CN113627536A (zh) | 模型训练、视频分类方法,装置,设备以及存储介质 | |
CN113792876B (zh) | 骨干网络的生成方法、装置、设备以及存储介质 | |
CN118053027A (zh) | 一种缺陷识别方法、装置、电子设备及存储介质 | |
CN113657468A (zh) | 预训练模型的生成方法、装置、电子设备和存储介质 | |
CN113361574A (zh) | 数据处理模型的训练方法、装置、电子设备及存储介质 | |
CN115239889B (zh) | 3d重建网络的训练方法、3d重建方法、装置、设备和介质 | |
CN115272705B (zh) | 显著性物体检测模型的训练方法、装置以及设备 | |
CN114882313B (zh) | 生成图像标注信息的方法、装置、电子设备及存储介质 | |
CN115081630A (zh) | 多任务模型的训练方法、信息推荐方法、装置和设备 | |
CN112560987B (zh) | 图像样本处理方法、装置、设备、存储介质和程序产品 | |
CN113139463B (zh) | 用于训练模型的方法、装置、设备、介质和程序产品 | |
CN114792097A (zh) | 预训练模型提示向量的确定方法、装置及电子设备 | |
CN116416500B (zh) | 图像识别模型训练方法、图像识别方法、装置及电子设备 | |
CN114693950B (zh) | 一种图像特征提取网络的训练方法、装置及电子设备 | |
CN113223058B (zh) | 光流估计模型的训练方法、装置、电子设备及存储介质 | |
CN114331379B (zh) | 用于输出待办任务的方法、模型训练方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |