CN117973505A - 一种联邦学习模型训练方法、装置、电子设备和存储介质 - Google Patents
一种联邦学习模型训练方法、装置、电子设备和存储介质 Download PDFInfo
- Publication number
- CN117973505A CN117973505A CN202410267744.3A CN202410267744A CN117973505A CN 117973505 A CN117973505 A CN 117973505A CN 202410267744 A CN202410267744 A CN 202410267744A CN 117973505 A CN117973505 A CN 117973505A
- Authority
- CN
- China
- Prior art keywords
- model
- sample data
- pseudo sample
- federal learning
- learning model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 88
- 238000000034 method Methods 0.000 title claims abstract description 76
- 238000003860 storage Methods 0.000 title claims abstract description 17
- 238000005070 sampling Methods 0.000 claims abstract description 78
- 230000006854 communication Effects 0.000 claims description 130
- 238000004891 communication Methods 0.000 claims description 129
- 230000006870 function Effects 0.000 claims description 105
- 238000005457 optimization Methods 0.000 claims description 20
- 230000000694 effects Effects 0.000 claims description 18
- 238000007906 compression Methods 0.000 claims description 17
- 230000006835 compression Effects 0.000 claims description 17
- 238000004590 computer program Methods 0.000 claims description 16
- 230000008485 antagonism Effects 0.000 claims description 13
- 230000004044 response Effects 0.000 claims description 6
- 230000008569 process Effects 0.000 description 19
- 238000010586 diagram Methods 0.000 description 13
- 238000013140 knowledge distillation Methods 0.000 description 12
- 238000004821 distillation Methods 0.000 description 11
- 238000012545 processing Methods 0.000 description 9
- 238000009826 distribution Methods 0.000 description 4
- 230000005540 biological transmission Effects 0.000 description 3
- 238000004364 calculation method Methods 0.000 description 3
- 238000013144 data compression Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000011478 gradient descent method Methods 0.000 description 2
- 238000007726 management method Methods 0.000 description 2
- 238000012795 verification Methods 0.000 description 2
- 208000009119 Giant Axonal Neuropathy Diseases 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 239000002131 composite material Substances 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 201000003382 giant axonal neuropathy 1 Diseases 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012423 maintenance Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 238000013139 quantization Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明涉及一种联邦学习模型训练方法、装置、电子设备和存储介质,联邦学习模型包括学生模型和预训练的教师模型;包括:对联邦学习模型的嵌入层进行采样,生成伪样本数据;基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;依据所述目标模型参数更新所述联邦学习模型;本发明实施例可以在保护隐私的同时,通过嵌入结构的样本提取和学习一组伪数据来提高联邦学习模型的性能。
Description
技术领域
本发明涉及神经网络模型训练技术领域,具体涉及一种联邦学习模型训练方法、一种联邦学习模型训练装置、一种电子设备和一种存储介质。
背景技术
随着技术的发展,大型模型在车端的应用正在快速扩展。如说利用大型模型来分析驾驶者和乘客的行为和偏好,从而提供个性化的娱乐、导航和车内服务。包括语音助手、智能推荐系统、驾驶习惯分析等。而且随着技术的进步,特别是在人工智能、机器学习、传感器技术和数据处理方面,大型模型在车端的应用将继续扩展,为驾驶安全、效率、舒适性和车辆维护带来革命性的变化。
然而,目前训练大模型需要用到大量真实数据才能提升大模型泛化能力,且目前大部分预训练的大模型已经涵盖了开源的数据,因此要得到满足使用要求的大型模型,需为该大型模型提供更多的领域数据。但目前训练大模型的数据因涉及数据安全及成本等问题,往往很难得到大量的优质数据供大模型进行训练导致数据隐私的差。而如常用的基于差分的方式进行基于知识蒸馏的过程,使用梯度裁剪和加噪声等步骤,模型训练的过程复杂,效率低下。
发明内容
本发明的目的之一在于提供一种联邦学习模型训练方法,以解决现有技术中的模型训练的过程复杂,效率低下和隐私性差的问题;目的之二在于提供一种联邦学习模型训练装置;目的之三在于提供一种电子设备;目的之四在于提供一种存储介质。
为了实现上述目的,本发明采用的技术方案如下:
一种联邦学习模型训练方法,联邦学习模型包括学生模型和预训练的教师模型;所述方法包括:
对联邦学习模型的嵌入层进行采样,生成伪样本数据;
基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
依据所述目标模型参数更新所述联邦学习模型。
可选的,所述方法还包括:
获取所述教师模型与所述学生模型之间的通信指标参数和性能指标参数;
依据所述性能指标参数和所述通信指标参数,确定通信效率函数值;
依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信。
可选的,所述方法还包括:
响应于所述联邦学习模型的更新,生成迭代计数值;
基于更新后的联邦学习模型,循环执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,直至所述迭代计数值满足预设迭代条件。
可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:
对联邦学习模型的嵌入层进行随机采样,生成初始样本集;
对所述初始样本集进行目标采样,生成伪样本数据。
可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:
依据所述教师模型与所述学生模型,确定对抗效应强度;
依据所述对抗效应强度和所述伪样本数据建立对抗损失函数;
对所述对抗损失函数进行梯度下降优化,得到目标伪样本;
采用所述目标伪样本更新所述伪样本数据。
可选的,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:
基于所述联邦学习模型的嵌入层的交叉熵和所述初始样本集建立目标损失函数;
对所述目标损失函数进行梯度下降优化,生成伪样本数据。
可选的,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:
依据所述教师模型的模型参数和前向传播函数计算所述伪样本数据,生成初始分类值;
对所述初始分类值进行归一化,生成所述伪样本数据对应的软标签。
可选的,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数的步骤包括:
采用所述伪样本数据训练所述学生模型,生成输出数据;
依据所述输出数据和所述软标签建立交叉熵损失函数;
循环计算所述交叉熵损失函数直至满足预设次数,生成目标模型参数。
可选的,所述方法还包括:
依据所述通信指标参数,确定压缩策略;
在所述教师模型与所述学生模型通信时,执行所述压缩策略。
可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定通讯轮次。
可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定异步通讯策略。
一种联邦学习模型训练装置,联邦学习模型包括学生模型和预训练的教师模型;所述装置包括:
采样模块,用于对联邦学习模型的嵌入层进行采样,生成伪样本数据;
分类模块,用于基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
训练模块,用于采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
更新模块,用于依据所述目标模型参数更新所述联邦学习模型。
一种电子设备,包括处理器、存储器及存储在所述存储器上并能够在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现如上所述的联邦学习模型训练方法的步骤。
一种计算机可读存储介质,所述计算机可读存储介质上存储计算机程序,所述计算机程序被处理器执行时实现如上所述的联邦学习模型训练方法的步骤。
本发明的有益效果:
(1)本发明通过在对联邦学习模型的嵌入层进行采样,生成用于知识蒸馏的伪样本数据,避免了对真实数据的依赖和使用;知识蒸馏过程不需要使用教师模型的实际数据,强化了整个联邦学习模型在联邦学习过程中的数据隐私保护;进一步地,通过多样性随机采样方法得到对伪样本数据,通过这些伪样本数进行训练可以增强模型的泛化能力。
(2)本发明在蒸馏的过程中,无需依赖生成对抗网络或辅助数据,减轻了计算和通信的负担;提高了训练的效率;不依赖于多个中间模型的协同训练和蒸馏,从而可能减少了模型训练和管理的复杂性;并且不依赖于可用的公共数据集进行蒸馏,在缺乏大量公共数据集的应用场景中使用该方法训练得到的模型可以更具备实用性和准确性。
(3)本发明对通讯协议进行优化,基于通信效率函数值确定不同的通讯策略,以减少在模型训练过程中的通讯开销,并确保在保护隐私的同时实现高效的模型更新。
附图说明
图1为本发明的一种联邦学习模型训练方法实施例的步骤流程图;
图2为本发明的另一种联邦学习模型训练方法实施例的步骤流程图;
图3为本发明的另一种联邦学习模型训练方法实施例的采样过程伪代码执行示意图;
图4为本发明的另一种联邦学习模型训练方法实施例的模型训练示意图;
图5为本发明的另一种联邦学习模型训练方法实施例的模型使用全流程示意图;
图6为本发明的一种联邦学习模型训练装置实施例的结构框图;
图7为本发明的一种车辆实施例的处理器和存储介质的示意图;
图8为本发明的一种计算机存储介质实施例的示意图。
具体实施方式
以下将参照附图和优选实施例来说明本发明的实施方式,本领域技术人员可由本说明书中所揭露的内容轻易地了解本发明的其他优点与功效。本发明还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本发明的精神下进行各种修饰或改变。应当理解,优选实施例仅为了说明本发明,而不是为了限制本发明的保护范围。
需要说明的是,以下实施例中所提供的图示仅以示意方式说明本发明的基本构想,遂图式中仅显示与本发明中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。
为了后续的表述清晰,针对以下符号的定义进行说明:
参照图1,示出了本发明的一种联邦学习模型训练方法实施例的步骤流程图。联邦学习模型包括学生模型和预训练的教师模型;该联邦学习模型的结构主要由学生模型和预训练的教师模型组成。学生模型提供全局共享的模型,预训练的教师模型下载模型并训练自己的数据集,同时更新模型参数。
所述联邦学习模型训练方法具体可以包括如下:
步骤101,对联邦学习模型的嵌入层进行采样,生成伪样本数据;
在本发明实施例中,可以对从联邦学习模型的嵌入层中抽取数据,对嵌入层中的训练编码器和解码器部分的输入数据进行采样,生成伪样本数据。
步骤102,基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
对采样生成的伪样本数据,基于教师模型对其进行分类,确定出伪样本数据对应的软标签。软标签包含了教师模型对每个类别的预测概率,能够提供更多的信息和更好的模型泛化能力。
步骤103,采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
然后再采用伪样本数据和软标签对学生模型进行训练,得到模型参数,即单轮次训练得到的目标模型参数;
步骤104,依据所述目标模型参数更新所述联邦学习模型。
将得到的目标模型参数不断更新联邦学习模型,以得到满足使用要求的联邦学习模型。
本发明实施例通过对联邦学习模型的嵌入层进行采样,生成伪样本数据;基于教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;采用所述伪样本数据和所述软标签对学生模型进行训练,生成目标模型参数;依据所述目标模型参数更新所述联邦学习模型;通过在对联邦学习模型的嵌入层进行采样,生成用于知识蒸馏的伪样本数据,避免了对真实数据的依赖和使用;知识蒸馏过程不需要使用教师模型的实际数据,强化了整个联邦学习模型在联邦学习过程中的数据隐私保护;在蒸馏的过程中,无需依赖生成对抗网络或辅助数据,减轻了计算和通信的负担;提高了训练的效率;不依赖于多个中间模型的协同训练和蒸馏,从而可能减少了模型训练和管理的复杂性;并且不依赖于可用的公共数据集进行蒸馏,在缺乏大量公共数据集的应用场景中使用该方法训练得到的模型可以更具备实用性和准确性。
参照图2,示出了本发明的另一种联邦学习模型训练方法实施例的步骤流程图。联邦学习模型包括学生模型和预训练的教师模型;所述联邦学习模型训练方法具体可以包括如下:
步骤201,对联邦学习模型的嵌入层进行采样,生成伪样本数据;
可以对联邦学习模型的嵌入层进行采样,生成伪样本数据,通过伪样本数据来表征全面考虑各客户端独立的数据。
具体地,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:对联邦学习模型的嵌入层进行随机采样,生成初始样本集;对所述初始样本集进行目标采样,生成伪样本数据。
在本发明实施例中,可以直接从嵌入层中,基于随机采样的方式抽取数据,生成初始样本集。然后对初始样本集进行目标采样,生成初始样本集;通过这样的采样方法来增强模型的泛化能力。
进一步地,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:基于所述联邦学习模型的嵌入层的交叉熵和所述初始样本集建立目标损失函数;对所述目标损失函数进行梯度下降优化,生成伪样本数据。
可以在嵌入层中提取的伪样本后,使用目标损失进行优化,以便与教师模型在上的输出分布对齐。可以通过联邦学习模型的嵌入层的交叉熵和初始样本集构建目标损失函数:
其中,表示从分布γk中随机生成的伪标签集。
然后,使用梯度下降方法优化目标损失函数。在这个优化过程中,生成伪样本数据。
此外,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:依据所述教师模型与所述学生模型,确定对抗效应强度;依据所述对抗效应强度和所述伪样本数据建立对抗损失函数;对所述对抗损失函数进行梯度下降优化,得到目标伪样本;采用所述目标伪样本更新所述伪样本数据。
为了进一步提高样本的质量并增加伪样本的多样性,可以对抗采样的方式(但无需使用对抗生成网络GANs)。可以获得在ωk上显示正确的标签的伪样本数据,同时在ωs上产生显著的损失。
可以依据教师模型与学生模型确定对抗效应强度λ;所述对抗效应强度和所述伪样本数据建立对抗损失函数:
这里,参数λ控制教师(客户端模型)和学生(全局模型)之间的对抗效应强度,表示从分布γk中随机生成的伪标签集。
使用梯度下降方法来优化对抗损失函数得到目标伪样本;再采用目标伪样本更新伪样本数据,以筛选伪样本数据。
对于随机采样、目标采样和对抗采样的使用时机可以在T轮通信中,服务器选择一组在线可训练的客户端)。然后将学生模型发送到客户端进行更新。在一轮通信后,从平均参数开始作为蒸馏的起点。通过I轮的采样和微调,得到该轮的最佳模型。在最后一轮中,通过增加参数I和对抗项来进行后处理,从而在最后一轮中增强对抗效应强度,以实现最佳性能。
综上采样过程可知,输入:通信轮次T,客户端数量K,客户端数据集D^K_k=1,学生ωs的参数,对抗采样迭代次数I和I*,更新步骤η,η*和β。输出:全局模型参数ω_s。
对应的伪代码可以参照图3:
1:对于t=1→T执行
2:S_t←随机均匀选择活跃的客户端
3:对于k∈S_t执行
4:ω_k←ClientUpdate(ω_s;D_k;s)
5:结束
6:ω_s←FedIS({ω}^k∈S_t,I)
7:结束
8:ω_s←FedIS({ω}^K_k=1,I*)
9:返回ω_s
10:FedIS({ω}^M_m=1,I):
11:ω_s←(1/M)∑^M_m=1ω_m
12:对于i=1→I执行
13:采样一个代理数据集{θ^rd,θ^tr,θ^ad}^M_m=1和伪标签{γ^tr,γ^ad}^M_m=1
14:
15:
16:
17:结束
18:返回ω_s
步骤202,基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
在得到伪样本数据后,可以使用预训练的教师模型来为伪样本数据进行分类,生成软标签。该软标签与伪样本数据对应,即一个伪样本数据至少具有一个软标签。
具体地,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:依据所述教师模型的模型参数和前向传播函数计算所述伪样本数据,生成初始分类值;对所述初始分类值进行归一化,生成所述伪样本数据对应的软标签。
首先,可以依据教师模型的模型参数和前向传播函数计算伪样本数据,对伪样本数据进行分类确定初始分类值;再对初始分类值进行归一化,生成伪样本数据对应的软标签。即给定一个数据样本x,其软标签y可以通过以下公式获得:
y=σ(f(x;θT))
其中,σ是softmax函数,f是教师模型的前向传播函数,θT是教师模型的参数。
步骤203,采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
在对模型的知识蒸馏过程中,在得到软标签后,即可以采用伪样本数据和软标签对学生模型进行训练,完成知识蒸馏,生成本次训练对应的目标模型参数。
具体地,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数的步骤包括:采用所述伪样本数据训练所述学生模型,生成输出数据;依据所述输出数据和所述软标签建立交叉熵损失函数;循环计算所述交叉熵损失函数直至满足预设次数,生成目标模型参数。
对于学生模型的训练可以是针对伪样本数据和软标签作为输入训练学生模型。可以首先采用伪样本数据训练学生模型,得到学生模型的预测结果,即输出数据。在通过输出数据和软标签之间的差异,确定交叉熵,进而建立交叉熵损失函数。然后对交叉熵损失函数循环计算每个软标签的类别数量对应的次数,从而生成目标模型参数。其中训练过程中的交叉熵损失函数,其形式如下:
其中,N是类别的数量,y是软标签,是学生模型的预测。
步骤204,依据所述目标模型参数更新所述联邦学习模型;
在得到目标模型参数后,可以采用目标模型参数更新联邦学习模型,以使得联邦学习模型可以更新到更准确的状态。
步骤205,响应于所述联邦学习模型的更新,生成迭代计数值;
在本发明实施例中,还可以在每一次联邦学习模型的更新,进行记录,生成迭代计数值。
步骤206,基于更新后的联邦学习模型,循环执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,直至所述迭代计数值满足预设迭代条件。
重复执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,即在每一轮通信中,首先进行智能采样,然后进行无数据知识蒸馏。具体来说,采样负责生成高质量的合成数据,而无数据知识蒸馏则利用这些数据来进行模型的训练和优化。模型持续优化关注于在整个模型生命周期中不断提升其性能和泛化能力。通过迭代优化过程,在该过程交替进行智能采样和无数据知识蒸馏,以在多轮通信中逐步提升模型的性能。
为此,可以一个迭代优化函数o(θ,φ,T),其中θ表示模型参数,φ表示采样参数,而T表示迭代次数。该函数的目的是在T轮迭代中优化模型参数θ和采样参数φ。在多轮通信的过程中,模型参数θ和采样参数φ将在每一轮中得到更新和优化。通过多轮的迭代,提升模型的性能。
在实际应用时,可以初始化模型参数θ和采样参数φ。然后,在每一轮通信中,首先进行智能采样,生成合成数据。接下来,使用这些合成数据和对应的软标签进行无数据知识蒸馏,优化模型参数θ。在多轮通信的过程中,我们将不断迭代优化模型参数θ和采样参数φ,直到模型性能满足预定的标准或达到预定的迭代次数T即满足预设迭代条件。
在本发明的一可选实施例中,所述方法还包括:
步骤S1,获取所述教师模型与所述学生模型之间的通信指标参数和性能指标参数;
在本发明实施例中,可以确定教师模型与学生模型之间的通信指标参数和性能指标参数,性能指标参数可以包括模型在验证集上的准确率、分数或其他相关指标。通信指标参数可以为通信轮次、传输的数据量等。
步骤S2,依据所述性能指标参数和所述通信指标参数,确定通信效率函数值;
依据能指标参数和通信指标参数之间的关系,确定通信效率函数,基于通信效率函数确定每次训练后的确定通信效率函数值。
其中,通信效率函数C(θ.φ)为:
其中:
P(θ,φ)表示在参数θ和φ下模型的性能度量,即性能指标参数。
T(θ.φ)表示在参数θ和φ下的通信开销,即通信指标参数。
目标是在保持P(θ,φ)达到一定阈值的前提下,最小化T(θ,φ),即:
其中Pmin是模型性能的最低要求。
步骤S3,依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信。
依据通信效率函数值对不同,确定对应的通信策略来控制教师模型与学生模型之间的通信。通过优化通信的策略,以减少在模型训练过程中的通讯开销,并确保在保护隐私的同时实现高效的模型更新。
具体地,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:依据所述通信效率函数值确定通讯轮次。
在本发明实施例中,可以通过优化通信效率函数C(θ,φ)来智能选择通讯轮次。具体来说,我们可以在每个通讯轮次结束后,评估当前的P(θ,φ)和T(θ,φ),并根据通信效率函数的值来决定是否继续进行下一轮的通讯。
具体地,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:依据所述通信效率函数值确定异步通讯策略。
考虑到在实际的联邦学习场景中,不同的客户端可能因为网络条件、计算资源等因素,无法保证同时进行模型训练和更新。因此,可以允许客户端在不同的时间点与服务器进行通讯。异步通讯策略不仅能够适应各种不稳定的网络环境,还能够充分利用各个客户端的计算资源,提高系统的整体效率。
在异步通讯策略中,可以根据通信效率函数C(θ,φ)来动态调整客户端的通讯时间点。例如,当某个客户端的模型性能P(θ,φ)达到一定阈值时,可以允许其与服务器进行通讯,从而减少不必要的通讯轮次。
在本发明实施例中,依据所述通信指标参数,确定压缩策略;在所述教师模型与所述学生模型通信时,执行所述压缩策略。
为了减少在通讯过程中的数据传输量可以进行数据压缩,在保证模型更新质量的同时,显著减少需要传输的数据大小。具体来说,在数据传输的压缩方面,我们可以通过分析T(θ,φ)中的数据量部分,来确定合适的数据压缩策略。例如,可以通过调整梯度压缩中的阈值或模型参数量化中的精度,来在保证P(θ,φ)的同时,减小T(θ,φ)。
其中数据压缩策略包括但不限于如下的方式:
梯度压缩,只传输模型参数梯度的重要部分(例如,大于某个阈值的梯度值),而忽略那些对模型更新贡献较小的梯度值。
模型参数量化,将模型参数量化到较低的精度,以减少每个参数的位数,从而减少数据传输量。
稀疏更新,只更新那些在多个通讯轮次中表现出较大变化的模型参数,而忽略那些变化较小的参数。
本发明实施例,通过智能采样策略生成的合成数据能够更好地模拟原始数据的分布和特性,从而在无数据知识蒸馏过程中实现更精确的模型训练。此外,由于本方法不依赖于原始数据,因此在处理大规模模型时,它能够显著降低计算和通信的开销,提高模型训练和优化的效率。
综上,即本发明中的联邦学习模型的训练过程可以参照图4;对于模型的学习训练可以参照图5。
智能采样设计是生成用于知识蒸馏的样本。无文本知识蒸馏使用这些样本来进行模型的蒸馏。
通信效率优化关注在联邦学习环境中的通信效率。模型迭代优化包括智能采样和无文本知识蒸馏的迭代过程。模型验证是在多种NLP理解任务上验证模型性能的步骤。模型部署和模型应用分别关注模型在实际应用场景中的部署和应用。
参照图6,示出了本发明的一种联邦学习模型训练装置实施例的结构框图,联邦学习模型包括学生模型和预训练的教师模型;所述联邦学习模型训练装置具体可以包括如下模块:
采样模块601,用于对联邦学习模型的嵌入层进行采样,生成伪样本数据;
分类模块602,用于基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
训练模块603,用于采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
更新模块604,用于依据所述目标模型参数更新所述联邦学习模型。在本发明的一可选实施例中,所述装置还包括:
获取模块,用于获取所述教师模型与所述学生模型之间的通信指标参数和性能指标参数;
通信效率函数值确定模块,用于依据所述性能指标参数和所述通信指标参数,确定通信效率函数值;
通信控制模块,用于依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信。
在本发明的一可选实施例中,所述装置还包括:
响应模块,用于响应于所述联邦学习模型的更新,生成迭代计数值;
循环模块,用于基于更新后的联邦学习模型,循环执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,直至所述迭代计数值满足预设迭代条件。
在本发明的一可选实施例中,所述采样模块601包括:
第一采样子模块,用于对联邦学习模型的嵌入层进行随机采样,生成初始样本集;
第二采样子模块,用于对所述初始样本集进行目标采样,生成伪样本数据。在本发明的一可选实施例中,所述采样模块601还包括:
对抗效应强度确定子模块,用于依据所述教师模型与所述学生模型,确定对抗效应强度;
对抗损失函数建立子模块,用于依据所述对抗效应强度和所述伪样本数据建立对抗损失函数;
第一优化子模块,用于对所述对抗损失函数进行梯度下降优化,得到目标伪样本;
伪样本更新子模块,用于采用所述目标伪样本更新所述伪样本数据。
在本发明的一可选实施例中,所述第二采样子模块包括:
目标损失函数建立单元,用于基于所述联邦学习模型的嵌入层的交叉熵和所述初始样本集建立目标损失函数;
第二优化单元,用于对所述目标损失函数进行梯度下降优化,生成伪样本数据。
在本发明的一可选实施例中,所述分类模块602包括:
初始分类子模块,用于依据所述教师模型的模型参数和前向传播函数计算所述伪样本数据,生成初始分类值;
归一化子模块,用于对所述初始分类值进行归一化,生成所述伪样本数据对应的软标签。
在本发明的一可选实施例中,所述训练模块603包括:
训练子模块,用于采用所述伪样本数据训练所述学生模型,生成输出数据;交叉熵损失函数建立子模块,用于依据所述输出数据和所述软标签建立交叉熵损失函数;
循环子模块,用于循环计算所述交叉熵损失函数直至满足预设次数,生成目标模型参数。
在本发明的一可选实施例中,所述装置还包括:
压缩策略确定模块,用于依据所述通信指标参数,确定压缩策略;
压缩策略执行模块,用于在所述教师模型与所述学生模型通信时,执行所述压缩策略。
在本发明的一可选实施例中,所述通信控制模块包括:
通讯轮次控制子模块,用于依据所述通信效率函数值确定通讯轮次。
在本发明的一可选实施例中,所述通信控制模块包括:
异步通讯策略控制子模块,用于依据所述通信效率函数值确定异步通讯策略。
参照图7,本发明实施例还提供了一种电子设备,包括:
处理器701和存储介质702,所述存储介质702存储有所述处理器701可执行的计算机程序,当电子设备运行时,所述处理器701执行所述计算机程序,以执行如本发明实施例任一项所述的联邦学习模型训练方法。
联邦学习模型包括学生模型和预训练的教师模型;所述联邦学习模型训练方法包括:
对联邦学习模型的嵌入层进行采样,生成伪样本数据;
基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
依据所述目标模型参数更新所述联邦学习模型。
可选的,所述方法还包括:
获取所述教师模型与所述学生模型之间的通信指标参数和性能指标参数;依据所述性能指标参数和所述通信指标参数,确定通信效率函数值;
依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信。
可选的,所述方法还包括:
响应于所述联邦学习模型的更新,生成迭代计数值;
基于更新后的联邦学习模型,循环执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,直至所述迭代计数值满足预设迭代条件。
可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:
对联邦学习模型的嵌入层进行随机采样,生成初始样本集;
对所述初始样本集进行目标采样,生成伪样本数据。
可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:
依据所述教师模型与所述学生模型,确定对抗效应强度;
依据所述对抗效应强度和所述伪样本数据建立对抗损失函数;
对所述对抗损失函数进行梯度下降优化,得到目标伪样本;
采用所述目标伪样本更新所述伪样本数据。
可选的,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:
基于所述联邦学习模型的嵌入层的交叉熵和所述初始样本集建立目标损失函数;
对所述目标损失函数进行梯度下降优化,生成伪样本数据。
可选的,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:
依据所述教师模型的模型参数和前向传播函数计算所述伪样本数据,生成初始分类值;
对所述初始分类值进行归一化,生成所述伪样本数据对应的软标签。
可选的,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数的步骤包括:
采用所述伪样本数据训练所述学生模型,生成输出数据;
依据所述输出数据和所述软标签建立交叉熵损失函数;
循环计算所述交叉熵损失函数直至满足预设次数,生成目标模型参数。
可选的,所述方法还包括:
依据所述通信指标参数,确定压缩策略;
在所述教师模型与所述学生模型通信时,执行所述压缩策略。
可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定通讯轮次。
可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定异步通讯策略。
上述存储器可以包括随机存取存储器(Random Access Memory,简称RAM),也可以包括非易失性存储器(non-volatile memory),例如至少一个磁盘存储器。可选的,存储器还可以是至少一个位于远离前述处理器的存储装置。
上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,简称CPU)、网络处理器(Network Processor,简称NP)等;还可以是数字信号处理器(Digital Signal Processing,简称DSP)、专用集成电路(Application SpecificIntegrated Circuit,简称ASIC)、现场可编程门阵列(Field-Programmable Gate Array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。
参照图8,本发明实施例还提供了一种计算机可读存储介质801,所述存储介质801上存储有计算机程序,所述计算机程序被处理器运行时执行如本发明实施例任一项所述的联邦学习模型训练方法。
联邦学习模型包括学生模型和预训练的教师模型;所述联邦学习模型训练方法包括:
对联邦学习模型的嵌入层进行采样,生成伪样本数据;
基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
依据所述目标模型参数更新所述联邦学习模型。
可选的,所述方法还包括:
获取所述教师模型与所述学生模型之间的通信指标参数和性能指标参数;依据所述性能指标参数和所述通信指标参数,确定通信效率函数值;
依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信。
可选的,所述方法还包括:
响应于所述联邦学习模型的更新,生成迭代计数值;
基于更新后的联邦学习模型,循环执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,直至所述迭代计数值满足预设迭代条件。
可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:
对联邦学习模型的嵌入层进行随机采样,生成初始样本集;
对所述初始样本集进行目标采样,生成伪样本数据。
可选的,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:
依据所述教师模型与所述学生模型,确定对抗效应强度;
依据所述对抗效应强度和所述伪样本数据建立对抗损失函数;
对所述对抗损失函数进行梯度下降优化,得到目标伪样本;
采用所述目标伪样本更新所述伪样本数据。
可选的,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:
基于所述联邦学习模型的嵌入层的交叉熵和所述初始样本集建立目标损失函数;
对所述目标损失函数进行梯度下降优化,生成伪样本数据。
可选的,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:
依据所述教师模型的模型参数和前向传播函数计算所述伪样本数据,生成初始分类值;
对所述初始分类值进行归一化,生成所述伪样本数据对应的软标签。
可选的,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数的步骤包括:
采用所述伪样本数据训练所述学生模型,生成输出数据;
依据所述输出数据和所述软标签建立交叉熵损失函数;
循环计算所述交叉熵损失函数直至满足预设次数,生成目标模型参数。
可选的,所述方法还包括:
依据所述通信指标参数,确定压缩策略;
在所述教师模型与所述学生模型通信时,执行所述压缩策略。
可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定通讯轮次。
可选的,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定异步通讯策略。
说明书中的各个实施例均采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似的部分互相参见即可。
本领域内的技术人员应明白,本发明实施例的实施例可提供为方法、装置、或计算机程序产品。因此,本发明实施例可采用完全硬件实施例、完全软件实施例、或结合软件和硬件方面的实施例的形式。而且,本发明实施例可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
本发明实施例是参照根据本发明实施例的方法、终端设备(系统)、和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理终端设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理终端设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
这些计算机程序指令也可存储在能引导计算机或其他可编程数据处理终端设备以特定方式工作的计算机可读存储器中,使得存储在该计算机可读存储器中的指令产生包括指令装置的制造品,该指令装置实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能。
这些计算机程序指令也可装载到计算机或其他可编程数据处理终端设备上,使得在计算机或其他可编程终端设备上执行一系列操作步骤以产生计算机实现的处理,从而在计算机或其他可编程终端设备上执行的指令提供用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的步骤。
以上实施例仅是为充分说明本发明而所举的较佳的实施例,本发明的保护范围不限于此。本技术领域的技术人员在本发明基础上所作的等同替代或变换,均在本发明的保护范围之内。
Claims (14)
1.一种联邦学习模型训练方法,其特征在于,联邦学习模型包括学生模型和预训练的教师模型;所述方法包括:
对联邦学习模型的嵌入层进行采样,生成伪样本数据;
基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
依据所述目标模型参数更新所述联邦学习模型。
2.根据权利要求1所述的方法,其特征在于,所述方法还包括:
获取所述教师模型与所述学生模型之间的通信指标参数和性能指标参数;
依据所述性能指标参数和所述通信指标参数,确定通信效率函数值;
依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信。
3.根据权利要求1所述的方法,其特征在于,所述方法还包括:
响应于所述联邦学习模型的更新,生成迭代计数值;
基于更新后的联邦学习模型,循环执行对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤和所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤,直至所述迭代计数值满足预设迭代条件。
4.根据权利要求1至3任一项所述的方法,其特征在于,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤包括:
对联邦学习模型的嵌入层进行随机采样,生成初始样本集;
对所述初始样本集进行目标采样,生成伪样本数据。
5.根据权利要求4所述的方法,其特征在于,所述对联邦学习模型的嵌入层进行采样,生成伪样本数据的步骤还包括:
依据所述教师模型与所述学生模型,确定对抗效应强度;
依据所述对抗效应强度和所述伪样本数据建立对抗损失函数;
对所述对抗损失函数进行梯度下降优化,得到目标伪样本;
采用所述目标伪样本更新所述伪样本数据。
6.根据权利要求4所述的方法,其特征在于,所述对所述初始样本集进行目标采样,生成伪样本数据的步骤包括:
基于所述联邦学习模型的嵌入层的交叉熵和所述初始样本集建立目标损失函数;
对所述目标损失函数进行梯度下降优化,生成伪样本数据。
7.根据权利要求1至3任一项所述的方法,其特征在于,所述基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签的步骤包括:
依据所述教师模型的模型参数和前向传播函数计算所述伪样本数据,生成初始分类值;
对所述初始分类值进行归一化,生成所述伪样本数据对应的软标签。
8.根据权利要求1至3任一项所述的方法,其特征在于,所述采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数的步骤包括:
采用所述伪样本数据训练所述学生模型,生成输出数据;
依据所述输出数据和所述软标签建立交叉熵损失函数;
循环计算所述交叉熵损失函数直至满足预设次数,生成目标模型参数。
9.根据权利要求2所述的方法,其特征在于,所述方法还包括:
依据所述通信指标参数,确定压缩策略;
在所述教师模型与所述学生模型通信时,执行所述压缩策略。
10.根据权利要求2或9所述的方法,其特征在于,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定通讯轮次。
11.根据权利要求10所述的方法,其特征在于,所述依据所述通信效率函数值控制所述教师模型与所述学生模型之间的通信的步骤包括:
依据所述通信效率函数值确定异步通讯策略。
12.一种联邦学习模型训练装置,其特征在于,联邦学习模型包括学生模型和预训练的教师模型;所述装置包括:
采样模块,用于对联邦学习模型的嵌入层进行采样,生成伪样本数据;
分类模块,用于基于所述教师模型对所述伪样本数据进行分类,确定所述伪样本数据对应的软标签;
训练模块,用于采用所述伪样本数据和所述软标签对所述学生模型进行训练,生成目标模型参数;
更新模块,用于依据所述目标模型参数更新所述联邦学习模型。
13.一种电子设备,其特征在于,包括处理器、存储器及存储在所述存储器上并能够在所述处理器上运行的计算机程序,所述计算机程序被所述处理器执行时实现如权利要求1至11任一项所述的联邦学习模型训练方法的步骤。
14.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储计算机程序,所述计算机程序被处理器执行时实现如权利要求1至11任一项所述的联邦学习模型训练方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410267744.3A CN117973505A (zh) | 2024-03-08 | 2024-03-08 | 一种联邦学习模型训练方法、装置、电子设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202410267744.3A CN117973505A (zh) | 2024-03-08 | 2024-03-08 | 一种联邦学习模型训练方法、装置、电子设备和存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117973505A true CN117973505A (zh) | 2024-05-03 |
Family
ID=90854785
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202410267744.3A Pending CN117973505A (zh) | 2024-03-08 | 2024-03-08 | 一种联邦学习模型训练方法、装置、电子设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117973505A (zh) |
-
2024
- 2024-03-08 CN CN202410267744.3A patent/CN117973505A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
KR102422729B1 (ko) | 학습 데이터 증강 정책 | |
CN110880036B (zh) | 神经网络压缩方法、装置、计算机设备及存储介质 | |
US10984319B2 (en) | Neural architecture search | |
CN111461226A (zh) | 对抗样本生成方法、装置、终端及可读存储介质 | |
CN111602148A (zh) | 正则化神经网络架构搜索 | |
JP2019533257A (ja) | ニューラルアーキテクチャ検索 | |
CN110659678B (zh) | 一种用户行为分类方法、系统及存储介质 | |
US20220164666A1 (en) | Efficient mixed-precision search for quantizers in artificial neural networks | |
Li et al. | Energy-based models for continual learning | |
CN111357018A (zh) | 使用神经网络的图像分段 | |
CN115017178A (zh) | 数据到文本生成模型的训练方法和装置 | |
CN110717582A (zh) | 使用鉴别器神经网络从生成器神经网络采样 | |
CN113377964A (zh) | 知识图谱链接预测方法、装置、设备及存储介质 | |
CN115210717A (zh) | 硬件优化的神经架构搜索 | |
CN116976461A (zh) | 联邦学习方法、装置、设备及介质 | |
CN117973505A (zh) | 一种联邦学习模型训练方法、装置、电子设备和存储介质 | |
KR102393761B1 (ko) | 이미지 처리를 위한 인공 신경망 모델 학습 방법 및 시스템 | |
CN115577797A (zh) | 一种基于本地噪声感知的联邦学习优化方法及系统 | |
CN117033997A (zh) | 数据切分方法、装置、电子设备和介质 | |
CN111104951A (zh) | 一种主动学习方法、装置及终端设备 | |
CN114692888A (zh) | 系统参数处理方法、装置、设备及存储介质 | |
CN111402121A (zh) | 图像风格的转换方法、装置、计算机设备和存储介质 | |
CN111062477A (zh) | 一种数据处理方法、装置及存储介质 | |
CN117196070B (zh) | 一种面向异构数据的双重联邦蒸馏学习方法及装置 | |
CN117726857A (zh) | 一种无数据条件下的模型鲁棒性增强方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |