CN117114148B - 一种轻量级联邦学习训练方法 - Google Patents
一种轻量级联邦学习训练方法 Download PDFInfo
- Publication number
- CN117114148B CN117114148B CN202311046071.0A CN202311046071A CN117114148B CN 117114148 B CN117114148 B CN 117114148B CN 202311046071 A CN202311046071 A CN 202311046071A CN 117114148 B CN117114148 B CN 117114148B
- Authority
- CN
- China
- Prior art keywords
- neural network
- convolutional neural
- local
- network model
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 47
- 238000000034 method Methods 0.000 title claims abstract description 38
- 238000013527 convolutional neural network Methods 0.000 claims abstract description 130
- 238000013138 pruning Methods 0.000 claims abstract description 46
- 238000004821 distillation Methods 0.000 claims abstract description 18
- 230000006870 function Effects 0.000 claims description 28
- 238000012545 processing Methods 0.000 claims description 8
- 230000002159 abnormal effect Effects 0.000 claims description 4
- 238000003062 neural network model Methods 0.000 claims description 4
- 238000004891 communication Methods 0.000 abstract description 18
- 230000002776 aggregation Effects 0.000 abstract description 10
- 238000004220 aggregation Methods 0.000 abstract description 10
- 230000008569 process Effects 0.000 description 5
- 238000012360 testing method Methods 0.000 description 5
- 238000013528 artificial neural network Methods 0.000 description 4
- 230000005540 biological transmission Effects 0.000 description 4
- 238000011156 evaluation Methods 0.000 description 4
- 238000013140 knowledge distillation Methods 0.000 description 4
- 230000002457 bidirectional effect Effects 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 230000006978 adaptation Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000011176 pooling Methods 0.000 description 2
- 244000141353 Prunus domestica Species 0.000 description 1
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 230000001934 delay Effects 0.000 description 1
- 238000012217 deletion Methods 0.000 description 1
- 230000037430 deletion Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 230000004044 response Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/098—Distributed learning, e.g. federated learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Biomedical Technology (AREA)
- Health & Medical Sciences (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Image Analysis (AREA)
Abstract
本发明提供了一种轻量级联邦学习训练方法,包括:中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型并传输至多个客户端;客户端将初始化深度卷积神经网络模型的模型反向蒸馏至本地深度卷积神经网络模型;并将本地图像数据输入本地深度卷积神经网络模型对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;通过剪枝算法对训练后的本地深度卷积神经网络模型进行剪枝,得到轻量化深度卷积神经网络模型并正向蒸馏至本地局部模型;将本地局部模型输入中央服务器进行聚合,得到全局模型;与现有技术相比,本发明能够在提高通信和聚合效率的同时提升模型的精确性。
Description
技术领域
本发明涉及信息技术领域,特别涉及一种轻量级联邦学习训练方法。
背景技术
随着移动设备的功能越来越强大,越来越多的基于神经网络的智能应用已被开发用于移动设备,例如图像识别、视频分析、目标检测等。为了使智能应用能够达到预计效果,通常会通过大量的数据训练智能应用的神经网络模型,然而,单个移动设备的数据量是有限的,不太可能帮助神经网络达到理想的精度。同时,考虑到隐私保护和通信量过大等原因,将数据从许多移动设备传输到一个中央服务器并进行集中训练将不再可行。在联邦学习中的中央服务器的编排下,以分散的方式训练共享全局模型,实现在保护用户数据隐私的同时,最大化提升模型的训练效率和模型的整体精度。
目前,由于联邦学习在解决隐私保护和数据孤岛等问题方面的优势,已经逐步成为流行的机器学习范式。此类方法通常分为四个步骤:首先,在每轮通信中,每个参与设备从中央服务器下载当前模型;其次,通过本地数据训练局部模型;第三,通过中央服务器聚合所有局部模型;第四,将聚合后的全局模型发送回设备。然而,由于移动设备通信成本高且通信传输不稳定,联邦学习通信负载较大等问题,常规的联邦学习方法难以在一定设备尤其是告诉移动设备中使用。因此,目前的面向移动设备的联邦学习方法却存在以下不可忽略的技术问题:
传统的联邦学习训练方法主要考虑的是稳定通信的设备或者是慢速的移动设备,从而忽略了联邦学习算法应用在高速移动设备上的挑战。在高速移动场景下,例如高速车联网中,车辆的高速移动性带来了信号质量的下降,导致车载网络无法实现最佳带宽和通信速度,这意味着参与训练的设备将会消耗大量的时间和资源在模型的传输过程中。同时,由于不同设备的网络时延不同,中央服务器的聚合过程将会导致更长的等待时间,这将导致联邦学习的效率进一步降低,这些问题严重影响了传统联邦学习在移动场景下的应用效果。
发明内容
本发明提供了一种轻量级联邦学习训练方法,其目的是为了节约模型传输过程中的传输时间和减少模型聚合过程中的等待时间。
为了达到上述目的,本发明提供了一种轻量级联邦学习训练方法,包括:
步骤1,中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型,并将初始化深度卷积神经网络模型传输至多个客户端;
步骤2,客户端通过设定蒸馏温度,将初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型;
步骤3,客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过损失函数对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;
步骤4,客户端通过剪枝算法对训练后的本地深度卷积神经网络模型中的编码器和分类器分别进行剪枝,得到剪枝后的编码器和剪枝后的分类器,并将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;
步骤5,客户端通过设定蒸馏温度,将轻量化深度卷积神经网络模型的参数正向知识蒸馏至本地局部模型,并将本地局部模型输入中央服务器;
步骤6,中央服务器将多个客户端上传的本地局部模型进行聚合,得到全局模型,并判断全局模型是否满足预设训练条件;若是,则训练结束,将待识别的图像数据输入全局模型进行图像识别,得到识别结果;否则,将全局模型作为步骤1中的初始化深度卷积神经网络模型传输至多个客户端,并返回执行步骤2。
进一步来说,在客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别之前,还包括:
对采集的本地图像数据进行数据标签规范化处理和异常数据删除处理,得到处理后的本地图像数据;
客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别。
进一步来说,步骤4包括:
根据训练后的本地深度卷积神经网络模型的网络特性,将训练后的本地深度卷积神经网络模型分为编码器和分类器;
利用结构化剪枝的方式,将编码器的权重绝对值小于预设阈值的权重进行修剪,得到剪枝后的编码器;
利用非结构化剪枝的方式,评估分类器中每个卷积层中每个过滤器的影响系数,并影响系数低于预设值的过滤器进行修剪,得到剪枝后的分类器;
将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型。
进一步来说,根据训练后的本地深度卷积神经网络模型的网络特性,通过对训练后的本地深度卷积神经网络模型进行正则化,得到编码器和分类器,正则化的表达式为:
R(W)=REnc(WE)+RCls(WC)
其中,R(W)表示本地深度卷积神经网络模型的剪枝权重,REnc表示编码器的剪枝权重,RCls表示分类器的剪枝权重,WE表示编码器的权重,WC表示分类器的权重,||·||g是group Lasso算法,Fl是第l个卷积层中滤波器的数量,Chl是第l个卷积层中通道的个数,Rowl代表分类器中第l层的行数,Col1代表分类器中第l层的列数。
进一步来说,轻量化深度卷积神经网络模型的损失函数为:
F(W)=FD(W)+λR(W)
其中,FD(W)是轻量化深度卷积神经网络模型的损失函数,λ是结构化稀疏正则化的系数。
进一步来说,本地局部模型的损失函数为:
其中,β表示控制来自数据或其他模型知识比例的超参数,表示本地局部模型的交叉熵损失函数,DKL表示KL散度,pl表示本地深度卷积神经网络模型的预测值,pm表示本地局部模型的预测值。
进一步来说,训练终止的条件为:
直至全局模型的精度达到预设训练精度或迭代次数达到预设上限时,终止训练。
本发明的上述方案有如下的有益效果:
本发明通过中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型,并将初始化深度卷积神经网络模型传输至多个客户端;客户端通过设定蒸馏温度,将初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型;客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过损失函数对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;客户端通过剪枝算法对训练后的本地深度卷积神经网络模型中的编码器和分类器分别进行剪枝,得到剪枝后的编码器和剪枝后的分类器,并将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;客户端通过设定蒸馏温度,将轻量化深度卷积神经网络模型的参数正向知识蒸馏至本地局部模型,并将本地局部模型输入中央服务器;中央服务器将多个客户端上传的本地局部模型进行聚合,得到全局模型,并判断全局模型是否满足预设训练条件;若是,则训练结束,将待识别的图像数据输入全局模型进行图像识别,得到识别结果;否则,将全局模型作为步骤1中的初始化深度卷积神经网络模型传输至多个客户端,并返回执行步骤2;与现有技术相比,本发明采用双向蒸馏的方式压缩模型的参数,极大程度上提高了通信效率并减少了聚合时的等待时间,通过剪枝算法对模型做进一步的压缩,有效的去除了局部模型中多余的参数以减少模型参数量,从而能够在提高通信和聚合效率的同时提升模型的精确性。
本发明的其它有益效果将在随后的具体实施方式部分予以详细说明。
附图说明
图1为本发明实施例的流程示意图;
图2为本发明实施例中轻量级联邦学习训练框架示意图。
具体实施方式
为使本发明要解决的技术问题、技术方案和优点更加清楚,下面将结合附图及具体实施例进行详细描述。显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
在本发明的描述中,需要说明的是,术语“中心”、“上”、“下”、“左”、“右”、“竖直”、“水平”、“内”、“外”等指示的方位或位置关系为基于附图所示的方位或位置关系,仅是为了便于描述本发明和简化描述,而不是指示或暗示所指的装置或元件必须具有特定的方位、以特定的方位构造和操作,因此不能理解为对本发明的限制。此外,术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性。
在本发明的描述中,需要说明的是,除非另有明确的规定和限定,术语“安装”、“相连”、“连接”应做广义理解,例如,可以是锁定连接,也可以是可拆卸连接,或一体地连接;可以是机械连接,也可以是电连接;可以是直接相连,也可以通过中间媒介间接相连,可以是两个元件内部的连通。对于本领域的普通技术人员而言,可以具体情况理解上述术语在本发明中的具体含义。
此外,下面所描述的本发明不同实施方式中所涉及的技术特征只要彼此之间未构成冲突就可以相互结合。
本发明针对现有的问题,提供了一种轻量级联邦学习训练方法。
如图1所示,本发明的实施例提供了一种轻量级联邦学习训练方法,包括:
步骤1,中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型,并将初始化深度卷积神经网络模型传输至多个客户端;
步骤2,客户端通过设定蒸馏温度,将初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型;
步骤3,客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过损失函数对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;
步骤4,客户端通过剪枝算法对训练后的本地深度卷积神经网络模型中的编码器和分类器分别进行剪枝,得到剪枝后的编码器和剪枝后的分类器,并将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;
步骤5,客户端通过设定蒸馏温度,将轻量化深度卷积神经网络模型的参数正向知识蒸馏至本地局部模型,并将本地局部模型输入中央服务器;
步骤6,中央服务器将多个客户端上传的本地局部模型进行聚合,得到全局模型,并判断全局模型是否满足预设训练条件;若是,则训练结束,将待识别的图像数据输入全局模型进行图像识别,得到识别结果;否则,将全局模型作为步骤1中的初始化深度卷积神经网络模型传输至多个客户端,并返回执行步骤2。
具体来说,基于移动设备的数据质量、处理器性能、通信质量等因素,选择多个高质量的客户端与中央服务器建立联系并加入联邦学习的训练过程;中央服务器对深度卷积神经网络模型的参数进行初始化,并通过无线网络将初始化深度卷积神经网络模型分别传输至相应的客户端,初始化深度卷积神经网络模型由19个卷积层、5个池化层、3个全连接层和softmax层组成。
需要说明的是,本发明实施例中所提到的客户端搭载于具备摄像功能的网联汽车上,网联汽车通过采集模块采集道路图像数据存储至客户端。
具体来说,在客户端设定合适的蒸馏温度,将初始化深度卷积神经网络模型参数反向蒸馏至本地的深度卷积神经网络模型中;联邦学习初始化阶段通过反向蒸馏的方式间接的将本地深度卷积神经网络模型初始化,进而使得整个联邦学习过程加快。
具体来说,客户端将采集的本地道路图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果和每张图像对应的标签值结果,然后反向传播通过导数链式法则计算损失函数对各参数的梯度,并根据梯度进行参数的更新,得到训练后的本地深度卷积神经网络模型。
具体来说,在客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别之前,还包括:
对采集的本地图像数据进行数据标签规范化处理和异常数据删除处理,得到处理后的本地图像数据;
客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别。
本发明实施例以客户端i为例,通过客户端i采集本地图像数据;对本地图像数据进行数据标签规范化处理和异常数据删除处理,得到处理后的本地图像数据;将处理后的本地图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过损失函数对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型locali。
具体来说,步骤4包括:
根据训练后的本地深度卷积神经网络模型locali的网络结构,将训练后的本地深度卷积神经网络模型locali分为编码器Encoder和分类器Classifier,如图2所示;其中编码器Encoder由卷积神经网络CNN组成,分类器Classifier由全连接神经网络组成;将模型分为编码器和分类器再分别进行剪枝是根据编码器和分类器在网络中的角色和性质不同。编码器中的滤波器负责提取局部特征,对于图像的不同部分有不同的响应。而分类器中的全连接层负责整合卷积层提取的特征,对于整体任务的影响较大,全连接层剪枝时需要保留对任务性能影响较大的神经元。通过分别考虑滤波器和全连接层的剪枝,以求得最大限度地对模型进行压缩,减少计算复杂度,同时保持模型的性能。
利用结构化剪枝的方式,评估编码器中每个卷积层中每个过滤器的影响系数,并影响系数低于预设值的过滤器进行修剪,得到剪枝后的编码器Encoder;
利用非结构化剪枝的方式,将分类器的权重绝对值小于预设阈值的权重进行修剪,得到剪枝后的分类器Classifier;
将剪枝后的编码器Encoder和剪枝后的分类器Classifier进行拼接,得到轻量化深度卷积神经网络模型,轻量化深度卷积神经网络模型包括13个卷积层、5个池化层、3个全连接层以及softmax层组成。
本发明实施例提出的剪枝算法用于减小模型大小和通信开销,包括基于结构化剪枝的编码器Encoder修剪方法和基于非结构化剪枝的分类器Classifier修剪方法。
非结构化剪枝通常适用于全连接神经网络,根据设定的阈值,将权重绝对值小于阈值的参数定义为不重要的参数直接设为零,具有很高的灵活性;结构化剪枝通常用于卷积神经网络CNN中,通过一些方法评估CNN中每个卷积层过滤器的影响系数,然后将其中影响系数较低的卷积层过滤器移除,该方法虽然灵活性较低,但是能更大程度上压缩模型,基于上述讨论,对于本地深度卷积神经网络模型locali的正则化可以表示为:
R(W)=REnc(WE)+RCls(WC)
其中,R(W)表示本地深度卷积神经网络模型的剪枝权重,REnc表示编码器的剪枝权重,RCls表示分类器的剪枝权重,WE表示编码器Encoder的权重,WC表示分类器Classifier的权重,||·||g表示group Lasso分组最小角回归算法,Fl表示第l个卷积层中过滤器的数量,Chl表示第l个卷积层中通道的个数,Rowl表示分类器中第l层的行数,Col1代表分类器中第l层的列数。
其中Modules={C:Classifer,E:Encoder},代表/>的参数量。
应用上述的正则化方法后,轻量化深度卷积神经网络模型的训练损失函数为:
F(W)=FD(W)+λR(W)
其中,FD(W)是轻量化深度卷积神经网络模型的损失函数,λ是结构化剪枝正则化的系数。
通过使用客户端收集的本地图像数据优化轻量化深度卷积神经网络模型中的损失函数,可以识别轻量化深度卷积神经网络模型中的零值和非零值参数。
具体来说,终止训练的条件为:直至全局模型的精度达到预设训练精度或迭代次数达到预设上限时,终止训练。
本发明实施例通过剪枝算法得到轻量化深度卷积神经网络模型,再使用双向知识蒸馏算法进一步压缩轻量化深度卷积神经网络模型的模型参数,以便于联邦学习过程中知识的上传和下载,首先通过正向知识蒸馏将轻量化深度卷积神经网络模型提取到更紧凑、轻量的本地局部模型中,然后将本地局部模型输入中央服务器进行模型聚合,得到全局模型并传输至各个客户端,用于更新客户端中的初始化深度卷积神经网络模型;客户端将全局模型替换为初始化深度卷积神经网络模型,最后通过反向知识蒸馏将初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型。
具体来说,本地局部模型的损失函数为:
其中,β表示控制来自数据或其他模型知识比例的超参数,表示本地局部模型的交叉熵损失函数,DKL表示KL散度,pl表示本地深度卷积神经网络模型的预测值,pm表示本地局部模型的预测值。
具体来说,中央服务器基于FedAug算法对本地局部模型进行加权聚合,得到本轮全局模型,FedAug算法如下所示:
其中,是第t+1轮中第i个客户端的本地局部模型的参数,Wt+1是第t+1轮联邦学习中全局模型的参数。
具体来说,中央服务器将全局模型传输至每个客户端中,客户端利用接收到的全局模型代替初始化深度卷积神经网络模型作为下一轮的初始化深度卷积神经网络模型,即
初始化深度卷积神经网络模型的损失函数为:
其中,是初始化深度卷积神经网络模型的交叉熵损失函数,α是控制来自数据或其他模型知识比例的超参数。
本发明实施例通过网联汽车上的采集模块采集道路图像数据,并将道路图像数据输入全局模型进行图像识别,得到识别结果,识别结果包括:道路上存在行人、道路上存在非静止的障碍物、道路上存在静止的障碍物。
下面结合具体的实例对本发明实施例所提出的训练方法进行验证,具体如下:
本发明实施例利用CIFAR10和MNIST数据集进行测试。CIFAR10由60000张32*32彩色图像组成,图像有10个类,每个类有6000个图像,它分别包含有50000个训练图像和10000个测试图像;MNIST由70000张28*28像素的灰度手写数字图像,图像有10个类,每个类有7000个样本,它分别包含有60000个训练图像和10000个测试图像;具体如表1所示:
表1
图像尺寸 | 图像通道数 | 图像类数 | 训练集数量 | 测试集数量 | |
CIFAR10 | 32*32 | 3 | 10 | 50000 | 10000 |
MNIST | 28*28 | 1 | 10 | 60000 | 10000 |
由于在高速移动场景下,每个客户端中间的数据集常常不满足独立同分布不假设,因此本发明实施例额外采用Dirichlet分布来为每个客户端划分数据集,是每个客户端上样本标签分布不同。
表2
为了评估和验证本发明实施例所训练出的全局模型的性能,本发明实施例首先测算了FL(Federated Learining)、FL+KD(Federated Learining+Knowledge Distillation)分别在IID和Non-IID情况下的通讯开销,采用CR(communication rounds)、TCC(totalcommunication cost)作为通信开销的主要评价指标,根据表2可以得出,本发明实施例所提供的方法在评价指标中都取得了较好的数值测算结果,对于模型的性能表现,本发明实施例使用目前流行的Basic(Centralized Machine Learning,集中式机器学习集中式机器学习)、联邦平均算法(FederatedAveragingAlgorithm,FedAVG)、联邦学习框架FedProx作为基准测试模型,并采用Acc、Precision、Recall、F1作为模型的主要评价指标,结果如表3所示:
表3
从上表3可看出,本发明所述方法在评价指标中都取得了较高的性能表现,并超过基准测试(FedAVG、FedProx)模型。
本发明实施例中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型,并将初始化深度卷积神经网络模型传输至多个客户端;客户端通过设定蒸馏温度,将初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型;客户端将采集的本地图像数据输入本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过损失函数对本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;客户端通过剪枝算法对训练后的本地深度卷积神经网络模型中的编码器和分类器分别进行剪枝,得到剪枝后的编码器和剪枝后的分类器,并将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;客户端通过设定蒸馏温度,将轻量化深度卷积神经网络模型的参数正向知识蒸馏至本地局部模型,并将本地局部模型输入中央服务器;中央服务器将多个客户端上传的本地局部模型进行聚合,得到全局模型,并判断全局模型是否满足预设训练条件;若是,则训练结束,将待识别的图像数据输入全局模型进行图像识别,得到识别结果;否则,将全局模型作为步骤1中的初始化深度卷积神经网络模型传输至多个客户端,并返回执行步骤2;与现有技术相比,本发明采用双向蒸馏的方式压缩模型的参数,极大程度上提高了通信效率并减少了聚合时的等待时间,通过剪枝算法对模型做进一步的压缩,有效的去除了局部模型中多余的参数以减少模型参数量,从而能够在提高通信和聚合效率的同时提升模型的精确性。
以上所述是本发明的优选实施方式,应当指出,对于本技术领域的普通技术人员来说,在不脱离本发明所述原理的前提下,还可以作出若干改进和润饰,这些改进和润饰也应视为本发明的保护范围。
Claims (6)
1.一种轻量级联邦学习训练方法,其特征在于,包括:
步骤1,中央服务器将深度卷积神经网络模型的参数进行初始化,得到初始化深度卷积神经网络模型,并将所述初始化深度卷积神经网络模型传输至多个客户端;
步骤2,所述客户端通过设定蒸馏温度,将所述初始化深度卷积神经网络模型的模型参数反向蒸馏至本地深度卷积神经网络模型;
步骤3,所述客户端将采集的本地图像数据输入所述本地深度卷积神经网络模型进行图像识别,得到识别结果并计算损失函数,通过所述损失函数对所述本地深度卷积神经网络的参数进行更新,得到训练后的本地深度卷积神经网络模型;
步骤4,所述客户端通过剪枝算法对训练后的本地深度卷积神经网络模型中的编码器和分类器分别进行剪枝,得到剪枝后的编码器和剪枝后的分类器,并将剪枝后的编码器和剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;
根据所述训练后的本地深度卷积神经网络模型的网络结构,将所述训练后的本地深度卷积神经网络模型分为编码器和分类器;
利用结构化剪枝的方式,评估所述编码器中每个卷积层中每个过滤器的影响系数,并所述影响系数低于预设值的过滤器进行修剪,得到剪枝后的编码器;
利用非结构化剪枝的方式,将所述分类器的权重绝对值小于预设阈值的权重进行修剪,得到剪枝后的分类器;
将所述剪枝后的编码器和所述剪枝后的分类器进行拼接,得到轻量化深度卷积神经网络模型;
步骤5,所述客户端通过设定蒸馏温度,将所述轻量化深度卷积神经网络模型的参数正向知识蒸馏至本地局部模型,并将所述本地局部模型输入所述中央服务器;
步骤6,所述中央服务器将多个所述客户端上传的本地局部模型进行聚合,得到全局模型,并判断所述全局模型是否满足预设训练条件;若是,则训练结束,将待识别的图像数据输入所述全局模型进行图像识别,得到识别结果;否则,将所述全局模型作为所述步骤1中的初始化深度卷积神经网络模型传输至多个客户端,并返回执行步骤2。
2.根据权利要求1所述的轻量级联邦学习训练方法,其特征在于,在所述客户端将采集的本地图像数据输入所述本地深度卷积神经网络模型进行图像识别之前,还包括:
对采集的本地图像数据进行数据标签规范化处理和异常数据删除处理,得到处理后的本地图像数据;
所述客户端将采集的本地图像数据输入所述本地深度卷积神经网络模型进行图像识别。
3.根据权利要求2所述的轻量级联邦学习训练方法,其特征在于,
根据所述训练后的本地深度卷积神经网络模型的网络特性,通过对所述训练后的本地深度卷积神经网络模型进行正则化,得到编码器和分类器,正则化的表达式为:
R(W)=REnc(WE)+RCls(WC)
其中,R(W)表示本地深度卷积神经网络模型的剪枝权重,REnc表示编码器的剪枝权重,RCls表示分类器的剪枝权重,WE表示编码器的权重,WC表示分类器的权重,||·||g是groupLasso算法,Fl是第l个卷积层中滤波器的数量,Chl是第l个卷积层中通道的个数,Rowl代表分类器中第l层的行数,Col1代表分类器中第l层的列数。
4.根据权利要求3所述的轻量级联邦学习训练方法,其特征在于,所述轻量化深度卷积神经网络模型的损失函数为:
F(W)=FD(W)+λR(W)
其中,FD(W)是轻量化深度卷积神经网络模型的损失函数,λ是结构化稀疏正则化的系数。
5.根据权利要求4所述的轻量级联邦学习训练方法,其特征在于,所述所述本地局部模型的损失函数为:
其中,β表示控制来自数据或其他模型知识比例的超参数,表示本地局部模型的交叉熵损失函数,DKL表示KL散度,pl表示本地深度卷积神经网络模型的预测值,pm表示本地局部模型的预测值。
6.根据权利要求5所述的轻量级联邦学习训练方法,其特征在于,所述全局模型训练终止的条件为:
直至所述全局模型的精度达到预设训练精度或迭代次数达到预设上限时,终止训练。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311046071.0A CN117114148B (zh) | 2023-08-18 | 2023-08-18 | 一种轻量级联邦学习训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311046071.0A CN117114148B (zh) | 2023-08-18 | 2023-08-18 | 一种轻量级联邦学习训练方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117114148A CN117114148A (zh) | 2023-11-24 |
CN117114148B true CN117114148B (zh) | 2024-04-09 |
Family
ID=88794104
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311046071.0A Active CN117114148B (zh) | 2023-08-18 | 2023-08-18 | 一种轻量级联邦学习训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117114148B (zh) |
Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109389043A (zh) * | 2018-09-10 | 2019-02-26 | 中国人民解放军陆军工程大学 | 一种无人机航拍图片的人群密度估计方法 |
CN109886397A (zh) * | 2019-03-21 | 2019-06-14 | 西安交通大学 | 一种针对卷积层的神经网络结构化剪枝压缩优化方法 |
CN113205863A (zh) * | 2021-06-04 | 2021-08-03 | 广西师范大学 | 基于蒸馏的半监督联邦学习的个性化模型的训练方法 |
CN113505210A (zh) * | 2021-07-12 | 2021-10-15 | 广东工业大学 | 一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 |
CN113705712A (zh) * | 2021-09-02 | 2021-11-26 | 广州大学 | 一种基于联邦半监督学习的网络流量分类方法和系统 |
CN114154643A (zh) * | 2021-11-09 | 2022-03-08 | 浙江师范大学 | 基于联邦蒸馏的联邦学习模型的训练方法、系统和介质 |
CN114547315A (zh) * | 2022-04-25 | 2022-05-27 | 湖南工商大学 | 一种案件分类预测方法、装置、计算机设备及存储介质 |
CN114663791A (zh) * | 2022-04-19 | 2022-06-24 | 重庆邮电大学 | 一种非结构化环境下面向剪枝机器人的枝条识别方法 |
CN114882582A (zh) * | 2022-04-06 | 2022-08-09 | 南方科技大学 | 基于联邦学习模式的步态识别模型训练方法与系统 |
CN115018039A (zh) * | 2021-03-05 | 2022-09-06 | 华为技术有限公司 | 一种神经网络蒸馏方法、目标检测方法以及装置 |
CN115272738A (zh) * | 2021-04-29 | 2022-11-01 | 华为技术有限公司 | 数据处理方法、模型的训练方法及装置 |
CN115358419A (zh) * | 2022-08-25 | 2022-11-18 | 浙江工业大学 | 一种基于联邦蒸馏的物联网室内定位方法 |
CN115511108A (zh) * | 2022-09-27 | 2022-12-23 | 河南大学 | 一种基于数据集蒸馏的联邦学习个性化方法 |
-
2023
- 2023-08-18 CN CN202311046071.0A patent/CN117114148B/zh active Active
Patent Citations (13)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109389043A (zh) * | 2018-09-10 | 2019-02-26 | 中国人民解放军陆军工程大学 | 一种无人机航拍图片的人群密度估计方法 |
CN109886397A (zh) * | 2019-03-21 | 2019-06-14 | 西安交通大学 | 一种针对卷积层的神经网络结构化剪枝压缩优化方法 |
CN115018039A (zh) * | 2021-03-05 | 2022-09-06 | 华为技术有限公司 | 一种神经网络蒸馏方法、目标检测方法以及装置 |
CN115272738A (zh) * | 2021-04-29 | 2022-11-01 | 华为技术有限公司 | 数据处理方法、模型的训练方法及装置 |
CN113205863A (zh) * | 2021-06-04 | 2021-08-03 | 广西师范大学 | 基于蒸馏的半监督联邦学习的个性化模型的训练方法 |
CN113505210A (zh) * | 2021-07-12 | 2021-10-15 | 广东工业大学 | 一种基于轻量化Actor-Critic生成式对抗网络的医疗问答生成系统 |
CN113705712A (zh) * | 2021-09-02 | 2021-11-26 | 广州大学 | 一种基于联邦半监督学习的网络流量分类方法和系统 |
CN114154643A (zh) * | 2021-11-09 | 2022-03-08 | 浙江师范大学 | 基于联邦蒸馏的联邦学习模型的训练方法、系统和介质 |
CN114882582A (zh) * | 2022-04-06 | 2022-08-09 | 南方科技大学 | 基于联邦学习模式的步态识别模型训练方法与系统 |
CN114663791A (zh) * | 2022-04-19 | 2022-06-24 | 重庆邮电大学 | 一种非结构化环境下面向剪枝机器人的枝条识别方法 |
CN114547315A (zh) * | 2022-04-25 | 2022-05-27 | 湖南工商大学 | 一种案件分类预测方法、装置、计算机设备及存储介质 |
CN115358419A (zh) * | 2022-08-25 | 2022-11-18 | 浙江工业大学 | 一种基于联邦蒸馏的物联网室内定位方法 |
CN115511108A (zh) * | 2022-09-27 | 2022-12-23 | 河南大学 | 一种基于数据集蒸馏的联邦学习个性化方法 |
Also Published As
Publication number | Publication date |
---|---|
CN117114148A (zh) | 2023-11-24 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US20190244362A1 (en) | Differentiable Jaccard Loss Approximation for Training an Artificial Neural Network | |
CN108090472B (zh) | 基于多通道一致性特征的行人重识别方法及其系统 | |
CN111506773B (zh) | 一种基于无监督深度孪生网络的视频去重方法 | |
CN107481209B (zh) | 一种基于卷积神经网络的图像或视频质量增强方法 | |
CN111062410B (zh) | 基于深度学习的星型信息桥气象预测方法 | |
CN109544204B (zh) | 一种基于轻量化多任务卷积神经网络的导购行为分析方法 | |
CN108596890B (zh) | 一种基于视觉测量率自适应融合的全参考图像质量客观评价方法 | |
CN109800795A (zh) | 一种果蔬识别方法及系统 | |
CN114998958B (zh) | 一种基于轻量化卷积神经网络的人脸识别方法 | |
CN111127435A (zh) | 基于双流卷积神经网络的无参考图像质量评估方法 | |
CN112749663B (zh) | 基于物联网和ccnn模型的农业果实成熟度检测系统 | |
CN113567159A (zh) | 一种基于边云协同的刮板输送机状态监测及故障诊断方法 | |
CN112767385A (zh) | 基于显著性策略与特征融合无参考图像质量评价方法 | |
CN116362325A (zh) | 一种基于模型压缩的电力图像识别模型轻量化应用方法 | |
CN115953630A (zh) | 一种基于全局-局部知识蒸馏的跨域小样本图像分类方法 | |
CN115358418A (zh) | 基于模型扰动的联邦学习分类模型训练方法 | |
CN117114148B (zh) | 一种轻量级联邦学习训练方法 | |
CN114359167A (zh) | 一种复杂场景下基于轻量化YOLOv4的绝缘子缺陷检测方法 | |
CN117853596A (zh) | 无人机遥感测绘方法及系统 | |
CN116362328A (zh) | 一种基于公平性特征表示的联邦学习异构模型聚合方法 | |
CN113780371B (zh) | 基于边缘计算与深度学习的绝缘子状态边缘识别方法 | |
CN114826949A (zh) | 一种通信网络状况预测方法 | |
CN113486929A (zh) | 基于残差收缩模块与注意力机制的岩石薄片图像识别方法 | |
CN111510740B (zh) | 转码方法、装置、电子设备和计算机可读存储介质 | |
CN114972900A (zh) | 一种电力多源数据筛选方法、装置及终端设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |