CN113361645A - 基于元学习及知识记忆的目标检测模型构建方法及系统 - Google Patents
基于元学习及知识记忆的目标检测模型构建方法及系统 Download PDFInfo
- Publication number
- CN113361645A CN113361645A CN202110753866.XA CN202110753866A CN113361645A CN 113361645 A CN113361645 A CN 113361645A CN 202110753866 A CN202110753866 A CN 202110753866A CN 113361645 A CN113361645 A CN 113361645A
- Authority
- CN
- China
- Prior art keywords
- training
- model
- category
- target detection
- target
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000001514 detection method Methods 0.000 title claims abstract description 182
- 238000010276 construction Methods 0.000 title claims abstract description 16
- 238000012549 training Methods 0.000 claims abstract description 167
- 238000000034 method Methods 0.000 claims abstract description 64
- 230000008569 process Effects 0.000 claims abstract description 25
- 238000012216 screening Methods 0.000 claims abstract description 10
- 238000011176 pooling Methods 0.000 claims description 56
- 238000012360 testing method Methods 0.000 claims description 24
- 230000006870 function Effects 0.000 claims description 13
- 238000013135 deep learning Methods 0.000 claims description 9
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 230000000007 visual effect Effects 0.000 abstract description 4
- 238000013508 migration Methods 0.000 description 10
- 230000005012 migration Effects 0.000 description 10
- 238000013526 transfer learning Methods 0.000 description 7
- 238000012935 Averaging Methods 0.000 description 6
- 238000003860 storage Methods 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- 238000004364 calculation method Methods 0.000 description 3
- 238000000605 extraction Methods 0.000 description 3
- 230000007547 defect Effects 0.000 description 2
- 238000009776 industrial production Methods 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 238000012952 Resampling Methods 0.000 description 1
- 230000003044 adaptive effect Effects 0.000 description 1
- 238000009411 base construction Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Abstract
本发明实施例涉及工业场景视觉检测技术领域,公开了一种基于元学习及知识记忆的目标检测模型构建方法及系统。该方法包括:利用开源数据集对目标检测模型进行训练得到预训练模型;利用预训练模型统计开源数据集中每个类别的特征,并构建开源特征记忆库Bp;利用检测目标的训练集微调预训练模型;利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库;将开源特征记忆库加入目标检测模型后进行训练;通过相似度误差和预测误差对目标检测模型动态更新。本发明将先验知识动态迁移至算法模型对目标样本的学习过程中,提升对长尾样本的学习能力和识别精度。
Description
技术领域
本发明涉及工业场景视觉检测技术应用技术领域,特别涉及一种基于元学习及知识记忆的目标检测模型构建方法及系统。
背景技术
随着计算机视觉算法以及深度学习的知识体系趋于成熟,基于深度学习的目标检测应用已大量应用于智能工业场景,比如缺陷识别,原件分拣,流水线生产实时监控等。然而在真实的工作场景中,由于目标数量不均和样本采集难度不同,实际采集到的数据集通常是不均衡的。利用不均衡样本训练的目标检测模型识别能力大幅度下降,在工作中经常出现误检,错检和漏检等问题。
工业应用中普遍使用的数据平衡方法包括图像增强算法以及人工采集方法。这类方法带来三个问题:首先,人工丰富过程繁琐复杂并且样本重复率高,降低工作效率;其次,这类方法虽然平衡了类间样本量,但是没有显著丰富尾类特征。导致模型对小样本过拟合学习,模型的泛化能力差。最后,人工丰富样本无法提升算法识别能力上限,难以突破实际应用瓶颈。
针对不均衡样本的视觉识别问题主要包涵重加权和重采样两种算法,目前衍生出四种主要方法。第一种是使用类别均衡损失自适应函数,让模型在训练时更关注尾类,但这种算法无法提升对所有类别的整体识别能力,难以提升算法识别能力上限。其次,基于重采样方法旨在调整头尾类别的采样率平衡输入样本,但这种方法不可避免地带来头类样本欠学习,尾类样本过拟合等问题。再次,基于课程学习和集成学习等方法证明了分段式学习策略能有效提升模型对所有类别的识别精度。但其网络架构和训练过程复杂,无法灵活落地实现。最后,基于迁移学习及元学习的方法利用大样本泛化小样本特征,从而提高模型对小样本的学习速度和学习能力。相比之下,基于迁移学习的解决方法更加灵活,适合需求简单但环境复杂的工业视觉任务。
在工业视觉目标检测领域,迁移学习的主要应用方式仍停留于调优训练大样本预训练模型,对于如何利用外部开源数据,如何筛选可迁移特征,如何实现知识的动态迁移过程等方面暂未提出具体有效的实施方案。
发明内容
本发明实施例的目的在于提供一种基于元学习及知识记忆的目标检测模型构建方法及系统,可提升模型在面对不均衡样本时对所有目标的检测和识别能力,以及模型在多个场景的泛化能力。
为解决上述技术问题,第一方面,本发明实施例提供了一种基于元学习及知识记忆的目标检测模型构建方法,包括:
获取检测目标的训练集、检测目标的测试集以及外部开源数据集;
利用所述开源数据集对选取的目标检测模型进行训练得到预训练模型;所述训练模型包含预训练骨干网络以及预测网络;
利用所述预训练模型统计所述开源数据集中每个类别的特征,并构建开源特征记忆库Bp;所述Bp包含每个类别的卷积层特征、池化层特征以及对应的类别标签形成的记忆单元;
利用所述检测目标的训练集微调所述预训练模型;微调过程中固定所述预训练骨干网络的浅层卷积层,降低所述预训练骨干网络中其余卷积层的学习率,重新训练所述目标检测网络中的预测网络;
利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库;
将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差,以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型动态更新;
利用所述检测目标的测试集测试更新后的所述目标检测模型。
另外,所述利用所述预训练模型统计所述开源数据集中每个类别的特征,包括:从所述开源数据集中为每个类别随机抽取预设数量的图片,分别输入所述预训练模型并得到所述预训练骨干网络最后一个卷积层输出的特征;计算每个类别的平均卷积层特征,作为每个类别对应的卷积层特征;对平均卷积层特征每个通道的特征图取全局平均值得到每个类别的池化层特征。
另外,构建开源特征记忆库Bp,包括:对所述卷积层特征和池化层特征进行归一化后与对应的类别标签构成记忆单元;将每个类别的记忆单元组合后得到所述开源特征记忆库Bp。
另外,所述降低所述预训练骨干网络中其余卷积层的学习率,包括:将所述其余卷积层的学习率降低至训练得到所述预训练模型时骨干网络中对应卷积层的学习率的预设比例。
另外,所述预设比例为二分之一。
另外,所述利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,包括:从所述检测目标的训练集中为每个类别随机抽取预设数量的样本,输入调整后的所述目标检测模型统计得到每个类别的卷积层特征以及池化层特征。
另外,从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,包括:采用EMD距离从所述Bp中遍历出与所述检测目标的训练集中的每个类别的各池化层特征匹配的记忆单元。
另外,所述将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差,以及目标检测网络预测误差,通过所述特征相似度误差和所述预测误差对所述目标检测模型进行动态更新,包括:
在包含所述新的特征记忆库的目标检测模型的骨干网络的最后一个卷积层后加入全局池化层,通过所述骨干网络直接得到当前样本的卷积层特征和池化层特征;
对所述当前样本的卷积层特征和池化层特征归一化,利用MMD距离逐条计算当前池化层特征与所述新的特征记忆库中的记忆单元的池化层特征的相似度,得到与当前检测目标匹配的记忆单元;
利用MSE距离计算当前卷积层特征与匹配的记忆单元的卷积层特征的相似度误差,得到所述骨干网络的记忆库匹配损失函数Lmse;
将所述当前样本的卷积层特征继续输入至后续预测网络,计算预测误差,得到损失函数Lce;
训练所述目标检测网络,利用所述Lmse以及Lce更新所述目标检测模型。
另外,所述目标检测模型为深度学习神经网络。
第二方面,本发明实施例提供了一种基于元学习及知识记忆的目标检测模型构建系统,包括:
获取模块,用于获取检测目标的训练集、检测目标的测试集以及外部开源数据集;
预训练模块,用于利用所述开源数据集对选取的目标检测模型进行训练得到预训练模型;所述预训练模型包含预训练骨干网络及预测网络;
特征记忆库构建模块,用于利用所述预训练模型统计所述开源数据集中每个类别的特征,并构建开源特征记忆库Bp;所述Bp包含每个类别的卷积层特征、池化层特征以及对应的类别标签形成的记忆单元;
模型调整模块,用于利用所述检测目标的训练集微调所述预训练模型;微调过程中固定所述预训练骨干网络的浅层卷积层,降低所述预训练骨干网络中其余卷积层的学习率,重新训练所述目标检测网络的预测网络;
特征匹配模块,用于利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库;
模型更新模块,用于将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差,以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型动态更新;
模型测试模块,用于利用所述检测目标的测试集测试更新后的所述目标检测模型。
本发明实施例与现有技术相比,通过获取检测目标的训练集、检测目标的测试集以及外部开源数据集;利用所述开源数据集对选取的目标检测模型进行训练得到预训练模型;所述训练模型包含预训练骨干网络以及预测网络;利用所述预训练模型统计所述开源数据集中每个类别的特征,并构建开源特征记忆库Bp;所述特征记忆库包含每个类别的卷积层特征、池化层特征以及对应的类别标签形成的记忆单元;利用所述检测目标的训练集微调所述预训练模型;微调过程中固定所述预训练骨干网络的浅层卷积层,降低所述预训练骨干网络中其余卷积层的学习率,重新训练所述目标检测网络的预测网络;利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从所述源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库;将所述开源特征记忆库加加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型动态更新;利用所述检测目标的测试集测试更新后的所述目标检测模型。因此,本发明实施例通过迁移学习及元学习的知识记忆概念解决了基于深度学习的目标检测模型在面对不均衡样本时识别能力严重下降等问题。通过将外部大样本数据的特征知识迁移至当前学习过程,加速算法模型对小样本的学习进程,矫正对小样本的学习偏移,提高对所有类别的整体学习和识别能力。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,可以理解地,下面描述中的附图仅仅是本发明的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。
图1是本发明实施例一提供的基于元学习和知识记忆的目标检测模型构建方法的流程示意图;
图2是本发明实施例一提供的基于元学习和知识记忆的目标检测模型构建源特征库的流程示意图;
图3是本发明实施例一提供的基于元学习和知识记忆的目标检测模型动态更新的网络架构示意图;
图4是本发明实施例一提供的基于元学习和知识记忆的目标检测模型动态更新的流程示意图;
图5是本发明实施例二提供的基于元学习和知识记忆的目标检测模型构建系统的结构示意图;
图6是本发明实施例三提供的服务器的结构示意图。
具体实施方式
为使本发明的目的、技术方案和优点更加清楚,以下将参照本发明实施例中的附图,通过实施方式清楚、完整地描述本发明的技术方案,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
图1是本发明实施例一提供的一种基于元学习和知识记忆的目标检测模型构建方法的流程图。该方法可以由本发明实施例提供的一种基于元学习和知识记忆的目标检测模型构建方法系统来执行,该系统可以采用软件和/或硬件的方式实现,并配置于服务器。如图1所示,本实施例的基于元学习和知识记忆的目标检测模型构建方法包括以下步骤:
步骤101:获取检测目标的训练集、检测目标的测试集以及外部开源数据集。
开源数据集可以从互联网下载,在选取开源数据集时,应尽量选取和当前任务相关的数据集。如若外部开源数据集相关性差,则可迁移特征数量大幅减少,影响模型对检测目标的学习速度和精度。
步骤102:利用所述开源数据集对选取的目标检测模型进行训练得到预训练模型;所述训练模型包含预训练骨干网络及预测网络。
目标检测模型可以根据实际需求选择深度学习神经网络,比如EfficientNet,CenterNet,YoloV5等,在此不做具体限制。然后利用选取的开源数据集对目标检测模型进行训练得到预训练模型。基于深度学习的目标检测模型通常包含骨干网络和预测网络。骨干网络可看作是图像的特征提取模块,迁移学习过程通常在此模块完成。
步骤103:利用所述预训练模型统计所述开源数据集中每个类别的特征,并构建开源特征记忆库Bp;所述Bp包含每个类别的卷积层特征、池化层特征以及对应的类别标签形成的记忆单元。
可选地,利用所述预训练模型统计所述开源数据集中每个类别的特征,包括:从所述开源数据集中为每个类别随机抽取预设数量的图片,分别输入所述预训练模型并得到所述预训练骨干网络最后一个卷积层输出的特征;计算每个类别的平均卷积层特征(即分别统计每个类别的所有图片的特征均值),作为每个类别对应的卷积层特征;对平均卷积层特征每个通道内的特征图取全局平均值得到每个类别的池化层特征。具体地,为了减少随机性,为开源数据集中每个类别随机抽取100张图片,分别输入到预训练模型中,采用预训练骨干网络的最后一个卷积层输出的特征作为当前类别的特征,得到当前类别对应的100张图片的随机特征,再对该类别的所有随机特征取均值,得到此类别的平均特征。需要注意的是,此处的平均操作只针对100个图片在同一个特征点的特征,特征点间与通道间不做平均操作。将平均特征命名为卷积层特征。得到平均特征后,对卷积层特征每一个通道内的特征图取全局平均值,这一步操作等同于全局池化。将当前C*H*W的三维特征转为H*1的二维特征,此步骤得到的特征命名为池化层特征。
可选地,构建开源特征记忆库Bp包括:对所述卷积层特征和池化层特征进行归一化后与对应的类别标签构成记忆单元;将每个类别的记忆单元组合后得到所述开源特征记忆库Bp。对卷积层特征和池化层特征归一化将其数值限制在[0,1]区间内,然后与当前类别标签构建记忆单元。因此,每个记忆单元包含卷积层特征、池化层特征以及对应的类别标签。构建操作可参考Python字典,此处不再赘述。Bp包含开源数据集中每个类别的特征,请参考图2。
步骤104:利用所述检测目标的训练集微调所述预训练模型;微调过程中固定所述预训练骨干网络的浅层卷积层,降低所述预训练骨干网络中其余卷积层的学习率,重新训练所述预目标检测网络的预测网络。
可选地,降低所述预训练骨干网络中其余卷积层的学习率,包括:将所述其余卷积层的学习率降低至训练得到所述预训练模型时骨干网络中对应卷积层的学习率的预设比例,该预设比例可以为二分之一。具体地,利用检测目标的训练集重新训练预训练模型时,固定预训练骨干网络的前三层(即浅层卷积层),其余卷积层的学习率降至步骤102训练得到预训练模型时对应卷积层的学习率的1/2。由于网络浅层特征一般是图像的普遍特征,还未具有针对指定目标的深层语义。因此在微调时固定网络的前三层用来提取图像的普遍特征,微调后续的卷积层让网络学习到针对于当前任务的深层语义特征。
步骤105:利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库。
可选地,所述利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,包括:从所述检测目标的训练集中为每个类别随机抽取预设数量的样本,输入调整后的所述目标检测模型统计得到每个类别的卷积层特征以及池化层特征。具体地,步骤104训练结束后,为检测目标的训练集中的每个类别随机抽取100张图片,如果当前类别不足100张图片(亦可称为样本),则选择所有样本。将每个类别的所有样本分别输入调整后的所述目标检测模型,得到其骨干网络的最后一个卷积层的特征,即得到当前类别对应的所有图片的随机特征,再对该类别的所有随机特征取均值,得到此类别的平均特征。与通过目标检测网络抽取开源数据集中每个类别的特征相同,此处的平均操作只针对所有图片在同一个特征点内的特征,特征点间与通道间不做平均操作。将平均特征命名为卷积层特征。得到平均特征后,对卷积层特征每一个通道内的特征图取全局平均值,得到该类别的池化层特征。然后对每个类别的卷积层特征和池化层特征进行归一化。
可选地,从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,包括:采用EMD距离从所述Bp中遍历出与所述检测目标的训练集中的每个类别的各池化层特征匹配的记忆单元。
具体地,利用EMD距离(Earth Mover Distance,地球移动距离)在开源特征记忆库Bp中寻找与当前类别匹配度最高的五个类别,加入到一个新的特征记忆库中,如果新的特征记忆库中存在将加入的类别的特征,则跳过。利用检测目标的训练集中的所有类别的特征分别遍历开源特征记忆库Bp,得到与每个类别匹配度最高的五种类别的特征,加入一个新的特征记忆库。新的特征记忆库由于与检测目标相似,因此可以避免迁移学习过程中的负迁移问题。
步骤106:将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型动态更新。
如图4所示,将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型进行动态更新,包括:
步骤401:在包含所述新的特征记忆库的目标检测模型的骨干网络的最后一个卷积层后加入全局池化层,通过所述骨干网络直接得到当前样本的卷积层特征和池化层特征。
具体地,在当前目标检测网络的基础上修改网络框架,网络框架的具体信息请参考图3。在骨干网络的最后一个卷积层后加入全局池化层,得到当前输出的池化层特征。
步骤402:对所述当前样本的卷积层特征和池化层特征归一化,利用MMD距离逐条计算当前池化层特征与所述新的特征记忆库中的记忆单元的池化层特征的相似度,得到与当前检测目标匹配的记忆单元。
对卷积层特征和池化层特征归一化,使其取值范围固定在[0,1]区间内。利用MMD(Maximum Mean Discrepancy,最大均值差异)距离逐条计算当前池化层特征与新的特征记忆库中的池化层特征,得到与当前检测目标最相近的记忆单元。
步骤403:利用MSE距离计算当前卷积层特征与匹配的记忆单元的卷积层特征的相似度误差,得到所述骨干网络的记忆库匹配损失函数Lmse。由于类别不同,误差只能尽量减小,不能完全为零。
步骤404:将所述当前样本的卷积层特征继续输入至后续预测网络,计算预测误差,得到损失函数Lce。
具体地,将步骤401得到的特征继续输入至后续预测网络,计算预测误差,与训练集标签计算损失函数Lce。
步骤405:训练所述目标检测网络,利用所述Lmse以及Lce更新所述目标检测模型。
请参考图3,一方面,Lce计算的是预测误差,因此可通过Lce的误差反向传播过程迭代更新骨干网络以及预测网络。另一方面,Lmse计算的是特征匹配误差,缩小Lmse误差可拉近当前特征与记忆单元特征之间的距离,利用特征记忆库矫正骨干网络学习到的特征,加速网络对所有样本的学习能力以及学习速度。
步骤107:利用所述检测目标的测试集测试更新后的所述目标检测模型。
本发明实施例通过迁移学习以及元学习中的知识记忆概念解决了基于深度学习的目标检测网络当面对不均衡样本时识别能力严重下降等问题。通过将外部大样本数据的知识迁移至当前学习过程,加速模型对小样本的学习进程,矫正对小样本的学习偏移,提高对所有类别的整体学习和识别能力。因此,本发明实施例从实际应用需求出发,同时考虑到迁移学习的技术缺陷以及实际应用限制,有效解决了目标检测的长尾识别问题。与现有技术相比,本发明实施例有如下优点:
1、考虑检测目标之间可能不存在任何的相关性,利用外部数据丰富类别特征,有效解决了长尾识别问题中模型对尾类识别能力不足,学习过程中尾类特征逐渐向头类偏移等问题。
2、通过距离函数筛选可迁移特征,有效避免的迁移学习过程中的负迁移问题,同时增强了知识迁移的可解释性。
3、在原网络模型架构中嵌入外部知识记忆单元并引入特征匹配误差,实现端到端训练动态知识迁移过程。本发明实施例可应用于任何目标检测算法,满足工业生产需要。
图5是本发明实施例二提供的基于元学习和知识记忆的目标检测模型构建系统的结构框图。该系统可配置于服务器,用于执行上述任意实施例所提供的基于元学习和知识记忆的目标检测模型构建方法。该系统500包括:
获取模块501,用于获取检测目标的训练集、检测目标的测试集以及外部开源数据集。
预训练模块502,用于利用所述开源数据集对选取的目标检测模型进行训练得到预训练模型;所述预训练模型包含预训练骨干网络以及预测网络。
特征记忆库构建模块503,用于利用所述预训练模型统计所述开源数据集中每个类别的特征,并构建开源特征记忆库Bp;所述Bp包含每个类别的卷积层特征、池化层特征以及对应的类别标签形成的记忆单元。
模型调整模块504,用于利用所述检测目标的训练集微调所述预训练模型;微调过程中固定所述预训练骨干网络的浅层卷积层,降低所述预训练骨干网络中其余卷积层的学习率,重新训练所述目标检测网络的预测网络。
特征匹配模块505,用于利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库。
模型更新模块506,用于将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型动态更新。
模型测试模块507,用于利用所述检测目标的测试集测试更新后的所述目标检测模型。
可选地,特征记忆库构建模块503包括:
源图片卷积层特征计算子模块,用于从所述开源数据集中为每个类别随机抽取预设数量的图片,分别输入所述预训练模型并得到所述预训练骨干网络最后一个卷积层输出的特征;
源卷积层特征统计子模块,用于计算每个类别的平均卷积层特征,作为每个类别对应的卷积层特征,作为每个类别对应的卷积层特征;以及
源池化层特征统计子模块,用于对平均卷积层特征每个通道内的特征图取全局平均值得到每个类别的池化层特征。
可选地,特征记忆库构建模块503还包括:
源特征归一化子模块,用于对所述卷积层特征和池化层特征进行归一化后与对应的类别标签构成记忆单元;
组合子模块,用于将每个类别的记忆单元组合后得到所述开源特征记忆库Bp。
可选地,模型调整模块504具体用于将所述其余卷积层的学习率降低至训练得到所述预训练模型时骨干网络中对应卷积层的学习率的预设比例。所述预设比例可以为二分之一。
可选地,特征匹配模块505包括:目标特征抽取子模块,用于从所述检测目标的训练集中为每个类别随机抽取预设数量的样本,输入调整后的所述目标检测模型统计得到每个类别的卷积层特征以及池化层特征。
可选地,特征匹配模块505还包括:
匹配子模块,用于采用EMD距离从所述Bp中遍历出与所述检测目标的训练集中的每个类别的池化层特征匹配的记忆单元。
可选地,模型更新模块506包括:
模型修改子模块,用于在包含所述新的特征记忆库的目标检测模型的骨干网络的最后一个卷积层后加入全局池化层,通过所述骨干网络直接得到当前样本的卷积层特征和池化层特征;
相似特征匹配子模块,用于对所述当前样本的卷积层特征和池化层特征归一化,利用MMD距离逐条计算当前池化层特征与所述新的特征记忆库中的记忆单元的池化层特征的相似度,得到与当前检测目标匹配的记忆单元;
特征误差计算子模块,用于利用MSE相似度计算当前卷积层特征与匹配的记忆单元的卷积层特征的相似度误差,得到所述骨干网络的记忆库匹配损失函数Lmse;
预测误差计算子模块,用于将所述当前样本的卷积层特征继续输入至后续预测网络,计算预测误差,得到损失函数Lce;
更新子模块,用于训练所述目标检测网络,利用所述Lmse以及Lce更新所述目标检测模型。
可选地,所述目标检测模型为深度学习神经网络。
本发明实施例的构建系统与现有技术相比有如下优点:
1、考虑检测目标之间可能不存在任何的相关性,利用外部数据丰富类别特征,有效解决了长尾识别问题中模型对尾类识别能力不足,学习过程中尾类特征逐渐向头类偏移等问题。
2、通过距离函数筛选可迁移特征,有效避免的迁移学习过程中的负迁移问题,同时增强了知识迁移的可解释性。
3、在原网络模型架构中嵌入外部知识记忆单元并引入特征匹配误差,实现端到端训练动态知识迁移过程。本发明实施例可应用于任何目标检测算法,满足工业生产需要。
图6为本发明实施例三提供的一种服务器的结构示意图。如图6所示,该云端包括:存储器602、处理器601;
其中,所述存储器602存储有可被所述至少一个处理器601执行的指令,所述指令被所述至少一个处理器601执行以实现前述任意实施例所述的基于元学习和知识记忆的目标检测模型构建方法。
该服务器可以包括一个或多个处理器601以及存储器602,图6中以一个处理器601为例。处理器601、存储器602可以通过总线或者其他方式连接,图6中以通过总线连接为例。存储器602作为一种非易失性计算机可读存储介质,可用于存储非易失性软件程序、非易失性计算机可执行程序以及模块。处理器601通过运行存储在存储器602中的非易失性软件程序、指令以及模块,从而执行云端的各种功能应用以及数据处理,即实现上述任一实施例所述的基于元学习和知识记忆的目标检测模型构建方法。
存储器602可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序。此外,存储器602可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他非易失性固态存储器件。
一个或者多个模块存储在存储器602中,当被一个或者多个处理器601执行时,执行上述任意方法实施方式中的基于元学习和知识记忆的目标检测模型构建方法。
上述服务器可执行本发明实施方式所提供的方法,具备执行方法相应的功能模块和有益效果,未在本实施方式中详尽描述的技术细节,可参见本发明实施方式所提供的方法。
本发明实施例四提供一种计算机可读存储介质,用于存储计算机可读程序,所述计算机可读程序用于供云端执行上述部分或全部的方法实施例。
即,本领域技术人员可以理解,实现上述实施例方法中的全部或部分步骤是可以通过程序来指令相关的硬件来完成,该程序存储在一个存储介质中,包括若干指令用以使得一个云端(可以是单片机,芯片等)或处理器(processor)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-OnlyMemory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
本领域的普通技术人员可以理解,上述各实施方式是实现本发明的具体实施例,而在实际应用中,可以在形式上和细节上对其作各种改变,而不偏离本发明的精神和范围。
Claims (10)
1.一种基于元学习和知识记忆的目标检测模型构建方法,其特征在于,包括:
获取检测目标的训练集、检测目标的测试集以及外部开源数据集;
利用所述开源数据集对选取的目标检测模型进行训练得到预训练模型;所述预训练模型包含预训练骨干网络以及预测网络;
利用所述预训练模型统计所述开源数据集中每个类别的特征,并构建开源特征记忆库Bp;所述Bp包含每个类别的卷积层特征、池化层特征以及对应的类别标签形成的记忆单元;
利用所述检测目标的训练集微调所述预训练模型;微调过程中固定所述预训练骨干网络的浅层卷积层,降低所述预训练骨干网络中其余卷积层的学习率,重新训练所述目标检测网络中的预测网络;
利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库;
将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型进行动态更新;
利用所述检测目标的测试集测试更新后的所述目标检测模型。
2.根据权利要求1所述的方法,其特征在于,利用所述预训练模型统计所述开源数据集中每个类别的特征,包括:
从所述开源数据集中为每个类别随机抽取预设数量的图片,分别输入所述预训练模型并得到所述预训练骨干网络最后一个卷积层输出的特征;计算每个类别的平均卷积层特征,作为每个类别对应的卷积层特征;对平均卷积层特征每个通道的特征图取全局平均值得到每个类别的池化层特征。
3.根据权利要求2所述的方法,其特征在于,构建开源特征记忆库Bp,包括:
对所述卷积层特征和池化层特征进行归一化后与对应的类别标签构成记忆单元;
将每个类别的记忆单元组合后得到所述开源特征记忆库Bp。
4.根据权利要求1所述的方法,其特征在于,所述降低所述预训练骨干网络中其余卷积层的学习率,包括:
将所述其余卷积层的学习率降低至训练得到所述预训练模型时骨干网络中对应卷积层的学习率的预设比例。
5.根据权利要求4所述的方法,其特征在于,所述预设比例为二分之一。
6.根据权利要求1所述的方法,其特征在于,所述利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,包括:
从所述检测目标的训练集中为每个类别随机抽取预设数量的样本,输入调整后的所述目标检测模型统计得到每个类别的卷积层特征以及池化层特征。
7.根据权利要求6所述的方法,其特征在于,从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,包括:
采用EMD距离从所述Bp中遍历出与所述检测目标的训练集中的每个类别的各池化层特征匹配的记忆单元。
8.根据权利要求1所述的方法,其特征在于,所述将所述新的开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型动态更新,包括:
在包含所述新的特征记忆库的目标检测模型的骨干网络的最后一个卷积层后加入全局池化层,通过所述骨干网络直接得到当前样本的卷积层特征和池化层特征;
对所述当前样本的卷积层特征和池化层特征归一化,利用MMD距离逐条计算当前池化层特征与所述新的特征记忆库中的记忆单元的池化层特征的相似度,得到与当前检测目标匹配的记忆单元;
利用MSE距离计算当前卷积层特征与匹配的记忆单元的卷积层特征的相似度误差,得到所述骨干网络的记忆库匹配损失函数Lmse;
将所述当前样本的卷积层特征继续输入至后续预测网络,计算预测误差,得到损失函数Lce;
训练所述目标检测网络,利用所述Lmse以及Lce更新所述目标检测模型。
9.根据权利要求1所述的方法,其特征在于,所述目标检测模型为深度学习神经网络。
10.一种基于元学习和知识记忆的目标检测模型构建系统,其特征在于,包括:
获取模块,用于获取检测目标的训练集、检测目标的测试集以及外部开源数据集;
预训练模块,用于利用所述开源数据集对选取的目标检测模型进行训练得到预训练模型;所述预训练模型包含预训练骨干网络以及预测网络;
特征记忆库构建模块,用于利用所述预训练模型统计所述开源数据集中每个类别的特征,并构建开源特征记忆库Bp;所述Bp包含每个类别的卷积层特征、池化层特征以及对应的类别标签形成的记忆单元;
模型调整模块,用于利用所述检测目标的训练集微调所述预训练模型;微调过程中固定所述预训练骨干网络的浅层卷积层,降低所述预训练骨干网络中其余卷积层的学习率,重新训练所述目标检测网络中的预测网络;
特征匹配模块,用于利用调整后的所述预训练模型从所述检测目标的训练集中抽取每个类别的特征,并从所述开源特征记忆库Bp中筛选出与每个类别匹配的记忆单元,存入新的特征记忆库;
模型更新模块,用于将所述开源特征记忆库加入所述目标检测模型后进行训练;在训练时所述目标检测模型通过计算当前特征与所述新的特征记忆库中的特征之间的相似度误差,以及目标检测网络的预测误差,通过所述相似度误差和所述预测误差对所述目标检测模型动态更新;
模型测试模块,用于利用所述检测目标的测试集测试更新后的所述目标检测模型。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110753866.XA CN113361645B (zh) | 2021-07-03 | 2021-07-03 | 基于元学习及知识记忆的目标检测模型构建方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110753866.XA CN113361645B (zh) | 2021-07-03 | 2021-07-03 | 基于元学习及知识记忆的目标检测模型构建方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113361645A true CN113361645A (zh) | 2021-09-07 |
CN113361645B CN113361645B (zh) | 2024-01-23 |
Family
ID=77538133
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110753866.XA Active CN113361645B (zh) | 2021-07-03 | 2021-07-03 | 基于元学习及知识记忆的目标检测模型构建方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113361645B (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113822368A (zh) * | 2021-09-29 | 2021-12-21 | 成都信息工程大学 | 一种基于无锚的增量式目标检测方法 |
CN114330446A (zh) * | 2021-12-30 | 2022-04-12 | 安徽心之声医疗科技有限公司 | 一种基于课程元学习的心律失常检测算法和系统 |
CN115331128A (zh) * | 2022-10-11 | 2022-11-11 | 松立控股集团股份有限公司 | 一种高架桥裂痕检测方法 |
CN116524297A (zh) * | 2023-04-28 | 2023-08-01 | 迈杰转化医学研究(苏州)有限公司 | 一种基于专家反馈的弱监督学习训练方法 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018137357A1 (zh) * | 2017-01-24 | 2018-08-02 | 北京大学 | 一种目标检测性能优化的方法 |
CN109508655A (zh) * | 2018-10-28 | 2019-03-22 | 北京化工大学 | 基于孪生网络的不完备训练集的sar目标识别方法 |
CN109961089A (zh) * | 2019-02-26 | 2019-07-02 | 中山大学 | 基于度量学习和元学习的小样本和零样本图像分类方法 |
CN112084330A (zh) * | 2020-08-12 | 2020-12-15 | 东南大学 | 一种基于课程规划元学习的增量关系抽取方法 |
CN112132257A (zh) * | 2020-08-17 | 2020-12-25 | 河北大学 | 基于金字塔池化及长期记忆结构的神经网络模型训练方法 |
CN112699966A (zh) * | 2021-01-14 | 2021-04-23 | 中国人民解放军海军航空大学 | 基于深度迁移学习的雷达hrrp小样本目标识别预训练及微调方法 |
-
2021
- 2021-07-03 CN CN202110753866.XA patent/CN113361645B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018137357A1 (zh) * | 2017-01-24 | 2018-08-02 | 北京大学 | 一种目标检测性能优化的方法 |
CN109508655A (zh) * | 2018-10-28 | 2019-03-22 | 北京化工大学 | 基于孪生网络的不完备训练集的sar目标识别方法 |
CN109961089A (zh) * | 2019-02-26 | 2019-07-02 | 中山大学 | 基于度量学习和元学习的小样本和零样本图像分类方法 |
CN112084330A (zh) * | 2020-08-12 | 2020-12-15 | 东南大学 | 一种基于课程规划元学习的增量关系抽取方法 |
CN112132257A (zh) * | 2020-08-17 | 2020-12-25 | 河北大学 | 基于金字塔池化及长期记忆结构的神经网络模型训练方法 |
CN112699966A (zh) * | 2021-01-14 | 2021-04-23 | 中国人民解放军海军航空大学 | 基于深度迁移学习的雷达hrrp小样本目标识别预训练及微调方法 |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113822368A (zh) * | 2021-09-29 | 2021-12-21 | 成都信息工程大学 | 一种基于无锚的增量式目标检测方法 |
CN114330446A (zh) * | 2021-12-30 | 2022-04-12 | 安徽心之声医疗科技有限公司 | 一种基于课程元学习的心律失常检测算法和系统 |
CN114330446B (zh) * | 2021-12-30 | 2023-04-07 | 安徽心之声医疗科技有限公司 | 一种基于课程元学习的心律失常检测算法和系统 |
CN115331128A (zh) * | 2022-10-11 | 2022-11-11 | 松立控股集团股份有限公司 | 一种高架桥裂痕检测方法 |
CN116524297A (zh) * | 2023-04-28 | 2023-08-01 | 迈杰转化医学研究(苏州)有限公司 | 一种基于专家反馈的弱监督学习训练方法 |
CN116524297B (zh) * | 2023-04-28 | 2024-02-13 | 迈杰转化医学研究(苏州)有限公司 | 一种基于专家反馈的弱监督学习训练方法 |
Also Published As
Publication number | Publication date |
---|---|
CN113361645B (zh) | 2024-01-23 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113361645B (zh) | 基于元学习及知识记忆的目标检测模型构建方法及系统 | |
CN112734775B (zh) | 图像标注、图像语义分割、模型训练方法及装置 | |
CN111209907B (zh) | 一种复杂光污染环境下产品特征图像人工智能识别方法 | |
CN110796048A (zh) | 一种基于深度神经网络的船舰目标实时检测方法 | |
CN112734803B (zh) | 基于文字描述的单目标跟踪方法、装置、设备及存储介质 | |
CN110929848A (zh) | 基于多挑战感知学习模型的训练、跟踪方法 | |
CN110827312A (zh) | 一种基于协同视觉注意力神经网络的学习方法 | |
CN111126278A (zh) | 针对少类别场景的目标检测模型优化与加速的方法 | |
CN112084895B (zh) | 一种基于深度学习的行人重识别方法 | |
CN114842343A (zh) | 一种基于ViT的航空图像识别方法 | |
CN111695640A (zh) | 地基云图识别模型训练方法及地基云图识别方法 | |
CN116452810A (zh) | 一种多层次语义分割方法、装置、电子设备及存储介质 | |
CN111310837A (zh) | 车辆改装识别方法、装置、系统、介质和设备 | |
CN111222534A (zh) | 一种基于双向特征融合和更平衡l1损失的单发多框检测器优化方法 | |
CN111914949B (zh) | 基于强化学习的零样本学习模型的训练方法及装置 | |
CN113780287A (zh) | 一种多深度学习模型的最优选取方法及系统 | |
CN112991281A (zh) | 视觉检测方法、系统、电子设备及介质 | |
TWI803243B (zh) | 圖像擴增方法、電腦設備及儲存介質 | |
CN116958809A (zh) | 一种特征库迁移的遥感小样本目标检测方法 | |
CN114612450B (zh) | 基于数据增广机器视觉的图像检测分割方法、系统、电子设备 | |
CN115410250A (zh) | 阵列式人脸美丽预测方法、设备及存储介质 | |
CN113420824A (zh) | 针对工业视觉应用的预训练数据筛选及训练方法、系统 | |
CN112419362B (zh) | 一种基于先验信息特征学习的运动目标跟踪方法 | |
CN114299012A (zh) | 一种基于卷积神经网络的物体表面缺陷检测方法及系统 | |
Pang et al. | Target tracking based on siamese convolution neural networks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |