CN116580824A - 基于联邦图机器学习的跨地域医疗合作预测方法 - Google Patents
基于联邦图机器学习的跨地域医疗合作预测方法 Download PDFInfo
- Publication number
- CN116580824A CN116580824A CN202310555544.3A CN202310555544A CN116580824A CN 116580824 A CN116580824 A CN 116580824A CN 202310555544 A CN202310555544 A CN 202310555544A CN 116580824 A CN116580824 A CN 116580824A
- Authority
- CN
- China
- Prior art keywords
- node
- model
- client
- formula
- probability value
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 33
- 238000010801 machine learning Methods 0.000 title claims abstract description 13
- 238000012549 training Methods 0.000 claims abstract description 26
- 230000003044 adaptive effect Effects 0.000 claims abstract description 7
- 230000006870 function Effects 0.000 claims description 30
- 238000012937 correction Methods 0.000 claims description 21
- 238000012360 testing method Methods 0.000 claims description 9
- 238000013528 artificial neural network Methods 0.000 claims description 8
- 230000007246 mechanism Effects 0.000 claims description 7
- 239000011159 matrix material Substances 0.000 claims description 6
- 230000000295 complement effect Effects 0.000 claims description 5
- 235000002020 sage Nutrition 0.000 claims description 4
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 claims description 3
- 238000012935 Averaging Methods 0.000 claims description 3
- 230000004913 activation Effects 0.000 claims description 3
- 230000002776 aggregation Effects 0.000 claims description 3
- 238000004220 aggregation Methods 0.000 claims description 3
- 238000005315 distribution function Methods 0.000 claims description 3
- 230000010354 integration Effects 0.000 claims description 3
- 238000005457 optimization Methods 0.000 claims description 3
- 230000009466 transformation Effects 0.000 claims description 3
- 239000013598 vector Substances 0.000 claims description 3
- 238000004364 calculation method Methods 0.000 claims description 2
- 238000000547 structure data Methods 0.000 description 5
- 201000010099 disease Diseases 0.000 description 2
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 208000017667 Chronic Disease Diseases 0.000 description 1
- 206010011224 Cough Diseases 0.000 description 1
- 206010037660 Pyrexia Diseases 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 235000013399 edible fruits Nutrition 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002955 isolation Methods 0.000 description 1
- 238000005065 mining Methods 0.000 description 1
- 238000001356 surgical procedure Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H40/00—ICT specially adapted for the management or administration of healthcare resources or facilities; ICT specially adapted for the management or operation of medical equipment or devices
- G16H40/20—ICT specially adapted for the management or administration of healthcare resources or facilities; ICT specially adapted for the management or operation of medical equipment or devices for the management or administration of healthcare resources or facilities, e.g. managing hospital staff or surgery rooms
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02A—TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
- Y02A90/00—Technologies having an indirect contribution to adaptation to climate change
- Y02A90/10—Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation
Landscapes
- Engineering & Computer Science (AREA)
- Medical Informatics (AREA)
- General Business, Economics & Management (AREA)
- Health & Medical Sciences (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Business, Economics & Management (AREA)
- Artificial Intelligence (AREA)
- Epidemiology (AREA)
- Public Health (AREA)
- General Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Primary Health Care (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Biomedical Technology (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开一种基于联邦图机器学习的跨地域医疗合作预测方法,其特征在于,包括如下步骤:1)定义联邦学习系统;2)本地客户端模型训练;3)根据组合策略生成全局节点表示;4)得到微平衡的类别预测概率值;5)类别的预测概率值调整;6)自适应校准来自不同方法得到的类别预测概率值。这种方法在保护隐私的前提下,能解决non‑IID和长尾分布数据异质性导致的模型偏差问题,从而获得良好的预测性能,为跨地域医疗合作提供有力的支持。
Description
技术领域
本发明涉及联邦学习以及机器学习领域,具体是一种基于联邦图机器学习的跨地域医疗合作预测方法。
背景技术
跨地域医疗合作是指不同地区的医疗机构之间合作,共同开展医疗服务和资源共享。这种合作可以提高医疗资源的利用效率,提高医疗服务的质量和水平,也可以解决某些地区医疗资源短缺的问题。然而,不同医疗机构之间的数据隔离、数据安全和隐私保护等问题也成为了跨地域医疗合作面临的挑战。
图神经网络是一种能够处理图结构数据的人工神经网络。近年来,图神经网络在处理图结构数据方面被证明是一种有效的方法。然而,由于单个客户的数据可用性有限以及对数据隐私的需求,对分布式图结构数据学习的需求逐渐增加。但是直接的数据共享会导致隐私的泄露,因此集中式图神经网络的隐私保护方法受到了很多关注。然而,这些方法往往会引入过多的噪声,这可能会严重影响模型的准确性,使数据共享失去意义。
联邦学习是解决跨地域医疗机构之间数据隐私保护挑战的方法,允许多个数据所有者在不共享原始数据的情况下参与训练过程。然而,联邦学习最初是为图像数据设计的,对于图数据来说存在挑战,因为图数据具有不规则性和连通性。在跨地域医疗合作这样的实际场景中,服务器代表医疗管理中心,而每个客户端代表一个医疗机构。医疗数据通常是图结构数据,其中节点表示患者,节点属性表示疾病(例如咳嗽和发热),两个患者之间的边可能表示他们患有相同的慢性疾病、接受过类似的手术,或者由同一位医生治疗过。因此,需要研究专门针对图数据的联邦学习算法,以解决跨地域医疗机构之间数据共享和隐私保护问题。
联邦图机器学习是一种分布式图结构机器学习方法,可以处理图结构数据,可以在不损害客户数据隐私的情况下对分布式图结构数据进行学习。然而,联邦图机器学习面临着一些挑战,比如非独立同分布(non-IID)和长尾分布数据。Non-IID数据是指在分布式设置中,每个医疗机构可能有不同的患者人群,导致医疗状况和结果的分布存在差异。长尾数据是某些疾病在医院中可能比其它疾病更为罕见。这些数据异质问题常常导致客户端之间存在显著的模型偏差,特别是在图挖掘任务中,因此,在跨地区医疗合作中,由于涉及到不同地区的医疗数据,数据之间的差异性可能比较大,这就需要一种新的联邦图机器学习方法来解决模型偏差问题。
发明内容
本发明的目的是针对现有技术的不足,而提供一种基于联邦图机器学习的跨地域医疗合作预测方法。这种方法在保护隐私的前提下,能解决non-IID和长尾分布数据异质性导致的模型偏差问题,从而获得良好的预测性能,为跨地域医疗合作提供有力的支持。
实现本发明目的的技术方案是:
一种基于联邦图机器学习的跨地域医疗合作预测方法,包括如下步骤:
1)定义联邦学习系统:设定一个中央服务器S,K个客户端,每个客户端Di拥有本地子图数据Gi=(Vi,Ei,Xi),Vi是节点集合,Ei是边的集合,Ei表示节点特征矩阵,i∈[K],采用Pareto分布函数将Di的数据分布设定为长尾分布,将每个客户端Di的10%的数据划分为测试集,剩余的90%的数据划分为训练集;
2)本地客户端模型训练:每个客户端Di的训练集输入本地数据Gi到本地两阶段的GraphSage-Mixup模型训练,采用图神经网络GraphSage和Mixup模型学习到匿名节点表示,并将匿名节点表示发送给服务器,采用两阶段的GraphSage-Mixup模型,在第一阶段,先采用GraphSage获取节点的隐藏表示,但不进行Mixup,计算方法如公式(1)所示:
在第二阶段中,随机将小批量中的节点配对,按照公式(2)所示方式进行节点属性和标签的Mixup:
xij=(1-λ)xi+λxj and yij=(1-λ)yi+λyj (2),
最后,在每个层l,将GraphSage视为一种基于节点i和j拓扑结构的聚合函数,分别对它们进行如公式(3)所示聚合:
其中,并且在下一层之前将来自两个拓扑结构的聚合特征进行混合如公式(4)所示:
通过公式(2)、(3)和公式(4)应用于图中的所有节点,得到了混合节点嵌入混合节点嵌入/>与类别数目具有相同的维度,并采用一个SoftMax分类器传递,用于进行多类节点分类训练,采用公式(5)计算分类结果的交叉熵误差:
其中Vl是一个带标签的训练节点集合,Y∈Rn×|L|是图节点的独热标签指示矩阵,采用公式(2)和公式(4)生成混合节点嵌入,从而提供隐私保护.因此,潜在的攻击者很难从这些节点嵌入中推断出敏感信息;
3)根据组合策略生成全局节点表示:服务器接收客户端上传的节点表示后,采用Concat策略生成全局节点表示,Concat策略即将每个客户端的本地节点嵌入连接起来生成全局节点嵌入,具体为:服务器采用Concat的组合策略将来自每个客户端的本地节点嵌入组合起来以获得全局节点嵌入hg,如公式(6)所示:
hg←CONCAT(h1,h2,...,hK) (6);
4)得到微平衡的类别预测概率值:服务器根据客户端上传的模型参数wk采用注意力机制量化每个客户端的贡献程度,生成个性化模型参数wg,得到个性化模型参数后,服务器将采用个性化模型参数和全局节点表示训练模型,从而得到微平衡的每个类别预测概率值,具体为:
在服务器上采用层级别的注意力机制如公式(7)所示:
其中,wg表示个性化模型参数,wk表示客户端模型的参数,对于第k个客户端模型的每一层l,和/>分别表示注意力权重和模型参数,||是一个将多个向量按顺序连接的运算符,注意力权重采用softmax函数计算如公式(8)所示:
然后,基于全局嵌入hg和校正后的特征提取器如公式(9)所示得到微平衡的类别预测概率值:
5)类别的预测概率值调整:对所有客户端上传的模型参数wk进行集成并生成集成模型μe(hg),采用基于元学习的元网络校正方法,通过学习多个客户端微调的每个类别的预测概率值,对集成模型进行微调,在服务器上定义集成模型如公式(10)所示:
然后采用非线性变换来计算集成权重pk如公式(11)所示:
其中,ap∈RC和bp是可学习的参数,然后,pk被归一化,使pk之和等于1,采用基于元学习的元网络知识校准方法,将logits∈μe(hg)得到类别预测概率值,且通过可学习的温度参数τ进行缩放,使映射到0到1之间的范围,然后,scaled_logits采用减去可学习的偏移参数α对scaled_logits进行校准如公式(12)所示:
将sigmoid函数应用于scaled_logits,并沿着第一维取平均值,得到校准因子correction,其中函数f是一个多层感知器如公式(13)所示:
correction=f(mean(sigmoid(scaled_logits),axis=0)) (13),
将corrected_logits与correction因子进行逐元素乘积,得到调整后的类别预测概率值zcl,其中符号⊙表示哈达玛积,如公式(14)所示:
zcl=corrected_logits⊙correction (14);
6)自适应校准来自不同方法得到的类别预测概率值:采用自适应校正函数校正微平衡和微调每个类别的预测概率值,对平衡的预测概率值进行交叉熵损失优化模型,依据自适应校正函数校正微平衡和微调的类别预测概率值,以平衡zmb和zcl之间的权重,使调整后的对数和微观平衡对数得以有效整合,并发挥微平衡和微调每个类别的预测概率值互补优势如公式(15)所示:
z′=σ(x)·zmb+(1-σ(x))·zcl (15),
σ(x)是sigmoid激活函数,x是可学习的参数,sigmoid函数输出介于0和1之间的值,sigmoid函数决定了zmb和zcl之间的权衡关系,如果sigmoid函数的输出接近1,那么校准后的zcl将对最终输出z′产生更大的影响,相反,如果sigmoid函数的输出接近0,那么微调后的zmb将对最终输出产生更大的影响,通过在训练过程中调整x的值,模型学习有效地整合校准后的zmb和微调后的zcl,并使它们相互补充,在得到校准后的类别预测概率值z′之后,采用交叉熵损失训练全局模型,并获得更新的全局模型参数θg,然后将全局模型参数θg发送回每个客户端进行模型参数wk更新得到更新的本地客户端模型参数wk+1,更新后的模型可以用于对客户端本地的测试集进行测试和预测。
本技术方案的主要目的是解决跨地域医疗合作中医疗数据异质的问题,虽然现有的联邦学习方法已经能够处理non-IID数据导致的负面影响,但它们只考虑了客户端节点的类分布平衡的情况,在跨地域医疗合作这种实际场景中,客户端节点的类分布通常是长尾分布的,这可能会导致模型发生偏置,因此,本技术方案分为本地阶段和全局阶段,以处理non-IID和长尾数据导致的全局模型偏置问题。
本技术方案针对联邦学习中本地客户端医疗数据量稀疏,仅仅使用图神经网络生成的节点表示能力较弱,采用GraphSage结合Mixup增强节点表示能力,同时使用Mixup混合节点对,使节点特征匿名,防止攻击者通过收集图中节点信息后,对目标的隐私进行推测;本技术方案采用集成模型的思想,结合基于元学习的元网络方法,来纠正长尾数据对模型性能的影响;本技术方案中服务器采用层次化注意力机制量化每个客户端的贡献,从而得到无偏的个性化模型参数,解决了非独立同分布和长尾本地数据带来的影响。
这种方法在保护隐私的前提下,能解决non-IID和长尾分布数据异质性导致的模型偏差问题,从而获得良好的预测性能,为跨地域医疗合作提供有力的支持。
附图说明
图1为实施例的方法流程示意图。
具体实施方式
下面结合附图和实施例对本发明的内容做进一步的阐述,但不是对本发明的限定。
实施例:
参照图1,一种基于联邦图机器学习的跨地域医疗合作预测方法,包括如下步骤:
1)定义联邦学习系统:设定一个中央服务器S,K个客户端,每个客户端Di拥有本地子图数据Gi=(Vi,Ei,Xi),Vi是节点集合,Ei是边的集合,Xi表示节点特征矩阵,i∈[K],采用Pareto分布函数将Di的数据分布设定为长尾分布,将每个客户端Di的10%的数据划分为测试集,剩余的90%的数据划分为训练集;
2)本地客户端模型训练:每个客户端Di的训练集输入本地数据Gi到本地两阶段的GraphSage-Mixup模型训练,采用图神经网络GraphSage和Mixup模型学习到匿名节点表示,并将匿名节点表示发送给服务器,采用两阶段的GraphSage-Mixup模型,在第一阶段,先采用GraphSage获取节点的隐藏表示,但不进行Mixup,计算方法如公式(1)所示:
在第二阶段中,随机将小批量中的节点配对,按照公式(2)所示方式进行节点属性和标签的Mixup:
xij=(1-λ)xi+λxj and yij=(1-λ)yi+λyj (2),
最后,在每个层l,将GraphSage视为一种基于节点i和j拓扑结构的聚合函数,分别对它们进行如公式(3)所示聚合:
其中,并且在下一层之前将来自两个拓扑结构的聚合特征进行混合如公式(4)所示:
通过公式(2)、(3)和公式(4)应用于图中的所有节点,得到了混合节点嵌入混合节点嵌入/>与类别数目具有相同的维度,并采用一个SoffMax分类器传递,用于进行多类节点分类训练,采用公式(5)计算分类结果的交叉熵误差:
其中Vl是一个带标签的训练节点集合,Y∈Rn×|L|是图节点的独热标签指示矩阵,采用公式(2)和公式(4)生成混合节点嵌入,从而提供隐私保护.因此,潜在的攻击者很难从这些节点嵌入中推断出敏感信息;
3)根据组合策略生成全局节点表示:服务器接收客户端上传的节点表示后,采用Concat策略生成全局节点表示,Concat策略即将每个客户端的本地节点嵌入连接起来生成全局节点嵌入,具体为:服务器采用Concat的组合策略将来自每个客户端的本地节点嵌入组合起来以获得全局节点嵌入hg,如公式(6)所示:
hg←CONCAT(h1,h2,...,hK) (6);
4)得到微平衡的类别预测概率值:服务器根据客户端上传的模型参数wk采用注意力机制量化每个客户端的贡献程度,生成个性化模型参数wg,得到个性化模型参数后,服务器将采用个性化模型参数和全局节点表示训练模型,从而得到微平衡的每个类别预测概率值,具体为:
在服务器上采用层级别的注意力机制如公式(7)所示:
其中,wg表示个性化模型参数,wk表示客户端模型的参数,对于第k个客户端模型的每一层l,和/>分别表示注意力权重和模型参数,||是一个将多个向量按顺序连接的运算符,注意力权重采用softmax函数计算如公式(8)所示:
然后,基于全局嵌入hg和校正后的特征提取器如公式(9)所示得到微平衡的类别预测概率值:
5)类别的预测概率值调整:对所有客户端上传的模型参数wk进行集成并生成集成模型μe(hg),采用基于元学习的元网络校正方法,通过学习多个客户端微调的每个类别的预测概率值,对集成模型进行微调,在服务器上定义集成模型如公式(10)所示:
然后采用非线性变换来计算集成权重pk如公式(11)所示:
其中,ap∈RC和bp是可学习的参数,然后,pk被归一化,使pk之和等于1,采用基于元学习的元网络知识校准方法,将logits∈μe(hg)得到类别预测概率值,且通过可学习的温度参数τ进行缩放,使映射到0到1之间的范围,然后,scaled_logits采用减去可学习的偏移参数σ对scaled_logits进行校准如公式(12)所示:
将sigmoid函数应用于scaled_logits,并沿着第一维取平均值,得到校准因子correction,其中函数f是一个多层感知器如公式(13)所示:
correction=f(mean(sigmoid(scaled_logits),axis=0)) (13),
将corrected_logits与correction因子进行逐元素乘积,得到调整后的类别预测概率值zcl,其中符号⊙表示哈达玛积,如公式(14)所示:
zcl=corrected_logits⊙correction (14);
6)自适应校准来自不同方法得到的类别预测概率值:采用自适应校正函数校正微平衡和微调每个类别的预测概率值,对平衡的预测概率值进行交叉熵损失优化模型,依据自适应校正函数校正微平衡和微调的类别预测概率值,以平衡zmb和zcl之间的权重,使调整后的对数和微观平衡对数得以有效整合,并发挥微平衡和微调每个类别的预测概率值互补优势如公式(15)所示:
z′=σ(x)·zmb+(1-σ(x))·zcl (15),
σ(x)是sigmoid激活函数,x是可学习的参数,sigmoid函数输出介于0和1之间的值,sigmoid函数决定了zmb和zcl之间的权衡关系,如果sigmoid函数的输出接近1,那么校准后的zcl将对最终输出z′产生更大的影响,相反,如果sigmoid函数的输出接近0,那么微调后的zmb将对最终输出产生更大的影响,通过在训练过程中调整x的值,模型学习有效地整合校准后的zmb和微调后的zcl,并使它们相互补充,在得到校准后的类别预测概率值z′之后,采用交叉熵损失训练全局模型,并获得更新的全局模型参数θg,然后将全局模型参数θg发送回每个客户端进行模型参数wk更新得到更新的本地客户端模型参数wk+1,更新后的模型可以用于对客户端本地的测试集进行测试和预测。
Claims (1)
1.一种基于联邦图机器学习的跨地域医疗合作预测方法,其特征在于,包括如下步骤:
1)定义联邦学习系统:设定一个中央服务器S,K个客户端,每个客户端Di拥有本地子图数据Gi=(Vi,Ei,Xi),Vi是节点集合,Ei是边的集合,Xi表示节点特征矩阵,i∈[K],采用Pareto分布函数将Di的数据分布设定为长尾分布,将每个客户端Di的10%的数据划分为测试集,剩余的90%的数据划分为训练集;
2)本地客户端模型训练:每个客户端Di的训练集输入本地数据Gi到本地两阶段的GraphSage-Mixup模型训练,采用图神经网络GraphSage和Mixup模型学习到匿名节点表示,并将匿名节点表示发送给服务器,采用两阶段的GraphSage-Mixup模型,在第一阶段,先采用GraphSage获取节点的隐藏表示,但不进行Mixup,计算方法如公式(1)所示:
在第二阶段中,随机将小批量中的节点配对,按照公式(2)所示方式进行节点属性和标签的Mixup:
xij=(1-λ)xi+λxjand yij=(1-λ)yi+λyj (2),
最后,在每个层l,将GraphSage视为一种基于节点i和j拓扑结构的聚合函数,分别对它们进行如公式(3)所示聚合:
其中,并且在下一层之前将来自两个拓扑结构的聚合特征进行混合如公式(4)所示:
通过公式(2)、(3)和公式(4)应用于图中的所有节点,得到了混合节点嵌入混合节点嵌入/>与类别数目具有相同的维度,并采用一个SoftMax分类器传递,采用公式(5)计算分类结果的交叉熵误差:
其中Vl是一个带标签的训练节点集合,Y∈Rn×|L|是图节点的独热标签指示矩阵,采用公式(2)和公式(4)生成混合节点嵌入;
3)根据组合策略生成全局节点表示:服务器接收客户端上传的节点表示后,采用Concat策略生成全局节点表示,Concat策略即将每个客户端的本地节点嵌入连接起来生成全局节点嵌入,具体为:服务器采用Concat的组合策略将来自每个客户端的本地节点嵌入组合起来以获得全局节点嵌入hg,如公式(6)所示:
hg←CONCAT(h1,h2,...,hK) (6);
4)得到微平衡的类别预测概率值:服务器根据客户端上传的模型参数wk采用注意力机制量化每个客户端的贡献程度,生成个性化模型参数wg,得到个性化模型参数后,服务器将采用个性化模型参数和全局节点表示训练模型,从而得到微平衡的每个类别预测概率值,具体为:
在服务器上采用层级别的注意力机制如公式(7)所示:
其中,wg表示个性化模型参数,wk表示客户端模型的参数,对于第k个客户端模型的每一层l,和/>分别表示注意力权重和模型参数,||是一个将多个向量按顺序连接的运算符,注意力权重采用softmax函数计算如公式(8)所示:
然后,基于全局嵌入hg和校正后的特征提取器如公式(9)所示得到微平衡的类别预测概率值:
5)类别的预测概率值调整:对所有客户端上传的模型参数wk进行集成并生成集成模型μe(hg),采用基于元学习的元网络校正方法,通过学习多个客户端微调的每个类别的预测概率值,对集成模型进行微调,在服务器上定义集成模型如公式(10)所示:
然后采用非线性变换来计算集成权重pk如公式(11)所示:
其中,ap∈RC和bp是可学习的参数,然后,pk被归一化,使pk之和等于1,采用基于元学习的元网络知识校准方法,将logits∈μe(hg)得到类别预测概率值,且通过可学习的温度参数r进行缩放,使映射到0到1之间的范围,然后,scaled_logits采用减去可学习的偏移参数α对scaled_logits进行校准如公式(12)所示:
将sigmoid函数应用于scaledlogits,并沿着第一维取平均值,得到校准因子correction,其中函数f是一个多层感知器如公式(13)所示:
correction=f(mean(sigmoid(scaled_logits),axis=0)) (13),
将corrected_logits与correction因子进行逐元素乘积,得到调整后的类别预测概率值zcl,其中符号⊙表示哈达玛积,如公式(14)所示:
zcl=corrected_logits⊙correction (14);
6)自适应校准来自不同方法得到的类别预测概率值:采用自适应校正函数校正微平衡和微调每个类别的预测概率值,对平衡的预测概率值进行交叉熵损失优化模型,依据自适应校正函数校正微平衡和微调的类别预测概率值,微平衡和微调每个类别的预测概率值互补优势如公式(15)所示:
z′=σ(x)·zmb+(1-σ(x))·zcl (15),
σ(x)是sigmoid激活函数,x是可学习的参数,sigmoid函数输出介于0和1之间的值,在得到校准后的类别预测概率值z′之后,采用交叉熵损失训练全局模型,并获得更新的全局模型参数θg,然后将全局模型参数θg发送回每个客户端进行模型参数wk更新得到更新的本地客户端模型参数wk+1,更新后的模型可以用于对客户端本地的测试集进行测试和预测。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310555544.3A CN116580824A (zh) | 2023-05-17 | 2023-05-17 | 基于联邦图机器学习的跨地域医疗合作预测方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310555544.3A CN116580824A (zh) | 2023-05-17 | 2023-05-17 | 基于联邦图机器学习的跨地域医疗合作预测方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116580824A true CN116580824A (zh) | 2023-08-11 |
Family
ID=87539216
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310555544.3A Pending CN116580824A (zh) | 2023-05-17 | 2023-05-17 | 基于联邦图机器学习的跨地域医疗合作预测方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116580824A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116958149A (zh) * | 2023-09-21 | 2023-10-27 | 湖南红普创新科技发展有限公司 | 医疗模型训练方法、医疗数据分析方法、装置及相关设备 |
-
2023
- 2023-05-17 CN CN202310555544.3A patent/CN116580824A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116958149A (zh) * | 2023-09-21 | 2023-10-27 | 湖南红普创新科技发展有限公司 | 医疗模型训练方法、医疗数据分析方法、装置及相关设备 |
CN116958149B (zh) * | 2023-09-21 | 2024-01-12 | 湖南红普创新科技发展有限公司 | 医疗模型训练方法、医疗数据分析方法、装置及相关设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20230039182A1 (en) | Method, apparatus, computer device, storage medium, and program product for processing data | |
CN113850272B (zh) | 基于本地差分隐私的联邦学习图像分类方法 | |
CN113190654B (zh) | 一种基于实体联合嵌入和概率模型的知识图谱补全方法 | |
CN112990385B (zh) | 一种基于半监督变分自编码器的主动众包图像学习方法 | |
CN113518007B (zh) | 一种基于联邦学习的多物联网设备异构模型高效互学习方法 | |
CN112101404B (zh) | 基于生成对抗网络的图像分类方法、系统及电子设备 | |
CN115358487A (zh) | 面向电力数据共享的联邦学习聚合优化系统及方法 | |
CN114357676B (zh) | 一种针对层次化模型训练框架的聚合频率控制方法 | |
CN116580824A (zh) | 基于联邦图机器学习的跨地域医疗合作预测方法 | |
CN115051929A (zh) | 基于自监督目标感知神经网络的网络故障预测方法及装置 | |
CN116205383A (zh) | 一种基于元学习的静态动态协同图卷积交通预测方法 | |
WO2021196240A1 (zh) | 面向跨网络的表示学习算法 | |
CN116094792A (zh) | 基于时空特征和注意力机制的加密恶意流识别方法及装置 | |
Ren et al. | Federated distillation for medical image classification: Towards trustworthy computer-aided diagnosis | |
CN118036706A (zh) | 基于图子树差异实现图联邦迁移学习的多任务处理系统 | |
CN113037778A (zh) | 针对连续变量量子密钥分发系统的攻击检测方法 | |
CN116486150A (zh) | 一种基于不确定性感知的图像分类模型回归误差消减方法 | |
Mi et al. | Fedmdr: Federated model distillation with robust aggregation | |
CN111247556A (zh) | 在附加的任务上训练人工神经网络的同时避免灾难性干扰 | |
CN109118483A (zh) | 一种标签质量检测方法及装置 | |
CN115150246A (zh) | 基于新型嵌套链架构的面向海量实时物联网的上链方法 | |
CN109461498B (zh) | 一种基于卷积神经网络对舌体胖瘦精细分类的方法 | |
CN117036910B (zh) | 一种基于多视图及信息瓶颈的医学图像训练方法 | |
CN116935143B (zh) | 基于个性化联邦学习的dfu医学图像分类方法及系统 | |
CN116452559B (zh) | 基于ctDNA片段化模式的肿瘤病灶的定位方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |