CN114998691B - 半监督船舶分类模型训练方法及装置 - Google Patents
半监督船舶分类模型训练方法及装置 Download PDFInfo
- Publication number
- CN114998691B CN114998691B CN202210721409.7A CN202210721409A CN114998691B CN 114998691 B CN114998691 B CN 114998691B CN 202210721409 A CN202210721409 A CN 202210721409A CN 114998691 B CN114998691 B CN 114998691B
- Authority
- CN
- China
- Prior art keywords
- training
- class
- data
- ship classification
- ship
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
- G06V10/7753—Incorporation of unlabelled data, e.g. multiple instance learning [MIL]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02A—TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
- Y02A10/00—TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE at coastal zones; at river basins
- Y02A10/40—Controlling or monitoring, e.g. of flood or hurricane; Forecasting, e.g. risk assessment or mapping
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Databases & Information Systems (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Molecular Biology (AREA)
- Data Mining & Analysis (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种半监督船舶分类模型训练方法及装置,其中,该方法包括:采用标定数据集分别对初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;使用第一总类别船舶分类模型及多个第一单类别船舶分类模型分别赋值网络,并进行模型训练;使用训练得到的模型对最新的伪标签训练集进行预测,并从中删除奇异样本;将删除奇异样本的伪标签训练集作为训练集对总类别船舶分类模型进行训练,得到最终船舶分类模型。本发明降低了错误样本对模型的影响,提高模型的准确率。
Description
技术领域
本发明涉及船舶类型分类领域,尤其涉及一种半监督船舶分类模型训练方法及装置。
背景技术
随着水上交通管理任务不断增加和人工智能的快速发展,人工智能在水上管理业务的重要性迅速提升,其中船舶分类是实现水上交通管理自动化重要因素之一。但由于我国水路交通路线复杂且环境差别偏大,同一类型船舶在外观上区别较大且在部分水上路线数据集无法采集,由于这些原因,导致收集到的数据集具有局限性,无法覆盖我国所有水上交通路线背景和船舶类型信息。基于上述问题,目前很多算法工程师研发了半监督模型训练方法,但目前半监督船舶分类模型训练算法存下以下问题:
1)伪标签数据集由目前已存在船舶分类模型获得,不受其他模型监督;
2)没有挖掘船舶分类模型训练过程中的推理结果对伪标签训练集的监督。发明内容
为解决上述问题,本发明提供一种半监督船舶分类模型训练方法及装置,通过在模型训练中,对伪标签数据集中数据的置信度及类别的变化进行分析;另外,通过将总类别船舶分类模型差分成多个单类别船舶分类模型,并采用多个单类别船舶分类模型的推理结果对总类别船舶分类模型的分析结果进行监督,提高最终生成的船舶分类模型的准确率,以解决上述现有技术中的问题。
为达到上述目的,S1、获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;S2、用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;S3、用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;S4、提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;S5、将所述第一总类别数据集中的第一预测结果与所述第二单类别数据集中的第四预测结果融合,得到对应有第五预测结果的未标定数据的集合,记为第二伪标签训练集;提取所述第一总类别船舶分类模型的参数,并对总类别船舶分类网络进行赋值,将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型;用所述第二总类别船舶分类模型对所述未标定数据集进行预测,得到第二总类别数据集;其中,所述第二总类别数据集包括每个未标定数据集的第六预测结果;S6、将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;S7、用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
进一步可选的,所述将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集,包括:S301、识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;S302、识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;S303、根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
进一步可选的,所述根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
其中,mkij为第k个未标定数据的类别赋值权重,为第k个未标定数据第一类别,为第k个未标定数据的第二类别;
所述根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
进一步可选的,所述将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型,其中,任一单类别船舶分类网络进行模型训练包括:S401、按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;S402、采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;S403、根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;S404、重复S401-S403,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
进一步可选的,所述将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型,包括:S501、按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;S502、采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;S503、根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;S504、重复S501-503,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
进一步可选的,所述用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集,包括:S601、采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;S602、采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;S603、在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
进一步可选的,采用以下公式计算该次迭代的损失值:
其中,loss为每次迭代的损失值,为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
另一方面,本发明还提供了一种半监督船舶分类模型训练装置,包括:获取模块,用于获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;训练模块,用于用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;融合模块,用于用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;第一预测模块,用于提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;第二预测模块,将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;数据筛选模块,用于将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;最终船舶分类模型生成模块,用于用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
进一步可选的,所述融合模块包括:第一识别子模块,用于识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;第二识别子模块,用于识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;计算子模块,用于根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
进一步可选的,所述第一预测模块包括:第一采样子模块,用于按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;第一单次迭代子模块,用于采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;第一网络更新子模块,用于根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;第一循环子模块,用于控制重复第一采样子模块、第一单次迭代子模块及第一网络更新子模块的操作,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
进一步可选的,所述第二预测模块包括:第二采样子模块,用于按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;第二单次迭代子模块、采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;第二网络更新子模块,用于根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;第二循环子模块,用于控制重复第二采样子模块、第二单次迭代子模块及第二网络更新子模块的操作,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
进一步可选的,所述数据筛选模块包括:第一奇异伪标签数据标记子模块,用于采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;第二奇异伪标签数据标记子模块,用于采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;数据剔除子模块,用于在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
上述技术方案具有如下有益效果:在以往半监督分类模型中增加其他模型监督,使伪标签数据集的标签每一轮都会受到其他船舶单类别分类模型监督,得到更准确类别的伪标签数据集;伪标签数据集准确率受自身船舶类别分类模型推理历史结果监督,挖掘每轮训练推理的类别和置信度,减少错位伪标签数据对模型训练的影响;每次迭代增加正确数据集的张数,从而降低错误样本对损失值的占比,减少模型反向传播中错误样本的作用;对伪标签损失函数增加权重参数,增加自身模型推理结果和其他类别模型推理结果监督,如果伪标签的分类结果错误,那么自身推理的置信度和其他模型推理的结果中的置信度一般偏低,从而降低错误样本的对模型的影响。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的半监督船舶分类模型训练方法的流程图;
图2是本发明实施例提供的第一伪标签训练集生成方法的流程图;
图3是本发明实施例提供的第二单类别船舶分类模型生成方法的流程图;
图4是本发明实施例提供的第二总类别船舶分类模型生成方法的流程图;
图5是本发明实施例提供的伪标签训练集更新方法的流程图;
图6是本发明实施例提供的半监督船舶分类模型训练装置的结构示意图;
图7是本发明实施例提供的融合模块的结构示意图;
图8是本发明实施例提供的第一预测模块的结构示意图;
图9是本发明实施例提供的第二预测模块的结构示意图;
图10是本发明实施例提供的数据筛选模块的结构示意图。
附图标记:100-获取模块 200-训练模块 300-融合模块 3001-第一识别子模块3002-第二识别子模块 3003-计算子模块 400-第一预测模块 4001-第一采样子模块4002-第一单次迭代子模块 4003-第一网络更新子模块 4004-第一循环子模块 500-第二预测模块 5001-第二采样子模块 5002-第二单次迭代子模块 5003-第二网络更新子模块 5004-第二循环子模块 600-数据筛选模块 6001-第一奇异伪标签数据标记子模块6002-第二奇异伪标签数据标记子模块 6003-数据剔除子模块 700-最终船舶分类模型生成模块。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
由于我国水上交通线路复杂和船舶类型内间差异大,导致无法对所有水上交通背景和船舶类型进行收集,另外,目前已存在半监督模型训练方法存在不受其他模型监督和伪标签训练存在一定概率错误样本等弊端。众所周知,深度学习分类模型对数据集类别的标定十分的依赖,如果训练集不准确,将会引导船舶分类模型向错误方向学习,导致模型的准确性降低,错误率增加。
因此为减少学习错误样本对模型的影响和提高伪标签训练集的准确性,本发明提供了一种半监督船舶分类模型训练方法,图1是本发明实施例提供的半监督船舶分类模型训练方法的流程图,如图1所示,该方法包括:
S1、获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
初始总类别船舶分类模型可以识别多种类别的船舶,例如,初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类。
每个初始单类别船舶分类模型用于分类初始总类别船舶分类模型中的其中一种船舶类型。例如,若初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类,那么该初始总类别船舶分类模型则对应有五个初始单类别船舶分类模型:快艇分类模型,其只能对快艇进行分类,其它类型船舶识别为背景;管装船分类模型,其只能对管装船进行分类,其它类型船舶识别为背景;货船分类模型,其只能对货船进行分类,其它类型船舶识别为背景;液压船分类模型,其只能对液压船进行分类,其它类型的船舶识别为背景;帆船分类模型,其只能对帆船进行分类,其它类型的船舶识别为背景。
标定数据集提前由人工进行打标,标定数据集中的每个标定数据均为一张船舶图片,且带有人工标记的原标签。例如,一张快艇的图片,其标签(本实施例中指类别)由人工标定为快艇。
未标定数据集为没有标签的数据集合,未标定数据集中的每个未标定数据可能为一张船舶图片,也可能为不存在船舶的纯背景图片。例如,一张快艇的图片,但其不对应任何标签(本实施例中指类别)。
S2、用标定数据集分别对初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
标定数据集为准确的数据集,采用标定数据集对初始总类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型;
另外,标定数据集还对每个初始单类别船舶分类模型进行模型训练,得到每个初始单类别船舶分类模型对应的第一单类别船舶分类模型。
作为一种可选的实施方式,标定数据集对从类别船舶分类模型的训练、和对多个初始单类别船舶分类模型的训练可以同时训练,也可以依次进行训练,当然,也可以同时对其中几个分类模型进行训练。本实施例对同一时间训练的模型数量不做限制。
S3、用第一总类别船舶分类模型预测未标定数据集,得到第一总类别数据集;其中,第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测未标定数据集,得到第一单类别数据集;其中,第一单类别数据集包括每个未标定数据的第二预测结果;将第一总类别数据集中的第一预测结果与第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
采用第一总类别船舶分类模型对未标定数据集中的每个未标定数据(本实施例中数据指图片)进行推理,得到每个未标定数据的第一预测结果,这些未标定数据的集合记为第一总类别数据集。其中,第一预测结果包括每个未标定数据的类别与置信度。
采用多个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行推理,得到其中每个未标定数据的第二预测结果,这些未标定数据的集合记为第一单类别数据集。其中,第二预测结果包括每个未标定数据的类别与置信度。具体的,首先选取其中一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,之后再选取另一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,重复上述步骤,直至所有单类别船舶分类模型均对未标定数据集进行了预测,由于每个第一单类别船舶分类模型仅能分类一种类别,因此,所有的第一单类别船舶分类模型可对未标定数据集中的每个未标定数据进行全面的预测。
此时,对于每个未标定数据来说均对应一个第一预测结果及第二预测结果,为实现后续的数据处理,需将第一预测结果与第二预测结果融合,得到每个未标定数据的第三预测结果,所有对应有第三预测结果的未标定数据的集合记为第一伪标签训练集。其中,第三预测结果包括每个未标定数据集的类别及置信度。
本实施例中,由于单类别船舶分类模型的识别准确率高于总类别船舶分类模型的识别准确率,因此使用多个单类别船舶分类模型分别进行后续的预测操作,可以使单类别船舶分类模型的预测结果对总类别船舶分类模型的预测结果产生监督作用。
S4、提取多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用多个第二单类别船舶分类模型对未标定数据集进行预测,得到第二单类别数据集;其中,第二单类别数据集包括每个未标定数据集的第四预测结果;
将一个第一单类别船舶分类模型作为预训练模型,提取第一单类别船舶分类模型的参数并对相应单类别船舶分类网络进行赋值。使用第一伪标签训练集和标定数据集作为训练集,对单类别船舶分类网络进行模型训练,得到对应的第二单类别船舶分类模型。
对每个第一单类别船舶分类模型均重复上述步骤,得到多个第二单类别船舶分类模型。
分别使用训练得到的多个第二单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第四预测结果,这些未标定数据的集合记为第二单类别数据集。其中,第四预测结果包括未标定数据的类别及置信度。
S5、将第一总类别数据集中的第一预测结果与第二单类别数据集中的第四预测结果融合,得到对应有第五预测结果的未标定数据的集合,记为第二伪标签训练集;提取第一总类别船舶分类模型的参数,并对总类别船舶分类网络进行赋值,将第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型;用第二总类别船舶分类模型对未标定数据集进行预测,得到第二总类别数据集;其中,第二总类别数据集包括每个未标定数据集的第六预测结果;
在总类别船舶分类网络进行模型训练之前,以每个未标定数据为主体,将其对应的第一预测结果与第四预测结果进行融合,得到每个未标定数据的第五预测结果,所有对应有第五预测结果的未标定数据的集合记为第二伪标签训练集。其中,第五预测结果包括每个未标定数据的类别及置信度。
将第一总类别船舶分类模型作为预训练模型,提取第一总类别船舶分类模型的参数并对总类别船舶分类网络进行赋值。使用第二伪标签训练集和标定数据集作为训练集,对总类别船舶分类网络进行模型训练,得到的第二总类别船舶分类模型。
使用训练得到的第二总类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第六预测结果,这些未标定数据的集合记为第二总类别数据集。其中,第六预测结果包括未标定数据的类别及置信度。
S6、将第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
以每个未标定数据为主体,将其对应的第六预测结果与第四预测结果进行融合,得到每个未标定数据的第七预测结果,所有对应有第七预测结果的未标定数据的集合记为第三伪标签训练集。其中,第七预测结果包括每个未标定数据的类别及置信度。
采用第三伪标签训练集及标定数据集作为目标训练集,对一个第二单类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,记为奇异伪标签数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三单类别船舶分类模型。
对每个第二单类别船舶分类模型均重复上述步骤,直至根据每个第二单类别船舶分类模型训练过程筛选出奇异伪标签数据,并将其从第三伪标签训练集中删除。得到多个第三单类别船舶分类模型。
采用第三伪标签训练集及标定数据集作为目标训练集,对第二总类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三总类别船舶分类模型。
将删除了所有奇异伪标签数据的第三伪标签训练集记为第四伪标签训练集。
根据模型训练中每轮的推理结果找出奇异样本,排除错误样本对模型的负面影响。
S7、用多个第三单类别船舶分类模型对第四伪标签训练集进行预测,得到第三单类别数据集;其中,第三单类别数据集包括每个未标定数据的第八预测结果;用第三总类别船舶分类模型对第四伪标签训练集进行预测,得到第三总类别数据集;其中,第三总类别数据集包括每个未标定数据的第九预测结果;将第三单类别数据集中的第八预测结果与第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用第五伪标签训练集对第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
分别使用训练得到的多个第三单类别船舶分类模型对第四伪标签训练集中每个数据进行预测,得到每个数据的第八预测结果,这些数据的集合记为第三单类别数据集。其中,第八预测结果包括数据的类别及置信度。
使用第三总类别船舶分类模型对第四伪标签训练集中的每个数据进行预测,得到每个数据的第九预测结果,这些数据的集合记为第三总类别数据集。其中,第九预测结果包括数据的类别及置信度。
以每个数据为主体,将其对应的第八预测结果与第九预测结果进行融合,得到每个数据的第十预测结果,所有对应有第十预测结果的数据集合记为第五伪标签训练集。其中,第十预测结果包括每个数据的类别及置信度。
采用第五伪标签训练集对第三总类别船舶分类模型进行训练,直至模型收敛,得到最终船舶分类模型。
作为一种可选的实施方式,图2是本发明实施例提供的第一伪标签训练集生成方法的流程图,如图2所示,将第一总类别数据集中的第一预测结果与第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集,包括:
S301、识别第一预测结果中每个未标定数据对应的第一类别及第一置信度;
第一类别及第一置信度均由第一总类别船舶分类模型对未标定数据预测得到。
S302、识别第二预测结果中每个未标定数据对应的第二类别及第二置信度;
第二类别及第二置信度均由多个第一单类别船舶分类模型对未标定数据预测得到。
S303、根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
两种预测结果的融合时,首先通过第一类别与第二类别计算类别赋值权重,再根据第一置信度、第二置信度及类别赋值权重,得到未标定数据对应的第三置信度。
作为一种具体的实施方式,根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
其中,mkij为第k个未标定数据的类别赋值权重,为第k个未标定数据第一类别,为第k个未标定数据的第二类别;
若每个未标定数据的第一类别与第二类别一致,则将类别赋值权重为1;若每个未标定数据的第一类别与第二类别不一致,则将类别赋值权重为0。
根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
两个预测结果融合时,若第一类别与第二类别不一致,则置信度也为0,即不参与候选的数据处理;若第一类别与第二类别一致,则置信度不为0,根据第一置信度与第二置信度得到对应未标定数据的第三置信度。
该步骤实现了单类别船舶分类模型的监督,如果总类别分类模型与单类别分类模型的预测类别不一致,置为0;类别一致,置为1。采用第三类别赋值置信度权重,防止错误样本对模型训练的影响。
作为一种可选的实施方式,图3是本发明实施例提供的第二单类别船舶分类模型生成方法的流程图,如图3所示,将第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型,其中,任一单类别船舶分类网络进行模型训练,包括:
S401、按第一比例在第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在标定数据集中抽取第一标定训练子集,将第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为,第二比例为。也即,每次迭代从标定数据集中抽取张图片,记为第一标定训练子集;从第一伪标签训练集中抽取张图片,作为第一伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第一训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第一伪标签训练子集时,从第一伪标签训练集中抽取置信度大于0的张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
S402、采用第一训练样本集对单类别船舶分类网络进行训练,并计算该次迭代的损失值;
使用得到的第一训练样本集对单类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
S403、根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,根据偏导结果对单类别船舶分类网络的参数进行修正;
根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使单类别船舶分类网络更改后的向正确方向学习。
S404、重复S401-S403,直至单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二单类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
对每个第一单类别船舶分类模型均采用上述操作,完成每个第一单类别船舶分类模型的模型训练,得到多个单类别船舶分类模型。
作为一种可选的实施方式,图4是本发明实施例提供的第二总类别船舶分类模型生成方法的流程图,如图4所示,将第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型,包括:
S501、按第一比例在第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在标定数据集中抽取第二标定训练子集,将第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为,第二比例为。也即,每次迭代从标定数据集中抽取张图片,记为第二标定训练子集;从第二伪标签训练集中抽取张图片,作为第二伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第二训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第二伪标签训练子集时,从第二伪标签训练集中抽取置信度大于0的张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
S502、采用第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;
使用得到的第二训练样本集对总类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
S503、根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,依据偏导结果对总类别船舶分类网络中的参数进行修正;
根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使总类别船舶分类网络更改后的向正确方向学习。
S504、重复S501-503,直至单类别船舶分类网络收敛,得到第二总类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二总类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
作为一种可选的实施方式,图5是本发明实施例提供的伪标签训练集更新方法的流程图,如图5所示,用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集,包括:
S601、采用目标训练集分别对多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
以一个第二单类别船舶分类模型为例,采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间单类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。直至模型收敛,该模型会进行多轮训练,多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第一变化清单。
每个第二单类别船舶分类模型均进行以上操作,得到每个未标定数据的多轮预测结果,这些预测结果记录进第一变化清单。
第一变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第一变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮单类别船舶分类模型训练。
S602、采用目标训练集对第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间总类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。直至模型收敛,该模型会进行多轮训练,多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第二变化清单。
第二变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第二变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮总类别船舶分类模型训练。
S603、在第三伪标签训练集中删除奇异伪标签数据,得到第四伪标签训练集。
从第三伪标签训练集中删除以上标记的奇异伪标签数据,得到更新后的第四伪标签训练集。
作为一种可选的实施方式,采用以下公式计算该次迭代的损失值:
其中,loss为每次迭代的损失值,为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
具体的,在单类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;为每次迭代的第一伪标签训练子集的数据量;Nlabel为每次迭代的第一标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前单类别船舶分类模型预测得到的预测类别与为伪标签的损失值;lossj表示第j个标定数据通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,通过当前单类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第一伪标签训练集中每个未标定数据对应的类别。
具体的,在总类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;为每次迭代的第二伪标签训练子集的数据量;Nlabel为每次迭代的第二标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值;lossj表示第j个标定数据通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第二伪标签训练集中每个未标定数据对应的类别。
作为一种可选的实施方式,本发明实施例还提供了一种半监督船舶分类模型训练装置,图6是本发明实施例提供的半监督船舶分类模型训练装置的结构示意图,如图6所示,该装置包括:
获取模块100,用于获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
初始总类别船舶分类模型可以识别多种类别的船舶,例如,初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类。
每个初始单类别船舶分类模型用于分类初始总类别船舶分类模型中的其中一种船舶类型。例如,若初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类,那么该初始总类别船舶分类模型则对应有五个初始单类别船舶分类模型:快艇分类模型,其只能对快艇进行分类,其它类型船舶识别为背景;管装船分类模型,其只能对管装船进行分类,其它类型船舶识别为背景;货船分类模型,其只能对货船进行分类,其它类型船舶识别为背景;液压船分类模型,其只能对液压船进行分类,其它类型的船舶识别为背景;帆船分类模型,其只能对帆船进行分类,其它类型的船舶识别为背景。
标定数据集提前由人工进行打标,标定数据集中的每个标定数据均为一张船舶图片,且带有人工标记的原标签。例如,一张快艇的图片,其标签(本实施例中指类别)由人工标定为快艇。
未标定数据集为没有标签的数据集合,未标定数据集中的每个未标定数据可能为一张船舶图片,也可能为不存在船舶的纯背景图片。例如,一张快艇的图片,但其不对应任何标签(本实施例中指类别)。
训练模块200,用于用标定数据集分别对初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
标定数据集为准确的数据集,采用标定数据集对初始总类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型;
另外,标定数据集还对每个初始单类别船舶分类模型进行模型训练,得到每个初始单类别船舶分类模型对应的第一单类别船舶分类模型。
作为一种可选的实施方式,标定数据集对从类别船舶分类模型的训练、和对多个初始单类别船舶分类模型的训练可以同时训练,也可以依次进行训练,当然,也可以同时对其中几个分类模型进行训练。本实施例对同一时间训练的模型数量不做限制。
融合模块300,用于用第一总类别船舶分类模型预测未标定数据集,得到第一总类别数据集;其中,第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测未标定数据集,得到第一单类别数据集;其中,第一单类别数据集包括每个未标定数据的第二预测结果;将第一总类别数据集中的第一预测结果与第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
采用第一总类别船舶分类模型对未标定数据集中的每个未标定数据(本实施例中数据指图片)进行推理,得到每个未标定数据的第一预测结果,这些未标定数据的集合记为第一总类别数据集。其中,第一预测结果包括每个未标定数据的类别与置信度。
采用多个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行推理,得到其中每个未标定数据的第二预测结果,这些未标定数据的集合记为第一单类别数据集。其中,第二预测结果包括每个未标定数据的类别与置信度。具体的,首先选取其中一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,之后再选取另一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,重复上述步骤,直至所有单类别船舶分类模型均对未标定数据集进行了预测,由于每个第一单类别船舶分类模型仅能分类一种类别,因此,所有的第一单类别船舶分类模型可对未标定数据集中的每个未标定数据进行全面的预测。
此时,对于每个未标定数据来说均对应一个第一预测结果及第二预测结果,为实现后续的数据处理,需将第一预测结果与第二预测结果融合,得到每个未标定数据的第三预测结果,所有对应有第三预测结果的未标定数据的集合记为第一伪标签训练集。其中,第三预测结果包括每个未标定数据集的类别及置信度。
本实施例中,由于单类别船舶分类模型的识别准确率高于总类别船舶分类模型的识别准确率,因此使用多个单类别船舶分类模型分别进行后续的预测操作,可以使单类别船舶分类模型的预测结果对总类别船舶分类模型的预测结果产生监督作用。
第一预测模块400,用于提取多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用多个第二单类别船舶分类模型对未标定数据集进行预测,得到第二单类别数据集;其中,第二单类别数据集包括每个未标定数据集的第四预测结果;
将一个第一单类别船舶分类模型作为预训练模型,提取第一单类别船舶分类模型的参数并对相应单类别船舶分类网络进行赋值。使用第一伪标签训练集和标定数据集作为训练集,对单类别船舶分类网络进行模型训练,得到对应的第二单类别船舶分类模型。
对每个第一单类别船舶分类模型均重复上述步骤,得到多个第二单类别船舶分类模型。
分别使用训练得到的多个第二单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第四预测结果,这些未标定数据的集合记为第二单类别数据集。其中,第四预测结果包括未标定数据的类别及置信度。
第二预测模块500,将第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
在总类别船舶分类网络进行模型训练之前,以每个未标定数据为主体,将其对应的第一预测结果与第四预测结果进行融合,得到每个未标定数据的第五预测结果,所有对应有第五预测结果的未标定数据的集合记为第二伪标签训练集。其中,第五预测结果包括每个未标定数据的类别及置信度。
将第一总类别船舶分类模型作为预训练模型,提取第一总类别船舶分类模型的参数并对总类别船舶分类网络进行赋值。使用第二伪标签训练集和标定数据集作为训练集,对总类别船舶分类网络进行模型训练,得到的第二总类别船舶分类模型。
使用训练得到的第二总类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第六预测结果,这些未标定数据的集合记为第二总类别数据集。其中,第六预测结果包括未标定数据的类别及置信度。
数据筛选模块600,用于将第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
以每个未标定数据为主体,将其对应的第六预测结果与第四预测结果进行融合,得到每个未标定数据的第七预测结果,所有对应有第七预测结果的未标定数据的集合记为第三伪标签训练集。其中,第七预测结果包括每个未标定数据的类别及置信度。
采用第三伪标签训练集及标定数据集作为目标训练集,对一个第二单类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,记为奇异伪标签数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三单类别船舶分类模型。
对每个第二单类别船舶分类模型均重复上述步骤,直至根据每个第二单类别船舶分类模型训练过程筛选出奇异伪标签数据,并将其从第三伪标签训练集中删除。得到多个第三单类别船舶分类模型。
采用第三伪标签训练集及标定数据集作为目标训练集,对第二总类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三总类别船舶分类模型
将删除了所有奇异伪标签数据的第三伪标签训练集记为第四伪标签训练集。
根据模型训练中每轮的推理结果找出奇异样本,排除错误样本对模型的负面影响。
最终船舶分类模型生成模块700,用于用多个第三单类别船舶分类模型对第四伪标签训练集进行预测,得到第三单类别数据集;其中,第三单类别数据集包括每个未标定数据的第八预测结果;用第三总类别船舶分类模型对第四伪标签训练集进行预测,得到第三总类别数据集;其中,第三总类别数据集包括每个未标定数据的第九预测结果;将第三单类别数据集中的第八预测结果与第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用第五伪标签训练集对第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
分别使用训练得到的多个第三单类别船舶分类模型对第四伪标签训练集中每个数据进行预测,得到每个数据的第八预测结果,这些数据的集合记为第三单类别数据集。其中,第八预测结果包括数据的类别及置信度。
使用第三总类别船舶分类模型对第四伪标签训练集中的每个数据进行预测,得到每个数据的第九预测结果,这些数据的集合记为第三总类别数据集。其中,第九预测结果包括数据的类别及置信度。
以每个数据为主体,将其对应的第八预测结果与第九预测结果进行融合,得到每个数据的第十预测结果,所有对应有第十预测结果的数据集合记为第五伪标签训练集。其中,第十预测结果包括每个数据的类别及置信度。
采用第五伪标签训练集对第三总类别船舶分类模型进行训练,直至模型收敛,得到最终船舶分类模型。
作为一种可选的实施方式,图7是本发明实施例提供的融合模块的结构示意图,如图7所示,融合模块300包括:
第一识别子模块3001,用于识别第一预测结果中每个未标定数据对应的第一类别及第一置信度;
第一类别及第一置信度均由第一总类别船舶分类模型对未标定数据预测得到。
第二识别子模块3002,用于识别第二预测结果中每个未标定数据对应的第二类别及第二置信度;
第二类别及第二置信度均由多个第一单类别船舶分类模型对未标定数据预测得到。
计算子模块3003,用于根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
两种预测结果的融合时,首先通过第一类别与第二类别计算类别赋值权重,再根据第一置信度、第二置信度及类别赋值权重,得到未标定数据对应的第三置信度。
作为一种可选的实施方式,根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
其中,mkij为第k个未标定数据的类别赋值权重,为第k个未标定数据第一类别,为第k个未标定数据的第二类别;
若每个未标定数据的第一类别与第二类别一致,则将类别赋值权重为1;若每个未标定数据的第一类别与第二类别不一致,则将类别赋值权重为0。
根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
两个预测结果融合时,若第一类别与第二类别不一致,则置信度也为0,即不参与候选的数据处理;若第一类别与第二类别一致,则置信度不为0,根据第一置信度与第二置信度得到对应未标定数据的第三置信度。
该步骤实现了单类别船舶分类模型的监督,如果总类别分类模型与单类别分类模型的预测类别不一致,置为0;类别一致,置为1。采用第三类别赋值置信度权重,防止错误样本对模型训练的影响。
作为一种可选的实施方式,图8是本发明实施例提供的第一预测模块的结构示意图,如图8所示,第一预测模块400,包括:
第一采样子模块4001,用于按第一比例在第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在标定数据集中抽取第一标定训练子集,将第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为,第二比例为。也即,每次迭代从标定数据集中抽取张图片,记为第一标定训练子集;从第一伪标签训练集中抽取张图片,作为第一伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第一训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第一伪标签训练子集时,从第一伪标签训练集中抽取置信度大于0的张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
第一单次迭代子模块4002,用于采用第一训练样本集对单类别船舶分类网络进行训练,并计算该次迭代的损失值;
使用得到的第一训练样本集对单类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
第一网络更新子模块4003,用于根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,根据偏导结果对单类别船舶分类网络的参数进行修正;
根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使单类别船舶分类网络更改后的向正确方向学习。
第一循环子模块4004,用于控制重复第一采样子模块、第一单次迭代子模块及第一网络更新子模块的操作,直至单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二单类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值的差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
对每个第一单类别船舶分类模型均采用上述操作,完成每个第一单类别船舶分类模型的模型训练,得到多个单类别船舶分类模型。
作为一种可选的实施方式,图9是本发明实施例提供的第二预测模块的结构示意图,如图9所示,第二预测模块500,包括:
第二采样子模块5001,用于按第一比例在第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在标定数据集中抽取第二标定训练子集,将第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为,第二比例为。也即,每次迭代从标定数据集中抽取张图片,记为第二标定训练子集;从第二伪标签训练集中抽取张图片,作为第二伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第二训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第二伪标签训练子集时,从第二伪标签训练集中抽取置信度大于0的张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
第二单次迭代子模块5002,采用第二训练样本集对总类别船舶分类网络进行训练,得到该次迭代的损失值;
使用得到的第二训练样本集对总类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
第二网络更新子模块5003,用于根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,根据偏导结果对总类别船舶分类网络的参数进行修正;
根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使总类别船舶分类网络更改后的向正确方向学习。
第二循环子模块5004,用于控制重复第二采样子模块、第二单次迭代子模块及第二网络更新子模块的操作,直至总类别船舶分类网络收敛,得到第二总类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二总类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值的差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
作为一种可选的实施方式,图10是本发明实施例提供的数据筛选模块的结构示意图,如图9所示,数据筛选模块600,包括:
第一奇异伪标签数据标记子模块6001,用于采用目标训练集分别对多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
以一个第二单类别船舶分类模型为例,采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间单类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第一变化清单。
每个第二单类别船舶分类模型均进行以上操作,得到每个未标定数据的多轮预测结果,这些预测结果记录进第一变化清单。
第一变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第一变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮单类别船舶分类模型训练。
第二奇异伪标签数据标记子模块6002,用于采用目标训练集对第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间总类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第二变化清单。
第二变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第二变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮总类别船舶分类模型训练。
数据剔除子模块6003,用于在第三伪标签训练集中删除奇异伪标签数据,得到第四伪标签训练集。
从第三伪标签训练集中删除以上标记的奇异伪标签数据,得到更新后的第四伪标签训练集。
作为一种可选的实施方式,采用以下公式计算该次迭代的损失值:
其中,loss为每次迭代的损失值,为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
具体的,在单类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;为每次迭代的第一伪标签训练子集的数据量;Nlabel为每次迭代的第一标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前单类别船舶分类模型预测得到的预测类别与为伪标签的损失值;lossj表示第j个标定数据通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,通过当前单类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第一伪标签训练集中每个未标定数据对应的类别。
具体的,在总类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;为每次迭代的第二伪标签训练子集的数据量;Nlabel为每次迭代的第二标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值;lossj表示第j个标定数据通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第二伪标签训练集中每个未标定数据对应的类别。
上述技术方案具有如下有益效果:基于推理结果分析和其他模型监督的半监督船舶类型模型训练方法,缓解了伪标签错误类别对分类模型错误引导,从损失函数和伪标签数据集类别的准确率进行优化。修改采样方法,每次迭代的训练集由较高比例的标定训练数据和较低比例的伪标签数据组成,用人工标定数据引导伪标签数据集,降低伪标签数据集错误类别影响;伪标签数据集损失权重,利用其他模型得到的推理结果和自身模型推理结果进行结果融合,得到损失权重,降低伪标签数据集错误类别的影响;选用准确率高的其他船舶类型分类模型作为监督,对伪标签数据预测得到伪标签类别;挖掘自身船舶类别分类模型每轮推理结果信息,通过类别和置信度筛选,得到准确率更高的伪标签训练集。
在采样方法上进行修改,增加每次迭代正确训练样本数量,在训练样本数量上降低错误类别样本对损失值影响;在损失函数上进行优化,增加伪标签数据集的权重,其权重由其他模型和自身模型结果融合得到,降低错误样本的影响;通过自身模型推理结果,对伪标签数据集的类别准确率进行分析;其他模型监督,提高伪标签数据集的准确率。
以上发明的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上内容仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (12)
1.一种半监督船舶分类模型训练方法,其特征在于,包括:
S1、获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
S2、用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
S3、用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
S4、提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;
S5、将所述第一总类别数据集中的第一预测结果与所述第二单类别数据集中的第四预测结果融合,得到对应有第五预测结果的未标定数据的集合,记为第二伪标签训练集;提取所述第一总类别船舶分类模型的参数,并对总类别船舶分类网络进行赋值,将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型;用所述第二总类别船舶分类模型对所述未标定数据集进行预测,得到第二总类别数据集;其中,所述第二总类别数据集包括每个未标定数据集的第六预测结果;
S6、将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
S7、用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
2.根据权利要求1所述的半监督船舶分类模型训练方法,其特征在于,所述将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集,包括:
S301、识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;
S302、识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;
S303、根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
3.根据权利要求2所述的半监督船舶分类模型训练方法,其特征在于,所述根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
其中,mkij为第k个未标定数据的类别赋值权重,Labelik为第k个未标定数据第一类别,Labeljk为第k个未标定数据的第二类别;
所述根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
pk=mkijpkipkj
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
4.根据权利要求2所述的半监督船舶分类模型训练方法,其特征在于,所述将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型,其中,任一单类别船舶分类网络进行模型训练包括:
S401、按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
S402、采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;
S403、根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;
S404、重复S401-S403,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
5.根据权利要求4所述的半监督船舶分类模型训练方法,其特征在于,所述将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型,包括:
S501、按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
S502、采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;
S503、根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;
S504、重复S501-503,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
6.根据权利要求5所述的半监督船舶分类模型训练方法,其特征在于,所述用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集,包括:
S601、采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
S602、采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
S603、在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
7.根据权利要求4-5任一项所述的半监督船舶分类模型训练方法,其特征在于,采用以下公式计算该次迭代的损失值:
其中,loss为每次迭代的损失值,N伪为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
8.一种半监督船舶分类模型训练装置,其特征在于,包括:
获取模块,用于获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
训练模块,用于用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
融合模块,用于用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
第一预测模块,用于提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;
第二预测模块,用于将所述第一总类别数据集中的第一预测结果与所述第二单类别数据集中的第四预测结果融合,得到对应有第五预测结果的未标定数据的集合,记为第二伪标签训练集;提取所述第一总类别船舶分类模型的参数,并对总类别船舶分类网络进行赋值,将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型;用所述第二总类别船舶分类模型对所述未标定数据集进行预测,得到第二总类别数据集;其中,所述第二总类别数据集包括每个未标定数据集的第六预测结果;
数据筛选模块,用于将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
最终船舶分类模型生成模块,用于用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
9.根据权利要求8所述的半监督船舶分类模型训练装置,其特征在于,所述融合模块包括:
第一识别子模块,用于识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;
第二识别子模块,用于识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;
计算子模块,用于根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
10.根据权利要求9所述的半监督船舶分类模型训练装置,其特征在于,所述第一预测模块包括:
第一采样子模块,用于按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
第一单次迭代子模块,用于采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;
第一网络更新子模块,用于根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;
第一循环子模块,用于控制重复第一采样子模块、第一单次迭代子模块及第一网络更新子模块的操作,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
11.根据权利要求10所述的半监督船舶分类模型训练装置,其特征在于,所述第二预测模块包括:
第二采样子模块,用于按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
第二单次迭代子模块,用于采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;
第二网络更新子模块,用于根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;
第二循环子模块,用于控制重复第二采样子模块、第二单次迭代子模块及第二网络更新子模块的操作,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
12.根据权利要求11所述的半监督船舶分类模型训练装置,其特征在于,所述数据筛选模块包括:
第一奇异伪标签数据标记子模块,用于采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
第二奇异伪标签数据标记子模块,用于采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
数据剔除子模块,用于在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210721409.7A CN114998691B (zh) | 2022-06-24 | 2022-06-24 | 半监督船舶分类模型训练方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210721409.7A CN114998691B (zh) | 2022-06-24 | 2022-06-24 | 半监督船舶分类模型训练方法及装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114998691A CN114998691A (zh) | 2022-09-02 |
CN114998691B true CN114998691B (zh) | 2023-04-18 |
Family
ID=83036565
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210721409.7A Active CN114998691B (zh) | 2022-06-24 | 2022-06-24 | 半监督船舶分类模型训练方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114998691B (zh) |
Families Citing this family (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115359062B (zh) * | 2022-10-24 | 2023-01-24 | 浙江华是科技股份有限公司 | 通过半监督实例分割标定监控目标的方法及系统 |
CN115620155B (zh) * | 2022-12-19 | 2023-03-10 | 浙江华是科技股份有限公司 | 一种变电站目标检测方法、系统及计算机存储介质 |
CN117152587B (zh) * | 2023-10-27 | 2024-01-26 | 浙江华是科技股份有限公司 | 一种基于对抗学习的半监督船舶检测方法及系统 |
CN117557477B (zh) * | 2024-01-09 | 2024-04-05 | 浙江华是科技股份有限公司 | 一种船舶去雾复原方法及系统 |
Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113705769A (zh) * | 2021-05-17 | 2021-11-26 | 华为技术有限公司 | 一种神经网络训练方法以及装置 |
CN114186615A (zh) * | 2021-11-22 | 2022-03-15 | 浙江华是科技股份有限公司 | 船舶检测半监督在线训练方法、装置及计算机存储介质 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111382758B (zh) * | 2018-12-28 | 2023-12-26 | 杭州海康威视数字技术股份有限公司 | 训练图像分类模型、图像分类方法、装置、设备及介质 |
CN112183577A (zh) * | 2020-08-31 | 2021-01-05 | 华为技术有限公司 | 一种半监督学习模型的训练方法、图像处理方法及设备 |
CN113269267B (zh) * | 2021-06-15 | 2024-04-26 | 苏州挚途科技有限公司 | 目标检测模型的训练方法、目标检测方法和装置 |
-
2022
- 2022-06-24 CN CN202210721409.7A patent/CN114998691B/zh active Active
Patent Citations (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113705769A (zh) * | 2021-05-17 | 2021-11-26 | 华为技术有限公司 | 一种神经网络训练方法以及装置 |
CN114186615A (zh) * | 2021-11-22 | 2022-03-15 | 浙江华是科技股份有限公司 | 船舶检测半监督在线训练方法、装置及计算机存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN114998691A (zh) | 2022-09-02 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114998691B (zh) | 半监督船舶分类模型训练方法及装置 | |
CN107067025A (zh) | 一种基于主动学习的数据自动标注方法 | |
WO2022057671A1 (zh) | 一种基于神经网络的知识图谱不一致性推理方法 | |
CN111985325B (zh) | 特高压环境评价中的航拍小目标快速识别方法 | |
CN112149547A (zh) | 基于图像金字塔引导和像素对匹配的遥感影像水体识别 | |
CN113037783B (zh) | 一种异常行为检测方法及系统 | |
CN114863091A (zh) | 一种基于伪标签的目标检测训练方法 | |
Mahasin et al. | Comparison of cspdarknet53, cspresnext-50, and efficientnet-b0 backbones on yolo v4 as object detector | |
CN116484024A (zh) | 一种基于知识图谱的多层次知识库构建方法 | |
CN116225760A (zh) | 一种基于运维知识图谱的实时根因分析方法 | |
CN114758199A (zh) | 检测模型的训练方法、装置、设备和存储介质 | |
CN114298270A (zh) | 融合领域知识的污染物浓度预测方法及其相关设备 | |
CN112579777B (zh) | 一种未标注文本的半监督分类方法 | |
CN117058882A (zh) | 一种基于多特征双判别器的交通数据补偿方法 | |
CN115438190B (zh) | 一种配电网故障辅助决策知识抽取方法及系统 | |
CN116630277A (zh) | 一种基于持续学习的pcb板缺陷检测方法及装置 | |
CN114595770B (zh) | 一种船舶轨迹的长时序预测方法 | |
CN113821452B (zh) | 根据被测系统测试表现动态生成测试案例的智能测试方法 | |
CN115457305A (zh) | 一种半监督目标检测方法与系统 | |
CN114943741A (zh) | 一种动态场景下基于目标检测和几何概率的视觉slam方法 | |
CN112347826B (zh) | 一种基于强化学习的视频连续手语识别方法及系统 | |
CN114972429A (zh) | 云边协同自适应推理路径规划的目标追踪方法和系统 | |
CN117252851B (zh) | 一种基于图像检测识别的标准质量检测管理平台 | |
CN113434617B (zh) | 一种基于船舶轨迹的行为自动划分方法、系统及电子设备 | |
Wang | Semi-supervised Semantic Segmentation Network based on Knowledge Distillation |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |