CN114882245B - 一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统 - Google Patents
一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统 Download PDFInfo
- Publication number
- CN114882245B CN114882245B CN202210438889.6A CN202210438889A CN114882245B CN 114882245 B CN114882245 B CN 114882245B CN 202210438889 A CN202210438889 A CN 202210438889A CN 114882245 B CN114882245 B CN 114882245B
- Authority
- CN
- China
- Prior art keywords
- feature extraction
- network
- classifier
- data
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
- G06V10/46—Descriptors for shape, contour or point-related descriptors, e.g. scale invariant feature transform [SIFT] or bags of words [BoW]; Salient regional features
- G06V10/462—Salient features, e.g. scale invariant feature transforms [SIFT]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F21/00—Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
- G06F21/60—Protecting data
- G06F21/62—Protecting access to data via a platform, e.g. using keys or access control rules
- G06F21/6218—Protecting access to data via a platform, e.g. using keys or access control rules to a system of files or objects, e.g. local or distributed file system or database
- G06F21/6245—Protecting personal data, e.g. for financial or medical purposes
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/94—Hardware or software architectures specially adapted for image or video understanding
- G06V10/95—Hardware or software architectures specially adapted for image or video understanding structured as a network, e.g. client-server architectures
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Software Systems (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Computing Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Medical Informatics (AREA)
- Multimedia (AREA)
- Mathematical Physics (AREA)
- Databases & Information Systems (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- Biophysics (AREA)
- Probability & Statistics with Applications (AREA)
- Bioethics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- Computer Hardware Design (AREA)
- Computer Security & Cryptography (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明涉及一种联邦多任务学习中基于特征提取‑子任务分类器的数据标签分类方法及系统,适用于中央节点式联邦学习系统。为了提升整体模型的有效性和精度并解决标签缺失数据的问题,本发明通过两步分离式的联邦多任务学习训练方式,实现了一个“特征提取‑子任务分类器”的统一网络架构设计。该设计方法能够解决联邦多任务学习中多标签数据的部分标签缺失问题并拥有较高的模型性能以及测试精度,最终训练出一个高性能多标签分类器网络,同时保护了用户节点的数据隐私。
Description
技术领域
本发明涉及一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统,属于人工智能技术领域。
背景技术
联邦学习属于分布式机器学习,是一种新兴的机器学习框架。随着大数据时代的到来,用户的数据安全和隐私保护越来越重要,诸多国家也出台了隐私保护相关法律法规。而对于训练大规模机器学习模型,传统的分布式机器学习往往不涉及数据隐私问题,中央服务器对计算节点以及其中的数据具有较高的控制权。2016年,谷歌公司提出了联邦学习,旨在每个用户数据不出本地仍可参与到模型的训练中去,以实现保护各参与者数据安全的目的。联邦学习中各个用户节点通过本地的私有数据训练模型,经中央服务器协调,聚合各个用户节点的模型参数,并更新全局的模型。这期间不涉及数据的传输,在很大的程度上保护了数据安全。详见文献[1]:Mcmahan H B,Moore E,D Ramage,et al.Communication-Efficient Learning of Deep Networks from Decentralized Data[J].2016.。
传统的机器学习训练中,数据往往是单标签的,即每个实例都与仅一个标签相关联,以表示其概念类的归属。然而,在许多实际应用中,一个对象通常会附带有多个标签,即一个实例对应于一组标签。例如,在文本分类任务中,一个文档可能属于多个主题,如“小说”、“社会”;在图像分类任务中,一幅图像可能属于多个语义,如“猫”、“白色”。使用多标签数据的多标签学习在从文档分类到基因功能预测和自动图像注释等诸多应用中起着至关重要的作用。多标签分类中,一种常见方法是问题转换,即将一个多标签问题转化为一个或多个单标签分类器来进行分类,再将其转换为多标签表示。详见文献[2]:Read J,Pfahringer B,Holmes G,et al.Classifier chains for multi-label classification[J].Machine Learning,2011,85(3):333-359.。
在多标签学习中,一个常见的假设是,所有的类标签及其值在训练过程之前都被观测到。然而,在一些实际应用中,由于标签的标注成本很高、一些标签在标注过程中的刻意省略以及部分标签存在未知性等因素,故一些观测到的标签是缺失的,甚至有部分标签并没有被观测到。这给多标签分类任务造成了巨大的困难。因此,如何在多标签分类任务中,解决标签缺失的问题并保证较好的分类精度,得到了广泛关注。用现有技术来解决标签缺失数据的多标签学习问题,详见文献[3]:Sun Y Y,Zhang Y,Zhou Z H.Multi-labellearning with weak label[C]//Twenty-fourth AAAI conference on artificialintelligence.2010.,一个基本的先决条件是每个标签至少有一个正向的数据示例,即每个标签至少在数据中出现一次。但此类方法无法解决某一标签完全缺失的问题,有一定的局限性,实用性不足。
发明内容
针对现有技术的不足,本发明提供了一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法。
本发明通过两步分离式的联邦多任务学习训练方式,实现了一个“特征提取-子任务分类器”的统一网络架构设计。所有用户节点参与构建适用于所有用户数据的特征提取网络,该特征提取网络在所给用户数据上具有普适性。原始训练图像经过特征提取网络,输出提取显著特征后的图像数据,使得在下一步分类器网络的训练中降低训练损失,提高测试精度。通过子任务分类器网络,使得某些用户节点的数据缺失某些标签而不能完成模型训练的问题得以解决。其中,每个子任务分类器网络不再单独训练特征提取层,而是通过训练一个针对于所有用户的特征提取网络,完成输入图像的特征提取,降低子任务分类器网络模型复杂度。该设计方法能够解决联邦多任务学习中多标签数据的部分标签缺失问题并拥有较高的模型性能以及测试精度,最终训练出一个高性能多标签分类器网络,同时保护了用户节点的数据隐私。
本发明还提供了一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类系统。
术语解释:
1、MBGD:小批量梯度下降法;
2、MSELoss:均方误差损失函数;
3、CrossEntropyLoss:交叉熵损失函数;
4、One-Hot:独热编码,又称一位有效编码,其方法是使用N位状态寄存器来对N个状态进行编码,每个状态都有它独立的寄存器位,并且在任意时候,其中只有一位有效。
本发明的技术方案为:
一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,适用于中央节点式联邦学习系统,所述中央节点式联邦学习系统包括M个用户节点和1个中央服务器,每个用户节点均与中心服务器相连接;设定所有用户的训练数据均为多标签数据,且来自于同一个特征空间,标签的总数为L;对任意的用户,其每一个本地数据点都拥有相同种类的标签;第m个用户拥有的本地数据数目用Km来表示,且所有用户的数据数之和为K,即满足:在第m个用户上的本地数据集用Dm表示,即|Dm|=Km;包括:
构建并训练全局模型,全局模型包括特征提取网络以及多个分类器网络;
将待分类的图像输入到训练好的全局模型,图像数据经过特征提取网络,提取特征;提取特征后的图像数据再经过所有分类器网络,每个分类器网络分别输出该待分类的图像对于每一种标签中,属于各个类别的概率输出值;每一种标签选择概率输出值最大的类别作为此标签的分类结果,最终输出每一种标签的分类结果;
其中,全局模型的训练过程为:
第一步,训练特征提取网络:
在第t个特征提取网络训练周期中,用户节点m收到由中央服务器广播的最新特征提取网络的模型参数wt,并以此为初始模型,使用本地数据集Dm,在多轮本地迭代训练中,通过MBGD法,得到更新后的本地特征提取网络,其模型参数为wm,t,m的取值为1,2,3...M,且m为正整数;在所有用户节点完成一轮训练后,各个用户节点将各自更新后的本地特征提取网络的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的特征提取网络,其模型参数为wt+1;重复上述过程,直至中央服务器端的特征提取网络收敛;每个用户节点都有对应的本地特征提取网络,其网络架构和特征提取网络的网络架构相同;
第二步,训练多个分类器网络:
根据每个用户节点对应的数据标签对用户节点进行分组,设定分为L组,第i组的用户节点个数记为Mi,第i组中第mi个用户节点的本地数据集记声每组用户对应一个分类器网络,共训练出L个分类器网络;
对于第i组的所有用户节点,目标为训练一个分类器网络i,其中,i表示在所有分类器网络中此分类器网络的索引号;分类器网络i在第t个训练周期中,第mi个用户节点收到由中央服务器广播的最新分类器网络i的模型参数并以此为初始模型,使用本地数据集/>在多轮本地迭代训练中,通过MBGD法,得到更新后的本地分类器网络i,模型参数为mi的取值为1,2,3...Mi,且mi为正整数;
在该用户组所有用户节点完成一轮训练后,各用户节点将各自更新后的本地分类器网络i的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的分类器网络i,其模型参数为重复上述过程,直至中央服务器端的分类器网络i收敛;
对于全部L组用户节点均进行上述分类器网络的训练过程,直至中央服务器端的所有分类器网络收敛。
全局模型的训练过程包括:
在中央节点式联邦学习系统中,基于所有用户拥有的数据,构建一个统一的、适用于所有用户的特征提取网络
所有用户节点使用此特征提取网络,执行对本地数据的特征提取,得到提取显著特征后的图像数据;
在中央节点式联邦学习系统中,设定每一个用户节点的训练数据只拥有部分标签,即有一些标签是缺失的,且同一用户的数据所缺失标签是一致的;首先,根据用户节点拥有的标签进行分组,将拥有同一标签的用户节点集合称为一个用户组,形成多个用户组;之后;对于每一用户组的用户节点,均通过联邦学习的形式,训练出一个适用于此组标签的分类器网络;对于第i个标签的用户组所训练的分类器网络表示为
根据本发明优选的,定义特征提取网络的学习目标是最小化一个经验损失函数,如式(I)、(II)所示:
式(I)中,F(w)表示全局的平均训练损失,w表示d维的模型参数向量,Fm(w;Dm)表示第m个用户节点的本地平均训练损失;式(II)中,f(w;xmk,ymk)是第m个用户节点中第k个训练数据点(xmk,ymk)的训练损失,Dm={(xmk,ymk):1≤k≤Km}。
根据本发明优选的,在用户节点m收到由中央服务器广播的最新特征提取网络的模型参数wt,之后,每个用户节点根据其拥有的本地数据以及本地特征提取网络,计算出本地特征提取网络训练损失Fm(wt;Dm),同时,根据式(III),计算出本地特征提取网络训练损失的梯度gm,t:
式(III)中,友示训练损失Fm(w;Dm)在w=wt时的梯度;
在第t个特征提取网络训练周期中,所有的用户节点选择在本地通过MBGD法进行多次的本地特征提取网络训练损失的梯度更新;然后再将最新本地特征提取网络训练损失的梯度{gm,t}上传至中央服务器,并通过(IV)式完成参数的聚合:
根据本发明优选的,特征提取网络为卷积自编码器网络,包括编码器和解码器,编码器包括两个卷积层和池化层,实现对图片数据的特征提取;解码器的输入为特征提取后的图像数据,恢复出与原图片特征维度一致的图像数据,完成对原图像的重构过程。
根据本发明优选的,本地特征提取网络的损失函数选用MSELoss损失函数f(xi,yi),如式(V)所示:
f(xi,yi)=(xi-yi)2 (V)
其中,xi表示第i个原始图像数据,yi表示经过特征提取网络之后恢复出的第i个图像数据。
根据本发明优选的,第i个分类器网络的局平均训练损失Fi(w)以及第i个分类器网络平均训练损失分别如式(VI)和(VII)所示:
式(VI)和(VII)中,上标i表示该变量对应了第i个分类器网络;Fi(w)表示第i个分类器网络的全局平均训练损失,wi表示第i个分类器网络的参数向量,表示第i个分类器网络中,第m个用户节点的本地平均训练损失,/>则表示第k个训练数据点/>的训练损失,/>表示第i个分类器网络训练中第m个用户节点的数据集。
根据本发明优选的,每个分类器网络包括线性层和激活层,输入特征提取后的图像数据后,分类器网络分别输出图片对于特定一种标签中,属于各个类别的概率输出值,每一种标签选择概率最大的类别作为此标签的分类结果。
根据本发明优选的,每个分类器网络的本地损失函数均选用CrossEntropyLoss损失函数,其计算方法如(VIII)式所示:
如(VIII)中,输入xi是一个维度为j的向量,即经过分类器网络的输出结果;yi是One-Hot形式的标签向量,维度也为j。
一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,所述处理器执行所述计算机程序时实现联邦多任务学习中解决子任务数据标签缺失的方法的步骤。
一种计算机可读存储介质,其上存储有计算机程序,所述计算机程序被处理器执行时实现联邦多任务学习中解决子任务数据标签缺失的方法的步骤。
一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类系统,包括:
特征提取模块,被配置为,对待分类的图像进行特征提取,提取出图像数据的主要特征;使得图片数据的RGB特征分量增加,总特征数目明显提升;
标签分类模块,被配置为,从分类器网络中输出对应某一标签的分类结果。
本发明的有益效果为:
本发明针对中央节点式联邦学习系统应用场景,提出了一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法。通过两步分离式的联邦多任务学习架构,实现了联合“特征提取网络”、“分类器网络”的统一设计。原始训练图像经过统一的特征提取网络,输出提取显著特征后的图像数据,使得在下一步分类器网络的训练中降低训练损失,提高测试精度以及模型的有效性,同时,每个子任务分类器网络不再单独训练特征提取层,而是通过训练一个针对于所有用户的特征提取网络,降低了模型复杂度。通过分类器网络,使得部分用户节点数据缺失某些标签而不能完成模型训练的问题得以解决。以联邦学习的方式进行分组多任务训练,能够在保护用户数据隐私的前提下,训练出一个高性能的多标签分类器网络。
附图说明
图1是本发明联邦多任务学习中解决子任务数据标签缺失的方法的流程框图;
图2(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的训练损失示意图;
图2(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的训练损失示意图;
图3(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的测试精度示意图;
图3(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的测试精度示意图;
图4是卷积自编码器网络的结构示意图;
图5是分类器网络的网络的结构示意图。
具体实施方式
下面结合说明书附图和实施例对本发明作进一步限定,但不限于此。
实施例1
一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,适用于中央节点式联邦学习系统,所述中央节点式联邦学习系统包括M个用户节点和1个中央服务器,每个用户节点均与中心服务器相连接;设定所有用户的训练数据均为多标签数据,且来自于同一个特征空间,标签的总数为L;对任意的用户,其每一个本地数据点都拥有相同种类的标签;第m个用户拥有的本地数据数目用Km来表示,且所有用户的数据数之和为K,即满足:在第m个用户上的本地数据集用Dm表示,即|Dm|=Km;如图1所示,包括:
构建并训练全局模型,全局模型包括特征提取网络以及多个分类器网络;特征提取-子任务分类器即全局模型。
将待分类的图像输入到训练好的全局模型,图像数据经过特征提取网络,提取特征;提取特征后的图像数据再经过所有分类器网络,每个分类器网络分别输出该待分类的图像对于每一种标签中,属于各个类别的概率输出值;每一种标签选择概率输出值最大的类别作为此标签的分类结果,最终输出每一种标签的分类结果;
其中,全局模型的训练过程为:
第一步,训练特征提取网络:
在第t个特征提取网络训练周期中,用户节点m收到由中央服务器广播的最新特征提取网络的模型参数wt,并以此为初始模型,使用本地数据集Dm,在多轮本地迭代训练中,通过MBGD法,得到更新后的本地特征提取网络,其模型参数为wm,t,m的取值为1,2,3...M,且m为正整数;在所有用户节点完成一轮训练后,各个用户节点将各自更新后的本地特征提取网络的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的特征提取网络,其模型参数为wt+1;重复上述过程,直至中央服务器端的特征提取网络收敛;每个用户节点都有对应的本地特征提取网络,其网络架构和特征提取网络的网络架构相同;
第二步,训练多个分类器网络:
根据每个用户节点对应的数据标签对用户节点进行分组,设定分为L组,第i组的用户节点个数记为Mi,第i组中第mi个用户节点的本地数据集记为每组用户对应一个分类器网络,共训练出L个分类器网络;
对于第i组的所有用户节点,目标为训练一个分类器网络i,其中,i表示在所有分类器网络中此分类器网络的索引号;分类器网络i在第t个训练周期中,第mi个用户节点收到由中央服务器广播的最新分类器网络i的模型参数并以此为初始模型,使用本地数据集/>在多轮本地迭代训练中,通过MBGD法,得到更新后的本地分类器网络i,模型参数为mi的取值为1,2,3...Mi,且mi为正整数;
在该用户组所有用户节点完成一轮训练后,各用户节点将各自更新后的本地分类器网络i的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的分类器网络i,其模型参数为重复上述过程,直至中央服务器端的分类器网络i收敛;
对于全部L组用户节点均进行上述分类器网络的训练过程,直至中央服务器端的所有分类器网络收敛。
实施例2
根据实施例1所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其区别在于:
全局模型的训练过程包括:
为了提升整体模型的精度和有效性并解决标签缺失数据的问题,本发明采用“特征提取、子任务分类器”的两步分离式的网络架构;
为了降低后续子任务分类器网络的训练损失,提高测试精度,在中央节点式联邦学习系统中,基于所有用户拥有的数据,构建一个统一的、适用于所有用户的特征提取网络
所有用户节点使用此特征提取网络,执行对本地数据的特征提取,得到提取显著特征后的图像数据;用于后续的分类器模型训练;
在中央节点式联邦学习系统中,设定每一个用户节点的训练数据只拥有部分标签,即有一些标签是缺失的,且同一用户的数据所缺失标签是一致的;首先,运用多任务学习的思想,根据用户节点拥有的标签进行分组,将拥有同一标签的用户节点集合称为一个用户组,形成多个用户组;之后;对于每一用户组的用户节点,均通过联邦学习的形式,训练出一个适用于此组标签的分类器网络;对于第i个标签的用户组所训练的分类器网络表示为
将特征提取网络以及后续训练完毕的所有分类器网络相连接,构成一个统一的特征提取、子任务分类器的两步分离式的全局模型。
定义特征提取网络的学习目标是最小化一个经验损失函数,如式(I)、(II)所示:
式(I)中,F(w)表示全局的平均训练损失,w表示d维的模型参数向量,Fm(w;Dm)表示第m个用户节点的本地平均训练损失;式(II)中,f(w;xmk,ymk)是第m个用户节点中第k个训练数据点(xmk,ymk)的训练损失,Dm={(xmk,ymk):1≤k≤Km}。
在用户节点m收到由中央服务器广播的最新特征提取网络的模型参数wt,之后,每个用户节点根据其拥有的本地数据以及本地特征提取网络,计算出本地特征提取网络训练损失Fm(wt;Dm),同时,根据式(III),计算出本地特征提取网络训练损失的梯度gm,t:
式(III)中,表示训练损失Fm(w;Dm)在w=wt时的梯度;
在第t个特征提取网络训练周期中,所有的用户节点选择在本地通过MBGD法进行多次的本地特征提取网络训练损失的梯度更新;然后再将最新本地特征提取网络训练损失的梯度{gm,t}上传至中央服务器,并通过(IV)式完成参数的聚合:
特征提取网络为卷积自编码器网络,包括编码器和解码器,其网络结构如图4所示;编码器对输入数据进行压缩编码操作,解码器将编码后的数据恢复成原始数据;卷积神经网络通常模型在结构上可以分为:卷积层、池化层与全连接层。其中卷积层和池化层用能够实现对输入图像的特征提取。在特征提取网络的训练过程中,图片数据经过编码器、解码器两个部分。编码器包括两个卷积层和池化层,若输入图片的特征数为3×128×128,那么输出的特征数为328×32×32,图片数据的RGB特征分量增加,总特征数目明显提升,实现对图片数据的特征提取;解码器的输入为特征提取后的图像数据,通过与编码器相反的设置,恢复出与原图片特征维度一致的图像数据,完成对原图像的重构过程。其中,通过对比编码器恢复的图像数据与原始图像数据之间的差异程度,来衡量特征提取网络的性能。
对于训练完毕的特征提取网络,只采用编码器部分。将原始图像数据输入到编码器,得到提取显著特征后的图像数据,用于后续的分类器网络训练。
本地特征提取网络的损失函数选用MSELoss损失函数f(xi,yi),如式(V)所示:
f(xi,yi)=(xi-yi)2 (V)
其中,xi表示第i个原始图像数据,yi表示经过特征提取网络之后恢复出的第i个图像数据。
第i个分类器网络的局平均训练损失Fi(w)以及第i个分类器网络平均训练损失分别如式(VI)和(VII)所示:
式(VI)和(VII)中,上标i表示该变量对应了第i个分类器网络;Fi(w)表示第i个分类器网络的全局平均训练损失,wi表示第i个分类器网络的参数向量,表示第i个分类器网络中,第m个用户节点的本地平均训练损失,/>则表示第k个训练数据点/>的训练损失,/>表示第i个分类器网络训练中第m个用户节点的数据集。
因为每个用户节点的训练数据已经经过特征提取,每个分类器网络包括线性层和激活层,网络结构如图5所示。对于某一个分类器网络,输入特征提取后的图像数据后,分类器网络分别输出图片对于特定一种标签中,属于各个类别的概率输出值,每一种标签选择概率最大的类别作为此标签的分类结果。
每个分类器网络的本地损失函数均选用CrossEntropyLoss损失函数,其计算方法如(VIII)式所示:
如(VIII)中,输入xi是一个维度为j的向量,即经过分类器网络的输出结果;yi是One-Hot形式的标签向量,维度也为j。
选取CelebA数据集中的40000个数据点,下放到所有用户节点。这些数据点只附带有40个原始标签的中两个标签,每一个数据点只有一个标签,即为标签缺失。
图2(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的训练损失示意图;图2(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的训练损失示意图;横坐标为训练轮次,纵坐标为训练数据的损失。
图3(a)是本发明在CelebA数据集上,对于其中标签1的分类子任务中的测试精度示意图;图3(b)是本发明在CelebA数据集上,对于其中标签2的分类子任务中的测试精度示意图;其中,横坐标为训练轮次,纵坐标为测试数据的精度。
由图2(a)、图2(b)、图3(a)、图3(b)可知,由于在特征提取网络中,经卷积编码器提取出图像的显著特征,能提高子任务分类器的性能、稳定性以及分类精度。通过此“特征提取-子任务分类器”的网络架构,能够有效解决多标签数据的标签缺失的问题。同时以联邦学习的方式,在保证了参与训练用户的数据隐私的前提下,仍然能够保持较高模型性能以及测试精度,可见此设计的有效性。
将本发明应用到医学图像的标签识别上,训练数据为所有用户某医学图像,多类标签为不同病症或科室的诊断结果(存在标签缺失状况)。在训练医学图像智能诊断模型中,实现联邦多任务学习中解决子任务数据标签缺失的方法的步骤。
实施例3
一种计算机设备,包括存储器和处理器,存储器存储有计算机程序,处理器执行计算机程序时实现实施例1或2一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法的步骤。
实施例4
一种计算机可读存储介质,其上存储有计算机程序,计算机程序被处理器执行时实现实施例1或2一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法的步骤。
实施例5
一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类系统,包括:
特征提取模块,被配置为,对待分类的图像进行特征提取,提取出图像数据的主要特征;使得图片数据的RGB特征分量增加,总特征数目明显提升;
标签分类模块,被配置为,从分类器网络中输出对应某一标签的分类结果。
Claims (11)
1.一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,适用于中央节点式联邦学习系统,所述中央节点式联邦学习系统包括M个用户节点和1个中央服务器,每个用户节点均与中心服务器相连接;设定所有用户的训练数据均为多标签数据,且来自于同一个特征空间,标签的总数为L;对任意的用户,其每一个本地数据点都拥有相同种类的标签;第m个用户拥有的本地数据数目用Km来表示,且所有用户的数据数之和为K,即满足:在第m个用户上的本地数据集用Dm表示,即|Dm|=Km;包括:
构建并训练全局模型,全局模型包括特征提取网络以及多个分类器网络;
将待分类的图像输入到训练好的全局模型,图像数据经过特征提取网络,提取特征;提取特征后的图像数据再经过所有分类器网络,每个分类器网络分别输出该待分类的图像对于每一种标签中,属于各个类别的概率输出值;每一种标签选择概率输出值最大的类别作为此标签的分类结果,最终输出每一种标签的分类结果;
其中,全局模型的训练过程为:
第一步,训练特征提取网络:
在第t个特征提取网络训练周期中,用户节点m收到由中央服务器广播的最新特征提取网络的模型参数wt,并以此为初始模型,使用本地数据集Dm,在多轮本地迭代训练中,通过MBGD法,得到更新后的本地特征提取网络,其模型参数为wm,t,m的取值为1,2,3...M,且m为正整数;在所有用户节点完成一轮训练后,各个用户节点将各自更新后的本地特征提取网络的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的特征提取网络,其模型参数为wt+1;重复上述过程,直至中央服务器端的特征提取网络收敛;每个用户节点都有对应的本地特征提取网络,其网络架构和特征提取网络的网络架构相同;
第二步,训练多个分类器网络:
根据每个用户节点对应的数据标签对用户节点进行分组,设定分为L组,第i组的用户节点个数记为Mi,第i组中第mi个用户节点的本地数据集记为每组用户对应一个分类器网络,共训练出L个分类器网络;
对于第i组的所有用户节点,目标为训练一个分类器网络i,其中,i表示在所有分类器网络中此分类器网络的索引号;分类器网络i在第t个训练周期中,第mi个用户节点收到由中央服务器广播的最新分类器网络i的模型参数并以此为初始模型,使用本地数据集在多轮本地迭代训练中,通过MBGD法,得到更新后的本地分类器网络i,模型参数为的取值为1,2,3...Mi,且mi为正整数;
在该用户组所有用户节点完成一轮训练后,各用户节点将各自更新后的本地分类器网络i的模型参数上传至中央服务器,并在中央服务器进行参数的聚合,得到一个新的分类器网络i,其模型参数为重复上述过程,直至中央服务器端的分类器网络i收敛;
对于全部L组用户节点均进行上述分类器网络的训练过程,直至中央服务器端的所有分类器网络收敛。
2.根据权利要求1所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,定义特征提取网络的学习目标是最小化一个经验损失函数,如式(I)、(II)所示:
式(I)中,F(w)表示全局的平均训练损失,w表示d维的模型参数向量,Fm(w;Dm)表示第m个用户节点的本地平均训练损失;式(II)中,f(w;xm k,ym k)是第m个用户节点中第k个训练数据点(xm k,ym k)的训练损失,Dm={(xm k,ym k):1≤k≤Km}。
3.根据权利要求1所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,在用户节点m收到由中央服务器广播的最新特征提取网络的模型参数wt,之后,每个用户节点根据其拥有的本地数据以及本地特征提取网络,计算出本地特征提取网络训练损失Fm(wt;Dm),同时,根据式(III),计算出本地特征提取网络训练损失的梯度gm,t:
式(III)中,表示训练损失Fm(w;Dm)在w=wt时的梯度;
在第t个特征提取网络训练周期中,所有的用户节点选择在本地通过MBGD法进行多次的本地特征提取网络训练损失的梯度更新;然后再将最新本地特征提取网络训练损失的梯度{gm,t}上传至中央服务器,并通过(IV)式完成参数的聚合:
4.根据权利要求1所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,特征提取网络为卷积自编码器网络,包括编码器和解码器,编码器包括两个卷积层和池化层,实现对图片数据的特征提取;解码器的输入为特征提取后的图像数据,恢复出与原图片特征维度一致的图像数据,完成对原图像的重构过程。
5.根据权利要求1所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,本地特征提取网络的损失函数选用MSELoss损失函数f(xi,yi),如式(V)所示:
f(xi,yi)=(xi-yi)2 (V)
其中,xi表示第i个原始图像数据,yi表示经过特征提取网络之后恢复出的第i个图像数据。
6.根据权利要求1所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,第i个分类器网络的局平均训练损失Fi(w)以及第i个分类器网络平均训练损失分别如式(VI)和(VII)所示:
式(VI)和(VII)中,上标i表示该变量对应了第i个分类器网络;Fi(w)表示第i个分类器网络的全局平均训练损失,wi表示第i个分类器网络的参数向量,表示第i个分类器网络中,第m个用户节点的本地平均训练损失,/>则表示第k个训练数据点的训练损失,/>表示第i个分类器网络训练中第m个用户节点的数据集。
7.根据权利要求1-6任一所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,每个分类器网络包括线性层和激活层,输入特征提取后的图像数据后,分类器网络分别输出图片对于特定一种标签中,属于各个类别的概率输出值,每一种标签选择概率最大的类别作为此标签的分类结果。
8.根据权利要求7所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,每个分类器网络的本地损失函数均选用CrossEntropyLoss损失函数,其计算方法如(VIII)式所示:
如(VIII)中,输入xi是一个维度为j的向量,即经过分类器网络的输出结果;yi是One-Hot形式的标签向量,维度也为j。
9.一种计算机设备,包括存储器和处理器,所述存储器存储有计算机程序,其特征在于,所述处理器执行所述计算机程序时实现权利要求1-8任一所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法的步骤。
10.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现权利要求1-8任一所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法的步骤。
11.一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类系统,用于实现权利要求1-8任一所述的一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法,其特征在于,包括:
特征提取模块,被配置为,对待分类的图像进行特征提取,提取出图像数据的主要特征;使得图片数据的RGB特征分量增加,总特征数目明显提升;
标签分类模块,被配置为,从分类器网络中输出对应某一标签的分类结果。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210438889.6A CN114882245B (zh) | 2022-04-22 | 2022-04-22 | 一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210438889.6A CN114882245B (zh) | 2022-04-22 | 2022-04-22 | 一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114882245A CN114882245A (zh) | 2022-08-09 |
CN114882245B true CN114882245B (zh) | 2023-08-25 |
Family
ID=82670960
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210438889.6A Active CN114882245B (zh) | 2022-04-22 | 2022-04-22 | 一种联邦多任务学习中基于特征提取-子任务分类器的数据标签分类方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114882245B (zh) |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112862011A (zh) * | 2021-03-31 | 2021-05-28 | 中国工商银行股份有限公司 | 基于联邦学习的模型训练方法、装置及联邦学习系统 |
CN112949837A (zh) * | 2021-04-13 | 2021-06-11 | 中国人民武装警察部队警官学院 | 一种基于可信网络的目标识别联邦深度学习方法 |
CN113420888A (zh) * | 2021-06-03 | 2021-09-21 | 中国石油大学(华东) | 一种基于泛化域自适应的无监督联邦学习方法 |
CN113705712A (zh) * | 2021-09-02 | 2021-11-26 | 广州大学 | 一种基于联邦半监督学习的网络流量分类方法和系统 |
CN113792892A (zh) * | 2021-09-29 | 2021-12-14 | 深圳前海微众银行股份有限公司 | 联邦学习建模优化方法、设备、可读存储介质及程序产品 |
CN113850272A (zh) * | 2021-09-10 | 2021-12-28 | 西安电子科技大学 | 基于本地差分隐私的联邦学习图像分类方法 |
-
2022
- 2022-04-22 CN CN202210438889.6A patent/CN114882245B/zh active Active
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112862011A (zh) * | 2021-03-31 | 2021-05-28 | 中国工商银行股份有限公司 | 基于联邦学习的模型训练方法、装置及联邦学习系统 |
CN112949837A (zh) * | 2021-04-13 | 2021-06-11 | 中国人民武装警察部队警官学院 | 一种基于可信网络的目标识别联邦深度学习方法 |
CN113420888A (zh) * | 2021-06-03 | 2021-09-21 | 中国石油大学(华东) | 一种基于泛化域自适应的无监督联邦学习方法 |
CN113705712A (zh) * | 2021-09-02 | 2021-11-26 | 广州大学 | 一种基于联邦半监督学习的网络流量分类方法和系统 |
CN113850272A (zh) * | 2021-09-10 | 2021-12-28 | 西安电子科技大学 | 基于本地差分隐私的联邦学习图像分类方法 |
CN113792892A (zh) * | 2021-09-29 | 2021-12-14 | 深圳前海微众银行股份有限公司 | 联邦学习建模优化方法、设备、可读存储介质及程序产品 |
Non-Patent Citations (1)
Title |
---|
基于联邦学习和卷积神经网络的入侵检测方法;王蓉;马春光;武朋;;信息网络安全(第04期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN114882245A (zh) | 2022-08-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109241317B (zh) | 基于深度学习网络中度量损失的行人哈希检索方法 | |
WO2020238293A1 (zh) | 图像分类方法、神经网络的训练方法及装置 | |
WO2021164772A1 (zh) | 训练跨模态检索模型的方法、跨模态检索的方法和相关装置 | |
JP6395158B2 (ja) | シーンの取得画像を意味的にラベル付けする方法 | |
CN109711426B (zh) | 一种基于gan和迁移学习的病理图片分类装置及方法 | |
CN117456297A (zh) | 图像生成方法、神经网络的压缩方法及相关装置、设备 | |
CN110059206A (zh) | 一种基于深度表征学习的大规模哈希图像检索方法 | |
CN113095370B (zh) | 图像识别方法、装置、电子设备及存储介质 | |
CN110555060A (zh) | 基于成对样本匹配的迁移学习方法 | |
CN112784929B (zh) | 一种基于双元组扩充的小样本图像分类方法及装置 | |
CN107480723B (zh) | 基于局部二进制阈值学习网络的纹理识别方法 | |
CN112926675B (zh) | 视角和标签双重缺失下的深度不完整多视角多标签分类方法 | |
CN113283590B (zh) | 一种面向后门攻击的防御方法 | |
CN111738169A (zh) | 一种基于端对端网络模型的手写公式识别方法 | |
CN114896434B (zh) | 一种基于中心相似度学习的哈希码生成方法及装置 | |
Yang et al. | Local label descriptor for example based semantic image labeling | |
CN111126464A (zh) | 一种基于无监督域对抗领域适应的图像分类方法 | |
Fan | Research and realization of video target detection system based on deep learning | |
CN103942214B (zh) | 基于多模态矩阵填充的自然图像分类方法及装置 | |
CN116758005A (zh) | 一种面向pet/ct医学图像的检测方法 | |
CN112990340B (zh) | 一种基于特征共享的自学习迁移方法 | |
KR20210040604A (ko) | 행위 인식 방법 및 장치 | |
CN114168773A (zh) | 一种基于伪标签和重排序的半监督草图图像检索方法 | |
CN117611838A (zh) | 一种基于自适应超图卷积网络的多标签图像分类方法 | |
CN111143544B (zh) | 一种基于神经网络的柱形图信息提取方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |