CN113379071A - 一种基于联邦学习的噪声标签修正方法 - Google Patents
一种基于联邦学习的噪声标签修正方法 Download PDFInfo
- Publication number
- CN113379071A CN113379071A CN202110666751.7A CN202110666751A CN113379071A CN 113379071 A CN113379071 A CN 113379071A CN 202110666751 A CN202110666751 A CN 202110666751A CN 113379071 A CN113379071 A CN 113379071A
- Authority
- CN
- China
- Prior art keywords
- local
- class
- global
- model parameters
- sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明提供一种基于联邦学习的噪声标签修正方法,包括:将客户端根据本地训练数据更新的本地模型参数和对应的样本数据量发送给服务端;获取服务端根据客户端以及其他客户端更新的本地模型参数和对应的样本数据量计算的全局模型参数;由客户端根据全局模型参数和本地训练数据,计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端;获取服务端根据多个本地类基准计算的多个全局类基准,并基于全局模型参数和全局类基准对客户端的本地训练数据进行噪声标签修正。将该修正方法应用于联邦学习系统中,实现了信息增强,减少了数据的损失以及对额外参照集的依赖,有效地提高了联邦学习训练结构测试的准确率。
Description
技术领域
本发明涉及的是分布式机器学习领域,具体涉及一种基于联邦学习的噪声标签修正方法。
背景技术
随着分布式机器学习和大数据分析的发展,联邦学习作为一种新型的分布式机器学习框架,满足了多个客户端(机构)在数据安全的要求下进行模型训练。在模型训练过程中,服务端和客户端之间仅交换模型参数,各客户端无需上传任何原始数据。在实际的联邦学习场景下,多个客户端的加入虽然带来了更多的知识,但同时也增加了数据噪声的风险,如多分类任务中的标签噪声问题。在实际操作中,标签噪声的问题难以避免,比如众包标定的电商货品图片,或者是医学生、非专家标定的医学影像,它们的类别标签往往依赖于操作人员的水平以及标定过程的准确性。这些现实存在的标签噪声往往会影响模型训练的准确性,除此之外,由于联邦学习规定服务端对原始数据不可见,检测标签噪声进而对其进行修正变得更具挑战。
现有研究多认为具有错误标签的样本是总体数据中的异常点,常基于服务端提供的一个额外的干净数据集来生成一个参照模型,用该参照模型来度量本地数据和服务端标准数据的差异,从而进行数据的筛选或者降低一部分样本参与训练的权重。
现有的解决标签噪声问题的技术,往往高度依赖于一个完全干净的参照数据集,这样的数据集要求标注信息完全准确,获取数据的开销极大。当参照数据集规模有限时,其类别分布和总体分布不一定一致,对于多分类任务来说其参考价值也会大打折扣。另外,现有的技术侧重于选择性地降低噪声数据的参与度,本质上是一种损失信息的方式。
在进行联邦学习中的标签噪声研究时,发现现有技术中的信息损失问题是由于没有对噪声标签进行修正而导致的。标签修正往往需要模型达到一定的预测水平,由于服务端并不总能提供一个理想的干净参照集,通过参照模型进行数据预筛选一类的方法在现实联邦场景中很可能失效,因此在模型达到稳定的预测水平之前的这一动态过程中,各客户端的数据需要遵循一个噪声留出机制以支持动态过程中的噪声学习,并快速提升模型性能,从而实现最终的修正。而关于现有技术的参照集依赖问题,本质上是忽略了联邦学习自身的合作特性所导致的。现有的技术仍处于传统机器学习方法中的依赖参照集模式,没有将重点转移到联邦学习的“联邦”优势上来,从而不得不依托于一个额外的参照数据集。
因此,亟需一种既能减少信息损失,又能不依赖额外参照集的联邦学习系统。
发明内容
因此,本发明的目的在于克服上述现有技术的存在的信息损失问题和参照集依赖等缺陷,提供一种基于联邦学习的噪声标签修正方法。
本发明的目的是通过以下技术方案实现的:
根据本发明的第一方面,提供一种基于联邦学习的噪声标签修正方法,包括:将客户端根据本地训练数据更新的本地模型参数和对应的样本数据量发送给服务端;获取服务端根据所述客户端以及其他客户端更新的本地模型参数和对应的样本数据量计算的全局模型参数;由所述客户端根据所述全局模型参数和本地训练数据,计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端;获取服务端根据所述多个本地类基准计算的多个全局类基准,并基于所述全局模型参数和全局类基准对所述客户端的本地训练数据进行噪声标签修正。
在本发明的一些实施例中,所述样本数据量包括本次更新本地模型所采用的本地训练数据的第一样本量,所述全局模型参数是由服务端对所有客户端的本地模型参数进行加权求和得到的,其中,所述客户端的第一样本量除以所述客户端以及其他客户端的第一样本量之和作为该所述客户端的本地模型参数的权值。
在本发明的一些实施例中,所述由所述客户端根据所述全局模型参数和本地训练数据,计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端的步骤包括:在所述客户端通过最新全局模型参数对本地训练数据进行预测,得到本地训练数据中每个样本在各个类别的预测概率;基于每个给定标签类下所有样本属于该给定标签类的预测概率计算平均值,得到该给定标签类对应的本地类基准。
在本发明的一些实施例中,所述样本数据量包括本次更新本地模型所采用的本地训练数据中每个给定标签类下的第二样本量,所述多个全局类基准包括每个给定标签类的全局类基准,每个给定标签类的全局类基准是由服务端对所有客户端的相应类别的本地类基准进行加权求和得到的,其中,所述客户端相应类别下的第二样本量除以所述客户端及其他客户端在该相应类别下的第二样本量之和作为该所述客户端的相应类别的本地类基准的权值。
在本发明的一些实施例中,所述基于所述全局模型参数和全局类基准对所述客户端的本地训练数据进行噪声标签修正的步骤包括:根据所述全局模型参数对本地训练数据中的每个样本进行预测得到预测结果,根据预测结果和全局类基准生成每个样本的伪标签,计算伪标签与当前标签不一致的样本的边际值,所述边际值等于根据所述全局模型参数对该样本预测得到的最大预测概率与该样本的当前标签对应的预测概率之差,将边际值大于预定阈值的样本的当前标签修改为伪标签。
在本发明的一些实施例中,所述根据预测结果和全局类基准生成每个样本的伪标签的步骤包括:在一个样本的预测结果中有任何类别的预测概率超过该类别的全局类基准时,该样本的伪标签为所有超过全局类基准的类中最大预测概率对应的类别,否则,该样本伪标签为预测结果中最大预测概率对应的类别。
可选的,全局类基准的计算方法如下:
其中,表示在第t轮训练时第k个客户端的类别l的本地类基准,表示第k个客户端的本地训练数据对应的数据集Dk中属于给定标签类l的所有样本的集合,表示全局模型在样本给定标签类l的预测概率,x表示特征数据,y表示类别,表示第t轮训练时对应的全局模型参数,n表示第n个客户端,N表示客户端的总数。
根据本发明的第二方面,提供一种联邦学习方法,包括:组织多个客户端进行联邦学习,在联邦学习的过程中各客户端利用第一方面的方法对本地训练数据的噪声标签进行修正以及获得全局模型参数;相应客户端利用获得的全局模型参数替换本地模型参数,利用修正噪声标签后的本地训练数据对本地模型进行训练,更新本地模型参数;在服务端根据多个客户端更新的本地模型参数更新全局模型参数。
根据本发明的第三方面,提供了一种联邦学习系统,包括:服务端和多个客户端,所述服务端,被配置为组织多个客户端以隐私保护的方式进行参数交换,生成中间参数,所述中间参数包括全局模型参数和全局类基准;所述多个客户端,被配置为基于中间参数以进行联邦学习,并利用第一方面的方法对本地训练数据的噪声标签进行修正,以及相应客户端利用获得的全局模型参数替换本地模型参数,利用修正噪声标签后的本地训练数据对本地模型进行训练,更新本地模型参数;服务端,还被配置为根据多个客户端更新的本地模型参数更新全局模型参数。
在本发明的一些实施例中,所述客户端包括:模型训练模块,用于将当前客户端根据本地训练数据更新的本地模型参数和对应的样本数据量发送给服务端;以及本地类基准计算模块,用于获取服务端根据多个客户端更新的本地模型参数和对应的样本数据量计算的全局模型参数,根据最新的全局模型参数和本地训练数据计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端;噪声修正模块,用于获取服务端根据所述多个本地类基准计算的多个全局类基准,并基于最新的全局模型参数和全局类基准对当前客户端的本地训练数据进行噪声标签修正;其中,模型训练模块,还用于利用标签修正后的本地训练数据对当前客户端本地模型进行训练,更新本地模型参数。
在本发明的一些实施例中,所述服务端包括:模型聚合模块,用于获取多个客户端发送的根据客户端自身本地训练数据更新的本地模型参数和对应的样本数据量,根据多个客户端更新的本地模型参数和对应的样本数据量计算全局模型参数并发送给多个客户端;以及全局类基准聚合模块,用于获取所述多个客户端发送的多个本地类基准,根据所述多个本地类基准计算多个全局类基准,相应类别的本地类基准是客户端根据最新的全局模型参数对本地训练数据计算的该类别的平均预测概率;其中,所述模型聚合模块,还用于获取一个或者多个客户端根据最新的全局模型参数和全局类基准对标签进行修正后更新的本地模型参数。
根据本发明的第四方面,提供一种电子设备,包括:一个或多个处理器;以及存储器,其中存储器用于存储一个或多个可执行指令;所述一个或多个处理器被配置为经由执行所述一个或多个可执行指令以实现利用第二方面的方法或者第三方面的联邦学习系统更新好的全局模型进行分类预测。
与现有技术相比,本发明的优点在于:
1、与传统噪声标签处理系统相比,不用丢弃数据或减少数据的参与度,而是通过修正噪声标签来实现信息增强,减少了数据的损失。
2、本发明利用联邦的“合作”特性来估计样本的真实标签,使得噪声估计过程摆脱对额外参照集的依赖,增强模型部署的可行性,有效地提高了联邦学习训练结构测试的准确率。
3、本发明既不需要上传原始数据也不需要上传样本位(samp l e-wi se,即原始数据,例如样本特征数据,涉及到用户隐私)的参数,通过交换类别位(c l ass-wi se)的中间参数(模型参数和类基准,不涉及用户隐私)来构建共识,修改噪声标签,满足了联邦学习系统的数据隐私要求。
附图说明
以下参照附图对本发明实施例作进一步说明,其中:
图1为根据本发明实施例的基于联邦学习的噪声标签修正方法的流程图;
图2为根据本发明实施例的联邦学习系统的示意图;
图3为根据本发明实施例的基于噪声标签修正方法的联邦学习系统示意图;
图4为根据本发明实施例在公开的行为识别数据集USC-HAD上,在噪声强度(数据含噪比例0.3)下的实验结果示意图;
图5为根据本发明实施例在公开的行为识别数据集USC-HAD上,在噪声强度(数据含噪比例0.4)下的实验结果示意图;
图6为根据本发明实施例在公开的行为识别数据集USC-HAD上,在噪声强度(数据含噪比例0.5)下的实验结果示意图。
具体实施方式
为了使本发明的目的,技术方案及优点更加清楚明白,以下结合附图通过具体实施例对本发明进一步详细说明。应当理解,此处所描述的具体实施例仅用以解释本发明,并不用于限定本发明。
在对本发明的实施例进行具体介绍之前,先对其中使用到的部分术语作如下解释:
客户端,也称为用户端,是指为客户提供服务的节点。客户端可以是不同的工作站(如医学组织机构、金融组织机构、地理分布的数据中心等),这些工作站一般存在数据上的屏障,无法直接交换数据或上传数据到一个信任的中心节点。客户端也可以是大量的移动设备或物联网设备,原始数据同样都保存在本地设备上。本发明实施例中的客户端不局限于任何应用场景。
服务端,是指为客户端提供服务的节点。服务端可以用于协调多个客户端在不泄露或者汇聚各方原始数据的情况下进行联合建模的终端。例如,用于组织一些中间参数(如本发明中的模型参数)的交换,承担着参数的更新、分发等工作,且对客户端的原始数据不可见。
机器学习模型,是由多个处理单元连接而形成的复杂人工神经网络。机器学习模型反映了人脑功能的基本特征,是一个高度复杂的非线性学习系统。其中,在客户端更新的模型称为本地模型,在服务端更新的模型称为全局模型。
如背景技术中提到的,发明人在进行联邦学习中的标签噪声研究时,发现现有技术中的信息损失问题是由于没有对噪声标签进行修正而导致的。标签修正往往需要模型达到一定的预测水平,由于服务端并不总能提供一个理想的干净参照集,通过参照模型进行数据预筛选一类的方法在现实联邦场景中很可能失效,因此在模型达到稳定的预测水平之前的这一动态过程中,各客户端的数据需要遵循一个噪声留出机制以支持动态过程中的噪声学习,并快速提升模型性能,从而实现最终的修正。
而关于现有技术的参照集依赖问题,本质上是忽略了联邦学习自身的合作特性所导致的。现有的技术仍处于传统机器学习方法中的依赖参照集模式,没有将重点转移到联邦学习的“联邦”优势上来,从而不得不依托于一个额外的参照数据集。发明人经过对联邦学习中的标签噪声的研究发现,解决该项缺陷可以通过共识的方法来实现。在一个现实的联邦场景中,由于各客户端的数据其数量和质量参差不齐,某一方基于本地有限的数据能识别出的噪声很有限,因此,各客户端之间可以通过共识方法生成全局模型参数以及全局类基准,从而支持利用全局模型参数以及全局类基准来修正噪声标签。依据本发明提出的方法,可以摆脱对额外干净参照集的依赖,降低人工修正噪声标签所带来的工作量,增强客户端数据的标签质量,提高模型的精度。
基于上述研究,根据本发明的一个实施例,如图1所示,提供一种基于联邦学习的噪声标签修正方法,包括执行步骤S1、S2、S3、S4,下面详细说明每个步骤。
步骤S1:将客户端根据本地训练数据更新的本地模型参数和对应的样本数据量发送给服务端。
根据本发明的一个实施例,将本地训练数据分批次投入客户端的本地模型中进行多轮训练,直到本地模型训练到指定轮次或本地模型达到收敛时,停止训练,得到本地模型参数,并发送给服务端。例如,可以是本地模型训练到指定的迭代轮次(Local_epoch)时,即视为本地模型收敛。如本地模型的迭代轮次(Local_epoch)可以指定为20轮(该值是一个经验值,一般取20,即认为在20轮内收敛,但其他实施者根据对不同模型训练的经验,可以自定义设置),每一轮里是以批次(Batch)为单位投入数据,每一批次10条,即每一个全局模型的迭代轮次(Global_epoch)对应着20个本地模型的迭代轮次(Local_epoch),20轮结束后得到更新后的本地模型参数并发送给服务端。本领域技术人员可以理解,此处的轮(Epoch)表示使用训练集的全部数据对模型进行一次完整的训练,被称为“一轮训练”。批次(Batch)表示使用训练集中的一小部分样本对模型权重参数进行一次反向传播的参数更新,这一小部分样本为“一批数据”。其中,本地模型的模型参数以一种随机梯度下降的方式进行更新,更新方式如下:
步骤S2:获取服务端根据所述客户端以及其他客户端更新的本地模型参数和对应的样本数据量计算的全局模型参数。
根据本发明的一个实施例,首先,需初始化服务端的全局模型,将本地模型在本地训练数据上训练指定轮次(例如上述训练20轮次)后,将训练好的本地模型参数发送给服务端,由服务端进行加权聚合得到全局模型并分发至各客户端。初始化后,将客户端的本地模型在本地训练数据上训练多轮,直到本地模型收敛后,将收敛后更新的本地模型参数和对应的样本数据量发送给服务端进行加权聚合。此处对应的样本数据量包括本次更新本地模型所采用的本地训练数据的第一样本量,所述全局模型参数是由服务端对所有客户端的本地模型参数进行加权求和得到的,其中,所述客户端的第一样本量除以所述客户端以及其他客户端的第一样本量之和作为该所述客户端的本地模型参数的权值。
根据本发明的一个实施例,全局模型参数可以按照以下公式计算:
步骤S3:由所述客户端根据所述全局模型参数和本地训练数据,计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端。
根据本发明的一个实施例,步骤S3中可以包括:
S32、基于每个给定标签类下所有样本属于该给定标签类的预测概率计算平均值,得到该给定标签类对应的本地类基准。
根据本发明的一个实施例,本地类基准可以按照以下公式计算:
其中,表示第k个客户端在第t轮计算的给定标签类l对应的本地类基准,表示第k个客户端数据集Dk中属于给定标签类l的所有样本的集合,表示模型在样本给定标签类l的预测概率,x表示特征数据,y表示类别,表示所述模型预测单元通过第t轮对应的全局模型参数。
步骤S4:获取服务端根据所述多个本地类基准计算的多个全局类基准,并基于所述全局模型参数和全局类基准对所述客户端的本地训练数据进行噪声标签修正。样本数据量包括本次更新本地模型所采用的本地训练数据中每个给定标签类下的第二样本量,所述多个全局类基准包括每个给定标签类的全局类基准,每个给定标签类的全局类基准是由服务端对所有客户端的相应类别的本地类基准进行加权求和得到的,其中,所述客户端相应类别下的第二样本量除以所述客户端及其他客户端在该相应类别下的第二样本量之和作为该所述客户端的相应类别的本地类基准的权值。该实施例的技术方案至少能够实现以下有益技术效果:按照该方式计算全局类基准,各个客户端没有彼此暴露原始数据,仅向服务端发送自身的本地类基准和第二样本量,服务端基于第二样本量对本地类基准进行加权聚合,得到全局类基准下发给多个客户端;客户端彼此并不知道其余客户端贡献的本地类基准和第二样本量,难以反推相应的数据,可以实现隐私保护。其中,在服务端根据多个本地类基准计算的多个全局类基准中,其中一个全局类基准的计算方式如下:
其中,表示在第t轮训练时第k个客户端的类别l的本地类基准,表示第k个客户端的本地训练数据对应的数据集Dk中属于给定标签类l的所有样本的集合,表示全局模型在样本给定标签类l的预测概率,x表示特征数据,y表示类别,表示第t轮训练时对应的全局模型参数,n表示第n个客户端,N表示客户端的总数。
根据本发明的一个实施例,步骤S4中包括:
S41、根据所述全局模型参数对本地训练数据中的每个样本进行预测得到预测结果。
S42、根据预测结果和全局类基准生成每个样本的伪标签,其中,在一个样本的预测结果中有任何类别的预测概率超过该类别的全局类基准时,该样本的伪标签为所有超过全局类基准的类中最大预测概率对应的类别,否则,该样本伪标签为预测结果中最大预测概率对应的类别。伪标签的计算方法为:
其中,m表示总类别数。
S43、计算伪标签与当前标签不一致的样本的边际值,边际值等于根据所述全局模型参数对该样本预测得到的最大预测概率与该样本的当前标签对应的预测概率之差。其中,将伪标签与当前标签一致的样本作为干净样本分选至训练集合中。
根据本发明的一个实施例,边际值可以按照以下公式计算:
即当样本的伪标签与当前标签不一致时,通过比较样本的边际值m(x)与经验阈值v的大小来分选出噪声样本和干净样本,当m(x)>τ,将当前样本视为噪声样本,转到步骤S44。当m(x)≤τ,则将当前样本视为干净样本。经验阈值τ可以由用户根据需要自定义设置,取值范围在(0,1)之间。
S44、将边际值大于预定阈值的样本的当前标签修改为伪标签。边际值大于预定阈值的样本即为噪声样本。
在一个实施例中,将边际值大于预定阈值的样本作为噪声样本并分选至噪声样本集合中,当客户端当轮的全局模型已经收敛且为第一次收敛(全局模型在验证集上的误差和准确率稳定时即为收敛)时,对边际值大于预定阈值的样本的当前标签进行修改,赋予修改后的伪标签,得到修正样本,将修正样本分选至训练集合中。将边际值小于预定阈值的样本作为干净样本并分选至训练集合中,收集完毕后,客户端利用标签修正后的本地训练数据对本地模型进行训练,更新本地模型参数。
在一个实施例中,本发明还提供一种联邦学习方法,包括以下步骤:
A1、组织多个客户端进行联邦学习,在联邦学习的过程中各客户端利用上述基于联邦学习的噪声标签修正方法对本地训练数据的噪声标签进行修正以及获得全局模型参数。
A2、相应客户端利用获得的全局模型参数替换本地模型参数,利用修正噪声标签后的本地训练数据对本地模型进行训练,更新本地模型参数。即:在修正噪声标签后,用最新获得的全局模型替换之前的本地模型,作为修正噪声标签后待训练的本地模型。
A3、在服务端根据多个客户端更新的本地模型参数更新全局模型参数。
根据本发明的一个实施例,本发明还提供了一种联邦学习系统,如图2所示,该系统可以包括:服务端和多个客户端。其中,每个客户端部署有本地模型和本地训练数据,服务端部署有全局模型,初始化时,服务端将初始全局模型下发给客户端,客户端基于下发的全局模型计算得到本地类基准,将本地类基准发送给服务端,服务端将各个客户端发送的本地类基准进行加权聚合得到全局类基准,客户端将获得的全局模型作为新的本地模型并基于全局类基准对本地训练数据进行噪声标签修正,采用噪声标签修正后的本地训练数据训练本地模型,更新的本地模型发送给服务端进行加权聚合,获得新的全局模型。在进行噪声标签修正时,可以基于全局类基准筛选出干净样本和噪声样本,对噪声样本的标签进行修正。本发明实施例中的本地模型和全局模型不局限任何结构、类型、应用场景,模型结构如随机森林。应用场景例如小微信贷(金融)对用户进行信用分类或者信用分级。
服务端,被配置为组织多个客户端以隐私保护的方式进行参数交换,生成中间参数,包括全局模型参数和全局类基准;
多个客户端,被配置为基于中间参数以进行联邦学习,联邦学习过程中各客户端利用上述基于联邦学习的噪声标签修正方法对本地训练数据的噪声标签进行修正,并且相应客户端利用获得的全局模型参数替换本地模型参数,利用修正噪声标签后的本地训练数据对本地模型进行训练,更新本地模型参数;
服务端,还被配置为根据多个客户端更新的本地模型参数更新全局模型参数。
在一个实施例中,如图3所示,包括至少两个客户端和服务端。
每个客户端可以包括:本地类基准计算模块、噪声修正模块和模型训练模块。
其中,本地类基准计算模块可以包括模型预测单元和本地类基准计算单元。
模型预测单元,可以用于通过第t轮分发的全局模型参数,预测各客户端的本地数据的每一样本在各个类别的概率,得到相应的预测概率。
本地类基准计算单元,可以用于基于每个给定标签类下所有样本属于该给定标签类的预测概率计算平均值,得到该给定标签类对应的本地类基准。
噪声修正模块可以包括伪标签生成单元、标签判断单元、边际判断单元、噪声样本留出单元、模型收敛判断单元、标签修正单元、修正样本等候训练单元和干净样本等候训练单元。
伪标签生成单元,可以用于将一个样本根据所述全局类基准生成一个伪标签。
标签判断单元,用于判断所述样本的给定标签与所述伪标签是否相等。
边际判断单元,用于计算样本的边际值,以及基于样本的边际值大小,分选出噪声样本和干净样本。
噪声样本留出单元,用于留出噪声样本。
模型收敛判断单元,用于判断当轮全局模型是否收敛。
标签修正单元,用于修正当轮留出所述噪声样本的噪声标签,得到修正样本。
修正样本等候训练单元,用于存放修正样本训练集。
干净样本等候训练单元,用于存放干净样本训练集。
模型训练模块包括模型训练单元,用于将当前客户端根据本地训练数据更新的本地模型参数和对应的样本数据量发送给服务端;还用于利用标签修正后的本地训练数据对当前客户端本地模型进行训练,更新本地模型参数。
根据本发明的一个实施例,服务端可以包括:模型聚合模块、全局类基准聚合模块。
其中,模型聚合模块,可以用于获取多个客户端发送的根据客户端自身本地训练数据更新的本地模型参数和对应的样本数据量,根据多个客户端更新的本地模型参数和对应的样本数据量计算全局模型参数并发送给多个客户端。
全局类基准聚合模块,可以用于获取多个客户端发送的多个本地类基准,根据多个本地类基准计算多个全局类基准,相应类别的本地类基准是客户端根据最新的全局模型参数对本地训练数据计算的该类别的平均预测概率。
模型聚合模块,还可以用于获取一个或者多个客户端根据最新的全局模型参数和全局类基准对标签进行修正后更新的本地模型参数,并将基于更新的本地模型参数聚合得到全局模型参数发送给多个客户端。
在联邦学习系统中,全局模型是通过加权聚合多个本地模型得到的,而本地模型是基于全局类基准对本地训练数据进行噪声标签修正,并采用噪声标签修正后的本地训练数据训练得到的。其中,在进行噪声标签修正时,可以基于全局类基准筛选出干净样本和噪声样本,进而对噪声样本的标签进行修正。根据本发明的一个实施例,参见图3,在上述联邦学习系统中执行的一种联邦学习方法,可以包括以下步骤(由于附图中所能记载的文字有限,步骤号内仅显示其对应的模块或单元名称):
B1、通过模型预测单元,基于第t轮分发的全局模型参数,预测各客户端的本地数据的每一样本在各个类别的概率,得到相应的预测概率。
B2、通过本地类基准计算单元,计算每个给定标签类下所有样本属于该给定标签类的预测概率的平均值,得到该给定标签类对应的本地类基准。
B3、通过全局类基准聚合模块,获取多个客户端发送的多个本地类基准,根据多个本地类基准计算得到多个全局类基准。
B4、通过伪标签生成单元,为一个样本根据所述全局类基准生成一个伪标签。
B5、标签判断单元判断伪标签是否等于给定标签。例如采用标签判断单元判断样本的给定标签与所述伪标签是否相等,将给定标签与伪标签相等的样本筛选为干净样本,并发送给干净样本等候训练单元;若样本的给定标签与伪标签不相等,则继续下一步。
B6、采用边际判断单元判断样本对应的边际值是否高于经验阈值,计算给定标签与伪标签不相等的样本的边际值,比较样本的边际值与经验阈值的大小来识别样本为噪声样本还是干净样本,若样本的边际值大于经验阈值,则样本筛选为噪声样本并发送至噪声样本留出单元,转至步骤B7,否则筛选为干净样本并发送至干净样本等候训练单元,并转至步骤B8。
B9、采用模型收敛判断单元判断当轮全局模型是否第一次达到收敛,若收敛则继续下一步,否则跳转至步骤B12。
B10、采用标签修正单元修正当轮筛选出的所述噪声样本的标签,得到修正样本。
B12、采用模型训练单元将当轮用于训练的数据(即训练集合)训练本地模型,直到本地模型收敛或达到指定的本地迭代轮次,例如,可以指定以20轮为本地模型的迭代轮次,每一轮里是将本地训练数据分为多个批次,以批次为单位投入数据,每一批次10条,即每一个全局模型的迭代轮次(Global_epoch)对应着20个本地模型的迭代轮次(Local_epoch),20轮结束后得到更新后的本地模型参数。
B13、采用模型聚合模块将当轮各客户端提供的数据量和本地模型参数采用加权聚合的方式得到全局模型参数,将全局模型参数分发至各客户端。重复上述过程,直至全局模型再次收敛或全局模型达到指定迭代轮次时停止训练。例如,可以是以500轮为全局迭代轮次(Global epoch)。
根据本发明一个实施例,在实际的应用场景中,最终是为了获得全局模型,并部署在不同的客户端,用于客户端的本地分类预测。因此,本发明还提供了一种电子设备,可以包括:一个或多个处理器;以及存储器,其中存储器用于存储一个或多个可执行指令;所述一个或多个处理器被配置为经由执行所述一个或多个可执行指令以实现利用上述的联邦学习方法或者上述的联邦学习系统更新好的全局模型进行分类预测。
为了验证本发明的效果,发明人进行了以下实验,实验前,选择数据集USC-HAD(University of Southern California Human Activity Dataset),该人类活动数据集中的类别(标签)包括:向前走、向左走、向右走、走上楼、走下楼、向前奔跑、跳跃、坐着、站着、睡觉、电梯向上、电梯下来。
实验时,选择不同的噪声强度进行多次预测对比,每个噪声强度下,均通过将本发明设计的联邦学习方法与传统的联邦学习以及其他方法进行在同样的迭代轮次下对测试集进行预测,得到实验结果如图4-6所示,图4、图5和图6分别表示在公开的行为识别数据集USC-HAD上,在不同噪声强度(数据含噪比例0.3、0.4、0.5)下的实验结果,其他对比方法均部署在联邦学习的框架下。图中横坐标为迭代轮数,纵坐标为测试集准确率,测试准确率越大说明联邦学习的训练过程设计越好。从图中可以发现,本发明可以有效地提高联邦学习训练结构的测试准确率,达到更准确地预测数据的效果,并且随着迭代次数的增长,其相对于传统联邦学习以及其他方法的优势更显著。即使在严重的噪声强度下,本发明设计仍保持高度优越性。
需要说明的是,虽然上文按照特定顺序描述了各个步骤,但是并不意味着必须按照上述特定顺序来执行各个步骤,实际上,这些步骤中的一些可以并发执行,甚至改变顺序,只要能够实现所需要的功能即可。
本发明可以是系统、方法和/或计算机程序产品。计算机程序产品可以包括计算机可读存储介质,其上载有用于使处理器实现本发明的各个方面的计算机可读程序指令。
计算机可读存储介质可以是保持和存储由指令执行设备使用的指令的有形设备。计算机可读存储介质例如可以包括但不限于电存储设备、磁存储设备、光存储设备、电磁存储设备、半导体存储设备或者上述的任意合适的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、静态随机存取存储器(SRAM)、便携式压缩盘只读存储器(CD-ROM)、数字多功能盘(DVD)、记忆棒、软盘、机械编码设备、例如其上存储有指令的打孔卡或凹槽内凸起结构、以及上述的任意合适的组合。
以上已经描述了本发明的各实施例,上述说明是示例性的,并非穷尽性的,并且也不限于所披露的各实施例。在不偏离所说明的各实施例的范围和精神的情况下,对于本技术领域的普通技术人员来说许多修改和变更都是显而易见的。本文中所用术语的选择,旨在最好地解释各实施例的原理、实际应用或对市场中的技术改进,或者使本技术领域的其它普通技术人员能理解本文披露的各实施例。
Claims (12)
1.一种基于联邦学习的噪声标签修正方法,其特征在于,包括:
将客户端根据本地训练数据更新的本地模型参数和对应的样本数据量发送给服务端;
获取服务端根据所述客户端以及其他客户端更新的本地模型参数和对应的样本数据量计算的全局模型参数;
由所述客户端根据所述全局模型参数和本地训练数据,计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端;
获取服务端根据所述多个本地类基准计算的多个全局类基准,并基于所述全局模型参数和全局类基准对所述客户端的本地训练数据进行噪声标签修正。
2.根据权利要求1所述的噪声标签修正方法,其特征在于,所述样本数据量包括本次更新本地模型所采用的本地训练数据的第一样本量,所述全局模型参数是由服务端对所有客户端的本地模型参数进行加权求和得到的,其中,所述客户端的第一样本量除以所述客户端以及其他客户端的第一样本量之和作为该所述客户端的本地模型参数的权值。
3.根据权利要求1所述的噪声标签修正方法,其特征在于,所述由所述客户端根据所述全局模型参数和本地训练数据,计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端的步骤包括:
在所述客户端通过最新全局模型参数对本地训练数据进行预测,得到本地训练数据中每个样本在各个类别的预测概率;
基于每个给定标签类下所有样本属于该给定标签类的预测概率计算平均值,得到该给定标签类对应的本地类基准。
4.根据权利要求1或权利要求3所述的噪声标签修正方法,其特征在于,所述样本数据量包括本次更新本地模型所采用的本地训练数据中每个给定标签类下的第二样本量,所述多个全局类基准包括每个给定标签类的全局类基准,
每个给定标签类的全局类基准是由服务端对所有客户端的相应类别的本地类基准进行加权求和得到的,其中,所述客户端相应类别下的第二样本量除以所述客户端及其他客户端在该相应类别下的第二样本量之和作为该所述客户端的相应类别的本地类基准的权值。
5.根据权利要求1所述的噪声标签修正方法,其特征在于,所述基于所述全局模型参数和全局类基准对所述客户端的本地训练数据进行噪声标签修正的步骤包括:
根据所述全局模型参数对本地训练数据中的每个样本进行预测得到预测结果,
根据预测结果和全局类基准生成每个样本的伪标签,
计算伪标签与当前标签不一致的样本的边际值,所述边际值等于根据所述全局模型参数对该样本预测得到的最大预测概率与该样本的当前标签对应的预测概率之差,
将边际值大于预定阈值的样本的当前标签修改为伪标签。
6.根据权利要求5所述的噪声标签修正方法,其特征在于,所述根据预测结果和全局类基准生成每个样本的伪标签的步骤包括:
在一个样本的预测结果中有任何类别的预测概率超过该类别的全局类基准时,该样本的伪标签为所有超过全局类基准的类中最大预测概率对应的类别,否则,该样本伪标签为预测结果中最大预测概率对应的类别。
8.一种联邦学习方法,其特征在于,包括:
组织多个客户端进行联邦学习,在联邦学习的过程中各客户端利用权利要求1-7任一项所述基于联邦学习的噪声标签修正方法对本地训练数据的噪声标签进行修正以及获得全局模型参数;
相应客户端利用获得的全局模型参数替换本地模型参数,利用修正噪声标签后的本地训练数据对本地模型进行训练,更新本地模型参数;
在服务端根据多个客户端更新的本地模型参数更新全局模型参数。
9.一种联邦学习系统,其特征在于,包括:服务端和多个客户端,
所述服务端,被配置为组织多个客户端以隐私保护的方式进行参数交换,生成中间参数,所述中间参数包括全局模型参数和全局类基准;
所述多个客户端,被配置为基于中间参数以进行联邦学习,并利用权利要求1-7任一项所述基于联邦学习的噪声标签修正方法对本地训练数据的噪声标签进行修正,以及相应客户端利用获得的全局模型参数替换本地模型参数,利用修正噪声标签后的本地训练数据对本地模型进行训练,更新本地模型参数;
服务端,还被配置为根据多个客户端更新的本地模型参数更新全局模型参数。
10.根据权利要求9所述的联邦学习系统,其特征在于,所述客户端包括:
模型训练模块,用于将当前客户端根据本地训练数据更新的本地模型参数和对应的样本数据量发送给服务端;
本地类基准计算模块,用于获取服务端根据多个客户端更新的本地模型参数和对应的样本数据量计算的全局模型参数,根据最新的全局模型参数和本地训练数据计算指示不同类别的平均预测概率的多个本地类基准并发送给服务端;以及
噪声修正模块,用于获取服务端根据所述多个本地类基准计算的多个全局类基准,并基于最新的全局模型参数和全局类基准对当前客户端的本地训练数据进行噪声标签修正;
其中,模型训练模块还利用标签修正后的本地训练数据对当前客户端本地模型进行训练,更新本地模型参数。
11.根据权利要求9所述的联邦学习系统,其特征在于,所述服务端包括:模型聚合模块,用于获取多个客户端发送的根据客户端自身本地训练数据更新的本地模型参数和对应的样本数据量,根据多个客户端更新的本地模型参数和对应的样本数据量计算全局模型参数并发送给多个客户端;以及
全局类基准聚合模块,用于获取所述多个客户端发送的多个本地类基准,根据所述多个本地类基准计算多个全局类基准,相应类别的本地类基准是客户端根据最新的全局模型参数对本地训练数据计算的该类别的平均预测概率;
其中所述模型聚合模块还用于获取一个或者多个客户端根据最新的全局模型参数和全局类基准对标签进行修正后更新的本地模型参数。
12.一种电子设备,其特征在于,包括:
一个或多个处理器;以及
存储器,其中存储器用于存储一个或多个可执行指令;
所述一个或多个处理器被配置为经由执行所述一个或多个可执行指令以实现利用权利要求8所述的联邦学习方法或者权利要求9-11任一项所述的联邦学习系统更新好的全局模型进行分类预测。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110666751.7A CN113379071B (zh) | 2021-06-16 | 2021-06-16 | 一种基于联邦学习的噪声标签修正方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110666751.7A CN113379071B (zh) | 2021-06-16 | 2021-06-16 | 一种基于联邦学习的噪声标签修正方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113379071A true CN113379071A (zh) | 2021-09-10 |
CN113379071B CN113379071B (zh) | 2022-11-29 |
Family
ID=77574723
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110666751.7A Active CN113379071B (zh) | 2021-06-16 | 2021-06-16 | 一种基于联邦学习的噪声标签修正方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113379071B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114827289A (zh) * | 2022-06-01 | 2022-07-29 | 深圳大学 | 一种通信压缩方法、系统、电子装置和存储介质 |
CN115577797A (zh) * | 2022-10-18 | 2023-01-06 | 东南大学 | 一种基于本地噪声感知的联邦学习优化方法及系统 |
WO2023216900A1 (zh) * | 2022-05-13 | 2023-11-16 | 北京字节跳动网络技术有限公司 | 用于模型性能评估的方法、装置、设备和存储介质 |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111275207A (zh) * | 2020-02-10 | 2020-06-12 | 深圳前海微众银行股份有限公司 | 基于半监督的横向联邦学习优化方法、设备及存储介质 |
CN112274925A (zh) * | 2020-10-28 | 2021-01-29 | 超参数科技(深圳)有限公司 | Ai模型训练方法、调用方法、服务器及存储介质 |
WO2021022707A1 (zh) * | 2019-08-06 | 2021-02-11 | 深圳前海微众银行股份有限公司 | 一种混合联邦学习方法及架构 |
US20210073639A1 (en) * | 2018-12-04 | 2021-03-11 | Google Llc | Federated Learning with Adaptive Optimization |
CN112862011A (zh) * | 2021-03-31 | 2021-05-28 | 中国工商银行股份有限公司 | 基于联邦学习的模型训练方法、装置及联邦学习系统 |
CN112906911A (zh) * | 2021-02-03 | 2021-06-04 | 厦门大学 | 联邦学习的模型训练方法 |
-
2021
- 2021-06-16 CN CN202110666751.7A patent/CN113379071B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210073639A1 (en) * | 2018-12-04 | 2021-03-11 | Google Llc | Federated Learning with Adaptive Optimization |
WO2021022707A1 (zh) * | 2019-08-06 | 2021-02-11 | 深圳前海微众银行股份有限公司 | 一种混合联邦学习方法及架构 |
CN111275207A (zh) * | 2020-02-10 | 2020-06-12 | 深圳前海微众银行股份有限公司 | 基于半监督的横向联邦学习优化方法、设备及存储介质 |
CN112274925A (zh) * | 2020-10-28 | 2021-01-29 | 超参数科技(深圳)有限公司 | Ai模型训练方法、调用方法、服务器及存储介质 |
CN112906911A (zh) * | 2021-02-03 | 2021-06-04 | 厦门大学 | 联邦学习的模型训练方法 |
CN112862011A (zh) * | 2021-03-31 | 2021-05-28 | 中国工商银行股份有限公司 | 基于联邦学习的模型训练方法、装置及联邦学习系统 |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023216900A1 (zh) * | 2022-05-13 | 2023-11-16 | 北京字节跳动网络技术有限公司 | 用于模型性能评估的方法、装置、设备和存储介质 |
CN114827289A (zh) * | 2022-06-01 | 2022-07-29 | 深圳大学 | 一种通信压缩方法、系统、电子装置和存储介质 |
CN115577797A (zh) * | 2022-10-18 | 2023-01-06 | 东南大学 | 一种基于本地噪声感知的联邦学习优化方法及系统 |
CN115577797B (zh) * | 2022-10-18 | 2023-09-26 | 东南大学 | 一种基于本地噪声感知的联邦学习优化方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN113379071B (zh) | 2022-11-29 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111124840B (zh) | 业务运维中告警的预测方法、装置与电子设备 | |
CN113379071B (zh) | 一种基于联邦学习的噪声标签修正方法 | |
CN107636690B (zh) | 基于卷积神经网络的全参考图像质量评估 | |
CN109271958B (zh) | 人脸年龄识别方法及装置 | |
US10565525B2 (en) | Collaborative filtering method, apparatus, server and storage medium in combination with time factor | |
CN111178523A (zh) | 一种行为检测方法、装置、电子设备及存储介质 | |
CN112465043B (zh) | 模型训练方法、装置和设备 | |
CN111177473B (zh) | 人员关系分析方法、装置和可读存储介质 | |
US20210374582A1 (en) | Enhanced Techniques For Bias Analysis | |
CN112990478B (zh) | 联邦学习数据处理系统 | |
CN112365007A (zh) | 模型参数确定方法、装置、设备及存储介质 | |
CN111159241B (zh) | 一种点击转化预估方法及装置 | |
CN110688484B (zh) | 一种基于不平衡贝叶斯分类的微博敏感事件言论检测方法 | |
CN115391561A (zh) | 图网络数据集的处理方法、装置、电子设备、程序及介质 | |
CN115730947A (zh) | 银行客户流失预测方法及装置 | |
CN114548300B (zh) | 解释业务处理模型的业务处理结果的方法和装置 | |
CN112836750A (zh) | 一种系统资源分配方法、装置及设备 | |
Ahamed et al. | ATTL: an automated targeted transfer learning with deep neural networks | |
CN115577797A (zh) | 一种基于本地噪声感知的联邦学习优化方法及系统 | |
CN112528500B (zh) | 一种场景图构造模型的评估方法及评估设备 | |
CN112035736B (zh) | 信息推送方法、装置及服务器 | |
CN114170000A (zh) | 信用卡用户风险类别识别方法、装置、计算机设备和介质 | |
CN112308466A (zh) | 企业资质审核方法、装置、计算机设备和存储介质 | |
CN110837847A (zh) | 用户分类方法及装置、存储介质、服务器 | |
CN116610484B (zh) | 一种模型训练方法、故障预测方法、系统、设备以及介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |