CN114003949B - 基于隐私数据集的模型训练方法和装置 - Google Patents
基于隐私数据集的模型训练方法和装置 Download PDFInfo
- Publication number
- CN114003949B CN114003949B CN202111189306.2A CN202111189306A CN114003949B CN 114003949 B CN114003949 B CN 114003949B CN 202111189306 A CN202111189306 A CN 202111189306A CN 114003949 B CN114003949 B CN 114003949B
- Authority
- CN
- China
- Prior art keywords
- model
- public data
- target
- output
- data set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- Evolutionary Computation (AREA)
- Biomedical Technology (AREA)
- Computing Systems (AREA)
- Bioethics (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Mathematical Physics (AREA)
- Molecular Biology (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及多方数据合作的技术领域,提供一种基于隐私数据集的模型训练方法和装置。其中,方法包括:基于公开数据集和与公开数据集对应的真实标签,对服务器端模型进行训练;获取各个客户端发送的第一模型输出;第一模型输出是将公开数据集输入本地学习模型得到的;本地学习模型为基于隐私数据集和对应标签训练得到的;基于各第一模型输出的对应的公开数据,对服务器端模型进行训练;将公开数据集输入服务器端模型,得到第二模型输出;将第二模型输出发送至各客户端,以供各客户端基于第二模型输出和公开数据集,进行本地学习模型的再训练。如此在避免隐私数据集泄露的前提下,基于知识蒸馏和知识融合以隐私数据集为部分训练样本进行模型训练。
Description
技术领域
本发明涉及多方数据合作的技术领域,尤其涉及一种基于隐私数据集的模型训练方法和装置。
背景技术
在数据分析、数据挖掘、经济预测等领域,机器学习模型可被用来分析、发现潜在的数据价值。由于单个数据拥有方持有的数据可能是不完整的,由此难以准确地刻画目标,为了得到更好的模型预测结果,通过多个数据拥有方的数据合作,来进行模型的联合训练的方式得到了广泛的使用。但是在多方数据合作的过程中,涉及到数据安全和模型安全等问题。
特别是在医疗领域,一些数据集涉及隐私无法公开,只可以在医院内部使用。若想基于各个医院的隐私数据集搭建一个学习模型十分困难。现有的方案中,存在利用隐私数据集和将隐私数据集输入学习模型后得到的模型输出(一般为学习模型的最后一层神经网络的输出)而非模型结果和对应标签作为交换的信息,通过知识蒸馏和知识融合的方式进行模型的训练。但是这种方式下,不仅仍然存在隐私泄露的问题。
因此,目前缺少基于多方的隐私数据集进行模型训练方案。
发明内容
本发明实施例提供一种基于隐私数据集的模型训练方法和装置,用以解决现有缺少基于多方的隐私数据集进行模型训练方案问题。
第一方面,本发明实施例提供一种基于隐私数据集的模型训练方法,包括:
基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;
获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;
基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;
将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;
将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。
可选的,所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,包括:
将所述公开数据集输入服务器端模型得到预测结果;
基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练。
可选的,所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,还包括:
确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;
确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;
确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据;
基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。
可选的,所述获取各个客户端发送的第一模型输出,包括:
确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标待蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据;
向各所述客户端发送请求;所述请求用于请求客户端回传第一模型的第一输出;所述第一模型的第一输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;
接收各所述客户端回传的第一模型的第一输出。
可选的,所述基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练,包括:
对所述第一模型的第一输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型的第一输出中对应的预测结果符合对应的真实标签的部分模型输出;
确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所述第二待蒸馏公开数据中具有对应的第二目标模型输出的部分数据;
基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。
可选的,所述基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练,包括:
确定各所述第二目标模型输出中的模型输出的信息熵;
基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值;
基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模型输出;
基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。
可选的,所述公开数据集和所述隐私数据集包括:与实体相关的图像数据、文本数据或声音数据。
第二方面,本发明实施例提供一种基于隐私数据集的模型训练装置包括:
第一训练单元,用于基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;
获取单元,用于获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;
第二训练单元,用于基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;
输入单元,用于将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;
发送单元,用于将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。
第三方面,本发明实施例提供一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如本发明提供的基于隐私数据集的模型训练方法的步骤。
第四方面,本发明实施例提供一种非暂态计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现如本发明提供的基于隐私数据集的模型训练方法的步骤。
本发明实施例提供一种基于隐私数据集的模型训练方法,通过公开数据集、第一模型输出和第二模型输出作为各个本地学习模型和服务器端模型信息交换的渠道和媒介,充分发挥服务器端模型的自主训练能力,基于各个第一模型输出进行知识蒸馏和知识融合,之后将融合后得到的知识,基于第二模型输出发送回各个本地学习模型,使得各个本地学习模型可以得到融合后的知识。即:通过公开数据集、第一模型输出和第二模型输出作为知识传输的媒介,将所有的知识都存储在一个强大的模型(服务器端模型)里作为通用的知识库来帮助联邦学习。服务器端模型不仅仅利用充分的计算资源去训练自身,同时也会将所有的客户端作为多个老师来学习知识,帮助服务器端的模型的效果进一步提升。作为回报,服务器端的积累的知识也会进一步传递给客户端来帮助所有客户端的本地学习模型效果提升,使得最后训练得到的各个本地学习模型包含多方隐私数据集的知识,即各个本地学习模型是基于多方隐私数据集训练得到的。如此,本发明实施例提供了一种可行的基于隐私数据集的模型训练方法,所述基于隐私数据集的模型训练方法可以具体应用于医疗领域中一些有关隐私数据的模型的训练。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作一简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的基于隐私数据集的模型训练方法的流程示意图之一;
图2为本发明实施例提供的基于隐私数据集的模型训练方法的流程示意图之二;
图3为本发明实施例提供的基于隐私数据集的模型训练方法的流程示意图之三;
图4为本发明实施例提供的基于隐私数据集的模型训练方法的流程示意图之四;
图5为本发明实施例提供的基于隐私数据集的模型训练方法的流程示意图之五;
图6为本发明实施例提供的基于隐私数据集的模型训练装置的结构示意图;
图7为本发明实施例提供的电子设备的结构示意图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
在经济、文化、教育、医疗、公共管理等各行各业充斥的大量信息数据,对其进行例如数据分析、数据挖掘、以及趋势预测等的数据处理在越来越多场景中广泛应用。其中,通过数据合作的方式可以使多个数据拥有方获得更好的数据处理结果。例如,可以通过多方数据的联合训练来获得更为准确的模型参数。
在一些实施例中,基于隐私数据进行模型的联合训练系统可以应用于在保证各方数据安全的情况下,各方协同训练机器学习模型供多方使用的场景。在这个场景中,多个数据方拥有自己的数据,他们想共同使用彼此的数据来统一建模(例如,线性回归模型、逻辑回归模型等),但并不想各自的数据(尤其是隐私数据)被泄露。例如,医院A拥有一批患者数据(例如患者病患部位的照片)因为患者隐私问题不适合公开,医院B拥有一批患者数据同样因为患者隐私问题不适合公开,基于医院A和医院B的患者数据确定的训练样本集可以训练得到比较好的机器学习模型。A和B都愿意通过彼此的患者数据共同参与模型训练,但是医院A和医院B需要保证患者数据不会遭到泄露,不可以或者不愿意让对方知道自己的患者数据。 因此需要一种基于隐私数据集的模型训练方法可以使多方的隐私数据在不受到泄露的情况下,通过多方数据的联合训练来得到共同使用的机器学习模型,达到一种共赢的合作状态。 基于此,本发明实施例基于知识蒸馏和联邦学习提供一种基于隐私数据集的模型训练方法和装置。
其中,在传统的联邦学习设置中,客户端上传模型参数或者模型梯度给中心服务器端,由服务器端按照一定的形式聚合后分发回客户端,并且在本地化数据上进一步更新。传递参数或者梯度会带来一系列隐私、异质性以及通讯成本的问题,目前有工作采用知识蒸馏的方式在终端和服务器端传递知识来解决。但是由于客户端实际上是资源受限的,直接在客户端使用大模型是不可能的,因此如何解决资源受限问题仍然是一个巨大的挑战。只有通过尽可能地去挖掘服务器端的计算资源,在服务器端利用辅助的大模型传输和累积知识,才能实现和用大模型进行中心化训练一样的知识融合的效果。
图1为本发明实施例提供的基于隐私数据集的模型训练方法的流程示意图之一,如图1所示,该方法包括:
步骤110,基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;
其中,公开数据集和隐私数据集为同一类数据,只是公开数据集为可以进行公开的数据,隐私数据集为不可以或者不适合进行公开的数据。具体的,公开数据集和隐私数据集可以为与实体相关的图像数据、文本数据或声音数据。例如,一些医院的患者疾病图片数据,一些互联网公司的用户数据。所述服务器端模型为大模型,即服务器端模型较为复杂,可以尽可能的挖掘和学习知识。
步骤120,获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;
步骤130,基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;
步骤140,将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;
步骤150,将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。
通过公开数据集、第一模型输出和第二模型输出作为各个本地学习模型和服务器端模型信息交换的渠道和媒介,充分发挥服务器端模型的自主训练能力,基于各个第一模型输出进行知识蒸馏和知识融合,之后将融合后得到的知识,基于第二模型输出发送回各个本地学习模型,使得各个本地学习模型可以得到融合后的知识。即:通过公开数据集、第一模型输出和第二模型输出作为知识传输的媒介,将所有的知识都存储在一个强大的模型(服务器端模型)里作为通用的知识库来帮助联邦学习。服务器端模型不仅仅利用充分的计算资源去训练自身,同时也会将所有的客户端作为多个老师来学习知识,帮助服务器端的模型的效果进一步提升。作为回报,服务器端的积累的知识也会进一步传递给客户端来帮助所有客户端的本地学习模型效果提升,使得最后训练得到的各个本地学习模型包含多方隐私数据集的知识,即各个本地学习模型是基于多方隐私数据集训练得到的。
本发明实施例提供的方案中,服务器端模型作为知识聚合的中心,其学习的知识直接影响了最终各个本地学习模型基于第二模型输出获取的知识;因此服务器端模型的训练是比较重要的一部分。
具体的,步骤110中“基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练”和步骤130“基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练”是服务器端模型的训练部分:
服务器端模型的训练主要分为3部分:初步训练、自蒸馏和聚集蒸馏(再训练)。需要说明的是,这3个部分并非是严格按照时间顺序执行的,而是互相融合进行的。
参照图2,初步训练、自蒸馏的步骤具体如下: 步骤111,将所述公开数据集输入服务器端模型得到预测结果;
步骤112,基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练。
这一部分的训练比较常规,简单地采用预测结果和真实标签之间的交叉熵损失函数来训练服务器端模型。具体的可以参照一些现有的训练实施例。
步骤113,确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;
具体的,步骤114,进行第一目标模型输出的存储,为步骤114、步骤115和步骤116中得自蒸馏进行准备。将本次预测正确的模型输出(即:第一目标模型输出)保存到全局的模型输出中作为记忆,帮助之后纠正预测错误但曾经做对过的样本。
关于自蒸馏的具体说明如下:针对公开数据集中模型预测错误的样本(即目标待蒸馏公开数据),我们首先去寻找全局的模型输出的记忆(即:第一目标模型输出)中是否存在该样本对应的模型输出,如果存在的话说明这部分的知识模型曾经是包含的,因此温习自身曾经掌握的知识可以帮助模型纠正自身的错误,针对该思路我们采用了自蒸馏的方式对模型进行蒸馏训练,用目前的模型输出去接近之前做对的时候的模型输出,同时结合交叉熵损失。具体步骤如下:
步骤114,确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;
步骤115,确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据;
步骤116,基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。
通过上述方式完成服务器端模型的自蒸馏。对于目标待蒸馏公开数据中除了第一待蒸馏公开数据的其他数据进行聚焦蒸馏;聚焦蒸馏是本发明实施例方案的一个核心,只要用于获取其他的客户端的隐私数据集的知识。
参照图3,进行聚集蒸馏之前需要执行步骤120中“获取各个客户端发送的第一模型输出”,具体步骤如下:
步骤121,客户端基于隐私数据集和对应标签对预设的预设服务器端模型训练;
步骤122,客户端将所述公开数据集输入本地学习模型得到的模型输出;
步骤123,确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标待蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据;
步骤124,向各所述客户端发送请求;
步骤125;客户端回传第一模型的第一输出;其中,所述第一模型的第一输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;
步骤126,接收各所述客户端回传的第一模型的第一输出。
如此设置,各个客户端发送的数据为进行聚集蒸馏使用的数据,将各个基于隐私数据集训练的到的本地学习模型包含的知识,基于这些第一模型输出将这些知识发送至服务器端模型。如此设置,不仅仅避免了知识的传输过程中的隐私数据泄露的问题,还减少了需要传输的数据的量。
针对服务器端从始至终从未做对过的样本(即:第二待蒸馏公开数据),我们认为服务器端暂时不具备仅依靠自身预测正确的能力,因此选择聚集来自客户端的知识来帮助引导服务器端学习。首先我们从所有的客户端选择出能预测正确答案的模型,然后根据模型的输出的信息熵高低,以信息熵越高则其相应的置信度越低为原则,对其进行加权。
具体的,参照图4,进行聚集蒸馏,步骤主要包括:
步骤131,对所述第一模型的第一输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型的第一输出中对应的预测结果符合对应的真实标签的部分模型输出;
这一步骤的目的是剔除第一模型输出中,无法对的第二待蒸馏公开数据的训练起到好的教导作用的模型输出。
步骤132,确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所述第二待蒸馏公开数据中具有对应的第二目标模型输出的部分数据;
步骤133,确定各所述第二目标模型输出中的模型输出的信息熵;
步骤134,基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值;
步骤135,基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模型输出;
步骤136,基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。即:利用加权得到的模型输出结合交叉熵损失对服务器端进行蒸馏。
其中,步骤133到步骤136,基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。在具体的融合过程中,根据模型的输出的信息熵高低对其进行加权,认为信息熵越高则其相应的置信度越低,有选择的进行知识的融合。之后执行步骤140 和步骤150完成本地学习模型的再训练。
基于上述方案,本发明实施例提供一种新颖的方法,采用有选择地知识融合的方式将所有的知识都存储在一个强大的模型里作为通用的知识库来帮助联邦学习。服务器端模型不仅仅利用充分的计算资源去训练自身,同时也会将所有的客户端作为多个老师来学习知识,帮助服务器端的模型的效果进一步提升。作为回报,服务器端的积累的知识也会进一步传递给客户端来帮助所有客户端本地学习模型效果提升。与此同时,还能增加两端模型的鲁棒性,并且减少从客户端上传知识到服务器端的通讯成本。
下面结合具体的实施例对本发明实施例提供的方案进行说明:
参照图5,基于隐私数据集的模型训练系统包括:一个服务器端和多个客户端(图5中以医院A和医院B来表示多个客户端)
步骤501:医院A基于隐私数据集A和对应标签对预设模型训练得到本地学习模型A;
步骤502:医院B基于隐私数据集B和对应标签对预设模型训练得到本地学习模型B;
步骤503:服务器端基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;
其中,隐私数据集A、隐私数据集B和公开数据集为患者伤患处的图片,本发明实施例的主要目的是得到一种可以识别伤患,对伤患进行预测的模型;
步骤504:将公开数据集输入服务器端模型进行预测;
步骤505:将预测正确的模型输出保存到全局的模型输出的记忆中;
步骤506:基于全局的模型输出的记忆,对部分预测错误的样本进行自蒸馏;
步骤507:获取医院A发送的模型输出A;
步骤508:获取医院B发送的模型输出B;
其中,模型输出A是将公开数据集输入本地学习模型A得到的;模型输出B是将公开数据集输入本地学习模型B得到的;
步骤509:对模型输出A和对模型输出B进行剔除操作和加权融合。
步骤510:基于融合后的模型输出,对部分预测错误的样本进行聚集蒸馏;
需说明的是,进行聚集蒸馏的数据可以为多张图片,模型输出A和模型输出B中具有针对每一张进行聚集蒸馏的图片的模型输出;在进行融合和剔除时,应该一张图片一张图片的进行。即首先确定进行聚集蒸馏的一张图片,之后找出医院A和医院B获取对应这张图片的模型输出;判断这两个模型输出得到的预测结果是否与真实标签匹配,如果匹配,则确定这两种图片的信息熵,基于信息熵的高低对其进行加权,认为信息熵越高则其相应的置信度越低。
步骤511:将公开数据集输入服务器端模型进行预测得到第二模型输出;
步骤512:发送第二模型输出至医院A;
步骤513:基于第二模型输出训练本地学习模型A
步骤514:发送第二模型输出至医院B;
步骤515:基于第二模型输出训练本地学习模型B。
如此循环进行,采用有选择地知识融合的方式将所有的知识都存储在一个强大的模型里作为通用的知识库来帮助联邦学习。之后传递给客户端医院A和医院B来帮助提示本地学习模型A本地学习模型B的效果。使得医院A和医院B在不泄露自身隐私数据集的情况下,进行联合训练分别得到实际预测效果较好的本地学习模型A本地学习模型B。
基于上述任一实施例,图6为本发明实施例提供的基于隐私数据集的模型训练装置的结构示意图,如图6所示,该装置包括:
第一训练单元61,用于基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;
获取单元62,用于获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;
第二训练单元63,用于基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;
输入单元64,用于将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;
发送单元65,用于将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。
其中,第一训练单元61,具体用于:
将所述公开数据集输入服务器端模型得到预测结果;
基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练。
确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;
确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;
确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据;
基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。
其中,所述获取各个客户端发送的第一模型输出,包括:
确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标待蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据;
向各所述客户端发送请求;所述请求用于请求客户端回传第一模型输出;所述第一模型的第一输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;
接收各所述客户端回传的第一模型的第一输出。
可选的,第二训练单元63,具体用于:
对所述第一模型的第一输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型的第一输出中对应的预测结果符合对应的真实标签的部分模型输出;
确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所述第二待蒸馏公开数据中具有对应的第二目标模型输出的部分数据;
基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。
可选的,所述基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练,包括:
确定各所述第二目标模型输出中的模型输出的信息熵;
基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值;
基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模型输出;
基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。
可选的,所述公开数据集和所述隐私数据集包括:与实体相关的图像数据、文本数据或声音数据。
图7为本发明实施例提供的电子设备的结构示意图,如图7所示,该电子设备可以包括:处理器(processor)710、通信接口(Communications Interface)720、存储器(memory)730和通信总线740,其中,处理器710,通信接口720,存储器730通过通信总线740完成相互间的通信。处理器710可以调用存储器730中的逻辑命令,以执行如下方法:基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。
此外,上述的存储器730中的逻辑命令可以通过软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干命令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本发明各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质。
本发明实施例还提供一种非暂态计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现以执行上述各实施例提供的方法,例如包括:基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练。
以上所描述的装置实施例仅仅是示意性的,其中所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块来实现本实施例方案的目的。本领域普通技术人员在不付出创造性的劳动的情况下,即可以理解并实施。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到各实施方式可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件。基于这样的理解,上述技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品可以存储在计算机可读存储介质中,如ROM/RAM、磁碟、光盘等,包括若干命令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行各个实施例或者实施例的某些部分所述的方法。
最后应说明的是:以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。
Claims (8)
1.一种基于隐私数据集的模型训练方法,其特征在于,包括:
基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;
获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;
基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;
将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;
将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练;
所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,包括:
将所述公开数据集输入服务器端模型得到预测结果;
基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练;
所述基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练,还包括:
确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;
确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;
确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据;
基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。
2.根据权利要求1所述的基于隐私数据集的模型训练方法,其特征在于,所述获取各个客户端发送的第一模型输出,包括:
确定第二待蒸馏公开数据;所述第二待蒸馏公开数据为所述目标待蒸馏公开数据中,不具有对应的第一目标模型输出的部分数据;
向各所述客户端发送请求;所述请求用于请求客户端回传第一模型的第一输出;所述第一模型的第一输出为各本地学习模型的模型输出中对应所述第二待蒸馏公开数据的部分模型输出;
接收各所述客户端回传的第一模型的第一输出。
3.根据权利要求2所述的基于隐私数据集的模型训练方法,其特征在于,所述基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练,包括:
对所述第一模型的第一输出进行筛选,得到第二目标模型输出;所述第二目标模型输出为所述第一模型的第一输出中对应的预测结果符合对应的真实标签的部分模型输出;
确定第三待蒸馏公开数据;所述第三待蒸馏公开数据为所述第二待蒸馏公开数据中具有对应的第二目标模型输出的部分数据;
基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练。
4.根据权利要求3所述的基于隐私数据集的模型训练方法,其特征在于,所述基于第三待蒸馏公开数据和各所述第二目标模型输出,对服务器端模型进行再训练,包括:
确定各所述第二目标模型输出中的模型输出的信息熵;
基于所述信息熵的大小确定各所述第二目标模型输出中的模型输出的权值;
基于所述权值对各所述第二目标模型输出进行融合,得到第三目标模型输出;
基于所述第三待蒸馏公开数据和所述第三目标模型输出,对服务器端模型进行再训练。
5.根据权利要求1所述的基于隐私数据集的模型训练方法,其特征在于,所述公开数据集和所述隐私数据集包括:与实体相关的图像数据、文本数据或声音数据。
6.一种基于隐私数据集的模型训练装置,其特征在于,包括:
第一训练单元,用于基于公开数据集和与所述公开数据集对应的真实标签,对服务器端模型进行训练;
获取单元,用于获取各个客户端发送的第一模型输出;所述第一模型输出为客户端将所述公开数据集输入本地学习模型得到的;所述本地学习模型为客户端基于隐私数据集和对应标签对预设模型训练得到的;
第二训练单元,用于基于与各所述第一模型输出对应的公开数据和各所述第一模型输出,对服务器端模型进行再训练;
输入单元,用于将所述公开数据集输入再训练后的所述服务器端模型,得到第二模型输出;
发送单元,用于将所述第二模型输出发送至各所述客户端,以供各所述客户端基于所述第二模型输出和所述公开数据集,进行所述本地学习模型的再训练;
其中,第一训练单元,具体用于:
将所述公开数据集输入服务器端模型得到预测结果;基于所述预测结果与所述真实标签之间的交叉熵损失函数,对所述服务器端模型进行训练;确定并存储第一目标模型输出;所述第一目标模型输出为与目标公开数据对应的模型输出;所述目标公开数据为所述公开数据集中,被输入服务器端模型后得到的预测结果符合对应真实标签的公开数据;确定目标待蒸馏公开数据;所述目标待蒸馏公开数据为所述公开数据集中,输入服务器端模型后得到的预测结果不符合对应真实标签的公开数据;确定第一待蒸馏公开数据;所述第一待蒸馏公开数据为所述目标待蒸馏公开数据中,具有对应的第一目标模型输出的部分数据;基于所述第一待蒸馏公开数据和与所述第一待蒸馏公开数据对应的第一目标模型输出,对所述服务器端模型进行训练。
7.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1至5中任一项所述的基于隐私数据集的模型训练方法的步骤。
8.一种非暂态计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现如权利要求1至5中任一项所述的基于隐私数据集的模型训练方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
PCT/CN2022/085131 WO2023050754A1 (zh) | 2021-09-30 | 2022-04-02 | 隐私数据集的模型训练方法和装置 |
Applications Claiming Priority (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN2021111656796 | 2021-09-30 | ||
CN202111165679 | 2021-09-30 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114003949A CN114003949A (zh) | 2022-02-01 |
CN114003949B true CN114003949B (zh) | 2022-08-30 |
Family
ID=79922769
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111189306.2A Active CN114003949B (zh) | 2021-09-30 | 2021-10-12 | 基于隐私数据集的模型训练方法和装置 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN114003949B (zh) |
WO (1) | WO2023050754A1 (zh) |
Families Citing this family (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114003949B (zh) * | 2021-09-30 | 2022-08-30 | 清华大学 | 基于隐私数据集的模型训练方法和装置 |
CN115238826B (zh) * | 2022-09-15 | 2022-12-27 | 支付宝(杭州)信息技术有限公司 | 一种模型训练的方法、装置、存储介质及电子设备 |
CN115270001B (zh) * | 2022-09-23 | 2022-12-23 | 宁波大学 | 基于云端协同学习的隐私保护推荐方法及系统 |
CN115578369B (zh) * | 2022-10-28 | 2023-09-15 | 佐健(上海)生物医疗科技有限公司 | 一种基于联邦学习的在线宫颈细胞tct切片检测方法和系统 |
CN116797829B (zh) * | 2023-06-13 | 2024-06-14 | 北京百度网讯科技有限公司 | 一种模型生成方法、图像分类方法、装置、设备及介质 |
CN117313869B (zh) * | 2023-10-30 | 2024-04-05 | 浙江大学 | 一种基于模型分割的大模型隐私保护推理方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112329052A (zh) * | 2020-10-26 | 2021-02-05 | 哈尔滨工业大学(深圳) | 一种模型隐私保护方法及装置 |
CN112862011A (zh) * | 2021-03-31 | 2021-05-28 | 中国工商银行股份有限公司 | 基于联邦学习的模型训练方法、装置及联邦学习系统 |
WO2021184836A1 (zh) * | 2020-03-20 | 2021-09-23 | 深圳前海微众银行股份有限公司 | 识别模型的训练方法、装置、设备及可读存储介质 |
Family Cites Families (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US11580453B2 (en) * | 2020-02-27 | 2023-02-14 | Omron Corporation | Adaptive co-distillation model |
CN113052334B (zh) * | 2021-04-14 | 2023-09-29 | 中南大学 | 一种联邦学习实现方法、系统、终端设备及可读存储介质 |
CN113222175B (zh) * | 2021-04-29 | 2023-04-18 | 深圳前海微众银行股份有限公司 | 信息处理方法及系统 |
CN114003949B (zh) * | 2021-09-30 | 2022-08-30 | 清华大学 | 基于隐私数据集的模型训练方法和装置 |
-
2021
- 2021-10-12 CN CN202111189306.2A patent/CN114003949B/zh active Active
-
2022
- 2022-04-02 WO PCT/CN2022/085131 patent/WO2023050754A1/zh active Application Filing
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021184836A1 (zh) * | 2020-03-20 | 2021-09-23 | 深圳前海微众银行股份有限公司 | 识别模型的训练方法、装置、设备及可读存储介质 |
CN112329052A (zh) * | 2020-10-26 | 2021-02-05 | 哈尔滨工业大学(深圳) | 一种模型隐私保护方法及装置 |
CN112862011A (zh) * | 2021-03-31 | 2021-05-28 | 中国工商银行股份有限公司 | 基于联邦学习的模型训练方法、装置及联邦学习系统 |
Also Published As
Publication number | Publication date |
---|---|
WO2023050754A1 (zh) | 2023-04-06 |
CN114003949A (zh) | 2022-02-01 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114003949B (zh) | 基于隐私数据集的模型训练方法和装置 | |
CN111275207B (zh) | 基于半监督的横向联邦学习优化方法、设备及存储介质 | |
CN111291897B (zh) | 基于半监督的横向联邦学习优化方法、设备及存储介质 | |
CN110472647B (zh) | 基于人工智能的辅助面试方法、装置及存储介质 | |
CN111860829A (zh) | 联邦学习模型的训练方法及装置 | |
CN110797124A (zh) | 一种模型多端协同训练方法、医疗风险预测方法和装置 | |
CN113408209A (zh) | 跨样本联邦分类建模方法及装置、存储介质、电子设备 | |
US11615463B2 (en) | Artificial intelligence based digital leasing assistant | |
CN108229718A (zh) | 一种信息预测方法及装置 | |
CN116664930A (zh) | 基于自监督对比学习的个性化联邦学习图像分类方法及系统 | |
CN115098692B (zh) | 跨域推荐方法、装置、电子设备及存储介质 | |
CN113742488B (zh) | 基于多任务学习的嵌入式知识图谱补全方法和装置 | |
CN113391992B (zh) | 测试数据的生成方法和装置、存储介质及电子设备 | |
CN113014566A (zh) | 恶意注册的检测方法、装置、计算机可读介质及电子设备 | |
CN110049006A (zh) | 实时与非实时结合的多人远程咨询系统与方法 | |
CN116644167A (zh) | 目标答案的生成方法和装置、存储介质及电子装置 | |
CN114358250A (zh) | 数据处理方法、装置、计算机设备、介质及程序产品 | |
CN115563259A (zh) | 一种多模态问答数据采集方法及装置 | |
JPWO2020032125A1 (ja) | 議論支援装置および議論支援装置用のプログラム | |
Serhani et al. | Dynamic Data Sample Selection and Scheduling in Edge Federated Learning | |
Agarwal et al. | A novel approach to big data veracity using crowdsourcing techniques and Bayesian predictors | |
CN116415064A (zh) | 双目标域推荐模型的训练方法及装置 | |
Velagapudi et al. | FedDHr: Improved Adaptive Learning Strategy Using Federated Learning for Image Processing | |
CN114528392A (zh) | 一种基于区块链的协同问答模型构建方法、装置及设备 | |
CN115114467B (zh) | 图片神经网络模型的训练方法以及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
CB03 | Change of inventor or designer information | ||
CB03 | Change of inventor or designer information |
Inventor after: Liu Yang Inventor before: Liu Yang Inventor before: Cheng Sijie Inventor before: Wu Jingwen |
|
GR01 | Patent grant | ||
GR01 | Patent grant |