CN117236421A - 一种基于联邦知识蒸馏的大模型训练方法 - Google Patents
一种基于联邦知识蒸馏的大模型训练方法 Download PDFInfo
- Publication number
- CN117236421A CN117236421A CN202311512843.5A CN202311512843A CN117236421A CN 117236421 A CN117236421 A CN 117236421A CN 202311512843 A CN202311512843 A CN 202311512843A CN 117236421 A CN117236421 A CN 117236421A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- student
- server
- small
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 94
- 238000000034 method Methods 0.000 title claims abstract description 49
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 39
- 230000007246 mechanism Effects 0.000 claims abstract description 24
- 230000006870 function Effects 0.000 claims description 27
- 230000008520 organization Effects 0.000 claims description 25
- 230000002776 aggregation Effects 0.000 claims description 9
- 238000004220 aggregation Methods 0.000 claims description 9
- 230000004931 aggregating effect Effects 0.000 claims description 5
- 238000004821 distillation Methods 0.000 claims description 4
- 238000004364 calculation method Methods 0.000 abstract description 6
- 230000008569 process Effects 0.000 description 10
- 238000010586 diagram Methods 0.000 description 2
- 238000010187 selection method Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明实施例中提供了一种基于联邦知识蒸馏的大模型训练方法,属于计算技术领域,具体包括:步骤1,将预设机构中的数据集联合,构建联邦大模型系统;步骤2,在联邦大模型系统中的服务器中部署一个知识蒸馏,以用户的本地数据训练出来的模型参数作为输入,训练得到一个教师模型,在知识蒸馏的控制下使用教师模型的输出和本地数据的真实标签训练学生模型;步骤3,将训练好的学生模型通过服务器发送给小机构的客户端;步骤4,根据小机构的数据量和训练需要,结合学生模型确定训练方案进行训练,得到目标模型。通过本发明的方案,提高了训练效率和安全性。
Description
技术领域
本发明实施例涉及计算技术领域,尤其涉及一种基于联邦知识蒸馏的大模型训练方法。
背景技术
目前,大模型是一种比传统机器学习模型拥有更多参数和结构更复杂的模型,它的参数数量通常是数亿到数万亿,常用于解决复杂的自然语言处理、计算机视觉和语音识别等任务。通过使用大模型,深度学习算法可以更好地处理这些任务,提高模型的准确性和性能。虽然大模型有着很好的发展前景,但也同样面临着很多的挑战。由于大模型在训练时需要使用大量的数据,为了保证数据隐私不被泄露,可以通过联邦学习平台来进行大模型的训练过程。因为联邦学习具有保存数据在本地进行迭代训练而不需要上传至服务器的特点,这大大降低了数据泄露的风险。但是,由于大模型庞大的参数数量,在训练时需要使用大量的计算资源进行优化和调整,这对于一些小医院等小机构而言,它们可能没有足够的算力条件去训练大模型,这就导致了它们无法参与到大模型的训练过程,也可能由于数据量太少而无法训练出性能很好的模型。
研究表明,如今的大模型由于参数数量的庞大需要大量的计算资源,但对于很多机构而言,它们没有足够的计算资源去支撑它们去训练自己的大模型,但是他们又拥有自己的数据集,而且还有可能这些数据涉及隐私问题,无法将数据给出。
可见,亟需一种训练效率和安全性高的基于联邦知识蒸馏的大模型训练方法。
发明内容
有鉴于此,本发明实施例提供一种基于联邦知识蒸馏的大模型训练方法,至少部分解决现有技术中存在训练效率和安全性较差的问题。
本发明实施例提供了一种基于联邦知识蒸馏的大模型训练方法,包括:
步骤1,将预设机构中的数据集联合,构建联邦大模型系统;
步骤2,在联邦大模型系统中的服务器中部署一个知识蒸馏,以用户的本地数据训练出来的模型参数作为输入,训练得到一个教师模型,在知识蒸馏的控制下使用教师模型的输出和本地数据的真实标签训练学生模型;
步骤3,将训练好的学生模型通过服务器发送给小机构的客户端;
步骤4,根据小机构的数据量和训练需要,结合学生模型确定训练方案进行训练,得到目标模型。
根据本发明实施例的一种具体实现方式,所述步骤2具体包括:
步骤2.1,将知识蒸馏部署在联邦大模型系统的服务器端,以用户的本地数据训练出来的模型参数作为输入,训练得到教师模型;
步骤2.2,在知识蒸馏的控制下,将教师模型的预测输出作为软标签输入学生模型进行学习并计算第一损失函数,将真实标签作为硬标签输入学生模型进行学习并计算第二损失函数,然后将第一损失函数和第二损失函数加权求和,作为最终损失函数更新学生模型的参数。
根据本发明实施例的一种具体实现方式,所述第一损失函数的表达式为
;
其中,
其中,N表示模型训练样本数量,p表示教师模型的输出,q表示学生模型的输出,T表示温度,表示教师模型输出的对于第i个类别的概率预测值,/>是学生模型输出的对于第i个类别的概率预测值,/>表示教师网络中第i个样本在温度T时的输出,/>表示学生网络中第i个样本在温度T时的输出,/>表示第i个样本,/>表示第j个样本,/>表示第k个样本。
根据本发明实施例的一种具体实现方式,所述第二损失函数的表达式为
;
其中
其中,c表示真实标签,表示第j个样本的真实标签。
根据本发明实施例的一种具体实现方式,所述最终损失函数的表达式为
其中,和/>是平衡蒸馏损失和学生损失的参数,且/>。
根据本发明实施例的一种具体实现方式,所述步骤4具体包括:
步骤4.1,判断小机构的数据量是否满足模型训练条件且需要训练符合其对应要求的模型,若是,则执行步骤4.2,若否,则执行步骤4.3;
步骤4.2,将训练好的学生模型发送给待训练的小机构,根据待训练的小机构的本地数据训练学生模型并根据小机构的本地数据集的特点和要求进行微调,得到目标模型模型;
步骤4.3,将待训练的小机构与其他的小机构共同构成一个新的联邦学习平台,服务器利用学生模型作为初始共享模型下发给小机构的各个客户端,每个客户端利用本地私有数据在本地进行训练,训练好之后再将模型参数发送回给服务器,服务器进行聚合更新全局模型,然后再次发送回客户端,迭代直至全局模型收敛,得到的全局模型通过客户端发送至参与该联邦大模型系统中的所有小机构。
根据本发明实施例的一种具体实现方式,所述将待训练的小机构与其他的小机构共同构成一个新的联邦学习平台,服务器利用学生模型作为初始共享模型下发给小机构的各个客户端,每个客户端利用本地私有数据在本地进行训练,训练好之后再将模型参数发送回给服务器,服务器进行聚合更新全局模型的步骤,包括:
小机构一共有M个客户机,中心服务器初始化模型参数,执行预设数量的轮次,每轮选取至少1个至多M个客户机参与训练,接下来每个被选中的客户机同时在自己的本地根据服务器下发的本轮模型用自己的数据训练自己的模型/>,上传回服务器,服务器将收集来的各客户机的模型根据各方样本数量用加权平均的方式进行聚合,得到下一轮的模型/>:
其中,为客户机m上的样本数量,n为所有被选中客户机的总样本数量。
本发明实施例中的基于联邦知识蒸馏的大模型训练方案,包括:步骤1,将预设机构中的数据集联合,构建联邦大模型系统;步骤2,在联邦大模型系统中的服务器中部署一个知识蒸馏,以用户的本地数据训练出来的模型参数作为输入,训练得到一个教师模型,在知识蒸馏的控制下使用教师模型的输出和本地数据的真实标签训练学生模型;步骤3,将训练好的学生模型通过服务器发送给小机构的客户端;步骤4,根据小机构的数据量和训练需要,结合学生模型确定训练方案进行训练,得到目标模型。
本发明实施例的有益效果为:
1)帮助一些没有足够算力条件的小机构使用大模型蒸馏后的模型进行训练,使其同样可以用自己的数据集去训练这个小模型进行微调后得到适合自己机构所需的个性化模型或者和其他的小机构构成一个新的联邦大模型系统使用蒸馏后的模型作为联邦学习的初始共享模型训练出一个性能更好的全局模型;
2)使小机构即使在没有足够算力的情况下同样可以使用大模型训练出准确性更好、性能更高的模型;
3)这些机构不用将自己的数据给出去,减少了隐私泄露的风险。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它的附图。
图1为本发明实施例提供的一种基于联邦知识蒸馏的大模型训练方法的流程示意图;
图2为本发明实施例提供的一种基于联邦知识蒸馏的大模型训练方法的具体实施流程示意图。
具体实施方式
下面结合附图对本发明实施例进行详细描述。
以下通过特定的具体实例说明本发明的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本发明的其他优点与功效。显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。本发明还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本发明的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
需要说明的是,下文描述在所附权利要求书的范围内的实施例的各种方面。应显而易见,本文中所描述的方面可体现于广泛多种形式中,且本文中所描述的任何特定结构及/或功能仅为说明性的。基于本发明,所属领域的技术人员应了解,本文中所描述的一个方面可与任何其它方面独立地实施,且可以各种方式组合这些方面中的两者或两者以上。举例来说,可使用本文中所阐述的任何数目个方面来实施设备及/或实践方法。另外,可使用除了本文中所阐述的方面中的一或多者之外的其它结构及/或功能性实施此设备及/或实践此方法。
还需要说明的是,以下实施例中所提供的图示仅以示意方式说明本发明的基本构想,图式中仅显示与本发明中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。
另外,在以下描述中,提供具体细节是为了便于透彻理解实例。然而,所属领域的技术人员将理解,可在没有这些特定细节的情况下实践所述方面。
本发明实施例提供一种基于联邦知识蒸馏的大模型训练方法,所述方法可以应用于社区、医院等场景的本地客户端模型训练过程。
参见图1,为本发明实施例提供的一种基于联邦知识蒸馏的大模型训练方法的流程示意图。如图1和图2所示,所述方法主要包括以下步骤:
步骤1,将预设机构中的数据集联合,构建联邦大模型系统;
具体实施时,可以将多个拥有足够算力的机构中的数据集联合起来构成一个联邦大模型系统。
步骤2,在联邦大模型系统中的服务器中部署一个知识蒸馏,以用户的本地数据训练出来的模型参数作为输入,训练得到一个教师模型,在知识蒸馏的控制下使用教师模型的输出和本地数据的真实标签训练学生模型;
在上述实施例的基础上,所述步骤2具体包括:
步骤2.1,将知识蒸馏部署在联邦大模型系统的服务器端,以用户的本地数据训练出来的模型参数作为输入,训练得到教师模型;
步骤2.2,在知识蒸馏的控制下,将教师模型的预测输出作为软标签输入学生模型进行学习并计算第一损失函数,将真实标签作为硬标签输入学生模型进行学习并计算第二损失函数,然后将第一损失函数和第二损失函数加权求和,作为最终损失函数更新学生模型的参数。
进一步的,所述第一损失函数的表达式为
;
其中,
;
其中,N表示模型训练样本数量,p表示教师模型的输出,q表示学生模型的输出,T表示温度,表示教师模型输出的对于第i个类别的概率预测值,/>是学生模型输出的对于第i个类别的概率预测值,/>表示教师网络中第i个样本在温度T时的输出,/>表示学生网络中第i个样本在温度T时的输出,/>表示第i个样本,/>表示第j个样本,/>表示第k个样本。
进一步的,所述第二损失函数的表达式为
其中
其中,c表示真实标签,表示第j个样本的真实标签。
进一步的,所述最终损失函数的表达式为
其中,和/>是平衡蒸馏损失和学生损失的参数,且/>。
具体实施时,(1)大模型联邦训练的具体过程:
首先,大模型采用分布式训练的方式部署在联邦学习框架中,本地客户端是拥有足够算力的大机构,大机构中的客户端都拥有庞大的数据量。接下来就是联邦大模型的具体训练过程,主要分为下面3个步骤:
①任务初始化:在训练开始之前,服务器首先要确定训练的任务和目标,并选择参与联邦学习的设备,然后把挑选好的共享大模型发送给已选择的设备。
②本地训练与共享:每个设备利用私有数据训练本地模型。训练的目标就是找到最佳的本地模型。设备训练完之后把模型参数上传到服务器,进行下一步操作。
③全局聚合与更新:服务器收集到来自所有参与设备的本地模型后,进行模型参数聚合。典型的聚合操作是平均算法 FedAvg,联邦学习服务器通过平均本地模型参数得到下一轮的共享全局模型,目标是找到最佳的全局模型。
上述步骤将会依次迭代进行,当全局模型收敛或者达到一定的准确率时结束训练。
但是对于很多的小机构而言,它们可能没有足够的条件去满足这些大模型所需要的算力或是对于它们来说这种规模的算力成本太高。在这种情况下我们就可以通过本发明的基于知识蒸馏的联邦大模型训练方法来蒸馏出一个模型规模更小的学生模型来交给小机构训练,在这同时还通过联邦学习平台保护了数据的隐私性。
(2)联邦大模型中知识蒸馏的具体过程:
知识蒸馏部署在联邦大模型系统的服务器端,以用户的本地数据训练出来的模型参数作为输入,以训练得到一个教师模型,在知识蒸馏的控制下以相同的输入训练得出最终的学生模型。也就是说,将同一批数据放入两个模型中,将教师模型的预测输出作为软标签,将真实标签作为硬标签,分别计算学生模型的两种损失,最后将两个损失加权求和,作为最终损失更新网络参数。预测的时候,仅使用学生模型。
知识蒸馏,可以将一个网络的知识转移到另一个网络,两个网络可以是同构或者异构。做法是先训练一个教师网络,然后使用这个教师网络的输出和数据的真实标签去训练学生网络。知识蒸馏,可以用来将网络从大网络转化成一个小网络,并保留接近于大网络的性能;也可以将多个网络的学到的知识转移到一个网络中。
知识蒸馏的具体过程:
设模型训练样本数为N,类别数为C。
是教师模型教学生学习的损失函数:
;
其中
其中p表示教师模型的输出,q表示学生模型的输出。将教师模型的输出结果p作为学生模型的目标,使学生模型的输出结果q尽可能接近p,具体就是计算教师和学生的交叉熵。其中T是通常设置为1的温度,使用较高的T值可以在类上产生较软的概率分布。是教师模型输出的logit值,/>是学生模型输出的logit值,logit是模型输出的对于各个类别的概率预测值。
是学生自己跟着真实标签学习的损失函数:
其中
其实和常规模型是一样的,就是根据训练集的/>来学习。上面公式中c就是真实/>,也就是计算学生模型的输出结果q和标签c的交叉熵。
总损失为:
其中,和/>是平衡蒸馏损失和学生损失的参数,且/>。
步骤3,将训练好的学生模型通过服务器发送给小机构的客户端;
具体实施时,在通过知识蒸馏得到训练好的学生模型后,可以将训练好的学生模型通过服务器发送给需要参与联邦学习的小机构的客户端,以便于进行后续操作流程。
步骤4,根据小机构的数据量和训练需要,结合学生模型确定训练方案进行训练,得到目标模型。
在上述实施例的基础上,所述步骤4具体包括:
步骤4.1,判断小机构的数据量是否满足模型训练条件且需要训练符合其对应要求的模型,若是,则执行步骤4.2,若否,则执行步骤4.3;
步骤4.2,将训练好的学生模型发送给待训练的小机构,根据待训练的小机构的本地数据训练学生模型并根据小机构的本地数据集的特点和要求进行微调,得到目标模型模型;
步骤4.3,将待训练的小机构与其他的小机构共同构成一个新的联邦学习平台,服务器利用学生模型作为初始共享模型下发给小机构的各个客户端,每个客户端利用本地私有数据在本地进行训练,训练好之后再将模型参数发送回给服务器,服务器进行聚合更新全局模型,然后再次发送回客户端,迭代直至全局模型收敛,得到的全局模型通过客户端发送至参与该联邦大模型系统中的所有小机构。
进一步的,所述将待训练的小机构与其他的小机构共同构成一个新的联邦学习平台,服务器利用学生模型作为初始共享模型下发给小机构的各个客户端,每个客户端利用本地私有数据在本地进行训练,训练好之后再将模型参数发送回给服务器,服务器进行聚合更新全局模型的步骤,包括:
小机构一共有M个客户机,中心服务器初始化模型参数,执行预设数量的轮次,每轮选取至少1个至多M个客户机参与训练,接下来每个被选中的客户机同时在自己的本地根据服务器下发的本轮模型用自己的数据训练自己的模型/>,上传回服务器,服务器将收集来的各客户机的模型根据各方样本数量用加权平均的方式进行聚合,得到下一轮的模型/>:
其中,为客户机m上的样本数量,n为所有被选中客户机的总样本数量。
具体实施时,对于蒸馏出来的学生模型和小机构数据,可以通过两种方法来帮助小机构训练模型。方法1:将知识蒸馏后的学生模型发送给各个小机构,各个小机构根据自己本地的数据训练这个模型并根据自己数据集的特点和要求进行微调,得到属于自己机构的个性化模型;方法2:与其他的小机构共同构成一个新的联邦学习平台,服务器用这个学生模型作为初始共享模型下发给小机构的各个客户端,客户端用自己本地私有数据在本地进行训练,训练好之后再将模型参数发送回给服务器,服务器进行聚合更新全局模型,然后再次发送回客户端,如此迭代直至全局模型收敛。当小机构的数据不够多,但又想使用别人的隐私数据帮助训练出更好的模型的时候可以使用方法2;当小机构自己拥有足够的数据,不需要借助其他数据的帮助,又想要训练出符合自己机构要求的个性化模型时可以使用方法1。针对不同小机构的不同条件和要求可以自主选择训练方法,甚至在这个场景下,会出现部分机构选择方法2参与联邦学习,部分机构选择方法1训练个性化模型的情况。
具体的,方法1:小机构训练自己个性化模型的过程:
小机构通过自己所拥有的数据训练服务器发送的学生模型,从头训练输出层,而其余层的参数都是基于学生模型的参数微调得到的。通过模型微调的方式,训练出速度更快、模型精度更高的适用于自己场景和任务的个性化模型。
方法2:小机构联邦训练过程:
小机构一共有M个客户机,中心服务器初始化模型参数,执行若干轮,每轮选取至少1个至多M个客户机参与训练,接下来每个被选中的客户机同时在自己的本地根据服务器下发的本轮(t轮)模型用自己的数据训练自己的模型/>,上传回服务器。服务器将收集来的各客户机的模型根据各方样本数量用加权平均的方式进行聚合,得到下一轮的模型/>:
;
其中,为客户机m上的样本数量,n为所有被选中客户机的总样本数量。
为了增加客户机计算量,可以在中心服务器做聚合(加权平均)操作前在每个客户机上多迭代更新几次。计算量由三个参数决定:
·C,每一轮参与计算的客户机比例。
·E,每一轮每个客户机投入其全部本地数据训练一遍的次数。
·B,用于客户机更新的batch大小。B = ∞表示batch为全部样本,此时就是full-batch梯度下降了。
当E = 1, B = ∞时,对应的就是FedSGD,即每一轮客户机一次性将所有本地数据投入训练,更新模型参数。
对于一个有着个本地样本的客户机m来说,每轮的本地更新次数为。
对于小机构联邦学习的过程,这里使用的是联邦平均算法,当然除此之外还有很多联邦学习的算法,实际应用中可以选择其他合适的算法进行联邦学习训练出更好的模型。
本实施例提供的基于联邦知识蒸馏的大模型训练方法,通过使用大模型蒸馏后的模型在小机构进行训练,使其同样可以用自己的数据集去训练这个小模型进行微调后得到适合自己机构所需的个性化模型,或者和其他的小机构构成一个新的联邦大模型系统使用蒸馏后的模型作为联邦学习的初始共享模型训练出一个性能更好的全局模型,使小机构即使在没有足够算力的情况下同样可以使用大模型训练出准确性更好、性能更高的模型,提高了训练效率和安全性。
应当理解,本发明的各部分可以用硬件、软件、固件或它们的组合来实现。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到的变化或替换,都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。
Claims (6)
1.一种基于联邦知识蒸馏的大模型训练方法,其特征在于,包括:
步骤1,将预设机构中的数据集联合,构建联邦大模型系统;
步骤2,在联邦大模型系统中的服务器中部署一个知识蒸馏,以用户的本地数据训练出来的模型参数作为输入,训练得到一个教师模型,在知识蒸馏的控制下使用教师模型的输出和本地数据的真实标签训练学生模型;
步骤3,将训练好的学生模型通过服务器发送给小机构的客户端;
步骤4,根据小机构的数据量和训练需要,结合学生模型确定训练方案进行训练,得到目标模型;
所述步骤4具体包括:
步骤4.1,判断小机构的数据量是否满足模型训练条件且需要训练符合其对应要求的模型,若是,则执行步骤4.2,若否,则执行步骤4.3;
步骤4.2,将训练好的学生模型发送给待训练的小机构,根据待训练的小机构的本地数据训练学生模型并根据小机构的本地数据集的特点和要求进行微调,得到目标模型模型;
步骤4.3,将待训练的小机构与其他的小机构共同构成一个新的联邦学习平台,服务器利用学生模型作为初始共享模型下发给小机构的各个客户端,每个客户端利用本地私有数据在本地进行训练,训练好之后再将模型参数发送回给服务器,服务器进行聚合更新全局模型,然后再次发送回客户端,迭代直至全局模型收敛,得到的全局模型通过客户端发送至参与该联邦大模型系统中的所有小机构。
2.根据权利要求1所述的方法,其特征在于,所述步骤2具体包括:
步骤2.1,将知识蒸馏部署在联邦大模型系统的服务器端,以用户的本地数据训练出来的模型参数作为输入,训练得到教师模型;
步骤2.2,在知识蒸馏的控制下,将教师模型的预测输出作为软标签输入学生模型进行学习并计算第一损失函数,将真实标签作为硬标签输入学生模型进行学习并计算第二损失函数,然后将第一损失函数和第二损失函数加权求和,作为最终损失函数更新学生模型的参数。
3.根据权利要求2所述的方法,其特征在于,所述第一损失函数的表达式为
;
其中,
其中,N表示模型训练样本数量,p表示教师模型的输出,q表示学生模型的输出,T表示温度,表示教师模型输出的对于第i个类别的概率预测值,/>是学生模型输出的对于第i个类别的概率预测值,/>表示教师网络中第i个样本在温度T时的输出,/>表示学生网络中第i个样本在温度T时的输出,/>表示第i个样本,/>表示第j个样本,/>表示第k个样本。
4.根据权利要求3所述的方法,其特征在于,所述第二损失函数的表达式为
其中
其中,c表示真实标签,表示第j个样本的真实标签。
5.根据权利要求4所述的方法,其特征在于,所述最终损失函数的表达式为
其中,和/>是平衡蒸馏损失和学生损失的参数,且/>。
6.根据权利要求5所述的方法,其特征在于,所述将待训练的小机构与其他的小机构共同构成一个新的联邦学习平台,服务器利用学生模型作为初始共享模型下发给小机构的各个客户端,每个客户端利用本地私有数据在本地进行训练,训练好之后再将模型参数发送回给服务器,服务器进行聚合更新全局模型的步骤,包括:
小机构一共有M个客户机,中心服务器初始化模型参数,执行预设数量的轮次,每轮选取至少1个至多M个客户机参与训练,接下来每个被选中的客户机同时在自己的本地根据服务器下发的本轮模型用自己的数据训练自己的模型/>,上传回服务器,服务器将收集来的各客户机的模型根据各方样本数量用加权平均的方式进行聚合,得到下一轮的模型:
其中,为客户机m上的样本数量,n为所有被选中客户机的总样本数量。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311512843.5A CN117236421B (zh) | 2023-11-14 | 2023-11-14 | 一种基于联邦知识蒸馏的大模型训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311512843.5A CN117236421B (zh) | 2023-11-14 | 2023-11-14 | 一种基于联邦知识蒸馏的大模型训练方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN117236421A true CN117236421A (zh) | 2023-12-15 |
CN117236421B CN117236421B (zh) | 2024-03-12 |
Family
ID=89086460
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311512843.5A Active CN117236421B (zh) | 2023-11-14 | 2023-11-14 | 一种基于联邦知识蒸馏的大模型训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117236421B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117521856A (zh) * | 2023-12-29 | 2024-02-06 | 南京邮电大学 | 一种基于本地特征的大模型切割联邦学习方法及系统 |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113689000A (zh) * | 2021-08-25 | 2021-11-23 | 深圳前海微众银行股份有限公司 | 联邦学习模型的训练方法、装置、电子设备及存储介质 |
CN114429219A (zh) * | 2021-12-09 | 2022-05-03 | 之江实验室 | 一种面向长尾异构数据的联邦学习方法 |
CN114863092A (zh) * | 2022-04-29 | 2022-08-05 | 广州广电运通金融电子股份有限公司 | 一种基于知识蒸馏的联邦目标检测方法及系统 |
CN115630361A (zh) * | 2022-09-19 | 2023-01-20 | 扬州大学 | 一种基于注意力蒸馏的联邦学习后门防御方法 |
CN115907001A (zh) * | 2022-11-11 | 2023-04-04 | 中南大学 | 基于知识蒸馏的联邦图学习方法及自动驾驶方法 |
CN116681144A (zh) * | 2023-06-09 | 2023-09-01 | 安徽师范大学 | 基于动态自适应知识蒸馏的联邦学习模型聚合方法 |
CN116957064A (zh) * | 2023-05-09 | 2023-10-27 | 南京邮电大学 | 基于知识蒸馏的联邦学习隐私保护模型训练方法及系统 |
-
2023
- 2023-11-14 CN CN202311512843.5A patent/CN117236421B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113689000A (zh) * | 2021-08-25 | 2021-11-23 | 深圳前海微众银行股份有限公司 | 联邦学习模型的训练方法、装置、电子设备及存储介质 |
CN114429219A (zh) * | 2021-12-09 | 2022-05-03 | 之江实验室 | 一种面向长尾异构数据的联邦学习方法 |
CN114863092A (zh) * | 2022-04-29 | 2022-08-05 | 广州广电运通金融电子股份有限公司 | 一种基于知识蒸馏的联邦目标检测方法及系统 |
CN115630361A (zh) * | 2022-09-19 | 2023-01-20 | 扬州大学 | 一种基于注意力蒸馏的联邦学习后门防御方法 |
CN115907001A (zh) * | 2022-11-11 | 2023-04-04 | 中南大学 | 基于知识蒸馏的联邦图学习方法及自动驾驶方法 |
CN116957064A (zh) * | 2023-05-09 | 2023-10-27 | 南京邮电大学 | 基于知识蒸馏的联邦学习隐私保护模型训练方法及系统 |
CN116681144A (zh) * | 2023-06-09 | 2023-09-01 | 安徽师范大学 | 基于动态自适应知识蒸馏的联邦学习模型聚合方法 |
Non-Patent Citations (1)
Title |
---|
徐梦炜;刘渊强;黄康;刘?哲;黄罡;: "面向移动终端智能的自治学习系统", 软件学报, no. 10, 14 October 2020 (2020-10-14), pages 28 - 42 * |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117521856A (zh) * | 2023-12-29 | 2024-02-06 | 南京邮电大学 | 一种基于本地特征的大模型切割联邦学习方法及系统 |
CN117521856B (zh) * | 2023-12-29 | 2024-03-15 | 南京邮电大学 | 一种基于本地特征的大模型切割联邦学习方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN117236421B (zh) | 2024-03-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109902222A (zh) | 一种推荐方法及装置 | |
CN110969250A (zh) | 一种神经网络训练方法及装置 | |
CN111259738B (zh) | 人脸识别模型构建方法、人脸识别方法及相关装置 | |
CN117236421B (zh) | 一种基于联邦知识蒸馏的大模型训练方法 | |
WO2020151310A1 (zh) | 文本生成方法、装置、计算机设备及介质 | |
CN109918663A (zh) | 一种语义匹配方法、装置及存储介质 | |
CN114912705A (zh) | 一种联邦学习中异质模型融合的优化方法 | |
CN109983480A (zh) | 使用聚类损失训练神经网络 | |
EP3688673A1 (en) | Neural architecture search | |
EP4350572A1 (en) | Method, apparatus and system for generating neural network model, devices, medium and program product | |
CN114091667A (zh) | 一种面向非独立同分布数据的联邦互学习模型训练方法 | |
CN113190688A (zh) | 基于逻辑推理和图卷积的复杂网络链接预测方法及系统 | |
CN115344883A (zh) | 一种用于处理不平衡数据的个性化联邦学习方法和装置 | |
CN115587633A (zh) | 一种基于参数分层的个性化联邦学习方法 | |
CN109510610A (zh) | 一种基于软投影加权核递归最小二乘的核自适应滤波方法 | |
CN115511109A (zh) | 一种高泛化性的个性化联邦学习实现方法 | |
CN107862329A (zh) | 一种基于深度置信网络的雷达一维距离像真假目标识别方法 | |
CN114758180B (zh) | 一种基于知识蒸馏的轻量化花卉识别方法 | |
US11941867B2 (en) | Neural network training using the soft nearest neighbor loss | |
WO2020220692A1 (zh) | 深度神经网络及其训练 | |
CN113947214A (zh) | 一种基于客户端知识蒸馏的联邦学习实现方法 | |
CN116975686A (zh) | 训练学生模型的方法、行为预测方法和装置 | |
CN116976461A (zh) | 联邦学习方法、装置、设备及介质 | |
CN116645130A (zh) | 基于联邦学习与gru结合的汽车订单需求量预测方法 | |
CN116911459A (zh) | 适应于虚拟电厂的多输入多输出超短期电力负荷预测方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |