CN113191479A - 联合学习的方法、系统、节点及存储介质 - Google Patents
联合学习的方法、系统、节点及存储介质 Download PDFInfo
- Publication number
- CN113191479A CN113191479A CN202010038486.3A CN202010038486A CN113191479A CN 113191479 A CN113191479 A CN 113191479A CN 202010038486 A CN202010038486 A CN 202010038486A CN 113191479 A CN113191479 A CN 113191479A
- Authority
- CN
- China
- Prior art keywords
- model
- network
- node
- models
- central
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 146
- 238000003860 storage Methods 0.000 title claims description 11
- 238000012549 training Methods 0.000 claims abstract description 198
- 230000004927 fusion Effects 0.000 claims abstract description 182
- 238000013528 artificial neural network Methods 0.000 claims abstract description 70
- 230000006870 function Effects 0.000 claims description 258
- 239000013598 vector Substances 0.000 claims description 129
- 238000012545 processing Methods 0.000 claims description 36
- 230000009467 reduction Effects 0.000 claims description 22
- 238000012360 testing method Methods 0.000 claims description 19
- 230000000694 effects Effects 0.000 claims description 13
- 238000013473 artificial intelligence Methods 0.000 abstract description 21
- 230000015654 memory Effects 0.000 description 45
- 230000008569 process Effects 0.000 description 44
- 238000004422 calculation algorithm Methods 0.000 description 38
- 238000004891 communication Methods 0.000 description 30
- 238000013140 knowledge distillation Methods 0.000 description 18
- 238000004364 calculation method Methods 0.000 description 15
- 238000010586 diagram Methods 0.000 description 14
- 239000011159 matrix material Substances 0.000 description 13
- 230000006835 compression Effects 0.000 description 12
- 238000007906 compression Methods 0.000 description 12
- 238000011176 pooling Methods 0.000 description 8
- 238000013527 convolutional neural network Methods 0.000 description 7
- 206010028980 Neoplasm Diseases 0.000 description 6
- 238000004590 computer program Methods 0.000 description 6
- 238000012544 monitoring process Methods 0.000 description 6
- 238000003062 neural network model Methods 0.000 description 6
- 238000013135 deep learning Methods 0.000 description 5
- 238000009826 distribution Methods 0.000 description 5
- 238000005516 engineering process Methods 0.000 description 4
- 238000007499 fusion processing Methods 0.000 description 4
- 238000010606 normalization Methods 0.000 description 4
- 230000003287 optical effect Effects 0.000 description 4
- 238000007500 overflow downdraw method Methods 0.000 description 4
- 230000005540 biological transmission Effects 0.000 description 3
- 238000001514 detection method Methods 0.000 description 3
- 238000011161 development Methods 0.000 description 3
- 230000018109 developmental process Effects 0.000 description 3
- 238000012546 transfer Methods 0.000 description 3
- 238000012935 Averaging Methods 0.000 description 2
- 230000004913 activation Effects 0.000 description 2
- 239000000872 buffer Substances 0.000 description 2
- 238000001816 cooling Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000002156 mixing Methods 0.000 description 2
- 230000008447 perception Effects 0.000 description 2
- 230000002093 peripheral effect Effects 0.000 description 2
- 238000000513 principal component analysis Methods 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 230000004044 response Effects 0.000 description 2
- 238000012216 screening Methods 0.000 description 2
- 230000011218 segmentation Effects 0.000 description 2
- 238000010187 selection method Methods 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000015572 biosynthetic process Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 239000003795 chemical substances by application Substances 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 238000012790 confirmation Methods 0.000 description 1
- 238000011217 control strategy Methods 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 230000006837 decompression Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000006073 displacement reaction Methods 0.000 description 1
- 238000004821 distillation Methods 0.000 description 1
- 238000005538 encapsulation Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 230000008570 general process Effects 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 238000012886 linear function Methods 0.000 description 1
- 239000007788 liquid Substances 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 238000011946 reduction process Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000012163 sequencing technique Methods 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
Abstract
本申请提供了人工智能领域中的一种联合学习的方法,应用于联合学习系统,所述联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,所述方法包括:计算节点接收由中心节点发送的第一初级模型,其中,所述第一初级模型是所述中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;所述计算节点使用所述计算节点的本地数据库对所述第一初级模型进行增量学习从而获得第一中级模型;所述计算节点向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
Description
技术领域
本申请涉及人工智能(Artificial Intelligence,AI)领域,尤其涉及联合学习的方法、系统、节点及存储介质。
背景技术
AI是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式作出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。人工智能领域的研究包括机器人,自然语言处理,计算机视觉,决策与推理,人机交互,推荐与搜索,AI基础理论等。
随着大规模数据集的出现以及算力(比如英伟达显卡的计算性能)的快速发展,深度学习作为人工智能领域的重要算法之一,开始在数据分类、物体分割、物体检测等任务上展现出它惊人的统治力,让世界各地的人们看到了人工智能的巨大潜力,也让人们对于无人驾驶、远程手术、场景理解等先进技术产生了美好的憧憬。
人工智能需要大量的数据来进行训练,而大量的数据往往分布在不同地点的计算节点上,传统方法中,训练数据往往是通过大规模采集与合并各个计算节点的数据获得的,这也带来了一些数据隐私的问题。特别是欧盟通用数据保护条例(General DataProtection Regulation,GDPR)正式执行后,数据隐私保护问题已经成为了社会关注的焦点,之前大规模采集与合并数据的方式已经不适用于当前的法律规范,如何在遵守隐私与安全规定(例如GDPR)的前提下,使用分散的训练数据和计算资源对神经网络进行训练成为了亟需解决的问题。
发明内容
本申请提供了一种联合学习的方法、系统、节点及存储介质,该方法可以解决如何在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题。
第一方面,提供了一种联合学习的方法,应用于联合学习系统,该联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,该方法包括以下步骤:
计算节点接收由中心节点发送的第一初级模型,其中,该第一初级模型是中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
计算节点使用其本地数据库对该第一初级模型进行增量学习从而获得第一中级模型;
计算节点向中心节点发送该第一中级模型,使得中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型。
上述方法中,通过中心节点将训练好的第一初级模型下发给多个计算节点,每个计算节点可以根据本地数据库中的数据对第一初级模型进行增量学习,得到多个第一中级模型并将其返回至中心节点,使得中心节点可以对接收到的多个第一中级模型进行模型融合,从而获得训练好的第一高级模型,该第一高级模型学习了每个计算节点的本地数据库中的训练数据的数据特征,模型性能优越,解决了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题。
在第一方面的一种可能的实现方式中,该计算节点使用其本地数据库对该第一初级模型进行增量学习从而获得第一中级模型包括:计算节点将本地数据库中的训练数据输入第一初级模型中的第一网络,获得第一网络的输出结果,其中,第一初级模型包括第一网络和第一预测函数,第一预测函数用于根据第一网络的输出结果生成第一初级模型的预测值;计算节点将第一网络的输出结果输入软目标预测函数,获得本地数据库中的训练数据的软目标,其中,软目标预测函数是根据温度参数和第一预测函数得到的,温度参数用于使得第一网络的输出结果输入软目标预测函数得到的软目标,不大于第一网络的输出结果输入第一预测函数得到的第一初级模型的预测值;计算节点使用本地数据库对第一中级网络进行训练,根据软目标以及本地数据库中的训练数据的真实标签得到混合损失函数,其中,第一中级网络的网络类型与神经网络的网络类型相同;计算节点根据混合损失函数对第一中级网络进行反向传播,获得第一中级模型。
具体实现中,第一中级网络包括第二网络和第二预测函数,第二预测函数用于根据第二网络的输出结果生成第一中级网络的预测值,混合损失函数是根据第一损失函数以及第二损失函数确定的,第一损失函数是计算节点将本地数据库中的训练数据输入第一中级网络的第二网络获得第二网络的输出结果后,将第二网络的输出结果输入软目标预测函数生成的软目标预测值与软目标之间的差距确定的,第二损失函数是将第二网络的输出结果输入第二预测函数生成的预测值与真实标签之间的差距确定的。
可以理解的,在传统的神经网络训练过程中,将使用数据集对神经网络进行训练,训练过程中根据数据集的真实标签来对预测结果进行调整,因此可以理解为训练好的神经网络模型学到了数据集中的“知识”。而上述实现方式中,使用预先训练好的第一初级模型对本地数据库中的训练数据进行预测,获得每个训练数据的软目标,对第一中级网络进行训练的过程中,将会根据真实标签和软目标同时对第一中级网络的网络参数进行调整,因此可以理解为训练好的神经网络模型不仅学习到了数据集中的“知识”,也学习到了第一初级模型的“知识”,从而达到增量学习的目的。
在第一方面的一种可能的实现方式中,第一高级模型是中心节点根据多个第一中级模型获得多教师融合模型后,使用中心节点的中心数据库对多教师融合模型进行增量学习后获得的,其中,多教师融合模型的全连接层提取的特征向量是多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量的维度相同。
上述实现方式中,通过将多个增量学习后的第一中级模型的全连接层提取的特征向量拼接为一个特征向量获得多教师融合模型,结合中心节点的中心数据库,对该多教师融合模型进行增量学习,使得最终得到的第一高级模型不但学习到了中心数据库中的数据特征,也学习到了多个第一中级模型的预测能力,从而达到将多个第一中级模型融合为一个第一高级模型的目的,进而实现了使用分散的训练数据和计算资源对神经网络进行训练的问题。并且,由于模型融合过程中对第一中级模型的模型结构不作限定,因此计算节点在进行增量学习时,用于结合本地数据库进行训练的第一中级网络结构可以不作任何限制,使得本申请适用的应用场景更加广泛。
在第一方面的一种可能的实现方式中,多教师融合模型的全连接层提取的特征向量是m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,m个正向提升模型是根据多个第一中级模型得到的,正向提升模型是对基础模型有指导作用的模型,基础模型是将中心数据库中的测试数据输入多个第一中级模型后,多个第一中级模型中预测准确率最高的模型。
上述实现方式中,通过对多个第一中级模型进行筛选,获得m个对基础模型有指导作用的正向提升模型,可以理解的,如果正向提升模型能够指导第一中级模型里预测准确率最高的基础模型进行训练,并使得基础模型的预测准确率得到进一步提升,那么正向提升模型一定可以指导未经过训练集的训练、仅为简单的神经网络结构甚至不具备模型预测能力的第一高级网络进行训练,并使得第一高级网络的预测准确率能够得到提升,从而减少模型融合过程中参与的模型数量,提高模型融合的效率,进而提高联合学习的效率。
在第一方面的一种可能的实现方式中,计算节点向中心节点发送第一中级模型,使得中心节点对接收到的多个第一中级模型进行模型融合,获得第一高级模型之后,方法还包括:计算节点在本地数据库中存在新数据的情况下,接收由中心节点发送的第二初级模型,其中,第二初级模型是第一高级模型;计算节点结合新数据对第二初级模型进行增量学习,获得第二中级模型;计算节点向中心节点发送第二中级模型,使得中心节点对接收到的多个更新后的第二中级模型进行模型融合,获得第二高级模型。
上述实现方式中,通过在计算节点的本地数据库中存在新数据的情况下,可以将上一个训练周期获得的第一高级模型作为第二初级模型,再进行第二个训练周期的联合学习,从而获得性能更好的第二高级模型。可以理解的,本申请提供的联合学习方法,在一个训练周期内,中心节点可以向计算节点发送一次第一初级模型数据,计算节点可以向中心节点发送一次第一中级模型数据即可完成一次联合学习的过程,使得整个联合学习的过程占用的通信资源数量极少,使得本申请适用的应用场景更加广泛。
在第一方面的一种可能的实现方式中,一个或多个中心节点包括第一中心节点和第二中心节点,第一中心节点和第二中心节点部署于不同地理地区,计算节点接收由中心节点发送的第一初级模型包括:计算节点接收由第一中心节点发送的第一初级模型;计算节点向中心节点发送第一中级模型,使得中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型包括:计算节点向第二中心节点发送第一中级模型,使得第二中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型。
上述实现方式中,训练初级模型的中心节点和进行模型融合的中心节点可以是不同的中心节点,也可以是同一个中心节点,具体可以根据应用场景的实际情况确定,使得本申请适用的应用场景更加广泛。
第二方面,提供了一种联合学习的方法,应用于联合学习系统,联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,方法包括:
中心节点向计算节点发送第一初级模型,其中,第一初级模型是中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
中心节点接收计算节点发送的第一中级模型,其中,第一中级模型是计算节点使用计算节点的本地数据库对第一初级模型进行增量学习后获得的;
中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型。
上述方法中,通过中心节点将训练好的第一初级模型下发给多个计算节点,每个计算节点可以根据本地数据库中的数据对第一初级模型进行增量学习,得到多个第一中级模型并将其返回至中心节点,使得中心节点可以对接收到的多个第一中级模型进行模型融合,从而获得训练好的第一高级模型,该第一高级模型学习了每个计算节点的本地数据库中的训练数据的数据特征,模型性能优越,解决了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题。
在第二方面的一种可能的实现方式中,第一中级模型是计算节点将本地数据库中的训练数据输入第一初级模型中的第一网络获得第一网络的输出结果,并将第一网络的输出结果输入软目标预测函数获得本地数据库中的训练数据的软目标之后,使用本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,根据混合损失函数对第一中级网络进行反向传播从而获得第一中级模型,其中,第一初级模型包括第一网络和第一预测函数,第一预测函数用于根据第一网络的输出结果生成第一初级模型的预测值,软目标预测函数是根据温度参数和第一预测函数得到的,温度参数用于使得第一网络的输出结果输入软目标预测函数得到的软目标,不大于第一网络的输出结果输入第一预测函数得到的第一初级模型的预测值。
具体实现中,第一中级网络包括第二网络和第二预测函数,第二预测函数用于根据第二网络的输出结果生成第一中级网络的预测值,混合损失函数是根据第一损失函数以及第二损失函数确定的,第一损失函数是计算节点将本地数据库中的训练数据输入第一中级网络的第二网络获得第二网络的输出结果后,将第二网络的输出结果输入软目标预测函数生成的软目标预测值与软目标之间的差距确定的,第二损失函数是将第二网络的输出结果输入第二预测函数生成的预测值与真实标签之间的差距确定的。
可以理解的,在传统的神经网络训练过程中,将使用数据集对神经网络进行训练,训练过程中根据数据集的真实标签来对预测结果进行调整,因此可以理解为训练好的神经网络模型学到了数据集中的“知识”。而上述实现方式中,使用预先训练好的第一初级模型对本地数据库中的训练数据进行预测,获得每个训练数据的软目标,对第一中级网络进行训练的过程中,将会根据真实标签和软目标同时对第一中级网络的网络参数进行调整,因此可以理解为训练好的神经网络模型不仅学习到了数据集中的“知识”,也学习到了教师模型的“知识”,从而达到增量学习的目的。
在第二方面的一种可能的实现方式中,中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型包括:中心节点根据多个第一中心模型,获得多教师融合模型,其中,多教师融合模型的全连接层提取的特征向量是多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量维度相同;中心节点使用中心数据库对多教师融合模型进行增量学习,获得第一高级模型。
上述实现方式中,通过将多个增量学习后的第一中级模型的全连接层提取的特征向量拼接为一个特征向量获得多教师融合模型,结合中心节点的中心数据库,对该多教师融合模型进行增量学习,使得最终得到的第一高级模型不但学习到了中心数据库中的数据特征,也学习到了多个第一中级模型的预测能力,从而达到将多个第一中级模型融合为一个第一高级模型的目的,进而实现了使用分散的训练数据和计算资源对神经网络进行训练的问题。并且,由于模型融合过程中对第一中级模型的模型结构不作限定,因此计算节点在进行增量学习时,用于结合本地数据库进行训练的第一中级网络结构可以不作任何限制,使得本申请适用的应用场景更加广泛。
在第二方面的一种可能的实现方式中,根据多个第一中级模型,获得多教师融合模型包括:中心节点在多个第一中级模型中筛选出基础模型和m个正向提升模型,其中,基础模型是中心节点将中心数据库中的测试数据输入多个第一中级模型后,多个第一中级模型中预测准确率最高的模型,正向提升模型是对基础模型有指导作用的模型;中心节点根据m个正向提升模型,获得多教师融合模型,其中,多教师融合模型的全连接层提取的特征向量是m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量维度相同。
上述实现方式中,通过对多个第一中级模型进行筛选,获得m个对基础模型有指导作用的正向提升模型,可以理解的,如果正向提升模型能够指导第一中级模型里预测准确率最高的基础模型进行训练,并使得基础模型的预测准确率得到进一步提升,那么正向提升模型一定可以指导未经过训练集的训练、仅为简单的神经网络结构甚至不具备模型预测能力的第一高级网络进行训练,并使得第一高级网络的预测准确率能够得到提升,从而减少模型融合过程中参与的模型数量,提高模型融合的效率,进而提高联合学习的效率。
在第二方面的一种可能的实现方式中,中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型之后,方法还包括:中心节点向计算节点发送第二初级模型,使得计算节点结合本地数据库中的新数据对第二初级模型进行增量学习,获得第二中级模型,其中,第二初级模型是高级模型;中心节点接收计算节点发送的第二中级模型;中心节点对接收到的多个第二中级模型进行模型融合,获得第二高级模型。
上述实现方式中,通过在计算节点的本地数据库中存在新数据的情况下,可以将上一个训练周期获得的第一高级模型作为第二初级模型,再进行第二个训练周期的联合学习,从而获得性能更好的第二高级模型。可以理解的,本申请提供的联合学习方法,在一个训练周期内,中心节点可以向计算节点发送一次第一初级模型数据,计算节点可以向中心节点发送一次第一中级模型数据即可完成一次联合学习的过程,使得整个联合学习的过程占用的通信资源数量极少,使得本申请适用的应用场景更加广泛。
在第二方面的一种可能的实现方式中,一个或多个中心节点包括第一中心节点和第二中心节点,中心节点向计算节点发送第一初级模型包括:第一中心节点向计算节点发送第一初级模型,其中,第一初级模型是第一中心节点根据第一中心节点的中心数据库对神经网络进行训练后得到的;中心节点接收计算节点发送的第一中级模型包括:第二中心节点接收计算节点发送的第一中级模型;中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型包括:第二中心节点对接收到的多个第一中级模型进行模型融合,从而获得第一高级模型。
上述实现方式中,训练初级模型的中心节点和进行模型融合的中心节点可以是不同的中心节点,也可以是同一个中心节点,具体可以根据应用场景的实际情况确定,使得本申请适用的应用场景更加广泛。
第三方面,提供了一种计算节点,应用于联合学习系统,联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,计算节点包括接收单元、学习单元以及发送单元,其中,
接收单元用于接收由中心节点发送的第一初级模型,其中,第一初级模型是中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
学习单元用于使用计算节点的本地数据库对第一初级模型进行增量学习从而获得第一中级模型;
发送单元用于向中心节点发送第一中级模型,使得中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型。
在第三方面的一种可能的实现方式中,学习单元用于将本地数据库中的训练数据输入第一初级模型中的第一网络,获得第一网络的输出结果,其中,第一初级模型包括第一网络和第一预测函数,第一预测函数用于根据第一网络的输出结果生成第一初级模型的预测值;学习单元用于将第一网络的输出结果输入软目标预测函数,获得本地数据库中的训练数据的软目标,其中,软目标预测函数是根据温度参数和第一预测函数得到的,温度参数用于使得第一网络的输出结果输入软目标预测函数得到的软目标,不大于第一网络的输出结果输入第一预测函数得到的第一初级模型的预测值;学习单元用于使用本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,其中,第一中级网络的网络类型与神经网络的网络类型相同;学习单元用于根据混合损失函数对第一中级网络进行反向传播,获得训练好的第一中级模型。
在第三方面的一种可能的实现方式中,第一中级网络包括第二网络和第二预测函数,第二预测函数用于根据第二网络的输出结果生成第一中级网络的预测值,混合损失函数是根据第一损失函数以及第二损失函数确定的,第一损失函数是计算节点将本地数据库中的训练数据输入第一中级网络的第二网络获得第二网络的输出结果后,将第二网络的输出结果输入软目标预测函数生成的软目标预测值与软目标之间的差距确定的,第二损失函数是将第二网络的输出结果输入第二预测函数生成的预测值与真实标签之间的差距确定的。
在第三方面的一种可能的实现方式中,第一高级模型是中心节点根据多个第一中级模型获得多教师融合模型后,使用中心节点的中心数据库对多教师融合模型进行增量学习后获得的,其中,多教师融合模型的全连接层提取的特征向量是多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量的维度相同。
在第三方面的一种可能的实现方式中,多教师融合模型的全连接层提取的特征向量是m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,m个正向提升模型是根据多个第一中级模型得到的,正向提升模型是对基础模型有指导作用的模型,基础模型是将中心数据库中的测试数据输入多个第一中级模型后,多个第一中级模型中预测准确率最高的模型。
在第三方面的一种可能的实现方式中,接收单元还用于在向中心节点发送第一中级模型,使得中心节点对接收到的多个第一中级模型进行模型融合,获得第一高级模型之后,接收单元在本地数据库中存在新数据的情况下,接收由中心节点发送的第二初级模型,其中,第二初级模型是第一高级模型;学习单元还用于结合新数据对第二初级模型进行增量学习,获得第二中级模型;发送单元还用于向中心节点发送第二中级模型,使得中心节点对接收到的多个更新后的第二中级模型进行模型融合,获得第二高级模型。
在第三方面的一种可能的实现方式中,一个或多个中心节点包括第一中心节点和第二中心节点,接收单元还用于接收由第一中心节点发送的第一初级模型;发送单元还用于向第二中心节点发送第一中级模型,使得第二中心节点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型。
第四方面,提供了一种中心节点,应用于联合学习系统,联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,中心节点包括发送单元、接收单元以及融合单元,其中,
发送单元用于向计算节点发送第一初级模型,其中,第一初级模型是中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
接收单元用于接收计算节点发送的第一中级模型,其中,第一中级模型是计算节点使用计算节点的本地数据库对第一初级模型进行增量学习后获得的;
融合单元用于对接收到的多个第一中级模型进行模型融合从而获得第一高级模型。
在第四方面的一种可能的实现方式中,第一中级模型是计算节点将本地数据库中的训练数据输入第一初级模型中的第一网络获得第一网络的输出结果,并将第一网络的输出结果输入软目标预测函数获得本地数据库中的训练数据的软目标之后,使用本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,根据混合损失函数对第一中级网络进行反向传播从而获得的第一中级模型,其中,第一初级模型包括第一网络和第一预测函数,第一预测函数用于根据第一网络的输出结果生成第一初级模型的预测值,软目标预测函数是根据温度参数和第一预测函数得到的,温度参数用于使得第一网络的输出结果输入软目标预测函数得到的软目标,不大于第一网络的输出结果输入第一预测函数得到的第一初级模型的预测值。
在第四方面的一种可能的实现方式中,第一中级网络包括第二网络和第二预测函数,第二预测函数用于根据第二网络的输出结果生成第一中级网络的预测值,混合损失函数是根据第一损失函数以及第二损失函数确定的,第一损失函数是计算节点将本地数据库中的训练数据输入第一中级网络的第二网络获得第二网络的输出结果后,将第二网络的输出结果输入软目标预测函数生成的软目标预测值与软目标之间的差距确定的,第二损失函数是将第二网络的输出结果输入第二预测函数生成的预测值与真实标签之间的差距确定的。
在第四方面的一种可能的实现方式中,融合单元用于根据多个第一中级模型,获得多教师融合模型,其中,多教师融合模型的全连接层提取的特征向量是多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量维度相同;融合单元用于使用中心数据库对多教师融合模型进行增量学习,获得第一高级模型。
在第四方面的一种可能的实现方式中,第一中级网络包括第二网络和第二预测函数,第二预测函数用于根据第二网络的输出结果生成第一中级网络的预测值,混合损失函数是根据第一损失函数以及第二损失函数确定的,第一损失函数是计算节点将本地数据库中的训练数据输入第一中级网络的第二网络获得第二网络的输出结果后,将第二网络的输出结果输入软目标预测函数生成的软目标预测值与软目标之间的差距确定的,第二损失函数是将第二网络的输出结果输入第二预测函数生成的预测值与本地数据库中的训练数据的真实标签之间的差距确定的。
在第四方面的一种可能的实现方式中,融合单元用于在多个第一中级模型中筛选出基础模型和m个正向提升模型,其中,基础模型是中心节点将中心数据库中的测试数据输入多个第一中级模型后,多个第一中级模型中预测准确率最高的模型,正向提升模型是对基础模型有指导作用的模型;融合单元用于根据m个正向提升模型,获得多教师融合模型,其中,多教师融合模型的全连接层提取的特征向量是m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量维度相同。
在第四方面的一种可能的实现方式中,接收单元还用于在点对接收到的多个第一中级模型进行模型融合从而获得第一高级模型之后,接收由计算节点发送的计算节点的本地数据库中存在新数据的消息;发送单元还用于根据本地数据库中存在新数据的消息,向计算节点发送第二初级模型,使得计算节点结合新数据对第二初级模型进行增量学习,获得第二中级模型,其中,第二初级模型是高级模型;接收单元还用于接收计算节点发送的第二中级模型;融合单元还用于对接收到的多个第二中级模型进行模型融合,获得第二高级模型。
第五方面,提供了一种联合学习系统,系统包括一个或多个如第二方面描述的中心节点以及一个或多个如第一方面描述的计算节点,该中心节点连接该计算节点。
第六方面,提供了一种计算机可读存储介质,包括指令,当指令在计算设备上运行时,使得计算设备执行如第一方面和/或第二方面描述的方法。
第七方面,提供了一种计算机程序产品,包括计算机程序,当计算机程序被计算设备读取并执行时,实现如第一方面和/或第二方面所描述的方法。
第八方面,提供了一种电子设备,包括处理器和存储器,处理器执行存储器中的代码实现如第一方面和/或第二方面描述的方法。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍。
图1是本申请实施例提供的一种人工智能主体框架示意图;
图2是本申请实施例提供的一种联合学习系统的结构示意图;
图3A-图3C是本申请实施例提供的联合学习方法在一应用场景下的流程示意图;
图4是本申请实施例提供的一种联合学习方法的流程示意图;
图5是本申请实施例提供的一种增量学习的流程示意图;
图6是本申请实施例提供的一种模型融合方法的流程示意图;
图7是本申请实施例提供的一种确定正向提升模型的方法流程示意图;
图8是本申请实施例提供的一种多教师融合模型的结构示意图;
图9是本申请提供的一种计算节点的结构示意图;
图10是本申请提供的一种中心节点的结构示意图;
图11是本申请提供的一种芯片的硬件结构示意图;
图12是本申请提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
图1示出一种人工智能主体框架示意图,该主体框架描述了人工智能系统总体工作流程,适用于通用的人工智能领域需求。
下面从“智能信息链”(水平轴)和“IT价值链”(垂直轴)两个维度对上述人工智能主题框架进行阐述。
“智能信息链”反映从数据的获取到处理的一列过程。举例来说,可以是智能信息感知、智能信息表示与形成、智能推理、智能决策、智能执行与输出的一般过程。在这个过程中,数据经历了“数据—信息—知识—智慧”的凝练过程。
“IT价值链”从人智能的底层基础设施、信息(提供和处理技术实现)到系统的产业生态过程,反映人工智能为信息技术产业带来的价值。
(1)基础设施:
基础设施为人工智能系统提供计算能力支持,实现与外部世界的沟通,并通过基础平台实现支撑。通过传感器与外部沟通;计算能力由智能芯片提供,比如中央处理器(Central Processing Unit,CPU)、神经网络处理器(Neural-network Processing Unit,NPU)、图像处理器(Graphics Processing Unit,GPU)、专用集成电路(ApplicationSpecific Integrated Circuit,ASIC)、可编程逻辑门阵列(Field Programmable GateArray,FPGA)等硬件加速芯片;基础平台包括分布式计算框架及网络等相关的平台保障和支持,可以包括云存储和计算、互联互通网络等。举例来说,传感器和外部沟通获取数据,这些数据提供给基础平台提供的分布式计算系统中的智能芯片进行计算。
(2)数据
基础设施的上一层的数据用于表示人工智能领域的数据来源。数据涉及到图形、图像、语音、文本,还涉及到传统设备的物联网数据,包括已有系统的业务数据以及力、位移、液位、温度、湿度等感知数据。
(3)数据处理
数据处理通常包括数据训练,机器学习,深度学习,搜索,推理,决策等方式。
其中,机器学习和深度学习可以对数据进行符号化和形式化的智能信息建模、抽取、预处理、训练等。
推理是指在计算机或智能系统中,模拟人类的智能推理方式,依据推理控制策略,利用形式化的信息进行机器思维和求解问题的过程,典型的功能是搜索与匹配。
决策是指智能信息经过推理后进行决策的过程,通常提供分类、排序、预测等功能。
(4)通用能力
对数据经过上面提到的数据处理后,进一步基于数据处理的结果可以形成一些通用的能力,比如可以是算法或者一个通用系统,例如,翻译,文本的分析,计算机视觉的处理,语音识别,图像的识别等等。
(5)智能产品及行业应用
智能产品及行业应用指人工智能系统在各领域的产品和应用,是对人工智能整体解决方案的封装,将智能信息决策产品化、实现落地应用,其应用领域主要包括:智能制造、智能交通、智能家居、智能医疗、智能安防、自动驾驶,平安城市,智能终端等。
随着大规模数据集的出现以及算力(比如英伟达显卡的计算性能)的快速发展,深度学习作为人工智能领域的重要算法之一,开始在数据分类、物体分割、物体检测等任务上展现出它惊人的统治力。尤其是2016年AlphaGo的横空出世,让世界各地的人们看到了人工智能的巨大潜力,也让人们对于无人驾驶、远程手术、场景理解等先进技术产生了美好的憧憬。
随着社会的发展,一方面人工智能带来了巨大的生产力,另一方面它也需要大量的数据来进行训练,这些训练数据往往是用户在使用产品的过程中产生,其中涉及到一些用户的使用频率、使用时间、使用偏好等隐私数据,如果将用户的这些数据统一收集到服务器端然后再进行大规模的训练,容易使用户对个人数据的隐私问题而产生担忧。特别是GDPR条例正式执行后,数据隐私保护问题已经成为了社会关注的焦点,之前大规模采集与合并数据的方式已经不适用于当前的法律规范,如何在遵守隐私与安全规定(例如GDPR条例)的前提下,使用分散的训练数据和计算资源对神经网络进行训练成为了亟需解决的问题。
联合学习(Federated Learning)是谷歌公司提出的一种解决上述问题的方法,它能使多台客户端计算设备以协作的形式,训练出共享的预测模型。首先,服务器端给客户端计算设备下发初始模型;其次,客户端计算设备根据本地数据对初始模型进行训练,训练完成后将训练好的初始模型的梯度更新数据发送到服务端;接着,服务端整合接收到的所有梯度更新数据,作为模型的一次更新;最后,将更新后的模型再次发送到客户端计算设备,重复上述训练过程,直至服务端的模型收敛。从而实现了在不传输个人数据的同时,完成模型的训练,保证了用户数据的安全。
然而,上述方法仍存在缺陷,个别情况下,依旧无法解决在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题。第一,梯度更新数据蕴含部分原始图像信息,即使在梯度数据加密的情况下,在某些情况下仍有可能存在加密信息被破解的可能性,进而存在从梯度更新数据反推出原始数据的可能性,换句话说,在某些情况下仍然可能泄露用户信息。第二,需要客户端计算设备向服务端发送很多次梯度更新数据,直至服务端的模型收敛,传输频率很高,这将会占用大量的通信资源,使得该方法的应用场景得到极大地限制,换句话说,在一些通信带宽受限的应用场景下依旧无法解决在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题。第三,由于最终训练得到的模型,是根据各个客户端计算设备发送的梯度更新数据整合后获得的,具体地,将会对同一个参数的梯度或者权重进行累加然后求平均,因此各个客户端计算设备和服务端必须使用完全相同的模型结构参数进行训练,各个客户端计算设备和服务端使用的训练数据也必须是非常类似的独立同分布数据(Independent andIdentically Distributed,IID),比如必须都是黄种人的人脸数据,而不能是黄种人和白种人的人脸数据,否则将会因为数据差异较大使得更新后的梯度数据差异很大,导致对同一个参数的梯度或者权重进行累加然后求平均而确定的最终的模型性能很差。而实际应用中,如果不同客户端计算设备属于不同的组织,那么每一个客户端计算设备采用的模型结构无法确保完全相同,比如超参数设置的迭代次数、输入图像参数大小等等都无法确保其统一,训练数据也无法确保是一定是独立同分布数据,比如不同省市的人脸数据中可能存在不同民族、不同人种的数据,使得该方法的应用场景得到极大限制。换句话说,在客户端计算设备由不同组织进行管理的应用场景下,该方法依旧无法解决在遵守隐私与安全规定的前提下使用分散的训练数据和计算资源对神经网络进行训练的问题。
为了解决上述如何在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题,本申请提供了一种联合学习系统,客户端计算设备接收到服务端发送的初始模型之后,可以根据各自的本地数据对初始模型进行增量学习得到增量模型,并将其传输至服务端,服务端经过模型融合算法得到联合学习后的模型。从而实现在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练。
本申请提供的联合学习系统内部的单元模块也可以有多种划分,各个模块可以是软件模块,也可以是硬件模块,也可以部分是软件模块部分是硬件模块,本申请不对其进行限制。图2为一种示例性的划分方式,如图2所示,联合学习系统100包括中心节点110和计算节点120。其中,中心节点110可以是前述内容中的服务端,计算节点120可以是前述内容中的客户端计算设备,下面分别介绍每个节点的功能。
中心节点110可以是单独设置于单个终端或服务器内,也可以设置于服务器集群内,也可以设置于数据中心。当中心节点110设置于单个终端或服务器内时,该终端或服务器可以设置于世界的任意一个区域;当中心节点110设置于服务器集群内时,服务器集群可以设置于世界的任意一个区域;当中心节点110设置与数据中心时,中心节点110可以设置于电力供应充足、土地价格低廉、冷却条件较好的区域,当然,在不考虑成本等因素时,中心节点110也可以设置于世界的任意一个区域。
计算节点120可以是单独设置于单个服务器内,也可以设置于服务器集群内,也可以设置于数据中心。当计算节点120设置于单个服务器内时,服务器可以设置于世界的任意一个区域;当计算节点120设置于服务器集群内时,服务器集群可以设置于世界的任意一个区域;当计算节点120设置与数据中心时,计算节点120可以设置于电力供应充足、土地价格低廉、冷却条件较好的区域,当然,在不考虑成本等因素时,计算节点120也可以设置于世界的任意一个区域。
其中,中心节点110和计算节点120可以分别设置在多个不同的区域。举例而言,图2中的中心节点110可设置在深圳市南山区,计算节点1设置在深圳市福田区,计算节点2设置在深圳市宝安区,计算节点3设置在深圳市龙华区等等,此处不作具体限定。
具体实现中,中心节点110可以包括中心数据库111、模型训练模块112以及模型融合模块113,其中,中心数据库111中存储有中心节点本地的训练数据;计算节点120可以包括本地数据库121和增量学习模块122,其中,本地数据库121中存储有计算节点本地的训练数据。在本申请实施例中,中心节点的模型训练模块112可以使用中心数据库111训练得到第一初级模型,并将其加密后发送至多个计算节点120处;计算节点120可以通过增量学习模块122,结合本地数据库121对该第一初级模型进行增量学习,从而生成训练好的第一中级模型,并将其加密后发送至中心节点110;最后中心节点110的模型融合模块113可以根据接收到多个计算节点120发送的第一中级模型,结合中心数据库111中的数据通过模型融合算法进行模型融合,从而获得最终的第一高级模型。该第一高级模型学习到了各个计算节点120的本地数据库中训练数据的数据特征,模型性能非常优秀。
需要说明的,图2仅以1个中心节点110和3个计算节点120为例进行了举例说明,具体实现中,中心节点110的数量可以是一个或者多个,计算节点的数量也可以是一个或者多个,具体可以根据实际情况确定中心节点和计算节点的数量,本申请不作具体限定。其中,中心节点110的数量是多个的情况下,训练并下发第一初级模型的节点,与融合多个第一中级模型的节点可以是不同的中心节点110,举例来说,中心节点1位于深圳市,用于根据中心节点1的中心数据库训练第一初级模型,中心节点2位于上海市,用于根据中心节点2的中心数据库,对多个计算节点120发送的第一中级模型进行模型融合,获得最终的第一高级模型。上述举例仅用于说明,并不能构成具体限定。
可以理解的,本申请提供的联合学习系统,中心节点只需要向计算节点发送一次初始模型数据,计算节点只需要向中心节点发送一次第一中级模型数据,整个系统不需要占用过多的通信资源,并且,由于初始模型数据和第一中级模型数据均为加密传输,并且模型数据与梯度数据不同,模型数据是无法恢复出原始的训练数据的,因此本申请提供的联合系统完全实现了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练。
下面对本申请提供的联合学习系统适用的应用场景进行说明。
本申请提供的联合学习系统可以应用于多地区、多用户联合学习的场景。该场景下每个用户的数据由于隐私问题无法聚集到某处进行统一训练。举例来说,比如应用在手机用户智能识图模型更新的场景,终端用户的相册数据是隐私敏感数据,但各个终端用户拥有图像数量较少,此外用户的使用习惯各不相同,导致数据类别分布不均匀,不是独立同分布的数据。在这一应用场景下,考虑到用户的隐私问题,无法将各个手机中的相册数据统一上传到云端进行训练,并且,考虑通信带宽/流量费用问题,又难以承担前述内容中的传统联合学习方法带来的巨大通信量。而使用本申请提供的联合学习系统,如图3A所示,首先,云端(中心节点110)可以提前使用云端的数据库(中心数据库111)中的相册数据训练好多个第一初级模型,并将其加密;其次,每个用户的手机(计算节点120)可以在空闲时间从云端接收该第一初级模型,再根据每个手机中的用户数据(本地数据库121)中的相册数据进行增量学习得到第一中级模型,具体可以参考图3A中手机1描绘的增量学习的过程,应理解,手机2~手机N根据第一初级模型和本地数据库121进行增量学习生成第一中级模型的过程与手机1相同,因此图3A中手机2~手机N并没有将该步骤再展开描述;接着,每个用户的手机可以在空闲时间将第一中级模型加密后发送至云端;最后,云端可以通过模型融合算法将接收到的多个第一中级模型进行模型融合,从而得到模型性能比第一初级模型以及每个第一中级模型都要好的第一高级模型。可以理解的,由于每个手机只能根据本地数据库121进行增量学习,这样训练出来第一中级模型因为缺少了一部分数据,可能对于某些类别识别的并不准确,而中心节点通过对多个第一中级模型进行模型融合后获得的第一高级模型,学习了各个手机的相册数据特征,因此最终获得的第一高级模型的模型性能优良,完全实现了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练。应理解,上述举例仅用于说明,本申请还可以使用在其他多地区、多用户的联合学习场景,本申请不作具体限定。
本申请提供的联合学习系统还可以应用于多地区、多用户的异构模型训练的联合学习场景。这里,异构模型指的是该场景下,每个计算节点训练模型的超参数设置可能不同,训练模型的方法可能不同。举例来说,各个医院的本地肿瘤患者数据较少,使用本地数据库121训练的模型性能很差,而肿瘤患者数据是高度隐私敏感数据,在这一应用场景下,考虑到肿瘤患者的隐私问题,无法将各个医院的肿瘤患者数据统一上传到云端进行训练,并且,各医院使用的系统为自行采购,因此各个医院训练模型的超参数设置可能不同,训练模型的方法也不尽相同,也不适用于前述内容中的传统联合学习。而使用本申请提供的联合学习系统,如图3B所示,可以将计算能力较强、肿瘤患者数据量最多的中心医院作为中心节点110,其他医院作为计算节点120。首先,中心医院可以提前使用中心医院的数据库(中心数据库111)中的肿瘤患者数据训练好多个第一初级模型,并将其加密;其次,每个医院可以根据自身的情况选择是否参加联合学习,确认参加联合学习的医院可以向中心医院获取第一初级模型,再根据各自本地数据库121中的肿瘤数据结合第一初级模型进行增能量学习获得第一中级模型,具体可以参考图3B中医院1描绘的增量学习的过程,同理,医院2~医院N根据第一初级模型和本地数据库121进行增量学习生成第一中级模型的过程与医院1相同,因此图3B中医院2~医院N并没有将该步骤再展开描述;接着,每个医院可以在空闲时间将第一中级模型加密后发送至中心医院;最后,中心医院可以通过模型融合算法将接收到的多个第一中级模型进行模型融合,从而得到模型性能比第一初级模型以及每个第一中级模型都要好的第一高级模型。可以理解的,由于中心节点通过对多个第一中级模型进行模型融合后获得的第一高级模型,因此解决了传统联合学习无法对异构模型进行联合学习的问题。应理解,上述举例仅用于说明,本申请还可以使用在其他多地区、多用户的异构模型训练的联合学习场景,本申请不作具体限定。
本申请提供的联合学习系统还可以应用于多地区、多用户、异构模型训练以及无通信带宽的场景。这里,无通信带宽指的是该场景下,计算节点与中心节点之间无法进行网络通信。举例来说,用于训练人脸识别网络的视频监控数据往往存储于各个公安局的本地数据库121中,而视频监控数据涉及到用户的身份信息,属于高度隐私数据,各地公安局可能存在不允许互相通信的情况,在这一应用场景下,考虑到用户隐私问题,无法将各地市公安局将本地的视频监控数据汇总在一起统一训练;同时各省市公安局合作的算法厂商也不相同,这就会导致每个公安局系统训练人脸识别网络的模型架构不同、超参数设置也不尽相同,因此也不适用于前述内容中的传统联合学习。而使用本申请提供的联合学习系统,如图3C所示,可以将计算能力较强、视频监控数据量最多的省公安厅作为中心节点110,其他市公安局(分局)或者派出所作为计算节点120。首先,省公安厅可以提前使用省公安厅的数据库(中心数据库111)中的视频监控数据训练好多个第一初级模型,并将其加密;其次,每个公安局可以派工作人员去省公安厅将第一初级模型刻录至光盘内,再将其取回至自己所属的公安局,使得每个公安局可以根据各自本地数据库121中的视频监控数据,结合该第一初级模型进行增能量学习获得第一中级模型,具体可以参考图3C中公安局1描绘的增量学习的过程,同理,公安局2~公安局N根据第一初级模型和本地数据库121进行增量学习生成第一中级模型的过程与公安局1相同,因此图3C中公安局2~公安局N并没有将该步骤再展开描述;接着,每个公安局可以将第一中级模型加密后派工作人员将其刻录至光盘内,并将刻录好的光盘送达至省公安厅;最后,省公安厅可以通过模型融合算法将接收到的多个第一中级模型进行模型融合,从而得到模型性能比第一初级模型以及每个第一中级模型都强的第一高级模型。可以理解的,传统的联合学习由于是计算节点传输梯度数据至中心节点,需要多次迭代直至误差收敛,因此计算节点需要传输很多次梯度数据至中心节点,如果在无通信带宽的情况下,将完全无法使用人工代替,而本申请提供的联合学习系统由于中心节点只向计算节点发送一次第一初级模型数据,计算节点只向中心节点发送一次第一中级模型数据,因此中心节点和计算节点之间的通信过程在无通信带宽的应用场景下,完全可以使用人工代替,从而解决了无通信带宽应用场景下无法进行联合学习的问题。应理解,上述举例仅用于说明,本申请还可以使用在其他多地区、多用户的异构模型训练的联合学习场景,本申请不作具体限定。
下面结合附图,对本申请提供的上述联合学习系统如何解决在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题,进行详细介绍。
如图4所示,本申请提供了一种联合学习方法,该方法应用于联合学习系统,所述联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,所述计算节点与所述中心节点的地理位置不同,该方法包括以下步骤:
S301:中心节点110向计算节点120发送第一初级模型,计算节点120接收由中心节点发送的第一初级模型,其中,所述第一初级模型是所述中心节点110根据中心节点110的中心数据库111对神经网络进行训练后得到的。
在本申请实施例中,在步骤S301之前,中心节点可以使用本地的中心数据库111对神经网络进行训练,获得第一初级模型,这里的神经网络可以根据实际需要确定。举例来说,在图3C的应用场景下,中心节点(省公安厅)需要根据监控数据对该神经网络进行训练从而获得人脸识别模型(即第一初级模型),目前该方面在计算机视觉领域已经非常成熟,形成了各种人脸识别网络结构,比如Resnet、DenseNet、GoogleNet等,因此该神经网络可以根据实际需要进行选择,例如该神经网络可以是Resnet50,第一初级模型可以是使中心节点使用本地的中心数据库111对Resnet50进行训练获得的。应理解,上述举例仅用于说明,本申请不对待训练网络的具体结构进行限定。
下面以残差网络Resnet为例,对第一初级模型进行详细解释。Resnet是一种特殊的卷积神经网络(Convolutional Neuron Network,CNN),与CNN类似的,Resnet也包括卷积层、池化层、全连接层以及输出层,其中,卷积层用于提取图像特征,比如轮廓特征、颜色特征以及纹理特征等等;池化层则对输出图像的大小进行压缩,从而保证图像特征的高度紧凑;全连接层用于根据卷积层和池化层提取的图像特征,进行线性整合,输出用于分类的特征向量;输出层(或者softmax层)用于将全连接层提取的特征向量,输入预先设定的预测函数(Softmax)中,从而得到一个概率分布,确定最终的分类结果。
但是,Resnet与传统的CNN不同的是,全连接层之间除了逐层相连之外,例如第1层全连接层连接第2层全连接层,第2层全连接层连接第3层全连接层,第3层全连接层连接第4层全连接层(这是一条神经网络的数据运算通路,也可以形象的称为神经网络传输),Resnet还多了一条直连支路,这条直连支路从第1层全连接层直接连到第4层全连接层,即跳过第2层和第3层全连接层的处理,将第1层全连接层的数据直接传输给第4层全连接层进行运算。需要说明的,全连接层输出的特征向量可以是多个维度的特征向量,具体可以根据实际情况进行设置,比如输出256维的特征向量,后续计算过程中这个256维的特征向量就可以表示每一个人脸图像蕴含的信息,该256维的向量输入softmax层之后,即可获得确定输入的人脸图像所属类别的概率分布,比如五分类模型中,输出层的输出结果可以是[0.1,0.1,0.1,0.6,0.1],即预测结果显示该人脸图像所述的类别为第四类。应理解,上述举例仅用于说明,并不能构成具体限定。
因此,在上述结构的Resnet具体训练过程中,可以将中心数据库111中的每一个训练数据输入Resnet网络,依次经过卷积层、池化层、全连接层和输出层,获得每一个训练数据的预测结果。因为希望Resnet的预测结果尽可能的接近真正想要预测的值,所以可以通过比较当前网络的预测值和真正想要的目标值,再根据两者之间的差异情况来更新每一层神经网络的权重向量(当然,在第一次更新之前通常会有初始化的过程,即为Resnet中的各层预先配置参数),比如,如果Resnet的预测值高了,就调整权重向量让它预测低一些,不断的调整,直到Resnet能够预测出真正想要的目标值(即训练数据的真实标签)。因此,就需要预先定义“如何比较预测值和目标值之间的差异”,这便是损失函数(Loss Function)或目标函数(Objective Function),它们是用于衡量预测值和目标值的差异的重要方程。其中,以损失函数举例,损失函数的输出值(Loss)越高表示差异越大,那么在对Resnet的训练过程中,可以不断根据输出的预测结果与目标值之间LOSS,对权重参数进行调整,直至Loss缩小为某一阈值,从而完成该第一初级模型的训练。
在本申请实施例中,中心节点110可以将训练好的第一初级模型进行加密压缩处理后,再向计算节点120发送加密压缩处理后的第一初级模型。可以理解的,本申请提供的联合学习方法中,中心节点110与计算节点120之间传输模型数据而不是梯度数据,训练好的模型数据是无法复原出原始训练数据的,因此在步骤S301处可以直接将训练好的模型发送至各个计算节点120,也可以将训练好的模型使用加密算法和压缩算法对其进行加密压缩处理后再进行传输,从而进一步增强隐私保护功能,减少通信量,其中,加密算法可以是RC6加密算法、二进制加密算法等等,压缩算法可以是LZ4(Lempel-Ziv-4-algorithm)压缩算法、LZMA(Lempel-Ziv-Markov chain-Algorithm)压缩算法等等,本申请不作具体限定。
S302:所述计算节点120使用所述计算节点120的本地数据库121对所述第一初级模型进行增量学习,从而获得第一中级模型。
在本申请实施例中,增量学习可以在保持原有第一初级模型的基础上,额外学习本地数据库中的数据特征,从而达到提高第一初级模型的性能的目的,获得增量学习后的第一中级模型。增量学习使用的具体算法包括模型微调、知识蒸馏(KnowledgeDistillation,KD)等等,本申请不作具体限定。该步骤的具体实现方法将在下文中的步骤S3021-步骤S3023进行详细解释。
S303:计算节点120向中心节点110发送所述第一中级模型。
可以理解的,参考步骤S301可知,中心节点110与计算节点120之间传输模型数据而不是梯度数据,训练好的模型数据是无法复原出原始训练数据的,因此步骤S303处可以直接将训练好的第一中级模型发送至中心节点110,也可以将训练好的第一中级模型使用加密算法和压缩算法对其进行加密压缩处理后再进行传输,从而进一步增强隐私保护功能,减少通信量,其中,加密算法可以是RC6加密算法、二进制加密算法等等,压缩算法可以是LZ4压缩算法、LZMA压缩算法等等,本申请不作具体限定。
S304:中心节点110对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
具体实现中,模型融合可使用的具体算法可以包括多教师示范方法、投票(Voting)、回归(Averaging)、抽样(Bagging)、迭代(Boosting)、分层(Stacking)、等等,本申请不作具体限定。该步骤的具体实现方法将在下文中的步骤S3041-步骤S3045进行详细解释。
在本申请实施例中,参考前述内容可知,在中心节点是多个的情况下,步骤S301以及步骤S304可以是不同的两个中心节点执行的,具体地,所述中心节点包括第一中心节点和第二中心节点,所述第一中心节点和所述第二中心节点部署于不同地理地区,所述计算节点接收由中心节点发送的第一初级模型包括:所述计算节点接收由第一中心节点发送的第一初级模型;所述计算节点向所述中心节点发送所述第一中级模型包括:所述计算节点向所述第二中心节点发送所述第一中级模型。所述中心节点向计算节点发送第一初级模型包括:所述第一中心节点向所述计算节点发送第一初级模型;所述中心节点接收所述计算节点发送的第一中级模型包括:所述第二中心节点接收所述计算节点发送的第一中级模型;所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型包括:所述第二中心节点对接收到的多个所述第一中级模型进行模型融合从而获得所述第一高级模型。应理解,实际情况中,中心节点可以是一个节点,也可以是两个节点,具体可以根据应用场景的实际情况确定,本申请不对此进行具体限定。
在本申请实施例中,所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型之后,所述方法还包括:所述中心节点接收由所述计算节点发送的所述计算节点的所述本地数据库中存在新数据的消息;所述中心节点根据所述本地数据库中存在新数据的消息,向所述计算节点发送第二初级模型,其中,所述第二初级模型是所述第一高级模型;所述计算节点结合所述新数据对所述第二初级模型进行增量学习,获得第二中级模型;所述计算节点向所述中心节点发送所述第二中级模型;所述中心节点对接收到的多个所述第二中级模型进行模型融合,获得第二高级模型。简单来说,在计算节点的本地数据库中存在新数据的情况下,还可以将步骤S304中训练好的第一高级模型作为第二初级模型,重复步骤S301-步骤S304,从而获得性能更好的第二高级模型。
可以理解的,本申请提供的联合学习方法,解决了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题,并且将计算压力分担在计算节点上,从而减轻中心节点压力,降低中心节点的故障率,提高联合学习的效率。整个联合学习的过程中心节点只需要向计算节点发送一次模型数据,计算节点也只需要向中心节点发送一次数据,通信量非常低,如果用计算节点数量S=50,迭代次数N=10000,模型/梯度数据大小M=50MB,参与训练的计算节点比例α=30%,数据压缩后占比β=50%,传统联合学习方法的理论通信量与本申请提供的联合学习方法的理论通信量可以如下表1所示:
表1传统联合学习方法以及本申请提供的联合学习方法的理论通信量
联合学习方法 | 传统 | 本申请 |
理论通信量公式 | T=2αSNβM | T=2αSβM |
一次完整联合学习所需的通信量 | 7.15TB | 2.44GB |
其中,一次完成联合学习所需的通信量指的是中心节点110将训练好的第一初级模型发送至计算节点120,以及计算节点120将增量学习后的第一中级模型发送至中心节点110,两次通信所需的总通信量。由上表可知,本申请提供的联合学习方法可以大大减少数据通信量,相比于传统的联合学习方法,本申请提供的联合学习方法的应用前景更加广泛。
下面以第一初级模型为Resnet,增量学习选择的方法为知识蒸馏为例,对步骤S302中计算节点120根据第一初级模型和本地数据库121进行增量学习的具体过程进行阐述说明。
在传统的神经网络训练过程中,将使用数据集对神经网络进行训练,训练过程中根据数据集的真实标签来对预测结果进行调整,因此可以理解为训练好的神经网络模型学到了数据集中的“知识”。而本申请实施例中,使用预先训练好的第一初级模型对本地数据库中的训练数据进行预测,获得每个训练数据的软目标,对第一中级网络进行训练的过程中,将会根据真实标签和软目标同时对第一中级网络的网络参数进行调整,因此可以理解为训练好的神经网络模型不仅学习到了数据集中的“知识”,也学习到了第一初级模型的“知识”,从而达到增量学习的目的。
因此,在本申请实施例中,知识蒸馏算法中使用的教师模型可以是中心节点110向计算节点120发送的第一初级模型,学生模型可以是计算节点进行增量学习后生成的第一中级模型,如图5所示,步骤S302可以分为以下步骤:
S3021:所述计算节点将所述本地数据库中的训练数据输入所述第一初级模型中的第一网络,获得第一网络的输出结果,其中,所述第一初级模型包括第一网络和第一预测函数,所述第一预测函数用于根据所述第一网络的输出结果生成所述第一初级模型的预测值。
参考前述内容可知,训练数据输入第一初级模型之后,将会依次通过输入层、卷积层、池化层以及全连接层,全连接层提取出特征向量之后,输入输出层中的第一预测函数(即softmax),即可获得最终的预测结果。因此,第一网络即可包括输入层、卷积层、池化层以及全连接层。而在步骤S3021,第一网络的输出结果即为第一初级模型的全连接层提取的特征向量。
S3022:所述计算节点将所述第一网络的输出结果输入软目标预测函数,获得所述本地数据库中的训练数据的软目标,其中,所述软目标预测函数是根据温度参数和所述第一预测函数得到的,所述温度参数用于使得所述第一网络的输出结果输入所述软目标预测函数得到的软目标,不大于所述第一网络的输出结果输入所述第一预测函数得到的所述第一初级模型的预测值。
可以理解的,计算节点120将第一初级模型的全连接层提取的特征向量输入软目标预测函数,获得每个训练数据的软目标,使得计算节点120的本地数据库121中的训练数据既包含对应的真实标签,也包含软目标。其中,软目标指的是教师网络使用软目标预测函数产生的输出结果,而所述软目标预测函数是根据温度参数T(Temperature)和所述第一预测函数得到的,所述软目标用于供所述第一中级网络学习所述第一初级网络的预测能力。可以理解的,使用知识蒸馏方法进行增量学习的情况下,本地数据库121中的训练数据中的每一数据均标注有真实标签,该真实标签可以是人工标注的,也可以是机器标注的,本申请不作具体限定。
其中,知识蒸馏中的温度参数T的含义可以用跑步的例子来解释,某运动员每次跑步均为负重跑步,那么在取下负重正常跑步的时候,就会非常轻松,也可以比其他运动员跑步速度更快。同理,温度参数T就是这个负重包,对于一个结构复杂的教师网络来说,训练结束后往往能够得到很好的学习效果,但是对于一个结构简单的学生网络来说,训练结束后往往无法得到很好的学习效果,因此,为了帮助结构简单的学生网络进行学习,可以在学生网络的softmax中增加一个温度参数T,加上这个温度参数以后,错误分类再经过softmax以后错误输出会被“放大”,正确分类会被“缩小”,也就是说,增大了训练的难度,一旦将T重新设置为1,分类结果会非常接近于老师网络的分类效果。而软目标预测函数即为包含温度参数T的softmax,软目标则指的是教师网络使用软目标预测函数产生的输出结果。
具体地,软目标预测函数的公式可以是:
其中,q为训练数据输入模型后使用软目标预测函数获得的软目标预测值,z为训练数据输入模型后使用softmax获得的预测值,T为预设的蒸馏学习温度参数。其中,softmax获得的预测值为概率分布。例如,一个5分类问题,输入的一张图片的真实分类结果为第4类,那么这张图片的真实标签可以是y=[0,0,0,1,0],当模型的softmax获得的预测值z=[0.1,0.15,0.05,0.6,0.1]时,如果T=5,软目标预测值可以是z=[0.5,0.51,0.5,0.53,0.5]。因此,在softmax中增加温度参数T后,得到的软目标预测值q相比softmax获得的预测值z的概率分布更缓和、均匀,而神经网络是根据预测值与真实标签之间的差距来调整网络模型的权重参数的,因此使用第一初级模型(教师模型)预测的软目标指导第一中级模型(学生模型)的训练,可以增大训练的难度,使得训练好的学生模型用不包含温度参数T的softmax进行预测时,输出值的准确率可以更高。上述举例仅用于说明,并不能构成具体限定。
S3023:计算节点120使用所述本地数据库121对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,其中,所述第一中级网络的网络类型与所述神经网络的网络类型相同。
具体地,所述第一中级网络包括第二网络和第二预测函数,所述第二预测函数用于根据所述第二网络的输出结果生成所述第一中级网络的预测值,所述混合损失函数是根据第一损失函数以及第二损失函数确定的,所述第一损失函数是所述计算节点将所述本地数据库中的训练数据输入所述第一中级网络的所述第二网络获得第二网络的输出结果后,将所述第二网络的输出结果输入所述软目标预测函数生成的软目标预测值与软目标之间的差距确定的,所述第二损失函数是将所述第二网络的输出结果输入所述第二预测函数生成的预测值与真实标签之间的差距确定的。
具体实现中,该中级网络与第一初级模型的网络类型相同,但是网络结构可能不同。其中,所述网络类型可以是多层感知机(Multi-Layer Perceptron,MLP)、卷积神经网络(Convolutional Neural Networks,CNN)、循环神经网络(Recurrent Neural Network,RNN)、图神经网络(Graph Neural Network,GNN)以及深度信念网络(Deep Belief Nets,DBN)等等,本申请不作具体限定。不同的网络结构可以是指第一中级网络与第一初级模型虽然是同一种网络类型(比如都是CNN网络),但是第一中级网络与第一初级模型拥有不同的网络层数、不同的通道(Channel)数、不同优化方法选择、不同的训练逻辑、不同的学习率、不同的迭代次数或者不同的正则化参数等等,本申请也不对此进行具体限定。
可以理解的,第一中级网络作为学生网络,其网络结构相比于第一初级模型可以更加简单,比如网络层数更少,迭代次数更少等等,例如:第一初级模型(教师模型)可以是用于人脸识别的残差神经网络Resnet101,第一中级模型(学生模型)可以是用于人脸检测的Resnet50,其中,第一初级模型的网络层数为101,第一中级模型的网络层数为50。可以理解的,选择结构更加简单的第一中级模型对第一初级模型和本地数据库进行增量学习,可以大大减少计算节点的处理压力,使得本申请提供的联合学习方法可以在移动终端上对第一初级模型进行增量学习,获得第一中级模型,应用场景更加广泛。
因此步骤S3023的具体流程可以如图5所示,首先将本地数据库中的数据输入第一中级网络的第二网络生成第二网络的输出结果(即第一中级网络的全连接层提取的特征向量),并将其输入包含温度参数T的软目标预测函数生成软目标预测值,根据该软目标预测值与软目标之间的差距生成第一损失函数L1,再根据使用不包含温度参数T的softmax生成预测值,根据该预测值与真实标签之间的差距生成第二损失函数L2,最后根据第一损失函数L1以及第二损失函数L2确定第一中级网络训练时的混合损失函数L。其中,步骤S3023的温度参数T和步骤S3021的温度参数T是相同的。具体实现中,混合损失函数L可以是第一损失函数L1以及第二损失函数L2的加权平均,具体可以如公式(2)所示:
L=γL1+(1-γ)L2 (2)
其中,γ为加权系数,加权系数γ越大,根据软目标预测值与软目标之间的差距生成的第一损失函数L1对于混合损失函数L的影响越大,因此在训练过程中,可以在训练初期选择较大的加权系数γ,有助于让学生模型能够快速模仿教师模型来鉴别样本,在训练后期选择较小的加权系数γ,让真实标签帮助鉴别困难样本,从而在提高学生模型的训练速度的同时,提高学生模型的模型性能。
S3024:所述计算节点根据混合损失函数对第一中级网络进行反向传播,获得训练好的第一中级模型.
在本申请实施例中,根据混合损失函数L对所述第一中级网络进行反向传播,根据所述混合损失函数L调整所述第一中级网络的模型参数,直到混合损失函数L的输出值达到预设的阈值,从而获得增量学习好的第一中级模型。
可以理解的,计算节点120通过上述知识蒸馏的方法进行增量学习,用于控制第一中级网络的权重参数调整的混合损失函数是根据第一损失函数和第二损失函数确定的,第一损失函数是根据第一初级模型预测的软目标确定的,第二损失函数是根据本地数据库的真实标签确定的,因此第一中级模型的训练过程中既学习到了教师网络的“知识”,也学习到了本地数据集的“知识”,可以达到增量学习的目的。并且,计算节点可以选择结构更加简单的第一中级网络作为学生网络进行知识蒸馏,可以大大减少计算节点的处理压力,使得本申请提供的联合学习方法可以拥有更广的应用范围,甚至可以在移动终端上对第一初级模型进行增量学习,获得训练好的第一中级模型。而传统的联合学习只有中心节点对模型进行训练,因此中心节点的计算压力非常大,本申请提供的联合学习方法还可以将计算压力分担在计算节点上,从而减轻中心节点压力,降低中心节点的故障率,提高联合学习的效率。
下面以第一初级模型为Resnet,增量学习选择的方法为知识蒸馏,模型融合选择的方法为多教师示教融合方法为例,对步骤S304的模型融合步骤进行详细解释。
多教师示教融合方法通过将多个增量学习后的第一中级模型的全连接层提取的特征向量拼接为一个特征向量向量,获得多教师融合模型,再将多教师融合模型作为知识蒸馏算法中的教师模型,第一高级网络作为知识蒸馏算法中的学生网络,对第一高级网络进行知识蒸馏,从而获得训练好的第一高级模型。训练过程中不仅使用数据集的真实标签来对第一高级模型的预测结果进行调整,还会使用多教师融合模型对数据集进行预测获得每个数据的软目标,对第一高级模型的预测结果进行调整,因此训练好的第一高级模型不仅学习到了数据集中的“知识”,也学习到了多个教师模型(这是是多个第一中级模型)的“知识”,从而达到模型融合的目的。
如图6所示,步骤S302中对多个第一中级模型进行模型融合获得第一高级模型的具体过程,可以分为以下步骤:
S3041:中心节点110在多个第一中级模型中筛选出基础模型和m个正向提升模型。其中,所述基础模型是所述中心节点将所述中心数据库中的测试数据输入所述多个第一中级模型后,所述多个第一中级模型中预测准确率最高的模型,所述正向提升模型是对所述基础模型有指导作用的模型。
可以理解的,如果正向提升模型能够指导第一中级模型里预测准确率最高的基础模型进行训练,并使得基础模型的预测准确率得到进一步提升,那么正向提升模型一定可以指导未经过训练集的训练、仅为简单的神经网络结构甚至不具备模型预测能力的第一高级网络进行训练,并使得第一高级网络的预测准确率能够得到提升。因此,所述正向提升模型的获取方法可以如图7所示:
首先,中心节点110将其他第一中级模型(比如图6中的第i个第一中级模型)的x维特征向量与基础模型的y维特征向量进行水平方向的拼接,获得拼接后特征向量,该特征向量为x+y维的特征向量,其中,特征向量可以是全连接层提取的特征向量,该特征向量用于输入softmax从而输出最终的概率分布。举例来说,如果第i个第一中级模型的全连接层提取的特征向量为4维的特征向量(1,1,1,1),基础模型的全连接层提取的特性向量为8维的特征向量(1,2,3,4,5,6,7,8),那么两个特征向量进行水平方向的拼接后,将会获得12维的特征向量(1,1,1,1,1,2,3,4,5,6,7,8)。
其次,对拼接后特征向量进行降维处理,使得拼接后的特征向量的维度与基础模型提取的特征向量的维度一致。换句话说,将拼接后的x+y维的特征向量降为处理获得x维特征向量。具体实现中,可以通过主成分分析算法(Principal Components Analysis,PCA)对拼接后的特征向量进行降维处理,还可以通过其他降维算法,本申请不作具体限定。
最后,将拼接后的特征向量输入基础模型的softmax,得到拼接后的预测结果,根据所述预测结果确定所述第x个第一中级模型与所述基础模型进行特征拼接后的预测准确率,在所述特征拼接后的预测准确率高于所述基础模型的预测准确率的情况下,将所述第x个第一中级模型确认为正向提升模型,依次类推,即可从所述多个第一中级模型中筛选出符合条件的m个正向提升模型。
应理解,图7所示的正向提升模型确认方法仅用于举例说明,具体实现中,还可以使用其他方法对多个第一中级模型进行筛选,确定出可以作为所述第一高级网络的教师模型的正向提升模型,本申请不作具体限定。
需要说明的,步骤S3041用于在第一中级模型中筛选出m个适合作为老师模型对第一高级网络进行知识蒸馏的模型,使得多教师融合模型的结构可以得到缩减,减少第一高级模型训练所需的时间。具体实现中,步骤S3041也可以被省略,直接使用多个第一中级模型作为正向提升模型执行步骤S3042,具体可以根据实际情况确认步骤S3041是否被执行,本申请不作具体限定。
S3042:根据m个正向提升模型,获得多教师融合模型,其中,所述多教师融合模型的全连接层提取的特征向量是所述m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与所述第一高级模型的特征向量维度相同。
举例来说,如图8所示,多教师融合模型可以包括m个正向提升模型,该多教师融合模型可以将第一个正向提升模型的全连接层提取的x1维特征向量,第二个正向提升模型的全连接层提取的x2维特征向量,…,以及第m个正向提升模型的全连接层提取的xm维特征向量进行水平方向的拼接,从而获得一个(x1+x2+…+xm)维的特征向量,该多教师融合模型将根据所述(x1+x2+…+xm)维的特征向量确定输入数据的预测结果。应理解,由m个正向提升模型组成的多教师融合模型,在接下来步骤S3043-步骤S3045的知识蒸馏步骤中,将会作为老师模型对第一高级网络进行知识蒸馏,从而获得最终的第一高级模型,这样可以使得训练好的第一高级模型不仅学习到了数据集中的“知识”,也学习到了这m个正向提升模型的“知识”,从而达到模型融合的目的。
S3043:使用多教师融合模型对中心数据库111进行预测,获得每个训练数据的软目标。其中,软目标的定义以及确定方法可以参考前述实施例中的步骤S3021,这里不再展开赘述。
S3044:使用中心数据库对基础模型进行训练,得到混合损失函数L,具体可以参考前述内容的步骤S3022,这里不再展开赘述。
S3045:根据混合损失函数L对第一高级网络进行反向传播,获得训练好的第一高级模型,具体可以参考前述内容的步骤S3023,这里不再展开赘述。
举例来说,如果联合学习系统只包括一个中心节点H和一个计算节点A,中心节点H的中心数据库中包括67万ID,图像共5000万张,计算节点A的本地数据库中包括62万ID,图像共600万张,首先使用本地数据库训练好第一初级模型H,模型结构可以选择前述内容中的Resnet网络;然后将第一初级模型H发送至计算节点A,计算节点A可以根据本地数据库对第一初级模型H进行增量学习(比如前述内容中步骤S3021-步骤S3024描述的的知识蒸馏的方法),获得第一中级模型A;接着计算节点A可以将增量学习后的第一中级模型A发送至中心节点H;最后中心节点H可以对第一中级模型A和第一初级模型H进行模型融合(比如前述内容中步骤S3041-步骤S3025的多教师示教融合方法),获得模型融合好的模型H’。使用节点H和节点A的测试集对第一初级模型H、第一中级模型A以及第一高级模型H’进行测试后,测试结果如下表2所示:
表2基于1个中心节点和1个计算节点的人脸识别模型的测试
模型 | 中心节点H测试结果 | 计算节点A测试结果 | 平均结果 |
模型H | 92.29% | 93.06% | 92.68% |
模型A | 91.72% | 94.38% | 93.05% |
模型H’ | 92.63% | 93.56% | 93.10% |
由上表可知,本申请提供的联合学习方法,不但可以在低通信代价的情况下,解决了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题,而且训练好的神经网络的模型性能可以得到提升。
综上可知,使用本申请提供的联合学习方法,不但解决了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题,而且整个联合学习的过程中,中心节点最少可以向计算节点发送一次第一初级模型数据,计算节点最少可以向中心节点发送一次第一中级模型数据即可完成一次联合学习的过程,使得整个联合学习的过程占用的通信资源数量极少。并且,由于服务端通过模型融合算法得到联合学习后的模型,因此每一个计算节点在进行增量学习时,用于结合本地数据库进行训练的第一中级网络结构可以不作任何限制,使得本申请适用的应用场景更加广泛。
上述详细阐述了本申请实施例的方法,为了便于更好的实施本申请实施例上述方案,相应地,下面还提供用于配合实施上述方案的相关设备。
图9是本申请提供的一种计算节点800的结构示意图,其中,该计算节点800应用于图2所示的联合学习系统100,该计算节点800可以是前述内容中的计算节点120,所述计算节点包括接收单元810、学习单元820、发送单元830以及本地数据库840,其中,
所述接收单元810用于接收由中心节点发送的第一初级模型,其中,所述第一初级模型是所述中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
所述学习单元820用于使用所述计算节点的本地数据库840对所述第一初级模型进行增量学习从而获得第一中级模型,其中,学习单元820可以是图2实施例中的增量学习模块122,本地是数据库840可以是图2实施例中的本地数据库121;
所述发送单元830用于向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
在本申请实施例中,所述学习单元820用于将所述本地数据库中的训练数据输入所述第一初级模型中的第一网络,获得第一网络的输出结果,其中,所述第一初级模型包括第一网络和第一预测函数;所述学习单元820用于将所述第一网络的输出结果输入软目标预测函数,获得所述本地数据库中的训练数据的软目标,其中,将所述第一网络的输出结果输入软目标预测函数,获得所述本地数据库中的训练数据的软目标,其中,所述软目标预测函数是根据温度参数和所述第一预测函数得到的,所述温度参数用于使得所述第一网络的输出结果输入所述软目标预测函数得到的软目标,不大于所述第一网络的输出结果输入所述第一预测函数得到的所述第一初级模型的预测值;所述学习单元820用于使用所述本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,其中,所述第一中级网络的网络类型与所述神经网络的网络类型相同;所述学习单元820用于根据所述混合损失函数对所述第一中级网络进行反向传播,获得训练好的第一中级模型。
在本申请实施例中,所述第一中级网络包括第二网络和第二预测函数,所述第二预测函数用于根据所述第二网络的输出结果生成所述第一中级网络的预测值,所述混合损失函数是根据第一损失函数以及第二损失函数确定的,所述第一损失函数是所述计算节点将所述本地数据库中的训练数据输入所述第一中级网络的所述第二网络获得第二网络的输出结果后,将所述第二网络的输出结果输入所述软目标预测函数生成的软目标预测值与软目标之间的差距确定的,所述第二损失函数是将所述第二网络的输出结果输入所述第二预测函数生成的预测值与所述真实标签之间的差距确定的。
在本申请实施例中,所述第一高级模型是所述中心节点根据所述多个第一中级模型获得多教师融合模型后,使用所述中心节点的所述中心数据库对所述多教师融合模型进行增量学习后获得的,其中,所述多教师融合模型的全连接层提取的特征向量是所述多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与所述第一高级模型的特征向量的维度相同。
在本申请实施例中,所述多教师融合模型的全连接层提取的特征向量是m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述m个正向提升模型是根据所述多个第一中级模型得到的,所述正向提升模型是对基础模型有指导作用的模型,所述基础模型是将所述中心数据库中的测试数据输入所述多个第一中级模型后,所述多个第一中级模型中预测准确率最高的模型。
在本申请实施例中,所述接收单元810还用于在向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合,获得第一高级模型之后,所述接收单元在所述本地数据库中存在新数据的情况下,接收由中心节点发送的第二初级模型,其中,所述第二初级模型是所述第一高级模型;所述学习单元820还用于结合所述新数据对所述第二初级模型进行增量学习,获得第二中级模型;所述发送单元830还用于向所述中心节点发送所述第二中级模型,使得所述中心节点对接收到的多个所述更新后的第二中级模型进行模型融合,获得第二高级模型。
在本申请实施例中,所述一个或多个中心节点包括第一中心节点和第二中心节点,所述接收单元810还用于接收由第一中心节点发送的第一初级模型;所述发送单元830还用于向所述第二中心节点发送所述第一中级模型,使得所述第二中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
图10是本申请提供的一种中心节点900的结构示意图,其中,该中心节点900可以是前述内容中的中心节点110,如图10所示,中心节900可以应用于如图2所示的联合学习系统100,所述中心节点900包括发送单元910、接收单元920,融合单元930,以及中心数据库940,其中,
所述发送单元910用于向所述计算节点发送第一初级模型,其中,所述第一初级模型是所述中心节点根据中心节点的中心数据库940对神经网络进行训练后得到的;
所述接收单元920用于接收所述计算节点发送的第一中级模型,其中,所述第一中级模型是所述计算节点使用所述计算节点的本地数据库对所述第一初级模型进行增量学习后获得的;
所述融合单元930用于对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。其中,融合单元930可以是前述图2实施例中的模型融合模块113,中心数据库940可以是前述内容中的中心数据库111。
在本申请实施例中,所述第一中级模型是所述计算节点将所述本地数据库中的训练数据输入所述第一初级模型中的第一网络获得第一网络的输出结果,并将所述第一网络的输出结果输入软目标预测函数获得所述本地数据库中的训练数据的软目标之后,使用所述本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,根据所述混合损失函数对所述第一中级网络进行反向传播从而获得的所述第一中级模型,其中,所述第一初级模型包括第一网络和第一预测函数,所述第一预测函数用于根据所述第一网络的输出结果生成所述第一初级模型的预测值,所述软目标预测函数是根据温度参数和所述第一预测函数得到的,所述温度参数用于使得所述第一网络的输出结果输入所述软目标预测函数得到的软目标,不大于所述第一网络的输出结果输入所述第一预测函数得到的所述第一初级模型的预测值。
在本申请实施例中,所述第一中级网络包括第二网络和第二预测函数,所述第二预测函数用于根据所述第二网络的输出结果生成所述第一中级网络的预测值,所述混合损失函数是根据第一损失函数以及第二损失函数确定的,所述第一损失函数是所述计算节点将所述本地数据库中的训练数据输入所述第一中级网络的所述第二网络获得第二网络的输出结果后,将所述第二网络的输出结果输入所述软目标预测函数生成的软目标预测值与软目标之间的差距确定的,所述第二损失函数是将所述第二网络的输出结果输入所述第二预测函数生成的预测值与所述真实标签之间的差距确定的。
在本申请实施例中,所述融合单元930用于根据所述多个第一中级模型,获得多教师融合模型,其中,所述多教师融合模型的全连接层提取的特征向量是所述多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量维度相同;所述融合单元930用于使用所述中心数据库940对所述多教师融合模型进行增量学习,获得所述第一高级模型。
在本申请实施例中,所述融合单元930用于在多个第一中级模型中筛选出基础模型和m个正向提升模型,其中,所述基础模型是所述中心节点将所述中心数据库940中的测试数据输入所述多个第一中级模型后,所述多个第一中级模型中预测准确率最高的模型,所述正向提升模型是对所述基础模型有指导作用的模型;所述融合单元930用于根据所述m个正向提升模型,获得所述多教师融合模型,其中,所述多教师融合模型的全连接层提取的特征向量是所述m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与所述第一高级模型的特征向量维度相同。
在本申请实施例中,所述接收单元920还用于在点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型之后,接收由所述计算节点发送的所述计算节点的所述本地数据库中存在新数据的消息;所述发送单元910还用于根据所述本地数据库中存在新数据的消息,向所述计算节点发送第二初级模型,使得所述计算节点结合所述新数据对所述第二初级模型进行增量学习,获得第二中级模型,其中,所述第二初级模型是所述高级模型;所述接收单元920还用于接收所述计算节点发送的所述第二中级模型;所述融合单元930还用于对接收到的多个所述第二中级模型进行模型融合,获得第二高级模型。
在本申请实施例中,所述中心节点900包括第一中心节点和第二中心节点,所述第一中心节点和所述第二中心节点部署于不同地理地区,所述第一中心节点包括所述发送单元910,所述第二中心节点包括所述接收单元920和所述融合单元930。
可以理解的,使用本申请提供的联合学习方法,不但解决了在遵守隐私与安全规定的前提下,使用分散的训练数据和计算资源对神经网络进行训练的问题,而且整个联合学习的过程中,中心节点只需要向计算节点发送一次第一初级模型数据,计算节点只需要向中心节点发送一次第一中级模型数据,整个联合学习的过程占用的通信资源数量极少。并且,由于服务端通过模型融合算法得到联合学习后的模型,因此每一个计算节点在进行增量学习时,用于结合本地数据库进行训练的第一中级网络结构可以不作任何限制,使得本申请适用的应用场景更加广泛。
参见图11,图11是本申请提供的一种芯片硬件结构图,该芯片为NPU芯片,其中,图2-图8实施例中关于神经网络的算法(比如第一初级模型的训练、第一中级模型的知识蒸馏以及第一高级模型的模型融合算法等等)可以在图11所示的NPU芯片中实现。
如图11所示,NPU 50作为协处理器挂载到主CPU(Host CPU)上,由Host CPU分配任务。NPU的核心部分为运算电路50,通过控制器504控制运算电路503提取存储器中的矩阵数据并进行乘法运算。
在一些实现中,运算电路503内部包括多个处理单元(Process Engine,PE)。在一些实现中,运算电路503是二维脉动阵列。运算电路503还可以是一维脉动阵列或者能够执行例如乘法和加法这样的数学运算的其它电子线路。在一些实现中,运算电路503是通用的矩阵处理器。
举例来说,假设有输入矩阵A,权重矩阵B,输出矩阵C。运算电路从权重存储器502中取矩阵B相应的数据,并缓存在运算电路中每一个PE上。运算电路从输入存储器501中取矩阵A数据与矩阵B进行矩阵运算,得到的矩阵的部分结果或最终结果,保存在累加器(Accumulator)508中。
统一存储器506用于存放输入数据以及输出数据。权重数据直接通过存储单元访问控制器(Direct Memory Access Controller,DMAC)505,DMAC 505被搬运到权重存储器502中。输入数据也通过DMAC 505被搬运到统一存储器506中。
总线接口单元(Bus Interface Unit,BIU)510用于AXI(Advanced eXtensibleInterface)总线与DMAC 505和取指存储器(Instruction Fetch Buffer,IFB)509的交互。具体用于取指存储器509从外部存储器获取指令,还用于存储单元访问控制器505从外部存储器获取输入矩阵A或者权重矩阵B的原数据。
DMAC主要用于将外部存储器DDR中的输入数据搬运到统一存储器506或将权重数据搬运到权重存储器502中或将输入数据数据搬运到输入存储器501中。
向量计算单元507多个运算处理单元,在需要的情况下,对运算电路的输出做进一步处理,如向量乘,向量加,指数运算,对数运算,大小比较等等。主要用于神经网络中非卷积/FC层网络计算,如Pooling(池化),Batch Normalization(批归一化),Local ResponseNormalization(局部响应归一化)等。
在一些实现种,向量计算单元能507将经处理的输出的向量存储到统一缓存器506。例如,向量计算单元507可以将非线性函数应用到运算电路503的输出,例如累加值的向量,用以生成激活值。在一些实现中,向量计算单元507生成归一化的值、合并值,或二者均有。在一些实现中,处理过的输出的向量能够用作到运算电路503的激活输入,例如用于在神经网络中的后续层中的使用。
控制器504连接的取指存储器509,用于存储控制器504使用的指令;
统一存储器506,输入存储器501,权重存储器502以及取指存储器509均为嵌入式(On-Chip)存储器。外部存储器私有于该NPU硬件架构。
其中,图2–图8实施例中,第一初级模型、第一中级模型以及第一高级模型的网络各层的运算可以由矩阵计算单元212或向量计算单元507执行。
图12为本申请实施例提供的一种电子设备1200的结构示意图。其中,所述电子设备1200可以是前述内容中的中心节点或者计算节点。如图12所示,电子设备1200包括:处理器1210、通信接口1220、存储器1230、神经网络处理器1240以及总线1250.。其中,处理器1210、通信接口1220存储器1230以及神经网络处理器1240可以通过内部总线1250相互连接,也可通过无线传输等其他手段实现通信。本申请实施例以通过总线1250连接为例,总线1250可以是外设部件互连标准(Peripheral Component Interconnect,PCI)总线或扩展工业标准结构(Extended Industry Standard Architecture,EISA)总线等。所述总线1250可以分为地址总线、数据总线、控制总线等。为便于表示,图12中仅用一条粗线表示,但并不表示仅有一根总线或一种类型的总线。
所述处理器1210可以由一个或者多个通用处理器构成,例如中央处理器(CentralProcessing Unit,CPU),或者CPU和硬件芯片的组合。上述硬件芯片可以是专用集成电路(Application-Specific Inegrated Circuit,ASIC)、可编程逻辑器件(ProgrammableLogic Device,PLD)或其组合。上述PLD可以是复杂可编程逻辑器件(ComplexProgrammable Logic Device,CPLD)、现场可编程逻辑门阵列(Field-Programmable GateArray,FPGA)、通用阵列逻辑(Generic Array Logic,GAL)或其任意组合。处理器1210执行各种类型的数字存储指令,例如存储在存储器1230中的软件或者固件程序,它能使电子设备1200提供较宽的多种服务。
在电子设备1200是前述内容中的中心节点的情况下,所述处理器1210可以包括融合单元,该融合单元可以通过调用存储器1230中的程序代码以实现处理功能,包括图2中的模型融合模块113所描述的功能以及图10中的融合单元930所描述的功能,例如根据多个第一中级模型进行模型融合,从而获得第一高级模型,具体可用于执行前述方法的S3041-步骤S3045及其可选步骤,还可以用于执行图6-图8实施例描述的其他步骤,这里不再进行赘述。
在电子设备1200是前述内容中的计算节点的情况下,所述处理器12100可以包括学习单元,该学习单元可以通过调用存储器1230中的程序代码以实现处理功能,包括图2中的增量学习模块122所描述的功能以及图9中学习单元820所描述的功能,例如根据本地数据库对第一初级模型进行增量学习,从而获得第一中级模型,具体可用于执行前述方法的步骤S3021-步骤S3023及其可选步骤,还可以用于执行图4-图8实施例描述的其他步骤,这里不再进行赘述。
所述存储器1230可以包括易失性存储器(Volatile Memory),例如随机存取存储器(Random Access Memory,RAM);存储器1230也可以包括非易失性存储器(Non-VolatileMemory),例如只读存储器(Read-Only Memory,ROM)、快闪存储器(Flash Memory)、硬盘(Hard Disk Drive,HDD)或固态硬盘(Solid-State Drive,SSD);存储器1230还可以包括上述种类的组合。
在电子设备1200是前述内容中的中心节点的情况下,存储器1230可以存储有图2实施例中的中心数据库111或者图10实施例中的中心数据库940,中心数据库中包括训练数据和测试数据,训练数据用于训练第一初级模型,用于根据多个第一中级模型进行模型融合从而获得第一高级模型等等,测试数据用于对中级模型进行测试,从而获得预测准确率最高的基础模型等等;存储器1230还可以存储有程序代码。程序代码可以是训练第一初级模型的代码、压缩/解压缩模型的代码、对多个第一中级模型进行模型融合的代码等等,还可以包括其他用于执行图2-图8实施例描述的其他步骤的程序代码,这里不再进行赘述。
在电子设备1200是前述内容中的计算节点的情况下,存储器1230可以存储有图2实施例中的本地数据库121或者图9实施例中的本地数据库840,本地数据库中包括训练数据,用于对第一初级模型进行增量学习从而获得第一中级模型等等;存储器1230还可以存储有应用程序代码。程序代码可以是增量学习的代码,压缩/解压缩模型的代码等等,还可以包括其他用于执行图2-图8实施例描述的其他步骤的程序代码,这里不再进行赘述。
神经网络处理器1240可以用于通过存储器1230的程序代码以及中心数据库或者本地数据库中的训练数据执行各种神经网络的算法(比如第一初级模型的训练、第一中级模型的知识蒸馏以及第一高级模型的模型融合算法等等),以执行本文讨论的方法的至少一部分。其中,神经网络处理器1240的硬件结构具体可参考图11,这里不再进行赘述。
通信接口1220可以为有线接口(例如以太网接口),可以为内部接口(例如高速串行计算机扩展总线(Peripheral Component Interconnect express,PCIe)总线接口)、有线接口(例如以太网接口)或无线接口(例如蜂窝网络接口或使用无线局域网接口),用于与与其他设备或模块进行通信。
需要说明的,图12仅仅是本申请实施例的一种可能的实现方式,实际应用中,所述电子设备还可以包括更多或更少的部件,这里不作限制。关于本申请实施例中未示出或未描述的内容,可参见前述图2-图8所述实施例中的相关阐述,这里不再赘述。
应理解,图12所示的电子设备还可以是多个服务器构成的计算机集群,本申请不作具体限定。
本申请实施例还提供一种计算机可读存储介质,所述计算机可读存储介质中存储有指令,当其在处理器上运行时,图2-图8所示的方法流程得以实现。
本申请实施例还提供一种计算机程序产品,当所述计算机程序产品在处理器上运行时,图2-图8所示的方法流程得以实现。
上述实施例,可以全部或部分地通过软件、硬件、固件或其他任意组合来实现。当使用软件实现时,上述实施例可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载或执行所述计算机程序指令时,全部或部分地产生按照本发明实施例所述的流程或功能。所述计算机可以为通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(Digital Subscriber Line,DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集合的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质(例如,软盘、硬盘、磁带)、光介质(例如,高密度数字视频光盘(Digital Video Disc,DVD)、或者半导体介质。半导体介质可以是SSD。
以上所述,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,可轻易想到各种等效的修改或替换,这些修改或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。
Claims (29)
1.一种联合学习的方法,其特征在于,应用于联合学习系统,所述联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,所述方法包括:
计算节点接收由中心节点发送的第一初级模型,其中,所述第一初级模型是所述中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
所述计算节点使用所述计算节点的本地数据库对所述第一初级模型进行增量学习从而获得第一中级模型;
所述计算节点向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
2.根据权利要求1所述的方法,其特征在于,所述计算节点使用所述计算节点的本地数据库对所述第一初级模型进行增量学习从而获得第一中级模型包括:
所述计算节点将所述本地数据库中的训练数据输入所述第一初级模型中的第一网络,获得第一网络的输出结果,其中,所述第一初级模型包括第一网络和第一预测函数,所述第一预测函数用于根据所述第一网络的输出结果生成所述第一初级模型的预测值;
所述计算节点将所述第一网络的输出结果输入软目标预测函数,获得所述本地数据库中的训练数据的软目标,其中,所述软目标预测函数是根据温度参数和所述第一预测函数得到的,所述温度参数用于使得所述第一网络的输出结果输入所述软目标预测函数得到的软目标,不大于所述第一网络的输出结果输入所述第一预测函数得到的所述第一初级模型的预测值;
所述计算节点使用所述本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,其中,所述第一中级网络的网络类型与所述神经网络的网络类型相同;
所述计算节点根据所述混合损失函数对所述第一中级网络进行反向传播,获得所述第一中级模型。
3.根据权利要求2所述的方法,其特征在于,所述第一中级网络包括第二网络和第二预测函数,所述第二预测函数用于根据所述第二网络的输出结果生成所述第一中级网络的预测值,所述混合损失函数是根据第一损失函数以及第二损失函数确定的,所述第一损失函数是所述计算节点将所述本地数据库中的训练数据输入所述第一中级网络的所述第二网络获得第二网络的输出结果后,将所述第二网络的输出结果输入所述软目标预测函数生成的软目标预测值与软目标之间的差距确定的,所述第二损失函数是将所述第二网络的输出结果输入所述第二预测函数生成的预测值与所述真实标签之间的差距确定的。
4.根据权利要求1所述的方法,其特征在于,所述第一高级模型是所述中心节点根据所述多个第一中级模型获得多教师融合模型后,使用所述中心节点的所述中心数据库对所述多教师融合模型进行增量学习后获得的,其中,所述多教师融合模型的全连接层提取的特征向量是所述多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与所述第一高级模型的特征向量的维度相同。
5.根据权利要求4所述的方法,其特征在于,所述多教师融合模型的全连接层提取的特征向量是m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述m个正向提升模型是根据所述多个第一中级模型得到的,所述正向提升模型是对基础模型有指导作用的模型,所述基础模型是将所述中心数据库中的测试数据输入所述多个第一中级模型后,所述多个第一中级模型中预测准确率最高的模型。
6.根据权利要求1至5任一权利要求所述的方法,其特征在于,所述计算节点向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合,获得第一高级模型之后,所述方法还包括:
所述计算节点在所述本地数据库中存在新数据的情况下,接收由中心节点发送的第二初级模型,其中,所述第二初级模型是所述第一高级模型;
所述计算节点结合所述新数据对所述第二初级模型进行增量学习,获得第二中级模型;
所述计算节点向所述中心节点发送所述第二中级模型,使得所述中心节点对接收到的多个所述更新后的第二中级模型进行模型融合,获得第二高级模型。
7.根据权利要求1至5任一权利要求所述的方法,其特征在于,所述一个或多个中心节点包括第一中心节点和第二中心节点,所述第一中心节点和所述第二中心节点部署于不同地理地区,所述计算节点接收由中心节点发送的第一初级模型包括:
所述计算节点接收由第一中心节点发送的第一初级模型;
所述计算节点向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型包括:
所述计算节点向所述第二中心节点发送所述第一中级模型,使得所述第二中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
8.一种联合学习的方法,其特征在于,应用于联合学习系统,所述联合学习系统包括一个或者多个计算节点以及一个或多个中心节点,所述方法包括:
所述中心节点向所述计算节点发送第一初级模型,其中,所述第一初级模型是所述中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
所述中心节点接收所述计算节点发送的第一中级模型,其中,所述第一中级模型是所述计算节点使用所述计算节点的本地数据库对所述第一初级模型进行增量学习后获得的;
所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
9.根据权利要求8所述的方法,其特征在于,所述第一中级模型是所述计算节点将所述本地数据库中的训练数据输入所述第一初级模型中的第一网络获得第一网络的输出结果,并将所述第一网络的输出结果输入软目标预测函数获得所述本地数据库中的训练数据的软目标之后,使用所述本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,根据所述混合损失函数对所述第一中级网络进行反向传播从而获得所述第一中级模型,其中,所述第一初级模型包括第一网络和第一预测函数,所述第一预测函数用于根据所述第一网络的输出结果生成所述第一初级模型的预测值,所述软目标预测函数是根据温度参数和所述第一预测函数得到的,所述温度参数用于使得所述第一网络的输出结果输入所述软目标预测函数得到的软目标,不大于所述第一网络的输出结果输入所述第一预测函数得到的所述第一初级模型的预测值。
10.根据权利要求9所述的方法,其特征在于,所述第一中级网络包括第二网络和第二预测函数,所述第二预测函数用于根据所述第二网络的输出结果生成所述第一中级网络的预测值,所述混合损失函数是根据第一损失函数以及第二损失函数确定的,所述第一损失函数是所述计算节点将所述本地数据库中的训练数据输入所述第一中级网络的所述第二网络获得第二网络的输出结果后,将所述第二网络的输出结果输入所述软目标预测函数生成的软目标预测值与软目标之间的差距确定的,所述第二损失函数是将所述第二网络的输出结果输入所述第二预测函数生成的预测值与所述真实标签之间的差距确定的。
11.根据权利要求8所述的方法,其特征在于,所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型包括:
所述中心节点根据所述多个第一中级模型,获得多教师融合模型,其中,所述多教师融合模型的全连接层提取的特征向量是所述多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量维度相同;
所述中心节点使用所述中心数据库对所述多教师融合模型进行增量学习,获得所述第一高级模型。
12.根据权利要求11所述的方法,其特征在于,所述根据所述多个第一中级模型,获得多教师融合模型包括:
所述中心节点在多个第一中级模型中筛选出基础模型和m个正向提升模型,其中,所述基础模型是所述中心节点将所述中心数据库中的测试数据输入所述多个第一中级模型后,所述多个第一中级模型中预测准确率最高的模型,所述正向提升模型是对所述基础模型有指导作用的模型;
所述中心节点根据所述m个正向提升模型,获得所述多教师融合模型,其中,所述多教师融合模型的全连接层提取的特征向量是所述m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与所述第一高级模型的特征向量维度相同。
13.根据权利要求8至12任一权利要求所述的方法,其特征在于,所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型之后,所述方法还包括:
所述中心节点向所述计算节点发送第二初级模型,使得所述计算节点结合所述本地数据库中的新数据对所述第二初级模型进行增量学习,获得第二中级模型,其中,所述第二初级模型是所述高级模型;
所述中心节点接收所述计算节点发送的所述第二中级模型;
所述中心节点对接收到的多个所述第二中级模型进行模型融合,获得第二高级模型。
14.根据权利要求8至12任一权利要求所述的方法,其特征在于,所述一个或多个中心节点包括第一中心节点和第二中心节点,所述中心节点向计算节点发送第一初级模型包括:
所述第一中心节点向所述计算节点发送第一初级模型,其中,所述第一初级模型是所述第一中心节点根据第一中心节点的中心数据库对神经网络进行训练后得到的;
所述中心节点接收所述计算节点发送的第一中级模型包括:
所述第二中心节点接收所述计算节点发送的第一中级模型;
所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型包括:
所述第二中心节点对接收到的多个所述第一中级模型进行模型融合,从而获得所述第一高级模型。
15.一种计算节点,其特征在于,应用于联合学习系统,所述联合学习系统包括一个或者多个所述计算节点以及一个或多个中心节点,所述计算节点包括接收单元、学习单元以及发送单元,其中,
所述接收单元用于接收由中心节点发送的第一初级模型,其中,所述第一初级模型是所述中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
所述学习单元用于使用所述计算节点的本地数据库对所述第一初级模型进行增量学习从而获得第一中级模型;
所述发送单元用于向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
16.根据权利要求15所述的节点,其特征在于,所述学习单元用于将所述本地数据库中的训练数据输入所述第一初级模型中的第一网络,获得第一网络的输出结果,其中,所述第一初级模型包括第一网络和第一预测函数,所述第一预测函数用于根据所述第一网络的输出结果生成所述第一初级模型的预测值;
所述学习单元用于将所述第一网络的输出结果输入软目标预测函数,获得所述本地数据库中的训练数据的软目标,其中,所述软目标预测函数是根据温度参数和所述第一预测函数得到的,所述温度参数用于使得所述第一网络的输出结果输入所述软目标预测函数得到的软目标,不大于所述第一网络的输出结果输入所述第一预测函数得到的所述第一初级模型的预测值;
所述学习单元用于使用所述本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,其中,所述第一中级网络的网络类型与所述神经网络的网络类型相同;
所述学习单元用于根据所述混合损失函数对所述第一中级网络进行反向传播,获得训练好的第一中级模型。
17.根据权利要求16所述的节点,其特征在于,所述第一中级网络包括第二网络和第二预测函数,所述第二预测函数用于根据所述第二网络的输出结果生成所述第一中级网络的预测值,所述混合损失函数是根据第一损失函数以及第二损失函数确定的,所述第一损失函数是所述计算节点将所述本地数据库中的训练数据输入所述第一中级网络的所述第二网络获得第二网络的输出结果后,将所述第二网络的输出结果输入所述软目标预测函数生成的软目标预测值与软目标之间的差距确定的,所述第二损失函数是将所述第二网络的输出结果输入所述第二预测函数生成的预测值与所述真实标签之间的差距确定的。
18.根据权利要求15所述的节点,其特征在于,所述第一高级模型是所述中心节点根据所述多个第一中级模型获得多教师融合模型后,使用所述中心节点的所述中心数据库对所述多教师融合模型进行增量学习后获得的,其中,所述多教师融合模型的全连接层提取的特征向量是所述多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与所述第一高级模型的特征向量的维度相同。
19.根据权利要求18所述的节点,其特征在于,所述多教师融合模型的全连接层提取的特征向量是m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述m个正向提升模型是根据所述多个第一中级模型得到的,所述正向提升模型是对基础模型有指导作用的模型,所述基础模型是将所述中心数据库中的测试数据输入所述多个第一中级模型后,所述多个第一中级模型中预测准确率最高的模型。
20.根据权利要求15至19任一权利要求所述的节点,其特征在于,所述接收单元还用于在向所述中心节点发送所述第一中级模型,使得所述中心节点对接收到的多个所述第一中级模型进行模型融合,获得第一高级模型之后,所述接收单元在所述本地数据库中存在新数据的情况下,接收由中心节点发送的第二初级模型,其中,所述第二初级模型是所述第一高级模型;
所述学习单元还用于结合所述新数据对所述第二初级模型进行增量学习,获得第二中级模型;
所述发送单元还用于向所述中心节点发送所述第二中级模型,使得所述中心节点对接收到的多个所述更新后的第二中级模型进行模型融合,获得第二高级模型。
21.根据权利要求15至19任一权利要求所述的节点,其特征在于,所述一个或多个中心节点包括第一中心节点和第二中心节点,所述接收单元还用于接收由第一中心节点发送的第一初级模型;
所述发送单元还用于向所述第二中心节点发送所述第一中级模型,使得所述第二中心节点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
22.一种中心节点,其特征在于,应用于联合学习系统,所述联合学习系统包括一个或者多个计算节点以及一个或多个所述中心节点,所述中心节点包括发送单元、接收单元以及融合单元,其中,
所述发送单元用于向所述计算节点发送第一初级模型,其中,所述第一初级模型是所述中心节点根据中心节点的中心数据库对神经网络进行训练后得到的;
所述接收单元用于接收所述计算节点发送的第一中级模型,其中,所述第一中级模型是所述计算节点使用所述计算节点的本地数据库对所述第一初级模型进行增量学习后获得的;
所述融合单元用于对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型。
23.根据权利要求22所述的节点,其特征在于,所述第一中级模型是所述计算节点将所述本地数据库中的训练数据输入所述第一初级模型中的第一网络获得第一网络的输出结果,并将所述第一网络的输出结果输入软目标预测函数获得所述本地数据库中的训练数据的软目标之后,使用所述本地数据库对第一中级网络进行训练,根据所述软目标以及所述本地数据库中的训练数据的真实标签得到混合损失函数,根据所述混合损失函数对所述第一中级网络进行反向传播从而获得的所述第一中级模型,其中,所述第一初级模型包括第一网络和第一预测函数,所述第一预测函数用于根据所述第一网络的输出结果生成所述第一初级模型的预测值,所述软目标预测函数是根据温度参数和所述第一预测函数得到的,所述温度参数用于使得所述第一网络的输出结果输入所述软目标预测函数得到的软目标,不大于所述第一网络的输出结果输入所述第一预测函数得到的所述第一初级模型的预测值。
24.根据权利要求23所述的节点,其特征在于,所述第一中级网络包括第二网络和第二预测函数,所述第二预测函数用于根据所述第二网络的输出结果生成所述第一中级网络的预测值,所述混合损失函数是根据第一损失函数以及第二损失函数确定的,所述第一损失函数是所述计算节点将所述本地数据库中的训练数据输入所述第一中级网络的所述第二网络获得第二网络的输出结果后,将所述第二网络的输出结果输入所述软目标预测函数生成的软目标预测值与软目标之间的差距确定的,所述第二损失函数是将所述第二网络的输出结果输入所述第二预测函数生成的预测值与所述真实标签之间的差距确定的。
25.根据权利要求22所述的节点,其特征在于,所述融合单元用于根据所述多个第一中级模型,获得多教师融合模型,其中,所述多教师融合模型的全连接层提取的特征向量是所述多个第一中级模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与第一高级模型的特征向量维度相同;
所述融合单元用于使用所述中心数据库对所述多教师融合模型进行增量学习,获得所述第一高级模型。
26.根据权利要求25所述的节点,其特征在于,所述融合单元用于在多个第一中级模型中筛选出基础模型和m个正向提升模型,其中,所述基础模型是所述中心节点将所述中心数据库中的测试数据输入所述多个第一中级模型后,所述多个第一中级模型中预测准确率最高的模型,所述正向提升模型是对所述基础模型有指导作用的模型;
所述融合单元用于根据所述m个正向提升模型,获得所述多教师融合模型,其中,所述多教师融合模型的全连接层提取的特征向量是所述m个正向提升模型的全连接层提取的特征向量进行拼接并降维处理后得到的,所述多教师融合模型的全连接层提取的特征向量与所述第一高级模型的特征向量维度相同。
27.根据权利要求22至26所述的节点,其特征在于,所述接收单元还用于在点对接收到的多个所述第一中级模型进行模型融合从而获得第一高级模型之后,接收由所述计算节点发送的所述计算节点的所述本地数据库中存在新数据的消息;
所述发送单元还用于根据所述本地数据库中存在新数据的消息,向所述计算节点发送第二初级模型,使得所述计算节点结合所述新数据对所述第二初级模型进行增量学习,获得第二中级模型,其中,所述第二初级模型是所述高级模型;
所述接收单元还用于接收所述计算节点发送的所述第二中级模型;
所述融合单元还用于对接收到的多个所述第二中级模型进行模型融合,获得第二高级模型。
28.一种联合学习系统,其特征在于,所述系统包括一个或多个如权利要求22至27任一所述的中心节点以及一个或多个如权利要求15至21任一所述的计算节点。
29.一种计算机可读存储介质,其特征在于,包括指令,当所述指令在计算设备上运行时,使得所述计算设备执行如权利要求1至14任一权利要求所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010038486.3A CN113191479A (zh) | 2020-01-14 | 2020-01-14 | 联合学习的方法、系统、节点及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010038486.3A CN113191479A (zh) | 2020-01-14 | 2020-01-14 | 联合学习的方法、系统、节点及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113191479A true CN113191479A (zh) | 2021-07-30 |
Family
ID=76972383
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010038486.3A Pending CN113191479A (zh) | 2020-01-14 | 2020-01-14 | 联合学习的方法、系统、节点及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113191479A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113984741A (zh) * | 2021-09-29 | 2022-01-28 | 西北大学 | 基于深度学习和声音信号液体识别系统及方法 |
CN114782960A (zh) * | 2022-06-22 | 2022-07-22 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备及计算机可读存储介质 |
WO2023102714A1 (en) * | 2021-12-07 | 2023-06-15 | Intel Corporation | Decentralized active-learning model update and broadcast mechanism in internet-of-things environment |
WO2023184958A1 (zh) * | 2022-03-29 | 2023-10-05 | 上海商汤智能科技有限公司 | 目标识别及神经网络的训练 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN110263936A (zh) * | 2019-06-14 | 2019-09-20 | 深圳前海微众银行股份有限公司 | 横向联邦学习方法、装置、设备及计算机存储介质 |
CN110442457A (zh) * | 2019-08-12 | 2019-11-12 | 北京大学深圳研究生院 | 基于联邦学习的模型训练方法、装置及服务器 |
CN110443063A (zh) * | 2019-06-26 | 2019-11-12 | 电子科技大学 | 自适性保护隐私的联邦深度学习的方法 |
-
2020
- 2020-01-14 CN CN202010038486.3A patent/CN113191479A/zh active Pending
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109711544A (zh) * | 2018-12-04 | 2019-05-03 | 北京市商汤科技开发有限公司 | 模型压缩的方法、装置、电子设备及计算机存储介质 |
CN110263936A (zh) * | 2019-06-14 | 2019-09-20 | 深圳前海微众银行股份有限公司 | 横向联邦学习方法、装置、设备及计算机存储介质 |
CN110443063A (zh) * | 2019-06-26 | 2019-11-12 | 电子科技大学 | 自适性保护隐私的联邦深度学习的方法 |
CN110442457A (zh) * | 2019-08-12 | 2019-11-12 | 北京大学深圳研究生院 | 基于联邦学习的模型训练方法、装置及服务器 |
Non-Patent Citations (1)
Title |
---|
MENG-CHIEH WU ET AL.: "Multi-teacher Knowledge Distillation for Compressed Video Action Recognition on Deep Neural Networks", 2019 IEEE INTERNATIONAL CONFERENCE ON ACOUSTICS, SPEECH AND SIGNAL PROCESSING, 12 May 2019 (2019-05-12), pages 1 - 5 * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113984741A (zh) * | 2021-09-29 | 2022-01-28 | 西北大学 | 基于深度学习和声音信号液体识别系统及方法 |
WO2023102714A1 (en) * | 2021-12-07 | 2023-06-15 | Intel Corporation | Decentralized active-learning model update and broadcast mechanism in internet-of-things environment |
WO2023184958A1 (zh) * | 2022-03-29 | 2023-10-05 | 上海商汤智能科技有限公司 | 目标识别及神经网络的训练 |
CN114782960A (zh) * | 2022-06-22 | 2022-07-22 | 深圳思谋信息科技有限公司 | 模型训练方法、装置、计算机设备及计算机可读存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2022083536A1 (zh) | 一种神经网络构建方法以及装置 | |
CN112651511B (zh) | 一种训练模型的方法、数据处理的方法以及装置 | |
US20220319154A1 (en) | Neural network model update method, image processing method, and apparatus | |
CN113191479A (zh) | 联合学习的方法、系统、节点及存储介质 | |
CN113011282A (zh) | 图数据处理方法、装置、电子设备及计算机存储介质 | |
WO2022105714A1 (zh) | 数据处理方法、机器学习的训练方法及相关装置、设备 | |
JP2022553252A (ja) | 画像処理方法、画像処理装置、サーバ、及びコンピュータプログラム | |
WO2021218471A1 (zh) | 一种用于图像处理的神经网络以及相关设备 | |
WO2022156561A1 (zh) | 一种自然语言处理方法以及装置 | |
CN109034206A (zh) | 图像分类识别方法、装置、电子设备及计算机可读介质 | |
US20220351039A1 (en) | Federated learning using heterogeneous model types and architectures | |
CN110222717A (zh) | 图像处理方法和装置 | |
CN113516227B (zh) | 一种基于联邦学习的神经网络训练方法及设备 | |
CN112862828B (zh) | 一种语义分割方法、模型训练方法及装置 | |
WO2022012668A1 (zh) | 一种训练集处理方法和装置 | |
CN113505883A (zh) | 一种神经网络训练方法以及装置 | |
WO2021169366A1 (zh) | 数据增强方法和装置 | |
CN112395979A (zh) | 基于图像的健康状态识别方法、装置、设备及存储介质 | |
US20200210754A1 (en) | Cloud device, terminal device, and method for classifyiing images | |
CN113536970A (zh) | 一种视频分类模型的训练方法及相关装置 | |
CN116057542A (zh) | 分布式机器学习模型 | |
CN115018039A (zh) | 一种神经网络蒸馏方法、目标检测方法以及装置 | |
CN113657272B (zh) | 一种基于缺失数据补全的微视频分类方法及系统 | |
WO2021136058A1 (zh) | 一种处理视频的方法及装置 | |
WO2023231753A1 (zh) | 一种神经网络的训练方法、数据的处理方法以及设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |