CN114445679A - 模型训练方法及相关装置、设备和存储介质 - Google Patents

模型训练方法及相关装置、设备和存储介质 Download PDF

Info

Publication number
CN114445679A
CN114445679A CN202210101882.5A CN202210101882A CN114445679A CN 114445679 A CN114445679 A CN 114445679A CN 202210101882 A CN202210101882 A CN 202210101882A CN 114445679 A CN114445679 A CN 114445679A
Authority
CN
China
Prior art keywords
source domain
model
target
prediction result
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Withdrawn
Application number
CN202210101882.5A
Other languages
English (en)
Inventor
宋涛
张少霆
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Sensetime Intelligent Technology Co Ltd
Original Assignee
Shanghai Sensetime Intelligent Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Sensetime Intelligent Technology Co Ltd filed Critical Shanghai Sensetime Intelligent Technology Co Ltd
Priority to CN202210101882.5A priority Critical patent/CN114445679A/zh
Publication of CN114445679A publication Critical patent/CN114445679A/zh
Withdrawn legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/217Validation; Performance evaluation; Active pattern learning techniques

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请公开了一种模型训练方法及相关装置、设备和存储介质,方法包括:获取基于源域样本数据训练得到的目标模型和评价模型,其中,评价模型用于对目标模型输出的预测结果的准确性进行评价;利用目标模型对目标域样本数据进行预测,得到目标域预测结果;利用评价模型对目标域预测结果进行评价,得到目标域评价结果;基于目标域评价结果,调整目标模型的网络参数。通过该方法,实现了目标模型的域适应。

Description

模型训练方法及相关装置、设备和存储介质
技术领域
本申请涉及深度学习技术领域,特别是涉及一种模型训练方法及相关装置、设备和存储介质。
背景技术
深度学习的快速发展,各行各业使用神经网络模型进行工作已经成为常态。例如,在医学领域,利用神经网络模型进行医学图像分割。又如,在交通领域,利用神经网络模型进行车辆识别。
随着神经网络模型的逐渐普及,神经网络模型的域不适应问题变得日益严重。神经网络模型的域不适应问题主要表现为利用源域数据训练的模型,在目标域数据上的效果不好。域不适应的问题极大地限制了神经网络模型的进一步普及。
因此,如何解决域不适应问题,是当下研究的重点,对于促进神经网络模型的进一步普及,具有重要的意义。
发明内容
本申请至少提供一种模型训练方法及相关装置、设备和存储介质。
本申请第一方面提供了一种模型训练方法,方法包括:获取基于源域样本数据训练得到的目标模型和评价模型,其中,评价模型用于对目标模型输出的预测结果的准确性进行评价;利用目标模型对目标域样本数据进行预测,得到目标域预测结果;利用评价模型对目标域预测结果进行评价,得到目标域评价结果;基于目标域评价结果,调整目标模型的网络参数。
因此,通过获得基于源域样本数据训练得到的目标模型和评价模型,并利用评价模型对目标模型基于目标域样本数据预测得到的目标域预测结果进行评价,以此实现了目标模型在目标域的训练,有助于提高目标模型在目标域样本数据的预测准确性,以此实现了目标模型的域适应。
其中,上述的获取基于源域样本数据训练得到的目标模型和评价模型,包括:基于源域样本数据对目标模型和评价模型进行至少一次迭代训练,其中,每次训练所基于的源域样本数据相同或不同。
因此,通过利用源域样本数据对目标模型和评价模型进行至少一次迭代训练,可以提高目标模型输出的预测结果的准确性,同时也能提高评价模型的评价准确度。
其中,上述的每次基于源域样本数据对目标模型和评价模型进行训练,包括:利用目标模型对源域样本数据进行预测,得到本次训练对应的第一源域预测结果;基于本次训练对应的第一源域预测结果,调整目标模型的网络参数;以及利用评价模型对第二源域预测结果进行评价,得到源域评价结果,其中,第二源域预测结果包括本次训练和/或历史训练对应的第一源域预测结果;基于第二源域预测结果以及源域评价结果,调整评价模型的网络参数。
因此,通过利用评价模型对第二源域预测结果进行评价来得到源域评价结果,后续便可基于第二源域预测结果以及源域评价结果,来调整评价模型的网络参数,以此实现对评价模型的训练。
其中,在利用评价模型对第二源域预测结果进行评价,得到源域评价结果之前,方法还包括:基于目标模型当前预测的准确性,将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果;其中,目标模型当前预测的准确性是基于前若干次训练对应的第一源域预测结果或前若干次训练中评价模型输出的评价结果确定的。
因此,通过确定目标模型当前预测的准确性,可以基于目标模型的训练程度,决定是否将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果。
其中,上述的基于目标模型当前预测的准确性,将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果,包括:响应于目标模型当前预测的准确性满足第一预设要求,选出至少一个历史训练对应的第一源域预测结果,并将本次训练对应的第一源域预测结果和选出的第一源域预测结果,作为第二源域预测结果;响应于目标模型当前预测的准确性不满足第一预设要求,将本次训练对应的第一源域预测结果作为第二源域预测结果。
因此,通过确定目标模型当前预测的准确性是否满足第一预设要求,可以确定是否将历史训练对应的第一源域预测结果作为第二源域预测结果,以此灵活控制输入到评价模型的预测结果的数量,而且根据在目标模型的预测准确性可增加历史训练的预测结果对评价模型进行训练,可提高对评价模型的训练效果。
其中,在利用目标模型对源域样本数据进行预测,得到本次训练对应的第一源域预测结果之后,方法还包括:响应于当前满足第二预设要求,将本次训练对应的第一源域预测结果保存至预设结果集中;选出至少一个历史训练对应的第一源域预测结果,包括:从预设结果集中选出至少一个第一源域预测结果。
因此,通过判断目标模型当前是否满足第二预设要求,可以确定是否将本次训练对应的第一源域预测结果保存至预设结果集中,进而可实现后续从预设结果集中获取到历史训练的预测结果,以用于对评价模型进行训练。
其中,上述的第二预设要求包括以下至少一者:目标模型当前预测的准确性不满足第一预设要求,当前目标模型的训练次数少于预设数量。
因此,当目标模型当前预测的准确性不满足第一预设要求和/或当前目标模型的训练次数少于预设数量时,通过将本次训练对应的第一源域预测结果保存至预设结果集中,可以使得预设结果集中能够存储有预测结果准确率较差的第一源域预测结果,使得后续从预设结果集中选择第一源域预测结果作为第二源域预测结果来对评价模型进行训练时,训练的样本数据更加丰富,提高评价模型对准确率较差的第一源域预测结果的识别能力,有助于提高评价模型的训练效果。
预测结果集中的第一源域预测结果是按照第一源域预测结果对应的第一损失值划分在不同损失值区间中;述从预设结果集中选出至少一个第一源域预测结果,包括:从预设结果集中随机选出至少一个第一源域预测结果;或者,从预设结果的各损失值区间中分别选出至少一个第一源域预测结果。
因此,通过利用第一损失值来对预测结果集中的第一源域预测结果进行分类,实现基于第一源域预测结果的准确率进行分类。其中,上述的基于第二源域预测结果以及源域评价结果,调整评价模型的网络参数,包括:获取第二源域预测结果对应的第一损失值,其中,第一损失值是基于第二源域预测结果与对应的源域数据的源域标注信息之间的差异确定的;基于第一损失值和源域评价结果,调整评价模型的网络参数。
因为第一损失值能够表示第二源域预测结果的优劣,源域评价结果也能够表示第二源域预测结果的优劣,因此可以将第一损失值作为标签信息,通过比较第一损失值和源域评价结果的差异,来调整评价模型的网络参数,使得源域评价结果能够与第一损失值相互对应。
其中,上述的调整目标模型的网络参数,包括:调整目标模型的部分网络层的参数。
因此,通过确定仅调整目标模型的部分网络层的参数,可以较少需要调整的参数量,有助于提高训练速度。
其中,上述的基于目标域评价结果,调整目标模型的网络参数,包括:基于目标域评价结果,得到第二损失值;基于第二损失值,调整目标模型的网络参数。
因此,通过基于目标域评价结果得到第二损失值,后续便可基于第二损失值,调整目标模型的网络参数,以此实现对目标模型的训练。
其中,上述的基于目标域评价结果,得到第二损失值,包括:对目标域评价结果进行预设运算,得到第二损失值;和/或,第二损失值与目标模型输出的预测结果的准确性为负相关关系。
因此,通过将第二损失值设置为与目标模型输出的预测结果的准确性为负相关关系,可以直观地通过第二损失值判断目标域预测结果的优劣。
其中,上述的源域样本数据和目标域样本数据均为包含目标器官的三维图像;目标模型为图像分割模型。
因此,通过限定源域样本数据和目标域样本数据均为包含目标器官的三维图像,使得在利用源域样本数据训练目标模型和评价模型时,评价模型能够学习到关于目标器官的先验,例如是标签空间的分布以及目标器官的形状。
本申请第二方面提供了一种模型训练装置,装置包括:获取模块、预测模块、确定模块和调整模块。预测模块用于利用目标模型对目标域样本数据进行预测,得到目标域预测结果;确定模块用于利用评价模型对目标域预测结果进行评价,得到目标域评价结果;调整模块用于基于目标域评价结果,调整目标模型的网络参数。
本申请第三方面提供了一种电子设备,包括相互耦接的存储器和处理器,处理器用于执行存储器中存储的程序指令,以实现上述第一方面中的模型训练方法。
本申请第四方面提供了一种计算机可读存储介质,其上存储有程序指令,程序指令被处理器执行时实现上述第一方面中的模型训练方法。
上述方案,通过获得基于源域样本数据训练得到的目标模型和评价模型,并利用评价模型对目标模型基于目标域样本数据预测得到的目标域预测结果进行评价,以此实现了目标模型在目标域的训练,有助于提高目标模型在目标域样本数据的预测准确性,以此实现了目标模型的域适应。
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,而非限制本申请。
附图说明
此处的附图被并入说明书中并构成本说明书的一部分,这些附图示出了符合本申请的实施例,并与说明书一起用于说明本申请的技术方案。
图1是本申请模型训练方法第一实施例的第一流程示意图;
图2是本申请模型训练方法第二实施例的流程示意图;
图3是本申请模型训练方法第三实施例的流程示意图;
图4是本申请模型训练方法第四实施例的流程示意图;
图5是本申请模型训练方法第一实施例的第二流程示意图;
图6是本申请模型训练方法实施例目标模型的一结构示意图;
图7是本申请模型训练方法实施例评价模型的一结构示意图;
图8是本申请用于模型训练装置的一框架示意图;
图9是本申请电子设备一实施例的框架示意图;
图10是本申请计算机可读存储介质一实施例的框架示意图。
具体实施方式
下面结合说明书附图,对本申请实施例的方案进行详细说明。
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、接口、技术之类的具体细节,以便透彻理解本申请。
本文中术语“和/或”,仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。另外,本文中字符“/”,一般表示前后关联对象是一种“或”的关系。此外,本文中的“多”表示两个或者多于两个。另外,本文中术语“至少一种”表示多种中的任意一种或多种中的至少两种的任意组合,例如,包括A、B、C中的至少一种,可以表示包括从A、B和C构成的集合中选择的任意一个或多个元素。
请参阅图1,图1是本申请模型训练方法第一实施例的第一流程示意图。具体而言,可以包括如下步骤:
步骤S11:获取基于源域样本数据训练得到的目标模型和评价模型。
在本申请中,源域样本数据可以是图像数据、例如是医学图像数据,也可以是文本数据等等,本申请不做限制。利用源域样本数据对目标模型进行训练的过程,可以是利用本领域通用的监督学习方法来训练。
在本申请中,评价模型用于对目标模型输出的预测结果的准确性进行评价。评价模型的例如是编码-预测结构。评价模型的输入为目标模型的输出。例如,目标模型为图像分割模型,则目标模型的输出结果为预测分割结果,则输入至评价模型的数据为目标模型输出的预测分割结果,评价模型能够输出对输入的预测结果的评价分数。此外,评价模型对目标模型输出的预测结果的准确性进行评价,表明评价模型能够识别目标模型输出的预测结果的优劣。例如,目标模型输出的预测结果越好(损失值越小),利用评价模型得到的评价分数越高。
在一个实施方式中,可以基于源域样本数据对目标模型和评价模型进行至少一次迭代训练。在本实施例方式中,每次训练所基于的源域样本数据相同或不同。也即,可以利用相同或不同的源域样本数据,对目标模型和评价模型执行不少于一次的训练。在一个具体实施方式中,可以进行多次的迭代训练。因此,通过利用源域样本数据对目标模型和评价模型进行至少一次迭代训练,可以提高目标模型输出的预测结果的准确性,同时也能提高评价模型的评价准确度。
步骤S12:利用目标模型对目标域样本数据进行预测,得到目标域预测结果。
在本申请中,目标域样本数据不同于源域样本数据。在一个具体实施方式中,目标域样本数据与源域样本数据的不同可以体现为,对于同一目标对象,其来源不同。例如,目标对象为肺部三维图像,源域样本数据来源为A中心,目标域样本数据来源为B中心。在另一个具体实施方式中,目标域样本数据与源域样本数据的不同可以体现为模态不同,例如,源域样本数据的模态为M,目标域样本数据的模态为N。
将目标域样本数据输入到目标模型中,目标模型会对应的输出目标域预测结果。例如,目标域样本数据为某一器官图像,则目标域预测结果为该器官的分割结果。
步骤S13:利用评价模型对目标域预测结果进行评价,得到目标域评价结果。
具体的,可以将目标域预测结果输入到评价模型中,使得评价模型能够输出目标域评价结果。因为评价模型已经利用了源域样本数据进行训练,表明评价模型已经学习到了源域样本数据的特征信息,也即学习到了源域样本数据中的目标对象的特征信息,由于目标域样本数据与源域样本数据都是基于同一目标对象得到的,只是来源不同。因此,评价模型同样能够对目标域预测结果进行合理的评价,得到符合要求的目标域评价结果,即目标域评价结果也是对目标模型输出的目标域预测结果的准确性进行评价。
步骤S14:基于目标域评价结果,调整目标模型的网络参数。
因为目标域评价结果是对目标模型输出的目标域预测结果的准确性进行评价,因为可以反映出目标域预测结果的优劣。在一个实施方式中,也可以将目标域评价结果视为目标域预测结果的损失值。后续,根据目标域评价结果体现的目标域预测结果的优劣,便可相应调整目标模型的网络参数,实现目标模型在目标域的训练,进而提高目标模型在目标域样本数据的预测结果准确性。
因此,通过获得基于源域样本数据训练得到的目标模型和评价模型,并利用评价模型对目标模型基于目标域样本数据预测得到的目标域预测结果进行评价,以此实现了目标模型在目标域的训练,有助于提高目标模型在目标域样本数据的预测准确性,以此实现了目标模型的域适应。
在一个实施例中,源域样本数据和目标域样本数据均为包含目标器官的三维(3D)图像,目标器官例如是肺部、心脏、大脑等等。另外,目标模型为图像分割模型,图像分割模型例如是3D的U-Net模型、Mask R-CNN模型等等。因此,通过限定源域样本数据和目标域样本数据均为包含目标器官的三维图像,使得在利用源域样本数据训练目标模型和评价模型时,评价模型能够学习到关于目标器官的先验,例如是标签空间的分布以及目标器官的形状。
请参阅图2,图2是本申请模型训练方法第二实施例的流程示意图。在本实施例中,上述步骤提及的“每次基于源域样本数据对目标模型和评价模型进行训练”,具体包括步骤S21至步骤S24。
步骤S21:利用目标模型对源域样本数据进行预测,得到本次训练对应的第一源域预测结果。
通过将源域样本数据输入到目标模型中,可以得到对应的第一源域样本数据。例如,将目标器官的三维图像输入到目标模型中,能够得到目标器官的分割结果。在一个实施方式中,源域样本数据是有标签信息的数据。例如,对于目标器官的血管分割而言,标签信息可以是每一个点是否为血管,血管的类型是动脉、静脉等等标签信息。
步骤S22:基于本次训练对应的第一源域预测结果,调整目标模型的网络参数。
在一个实施方式中,具体可以是基于本次训练对应的第一源域预测结果和标签信息,基于确定的损失函数,得到损失值,然后根据损失值来调整目标模型的网络参数。例如,对于目标器官的血管分割,则是基于血管分割的预测结果,血管的标签信息,得到对应的损失值,然后根据损失值来调整目标模型的网络参数。
步骤S23:利用评价模型对第二源域预测结果进行评价,得到源域评价结果。
在本实施例中,第二源域预测结果包括本次训练和/或历史训练对应的第一源域预测结果。历史训练对应的第一源域预测结果,可以是在本次训练之前,将源域样本数据输入到目标模型,由目标模型输出的预测结果。也即,在本实施例中,输入到评价的预测结果,可以是利用源域样本数据进行训练时,本次训练对应的预测结果,还可以包括之前利用源域样本数据进行训练得到的预测结果。通过增加输入到评价模型的预测结果,可以利用更多的数据对评价模型进行训练,有助于加快评价模型的训练速度。
评价模型可以分别对输入到评价模型中的第二源域预测结果中的每一个预测结果进行评价。例如,输入到评价模型中的第二源域预测结果中包含5个预测结果,则评价模型会分别对这5个预测结果进行评价。源域评价结果可以是基于第二源域预测结果中的每一个预测结果对应的评价结果而得到的。例如,可以对第二源域预测结果中的每一个预测结果对应的评价结果进行加权求和,以此得到源域评价结果。
步骤S24:基于第二源域预测结果以及源域评价结果,调整评价模型的网络参数。
源域评价结果表明评价模型对第二源域预测结果中的每一个预测结果的准确性高低,作用与利用损失函数得到的损失值相同。因此,可以基于第二源域预测结果中的每一个预测结果对应的损失值,以及源域评价结果,调整评价模型的网络参数。
具体的,可以针对第二源域预测结果中的每一个预测结果对应的损失值进行处理,以此得到一个综合的损失值。另外,源域评价结果也是基于第二源域预测结果中的每一个预测结果对应的评价结果而得到的综合的评价结果。以此,便可根据第二源域预测结果对应的损失值以及源域评价结果的差异,调整评价模型的网络参数,以使得评价模型经过训练后,能够正确评价目标模型输出的预测结果的准确性。
因此,通过利用评价模型对第二源域预测结果进行评价来得到源域评价结果,后续便可基于第二源域预测结果以及源域评价结果,来调整评价模型的网络参数,以此实现对评价模型的训练。
在一个实施例中,在步骤“利用评价模型对第二源域预测结果进行评价,得到源域评价结果”之前,模型训练方法还包括:基于目标模型当前预测的准确性,将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果。
在本实施例中,目标模型当前预测的准确性是基于前若干次训练对应的第一源域预测结果或前若干次训练中评价模型输出的评价结果确定的。例如,目标模型当前预测的准确性是基于前0次训练对应的第一源域预测结果得到的,也即,目标模型当前预测的准确性可以是基于本次训练对应的第一源域预测结果得到。本次训练对应的第一源域预测结果为0.18,则表明目标模型当前预测的准确率为82%左右。又如,目标模型当前预测的准确性是基于前1次训练对应的第一源域预测结果得到的。前1次训练对应的第一源域预测结果的损失值为0.2,则表明目标模型当前预测的准确率为80%左右。又如,目标模型当前预测的准确性可以是基于前5次训练对应的第一源域预测结果得到的。前5次训练对应的第一源域预测结果的损失值为0.2、0.22、0.25、0.28和0.30,则目标模型当前预测的准确率可以是这5个损失值的平均值。再如,目标模型当前预测的准确性是基于前1次训练中评价模型输出的评价结果确定的。前1次训练评价模型输出的评价结果为8分(满分10分),则表明目标模型当前预测的准确率为80%左右。
因此,通过确定目标模型当前预测的准确性,可以基于目标模型的训练程度,决定是否将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果。
在一个实施方式中,上述步骤“基于目标模型当前预测的准确性,将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果”具体包括步骤1和步骤2(图未示)。
步骤1:响应于目标模型当前预测的准确性满足第一预设要求,选出至少一个历史训练对应的第一源域预测结果,并将本次训练对应的第一源域预测结果和选出的第一源域预测结果,作为第二源域预测结果。
在一个实施方式中,第一预设要求可以是目标模型当前预测的准确性达到某一阈值。阈值例如是70%的准确率。阈值的设置可以根据需要确定,此处不做限制。
在本实施例中,可以保存历史训练中每一次训练得到的第一源域预测结果。以此,当目标模型当前预测的准确性满足第一预设要求时,可以选出至少一个历史训练对应的第一源域预测结果,以及将本次训练对应的第一源域预测结果作为第二源域预测结果。
步骤2:响应于目标模型当前预测的准确性不满足第一预设要求,将本次训练对应的第一源域预测结果作为第二源域预测结果。
当目标模型当前预测的准确性不满足第一预设要求时,可以不将历史训练对应的第一源域预测结果作为第二源域预测结果,也即,仅将次训练对应的第一源域预测结果输入至评价模型中。
因此,通过确定目标模型当前预测的准确性是否满足第一预设要求,可以确定是否将历史训练对应的第一源域预测结果作为第二源域预测结果,以此灵活控制输入到评价模型的预测结果的数量,而且根据在目标模型的预测准确性可增加历史训练的预测结果对评价模型进行训练,可提高对评价模型的训练效果。
在一个实施例中,在步骤“利用目标模型对源域样本数据进行预测,得到本次训练对应的第一源域预测结果”之后,模型训练方法还包括:响应于当前满足第二预设要求,将本次训练对应的第一源域预测结果保存至预设结果集中。对应于存在预设结果集的实施例,上述步骤提及的“选出至少一个历史训练对应的第一源域预测结果”具体包括:从预设结果集中选出至少一个第一源域预测结果。因此,通过判断目标模型当前是否满足第二预设要求,可以确定是否将本次训练对应的第一源域预测结果保存至预设结果集中,进而可实现后续从预设结果集中获取到历史训练的预测结果,以用于对评价模型进行训练。
当前满足第二预设要求,可以是对目标模型的训练过程设定的要求。在一个实施方式中,第二预设要求包括以下至少一者:目标模型当前预测的准确性不满足第一预设要求,当前目标模型的训练次数少于预设数量。
目标模型当前预测的准确性不满足第一预设要求,可以认为目标模型本次训练输出的预测结果的准确率不满足要求。前目标模型的训练次数少于预设数量,可以认为目标模型的训练还处于比较早的阶段,目标模型输出的预测结果准确率不高。此时可以将本次训练对应的第一源域预测结果保存至预设结果集中,使得后续可以利用该第一源域预测结果对评价模型进行训练。
因此,当目标模型当前预测的准确性不满足第一预设要求和/或当前目标模型的训练次数少于预设数量时,通过将本次训练对应的第一源域预测结果保存至预设结果集中,可以使得预设结果集中能够存储有预测结果准确率较差的第一源域预测结果,使得后续从预设结果集中选择第一源域预测结果作为第二源域预测结果来对评价模型进行训练时,训练的样本数据更加丰富,提高评价模型对准确率较差的第一源域预测结果的识别能力,有助于提高评价模型的训练效果。
在一个实施方式中,预测结果集中的第一源域预测结果是按照第一源域预测结果对应的第一损失值划分在不同损失值区间中。也即,预测结果集的第一源域预测结果可以按照其对应的第一损失值进行分类。预设区间可以根据需要进行设置,此处不再赘述。因此,通过利用第一损失值来对预测结果集中的第一源域预测结果进行分类,实现基于第一源域预测结果的准确率进行分类。
请参阅图3,图3是本申请模型训练方法第三实施例的流程示意图。上述步骤“从预设结果集中选出至少一个第一源域预测结果”具体包括步骤S31或者步骤S32。
步骤S31:从预设结果集中随机选出至少一个第一源域预测结果。
步骤S32:从预设结果的各损失值区间中分别选出至少一个第一源域预测结果。
因此,通过从预设结果的各损失值区间中分别选出至少一个第一源域预测结果,使得后续能够利用处于不同损失值区间对应的第一源域预测结果来对评价模型进行训练,提高评价模型识别不同准确率的第一源域预测结果的识别能力。
请参阅图4,图4是本申请模型训练方法第四实施例的流程示意图。上述步骤“基于第二源域预测结果以及源域评价结果,调整评价模型的网络参数”具体包括步骤S41和步骤S42。
步骤S41:获取第二源域预测结果对应的第一损失值。
在本实施例中,第一损失值是基于第二源域预测结果与对应的源域数据的源域标注信息之间的差异确定的。具体的,第一损失值是基于第二源域预测结果包含的每一个第一源域预测结果与对应的源域数据的源域标注信息之间的差异确定的。基于每一个第一源域预测结果与对应的源域数据的源域标注信息之间的差异,可以确定每一个第一源域预测结果对应的损失值。每一个第一源域预测结果对应的损失值,可以直接获取历史训练时或本次训练时得到的损失值;也可以是重新基于每一个第一源域预测结果与标注信息之间的差异,得到对应的损失值。
后续,第二源域预测结果的第一损失值可以基于第二源域预测结果包含的每一个第一源域预测结果对应的损失值得到的。例如是基于于第二源域预测结果包含每一个第一源域预测结果对应的损失值进行加权求和得到第一损失值。
步骤S42:基于第一损失值和源域评价结果,调整评价模型的网络参数。
第一损失值能够表示第二源域预测结果的优劣,源域评价结果也能够表示第二源域预测结果的优劣,因此可以将第一损失值作为标签信息,通过比较第一损失值和源域评价结果的差异,来调整评价模型的网络参数,使得源域评价结果能够与第一损失值相互对应。
请参阅图5,图5是本申请模型训练方法第一实施例的第二流程示意图。在本实施例中,上述步骤“基于目标域评价结果,调整目标模型的网络参数”具体可以包括步骤S51和步骤S52。
步骤S51:基于目标域评价结果,得到第二损失值。
在一个实施方式中,可以直接将目标域评价结果作为第二损失值。在另一个实施方式中,也可以是对目标域评价结果进行预设运算,得到第二损失值,预设运算例如是归一化运算。在一个具体实施方式中,可以将第二损失值设置为与目标模型输出的预测结果的准确性为负相关关系。也即,目标域评价结果表明目标域预测结果越好,第二损失值越小。例如,目标域评价结果的评价分数越高,表明目标域预测结果越好,对应的第二损失值越小。因此,通过将第二损失值设置为与目标模型输出的预测结果的准确性为负相关关系,可以直观地通过第二损失值判断目标域预测结果的优劣。
步骤S52:基于第二损失值,调整目标模型的网络参数。
确定第二损失值后,表明已经能够对目标域预测结果进行正确的评价,此时即可以根据第二损失值来调整目标模型的网络参数。例如,可以通过调整目标模型的网络参数,使得第二损失值尽可能的小,以提高目标模型输出的目标域预测结果的准确度。基于第二损失值调整目标模型的网络参数过程,可以与一般的网络模型的训练过程相同,此处不再赘述。
在一个具体实施方式中,调整目标模型的网络参数可以本领域通用的微调(FineTune),即是调整目标模型的部分网络层的参数,例如是调整批标准化层(batchnormalization)的网络参数。因此,通过确定仅调整目标模型的部分网络层的参数,可以较少需要调整的参数量,有助于提高训练速度。
因此,通过基于目标域评价结果得到第二损失值,后续便可基于第二损失值,调整目标模型的网络参数,以此实现对目标模型的训练。
请参阅图6,图6是本申请模型训练方法实施例目标模型的一结构示意图。在本实施例中,目标模型60包括特征提取模块61、特征解码模块62和预测层63。特征提取模块61包括特征提取层611-615。特征解码模块62包括特征解码层621至624。特征提取层和特征解码层均可以包括若干层的卷积层。特征提取层之间,以及特征提取层与特征解码层之间可以设置有池化层,例如是最大池化层(图未示)。特征解码层之间可以设置有上采样层(图未示)。另外,特征提取层611与特征解码层624连接,特征提取层612与特征解码层623连接,特征提取层613与特征解码层622连接,特征提取层614与特征解码层621连接。特征提取层611与特征解码层624连接,表示特征提取层611的输出会与特征解码层623的输出进行融合,拼接后的特征信息会输入到特征解码层624。预测结果层63例如是1*1的卷积层。
在本实施例中,输入为包含目标器官的三维图像。特征提取模块61能够提取关于目标器官的特征信息,具体可以是特征提取模块61的每一个特征提取层来提取特征信息。特征解码模块62能够解码关于目标器官的特征信息,具体可以是特征解码模块62的每一个特征解码层来解码特征信息。预测结果层可以基于特征解码模块62输出的特征信息,输出预测结果。预测结果具体可以是对目标器官的像素点进行分类的信息,例如对目标器官的点是否属于血管,血管的类别进行分类,最终实现对目标器官的分割。
请参阅图7,图7是本申请模型训练方法实施例评价模型的一结构示意图。评价模型70包括特征提取模块71和输出层72。特征提取模块71包括特征提取层711-713。输出层72例如是1*1的卷积层。评价模型70的输入为目标模型输出的预测结果,在本实施例中,输入为目标模型输出的目标器官的分割结果。输出具体为评价分数。以此。评价模型能够对输入的预测结果进行评价。
请参阅图8,图8是本申请用于模型训练装置的一框架示意图。用于模型训练装置80包括获取模块81、预测模块82、确定模块83和调整模块84。获取模块81用于获取基于源域样本数据训练得到的目标模型和评价模型,其中,评价模型用于对目标模型输出的预测结果的准确性进行评价;预测模块82用于利用目标模型对目标域样本数据进行预测,得到目标域预测结果;确定模块83用于利用评价模型对目标域预测结果进行评价,得到目标域评价结果;调整模块84用于基于目标域评价结果,调整目标模型的网络参数。
其中,获取模块81用于获取基于源域样本数据训练得到的目标模型和评价模型,包括:基于源域样本数据对目标模型和评价模型进行至少一次迭代训练,其中,每次训练所基于的源域样本数据相同或不同。
其中,获取模块81用于每次基于源域样本数据对目标模型和评价模型进行训练,包括:利用目标模型对源域样本数据进行预测,得到本次训练对应的第一源域预测结果;基于本次训练对应的第一源域预测结果,调整目标模型的网络参数;以及利用评价模型对第二源域预测结果进行评价,得到源域评价结果,其中,第二源域预测结果包括本次训练和/或历史训练对应的第一源域预测结果;基于第二源域预测结果以及源域评价结果,调整评价模型的网络参数。
其中,在确定模块83用于利用评价模型对目标域预测结果进行评价,得到目标域评价结果之前,用于模型训练装置80的第二源域预测结果确定模块用于基于目标模型当前预测的准确性,将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果;其中,目标模型当前预测的准确性是基于前若干次训练对应的第一源域预测结果或前若干次训练中评价模型输出的评价结果确定的。
其中,第二源域预测结果确定模块用于基于目标模型当前预测的准确性,将本次训练和/或历史训练对应的第一源域预测结果作为第二源域预测结果,包括:响应于目标模型当前预测的准确性满足第一预设要求,选出至少一个历史训练对应的第一源域预测结果,并将本次训练对应的第一源域预测结果和选出的第一源域预测结果,作为第二源域预测结果;响应于目标模型当前预测的准确性不满足第一预设要求,将本次训练对应的第一源域预测结果作为第二源域预测结果。
其中,在获取模块81用于用目标模型对源域样本数据进行预测,得到本次训练对应的第一源域预测结果之后,用于模型训练装置80的选择模块用于响应于当前满足第二预设要求,将本次训练对应的第一源域预测结果保存至预设结果集中。第二源域预测结果确定模块用于选出至少一个历史训练对应的第一源域预测结果,包括:从预设结果集中选出至少一个第一源域预测结果。
其中,上述的第二预设要求包括以下至少一者:目标模型当前预测的准确性不满足第一预设要求,当前目标模型的训练次数少于预设数量;上述的预测结果集中的第一源域预测结果是按照第一源域预测结果对应的第一损失值划分在不同损失值区间中;上述的第二源域预测结果确定模块用于从预设结果集中选出至少一个第一源域预测结果,包括:从预设结果集中随机选出至少一个第一源域预测结果;或者,从预设结果的各损失值区间中分别选出至少一个第一源域预测结果。
其中,获取模块81用于基于第二源域预测结果以及源域评价结果,调整评价模型的网络参数,包括:获取第二源域预测结果对应的第一损失值,其中,第一损失值是基于第二源域预测结果与对应的源域数据的源域标注信息之间的差异确定的;基于第一损失值和源域评价结果,调整评价模型的网络参数。
其中,调整模块84用于调整目标模型的网络参数,包括:调整目标模型的部分网络层的参数。
其中,调整模块84用于述基于目标域评价结果,调整目标模型的网络参数,包括:基于目标域评价结果,得到第二损失值;基于第二损失值,调整目标模型的网络参数。
其中,调整模块84用于基于目标域评价结果,得到第二损失值,包括:对目标域评价结果进行预设运算,得到第二损失值;和/或,第二损失值与目标模型输出的预测结果的准确性为负相关关系。
其中,上述的源域样本数据和目标域样本数据均为包含目标器官的三维图像;目标模型为图像分割模型。
请参阅图9,图9是本申请电子设备一实施例的框架示意图。电子设备90包括相互耦接的存储器91和处理器92,处理器92用于执行存储器91中存储的程序指令,以实现上述任一模型训练方法实施例中的步骤。在一个具体的实施场景中,电子设备90可以包括但不限于:微型计算机、服务器,此外,电子设备90还可以包括笔记本电脑、平板电脑等移动设备,在此不做限定。
具体而言,处理器92用于控制其自身以及存储器91以实现上述任一模型训练方法实施例中的步骤。处理器92还可以称为CPU(Central Processing Unit,中央处理单元)。处理器92可能是一种集成电路芯片,具有信号的处理能力。处理器92还可以是通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application SpecificIntegrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。另外,处理器92可以由集成电路芯片共同实现。
请参阅图10,图10为本申请计算机可读存储介质一实施例的框架示意图。计算机可读存储介质100存储有能够被处理器运行的程序指令101,程序指令101用于实现上述任一模型训练方法实施例中的步骤。
在一些实施例中,本公开实施例提供的装置具有的功能或包含的模块可以用于执行上文方法实施例描述的方法,其具体实现可以参照上文方法实施例的描述,为了简洁,这里不再赘述。
上述方案,通过获得基于源域样本数据训练得到的目标模型和评价模型,并利用评价模型对目标模型基于目标域样本数据预测得到的目标域预测结果进行评价,以此实现了目标模型在目标域的训练,有助于提高目标模型在目标域样本数据的预测准确性,以此实现了目标模型的域适应。
上文对各个实施例的描述倾向于强调各个实施例之间的不同之处,其相同或相似之处可以互相参考,为了简洁,本文不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的方法和装置,可以通过其它的方式实现。例如,以上所描述的装置实施方式仅仅是示意性的,例如,模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性、机械或其它的形式。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本申请各个实施方式方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
若本申请技术方案涉及个人信息,应用本申请技术方案的产品在处理个人信息前,已明确告知个人信息处理规则,并取得个人自主同意。若本申请技术方案涉及敏感个人信息,应用本申请技术方案的产品在处理敏感个人信息前,已取得个人单独同意,并且同时满足“明示同意”的要求。例如,在摄像头等个人信息采集装置处,设置明确显著的标识告知已进入个人信息采集范围,将会对个人信息进行采集,若个人自愿进入采集范围即视为同意对其个人信息进行采集;或者在个人信息处理的装置上,利用明显的标识/信息告知个人信息处理规则的情况下,通过弹窗信息或请个人自行上传其个人信息等方式获得个人授权;其中,个人信息处理规则可包括个人信息处理者、个人信息处理目的、处理方式以及处理的个人信息种类等信息。

Claims (15)

1.一种模型训练方法,其特征在于,包括:
获取基于源域样本数据训练得到的目标模型和评价模型,其中,所述评价模型用于对所述目标模型输出的预测结果的准确性进行评价;
利用所述目标模型对目标域样本数据进行预测,得到目标域预测结果;
利用所述评价模型对所述目标域预测结果进行评价,得到目标域评价结果;
基于所述目标域评价结果,调整所述目标模型的网络参数。
2.根据权利要求1所述的方法,其特征在于,所述获取基于源域样本数据训练得到的目标模型和评价模型,包括:
基于所述源域样本数据对所述目标模型和评价模型进行至少一次迭代训练,其中,每次训练所基于的所述源域样本数据相同或不同。
3.根据权利要求1所述的方法,其特征在于,每次基于所述源域样本数据对所述目标模型和评价模型进行训练,包括:
利用所述目标模型对所述源域样本数据进行预测,得到本次训练对应的第一源域预测结果;
基于本次训练对应的所述第一源域预测结果,调整所述目标模型的网络参数;以及
利用所述评价模型对第二源域预测结果进行评价,得到源域评价结果,其中,所述第二源域预测结果包括本次训练和/或历史训练对应的所述第一源域预测结果;
基于所述第二源域预测结果以及所述源域评价结果,调整所述评价模型的网络参数。
4.根据权利要求3所述的方法,其特征在于,在所述利用所述评价模型对第二源域预测结果进行评价,得到源域评价结果之前,所述方法还包括:
基于所述目标模型当前预测的准确性,将本次训练和/或历史训练对应的所述第一源域预测结果作为所述第二源域预测结果;
其中,所述目标模型当前预测的准确性是基于前若干次训练对应的第一源域预测结果或前若干次训练中所述评价模型输出的评价结果确定的。
5.根据权利要求4所述的方法,其特征在于,所述基于所述目标模型当前预测的准确性,将本次训练和/或历史训练对应的所述第一源域预测结果作为所述第二源域预测结果,包括:
响应于所述目标模型当前预测的准确性满足第一预设要求,选出至少一个历史训练对应的第一源域预测结果,并将本次训练对应的第一源域预测结果和所述选出的第一源域预测结果,作为所述第二源域预测结果;
响应于所述目标模型当前预测的准确性不满足第一预设要求,将本次训练对应的所述第一源域预测结果作为所述第二源域预测结果。
6.根据权利要求5所述的方法,其特征在于,在所述利用所述目标模型对所述源域样本数据进行预测,得到本次训练对应的第一源域预测结果之后,所述方法还包括:
响应于当前满足第二预设要求,将本次训练对应的所述第一源域预测结果保存至预设结果集中;
所述选出至少一个历史训练对应的第一源域预测结果,包括:
从所述预设结果集中选出至少一个所述第一源域预测结果。
7.根据权利要求6所述的方法,其特征在于,所述第二预设要求包括以下至少一者:所述目标模型当前预测的准确性不满足第一预设要求,当前所述目标模型的训练次数少于预设数量;
和/或,所述预测结果集中的第一源域预测结果是按照所述第一源域预测结果对应的第一损失值划分在不同损失值区间中;所述从所述预设结果集中选出至少一个所述第一源域预测结果,包括:
从所述预设结果集中随机选出至少一个所述第一源域预测结果;或者,
从所述预设结果的各所述损失值区间中分别选出至少一个所述第一源域预测结果。
8.根据权利要求3至7任一项所述的方法,其特征在于,所述基于所述第二源域预测结果以及所述源域评价结果,调整所述评价模型的网络参数,包括:
获取所述第二源域预测结果对应的第一损失值,其中,所述第一损失值是基于所述第二源域预测结果与对应的所述源域数据的源域标注信息之间的差异确定的;
基于所述第一损失值和所述源域评价结果,调整所述评价模型的网络参数。
9.根据权利要求1至8任一项所述的方法,其特征在于,所述调整所述目标模型的网络参数,包括:调整所述目标模型的部分网络层的参数。
10.根据权利要求1至9任一项所述的方法,其特征在于,所述基于所述目标域评价结果,调整所述目标模型的网络参数,包括:
基于所述目标域评价结果,得到第二损失值;
基于所述第二损失值,调整所述目标模型的网络参数。
11.根据权利要求10所述的方法,其特征在于,所述基于所述目标域评价结果,得到第二损失值,包括:
对所述目标域评价结果进行预设运算,得到第二损失值;
和/或,所述第二损失值与所述目标模型输出的预测结果的准确性为负相关关系。
12.根据权利要求1至11任一项所述的方法,其特征在于,所述源域样本数据和所述目标域样本数据均为包含目标器官的三维图像;所述目标模型为图像分割模型。
13.一种用于模型训练装置,其特征在于,包括:
获取模块,用于获取基于源域样本数据训练得到的目标模型和评价模型,其中,所述评价模型用于对所述目标模型输出的预测结果的准确性进行评价;
预测模块,用于利用所述目标模型对目标域样本数据进行预测,得到目标域预测结果;
确定模块,用于利用所述评价模型对所述目标域预测结果进行评价,得到目标域评价结果;
调整模块,用于基于所述目标域评价结果,调整所述目标模型的网络参数。
14.一种电子设备,其特征在于,包括相互耦接的存储器和处理器,所述处理器用于执行所述存储器中存储的程序指令,以实现权利要求1至12任一项所述的模型训练方法。
15.一种计算机可读存储介质,其上存储有程序指令,其特征在于,所述程序指令被处理器执行时实现权利要求1至12任一项所述的模型训练方法。
CN202210101882.5A 2022-01-27 2022-01-27 模型训练方法及相关装置、设备和存储介质 Withdrawn CN114445679A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210101882.5A CN114445679A (zh) 2022-01-27 2022-01-27 模型训练方法及相关装置、设备和存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210101882.5A CN114445679A (zh) 2022-01-27 2022-01-27 模型训练方法及相关装置、设备和存储介质

Publications (1)

Publication Number Publication Date
CN114445679A true CN114445679A (zh) 2022-05-06

Family

ID=81370687

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210101882.5A Withdrawn CN114445679A (zh) 2022-01-27 2022-01-27 模型训练方法及相关装置、设备和存储介质

Country Status (1)

Country Link
CN (1) CN114445679A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2024066927A1 (zh) * 2022-09-30 2024-04-04 腾讯科技(深圳)有限公司 图像分类模型的训练方法、装置及设备

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2024066927A1 (zh) * 2022-09-30 2024-04-04 腾讯科技(深圳)有限公司 图像分类模型的训练方法、装置及设备

Similar Documents

Publication Publication Date Title
CN106845421B (zh) 基于多区域特征与度量学习的人脸特征识别方法及系统
CN109376615B (zh) 用于提升深度学习网络预测性能的方法、装置及存储介质
CN109919928B (zh) 医学影像的检测方法、装置和存储介质
CN109583332B (zh) 人脸识别方法、人脸识别系统、介质及电子设备
JP2015087903A (ja) 情報処理装置及び情報処理方法
CN111126396A (zh) 图像识别方法、装置、计算机设备以及存储介质
CN111881741B (zh) 车牌识别方法、装置、计算机设备和计算机可读存储介质
CN110717554A (zh) 图像识别方法、电子设备及存储介质
CN113688851B (zh) 数据标注方法和装置和精细粒度识别方法和装置
CN110796199A (zh) 一种图像处理方法、装置以及电子医疗设备
CN111914665A (zh) 一种人脸遮挡检测方法、装置、设备及存储介质
CN111401196A (zh) 受限空间内自适应人脸聚类的方法、计算机装置及计算机可读存储介质
CN111027347A (zh) 一种视频识别方法、装置和计算机设备
CN112149754B (zh) 一种信息的分类方法、装置、设备及存储介质
CN112330624A (zh) 医学图像处理方法和装置
CN111753702A (zh) 目标检测方法、装置及设备
CN111340213B (zh) 神经网络的训练方法、电子设备、存储介质
CN112383824A (zh) 视频广告过滤方法、设备及存储介质
CN112182269A (zh) 图像分类模型的训练、图像分类方法、装置、设备及介质
CN114445679A (zh) 模型训练方法及相关装置、设备和存储介质
CN113269307B (zh) 神经网络训练方法以及目标重识别方法
CN112949456B (zh) 视频特征提取模型训练、视频特征提取方法和装置
CN111414930A (zh) 深度学习模型训练方法及装置、电子设备及存储介质
CN111242176A (zh) 计算机视觉任务的处理方法、装置及电子系统
CN114445716B (zh) 关键点检测方法、装置、计算机设备、介质及程序产品

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WW01 Invention patent application withdrawn after publication

Application publication date: 20220506

WW01 Invention patent application withdrawn after publication