CN114998691A - 半监督船舶分类模型训练方法及装置 - Google Patents

半监督船舶分类模型训练方法及装置 Download PDF

Info

Publication number
CN114998691A
CN114998691A CN202210721409.7A CN202210721409A CN114998691A CN 114998691 A CN114998691 A CN 114998691A CN 202210721409 A CN202210721409 A CN 202210721409A CN 114998691 A CN114998691 A CN 114998691A
Authority
CN
China
Prior art keywords
training
class
data
ship classification
ship
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210721409.7A
Other languages
English (en)
Other versions
CN114998691B (zh
Inventor
吴显德
邹凡
欧阳志益
雷明根
赵立立
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang Whyis Technology Co ltd
Original Assignee
Zhejiang Whyis Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang Whyis Technology Co ltd filed Critical Zhejiang Whyis Technology Co ltd
Priority to CN202210721409.7A priority Critical patent/CN114998691B/zh
Publication of CN114998691A publication Critical patent/CN114998691A/zh
Application granted granted Critical
Publication of CN114998691B publication Critical patent/CN114998691B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06V10/7753Incorporation of unlabelled data, e.g. multiple instance learning [MIL]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A10/00TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE at coastal zones; at river basins
    • Y02A10/40Controlling or monitoring, e.g. of flood or hurricane; Forecasting, e.g. risk assessment or mapping

Abstract

本发明公开了一种半监督船舶分类模型训练方法及装置,其中,该方法包括:采用标定数据集分别对初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;使用第一总类别船舶分类模型及多个第一单类别船舶分类模型分别赋值网络,并进行模型训练;使用训练得到的模型对最新的伪标签训练集进行预测,并从中删除奇异样本;将删除奇异样本的伪标签训练集作为训练集对总类别船舶分类模型进行训练,得到最终船舶分类模型。本发明降低了错误样本对模型的影响,提高模型的准确率。

Description

半监督船舶分类模型训练方法及装置
技术领域
本发明涉及船舶类型分类领域,尤其涉及一种半监督船舶分类模型训练方法及装置。
背景技术
随着水上交通管理任务不断增加和人工智能的快速发展,人工智能在水上管理业务的重要性迅速提升,其中船舶分类是实现水上交通管理自动化重要因素之一。但由于我国水路交通路线复杂且环境差别偏大,同一类型船舶在外观上区别较大且在部分水上路线数据集无法采集,由于这些原因,导致收集到的数据集具有局限性,无法覆盖我国所有水上交通路线背景和船舶类型信息。基于上述问题,目前很多算法工程师研发了半监督模型训练方法,但目前半监督船舶分类模型训练算法存下以下问题:
1)伪标签数据集由目前已存在船舶分类模型获得,不受其他模型监督;
2)没有挖掘船舶分类模型训练过程中的推理结果对伪标签训练集的监督。发明内容
为解决上述问题,本发明提供一种半监督船舶分类模型训练方法及装置,通过在模型训练中,对伪标签数据集中数据的置信度及类别的变化进行分析;另外,通过将总类别船舶分类模型差分成多个单类别船舶分类模型,并采用多个单类别船舶分类模型的推理结果对总类别船舶分类模型的分析结果进行监督,提高最终生成的船舶分类模型的准确率,以解决上述现有技术中的问题。
为达到上述目的,S1、获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;S2、用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;S3、用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;S4、提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;S5、将所述第一总类别数据集中的第一预测结果与所述第二单类别数据集中的第四预测结果融合,得到对应有第五预测结果的未标定数据的集合,记为第二伪标签训练集;提取所述第一总类别船舶分类模型的参数,并对总类别船舶分类网络进行赋值,将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型;用所述第二总类别船舶分类模型对所述未标定数据集进行预测,得到第二总类别数据集;其中,所述第二总类别数据集包括每个未标定数据集的第六预测结果;S6、将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;S7、用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
进一步可选的,所述将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集,包括:S301、识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;S302、识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;S303、根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
进一步可选的,所述根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
Figure 32714DEST_PATH_IMAGE001
其中,mkij为第k个未标定数据的类别赋值权重,
Figure 60713DEST_PATH_IMAGE002
为第k个未标定数据第一类别,
Figure 304613DEST_PATH_IMAGE003
为第k个未标定数据的第二类别;
所述根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
Figure 494285DEST_PATH_IMAGE004
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
进一步可选的,所述将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型,其中,任一单类别船舶分类网络进行模型训练包括:S401、按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;S402、采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;S403、根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;S404、重复S401-S403,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
进一步可选的,所述将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型,包括:S501、按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;S502、采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;S503、根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;S504、重复S501-503,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
进一步可选的,所述用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集,包括:S601、采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;S602、采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;S603、在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
进一步可选的,采用以下公式计算该次迭代的损失值:
Figure 456425DEST_PATH_IMAGE005
其中,loss为每次迭代的损失值,
Figure 22536DEST_PATH_IMAGE006
为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
另一方面,本发明还提供了一种半监督船舶分类模型训练装置,包括:获取模块,用于获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;训练模块,用于用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;融合模块,用于用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;第一预测模块,用于提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;第二预测模块,将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;数据筛选模块,用于将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;最终船舶分类模型生成模块,用于用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
进一步可选的,所述融合模块包括:第一识别子模块,用于识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;第二识别子模块,用于识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;计算子模块,用于根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
进一步可选的,所述第一预测模块包括:第一采样子模块,用于按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;第一单次迭代子模块,用于采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;第一网络更新子模块,用于根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;第一循环子模块,用于控制重复第一采样子模块、第一单次迭代子模块及第一网络更新子模块的操作,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
进一步可选的,所述第二预测模块包括:第二采样子模块,用于按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;第二单次迭代子模块、采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;第二网络更新子模块,用于根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;第二循环子模块,用于控制重复第二采样子模块、第二单次迭代子模块及第二网络更新子模块的操作,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
进一步可选的,所述数据筛选模块包括:第一奇异伪标签数据标记子模块,用于采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;第二奇异伪标签数据标记子模块,用于采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;数据剔除子模块,用于在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
上述技术方案具有如下有益效果:在以往半监督分类模型中增加其他模型监督,使伪标签数据集的标签每一轮都会受到其他船舶单类别分类模型监督,得到更准确类别的伪标签数据集;伪标签数据集准确率受自身船舶类别分类模型推理历史结果监督,挖掘每轮训练推理的类别和置信度,减少错位伪标签数据对模型训练的影响;每次迭代增加正确数据集的张数,从而降低错误样本对损失值的占比,减少模型反向传播中错误样本的作用;对伪标签损失函数增加权重参数,增加自身模型推理结果和其他类别模型推理结果监督,如果伪标签的分类结果错误,那么自身推理的置信度和其他模型推理的结果中的置信度一般偏低,从而降低错误样本的对模型的影响。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的半监督船舶分类模型训练方法的流程图;
图2是本发明实施例提供的第一伪标签训练集生成方法的流程图;
图3是本发明实施例提供的第二单类别船舶分类模型生成方法的流程图;
图4是本发明实施例提供的第二总类别船舶分类模型生成方法的流程图;
图5是本发明实施例提供的伪标签训练集更新方法的流程图;
图6是本发明实施例提供的半监督船舶分类模型训练装置的结构示意图;
图7是本发明实施例提供的融合模块的结构示意图;
图8是本发明实施例提供的第一预测模块的结构示意图;
图9是本发明实施例提供的第二预测模块的结构示意图;
图10是本发明实施例提供的数据筛选模块的结构示意图。
附图标记:100-获取模块 200-训练模块 300-融合模块 3001-第一识别子模块3002-第二识别子模块 3003-计算子模块 400-第一预测模块 4001-第一采样子模块4002-第一单次迭代子模块 4003-第一网络更新子模块 4004-第一循环子模块 500-第二预测模块 5001-第二采样子模块 5002-第二单次迭代子模块 5003-第二网络更新子模块 5004-第二循环子模块 600-数据筛选模块 6001-第一奇异伪标签数据标记子模块6002-第二奇异伪标签数据标记子模块 6003-数据剔除子模块 700-最终船舶分类模型生成模块。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
由于我国水上交通线路复杂和船舶类型内间差异大,导致无法对所有水上交通背景和船舶类型进行收集,另外,目前已存在半监督模型训练方法存在不受其他模型监督和伪标签训练存在一定概率错误样本等弊端。众所周知,深度学习分类模型对数据集类别的标定十分的依赖,如果训练集不准确,将会引导船舶分类模型向错误方向学习,导致模型的准确性降低,错误率增加。
因此为减少学习错误样本对模型的影响和提高伪标签训练集的准确性,本发明提供了一种半监督船舶分类模型训练方法,图1是本发明实施例提供的半监督船舶分类模型训练方法的流程图,如图1所示,该方法包括:
S1、获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
初始总类别船舶分类模型可以识别多种类别的船舶,例如,初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类。
每个初始单类别船舶分类模型用于分类初始总类别船舶分类模型中的其中一种船舶类型。例如,若初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类,那么该初始总类别船舶分类模型则对应有五个初始单类别船舶分类模型:快艇分类模型,其只能对快艇进行分类,其它类型船舶识别为背景;管装船分类模型,其只能对管装船进行分类,其它类型船舶识别为背景;货船分类模型,其只能对货船进行分类,其它类型船舶识别为背景;液压船分类模型,其只能对液压船进行分类,其它类型的船舶识别为背景;帆船分类模型,其只能对帆船进行分类,其它类型的船舶识别为背景。
标定数据集提前由人工进行打标,标定数据集中的每个标定数据均为一张船舶图片,且带有人工标记的原标签。例如,一张快艇的图片,其标签(本实施例中指类别)由人工标定为快艇。
未标定数据集为没有标签的数据集合,未标定数据集中的每个未标定数据可能为一张船舶图片,也可能为不存在船舶的纯背景图片。例如,一张快艇的图片,但其不对应任何标签(本实施例中指类别)。
S2、用标定数据集分别对初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
标定数据集为准确的数据集,采用标定数据集对初始总类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型;
另外,标定数据集还对每个初始单类别船舶分类模型进行模型训练,得到每个初始单类别船舶分类模型对应的第一单类别船舶分类模型。
作为一种可选的实施方式,标定数据集对从类别船舶分类模型的训练、和对多个初始单类别船舶分类模型的训练可以同时训练,也可以依次进行训练,当然,也可以同时对其中几个分类模型进行训练。本实施例对同一时间训练的模型数量不做限制。
S3、用第一总类别船舶分类模型预测未标定数据集,得到第一总类别数据集;其中,第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测未标定数据集,得到第一单类别数据集;其中,第一单类别数据集包括每个未标定数据的第二预测结果;将第一总类别数据集中的第一预测结果与第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
采用第一总类别船舶分类模型对未标定数据集中的每个未标定数据(本实施例中数据指图片)进行推理,得到每个未标定数据的第一预测结果,这些未标定数据的集合记为第一总类别数据集。其中,第一预测结果包括每个未标定数据的类别与置信度。
采用多个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行推理,得到其中每个未标定数据的第二预测结果,这些未标定数据的集合记为第一单类别数据集。其中,第二预测结果包括每个未标定数据的类别与置信度。具体的,首先选取其中一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,之后再选取另一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,重复上述步骤,直至所有单类别船舶分类模型均对未标定数据集进行了预测,由于每个第一单类别船舶分类模型仅能分类一种类别,因此,所有的第一单类别船舶分类模型可对未标定数据集中的每个未标定数据进行全面的预测。
此时,对于每个未标定数据来说均对应一个第一预测结果及第二预测结果,为实现后续的数据处理,需将第一预测结果与第二预测结果融合,得到每个未标定数据的第三预测结果,所有对应有第三预测结果的未标定数据的集合记为第一伪标签训练集。其中,第三预测结果包括每个未标定数据集的类别及置信度。
本实施例中,由于单类别船舶分类模型的识别准确率高于总类别船舶分类模型的识别准确率,因此使用多个单类别船舶分类模型分别进行后续的预测操作,可以使单类别船舶分类模型的预测结果对总类别船舶分类模型的预测结果产生监督作用。
S4、提取多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用多个第二单类别船舶分类模型对未标定数据集进行预测,得到第二单类别数据集;其中,第二单类别数据集包括每个未标定数据集的第四预测结果;
将一个第一单类别船舶分类模型作为预训练模型,提取第一单类别船舶分类模型的参数并对相应单类别船舶分类网络进行赋值。使用第一伪标签训练集和标定数据集作为训练集,对单类别船舶分类网络进行模型训练,得到对应的第二单类别船舶分类模型。
对每个第一单类别船舶分类模型均重复上述步骤,得到多个第二单类别船舶分类模型。
分别使用训练得到的多个第二单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第四预测结果,这些未标定数据的集合记为第二单类别数据集。其中,第四预测结果包括未标定数据的类别及置信度。
S5、将第一总类别数据集中的第一预测结果与第二单类别数据集中的第四预测结果融合,得到对应有第五预测结果的未标定数据的集合,记为第二伪标签训练集;提取第一总类别船舶分类模型的参数,并对总类别船舶分类网络进行赋值,将第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型;用第二总类别船舶分类模型对未标定数据集进行预测,得到第二总类别数据集;其中,第二总类别数据集包括每个未标定数据集的第六预测结果;
在总类别船舶分类网络进行模型训练之前,以每个未标定数据为主体,将其对应的第一预测结果与第四预测结果进行融合,得到每个未标定数据的第五预测结果,所有对应有第五预测结果的未标定数据的集合记为第二伪标签训练集。其中,第五预测结果包括每个未标定数据的类别及置信度。
将第一总类别船舶分类模型作为预训练模型,提取第一总类别船舶分类模型的参数并对总类别船舶分类网络进行赋值。使用第二伪标签训练集和标定数据集作为训练集,对总类别船舶分类网络进行模型训练,得到的第二总类别船舶分类模型。
使用训练得到的第二总类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第六预测结果,这些未标定数据的集合记为第二总类别数据集。其中,第六预测结果包括未标定数据的类别及置信度。
S6、将第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
以每个未标定数据为主体,将其对应的第六预测结果与第四预测结果进行融合,得到每个未标定数据的第七预测结果,所有对应有第七预测结果的未标定数据的集合记为第三伪标签训练集。其中,第七预测结果包括每个未标定数据的类别及置信度。
采用第三伪标签训练集及标定数据集作为目标训练集,对一个第二单类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,记为奇异伪标签数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三单类别船舶分类模型。
对每个第二单类别船舶分类模型均重复上述步骤,直至根据每个第二单类别船舶分类模型训练过程筛选出奇异伪标签数据,并将其从第三伪标签训练集中删除。得到多个第三单类别船舶分类模型。
采用第三伪标签训练集及标定数据集作为目标训练集,对第二总类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三总类别船舶分类模型。
将删除了所有奇异伪标签数据的第三伪标签训练集记为第四伪标签训练集。
根据模型训练中每轮的推理结果找出奇异样本,排除错误样本对模型的负面影响。
S7、用多个第三单类别船舶分类模型对第四伪标签训练集进行预测,得到第三单类别数据集;其中,第三单类别数据集包括每个未标定数据的第八预测结果;用第三总类别船舶分类模型对第四伪标签训练集进行预测,得到第三总类别数据集;其中,第三总类别数据集包括每个未标定数据的第九预测结果;将第三单类别数据集中的第八预测结果与第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用第五伪标签训练集对第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
分别使用训练得到的多个第三单类别船舶分类模型对第四伪标签训练集中每个数据进行预测,得到每个数据的第八预测结果,这些数据的集合记为第三单类别数据集。其中,第八预测结果包括数据的类别及置信度。
使用第三总类别船舶分类模型对第四伪标签训练集中的每个数据进行预测,得到每个数据的第九预测结果,这些数据的集合记为第三总类别数据集。其中,第九预测结果包括数据的类别及置信度。
以每个数据为主体,将其对应的第八预测结果与第九预测结果进行融合,得到每个数据的第十预测结果,所有对应有第十预测结果的数据集合记为第五伪标签训练集。其中,第十预测结果包括每个数据的类别及置信度。
采用第五伪标签训练集对第三总类别船舶分类模型进行训练,直至模型收敛,得到最终船舶分类模型。
作为一种可选的实施方式,图2是本发明实施例提供的第一伪标签训练集生成方法的流程图,如图2所示,将第一总类别数据集中的第一预测结果与第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集,包括:
S301、识别第一预测结果中每个未标定数据对应的第一类别及第一置信度;
第一类别及第一置信度均由第一总类别船舶分类模型对未标定数据预测得到。
S302、识别第二预测结果中每个未标定数据对应的第二类别及第二置信度;
第二类别及第二置信度均由多个第一单类别船舶分类模型对未标定数据预测得到。
S303、根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
两种预测结果的融合时,首先通过第一类别与第二类别计算类别赋值权重,再根据第一置信度、第二置信度及类别赋值权重,得到未标定数据对应的第三置信度。
作为一种具体的实施方式,根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
Figure 887986DEST_PATH_IMAGE001
其中,mkij为第k个未标定数据的类别赋值权重,
Figure 248560DEST_PATH_IMAGE002
为第k个未标定数据第一类别,
Figure 697996DEST_PATH_IMAGE003
为第k个未标定数据的第二类别;
若每个未标定数据的第一类别与第二类别一致,则将类别赋值权重为1;若每个未标定数据的第一类别与第二类别不一致,则将类别赋值权重为0。
根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
Figure 67797DEST_PATH_IMAGE004
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
两个预测结果融合时,若第一类别与第二类别不一致,则置信度也为0,即不参与候选的数据处理;若第一类别与第二类别一致,则置信度不为0,根据第一置信度与第二置信度得到对应未标定数据的第三置信度。
该步骤实现了单类别船舶分类模型的监督,如果总类别分类模型与单类别分类模型的预测类别不一致,置为0;类别一致,置为1。采用第三类别赋值置信度权重,防止错误样本对模型训练的影响。
作为一种可选的实施方式,图3是本发明实施例提供的第二单类别船舶分类模型生成方法的流程图,如图3所示,将第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型,其中,任一单类别船舶分类网络进行模型训练,包括:
S401、按第一比例在第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在标定数据集中抽取第一标定训练子集,将第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为
Figure 20710DEST_PATH_IMAGE007
,第二比例为
Figure 817765DEST_PATH_IMAGE008
。也即,每次迭代从标定数据集中抽取
Figure 488917DEST_PATH_IMAGE008
张图片,记为第一标定训练子集;从第一伪标签训练集中抽取
Figure 662410DEST_PATH_IMAGE007
张图片,作为第一伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第一训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第一伪标签训练子集时,从第一伪标签训练集中抽取置信度大于0的
Figure 64021DEST_PATH_IMAGE007
张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
S402、采用第一训练样本集对单类别船舶分类网络进行训练,并计算该次迭代的损失值;
使用得到的第一训练样本集对单类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
S403、根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,根据偏导结果对单类别船舶分类网络的参数进行修正;
根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使单类别船舶分类网络更改后的向正确方向学习。
S404、重复S401-S403,直至单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二单类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
对每个第一单类别船舶分类模型均采用上述操作,完成每个第一单类别船舶分类模型的模型训练,得到多个单类别船舶分类模型。
作为一种可选的实施方式,图4是本发明实施例提供的第二总类别船舶分类模型生成方法的流程图,如图4所示,将第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型,包括:
S501、按第一比例在第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在标定数据集中抽取第二标定训练子集,将第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为
Figure 31977DEST_PATH_IMAGE007
,第二比例为
Figure 190426DEST_PATH_IMAGE008
。也即,每次迭代从标定数据集中抽取
Figure 902030DEST_PATH_IMAGE008
张图片,记为第二标定训练子集;从第二伪标签训练集中抽取
Figure 829535DEST_PATH_IMAGE007
张图片,作为第二伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第二训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第二伪标签训练子集时,从第二伪标签训练集中抽取置信度大于0的
Figure 968392DEST_PATH_IMAGE007
张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
S502、采用第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;
使用得到的第二训练样本集对总类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
S503、根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,依据偏导结果对总类别船舶分类网络中的参数进行修正;
根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使总类别船舶分类网络更改后的向正确方向学习。
S504、重复S501-503,直至单类别船舶分类网络收敛,得到第二总类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二总类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
作为一种可选的实施方式,图5是本发明实施例提供的伪标签训练集更新方法的流程图,如图5所示,用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集,包括:
S601、采用目标训练集分别对多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
以一个第二单类别船舶分类模型为例,采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间单类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。直至模型收敛,该模型会进行多轮训练,多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第一变化清单。
每个第二单类别船舶分类模型均进行以上操作,得到每个未标定数据的多轮预测结果,这些预测结果记录进第一变化清单。
第一变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第一变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮单类别船舶分类模型训练。
S602、采用目标训练集对第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间总类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。直至模型收敛,该模型会进行多轮训练,多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第二变化清单。
第二变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第二变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮总类别船舶分类模型训练。
S603、在第三伪标签训练集中删除奇异伪标签数据,得到第四伪标签训练集。
从第三伪标签训练集中删除以上标记的奇异伪标签数据,得到更新后的第四伪标签训练集。
作为一种可选的实施方式,采用以下公式计算该次迭代的损失值:
Figure 614137DEST_PATH_IMAGE005
其中,loss为每次迭代的损失值,
Figure 863853DEST_PATH_IMAGE006
为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
具体的,在单类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;
Figure 412908DEST_PATH_IMAGE006
为每次迭代的第一伪标签训练子集的数据量;Nlabel为每次迭代的第一标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前单类别船舶分类模型预测得到的预测类别与为伪标签的损失值;lossj表示第j个标定数据通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,通过当前单类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第一伪标签训练集中每个未标定数据对应的类别。
具体的,在总类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;
Figure 722667DEST_PATH_IMAGE006
为每次迭代的第二伪标签训练子集的数据量;Nlabel为每次迭代的第二标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值;lossj表示第j个标定数据通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第二伪标签训练集中每个未标定数据对应的类别。
作为一种可选的实施方式,本发明实施例还提供了一种半监督船舶分类模型训练装置,图6是本发明实施例提供的半监督船舶分类模型训练装置的结构示意图,如图6所示,该装置包括:
获取模块100,用于获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
初始总类别船舶分类模型可以识别多种类别的船舶,例如,初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类。
每个初始单类别船舶分类模型用于分类初始总类别船舶分类模型中的其中一种船舶类型。例如,若初始总类别船舶分类模型可以对快艇、管装船、货船、液压船及帆船五种类别的船舶进行分类,那么该初始总类别船舶分类模型则对应有五个初始单类别船舶分类模型:快艇分类模型,其只能对快艇进行分类,其它类型船舶识别为背景;管装船分类模型,其只能对管装船进行分类,其它类型船舶识别为背景;货船分类模型,其只能对货船进行分类,其它类型船舶识别为背景;液压船分类模型,其只能对液压船进行分类,其它类型的船舶识别为背景;帆船分类模型,其只能对帆船进行分类,其它类型的船舶识别为背景。
标定数据集提前由人工进行打标,标定数据集中的每个标定数据均为一张船舶图片,且带有人工标记的原标签。例如,一张快艇的图片,其标签(本实施例中指类别)由人工标定为快艇。
未标定数据集为没有标签的数据集合,未标定数据集中的每个未标定数据可能为一张船舶图片,也可能为不存在船舶的纯背景图片。例如,一张快艇的图片,但其不对应任何标签(本实施例中指类别)。
训练模块200,用于用标定数据集分别对初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
标定数据集为准确的数据集,采用标定数据集对初始总类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型;
另外,标定数据集还对每个初始单类别船舶分类模型进行模型训练,得到每个初始单类别船舶分类模型对应的第一单类别船舶分类模型。
作为一种可选的实施方式,标定数据集对从类别船舶分类模型的训练、和对多个初始单类别船舶分类模型的训练可以同时训练,也可以依次进行训练,当然,也可以同时对其中几个分类模型进行训练。本实施例对同一时间训练的模型数量不做限制。
融合模块300,用于用第一总类别船舶分类模型预测未标定数据集,得到第一总类别数据集;其中,第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测未标定数据集,得到第一单类别数据集;其中,第一单类别数据集包括每个未标定数据的第二预测结果;将第一总类别数据集中的第一预测结果与第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
采用第一总类别船舶分类模型对未标定数据集中的每个未标定数据(本实施例中数据指图片)进行推理,得到每个未标定数据的第一预测结果,这些未标定数据的集合记为第一总类别数据集。其中,第一预测结果包括每个未标定数据的类别与置信度。
采用多个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行推理,得到其中每个未标定数据的第二预测结果,这些未标定数据的集合记为第一单类别数据集。其中,第二预测结果包括每个未标定数据的类别与置信度。具体的,首先选取其中一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,之后再选取另一个第一单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,重复上述步骤,直至所有单类别船舶分类模型均对未标定数据集进行了预测,由于每个第一单类别船舶分类模型仅能分类一种类别,因此,所有的第一单类别船舶分类模型可对未标定数据集中的每个未标定数据进行全面的预测。
此时,对于每个未标定数据来说均对应一个第一预测结果及第二预测结果,为实现后续的数据处理,需将第一预测结果与第二预测结果融合,得到每个未标定数据的第三预测结果,所有对应有第三预测结果的未标定数据的集合记为第一伪标签训练集。其中,第三预测结果包括每个未标定数据集的类别及置信度。
本实施例中,由于单类别船舶分类模型的识别准确率高于总类别船舶分类模型的识别准确率,因此使用多个单类别船舶分类模型分别进行后续的预测操作,可以使单类别船舶分类模型的预测结果对总类别船舶分类模型的预测结果产生监督作用。
第一预测模块400,用于提取多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用多个第二单类别船舶分类模型对未标定数据集进行预测,得到第二单类别数据集;其中,第二单类别数据集包括每个未标定数据集的第四预测结果;
将一个第一单类别船舶分类模型作为预训练模型,提取第一单类别船舶分类模型的参数并对相应单类别船舶分类网络进行赋值。使用第一伪标签训练集和标定数据集作为训练集,对单类别船舶分类网络进行模型训练,得到对应的第二单类别船舶分类模型。
对每个第一单类别船舶分类模型均重复上述步骤,得到多个第二单类别船舶分类模型。
分别使用训练得到的多个第二单类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第四预测结果,这些未标定数据的集合记为第二单类别数据集。其中,第四预测结果包括未标定数据的类别及置信度。
第二预测模块500,将第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
在总类别船舶分类网络进行模型训练之前,以每个未标定数据为主体,将其对应的第一预测结果与第四预测结果进行融合,得到每个未标定数据的第五预测结果,所有对应有第五预测结果的未标定数据的集合记为第二伪标签训练集。其中,第五预测结果包括每个未标定数据的类别及置信度。
将第一总类别船舶分类模型作为预训练模型,提取第一总类别船舶分类模型的参数并对总类别船舶分类网络进行赋值。使用第二伪标签训练集和标定数据集作为训练集,对总类别船舶分类网络进行模型训练,得到的第二总类别船舶分类模型。
使用训练得到的第二总类别船舶分类模型对未标定数据集中的每个未标定数据进行预测,得到每个未标定数据的第六预测结果,这些未标定数据的集合记为第二总类别数据集。其中,第六预测结果包括未标定数据的类别及置信度。
数据筛选模块600,用于将第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用第三伪标签训练集及标定数据集作为目标训练集分别对多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
以每个未标定数据为主体,将其对应的第六预测结果与第四预测结果进行融合,得到每个未标定数据的第七预测结果,所有对应有第七预测结果的未标定数据的集合记为第三伪标签训练集。其中,第七预测结果包括每个未标定数据的类别及置信度。
采用第三伪标签训练集及标定数据集作为目标训练集,对一个第二单类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,记为奇异伪标签数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三单类别船舶分类模型。
对每个第二单类别船舶分类模型均重复上述步骤,直至根据每个第二单类别船舶分类模型训练过程筛选出奇异伪标签数据,并将其从第三伪标签训练集中删除。得到多个第三单类别船舶分类模型。
采用第三伪标签训练集及标定数据集作为目标训练集,对第二总类别船舶分类模型进行训练,每轮训练均记录未标定数据对应的预测结果,根据多轮训练的预测结果变化情况,选取出多轮预测结果差异过大的数据,并从第三伪标签训练集中删除该数据。训练完成后得到对应的第三总类别船舶分类模型
将删除了所有奇异伪标签数据的第三伪标签训练集记为第四伪标签训练集。
根据模型训练中每轮的推理结果找出奇异样本,排除错误样本对模型的负面影响。
最终船舶分类模型生成模块700,用于用多个第三单类别船舶分类模型对第四伪标签训练集进行预测,得到第三单类别数据集;其中,第三单类别数据集包括每个未标定数据的第八预测结果;用第三总类别船舶分类模型对第四伪标签训练集进行预测,得到第三总类别数据集;其中,第三总类别数据集包括每个未标定数据的第九预测结果;将第三单类别数据集中的第八预测结果与第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用第五伪标签训练集对第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
分别使用训练得到的多个第三单类别船舶分类模型对第四伪标签训练集中每个数据进行预测,得到每个数据的第八预测结果,这些数据的集合记为第三单类别数据集。其中,第八预测结果包括数据的类别及置信度。
使用第三总类别船舶分类模型对第四伪标签训练集中的每个数据进行预测,得到每个数据的第九预测结果,这些数据的集合记为第三总类别数据集。其中,第九预测结果包括数据的类别及置信度。
以每个数据为主体,将其对应的第八预测结果与第九预测结果进行融合,得到每个数据的第十预测结果,所有对应有第十预测结果的数据集合记为第五伪标签训练集。其中,第十预测结果包括每个数据的类别及置信度。
采用第五伪标签训练集对第三总类别船舶分类模型进行训练,直至模型收敛,得到最终船舶分类模型。
作为一种可选的实施方式,图7是本发明实施例提供的融合模块的结构示意图,如图7所示,融合模块300包括:
第一识别子模块3001,用于识别第一预测结果中每个未标定数据对应的第一类别及第一置信度;
第一类别及第一置信度均由第一总类别船舶分类模型对未标定数据预测得到。
第二识别子模块3002,用于识别第二预测结果中每个未标定数据对应的第二类别及第二置信度;
第二类别及第二置信度均由多个第一单类别船舶分类模型对未标定数据预测得到。
计算子模块3003,用于根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
两种预测结果的融合时,首先通过第一类别与第二类别计算类别赋值权重,再根据第一置信度、第二置信度及类别赋值权重,得到未标定数据对应的第三置信度。
作为一种可选的实施方式,根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
Figure 855708DEST_PATH_IMAGE001
其中,mkij为第k个未标定数据的类别赋值权重,
Figure 909115DEST_PATH_IMAGE002
为第k个未标定数据第一类别,
Figure 811212DEST_PATH_IMAGE003
为第k个未标定数据的第二类别;
若每个未标定数据的第一类别与第二类别一致,则将类别赋值权重为1;若每个未标定数据的第一类别与第二类别不一致,则将类别赋值权重为0。
根据类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
Figure 291871DEST_PATH_IMAGE004
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
两个预测结果融合时,若第一类别与第二类别不一致,则置信度也为0,即不参与候选的数据处理;若第一类别与第二类别一致,则置信度不为0,根据第一置信度与第二置信度得到对应未标定数据的第三置信度。
该步骤实现了单类别船舶分类模型的监督,如果总类别分类模型与单类别分类模型的预测类别不一致,置为0;类别一致,置为1。采用第三类别赋值置信度权重,防止错误样本对模型训练的影响。
作为一种可选的实施方式,图8是本发明实施例提供的第一预测模块的结构示意图,如图8所示,第一预测模块400,包括:
第一采样子模块4001,用于按第一比例在第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在标定数据集中抽取第一标定训练子集,将第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为
Figure 646629DEST_PATH_IMAGE007
,第二比例为
Figure 503727DEST_PATH_IMAGE008
。也即,每次迭代从标定数据集中抽取
Figure 758866DEST_PATH_IMAGE008
张图片,记为第一标定训练子集;从第一伪标签训练集中抽取
Figure 410427DEST_PATH_IMAGE007
张图片,作为第一伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第一训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第一伪标签训练子集时,从第一伪标签训练集中抽取置信度大于0的
Figure 518060DEST_PATH_IMAGE007
张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
第一单次迭代子模块4002,用于采用第一训练样本集对单类别船舶分类网络进行训练,并计算该次迭代的损失值;
使用得到的第一训练样本集对单类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
第一网络更新子模块4003,用于根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,根据偏导结果对单类别船舶分类网络的参数进行修正;
根据该次迭代的损失值对单类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使单类别船舶分类网络更改后的向正确方向学习。
第一循环子模块4004,用于控制重复第一采样子模块、第一单次迭代子模块及第一网络更新子模块的操作,直至单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二单类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值的差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
对每个第一单类别船舶分类模型均采用上述操作,完成每个第一单类别船舶分类模型的模型训练,得到多个单类别船舶分类模型。
作为一种可选的实施方式,图9是本发明实施例提供的第二预测模块的结构示意图,如图9所示,第二预测模块500,包括:
第二采样子模块5001,用于按第一比例在第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在标定数据集中抽取第二标定训练子集,将第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
本实施例提供一种新的数据采样器,将训练集分为两个文件,分别为人工标定的训练集及伪标签数据训练集。
每次迭代的第一训练样本集为N张,作为一种可选的实施例,第一比例为
Figure 913269DEST_PATH_IMAGE007
,第二比例为
Figure 462062DEST_PATH_IMAGE008
。也即,每次迭代从标定数据集中抽取
Figure 612421DEST_PATH_IMAGE008
张图片,记为第二标定训练子集;从第二伪标签训练集中抽取
Figure 879454DEST_PATH_IMAGE007
张图片,作为第二伪标签训练子集,将二者组成该次迭代所需的包含有N张图片的第二训练样本集。
为提高模型训练效率,减少错误率,本实施例中,抽取每次迭代的第二伪标签训练子集时,从第二伪标签训练集中抽取置信度大于0的
Figure 140671DEST_PATH_IMAGE007
张图片。
本实施例提供的采样方法,目的是为了模型每一步迭代都有标定数据和伪标签数据参与,降低伪标签的影响,提高人工打标的标签比重,既可以让模型学习到伪标签信息,又降低伪标签错误信息的影响。
第二单次迭代子模块5002,采用第二训练样本集对总类别船舶分类网络进行训练,得到该次迭代的损失值;
使用得到的第二训练样本集对总类别船舶分类网络进行模型训练,并计算该次迭代的损失值。
第二网络更新子模块5003,用于根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,根据偏导结果对总类别船舶分类网络的参数进行修正;
根据该次迭代的损失值对总类别船舶分类网络进行反向求偏导,依据偏导结果对网络中的参数进行修正,使总类别船舶分类网络更改后的向正确方向学习。
第二循环子模块5004,用于控制重复第二采样子模块、第二单次迭代子模块及第二网络更新子模块的操作,直至总类别船舶分类网络收敛,得到第二总类别船舶分类模型。
重复以上操作,即每次迭代使用上述采样方法获取该次迭代的训练样本集,并采用训练样本集对当前模型网络进行反向传播,更新网络参数,直至所有的数据集遍历完成,标志完成一轮模型训练;进行多轮训练,直至网络收敛,此时得到第二总类别船舶分类模型。
作为一种可选的实施方式,当该次迭代的损失值与上次迭代的损失值的差距范围在0.003时,说明损失值基本趋于稳定,标志模型收敛。
作为一种可选的实施方式,图10是本发明实施例提供的数据筛选模块的结构示意图,如图9所示,数据筛选模块600,包括:
第一奇异伪标签数据标记子模块6001,用于采用目标训练集分别对多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
以一个第二单类别船舶分类模型为例,采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间单类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第一变化清单。
每个第二单类别船舶分类模型均进行以上操作,得到每个未标定数据的多轮预测结果,这些预测结果记录进第一变化清单。
第一变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第一变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮单类别船舶分类模型训练。
第二奇异伪标签数据标记子模块6002,用于采用目标训练集对第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
采用以上采样方法作为每次迭代的采样方法,从第三伪标签训练集及标定数据集中进行采样,得到每次迭代的目标训练子集,采用目标训练子集完成一次迭代,当遍历完成第三伪标签训练集和标定数据集时,完成一轮训练。每轮训练后会得到一个中间总类别船舶分类模型,使用该中间单类别船舶分类模型对未标定数据集进行预测,会得到未标定数据集的预测结果。多轮训练后,会得到未标定数据的多轮预测结果,将这些预测结果记录进第二变化清单。
第二变化清单中记录有每个未标定数据每轮的置信度及类别,通过在训练过程中分析连续预设轮数置信度和类别的变化,筛选出奇异伪标签数据。
作为一种具体的实施方式,预设轮数为n=5,当前轮数为k,分析第二变化清单,如果满足k-n>0,对k-n,……,k-3,k-2,k-1,k预测结果进行分析,如果当前伪标签训练集中的伪标签近n轮的类别都不一样或置信度连续变小,则将该数据标记为奇异伪标签数据,不参与下一轮总类别船舶分类模型训练。
数据剔除子模块6003,用于在第三伪标签训练集中删除奇异伪标签数据,得到第四伪标签训练集。
从第三伪标签训练集中删除以上标记的奇异伪标签数据,得到更新后的第四伪标签训练集。
作为一种可选的实施方式,采用以下公式计算该次迭代的损失值:
Figure 543971DEST_PATH_IMAGE005
其中,loss为每次迭代的损失值,
Figure 366696DEST_PATH_IMAGE006
为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
具体的,在单类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;
Figure 121025DEST_PATH_IMAGE006
为每次迭代的第一伪标签训练子集的数据量;Nlabel为每次迭代的第一标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前单类别船舶分类模型预测得到的预测类别与为伪标签的损失值;lossj表示第j个标定数据通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,通过当前单类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前单类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第一伪标签训练集中每个未标定数据对应的类别。
具体的,在总类别船舶分类模型的损失值计算中,loss为每次迭代的损失值;
Figure 920354DEST_PATH_IMAGE006
为每次迭代的第二伪标签训练子集的数据量;Nlabel为每次迭代的第二标定训练子集的数据量;Pi为第i个未标定数据对应的第三置信度,即上述的Pk;lossi为第i个未标定数据通过当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值;lossj表示第j个标定数据通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值。
其中,当前总类别船舶分类模型预测得到的预测类别与伪标签的损失值、通过当前总类别船舶分类模型预测得到的预测类别与原标定类别的损失值可通过Sigmoid函数进行计算。
其中,伪标签为第二伪标签训练集中每个未标定数据对应的类别。
上述技术方案具有如下有益效果:基于推理结果分析和其他模型监督的半监督船舶类型模型训练方法,缓解了伪标签错误类别对分类模型错误引导,从损失函数和伪标签数据集类别的准确率进行优化。修改采样方法,每次迭代的训练集由较高比例的标定训练数据和较低比例的伪标签数据组成,用人工标定数据引导伪标签数据集,降低伪标签数据集错误类别影响;伪标签数据集损失权重,利用其他模型得到的推理结果和自身模型推理结果进行结果融合,得到损失权重,降低伪标签数据集错误类别的影响;选用准确率高的其他船舶类型分类模型作为监督,对伪标签数据预测得到伪标签类别;挖掘自身船舶类别分类模型每轮推理结果信息,通过类别和置信度筛选,得到准确率更高的伪标签训练集。
在采样方法上进行修改,增加每次迭代正确训练样本数量,在训练样本数量上降低错误类别样本对损失值影响;在损失函数上进行优化,增加伪标签数据集的权重,其权重由其他模型和自身模型结果融合得到,降低错误样本的影响;通过自身模型推理结果,对伪标签数据集的类别准确率进行分析;其他模型监督,提高伪标签数据集的准确率。
以上发明的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上内容仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (12)

1.一种半监督船舶分类模型训练方法,其特征在于,包括:
S1、获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
S2、用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
S3、用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
S4、提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;
S5、将所述第一总类别数据集中的第一预测结果与所述第二单类别数据集中的第四预测结果融合,得到对应有第五预测结果的未标定数据的集合,记为第二伪标签训练集;提取所述第一总类别船舶分类模型的参数,并对总类别船舶分类网络进行赋值,将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型;用所述第二总类别船舶分类模型对所述未标定数据集进行预测,得到第二总类别数据集;其中,所述第二总类别数据集包括每个未标定数据集的第六预测结果;
S6、将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
S7、用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
2.根据权利要求1所述的半监督船舶分类模型训练方法,其特征在于,所述将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集,包括:
S301、识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;
S302、识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;
S303、根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
3.根据权利要求2所述的半监督船舶分类模型训练方法,其特征在于,所述根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,通过以下公式计算:
Figure DEST_PATH_IMAGE001
其中,mkij为第k个未标定数据的类别赋值权重,
Figure DEST_PATH_IMAGE002
为第k个未标定数据第一类别,
Figure DEST_PATH_IMAGE003
为第k个未标定数据的第二类别;
所述根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度,通过以下公式计算:
Figure DEST_PATH_IMAGE004
其中,pk为第k个未标定数据的第三置信度,pki为第k个未标定数据的第一置信度,pkj为第k个未标定数据的第二置信度。
4.根据权利要求2所述的半监督船舶分类模型训练方法,其特征在于,所述将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型,其中,任一单类别船舶分类网络进行模型训练包括:
S401、按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
S402、采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;
S403、根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;
S404、重复S401-S403,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
5.根据权利要求4所述的半监督船舶分类模型训练方法,其特征在于,所述将所述第二伪标签训练集与标定数据集输入至总类别船舶分类网络进行模型训练,得到第二总类别船舶分类模型,包括:
S501、按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
S502、采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;
S503、根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;
S504、重复S501-503,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
6.根据权利要求5所述的半监督船舶分类模型训练方法,其特征在于,所述用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集,包括:
S601、采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
S602、采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
S603、在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
7.根据权利要求4-5任一项所述的半监督船舶分类模型训练方法,其特征在于,采用以下公式计算该次迭代的损失值:
Figure DEST_PATH_IMAGE005
其中,loss为每次迭代的损失值,
Figure DEST_PATH_IMAGE006
为每次迭代对应的伪标签训练子集中的数据量,Nlabel为每次迭代对应的标定训练子集的数据量,Pi为第i个未标定数据对应的第三置信度,lossi为第i个未标定数据的预测类别与伪标签的损失值,lossj表示第j个标定数据的预测类别与原标定类别的损失值。
8.一种半监督船舶分类模型训练装置,其特征在于,包括:
获取模块,用于获取初始总类别船舶分类模型、多个初始单类别船舶分类模型、标定数据集及未标定数据集;其中,所述总类别船舶分类模型用于分类所有船舶类型,单类别船舶分类模型用于分类一种船舶类型;
训练模块,用于用标定数据集分别对所述初始总类别船舶分类模型及多个初始单类别船舶分类模型进行模型训练,得到第一总类别船舶分类模型及多个第一单类别船舶分类模型;
融合模块,用于用所述第一总类别船舶分类模型预测所述未标定数据集,得到第一总类别数据集;其中,所述第一总类别数据集包括每个未标定数据的第一预测结果;分别用多个第一单类别船舶分类模型预测所述未标定数据集,得到第一单类别数据集;其中,所述第一单类别数据集包括每个未标定数据的第二预测结果;将所述第一总类别数据集中的第一预测结果与所述第一单类别数据集中的第二预测结果融合,得到对应有第三预测结果的未标定数据的集合,记为第一伪标签训练集;
第一预测模块,用于提取所述多个第一单类别船舶分类模型的参数,并分别对相应的单类别船舶分类网络进行赋值,将所述第一伪标签训练集与标定数据集输入至相应的单类别船舶分类网络进行模型训练,得到多个第二单类别船舶分类模型;分别用所述多个第二单类别船舶分类模型对所述未标定数据集进行预测,得到第二单类别数据集;其中,所述第二单类别数据集包括每个未标定数据集的第四预测结果;
第二预测模块,将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
数据筛选模块,用于将所述第二总类别数据集中的第六预测结果与第二单类别数据集中的第四预测结果进行融合,得到对应有第七预测结果的未标定数据的集合,记为第三伪标签训练集;用所述第三伪标签训练集及标定数据集作为目标训练集分别对所述多个第二单类别船舶分类模型及第二总类别船舶分类模型进行训练,并在训练过程中删除奇异伪标签数据,得到第四伪标签训练集、多个第三单类别船舶分类模型及第三总类别船舶分类模型;
最终船舶分类模型生成模块,用于用所述多个第三单类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三单类别数据集;其中,所述第三单类别数据集包括每个未标定数据的第八预测结果;用所述第三总类别船舶分类模型对所述第四伪标签训练集进行预测,得到第三总类别数据集;其中,所述第三总类别数据集包括每个未标定数据的第九预测结果;将所述第三单类别数据集中的第八预测结果与所述第三总类别数据集中的第九预测结果融合,得到对应有第十预测结果的未标定数据的集合,记为第五伪标签训练集;采用所述第五伪标签训练集对所述第三总类别船舶分类模型进行训练,得到最终船舶分类模型。
9.根据权利要求8所述的半监督船舶分类模型训练装置,其特征在于,所述融合模块包括:
第一识别子模块,用于识别所述第一预测结果中每个未标定数据对应的第一类别及第一置信度;
第二识别子模块,用于识别所述第二预测结果中每个未标定数据对应的第二类别及第二置信度;
计算子模块,用于根据每个未标定数据对应的第一类别及第二类别,计算类别赋值权重,根据所述类别赋值权重、第一置信度及第二置信度,计算得到第三置信度。
10.根据权利要求9所述的半监督船舶分类模型训练装置,其特征在于,所述第一预测模块包括:
第一采样子模块,用于按第一比例在所述第一伪标签训练集中抽取第一伪标签训练子集,按第二比例在所述标定数据集中抽取第一标定训练子集,将所述第一伪标签训练子集和第一标定训练子集作为该次迭代的第一训练样本集;
第一单次迭代子模块,用于采用所述第一训练样本集对所述单类别船舶分类网络进行训练,并计算该次迭代的损失值;
第一网络更新子模块,用于根据该次迭代的损失值对所述单类别船舶分类网络进行反向求偏导,根据偏导结果对所述单类别船舶分类网络的参数进行修正;
第一循环子模块,用于控制重复第一采样子模块、第一单次迭代子模块及第一网络更新子模块的操作,直至所述单类别船舶分类网络收敛,得到对应的第二单类别船舶分类模型。
11.根据权利要求10所述的半监督船舶分类模型训练装置,其特征在于,所述第二预测模块包括:
第二采样子模块,用于按第一比例在所述第二伪标签训练集中抽取第二伪标签训练子集,按第二比例在所述标定数据集中抽取第二标定训练子集,将所述第二伪标签训练子集和第二标定训练子集作为该次迭代的第二训练样本集;
第二单次迭代子模块,用于采用所述第二训练样本集对所述总类别船舶分类网络进行训练,得到该次迭代的损失值;
第二网络更新子模块,用于根据该次迭代的损失值对所述总类别船舶分类网络进行反向求偏导,根据偏导结果对所述总类别船舶分类网络的参数进行修正;
第二循环子模块,用于控制重复第二采样子模块、第二单次迭代子模块及第二网络更新子模块的操作,直至所述总类别船舶分类网络收敛,得到所述第二总类别船舶分类模型。
12.根据权利要求11所述的半监督船舶分类模型训练装置,其特征在于,所述数据筛选模块包括:
第一奇异伪标签数据标记子模块,用于采用所述目标训练集分别对所述多个第二单类别船舶分类模型进行训练,依次记录每轮训练生成的中间单类别船舶分类模型对未标定数据集进行预测的预测结果,记为第一变化清单,选取所述第一变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
第二奇异伪标签数据标记子模块,用于采用所述目标训练集对所述第二总类别船舶分类模型进行训练,依次记录每轮训练生成的中间总类别船舶分类模型对未标定数据集预测的预测结果,记为第二变化清单,选取所述第二变化清单中,连续预设轮数类别均不一致,和/或置信度连续变小的数据,将其标记为奇异伪标签数据;
数据剔除子模块,用于在所述第三伪标签训练集中删除所述奇异伪标签数据,得到第四伪标签训练集。
CN202210721409.7A 2022-06-24 2022-06-24 半监督船舶分类模型训练方法及装置 Active CN114998691B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210721409.7A CN114998691B (zh) 2022-06-24 2022-06-24 半监督船舶分类模型训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210721409.7A CN114998691B (zh) 2022-06-24 2022-06-24 半监督船舶分类模型训练方法及装置

Publications (2)

Publication Number Publication Date
CN114998691A true CN114998691A (zh) 2022-09-02
CN114998691B CN114998691B (zh) 2023-04-18

Family

ID=83036565

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210721409.7A Active CN114998691B (zh) 2022-06-24 2022-06-24 半监督船舶分类模型训练方法及装置

Country Status (1)

Country Link
CN (1) CN114998691B (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115359062A (zh) * 2022-10-24 2022-11-18 浙江华是科技股份有限公司 通过半监督实例分割标定监控目标的方法及系统
CN115620155A (zh) * 2022-12-19 2023-01-17 浙江华是科技股份有限公司 一种变电站目标检测方法、系统及计算机存储介质
CN117152587A (zh) * 2023-10-27 2023-12-01 浙江华是科技股份有限公司 一种基于对抗学习的半监督船舶检测方法及系统
CN117557477A (zh) * 2024-01-09 2024-02-13 浙江华是科技股份有限公司 一种船舶去雾复原方法及系统

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111382758A (zh) * 2018-12-28 2020-07-07 杭州海康威视数字技术股份有限公司 训练图像分类模型、图像分类方法、装置、设备及介质
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN113269267A (zh) * 2021-06-15 2021-08-17 苏州挚途科技有限公司 目标检测模型的训练方法、目标检测方法和装置
CN113705769A (zh) * 2021-05-17 2021-11-26 华为技术有限公司 一种神经网络训练方法以及装置
CN114186615A (zh) * 2021-11-22 2022-03-15 浙江华是科技股份有限公司 船舶检测半监督在线训练方法、装置及计算机存储介质

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111382758A (zh) * 2018-12-28 2020-07-07 杭州海康威视数字技术股份有限公司 训练图像分类模型、图像分类方法、装置、设备及介质
CN112183577A (zh) * 2020-08-31 2021-01-05 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
WO2022042002A1 (zh) * 2020-08-31 2022-03-03 华为技术有限公司 一种半监督学习模型的训练方法、图像处理方法及设备
CN113705769A (zh) * 2021-05-17 2021-11-26 华为技术有限公司 一种神经网络训练方法以及装置
CN113269267A (zh) * 2021-06-15 2021-08-17 苏州挚途科技有限公司 目标检测模型的训练方法、目标检测方法和装置
CN114186615A (zh) * 2021-11-22 2022-03-15 浙江华是科技股份有限公司 船舶检测半监督在线训练方法、装置及计算机存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
DANYU LAI: "Improving classification with semi-supervised and fine-grained learning" *
余游: "基于深度网络的少样本学习算法研究" *

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115359062A (zh) * 2022-10-24 2022-11-18 浙江华是科技股份有限公司 通过半监督实例分割标定监控目标的方法及系统
CN115620155A (zh) * 2022-12-19 2023-01-17 浙江华是科技股份有限公司 一种变电站目标检测方法、系统及计算机存储介质
CN117152587A (zh) * 2023-10-27 2023-12-01 浙江华是科技股份有限公司 一种基于对抗学习的半监督船舶检测方法及系统
CN117152587B (zh) * 2023-10-27 2024-01-26 浙江华是科技股份有限公司 一种基于对抗学习的半监督船舶检测方法及系统
CN117557477A (zh) * 2024-01-09 2024-02-13 浙江华是科技股份有限公司 一种船舶去雾复原方法及系统
CN117557477B (zh) * 2024-01-09 2024-04-05 浙江华是科技股份有限公司 一种船舶去雾复原方法及系统

Also Published As

Publication number Publication date
CN114998691B (zh) 2023-04-18

Similar Documents

Publication Publication Date Title
CN114998691B (zh) 半监督船舶分类模型训练方法及装置
CN114241282B (zh) 一种基于知识蒸馏的边缘设备场景识别方法及装置
CN110232350B (zh) 一种基于在线学习的实时水面多运动目标检测跟踪方法
CN110580496A (zh) 一种基于熵最小化的深度迁移学习系统及方法
CN111079836B (zh) 基于伪标签方法和弱监督学习的过程数据故障分类方法
CN109948522B (zh) 一种基于深度神经网络的x光片手骨成熟度判读方法
CN110909909A (zh) 基于深度学习和多层时空特征图的短时交通流预测方法
CN113313166B (zh) 基于特征一致性学习的船舶目标自动标注方法
CN109919302B (zh) 一种用于图像的神经网络的训练方法及装置
CN112215412A (zh) 溶解氧预测方法及装置
CN116484024A (zh) 一种基于知识图谱的多层次知识库构建方法
CN114758199A (zh) 检测模型的训练方法、装置、设备和存储介质
CN114942951A (zh) 一种基于ais数据的渔船捕鱼行为分析方法
EP3975062A1 (en) Method and system for selecting data to train a model
CN111797935B (zh) 基于群体智能的半监督深度网络图片分类方法
CN117058882A (zh) 一种基于多特征双判别器的交通数据补偿方法
CN115438190B (zh) 一种配电网故障辅助决策知识抽取方法及系统
CN114595770B (zh) 一种船舶轨迹的长时序预测方法
CN114495114B (zh) 基于ctc解码器的文本序列识别模型校准方法
CN113821452B (zh) 根据被测系统测试表现动态生成测试案例的智能测试方法
CN115457305A (zh) 一种半监督目标检测方法与系统
CN114972429A (zh) 云边协同自适应推理路径规划的目标追踪方法和系统
CN113139624A (zh) 基于机器学习的网络用户分类方法
Suyal et al. An Agile Review of Machine Learning Technique
CN117251599B (zh) 一种视频语料智能测试优化方法、装置和存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant