CN115761408A - 一种基于知识蒸馏的联邦域适应方法及系统 - Google Patents

一种基于知识蒸馏的联邦域适应方法及系统 Download PDF

Info

Publication number
CN115761408A
CN115761408A CN202211475594.2A CN202211475594A CN115761408A CN 115761408 A CN115761408 A CN 115761408A CN 202211475594 A CN202211475594 A CN 202211475594A CN 115761408 A CN115761408 A CN 115761408A
Authority
CN
China
Prior art keywords
model
loss
data
teacher
models
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211475594.2A
Other languages
English (en)
Inventor
肖云鹏
朱海鹏
李暾
李茜
庞育才
王蓉
唐飞
王国胤
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Chongqing University of Post and Telecommunications
Original Assignee
Chongqing University of Post and Telecommunications
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Chongqing University of Post and Telecommunications filed Critical Chongqing University of Post and Telecommunications
Priority to CN202211475594.2A priority Critical patent/CN115761408A/zh
Publication of CN115761408A publication Critical patent/CN115761408A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Abstract

本发明属于数据安全技术领域,具体涉及一种基于知识蒸馏的联邦域适应方法及系统;该方法包括:多个医疗机构作为客户端采集数据库中的医疗图像,服务器采集本地医疗数据库中的医疗图像;将客户端中的数据作为源域数据,服务器中的数据作为目标域数据;构建基于医疗图像分类模型的联邦学习模型;根据源域数据和目标域数据对联邦学习模型进行医疗图像分类模型训练和对比学习,得到训练好的全局模型;服务器采集目标医疗机构的医疗图像并将其输入到全局模型中,得到医疗图像分类结果;本发明预测精度高,对用户数据隐私保护性好,具有良好的应用前景。

Description

一种基于知识蒸馏的联邦域适应方法及系统
技术领域
本发明属于数据安全技术领域,具体涉及一种基于知识蒸馏的联邦域适应方法及系统。
背景技术
近年来,我们见证了机器学习在人工智能应用领域的迅猛发展,人工智能热潮在短短几年内席卷至各大领域成为研究热门,大数据医疗、大数据金融、智慧城市等服务孕育而生。而这些技术的成功,尤其是深度学习,无一不是基于海量数据基础之上的。然而在实际情况中,人们发现在很多应用领域无法满足如此大规模的数据量。且伴随着社会不断发展,现代社会已经逐渐意识到了数据以及隐私安全的重要性,例如在,大数据金融以及大数据医疗领域,对用户的隐私保护需求较高,这使得这些企业、机构之间数据难以流通,想要得到一份高质量、大数量的训练数据,不得不面对难以桥接的“数据孤岛”现象。为了解决以上现象,联邦学习理论被提出。在传统分布式的基础上构思,提出了数据不动模型动的思想,在数据不出域的场景下进行安全的学习,解决了“数据孤岛”问题,充分发掘了分布在各处的数据的潜在价值。
与此同时,为了避免每次对深度学习中所需的海量数据进行标签,一部分的工作选择使用相似的数据集去训练目标模型。无监督域适应的技术则致力于提升在目标场景的模型性能,且有时需要多个源域数据集来提升目标模型的性能,无监督多源域适应通过建立从多个源域到无标记目标域的可转移特征来解决此类问题。
近年来,越来越多的研究者对联邦学习的场景下进行多源域适应进行研究。研究的方面主要有利用模型对抗训练、计算域最大平均差异以及知识蒸馏方法。基于对抗训练的思想就是在特征空间中应用对抗性训练优化源域与目标域之间的H-散度。基于最大平均差异方法则是通过构建一个可复制核特征空间,然后通过缩小最大平均差异距离来优化H-散度。基于知识蒸馏方法通过教师-学生策略将知识提炼扩展到域适应场景中,在源域中训练多个教师模型,然后在目标域上集成它们训练一个学生模型。
针对数据隐私环境下多源域适应无法直接获取源域数据问题,考虑到知识蒸馏允许只访问模型即可获取知识的特点,采用知识蒸馏的方式从多个源域获取知识。现有一种基于知识投票的多源模型知识蒸馏方法,用以获取高质量的域共识知识。然后定义每个源域所贡献共识知识的质量,并得到一个可以识别无关域与恶意域的指标。最后利用深度学习模型中的正则化归一层所记录的特征滑动均值与方差,提出了BatchNorm MMD距离。但该方法在图像数据集中表现效能表现欠佳,且训练过程中同一时间源域与目标域一方必须闲置,训练效率较低;本发明提出了一种基于知识蒸馏的联邦域适应方法,引入了针对域数据质量参差问题的多教师师置信度知识蒸馏方法以及对比学习的思想,不仅可以一定程度上提高在目标域上的准确度,还可以识别一些不相关的源域和恶意源域,提高医疗图像分类模型的分类准确性。
发明内容
针对现有技术存在的不足,本发明提出了一种基于基于知识蒸馏的联邦域适应方法及系统,该方法包括:
S1:多个医疗机构作为客户端采集数据库中的医疗图像,服务器采集本地医疗数据库中的医疗图像;将客户端中的数据作为源域数据,服务器中的数据作为目标域数据;
S2:构建基于医疗图像分类模型的联邦学习模型;
S3:根据源域数据和目标域数据对联邦学习模型进行医疗图像分类模型训练和对比学习,得到训练好的全局模型;
S4:服务器采集目标医疗机构的医疗图像并将其输入到全局模型中,得到医疗图像分类结果。
优选的,对联邦学习模型进行医疗图像分类模型训练的过程包括:
S31:根据源域数据训练医疗图像分类模型,得到初始源域模型;
S32:根据初始源域模型对目标域数据进行知识投票,得到高质量的知识共识;
S33:根据高质量的知识共识扩展源域,得到扩展源域数据;根据扩展源域数据训练医疗图像分类模型,得到扩展源域模型;
S34:根据目标域数据对所有源域模型进行置信度多教师知识蒸馏,训练得到适应于目标域的学生模型。
进一步的,对初始源域模型进行知识投票的过程包括:
将目标域数据输入到初始源域模型中,得到输出结果;计算每个模型输出结果属于不同类别的置信度;根据置信度采用高阶置信度门进行过滤处理,去除不自信的模型;
将剩余模型按照输出结果所属类别进行计数,将计数最多的类别作为共识类,去除与共识类不一致的模型,得到共识模型;
计算所有共识模型输出结果为相同类别的置信度均值,将置信度均值作为共识模型的共识知识,将共识模型数量作为每个共识模型的共识权重;
若高阶置信度门过滤了所有模型,则将所有模型输出结果的置信度均值作为共识知识,并为其分配一个低共识权重。
进一步的,对源域模型进行置信度多教师知识蒸馏的过程包括:
将初始源域模型和扩展源域模型作为教师模型,采用目标域数据对教师模型进行置信度多教师知识蒸馏,得到预测结果;根据每个教师模型的预测结果计算第二交叉熵损失;根据第二交叉熵损失计算所有教师模型的第一权重;根据所有教师模型的第一权重和第二交叉熵损失计算标签损失;
根据教师模型分类层中的学生特征向量计算第三交叉熵损失;根据第三交叉熵损失计算所有教师模型的第二权重;根据所有教师模型的第二权重和第三交叉熵损失计算传递损失;
若目标域数据不存在标签,则根据标签损失和传递损失计算总体损失;若目标域数据存在部分标签,则根据教师模型的预测结果计算常规交叉熵损失,根据常规交叉熵损失、标签损失和传递损失计算总体损失;
根据总体损失指导学生模型训练,得到训练好的学生模型。
进一步的,计算标签损失的公式为:
Figure BDA0003959697600000041
其中,LKD表示标签损失,
Figure BDA0003959697600000042
表示第一权重,M表示教师模型的数量,
Figure BDA0003959697600000043
表示第k个教师模型输出结果为类别c的置信度,
Figure BDA0003959697600000044
表示第二交叉熵损失,
Figure BDA0003959697600000045
表示学生模型输出结果类别c的置信度。
进一步的,计算传递损失的公式为:
Figure BDA0003959697600000046
其中,Linter表示传递损失,
Figure BDA0003959697600000047
表示第二权重,M表示教师模型的数量,
Figure BDA0003959697600000048
表示第k个教师模型提取的特征,r(FS)表示学生模型提取的特征。
优选的,对联邦学习模型进行对比学习的过程包括:
获取输入数据在本地模型输出层前的网络中的映射表征向量、上一轮次本地训练好后发送给服务器的模型输出的表征向量、当前轮次服务器发送给本地的全局模型输出的表征向量;
根据三种表征向量进行对比学习,计算本地模型的对比学习损失和监督学习交叉熵损失;根据对比训练损失和监督学习交叉熵损失计算总体学习损失;根据总体学习损失调整本地模型的参数,得到训练好的局部模型;
服务器聚合训练好的局部模型,得到训练好的全局模型。
一种基于知识蒸馏的联邦域适应系统,包括:数据采集模块、模型训练模块和分类模块;
所述数据采集模块用于为各终端采集训练医疗图像数据或待分类医疗图像数据;
所述模型训练模块用于根据医疗图像数据训练基于医疗图像分类模型的联邦学习模型,得到全局模型;
所述分类模块用于使用全局模型对待分类医疗图像进行分类,输出分类结果。
本发明的有益效果为:本发明在联邦环境下对模型进行域适应迁移训练,保障了数据安全;采用了知识蒸馏的方式进行训练,相比其他联邦迁移学习方法一定程度上减少了通信量;针对图像分类问题进行了模型对比学习处理,以及引入知识投票、置信度多教师知识蒸馏方法,提升了最终目标域医疗图像分类模型的性能,从而提高了分类结果的准确性。
附图说明
图1为本发明中基于知识蒸馏的域适应方法结构示意图;
图2为本发明中联邦学习模型学习过程示意图;
图3为本发明中知识投票过程示意图;
图4为本发明中置信度多教师蒸馏模型图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明提出了一种基于知识蒸馏的联邦域适应方法及系统,如图1所示,所述方法包括以下内容:
S1:多个医疗机构作为客户端采集数据库中的医疗图像,服务器采集本地医疗数据库中的医疗图像;将客户端中的数据作为源域数据,服务器中的数据作为目标域数据。优选的,该医疗图像可为胸腔X射线图像。
获取数据的方式可以是直接查询多个医疗机构提供的医疗图像,医疗图像带有标签,该标签为医疗图像可划分的具体类别;对于胸腔X射线图像而言,标签包括“无肺炎”、“病毒性肺炎”、“病菌性肺炎”3种。
每个域都包含了一定的域特征信息,其最主要的特征就是每个域的数据分布。且在无监督多源域适应场景中,存在K个源域且源域数据有标签即K个医疗机构数据库中的带标签的医疗图像,其中源域表示为
Figure BDA0003959697600000061
因此可符号化源域信息为:
Figure BDA0003959697600000062
其中,Nk表示源域的数据数量,
Figure BDA0003959697600000063
表示源域中第i个用户的医疗图片,
Figure BDA0003959697600000064
表示第i个用户的肺炎标签,。
目标域可带标签和不带标签,若不带标签,其目标域可表示为DT,符号化目标域信息为:
Figure BDA0003959697600000065
其中,
Figure BDA0003959697600000066
表示目标域中第i个用户的医疗图像,NT表示目标域的数据数量。
S2:构建基于医疗图像分类模型的联邦学习模型。
如图2所示,迭代上述过程,可得到需要的全局模型;
S3:根据源域数据和目标域数据对联邦学习模型进行医疗图像分类模型训练和对比学习,得到训练好的全局模型。
联邦学习模型在进行迭代训练过程中,对联邦学习模型进行医疗图像分类模型训练的过程包括:
服务器将基于医疗图像分类模型作为全局模型发送给各个客户端,在客户端中,将医疗图像分类模型作为局部模型并采用医疗机构数据库中的医疗图像进行本地训练。
S31:根据源域数据训练医疗图像分类模型,得到初始源域模型;
优选的,医疗图像分类模型可使用ResNet模型、CNN模型等模型。
S32:根据初始源域模型对目标域数据进行知识投票,得到高质量的知识共识;
客户端将训练好的初始源域模型发送给服务器,服务器采用本地医疗数据库中的医疗图像对对初始源域模型进行知识投票,其过程如下:
如图3所示,将目标域数据输入到初始源域模型中,得到输出结果;计算每个模型输出结果即图像分类结果(无肺炎图像、病毒性肺炎图像、细菌性肺炎图像)的置信度;设置置信度阈值,对于每个用户医疗图像信息
Figure BDA0003959697600000071
计算各模型输出结果
Figure BDA0003959697600000072
的置信度,根据各模型输出结果的置信度采用高阶置信度门进行过滤处理,去除不置信的模型即去除置信度较低的模型;
将剩余模型按照输出结果所属类别进行计数,将计数最多的类别作为共识类,去除与共识类不一致的模型(模型输出结果为共识类的置信度非最高值),得到共识模型;其中,判断模型所属类别的方法是判断模型输出结果属于某一类别的置信度是否最高,若是,则判断模型属于该类别。
计算所有共识模型输出结果为相同类别的置信度均值,将置信度均值作为共识模型的共识知识pi,将共识模型数量作为每个共识模型的共识权重
Figure BDA0003959697600000073
若高阶置信度门过滤了所有模型,则将所有模型输出结果的置信度均值作为共识知识,并为其分配一个低的共识权重,例如,共识权重取0.001。
S33:根据高质量的知识共识扩展源域,得到扩展源域数据;根据扩展源域数据训练医疗图像分类模型,得到扩展源域模型;其中,扩展源域数据
Figure BDA0003959697600000074
表示为:
Figure BDA0003959697600000075
根据扩展源域数据训练医疗图像分类模型包括:
根据扩展源域模型的输出结果以及每个知识投票筛选出的初始源域模型的共识知识、共识权重计算知识蒸馏损失;其公式为:
Figure BDA0003959697600000081
其中,
Figure BDA0003959697600000088
表示第一交叉损失,即扩展源域模型输出结果与共识知识的交叉熵损失,
Figure BDA0003959697600000082
表示扩展域模型的输出结果,pi表示第i条输入数据的共识知识,
Figure BDA0003959697600000083
表示第i条输入数据的共识权重。
根据知识蒸馏损失调整模型参数,得到训练好的扩展源域模型。
与其他集成策略相比,知识投票策略对置信度高、支持域多的类别赋予了较高的权重,使模型学习到高质量的共识知识,避免了受到一些无关域和恶意域的影响,提升蒸馏出的模型性能。
S34:根据目标域数据对所有源域模型进行置信度多教师知识蒸馏,训练得到适应于目标域的学生模型。
对所有源域模型进行置信度多教师知识蒸馏的过程包括:
将初始源域模型和扩展源域模型作为教师模型,采用目标域数据对教师模型进行置信度多教师知识蒸馏,得到预测结果;
如图4所示,对源域模型进行置信度多教师知识蒸馏,目标域数据输入到特征提取层,经过一系列卷积池化处理,得到特征向量,对特征向量输入softmax层进行降维处理,得到一个一维向量,作为输出类别预测结果。对模型的输出置信度进行分析,根据置信度对学生模型与教师模型的交叉熵软标签损失进行加权,同时也让学生模型学习了教师模型特征提取层,最终得到一个学生模型。
源域模型在每个
Figure BDA0003959697600000084
上的预测结果:
Figure BDA0003959697600000085
Figure BDA0003959697600000086
根据每个教师模型的预测结果计算第二交叉熵损失;计算第二交叉熵损失
Figure BDA0003959697600000087
的公式为:
Figure BDA0003959697600000091
其中,
Figure BDA0003959697600000092
表示第k个教师的输出结果标签为c的置信度,τ表示温度系数。
为了有效地聚合多个教师的预测分布,通过计算教师预测和共识标签之间的交叉熵损失来分配不同的权重,以反映其样本置信度;根据第二交叉熵损失计算所有教师模型的第一权重;计算教师模型的第一权重的公式为:
Figure BDA0003959697600000093
Figure BDA0003959697600000094
其中,
Figure BDA0003959697600000095
表示教师预测和共识标签之间的交叉熵损失,yc表示共识标签在类别c的置信度,
Figure BDA0003959697600000096
表示知识蒸馏时第k个教师模型的第一权重;
Figure BDA0003959697600000097
越小,
Figure BDA00039596976000000912
越大,教师的标签是由计算得到的权重聚合而成的,根据所有教师模型的第一权重和第二交叉熵损失计算标签损失:
Figure BDA0003959697600000098
其中,
Figure BDA0003959697600000099
表示学生模型的输出结果标签为c的置信度,M=K+1表示教师模型数量。
根据上述公式,预测更接近共识标签即共识类的教师将被分配更大的权重
Figure BDA00039596976000000910
因为它有足够的信心做出正确的判断,以获得正确的指导。相比之下,如果只是通过计算教师预测的熵来获得权重,那么当输出分布尖锐时,无论最高概率类别是否正确,权重都会变大。在这种情况下,这些有偏见的目标可能会误导学生的训练,并进一步损害其蒸馏性能。
除了标签损失函数外,模型中间层(特征提取层)也有利于学习结构知识,因此将方法扩展到中间层,以挖掘更多信息。中间特征
Figure BDA00039596976000000911
Figure BDA0003959697600000101
其中,
Figure BDA0003959697600000102
表示第k个教师的分类层,vs∈Rc是最后的特征提取层输出的学生特征向量,即vs=AvgPooling(Fs),FS表示学生模型特征提取层输出。
根据教师模型分类层(图像分类模型的最后一层即softmax层为分类层)中的学生特征向量计算第三交叉熵损失:
Figure BDA0003959697600000103
其中,
Figure BDA0003959697600000104
表示学生特征向量输入到教师分类层的输出标签为c的置信度,
Figure BDA0003959697600000105
表示学生特征向量输入到教师分类层的输出标签为c的交叉熵,
Figure BDA0003959697600000106
表示第k个教师模型分类层的交叉熵损失,通过每个教师分类层传递vs得到的。
根据第三交叉熵损失计算所有教师模型的第二权重:
Figure BDA0003959697600000107
其中,
Figure BDA0003959697600000108
Figure BDA0003959697600000109
为了稳定知识转移过程,本发明让学生在相似的特征空间中更专注于模仿教师,且
Figure BDA00039596976000001010
确实能够表示出教师分类层在学生特征空间中的可辨别性;根据所有教师模型的第二权重和第三交叉熵损失计算传递损失:
Figure BDA00039596976000001011
其中,r(·)是一个用于对齐学生和教师特征尺寸的函数,
Figure BDA00039596976000001012
表示教师模型特征提取层输出,FS表示学生模型特征提取层输出,L2损失用作中间特征的距离度量,在中间层的选择中,通常只采用最后一层的输出特征,以避免产生太多的计算开销。
若目标域数据不存在标签,则根据标签损失和传递损失计算总体损失:
L=αLKD+Limter
如果目标域存在少量标签,则总体损失函数除了上述两种损失,还要计算与真实标签的常规交叉熵;根据教师模型的预测结果计算常规交叉熵损失:
Figure BDA0003959697600000111
根据常规交叉熵损失、标签损失和传递损失计算总体损失:
L=LCE+αLKD+βLinter
其中,α和β是超参数,用于平衡知识蒸馏和标准交叉熵损失的影响。
根据总体损失指导学生模型训练,得到训练好的学生模型。
联邦学习模型按照步骤S2进行迭代训练,进行迭代训练过程中,当前轮次的学生模型将作为全局模型发送给客户端。在训练过程中,联邦学习模型将进行对比学习,对联邦学习模型进行对比学习的过程包括:
针对联邦学习中客户端数据域偏移问题,考虑到全局模型能够更加准确的提取出数据特征表示,引入对比学习的思想让全局模型参与指导本地模型训练,控制源域漂移,弥合局部模型学习到的表示与全局模型之间的偏差。模型对比学习主要针对于图2中第二步的客户端在本地更新局部模型。
对比学习的核心思想可以概括为“同类相聚,异类相斥”,将其思想体现在模型中则表现为相同类别输入的输出结果应该更加接近,不同类别输入之间的输出结果应该有较大的差异。更简单点讲对比学习可以认知到谁与谁相似,谁与谁不相似的特点。对应医疗图像中相同分类的特征应该相似,与不同分类的医疗图像中提取出的特征应该很不相似,以求找到它们之间的最大区分点。
获取输入数据在本地模型(训练中的源域模型)输出层前的网络中的映射表征向量
Figure BDA0003959697600000112
上一轮次训练完客户端发送给服务器的本地模型输出的表征向量
Figure BDA0003959697600000113
当前轮次服务器发送给客户端的全局模型输出的表征向量
Figure BDA0003959697600000114
其中,Rh(·)表示模型h在输出层前的网络,也就是说Rh(x)是输入x的映射表征向量。
全局模型相比于局部模型能够学习到更好的表示,所以希望让z在训练中一定程度上的靠近zglob且远离zprev。由此可以让z和zglob作为一对正样本,z和zprev作为一对负样本进行模型层面的对比学习。类似于SimCLR算法的损失函数,本地模型对比学习损失函数为:
Figure BDA0003959697600000121
其中,Lcon表示对比学习损失,sim(z,zglob)表示当前轮次轮本地模型输出得到的表征向量与下发的全局模型输出得到的表征向量的相似度,sim(z,zprev)表示当前轮次本地模型输出得到的表征向量与前一轮本地模型输出得到的表征向量的相似度,τ表示温度系数。
联邦学习本地域内还有监督学习交叉熵损失函数:
Figure BDA0003959697600000122
其中,Fh(x)表示模型h输入x得到的输出结果,也就是这张医疗图像属于具体某一类图像的分类结果;优选的,当医疗图像为胸腔X射线图像时,其输出结果为该图像无肺炎图像图像、细菌性肺炎图像、病毒性肺炎图像的可能性;y表示真实标签,即是否为肺炎图像及肺炎图像类型。
根据对比训练损失和监督学习交叉熵损失计算总体学习损失:
Figure BDA0003959697600000123
其中,μ表示超参数,控制对比学习损失权重。
根据总体学习损失调整本地模型的参数,得到训练好的局部模型。
服务器聚合训练好的局部模型,得到训练好的全局模型。
S4:服务器采集目标医疗机构的医疗图像并将其输入到全局模型中,得到医疗图像分类结果。
实时获取目标医疗机构的用户医疗图像数据,将其输入到全局模型中,可得到医疗图像分类结果。
本发明还提出了一种基于知识蒸馏的联邦域适应系统,包括:数据采集模块、模型训练模块和分类模块;
所述数据采集模块用于为各终端采集训练医疗图像数据或待分类医疗图像数据;可从各医疗机构数据库中获取医疗图像数据;
所述模型训练模块用于根据医疗图像数据训练基于医疗图像分类模型的联邦学习模型,得到全局模型;
所述分类模块用于使用全局模型对待分类医疗图像进行分类,输出分类结果并由显示器显示;分类结果包括图像属于哪一类别及其概率。
上述系统可执行基于知识蒸馏的联邦域适应方法,执行步骤与基于知识蒸馏的联邦域适应方法类似,此处不再赘述。
本发明通过基于知识蒸馏学习出的目标域模型输出结果,模型能够在目标场景表现出较高性能的指标,省去了对于目标域数据进行打标签这种耗时耗财的重复工作;训练过程满足联邦学习场景要求,其他医疗机构数据不出域,保证了用户信息安全,增大了各方合作提供源域数据的可能性。
需要说明的是,本领域普通技术人员可以理解实现上述方法实施例中的全部或部分流程,是可以通过计算机程序来指令相关的硬件来完成,所述的程序可存储于一计算机可读取存储介质中,该程序在执行时,可包括如上述各方法实施例的流程。其中,所述存储介质可为磁碟、光盘、只读存储记忆体(Read-0nly Memory,ROM)或随机存储记忆体(RandomAccess Memory,RAM)等。
以上所举实施例,对本发明的目的、技术方案和优点进行了进一步的详细说明,所应理解的是,以上所举实施例仅为本发明的优选实施方式而已,并不用以限制本发明,凡在本发明的精神和原则之内对本发明所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (8)

1.一种基于知识蒸馏的联邦域适应方法,其特征在于,包括:
S1:多个医疗机构作为客户端采集数据库中的医疗图像,服务器采集本地医疗数据库中的医疗图像;将客户端中的数据作为源域数据,服务器中的数据作为目标域数据;
S2:构建基于医疗图像分类模型的联邦学习模型;
S3:根据源域数据和目标域数据对联邦学习模型进行医疗图像分类模型训练和对比学习,得到训练好的全局模型;
S4:服务器采集目标医疗机构的医疗图像并将其输入到全局模型中,得到医疗图像分类结果。
2.根据权利要求1所述的一种基于知识蒸馏的联邦域适应方法,其特征在于,对联邦学习模型进行医疗图像分类模型训练的过程包括:
S31:根据源域数据训练医疗图像分类模型,得到初始源域模型;
S32:根据初始源域模型对目标域数据进行知识投票,得到高质量的知识共识;
S33:根据高质量的知识共识扩展源域,得到扩展源域数据;根据扩展源域数据训练医疗图像分类模型,得到扩展源域模型;
S34:根据目标域数据对所有源域模型进行置信度多教师知识蒸馏,训练得到适应于目标域的学生模型。
3.根据权利要求2所述的一种基于知识蒸馏的联邦域适应方法,其特征在于,对初始源域模型进行知识投票的过程包括:
将目标域数据输入到初始源域模型中,得到输出结果;计算每个模型输出结果属于不同类别的置信度;根据置信度采用高阶置信度门进行过滤处理,去除不自信的模型;
将剩余模型按照输出结果所属类别进行计数,将计数最多的类别作为共识类,去除与共识类不一致的模型,得到共识模型;
计算所有共识模型输出结果为相同类别的置信度均值,将置信度均值作为共识模型的共识知识,将共识模型数量作为每个共识模型的共识权重;
若高阶置信度门过滤了所有模型,则将所有模型输出结果的置信度均值作为共识知识,并为其分配一个低共识权重。
4.根据权利要求2所述的一种基于知识蒸馏的联邦域适应方法,其特征在于,对源域模型进行置信度多教师知识蒸馏的过程包括:
将初始源域模型和扩展源域模型作为教师模型,采用目标域数据对教师模型进行置信度多教师知识蒸馏,得到预测结果;根据每个教师模型的预测结果计算第二交叉熵损失;根据第二交叉熵损失计算所有教师模型的第一权重;根据所有教师模型的第一权重和第二交叉熵损失计算标签损失;
根据教师模型分类层中的学生特征向量计算第三交叉熵损失;根据第三交叉熵损失计算所有教师模型的第二权重;根据所有教师模型的第二权重和第三交叉熵损失计算传递损失;
若目标域数据不存在标签,则根据标签损失和传递损失计算总体损失;若目标域数据存在部分标签,则根据教师模型的预测结果计算常规交叉熵损失,根据常规交叉熵损失、标签损失和传递损失计算总体损失;
根据总体损失指导学生模型训练,得到训练好的学生模型。
5.根据权利要求4所述的一种基于知识蒸馏的联邦域适应方法,其特征在于,计算标签损失的公式为:
Figure FDA0003959697590000021
其中,LKD表示标签损失,
Figure FDA0003959697590000022
表示第一权重,M表示教师模型的数量,
Figure FDA0003959697590000023
表示第k个教师模型输出结果为类别c的置信度,
Figure FDA0003959697590000024
表示第二交叉熵损失,
Figure FDA0003959697590000025
表示学生模型输出结果类别c的置信度。
6.根据权利要求4所述的一种基于知识蒸馏的联邦域适应方法,其特征在于,计算传递损失的公式为:
Figure FDA0003959697590000031
其中,Linter表示传递损失,
Figure FDA0003959697590000032
表示第二权重,M表示教师模型的数量,
Figure FDA0003959697590000033
表示第k个教师模型提取的特征,r(FS)表示学生模型提取的特征。
7.根据权利要求1所述的一种基于知识蒸馏的联邦域适应方法,其特征在于,对联邦学习模型进行对比学习的过程包括:
获取输入数据在本地模型输出层前的网络中的映射表征向量、上一轮次本地训练好后发送给服务器的模型输出的表征向量、当前轮次服务器发送给本地的全局模型输出的表征向量;
根据三种表征向量进行对比学习,计算本地模型的对比学习损失和监督学习交叉熵损失;根据对比训练损失和监督学习交叉熵损失计算总体学习损失;根据总体学习损失调整本地模型的参数,得到训练好的局部模型;
服务器聚合训练好的局部模型,得到训练好的全局模型。
8.一种基于知识蒸馏的联邦域适应系统,其特征在于,包括:数据采集模块、模型训练模块和分类模块;
所述数据采集模块用于为各终端采集训练医疗图像数据或待分类医疗图像数据;
所述模型训练模块用于根据医疗图像数据训练基于医疗图像分类模型的联邦学习模型,得到全局模型;
所述分类模块用于使用全局模型对待分类医疗图像进行分类,输出分类结果。
CN202211475594.2A 2022-11-23 2022-11-23 一种基于知识蒸馏的联邦域适应方法及系统 Pending CN115761408A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211475594.2A CN115761408A (zh) 2022-11-23 2022-11-23 一种基于知识蒸馏的联邦域适应方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211475594.2A CN115761408A (zh) 2022-11-23 2022-11-23 一种基于知识蒸馏的联邦域适应方法及系统

Publications (1)

Publication Number Publication Date
CN115761408A true CN115761408A (zh) 2023-03-07

Family

ID=85336123

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211475594.2A Pending CN115761408A (zh) 2022-11-23 2022-11-23 一种基于知识蒸馏的联邦域适应方法及系统

Country Status (1)

Country Link
CN (1) CN115761408A (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116701939A (zh) * 2023-06-09 2023-09-05 浙江大学 一种基于机器学习的分类器训练方法及装置
CN117011563A (zh) * 2023-08-04 2023-11-07 山东建筑大学 基于半监督联邦学习的道路损害巡检跨域检测方法及系统

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116701939A (zh) * 2023-06-09 2023-09-05 浙江大学 一种基于机器学习的分类器训练方法及装置
CN116701939B (zh) * 2023-06-09 2023-12-15 浙江大学 一种基于机器学习的分类器训练方法及装置
CN117011563A (zh) * 2023-08-04 2023-11-07 山东建筑大学 基于半监督联邦学习的道路损害巡检跨域检测方法及系统
CN117011563B (zh) * 2023-08-04 2024-03-01 山东建筑大学 基于半监督联邦学习的道路损害巡检跨域检测方法及系统

Similar Documents

Publication Publication Date Title
CN115761408A (zh) 一种基于知识蒸馏的联邦域适应方法及系统
CN111414461A (zh) 一种融合知识库与用户建模的智能问答方法及系统
CN111598167B (zh) 基于图学习的小样本图像识别方法及系统
CN111753918A (zh) 一种基于对抗学习的去性别偏见的图像识别模型及应用
CN114863175A (zh) 一种无监督多源部分域适应图像分类方法
CN114419379A (zh) 一种基于对抗性扰动的深度学习模型公平性提升系统及方法
CN114579794A (zh) 特征一致性建议的多尺度融合地标图像检索方法及系统
CN116910571B (zh) 一种基于原型对比学习的开集域适应方法及系统
CN113536015A (zh) 一种基于深度辨识度迁移的跨模态检索方法
CN117152459A (zh) 图像检测方法、装置、计算机可读介质及电子设备
CN112102135A (zh) 基于lstm神经网络的高校贫困生精准资助模型
CN115439791A (zh) 跨域视频动作识别方法、装置、设备和计算机可存储介质
CN116824216A (zh) 一种无源无监督域适应图像分类方法
CN116109834A (zh) 一种基于局部正交特征注意力融合的小样本图像分类方法
CN112149556B (zh) 一种基于深度互学习和知识传递的人脸属性识别方法
CN114998973A (zh) 一种基于域自适应的微表情识别方法
CN114298160A (zh) 一种基于孪生知识蒸馏与自监督学习的小样本分类方法
CN114792114A (zh) 一种基于黑盒多源域通用场景下的无监督域适应方法
CN114491103A (zh) 一种基于多标记深度关联分析的物联网跨媒体大数据检索方法
CN114139655A (zh) 一种蒸馏式竞争学习的目标分类系统和方法
CN114462466A (zh) 一种面向深度学习的数据去偏方法
CN113449631A (zh) 图像分类方法及系统
CN111860441A (zh) 基于无偏深度迁移学习的视频目标识别方法
Cai et al. Monitoring harmful bee colony with deep learning based on improved grey prediction algorithm
CN111914108A (zh) 基于语义保持的离散监督跨模态哈希检索方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination