CN114626550A - 分布式模型协同训练方法和系统 - Google Patents
分布式模型协同训练方法和系统 Download PDFInfo
- Publication number
- CN114626550A CN114626550A CN202210272759.XA CN202210272759A CN114626550A CN 114626550 A CN114626550 A CN 114626550A CN 202210272759 A CN202210272759 A CN 202210272759A CN 114626550 A CN114626550 A CN 114626550A
- Authority
- CN
- China
- Prior art keywords
- category
- model
- label
- class
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 89
- 238000000034 method Methods 0.000 title claims abstract description 29
- 230000004931 aggregating effect Effects 0.000 claims abstract description 3
- 239000013598 vector Substances 0.000 claims description 26
- 238000012935 Averaging Methods 0.000 claims description 18
- 238000003860 storage Methods 0.000 claims description 6
- 230000006870 function Effects 0.000 description 16
- 238000004891 communication Methods 0.000 description 15
- 238000004821 distillation Methods 0.000 description 14
- 238000009826 distribution Methods 0.000 description 14
- 238000013140 knowledge distillation Methods 0.000 description 14
- 238000010586 diagram Methods 0.000 description 13
- 238000010801 machine learning Methods 0.000 description 10
- 230000006835 compression Effects 0.000 description 7
- 238000007906 compression Methods 0.000 description 7
- 230000000694 effects Effects 0.000 description 7
- 238000004364 calculation method Methods 0.000 description 3
- 238000005457 optimization Methods 0.000 description 3
- 238000004458 analytical method Methods 0.000 description 2
- 230000033228 biological regulation Effects 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 230000008569 process Effects 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 101100289792 Squirrel monkey polyomavirus large T gene Proteins 0.000 description 1
- 230000001133 acceleration Effects 0.000 description 1
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 238000013508 migration Methods 0.000 description 1
- 230000005012 migration Effects 0.000 description 1
- 239000000126 substance Substances 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Computing Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Biomedical Technology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本公开提供了一种分布式模型协同训练方法,包括:在用户端使用本地样本数据进行模型的初始训练,其中本地样本数据具有硬标签;在该用户端将初始训练的结果分类,并获取各个类别的类别特征;将各个类别的类别特征传送至服务端;在服务端聚合多个用户端的各个类别的类别特征,并按类别获取每个类别的全局类别特征作为该类别的软标签;从服务端向多个用户端下发该软标签;以及在各个用户端根据软标签和硬标签进一步训练该模型至收敛。
Description
技术领域
本公开主要涉及机器学习,尤其涉及机器学习中的模型训练。
背景技术
联邦机器学习是具有隐私保护效果的分布式机器学习框架,能有效帮助多个用户在满足隐私保护、数据安全和政府法规的要求下,进行数据使用和机器学习建模。联邦学习作为分布式的机器学习范式,可以有效解决数据孤岛问题,让参与方在不共享数据的基础上联合建模,实现智能协作。
为了在不共享数据的情况下协作训练模型,用户需要上传模型参数更新,然而随着深度学习的不断发展,模型越来越复杂,规模越来越大,联邦学习中的通信开销将不受限制地增长,这将严重阻碍联邦学习的进展。
因此,本领域需要高效的联邦学习框架,在保证用户隐私的前提下提高模型的训练效率。
发明内容
为解决上述技术问题,本公开提供了一种分布式模型协同训练方案,该方案能够通过不同用户端之间的协同知识蒸馏,在保证用户隐私的前提下,仅传输低维模型知识,显著降低通信开销,并支持大规模模型的训练优化,极大提升了联邦学习的隐私计算效率。
在本公开一实施例中,提供了一种分布式模型协同训练方法,包括:在用户端使用本地样本数据进行模型的初始训练,其中本地样本数据具有硬标签;在用户端将初始训练的结果分类,并获取各个类别的类别特征;将各个类别的类别特征传送至服务端;在服务端聚合多个用户端的各个类别的类别特征,并按类别获取每个类别的全局类别特征作为类别的软标签;从服务端向多个用户端下发软标签;以及在各个用户端根据软标签和硬标签进一步训练模型至收敛。
在本公开另一实施例中,服务端部署有教师网络,而多个用户端部署有学生网络。
在本公开又一实施例中,各个类别的类别特征包括类别特征向量和类别权重。
在本公开另一实施例中,在各个用户端根据软标签和硬标签进一步训练模型至收敛包括采用模型输出和硬标签之间的交叉熵损失以及模型输出和软标签之间的KL散度损失的损失函数。
在本公开又一实施例中,在各个用户端根据软标签和硬标签进一步训练模型至收敛包括采用模型输出和硬标签之间的交叉熵损失以及模型输出和软标签之间的交叉熵损失的损失函数。
在本公开另一实施例中,类别特征向量是通过将类别的结果求平均获得的。
在本公开又一实施例中,按类别获取每个类别的全局类别特征是通过将所聚合多个用户端的类别的类别特征求平均获得的。
在本公开另一实施例中,求平均是加权平均或随机加权平均。
在本公开一实施例中,一种分布式模型协同训练系统,包括:训练模块,在用户端使用本地样本数据进行模型的初始训练,其中本地样本数据具有硬标签;特征上传模块,在用户端将初始训练的结果分类,并获取各个类别的类别特征,并且将各个类别的类别特征传送至服务端;以及标签下发模块,在服务端聚合多个用户端的各个类别的类别特征,按类别获取每个类别的全局类别特征作为类别的软标签,并从服务端向多个用户端下发软标签;其中训练模块在各个用户端根据软标签和硬标签进一步训练模型至收敛。
在本公开另一实施例中,服务端部署有教师网络,而多个用户端部署有学生网络。
在本公开又一实施例中,各个类别的类别特征包括类别特征向量和类别权重。
在本公开又一实施例中,训练模块在各个用户端根据软标签和硬标签进一步训练模型至收敛包括训练模块采用模型输出和硬标签之间的交叉熵损失以及模型输出和软标签之间的KL散度损失的损失函数。
在本公开另一实施例中,训练模块在各个用户端根据软标签和硬标签进一步训练模型至收敛包括训练模块采用模型输出和硬标签之间的交叉熵损失以及模型输出和软标签之间的交叉熵损失的损失函数。
在本公开另一实施例中,类别特征向量是通过将类别的结果求平均获得的。
在本公开又一实施例中,标签下发模块按类别获取每个类别的全局类别特征是通过标签下发模块将所聚合多个用户端的类别的类别特征求平均获得的。
在本公开另一实施例中,求平均是加权平均或随机加权平均。
在本公开一实施例中,提供了一种存储有指令的计算机可读存储介质,当这些指令被执行时使得机器执行如前所述的方法。
提供本概述以便以简化的形式介绍以下在详细描述中进一步描述的一些概念。本概述并不旨在标识所要求保护主题的关键特征或必要特征,也不旨在用于限制所要求保护主题的范围。
附图说明
本公开的以上发明内容以及下面的具体实施方式在结合附图阅读时会得到更好的理解。需要说明的是,附图仅作为所请求保护的发明的示例。在附图中,相同的附图标记代表相同或类似的元素。
图1是示出知识协同蒸馏模型框架的示意图;
图2是示出根据本公开一实施例的分布式模型协同训练方法的流程图;
图3是示出根据本公开一实施例的知识协同蒸馏模型框架在用户端和服务端的部署的示意图;
图4是示出根据本公开一实施例的基于分布式协同蒸馏的模型训练框架的示意图;
图5是示出根据本公开一实施例的分布式模型协同训练系统的框图。
具体实施方式
为使得本公开的上述目的、特征和优点能更加明显易懂,以下结合附图对本公开的具体实施方式作详细说明。
在下面的描述中阐述了很多具体细节以便于充分理解本公开,但是本公开还可以采用其它不同于在此描述的其它方式来实施,因此本公开不受下文公开的具体实施例的限制。
联邦机器学习是一种具有隐私保护效果的分布式机器学习框架,能有效帮助多个用户在满足隐私保护、数据安全和政府法规的要求下,进行数据使用和机器学习建模。联邦学习作为分布式的机器学习范式,可以有效解决数据孤岛问题,让参与方在不共享数据的基础上联合建模,实现智能协作。为了在不共享数据的情况下协作训练模型,用户需要上传模型参数更新,然而随着深度学习的不断发展,模型越来越复杂,规模越来越大,联邦学习中的通信开销将不受限制地增长,这将严重阻碍联邦学习的进展。
在工业应用中,通常要求模型要有好的预测,同时希望部署到应用中的模型使用较少的计算资源(存储空间、计算单元等),产生较低的时延。基于以上考虑,在联邦学习的应用中就有了模型压缩的动机:即希望有一个规模较小的模型,能达到和大模型一样或相当的结果。因此,模型压缩希望能先训练一个大而强的模型,然后将其包含的知识转移给小的模型。
知识蒸馏是一种简单有效的模型压缩/训练方法。知识蒸馏采用精度高、模型复杂度高的模型即Teacher网络(教师网络,下文中简称为T网络)的输出训练Student网络(学生网络,下文中简称为S网络),以期使计算量小参数少的小网络精度提升。使用知识蒸馏后的S网络能够达到较高的精度,而且更有利于实际应用部署,尤其是在移动设备中。
本公开提供一种分布式模型协同训练方案,该方案采用基于协同蒸馏的高效联邦学习框架,通过不同用户之间的协同知识蒸馏,在保证用户隐私的前提下,仅传输低维模型知识,显著降低通信开销,并支持大规模模型的训练优化,极大提升了联邦学习的隐私计算效率。
图1是示出知识蒸馏模型框架的示意图。
知识蒸馏指的是对网络模型中参数权重的一些抽取/迁移的操作。知识蒸馏的目的在于将T网络的知识蒸馏到S网络中去,其中T网络是效果比较好(比如拥有良好的性能和泛化能力)、但参数量大、计算成本高的网络,S网络是比较小的网络、表达能力有限。如果单单只用S网络去学习、去训练,那么它的训练效果不是很好,因为其表征能力有限,而用了T网络来指导,则可以提升S网络的模型效果,但是参数数量大幅降低,从而实现模型压缩与加速。
通常,神经网络使用softmax层来实现logits向概率值的转换。原始的softmax函数为:
其中qi是每个类别输出的概率,zi是每个类别输出的logits。
但是,直接使用softmax层的输出值作为软标签会带来的问题是:当softmax输出的概率分布熵相对较小时,负标签的值都接近于0,对损失函数的贡献非常小,小到可以忽略不计。
在知识蒸馏中,需要使用高温将知识“蒸馏”出来。此时,加了温度变量之后的softmax函数:
其中qi是每个类别输出的概率,zi是每个类别输出的logits,T是温度。当温度T=1时,即为标准Softmax公式。T越高,softmax的输出概率分布越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。
在如图1所示的知识蒸馏模型框架中,先训练好T模型,利用高温Thigh(即大T)产生软标签,然后使用{软标签,Thigh}和{硬标签,T=1}同时训练S模型,最后设置温度T=1,使得S模型线上进行推理。
S网络不仅要跟真实数据的标签(Hard Target,硬标签)作比较,而且要和T网络的输出(Soft Target,软标签)做比较。软标签作为正则化项约束S网络中参数的分布,真实数据的硬标签是离散的,而软标签给出了类别之间的关联性。
如图1所示的知识蒸馏模型框架的损失函数由蒸馏损失(Lsoft,对应软标签,为softmax输出与软标签的误差)和学生损失(Lhard,对应硬标签,为softmax输出与硬标签的误差)加权得到。如下所示:
L=αLSOft+βLhard
在本公开一实施例中,Lsoft可以是S网络模型在高温T下的softmax输出和软标签的交叉熵,而Lhard可以是S网络模型在T=1下的softmax输出和硬标签(ground truth值)的交叉熵,如下所示:
其中,指T网络模型在温度等于T的条件下softmax输出在第i类上的值。指S网络模型在温度等于T的条件下softmax输出在第i类上的值(参见式1),ci指在第i类上的ground truth值,ci∈(0,1),正标签取1,负标签取0。
Lhard的必要性在于:T网络模型也有一定的错误率,使用ground truth可以有效降低错误被传播给S网络模型的可能性。如果S网络模型在T网络模型的教授之外,可以同时参考到标准,就可以有效地降低被T网络模型偶尔的错误“带偏”的可能性。
最后,α和β是关于Lsoft和Lhard的权重。通常而言,当Lhard权重较小时,能产生较好的效果。
图2是示出根据本公开一实施例的分布式模型协同训练方法200的流程图。
为了降低联邦学习的通信开销,当前采用的模型压缩方案通过本地用户对需要上传的模型进行压缩(例如dropout,count sketch等方法)。然而,直接针对模型参数的压缩都是有损压缩,会导致出现全局模型收敛速度降低,收敛后的模型精度降低等问题。这样的方案并不能满足复杂业务场景下业务应用的需要。
本公开的分布式模型协同训练方法采用知识蒸馏模型框架,并不直接对模型参数进行压缩,而是对每个用户端的模型的类别知识(也就是每个模型输出的类别标签)进行提炼蒸馏,并将该知识传输给服务端,由于该知识只和样本类别数量相关,而类别数量大小远远小于模型的参数数量,因此本方案可以极大降低传输成本,并且不会造成模型的精度损失。
在本公开的分布式模型协同训练方法中存在服务端和用户端,服务端进行用户信息的聚合和分析,并将分析结果下传给用户端,用户端利用聚合的信息、分析结果和自己的数据对模型进行本地更新,并将更新信息反馈给服务端。其中图1所述的T网络模型部署在服务端,而S网络模型部署在用户端。以下结合图2的流程图描述根据本公开一实施例的分布式模型协同训练方法。
在202,在用户端使用本地样本数据进行模型的初始训练,其中本地样本数据具有硬标签。
每个用户端在本地使用各自的样本数据训练几轮模型,不用训练至收敛。这些本地样本数据具有硬标签(Hard Target Hi)。硬标签指ground truth值,正标签取1,负标签取0。
在204,在用户端将初始训练的结果分类,并获取各个类别的类别特征。
用户端记录所有本地样本数据经过模型的输出结果(即softmax层的结果),然后将所有结果按类别进行分组,并求出每个组内的平均值作为该类别的特征向量。
在本公开一实施例中,该求平均值可以是求加权平均。
在本公开另一实施例中,该求平均值可以是求随机加权平均。
进一步地,用户端计算不同类别的样本数量的比例,得到每个类别的权重。
由此,用户端获得了各个类别的类别特征,即每个类别的特征向量和权重。
在206,用户端将各个类别的类别特征传送至服务端。
用户端将包括类别权重和类别特征向量的类别特征上传至服务端。
在208,在服务端聚合多个用户端的各个类别的类别特征,并按类别获取每个类别的全局类别特征作为类别的软标签。
服务端对所有用户端上传的类别特征向量按类别分组,然后在每个组内依据类别权重计算加权平均值,该加权平均值即为每个类别的全局类别特征。
该每个类别的全局类别特征即作为每个类别的软标签(Soft Target/Label Si)。
在210,从服务端向多个用户端下发软标签。
每个类别的软标签被从服务端下发至各个用户端。
在212,在各个用户端根据软标签和硬标签进一步训练模型至收敛。
各个用户端根据软标签和硬标签训练模型,将模型训练至收敛。
所采用的损失函数(即总损失)可包括两个部分。在本公开的分布式模型协同训练方案中,损失函数的设计实际上是基于最大化T网络模型和S网络模型的信息熵之间的互信息(Mutual Information)的目标。本领域技术人员可以理解,互信息是度量两个事件集合之间的相关性(Mutual Dependence),其捕捉到的是集合之间非线性的统计相关性。两个事件集合之间的互信息越大,其相关性越强。
在本公开一实施例中,如前所述,损失函数的第一部分是模型输出和硬标签之间的交叉熵损失,第二部分是模型输出和软标签之间的交叉熵损失。
在本公开另一实施例中,损失函数的第一部分是模型输出和硬标签之间的交叉熵损失,第二部分是模型输出和软标签之间的kl散度损失。
交叉熵主要刻画的是实际输出与期望输出的距离,也就是交叉熵的值越小,两个概率分布就越接近。而KL散度(Kullback-Leibler divergence)又称相对熵,用来描述两个概率分布P和Q的差异:
其中,pi是期望分布中每个类别输出的概率,qi是实际分布中每个类别输出的概率。
kl散度衡量了分布差异,kl散度的本质是交叉熵减信息熵,即,使用编码估计分布所需的bit数、与编码真实分布所需的最少bit数的差。当且仅当估计分布与真实分布相同时,kl散度为0。因此可以作为两个分布差异的衡量方法。
由此,本公开的分布式模型协同训练方法在不同用户端之间传输的是类别向量而不是模型参数向量,数据维度显著降低,因此极大提升了模型的通信效率,加速了模型训练。
图3是示出根据本公开一实施例的知识协同蒸馏模型框架在用户端和服务端的部署的示意图。
知识蒸馏在模型训练中的应用通常是先行离线训练好T网络模型,T网络模型的知识离线蒸馏转移至S网络模型,引导训练S网络模型,最后是训练好的S网络模型上线工作。这样的离线蒸馏中,往往T网络模型容量大、模型复杂,需要大量训练时间,当T网络模型和S网络模型之间的容量差异过大时,S网络可能很难学习好这些知识。但实际应用中,T网络模型和S网络模型之间的容量差异往往比较大。
本公开的分布式模型协同训练方案采用的是在线蒸馏。即,T网络模型和S网络模型都在线,其中小而弱的S网络模型部署在用户端,大且强的T网络模型部署在服务端。
在线蒸馏中的集中式协同蒸馏技术需要每个用户端拥有完全相同的样本,在服务端聚合各个用户端模型的所有参数向量和/或权重,由此模型之间的通信效率不高,并且当用户端的数据是非独立同分布时无法应用。
在本公开的分布式框架下,数据无法互通,并且不同用户端之间的数据并非是同分布的,因此,需要有效保护每个用户端的隐私数据,同时可以实现模型的协同训练优化。
如图3所示,用户端先将各自的带硬标签样本数据输入本地的S网络模型,进行初始训练,然后将输出结果(即Softmax层的结果)分类并获取每一类别的类别特征,包括类别特征向量和权重,上传至服务端。
服务端聚合多个用户端的类别特征,按类别分组,并针对每个类别计算出该类别跨这些用户端的全局类别特征。然后以高温T蒸馏出该类别的软标签,下发给各个用户端。
用户端接收到软标签之后,根据硬标签和软标签训练模型,其损失函数的两个部分分别为模型输出与硬标签之间的损失、以及模型输出与软标签之间的损失。
由此,本公开的分布式模型协同训练方案采用的是分布式协同蒸馏,其中采用类别向量进行知识的传输,在不同用户端传输的是类别向量而非模型的所有参数向量,数据维度显著降低,因此极大地提升了模型的通信效率,加速了模型训练。
图4是示出根据本公开一实施例的基于分布式协同蒸馏的模型训练框架的示意图。
如图4所示,小而弱的S网络模型分别部署在用户端1和用户端2,而大且强的T网络模型部署在服务端。本领域技术人员可以理解,用户端可以有多个,不限于图4中所示的个数。
在如图4所示的实施例中,用户端1或用户端2对本地的S网络模型进行初始训练,然后基于输出结果的分类计算每一类别的类别特征向量和类别权重,上传至服务端。
服务端聚合用户端1、用户端2以及其他用户端的类别特征向量和类别权重,用类别权重对类别特征向量进行加权平均,计算出该类别跨这些用户端的全局类别特征,将其作为该类别的软标签,下发给这些用户端。
由此,本公开的分布式模型协同训练方案采用了分布式协同蒸馏技术,避免了在服务端对各个用户模型权重的聚合,因此不同用户的非独立同分布数据并不会影响模型的收敛,也就是说,本公开的分布式模型协同训练方案在非独立同分布的隐私数据上的性能表现更好。
图5是示出根据本公开一实施例的分布式模型协同训练系统500的框图。
根据本公开一实施例的分布式模型协同训练系统500包括训练模块502、特征上传模块506和软标签下发模块508。
训练模块502在用户端接收到具有硬标签的本地样本数据之后进行模型的初始训练。每个用户端在本地使用各自的样本数据训练几轮模型,不用训练至收敛。这些本地样本数据具有硬标签(Hard Target Hi)。训练模块502输出初始训练的结果。
特征上传模块506在用户端对初始训练的结果分类,并获取各个类别的类别特征。特征上传模块506将所有本地样本数据经过模型的输出结果(即softmax层的结果)按类别进行分组,并求出每个组内的平均值作为该类别的特征向量。在本公开一实施例中,该求平均值可以是求加权平均。在本公开另一实施例中,该求平均值可以是求随机加权平均。
进一步地,特征上传模块506在用户端计算不同类别的样本数量的比例,得到每个类别的权重。
由此,特征上传模块506获得了各个类别的类别特征,即每个类别的特征向量和权重。特征上传模块506将各个类别的类别特征传送至服务端。
软标签下发模块508在服务端聚合多个用户端的各个类别的类别特征,并按类别获取每个类别的全局类别特征作为类别的软标签。
软标签下发模块508在服务端对所有用户端上传的类别特征向量按类别分组,然后在每个组内依据类别权重计算加权平均值,该加权平均值即为每个类别的全局类别特征。该每个类别的全局类别特征即作为每个类别的软标签(Soft Target/Label Si)。
软标签下发模块508从服务端向多个用户端下发软标签。每个类别的软标签被从服务端下发至各个用户端。
训练模块502在各个用户端根据软标签和硬标签进一步训练模型至收敛。
训练模块502根据软标签和硬标签训练各个用户端的模型,将模型训练至收敛。
所采用的损失函数(即总损失)可包括两个部分。在本公开一实施例中,如前所述,损失函数的第一部分是模型输出和硬标签之间的交叉熵损失,第二部分是模型输出和软标签之间的交叉熵损失。
在本公开另一实施例中,损失函数的第一部分是模型输出和硬标签之间的交叉熵损失,第二部分是模型输出和软标签之间的kl散度损失。
本公开的分布式模型协同训练系统采用知识蒸馏模型框架,并不直接对模型参数进行压缩,而是对每个用户端的模型的类别知识(也就是每个模型输出的类别标签)进行提炼蒸馏,并将该知识传输给服务端,由于该知识只和样本类别数量相关,而类别数量大小远远小于模型的参数数量,因此本方案可以极大降低传输成本,并且不会造成模型的精度损失。
以上描述的分布式模型协同训练方法和系统的各个步骤和模块可以用硬件、软件、或其组合来实现。如果在硬件中实现,结合本发明描述的各种说明性步骤、模块、以及电路可用通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现场可编程门阵列(FPGA)、或其他可编程逻辑组件、硬件组件、或其任何组合来实现或执行。通用处理器可以是处理器、微处理器、控制器、微控制器、或状态机等。如果在软件中实现,则结合本发明描述的各种说明性步骤、模块可以作为一条或多条指令或代码存储在计算机可读介质上或进行传送。实现本发明的各种操作的软件模块可驻留在存储介质中,如RAM、闪存、ROM、EPROM、EEPROM、寄存器、硬盘、可移动盘、CD-ROM、云存储等。存储介质可耦合到处理器以使得该处理器能从/向该存储介质读写信息,并执行相应的程序模块以实现本发明的各个步骤。而且,基于软件的实施例可以通过适当的通信手段被上载、下载或远程地访问。这种适当的通信手段包括例如互联网、万维网、内联网、软件应用、电缆(包括光纤电缆)、磁通信、电磁通信(包括RF、微波和红外通信)、电子通信或者其他这样的通信手段。
还应注意,这些实施例可能是作为被描绘为流程图、流图、结构图、或框图的过程来描述的。尽管流程图可能会把诸操作描述为顺序过程,但是这些操作中有许多操作能够并行或并发地执行。另外,这些操作的次序可被重新安排。
所公开的方法、装置和系统不应以任何方式被限制。相反,本发明涵盖各种所公开的实施例(单独和彼此的各种组合和子组合)的所有新颖和非显而易见的特征和方面。所公开的方法、装置和系统不限于任何具体方面或特征或它们的组合,所公开的任何实施例也不要求存在任一个或多个具体优点或者解决特定或所有技术问题。
上面结合附图对本发明的实施例进行了描述,但是本发明并不局限于上述的具体实施方式,上述的具体实施方式仅仅是示意性的,而不是限制性的,本领域的普通技术人员在本发明的启示下,在不脱离本发明宗旨和权利要求所保护的范围情况下,还可做出很多更改,这些均落在本发明的保护范围之内。
Claims (17)
1.一种分布式模型协同训练方法,包括:
在用户端使用本地样本数据进行模型的初始训练,其中所述本地样本数据具有硬标签;
在所述用户端将所述初始训练的结果分类,并获取各个类别的类别特征;
将所述各个类别的类别特征传送至服务端;
在所述服务端聚合多个用户端的各个类别的类别特征,并按类别获取每个类别的全局类别特征作为所述类别的软标签;
从所述服务端向所述多个用户端下发所述软标签;以及
在各个用户端根据所述软标签和所述硬标签进一步训练所述模型至收敛。
2.如权利要求1所述的方法,所述服务端部署有教师网络,而所述多个用户端部署有学生网络。
3.如权利要求1所述的方法,所述各个类别的类别特征包括类别特征向量和类别权重。
4.如权利要求1所述的方法,在各个用户端根据所述软标签和所述硬标签进一步训练所述模型至收敛包括采用模型输出和所述硬标签之间的交叉熵损失以及模型输出和所述软标签之间的KL散度损失的损失函数。
5.如权利要求1所述的方法,在各个用户端根据所述软标签和所述硬标签进一步训练所述模型至收敛包括采用模型输出和所述硬标签之间的交叉熵损失以及模型输出和所述软标签之间的交叉熵损失的损失函数。
6.如权利要求3所述的方法,所述类别特征向量是通过将所述类别的结果求平均获得的。
7.如权利要求1所述的方法,按类别获取每个类别的全局类别特征是通过将所聚合多个用户端的所述类别的类别特征求平均获得的。
8.如权利要求6或7所述的方法,所述求平均是加权平均或随机加权平均。
9.一种分布式模型协同训练系统,包括:
训练模块,在用户端使用本地样本数据进行模型的初始训练,其中所述本地样本数据具有硬标签;
特征上传模块,在所述用户端将所述初始训练的结果分类,并获取各个类别的类别特征,并且将所述各个类别的类别特征传送至服务端;以及
标签下发模块,在所述服务端聚合多个用户端的各个类别的类别特征,按类别获取每个类别的全局类别特征作为所述类别的软标签,并从所述服务端向所述多个用户端下发所述软标签;
其中所述训练模块在各个用户端根据所述软标签和所述硬标签进一步训练所述模型至收敛。
10.如权利要求9所述的系统,所述服务端部署有教师网络,而所述多个用户端部署有学生网络。
11.如权利要求9所述的系统,所述各个类别的类别特征包括类别特征向量和类别权重。
12.如权利要求9所述的系统,所述训练模块在各个用户端根据所述软标签和所述硬标签进一步训练所述模型至收敛包括所述训练模块采用模型输出和所述硬标签之间的交叉熵损失以及模型输出和所述软标签之间的KL散度损失的损失函数。
13.如权利要求9所述的系统,所述训练模块在各个用户端根据所述软标签和所述硬标签进一步训练所述模型至收敛包括所述训练模块采用模型输出和所述硬标签之间的交叉熵损失以及模型输出和所述软标签之间的交叉熵损失的损失函数。
14.如权利要求11所述的系统,所述类别特征向量是通过将所述类别的结果求平均获得的。
15.如权利要求9所述的系统,所述标签下发模块按类别获取每个类别的全局类别特征是通过所述求平均获得的。
16.如权利要求14或15所述的系统,所述求平均是加权平均或随机加权平均。
17.一种存储有指令的计算机可读存储介质,当所述指令被执行时使得机器执行如权利要求1-7中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210272759.XA CN114626550A (zh) | 2022-03-18 | 2022-03-18 | 分布式模型协同训练方法和系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210272759.XA CN114626550A (zh) | 2022-03-18 | 2022-03-18 | 分布式模型协同训练方法和系统 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114626550A true CN114626550A (zh) | 2022-06-14 |
Family
ID=81901181
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210272759.XA Pending CN114626550A (zh) | 2022-03-18 | 2022-03-18 | 分布式模型协同训练方法和系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114626550A (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115098885A (zh) * | 2022-07-28 | 2022-09-23 | 清华大学 | 数据处理方法、系统及电子设备 |
CN115271033A (zh) * | 2022-07-05 | 2022-11-01 | 西南财经大学 | 基于联邦知识蒸馏医学图像处理模型构建及其处理方法 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190205748A1 (en) * | 2018-01-02 | 2019-07-04 | International Business Machines Corporation | Soft label generation for knowledge distillation |
CN110223281A (zh) * | 2019-06-06 | 2019-09-10 | 东北大学 | 一种数据集中含有不确定数据时的肺结节图像分类方法 |
CN112486686A (zh) * | 2020-11-30 | 2021-03-12 | 之江实验室 | 基于云边协同的定制化深度神经网络模型压缩方法及系统 |
CN113408209A (zh) * | 2021-06-28 | 2021-09-17 | 淮安集略科技有限公司 | 跨样本联邦分类建模方法及装置、存储介质、电子设备 |
CN113705610A (zh) * | 2021-07-26 | 2021-11-26 | 广州大学 | 一种基于联邦学习的异构模型聚合方法和系统 |
CN114118158A (zh) * | 2021-11-30 | 2022-03-01 | 西安电子科技大学 | 反黑盒探测攻击的稳健电磁信号调制类型识别方法 |
CN114154643A (zh) * | 2021-11-09 | 2022-03-08 | 浙江师范大学 | 基于联邦蒸馏的联邦学习模型的训练方法、系统和介质 |
-
2022
- 2022-03-18 CN CN202210272759.XA patent/CN114626550A/zh active Pending
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190205748A1 (en) * | 2018-01-02 | 2019-07-04 | International Business Machines Corporation | Soft label generation for knowledge distillation |
CN110223281A (zh) * | 2019-06-06 | 2019-09-10 | 东北大学 | 一种数据集中含有不确定数据时的肺结节图像分类方法 |
CN112486686A (zh) * | 2020-11-30 | 2021-03-12 | 之江实验室 | 基于云边协同的定制化深度神经网络模型压缩方法及系统 |
CN113408209A (zh) * | 2021-06-28 | 2021-09-17 | 淮安集略科技有限公司 | 跨样本联邦分类建模方法及装置、存储介质、电子设备 |
CN113705610A (zh) * | 2021-07-26 | 2021-11-26 | 广州大学 | 一种基于联邦学习的异构模型聚合方法和系统 |
CN114154643A (zh) * | 2021-11-09 | 2022-03-08 | 浙江师范大学 | 基于联邦蒸馏的联邦学习模型的训练方法、系统和介质 |
CN114118158A (zh) * | 2021-11-30 | 2022-03-01 | 西安电子科技大学 | 反黑盒探测攻击的稳健电磁信号调制类型识别方法 |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115271033A (zh) * | 2022-07-05 | 2022-11-01 | 西南财经大学 | 基于联邦知识蒸馏医学图像处理模型构建及其处理方法 |
CN115271033B (zh) * | 2022-07-05 | 2023-11-21 | 西南财经大学 | 基于联邦知识蒸馏医学图像处理模型构建及其处理方法 |
CN115098885A (zh) * | 2022-07-28 | 2022-09-23 | 清华大学 | 数据处理方法、系统及电子设备 |
CN115098885B (zh) * | 2022-07-28 | 2022-11-04 | 清华大学 | 数据处理方法、系统及电子设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114626550A (zh) | 分布式模型协同训练方法和系统 | |
CN114580663A (zh) | 面向数据非独立同分布场景的联邦学习方法及系统 | |
Zhang et al. | Prediction for network traffic of radial basis function neural network model based on improved particle swarm optimization algorithm | |
CN114091667A (zh) | 一种面向非独立同分布数据的联邦互学习模型训练方法 | |
CN112784920A (zh) | 云边端协同的旋转部件对抗域自适应故障诊断方法 | |
CN115374853A (zh) | 基于T-Step聚合算法的异步联邦学习方法及系统 | |
Dai et al. | Hybrid deep model for human behavior understanding on industrial internet of video things | |
CN115659254A (zh) | 一种双模态特征融合的配电网电能质量扰动分析方法 | |
CN112766603A (zh) | 一种交通流量预测方法、系统、计算机设备及存储介质 | |
CN115858675A (zh) | 基于联邦学习框架的非独立同分布数据处理方法 | |
Wang et al. | Deep joint source-channel coding for multi-task network | |
CN115146764A (zh) | 一种预测模型的训练方法、装置、电子设备及存储介质 | |
Zhou | Deep embedded clustering with adversarial distribution adaptation | |
Wang et al. | Knowledge-enhanced semi-supervised federated learning for aggregating heterogeneous lightweight clients in iot | |
CN116244484B (zh) | 一种面向不平衡数据的联邦跨模态检索方法及系统 | |
CN113887806B (zh) | 长尾级联流行度预测模型、训练方法及预测方法 | |
CN111652021A (zh) | 一种基于bp神经网络的人脸识别方法及系统 | |
CN115238749A (zh) | 一种基于Transformer的特征融合的调制识别方法 | |
CN115965078A (zh) | 分类预测模型训练方法、分类预测方法、设备及存储介质 | |
CN108734291A (zh) | 一种利用正确性反馈训练神经网络的伪标签生成器 | |
CN115348182A (zh) | 一种基于深度堆栈自编码器的长期频谱预测方法 | |
Hu et al. | Learning Multi-expert Distribution Calibration for Long-tailed Video Classification | |
CN110163249B (zh) | 基于用户参数特征的基站分类识别方法及系统 | |
CN113033653A (zh) | 一种边-云协同的深度神经网络模型训练方法 | |
Lee et al. | Application of end-to-end deep learning in wireless communications systems |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |