CN115482441A - 训练数据筛选方法、装置及计算机可读存储介质 - Google Patents
训练数据筛选方法、装置及计算机可读存储介质 Download PDFInfo
- Publication number
- CN115482441A CN115482441A CN202211409768.5A CN202211409768A CN115482441A CN 115482441 A CN115482441 A CN 115482441A CN 202211409768 A CN202211409768 A CN 202211409768A CN 115482441 A CN115482441 A CN 115482441A
- Authority
- CN
- China
- Prior art keywords
- training data
- training
- learning model
- active learning
- data
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
- G06V10/7753—Incorporation of unlabelled data, e.g. multiple instance learning [MIL]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/778—Active pattern-learning, e.g. online learning of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02P—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN THE PRODUCTION OR PROCESSING OF GOODS
- Y02P90/00—Enabling technologies with a potential contribution to greenhouse gas [GHG] emissions mitigation
- Y02P90/30—Computing systems specially adapted for manufacturing
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Databases & Information Systems (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Data Mining & Analysis (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种训练数据筛选方法、装置及计算机可读存储介质,其中,所述方法包括:将已标注训练数据和未标注训练数据输入深度主动学习模型;基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;根据所述欧式距离从所述未标注训练数据中确定目标训练数据。本发明旨在提高筛选出的训练数据的代表性,以降低训练成本,提高训练模型的效率。
Description
技术领域
本发明涉及深度学习领域,尤其涉及一种训练数据筛选方法、装置及计算机可读存储介质。
背景技术
深度主动学习模型在计算机视觉和模式识别的许多研究领域取得了前所未有的成功,如图像分类、目标检测和场景分割。虽然深度主动学习模型在许多任务中普遍成功,但它们有一个主要缺点;他们需要大量的标记数据才能学习大量的参数,尤其是工业场景图像分类。
在相关技术中,工业场景获取缺陷数据成本高,难度较大且缺陷类型多,因此这些有缺陷数据标注需要有经验的人进行标注,且标注量大、耗时费力,因此表现出有标注样本获取代价非常昂贵。而主动学习查询策略一般是通过确定样本数据的不确定性来进行筛选。但是由于不确定性的单独采样会导致采样偏差,而忽略了用于模型训练的样本数据的分布,筛选出来的样本数据作为训练数据,不利于提高模型性能,反而需要标注更多的样本数据,造成标注成本增加,因此,目前训练数据筛选代表性不高,不利于提高模型性能。
上述内容仅用于辅助理解本发明的技术方案,并不代表承认上述内容是现有技术。
发明内容
本发明的主要目的在于提供一种训练数据筛选方法、装置及计算机可读存储介质,旨在达成提高训练数据筛选代表性的效果。
为实现上述目的,本发明提供一种训练数据筛选方法,所述方法包括:
将已标注训练数据和未标注训练数据输入深度主动学习模型;
基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;
根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;
根据所述欧式距离从所述未标注训练数据中确定目标训练数据。
可选地,所述根据所述欧式距离从所述未标注训练数据中确定目标训练数据的步骤包括:
确定每个所述未标注训练数据对应的最小欧式距离为目标欧式距离;
将所述目标欧式距离进行降序排列;
确定前预设数量的目标欧式距离对应的未标注训练数据为所述目标训练数据。
可选地,所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤之后,还包括:
根据所述已标注训练数据训练所述深度主动学习模型;
所述根据所述卷积神经网络确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值的步骤包括:
基于训练后的所述深度主动学习模型的卷积神经网络确定所述第一全连接层值和所述第二全连接层值。
可选地,所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤之前,还包括:
将主动选择模块封装为功能函数;
将所述功能函数连接到所述卷积神经网络的分类模块之后,以组成所述深度主动学习模型。
可选地,所述根据所述欧式距离从所述未标注训练数据中确定目标训练数据的步骤之后,还包括:
获取进行标注后的所述目标训练数据;
根据标注后的所述目标训练数据更新所述已标注训练数据;
根据更新后的已标注训练数据训练所述深度主动学习模型。
可选地,所述根据更新后的已标注训练数据训练所述深度主动学习模型的步骤之后,还包括:
获取已标注测试数据;
将所述已标注测试数据输入所述深度主动学习模型,确定预测准确的正类测试数据数量和负类测试数据数量;
根据所述正类测试数据数量和所述负类测试数据数量确定所述深度主动学习模型的准确率指标;
当所述准确率指标小于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤。
可选地,所述根据所述正类测试数据数量和所述负类测试数据数量确定所述深度主动学习模型的准确率指标的步骤之后,还包括:
获取历史训练轮次中所述深度主动学习模型的历史准确率指标;
根据所述历史准确率指标确定所述预设阈值。
可选地,所述根据更新后的已标注训练数据训练所述深度主动学习模型的步骤之后,还包括:
确定剩余未标记训练数据数量;
当所述数量大于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤;
当所述数量小于预设阈值,终止训练,输出训练数据不足的提示信息。
此外,为实现上述目的,本发明还提供一种训练数据筛选装置,所述训练数据筛选装置包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的训练数据筛选程序,所述训练数据筛选程序被所述处理器执行时实现如上所述的训练数据筛选方法的步骤。
此外,为实现上述目的,本发明还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有训练数据筛选程序,所述训练数据筛选程序被处理器执行时实现如上所述的训练数据筛选方法的步骤。
本发明实施例提出的一种训练数据筛选方法、装置及计算机可读存储介质,先将已标注训练数据和未标注训练数据输入深度主动学习模型;基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;根据所述欧式距离从所述未标注训练数据中确定目标训练数据。这样通过已标注训练数据和未标注训练数据的全连接层值,确定已标注训练数据与未标注训练数据之间的欧式距离,根据两者的欧式距离,从未标注训练数据中筛选出的目标训练数据,从而关注到了训练数据的分布,在更新标注训练数据时,可以减少采样偏差。使得筛选出来的训练数据更具代表性,不仅提高的模型的准确率,还提高了训练模型的效率。
附图说明
图1是本发明实施例方案涉及的硬件运行环境的终端结构示意图;
图2为本发明训练数据筛选方法的一实施例的流程示意图;
图3为本发明训练数据筛选方法的另一实施例的流程示意图;
图4为本发明训练数据筛选方法涉及的模型训练流程。
本发明目的的实现、功能特点及优点将结合实施例,参照附图做进一步说明。
具体实施方式
应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
由于在相关技术中,主动学习查询策略对训练数据的筛选一般是通过确定样本数据的不确定性来进行筛选,不确定性的单独采样会导致采样偏差,而忽略了用于模型训练的样本数据的分布,筛选出来的样本数据作为训练数据,不利于提高模型性能,因此,目前筛选出的训练数据的代表性不高,不利于提高模型性能。
为了提高筛选出的训练数据的代表性,本发明实施例提出一种训练数据筛选方法、装置及计算机可读存储介质,其中,所述方法的主要步骤包括:
将已标注训练数据和未标注训练数据输入深度主动学习模型;
基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;
根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;
根据所述欧式距离从所述未标注训练数据中确定目标训练数据。
这样通过已标注训练数据和未标注训练数据的全连接层值,确定已标注训练数据与未标注训练数据之间的欧式距离,根据两者的欧式距离,从未标注训练数据中筛选出的目标训练数据关注到了训练数据的分布,更具有代表性,使得筛选出来的训练数据更具代表性,不仅提高的模型的准确率,还提高了训练模型的效率。
以下结合附图对本发明权利要求要求保护的内容进行详细说明。
如图1所示,图1是本发明实施例方案涉及的硬件运行环境的终端结构示意图。
本发明实施例终端可以是训练数据筛选装置。
如图1所示,该终端可以包括:处理器1001,例如CPU,存储器1003,通信总线1002。其中,通信总线1002用于实现这些组件之间的连接通信。存储器1003可以是高速RAM存储器,也可以是稳定的存储器(non-volatile memory),例如磁盘存储器。存储器1003可选的还可以是独立于前述处理器1001的存储装置。
本领域技术人员可以理解,图1中示出的终端结构并不构成对终端的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
如图1所示,作为一种计算机存储介质的存储器1003中可以包括操作系统以及训练数据筛选程序。
在图1所示的终端中,处理器1001可以用于调用存储器1003中存储的训练数据筛选程序,并执行以下操作:
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
将已标注训练数据和未标注训练数据输入深度主动学习模型;
基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;
根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;
根据所述欧式距离从所述未标注训练数据中确定目标训练数据。
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
确定每个所述未标注训练数据对应的最小欧式距离为目标欧式距离;
将所述目标欧式距离进行降序排列;
确定前预设数量的目标欧式距离对应的未标注训练数据为所述目标训练数据。
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
根据所述已标注训练数据训练所述深度主动学习模型;
所述根据所述卷积神经网络确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值的步骤包括:
基于训练后的所述深度主动学习模型的卷积神经网络确定所述第一全连接层值和所述第二全连接层值。
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
将主动选择模块封装为功能函数;
将所述功能函数连接到所述卷积神经网络的分类模块之后,以组成所述深度主动学习模型。
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
获取进行标注后的所述目标训练数据;
根据标注后的所述目标训练数据更新所述已标注训练数据;
根据更新后的已标注训练数据训练所述深度主动学习模型。
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
获取已标注测试数据;
将所述已标注测试数据输入所述深度主动学习模型,确定预测准确的正类测试数据数量和负类测试数据数量;
根据所述正类测试数据数量和所述负类测试数据数量确定所述深度主动学习模型的准确率指标;
当所述准确率指标小于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤。
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
获取历史训练轮次中所述深度主动学习模型的历史准确率指标;
根据所述历史准确率指标确定所述预设阈值。
进一步地,处理器1001可以调用存储器1003中存储的训练数据筛选程序,还执行以下操作:
确定剩余未标记训练数据数量;
当所述数量大于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤;
当所述数量小于预设阈值,终止训练,输出训练数据不足的提示信息。
目前工业场景图像分类存在很大的数据问题。数据问题与深度学习技术出现对立现象,具体表现为:监督式的深度学习模型目前仍需要大量的数据样本,而工业场景获取缺陷数据成本高,难度较大且缺陷类型多,因此这些有缺陷数据标注需要有经验的人进行标注,且标注量大、耗时费力,因此表现出有标注样本获取代价非常昂贵。
深度卷积神经网络在计算机视觉和模式识别的许多研究领域取得了前所未有的成功,如图像分类、目标检测和场景分割。虽然卷积神经网络在许多任务中普遍成功,但它们有一个主要缺点;他们需要大量的标记数据才能学习大量的参数。更重要的是,使用更大的数据几乎总是更好的,因为卷积神经网络的精度通常不会随着数据集的大小增加而饱和。因此,人们一直希望收集越来越多的数据。虽然从算法角度来看这是一种理想的行为(代表性更强通常更好),但标记数据集是一项耗时且昂贵的任务。这些实际考虑提出了一个关键问题:“在给定固定标记预算的情况下,选择要标记的数据点以获得最高精度的最佳方法是什么”主动学习是解决这个问题的常见范例之一。
目前主要的查询策略包括基于不确定性的方法、基于多样性的方法和预期模型变化。此外,许多工作还研究了混合查询策略,考虑了查询样本的不确定性和多样性,并试图在这两种策略之间找到平衡。但是基于不确定性的单独采样通常会导致采样偏差,因此当前选择的样本不能代表未标记数据集的分布。另一方面,仅考虑促进抽样多样性的策略可能会导致标签成本增加,因为可能会因此选择大量信息含量低的样本。因此,采用了不确定性、多样性等策略来挑选“最有价值”的样本数据,而忽略了训练数据的分布,在数据更新的情况下,有可能导致严重的采样偏差,无法反映数据真实的分布情况。采样这样挑选出来的样本数据作为训练数据,不利于提高模型性能。
由此可见,在相关训练数据筛选方法中,存在上述缺陷。本发明实施例为解决上述缺陷,提出一种训练数据筛选方法,旨在达成通过已标注训练数据与未标注训练数据之间的欧式距离从未标注训练数据中筛选出根据代表性的训练数据,提高筛选出的训练数据的代表性,以提高模型训练效率的效果。
以下,通过具体示例性方案对本发明权利要求要求保护的内容,进行解释说明,以便本领域技术人员更好地理解本发明权利要求的保护范围。可以理解的是,以下示例性方案不对本发明的保护范围进行限定,仅用于解释本发明。
示例性地,参照图2,在本发明训练数据筛选方法的一实施例中,所述训练数据筛选方法包括以下步骤:
步骤S10、将已标注训练数据和未标注训练数据输入深度主动学习模型;
在本实施例中,深度主动学习模型是执行工业场景图像分类任务的模型,训练数据可以来自工业场景项目中的工业场景图像。在对深度主动学习模型进行训练时,会对训练数据提前进行标注,将已标注训练数据输入深度主动学习模型中,进行训练。用于训练的数据应为分类难度高,缺陷大的图像数据,这样对模型训练更具有价值。而工业场景中获取缺陷数据成本高,难度较大且缺陷类型多,这些有缺陷数据标注需要有经验的人进行标注,且标注量大、耗时费力,因此表现出有标注训练获取代价非常昂贵。
获取具有多个训练数据的数据集,根据深度学习模型各个分类类型,从数据集中随机出抽取每个分类类型对应的多个训练数据,并通过人工或者高精度模型进行标注,划为初始训练集train,其中,初始训练集train中的训练数据是已标注训练数据,剩下的训练数据放入未标注集unlabel,其中未标注集unlabel中的训练数据是未标注训练数据。将已标注训练数据和未标注训练数据输入深度主动学习模型,需要说明的是,已标注训练数据和未标注训练数据可以同时输入,也可以异步输入。
已标注训练数据可以用于训练深度主动学习模型,深度主动学习模型可以对训练数据进行分类,深度主动学习模型可以是卷积神经网络和主动选择模块结合的模型,主动选择模块包括主动选择策略算法,基于卷积神经网络可以执行训练和分类任务,基于主动选择模块可以从未标注训练数据中选择目标训练数据,对目标训练数据进行标注后,可以再输入深度主动学习模型中,进行下一轮训练。
可选地,将主动选择模块封装为功能函数;将所述功能函数连接到所述卷积神经网络的分类模块之后,以组成所述深度主动学习模型。
卷积神经网络的网络架构采用19层的卷积层、Pooling层、ReLU层、全连接层和Softmax层的Caffe分类网络。Caffe框架主要有五个组件,Blob,Solver,Net,Layer,Proto。Solver负责深度网络的训练,每个Solver中包含一个训练网络对象和一个测试网络对象。每个网络则由若干个Layer构成。每个Layer的输入和输出Feature map表示为Input Blob和Output Blob。Blob是Caffe实际存储数据的结构,是一个不定维的矩阵,在Caffe中一般用来表示一个拉直的四维矩阵,四个维度分别对应Batch Size(N),Feature Map的通道数(C),Feature Map高度(H)和宽度(W)。主动学习模块中包括主动选择策略算法,将主动挑选策略算法分装成抽象层C++功能函数,并将其作为模块追加在全连接层或者Softmax层的Caffe网络分类模块之后,从而构成深度主动学习模型。
步骤S20、基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;
在本实施例中,深度主动学习模型包括卷积神经网络,卷积神经网络包括全连接层,全连接层值是训练数据经过卷积神经的全连接层得到的特征值。训练数据输入深度主动学习模型进行运算,运算过程经过卷积神经网络的全连接层后,可以得到训练数据对应的全连接层值。已标注训练数据对应得到第一全连接层值,未标注训练数据对应得到第二全连接层值。
需要说明的是,输入深度主动学习模型的数据无论是用于训练还是筛选,都会经过卷积神经网络的全连接层,得到全连接层值。
可选地,根据所述已标注训练数据训练所述深度主动学习模型;所述根据所述卷积神经网络确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值的步骤包括:基于训练后的所述深度主动学习模型的卷积神经网络确定所述第一全连接层值和所述第二全连接层值。
将已标注训练数据输入深度主动学习模型,已标注训练数据可以用于对深度主动学习模型的训练,主要是用于训练深度主动学习模型中的卷积神经网络,进行训练后,模型精度提高,对数据预测的准确性也会提高。将已标注训练数据和未标注训练数据可以一起输入深度主动学习模型,已标注训练数据用于训练深度主动学习模型。在训练过程中,未经训练的模型可以得出已标注训练数据对应的第一全连接层值,同时,基于未训练的深度主动学习模型或者训练完成后的深度主动学习模型,预测出未标注训练数据对应的第二全连接层值。也可以先将已标注训练数据输入深度主动学习模型进行训练,训练完成后,再输入一起输入已标注训练数据和未标注训练数据,经过训练后的深度主动学习模型得出全连接层值。为了提高已标注训练数据和未标注训练数据对应的全连接层值的一致性,得出第一全连接层值和所述第二全连接层值的模型是一致的,可以均为训练前的深度主动学习模型,也可以均为训练后的深度主动学习模型。基于训练后的深度主动学习模型得出的全连接层值更为准确,筛选出来的目标训练数据准确性更高。
可选地,在步骤S20之前包括,基于深度主动学习模型的卷积神经网络,确定每个未标注训练数据的Softmax值,Softmax值是训练数据经过卷积神经网络中的Softmax层得到特征值,Softmax值是深度主动学习模型的卷积神经网络对训练数据的预测结果。可以理解的是,将训练数据输入卷积神经网络后,会经过卷积神经网络的各个层级,各个层级得到对应的特征值,包括全连接层值和Softmax值。
根据未标注训练数据的Softmax值确定每个未标注训练数据的不确定性值,不确定性值越高,表示未标注训练数据信息量更高,分类难度越大。根据Softmax值计算出每个未标注训练数据的不确定性的计算公式如下:
可以根据entropy值按照未标注训练数据总量的预设比例,例如60%进行筛选,从而选出未标注训练数据中不确定性高的样本。具体地,根据不确定性值对未标注训练数据进行降序排列,剔除预设比例的排序靠后的未标注训练数据,剩余未标注训练数据不确定性值较大;或者剔除不确定性值小于预设阈值的未标注训练数据,剩余未标注训练数据是不确定性值大于或等于预设阈值的未标注训练数据。剩余标注训练数据不确定性值较大,因此信息量更高,分类难度越大,相较于剔除掉的未标注训练数据,对模型训练更具价值,更能提高模型精度。基于剩余未标注训练数据,执行后续步骤,包括基于深度主动学习模型的卷积神经网络,确定已标注训练数据的第一全连接层值和剩余未标注训练数据的第二全连接层值。根据第一全连接层值和第二全连接层值确定每个已标注训练数据与每个剩余未标注训练数据之间的欧式距离,也即根据第一全连接层值和第二全连接层值确定每个已标注训练数据与每个不确定性高的未标注训练数据之间的欧式距离。根据欧式距离从剩余未标注训练数据中确定目标训练数据。这样先基于未标注训练数据本身的不确定性进行筛选,在此基础上,再根据每个已标注训练数据与每个未标注训练数据之间的欧式距离对剩余不确定性校高的未标注训练数据进行筛选。在不确定性的基础上关注到了训练数据的分布,即在考虑到训练数据本身的训练价值,还考虑到即将参与训练的训练数据分布情况,避免筛选出冗余的目标训练数据。基于这样更具价值、更具代表性的目标训练数据对模型进行训练,更利于提升模型性能。
步骤S30、根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;
在本实施例中,全连接层值是卷积神经网络中全连接层的特征值。全连接层值是深度主动学习模型的卷积神经网络对训练数据的预测。全连接层得到训练数据的全连接层值后,卷积神经网络中位于全连接层之后的Caffe分类模块根据全连接层值得出训练数据的分类结果。因此,主动选择模块可以获取已标注训练数据对应的第一全连接层值和未标注训练数据对应的第二全连接层值,根据第一全连接层值和第二全连接层值计算出每个已标注训练数据与每个未标注训练数据之间的欧式距离,计算公式如下:
步骤S40、根据所述欧式距离从所述未标注训练数据中确定目标训练数据。
在本实施例中,欧式距离表示已标注训练数据和未标注训练数据之间的差异。根据欧式距离中从未标注训练数据中选择与已标注训练数据欧式距离最大的训练数据作为目标训练数据。目标训练数据的数量可以根据训练需求预先设置。目标训练数据进行标注后,加入已标注训练数据,继续用于下一轮模型训练,提高模型精度。通过欧式距离挑选出来的目标训练数据,考虑了训练数据的分布情况。选择较大欧式距离对应的未标注训练数据作为目标训练数据,可以与已标注训练数据保持差异,防止训练数据的冗余,这样的目标训练数据更具有代表性,更大限度的提高模型训练的效率。
可选地,确定每个所述每个未标注训练数据对应的最小欧式距离为目标欧式距离;将所述目标欧式距离进行降序排列;确定前预设数量的目标欧式距离对应的未标注训练数据为所述目标训练数据。
根据欧式距离从未标注训练数据中确定更具有代表性的目标训练数据的问题,相当于是一个集合中,选择与另外一个集合差异最大的目标集合。确定每个未标注训练数据对应的最小欧式距离为目标欧式距离,将目标欧式距离进行降序排列,确定前预设数量的目标欧式距离对应的未标注训练数据为目标训练数据,若预设数量是n个,则选择进行降序排列后,使得目标训练数据与已标注训练数据之间的最小距离达到最大。增加目标训练数据与已标注训练数据的差异性。
具体地,未标注训练数据记为unlabeled data(),已标注训练数据记为
initial labeled data()。目标训练数据需要从未标注训练数据中选择Budget个训练
数据,可以设置为全部数据的5%。主动选择模块将这个过程视为寻找一个当前最佳集合的
问题,顺序从unlabeled data中选出Budget个训练数据加入集合,新加入的目标训练
数据需要满足与集合的距离最大。一个未标注训练数据与集合的欧式距离为:该未标
注训练数据与集合内各个已标注训练数据的欧式距离的最小值,具体计算公式如下:
其中,,表示集合,中的训练数据,n为全部数据,通过深度主动学习
模型预测的全连接层的特征值,n为全部数据量。在工业场景图像分类场景中,其为n为数
据,n为图像的类别,表示L2-norm距离(欧式距离)。
可选地,若是已标注训练数据对深度主动学习模型的训练是初始训练轮次,即深度主动学习模型还未被训练过,则深度主动学习模型精度不高,对训练数据的全连接层值的预测准确性较低,使得筛选出来的目标训练数据准确性不高。为了节省计算成本,提高训练效率,主动选择模块在初始训练轮次中筛选目标训练数据可以是随机选择策略。
在本实施例公开的技术方案中,先将已标注训练数据和未标注训练数据输入深度主动学习模型;基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;根据所述欧式距离从所述未标注训练数据中确定目标训练数据。这样通过已标注训练数据和未标注训练数据的全连接层值,确定已标注训练数据与未标注训练数据之间的欧式距离,根据两者的欧式距离,从未标注训练数据中筛选出的目标训练数据,关注到了训练数据的分布,在更新标注训练数据时,可以减少采样偏差。目标训练数据进行标注后需要加入已标注训练数据后,用于下一轮模型训练,通过欧式距离筛选的方式,可以避免目标训练数据与已标注训练数据之间的冗余。使得筛选出来的训练数据更具代表性,不仅提高了模型准确率,还提高了训练模型的效率,降低标注成本。
可选地,参照图3,基于上述任一实施例,在本发明训练数据筛选方法的另一实施例中,所述训练数据筛选方法还包括:
步骤S50、获取进行标注后的所述目标训练数据;
在本实施例中,目标训练数据是从未标注训练数据中挑选出来的训练数据,还未进行标注,选择额外的标注方法对目标训练数据进行标注,包括采用人工标注的方式或者采用精度更高的模型进行标注。可以采用Artificial Intelligent Defect Inspection软件或lableIme进行标注,生成对应的已标注训练数据。
步骤S60、根据标注后的所述目标训练数据更新所述已标注训练数据;
在本实施例中,已标注训练数据包括多个已标注的训练数据,本质是数据集,将进行标注后的目标训练数据加入该数据集中,更新该数据集的数据分布,更新已标注训练数据,还可以将标注后的目标数据全部作为新的已标注训练数据,删除之前的已标注训练数据。
步骤S70、根据更新后的已标注训练数据训练所述深度主动学习模型。
在本实施例中,参照图4,将更新后的已标注训练数据输入深度主动学习模型,根据更新后的已标注训练数据和caffe分类模块进行训练,从而得到精度更高的深度主动学习模型,更适于执行工业场景图像分类任务,该步骤也可以与步骤S10将已标注训练数据和未标注训练数据输入深度主动学习模型进行衔接,即将更新后的已标注训练数据和剩余未标注训练数据输入深度主动学习模型,挑选新一轮的目标训练数据进行标注,从而完成深度主动学习模型迭代训练。在整个迭代训练的过程中,由于目标训练数据的筛选考虑到了与已标注训练数据的数据分布,提高了筛选准确性,可以用更少轮次、更少训练数据使得深度主动学习模型达到训练目标,提高了训练数据,降低了标注成本。
在迭代训练前,需要对训练参数初始化,基础学习率base_lr为0.01,迭代的过程中,通过lr_policy: "step"对基础学习率进行调整,梯度更新的权重momentum设为0.9,优化算法为Adam。迭代训练需要有终止训练的条件,整个迭代训练的过程的停止过程可以由人工判断是否终止,也可以由训练轮次、训练时间等参数作为训练终止条件,还可以将本训练轮次的模型精度与历史训练轮次的模型精度进行对比,根据对比结果确定是否终止训练,若相较于历史训练轮次的模型精度有提升,则继续训练,若相较于历史训练轮次的模型精度没有提升,则终止训练。
可选地,获取已标注测试数据;将所述已标注测试数据输入所述深度主动学习模型,确定预测准确的正类测试数据数量和负类测试数据数量;根据所述正类测试数据数量和所述负类测试数据数量确定所述深度主动学习模型的准确率指标;当所述准确率指标小于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤。
在获取到多个未标注的工业场景图像时,按照预设比例将图像数据分为训练数据和测试数据,在训练数据中,每类图像随机挑选3张,划为初始训练集train,剩下的未标注训练数据放入未标注集unlabel。初始训练集train和测试数据都需要进行额外的标注。根据将已标注测试数据输入深度主动学习模型,确定预测准确的正类测试数据数量和负类测试数据数量,根据正类测试数据数量和负类测试数据数量确定准确率指标,公式如下:
其中,True Positive(真正,TP):将正类预测为正类的数量,True Negative(真负,TN):将负类预测为负类的数量,Accuracy为准确率指标,Accuracy的预设阀值可以设置设为0.99。当准确率指标小于或等于预设阈值,重新执行步骤S10,继续重复训练,若大于预设阈值,则终止训练,输出最终的深度主动学习模型。
进一步地,获取历史训练轮次中所述深度主动学习模型的历史准确率指标;根据所述历史准确率指标确定所述预设阈值。
还可以获取历史训练轮次中深度主动学习模型的历史准确率指标,历史轮次指在本次训练轮次前的训练轮次,在每一次训练后,都可以基于测试数据计算准确率指标,根据历史准确率指标建立准确率指标变化曲线,根据准确率指标变化曲线预测本次训练轮次准确率指标,将预测的准确率指标作为预设阈值。或者根据历史准确率指标确定本次训练轮次的最低准确率指标,将最低准确率指标作为预设阈值。
可选地,确定剩余未标记训练数据数量;当所述数量大于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤。
若终止训练的指标设置成固定的标注预算量,则在主动学习模块挑选的目标训练数据达到预算量时停止迭代更新。例如数据集全部数据为1000张,标注预算量为500张,在进行多轮自动挑选之后已标注数据达到500张数据时停止迭代更新。同样地,也可以判断剩余未标记训练数据数量是否达到预设阈值,若数量小于预设阈值,剩余的未标注训练数据数量太小,无法完成训练任务或者提高模型精度不明显,停止训练,还可以输出训练数据不足的提示信息,向用户确认是否输出最终的深度主动学习模型,若数量大于或等于预设阈值,则重新执行步骤S10。
在本实施例公开的技术方案中,获取进行标注后的所述目标训练数据;根据标注后的所述目标训练数据更新所述已标注训练数据;根据更新后的已标注训练数据训练所述深度主动学习模型。通过目标训练数据更新已标注训练数据,训练深度主动学习模型。由于目标训练数据的筛选考虑到了目标训练数据与已标注训练数据的数据分布,目标训练数据的筛选准确性更高,提高了训练出来的模型精度。
此外,本发明实施例还提出一种训练数据筛选装置,所述训练数据筛选装置包括存储器、处理器及存储在所述存储器上并可在所述处理器上运行的训练数据筛选程序,所述训练数据筛选程序被所述处理器执行时实现如上各个实施例所述的训练数据筛选方法的步骤。
此外,本发明实施例还提出一种计算机可读存储介质,所述计算机可读存储介质上存储有训练数据筛选程序,所述训练数据筛选程序被处理器执行时实现如上各个实施例所述的训练数据筛选方法的步骤。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者系统不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者系统所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者系统中还存在另外的相同要素。
上述本发明实施例序号仅仅为了描述,不代表实施例的优劣。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在如上所述的一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得训练数据筛选装置执行本发明各个实施例所述的方法。
以上仅为本发明的优选实施例,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。
Claims (10)
1.一种训练数据筛选方法,其特征在于,所述方法包括:
将已标注训练数据和未标注训练数据输入深度主动学习模型;
基于所述深度主动学习模型的卷积神经网络,确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值;
根据所述第一全连接层值和所述第二全连接层值确定每个所述已标注训练数据与每个所述未标注训练数据之间的欧式距离;
根据所述欧式距离从所述未标注训练数据中确定目标训练数据。
2.如权利要求1所述的训练数据筛选方法,其特征在于,所述根据所述欧式距离从所述未标注训练数据中确定目标训练数据的步骤包括:
确定每个所述未标注训练数据对应的最小欧式距离为目标欧式距离;
将所述目标欧式距离进行降序排列;
确定前预设数量的目标欧式距离对应的未标注训练数据为所述目标训练数据。
3.如权利要求1所述的训练数据筛选方法,其特征在于,所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤之后,还包括:
根据所述已标注训练数据训练所述深度主动学习模型;
所述根据所述卷积神经网络确定所述已标注训练数据的第一全连接层值和所述未标注训练数据的第二全连接层值的步骤包括:
基于训练后的所述深度主动学习模型的卷积神经网络确定所述第一全连接层值和所述第二全连接层值。
4.如权利要求1所述的训练数据筛选方法,其特征在于,所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤之前,还包括:
将主动选择模块封装为功能函数;
将所述功能函数连接到所述卷积神经网络的分类模块之后,以组成所述深度主动学习模型。
5.如权利要求1所述的训练数据筛选方法,其特征在于,所述根据所述欧式距离从所述未标注训练数据中确定目标训练数据的步骤之后,还包括:
获取进行标注后的所述目标训练数据;
根据标注后的所述目标训练数据更新所述已标注训练数据;
根据更新后的已标注训练数据训练所述深度主动学习模型。
6.如权利要求5所述的训练数据筛选方法,其特征在于,所述根据更新后的已标注训练数据训练所述深度主动学习模型的步骤之后,还包括:
获取已标注测试数据;
将所述已标注测试数据输入所述深度主动学习模型,确定预测准确的正类测试数据数量和负类测试数据数量;
根据所述正类测试数据数量和所述负类测试数据数量确定所述深度主动学习模型的准确率指标;
当所述准确率指标小于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤。
7.如权利要求6所述的训练数据筛选方法,其特征在于,所述根据所述正类测试数据数量和所述负类测试数据数量确定所述深度主动学习模型的准确率指标的步骤之后,还包括:
获取历史训练轮次中所述深度主动学习模型的历史准确率指标;
根据所述历史准确率指标确定所述预设阈值。
8.如权利要求5所述的训练数据筛选方法,其特征在于,所述根据更新后的已标注训练数据训练所述深度主动学习模型的步骤之后,还包括:
确定剩余未标记训练数据数量;
当所述数量大于或等于预设阈值,重新执行所述将已标注训练数据和未标注训练数据输入深度主动学习模型的步骤;
当所述数量小于预设阈值,终止训练,输出训练数据不足的提示信息。
9.一种训练数据筛选装置,其特征在于,所述训练数据筛选装置包括:存储器、处理器及存储在所述存储器上并可在所述处理器上运行的训练数据筛选程序,所述训练数据筛选程序被所述处理器执行时实现如权利要求1至8中任一项所述的训练数据筛选方法的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有训练数据筛选程序,所述训练数据筛选程序被处理器执行时实现如权利要求1至8中任一项所述的训练数据筛选方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211409768.5A CN115482441B (zh) | 2022-11-11 | 2022-11-11 | 训练数据筛选方法、装置及计算机可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211409768.5A CN115482441B (zh) | 2022-11-11 | 2022-11-11 | 训练数据筛选方法、装置及计算机可读存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115482441A true CN115482441A (zh) | 2022-12-16 |
CN115482441B CN115482441B (zh) | 2023-06-23 |
Family
ID=84396428
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211409768.5A Active CN115482441B (zh) | 2022-11-11 | 2022-11-11 | 训练数据筛选方法、装置及计算机可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115482441B (zh) |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109920501A (zh) * | 2019-01-24 | 2019-06-21 | 西安交通大学 | 基于卷积神经网络和主动学习的电子病历分类方法及系统 |
CN109961009A (zh) * | 2019-02-15 | 2019-07-02 | 平安科技(深圳)有限公司 | 基于深度学习的行人检测方法、系统、装置及存储介质 |
CN110659740A (zh) * | 2018-06-28 | 2020-01-07 | 国际商业机器公司 | 基于边缘节点处的数据输入对机器学习模型排序和更新 |
CN111461232A (zh) * | 2020-04-02 | 2020-07-28 | 大连海事大学 | 一种基于多策略批量式主动学习的核磁共振图像分类方法 |
CN112508092A (zh) * | 2020-12-03 | 2021-03-16 | 上海云从企业发展有限公司 | 一种样本筛选方法、系统、设备及介质 |
CN114154570A (zh) * | 2021-11-30 | 2022-03-08 | 深圳壹账通智能科技有限公司 | 一种样本筛选方法、系统及神经网络模型训练方法 |
-
2022
- 2022-11-11 CN CN202211409768.5A patent/CN115482441B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110659740A (zh) * | 2018-06-28 | 2020-01-07 | 国际商业机器公司 | 基于边缘节点处的数据输入对机器学习模型排序和更新 |
CN109920501A (zh) * | 2019-01-24 | 2019-06-21 | 西安交通大学 | 基于卷积神经网络和主动学习的电子病历分类方法及系统 |
CN109961009A (zh) * | 2019-02-15 | 2019-07-02 | 平安科技(深圳)有限公司 | 基于深度学习的行人检测方法、系统、装置及存储介质 |
CN111461232A (zh) * | 2020-04-02 | 2020-07-28 | 大连海事大学 | 一种基于多策略批量式主动学习的核磁共振图像分类方法 |
CN112508092A (zh) * | 2020-12-03 | 2021-03-16 | 上海云从企业发展有限公司 | 一种样本筛选方法、系统、设备及介质 |
CN114154570A (zh) * | 2021-11-30 | 2022-03-08 | 深圳壹账通智能科技有限公司 | 一种样本筛选方法、系统及神经网络模型训练方法 |
Also Published As
Publication number | Publication date |
---|---|
CN115482441B (zh) | 2023-06-23 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108564097B (zh) | 一种基于深度卷积神经网络的多尺度目标检测方法 | |
CN112668630B (zh) | 一种基于模型剪枝的轻量化图像分类方法、系统及设备 | |
CN112488025B (zh) | 基于多模态特征融合的双时相遥感影像语义变化检测方法 | |
CN112528845B (zh) | 一种基于深度学习的物理电路图识别方法及其应用 | |
CN111428558A (zh) | 一种基于改进YOLOv3方法的车辆检测方法 | |
CN111369526B (zh) | 基于半监督深度学习的多类型旧桥裂痕识别方法 | |
CN111489370A (zh) | 基于深度学习的遥感图像的分割方法 | |
CN110599459A (zh) | 基于深度学习的地下管网风险评估云系统 | |
CN111178196B (zh) | 一种细胞分类的方法、装置及设备 | |
CN116385374A (zh) | 基于卷积神经网络的细胞计数方法 | |
CN114201572A (zh) | 基于图神经网络的兴趣点分类方法和装置 | |
CN104598898A (zh) | 一种基于多任务拓扑学习的航拍图像快速识别系统及其快速识别方法 | |
CN115292538A (zh) | 一种基于深度学习的地图线要素提取方法 | |
CN113052217A (zh) | 预测结果标识及其模型训练方法、装置及计算机存储介质 | |
CN110569871B (zh) | 一种基于深度卷积神经网络的鞍部点识别方法 | |
CN115984632A (zh) | 一种高光谱塑料垃圾材质快速分类方法、装置及存储介质 | |
CN115482441A (zh) | 训练数据筛选方法、装置及计算机可读存储介质 | |
CN116524296A (zh) | 设备缺陷检测模型的训练方法、装置和设备缺陷检测方法 | |
CN115880477A (zh) | 一种基于深度卷积神经网络的苹果检测定位方法与系统 | |
CN115457366A (zh) | 基于图卷积神经网络的中草药多标签识别模型 | |
CN112465821A (zh) | 一种基于边界关键点感知的多尺度害虫图像检测方法 | |
CN113192108A (zh) | 一种针对视觉跟踪模型的人在回路训练方法及相关装置 | |
CN113627537B (zh) | 一种图像识别方法、装置、存储介质及设备 | |
CN113313079B (zh) | 一种车辆属性识别模型的训练方法、系统及相关设备 | |
CN115272814B (zh) | 一种远距离空间自适应多尺度的小目标检测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |