CN115731424B - 基于强化联邦域泛化的图像分类模型训练方法及系统 - Google Patents

基于强化联邦域泛化的图像分类模型训练方法及系统 Download PDF

Info

Publication number
CN115731424B
CN115731424B CN202211539820.9A CN202211539820A CN115731424B CN 115731424 B CN115731424 B CN 115731424B CN 202211539820 A CN202211539820 A CN 202211539820A CN 115731424 B CN115731424 B CN 115731424B
Authority
CN
China
Prior art keywords
image
sample
strategy
data
federal
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202211539820.9A
Other languages
English (en)
Other versions
CN115731424A (zh
Inventor
李雅文
管泽礼
薛哲
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing University of Posts and Telecommunications
Original Assignee
Beijing University of Posts and Telecommunications
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing University of Posts and Telecommunications filed Critical Beijing University of Posts and Telecommunications
Priority to CN202211539820.9A priority Critical patent/CN115731424B/zh
Publication of CN115731424A publication Critical patent/CN115731424A/zh
Application granted granted Critical
Publication of CN115731424B publication Critical patent/CN115731424B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请提供一种基于强化联邦域泛化的图像分类模型训练方法及系统,方法包括:本申请基于强化学习设计了特征去相关策略,将样本加权转化为在联邦学习客户端间共享的参数化策略。通过经验回放,补充特征全局信息,在联邦学习的过程中从全局的角度对特征去相关,各客户端基于加权后的样本训练模型。使训练后的全局模型学习到特征与标签的根本关联,泛化到未知域图像数据。本申请能够针对使联邦学习中的图像分类模型训练过程中的在未知域图像数据进行域具有泛化能力,能够防止图像分类模型在联邦训练过程中学习到数据中的虚假关联,能够有效提高图像分类模型训练过程的有效性及可靠性,进而能够提高应用图像分类模型进行图像分类的有效性及准确性。

Description

基于强化联邦域泛化的图像分类模型训练方法及系统
技术领域
本申请涉及图像处理技术领域,尤其涉及基于强化联邦域泛化的图像分类模型训练方法及系统。
背景技术
联邦学习是用于进行图像分类的有效方式,联邦学习可以在图像数据不出本地的条件下,采用多方协作的方式,学习各个客户端本地图像数据中隐含的知识,共同训练一个有效的图像分类模型。在实际应用中,客户端数据除了类别的不均衡,还存在着数据风格不同得情况,每种风格可以被看作一种域,每个客户端数据的域分布可能是不同的。这种分布的偏移会导致模型退化,主要是模型学习到了特征与标签之间的虚假关联,而这种虚假关联本质上是由与类别标签不相关特征和相关特征之间的关联导致的。针对这种域分布偏移问题,研究人员提出了域泛化技术。它可以在测试集域未知的情况下,训练一个泛化能力强的模型,根本上是学习一个不随着域改变的规则。
由于,特征各维度变量之间的相关性的存在会使小的误差膨胀到任意大,从而导致不同域的测试数据分类性能不稳定。现有的域泛化方式通常采用基于样本加权的方式,通过去除相关特征和不相关特征之间的相关性来实现分布外泛化,但是这种数据全局感知方法,由于保存的特征和权重信息不能在各客户端之间分享,无法有效应用到联邦学习中,而其他能够应用到联邦学习的域泛化方式也无法适用于图像分类任务,进而无法保证基于联邦学习的图像分类模型的训练过程的有效性及可靠性。
发明内容
鉴于此,本申请实施例提供了基于强化联邦域泛化的图像分类模型训练方法及系统,以消除或改善现有技术中存在的一个或更多个缺陷。
本申请的一个方面提供了一种基于强化联邦域泛化的图像分类模型训练方法,包括:
根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充;
以本训练批次数据作为局部信息,补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种消除样本特征各维度之间的相关性的样本加权策略;
应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数;
将所述模型参数发送至自身所在的联邦学习系统中的服务器,以使该服务器对该模型参数和其接收的其他多个模型参数进行聚合以得到所述图像分类模型当前的总模型参数。
在本申请的一些实施例中,在所述根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充之前,还包括:
接收自身所在的联邦学习系统中的服务器发送的基于强化联邦域泛化的图像分类模型的总模型参数;
根据所述总模型参数对本地的基于强化联邦域泛化的图像分类模型进行初始化处理;
自历史图像样本中选取预设数量的已设有类型标签的历史图像样本以作为当前的目标图像样本,形成包含有各个所述目标图像样本及对应的类型标签的历史图像数据。
在本申请的一些实施例中,所述根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充,包括:
在本地的数据缓冲区中随机选取与所述目标图像样本的数量相同的已设有权重的历史图像样本以作为当前的补充图像样本,形成包含有各个所述补充图像样本和对应的权重的用于补充强化学习环境的补充数据。
在本申请的一些实施例中,所述基于强化联邦域泛化的图像分类模型包括:特征提取器、基于强化学习的策略梯度模块和分类器;
所述特征提取器用于根据输入的各个目标图像样本和各个补充图像样本,对应输出各个目标图像样本和各个补充图像样本各自对应的特征向量;
所述基于强化学习的策略梯度模块用于根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重;
所述分类器用于根据各个目标图像样本的特征向量,对应输出各个所述目标图像样本的类型预测标签,并计算各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失。
在本申请的一些实施例中,所述以本训练批次数据作为局部信息,补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种消除样本特征各维度之间的相关性的样本加权策略;应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数,包括:
将各个所述目标图像样本和各个所述补充图像样本输入所述特征提取器,以使该特征提取器输出各个目标图像样本和各个补充图像样本各自对应的特征向量;
将各个目标图像样本的特征向量输入所述分类器,以使该分类器输出各个所述目标图像样本的类型预测标签;
以及,将各个目标图像样本和各个补充图像样本各自对应的特征向量,以及各个补充图像样本的权重,输入所述基于强化学习的策略梯度模块,以使该策略梯度模块根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重;
计算所述分类器输出的各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失,并得到更新后的图像分类模型的模型参数。
在本申请的一些实施例中,所述模型参数包括:所述特征提取器的参数、所述分类器的参数和所述基于强化学习的策略梯度模块的样本加权策略的参数。
本申请的另一个方面提供了一种基于强化联邦域泛化的图像分类模型训练装置,包括:
经验回放模块,用于根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充;
强化特征去相关策略学习模块,用于以本训练批次数据作为局部信息,经验回放模块获取的补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种可以消除样本特征各维度之间的相关性的样本加权策略;
强化联邦训练模块,用于应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数;
数据发送模块,用于将所述模型参数发送至自身所在的联邦学习系统中的服务器,以使该服务器对该模型参数和其接收的其他多个模型参数进行聚合以得到所述图像分类模型当前的总模型参数。
本申请的第三个方面提供了一种用于图像分类的联邦学习系统,包括:服务器和与该服务器之间通信连接的多个客户端设备;
各个所述客户端设备分别用于实现所述的基于强化联邦域泛化的图像分类模型训练方法;
所述服务器用于接收各个所述客户端设备分别在各自本地训练得到的当前训练轮次的模型参数,并对各个所述模型参数进行聚合处理,以得到当前训练轮次的所述图像分类模型当前的总模型参数,并在下一个训练轮次时将所述总模型参数分别发送至各个所述客户端设备。
本申请的第四个方面提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现所述的基于强化联邦域泛化的图像分类模型训练方法。
本申请的第五个方面提供了一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时实现所述的基于强化联邦域泛化的图像分类模型训练方法。
本申请提供的基于强化联邦域泛化的图像分类模型训练方法,根据历史图像数据随机选取对应的补充数据;应用所述历史图像数据和所述补充数据训练本地当前基于强化联邦域泛化的图像分类模型,使得该图像分类模型采用预设的基于强化学习的策略梯度算法对提取到的所述历史图像数据和补充数据分别对应的特征进行域泛化处理,并基于域泛化后的特征训练分类器,以得到更新后的图像分类模型的模型参数;将所述模型参数发送至自身所在的联邦学习系统中的服务器,以使该服务器对该模型参数和其接收的其他多个模型参数进行聚合以得到所述图像分类模型当前的总模型参数;可以帮助联邦模型在未知域数据泛化,能够针对联邦学习中的图像分类任务中的未知域图像数据进行域泛化,实现基于策略的强化特征去相关,能够避免数据域分布与各客户端的数据域分布存在偏移而导致模型退化,能够防止图像分类模型学习到数据中的虚假关联,能够提高联邦学习训练的图像分类模型对未知域图像分类的能力,能够有效提高图像分类模型训练过程的有效性及可靠性,进而能够提高应用图像分类模型进行图像分类的有效性及准确性。
本申请的附加优点、目的,以及特征将在下面的描述中将部分地加以阐述,且将对于本领域普通技术人员在研究下文后部分地变得明显,或者可以根据本申请的实践而获知。本申请的目的和其它优点可以通过在说明书以及附图中具体指出的结构实现到并获得。
本领域技术人员将会理解的是,能够用本申请实现的目的和优点不限于以上具体所述,并且根据以下详细说明将更清楚地理解本申请能够实现的上述和其他目的。
附图说明
此处所说明的附图用来提供对本申请的进一步理解,构成本申请的一部分,并不构成对本申请的限定。附图中的部件不是成比例绘制的,而只是为了示出本申请的原理。为了便于示出和描述本申请的一些部分,附图中对应部分可能被放大,即,相对于依据本申请实际制造的示例性装置中的其它部件可能变得更大。在附图中:
图1为本申请一实施例中的基于强化联邦域泛化的图像分类模型训练方法的第一种流程示意图。
图2为本申请一实施例中的基于强化联邦域泛化的图像分类模型训练方法的第二种具体流程示意图。
图3为本申请一实施例中的基于强化联邦域泛化的图像分类模型的结构示意图。
图4为本申请另一实施例中的基于强化联邦域泛化的图像分类模型训练装置的结构示意图。
图5为本申请另一实施例中的用于图像分类的联邦学习系统的架构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚明白,下面结合实施方式和附图,对本申请做进一步详细说明。在此,本申请的示意性实施方式及其说明用于解释本申请,但并不作为对本申请的限定。
在此,还需要说明的是,为了避免因不必要的细节而模糊了本申请,在附图中仅仅示出了与根据本申请的方案密切相关的结构和/或处理步骤,而省略了与本申请关系不大的其他细节。
应该强调,术语“包括/包含”在本文使用时指特征、要素、步骤或组件的存在,但并不排除一个或更多个其它特征、要素、步骤或组件的存在或附加。
在此,还需要说明的是,如果没有特殊说明,术语“连接”在本文不仅可以指直接连接,也可以表示存在中间物的间接连接。
在下文中,将参考附图描述本申请的实施例。在附图中,相同的附图标记代表相同或类似的部件,或者相同或类似的步骤。
深度学习模型的效果与数据的质与量是高度相关的。模型可以从高质量有代表性的数据中心学习有用的知识。但是,高质量的数据通常掌握在众多公司、组织与设备中,由于隐私、法规与利益等因素,这些数据不能在各方自由流动,难以集中起来训练模型。
联邦学习可以在数据不出本地的条件下,采用多方协作的方式,学习各个客户端数据中隐含的知识,共同训练一个有效的模型。联邦学习的核心问题是非独立同分布问题,现有方法主要关注各客户端训练数据分布的差异。有研究人员从样本类别不均衡和本地异常更新限制的角度来解决客户端数据非独立同分布的问题,取得了较好的效果。但是在实际应用中,客户端数据除了类别的不均衡,还存在着数据风格不同得情况,每种风格可以被看作一种域,每个客户端数据的域分布可能是不同的。特别是联邦学习模型,由于各客户端收集数据的时间和策略不同,实际应用中通常存在较严重的非独立同分布文题。并且,联邦模型通常还面临着应用目标数据集不可见,也就是测试数据域分布会与各客户端的域分布可能发生偏移是情况。这种分布的偏移会导致模型退化,主要是模型学习到了特征与标签之间的虚假关联,而这种虚假关联本质上是由与类别标签不相关特征和相关特征之间的关联导致的。
针对这种域分布偏移问题,研究人员提出了域泛化技术。它可以在测试集域未知的情况下,训练一个泛化能力强的模型,根本上是学习一个不随着域改变的规则。传统的域泛化模型,通常包含了域平衡的假设或需要手工标记域信息的标签,这些条件在实际应用中都很难达到。
具体来说,近年来有研究人员证明了,特征之间相关性的存在,可以将一个小的虚假关联误差膨胀到任意大,从而导致不同分布式测试数据的预测性能不稳定。根据该发现他们提出了基于样本加权的域泛化模型,通过去除相关特征和不相关特征之间的相关性来实现分布外泛化。该方法进一步将此类基于样本加权去相关的方法扩展到深度学习模型中,通过迭代地保存和重载数据的特征和权重来全局感知和消除特征相关性,帮助深度学习模型特征与标签之间的虚假相关性。但是,这种数据全局感知方法,由于保存的特征和权重信息不能在各客户端之间分享,不能应用到联邦学习中。如何让每个客户端上域泛化方法感知数据的全局信息是一个新的挑战。
另外,还有研究将域泛化引入到联邦学习中,应用到物体边缘识别任务中,在未知域中测试取得了很好的效果。另有学者采用生成对抗的方式,学习一个全局生成器,在不直接共享数据的前提下,使各个客户端可以在本地训练过程中感知全局数据。然而,第一种方法只针对物体边缘识别任务,不适用于图像分类任务。第二种,通过生成全局数据的方法并不能防止模型学到虚假关联。并且,这两种方法都假设每个客户端的数据为同一个域,这些条件在实际应用中都很难达到。
本申请考虑如何实现适用于基于联邦学习的图像分类任务的域泛化方式,并将这种域泛化方式用在基于联邦学习的图像分类模型训练过程中,来针对联邦学习中的图像分类任务中的未知域图像数据进行域泛化,实现基于策略的强化特征去相关。
通过大量的研究分析工作,本申请首先提出了一种基于强化学习的域泛化方式,并将这种方式改进后用在基于联邦学习的图像分类任务中,这种强化学习应用在联邦学习过程中实现域泛化的方式,可以称之为强化联邦域泛化。在此基础上,本申请进一步提出基于强化联邦域泛化的图像分类模型训练方法,以针对联邦学习中的图像分类任务,实现基于策略的强化特征去相关,以解决在联邦学习框架下模型应用目标数据集不可见,测试数据域分布与各客户端的数据域分布存在偏移,导致模型退化的问题。
在本申请的一个或多个实施例中,强化学习RL(Reinforcement Learning)是一种可以通过智能体与环境的交互,根据反馈学习一种策略的技术,以奖励最大化为目标,迭代更新智能体策略,智能体每一个动作都会从环境中获得奖励值的反馈。它在机器人、控制和在线学习领域有着广泛的应用。近年来,有研究将强化学习应用到联邦学习的通信策略和本地学习中,还有研究通过强化学习样本选择策略,取得了不错的效果。强化学习可以很好的根据环境与反馈,学习到一种策略,这种策略可以表示为一组参数,这种由参数表示的策略,就可以通过联邦学习的方式迭代训练。
在本申请的一个或多个实施例中,域泛化可以从多个已知域的数据集中学习到可以泛化到其他领域的模型,本质上是学习域不变的特征与标签之间的关联。
在本申请的一个或多个实施例中,联邦学习可以采用多方协作的方式,在不需要分享本地数据的条件下,训练一个全局模型。最有代表性的方法是FedAvg,它先通过服务端参数下发对客户端参数进行初始化,采用客户端的本地数据训练本地模型并发送到服务端,服务端对收到的参数进行聚合,完成一轮的联邦训练。
具体通过下述实施例进行详细说明。
本申请实施例提供一种可以由基于强化联邦域泛化的图像分类模型训练装置执行的基于强化联邦域泛化的图像分类模型训练方法,参见图1,所述基于强化联邦域泛化的图像分类模型训练方法具体包含有如下内容:
步骤100:根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充。
步骤200:以本训练批次数据作为局部信息,补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种消除样本特征各维度之间的相关性的样本加权策略。
在步骤100和步骤200中,历史图像数据(也可以称之为目标历史图像数据)是指在客户端设备本地的历史图像数据中选取的当前轮次用于训练图形分类模型的训练用数据,且这些历史图像数据的域均是未知状态并设有对应的分类标签。分类标签是指用于标识或描述图像数据所属类型的标签,具体可以根据实际应用需求设置。本训练批次数据是客户端本地图像数据,历史数据是经过加权后的本地图像数据。
可以理解的是,补充数据可以指在客户端设备本地的以用于训练图形分类模型后的历史图像数据,还可以包含有这些历史图像数据的特征向量对应的权重,但不包含特征向量,是因为每个轮次训练后,特征提取器的参数已经发生了变化,特征的空间也会发生偏移,所以本申请选择保留历史图像数据作为补充数据,能够进一步提高联邦学习训练的图像分类模型对未知域图像分类的能力,并能够有效提高图像分类模型训练过程的有效性及可靠性。
另外,步骤100不同于现有的单纯的加入以往数据重新训练,而是随机在历史图像数据中随机选取一个轮次的补充数据作为新数据加权的条件,能够有效提高补充强化学习环境信息的有效性及可靠性。
步骤300:应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数。
在步骤300中,基于强化联邦域泛化的图像分类模型是指能够采用强化学习的方式对联邦学习中的图像分类训练任务进行数据域泛化并实现图像分类的模型。
可以理解的是,在步骤300之后,步骤100中选用的历史图像数据可以存储在数据缓存区中,作为下一个轮次的补充数据的备选之一。
步骤400:将所述模型参数发送至自身所在的联邦学习系统中的服务器,以使该服务器对该模型参数和其接收的其他多个模型参数进行聚合以得到所述图像分类模型当前的总模型参数。
在本申请的一个或多个实施例中,所述模型参数或所述总模型参数均包括:所述特征提取器的参数、所述分类器的参数和所述基于强化学习的策略梯度模块的样本加权策略的参数。
从上述描述可知,本申请实施例提供的基于强化联邦域泛化的图像分类模型训练方法,能够针对联邦学习中的图像分类任务中的未知域图像数据进行域泛化,实现基于策略的强化特征去相关,能够避免数据域分布与各客户端的数据域分布存在偏移而导致模型退化,能够防止图像分类模型学习到数据中的虚假关联,能够提高联邦学习训练的图像分类模型对未知域图像分类的能力,能够有效提高图像分类模型训练过程的有效性及可靠性,进而能够提高应用图像分类模型进行图像分类的有效性及准确性。
为了进一步图像分类模型训练的有效性及可靠性,在本申请实施例提供的一种基于强化联邦域泛化的图像分类模型训练方法中,参见图2,所述基于强化联邦域泛化的图像分类模型训练方法中的步骤100之前还具体包含有如下内容:
步骤010:接收自身所在的联邦学习系统中的服务器发送的基于强化联邦域泛化的图像分类模型的总模型参数。
步骤020:根据所述总模型参数对本地的基于强化联邦域泛化的图像分类模型进行初始化处理。
步骤030:自历史图像样本中选取预设数量的已设有类型标签的历史图像样本以作为当前的目标图像样本,形成包含有各个所述目标图像样本及对应的类型标签的历史图像数据。
为了进一步训练数据选取的有效性及可靠性,在本申请实施例提供的一种基于强化联邦域泛化的图像分类模型训练方法中,参见图2,所述基于强化联邦域泛化的图像分类模型训练方法中的步骤100具体包含有如下内容:
步骤110:在本地的数据缓冲区中随机选取与所述目标图像样本的数量相同的已设有权重的历史图像样本以作为当前的补充图像样本,形成包含有各个所述补充图像样本和对应的权重的用于补充强化学习环境的补充数据。
为了进一步图像分类模型的应用有效性及可靠性,在本申请实施例提供的一种基于强化联邦域泛化的图像分类模型训练方法中,参见图3,所述基于强化联邦域泛化的图像分类模型包括:特征提取器、基于强化学习的策略梯度模块和分类器。
所述特征提取器用于根据输入的各个目标图像样本和各个补充图像样本,对应输出各个目标图像样本和各个补充图像样本各自对应的特征向量;其中,所述特征提取器可以称之为特征提取模块。
所述基于强化学习的策略梯度模块用于根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重;可以理解的是,Frobenius范数(F-范数)是一种矩阵范数,即矩阵中每项数的平方和的开方值。
所述分类器用于根据各个目标图像样本的特征向量,对应输出各个所述目标图像样本的类型预测标签,并计算各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失。
基于上述基于强化联邦域泛化的图像分类模型的结构,为了进一步图像分类模型训练的有效性及可靠性,在本申请实施例提供的一种基于强化联邦域泛化的图像分类模型训练方法中,参见图2,所述基于强化联邦域泛化的图像分类模型训练方法中的步骤200和步骤300具体包含有如下内容:
步骤210:将各个所述目标图像样本和各个所述补充图像样本输入所述特征提取器,以使该特征提取器输出各个目标图像样本和各个补充图像样本各自对应的特征向量。
具体来说,特征提取需要将难以直接计算的图像信息转化为易于计算的向量形式,本质上是学习一种从图像到特征H的映射:f:x→h,其中,mh是特征维度。
特征提取模块需要有提取图像丰富的深层特征的能力。以往的工作设计了众多通用的特征选择模型结构,如VGG和Resnet,其中Resnet在众多应用中展示的优异效果,充分说明了它提取图像深层特征的能力。本申请选择Resnet作为本申请的特征提取器,客户端设备(client)中的特征提取可以形式化定义为:
s=f(x)
步骤220:将各个目标图像样本的特征向量输入所述分类器,以使该分类器输出各个所述目标图像样本的类型预测标签。
具体来说,分类器c由一个单层感知机来构造,输入维度为特征维度,输出维度为样本类别数量。图像通过特征提取模块得到样本特征,将样本特征输入图像分类器,得到样本预测的标签,根据预测结果与真实标签计算交叉熵,乘以样本权重w作为分类损失Lc
通过对样本加权降低特征相关性,相关的特征不会由于特定数据的原因被放大与强化,使模型学习到错误的不能在未知域上泛化的规则。
步骤230:将各个目标图像样本和各个补充图像样本各自对应的特征向量,以及各个补充图像样本的权重,输入所述基于强化学习的策略梯度模块,以使该策略梯度模块根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重。
具体来说,通过强化学习的策略梯度算法,学习一种可以使样本特征尽可能独立的样本加权策略。它根据策略智能体与环境的交互产生的基于特征独立性计算的奖励,学习样本到权重的计算策略。这种策略是参数化的,可以通过联邦学习分享与聚合的。本申请通过独立性检验,衡量特征之间的相关性。以消除特征之间的依赖关系为目标,使模型学习到可以泛化的知识。
本方法采用强化学习的策略梯度方法来拟合这种策略。策略梯度主要包含三个元素,状态、环境、反馈和动作,智能体通过与环境的交互感知状态,根据状态与规则产生动作,然后根据动作环境会产生反馈,智能体根据状态和对应的反馈调整参数,来学习参数化的规则。
步骤240:计算所述分类器输出的各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失,并得到更新后的图像分类模型的模型参数。
从软件层面来说,本申请还提供一种用于执行所述基于强化联邦域泛化的图像分类模型训练方法中全部或部分内的基于强化联邦域泛化的图像分类模型训练装置,参见图4,所述基于强化联邦域泛化的图像分类模型训练装置具体包含有如下内容:
经验回放模块10,用于根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充;
强化特征去相关策略学习模块20,用于以本训练批次数据作为局部信息,经验回放模块获取的补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种可以消除样本特征各维度之间的相关性的样本加权策略;
强化联邦训练模块30,用于应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数;
数据发送模块40,用于将所述模型参数发送至自身所在的联邦学习系统中的服务器,以使该服务器对该模型参数和其接收的其他多个模型参数进行聚合以得到所述图像分类模型当前的总模型参数。
本申请提供的基于强化联邦域泛化的图像分类模型训练装置的实施例具体可以用于执行上述实施例中的基于强化联邦域泛化的图像分类模型训练方法的实施例的处理流程,其功能在此不再赘述,可以参照上述基于强化联邦域泛化的图像分类模型训练方法实施例的详细描述。
所述基于强化联邦域泛化的图像分类模型训练装置进行基于强化联邦域泛化的图像分类模型训练的部分可以在客户端设备中完成。具体可以根据所述客户端设备的处理能力,以及用户使用场景的限制等进行选择。本申请对此不作限定。若所有的操作都在所述客户端设备中完成,所述客户端设备还可以包括处理器,用于基于强化联邦域泛化的图像分类模型训练的具体处理。
上述的客户端设备可以具有通信模块(即通信单元),可以与远程的服务器进行通信连接,实现与所述服务器的数据传输。所述服务器可以包括任务调度中心一侧的服务器,其他的实施场景中也可以包括中间平台的服务器,例如与任务调度中心服务器有通信链接的第三方服务器平台的服务器。所述的服务器可以包括单台计算机设备,也可以包括多个服务器组成的服务器集群,或者分布式装置的服务器结构。
上述服务器与所述客户端设备端之间可以使用任何合适的网络协议进行通信,包括在本申请提交日尚未开发出的网络协议。所述网络协议例如可以包括TCP/IP协议、UDP/IP协议、HTTP协议、HTTPS协议等。当然,所述网络协议例如还可以包括在上述协议之上使用的RPC协议(Remote Procedure Call Protocol,远程过程调用协议)、REST协议(Representational State Transfer,表述性状态转移协议)等。
从上述描述可知,本申请实施例提供的基于强化联邦域泛化的图像分类模型训练装置,能够针对联邦学习中的图像分类任务中的未知域图像数据进行域泛化,实现基于策略的强化特征去相关,能够避免数据域分布与各客户端的数据域分布存在偏移而导致模型退化,能够防止图像分类模型学习到数据中的虚假关联,能够提高联邦学习训练的图像分类模型对未知域图像分类的能力,能够有效提高图像分类模型训练过程的有效性及可靠性,进而能够提高应用图像分类模型进行图像分类的有效性及准确性。
另外,本申请还提供一种用于图像分类的联邦学习系统的实施例,参见图5,所述联邦学习系统具体包含有如下内容:
服务器和与该服务器之间通信连接的多个客户端设备;在图5中,多个客户端设备可以包含有客户端1至客户端P,P为大于2的正整数。
各个所述客户端设备分别用于实现前述实施例提及的所述的基于强化联邦域泛化的图像分类模型训练方法;
所述服务器用于接收各个所述客户端设备分别在各自本地训练得到的当前训练轮次的模型参数,并对各个所述模型参数进行聚合处理,以得到当前训练轮次的所述图像分类模型当前的总模型参数,并在下一个训练轮次时将所述总模型参数分别发送至各个所述客户端设备。
具体来说,服务端聚合,接收客户端发回的参数,根据客户端数据量对参数加权后聚合,将聚合后的模型参数发回客户端。在首次迭代时,对特征提取器、分类器和基于强化学习的策略梯度模块的参数进行初始化得到f0,c0和Policy0,并下发给各个联邦客户端,各客户端训练后,将训练好的参数返回服务端,服务端根据客户端的数据量占数据总量的比例对每个客户端返回的参数加权后求和,作为新一轮迭代的初始参数。
客户端训练,接收服务端发送的聚合参数来初始化本地模型,每训练一个时期(epoch),将模型参数发回服务端。客户端训练,从客户端i的“localdatasetDi”获取一个批次的数据。通过特征提取器fi提取图像特征,以特征之间协方差矩阵的Frobenius范数为损失,训练基于策略梯度的强化学习样本加权策略,使特征之间去相关。通过样本加权策略得到样本权重,并计算分类器得到的预测标签与真实标签的交叉熵,与样本权重相乘求和作为分类损失。更新特征提取、分类器和策略的参数,发送回服务端。
为了进一步说明本方案,本申请还提供一种基于强化联邦域泛化的图像分类模型训练方法及用于图像分类的联邦学习系统的具体应用实例,本申请应用实例在通过对训练样本加权,使训练样本的每一维特征尽可能独立,联邦的训练图像分类模型,学习一个不随着域改变的图像分类规则。将样本加权工作通过强化学习技术转化为一个参数化的可以在在各客户端共享的样本加权策略,可以更好地从全局的角度对特征去相关。本方法基于联邦学习框架,支持以多方协作的方式,以加权后的样本联邦的训练模型,从多个客户端私有数据集中,学习域不变的知识。
参见图5,用于图像分类的联邦学习系统中的客户端包括有三个部分,特征提取、基于策略的强化特征去相关和分类器,通过特征提取模块提取图像特征,采用强化特征去相关技术,学习一个可以使样本特征尽可能独立的样本加权参数化的策略,来预测样本权重。分类器计算加权后的损失,更新模型特征提取和分类器参数,使模型可以学习到域不变的知识。服务端主要对客户端中的三个部分的参数采用FedAvg方法进行聚合,分享编码与分类知识的同时也分享样本加权的策略。在联邦学习任务中,本申请应用实例通常有一个服务器(serverS)和一组客户端设备(client);C={C1,C2...Cn},每个客户端设备(client)独自维护一个包含标签的图像数据集(Xi,Yi),其中i表示数据集属于客户端设备(client)i,mx1mx2是图像的长和宽,d是图像的通道数量,my是类别数量。在本地训练过程中,本申请应用实例通过批量方法训练本地模型,在一个轮次(batch)中包含mb条数据,通过特征提取模块得到一组特征/>本申请应用实例可以将第k维特征表示为:/>其中,/>为特征/>第k维的元素。
具体来说,所述用于图像分类的联邦学习系统的具体内容如下:
(一)特征提取
特征提取需要将难以直接计算的图像信息转化为易于计算的向量形式,本质上是学习一种从图像到特征H的映射:f:x→h,其中,mh是特征维度。特征提取模块需要有提取图像丰富的深层特征的能力。以往的工作设计了众多通用的特征选择模型结构,如VGG和Resnet,其中Resnet在众多应用中展示的优异效果,充分说明了它提取图像深层特征的能力。本申请应用实例选择Resnet作为本申请应用实例的特征提取器,客户端设备(client)中的特征提取可以形式化定义为:
s=f(x)
(二)基于策略的强化特征去相关
通过强化学习的策略梯度算法,学习一种可以使样本特征尽可能独立的样本加权策略。它根据策略智能体与环境的交互产生的基于特征独立性计算的奖励,学习样本到权重的计算策略。这种策略是参数化的,可以通过联邦学习分享与聚合的。本申请应用实例通过独立性检验,衡量特征之间的相关性。以消除特征之间的依赖关系为目标,使模型学习到可以泛化的知识。
本方法采用强化学习的策略梯度方法来拟合这种策略。策略梯度主要包含三个元素,状态、环境、反馈和动作,智能体通过与环境的交互感知状态,根据状态与规则产生动作,然后根据动作环境会产生反馈,智能体根据状态和对应的反馈调整参数,来学习参数化的规则。
(1)状态,样本的状态是样本在本轮的特征s,由于特征提取模块的参数每一轮是会发生变化的,所以本申请应用实例每次接收一个轮次(batch)的样本,并且重新提取特征。
(2)动作,在本文中动作空间的设计是连续的,智能体根据样本的状态对每个样本赋予权重。本申请应用实例构造了一个参数化的策略函数,来捕获从样本特征到权重的映射策略:
a=π(s|θp)
其中,π为参数化的强化学习策略,θp为参数,a为根据样本的状态信息得到样本加权动作,也就是样本权重。
(3)奖励,根据状态与动作,反馈一定的奖励,以奖励最大化为目标,更新策略的参数。样本加权后特征之间越独立,奖励值越高,反之奖励值越低。satblenet通过协方差来计算特征之间的相关性,利用随机傅立叶特征(RFF)和样本加权的特性,消除了特征之间的线性和非线性依赖关系。本申请应用实例将经过RFF特征映射的g定义为
本申请应用实例采用对样本加权的方式,使特征之间尽可能的独立,添加了样本权重的协方差矩阵计算方式:
其中,n为样本数量,am为第m个样本的权重,为m个样本的第j维特征经过RFF映射后的结果。
所有特征之间的独立性计算结果可以通过协方差矩阵的Frobenius范数来计算,结果越接近于0,特征越独立,根据特征之间独立性评分构建reward函数:
(4)经验回放,理想状态下的特征去相关是需要对每个样本加权后,使特征之间独立的。但是由于联邦学习数据不出本地的限制,不能直接从全局的角度计算样本权重。为了更好的拟合全局样本加权策略,本申请应用实例设计了新的经验回放方法,在计算reward时考虑历史数据,构建一个历史数据缓冲区B保留历史数据与权重。在历史数据已经被赋予权重的条件下,学习新数据的加权策略。而与传统不同,本方法的经验回放不是单纯的加入以往数据重新训练,而是随机在历史数据缓冲区中提取一个轮次(batch)的数据sr与对应权重ar,作为新数据加权的条件,补充强化学习环境信息。也就是在训练时,连接随机提取的数据与当前轮次(batch)的数据se的特征,固定历史数据的权重ar,计算通过策略π得到的当前轮次(batch)数据的权重ae,计算本轮的reward,模型以reward最大化为目标函数训练模型,如式:
本方法在缓冲区中直接保留的是数据,没有直接保留数据特征作为历史信息的补充,是因为每个轮次(batch)训练后,特征提取模块的参数已经发生了变化,特征的空间也会发生偏移,所以本申请应用实例选择保留历史数据,每一个轮次(batch)重新提取特征的方式保留历史信息。
由于样本加权策略的学习是一个连续空间优化问题,所以本方法采用策略梯度算法,与经典的DQN算法通过近似估算状态-动作值函数π(s|θp)来推断最优策略不同,策略梯度方法则是直接优化策略参数,可以根据简单的状态、动作和奖励来更好地拟合策略。
(三)分类器
分类器c由一个单层感知机来构造,输入维度为特征维度,输出维度为样本类别数量。图像通过特征提取模块得到样本特征,将样本特征输入图像分类器,得到样本预测的标签,根据预测结果与真实标签计算交叉熵,乘以样本权重w作为分类损失Lc
通过对样本加权降低特征相关性,相关的特征不会由于特定数据的原因被放大与强化,使模型学习到错误的不能在未知域上泛化的规则。
进一步地,基于所述用于图像分类的联邦学习系统,本申请应用提供的基于强化联邦域泛化的图像分类模型训练方法的具体内容如下:
联邦学习主要分为两个过程客户端训练和服务端聚合,具体步骤如表1所述的算法1。
(1)服务端聚合,接收客户端发回的参数,根据客户端数据量对参数加权后聚合,将聚合后的模型参数发回客户端。在首次迭代时,对特征提取、分类器和策略的参数进行初始化得到f0,c0和Policy0,并下发给各个联邦客户端,各客户端训练后,将训练好的参数返回服务端,服务端根据客户端的数据量占数据总量的比例对每个客户端返回的参数加权后求和,作为新一轮迭代的初始参数。
(2)客户端训练,接收服务端发送的聚合参数来初始化本地模型,每训练一个epoch,将模型参数发回服务端。客户端训练,从客户端i的“localdatasetDi”获取一个批次的数据。通过特征提取器fi提取图像特征,以特征之间协方差矩阵的Frobenius范数为损失,训练基于策略梯度的强化学习样本加权策略,使特征之间去相关。通过样本加权策略得到样本权重,并计算分类器得到的预测标签与真实标签的交叉熵,与样本权重相乘求和作为分类损失。更新特征提取、分类器和策略的参数,发送回服务端。
具体来说,客户端的完整训练过程如下:
接收服务端发送的模型参数来初始化本地模型。根据数据缓冲区中数据量决定缓冲区初始化方式,当缓冲区非空时随机保留一个轮次(batch)的数据以及数据对应的权重,当缓冲区中没有数据时不采取操作(只有在联邦模型第一轮训练的初始化会出现缓冲区为空的情况)。从客户端i的本地数据Di获取一个批次的数据{x,y},其中x,y分别为一个轮次(batch)的数据集合和标签集合,数据量为batch size,从数据缓冲区中随机选择batchsize数量的样本集合xe,r与他们的权重集合ae,r,作为历史信息的补充。通过特征提取器fi提取数据集合x和xe,r中图像的特征得到特征集合se和se,r,固定权重集合ae,r,通过样本加权策略πi得到x中样本的权重ae。计算特征之间协方差矩阵的Frobenius范数,取它的负值作为策略的奖励,更新样本加权策略πi的参数,重新计算样本权重ae,使特征之间去相关。计算分类器根据输入se得到的预测标签与真实标签y的交叉熵,与样本权重相乘求和,作为分类损失。更新特征提取、分类器的参数。特征提取、分类器和样本加权策略的参数将发送回服务端。其中,e和t分别表示迭代的epoch和batch编号。
表1
/>
本申请应用实例提出一种强化联邦域泛化方法,首次将基于样本重加权的特征去相关技术引入联邦学习中,防止模型学习到数据中的虚假关联,提高了联邦训练的图像分类模型对未知域图像分类的能力。还设计了基于强化学习的特征去相关策略学习方法,通过策略梯度方法学习参数化的样本到权重的映射策略,通过样本加权策略参与联邦训练的方式,进一步从全局的角度强化了特征去相关工作。
综上所述,本申请应用实例提供的基于强化联邦域泛化的图像分类模型训练方法,针对联邦学习中的图像分类任务,提出了基于强化联邦域泛化的图像分类模型,在目标数据域不可知的条件下,学习域不变的图像分类知识,提高了联邦训练的图像分类模型对未知域图像分类的泛化能力。可以实现基于策略的强化特征去相关方法,以样本的特征为状态信息,设计可以根据样本状态得到样本权重的参数化策略,以特征独立性最大化为目标设计奖励函数,学习样本加权策略,并采用经验回放技巧从全局的角度对特征去相关。
本申请实施例还提供了一种电子设备(也即电子设备),例如中心服务器,该电子设备可以包括处理器、存储器、接收器及发送器,处理器用于执行上述实施例提及的基于强化联邦域泛化的图像分类模型训练方法,其中处理器和存储器可以通过总线或者其他方式连接,以通过总线连接为例。该接收器可通过有线或无线方式与处理器、存储器连接。
处理器可以为中央处理器(Central Processing Unit,CPU)。处理器还可以为其他通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等芯片,或者上述各类芯片的组合。
存储器作为一种非暂态计算机可读存储介质,可用于存储非暂态软件程序、非暂态计算机可执行程序以及模块,如本申请实施例中的基于强化联邦域泛化的图像分类模型训练方法对应的程序指令/模块。处理器通过运行存储在存储器中的非暂态软件程序、指令以及模块,从而执行处理器的各种功能应用以及数据处理,即实现上述方法实施例中的基于强化联邦域泛化的图像分类模型训练方法。
存储器可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序;存储数据区可存储处理器所创建的数据等。此外,存储器可以包括高速随机存取存储器,还可以包括非暂态存储器,例如至少一个磁盘存储器件、闪存器件、或其他非暂态固态存储器件。在一些实施例中,存储器可选包括相对于处理器远程设置的存储器,这些远程存储器可以通过网络连接至处理器。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
所述一个或者多个模块存储在所述存储器中,当被所述处理器执行时,执行实施例中的基于强化联邦域泛化的图像分类模型训练方法。
在本申请的一些实施例中,用户设备可以包括处理器、存储器和收发单元,该收发单元可包括接收器和发送器,处理器、存储器、接收器和发送器可通过总线系统连接,存储器用于存储计算机指令,处理器用于执行存储器中存储的计算机指令,以控制收发单元收发信号。
作为一种实现方式,本申请中接收器和发送器的功能可以考虑通过收发电路或者收发的专用芯片来实现,处理器可以考虑通过专用处理芯片、处理电路或通用芯片实现。
作为另一种实现方式,可以考虑使用通用计算机的方式来实现本申请实施例提供的服务器。即将实现处理器,接收器和发送器功能的程序代码存储在存储器中,通用处理器通过执行存储器中的代码来实现处理器,接收器和发送器的功能。
本申请实施例还提供一种计算机可读存储介质,其上存储有计算机程序,该计算机程序被处理器执行时以实现前述基于强化联邦域泛化的图像分类模型训练方法的步骤。该计算机可读存储介质可以是有形存储介质,诸如随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、软盘、硬盘、可移动存储盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质。
本领域普通技术人员应该可以明白,结合本文中所公开的实施方式描述的各示例性的组成部分、系统和方法,能够以硬件、软件或者二者的结合来实现。具体究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。当以硬件方式实现时,其可以例如是电子电路、专用集成电路(ASIC)、适当的固件、插件、功能卡等等。当以软件方式实现时,本申请的元素是被用于执行所需任务的程序或者代码段。程序或者代码段可以存储在机器可读介质中,或者通过载波中携带的数据信号在传输介质或者通信链路上传送。
需要明确的是,本申请并不局限于上文所描述并在图中示出的特定配置和处理。为了简明起见,这里省略了对已知方法的详细描述。在上述实施例中,描述和示出了若干具体的步骤作为示例。但是,本申请的方法过程并不限于所描述和示出的具体步骤,本领域的技术人员可以在领会本申请的精神后,作出各种改变、修改和添加,或者改变步骤之间的顺序。
本申请中,针对一个实施方式描述和/或例示的特征,可以在一个或更多个其它实施方式中以相同方式或以类似方式使用,和/或与其他实施方式的特征相结合或代替其他实施方式的特征。
以上所述仅为本申请的优选实施例,并不用于限制本申请,对于本领域的技术人员来说,本申请实施例可以有各种更改和变化。凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (6)

1.一种基于强化联邦域泛化的图像分类模型训练方法,其特征在于,包括:
根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充;
以本训练批次数据作为局部信息,补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种消除样本特征各维度之间的相关性的样本加权策略;
应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数;
将所述模型参数发送至自身所在的联邦学习系统中的服务器,以使该服务器对该模型参数和其接收的其他多个模型参数进行聚合以得到所述图像分类模型当前的总模型参数以应用所述图像分类模型进行图像分类;
在所述根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充之前,还包括:
接收自身所在的联邦学习系统中的服务器发送的基于强化联邦域泛化的图像分类模型的总模型参数;
根据所述总模型参数对本地的基于强化联邦域泛化的图像分类模型进行初始化处理;
自历史图像样本中选取预设数量的已设有类型标签的历史图像样本以作为当前的目标图像样本,形成包含有各个所述目标图像样本及对应的类型标签的历史图像数据;
所述根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充,包括:
在本地的数据缓冲区中随机选取与所述目标图像样本的数量相同的已设有权重的历史图像样本以作为当前的补充图像样本,形成包含有各个所述补充图像样本和对应的权重的用于补充强化学习环境的补充数据;
所述基于强化联邦域泛化的图像分类模型包括:特征提取器、基于强化学习的策略梯度模块和分类器;
所述特征提取器用于根据输入的各个目标图像样本和各个补充图像样本,对应输出各个目标图像样本和各个补充图像样本各自对应的特征向量;
所述基于强化学习的策略梯度模块用于根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重;
所述分类器用于根据各个目标图像样本的特征向量,对应输出各个所述目标图像样本的类型预测标签,并计算各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失;
所述以本训练批次数据作为局部信息,补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种消除样本特征各维度之间的相关性的样本加权策略;应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数,包括:
将各个所述目标图像样本和各个所述补充图像样本输入所述特征提取器,以使该特征提取器输出各个目标图像样本和各个补充图像样本各自对应的特征向量;
将各个目标图像样本的特征向量输入所述分类器,以使该分类器输出各个所述目标图像样本的类型预测标签;
以及,将各个目标图像样本和各个补充图像样本各自对应的特征向量,以及各个补充图像样本的权重,输入所述基于强化学习的策略梯度模块,以使该策略梯度模块根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重;
计算所述分类器输出的各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失,并得到更新后的图像分类模型的模型参数。
2.根据权利要求1所述的基于强化联邦域泛化的图像分类模型训练方法,其特征在于,所述模型参数包括:所述特征提取器的参数、所述分类器的参数和所述基于强化学习的策略梯度模块的样本加权策略的参数。
3.一种基于强化联邦域泛化的图像分类模型训练装置,其特征在于,包括:
经验回放模块,用于根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充;
强化特征去相关策略学习模块,用于以本训练批次数据作为局部信息,经验回放模块获取的补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种消除样本特征各维度之间的相关性的样本加权策略;
强化联邦训练模块,用于应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数;
数据发送模块,用于将所述模型参数发送至自身所在的联邦学习系统中的服务器,以使该服务器对该模型参数和其接收的其他多个模型参数进行聚合以得到所述图像分类模型当前的总模型参数以应用所述图像分类模型进行图像分类;
所述基于强化联邦域泛化的图像分类模型训练装置还用于执行下述内容:
在所述根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充之前,还包括:
接收自身所在的联邦学习系统中的服务器发送的基于强化联邦域泛化的图像分类模型的总模型参数;
根据所述总模型参数对本地的基于强化联邦域泛化的图像分类模型进行初始化处理;
自历史图像样本中选取预设数量的已设有类型标签的历史图像样本以作为当前的目标图像样本,形成包含有各个所述目标图像样本及对应的类型标签的历史图像数据;
其中,所述根据本地历史图像数据随机选取对应的补充数据,作为特征去相关策略的全局视角补充,包括:
在本地的数据缓冲区中随机选取与所述目标图像样本的数量相同的已设有权重的历史图像样本以作为当前的补充图像样本,形成包含有各个所述补充图像样本和对应的权重的用于补充强化学习环境的补充数据;
所述基于强化联邦域泛化的图像分类模型包括:特征提取器、基于强化学习的策略梯度模块和分类器;
所述特征提取器用于根据输入的各个目标图像样本和各个补充图像样本,对应输出各个目标图像样本和各个补充图像样本各自对应的特征向量;
所述基于强化学习的策略梯度模块用于根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重;
所述分类器用于根据各个目标图像样本的特征向量,对应输出各个所述目标图像样本的类型预测标签,并计算各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失;
其中,所述以本训练批次数据作为局部信息,补充数据作为全局视角补充,在历史图像数据权重固定的条件下,学习一种消除样本特征各维度之间的相关性的样本加权策略;应用根据强化特征去相关策略学习模块对图像数据加权,并基于加权后的图像训练特征提取器和分类器,以得到更新后的图像分类模型的模型参数,包括:
将各个所述目标图像样本和各个所述补充图像样本输入所述特征提取器,以使该特征提取器输出各个目标图像样本和各个补充图像样本各自对应的特征向量;
将各个目标图像样本的特征向量输入所述分类器,以使该分类器输出各个所述目标图像样本的类型预测标签;
以及,将各个目标图像样本和各个补充图像样本各自对应的特征向量,以及各个补充图像样本的权重,输入所述基于强化学习的策略梯度模块,以使该策略梯度模块根据当前的样本加权策略计算得到各个目标图像样本的特征向量各自对应的初始权重,并计算各个所述特征向量之间协方差矩阵的Frobenius范数以确定奖励,根据该奖励更新所述样本加权策略,并根据更新后的样本加权策略对所述初始权重进行优化,得到各个所述目标图像样本的特征向量各自对应的目标权重;
计算所述分类器输出的各个所述目标图像样本的类型预测标签和所述类型标签之间的交叉熵,再分别计算各个所述目标图像样本的交叉熵与所述目标权重的乘积之和,得到对应的分类损失,并得到更新后的图像分类模型的模型参数。
4.一种用于图像分类的联邦学习系统,其特征在于,包括:服务器和与该服务器之间通信连接的多个客户端设备;
各个所述客户端设备分别用于实现权利要求1或2所述的基于强化联邦域泛化的图像分类模型训练方法;
所述服务器用于接收各个所述客户端设备分别在各自本地训练得到的当前训练轮次的模型参数,并对各个所述模型参数进行聚合处理,以得到当前训练轮次的所述图像分类模型当前的总模型参数,并在下一个训练轮次时将所述总模型参数分别发送至各个所述客户端设备。
5.一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述计算机程序时实现如权利要求1或2所述的基于强化联邦域泛化的图像分类模型训练方法。
6.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该计算机程序被处理器执行时实现如权利要求1或2所述的基于强化联邦域泛化的图像分类模型训练方法。
CN202211539820.9A 2022-12-03 2022-12-03 基于强化联邦域泛化的图像分类模型训练方法及系统 Active CN115731424B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211539820.9A CN115731424B (zh) 2022-12-03 2022-12-03 基于强化联邦域泛化的图像分类模型训练方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211539820.9A CN115731424B (zh) 2022-12-03 2022-12-03 基于强化联邦域泛化的图像分类模型训练方法及系统

Publications (2)

Publication Number Publication Date
CN115731424A CN115731424A (zh) 2023-03-03
CN115731424B true CN115731424B (zh) 2023-10-31

Family

ID=85299855

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211539820.9A Active CN115731424B (zh) 2022-12-03 2022-12-03 基于强化联邦域泛化的图像分类模型训练方法及系统

Country Status (1)

Country Link
CN (1) CN115731424B (zh)

Families Citing this family (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115952442B (zh) * 2023-03-09 2023-06-13 山东大学 基于全局鲁棒加权的联邦域泛化故障诊断方法及系统
CN116363421A (zh) * 2023-03-15 2023-06-30 北京邮电大学 图像的特征分类方法、装置、电子设备及介质
CN116452922B (zh) * 2023-06-09 2023-09-22 深圳前海环融联易信息科技服务有限公司 模型训练方法、装置、计算机设备及可读存储介质
CN116541779B (zh) * 2023-07-07 2023-10-31 北京邮电大学 个性化公共安全突发事件检测模型训练方法、检测方法及装置
CN117992873A (zh) * 2024-03-20 2024-05-07 合肥工业大学 基于异构联邦学习的变压器故障分类方法及模型训练方法

Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113420888A (zh) * 2021-06-03 2021-09-21 中国石油大学(华东) 一种基于泛化域自适应的无监督联邦学习方法
CN113571203A (zh) * 2021-07-19 2021-10-29 复旦大学附属华山医院 多中心基于联邦学习的脑肿瘤预后生存期预测方法及系统
CN113688862A (zh) * 2021-07-09 2021-11-23 深圳大学 一种基于半监督联邦学习的脑影像分类方法及终端设备
CN113779563A (zh) * 2021-08-05 2021-12-10 国网河北省电力有限公司信息通信分公司 联邦学习的后门攻击防御方法及装置
CN114943345A (zh) * 2022-06-10 2022-08-26 西安电子科技大学 基于主动学习和模型压缩的联邦学习全局模型训练方法
CN115034836A (zh) * 2022-08-12 2022-09-09 腾讯科技(深圳)有限公司 一种模型训练方法及相关装置
CN115062710A (zh) * 2022-06-22 2022-09-16 西安电子科技大学 基于深度确定性策略梯度的联邦学习分类模型训练方法
CN115081532A (zh) * 2022-07-01 2022-09-20 西安电子科技大学 基于记忆重放和差分隐私的联邦持续学习训练方法
CN115310121A (zh) * 2022-07-12 2022-11-08 华中农业大学 车联网中基于MePC-F模型的实时强化联邦学习数据隐私安全方法
CN115331069A (zh) * 2022-07-01 2022-11-11 中银金融科技有限公司 一种基于联邦学习的个性化图像分类模型训练方法

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10922409B2 (en) * 2018-04-10 2021-02-16 Microsoft Technology Licensing, Llc Deep reinforcement learning technologies for detecting malware

Patent Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113420888A (zh) * 2021-06-03 2021-09-21 中国石油大学(华东) 一种基于泛化域自适应的无监督联邦学习方法
CN113688862A (zh) * 2021-07-09 2021-11-23 深圳大学 一种基于半监督联邦学习的脑影像分类方法及终端设备
CN113571203A (zh) * 2021-07-19 2021-10-29 复旦大学附属华山医院 多中心基于联邦学习的脑肿瘤预后生存期预测方法及系统
CN113779563A (zh) * 2021-08-05 2021-12-10 国网河北省电力有限公司信息通信分公司 联邦学习的后门攻击防御方法及装置
CN114943345A (zh) * 2022-06-10 2022-08-26 西安电子科技大学 基于主动学习和模型压缩的联邦学习全局模型训练方法
CN115062710A (zh) * 2022-06-22 2022-09-16 西安电子科技大学 基于深度确定性策略梯度的联邦学习分类模型训练方法
CN115081532A (zh) * 2022-07-01 2022-09-20 西安电子科技大学 基于记忆重放和差分隐私的联邦持续学习训练方法
CN115331069A (zh) * 2022-07-01 2022-11-11 中银金融科技有限公司 一种基于联邦学习的个性化图像分类模型训练方法
CN115310121A (zh) * 2022-07-12 2022-11-08 华中农业大学 车联网中基于MePC-F模型的实时强化联邦学习数据隐私安全方法
CN115034836A (zh) * 2022-08-12 2022-09-09 腾讯科技(深圳)有限公司 一种模型训练方法及相关装置

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
Selecting a Suitable Feature Subset for Classification using Multi-Agent Reinforcement Learning;Minwoo Kim 等;《ICTC 2021》;501-504 *
融合强化学习和关系网络的样本分类;张碧陶 等;《计算机工程与应用》;第55卷(第21期);189-196 *
面向加密数据的安全图像分类模型研究综述;孙隆隆 等;《密码学报》;第7卷(第4期);525-540 *

Also Published As

Publication number Publication date
CN115731424A (zh) 2023-03-03

Similar Documents

Publication Publication Date Title
CN115731424B (zh) 基于强化联邦域泛化的图像分类模型训练方法及系统
CN109726794B (zh) 基于关注的图像生成神经网络
US10909380B2 (en) Methods and apparatuses for recognizing video and training, electronic device and medium
US11379722B2 (en) Method for training generative adversarial network (GAN), method for generating images by using GAN, and computer readable storage medium
US10402469B2 (en) Systems and methods of distributed optimization
CN110926782B (zh) 断路器故障类型判断方法、装置、电子设备及存储介质
CN108229591A (zh) 神经网络自适应训练方法和装置、设备、程序和存储介质
US20220351039A1 (en) Federated learning using heterogeneous model types and architectures
WO2023174036A1 (zh) 联邦学习模型训练方法、电子设备及存储介质
CN116229170A (zh) 基于任务迁移的联邦无监督图像分类模型训练方法、分类方法及设备
CN110704599B (zh) 为预测模型生成样本、预测模型训练的方法及装置
CN116310530A (zh) 基于语义聚类的联邦无监督图像分类模型训练方法、分类方法及设备
US20220337455A1 (en) Enhancement of channel estimation in wireless communication based on supervised learning
CN117999562A (zh) 用于量化联邦学习中的客户端贡献的方法和系统
CN114612688A (zh) 对抗样本生成方法、模型训练方法、处理方法及电子设备
US11539504B2 (en) Homomorphic operation accelerator and homomorphic operation performing device including the same
CN116432039B (zh) 协同训练方法及装置、业务预测方法及装置
CN115965078A (zh) 分类预测模型训练方法、分类预测方法、设备及存储介质
CN106878403B (zh) 基于最近探索的启发式服务组合方法
EP4002213A1 (en) System and method for training recommendation policies
CN112541129B (zh) 处理交互事件的方法及装置
JP7148078B2 (ja) 属性推定装置、属性推定方法、属性推定器学習装置、及びプログラム
JP7024687B2 (ja) データ分析システム、学習装置、方法、及びプログラム
CN115081626B (zh) 基于表征学习的个性化联邦少样本学习系统及方法
Hossain et al. Fedavo: Improving communication efficiency in federated learning with african vultures optimizer

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant