CN116452922B - 模型训练方法、装置、计算机设备及可读存储介质 - Google Patents
模型训练方法、装置、计算机设备及可读存储介质 Download PDFInfo
- Publication number
- CN116452922B CN116452922B CN202310677518.8A CN202310677518A CN116452922B CN 116452922 B CN116452922 B CN 116452922B CN 202310677518 A CN202310677518 A CN 202310677518A CN 116452922 B CN116452922 B CN 116452922B
- Authority
- CN
- China
- Prior art keywords
- classifier
- feature vector
- labels
- training
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 101
- 238000000034 method Methods 0.000 title claims abstract description 45
- 238000003860 storage Methods 0.000 title claims abstract description 10
- 239000013598 vector Substances 0.000 claims abstract description 136
- 238000004590 computer program Methods 0.000 claims description 9
- 238000000605 extraction Methods 0.000 abstract 1
- 230000003902 lesion Effects 0.000 description 14
- 238000004891 communication Methods 0.000 description 11
- 230000008569 process Effects 0.000 description 7
- 210000004072 lung Anatomy 0.000 description 6
- 230000008901 benefit Effects 0.000 description 5
- 238000005457 optimization Methods 0.000 description 5
- 206010035664 Pneumonia Diseases 0.000 description 4
- 239000000284 extract Substances 0.000 description 4
- 230000036541 health Effects 0.000 description 4
- 238000003062 neural network model Methods 0.000 description 4
- 238000010586 diagram Methods 0.000 description 3
- 208000035143 Bacterial infection Diseases 0.000 description 2
- 238000005481 NMR spectroscopy Methods 0.000 description 2
- 208000036142 Viral infection Diseases 0.000 description 2
- 230000002776 aggregation Effects 0.000 description 2
- 238000004220 aggregation Methods 0.000 description 2
- 208000022362 bacterial infectious disease Diseases 0.000 description 2
- 230000005540 biological transmission Effects 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 239000003086 colorant Substances 0.000 description 2
- 230000009385 viral infection Effects 0.000 description 2
- 230000015556 catabolic process Effects 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000000203 mixture Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
- 238000004383 yellowing Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
Abstract
本申请公开了一种模型训练方法、装置、计算机设备及可读存储介质,方法包括:获取多个采集图像的多个图像数据、多个第一标签和多个第二标签;根据预设模型的特征提取器,确定多个采集图像的多个特征向量集;根据多个特征向量集和多个第一标签,对预设模型的第一分类器进行训练;根据多个图像数据和多个第二标签,对预设模型的第二分类器进行训练;根据多个特征向量集、多个第一标签、多个第二标签、第一分类器和第二分类器,对特征提取器进行训练。通过将特征提取器与第一分类器相结合进行同目标训练的同时,将特征提取器与第二分类器进行对抗训练,提高了与病灶有关的特征提取的精准度,进而提高了联邦学习的精确度。
Description
技术领域
本申请涉及数据处理技术领域,特别是涉及一种模型训练方法、装置、计算机设备及可读存储介质。
背景技术
联邦学习是一种新兴的机器学习范式,其核心思想是在客户端上利用本地数据训练本地模型,并将模型参数发送到服务器以聚合全局模型。然而,在整个学习过程中,不同医疗机构的图像数据往往存在很大差异,客户端的训练模型性能较差,进而导致联邦学习精度较低。
发明内容
有鉴于此,本申请提供了一种模型训练方法、装置、计算机设备及可读存储介质,主要目的在于解决不同医疗机构的图像数据往往存在很大差异,客户端的训练模型性能较差,进而导致联邦学习精度较低的问题。
依据本申请第一方面,提供了一种模型训练方法,该方法包括:
获取多个采集图像的多个图像数据、多个第一标签和多个第二标签;
根据预设模型的特征提取器,确定多个采集图像的多个特征向量集;
根据多个特征向量集和多个第一标签,对预设模型的第一分类器进行训练;
根据多个图像数据和多个第二标签,对预设模型的第二分类器进行训练;
根据多个特征向量集、多个第一标签、多个第二标签、第一分类器和第二分类器,对特征提取器进行训练。
可选地,根据预设模型的特征提取器,确定多个采集图像的多个特征向量集的步骤,具体包括:
依次将每个采集图像输入特征提取器,确定每个采集图像对应的特征向量集。
可选地,根据多个特征向量集和多个第一标签,对预设模型的第一分类器进行训练的步骤,具体包括:
将多个采集图像的多个特征向量集作为输入项,多个第一标签作为输出项,对第一分类器的参数进行调整。
可选地,根据多个图像数据和多个第二标签,对预设模型的第二分类器进行训练的步骤,具体包括:
将多个采集图像的多个图像数据作为输入项,多个第二标签作为输出项,对第二分类器的参数进行调整。
可选地,根据多个特征向量集、多个第一标签、多个第二标签、第一分类器和第二分类器,对特征提取器进行训练的步骤,具体包括:
将多个采集图像的多个特征向量集输入第一分类器,生成多个第一识别结果;
将多个第一识别结果和多个第一标签进行比较,判断多个第一识别结果与多个第一标签是否均相同;
若否,在多个第一识别结果中,获取与第一标签不相同的至少一个第一目标识别结果,以及每个第一目标识别结果对应的第一目标特征向量集;
根据至少一个第一目标识别结果和至少一个第一目标特征向量集,对特征提取器进行训练;
将多个采集图像的多个特征向量集输入第二分类器,生成多个第二识别结果;
将多个第二识别结果和多个第二标签进行比较,判断多个第二识别结果中,是否包含与第二标签相同的识别结果;
若是,在多个第二识别结果中,获取与第二标签相同的至少一个第二目标识别结果,以及每个第二目标识别结果对应的第二目标特征向量集;
根据至少一个第二目标识别结果和至少一个第二目标特征向量集,对特征提取器进行训练。
可选地,该方法应用于多个客户端,每个客户端与服务端通信连接,该方法还包括:
根据预设数量,在多个客户端中,确定多个目标客户端;
获取多个目标客户端的预设模型的多个第一模型参数、多个第二模型参数和多个第三模型参数;
将多个第一模型参数、多个第二模型参数和多个第三模型参数发送至服务端。
可选地,将多个第一模型参数、多个第二模型参数和多个第三模型参数发送至服务端之后,还包括:
接收服务端发送的第四模型参数、第五模型参数和第六模型参数;
根据第四模型参数,对特征提取器进行更新;
根据第五模型参数,对第一分类器进行更新;
根据第六模型参数,对第二分类器进行更新。
依据本申请第二方面,提供了一种模型训练装置,该装置包括:
获取模块,用于获取多个采集图像的多个图像数据、多个第一标签和多个第二标签;
确定模块,用于根据预设模型的特征提取器,确定多个采集图像的多个特征向量集;
训练模块,用于根据多个特征向量集和多个第一标签,对预设模型的第一分类器进行训练;
训练模块,还用于根据多个图像数据和多个第二标签,对预设模型的第二分类器进行训练;
训练模块,还用于根据多个特征向量集、多个第一标签、多个第二标签、第一分类器和第二分类器,对特征提取器进行训练。
可选地,确定模块,具体用于:
依次将每个采集图像输入特征提取器,确定每个采集图像对应的特征向量集。
可选地,训练模块,具体用于:
将多个采集图像的多个特征向量集作为输入项,多个第一标签作为输出项,对第一分类器的参数进行调整。
可选地,训练模块,具体还用于:
将多个采集图像的多个图像数据作为输入项,多个第二标签作为输出项,对第二分类器的参数进行调整。
可选地,该装置还包括:
生成模块,用于将多个采集图像的多个特征向量集输入第一分类器,生成多个第一识别结果;
判断模块,用于将多个第一识别结果和多个第一标签进行比较,判断多个第一识别结果与多个第一标签是否均相同。
可选地,获取模块,还用于若否,在多个第一识别结果中,获取与第一标签不相同的至少一个第一目标识别结果,以及每个第一目标识别结果对应的第一目标特征向量集。
可选地,训练模块,还用于根据至少一个第一目标识别结果和至少一个第一目标特征向量集,对特征提取器进行训练。
可选地,生成模块,还用于将多个采集图像的多个特征向量集输入第二分类器,生成多个第二识别结果。
可选地,判断模块,还用于将多个第二识别结果和多个第二标签进行比较,判断多个第二识别结果中,是否包含与第二标签相同的识别结果。
可选地,获取模块,还用于若是,在多个第二识别结果中,获取与第二标签相同的至少一个第二目标识别结果,以及每个第二目标识别结果对应的第二目标特征向量集。
可选地,训练模块,还用于根据至少一个第二目标识别结果和至少一个第二目标特征向量集,对特征提取器进行训练。
可选地,确定模块,还用于根据预设数量,在多个客户端中,确定多个目标客户端;
获取模块,还用于获取多个目标客户端的预设模型的多个第一模型参数、多个第二模型参数和多个第三模型参数。
可选地,该装置还包括:
发送模块,用于将多个第一模型参数、多个第二模型参数和多个第三模型参数发送至服务端。
可选地,该装置还包括:
接收模块,用于接收服务端发送的第四模型参数、第五模型参数和第六模型参数;
更新模块,用于根据第四模型参数,对特征提取器进行更新;
更新模块,还用于根据第五模型参数,对第一分类器进行更新;
更新模块,还用于根据第六模型参数,对第二分类器进行更新。
依据本申请第三方面,提供了一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现第一方面中任一项所述方法的步骤。
依据本申请第四方面,提供了一种可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现第一方面中任一项所述的方法的步骤。
借由上述技术方案,本申请提供的一种模型训练方法、装置、计算机设备及可读存储介质,相较于现有技术的联邦学习过程中,各个客户端将样本图像内提取的全部特征作为模型训练数据进行模型训练的方式,使得特征向量集中参杂与病灶无关的图像特征,进而导致客户端的训练模型性能较差,使得联邦学习精度较低的技术问题,本申请提出了通过将特征提取器与第一分类器相结合进行同目标训练的同时,将特征提取器与第二分类器进行对抗训练,确保特征提取器提取与病灶有关的特征的精准度,进而使得各个客户端提供的病灶识别模型的模型参数准确性较高,有效提高了联邦学习的精确度。
上述说明仅是本申请技术方案的概述,为了能够更清楚了解本申请的技术手段,而可依照说明书的内容予以实施,并且为了让本申请的上述和其它目的、特征和优点能够更明显易懂,以下特举本申请的具体实施方式。
附图说明
通过阅读下文优选实施方式的详细描述,各种其他的优点和益处对于本领域普通技术人员将变得清楚明了。附图仅用于示出优选实施方式的目的,而并不认为是对本申请的限制。而且在整个附图中,用相同的参考符号表示相同的部件。在附图中:
图1示出了本申请实施例提供的一种模型训练方法流程示意图;
图2示出了本申请实施例提供的模型训练方法的示意框图;
图3示出了本申请实施例提供的一种模型训练装置的结构示意图。
具体实施方式
下面将参照附图更详细地描述本申请的示例性实施例。虽然附图中显示了本申请的示例性实施例,然而应当理解,可以以各种形式实现本申请而不应被这里阐述的实施例所限制。相反,提供这些实施例是为了能够更透彻地理解本申请,并且能够将本申请的范围完整的传达给本领域的技术人员。
本申请实施例提供了一种模型训练方法,如图1所示,该方法包括:
S101、获取多个采集图像的多个图像数据、多个第一标签和多个第二标签。
可以理解的是,本发明的执行主体可以为多个医疗机构的多个客户端,具体地,每个客户端与服务端通信连接。
在该步骤中,获取到多个采集图像,以及每个采集图像对应的多个图像数据、第一标签和第二标签。其中,采集图像可以为核磁共振图像等医疗图像。第一标签指的是采集图像中的病灶类别,例如,采集图像为肺部影像图像,那么第一标签为肺部健康类别,肺部健康类别可以为:肺部健康、病毒感染肺炎或细菌感染肺炎等。第二标签指的是每个采集图像对应的出处类别,例如采集图像为第一医院的样本图像,那么该采集图像的第二标签即为第一医院。
S102、根据预设模型的特征提取器,确定多个采集图像的多个特征向量集。
在该步骤中,预设模型是用于执行分类任务的神经网络模型,具体地,对于任一客户端来说,在获取作为样本的多个采集图像后,依次将每个采集图像输入至预设模型中,利用预设模型的特征提取器将采集图像的像素信息编码成更低维度的向量,而向量中蕴含着采集图像的特征向量集,进而提取出采集图像的深层特征向量集。
S103、根据多个特征向量集和多个第一标签,对预设模型的第一分类器进行训练。
在该步骤中,预设模型的第一分类器用于对采集图像中的不同病灶类型进行分类。在对第一分类器进行训练时,将每个采集图像的特征向量集作为输入,将该采集图像对应的第一标签作为输出,对第一分类器进行训练,以实现对第一分类器的参数调整。使得第一分类器能够更加准确地识别出病灶类型。
S104、根据多个图像数据和多个第二标签,对预设模型的第二分类器进行训练。
在该步骤中,预设模型的第二分类器用于根据特征向量集中的图像特征,分辨出该特征向量集的来源,以辨别出该特征向量集属于哪一个客户端。在实际应用中,每一个医疗机构因图像采集传感器的差异,对同一病灶拍摄的图片存在像素、色彩等差异,而这些是与病灶无关的图像特征,为了确保特征提取器只提取与病灶有关的特征,利用第二分类器对图像特征进行识别。在对第二分类器进行训练时,将每个采集图像的多个图像数据作为输入,将该采集图像的第二标签作为输出,对第二分类器进行训练,以实现对第二分类器的参数调整。使得第二分类器能够更加精准地识别出病灶相关特征。
S105、根据多个特征向量集、多个第一标签、多个第二标签第一分类器和第二分类器,对特征提取器进行训练。
在该步骤中,依次将每个采集图像的特征向量集输入训练后的第一分类器,根据第一分类器的输出,得到特征向量集中是否包含病灶类型的识别结果。其后,依次将每个识别结果与其对应的特征向量集的第一标签进行比较,判断识别结果与第一标签是否相同。若识别结果与第一标签不同,说明第一分类器无法基于特征向量集识别出病灶类型,此时需要将这一识别结果反馈给特征提取器,使得特征提取器基于识别结果以及其对应的特征向量集进行优化训练,以实现对特征分类器的参数调整。进一步地,依次将每个采集图像对应的特征向量集输入训练后的第二分类器,根据第二分类器的输出,得到特征向量集中是否包含图像特征的识别结果,例如识别结果为第一医院,或无法识别特征出处。其后,依次将每个识别结果与其对应的特征向量集的第二标签进行比对,若识别结果与第二标签相同,说明第二分类器能够基于特征向量集识别出采集图像来自哪一个客户端,即采集图像的特征向量集中包含了与病灶无关的图像特征。此时需要将这一识别结果反馈给特征提取器,使得特征提取器基于该特征向量集以及识别结果进行优化训练,以实现特征提取器的参数调整。
在实际应用中,将采集图像输入特征提取器中,得到该采集图像的特征向量集。将特征向量集输入第一分类器,若第一分类器输出的识别结果为“无病灶”,说明第一分类器无法根据特征向量集确定病灶类型,此时将识别结果反馈给特征提取器进行参数调整。同时,将特征向量集输入第二分类器,若第二分类器输出的识别结果为“第一医院”,且采集图像的第二标签也为“第一医院”,说明特征向量集中包含了与病灶无关的图像特征,此时将识别结果反馈给特征提取器进行参数调整。
本申请实施例提供的模型训练方法,相较于现有技术的联邦学习过程中,各个客户端将样本图像内提取的全部特征作为模型训练数据进行模型训练的方式,使得特征向量集中参杂与病灶无关的图像特征,进而导致客户端的训练模型性能较差,使得联邦学习精度较低的技术问题,本申请提出了通过将特征提取器与第一分类器相结合进行同目标训练的同时,将特征提取器与第二分类器进行对抗训练,确保特征提取器提取与病灶有关的特征的精准度,进而使得各个客户端提供的病灶识别模型的模型参数准确性较高,有效提高了联邦学习的精确度。
进一步的,作为上述实施例具体实施方式的细化和扩展,为了完整说明本实施例的具体实施过程,本申请实施例提供了另一种模型训练方法,该方法包括:
S201、获取多个采集图像的多个图像数据、多个第一标签和多个第二标签。
在该步骤中,获取到多个采集图像,以及每个采集图像对应的多个图像数据、第一标签和第二标签。其中,采集图像可以为核磁共振图像等医疗图像。第一标签指的是采集图像中的病灶类别,例如,采集图像为肺部影像图像,那么第一标签为肺部健康类别,肺部健康类别可以为:肺部健康、病毒感染肺炎或细菌感染肺炎等。第二标签指的是每个采集图像对应的出处类别,例如采集图像为第一医院的样本图像,那么该采集图像的第二标签即为第一医院。
S202、依次将每个采集图像输入特征提取器,确定每个采集图像对应的特征向量集。
在该步骤中,在收集每个医疗机构中,作为样本的大量采集图像后,依次将每个采集图像输入至特征提取器中,利用特征提取器将每个采集图像中的像素信息编码成更低维度的向量,而向量中蕴含着采集图像的特征向量,进而利用特征提取器输出每个采集图像的特征向量集。
在实际应用中,根据具体执行的分类任务,预设模型可采用适于模型的分类任务的神经网络模型,本申请在此不作具体限定。例如,可采用ResNet模型、EfficientNet模型等。需要说明的是,在联邦学习过程中,多个客户端所采用的预设模型可以是相同的神经网络模型,或者不同的神经网络模型。但不同的预设模型所采用的分类任务、标签以及输出的识别结果的集合是相同的。
S203、将多个采集图像的多个特征向量集作为输入项,多个第一标签作为输出项,对第一分类器的参数进行调整。
在该步骤中,在对预设模型的第一分类器进行模型训练时,将每个采集图像的特征向量集作为输入,将该采集图像的第一标签作为输出,对第一分类器进行训练,进而对第一分类器的参数进行调整。使得第一分类器能够精准地识别出采集图像中的病灶类型。
S204、将多个采集图像的多个图像数据作为输入项,多个第二标签作为输出项,对第二分类器的参数进行调整。
在该步骤中,每一个医疗机构因图像采集传感器的差异,对同一病灶拍摄的图片存在像素、色彩等差异,而这些是与病灶无关的图像特征,为了确保特征提取器只提取与病灶有关的特征,利用第二分类器对图像特征进行识别。在对第二分类器进行训练时,将每个采集图像的多个图像数据作为输入,将该采集图像的第二标签作为输出,对第二分类器进行训练,进而对第二分类器的参数进行调整。使得第二分类器能够更加精准地识别出病灶相关特征。
S205、将多个采集图像的多个特征向量集输入第一分类器,生成多个第一识别结果。
在该步骤中,联邦学习是为了通过不同医疗机构的数据联合训练一个病灶识别模型。每一轮训练时,由每个医疗机构基于自己的本地数据进行模型训练,将训练后的模型参数统一发往服务端,经由服务端进行聚合后,再向给各个医疗机构分发。也就是说,各个医疗机构客户端的本地模型要能够精准地识别出病灶。然而,特征提取器提取出的特征向量集中,不仅包含与病灶有关的特征,还包含了与病灶无关的图像特征,为了提高各个客户端上传的模型参数的精度,需要对特征提取器进行优化,使得特征提取器只提取与病灶有关的特征。具体地,将每个采集图像的特征向量集输入训练后的第一分类器,根据第一分类器的输出,得出特征向量集中是否包含病灶类型的第一识别结果。
S206、将多个第一识别结果和多个第一标签进行比较,判断多个第一识别结果与多个第一标签是否均相同,若是,进入步骤S209,若否,进入步骤S207。
在该步骤中,将每个采集图像的第一识别结果与该采集图像的第一标签进行比较,判断第一识别结果与第一标签是否相同。若所有第一识别结果与第一标签均相同,说明第一分类器能够基于特征向量集识别出采集图像中的病灶类型,即特征提取器提取出的特征向量集中的与病灶有关的特征向量是准确的。若任一第一识别结果与第一标签不同,说明第一分类器无法基于特征向量集识别出采集图像中的病灶类型,或识别出的病灶类型与正确的病灶类型不符,即特征提取器提取出的特征向量集中的与病灶有关的特征向量是错误的或不全面的,此时,需要对特征提取器进行优化,以对特征提取器的参数进行调整。
S207、在多个第一识别结果中,获取与第一标签不相同的至少一个第一目标识别结果,以及每个第一目标识别结果对应的第一目标特征向量集。
S208、根据至少一个第一目标识别结果和至少一个第一目标特征向量集,对特征提取器进行训练。
在步骤S207和S208中,在确定出第一分类器所识别的多个第一识别结果中,包含了与第一标签不同的识别结果,即多个第一识别结果中存在错误时,在多个第一识别结果中,调取出错误的至少一个第一目标识别结果,并基于每个第一目标识别结果,确定其对应的第一目标特征向量集。其后,将每个第一目标识别结果和其对应的第一目标特征向量集反馈给特征提取器,使得特征提取器根据反馈的至少一个第一目标识别结果和至少一个第一目标特征向量集进行模型优化。
可选地,当特征提取器优化基于反馈的至少一个第一目标识别结果和至少一个第一目标特征向量集进行优化后,利用优化后的特征提取器再次提取每个采集图像的特征向量集,将特征向量集输入第一分类器中,得到第一识别结果,并对第一识别结果的准确性进行判断,若第一识别结果正确无误,停止对特征提取器的优化;若第一识别结果还是错误,继续对特征提取器进行优化,直至第一识别结果无误为止。
S209、无需对特征提取器进行优化。
在该步骤中,确定出第一分类器识别出的所有第一识别结果均与多个第一标签相同,说明多个第一识别结果准确无误,即特征提取器提取出的与病灶相关的特征数据足够全面,且准确性高,则无需对特征提取器进行优化。
通过上述方式,基于第一分类器识别出的与病灶类型相关的识别结果,对特征提取器进行优化,以实现特征提取器与第一分类器的同目标训练,提高特征提取器提取出的与病灶相关特征的精准性。
S210、将多个采集图像的多个特征向量集输入第二分类器,生成多个第二识别结果。
在该步骤中,特征提取器提取出的特征向量中,不仅包含了与病灶有关的特征,还可能包含了与病灶无关的图像特征,例如,图像像素、图像偏黄等色差特征,为了提高联邦学习的病灶识别模型的准确性,需要确保各个客户端提供的模型参数的精准度,则需要对特征提取器中提取的特征向量进行验证,以判断特征向量集中是否包含与病灶无关的图像特征,进而根据验证结果对特征提取器进行优化。具体地,将每个采集图像的特征向量集输入第二分类器,根据第二分类器的输出,得到是否识别出特征向量集出自哪个客户端的第二识别结果。
S211、将多个第二识别结果和多个第二标签进行比较,判断多个第二识别结果中,是否包含与第二标签相同的识别结果,若是,进入步骤S212,若否,进入步骤S214。
在该步骤中,将每个第二识别结果与其对应的第二标签进行比较,判断第二识别结果与第二标签是否相同。若多个第二识别结果中,存在与第二标签相同的识别结果,说明第二分类器识别出特征向量集来自哪个客户端,则确定特征向量集中包含了该客户端的图像特征,则需要对特征提取器进行优化。进一步地,若多个第二识别结果与第二标签均不相同,说明第二分类器无法识别出特征向量集来自哪个客户端,此时确定特征向量集中并未包含该客户端的图像特征。
S212、若是,在多个第二识别结果中,获取与第二标签相同的至少一个第二目标识别结果,以及每个第二目标识别结果对应的第二目标特征向量集。
S213、根据至少一个第二目标识别结果和至少一个第二目标特征向量集,对特征提取器进行训练。
在步骤S212和S213中,在确定出第二分类器识别出的第二识别结果中,包含了与第二标签相同的识别结果,即特征向量集中包含了该客户端的图像特征时,在多个第二识别结果中,调取出与第二标签相同的至少一个第二目标识别结果,并基于每个第二目标识别结果结果,确定其对应的第二目标特征向量集。其后,将每个第二目标识别结果和其对应的第二目标特征向量集反馈给特征提取器,使得特征提取器根据反馈的至少一个第二目标识别结果和至少一个第二目标特征向量集进行优化。
可选地,当特征提取器优化基于反馈的至少一个第二目标识别结果和至少一个第二目标特征向量集进行优化后,利用优化后的特征提取器再次提取每个采集图像的特征向量集,将特征向量集输入第二分类器中,得到第二识别结果,并对第二识别结果与第二标签进行比较,若第二识别结果均与第二标签均不同,停止对特征提取器的优化;若任一第二识别结果与第二标签相同,继续对特征提取器进行优化,直至第二识别结果均与第二标签均不同为止。
S214、无需对特征提取器进行优化。
在该步骤中,确定出第二分类器识别出的第二识别结果均与第二标签不同,即特征提取器提取出的特征向量集中未包含与病灶无关的图像特征,则无需对特征提取器进行优化。
通过上述方式,基于第二分类器识别出的与特征出处相关的识别结果,对特征提取器进行优化,以实现特征提取器与第一分类器的对抗训练,使得特征提取器不会提取出的与病灶无关的特征,一方面,提高特征提取器提取与病灶相关特征的精准性;另一方面,使得客户端的图像特征不会流出该客户端,保障了数据的隐私。
S215、根据预设数量,在多个客户端中,确定多个目标客户端。
在该步骤中,联邦学习可以通过不同医疗机构的数据联合训练一个病灶识别模型,具体地,联邦学习中包含多个医疗机构的客户端,每个客户端与服务端通信连接。整个联邦学习分为多个通讯轮次,在任一通讯轮次过程中,每个客户端基于本地数据对本地的预设模型进行训练,其后,将训练后的模型的参数统一发往服务端,经由服务端进行聚合再向各个医疗机构的客户端进行分发,进行下一轮训练。然而,在整个联邦学习过程中,若每一轮训练,每个客户端均与服务端进行数据交互,客户端需要与服务端进行频繁的交互通信,其通信效率也制约着联邦学习病灶识别模型训练的效率。且通信成本较高。因此,为了提高联邦学习效率,在每轮训练过程中,在多个客户端中选择部分客户端进行交互,减少单次通信开销和整体通信次数。具体地,对应任一通讯轮次,在多个客户端中,按照预设数量随机选取多个目标客户端,作为本轮模型参数传输的客户端。
可选地,预设数量的范围为整体客户端数量的10%至50%,具体数量可根据本轮通信信号确定,本申请在此不做具体先动。进一步地,当本轮训练结束后,对本轮次选择的多个目标客户端进行标记,在下一轮训练时,在标记的多个目标客户端外的其余客户端中再次随机选择。
S216获取多个目标客户端的预设模型的多个第一模型参数、多个第二模型参数和多个第三模型参数。
S217将多个第一模型参数、多个第二模型参数和多个第三模型参数发送至服务端。
在步骤S216和S217中,在多个客户端中筛选出多个目标客户端后,获取每个目标客户端的预设模型的第一模型参数、第二模型参数和第三模型参数。其中,第一模型参数为训练好的特征提取器的参数,第二模型参数为训练好的第一分类器的参数,第三模型参数为训练好的第二分类器的参数。其后,将多个目标客户端的多个第一模型参数、多个第二模型参数和多个第三模型参数发送至服务端,以供服务端对接收到的模型参数进行聚合、分发。
可选地,服务端接收到多个第一模型参数、多个第二模型参数和多个第三模型参数后,利用加权平均的方式分别对多个第一模型参数、多个第二模型参数和多个第三模型参数进行计算,以得到新的第四模型参数、第五模型参数和第六模型参数。
通过上述方式,选取部分客户端的模型参数进行传输,降低了通信时间,提高了传输效率,提高了联邦学习病灶识别模型训练的效率。
S218、接收服务端发送的第四模型参数、第五模型参数和第六模型参数。
S219、根据第四模型参数,对特征提取器进行更新。
S220、根据第五模型参数,对第一分类器进行更新。
S221、根据第六模型参数,对第二分类器进行更新。
在步骤S218至S221中,服务端基于接收到的多个第一模型参数、多个第二模型参数和多个第三模型参数进行加权平均计算后,得到第四模型参数、第五模型参数和第六模型参数。其后,将第四模型参数、第五模型参数和第六模型参数分别发送给每个客户端。每个客户端在接收到服务端发送的第四模型参数、第五模型参数和第六模型参数后,基于第四模型参数,对特征提取器进行更新,基于第五模型参数,对第一分类器进行更新,基于第六模型参数,对第二分类器进行更新。
通过上述方式,在各个客户端利用本地数据进行深度学习模型的训练的同时,引入服务端负责模型的聚合与分发,实现了对病灶识别模型的训练的同时,保障了各个医疗机构的数据隐私。
作为一种实施方式,本申请实施例提供了一种模型训练方法,如图2所示,为模型训练方法的示意框图。其中,各个客户端各自将特征提取器与第一分类器进行同目标训练,同时,将特征提取器与第二分类器进行对抗训练,以得到训练后的第一模型参数、第二模型参数以及第三模型参数,其后,分别将选取的多个目标客户端的多个第一模型参数、多个第二模型参数以及多个第三模型参数发送至服务端,以供服务端对模型参数采用加权平均的方式进行参数聚合计算。其后,接收服务端发送的计算后的第四模型参数、第五模型参数和第六模型参数,分别对特征提取器、第一分类器以及第二分类器进行更新。通过上述方式,确保特征异构的联邦学习场景下的模型训练的精准度,有效避免特征异构情况下造成较大的模型表现退化。
进一步地,作为图1所述方法的具体实现,本申请实施例提供了一种模型训练装置400,如图3所示,该装置包括:
获取模块401,用于获取多个采集图像的多个图像数据、多个第一标签和多个第二标签;
确定模块402,用于根据预设模型的特征提取器,确定多个采集图像的多个特征向量集;
训练模块403,用于根据多个特征向量集和多个第一标签,对预设模型的第一分类器进行训练;
训练模块403,还用于根据多个图像数据和多个第二标签,对预设模型的第二分类器进行训练;
训练模块403,还用于根据多个特征向量集、多个第一标签、多个第二标签、第一分类器和第二分类器,对特征提取器进行训练。
可选地,确定模块402,具体用于:
依次将每个采集图像输入特征提取器,确定每个采集图像对应的特征向量集。
可选地,训练模块403,具体用于:
将多个采集图像的多个特征向量集作为输入项,多个第一标签作为输出项,对第一分类器的参数进行调整。
可选地,训练模块403,具体还用于:
将多个采集图像的多个图像数据作为输入项,多个第二标签作为输出项,对第二分类器的参数进行调整。
可选地,该装置还包括:
生成模块404,用于将多个采集图像的多个特征向量集输入第一分类器,生成多个第一识别结果;
判断模块405,用于将多个第一识别结果和多个第一标签进行比较,判断多个第一识别结果与多个第一标签是否均相同。
可选地,获取模块401,还用于若否,在多个第一识别结果中,获取与第一标签不相同的至少一个第一目标识别结果,以及每个第一目标识别结果对应的第一目标特征向量集。
可选地,训练模块403,还用于根据至少一个第一目标识别结果和至少一个第一目标特征向量集,对特征提取器进行训练。
可选地,生成模块404,还用于将多个采集图像的多个特征向量集输入第二分类器,生成多个第二识别结果。
可选地,判断模块405,还用于将多个第二识别结果和多个第二标签进行比较,判断多个第二识别结果中,是否包含与第二标签相同的识别结果。
可选地,获取模块401,还用于若是,在多个第二识别结果中,获取与第二标签相同的至少一个第二目标识别结果,以及每个第二目标识别结果对应的第二目标特征向量集。
可选地,训练模块403,还用于根据至少一个第二目标识别结果和至少一个第二目标特征向量集,对特征提取器进行训练。
可选地,确定模块402,还用于根据预设数量,在多个客户端中,确定多个目标客户端;
获取模块401,还用于获取多个目标客户端的预设模型的多个第一模型参数、多个第二模型参数和多个第三模型参数。
可选地,该装置还包括:
发送模块406,用于将多个第一模型参数、多个第二模型参数和多个第三模型参数发送至服务端。
可选地,该装置还包括:
接收模块407,用于接收服务端发送的第四模型参数、第五模型参数和第六模型参数;
更新模块408,用于根据第四模型参数,对特征提取器进行更新;
更新模块408,还用于根据第五模型参数,对第一分类器进行更新;
更新模块408,还用于根据第六模型参数,对第二分类器进行更新。
本申请实施例提供的模型训练装置400,相较于现有技术的联邦学习过程中,各个客户端将样本图像内提取的全部特征作为模型训练数据进行模型训练的方式,使得特征向量集中参杂与病灶无关的图像特征,进而导致客户端的训练模型性能较差,使得联邦学习精度较低的技术问题,本申请提出了通过将特征提取器与第一分类器相结合进行同目标训练的同时,将特征提取器与第二分类器进行对抗训练,提高特征提取器提取与病灶有关的特征的精准度,进而使得各个客户端提供的病灶识别模型的模型参数准确性较高,有效提高了联邦学习的精确度。
在示例性实施例中,本申请还提供了一种计算机设备,包括存储器和处理器。该存储器存储有计算机程序,处理器,用于执行存储器上所存放的程序,执行上述实施例中的模型训练方法。
在示例性实施例中,本申请还提供了一种可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现所述的模型训练方法的步骤。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到本申请可以通过硬件实现,也可以借助软件加必要的通用硬件平台的方式来实现。基于这样的理解,本申请的技术方案可以以软件产品的形式体现出来,该软件产品可以存储在一个非易失性存储介质(可以是CD-ROM,U盘,移动硬盘等)中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施场景所述的方法。
本领域技术人员可以理解附图只是一个优选实施场景的示意图,附图中的模块或流程并不一定是实施本申请所必须的。
本领域技术人员可以理解实施场景中的装置中的模块可以按照实施场景描述进行分布于实施场景的装置中,也可以进行相应变化位于不同于本实施场景的一个或多个装置中。上述实施场景的模块可以合并为一个模块,也可以进一步拆分成多个子模块。
上述本申请序号仅仅为了描述,不代表实施场景的优劣。
以上公开的仅为本申请的几个具体实施场景,但是,本申请并非局限于此,任何本领域的技术人员能思之的变化都应落入本申请的保护范围。
Claims (9)
1.一种模型训练方法,其特征在于,包括:
获取多个采集图像的多个图像数据、多个第一标签和多个第二标签;
根据预设模型的特征提取器,确定所述多个采集图像的多个特征向量集;
根据所述多个特征向量集和所述多个第一标签,对预设模型的第一分类器进行训练;
根据所述多个图像数据和所述多个第二标签,对预设模型的第二分类器进行训练;
根据所述多个特征向量集、所述多个第一标签、所述多个第二标签、所述第一分类器和所述第二分类器,对所述特征提取器进行训练;
所述根据所述多个特征向量集、所述多个第一标签、所述多个第二标签、所述第一分类器和所述第二分类器,对所述特征提取器进行训练的步骤,具体包括:
将所述多个采集图像的所述多个特征向量集输入所述第一分类器,生成多个第一识别结果;
将所述多个第一识别结果和所述多个第一标签进行比较,判断所述多个第一识别结果与所述多个第一标签是否均相同;
若否,在所述多个第一识别结果中,获取与第一标签不相同的至少一个第一目标识别结果,以及每个第一目标识别结果对应的第一目标特征向量集;
根据所述至少一个第一目标识别结果和至少一个第一目标特征向量集,对所述特征提取器进行训练;
将所述多个采集图像的所述多个特征向量集输入所述第二分类器,生成多个第二识别结果;
将所述多个第二识别结果和所述多个第二标签进行比较,判断所述多个第二识别结果中,是否包含与第二标签相同的识别结果;
若是,在所述多个第二识别结果中,获取与第二标签相同的至少一个第二目标识别结果,以及每个第二目标识别结果对应的第二目标特征向量集;
根据所述至少一个第二目标识别结果和至少一个第二目标特征向量集,对所述特征提取器进行训练。
2.根据权利要求1所述的方法,其特征在于,所述根据预设模型的特征提取器,确定所述多个采集图像的多个特征向量集的步骤,具体包括:
依次将每个采集图像输入所述特征提取器,确定所述每个采集图像对应的特征向量集。
3.根据权利要求1所述的方法,其特征在于,所述根据所述多个特征向量集和所述多个第一标签,对预设模型的第一分类器进行训练的步骤,具体包括:
将所述多个采集图像的所述多个特征向量集作为输入项,所述多个第一标签作为输出项,对所述第一分类器的参数进行调整。
4.根据权利要求3所述的方法,其特征在于,所述根据所述多个图像数据和所述多个第二标签,对预设模型的第二分类器进行训练的步骤,具体包括:
将所述多个采集图像的所述多个图像数据作为输入项,所述多个第二标签作为输出项,对所述第二分类器的参数进行调整。
5.根据权利要求1至4中任一项所述的方法,其特征在于,应用于多个客户端,每个客户端与服务端通信连接,所述方法还包括:
根据预设数量,在所述多个客户端中,确定多个目标客户端;
获取多个目标客户端的预设模型的多个第一模型参数、多个第二模型参数和多个第三模型参数;
将所述多个第一模型参数、所述多个第二模型参数和所述多个第三模型参数发送至所述服务端。
6.根据权利要求5所述的方法,其特征在于,所述将所述多个第一模型参数、所述多个第二模型参数和所述多个第三模型参数发送至所述服务端之后,还包括:
接收所述服务端发送的第四模型参数、第五模型参数和第六模型参数;
根据所述第四模型参数,对特征提取器进行更新;
根据所述第五模型参数,对第一分类器进行更新;
根据所述第六模型参数,对第二分类器进行更新。
7.一种模型训练装置,其特征在于,包括:
获取模块,用于获取多个采集图像的多个图像数据、多个第一标签和多个第二标签;
确定模块,用于根据预设模型的特征提取器,确定所述多个采集图像的多个特征向量集;
训练模块,用于根据所述多个特征向量集和所述多个第一标签,对预设模型的第一分类器进行训练;
所述训练模块,还用于根据所述多个图像数据和所述多个第二标签,对预设模型的第二分类器进行训练;
所述训练模块,还用于根据所述多个特征向量集、所述多个第一标签、所述多个第二标签、所述第一分类器和所述第二分类器,对所述特征提取器进行训练;
生成模块,用于将多个采集图像的多个特征向量集输入第一分类器,生成多个第一识别结果;
判断模块,用于将多个第一识别结果和多个第一标签进行比较,判断多个第一识别结果与多个第一标签是否均相同;
所述获取模块,还用于若否,在多个第一识别结果中,获取与第一标签不相同的至少一个第一目标识别结果,以及每个第一目标识别结果对应的第一目标特征向量集;
所述训练模块,还用于根据至少一个第一目标识别结果和至少一个第一目标特征向量集,对特征提取器进行训练;
所述生成模块,还用于将多个采集图像的多个特征向量集输入第二分类器,生成多个第二识别结果;
所述判断模块,还用于将多个第二识别结果和多个第二标签进行比较,判断多个第二识别结果中,是否包含与第二标签相同的识别结果;
所述获取模块,还用于若是,在多个第二识别结果中,获取与第二标签相同的至少一个第二目标识别结果,以及每个第二目标识别结果对应的第二目标特征向量集;
所述训练模块,还用于根据至少一个第二目标识别结果和至少一个第二目标特征向量集,对特征提取器进行训练。
8.一种计算机设备,包括存储器和处理器,存储器存储有计算机程序,其特征在于,处理器执行计算机程序时实现权利要求1至6中任一项方法的步骤。
9.一种可读存储介质,其上存储有计算机程序,其特征在于,计算机程序被处理器执行时实现权利要求1至6中任一项方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310677518.8A CN116452922B (zh) | 2023-06-09 | 2023-06-09 | 模型训练方法、装置、计算机设备及可读存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310677518.8A CN116452922B (zh) | 2023-06-09 | 2023-06-09 | 模型训练方法、装置、计算机设备及可读存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116452922A CN116452922A (zh) | 2023-07-18 |
CN116452922B true CN116452922B (zh) | 2023-09-22 |
Family
ID=87133990
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310677518.8A Active CN116452922B (zh) | 2023-06-09 | 2023-06-09 | 模型训练方法、装置、计算机设备及可读存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116452922B (zh) |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021139309A1 (zh) * | 2020-07-31 | 2021-07-15 | 平安科技(深圳)有限公司 | 人脸识别模型的训练方法、装置、设备及存储介质 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN113609521A (zh) * | 2021-07-27 | 2021-11-05 | 广州大学 | 一种基于对抗训练的联邦学习隐私保护方法及系统 |
CN114565807A (zh) * | 2022-03-03 | 2022-05-31 | 腾讯科技(深圳)有限公司 | 训练目标图像检索模型的方法和装置 |
CN115731424A (zh) * | 2022-12-03 | 2023-03-03 | 北京邮电大学 | 基于强化联邦域泛化的图像分类模型训练方法及系统 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN108875821A (zh) * | 2018-06-08 | 2018-11-23 | Oppo广东移动通信有限公司 | 分类模型的训练方法和装置、移动终端、可读存储介质 |
-
2023
- 2023-06-09 CN CN202310677518.8A patent/CN116452922B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021139309A1 (zh) * | 2020-07-31 | 2021-07-15 | 平安科技(深圳)有限公司 | 人脸识别模型的训练方法、装置、设备及存储介质 |
CN113609521A (zh) * | 2021-07-27 | 2021-11-05 | 广州大学 | 一种基于对抗训练的联邦学习隐私保护方法及系统 |
CN113505797A (zh) * | 2021-09-09 | 2021-10-15 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备和存储介质 |
CN114565807A (zh) * | 2022-03-03 | 2022-05-31 | 腾讯科技(深圳)有限公司 | 训练目标图像检索模型的方法和装置 |
CN115731424A (zh) * | 2022-12-03 | 2023-03-03 | 北京邮电大学 | 基于强化联邦域泛化的图像分类模型训练方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN116452922A (zh) | 2023-07-18 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108229321B (zh) | 人脸识别模型及其训练方法和装置、设备、程序和介质 | |
CN108256591B (zh) | 用于输出信息的方法和装置 | |
JP6994588B2 (ja) | 顔特徴抽出モデル訓練方法、顔特徴抽出方法、装置、機器および記憶媒体 | |
CN110969202B (zh) | 基于颜色分量和感知哈希算法的人像采集环境验证方法及系统 | |
CN111401193B (zh) | 获取表情识别模型的方法及装置、表情识别方法及装置 | |
CN110717555B (zh) | 一种基于自然语言和生成对抗网络的图片生成系统及装置 | |
CN113568900A (zh) | 基于人工智能的大数据清洗方法及云服务器 | |
CN112949456B (zh) | 视频特征提取模型训练、视频特征提取方法和装置 | |
CN112200862B (zh) | 目标检测模型的训练方法、目标检测方法及装置 | |
CN116452922B (zh) | 模型训练方法、装置、计算机设备及可读存储介质 | |
CN112115994A (zh) | 图像识别模型的训练方法、装置、服务器及存储介质 | |
CN114693554B (zh) | 一种大数据图像处理方法及系统 | |
CN108229320B (zh) | 选帧方法和装置、电子设备、程序和介质 | |
EP4068163A1 (en) | Using multiple trained models to reduce data labeling efforts | |
CN115909335A (zh) | 一种商品标注方法及装置 | |
CN115146191A (zh) | 基于ai进行视频监控资产识别的方法、装置及电子设备 | |
CN111160330B (zh) | 电子标签识别辅助提升图像识别准确度的训练方法 | |
CN114241253A (zh) | 违规内容识别的模型训练方法、系统、服务器及存储介质 | |
CN114519416A (zh) | 模型蒸馏方法、装置及电子设备 | |
CN110414845B (zh) | 针对目标交易的风险评估方法及装置 | |
CN112232380A (zh) | 一种神经网络鲁棒性检测方法和装置 | |
CN113076983A (zh) | 一种图像的识别方法和装置 | |
CN113762382B (zh) | 模型的训练及场景识别方法、装置、设备及介质 | |
CN115082574B (zh) | 网络模型训练方法和脏器超声切面编码生成方法、装置 | |
CN116245962B (zh) | 用于无线传输到区块链服务器的数据提取系统及方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |