CN114155365A - 模型训练方法、图像处理方法及相关装置 - Google Patents

模型训练方法、图像处理方法及相关装置 Download PDF

Info

Publication number
CN114155365A
CN114155365A CN202210115405.4A CN202210115405A CN114155365A CN 114155365 A CN114155365 A CN 114155365A CN 202210115405 A CN202210115405 A CN 202210115405A CN 114155365 A CN114155365 A CN 114155365A
Authority
CN
China
Prior art keywords
image
feature map
model
original
detection
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210115405.4A
Other languages
English (en)
Other versions
CN114155365B (zh
Inventor
孟慧
余子牛
谷宁波
李青锋
牛建伟
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Hangzhou Innovation Research Institute of Beihang University
Original Assignee
Hangzhou Innovation Research Institute of Beihang University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Hangzhou Innovation Research Institute of Beihang University filed Critical Hangzhou Innovation Research Institute of Beihang University
Priority to CN202210115405.4A priority Critical patent/CN114155365B/zh
Publication of CN114155365A publication Critical patent/CN114155365A/zh
Application granted granted Critical
Publication of CN114155365B publication Critical patent/CN114155365B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Engineering & Computer Science (AREA)
  • Software Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Medical Informatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及图像处理技术领域,提供一种模型训练方法、图像处理方法及相关装置,所述方法包括:获取带标签的第一样本图像和不带标签的第二样本图像;将第一样本图像及第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图;将第一特征图和第二特征图输入至原始模型的注意力约束网络,得到第二特征图的强化特征图;将第一特征图及强化特征图分别输入至原始模型的回归网络,得到原始模型的第一检测结果和第二检测结果;根据第一检测结果、标签、第二检测结果及预设损失函数,对原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。本发明训练得到的检测模型具有较高的检测精度和泛化能力。

Description

模型训练方法、图像处理方法及相关装置
技术领域
本发明涉及图像处理技术领域,具体而言,涉及一种模型训练方法、图像处理方法及相关装置。
背景技术
监督式学习和无监督式学习是图像处理领域中经常使用的两种机器学习策略,两者的区别在于是否需要人工参与数据的标注,即是否需要对训练数据打标签。传统的监督式学习训练需要预先收集大量数据并对其进行标注以构建训练集,然后在此基础上进行建模拟合,最后让模型预测未知数据的结果。对于难以进行人工标注类别或进行人工类别标注的成本太高或者只需要分类,并不需要识别具体类别等应用场景,传统的监督式学习适用度并不高,无监督式学习虽然避免了数据标注带来的巨大工作量,但是图像处理时进行目标检测的精度通常不能满足要求。
发明内容
本发明的目的在于提供了一种模型训练方法、图像处理方法及相关装置,其能够在模型训练时利用注意力约束机制,以加强无标签的样本图像对模型的优化,最终提高训练后得到的检测模型的检测精度。
为了实现上述目的,本发明实施例采用的技术方案如下:
第一方面,本发明实施例提供了一种模型训练方法,所述方法包括:
获取带标签的第一样本图像和不带标签的第二样本图像;
将所述第一样本图像及所述第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图;
将所述第一特征图和所述第二特征图输入至所述原始模型的注意力约束网络,得到所述第二特征图的强化特征图;
将所述第一特征图及所述强化特征图分别输入至所述原始模型的回归网络,得到所述原始模型的第一检测结果和第二检测结果;
根据所述第一检测结果、所述标签、所述第二检测结果及预设损失函数,对所述原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。
进一步地,所述将所述第一特征图和所述第二特征图输入至所述原始模型的注意力约束网络,得到所述第二特征图的强化特征图的步骤包括:
利用所述注意力约束网络的池化层对所述第一特征图进行池化处理,得到注意力向量;
将所述注意力向量和所述第二特征图输入至所述注意力约束网络的强化层进行特征强化,得到所述第二特征图的强化特征图。
进一步地,所述利用所述注意力约束网络的池化层对所述第一特征图进行池化处理,得到注意力向量的步骤包括:
根据所述第一特征图及所述标签,从所述第一特征图中确定目标区域;
将所述目标区域输入至所述注意力约束网络的池化层进行处理,得到所述注意力向量。
进一步地,所述第一样本图像包括第一原始图像和对所述第一原始图像进行翻转得到的第一翻转图像,所述第二样本图像包括第二原始图像和对所述第二原始图像进行翻转得到的第二翻转图像,所述第一检测结果包括所述第一原始图像的检测结果和所述第一翻转图像的检测结果,所述第二检测结果包括所述第二原始图像的检测结果和所述第二翻转图像的检测结果,所述预设损失函数为:
Figure 728469DEST_PATH_IMAGE001
其中,
Figure 824601DEST_PATH_IMAGE002
表示预设损失函数,
Figure 270625DEST_PATH_IMAGE003
表示有监督损失函数,
Figure 463709DEST_PATH_IMAGE004
表示一致性损失权重函数,
Figure 602567DEST_PATH_IMAGE005
表示一致性损失函数,
Figure 999044DEST_PATH_IMAGE006
表示所述第一原始图像的标签,
Figure 248760DEST_PATH_IMAGE007
表示所述第一翻转图像的标签,
Figure 234033DEST_PATH_IMAGE008
表示所述第一原始图像的检测结果,
Figure 340530DEST_PATH_IMAGE009
表示所述第一翻转图像的检测结果,
Figure 411254DEST_PATH_IMAGE010
表示所述第二原始图像的检测结果,
Figure 464660DEST_PATH_IMAGE011
表示所述第二翻转图像的检测结果。
进一步地,所述第一原始图像的预测结果包括所述第一原始图像的类别预测结果和位置预测结果,所述第一翻转图像的预测结果包括所述第一翻转图像的类别预测结果和位置预测结果,所述第二原始图像的预测结果包括所述第二原始图像的类别预测结果和位置预测结果,所述第二翻转图像的预测结果包括所述第二翻转图像的类别预测结果和位置预测结果;
所述有监督损失函数为:
Figure 884534DEST_PATH_IMAGE012
Figure 365194DEST_PATH_IMAGE013
表示有监督分类损失函数,
Figure 657635DEST_PATH_IMAGE014
表示有监督位置损失函数,
Figure 577049DEST_PATH_IMAGE015
Figure 271336DEST_PATH_IMAGE016
分别表示所述第一原始图像的类别预测结果和位置预测结果,
Figure 922897DEST_PATH_IMAGE017
Figure 515683DEST_PATH_IMAGE018
分别表示所述第一翻转图像的类别预测结果和位置预测结果,
Figure 910893DEST_PATH_IMAGE019
Figure 459686DEST_PATH_IMAGE020
分别表示所述第一原始图像的类别标签和位置标签,
Figure 875624DEST_PATH_IMAGE021
Figure 142657DEST_PATH_IMAGE022
分别表示所述第一翻转图像的类别标签和位置标签;
所述一致性损失函数为:
Figure 121983DEST_PATH_IMAGE023
,其中,
Figure 790862DEST_PATH_IMAGE024
表示分类一致性损失函数,
Figure 49805DEST_PATH_IMAGE025
表示位置一致性损失函数;
Figure 866451DEST_PATH_IMAGE026
Figure 869042DEST_PATH_IMAGE027
分别表示所述第二原始图像的类别预测结果和位置预测结果,
Figure 392427DEST_PATH_IMAGE028
Figure 369742DEST_PATH_IMAGE029
分别表示所述第二翻转图像的类别预测结果和位置预测结果;
所述分类一致性损失函数为:
Figure 611367DEST_PATH_IMAGE030
,其中,
Figure 152070DEST_PATH_IMAGE031
表示求均值,
Figure 326699DEST_PATH_IMAGE032
表示
Figure 927445DEST_PATH_IMAGE033
Figure 390787DEST_PATH_IMAGE034
之间的詹森香农JS散度,
Figure 303556DEST_PATH_IMAGE035
表示
Figure 801533DEST_PATH_IMAGE036
Figure 307601DEST_PATH_IMAGE037
之间的詹森香农JS散度;
所述位置一致性损失函数为:
Figure 320556DEST_PATH_IMAGE038
其中,
Figure 203062DEST_PATH_IMAGE039
表示
Figure 555546DEST_PATH_IMAGE040
Figure 45564DEST_PATH_IMAGE041
之间的位置一致性损失,
Figure 483498DEST_PATH_IMAGE042
表示
Figure 497591DEST_PATH_IMAGE043
Figure 704581DEST_PATH_IMAGE044
之间的位置一致性损失。
第二方面,本发明实施例提供了一种图像处理方法,所述方法包括:
获取包含目标图像的待处理图像;
将所述待处理图像输入检测模型,所述检测模型是通过第一方面中的模型训练方法进行训练得到的,所述检测模型包括特征提取网络和回归网络;
利用所述检测模型的特征提取网络,得到特征图像;
利用所述检测模型的回归网络对所述特征图像进行目标检测,以检测出所述目标图像。
第三方面,本发明实施例还提供了一种模型训练装置,所述装置包括:
样本获取模块,获取带标签的第一样本图像和不带标签的第二样本图像;
训练模块,用于将所述第一样本图像及所述第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图;
所述训练模块,还用于将所述第一特征图和所述第二特征图输入至所述原始模型的注意力约束网络,得到所述第二特征图的强化特征图;
所述训练模块,还用于将所述第一特征图及所述强化特征图分别输入至所述原始模型的回归网络,得到所述原始模型的第一检测结果和第二检测结果;
所述训练模块,还用于根据所述第一检测结果、所述标签、所述第二检测结果及预设损失函数,对所述原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。
第四方面,本发明实施例还提供了一种图像处理装置,所述装置包括:
图像获取模块,用于获取包含目标图像的待处理图像;
处理模块,用于将所述待处理图像输入检测模型,利用所述检测模型的特征提取网络,得到特征图像,所述检测模型是通过第一方面中的模型训练方法进行训练得到的,所述检测模型包括特征提取网络和回归网络;
所述处理模块,还用于利用所述检测模型的回归网络对所述特征图像进行目标检测,以检测出所述目标图像。
第五方面,本发明实施例还提供了一种电子设备,包括处理器和存储器;所述存储器用于存储程序;所述处理器用于在执行所述程序时,实现上述第一方面中的模型训练方法和/或上述第二方面中的图像处理方法。
第六方面,本发明实施例还提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现如上述第一方面中的模型训练方法和/或上述第二方面中的图像处理方法。
相对于现有技术,本发明实施例提供的一种模型训练方法、图像处理方法及相关装置,通过对带标签的第一样本图像和不带标签的第二样本图像分别提取第一特征图和第二特征图,再将第一特征图和第二特征图输入至原始模型的注意力约束网络,得到第二特征图的强化特征图,最终将第一特征图和强化特征图分别输入回归网络,得到第一检测结果和第二检测结果,根据第一检测结果、标签、第二检测结果及预设损失函数,对原始模型的参数进行调整,最终得到检测模型,由于利用注意力约束网络对无标签的样本图像的特征图进行了特征强化,使得模型训练时,能够充分利用无标签的样本图像的特征,在减少样本图像打标签的工作量的同时,使训练得到的检测模型具有较高的检测精度和泛化能力。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,应当理解,以下附图仅示出了本发明的某些实施例,因此不应被看作是对范围的限定,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他相关的附图。
图1示出了本发明实施例提供的模型训练方法的流程示例图之一。
图2示出了本发明实施例提供的病灶区域标记的过程示例图。
图3示出了本发明实施例提供的模型训练方法的流程示例图之二。
图4示出了本发明实施例提供的利用第一特征图对第二特征图进行强化的示例图。
图5示出了本发明实施例提供的利用第一特征图的目标区域对第二特征图进行强化的示例图。
图6示出了本发明实施例提供的训练过程示例图。
图7示出了本发明实施例提供的图像处理方法的流程示例图。
图8示出了本发明实施例提供的模型训练装置的方框示意图。
图9示出了本发明实施例提供的图像处理装置的方框示意图。
图10示出了本发明实施例提供的电子设备的方框示意图。
图标:10-电子设备;11-处理器;12-存储器;13-总线;100-模型训练装置;110-样本获取模块;120-训练模块;200-图像处理装置;210-图像获取模块;220-处理模块。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本发明实施例的组件可以以各种不同的配置来布置和设计。
因此,以下对在附图中提供的本发明的实施例的详细描述并非旨在限制要求保护的本发明的范围,而是仅仅表示本发明的选定实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。
在本发明的描述中,需要说明的是,若出现术语“上”、“下”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,或者是该发明产品使用时惯常摆放的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。
此外,若出现术语“第一”、“第二”等仅用于区分描述,而不能理解为指示或暗示相对重要性。
需要说明的是,在不冲突的情况下,本发明的实施例中的特征可以相互结合。
现有的图像处理技术主要分为传统的机器学习方法和基于数据驱动的深度学习方法,其中,传统的机器学习方法通常利用图像处理算子提取影像特征,然后利用分类器区分目标区域和背景区域,例如可变形的组件模型DPM(Deformable Part Model, DPM)等。在基于深度学习的图像处理技术中,影像特征提取和分类全部由神经网络完成,实现了目标检测的自动化。通常使用到的网络模型有Faster-RCNN,Inception-ResNet-v2,FCN-AlexNet等。
常见的机器学习策略包括监督式学习和无监督式学习,两者的区别在于是否需要人工参与数据的标注。传统的监督学习训练需要预先收集大量数据并对其进行标注以构建训练集,然后在此基础上进行建模拟合,最后让模型预测未知数据的结果。监督式学习在进行模型训练时,通常需要对样本图像进行精准标注,得到具有准确标签的样本图像,但是对于标签的准确性要求和具有标签的样本图像的数据量规模限制了训练得到的检测模型的精度和准确性。针对这一问题,解决策略一般分为两类,第一类策略是充分利用少量数据来提升检测效果,常用的技术包括微调(fine-tuning)和元学习(Meta Learning)等。第二类策略是基于传统分类领域的半监督式学习策略,该策略通过合理利用无标签数据,降低数据成本,提升训练效果。
半监督学习的思路是:混合使用有标注数据(即带标签的样本图像)和无标注数据(即不带标签的样本图像),利用未标记的数据来提高模型的性能。半监督学习算法一般分为两类:一类是先用无标签数据预训练网络,然后用有标签数据对网络进行微调(fine-tune);另外一类是同时利用有标签数据和无标签数据训练网络,利用从网络中得到的深度特征来做半监督算法。例如,现有的基于一致性的半监督学习目标检测算法CSD(Consistency-based Semi-supervised Learning for object Detection, CSD), 通过引入一致性损失方法,强化网络对无标签数据的特征学习,使得变换后的无标签样本产生的检测框对应保持一致。
发明人发现,虽然现有的基于半监督学习有效减少了模型训练对带标签的样本图像的数据需求量,降低了模型训练的成本,但是,最终训练得到的检测模型的精度并不如人意,通过对现有技术中的各种实现方式的深入分析发现,其原因是半监督学习通常会产生大量无效背景框,干扰了模型训练时的学习效果,降低了训练出的检测模型的精度。
有鉴于此,本发明实施例提供一种模型训练方法、图像处理方法及相关装置,能够有效地过滤掉背景区域的负样本框,充分发挥了无标签数据在网络训练中的作用,提高检测模型的检测精度,下面将对其进行详细描述。
请参照图1,图1示出了本发明实施例提供的模型训练方法的流程示例图之一,该方法包括以下步骤:
步骤S100,获取带标签的第一样本图像和不带标签的第二样本图像。
在本实施例中,标签根据识别的需求而定,例如,需要识别的是第一样本图像中目标区域的位置,则标签可以是用于表征目标区域的坐标,目标区域的坐标具体可以是目标区域的左上角的角点的坐标和右下角的角点的坐标,再例如,需要识别的是第一样本图像中目标区域的类别,则标签可以是用于表征目标区域的类别的值,具体可以是:用0表示类别0,用1表示类别1等。
在本实施例中,对于一种具体的应用场景,例如对于超声图像而言,样本图像可以是包含病灶的超声切面图像、标签可以是病灶区域的标注信息,包括、但不限于表征病灶区域的位置,或者用于表征病灶类别,例如是恶性还是良性,也可以同时表征病灶区域的位置和类别。请参照图2,图2示出了本发明实施例提供的病灶区域标记的过程示例图,病灶标记的过程包括:(1)采集超声图像;(2)勾画病灶区域;(3)获取掩模图;(4)提取包裹框信息,包括病灶区域的宽度和高度及病灶区域的左上角的坐标。
除了样本图像和标签,还可以将样本图像对应的病理诊断信息作为模型训练的训练数据的一部分。
由于超声成像高度依赖操作者或诊断医生的扫描手法和阅片经验,给经验不足的医生带来了极大的挑战,同时超声因噪声和伪影引起的低成像质量制约着病灶的精准检测和诊断,利用本发明实施例提供的训练方法训练得到的检测模型,可以准确地识别超声图像中的病灶区域的位置和类别,能够辅助医生提高诊断准确性,降低漏误诊率。
在本实施例中,无论是第一样本图像还是第二样本图像,为了保证训练得到的检测模型的泛化性,通常会对样本图像(包括第一样本图像和第二样本图像)进行数据增强操作,数据增强操作包括、但不限于针对图像的像素内容的变换和针对空间几何的变换,针对像素内容的变换可以为随机改变图像亮度、随机改变对比度、色度、饱和度、随机改变颜色通道等;针对空间几何的变换可以为随机扩展、随机裁剪、随机镜像等。对于任一原始训练图像,可以随机选择变换方式,在选定变换方式后,可以随机选择该变换方式中的至少一种进行变换,当然,对于具体变换的亮度值、对比度值、色度值等也可以随机设定,例如,对于原始训练图像,随机选择的变换方式为针对像素内容的变换,可以进一步选择对原始训练图像进行图像亮度和对比度的改变, 将原始训练图像的亮度值改变为亮度值a,将其对比度值改变为对比度b,最终得到样本图像。
步骤S101,将第一样本图像及第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图。
在本实施例中,第一样本图像和第二样本图像可以先后输入,也可以同时输入,不管哪种输入方式,原始模型的特征提取网络均会对第一样本图像和第二样本图像分别进行特征提取,得到第一样本图像的第一特征图和第二样本图像的第二特征图。
在本实施例中,原始模型为需要训练的模型,原始模型包括特征提取网络、注意力约束网络及回归网络,其中,特征提取网络用于对样本图像进行特征提取得到样本图像的特征图,当输入特征提取网络的样本图像为第一样本图像时,特征提取网络输出的为第一样本图像的特征图(即第一特征图),当输入特征提取网络的样本图像为第二样本图像时,特征提取网络输出的为第二样本图像的特征图(即第二特征图)。
还需要说明的是,第一样本图像和第二样本图像可以采用相同的特征提取网络的参数进行特征提取,特征提取网络的参数包括、但不限于卷积层的层数、卷积核的大小等参数,由此,一方面可以节省参数数量,减小网络成本;另一方面第一样本图像学习到的特征可以帮助第二样本图像学习特征,以强化网络对第二样本图像的特征学习。
步骤S102,将第一特征图和第二特征图输入至原始模型的注意力约束网络,得到第二特征图的强化特征图。
在本实施例中,注意力约束网络用于利用注意力机制和对第二特征图进行特征强化。注意力机制(attention mechanism)是一种有效融合先验知识的技术。主流的注意力机制可以分为以下三种:通道注意力、空间注意力和自注意力(self-attention)。通道注意力旨在建模出不同通道之间的相关性,通过网络学习的方式自动获取到每个特征通道的重要程度,最后为每个通道赋予不同的权重系数,从而来强化重要的特征,抑制非重要的特征。空间注意力本质上是将原始图片中的空间信息通过空间转换模块,变换到另一个空间中并保留关键信息,为每个位置生成权重掩膜并加权输出,从而增强感兴趣的特定目标区域同时弱化不相关的背景区域。自注意力的目的是为了减少对外部信息依赖,尽可能地利用特征内部固有的信息进行注意力的交互,例如Non-Local模块,DANet、GC-Net等。
在本实施例中,由于第一特征图是根据带标签的第一样本图像进行特征提取得到的,利用第一特征图计算注意力权重,再通过注意力权重对第二特征图进行特征强化,加强无标签的第二样本图像中的特征信息,提升了检测精度。
步骤S103,将第一特征图及强化特征图分别输入至原始模型的回归网络,得到原始模型的第一检测结果和第二检测结果。
在本实施例中,第一特征图和强化特征图可以同时输入回归网络,也可以先后输入回归网络,无论哪种输入方式,回归网络均会根据第一特征图和强化特征图分别计算出对应的第一检测结果和第二检测结果。
在本实施例中,第一检测结果和第二检测结果分别为根据不同的输入得到预测结果。
步骤S104,根据第一检测结果、标签、第二检测结果及预设损失函数,对原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。
在本实施例中,预设损失函数包括有监督损失函数和一致性损失函数,有监督损失函数用于表征第一样本图像的检测结果和标签之间的偏差,其值是根据第一检测结果和标签计算得到的,一致性损失函数用于表征第一检测结果和第二检测结果之间的偏差,其值是根据第一检测结果和第二检测结果计算得到的。
在本实施例中,预设的训练完成条件可以是训练的次数达到预设次数,也可以是预测结果的准确度达到预设准确度,训练方式包括、但不限于利用梯度下降算法进行迭代训练。
本发明实施例提供的上述方法,通过对带标签的样本图像提取的特征图约束不带标签的样本图像提取的特征图,对不带标签的样本图像的特征图进行强化,有效地过滤掉了背景区域的负样本框,充分发挥了无标签数据在网络训练中的作用,提高了检测模型的精度。
在图1的基础上,本发明实施例还提供了一种对第二特征图进行强化的具体实现方式,请参照图3,图3示出了本发明实施例提供的模型训练方法的流程示例图之二,步骤S102包括以下子步骤:
子步骤S1021,利用注意力约束网络的池化层对第一特征图进行池化处理,得到注意力向量。
在本实施例中,注意力约束网络包括池化层和强化层,池化层用于计算注意力向量,注意力向量也称为注意力权重。强化层用于根据注意力向量对第二特征图进行特征强化,具体可以是根据注意力向量对第二特征图进行加权计算,得到强化特征图。
在本实施例中,可以将第一特征图直接输入池化层进行处理得到注意力向量,也可以先利用标签从第一特征图中确定目标区域,再将目标区域输入至池化层进行处理得到注意力向量,后者由于采用的是更有针对性的目标区域得到的注意力向量,因而可以排除掉更多与目标区域无关的噪声区域,实现更有效的特征强化。本发明实施例以第一特征图的目标区域为例,提供一种具体实现方式:
首先,根据第一特征图及标签,从第一特征图中确定目标区域。
在本实施例中,由于标签可以表征目标区域的位置,因此,根据标签可以从第一特征图中确定目标区域。
其次,将目标区域输入至注意力约束网络的池化层进行处理,得到注意力向量。
在本实施例中,池化层包括、但不限于平均池化、最大值池化及最小值池化等。
为了更清楚地说明利用第一特征图和利用第一特征图中的目标区域两种方式的区别,请参照图4和图5,图4示出了本发明实施例提供的利用第一特征图对第二特征图进行强化的示例图,图5示出了本发明实施例提供的利用第一特征图的目标区域对第二特征图进行强化的示例图,由图4和图5可以看出,图5中由于利用了更具有针对性的目标区域得到的注意力向量,因此,利用图5的注意力向量强化效果会更好。
子步骤S1022,将注意力向量和第二特征图输入至注意力约束网络的强化层进行特征强化,得到第二特征图的强化特征图。
在本实施例中,作为一种具体实现方式,可以对第一特征图i(H*W*C,H,W,C分别表示第一特征图的高、宽、通道)进行平均池化得到注意力向量P(1*1*C),然后将注意力向量
Figure 818031DEST_PATH_IMAGE045
在第二特征图j(H*W*C,H,W,C分别表示第二特征图的高、宽、通道)按每一个像素点位置的特征值相乘,第一特征图和第二特征图的高、宽、通道均一致,注意力向量P的通道与第一特征图及第二特征图的通道均一致,计算公式如下:
Figure 789267DEST_PATH_IMAGE046
Figure 13575DEST_PATH_IMAGE047
为经过注意力向量加权得到的加强特征图在位置(h,w,c)上的特征值,
Figure 340651DEST_PATH_IMAGE048
表示特征图j在位置(h,w,c)上的特征值。
在图1的基础上,本发明实施例在训练过程中采用一致性回归方法,不但利用正常图像作为样本,还同时利用了正常图像的翻转图像作为样本进行训练,并使得正常图像和翻转图像的预测框对应保持一致,提升了模型训练过程中对第二样本图像的特征学习。
在本实施例中,第一样本图像包括第一原始图像和对第一原始图像进行翻转得到的第一翻转图像,第二样本图像包括第二原始图像和对第二原始图像进行翻转得到的第二翻转图像,其中,第一原始图像可以是上述进行像素内容变换或者空间几何变换后的图像,无论是对第一原始图像进行翻转还是对第二原始图像进行翻转,其过程都是类似。
本实施例以对第一原始图像进行翻转进行说明,对第一原始图像进行翻转之前,根据实际情况确定是否需要对第一原始图像进行缩放,如果第一原始图像与原始模型要求的输入图像的大小不匹配,需要首先对第一原始图像进行缩放,以满足原始模型的输入需求,由于缩放可能造成目标图像的大小和位置发生变化,为了使最终训练得到的检测模型的精度不受缩放的影响,还需要对缩放后的第一原始图像的标签进行对应的调整,然后,对缩放后的第一原始图像进行翻转,翻转可以是竖直方向的翻转,也可以是水平方向的翻转,具体可以根据实际场景的需要进行设定,如果翻转后得到的第一翻转图像中的目标图像的大小和位置发生了变化,为了使最终训练得到的检测模型的精度不受缩放的影响,还需要对第一翻转图像的标签进行对应的调整。
在本实施例中,第一检测结果包括第一原始图像的检测结果和第一翻转图像的检测结果,第二检测结果包括第二原始图像的检测结果和第二翻转图像的检测结果,预设损失函数包括有监督损失函数和一致性损失函数,预设损失函数为:
Figure 625002DEST_PATH_IMAGE049
其中,
Figure 99845DEST_PATH_IMAGE050
表示预设损失函数,
Figure 127844DEST_PATH_IMAGE051
表示有监督损失函数,
Figure 309427DEST_PATH_IMAGE052
表示一致性损失权重函数,一致性损失权重函数的值随着时间逐渐增加的,
Figure 312149DEST_PATH_IMAGE053
表示一致性损失函数,
Figure 211972DEST_PATH_IMAGE054
表示第一原始图像的检测结果,
Figure 840399DEST_PATH_IMAGE055
表示第一原始图像的标签,
Figure 876488DEST_PATH_IMAGE056
表示第一翻转图像的检测结果,
Figure 502642DEST_PATH_IMAGE057
表示第一翻转图像的标签,
Figure 204275DEST_PATH_IMAGE058
表示第二原始图像的检测结果,
Figure 574076DEST_PATH_IMAGE059
表示第二翻转图像的检测结果。
在本实施例中,有监督损失函数的值是根据第一原始图像的检测结果和第一原始图像的标签,以及第一翻转图像的检测结果和第一翻转图像的标签得到的。一致性损失函数是根据第一原始图像的检测结果和第一翻转图像的检测结果、以及第二原始图像的检测结果和第二翻转图像的检测结果计算得到的。
作为一种具体实施方式,第一原始图像的预测结果包括第一原始图像的类别预测结果和位置预测结果,第一翻转图像的预测结果包括第一翻转图像的类别预测结果和位置预测结果,第二原始图像的预测结果包括第二原始图像的类别预测结果和位置预测结果,第二翻转图像的预测结果包括第二翻转图像的类别预测结果和位置预测结果,标签包括类别标签和位置标签。
在本实施例中,作为一种具体实施方式,有监督损失函数包括有监督分类损失函数和有监督位置损失函数,一致性损失函数包括分类一致性损失函数和位置一致性损失函数。
其中,有监督损失函数为:
Figure 730251DEST_PATH_IMAGE060
Figure 324044DEST_PATH_IMAGE061
表示有监督分类损失函数,
Figure 932880DEST_PATH_IMAGE062
表示有监督位置损失函数,
Figure 840793DEST_PATH_IMAGE063
Figure 664523DEST_PATH_IMAGE064
分别表示第一原始图像的类别预测结果和位置预测结果,
Figure 632479DEST_PATH_IMAGE065
Figure 728611DEST_PATH_IMAGE066
分别表示第一翻转图像的类别预测结果和位置预测结果,
Figure 502532DEST_PATH_IMAGE067
Figure 633299DEST_PATH_IMAGE068
分别表示第一原始图像的类别标签和位置标签,
Figure 772157DEST_PATH_IMAGE069
Figure 401590DEST_PATH_IMAGE070
分别表示第一翻转图像的类别标签和位置标签。
一致性损失函数为:
Figure 916885DEST_PATH_IMAGE071
,其中,
Figure 902158DEST_PATH_IMAGE072
表示分类一致性损失函数,
Figure 539813DEST_PATH_IMAGE073
表示位置一致性损失函数;
Figure 344958DEST_PATH_IMAGE074
Figure 398365DEST_PATH_IMAGE075
分别表示第二原始图像的类别预测结果和位置预测结果,
Figure 785615DEST_PATH_IMAGE076
Figure 266275DEST_PATH_IMAGE077
分别表示第二翻转图像的类别预测结果和位置预测结果。
分类一致性损失函数为:
Figure 824295DEST_PATH_IMAGE078
,其中,
Figure 478130DEST_PATH_IMAGE079
表示求均值,
Figure 437996DEST_PATH_IMAGE080
表示
Figure 404071DEST_PATH_IMAGE081
Figure 183809DEST_PATH_IMAGE082
之间的詹森香农JS散度(Jensen-ShannonDivergence,JS),
Figure 579018DEST_PATH_IMAGE083
表示
Figure 455707DEST_PATH_IMAGE084
Figure 278170DEST_PATH_IMAGE085
之间的詹森香农JS散度,即:首先,计算第一原始图像的类别预测结果和第一翻转图像的类别预测结果之间的第一JS散度,计算第二原始图像的类别预测结果和第二翻转图像的类别预测结果之间的第二JS散度,再计算第一JS散度和第二JS散度的均值。例如,第一原始图像的类别预测结果为a、b、c、d,第一翻转图像的类别预测结果为a`、b`、c`、d`,第二原始图像的类别预测结果为x、y、z,第二翻转图像的类别预测结果为x`、y`、z`,分别计算a和a`、b和b`、c和c`、d和d`之间的第一JS散度,第一JS散度为:js1、js2、js3、js4,分别计算x和x`、y和y`、z和z`之间的第二JS散度,第二JS散度为:js5、js6、js7,最后计算js1、js2、js3、js4、js5、js6、js7的均值,将均值作为分类一致性损失函数的取值。
位置一致性损失函数为:
Figure 545203DEST_PATH_IMAGE086
其中,
Figure 557152DEST_PATH_IMAGE087
表示
Figure 226031DEST_PATH_IMAGE088
Figure 219395DEST_PATH_IMAGE089
之间的位置一致性损失,
Figure 36041DEST_PATH_IMAGE090
表示
Figure 38632DEST_PATH_IMAGE091
Figure 562017DEST_PATH_IMAGE092
之间的位置一致性损失。首先,计算第一原始图像的位置预测结果和第一翻转图像的位置预测结果之间的第一位置一致性损失,计算第二原始图像的位置预测结果和第二翻转图像的位置预测结果之间的第二位置一致性损失,再计算第一位置一致性损失和第二位置一致性损失的均值。例如,第一原始图像的位置预测结果为m、n、o,第一翻转图像的预测结果为m`、n `、o `,第二原始图像为r、s、t,对应的第二翻转图像为r`、s `、t `,分别计算m和m`、n和n`、o和o`之间的第一位置一致性损失,第一位置一致性损失为:con1、con 2、con 3,分别计算r和r`、s和s `、t和t `之间的第二位置一致性损失,第二位置一致性损失为:con 4、con 5、con 6,最后计算con 1、con 2、con 3、con 4、con5、con 6的均值,将均值作为位置一致性损失函数的取值。
在本实施例中,作为一种具体实施方式,位置预测结果可以用预测框信息表示,预测框信息包括预测框的中心点位置、预测框的宽度及预测框的高度,位置一致性损失函数可以根据预测框的中心点位置相对于对应的锚框的中心点位置的偏移、预测框的宽度和高度分别相对于对应的锚框的宽度和高度的偏移计算得到,其中,锚框也称为Anchor Box,是目标检测算法中以锚点为中心,由算法预定义的多个不同长宽比的先验框,在本实施例中,每一个预测框均存在一个与其对应的锚框。
Figure 37867DEST_PATH_IMAGE093
的计算公式如下:
Figure 279492DEST_PATH_IMAGE094
其中,
Figure 882512DEST_PATH_IMAGE095
Figure 260404DEST_PATH_IMAGE096
Figure 861149DEST_PATH_IMAGE097
Figure 324492DEST_PATH_IMAGE098
Figure 481935DEST_PATH_IMAGE099
Figure 714333DEST_PATH_IMAGE100
分别表示第一原始图像的预测框的中心点位置相对于对应的锚框的中心位置的x坐标偏移和y坐标偏移的偏移,
Figure 548297DEST_PATH_IMAGE101
Figure 498935DEST_PATH_IMAGE102
分别表示第一翻转图像的预测框的中心点位置相对于对应的锚框的中心位置的x坐标偏移和y坐标偏移的偏移,
Figure 381441DEST_PATH_IMAGE103
Figure 48439DEST_PATH_IMAGE104
分别表示第一原始图像的预测框的宽、高相对于对应的锚框的宽和高的偏移量、
Figure 990987DEST_PATH_IMAGE105
Figure 428922DEST_PATH_IMAGE106
分别表示第一翻转图像的预测框的宽、高相对于对应的锚框的宽和高的偏移量,
Figure 177435DEST_PATH_IMAGE107
Figure 384425DEST_PATH_IMAGE108
对应、
Figure 497875DEST_PATH_IMAGE109
Figure 236155DEST_PATH_IMAGE110
对应,表示如下:
Figure 460463DEST_PATH_IMAGE111
Figure 849856DEST_PATH_IMAGE112
,其中,
Figure 868627DEST_PATH_IMAGE113
表示两组数据之间的对应关系。
Figure 281154DEST_PATH_IMAGE114
类似,
Figure 620737DEST_PATH_IMAGE115
的计算公式如下:
Figure 802320DEST_PATH_IMAGE116
其中,
Figure 257572DEST_PATH_IMAGE117
Figure 954133DEST_PATH_IMAGE118
Figure 520243DEST_PATH_IMAGE119
Figure 821912DEST_PATH_IMAGE120
Figure 261114DEST_PATH_IMAGE121
Figure 382654DEST_PATH_IMAGE122
分别表示第二原始图像的预测框的中心点位置相对于对应的锚框的中心位置的x坐标偏移和y坐标偏移的偏移,
Figure 814772DEST_PATH_IMAGE123
Figure 970947DEST_PATH_IMAGE124
分别表示第二翻转图像的预测框的中心点位置相对于对应的锚框的中心位置的x坐标偏移和y坐标偏移的偏移,
Figure 94235DEST_PATH_IMAGE125
Figure 703071DEST_PATH_IMAGE126
分别表示第二原始图像的预测框的宽、高相对于对应的锚框的宽和高的偏移量、
Figure 610984DEST_PATH_IMAGE127
Figure 683982DEST_PATH_IMAGE128
分别表示第二翻转图像的预测框的宽、高相对于对应的锚框的宽和高的偏移量,
Figure 651938DEST_PATH_IMAGE129
Figure 748070DEST_PATH_IMAGE130
对应、
Figure 272723DEST_PATH_IMAGE131
Figure 137911DEST_PATH_IMAGE132
对应,表示如下:
Figure 276769DEST_PATH_IMAGE133
Figure 922514DEST_PATH_IMAGE134
,其中,
Figure 437809DEST_PATH_IMAGE135
表示两组数据之间的对应关系。
需要说明的是,预测结果可以包括一个预测框,也可以包括多个预测框,当预测框为一个时,在原始模型训练过程中可以使用上述预设损失函数计算损失值,当预测框为多个时,每一个预测框的处理都是一样的,此时,针对任一预测框,可以使用上述预设损失函数计算每一预测框的损失值,再对所有预测框的损失值进行处理,得到总损失值,处理方式包括、但不限于加权计算、平均计算、取最大值、取最小值等。
在本实施例中,为了更清楚地说明训练过程,本发明实施例还提供了一种训练过程的示例图,请参照图6,图6示出了本发明实施例提供的训练过程示例图,图6中,虽然画出两个特征提取网络和两个回归网络,在实际实现上,特征提取网络只有一个,回归网络也只有一个,第一样本图像包括第一原始图像及第一原始图像经过翻转得到第一翻转图像,第一原始图像及第一翻转图像均带有各自的标签,第二样本图像包括第二原始图像及第二原始图像经过翻转得到第二翻转图像,第二原始图像及第二翻转图像均不带标签,将第一原始图像和第一翻转图像分别输入特征提取网络,得到第一原始图像的特征图和第一翻转图像的特征图,第一特征图包括第一原始图像的特征图和第一翻转图像的特征图,将第二原始图像和第二翻转图像分别输入特征提取网络,得到第二原始图像的特征图和第二翻转图像的特征图,第二特征图包括第二原始图像的特征图和第二翻转图像的特征图,利用注意力约束网络的池化层对第一特征图进行池化处理,得到注意力向量,将注意力向量和第二特征图输入至注意力约束网络的强化层进行特征强化,得到第二特征图的强化特征图,第二特征图的强化特征图包括第二原始图像的特征图的强化特征图和第二翻转图像的特征图的强化特征图,将第一特征图输入至回归网络得到第一检测结果,将强化特征图输入至回归网络得到第二检测结果,最终根据第一检测结果、标签、第二检测结果及预设损失函数,对原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。
在本实施例中,在对原始模型利用上述模型训练方法进行训练得到检测模型后,为了利用检测模型对待检测图像进行处理,以检测待检测图像中的目标图像,本发明实施例还提了一种图像处理的实现方式,请参照图7,图7示出了本发明实施例提供的图像处理方法的流程示例图,该方法包括以下步骤:
步骤S200,获取包含目标图像的待处理图像。
在本实施例中,待处理图像可以是超声影像,目标图像可以是病灶区域的图像,待处理图像还可以是由拍摄装置针对目标物体拍摄的图像,目标图像为该图像中目标物体的图像。
步骤S201,将待处理图像输入检测模型,检测模型是通过本申请实施例上述模型训练方法进行训练得到的,检测模型包括特征提取网络和回归网络。
在本实施例中,检测模型和原始模型的结构有所不同,检测模型不包括原始模型中的注意力约束网络。
步骤S202,利用检测模型的特征提取网络,得到特征图像。
在本实施例中,利用检测模型的特征提取网络得到特征图像的过程与前述实施例中得到第一特征图和第二特征图的过程一样,此处不再赘述。
步骤S203,利用检测模型的回归网络对特征图像进行目标检测,以检测出目标图像。
在本实施例中,利用检测模型的回归网络对特征图像进行目标检测,以检测出目标图像的过程与得到第一检测结果和第二检测结果的过程一样,此处不再赘述。
本发明实施例提供的上述方法,通过检测模型的特征提取网络,得到特征图像,通过检测模型的回归网络对特征图像进行目标检测,可以精确地检测出待检测图像中的目标图像。
为了执行上述实施例及各个可能的实施方式中模型训练方法的相应步骤,下面给出一种模型训练装置100的实现方式。请参照图8,图8示出了本发明实施例提供的模型训练装置100的方框示意图。需要说明的是,本实施例所提供的模型训练装置100,其基本原理及产生的技术效果和上述实施例相同,为简要描述,本实施例部分未提及指出。
模型训练装置100包括样本获取模块110和训练模块120。
样本获取模块110,用于获取带标签的第一样本图像和不带标签的第二样本图像。
训练模块120,用于将第一样本图像及第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图。
训练模块120,还用于将第一特征图和所述第二特征图输入至原始模型的注意力约束网络,得到第二特征图的强化特征图。
训练模块120,还用于将第一特征图及强化特征图分别输入至原始模型的回归网络,得到原始模型的第一检测结果和第二检测结果。
训练模块120,还用于根据第一检测结果、标签、第二检测结果及预设损失函数,对原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。
具体地,训练模块120具体用于:利用注意力约束网络的池化层对第一特征图进行池化处理,得到注意力向量;将注意力向量和第二特征图输入至注意力约束网络的强化层进行特征强化,得到第二特征图的强化特征图。
具体地,训练模块120在用于利用注意力约束网络的池化层对第一特征图进行池化处理,得到注意力向量时,具体用于:根据第一特征图及标签,从第一特征图中确定目标区域;将目标区域输入至注意力约束网络的池化层进行处理,得到注意力向量。
具体地,所述第一样本图像包括第一原始图像和对第一原始图像进行翻转得到的第一翻转图像,第二样本图像包括第二原始图像和对第二原始图像进行翻转得到的第二翻转图像,第一检测结果包括第一原始图像的检测结果和第一翻转图像的检测结果,第二检测结果包括第二原始图像的检测结果和第二翻转图像的检测结果,训练模块120中的预设损失函数为:
Figure 157503DEST_PATH_IMAGE136
其中,
Figure 467261DEST_PATH_IMAGE137
表示预设损失函数,
Figure 849570DEST_PATH_IMAGE138
表示有监督损失函数,
Figure 902977DEST_PATH_IMAGE139
表示一致性损失权重函数,
Figure 742757DEST_PATH_IMAGE140
表示一致性损失函数,
Figure 285734DEST_PATH_IMAGE141
表示第一原始图像的标签,
Figure 843754DEST_PATH_IMAGE142
表示第一翻转图像的标签,
Figure 248322DEST_PATH_IMAGE143
表示第一原始图像的检测结果,
Figure 942608DEST_PATH_IMAGE144
表示第一翻转图像的检测结果,
Figure 859748DEST_PATH_IMAGE145
表示第二原始图像的检测结果,
Figure 701803DEST_PATH_IMAGE146
表示第二翻转图像的检测结果。
具体地,第一原始图像的预测结果包括第一原始图像的类别预测结果和位置预测结果,第一翻转图像的预测结果包括第一翻转图像的类别预测结果和位置预测结果,第二原始图像的预测结果包括第二原始图像的类别预测结果和位置预测结果,第二翻转图像的预测结果包括第二翻转图像的类别预测结果和位置预测结果;
所述训练模块120中的有监督损失函数为:
Figure 97012DEST_PATH_IMAGE147
Figure 911384DEST_PATH_IMAGE148
表示有监督分类损失函数,
Figure 48361DEST_PATH_IMAGE149
表示有监督位置损失函数,
Figure 315394DEST_PATH_IMAGE150
Figure 576611DEST_PATH_IMAGE151
分别表示第一原始图像的类别预测结果和位置预测结果,
Figure 245490DEST_PATH_IMAGE152
Figure 238854DEST_PATH_IMAGE153
分别表示第一翻转图像的类别预测结果和位置预测结果,
Figure 806232DEST_PATH_IMAGE067
Figure 543244DEST_PATH_IMAGE154
分别表示第一原始图像的类别标签和位置标签,
Figure 66629DEST_PATH_IMAGE155
Figure 558791DEST_PATH_IMAGE156
分别表示第一翻转图像的类别标签和位置标签;
所述训练模块120中的一致性损失函数为:
Figure 800416DEST_PATH_IMAGE157
,其中,
Figure 75540DEST_PATH_IMAGE158
表示分类一致性损失函数,
Figure 30595DEST_PATH_IMAGE159
表示位置一致性损失函数;
Figure 365761DEST_PATH_IMAGE160
Figure 94683DEST_PATH_IMAGE161
分别表示第二原始图像的类别预测结果和位置预测结果,
Figure 235814DEST_PATH_IMAGE162
Figure 733792DEST_PATH_IMAGE163
分别表示第二翻转图像的类别预测结果和位置预测结果;
分类一致性损失函数为:
Figure 505439DEST_PATH_IMAGE164
,其中,
Figure 269126DEST_PATH_IMAGE031
表示求均值,
Figure 151632DEST_PATH_IMAGE165
表示
Figure 504116DEST_PATH_IMAGE150
Figure 508981DEST_PATH_IMAGE152
之间的詹森香农JS散度,
Figure 946915DEST_PATH_IMAGE166
表示
Figure 367532DEST_PATH_IMAGE167
Figure 154616DEST_PATH_IMAGE168
之间的詹森香农JS散度;
所述位置一致性损失函数为:
Figure 268066DEST_PATH_IMAGE169
其中,
Figure 255613DEST_PATH_IMAGE170
表示
Figure 479921DEST_PATH_IMAGE171
Figure 541418DEST_PATH_IMAGE172
之间的位置一致性损失,
Figure 638818DEST_PATH_IMAGE173
表示
Figure 51345DEST_PATH_IMAGE174
Figure 79344DEST_PATH_IMAGE175
之间的位置一致性损失。
为了执行上述实施例及各个可能的实施方式中模型训练方法的相应步骤,下面给出一种图像处理装置200的实现方式。请参照图9,图9示出了本发明实施例提供的图像处理装置200的方框示意图。需要说明的是,本实施例所提供的图像处理装置200,其基本原理及产生的技术效果和上述实施例相同,为简要描述,本实施例部分未提及指出。
图像处理装置200包括图像获取模块210和处理模块220。
图像获取模块210,用于获取包含目标图像的待处理图像。
处理模块220,用于将待处理图像输入检测模型,利用检测模型的特征提取网络,得到特征图像,检测模型是通过本申请实施例上述模型训练方法进行训练得到的,检测模型包括特征提取网络和回归网络。
处理模块220,还用于利用检测模型的回归网络对特征图像进行目标检测,以检测出目标图像。
请参照图10,图10示出了本申请实施例提供的电子设备10的方框示意图。电子设备10可以是计算机设备,例如,智能手机、平板电脑、个人电脑、服务器、地面站、私有云、公有云等中的任意一种,上述设备都可以用于实现上述实施例提供的模型训练方法或者图像处理方法,具体可根据实际应用场景确定,在此不作限制。电子设备10包括处理器11、存储器12及总线13,处理器11通过总线13与存储器12连接。
存储器12用于存储程序,例如图8所示的模型训练装置100或者图9所示的图像处理装置200,模型训练装置100或者图像处理装置200均包括至少一个可以软件或固件(firmware)的形式存储于存储器12中的软件功能模块,处理器11在接收到执行指令后,执行所述程序以实现上述实施例揭示的模型训练方法或者图像处理方法。
存储器12可能包括高速随机存取存储器(Random Access Memory,RAM),也可能还包括非易失存储器(non-volatile memory,NVM)。
处理器11可能是一种集成电路芯片,具有信号的处理能力。在实现过程中,上述方法的各步骤可以通过处理器11中的硬件的集成逻辑电路或者软件形式的指令完成。上述的处理器11可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、微控制单元(Microcontroller Unit,MCU)、复杂可编程逻辑器件(Complex Programmable LogicDevice,CPLD)、现场可编程门阵列(Field Programmable Gate Array,FPGA)、嵌入式ARM等芯片。
本申请实施例还提供了一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器11执行时实现上述实施例揭示的模型训练方法或者图像处理方法。
综上所述,本发明实施例提供了一种模型训练方法、图像处理方法及相关装置,所述方法包括:获取带标签的第一样本图像和不带标签的第二样本图像;将所述第一样本图像及所述第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图;将所述第一特征图和所述第二特征图输入至所述原始模型的注意力约束网络,得到所述第二特征图的强化特征图;将所述第一特征图及所述强化特征图分别输入至所述原始模型的回归网络,得到所述原始模型的第一检测结果和第二检测结果;根据所述第一检测结果、所述标签、所述第二检测结果及预设损失函数,对所述原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。与现有技术相比,本发明实施例利用注意力约束网络对无标签的样本图像的特征图进行了特征强化,使得模型训练时,能够充分利用无标签的样本图像的特征,在减少样本图像打标签的工作量的同时,使训练得到的检测模型具有较高的检测精度和泛化能力。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以所述权利要求的保护范围为准。

Claims (10)

1.一种模型训练方法,其特征在于,所述方法包括:
获取带标签的第一样本图像和不带标签的第二样本图像;
将所述第一样本图像及所述第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图;
将所述第一特征图和所述第二特征图输入至所述原始模型的注意力约束网络,得到所述第二特征图的强化特征图;
将所述第一特征图及所述强化特征图分别输入至所述原始模型的回归网络,得到所述原始模型的第一检测结果和第二检测结果;
根据所述第一检测结果、所述标签、所述第二检测结果及预设损失函数,对所述原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。
2.如权利要求1所述的模型训练方法,其特征在于,所述将所述第一特征图和所述第二特征图输入至所述原始模型的注意力约束网络,得到所述第二特征图的强化特征图的步骤包括:
利用所述注意力约束网络的池化层对所述第一特征图进行池化处理,得到注意力向量;
将所述注意力向量和所述第二特征图输入至所述注意力约束网络的强化层进行特征强化,得到所述第二特征图的强化特征图。
3.如权利要求2所述的模型训练方法,其特征在于,所述利用所述注意力约束网络的池化层对所述第一特征图进行池化处理,得到注意力向量的步骤包括:
根据所述第一特征图及所述标签,从所述第一特征图中确定目标区域;
将所述目标区域输入至所述注意力约束网络的池化层进行处理,得到所述注意力向量。
4.如权利要求1所述的模型训练方法,其特征在于,所述第一样本图像包括第一原始图像和对所述第一原始图像进行翻转得到的第一翻转图像,所述第二样本图像包括第二原始图像和对所述第二原始图像进行翻转得到的第二翻转图像,所述第一检测结果包括所述第一原始图像的检测结果和所述第一翻转图像的检测结果,所述第二检测结果包括所述第二原始图像的检测结果和所述第二翻转图像的检测结果,所述预设损失函数为:
Figure DEST_PATH_IMAGE001
其中,
Figure 518934DEST_PATH_IMAGE002
表示预设损失函数,
Figure DEST_PATH_IMAGE003
表示有监督损失函数,
Figure 788241DEST_PATH_IMAGE004
表示一致性损失权重函数,
Figure DEST_PATH_IMAGE005
表示一致性损失函数,
Figure 636112DEST_PATH_IMAGE006
表示所述第一原始图像的标签,
Figure DEST_PATH_IMAGE007
表示所述第一翻转图像的标签,
Figure 129717DEST_PATH_IMAGE008
表示所述第一原始图像的检测结果,
Figure DEST_PATH_IMAGE009
表示所述第一翻转图像的检测结果,
Figure 416342DEST_PATH_IMAGE010
表示所述第二原始图像的检测结果,
Figure DEST_PATH_IMAGE011
表示所述第二翻转图像的检测结果。
5.如权利要求4所述的模型训练方法,其特征在于,所述第一原始图像的预测结果包括所述第一原始图像的类别预测结果和位置预测结果,所述第一翻转图像的预测结果包括所述第一翻转图像的类别预测结果和位置预测结果,所述第二原始图像的预测结果包括所述第二原始图像的类别预测结果和位置预测结果,所述第二翻转图像的预测结果包括所述第二翻转图像的类别预测结果和位置预测结果;
所述有监督损失函数为:
Figure 743418DEST_PATH_IMAGE012
Figure DEST_PATH_IMAGE013
表示有监督分类损失函数,
Figure 575239DEST_PATH_IMAGE014
表示有监督位置损失函数,
Figure DEST_PATH_IMAGE015
Figure 315662DEST_PATH_IMAGE016
分别表示所述第一原始图像的类别预测结果和位置预测结果,
Figure DEST_PATH_IMAGE017
Figure 124087DEST_PATH_IMAGE018
分别表示所述第一翻转图像的类别预测结果和位置预测结果,
Figure DEST_PATH_IMAGE019
Figure 367986DEST_PATH_IMAGE020
分别表示所述第一原始图像的类别标签和位置标签,
Figure DEST_PATH_IMAGE021
Figure 823238DEST_PATH_IMAGE022
分别表示所述第一翻转图像的类别标签和位置标签;
所述一致性损失函数为:
Figure DEST_PATH_IMAGE023
,其中,
Figure 801690DEST_PATH_IMAGE024
表示分类一致性损失函数,
Figure DEST_PATH_IMAGE025
表示位置一致性损失函数;
Figure 367800DEST_PATH_IMAGE026
Figure DEST_PATH_IMAGE027
分别表示所述第二原始图像的类别预测结果和位置预测结果,
Figure 200627DEST_PATH_IMAGE028
Figure DEST_PATH_IMAGE029
分别表示所述第二翻转图像的类别预测结果和位置预测结果;
所述分类一致性损失函数为:
Figure 406874DEST_PATH_IMAGE030
,其中,
Figure DEST_PATH_IMAGE031
表示求均值,
Figure 793993DEST_PATH_IMAGE032
表示
Figure DEST_PATH_IMAGE033
Figure 694953DEST_PATH_IMAGE034
之间的詹森香农JS散度,
Figure DEST_PATH_IMAGE035
表示
Figure 664177DEST_PATH_IMAGE036
Figure DEST_PATH_IMAGE037
之间的詹森香农JS散度;
所述位置一致性损失函数为:
Figure 789128DEST_PATH_IMAGE038
其中,
Figure DEST_PATH_IMAGE039
表示
Figure 397964DEST_PATH_IMAGE040
Figure DEST_PATH_IMAGE041
之间的位置一致性损失,
Figure 617461DEST_PATH_IMAGE042
表示
Figure DEST_PATH_IMAGE043
Figure 690459DEST_PATH_IMAGE044
之间的位置一致性损失。
6.一种图像处理方法,其特征在于,所述方法包括:
获取包含目标图像的待处理图像;
将所述待处理图像输入检测模型,所述检测模型是通过权利要求1-5中任一项所述模型训练方法进行训练得到的,所述检测模型包括特征提取网络和回归网络;
利用所述检测模型的特征提取网络,得到特征图像;
利用所述检测模型的回归网络对所述特征图像进行目标检测,以检测出所述目标图像。
7.一种模型训练装置,其特征在于,所述装置包括:
样本获取模块,获取带标签的第一样本图像和不带标签的第二样本图像;
训练模块,用于将所述第一样本图像及所述第二样本图像分别输入至原始模型的特征提取网络,得到第一特征图和第二特征图;
所述训练模块,还用于将所述第一特征图和所述第二特征图输入至所述原始模型的注意力约束网络,得到所述第二特征图的强化特征图;
所述训练模块,还用于将所述第一特征图及所述强化特征图分别输入至所述原始模型的回归网络,得到所述原始模型的第一检测结果和第二检测结果;
所述训练模块,还用于根据所述第一检测结果、所述标签、所述第二检测结果及预设损失函数,对所述原始模型的参数进行调整,直至达到预设的训练完成条件,得到检测模型。
8.一种图像处理装置,其特征在于,所述装置包括:
图像获取模块,用于获取包含目标图像的待处理图像;
处理模块,用于将所述待处理图像输入检测模型,利用所述检测模型的特征提取网络,得到特征图像,所述检测模型是通过权利要求1-5中任一项所述模型训练方法进行训练得到的,所述检测模型包括特征提取网络和回归网络;
所述处理模块,还用于利用所述检测模型的回归网络对所述特征图像进行目标检测,以检测出所述目标图像。
9.一种电子设备,其特征在于,包括处理器和存储器;所述存储器用于存储程序;所述处理器用于在执行所述程序时,实现如权利要求1-5中任一项所述的模型训练方法和/或权利要求6所述的图像处理方法。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现如权利要求1-5中任一项所述的模型训练方法和/或权利要求6所述的图像处理方法。
CN202210115405.4A 2022-02-07 2022-02-07 模型训练方法、图像处理方法及相关装置 Active CN114155365B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210115405.4A CN114155365B (zh) 2022-02-07 2022-02-07 模型训练方法、图像处理方法及相关装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210115405.4A CN114155365B (zh) 2022-02-07 2022-02-07 模型训练方法、图像处理方法及相关装置

Publications (2)

Publication Number Publication Date
CN114155365A true CN114155365A (zh) 2022-03-08
CN114155365B CN114155365B (zh) 2022-06-14

Family

ID=80449938

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210115405.4A Active CN114155365B (zh) 2022-02-07 2022-02-07 模型训练方法、图像处理方法及相关装置

Country Status (1)

Country Link
CN (1) CN114155365B (zh)

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114612717A (zh) * 2022-03-09 2022-06-10 四川大学华西医院 Ai模型训练标签生成方法、训练方法、使用方法及设备
CN114638829A (zh) * 2022-05-18 2022-06-17 安徽数智建造研究院有限公司 隧道衬砌检测模型的抗干扰训练方法及隧道衬砌检测方法
CN114724183A (zh) * 2022-04-08 2022-07-08 平安科技(深圳)有限公司 人体关键点检测方法、系统、电子设备及可读存储介质
CN115439686A (zh) * 2022-08-30 2022-12-06 一选(浙江)医疗科技有限公司 一种基于扫描影像的关注对象检测方法及系统
CN115861684A (zh) * 2022-11-18 2023-03-28 百度在线网络技术(北京)有限公司 图像分类模型的训练方法、图像分类方法及装置
CN114724183B (zh) * 2022-04-08 2024-05-24 平安科技(深圳)有限公司 人体关键点检测方法、系统、电子设备及可读存储介质

Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111539947A (zh) * 2020-04-30 2020-08-14 上海商汤智能科技有限公司 图像检测方法及相关模型的训练方法和相关装置、设备
CN112200722A (zh) * 2020-10-16 2021-01-08 鹏城实验室 图像超分辨重构模型的生成方法、重构方法及电子设备
CN112818903A (zh) * 2020-12-10 2021-05-18 北京航空航天大学 一种基于元学习和协同注意力的小样本遥感图像目标检测方法
CN112949549A (zh) * 2021-03-19 2021-06-11 中山大学 一种基于超分辨率的多分辨率遥感影像的变化检测方法
CN113240655A (zh) * 2021-05-21 2021-08-10 深圳大学 一种自动检测眼底图像类型的方法、存储介质及装置
CN113392855A (zh) * 2021-07-12 2021-09-14 昆明理工大学 一种基于注意力和对比学习的小样本目标检测方法
CN113449775A (zh) * 2021-06-04 2021-09-28 广州大学 一种基于类激活映射机制的多标签图像分类方法和系统
US20210326656A1 (en) * 2020-04-15 2021-10-21 Adobe Inc. Panoptic segmentation
CN113688931A (zh) * 2021-09-01 2021-11-23 什维新智医疗科技(上海)有限公司 一种基于深度学习的超声图像筛选方法和装置
CN113971764A (zh) * 2021-10-29 2022-01-25 燕山大学 一种基于改进YOLOv3的遥感图像小目标检测方法
CN114004760A (zh) * 2021-10-22 2022-02-01 北京工业大学 图像去雾方法、电子设备、存储介质和计算机程序产品

Patent Citations (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20210326656A1 (en) * 2020-04-15 2021-10-21 Adobe Inc. Panoptic segmentation
CN111539947A (zh) * 2020-04-30 2020-08-14 上海商汤智能科技有限公司 图像检测方法及相关模型的训练方法和相关装置、设备
CN112200722A (zh) * 2020-10-16 2021-01-08 鹏城实验室 图像超分辨重构模型的生成方法、重构方法及电子设备
CN112818903A (zh) * 2020-12-10 2021-05-18 北京航空航天大学 一种基于元学习和协同注意力的小样本遥感图像目标检测方法
CN112949549A (zh) * 2021-03-19 2021-06-11 中山大学 一种基于超分辨率的多分辨率遥感影像的变化检测方法
CN113240655A (zh) * 2021-05-21 2021-08-10 深圳大学 一种自动检测眼底图像类型的方法、存储介质及装置
CN113449775A (zh) * 2021-06-04 2021-09-28 广州大学 一种基于类激活映射机制的多标签图像分类方法和系统
CN113392855A (zh) * 2021-07-12 2021-09-14 昆明理工大学 一种基于注意力和对比学习的小样本目标检测方法
CN113688931A (zh) * 2021-09-01 2021-11-23 什维新智医疗科技(上海)有限公司 一种基于深度学习的超声图像筛选方法和装置
CN114004760A (zh) * 2021-10-22 2022-02-01 北京工业大学 图像去雾方法、电子设备、存储介质和计算机程序产品
CN113971764A (zh) * 2021-10-29 2022-01-25 燕山大学 一种基于改进YOLOv3的遥感图像小目标检测方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
B SINGH ET AL: "《An Analysis of Scale Invariance in Object Detection Snip》", 《IEEE》 *
刘鑫辰: "《城市视频监控网络中车辆搜索关键技术研究》", 《中国博士学位论文全文数据库 信息科技辑》 *
陈珺莹: "《基于区域信息增强的细粒度图像分类研究及应用》", 《中国优秀硕士学位论文全文数据库 信息科技辑》 *

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114612717A (zh) * 2022-03-09 2022-06-10 四川大学华西医院 Ai模型训练标签生成方法、训练方法、使用方法及设备
CN114724183A (zh) * 2022-04-08 2022-07-08 平安科技(深圳)有限公司 人体关键点检测方法、系统、电子设备及可读存储介质
CN114724183B (zh) * 2022-04-08 2024-05-24 平安科技(深圳)有限公司 人体关键点检测方法、系统、电子设备及可读存储介质
CN114638829A (zh) * 2022-05-18 2022-06-17 安徽数智建造研究院有限公司 隧道衬砌检测模型的抗干扰训练方法及隧道衬砌检测方法
CN115439686A (zh) * 2022-08-30 2022-12-06 一选(浙江)医疗科技有限公司 一种基于扫描影像的关注对象检测方法及系统
CN115439686B (zh) * 2022-08-30 2024-01-09 一选(浙江)医疗科技有限公司 一种基于扫描影像的关注对象检测方法及系统
CN115861684A (zh) * 2022-11-18 2023-03-28 百度在线网络技术(北京)有限公司 图像分类模型的训练方法、图像分类方法及装置
CN115861684B (zh) * 2022-11-18 2024-04-09 百度在线网络技术(北京)有限公司 图像分类模型的训练方法、图像分类方法及装置

Also Published As

Publication number Publication date
CN114155365B (zh) 2022-06-14

Similar Documents

Publication Publication Date Title
CN114155365B (zh) 模型训练方法、图像处理方法及相关装置
CN109325954B (zh) 图像分割方法、装置及电子设备
Liu et al. Blind image quality assessment by relative gradient statistics and adaboosting neural network
US11830230B2 (en) Living body detection method based on facial recognition, and electronic device and storage medium
Li et al. Robust saliency detection via regularized random walks ranking
EP3333768A1 (en) Method and apparatus for detecting target
Jiang et al. Robust feature matching for remote sensing image registration via linear adaptive filtering
JP6232982B2 (ja) 画像処理装置、画像処理方法およびプログラム
CN107633237B (zh) 图像背景分割方法、装置、设备及介质
CN108986152B (zh) 一种基于差分图像的异物检测方法及装置
CN111626163B (zh) 一种人脸活体检测方法、装置及计算机设备
CN111429482A (zh) 目标跟踪方法、装置、计算机设备和存储介质
CN110598715A (zh) 图像识别方法、装置、计算机设备及可读存储介质
CN111814905A (zh) 目标检测方法、装置、计算机设备和存储介质
CN114359665A (zh) 全任务人脸识别模型的训练方法及装置、人脸识别方法
CN111382791B (zh) 深度学习任务处理方法、图像识别任务处理方法和装置
CN112348116A (zh) 利用空间上下文的目标检测方法、装置和计算机设备
CN107221005B (zh) 物体检测方法及装置
CN111382638B (zh) 一种图像检测方法、装置、设备和存储介质
CN114444565A (zh) 一种图像篡改检测方法、终端设备及存储介质
CN111507288A (zh) 图像检测方法、装置、计算机设备和存储介质
CN113095310B (zh) 人脸位置检测方法、电子设备和存储介质
Hossny et al. Towards autonomous image fusion
CN109871814B (zh) 年龄的估计方法、装置、电子设备和计算机存储介质
CN116977895A (zh) 用于通用相机镜头的污渍检测方法、装置及计算机设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant