CN113298229A - 联邦学习模型训练方法、客户端、服务器及存储介质 - Google Patents
联邦学习模型训练方法、客户端、服务器及存储介质 Download PDFInfo
- Publication number
- CN113298229A CN113298229A CN202110391127.0A CN202110391127A CN113298229A CN 113298229 A CN113298229 A CN 113298229A CN 202110391127 A CN202110391127 A CN 202110391127A CN 113298229 A CN113298229 A CN 113298229A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- parameters
- neural network
- client
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 175
- 238000000034 method Methods 0.000 title claims abstract description 73
- 238000003062 neural network model Methods 0.000 claims abstract description 101
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 44
- 238000004891 communication Methods 0.000 claims description 22
- 230000006870 function Effects 0.000 claims description 13
- 238000004821 distillation Methods 0.000 claims description 11
- 238000011156 evaluation Methods 0.000 claims description 9
- 238000005457 optimization Methods 0.000 claims description 7
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000004321 preservation Methods 0.000 claims 1
- 230000000694 effects Effects 0.000 abstract description 9
- 238000013473 artificial intelligence Methods 0.000 abstract description 4
- 238000010586 diagram Methods 0.000 description 5
- 238000001514 detection method Methods 0.000 description 2
- 230000009977 dual effect Effects 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 1
- 238000002955 isolation Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 239000004576 sand Substances 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- H—ELECTRICITY
- H04—ELECTRIC COMMUNICATION TECHNIQUE
- H04L—TRANSMISSION OF DIGITAL INFORMATION, e.g. TELEGRAPHIC COMMUNICATION
- H04L41/00—Arrangements for maintenance, administration or management of data switching networks, e.g. of packet switching networks
- H04L41/14—Network analysis or design
- H04L41/145—Network analysis or design involving simulating, designing, planning or modelling of a network
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Computer Networks & Wireless Communication (AREA)
- Signal Processing (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及人工智能算法技术领域,具体提供一种基于知识蒸馏的联邦学习模型训练方法,包括:接收来自服务器的用于模型训练的控制参数;根据控制参数以及本地数据样本对初始的第一神经网络模型进行训练,得到第一模型参数;将第一模型参数发送至服务器;接收来自服务器的第二神经网络模型的第二模型参数;利用知识蒸馏方法使第一神经网络模型学习到第二神经网络模型的知识,训练得到更新的第一神经网络模型。使用本发明的方法通过构建联邦学习系统有效解决目前存在的数据孤岛问题,同时在联邦学习框架系统中增加知识蒸馏模块,使得算法模型可以同时在所有训练数据的知识基础上进行训练优化,进一步提升联邦学习框架系统的训练效果。
Description
技术领域
本发明涉及人工智能技术领域,具体涉及一种基于知识蒸馏 的联邦学习模型训练方法、客户端、服务器及计算机可读存储介质。
背景技术
目前多家单位若合作利用人工智能算法在某一业务场景中 进行落地,会遇到一些问题,例如,由于数据安全与数据隐私要求,各 家单位的数据不能在各单位之间进行有效流通和使用,从而造成数据孤 岛问题。传统的算法训练框架强调数据的多样性和完整性,从而进一步 放大数据孤岛问题给算法能力带来的影响。因此传统的算法训练框架和数据孤岛问题会使得人工智能算法能力陷入瓶颈,并进一步限制算法在 实际应用场景中的使用和落地。
因此,本领域仍然需要一种新的方法来解决由于数据孤岛无 法提高算法能力,限制算法落地应用的问题。
发明内容
为了解决现有技术中的上述问题,即,为了解决现有方案由 于数据孤岛无法提高算法能力,限制算法落地应用的问题,一方面,一 种基于知识蒸馏的联邦学习模型训练方法,包括:接收服务器获取的用 于模型训练的控制参数;根据所述控制参数以及本地数据样本对初始的 第一神经网络模型进行训练,得到第一模型参数;将所述第一模型参数 发送至所述服务器;接收所述服务器获取的第二神经网络模型的第二模 型参数;利用知识蒸馏方法使所述第一神经网络模型学习到所述第二神 经网络模型的知识,训练得到更新的第一神经网络模型,其中,所述第 一神经网络模型是学生网络模型。
在上述联邦学习模型训练方法的优选实施方式中,所述控 制参数至少包括训练次数n,n大于等于2,还可以包括:将所述更新的 第一神经网络模型的第一模型参数发送至所述服务器;接收所述服务 器获取的更新后的第二模型参数;根据所述更新后的第二模型参数得 到更新后的第二神经网络模型;利用知识蒸馏方法使所述更新后的第 二神经网络模型学习到所述更新的第一神经网络模型的知识,训练得 到二次训练后的第二神经网络模型,其中,所述更新后的第二神经网 络模型是学生网络模型;将所述二次训练后的第二神经网络模型的第 二模型参数发送至所述服务器;如此循环,直至所述n次训练结束。
在上述联邦学习模型训练方法的优选实施方式中,所述第 一模型参数和所述第二模型参数至少包括神经网络的权重参数。
在上述联邦学习模型训练方法的优选实施方式中,还可以 包括:所述知识蒸馏方法所采用的损失函数包括以下任意一种:均方 误差损失函数、平均绝对误差损失函数。
根据本发明的另一方面,还提供了一种联邦学习模型训练 方法,包括:向第一客户端和第二客户端分别发送用于模型训练的控 制参数;接收来自第一客户端的经过更新的第一神经网络模型的第一 模型参数;接收来自第二客户端的经过更新的第二神经网络模型的第 二模型参数;将所述第一模型参数发送至所述第二客户端,将所述第 二模型参数发送至所述第一客户端;保存所述第一模型参数和所述第 二模型参数。
在上述联邦学习模型训练方法的优选实施方式中,包括: 所述控制参数至少包括训练次数n,n大于等于2;将更新后的第一模型 参数发送至所述第二客户端;将更新后的第二模型参数发送至所述第 一客户端;接收并保存来自所述第二客户端的经过二次训练的第一模 型参数;接收并保存来自所述第一客户端的经过二次训练的第二模型 参数;如此循环,直至所述n次训练结束。
在上述联邦学习模型训练方法的优选实施方式中,在保存 第一模型参数和第二模型参数时,利用指标评估方法选取一个或多个 第一模型参数和第二模型参数进行保存更新;将选取保存的模型参数 发送至对应客户端。
根据本发明的再一方面,还提供了一种基于知识蒸馏的联 邦学习模型训练客户端,包括:通讯模块,接收来自服务器的用于模 型训练的控制参数,以及接收来自所述服务器的第二神经网络模型的 第二模型参数;算法训练模块,与所述通讯模块连接,根据所述控制 参数以及本地数据样本对初始的第一神经网络模型进行训练,得到第 一模型参数,根据所述第二模型参数得到所述第二神经网络模型,以 及利用知识蒸馏方法使所述第一神经网络模型学习到所述第二神经网 络模型的知识,训练得到更新的第一神经网络模型,其中,所述第一 神经网络模型是学生网络模型。
在上述客户端的优选实施方式中,所述控制参数至少包括 训练次数n,n大于等于2,所述通讯模块还将得到的第一神经网络模型 的模型参数发送至所述服务器进行更新,以及接收来自所述服务器的 更新后的第二模型参数,以及将所二次训练后的第二神经网络模型的 第二模型参数发送至所述服务器;所述算法训练模块还根据所述更新 后的第二模型参数得到更新后的第二神经网络模型,以及利用知识蒸 馏方法使所述更新后的第二神经网络模型学习到所述更新的第一神经 网络模型的知识,训练得到二次训练后的第二神经网络模型,其中, 所述更新后的第二神经网络模型是学生网络模型,如此循环,直至所述n次训练结束。
根据本发明的又一方面,还提供了一种服务器,包括:训 练控制模块,生成用于模型训练的控制参数;通讯模块,向第一客户 端和第二客户端分别发送所述控制参数,接收来自第一客户端的经过 更新的第一神经网络模型的第一模型参数,以及接收来自第二客户端 的经过更新的第二神经网络模型的第二模型参数;参数更新模块,保 存所述第一模型参数和所述第二模型参数。
在上述服务器的优选实施方式中,包括:所述控制参数至 少包括训练次数n,n大于等于2;所述通讯模块还用于将更新后的第一 模型参数发送至所述第二客户端,将更新后的第二模型参数发送至所 述第一客户端,接收并保存来自所述第二客户端的经过二次训练的第 一模型参数,以及接收来自所述第一客户端的经过二次训练的第二模 型参数,如此循环,直至所述n次训练结束;所述参数更新模块保存所 述经过二次训练的第二模型参数。
在上述服务器的优选实施方式中,还可以包括:模型优选 模块,与所述参数更新模块和通讯模块连接,在保存第一模型参数和 第二模型参数时,利用指标评估方法选取一个或多个第一模型参数和 第二模型参数进行保存更新;所述通讯模块将选取保存的模型参数发 送至对应客户端。
本发明进一步还提供了一种基于知识蒸馏的联邦学习模 型训练系统,包括多个如上述任一技术方案中所述的基于知识蒸馏的 联邦学习模型训练客户端和如上述任一技术方案中所述的服务器。
本发明进一步还提供了一种计算机可读存储介质,所述存 储介质中存储有多条程序代码,所述程序代码适用于由处理器加载并 运行以执行如上述任一技术方案中所述的基于知识蒸馏的联邦学习模 型训练方法和上述任一技术方案中所述的联邦学习模型训练方法。
本发明将模型训练设置在本地,并通过中心服务器完成模 型参数交互,解决了数据孤岛和数据隐私的问题,可在数据不离本地 的情况下完成算法模型的训练及优化。能够支持双模型之间的相互蒸 馏,充分利用全部数据的知识,提高算法模型在联邦框架下的训练效 果,同时一次训练过程可以完成两个神经网络模型的训练,打破数据 孤岛和传统训练框架对算法能力造成的瓶颈。
附图说明
下面结合附图来描述本发明的优选实施方式,附图中:
图1为根据本发明实施例的基于知识蒸馏的联邦学习模型 训练方法的流程图;
图2为根据本发明实施例的联邦学习模型训练方法的流程 图;
图3为根据本发明一个实施例的基于知识蒸馏的联邦学习 模型训练的结构示意图。
具体实施方式
为了便于理解本发明,下文将结合说明书附图和实施例对 本发明作更全面、细致的描述,但本领域技术人员应当理解的是,这 些实施方式仅仅用于解释本发明的技术原理,并非旨在限制本发明的 保护范围。
在本发明的描述中,“模块”、“处理器”可以包括硬件、软 件或者两者的组合。一个模块可以包括硬件电路、各种合适的感应器、 通信端口、存储器,也可以包括软件部分,比如程序代码,也可以是 软件和硬件的组合。处理器可以是中央处理器、微处理器、图像处理器、数字信号处理器或者其他任何合适的处理器。处理器具有数据和/ 或信号处理功能。处理器可以以软件方式实现、硬件方式实现或者二 者结合的方式实现。非暂时性的计算机可读存储介质包括任何合适的 可存储程序代码的介质,比如磁碟、硬盘、光碟、闪存、只读存储器、 随机存取存储器等等。术语“A和/或B”表示所有可能的A与B的组合, 比如只是A、只是B或者A和B。术语“至少一个A或B”或者“A和B 中的至少一个”含义与“A和/或B”类似,可以包括只是A、只是B或者 A和B。单数形式的术语“一个”、“这个”也可以包含复数形式。
首先参阅图1,在客户端侧,根据本发明实施例的一种基 于知识蒸馏的联邦学习模型训练方法,包括:
S1,接收来自服务器的用于模型训练的控制参数。该控制 参数可以包括学习率、训练次数等模型训练需要的参数,不仅限于此。
S2,根据控制参数以及本地数据样本对初始的第一神经网 络模型进行训练,得到第一模型参数。每个客户端可以部署在各场景 的本地,直接利用本地的数据样本进行训练,这样就无需将本地数据 传输至外部,保护数据隐私。
S3,将第一模型参数发送至服务器。训练完成后,将模型 参数发送给服务器进行保存更新,以便服务器将第一模型参数发送给 其他客户端。
S4,接收来自服务器的第二神经网络模型的第二模型参 数。除了将自己的模型参数通过服务器发送给其他客户端之外,还接 收其他客户端的第二模型参数,便于后面的知识蒸馏学习。
S5,利用知识蒸馏方法使第一神经网络模型学习到第二神 经网络模型的知识,训练得到更新的第一神经网络模型,其中,第一 神经网络模型是学生网络模型。通过双模型之间的相互蒸馏,就可以 充分利用各客户端的全部数据的知识,提高算法模型在联邦框架下的 训练效果。
在上述联邦学习模型训练方法的优选实施方式中,完成了 一次训练,为了进一步提升训练效果,控制参数还可以包括训练次数n, n大于等于2,即按照同样的思路训练多次,直到符合预设效果。将更 新的第一神经网络模型的第一模型参数发送至服务器;接收来自服务 器的更新后的第二模型参数;根据更新后的第二模型参数得到更新后 的第二神经网络模型;利用知识蒸馏方法使更新后的第二神经网络模 型学习到更新的第一神经网络模型的知识,训练得到二次训练后的第 二神经网络模型,其中,更新后的第二神经网络模型是学生网络模型; 将二次训练后的第二神经网络模型的第二模型参数发送至服务器;如 此循环,直至n次训练结束。在第二次训练时,利用第一次训练更新 后的双模型,并且将学习网络和教师网络进行替换,进一步充分学习 了两个客户端数据的知识,同时汇聚了两个模型的优点。以此类推, 进行多次训练之后,可以获得更好的算法模型。
需要说明的是,第一神经网络模型和第二神经网络模型可 以是相同的模型,也可以不相同的模型。第一模型参数和第二模型参 数至少包括神经网络的权重参数。
在上述联邦学习模型训练方法的优选实施方式中,还可以 包括:知识蒸馏方法所采用的损失函数包括以下任意一种:均方误差 损失函数、平均绝对误差损失函数。
传统的联邦学习技术框架会由于参数更新方式、训练策略 以及数据孤立等问题,造成算法模型的训练效果以及性能要差于传统 的训练框架在全部数据上进行训练的效果。而通过上述实施方案,能 够解决数据孤岛和数据隐私的问题,可在数据不离本地的情况下完成 算法模型的训练及优化。并且能够支持双模型之间的相互蒸馏,充分 利用全部数据的知识,提高算法模型在联邦框架下的训练效果,同时 一次训练过程可以完成两个神经网络模型的训练。打破数据孤岛和传 统训练框架对算法能力造成的瓶颈。
下面结合图2和图3,详细说明根据本发明的另一实施例。
步骤21,server端(即服务器,可以是任意的节点)进行 训练初始化。在server配置神经网络结构模型、启动参数及训练参数, 进行训练初始化。将训练相关参数发送至各client端(即客户端)。在 实际应用时,客户端可以是银行系统的客户端,社保系统的客户端。 由于银行和社保系统都是需要高度安全和隐私,因此两端的数据无法 实现互通,通过本发明可以解决该问题。
本领域人员应该理解,这里的神经网络模型包括但不限于 YOLOv3,YOLOv4。
步骤22,client端启动并开始训练。Client端接收server 相关训练控制参数并进行训练启动,开始模型的训练。控制参数中可 以包括训练次数,例如epoch值。一个epoch就是使用训练集中的全部 样本训练一次。通俗的讲,Epoch的值就是整个训练数据集被反复使用 几次。Epoch数是一个超参数,它定义了学习算法在整个训练数据集中 的工作次数。
步骤S23,每个clinet端完成1个epoch训练后,将训练 完成的模型参数返回至server端。如图3所示,在客户端1中,利用 本地样本数据data1对model1进行训练,得到model1的模型参数,把 该模型参数发送至服务器进行更新。同理,在客户端2中,利用本地样本数据data2对model2进行训练,得到model2的模型参数,把该模 型参数发送至服务器进行更新。
步骤24,server端更新从client端获取的模型参数,并将 更新后的模型参数交换分发至对应的client端。服务器将model2的模 型参数发送给客客户端1,将model1的模型参数发送给客户端2。
步骤25,client端双模型相互蒸馏训练并返回模型参数至 server端。
在epoch 2训练阶段,每个client端都会存在一个学生网 络模型,并通过另一个模型(教师网络模型)进行知识蒸馏。每个客 户端在完成1个epoch的训练后将学生网络模型返回至server端。如图 3中所示,在epoch 2训练阶段,在客户端1中model1是学生网络模型, model2教师网络模型。将训练更新的model1的模型参数反馈至服务器。 同理,在客户端2中,model1是教师网络模型,model2是学生网络模 型。经过知识蒸馏训练后,将model2反馈至服务器。
在知识蒸馏过程中,学生网络通过优化的损失函数进行训 练。优化的损失函数为:
loss=lossA+λ·lossB
这里lossA为学生网络在标注数据下的检测损失,包括目标 框中心点损失,目标框大小损失,目标框是否存在目标的置信度损失 以及分类损失,具体为:lossA=lossxy+losswh+lossconf+losscls。
lossB为学生网络在从教师网络上提取监督信息时的蒸馏损 失,的lossB为
其中Mij为需要蒸馏区域对应的蒸馏mask,W、H、C分别为骨干 网每个阶段输出特征图对应的宽、高和通道数。Fs和Ft分别对应学生网 络和教师网络输出的特征图。N为蒸馏mask中值为1的数目,即λ为目标的检测损失和蒸馏损失之间的权重系数。在本实 施例中,在知识蒸馏损失方面,采用的是MSE loss(均方损失函数), 也可以采用MAEloss等其他损失函数。采用的MSE loss更容易获得较 为稳定的解,而且MSE loss相比于MAE loss更容易捕获到教师网络和 学生网络输出特征图两者差异的地方。
步骤26,server端更新参数,并进行模型选优保存。
server端在保存参数时,利用指标评估方法选取一个或多 个第一模型参数和第二模型参数进行保存更新。在评估时可以采用 mAP(mean Average Precision,不同召回率上的正确率的平均值),loss 等评估指标。
步骤27,server端发送模型至client端。
将server端将model2发送至客户端1,将model1发送给 客户端2。在每个客户端中,交换学生网络模型和教师网络模型,进行 知识蒸馏训练,例如在客户端1中,以model2为学生网络,model1为 教师网络,在客户端2中,以model1为学生网络,model2为教师网络。然后重复25,26步骤直至达到设计的训练epoch数据。通过交换的方 式,可以同时学习两个模型的精华知识。
在服务器侧,根据本发明的实施例的联邦学习模型训练方 法,包括:向第一客户端和第二客户端分别发送用于模型训练的控制 参数;接收来自第一客户端的经过更新的第一神经网络模型的第一模 型参数;接收来自第二客户端的经过更新的第二神经网络模型的第二 模型参数;将第一模型参数发送至第二客户端,将第二模型参数发送 至第一客户端;保存第一模型参数和第二模型参数。
服务器可以用于管理和交换多个客户端之间的模型参数, 在不传输数据的情况下,完成具备多端知识的模型训练,打破数据孤 岛的问题。
在上述联邦学习模型训练方法的优选实施方式中,控制参 数至少包括训练次数n,n大于等于2;将更新后的第一模型参数发送 至第二客户端;将更新后的第二模型参数发送至第一客户端;接收并 保存来自第二客户端的经过二次训练的第一模型参数;接收并保存来 自第一客户端的经过二次训练的第二模型参数;如此循环,直至n次 训练结束。
在上述联邦学习模型训练方法的优选实施方式中,在保存 第一模型参数和第二模型参数时,利用指标评估方法选取一个或多个 第一模型参数和第二模型参数进行保存更新;将选取保存的模型参数 发送至对应客户端。
继续参考图3,根据本发明的实施例的基于知识蒸馏的联 邦学习模型训练客户端31或者32,包括:通讯模块33,接收来自服 务器的用于模型训练的控制参数,以及接收来自服务器的第二神经网 络模型的第二模型参数;算法训练模块32,与通讯模块33连接,根据控制参数以及本地数据样本对初始的第一神经网络模型进行训练,得 到第一模型参数,根据第二模型参数得到第二神经网络模型,以及利 用知识蒸馏方法使第一神经网络模型学习到第二神经网络模型的知 识,训练得到更新的第一神经网络模型,其中,第一神经网络模型是 学生网络模型。
在上述优选实施方式中,控制参数至少包括训练次数n,n 大于等于2,通讯模块33还将得到的第一神经网络模型的模型参数发 送至服务器进行更新,以及接收来自服务器的更新后的第二模型参数, 以及将所二次训练后的第二神经网络模型的第二模型参数发送至服务 器;算法训练模块32还根据更新后的第二模型参数得到更新后的第二 神经网络模型,以及利用知识蒸馏方法使更新后的第二神经网络模型 学习到更新的第一神经网络模型的知识,训练得到二次训练后的第二 神经网络模型,其中,更新后的第二神经网络模型是学生网络模型, 如此循环,直至n次训练结束。图3中的数据加载模块可用于在训练 模型时加载样本数据data1。
继续参考图3,根据本发明的实施例的服务器(server端) 300,可以包括:训练控制模块36,生成用于模型训练的控制参数;通 讯模块39,向第一客户端和第二客户端分别发送控制参数,接收来自 第一客户端的经过更新的第一神经网络模型的第一模型参数,以及接 收来自第二客户端的经过更新的第二神经网络模型的第二模型参数; 参数更新模块37,保存所述第一模型参数和所述第二模型参数。日志 管理模块用于保存运行日志。
在上述优选实施方式中,包括:所述控制参数至少包括训 练次数n,n大于等于2;所述通讯模块还用于将更新后的第一模型参 数发送至所述第二客户端,将更新后的第二模型参数发送至所述第一 客户端,接收并保存来自所述第二客户端的经过二次训练的第一模型 参数,以及接收来自所述第一客户端的经过二次训练的第二模型参数, 如此循环,直至所述n次训练结束;所述参数更新模块保存所述经过 二次训练的第二模型参数。
在上述优选实施方式中,还可以包括:模型优选模块38, 与所述参数更新模块37和通讯模块39连接,在保存第一模型参数和 第二模型参数时,利用指标评估方法选取一个或多个第一模型参数和 第二模型参数进行保存更新;所述通讯模块将选取保存的模型参数发 送至对应客户端。
如图3所示,是根据本发明的实施例的基于知识蒸馏的联 邦学习模型训练系统,包括多个如上述任一技术方案中所述的基于知 识蒸馏的联邦学习模型训练客户端(客户端31、客户端32)和如上述 任一技术方案中所述的服务器300。
本发明进一步还提供了一种计算机可读存储介质,所述存 储介质中存储有多条程序代码,所述程序代码适用于由处理器加载并 运行以执行基于知识蒸馏的联邦学习模型训练方法和联邦学习模型训 练方法。
本发明将模型训练设置在本地,并通过中心服务器完成模 型参数交互,解决了数据孤岛和数据隐私的问题,可在数据不离本地 的情况下完成算法模型的训练及优化。能够支持双模型之间的相互蒸 馏,充分利用全部数据的知识,提高算法模型在联邦框架下的训练效 果,同时一次训练过程可以完成两个神经网络模型的训练,打破数据 孤岛和传统训练框架对算法能力造成的瓶颈。
至此,已经结合附图所示的一个实施方式描述了本发明的 技术方案,但是,本领域技术人员容易理解的是,本发明的保护范围 显然不局限于这些具体实施方式。在不偏离本发明的原理的前提下, 本领域技术人员可以对相关技术特征作出等同的更改或替换,这些更 改或替换之后的技术方案都将落入本发明的保护范围之内。
Claims (14)
1.一种基于知识蒸馏的联邦学习模型训练方法,其特征在于,包括:
接收服务器获取的用于模型训练的控制参数;
根据所述控制参数以及本地数据样本对初始的第一神经网络模型进行训练,得到第一模型参数;
将所述第一模型参数发送至所述服务器;
接收所述服务器获取的第二神经网络模型的第二模型参数;
利用知识蒸馏方法使所述第一神经网络模型学习到所述第二神经网络模型的知识,训练得到更新的第一神经网络模型,其中,所述第一神经网络模型是学生网络模型。
2.根据权利要求1所述的基于知识蒸馏的联邦学习模型训练方法,其特征在于,所述控制参数至少包括训练次数n,n大于等于2,还包括:
将所述更新的第一神经网络模型的第一模型参数发送至所述服务器;
接收所述服务器获取的更新后的第二模型参数;
根据所述更新后的第二模型参数得到更新后的第二神经网络模型;
利用知识蒸馏方法使所述更新后的第二神经网络模型学习到所述更新的第一神经网络模型的知识,训练得到二次训练后的第二神经网络模型,其中,所述更新后的第二神经网络模型是学生网络模型;
将所述二次训练后的第二神经网络模型的第二模型参数发送至所述服务器;
如此循环,直至所述n次训练结束。
3.根据权利要求1或2所述的基于知识蒸馏的联邦学习模型训练方法,其特征在于,所述第一模型参数和所述第二模型参数至少包括神经网络的权重参数。
4.根据权利要求1或2所述的基于知识蒸馏的联邦学习模型训练方法,其特征在于,还包括:所述知识蒸馏方法所采用的损失函数包括以下任意一种:均方误差损失函数、平均绝对误差损失函数。
5.一种联邦学习模型训练方法,其特征在于,包括:
向第一客户端和第二客户端分别发送用于模型训练的控制参数;
接收来自第一客户端的经过更新的第一神经网络模型的第一模型参数;
接收来自第二客户端的经过更新的第二神经网络模型的第二模型参数;
将所述第一模型参数发送至所述第二客户端,将所述第二模型参数发送至所述第一客户端;
保存所述第一模型参数和所述第二模型参数。
6.根据权利要求5所述的联邦学习模型训练方法,其特征在于,包括:所述控制参数至少包括训练次数n,n大于等于2;
将更新后的第一模型参数发送至所述第二客户端;
将更新后的第二模型参数发送至所述第一客户端;
接收并保存来自所述第二客户端的经过二次训练的第一模型参数;
接收并保存来自所述第一客户端的经过二次训练的第二模型参数;
如此循环,直至所述n次训练结束。
7.根据权利要求6所述的联邦学习模型训练方法,其特征在于,在保存第一模型参数和第二模型参数时,利用指标评估方法选取一个或多个第一模型参数和第二模型参数进行保存更新;
将选取保存的模型参数发送至对应客户端。
8.一种基于知识蒸馏的联邦学习模型训练客户端,其特征在于,包括:
通讯模块,接收来自服务器的用于模型训练的控制参数,以及接收来自所述服务器的第二神经网络模型的第二模型参数;
算法训练模块,与所述通讯模块连接,根据所述控制参数以及本地数据样本对初始的第一神经网络模型进行训练,得到第一模型参数,根据所述第二模型参数得到所述第二神经网络模型,以及利用知识蒸馏方法使所述第一神经网络模型学习到所述第二神经网络模型的知识,训练得到更新的第一神经网络模型,其中,所述第一神经网络模型是学生网络模型。
9.根据权利要求8所述的一种基于知识蒸馏的联邦学习模型训练客户端,其特征在于,所述控制参数至少包括训练次数n,n大于等于2,所述通讯模块还将得到的第一神经网络模型的模型参数发送至所述服务器进行更新,以及接收来自所述服务器的更新后的第二模型参数,以及将所二次训练后的第二神经网络模型的第二模型参数发送至所述服务器;
所述算法训练模块还根据所述更新后的第二模型参数得到更新后的第二神经网络模型,以及利用知识蒸馏方法使所述更新后的第二神经网络模型学习到所述更新的第一神经网络模型的知识,训练得到二次训练后的第二神经网络模型,其中,所述更新后的第二神经网络模型是学生网络模型,如此循环,直至所述n次训练结束。
10.一种服务器,其特征在于,包括:
训练控制模块,生成用于模型训练的控制参数;
通讯模块,向第一客户端和第二客户端分别发送所述控制参数,接收来自第一客户端的经过更新的第一神经网络模型的第一模型参数,以及接收来自第二客户端的经过更新的第二神经网络模型的第二模型参数;
参数更新模块,保存所述第一模型参数和所述第二模型参数。
11.根据权利要求10所述的服务器,其特征在于,包括:所述控制参数至少包括训练次数n,n大于等于2;
所述通讯模块还用于将更新后的第一模型参数发送至所述第二客户端,将更新后的第二模型参数发送至所述第一客户端,接收并保存来自所述第二客户端的经过二次训练的第一模型参数,以及接收来自所述第一客户端的经过二次训练的第二模型参数,如此循环,直至所述n次训练结束;
所述参数更新模块保存所述经过二次训练的第二模型参数。
12.根据权利要求11所述的服务器,其特征在于,还包括:
模型优选模块,与所述参数更新模块和通讯模块连接,在保存第一模型参数和第二模型参数时,利用指标评估方法选取一个或多个第一模型参数和第二模型参数进行保存更新;
所述通讯模块将选取保存的模型参数发送至对应客户端。
13.一种基于知识蒸馏的联邦学习模型训练系统,其特征在于,包括多个如权利要求8或9所述的基于知识蒸馏的联邦学习模型训练客户端和如权利要求10至12中任一项所述的服务器。
14.一种计算机可读存储介质,其特征在于,所述存储介质中存储有多条程序代码,所述程序代码适用于由处理器加载并运行以执行权利要求1至4中任一项所述的基于知识蒸馏的联邦学习模型训练方法和权利要求5至7中任一项所述的联邦学习模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110391127.0A CN113298229A (zh) | 2021-04-12 | 2021-04-12 | 联邦学习模型训练方法、客户端、服务器及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110391127.0A CN113298229A (zh) | 2021-04-12 | 2021-04-12 | 联邦学习模型训练方法、客户端、服务器及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113298229A true CN113298229A (zh) | 2021-08-24 |
Family
ID=77319667
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110391127.0A Pending CN113298229A (zh) | 2021-04-12 | 2021-04-12 | 联邦学习模型训练方法、客户端、服务器及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113298229A (zh) |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113919508A (zh) * | 2021-10-15 | 2022-01-11 | 河南工业大学 | 一种基于移动式服务器的联邦学习系统及方法 |
CN114220438A (zh) * | 2022-02-22 | 2022-03-22 | 武汉大学 | 基于bottleneck和通道切分的轻量级说话人识别方法及系统 |
CN115775010A (zh) * | 2022-11-23 | 2023-03-10 | 国网江苏省电力有限公司信息通信分公司 | 基于横向联邦学习的电力数据共享方法 |
CN115907001A (zh) * | 2022-11-11 | 2023-04-04 | 中南大学 | 基于知识蒸馏的联邦图学习方法及自动驾驶方法 |
CN117094355A (zh) * | 2023-10-20 | 2023-11-21 | 网络通信与安全紫金山实验室 | 模型更新方法、非易失性存储介质及计算机设备 |
CN117829320A (zh) * | 2024-03-05 | 2024-04-05 | 中国海洋大学 | 一种基于图神经网络和双向深度知识蒸馏的联邦学习方法 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109167695A (zh) * | 2018-10-26 | 2019-01-08 | 深圳前海微众银行股份有限公司 | 基于联邦学习的联盟网络构建方法、设备及可读存储介质 |
CN110572253A (zh) * | 2019-09-16 | 2019-12-13 | 济南大学 | 一种联邦学习训练数据隐私性增强方法及系统 |
CN110795477A (zh) * | 2019-09-20 | 2020-02-14 | 平安科技(深圳)有限公司 | 数据的训练方法及装置、系统 |
US20200302230A1 (en) * | 2019-03-21 | 2020-09-24 | International Business Machines Corporation | Method of incremental learning for object detection |
CN112580821A (zh) * | 2020-12-10 | 2021-03-30 | 深圳前海微众银行股份有限公司 | 一种联邦学习方法、装置、设备及存储介质 |
-
2021
- 2021-04-12 CN CN202110391127.0A patent/CN113298229A/zh active Pending
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109167695A (zh) * | 2018-10-26 | 2019-01-08 | 深圳前海微众银行股份有限公司 | 基于联邦学习的联盟网络构建方法、设备及可读存储介质 |
US20200302230A1 (en) * | 2019-03-21 | 2020-09-24 | International Business Machines Corporation | Method of incremental learning for object detection |
CN110572253A (zh) * | 2019-09-16 | 2019-12-13 | 济南大学 | 一种联邦学习训练数据隐私性增强方法及系统 |
CN110795477A (zh) * | 2019-09-20 | 2020-02-14 | 平安科技(深圳)有限公司 | 数据的训练方法及装置、系统 |
CN112580821A (zh) * | 2020-12-10 | 2021-03-30 | 深圳前海微众银行股份有限公司 | 一种联邦学习方法、装置、设备及存储介质 |
Non-Patent Citations (1)
Title |
---|
星辰大海与ZH: "Deep Mutual Learning", pages 1, Retrieved from the Internet <URL:https://zhuanlan.zhihu.com/p/86602170> * |
Cited By (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113919508A (zh) * | 2021-10-15 | 2022-01-11 | 河南工业大学 | 一种基于移动式服务器的联邦学习系统及方法 |
CN114220438A (zh) * | 2022-02-22 | 2022-03-22 | 武汉大学 | 基于bottleneck和通道切分的轻量级说话人识别方法及系统 |
CN115907001A (zh) * | 2022-11-11 | 2023-04-04 | 中南大学 | 基于知识蒸馏的联邦图学习方法及自动驾驶方法 |
CN115907001B (zh) * | 2022-11-11 | 2023-07-04 | 中南大学 | 基于知识蒸馏的联邦图学习方法及自动驾驶方法 |
CN115775010A (zh) * | 2022-11-23 | 2023-03-10 | 国网江苏省电力有限公司信息通信分公司 | 基于横向联邦学习的电力数据共享方法 |
CN115775010B (zh) * | 2022-11-23 | 2024-03-19 | 国网江苏省电力有限公司信息通信分公司 | 基于横向联邦学习的电力数据共享方法 |
CN117094355A (zh) * | 2023-10-20 | 2023-11-21 | 网络通信与安全紫金山实验室 | 模型更新方法、非易失性存储介质及计算机设备 |
CN117094355B (zh) * | 2023-10-20 | 2024-03-29 | 网络通信与安全紫金山实验室 | 模型更新方法、非易失性存储介质及计算机设备 |
CN117829320A (zh) * | 2024-03-05 | 2024-04-05 | 中国海洋大学 | 一种基于图神经网络和双向深度知识蒸馏的联邦学习方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113298229A (zh) | 联邦学习模型训练方法、客户端、服务器及存储介质 | |
CN109990790B (zh) | 一种无人机路径规划方法及装置 | |
US20190095780A1 (en) | Method and apparatus for generating neural network structure, electronic device, and storage medium | |
EP3528433B1 (en) | Data analyses using compressive sensing for internet of things (iot) networks | |
CN110598859B (zh) | 基于门控循环神经网络的非线性均衡方法 | |
CN113286275A (zh) | 一种基于多智能体强化学习的无人机集群高效通信方法 | |
CN113238867A (zh) | 一种基于网络卸载的联邦学习方法 | |
CN113657607A (zh) | 一种面向联邦学习的连续学习方法 | |
Karim et al. | Rl-ncs: Reinforcement learning based data-driven approach for nonuniform compressed sensing | |
CN109981596B (zh) | 一种主机外联检测方法及装置 | |
Fouda et al. | A lightweight hierarchical AI model for UAV-enabled edge computing with forest-fire detection use-case | |
Li et al. | Respipe: Resilient model-distributed dnn training at edge networks | |
CN116089652B (zh) | 视觉检索模型的无监督训练方法、装置和电子设备 | |
Gutierrez-Estevez et al. | Learning to communicate with intent: An introduction | |
CN116541779A (zh) | 个性化公共安全突发事件检测模型训练方法、检测方法及装置 | |
CN116433470A (zh) | 模型训练方法、数据增强方法、目标检测方法及相关设备 | |
CN115908522A (zh) | 基于终身学习的单目深度估计方法及相关设备 | |
CN115905978A (zh) | 基于分层联邦学习的故障诊断方法及系统 | |
CN115426635A (zh) | 一种不可靠传输场景下无人机通信网络推断方法及系统 | |
CN115134114A (zh) | 基于离散混淆自编码器的纵向联邦学习攻击防御方法 | |
CN115001937A (zh) | 面向智慧城市物联网的故障预测方法及装置 | |
CN114580661A (zh) | 基于联邦学习的数据处理方法、装置和计算机设备 | |
US20230018893A1 (en) | Multitask distributed learning system and method based on lottery ticket neural network | |
WO2022105374A1 (zh) | 信息处理方法、模型的生成及训练方法、电子设备和介质 | |
CN115631529B (zh) | 人脸特征隐私保护方法、人脸识别方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |