CN115049076A - 基于原型网络的迭代聚类式联邦学习方法 - Google Patents
基于原型网络的迭代聚类式联邦学习方法 Download PDFInfo
- Publication number
- CN115049076A CN115049076A CN202210824020.5A CN202210824020A CN115049076A CN 115049076 A CN115049076 A CN 115049076A CN 202210824020 A CN202210824020 A CN 202210824020A CN 115049076 A CN115049076 A CN 115049076A
- Authority
- CN
- China
- Prior art keywords
- user
- global
- prototype network
- prototype
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本公开提供了一种基于原型网络的迭代聚类式联邦学习方法,该方法包括:用户端接收由服务器端发送的本轮全局原型网络;用户端根据本地样本数据集,训练全局原型网络以及确定用户嵌入表示向量和用户类别分布,其中,每一个用户具有一个用户嵌入表示向量;用户端向服务器端发送训练后的全局原型网络、用户嵌入表示向量、用户类别分布,以使服务器端执行如下操作:根据用户嵌入表示向量和用户类别分布,确定基于训练后的全局原型网络的模型距离;根据模型距离确定用户聚类结果,进而执行全局模型参数聚合,得到最新的全局模型参数;循环执行T轮上述步骤,直至利用基于原型网络的迭代聚类式联邦学习方法训练的全局原型网络满足预设收敛条件。
Description
技术领域
本公开涉及联邦学习与分布式个性化模型训练领域,具体地涉及一种基于原型网络的迭代聚类式联邦学习方法。
背景技术
随着大众数据隐私意识的觉醒与相关法律的出台,机器学习的数据隐私安全问题逐渐凸显,这导致收集大量且有标注的数据集变得十分困难。联邦学习的出现使得终端用户利用本地数据联合训练机器学习模型成为可能。在传统联邦学习算法的范式中,所有用户分布式地联合训练一个唯一的全局模型,该模型随后被部署于所有的终端用户。
在实现本公开发明构思的过程中,发明人发现相关技术中至少存在如下技术问题:用户之间的数据异构问题使得仅仅训练一个全局模型不能符合所有用户的数据分布,而且目前的聚类式联邦学习方法尚未考虑到用户之间类别分布异构,即类别分布不平衡且不一致的问题,使得对用户的聚类划分准确率低,不能表现出优异的用户聚类效果,不能满足个性化联邦学习的需求。
发明内容
鉴于上述问题,本公开提供了一种个性化联邦学习方法,可以适用于用户类别分布异构场景以及提高用户聚类效果的基于原型网络的迭代聚类式联邦学习方法。该方法基于数据分布差异对用户群体进行聚类,实现了个性化的联邦学习训练,可以适用于用户类别分布异构场景,提高聚类式联邦学习方法的用户聚类效果。
本公开提供了一种基于原型网络的迭代聚类式联邦学习方法(PN-ICFL),包括:步骤S1,用户端接收由服务器端发送的本轮全局原型网络;步骤S2,上述用户端根据本地样本数据集,训练上述全局原型网络以及确定用户嵌入表示向量和用户类别分布,其中,每一个用户具有一个上述用户嵌入表示向量;步骤S3,上述用户端向上述服务器端发送训练后的全局原型网络、上述用户嵌入表示向量、上述用户类别分布,以使上述服务器端执行如下操作:根据上述用户嵌入表示向量和上述用户类别分布,确定基于上述训练后的全局原型网络的模型距离;根据上述模型距离确定用户聚类结果,进而执行全局模型参数聚合,得到最新的全局模型参数,其中,上述全局模型包括上述全局原型网络模型和多个全局类原型矩阵;步骤S4,循环执行T轮的步骤S1至步骤S3,直至利用基于原型网络的迭代聚类式联邦学习方法训练的全局原型网络满足预设收敛条件,其中,T为大于1的整数。
根据本公开的实施例,上述服务器端还用于执行如下操作:获取初始全局原型网络;利用公有样本数据集对上述初始全局原型进行预训练,得到上述全局原型网络。
根据本公开的实施例,上述用户嵌入表示向量包括上述本地样本数据集中的每一类样本数据集在上述全局原型网络中的类原型。
根据本公开的实施例,上述服务器端还包括多个聚类簇;上述根据上述用户嵌入标识和上述用户类别分布,确定基于上述训练后的全局原型网络的模型距离包括:根据上述用户类别分布、上述本地样本数据集中的每一类样本数据集在上述全局原型网络中的类原型、上述聚类簇中的全局类原型矩阵确定上述用户与上述聚类簇之间的模型距离。
根据本公开的实施例,上述服务器端还包括如下操作:基于最短模型距离的原则,确定上述用户与上述聚类簇之间最短的模型距离;根据上述最短的模型距离,确定目标聚类簇;根据上述目标聚类簇,确定上述用户聚类结果;根据用户聚类结果,进行全局模型参数聚合,得到最新的全局模型参数。
根据本公开的实施例,用户端根据本地样本数据集,进行局部更新,训练上述全局原型网络包括:根据元学习训练方法,确定元学习任务数量;根据上述元学习任务数量,划分上述本地样本数据集;从划分后的上述本地样本数据集中选取样本数据组成支持集数据和查询集数据;根据上述支持集数据和上述查询集数据训练上述全局原型网络。
根据本公开的实施例,通过为每一个用户计算一个用户嵌入表示向量,而且服务器端可以根据用户嵌入表示向量和用户类别分布,确定基于训练后的全局原型网络的模型距离;根据模型距离确定用户聚类结果。因为在计算模型距离的过程中考虑到了用户嵌入表示向量和用户类别分布,并根据得到的模型距离进行用户聚类划分,可以有效应对用户类别分布不平衡且不一致的问题,因此可以适用于用户分布异构场景,实现个性化的联邦训练。至少部分地克服了相关技术中由于未考虑到用户之间类别分布异构而导致的聚类效果准确率低的问题,达到了优化用户聚类效果,提高个性化模型训练性能的技术效果。
附图说明
通过以下参照附图对本公开实施例的描述,本公开的上述内容以及其他目的、特征和优点将更为清楚,在附图中:
图1示意性示出了根据本公开实施例的基于原型网络的迭代聚类式联邦学习方法的系统架构图;
图2示意性示出了根据本公开实施例的基于原型网络的迭代聚类式联邦学习方法的流程图;
图3示意性示出了根据本公开实施例的基于原型网络的模型距离的示意图;
图4示意性示出了根据本公开实施例的基于原型网络的迭代聚类式联邦学习方法的框架示意图;
图5示意性示出了根据本公开实施例的本地测试样本数据集;
图6示意性示出了根据本公开实施例的在小样本数据场景中平均测试准确率随通信轮数变化的曲线图。
具体实施方式
以下,将参照附图来描述本公开的实施例。但是应该理解,这些描述只是示例性的,而并非要限制本公开的范围。在下面的详细描述中,为便于解释,阐述了许多具体的细节以提供对本公开实施例的全面理解。然而,明显地,一个或多个实施例在没有这些具体细节的情况下也可以被实施。此外,在以下说明中,省略了对公知结构和技术的描述,以避免不必要地混淆本公开的概念。
在此使用的术语仅仅是为了描述具体实施例,而并非意在限制本公开。在此使用的术语“包括”、“包含”等表明了所述特征、步骤、操作和/或部件的存在,但是并不排除存在或添加一个或多个其他特征、步骤、操作或部件。
在此使用的所有术语(包括技术和科学术语)具有本领域技术人员通常所理解的含义,除非另外定义。应注意,这里使用的术语应解释为具有与本说明书的上下文相一致的含义,而不应以理想化或过于刻板的方式来解释。
在使用类似于“A、B和C等中至少一个”这样的表述的情况下,一般来说应该按照本领域技术人员通常理解该表述的含义来予以解释(例如,“具有A、B和C中至少一个的系统”应包括但不限于单独具有A、单独具有B、单独具有C、具有A和B、具有A和C、具有B和C、和/或具有A、B、C的系统等)。在使用类似于“A、B或C等中至少一个”这样的表述的情况下,一般来说应该按照本领域技术人员通常理解该表述的含义来予以解释(例如,“具有A、B或C中至少一个的系统”应包括但不限于单独具有A、单独具有B、单独具有C、具有A和B、具有A和C、具有B和C、和/或具有A、B、C的系统等)。
随着大众数据隐私意识的觉醒与相关法律的出台,机器学习的数据隐私安全问题逐渐凸显,这导致收集大量且有标注的数据集变得十分困难。联邦学习的出现使得终端用户利用本地数据联合训练机器学习模型成为可能。在传统联邦学习算法的范式中,所有用户分布式地联合训练一个唯一的全局模型,该模型随后被部署于所有的终端用户。然而,用户之间的数据异构问题使得仅仅训练一个全局模型不能符合所有用户的数据分布。因此,近年来一些个性化联邦学习方法被陆续提出。
现实场景中,地域,文化,性别等个性化偏向可能导致用户产生不同的数据分布,聚类式联邦学习考虑到此因素,将用户群体看作是由多个不同的全局特征分布所产生,进而根据用户数据分布的相似性进行用户聚类,实现聚类簇级别的个性化训练。现有的聚类式联邦学习算法可以大致分为两类,分别为层级聚类式联邦学习算法和迭代聚类式联邦学习算法。层级聚类式联邦学习算法可以包括IEEE Transactions on neural networks andlearning systems 2020期刊上发表的CFL(Clustered Federated Learning,聚类式联邦学习)框架,DASFAA 2020会议上发表的CIF-FL(Class Imbalance-Aware ClusteredFederated Learning,类别不平衡感知的聚类式联邦学习)框架;迭代聚类式联邦学习算法主要为NIPS 2020上发表的IFCA(Iterative Federated Clustering Algorithm,迭代联邦聚类算法)框架。这些算法均采用一定的技术在未直接访问用户数据的条件下,实现用户聚类以及个性化训练。然而,现有的层级聚类式联邦学习算法需要多阶段联邦训练,计算代价较高,超参数较多,实时性差。迭代式方法IFCA框架符合传统FedAvg(联邦学习)算法的训练范式,由于需要终端用户同时对所有的全局簇模型进行本地测试,也带来了额外的通信代价。目前的聚类式联邦学习方法还尚未考虑到用户之间类别分布异构,即类别分布不平衡且不一致的问题(也称为Label Non-IID)。此外,在用户样本稀少的小样本数据场景中,现有的上述方法也不能表现出优异的用户聚类效果,不能满足个性化联邦学习的要求。
有鉴于此,本公开提供了一种个性化联邦学习方法,具体地,提供了一种基于原型网络的迭代聚类式联邦学习方法,以适用于用户分布异构场景,优化对用户的聚类效果。该方法可以包括:步骤S1,用户端接收由服务器端发送的本轮全局原型网络;步骤S2,用户端根据本地样本数据集,训练全局原型网络以及确定用户嵌入表示向量和用户类别分布,其中,每一个用户具有一个用户嵌入表示向量;步骤S3,用户端向服务器端发送训练后的全局原型网络、用户嵌入表示向量、用户类别分布,以使服务器端执行如下操作:根据用户嵌入表示向量和用户类别分布,确定基于训练后的全局原型网络的模型距离;根据模型距离确定用户聚类结果;根据用户聚类结果,进行全局模型参数聚合,得到最新的全局模型参数,其中,全局模型包括全局原型网络和多个全局类原型矩阵;步骤S4,循环执行T轮的步骤S1至步骤S3,直至利用基于原型网络的迭代聚类式联邦学习方法训练的全局原型网络满足预设收敛条件,其中,T为大于1的整数。
需要说明的是,在本公开的技术方案中,所涉及的用户个人信息的收集、存储、使用、加工、传输、提供、公开和应用等处理,均符合相关法律法规的规定,采取了必要保密措施,且不违背公序良俗。在本公开的技术方案中,在获取或采集用户个人信息之前,均获取了用户的授权或同意。
图1示意性示出了根据本公开实施例的基于原型网络的迭代聚类式联邦学习方法的系统架构图。
如图1所示,根据该实施例的系统架构100可以包括终端设备101、102、103,用户端服务器104、云端服务器105。
终端设备101、102、103可以是公共环境或各种独立环境中的各种电子设备,包括但不限于智能手机、平板电脑、膝上型便携计算机和台式计算机等等。
终端设备上可以安装有各种通讯客户端应用,例如基于联邦学习模型对用户进行聚类的应用、购物类应用、网页浏览器应用、搜索类应用、即时通信工具、邮箱客户端、社交平台软件等(仅为示例)。
用户可以通过在终端设备101、102、103的客户端应用中进行操作,例如物品推荐、物品识别等等,基于此操作,终端设备101、102、103可以向用户端服务器104发送对用户进行聚类的请求,以便用户端服务器104可以根据基于原型网络的迭代聚类式联邦学习方法对基于联邦学习模型的物品推荐模型、物品识别模型进行训练。
用户端服务器104在终端设备101、102、103和云端服务器105之间可以通过有线、无线通信链路或者光纤电缆等等。
用户端服务器104和云端服务器105可以是提供各种服务的服务器。在一实施例中,用户端服务器104可以接收由云端服务器发送的本轮全局原型网络;用户端服务器104根据本地样本数据集训练本轮全局原型网络,并确定用户嵌入表示向量和用户类别分布;用户端服务器104将训练后的全局原型网络和确定的用户嵌入表示向量、用户类别分布上传至云端服务器105;云端服务器105可以根据用户嵌入表示向量和用户类别分布确定基于训练后的全局原型网络的模型聚类,并根据模型距离确定用户的聚类结果;根据用户聚类结果,进行全局模型参数聚合,得到最新的全局模型参数。循环执行T轮的上述方法,直至利用基于原型网络的迭代聚类式联邦学习方法训练的联邦学习模型满足预设收敛条件。用户端服务器104还可以将用户聚类结果反馈至终端设备101、102、103。
需要说明的是,本公开实施例所提供的基于原型网络的迭代聚类式联邦学习方法一般可以由用户端服务器104执行。本公开实施例所提供的基于原型网络的迭代聚类式联邦学习方法也可以由不同于用户端服务器104且能够与终端设备101、102、103和/或用户端服务器104通信的服务器或服务器集群执行。
应该理解,图1中的终端设备、网络和服务器的数目仅仅是示意性的。根据实现需要,可以具有任意数目的终端设备、网络和服务器。
以下将基于图1描述的场景,通过图2~图6对公开实施例的基于原型网络的联邦学习方法进行详细描述。
图2示意性示出了根据本公开实施例的基于原型网络的联邦学习方法的流程图。
如图2所示,该实施例的基于原型网络的联邦学习方法包括步骤S1~步骤S4。
在步骤S1,用户端接收由服务器端发送的本轮全局原型网络。
在步骤S2,用户端根据本地样本数据集,训练全局原型网络以及确定用户嵌入表示向量和用户类别分布,其中,每一个用户具有一个用户嵌入表示向量。
在步骤S3,用户端向服务器端发送训练后的全局原型网络、用户嵌入表示向量、用户类别分布,以使服务器端执行如下操作:根据用户嵌入表示向量和用户类别分布,确定基于训练后的全局原型网络的模型距离;根据模型距离确定用户聚类结果;根据用户聚类结果,进行全局模型参数聚合,得到最新的全局模型参数,其中,全局模型包括全局原型网络和多个全局类原型矩阵。
在步骤S4,循环执行T轮的步骤S1至步骤S3,直至利用基于原型网络的迭代聚类式联邦学习方法训练的全局原型网络满足预设收敛条件,其中,T为大于1的整数。
根据本公开的实施例,基于原型网络的联邦学习方法可以执行多轮。本轮全局原型网络可以是在每一轮的训练过程中,用户端从服务器端下载或接收的当前所在轮的全局原型网络。
根据本公开的实施例,本地样本数据集可以是多分类数据集,本地样本数据集中可以包括用户的兴趣爱好、用户在终端设备上的行为操作数据、多种复杂类型的物体的特征数据等,本地数据集还可以根据实际应用场景进行适应性调整。
根据本公开的实施例,用户嵌入表示向量可以是第一训练样本数据集中的每一类样本数据集在全局原型网络中的类原型。用户类别分布可以是本地样本数据集的类别分布。该学习方法可以基于由全局原型网络产出的用户类原型为每个用户构造一个用户嵌入表示向量,利于实现对用户的个性化联邦训练。
根据本公开的实施例,用户端还可以对接收到的原型网络进行训练,以更新该全局原型网络,并将更新后的全局原型网络再上传至服务器端。
根据本公开的实施例,模型距离可以是用户与聚类簇之间的距离,用于度量用户与聚类簇之间的数据(特征)分布差异性。在模型距离最短的情况下,可以将与该最短模型距离对应的聚类簇作为对用户进行聚类后的用户聚类结果。
根据本公开的实施例,利用本公开实施例提供的基于原型网络的联邦学习方法还可以训练其他模型,例如可以训练各类物品推荐模型、图像识别模型、文字识别模型等等。
根据本公开的实施例,预设收敛条件可以根据损失函数进行判断,在损失函数的损失结果表示为收敛的条件下,可以认为满足预设收敛条件。在本公开的实施例中,可以定义原型网络为hθ:其中D与M分别为原始样本特征大小和原型网络空间特征大小,θ为原型网络的参数。损失函数可以如公式(1)所示。
其中,表示第j个簇的全局类原型矩阵,为该簇内第k个类别的类原型。{S1,…,SK}为当前聚类划分。基于原型网络原始优化目标,当机器学习任务为多分类时,可以定义每个簇对应的损失函数Fj(θ,Ej,Sj)如公式(2)所示。
其中,hθ(x)为原始样本x在原型网络空间中的特征表示,d(,)函数为距离度量函数,为当前簇Sj内所有用户局部数据集中第k类样本的集合,与分别为第j个簇内的第k与k′个类别的类原型。随后,假设距离度量函数d(,)为欧式距离的平方,采用期望最大化(EM)算法的思想,交替地优化公式(1)中的模型参数θ,E1,…,Ek与隐变量{S1,…,SK}。
根据本公开的实施例,预设收敛条件还可以是联邦学习模型所对应的损失函数,在联邦学习模型所对应的损失函数的损失结果表示为收敛的条件下,也可以结束基于原型网络的联邦学习方法对联邦学习模型的训练。
根据本公开的实施例,还可以设置预设迭代轮数,当循环执行了预设迭代轮数后,也可以结束利用基于原型网络的联邦学习方法对联邦学习模型的训练。预设迭代轮数还可以根据实际需要进行适应性调整。
根据本公开的实施例,在基于原型网络的迭代聚类式联邦学习方法执行结束后,可以同时输出用户聚类结果与训练好的原型网络的模型参数。
根据本公开的实施例,通过为每一个用户计算一个用户嵌入表示向量,而且服务器端可以根据用户嵌入表示向量和用户类别分布,确定基于训练后的全局原型网络的模型距离;根据模型距离确定用户聚类结果。因为在计算模型距离的过程中考虑到了用户嵌入表示向量和用户类别分布,并根据得到的模型距离进行用户聚类划分,可以有效应对用户类别分布不平衡且不一致的问题,因此可以适用于由于用户数据量较少、身份差异或样本获取能力有限等因素导致的用户类别分布异构场景,实现个性化的联邦训练。至少部分地克服了相关技术中由于未考虑到用户之间类别分布异构而导致的聚类效果准确率低的问题,达到了优化用户聚类效果,提高个性化模型训练性能的技术效果。
根据本公开的实施例,在一些物品推荐模型的训练任务中,各用户由于地域、年龄差异,对同一事物可以有不同的评价或标签,例如不同年龄段的人可能对同一电影评价不同;地域差异也可能会导致用户对不同口味的食物的喜爱程度不同。利用本公开实施例提供的个性化联邦训练方法,即利用基于原型网络的迭代聚类式联邦学习方法在对物品推荐模型进行训练的过程中,可以根据用户的评价习惯、行为数据等操作将用户群体从年龄或地域的角度聚类为不同的簇,以便准确地对用户进行分类,进而可以提高物品推荐模型推荐物品或推荐信息的准确率。
根据本公开的实施例,在一些轻量级的图像识别模型或文字识别模型的训练任务中,移动端设备(例如手机)由于计算与存储能力的有限,通常只能部署轻量级的神经网络,然而用户的个人数据由于地域、身份或采样设备等差异可能存在不同的风格,数据分布复杂多样,例如各种样式类型的数字、字符、花草识别、宠物识别或其他图片识别等。轻量级模型的模型复杂度有限,很难同时拟合所有训练数据的分布,导致该移动端设备的轻量级模型预测准确率较低。利用本公开实施例提供的个性化联邦训练方法,即利用基于原型网络的迭代聚类式联邦学习方法可以在数据分布复杂的情况下对不同类型的数据进行建模,准确地对不同类型的数据进行聚类,提高轻量级模型的预测准确率。
根据本公开的实施例,在有对抗用户存在的情况下,例如,部分场景中,在隐私约束的情况下,分布式训练中存在行为异常的用户时,该类行为异常的用户可能会通过意图错误标注样本等恶意行文,影响最终训练得到的联邦模型的决策,该类用户一般被称为对抗用户。在这种情况下,利用本公开实施例提供的个性化联邦训练方法,即利用基于原型网络的迭代聚类式联邦学习方法可以将用户群体准确地聚类为正确用户与对抗用户,可以起到防御恶意攻击的作用。
根据本公开的实施例,服务器端还可以获取初始全局原型网络;利用公有样本数据集对初始全局原型进行训练,得到全局原型网络。
根据本公开的实施例,初始全局原型网络可以包括4个卷积块,每个卷积块可以包括一个卷积层、一个批标准化层、一个ReLU激活函数层。服务器端可以利用第二训练样本数据集对原型网络进行预训练,具体地,可以根据元学习训练方法,以元学习任务为训练单位进行预训练,得到原型网络参数θ(0)。对初始全局原型网络的预训练,可以使得原型网络成为一个在其他相关任务上具有泛化能力的空间映射函数,该函数对随后进行联邦训练的数据,可以起到相同类别的样本之间的欧式距离相对较近的作用。利用预训练得到的全局原型网络可以作为基于原型网络的迭代聚类式联邦学习框架的初始参数。
根据本公开的实施例,用户嵌入表示向量可以是本地样本数据集中的每一类样本数据集在全局原型网络中的类原型。具体地,用户i的用户嵌入表示向量可以如公式(3)所示。
根据本公开的实施例,服务器端包括多个聚类簇;根据用户嵌入标识和用户类别分布,确定基于训练后的全局原型网络的模型距离包括:根据用户类别分布、本地样本数据集中的每一类样本数据集在训练后的全局原型网络中的类原型、聚类簇中的全局类原型矩阵确定用户与聚类簇之间的模型距离。
图3示意性示出了根据本公开实施例的基于原型网络的模型距离的示意图。
如图3所示,服务器端可以根据用户i的用户嵌入表示向量ui与类别分布向量pi(y)分别计算用户i与K个聚类簇或全局簇模型之间的距离,还可以理解为计算用户i与(全局)类原型矩阵{Ej}的距离di1,di2,…,diK。模型距离的具体计算公式可以如公式(5)所示。
其中,dij为用户i与第j个全局簇模型之间的模型距离,pi(y=k)为用户i本地数据集的类别分布向量,C为类别总数,表示用户i的本地数据集中第k类样本子集在原型网络θ空间中的类原型,为第j个全局簇模型中的第k个类原型。
其中,类别分布向量pi(y)=[pi(y=1),…,pi(y=K)]T,可以基于用户本的每类样本占本地样本数据集中所有样本的经验比例计算得到,类别分布向量pi(y)的计算可过可以如公式(6)所示。
由于该模型距离dij显式地考虑了用户的类别分布pi(y=k),因此可以有效应对用户本地数据集类别分布不平衡且不一致,即类别分布异构的场景,进而可以实现个性化的联邦训练。
根据本公开的实施例,操作S230中还可以包括:基于最短模型距离的原则,确定用户与聚类簇之间最短的模型距离;根据最短的模型距离,确定目标聚类簇;根据目标聚类簇,确定用户聚类结果;根据用户聚类结果,进行全局模型参数聚合,得到最新的全局模型参数。
根据本公开的实施例,根据最短的模型距离,可以确定相关联的聚类簇的标识,例如聚类标识,具体地,可以通过最短模型距离,将与该模型距离对应的全局类原型矩阵认为是目标全局类原型矩阵,因为每个聚类簇均可对应一个全局类原型矩阵,将与该目标全局类原型矩阵对应的聚类簇认为是目标聚类簇,用户聚类结果可以是该目标聚类簇,还可以理解的是,用户聚类结果是与该样本用户对应的聚类簇,用户可以被聚类到该聚类簇中。
根据本公开的实施例,根据本地样本数据集,训练全局原型网络包括:根据元学习训练方法,确定元学习任务数量;根据元学习任务数量,划分本地样本数据集;从划分后的本地样本数据集中选取样本数据组成支持集数据和查询集数据;根据支持集数据和查询集数据训练全局原型网络。
根据本公开的实施例,在更新全局原型网络时,可以输入原型网络参数θ,学习率μ,迭代轮数E,用户元学习任务数量B;输出更新后的原型网络参数θ,类别分布向量pi(y)。具体地,元学习任务可以包括支持集和查询集在训练时数据集可以划分成B个小样本元学习任务,根据划分后的数据集选取支持集和查询集中分别用到的训练样本组成支持集数据和查询集数据,根据支持集数据基于支持集计算类原型,根据支查询集数据基于查询集计算与训练该原型网络相关的损失函数损失函数的计算过程可以如公式(7)所示,再基于该损失函数更新原型网络,更新过程可以如公式(8)所示。最后,用户端可以将每个用户的ui、pi(y)、和更新后的原型网络参数θ(t)上传给服务器端。
图4示意性示出了根据本公开实施例的基于原型网络的迭代聚类式联邦学习方法的框架示意图。
如图4所示,图中可以表示在每一轮联邦训练中,用户端可以首先下载当前的全局原型网络θ(t),计算用户嵌入表示向量ui并进行局部更新(图中最右侧)。在该框架中,用户端的操作包括计算用户嵌入表示和局部更新原型网络参数,服务器端的操作包括预训练原型网络、计算用户与聚类簇之间的模型距离、估计聚类标识、与模型参数聚合。具体地,每个用户可以在本地以元学习的训练方式更新原型网络θ。服务器在接收到用户上传的后,为每一个“用户i-聚类簇j”对计算模型距离dij、估计聚类标识si,再基于模型距离、聚类标识执行模型聚合(图中最左侧)。
图5示意性示出了根据本公开实施例的本地测试样本数据集。图中包含四种数字数据集,分别为两种手写体数据集MNIST和USPS,一种现实街景数字数据集SVHN,和一种手语数字数据集SIGN。每一种数据集代表了聚类式联邦学习问题中一个全局特征分布,即一种个性化偏向。每个用户的本地数据集可以生成自这四种数据集的其中一种,服务器端也可以对用户进行采样,得到这四种数据集的其中一种或多种数据。本公开的基于原型网络的迭代聚类式联邦学习方法还可以作为其他基于联邦学习的识别模型的一个预处理环节,可以对各种类型的数字、文本、字符、宠物、花草等物体的特征进行聚类,以便提高识别模型识别的准确率。
根据本公开的实施例,聚类式联邦学习主要用于解决数据分布异构中的特征分布异构问题。特征分布异构是指联邦学习中用户数据分布的来源不同,也可以看作数据的类条件分布,或称为特征分布的不同,如图5所示的多源数字数据集,包含4种不同形式的数字数据集,其中每一种可以视为一种个性化偏向。本公开的目标则是在未知用户真实特征分布的前提下,对用户群体进行聚类,使得特征分布相同的用户聚集到同一个聚类簇,进行实现个性化的联邦训练。
根据本公开的实施例,基于原型网络的迭代聚类式联邦学习方法还可以是如下操作:输入簇的个数K,用户个数N,通信轮数T,局部更新轮数E,局部学习率μ,用户元学习任务数量B,共有数据集G;输出的可以是一个全局原型网络θ(T),K个全局类原型矩阵用户群体的聚类划分结果{S1,…,SK}。具体地,在基于原型网络的迭代聚类式联邦学习模型的训练过程中,服务器端可以使用公有数据集G预训练初始原型网络,得到θ(0);服务器端初始化K个全局类原型矩阵服务器端进行用户采样,得到当前轮参与训练的用户集合具体地,采样策略还可以根据实际情况进行适应性调整;用户从服务器端下载当前原型网络参数θ(t),计算自身的用户嵌入表示ui,以元学习训练的方式对原型网络进行局部更新,可以参考上述更新原型网络的方法,之后,用户端可以将每个用户的ui、pi(y)、和更新后的原型网络参数θ(t)上传给服务器端;服务器收到用户端上传的每个用户的ui、pi(y)、和更新后的原型网络参数θ(t),依次计算用户与聚类簇之间的模型距离、估计用户聚类标识、与基于用户聚类结果进行全局模型参数聚合,全局模型参数包括全局原型网络和全局类原型矩阵,基于用户聚类结果进行全局模型参数聚合可以包括:基于用户聚类结果进行全局类原型矩阵的参数聚合以及基于用户聚类结果进行全局原型网络的参数聚合,其中,估计用户聚类标识的过程可以如公式(9)所示,基于用户聚类结果进行的全局类原型矩阵的参数聚合的过程可以如公式(10)所示,全局原型网络的参数聚合过程可以如公式(11)所示;在确定了本轮的模型距离、聚类标识、并基于聚类结果进行模型参数聚合后可以进入下一轮联邦训练。在联邦训练满足收敛条件后,可以结束此次的联邦训练,输出一个全局原型网络θ(T),K个全局类原型矩阵用户群体的聚类划分结果{S1,…,SK}。
si←argminj∈[K]dij (9)
其中,si表示聚类标识,argmin函数用于确定最短的模型距离。
其中,t表示迭代轮数,表示t+1轮第j个全局簇模型的全局类原型矩阵,为t+1轮第j个全局簇模型中的第k个类原型,pi(y=k)与pr(y=k)分别表示用户i与用户r的类别分布,表示在t轮用户i在原型网络θ空间中的类原型,Sj表示该轮第j个聚类簇的用户集合,即当前聚类划分。
根据本公开的实施例,基于原型网络的迭代聚类式联邦学习方法可以输出一个全局原型网络θ(T),K个全局类原型矩阵以及用户群体的聚类划分结果{S1,…,SK}。在模型训练时,用户i可以从服务器端下载第si个簇的全局类原型矩阵和全局原型网络参数θ(T)。测试样本x经过原型网络映射得到特征表示hθ(x),然后基于最短距离原则对用户进行类别预测。对于未参与联邦训练的新用户,在利用该基于原型网络的迭代聚类式联邦学习方法对该新用户进行聚类时,可以首先使用本地数据集计算该新用户嵌入表示向量、用户类别分布,再由服务器端根据该用户嵌入表示向量和用户类别分布为该新用户估计聚类标识,再根据聚类标识得到的全局类原型矩阵发送给用户端,也可以理解为将与该新用户对应的聚类簇反馈至用户端。该新用户随后也可以作为测试样本参与联邦训练。可以注意到,由于从服务器端下载的全局类原型矩阵包含所有类别的类原型,因此,在训练阶段用户端也可以对局部数据集中缺失类别的样本进行预测。
根据本公开的实施例,用户在本地计算得到用户嵌入表示向量ui之后,即可立即将其与类别分布向量pi(y)发送给服务器端。因此,服务器端在计算模型距离,并对全局类原型以最新聚类划分进行聚合的同时,各用户可以在本地进行局部更新。这可以提高训练过程的并行性,在一定程度上可以节省训练时间。本公开实例假设类别分布信息不为隐私敏感信息,可以由用户上传给服务器端。
图6示意性示出了根据本公开实施例的在小样本数据场景中平均测试准确率随通信轮数变化的曲线图。
如图6所示,其中,Local model为用户基于本地数据集训练的个性化模型,FedFoMo为一种个性化联邦学习算法,在目前大部分异构数据集上达到SOTA的结果。此场景中,本公开实施例提供的基于原型网络的迭代式聚类联邦学习模型(PN-ICFL)有最优的测试结果,且在通信轮数T=5时已到达接近0.7的测试准确率。通过对比其他算法多源数字数据集(图6)上的平均测试准确率,其结果证明了在采用本公开实施例所提出的方法的在小样本场景中的优越性。
根据本公开的实施例,本公开实施例提供的一个基于原型网络的迭代聚类式联邦学习方法,基于原型网络方法对聚类式联邦学习问题进行建模,还提出了一个“基于原型网络的模型距离”。由于该模型距离可以在服务器端直接进行计算或估计,相较于现有另一迭代式算法IFCA,在PN-ICFL框架中用户每一轮通信时有更低(大约1/K)的下载通信代价;由于该模型距离的计算直接考虑到了用户的类别分布信息,因此可以有效应对类别分布异构(Label Non-IID)的场景;同时由于元学习方法原型网络模型的使用,PN-ICFL模型可以有效应对小样本数据场景(如图6所示)。
根据本公开的实施例,由于具有通信量较低,且适用于小样本数据、类别分布异构数据场景的优点,本公开可以应用于现实中通信能力有限的移动端个性化联邦学习模型的训练中,根据用户数据特点快速地对用户实施聚类式个性化训练。此外,本公开的用户聚类操作还可以作为其他相关联邦学习算法的一个预处理环节,例如,在推荐系统中,根据数据相似性对用户群体进行聚类,得到不同的聚类簇,分别代表不同的个性化偏向,之后针对每个聚类簇进行不同的操作,形成多阶段的处理。
需要说明的是,本公开实施例中的流程图所示的操作除非明确说明不同操作之间存在执行的先后顺序,或者不同操作在技术实现上存在执行的先后顺序,否则,多个操作之间的执行顺序可以不分先后,多个操作也可以同时执行。
还需要说明的是,实施例中提到的方向用语,例如“上”、“下”、“前”、“后”、“左”、“右”等,仅是参考附图的方向,并非用来限制本公开的保护范围。贯穿附图,相同的元素由相同或相近的附图标记来表示。在可能导致对本公开的理解造成混淆时,将省略常规结构或构造。
并且图中各部件的形状和尺寸不反映真实大小和比例,而仅示意本公开实施例的内容。再者,单词"包含"不排除存在未列在权利要求中的元件或步骤。位于元件之前的单词“一”或“一个”不排除存在多个这样的元件。
类似地,应当理解,为了精简本公开并帮助理解各个发明方面中的一个或多个,在上面对本公开的示例性实施例的描述中,本公开的各个特征有时被一起分组到单个实施例、图、或者对其的描述中。然而,并不应将该发明的方法解释成反映如下意图:即所要求保护的本公开要求比在每个权利要求中所明确记载的特征更多的特征。更确切地说,如下面的权利要求书所反映的那样,发明方面在于少于前面发明的单个实施例的所有特征。因此,遵循具体实施方式的权利要求书由此明确地并入该具体实施方式,其中每个权利要求本身都作为本公开的单独实施例。
本领域技术人员可以理解,本公开的各个实施例和/或权利要求中记载的特征可以进行多种组合或/或结合,即使这样的组合或结合没有明确记载于本公开中。特别地,在不脱离本公开精神和教导的情况下,本公开的各个实施例和/或权利要求中记载的特征可以进行多种组合和/或结合。所有这些组合和/或结合均落入本公开的范围。
以上对本公开的实施例进行了描述。但是,这些实施例仅仅是为了说明的目的,而并非为了限制本公开的范围。尽管在以上分别描述了各实施例,但是这并不意味着各个实施例中的措施不能有利地结合使用。本公开的范围由所附权利要求及其等同物限定。不脱离本公开的范围,本领域技术人员可以做出多种替代和修改,这些替代和修改都应落在本公开的范围之内。
Claims (6)
1.一种基于原型网络的迭代聚类式联邦学习方法,包括:
步骤S1,用户端接收由服务器端发送的本轮全局原型网络;
步骤S2,所述用户端根据本地样本数据集,训练所述全局原型网络以及确定用户嵌入表示向量和用户类别分布,其中,每一个用户具有一个所述用户嵌入表示向量;
步骤S3,所述用户端向所述服务器端发送训练后的全局原型网络、所述用户嵌入表示向量、所述用户类别分布,以使所述服务器端执行如下操作:根据所述用户嵌入表示向量和所述用户类别分布,确定基于所述训练后的全局原型网络的模型距离;根据所述模型距离确定用户聚类结果,进而执行全局模型参数聚合,得到最新的全局模型参数,其中,所述全局模型包括所述全局原型网络和多个全局类原型矩阵;
步骤S4,循环执行T轮的步骤S1至步骤S3,直至利用基于原型网络的迭代聚类式联邦学习方法训练的全局原型网络满足预设收敛条件,其中,T为大于1的整数。
2.根据权利要求1所述的方法,所述服务器端还用于执行如下操作:
获取初始全局原型网络;
利用公有样本数据集对所述初始全局原型进行预训练,得到所述全局原型网络。
3.根据权利要求1所述的方法,其中,所述用户嵌入表示向量包括所述本地样本数据集中的每一类样本数据集在所述全局原型网络中的类原型。
4.根据权利要求3所述的方法,其中,所述服务器端还包括多个聚类簇;
所述根据所述用户嵌入标识和所述用户类别分布,确定基于所述训练后的全局原型网络的模型距离包括:
根据所述用户类别分布、所述本地样本数据集中的每一类样本数据集在所述全局原型网络中的类原型、所述聚类簇中的全局类原型矩阵确定所述用户与所述聚类簇之间的模型距离。
5.根据权利要求4所述的方法,其中,所述服务器端还包括如下操作:
基于最短模型距离的原则,确定所述用户与所述聚类簇之间最短的模型距离;
根据所述最短的模型距离,确定目标聚类簇;
根据所述目标聚类簇,确定所述用户聚类结果;
根据用户聚类结果,进行全局模型参数聚合,得到最新的全局模型参数。
6.根据权利要求1所述的方法,其中,所述用户端根据本地样本数据集,进行局部更新,训练所述全局原型网络包括:
根据元学习训练方法,确定元学习任务数量;
根据所述元学习任务数量,划分所述本地样本数据集;
从划分后的所述本地样本数据集中选取样本数据组成支持集数据和查询集数据;
根据所述支持集数据和所述查询集数据训练所述全局原型网络。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210824020.5A CN115049076A (zh) | 2022-07-13 | 2022-07-13 | 基于原型网络的迭代聚类式联邦学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210824020.5A CN115049076A (zh) | 2022-07-13 | 2022-07-13 | 基于原型网络的迭代聚类式联邦学习方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115049076A true CN115049076A (zh) | 2022-09-13 |
Family
ID=83166080
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210824020.5A Pending CN115049076A (zh) | 2022-07-13 | 2022-07-13 | 基于原型网络的迭代聚类式联邦学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115049076A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115994226A (zh) * | 2023-03-21 | 2023-04-21 | 杭州金智塔科技有限公司 | 基于联邦学习的聚类模型训练系统及方法 |
CN116226540A (zh) * | 2023-05-09 | 2023-06-06 | 浙江大学 | 一种基于用户兴趣域的端到端联邦个性化推荐方法和系统 |
SE2230332A1 (en) * | 2022-10-17 | 2024-04-18 | Atlas Copco Ind Technique Ab | Estimation of class-imbalance in training data of an iterative learning process |
-
2022
- 2022-07-13 CN CN202210824020.5A patent/CN115049076A/zh active Pending
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
SE2230332A1 (en) * | 2022-10-17 | 2024-04-18 | Atlas Copco Ind Technique Ab | Estimation of class-imbalance in training data of an iterative learning process |
CN115994226A (zh) * | 2023-03-21 | 2023-04-21 | 杭州金智塔科技有限公司 | 基于联邦学习的聚类模型训练系统及方法 |
CN115994226B (zh) * | 2023-03-21 | 2023-10-20 | 杭州金智塔科技有限公司 | 基于联邦学习的聚类模型训练系统及方法 |
CN116226540A (zh) * | 2023-05-09 | 2023-06-06 | 浙江大学 | 一种基于用户兴趣域的端到端联邦个性化推荐方法和系统 |
CN116226540B (zh) * | 2023-05-09 | 2023-09-26 | 浙江大学 | 一种基于用户兴趣域的端到端联邦个性化推荐方法和系统 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109919316B (zh) | 获取网络表示学习向量的方法、装置和设备及存储介质 | |
CN115049076A (zh) | 基于原型网络的迭代聚类式联邦学习方法 | |
US10229357B2 (en) | High-capacity machine learning system | |
WO2020094060A1 (zh) | 推荐方法及装置 | |
CN105608179B (zh) | 确定用户标识的关联性的方法和装置 | |
CN105631707A (zh) | 基于决策树的广告点击率预估方法与应用推荐方法及装置 | |
US20140363075A1 (en) | Image-based faceted system and method | |
AU2016218947A1 (en) | Learning from distributed data | |
WO2022166115A1 (en) | Recommendation system with adaptive thresholds for neighborhood selection | |
CN111310074B (zh) | 兴趣点的标签优化方法、装置、电子设备和计算机可读介质 | |
CN110020022B (zh) | 数据处理方法、装置、设备及可读存储介质 | |
Lin et al. | Content recommendation algorithm for intelligent navigator in fog computing based IoT environment | |
CN114298122B (zh) | 数据分类方法、装置、设备、存储介质及计算机程序产品 | |
WO2023231542A1 (zh) | 表示信息的确定方法、装置、设备及存储介质 | |
CN110929806A (zh) | 基于人工智能的图片处理方法、装置及电子设备 | |
WO2023020214A1 (zh) | 检索模型的训练和检索方法、装置、设备及介质 | |
CN114358109A (zh) | 特征提取模型训练、样本检索方法、装置和计算机设备 | |
CN103399900A (zh) | 基于位置服务的图片推荐方法 | |
CN114332550A (zh) | 一种模型训练方法、系统及存储介质和终端设备 | |
CN109271555A (zh) | 信息聚类方法、系统、服务器及计算机可读存储介质 | |
CN115439770A (zh) | 一种内容召回方法、装置、设备及存储介质 | |
CN117035059A (zh) | 一种通信高效的隐私保护推荐系统及方法 | |
WO2023087933A1 (zh) | 内容推荐方法、装置、设备、存储介质及程序产品 | |
US20230044035A1 (en) | Model pool for multimodal distributed learning | |
CN111935259A (zh) | 目标帐号集合的确定方法和装置、存储介质及电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |