CN113469340A - 一种模型处理方法、联邦学习方法及相关设备 - Google Patents
一种模型处理方法、联邦学习方法及相关设备 Download PDFInfo
- Publication number
- CN113469340A CN113469340A CN202110763965.6A CN202110763965A CN113469340A CN 113469340 A CN113469340 A CN 113469340A CN 202110763965 A CN202110763965 A CN 202110763965A CN 113469340 A CN113469340 A CN 113469340A
- Authority
- CN
- China
- Prior art keywords
- model
- pruning
- loss function
- neural network
- models
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 135
- 238000003672 processing method Methods 0.000 title abstract description 10
- 238000013138 pruning Methods 0.000 claims abstract description 284
- 238000012549 training Methods 0.000 claims abstract description 143
- 238000003062 neural network model Methods 0.000 claims abstract description 119
- 230000006870 function Effects 0.000 claims description 222
- 238000011144 upstream manufacturing Methods 0.000 claims description 122
- 238000012545 processing Methods 0.000 claims description 51
- 210000002569 neuron Anatomy 0.000 claims description 39
- 230000002776 aggregation Effects 0.000 claims description 10
- 238000004220 aggregation Methods 0.000 claims description 10
- 230000004931 aggregating effect Effects 0.000 claims description 9
- 238000004590 computer program Methods 0.000 claims description 6
- 230000000379 polymerizing effect Effects 0.000 claims description 3
- 238000006116 polymerization reaction Methods 0.000 claims description 2
- 230000008569 process Effects 0.000 abstract description 68
- 238000013528 artificial neural network Methods 0.000 description 37
- 244000141353 Prunus domestica Species 0.000 description 35
- 238000004891 communication Methods 0.000 description 25
- 238000004422 calculation algorithm Methods 0.000 description 20
- 238000010586 diagram Methods 0.000 description 17
- 238000010801 machine learning Methods 0.000 description 16
- 238000013527 convolutional neural network Methods 0.000 description 15
- 238000013473 artificial intelligence Methods 0.000 description 14
- 239000013598 vector Substances 0.000 description 12
- 230000001537 neural effect Effects 0.000 description 11
- 238000004364 calculation method Methods 0.000 description 8
- 238000005516 engineering process Methods 0.000 description 8
- 230000001133 acceleration Effects 0.000 description 7
- 230000005540 biological transmission Effects 0.000 description 7
- 239000011159 matrix material Substances 0.000 description 7
- 230000004913 activation Effects 0.000 description 5
- 238000005457 optimization Methods 0.000 description 5
- 230000001413 cellular effect Effects 0.000 description 4
- 238000012937 correction Methods 0.000 description 4
- 238000011156 evaluation Methods 0.000 description 4
- 238000012804 iterative process Methods 0.000 description 4
- 230000001154 acute effect Effects 0.000 description 3
- 230000008878 coupling Effects 0.000 description 3
- 238000010168 coupling process Methods 0.000 description 3
- 238000005859 coupling reaction Methods 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 238000001514 detection method Methods 0.000 description 3
- 238000007726 management method Methods 0.000 description 3
- 238000010295 mobile communication Methods 0.000 description 3
- 210000004027 cell Anatomy 0.000 description 2
- 230000008859 change Effects 0.000 description 2
- 230000006835 compression Effects 0.000 description 2
- 238000007906 compression Methods 0.000 description 2
- 238000013500 data storage Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 239000000835 fiber Substances 0.000 description 2
- 230000007774 longterm Effects 0.000 description 2
- 230000008447 perception Effects 0.000 description 2
- 230000000306 recurrent effect Effects 0.000 description 2
- 238000011160 research Methods 0.000 description 2
- 239000007787 solid Substances 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 230000015572 biosynthetic process Effects 0.000 description 1
- 238000011217 control strategy Methods 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000007599 discharging Methods 0.000 description 1
- 238000006073 displacement reaction Methods 0.000 description 1
- 238000005538 encapsulation Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 230000004927 fusion Effects 0.000 description 1
- 230000008570 general process Effects 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000005484 gravity Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 239000007788 liquid Substances 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 230000001902 propagating effect Effects 0.000 description 1
- 238000010079 rubber tapping Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 239000004576 sand Substances 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000012163 sequencing technique Methods 0.000 description 1
- 230000005236 sound signal Effects 0.000 description 1
- 238000010897 surface acoustic wave method Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/082—Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Mathematical Physics (AREA)
- General Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- Biomedical Technology (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Image Analysis (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请实施例公开了一种模型处理方法,可以应用于模型训练与剪枝场景,该方法可以由客户端执行,还可以由客户端的部件(例如处理器、芯片或芯片系统等)执行,该方法包括:根据第一损失函数与训练数据训练神经网络模型,得到第一模型;基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型,约束条件用于约束第二模型的精度不低于第一模型的精度。在对第一模型进行剪枝的过程中,考虑基于数据损失函数的约束条件,相当于为第一模型的剪枝提供一个方向,使得剪枝得到的第二模型的精度不低于第一模型的精度,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。
Description
技术领域
本申请实施例涉及终端人工智能领域,尤其涉及一种模型处理方法、联邦学习方法及相关设备。
背景技术
人工智能(artificial intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式作出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。人工智能领域的研究包括机器人,自然语言处理,计算机视觉,决策与推理,人机交互,推荐与搜索,AI基础理论等。
目前,深度学习作为机器学习的主流分支之一应用广泛,但深度神经网络模型大、参数多以及其带来的计算、存储、功耗、时延等方面的问题阻碍了深度学习模型的产品化。为解决该问题,需要对深度神经网络进行简化。其中,剪枝技术应用最广,剪枝技术通过移除部分参数和模块实现对深度神经网络的压缩。经典的剪枝过程包括三个步骤,首先基于本地数据集完成对模型的训练,然后根据预设的规则对训练好的模型剪枝,最后还需要利用本地数据集对剪枝后的模型进行微调以避免模型精度损失太多。整个剪枝流程较为繁琐,效率较低。
发明内容
本申请实施例提供了一种模型处理方法及相关设备,该方法可以和联邦学习方法结合起来使用。在对第一模型进行剪枝的过程中,考虑基于数据损失函数的约束条件,相当于约束了第一模型的剪枝方向,使得剪枝得到的第二模型的精度不低于第一模型的精度,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。
本申请第一方面提供了一种模型处理方法,可以应用于模型训练与剪枝场景,该方法可以由模型处理设备(例如客户端)执行,还可以由客户端的部件(例如处理器、芯片或芯片系统等)执行,该方法包括:获取包括标签值的训练数据;以训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型,第一模型包括多个子结构,多个子结构中的每个子结构包括至少两个神经元;基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型,第二损失函数用于指示将多个子结构中的至少一个子结构进行剪枝,约束条件用于约束第二模型的精度不低于第一模型的精度,精度指示模型的输出值与标签值之间的差异程度。
其中,第一损失函数也可以理解为是数据损失函数,主要用于在使用数据训练模型的过程中评估模型的精度。第二损失函数可以理解为是稀疏损失函数,主要用于对模型进行稀疏(或称为剪枝)。子结构可以是神经网络模型的通道、特征图、网络层、子网络、或者预定义的由多个神经元组成的其他网络结构;当神经网络模型是卷积神经网络时,子结构还可以是卷积核。总之,一个子结构可以看作一个功能整体,在剪枝时对子结构进行剪枝,是指将该子结构包括的所有神经元都进行剪枝。
本申请实施例中,在对第一模型进行剪枝的过程中,考虑基于数据损失函数的约束条件,相当于约束了第一模型的剪枝方向,使得剪枝得到的第二模型的精度不低于第一模型的精度,减少后续通过微调调整模型精度的步骤,而且是对子结构进行剪枝,比对神经元逐个剪枝效率更高,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率,得到的模型结构也更简洁。
可选地,在第一方面的一种可能的实现方式中,上述的约束条件具体用于约束第一损失函数的下降方向与第二损失函数的下降方向之间的夹角小于或等于90度。其中,第一损失函数的下降方向可以是对第一损失函数求导得到的梯度方向,第二损失函数的下降方向可以是对第二损失函数求导得到的梯度方向。
该种可能的实现方式中,通过调整第一损失函数的下降方向与第二损失函数的下降方向的夹角小于或等于90度,可以保证剪枝后的第二模型精度较剪枝前的第一模型的精度不会有所下降,减少了后续微调模型精度的步骤。
可选地,在第一方面的一种可能的实现方式中,上述的约束条件具体用于约束第二模型的第一损失函数的值小于或等于第一模型的第一损失函数的值。换句话说,用第一模型和第二模型对相同的数据做预测,再用第一损失函数衡量第一模型和第二模型的精度,对应第一损失函数的值越小的模型的精度越高。
该种可能的实现方式中,具体精度的确定方式可以是用第二模型的第一损失函数的值与第一模型的第一损失函数的值进行确定。当然,也可以使用与第一损失函数不同的评价方式来比较第一模型和第二模型的精度,本申请不限定具体的评价方式。
可选地,在第一方面的一种可能的实现方式中,上述的第二损失函数包括第一稀疏项,第一稀疏项与多个子结构中的至少一个子结构的权重相关。
该种可能的实现方式中,在对第一模型进行剪枝时,第二损失函数中的第一稀疏项是将子结构作为一个整体进行处理,所以剪枝的时候被剪掉的都是通道、卷积核、特征图、网络层等网络结构,而不是单个的神经元,大大提高了剪枝的效率,得到的模型也更加精炼、轻量。
可选地,在第一方面的一种可能的实现方式中,上述的第二损失函数还包括差异项,该差异项指示第一模型与第二模型的差异。
该种可能的实现方式中,通过在第二损失函数增加差异项,可以一定程度上约束剪枝前后的模型的差异程度,保证剪枝前后模型的相似性,进而保证剪枝后模型的精度。
可选地,在第一方面的一种可能的实现方式中,上述步骤:基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型,包括:基于约束条件计算更新系数,更新系数用于调整第一稀疏项的方向;使用更新系数更新第二损失函数中的第一稀疏项,以得到第三损失函数,第三损失函数包括差异项与第二稀疏项,第二稀疏项基于更新系数与第一稀疏项更新得到;基于第三损失函数对第一模型进行剪枝,以得到第二模型。
该种可能的实现方式中,通过引入更新系数可以调整剪枝方向,从而使得剪枝后的第二模型满足约束条件。该情况下的剪枝也可以理解为是定向剪枝。
可选地,在第一方面的一种可能的实现方式中,上述的第三损失函数包括:
该种可能的实现方式中,通过更新系数更新第一稀疏项得到第二稀疏项,且更新后的第二稀疏项可以对模型进行定向剪枝,保证剪枝后模型的精度不损失。
可选地,在第一方面的一种可能的实现方式中,上述步骤:基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型,包括:基于第二损失函数对第一模型进行至少一次随机剪枝,直至剪枝第一模型后得到的第二模型满足约束条件。具体可以是基于第二损失函数对第一模型进行随机剪枝,以得到第二模型;若满足约束条件,则输出第二模型;若不满足约束条件,则重复基于第二损失函数对第一模型进行随机剪枝的步骤直至满足约束条件。
该种可能的实现方式中,可以通过随机剪枝加上约束条件判断的方式进行剪枝,只有满足约束条件的情况下,才输出剪枝后的模型。这种方式不需要使用数据来微调剪枝后的模型,通用性更高。
可选地,在第一方面的一种可能的实现方式中,上述方法应用于联邦学习系统中的客户端,用来训练神经网络模型的数据是客户端本地的数据,例如客户端的传感器采集的数据或客户端的程序应用等运行过程中生成的数据等,该方法还包括:接收上游设备发送的神经网络模型;向上游设备发送第二模型。上游设备是可以与客户端通信的服务器等设备。
该种可能的实现方式中,该方法可以应用于联邦学习场景,一方面通过引入约束条件对模型子结构进行剪枝,可以帮助上游设备筛选出精度不损失、结构简化的且用于聚合的多个模型,减少了上行链路(即客户端到上游设备的通信链路)的通信负担。
可选地,在第一方面的一种可能的实现方式中,上述的训练数据包括:图像数据、音频数据或者文本数据等。可以理解的是,上述三种只是对训练数据的举例,在实际应用中,根据神经网络模型的所处理的任务类型不同,训练数据的具体形式不同,此处不做限定。
可选地,在第一方面的一种可能的实现方式中,上述的神经网络模型用于对图像数据进行分类和/或识别等。可以理解的是,在实际应用中,神经网络模型还可以用于目标检测、信息推荐、语音识别、文字识别、问答任务、人机游戏等,具体此处不做限定。
该种可能的实现方式中,对于应用于任何场景(例如:智能终端、智能交通、智能医疗、自动驾驶、智慧城市等)的神经网络模型,都可以适用于本申请实施例提供的剪枝方法,有利于提升神经网络模型的剪枝效率,在减少神经网络占用的存储空间外,还可以保证神经网络模型的精度。
在一种可能的实现方式中,上述模型训练与剪枝的步骤可以是一次也可以是多次,具体根据需要设置,该情况下可以得到更加符合用户预期的模型,在保证模型精度不损失的情况下完成剪枝,节省存储和通信成本。
本申请第二方面提供了一种联邦学习方法,可以应用于模型剪枝场景,该方法可以由上游设备(云服务器或边服务器等)执行,还可以由上游设备的部件(例如处理器、芯片或芯片系统等)执行,该方法可以理解为是先约束剪枝,后聚合的操作。该方法包括:向多个下游设备发送神经网络模型,神经网络模型包括多个子结构,每个子结构包括至少两个神经元;接收来自多个下游设备的多个第一模型,多个第一模型由神经网络模型训练得到,其中,训练过程中使用的损失函数可以称为第一损失函数;基于第二损失函数)以及约束条件对多个第一模型分别进行剪枝,其中,第二损失函数用于指示对多个第一模型的子结构进行剪枝,约束条件用于约束每个第一模型剪枝后的精度不低于剪枝前的精度;将剪枝后的多个第一模型进行聚合,以得到第二模型。
可选地,在第二方面的一种可能的实现方式中,上述步骤:基于第二损失函数以及约束条件对第一模型进行剪枝,包括:基于第二损失函数对第一模型进行至少一次随机剪枝,直至剪枝第一模型后得到的模型满足约束条件。具体可以是基于第二损失函数对第一模型进行随机剪枝,以得到剪枝前后的第一模型;若满足约束条件,则输出模型;若不满足约束条件,则重复基于第二损失函数对第一模型进行随机剪枝的步骤直至满足约束条件。
本申请第三方面提供了一种联邦学习方法,可以应用于模型剪枝场景,该方法可以由服务器(云服务器或边服务器等)执行,还可以由服务器的部件(例如处理器、芯片或芯片系统等)执行,该方法可以理解为是先聚合,后约束剪枝的操作。该方法包括:向多个下游设备发送神经网络模型,神经网络模型包括多个子结构,每个子结构包括至少两个神经元;接收来自多个下游设备的多个第一模型,多个第一模型由神经网络模型训练得到,其中,训练过程中使用的损失函数可以称为第一损失函数;将多个第一模型进行聚合,以得到第二模型;基于损失函数(后续称为第二损失函数)以及约束条件对第二模型进行剪枝,其中,第二损失函数用于指示对第二模型的子结构进行剪枝,约束条件用于约束第二模型剪枝后的精度不低于剪枝前的精度。
可选地,在第三方面的一种可能的实现方式中,上述步骤:基于第二损失函数以及约束条件对第二模型进行剪枝,包括:基于第二损失函数对第二模型进行至少一次随机剪枝,直至剪枝第二模型后得到的模型满足约束条件。具体可以是基于第二损失函数对第二模型进行随机剪枝,以得到模型;若满足约束条件,则输出剪枝后的第二模型;若不满足约束条件,则重复基于第二损失函数对第二模型进行随机剪枝的步骤直至满足约束条件。
在第二方面/第三方面提供的实现方式中,服务器使用本申请实施例提供的方法对模型进行剪枝之后,不需要再利用训练数据来调整模型以保证模型精度,也就是说,服务器可以在不使用客户端的训练数据的情况下对模型进行剪枝,且保证模型的精度,这样就避免了剪枝时要将客户端的训练数据传输到上游设备,可以保护客户端的数据隐私。在上述第二方面/第三方面提供的实现方式中,子结构可以是神经网络模型的通道、特征图、网络层、子网络、或者预定义的由多个神经元组成的其他网络结构;当神经网络模型是卷积神经网络时,子结构还可以是卷积核。总之,一个子结构可以看作一个功能整体,在剪枝时对子结构进行剪枝,是指将该子结构包括的所有神经元都进行剪枝。本申请提供的方法将每个子结构作为一个整体进行剪枝,比对神经元逐个剪枝效率更高,得到的模型结构也更简洁。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述的多个第一模型是根据第一损失函数训练得到,剪枝所使用的损失函数称为第二损失函数,约束条件具体用于约束第一损失函数的下降方向与第二损失函数的下降方向之间的夹角小于或等于90度。其中,第一损失函数的下降方向可以是对第一损失函数求导得到的梯度方向,第二损失函数的下降方向可以是对第二损失函数求导得到的梯度方向。
该种可能的实现方式中,通过调整第一损失函数的下降方向与第二损失函数的下降方向的夹角小于或等于90度,可以保证剪枝前后的第一模型的精度不会有所下降,减少了后续微调模型精度的步骤。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述的约束条件具体用于约束剪枝后的模型的第一损失函数的值小于或等于剪枝前模型的第一损失函数的值。换句话说,用剪枝前后的模型对相同的数据做预测,再用第一损失函数衡量剪枝前后模型的精度,对应第一损失函数的值越小的模型的精度越高。
该种可能的实现方式中,具体精度的确定方式可以是用剪枝前后的模型的第一损失函数的值进行确定。当然,也可以使用与第一损失函数不同的评价方式来比较剪枝前后模型的精度,本申请不限定具体的评价方式。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述步骤还包括:向所述多个下游设备发送剪枝后的模型。
该种可能的实现方式中,可以应用于云服务器或边服务器进行剪枝、聚合的场景,剪枝聚合之后,向多个下游设备发送剪枝后的模型。以使下游设备使用剪枝后的模型进行推理,或者对剪枝后的模型进行再训练。将模型剪枝后再发送给下游设备,一方面可以减少通信负担,另一方面可以降低对下游设备的存储空间、处理能力的要求。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述步骤还包括:向上游设备发送剪枝后的模型。
该种可能的实现方式中,可以应用于边服务器进行剪枝、聚合的场景,剪枝聚合之后,向上游设备发送剪枝后的模型。以使上游服务器继续对模型进行聚合、剪枝等操作,以综合来自更多客户端设备的信息。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述的第二损失函数包括第一稀疏项,第一稀疏项与多个子结构中的至少一个子结构的权重相关。
该种可能的实现方式中,在对模型进行剪枝时,第二损失函数中的第一稀疏项是将子结构作为一个整体进行处理,所以剪枝的时候被剪掉的都是通道、卷积核、特征图、网络层等网络结构,而不是单个的神经元,大大提高了剪枝的效率,得到的模型也更加精炼、轻量。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述的第二损失函数还包括差异项,差异项指示剪枝前后模型的差异。
该种可能的实现方式中,通过在第二损失函数增加差异项,可以一定程度上约束剪枝前后的模型的差异程度,保证剪枝前后模型的相似性。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述步骤:基于第二损失函数以及约束条件对多个第一模型分别进行剪枝,包括:基于约束条件计算更新系数,更新系数用于调整第一稀疏项的方向;使用更新系数更新第二损失函数中的第一稀疏项,以得到第三损失函数,第三损失函数包括差异项与第二稀疏项,第二稀疏项基于更新系数与第一稀疏项更新得到;基于第三损失函数对模型进行剪枝。
该种可能的实现方式中,通过引入更新系数可以调整剪枝方向,从而使得剪枝后的第一模型满足约束条件。该情况下的剪枝也可以理解为是定向剪枝。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述的第三损失函数包括:
该种可能的实现方式中,通过更新系数更新第一稀疏项得到第二稀疏项,且更新后的第二稀疏项可以对模型进行定向剪枝,保证剪枝后模型的精度不损失。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述的训练数据包括:图像数据、音频数据或者文本数据等。可以理解的是,上述三种只是对训练数据的举例,在实际应用中,根据神经网络模型的输入不同,训练数据的具体形式不同,此处不做限定。
可选地,在第二方面/第三方面的一种可能的实现方式中,上述的神经网络模型用于对图像数据进行分类和/或识别等。可以理解的是,在实际应用中,神经网络模型还可以用于预测、编码、解码等,具体此处不做限定。
在一种可能的实现方式中,上述接收、剪枝、聚合、发送的步骤可以是一次也可以是多次,具体根据需要设置,该情况下可以得到更加符合用户预期的模型,在保证模型精度不损失的情况下完成剪枝,节省存储和通信成本。
本申请第四方面提供了一种模型处理设备,可以应用于模型训练与剪枝场景,该模型处理设备可以是客户端,该模型处理设备包括:获取单元,用于获取包括标签值的训练数据;训练单元,用于以训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型,第一模型包括多个子结构,每个子结构包括至少两个神经元;剪枝单元,用于基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型,第二损失函数用于指示将多个子结构中的至少一个子结构进行剪枝,约束条件用于约束第二模型的精度不低于第一模型的精度,精度指示模型的输出值与标签值之间的差异程度。其中,第一损失函数也可以理解为是数据损失函数,主要用于在使用数据训练模型的过程中评估模型的精度。第二损失函数可以理解为是稀疏损失函数,主要用于对模型进行稀疏(或称为剪枝)。
可选地,上述第四方面提供的模型处理设备的各个单元可以被配置为用于实现前述第一方面的任意可能的实现方式中的方法。
本申请第五方面提供了一种上游设备,可以应用于模型训练与剪枝场景、联邦学习场景等,该上游设备可以是联邦学习场景中的云服务器或边服务器,该上游设备包括:发送单元,用于向多个下游设备发送神经网络模型,神经网络模型包括多个子结构,每个子结构包括至少两个神经元;接收单元,用于接收来自多个下游设备的多个第一模型,多个第一模型由神经网络模型训练得到,其中,训练过程中使用的损失函数可以称为第一损失函数;剪枝单元,用于基于损失函数(后续称为第二损失函数)以及约束条件对多个第一模型分别进行剪枝,其中,第二损失函数用于指示对多个第一模型的子结构进行剪枝,约束条件用于约束每个第一模型剪枝后的精度不低于剪枝前的精度;聚合单元,用于将剪枝后的多个第一模型进行聚合,以得到第二模型。
可选地,上述第五方面提供的上游设备的各个单元可以被配置为用于实现前述第二方面的任意可能的实现方式中的方法。
本申请第六方面提供了一种上游设备,可以应用于模型训练与剪枝场景、联邦学习场景等,该上游设备可以是联邦学习场景中的云服务器或边服务器,该上游设备包括:发送单元,用于向多个下游设备发送神经网络模型,神经网络模型包括多个子结构,每个子结构包括至少两个神经元;接收单元,用于接收来自多个下游设备的多个第一模型,多个第一模型由神经网络模型训练得到,其中,训练过程中使用的损失函数可以称为第一损失函数;聚合单元,用于将多个第一模型进行聚合,以得到第二模型;剪枝单元,用于基于损失函数(后续称为第二损失函数)以及约束条件对第二模型进行剪枝,其中,第二损失函数用于指示对第二模型的子结构进行剪枝,约束条件用于约束第二模型剪枝后的精度不低于剪枝前的精度。
可选地,上述第六方面提供的上游设备的各个单元可以被配置为用于实现前述第三方面的任意可能的实现方式中的方法。
本申请第七方面提供了一种电子设备,包括:处理器,处理器与存储器耦合,存储器用于存储程序或指令,当程序或指令被处理器执行时,使得该电子设备实现上述第一方面、第二方面、第三方面的任意可能的实现方式中的方法。
本申请第八方面提供了一种计算机可读介质,其上存储有计算机程序或指令,当计算机程序或指令在计算机上运行时,使得计算机执行前述第一方面、第二方面、第三方面的任意可能的实现方式中的方法。
本申请第九方面提供了一种计算机程序产品,该计算机程序产品在计算机上执行时,使得计算机执行前述第一方面、第二方面或第三方面的任意可能的实现方式中的方法。
上述第四方面、第五方面、第六方面、第七方面、第八方面、第九方面的任一种可能的实现方式所带来的技术效果可参考前面第一方面、第二方面、第三方面中对应的实现方式所带来的技术效果,此处不再赘述。
附图说明
图1为人工智能主体框架的一种结构示意图;
图2为本申请提供的一种联邦学习系统的架构示意图;
图3为本申请提供的另一种联邦学习系统的架构示意图;
图4为本申请提供的另一种联邦学习系统的架构示意图;
图5为本申请提供的另一种联邦学习系统的架构示意图;
图6为本申请提供的联邦学习方法的一个流程示意图;
图7A、图8-图10为本申请提供的剪枝过程中剪枝方向的几种示意图;
图7B为本申请提供的剪枝前后模型的结构示意图;
图11为本申请提供的联邦学习方法的另一个流程示意图;
图12为本申请提供的模型处理方法一个流程示意图;
图13为本申请提供的模型处理设备的一个结构示意图;
图14为本申请提供的上游设备的一个结构示意图;
图15为本申请提供的上游设备的另一个结构示意图;
图16为本申请提供的模型处理设备的另一个结构示意图;
图17为本申请提供的上游设备的另一个结构示意图。
具体实施方式
本申请实施例提供了一种模型处理方法、联邦学习方法及相关设备。在对第一模型进行剪枝的过程中,考虑基于数据损失函数的约束条件,相当于为第一模型的剪枝提供一个方向,使得剪枝得到的第二模型的精度不低于第一模型的精度,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
首先对人工智能系统总体工作流程进行描述,请参见图1,图1示出的为人工智能主体框架的一种结构示意图,下面从“智能信息链”(水平轴)和“IT价值链”(垂直轴)两个维度对上述人工智能主题框架进行阐述。其中,“智能信息链”反映从数据的获取到处理的一列过程。举例来说,可以是智能信息感知、智能信息表示与形成、智能推理、智能决策、智能执行与输出的一般过程。在这个过程中,数据经历了“数据—信息—知识—智慧”的凝练过程。“IT价值链”从人智能的底层基础设施、信息(提供和处理技术实现)到系统的产业生态过程,反映人工智能为信息技术产业带来的价值。
(1)基础设施
基础设施为人工智能系统提供计算能力支持,实现与外部世界的沟通,并通过基础平台实现支撑。通过传感器与外部沟通;计算能力由智能芯片,如中央处理器(centralprocessing unit,CPU)、网络处理器(neural-network processing unit,NPU)、图形处理器(英语:graphics processing unit,GPU)、专用集成电路(application specificintegrated circuit,ASIC)或现场可编程逻辑门阵列(field programmable gate array,FPGA)等硬件加速芯片)提供;基础平台包括分布式计算框架及网络等相关的平台保障和支持,可以包括云存储和计算、互联互通网络等。举例来说,传感器和外部沟通获取数据,这些数据提供给基础平台提供的分布式计算系统中的智能芯片进行计算。
(2)数据
基础设施的上一层的数据用于表示人工智能领域的数据来源。数据涉及到图形、图像、语音、文本,还涉及到传统设备的物联网数据,包括已有系统的业务数据以及力、位移、液位、温度、湿度等感知数据。
(3)数据处理
数据处理通常包括数据训练,机器学习,深度学习,搜索,推理,决策等方式。
其中,机器学习和深度学习可以对数据进行符号化和形式化的智能信息建模、抽取、预处理、训练等。
推理是指在计算机或智能系统中,模拟人类的智能推理方式,依据推理控制策略,利用形式化的信息进行机器思维和求解问题的过程,典型的功能是搜索与匹配。
决策是指智能信息经过推理后进行决策的过程,通常提供分类、排序、预测等功能。
(4)通用能力
对数据经过上面提到的数据处理后,进一步基于数据处理的结果可以形成一些通用的能力,比如可以是算法或者一个通用系统,例如,翻译,文本的分析,计算机视觉的处理,语音识别,图像的识别等等。
(5)智能产品及行业应用
智能产品及行业应用指人工智能系统在各领域的产品和应用,是对人工智能整体解决方案的封装,将智能信息决策产品化、实现落地应用,其应用领域主要包括:智能终端、智能交通、智能医疗、自动驾驶、智慧城市等。
本申请实施例可以应用于客户端、也可以应用于云端,还可以应用于对各种应用联邦学习场景中采用到的机器学习模型进行训练,训练后的机器学习模型可以应用于上述各种应用领域中以实现分类、回归或其他功能,训练后的机器学习模型的处理对象可以为图像样本、离散数据样本、文本样本或语音样本等,此处不做穷举。其中机器学习模型具体可以表现为神经网络、线性模型或其他类型的机器学习模型等,对应的,组成机器学习模型的多个模块具体可以表现为神经网络模块、现行模型模块或组成其他类型的机器学习模型的模块等,此处不做穷举。在后续实施例中,仅以机器学习模型表现为神经网络为例进行说明,对于机器学习模型表现为除神经网络之外的其他类型时可以类推理解,本申请实施例中不再赘述。
本申请实施例可以应用于客户端、云端、或联邦学习等,主要是对神经网络进行训练与剪枝,因此涉及了大量神经网络的相关应用。为了更好地理解本申请实施例的方案,下面先对本申请实施例可能涉及的神经网络的相关术语和概念进行介绍。
1、神经网络
神经网络可以是由神经单元组成的,神经单元可以是指以Xs和截距1为输入的运算单元,该运算单元的输出可以如公式(1-1)所示:
其中,s=1、2、……n,n为大于1的自然数,Ws为xs的权重,b为神经单元的偏置。f为神经单元的激活函数(activation functions),用于将非线性特性引入神经网络中,来将神经单元中的输入信号转换为输出信号。该激活函数的输出信号可以作为下一层卷积层的输入,激活函数可以是sigmoid函数。神经网络是将多个上述单一的神经单元联结在一起形成的网络,即一个神经单元的输出可以是另一个神经单元的输入。每个神经单元的输入可以与前一层的局部接受域相连,来提取局部接受域的特征,局部接受域可以是由若干个神经单元组成的区域。
2、深度神经网络
深度神经网络(deep neural network,DNN),也称多层神经网络,可以理解为具有多层中间层的神经网络。按照不同层的位置对DNN进行划分,DNN内部的神经网络可以分为三类:输入层,中间层,输出层。一般来说第一层是输入层,最后一层是输出层,中间的层数都是中间层,或者称为隐藏层。在未进行剪枝或压缩的神经网络中,层与层之间是全连接的,也就是说,第i层的任意一个神经元一定与第i+1层的任意一个神经元相连。
虽然DNN看起来很复杂,其每一层可以表示为线性关系表达式:其中,是输入向量,是输出向量,是偏移向量或者称为偏置参数,w是权重矩阵(也称系数),α()是激活函数。每一层仅仅是对输入向量经过如此简单的操作得到输出向量由于DNN层数多,系数W和偏移向量的数量也比较多。这些参数在DNN中的定义如下所述:以系数w为例:假设在一个三层的DNN中,第二层的第4个神经元到第三层的第2个神经元的线性系数定义为上标3代表系数W所在的层数,而下标对应的是输出的第三层索引2和输入的第二层索引4。
需要注意的是,输入层是没有W参数的。在深度神经网络中,更多的中间层让网络更能够刻画现实世界中的复杂情形。理论上而言,参数越多的模型复杂度越高,“容量”也就越大,也就意味着它能完成更复杂的学习任务。训练深度神经网络的也就是学习权重矩阵的过程,其最终目的是得到训练好的深度神经网络的所有层的权重矩阵(由很多层的向量W形成的权重矩阵)。
3、卷积神经网络
卷积神经网络(convolutional neuron network,CNN)是一种带有卷积结构的深度神经网络。卷积神经网络包含了一个由卷积层和子采样层构成的特征抽取器,该特征抽取器可以看作是滤波器。卷积层是指卷积神经网络中对输入信号进行卷积处理的神经元层。在卷积神经网络的卷积层中,一个神经元可以只与部分邻层神经元连接。一个卷积层中,通常包含若干个特征平面,每个特征平面可以由一些矩形排列的神经单元组成。同一特征平面的神经单元共享权重,这里共享的权重就是卷积核。共享权重可以理解为提取图像信息的方式与位置无关。卷积核可以以随机大小的矩阵的形式初始化,在卷积神经网络的训练过程中卷积核可以通过学习得到合理的权重。另外,共享权重带来的直接好处是减少卷积神经网络各层之间的连接,同时又降低了过拟合的风险。
4、循环神经网络(recurrent neural network,RNN)
在传统的神经网络中模型中,层与层之间是全连接的,每层之间的设备是无连接的。但是这种普通的神经网络对于很多问题是无法解决的。比如,预测句子的下一个单词是什么,因为一个句子中前后单词并不是独立的,一般需要用到前面的单词。循环神经网络指的是一个序列当前的输出与之前的输出也有关。具体的表现形式为网络会对前面的信息进行记忆,保存在网络的内部状态中,并应用于当前输出的计算中。
5、损失函数
在训练深度神经网络的过程中,因为希望深度神经网络的输出尽可能的接近真正想要预测的值,所以可以通过比较当前网络的预测值和真正想要的目标值,再根据两者之间的差异情况来更新每一层神经网络的权重向量(当然,在第一次更新之前通常会有初始化的过程,即为深度神经网络中的各层预先配置参数),比如,如果网络的预测值高了,就调整权重向量让它预测低一些,不断地调整,直到深度神经网络能够预测出真正想要的目标值或与真正想要的目标值非常接近的值。因此,就需要预先定义“如何比较预测值和目标值之间的差异”,这便是损失函数(loss function)或目标函数(objective function),它们是用于衡量预测值和目标值的差异的重要方程。其中,以损失函数举例,损失函数的输出值(loss)越高表示差异越大,那么深度神经网络的训练就变成了尽可能缩小这个loss的过程。该损失函数通常可以包括误差平方均方、交叉熵、对数、指数等损失函数。例如,可以使用均方误差作为损失函数,定义为具体可以根据实际应用场景选择具体的损失函数。
6、反向传播算法
神经网络可以采用误差反向传播(back propagation,BP)算法在训练过程中修正初始的神经网络模型中参数的大小,使得神经网络模型的重建误差损失越来越小。具体地,前向传递输入信号直至输出会产生误差损失,通过反向传播误差损失信息来更新初始的神经网络模型中参数,从而使误差损失收敛。反向传播算法是以误差损失为主导的反向传播运动,旨在得到最优的神经网络模型的参数,例如权重矩阵。
在本申请中,客户端在进行模型训练时,即可通过损失函数或者通过BP算法来对全局模型进行训练,以得到训练后的全局模型。
7、联邦学习(federated learning,FL)
一种分布式机器学习算法,通过多个客户端,如移动设备或边缘服务器,和服务器在数据不出域的前提下,协作式完成模型训练和算法更新,以得到训练后的全局模型。可以理解为,在进行机器学习的过程中,各参与方可借助其他方数据进行联合建模。各方无需共享数据资源,即数据不出本地的情况下,进行数据联合训练,建立共享的机器学习模型。
首先,本申请实施例可以应用于模型处理设备(例如客户端、云端)或联邦学习系统。下面先对本申请提供的联邦学习系统进行介绍。
参阅图2,本申请提供的一种联邦学习系统的架构示意图。该系统(或者也可以简称为集群)中可以包括多个服务器,该多个服务器之间可以互相建立连接,即各个服务器之间也可以进行通信。每个服务器可以和一个或者多个客户端通信,客户端可以部署于各种设备中,如部署于移动终端或者服务器等,如图2中所示出的客户端1、客户端2、…、客户端N-1以及客户端N等。
具体地,服务器之间或者服务器与客户端之间,可以通过任何通信机制/通信标准的通信网络进行交互,通信网络可以是广域网、局域网、点对点连接等方式,或它们的任意组合。具体地,该通信网络可以包括无线网络、有线网络或者无线网络与有线网络的组合等。该无线网络包括但不限于:第五代移动通信技术(5th-Generation,5G)系统,长期演进(long term evolution,LTE)系统、全球移动通信系统(global system for mobilecommunication,GSM)或码分多址(code division multiple access,CDMA)网络、宽带码分多址(wideband code division multiple access,WCDMA)网络、无线保真(wirelessfidelity,WiFi)、蓝牙(bluetooth)、紫蜂协议(Zigbee)、射频识别技术(radio frequencyidentification,RFID)、远程(Long Range,Lora)无线通信、近距离无线通信(near fieldcommunication,NFC)中的任意一种或多种的组合。该有线网络可以包括光纤通信网络或同轴电缆组成的网络等。
通常,客户端可以部署于各种服务器或者终端中,以下所提及的客户端也可以是指部署了客户端软件程序的服务器或者终端,该终端可以包括移动终端或者固定安装的终端等,例如,该终端具体可以包括手机、平板、个人计算机(personal computer,PC)、智能手环、音响、电视、智能手表或其他终端等。
在进行联邦学习时,每个服务器可以向与其建立了连接的客户端下发待训练的模型,客户端可以使用本地存储的训练样本对该模型进行训练,并将训练后的模型的参数等数据反馈至服务器,服务器在接收到一个或者多个客户端反馈的训练后的一个或者多个模型之后,可以对收到的一个或者多个模型进行剪枝,并对剪枝后的一个或者多个模型的数据进行聚合,以得到聚合后的数据,相当于聚合后的模型。在满足停止条件之后,即可输出最终的模型,完成联邦学习。
通常,为解决客户端和服务器之间距离较远而导致的传输时延大的问题,一般在服务器和客户端之间引入中间层服务器(本申请称为边服务器),形成多层架构,即客户端-边服务器-云服务器的架构,从而通过边服务器来减少客户端和联邦学习系统之间的传输时延。
具体地,本申请提供的联邦学习方法可以应用的联邦学习系统可以包括多种拓扑关系,如联邦学习系统可以包括两层或者两层以上的架构,下面对一些可能的架构进行示例性介绍。
一、两层架构
如图3所示,本申请提供的一种联邦学习系统的结构示意图。
其中,该联邦学习系统内包括服务器-客户端形成的两层架构。服务器可以直接与一个或者多个客户端直接建立连接。
在联邦学习的过程中,服务器向与其建立了连接的一个或者多个客户端下发全局模型。
一种可能实现的方式中,客户端使用本地存储的训练样本对接收到的全局模型进行训练,并将训练后的全局模型反馈至服务器,服务器基于接收到的训练后的全局模型进行剪枝,并对本地存储的全局模型进行更新,以得到最终的全局模型。
另一种可能实现的方式中,客户端使用本地存储的训练样本对接收到的全局模型进行训练与剪枝,并将剪枝后的全局模型反馈至服务器,服务器基于接收到的训练与剪枝后的全局模型对本地存储的全局模型进行更新,以得到最终的全局模型。
二、三层架构
如图4所示,本申请提供的一种联邦学习系统的结构示意图。
其中,联邦学习系统中包括了一个或多个云服务器、一个或多个边服务器以及一个或者多个客户端,形成云服务器-边服务器-客户端三层架构。
在该系统中,一个或者多个边服务器接入云服务器,一个或者多个客户端接入边服务器。
在进行联邦学习的过程中,云服务器将本地保存的全局模型下发给边服务器,然后边服务器将该全局模型下发给与其连接的客户端。
一种可能实现的方式中,客户端使用本地存储的训练样本对接收到的全局模型进行训练,并将训练后的全局模型反馈给边服务器,边服务器对接收到的训练后的全局模型进行剪枝,并用剪枝后的全局模型对本地存储的全局模型进行更新,并将边服务器更新后的全局模型再反馈至云服务器,完成联邦学习。
另一种可能实现的方式中,客户端使用本地存储的训练样本对接收到的全局模型进行训练与剪枝,并将训练与剪枝后的全局模型反馈给边服务器,边服务器根据接收到的训练后的全局模型对本地存储的全局模型进行更新,并将边服务器更新后的全局模型再反馈至云服务器,完成联邦学习。
一种可能实现的方式中,客户端使用本地存储的训练样本对接收到的全局模型进行训练,并将训练后的全局模型反馈给边服务器,边服务器根据接收到的训练后的全局模型对本地存储的全局模型进行更新,并将边服务器更新后的全局模型再反馈至云服务器,云服务器再对接收到的全局模型进行剪枝,得到剪枝后的全局模型,完成联邦学习。
模型剪枝的过程可以在客户端,也可以在边服务器,还可以在云服务器,除了上面举例的只在客户端或边服务器或云服务器进行模型剪枝的方式,在联邦学习的过程中,也可以在多个环节都进行剪枝的过程,例如客户端训练模型时进行剪枝后再发送给边服务器,边服务器聚合模型时也进行剪枝,再发给云服务器处理,具体此处不做限定。
三、三层以上架构
如图5所示,本申请提供的另一种联邦学习系统的结构示意图。
其中,该联邦学习系统中包括了三层以上内的架构,其中一层包括一个或者多个云服务器,多个边服务器形成两层或者两层以上的架构,如一个或者多个上游边服务器组成一层架构,每个上游边服务器与一个或者多个下游边服务器连接。边服务器形成的最后一层架构中的每个边服务器和一个或者多个客户端连接,从而客户端形成一层架构。
在联邦学习的过程中,最上游的云服务器将本地存储的最新的全局模型下发给下一层的边服务器,随后边服务器向下一层逐层下发全局模型,直至下发至客户端。客户端在接收到边服务器下发的全局模型之后,使用本地存储的训练样本对接收到的全局模型进行训练与剪枝,并将训练与剪枝后的全局模型反馈给上一层的边服务器,然后上一层边服务器基于接收到的训练后的全局模型对本地存储的全局模型进行更新之后,即可将更新后的全局模型上传至更上一层的边服务器,以此类推,直到第二层边服务器将更新后的全局模型上传至云服务器,云服务器基于接收到的全局模型更新本地的全局模型,以得到最终的全局模型,完成联邦学习。可以理解的是,这里仅以客户端对模型进行训练与剪枝为例进行说明,与上述三层架构类似,剪枝的过程可以在联邦学习系统中的任意一层,具体此处不做限定。
需要说明的是,在本申请中,针对联邦学习架构中的每个设备,将向云服务器传输数据的方向称为上游,将向客户端传输数据的方向称为下游,例如,如图3中所示,服务器是客户端的上游设备,客户端是服务器的下游设备,如图4所示,云服务器可以称为边服务器的上游设备,客户端可以称为边服务器的下游设备等,以此类推。
另外,对本申请实施例中的模型(例如:神经网络模型、第一模型、第二模型、第三模型等等)所应用的场景做下简单介绍,该模型可以应用于前述的智能终端、智能交通、智能医疗、自动驾驶、智慧城市等任何需要神经网络模型对文本、图像或语音等输入数据进行分类、识别、预测、推荐、翻译、编码、解码等场景。
前述对本申请提供的联邦学习系统以及模型的应用场景进行了介绍,下面对该联邦学习系统中各个设备执行的详细步骤进行介绍。
本申请实施例中,在联邦学习系统下,根据剪枝步骤是由上游设备(云服务器或边服务器)还是由客户端执行可以分为两种情况,下面分别描述:
第一种,上游设备执行剪枝步骤。
参阅图6,本申请实施例提供的联邦学习方法一个实施例,该实施例包括步骤601至步骤608。
步骤601,上游设备向客户端发送神经网络模型。相应的,客户端接收上游设备发送的神经网络模型。
本申请实施例中的上游设备可以是前述图2-图5中的联邦学习系统中的服务器。例如,该上游设备可以是如前述图2中所示出的多个服务器中的任意一个服务器,也可以是前述图3中所示出的两层架构中的任意一个服务器,也可以是如图4中所示的云服务器或者边服务器中的任意一个,还可以是如图5中所述示出的云服务器或者边服务器中的任意一个。该客户端的数量可以是一个或者多个,若上游设备与多个客户端建立了连接,则上游设备可以向每个客户端发送神经网络模型。
其中,上述的神经网络模型可以是上游设备本地存储的模型,如云服务器本地存储的全局模型,或者上游设备可以接收到其他服务器发送的模型之后,将接收到的模型保存在本地或者更新本地存储的模型。具体地,上游设备可以向客户端发送神经网络模型的结构参数(如神经网络的宽度、深度或者卷积核大小等)或者初始权重参数等,可选地,上游设备还可以向客户端发送训练配置参数,如学习率、epoch数量或者安全算法中类别等参数,以使最终进行训练的客户端可以使用该训练配置参数来对神经网络模型进行训练。
例如,当上游设备为云服务器时,该神经网络模型可以是云服务器上保存的全局模型,为便于区分,以下将云服务器上保存的全局模型称为云侧模型。
又例如,当该上游设备是边服务器时,该神经网络模型可以是边服务器上保存的本地模型,或者称为边服务器模型,在边服务器接收到上一层边服务器或者云服务器下发的模型之后,使用接收到的模型作为边服务器模型或更新已有的边服务器模型,以得到新的边服务器模型,并向客户端发送新的边服务器模型(即神经网络模型)。还需要说明的是,当上游设备是边服务器时,上游设备可以直接向客户端下发边服务器模型(或者称为神经网络模型)。
在本申请实施例中,所提及的神经网络模型,如第一模型、第二模型或者第三模型等等,具体可以包括卷积神经网络(convolutional neural networks,CNN),深度卷积神经网络(deep convolutional neural networks,DCNN),循环神经网络(recurrent neuralnetwork,RNN)等神经网络,具体可以根据实际应用场景确定待学习的模型,本申请对此并不作限定。
可选地,上游设备可以主动向与其连接的客户端发送神经网络模型,也可以是在客户端的请求下向客户端发送神经网络模型。例如,若上游设备为边服务器,客户端可以向边服务器发送请求消息,以请求参与联邦学习,边服务器在接收到请求消息之后,若确认允许该客户端参与联邦学习,则可以向该客户端下发神经网络模型。又例如,若上游设备为云服务器,客户端可以向云服务器发送请求消息,以请求参与联邦学习,云服务器在接收到该请求消息,并确认允许该客户端参与联邦学习,则可以将本地存储的云侧模型下发给边服务器,边服务器根据接收到的模型更新本地的网络模型得到神经网络模型,并将神经网络模型下发给客户端。
步骤602,客户端以训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型。
客户端在接收到上游设备发送的神经网络模型之后,即可基于该神经网络模型更新本地存储的端侧模型,例如通过替换、加权融合等方式得到新的神经网络模型。并使用带标签值的训练数据与第一损失函数对该新的神经网络模型进行训练,从而得到第一模型。
本申请实施例中的训练数据可以有多种类型或形式,具体与模型所应用的场景相关。例如:当模型的作用是音频识别,则训练数据的具体形式可以是音频数据等。又例如:当模型的作用是图像分类,则训练数据的具体形式可以是图像数据等。再例如:当模型的作用是预测语音,则训练数据的具体形式可以是文本数据等。可以理解的是,上述几种情况只是举例,并且并不一定是一一对应的关系,例如对于音频识别,训练数据的具体形式还可以是图像数据或文本数据等(例如:若应用于教育领域中的看图播放语音场景,则模型的作用是识别图像对应的语音,则训练数据的具体形式可以是图像数据),在实际应用中,还有其他的场景,例如:当模型的作用的电影推荐场景,则训练数据可以是电影对应的词向量等。在一些应用场景,上述训练数据还可以同时包括不同模态的数据,比如在自动驾驶场景,训练数据可以包括摄像头采集的图像/视频数据,还可以包括用户发出指示的语音/文本数据等。本申请实施例中对于训练数据的具体形式或类型不做限定。
客户端可以使用本地保存的训练样本对神经网络模型(或者是上述的新的神经网络模型)进行训练,以得到第一模型。例如,客户端可以部署于移动终端中,该移动终端在运行过程中可以采集到大量数据,客户端可以将采集到的数据作为训练样本,从而对神经网络模型进行个性化的训练,以得到客户端的个性化模型。
其中,客户端(以其中一个第一客户端为例)对神经网络模型进行训练的过程具体可以包括:以训练数据作为神经网络模型的输入,以减小第一损失函数的值为目标对神经网络模型进行训练,以得到第一模型。第一损失函数用于指示神经网络模型的输出值与标签值之间的差异。进一步的,使用训练数据与优化算法对神经网络进行训练,以得到第一模型。该第一损失函数可以理解为是数据损失函数(简称训练loss)。
本申请实施例中的第一损失函数可以是均方误差损失,也可以是交叉熵损失等可以用来衡量神经网络模型输出值与标签值(或真实值)之间差异的函数。
上述的优化算法可以是梯度下降方法,也可以是牛顿法,还可以是自适应矩估计法等可用于机器学习中的优化算法,具体此处不做限定,下面以梯度算法为例进行描述。
可选地,梯度算法的一种具体形式如下:
其中,vi与分别表示联邦学习第n轮训练过程中更新前与更新后的神经网络模型参数,γ是梯度下降优化算法的学习率或每一步更新的步长。f为第一损失函数(具体形式可以是上述的均方误差损失、交叉熵损失等)。为f的梯度,例如对f求导得到。
可选地,上述公式一只是一种梯度算法的举例,在实际应用中,梯度算法还可以是其他类型的公式,具体此处不做限定。
可选地,上述公式二只是一种计算梯度的举例,在实际应用中,计算梯度还可以是其他类型的公式,具体此处不做限定。
本申请实施例中的第一模型可以是上述训练过程中的神经网络模型,也可以是第一损失函数的值小于第一阈值后的第一模型,换句话说,第一模型可以是上述训练过程中的神经网络模型,也可以是基于客户端的本地数据集训练结束后得到的模型,具体此处不做限定。
步骤603,客户端向上游设备发送第一模型。相应的,上游设备接收客户端发送的第一模型。
可选地,客户端向上游设备发送第一模型或者是第一模型的信息,例如权重参数、梯度参数等。相应的,上游设备接收客户端发送的第一模型。
步骤604,上游设备基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型。
上游设备接收客户端发送的第一模型之后,可以基于第二损失函数与约束条件对第一模型进行剪枝,以得到第二模型。
可选地,在接收到客户端发送的第一模型之后,可以先确定第一模型的子结构,再对第一模型进行子结构上的剪枝。该子结构包括至少两个神经元,且该子结构可以根据实际需要设置,子结构可以是神经网络模型的通道、特征图、网络层、子网络、或者预定义的由多个神经元组成的其他网络结构;当神经网络模型是卷积神经网络时,子结构还可以是卷积核。总之,一个子结构可以看作一个功能整体,在剪枝时对子结构进行剪枝,是指将该子结构包括的所有神经元都进行剪枝。通过对模型子结构的剪枝,能够从模型结构上对模型进行压缩,便于底层硬件加速的实现。
上述的第二损失函数可以理解为是稀疏损失函数(简称稀疏loss),该第二损失函数包括差异项与第一稀疏项,其中,差异项用于表示第一模型与第二模型之间参数的差异。第一稀疏项用于将第一模型的多个子结构中的至少一个子结构进行剪枝。约束条件用于约束第二模型的精度不低于第一模型的精度,该精度指示模型的输出值与标签值之间的差异程度。
可选地,第二损失函数的一种具体形式如下:
其中,n是迭代的次数,n为正整数,为差异项,为第一稀疏项。|| ||2为L2范数,Vn为第一模型的参数,Wn为第二模型的参数,λ为超参数,用于调节第一稀疏项的权重,本申请实施例中的超参数可以取任意非负实数。例如,当λ=0时,表示稀疏项的权重为0,即训练过程中不要求子结构稀疏,通常可以适用于第一模型较小,传输通信成本较小,无需要求第一模型子结构稀疏的场景。为第二模型中第i个子结构。
可以理解的是,上述的第二损失函数只是一种举例,在实际应用中,还可以有其他形式的第二损失函数,例如,第一稀疏项中的L2范数可以更换为L0范数、L1范数、L0范数的近似、L1范数的近似、L0与Lp混合范数、L1与Lp混合范数等可以用于引导变量稀疏性的函数。又例如,差异项可以替换为欧氏距离、马氏距离、互信息、余弦相似度、内积或者范数等其他任何衡量两个变量相似度或距离的函数。差异项与第一稀疏项的选择具体可以根据实际应用场景适配,具体此处不做限定。
可选地,若客户端的数量为多个,即上游设备接收的第一模型的数量为多个,则上游设备可以用相同的稀疏损失函数与约束条件对多个第一模型分别进行剪枝,以得到多个第二模型,多个第二模型与多个第一模型一一对应;当然,上游设备也可以根据子结构的类型对多个第一模型进行分组,每组对应的稀疏损失函数可以相同或不同,然后将同一组的多个第一模型结合起来进行剪枝;或者,上游设备也可以将所接收到的所有第一模型结合起来一起进行剪枝。对于多个第一模型的剪枝方式,具体此处不做限定。下面以一个第一模型为例描述具体的剪枝过程。
本申请实施例中的第一模型的剪枝方向可以理解为是对第二损失函数计算得到的Wn的下降方向,第一模型的训练数据方向可以理解为是对第一损失函数计算得到Vn的下降方向。例如,对于梯度下降来说,第一模型的剪枝方向可以理解为是对第二损失函数求导得到的Wn的梯度方向,第一模型的训练数据方向可以理解为是对第一损失函数求导得到Vn的梯度方向。
本申请实施例中,基于第二损失函数与约束条件对第一模型进行剪枝的方式有多种,下面举例描述:
第一种,通过引入更新系数si的方式对第一模型进行剪枝。
一种可能实现的方式中,基于约束条件计算更新系数,该更新系数用于调整第一稀疏项的方向。使用更新系数更新第二损失函数中的第一稀疏项,以得到第三损失函数,该第三损失函数包括差异项与第二稀疏项,第二稀疏项基于更新系数与第一稀疏项更新得到。获取第三损失函数之后,基于第三损失函数对第一模型进行剪枝,以得到第二模型。具体的,可以根据约束条件确定子空间,该子空间内的第二模型精度与第一模型精度相同。
可选地,第三损失函数的一种形式具体如下:
其中,si为所述更新系数,通过调节si以满足所述约束条件。其余描述可参考前述关于第二损失函数的描述,具体此处不再赘述。
可选地,上述公式四只是第三损失函数的一种举例,在实际应用中可以根据如前述第二损失函数的描述所设置,具体此处不做限定。
为了更直观的理解si,下面结合附图进行描述。
示例性的,请参阅图7A,假设第一模型参数Vn包括三个子结构,分别为V1、V2以及V3,为了方便在图7A中示出剪枝方向,先对Vn进行分组,假设分为两组(或者理解为2个子结构):a与b。其中,a={1,2},b={3}。这样,对第一模型参数Vn进行剪枝,可以理解为是对Va和/或Vb进行剪枝。图7A以对Va进行剪枝为例进行描述,即对第一模型进行剪枝直至Va变为0。其中,第二损失函数的下降方向即图7A中校正前的剪枝方向。
上述举例用数学表达式表示如下:
Va=(V1,V2);Vb=(V3);
其中,E()可以理解为是组内归一化算子。另外,海森(Hessian)矩阵有很多近似0的特征值,在这些特征值对应的方向上对模型参数加扰动几乎不会改变模型的精度,如图7A所示的P0表示这些方向生成的子空间(图7A中以P0是平面为例),该平面内第一模型的精度与第二模型的精度相同。Π0表示投影到子空间P0的投影算子,sa为a组对应的si,计算方式可参考下述公式:
可以理解的是,上述计算si的公式只是一种示例,实际应用中,还可以有其他类型的公式,具体此处不做限定。
示例性的,以第一模型包括输入层、隐藏层1-3、输出层为例,展示剪枝前的模型(即第一模型)与剪枝后的模型(即第二模型)对比图,可以参考图7B,其中,子结构可以是神经网络模型的通道、特征图、网络层、子网络、或者预定义的由多个神经元组成的其他网络结构;当神经网络模型是卷积神经网络时,子结构还可以是卷积核。总之,一个子结构可以看作一个功能整体,在剪枝时对子结构进行剪枝,是指将该子结构包括的所有神经元都进行剪枝。图7B仅以一个子结构包括2个神经元为例进行描述。可以理解的是,一个子结构可以包括更多或更少的神经元。从图7B可以看出,剪枝后的模型相较于剪枝前的模型减少了两个子结构。当然,图7B只是为了更加直观的描述剪枝前后模型的变化,剪枝子结构的数量可以是一个或多个,具体此处不做限定。
另一种可能实现的方式中,还可以通过si调整Vb对Va进行剪枝,进而使得预测的训练数据方向V′n在进行剪枝后得到的方向为矫正后的剪枝方向,基于该方向对第一模型进行剪枝,对第一模型的精度影响较小。示例性的,请参阅图8。
第二种,基于约束条件校正第一模型的剪枝方向,得到校正后的剪枝方向,并基于校正后的剪枝方向对第一模型进行剪枝。
梯度方向梯度方向本申请实施例中基于约束条件校正第一模型的剪枝方向的方式有多种,下面分别描述:
1、基于约束条件确定第一模型的剪枝方向。
除了上述第一种方式中引入更新系数si之外,还可以通过下述的方式确定较优的第二模型(即剪枝后的模型):
可选地,假设是一边训练数据一边进行剪枝,则步骤602中梯度算法的另一种形式可以如下所示:
其中,n为迭代次数,Zn+1为第n组训练数据,且Zn+1∈Z,γ是梯度下降优化算法的学习率或每一步更新的步长。f为第一损失函数(具体形式可以是上述的均方误差损失、交叉熵损失等)。为f的梯度,例如对f求导得到。其余参数的解释可以参考前述步骤602中对于梯度算法的解释,此处不再赘述。g(n,γ)为调节剪枝方向的函数,c、μ是控制稀疏惩罚强度的两个超参数,μ∈(0.5,1],i表示第i个子结构,
换句话说,使用训练数据Z更新Vn得到第一模型Vn+1,并基于Vn+1更新Wn+1,再用Wn+1替换上述梯度算法中的Vn,从而不断对第一模型进行剪枝直至得到满足实际需要的第二模型(例如迭代次数达到阈值或第二模型精度/准确度达到阈值等)。
可选地,上述公式六与公式七只是解较优/最优第二模型的一种举例,在实际应用中可以有其他方式,具体此处不做限定。例如:上述的公式七可以替换为下述公式八:
其中,()+表示只取其中大于0的数,小于0的数置零。
2、确定第一模型精度与第二模型精度一致的子空间,基于子空间校正第一模型的剪枝方向。
基于约束条件确定第一模型精度与第二模型精度相同(或一致)的子空间(如上述描述的P0,此处不再赘述),并基于子空间校正第一模型的剪枝方向。
本申请实施例中,基于子空间校正第一模型的剪枝方向的方法的方式有多种,下面分别描述:
2.1、将第一模型的剪枝方向投影到子空间得到校正后的剪枝方向。
可选地,可以将第一模型的剪枝方向投影到子空间得到校正后的剪枝方向,第一模型根据校正后的剪枝方向进行剪枝,可以保证剪枝前的第一模型精度与剪枝后的第二模型精度相同。
示例性的,如图7A所示,确定子空间P0之后,将校正前的剪枝方向投影至子空间P0,从而得到如图7A所示的校正后的剪枝方向。
2.2、第一模型的剪枝方向基于子空间做镜像得到校正后的剪枝方向。
可选地,确定子空间后,可以基于子空间做镜像得到校正后的剪枝方向,第一模型根据校正后的剪枝方向进行剪枝,可以保证剪枝前的第一模型精度与剪枝后的第二模型精度相近。或者说第二模型的第一损失函数的值小于第一模型的第一损失函数的值。
示例性的,如图9所示,确定子空间P0之后,基于子空间P0对校正前的剪枝方向进行镜像处理,从而得到如图9所示镜像后的剪枝方向(即修正后的剪枝方向)。
3、若第一模型的剪枝方向与第一模型的训练数据方向之间的夹角为钝角,调整至锐角。
可选地,在对第一模型剪枝之前,可以先确定第一模型的训练数据方向与第一模型的剪枝方向之间的夹角,若该夹角为钝角(即剪枝方向与数据训练方向相反,这也是现有技术中对模型剪枝后需要对模型进行微调的原因),则调整第一模型的剪枝方向以满足校正后的剪枝方向与第一模型的训练数据方向之间的夹角为锐角或直角。再根据校正后的剪枝方向对第一模型进行剪枝,可以保证剪枝前的第一模型精度与剪枝后的第二模型精度相近。或者说第二模型的第一损失函数的值小于或等于第一模型的第一损失函数的值。
示例性的,如图10所示,校正前的剪枝方向可能与第一模型的数据训练方向之间的夹角为钝角,即两个方向不一致。为了保证剪枝方向与数据训练方向在一个大的方向上一致,减少后续对模型进行微调的步骤。将剪枝方向调整到可以与第一模型的数据训练方向之间夹角为锐角或直角的方向范围内。
可以理解的是,基于第二损失函数与约束条件对第一模型进行剪枝的方式有多种,上述几种只是举例说明,在实际应用中,基于第二损失函数与约束条件对第一模型进行剪枝还可以有其他方式,例如:可以重复随机优化第一模型,直至剪枝后的第二模型满足约束条件等,具体此处不做限定。
步骤605,上游设备聚合多个第二模型,得到第三模型。本步骤是可选地。
上游设备基于步骤605获取多个第二模型之后,可以聚合多个第二模型,得到第三模型。并将第三模型作为全局模型。
可选地,将第三模型作为步骤601中的神经网络模型重复执行步骤601至步骤605。换句话说,步骤601至步骤605算一次迭代,本申请实施例中的步骤601至步骤605可以执行多次。进一步的,若步骤601至步骤605循环执行,可以设置步骤601至步骤605的停止条件(也可以理解为是剪枝更新的停止条件),该停止条件可以是循环次数、剪枝后模型的稀疏程度达到某个阈值等,具体此处不做限定。
本申请实施例中,上述聚合的方式可以是求多个第二模型的加权平均,也可以是求多个第二模型的平均等,具体此处不做限定。
可选地,若每个客户端的模型的数据是同分布的,则聚合得到的第三模型的精度高于第一模型的精度,若每个客户端的模型的数据是非同步的,则聚合得到的第三模型的精度可能低于第一模型的精度。
可选地,上游设备聚合多个第二模型得到第三模型之后,还可以通过损失函数与约束条件对第三模型进行剪枝,得到第四模型。若上游设备为云服务器,则上游设备可以向边服务器或客户端发送第四模型。若上游设备为边服务器,则边服务器可以向云服务发送第四模型,以便于上层服务器对第四模型进行处理。
步骤606,是否满足第一预设条件,若是,训练结束。若否,以第三模型作为神经网络模型重复执行前述步骤。本步骤是可选地。
可选地,上游设备获取第三模型之后,可以判断是否满足第一预设条件,若是(满足),则模型的训练过程结束。训练过程结束后,可以执行步骤607与步骤608,或者可以向上游的边服务器或云服务器发送第三模型。本申请实施例中对于训练结束后执行的步骤不做限定。
若否(不满足),将第三模型作为步骤601中的神经网络模型重复执行图6所示步骤601至步骤606的步骤(或者理解为一次迭代)。即上游设备向客户端发送第三模型,客户端使用本地数据集训练第三模型得到第四模型,并向上游设备发送第四模型。上游设备接收多个客户端发送的多个第四模型,上游设备基于稀疏损失函数与约束条件对多个第四模型进行剪枝,得到多个第五模型。并聚合多个第五模型以得到第六模型,再判断是否满足第一预设条件,若满足,训练结束。若不满足,将第六模型作为上述第一次迭代的神经网络模型或第二次迭代的第三模型重复执行图6所示的步骤,直至满足第一预设条件。
可选地,若不满足第一预设条件,迭代过程中,上游设备还可以向客户端发送用于指示训练未结束的第一指示信息,以便于客户端根据第一指示信息确定是继续训练模型。
其中,第一预设条件可以是第三模型收敛、步骤601至步骤605的循环次数达到阈值、全局模型准确率达到阈值等,具体此处不做限定。
步骤607,上游设备向客户端发送第三模型与第二指示信息。本步骤是可选地,
可选地,若满足第一预设条件,上游设备向客户端发送第三模型与第二指示信息,该第二指示信息用于指示第三模型的训练过程结束。
步骤608,客户端根据第二指示信息使用第三模型进行推理。本步骤是可选地。
可选地,客户端接收到第三模型与第二指示信息之后,根据第二指示信息可以获知第三模型的训练过程已结束,并使用第三模型进行推理。
本实施例中,在联邦学习场景下,客户端通过本地数据对神经网络模型进行训练得到第一模型,并向上游设备发送该第一模型。上游设备再根据约束条件对第一模型进行剪枝。一方面,上游设备在剪枝过程中考虑约束条件,使得剪枝后的第二模型精度高于或等于第一模型,也可以理解为是不会增加训练损失的剪枝,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。另一方面,通过对模型子结构的剪枝,能够从模型结构上对模型进行压缩,便于底层硬件加速的实现。并且减小了模型体积,降低了客户端的存储和计算开销。
第二种,客户端执行剪枝步骤。
参阅图11,本申请实施例提供的联邦学习方法的另一个实施例,该实施例包括步骤1101至步骤1108。
步骤1101,上游设备向客户端发送神经网络模型。相应的,客户端接收上游设备发送的神经网络模型。
步骤1102,客户端以训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型。
本实施例中的步骤1101与步骤1102与前述图6所示实施例中的步骤601与步骤602类似,此处不再赘述。
步骤1103,客户端基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型。
本实施例中客户端执行的步骤1103与前述图6所示实施例中上游设备执行的步骤604类似,此处不再赘述。
可选地,可以将步骤1102与步骤1103视为一次迭代过程。客户端获取第二模型之后,可以判断是否满足第二预设条件,若是(满足),则执行步骤1104。若否(不满足),将第三模型作为步骤1102中的神经网络模型重复执行步骤1102与步骤1103(或者理解为一次迭代)。即以训练数据为输入,根据第一损失函数训练第三模型,得到训练后的第三模型。并基于约束条件对训练后的第三模型进行剪枝得到第七模型,再判断是否满足第一预设条件,若满足,执行步骤1104。若不满足,将第七模型作为上述第一次迭代的神经网络模型或第二次迭代的第三模型重复执行步骤1102与步骤1103,直至满足第二预设条件。
其中,第二预设条件可以是模型收敛、步骤1102与步骤1103的循环次数达到阈值、模型准确率达到阈值等,具体此处不做限定。
可选地,也可以将步骤1103视为一次迭代过程。客户端获取第二模型之后,可以判断是否满足第三预设条件,若是(满足),则执行步骤1104。若否(不满足),将第三模型作为步骤1102中的神经网络模型重复执行步骤1102,直至满足第三预设条件。
其中,第三预设条件可以是模型收敛、步骤1102的循环次数达到阈值、模型准确率达到阈值等,具体此处不做限定。
步骤1104,客户端向上游设备发送第二模型。相应的,上游设备接收客户端发送的第二模型。
可选地,客户端向上游设备发送第二模型或者是第二模型的信息,例如权重参数、梯度参数等。相应的,上游设备接收客户端发送的第二模型。
步骤1105,上游设备聚合多个第二模型,得到第三模型。
本实施例中的步骤1105与前述图6所示实施例中的步骤605类似,此处不再赘述。
步骤1106,是否满足第一预设条件,若是,训练结束。若否,以第三模型作为神经网络模型重复执行前述步骤。本步骤是可选地。
可选地,上游设备获取第三模型之后,可以判断是否满足第一预设条件,若是(满足),则训练结束。训练过程结束后,可以执行步骤1107与步骤1108,或者可以向上游的边服务器或云服务器发送第三模型。本申请实施例中对于训练结束后执行的步骤不做限定。
若否(不满足),将第三模型作为步骤601中的神经网络模型重复执行图6所示的步骤(或者理解为一次迭代)。即上游设备向客户端发送第三模型,客户端使用本地数据集训练第三模型得到第四模型,并基于稀疏损失函数与约束条件对多个第四模型进行剪枝,得到多个第五模型。向上游设备发送第五模型。上游设备接收多个客户端发送的多个第五模型,并聚合多个第五模型以得到第六模型,再判断是否满足第一预设条件,若满足,训练结束。若不满足,将第六模型作为上述第一次迭代的神经网络模型或第二次迭代的第三模型重复执行图6所示的步骤,直至满足第一预设条件。
可选地,若不满足第一预设条件,迭代过程中,上游设备还可以向客户端发送用于指示训练未结束的第一指示信息,以便于客户端根据第一指示信息确定是继续训练模型。
其中,第一预设条件可以是第三模型收敛、步骤1101至步骤1105的循环次数达到阈值、全局模型准确率达到阈值等,具体此处不做限定。
步骤1107,上游设备向客户端发送第二指示信息。本步骤是可选地。
步骤1108,客户端根据第二指示信息使用第三模型进行推理。本步骤是可选地。
本实施例中的步骤1107、步骤1108与前述图6所示实施例中的步骤607、步骤608类似,此处不再赘述。
该实施例与前述图6所示的实施例的区别主要是,图6所示实施例中的剪枝步骤由上游设备执行,本实施例中的剪枝步骤由客户端执行。
本实施例中,在联邦学习场景下,客户端通过本地数据对神经网络模型进行训练得到第一模型,再根据约束条件对第一模型进行剪枝得到第二模型,并向上游设备发送第二模型,上游设备根据第二模型进行聚合得到全局模型。一方面,上游设备在剪枝过程中考虑约束条件,使得剪枝后的第二模型精度高于或等于第一模型,也可以理解为是不会增加训练损失的剪枝,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。另一方面,通过对模型子结构的剪枝,能够从模型结构上对模型进行压缩,便于底层硬件加速的实现。并且减小了模型体积,降低了客户端的存储和计算开销。
上述对本申请实施例中的方法应用于联邦学习场景进行了描述,下面对本申请实施例还提供的一种模型处理方法进行描述。请参阅图12,本申请实施例提供的模型处理方法一个实施例,该方法可以由模型处理设备(例如客户端)执行,也可以由模型处理设备的部件(例如处理器、芯片、或芯片系统等)执行,其中,模型处理设备可以是云服务器或客户端,该实施例包括步骤1201至步骤1203。
步骤1201,获取包括标签值的训练数据。
本申请实施例中基于标签值的训练数据可以存储在服务器等其他设备中,模型处理设备通过服务器等其他设备获取该训练数据。也可以通过模型处理设备在运行过程中采集得到,具体此处不做限定。
步骤1202,以训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型。
本实施例中模型处理设备执行的步骤1202与前述图6所示实施例中客户端执行的步骤602类似,此处不再赘述。
步骤1203,基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型。
本实施例中模型处理设备执行的步骤1203与前述图6所示实施例中上游设备执行的步骤604类似,此处不再赘述。
本实施例中,模型处理设备通过本地数据对神经网络模型进行训练得到第一模型,再根据约束条件对第一模型进行剪枝得到第二模型。模型处理设备在剪枝过程中考虑约束条件,使得剪枝后的第二模型精度高于或等于第一模型,也可以理解为是不会增加训练损失的剪枝,减少后续模型微调的过程。
上面对本申请实施例中的模型处理方法与联邦学习方法进行了描述,下面对本申请实施例中的模型处理设备与上游设备进行描述,请参阅图13,本申请实施例中模型处理设备的一个实施例包括:
获取单元1301,用于获取包括标签值的训练数据;
训练单元1302,用于以训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型,第一模型包括多个子结构,多个子结构中的每个子结构包括至少两个神经元;
剪枝单元1303,用于基于第二损失函数以及约束条件对第一模型进行剪枝,以得到第二模型,第二损失函数用于指示将多个子结构中的至少一个子结构进行剪枝,约束条件用于约束第二模型的精度不低于第一模型的精度,精度指示模型的输出值与标签值之间的差异程度。
可选地,模型处理设备还包括:
接收单元1304,用于接收上游设备发送的神经网络模型;
发送单元1305,用于向上游设备发送第二模型。
本实施例中,模型处理设备中各单元所执行的操作与前述图2至图5或图11中客户端或图12所示实施例中模型处理设备执行的步骤、相关描述类似,此处不再赘述。
本实施例中,剪枝单元1303在对第一模型进行剪枝的过程中,考虑基于数据损失函数的约束条件,相当于为第一模型的剪枝提供一个方向,使得剪枝得到的第二模型的精度不低于第一模型的精度,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。
请参阅图14,本申请实施例中上游设备的一个实施例,该上游设备可以是前述的云服务器,也可以是前述的边服务器,该上游设备包括:
发送单元1401,用于向多个下游设备发送神经网络模型,神经网络模型包括多个子结构,多个子结构中的每个子结构包括至少两个神经元;
接收单元1402,用于接收来自多个下游设备的多个第一模型,多个第一模型由神经网络模型训练得到;
剪枝单元1403,用于基于损失函数以及约束条件对多个第一模型分别进行剪枝,其中,损失函数用于指示对多个第一模型的子结构进行剪枝,约束条件用于约束每个第一模型剪枝后的精度不低于剪枝前的精度;
聚合单元1404,用于将剪枝后的多个第一模型进行聚合,以得到第二模型。
本实施例中,上游设备中各单元所执行的操作与前述图2至图11所示实施例中云服务器或边服务器执行的步骤、相关描述类似,此处不再赘述。
本实施例中,在联邦学习场景下,客户端通过本地数据对神经网络模型进行训练得到第一模型,并向上游设备发送该第一模型。剪枝单元1403在根据约束条件对第一模型进行剪枝。一方面,剪枝单元1403在剪枝过程中考虑约束条件,使得剪枝前后第一模型的精度近似,也可以理解为是不会增加训练损失的剪枝,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。另一方面,剪枝单元1403通过对模型子结构的剪枝,能够从模型结构上对模型进行压缩,便于底层硬件加速的实现。并且减小了模型体积,降低了客户端的存储和计算开销。
请参阅图15,本申请实施例中上游设备的另一个实施例,该上游设备可以是前述的云服务器,也可以是前述的边服务器,该上游设备包括:
发送单元1501,用于向多个下游设备发送神经网络模型,所述神经网络模型包括多个子结构,多个子结构中的每个子结构包括至少两个神经元;
接收单元1502,用于接收来自所述多个下游设备的多个第一模型,所述多个第一模型由所述神经网络模型训练得到;
聚合单元1503,用于将所述多个第一模型进行聚合,以得到第二模型;
剪枝单元1504,用于基于损失函数以及约束条件对所述第二模型进行剪枝,其中,所述损失函数用于指示对所述第二模型的所述子结构进行剪枝,所述约束条件用于约束所述第二模型剪枝后的精度不低于剪枝前的精度。
本实施例中,上游设备中各单元所执行的操作与前述图2至图11所示实施例中云服务器或边服务器执行的步骤、相关描述类似,此处不再赘述。
本实施例中,在联邦学习场景下,客户端通过本地数据对神经网络模型进行训练得到第一模型,并向上游设备发送该第一模型。剪枝单元1504在根据约束条件对第二模型进行剪枝。一方面,剪枝单元1504在剪枝过程中考虑约束条件,使得剪枝前后的第二模型精度近似,也可以理解为是不会增加训练损失的剪枝,减少后续通过微调调整模型精度的步骤,从而在保证剪枝后模型的精度的同时提升模型剪枝过程的效率。另一方面,剪枝单元1504通过对模型子结构的剪枝,能够从模型结构上对模型进行压缩,便于底层硬件加速的实现。并且减小了模型体积,降低了客户端的存储和计算开销。
本申请实施例还提供了一种模型处理设备,如图16所示,为了便于说明,仅示出了与本申请实施例相关的部分,具体技术细节未揭示的,请参照本申请实施例方法部分(即前述图2至图11中客户端或图12所示实施例中模型处理设备执行的步骤与相关描述类似)。该模型处理设备可以为包括手机、平板电脑等任意终端设备,以模型处理设备是客户端,客户端是手机为例:
图16示出的是与本申请实施例提供的模型处理设备-手机的部分结构的框图。参考图16,手机包括:射频(radio frequency,RF)电路1610、存储器1620、输入单元1630、显示单元1640、传感器1650、音频电路1660、无线保真(wireless fidelity,WiFi)模块1670、处理器1680、以及电源1690等部件。本领域技术人员可以理解,图16中示出的手机结构并不构成对手机的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
下面结合图16对手机的各个构成部件进行具体的介绍:
RF电路1610可用于收发信息或通话过程中,信号的接收和发送,特别地,将基站的下行信息接收后,给处理器1680处理;另外,将设计上行的数据发送给基站。通常,RF电路1610包括但不限于天线、至少一个放大器、收发信机、耦合器、低噪声放大器(low noiseamplifier,LNA)、双工器等。此外,RF电路1610还可以通过无线通信与网络和其他设备通信。上述无线通信可以使用任一通信标准或协议,包括但不限于全球移动通讯系统(globalsystem of mobile communication,GSM)、通用分组无线服务(general packet radioservice,GPRS)、码分多址(code division multiple access,CDMA)、宽带码分多址(wideband code division multiple access,WCDMA)、长期演进(long term evolution,LTE)、电子邮件、短消息服务(short messaging service,SMS)等。
存储器1620可用于存储软件程序以及模块,处理器1680通过运行存储在存储器1620的软件程序以及模块,从而执行手机的各种功能应用以及数据处理。存储器1620可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据手机的使用所创建的数据(比如音频数据、电话本等)等。此外,存储器1620可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
输入单元1630可用于接收输入的数字或字符信息,以及产生与手机的用户设置以及功能控制有关的键信号输入。具体地,输入单元1630可包括触控面板1631以及其他输入设备1632。触控面板1631,也称为触摸屏,可收集用户在其上或附近的触摸操作(比如用户使用手指、触笔等任何适合的物体或附件在触控面板1631上或在触控面板1631附近的操作),并根据预先设定的程式驱动相应的连接装置。可选的,触控面板1631可包括触摸检测装置和触摸控制器两个部分。其中,触摸检测装置检测用户的触摸方位,并检测触摸操作带来的信号,将信号传送给触摸控制器;触摸控制器从触摸检测装置上接收触摸信息,并将它转换成触点坐标,再送给处理器1680,并能接收处理器1680发来的命令并加以执行。此外,可以采用电阻式、电容式、红外线以及表面声波等多种类型实现触控面板1631。除了触控面板1631,输入单元1630还可以包括其他输入设备1632。具体地,其他输入设备1632可以包括但不限于物理键盘、功能键(比如音量控制按键、开关按键等)、轨迹球、鼠标、操作杆等中的一种或多种。
显示单元1640可用于显示由用户输入的信息或提供给用户的信息以及手机的各种菜单。显示单元1640可包括显示面板1641,可选的,可以采用液晶显示器(liquidcrystal display,LCD)、有机发光二极管(organic light-emitting diode,OLED)等形式来配置显示面板1641。进一步的,触控面板1631可覆盖显示面板1641,当触控面板1631检测到在其上或附近的触摸操作后,传送给处理器1680以确定触摸事件的类型,随后处理器1680根据触摸事件的类型在显示面板1641上提供相应的视觉输出。虽然在图16中,触控面板1631与显示面板1641是作为两个独立的部件来实现手机的输入和输入功能,但是在某些实施例中,可以将触控面板1631与显示面板1641集成而实现手机的输入和输出功能。
手机还可包括至少一种传感器1650,比如光传感器、运动传感器以及其他传感器。具体地,光传感器可包括环境光传感器及接近传感器,其中,环境光传感器可根据环境光线的明暗来调节显示面板1641的亮度,接近传感器可在手机移动到耳边时,关闭显示面板1641和/或背光。作为运动传感器的一种,加速计传感器可检测各个方向上(一般为三轴)加速度的大小,静止时可检测出重力的大小及方向,可用于识别手机姿态的应用(比如横竖屏切换、相关游戏、磁力计姿态校准)、振动识别相关功能(比如计步器、敲击)等;至于手机还可配置的陀螺仪、气压计、湿度计、温度计、红外线、IMU、SLAM传感器等其他传感器,在此不再赘述。
音频电路1660、扬声器1661,传声器1662可提供用户与手机之间的音频接口。音频电路1660可将接收到的音频数据转换后的电信号,传输到扬声器1661,由扬声器1661转换为声音信号输出;另一方面,传声器1662将收集的声音信号转换为电信号,由音频电路1660接收后转换为音频数据,再将音频数据输出处理器1680处理后,经RF电路1610以发送给比如另一手机,或者将音频数据输出至存储器1620以便进一步处理。
WiFi属于短距离无线传输技术,手机通过WiFi模块1670可以帮助用户收发电子邮件、浏览网页和访问流式媒体等,它为用户提供了无线的宽带互联网访问。虽然图16示出了WiFi模块1670,但是可以理解的是,其并不属于手机的必须构成。
处理器1680是手机的控制中心,利用各种接口和线路连接整个手机的各个部分,通过运行或执行存储在存储器1620内的软件程序和/或模块,以及调用存储在存储器1620内的数据,执行手机的各种功能和处理数据,从而对手机进行整体监控。可选的,处理器1680可包括一个或多个处理单元;优选的,处理器1680可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器1680中。
手机还包括给各个部件供电的电源1690(比如电池),优选的,电源可以通过电源管理系统与处理器1680逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。
尽管未示出,手机还可以包括摄像头、蓝牙模块等,在此不再赘述。
在本申请实施例中,该手机所包括的处理器1680可以执行前述图2至图5、图11中客户端或图12所示实施例中模型处理设备的功能,此处不再赘述。
参阅图17,本申请提供的另一种上游设备的结构示意图。该上游设备可以包括处理器1701、存储器1702和通信接口1703。该处理器1701、存储器1702和通信接口1703通过线路互联。其中,存储器1702中存储有程序指令和数据。
存储器1702中存储了前述图2至图6或图11所示实施例中云服务器或边服务器执行的步骤对应的程序指令以及数据。
处理器1701,用于执行前述图2至图6或图11所示实施例中云服务器或边服务器执行的步骤。
通信接口1703可以用于进行数据的接收和发送,用于执行前述图2至图6或图11所示实施例中云服务器或边服务器与获取、发送、接收相关的步骤。
一种实现方式中,上游设备可以包括相对于图17更多或更少的部件,本申请对此仅仅是示例性说明,并不作限定。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元可以全部或部分地通过软件、硬件、固件或者其任意组合来实现。
当使用软件实现所述集成的单元时,可以全部或部分地以计算机程序产品的形式实现。所述计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行所述计算机程序指令时,全部或部分地产生按照本发明实施例所述的流程或功能。所述计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。所述计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质传输,例如,所述计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(digital subscriber line,DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。所述计算机可读存储介质可以是计算机能够存取的任何可用介质或者是包含一个或多个可用介质集成的服务器、数据中心等数据存储设备。所述可用介质可以是磁性介质,(例如,软盘、硬盘、磁带)、光介质(例如,DVD)、或者半导体介质(例如固态硬盘(solid state disk,SSD))等。
本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的术语在适当情况下可以互换,这仅仅是描述本申请的实施例中对相同属性的对象在描述时所采用的区分方式。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,以便包含一系列单元的过程、方法、系统、产品或设备不必限于那些单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它单元。
Claims (24)
1.一种模型处理方法,其特征在于,所述方法包括:
获取包括标签值的训练数据;
以所述训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型,所述第一模型包括多个子结构,所述多个子结构中的每个子结构包括至少两个神经元;
基于第二损失函数以及约束条件对所述第一模型进行剪枝,以得到第二模型,所述第二损失函数用于指示将所述多个子结构中的至少一个子结构进行剪枝,所述约束条件用于约束所述第二模型的精度不低于所述第一模型的精度,所述精度指示模型的输出值与所述标签值之间的差异程度。
2.根据权利要求1所述的方法,其特征在于,所述约束条件具体用于约束所述第一损失函数的下降方向与所述第二损失函数的下降方向之间的夹角小于或等于90度。
3.根据权利要求1所述的方法,其特征在于,所述约束条件具体用于约束所述第二模型的所述第一损失函数的值小于或等于所述第一模型的所述第一损失函数的值。
4.根据权利要求1至3中任一项所述的方法,其特征在于,所述第二损失函数包括第一稀疏项,所述第一稀疏项与所述多个子结构中的至少一个子结构的权重相关。
5.根据权利要求1至4中任一项所述的方法,其特征在于,所述基于第二损失函数以及约束条件对所述第一模型进行剪枝,以得到第二模型,包括:
基于所述第二损失函数对所述第一模型进行至少一次随机剪枝,直至剪枝所述第一模型后得到的所述第二模型满足所述约束条件。
6.根据权利要求1至5中任一项所述的方法,其特征在于,所述方法应用于客户端,所述训练数据是所述客户端本地的数据,所述方法还包括:
接收上游设备发送的所述神经网络模型;
向所述上游设备发送所述第二模型。
7.根据权利要求1至6中任一项所述的方法,其特征在于,所述训练数据包括:图像数据,音频数据或者文本数据。
8.根据权利要求1至7中任一项所述的方法,其特征在于,所述神经网络模型用于对图像数据进行分类和/或识别。
9.一种联邦学习方法,其特征在于,所述方法包括:
向多个下游设备发送神经网络模型,所述神经网络模型包括多个子结构,所述多个子结构中的每个子结构包括至少两个神经元;
接收来自所述多个下游设备的多个第一模型,所述多个第一模型由所述神经网络模型训练得到;
基于损失函数以及约束条件对所述多个第一模型分别进行剪枝,其中,所述损失函数用于指示对所述多个第一模型的所述子结构进行剪枝,所述约束条件用于约束每个所述第一模型剪枝后的精度不低于剪枝前的精度;
将剪枝后的所述多个第一模型进行聚合,以得到第二模型。
10.一种联邦学习方法,其特征在于,所述方法包括:
向多个下游设备发送神经网络模型,所述神经网络模型包括多个子结构,所述多个子结构中的每个子结构包括至少两个神经元;
接收来自所述多个下游设备的多个第一模型,所述多个第一模型由所述神经网络模型训练得到;
将所述多个第一模型进行聚合,以得到第二模型;
基于损失函数以及约束条件对所述第二模型进行剪枝,其中,所述损失函数用于指示对所述第二模型的所述子结构进行剪枝,所述约束条件用于约束所述第二模型剪枝后的精度不低于剪枝前的精度。
11.一种模型处理设备,其特征在于,所述设备包括:
获取单元,用于获取包括标签值的训练数据;
训练单元,用于以所述训练数据为输入,根据第一损失函数训练神经网络模型,得到第一模型,所述第一模型包括多个子结构,所述多个子结构中的每个子结构包括至少两个神经元;
剪枝单元,用于基于第二损失函数以及约束条件对所述第一模型进行剪枝,以得到第二模型,所述第二损失函数用于指示将所述多个子结构中的至少一个子结构进行剪枝,所述约束条件用于约束所述第二模型的精度不低于所述第一模型的精度,所述精度指示模型的输出值与所述标签值之间的差异程度。
12.根据权利要求11所述的设备,其特征在于,所述约束条件具体用于约束所述第一损失函数的下降方向与所述第二损失函数的下降方向之间的夹角小于或等于90度。
13.根据权利要求11所述的设备,其特征在于,所述约束条件具体用于约束所述第二模型的所述第一损失函数的值小于或等于所述第一模型的所述第一损失函数的值。
14.根据权利要求11至13中任一项所述的设备,其特征在于,所述第二损失函数包括第一稀疏项,所述第一稀疏项与所述多个子结构中的至少一个子结构的权重相关。
15.根据权利要求11至14中任一项所述的设备,其特征在于,所述剪枝单元,具体用于基于所述第二损失函数对所述第一模型进行至少一次随机剪枝,直至剪枝所述第一模型后得到的所述第二模型满足所述约束条件。
16.根据权利要求11至15中任一项所述的设备,其特征在于,所述模型处理设备为客户端,所述训练数据是所述客户端本地的数据,所述模型处理设备还包括:
接收单元,用于接收上游设备发送的所述神经网络模型;
发送单元,用于向所述上游设备发送所述第二模型。
17.根据权利要求11至16中任一项所述的设备,其特征在于,所述训练数据包括:图像数据,音频数据或者文本数据。
18.根据权利要求11至17中任一项所述的设备,其特征在于,所述神经网络模型用于对图像数据进行分类和/或识别。
19.一种上游设备,其特征在于,所述上游设备应用于联邦学习方法,所述上游设备包括:
发送单元,用于向多个下游设备发送神经网络模型,所述神经网络模型包括多个子结构,所述多个子结构中的每个子结构包括至少两个神经元;
接收单元,用于接收来自所述多个下游设备的多个第一模型,所述多个第一模型由所述神经网络模型训练得到;
剪枝单元,用于基于损失函数以及约束条件对所述多个第一模型分别进行剪枝,其中,所述损失函数用于指示对所述多个第一模型的所述子结构进行剪枝,所述约束条件用于约束每个所述第一模型剪枝后的精度不低于剪枝前的精度;
聚合单元,用于将剪枝后的所述多个第一模型进行聚合,以得到第二模型。
20.一种上游设备,其特征在于,所述上游设备应用于联邦学习方法,所述上游设备包括:
发送单元,用于向多个下游设备发送神经网络模型,所述神经网络模型包括多个子结构,所述多个子结构中的每个子结构包括至少两个神经元;
接收单元,用于接收来自所述多个下游设备的多个第一模型,所述多个第一模型由所述神经网络模型训练得到;
聚合单元,用于将所述多个第一模型进行聚合,以得到第二模型;
剪枝单元,用于基于损失函数以及约束条件对所述第二模型进行剪枝,其中,所述损失函数用于指示对所述第二模型的所述子结构进行剪枝,所述约束条件用于约束所述第二模型剪枝后的精度不低于剪枝前的精度。
21.一种电子设备,其特征在于,包括:处理器,所述处理器与存储器耦合,所述存储器用于存储程序或指令,当所述程序或指令被所述处理器执行时,使得所述电子设备执行如权利要求1至8中任一项所述的方法。
22.一种电子设备,其特征在于,包括:处理器,所述处理器与存储器耦合,所述存储器用于存储程序或指令,当所述程序或指令被所述处理器执行时,使得所述电子设备执行如权利要求9或10所述的方法。
23.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有指令,所述指令在计算机上执行时,使得所述计算机执行如权利要求1至10中任一项所述的方法。
24.一种计算机程序产品,其特征在于,所述计算机程序产品在计算机上执行时,使得所述计算机执行如权利要求1至10中任一项所述的方法。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110763965.6A CN113469340A (zh) | 2021-07-06 | 2021-07-06 | 一种模型处理方法、联邦学习方法及相关设备 |
PCT/CN2022/100682 WO2023279975A1 (zh) | 2021-07-06 | 2022-06-23 | 一种模型处理方法、联邦学习方法及相关设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110763965.6A CN113469340A (zh) | 2021-07-06 | 2021-07-06 | 一种模型处理方法、联邦学习方法及相关设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113469340A true CN113469340A (zh) | 2021-10-01 |
Family
ID=77878843
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110763965.6A Pending CN113469340A (zh) | 2021-07-06 | 2021-07-06 | 一种模型处理方法、联邦学习方法及相关设备 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN113469340A (zh) |
WO (1) | WO2023279975A1 (zh) |
Cited By (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114492847A (zh) * | 2022-04-18 | 2022-05-13 | 奥罗科技(天津)有限公司 | 一种高效个性化联邦学习系统和方法 |
CN114580632A (zh) * | 2022-03-07 | 2022-06-03 | 腾讯科技(深圳)有限公司 | 模型优化方法和装置、计算设备及存储介质 |
CN115115064A (zh) * | 2022-07-11 | 2022-09-27 | 山东大学 | 一种半异步联邦学习方法及系统 |
CN115170917A (zh) * | 2022-06-20 | 2022-10-11 | 美的集团(上海)有限公司 | 图像处理方法、电子设备及存储介质 |
WO2023279975A1 (zh) * | 2021-07-06 | 2023-01-12 | 华为技术有限公司 | 一种模型处理方法、联邦学习方法及相关设备 |
WO2024087573A1 (zh) * | 2022-10-29 | 2024-05-02 | 华为技术有限公司 | 一种联邦学习方法及装置 |
CN118101501A (zh) * | 2024-04-23 | 2024-05-28 | 山东大学 | 一种工业物联网异构联邦学习的通信方法和系统 |
Families Citing this family (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116148193B (zh) * | 2023-04-18 | 2023-07-18 | 天津中科谱光信息技术有限公司 | 水质监测方法、装置、设备及存储介质 |
CN116484922B (zh) * | 2023-04-23 | 2024-02-06 | 深圳大学 | 一种联邦学习方法、系统、设备及存储介质 |
CN116797829B (zh) * | 2023-06-13 | 2024-06-14 | 北京百度网讯科技有限公司 | 一种模型生成方法、图像分类方法、装置、设备及介质 |
CN117910536B (zh) * | 2024-03-19 | 2024-06-07 | 浪潮电子信息产业股份有限公司 | 文本生成方法及其模型梯度剪枝方法、装置、设备、介质 |
CN118504713A (zh) * | 2024-06-07 | 2024-08-16 | 北京天融信网络安全技术有限公司 | 面向全局优化的联邦学习方法、电子设备及存储介质 |
CN118396085A (zh) * | 2024-06-26 | 2024-07-26 | 中山大学 | 在线文字识别模型训练方法、在线文字识别方法及装置 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210125071A1 (en) * | 2019-10-25 | 2021-04-29 | Alibaba Group Holding Limited | Structured Pruning for Machine Learning Model |
CN112906853A (zh) * | 2019-12-03 | 2021-06-04 | 中国移动通信有限公司研究院 | 模型自动优化的方法及装置、设备、存储介质 |
CN112966818A (zh) * | 2021-02-25 | 2021-06-15 | 苏州臻迪智能科技有限公司 | 一种定向引导模型剪枝方法、系统、设备及存储介质 |
Family Cites Families (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US10832123B2 (en) * | 2016-08-12 | 2020-11-10 | Xilinx Technology Beijing Limited | Compression of deep neural networks with proper use of mask |
CN110874550A (zh) * | 2018-08-31 | 2020-03-10 | 华为技术有限公司 | 数据处理方法、装置、设备和系统 |
CN112101487B (zh) * | 2020-11-17 | 2021-07-16 | 深圳感臻科技有限公司 | 一种细粒度识别模型的压缩方法和设备 |
CN112396179A (zh) * | 2020-11-20 | 2021-02-23 | 浙江工业大学 | 一种基于通道梯度剪枝的柔性深度学习网络模型压缩方法 |
CN113065636B (zh) * | 2021-02-27 | 2024-06-07 | 华为技术有限公司 | 一种卷积神经网络的剪枝处理方法、数据处理方法及设备 |
CN113469340A (zh) * | 2021-07-06 | 2021-10-01 | 华为技术有限公司 | 一种模型处理方法、联邦学习方法及相关设备 |
-
2021
- 2021-07-06 CN CN202110763965.6A patent/CN113469340A/zh active Pending
-
2022
- 2022-06-23 WO PCT/CN2022/100682 patent/WO2023279975A1/zh active Application Filing
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210125071A1 (en) * | 2019-10-25 | 2021-04-29 | Alibaba Group Holding Limited | Structured Pruning for Machine Learning Model |
CN112906853A (zh) * | 2019-12-03 | 2021-06-04 | 中国移动通信有限公司研究院 | 模型自动优化的方法及装置、设备、存储介质 |
CN112966818A (zh) * | 2021-02-25 | 2021-06-15 | 苏州臻迪智能科技有限公司 | 一种定向引导模型剪枝方法、系统、设备及存储介质 |
Non-Patent Citations (2)
Title |
---|
YU-CHENG WU ET AL.: "Constraint-Aware Importance Estimation for Global Filter Pruning under Multiple Resource Constraints", 《2020 IEEE/CVF CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION WORKSHOPS (CVPRW)》, 28 July 2020 (2020-07-28), pages 2935 - 2943 * |
徐嘉荟: "基于模型剪枝的神经网络压缩技术研究", 《信息通信》, no. 12, 15 December 2019 (2019-12-15), pages 165 - 167 * |
Cited By (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2023279975A1 (zh) * | 2021-07-06 | 2023-01-12 | 华为技术有限公司 | 一种模型处理方法、联邦学习方法及相关设备 |
CN114580632A (zh) * | 2022-03-07 | 2022-06-03 | 腾讯科技(深圳)有限公司 | 模型优化方法和装置、计算设备及存储介质 |
CN114492847A (zh) * | 2022-04-18 | 2022-05-13 | 奥罗科技(天津)有限公司 | 一种高效个性化联邦学习系统和方法 |
CN115170917A (zh) * | 2022-06-20 | 2022-10-11 | 美的集团(上海)有限公司 | 图像处理方法、电子设备及存储介质 |
CN115170917B (zh) * | 2022-06-20 | 2023-11-07 | 美的集团(上海)有限公司 | 图像处理方法、电子设备及存储介质 |
CN115115064A (zh) * | 2022-07-11 | 2022-09-27 | 山东大学 | 一种半异步联邦学习方法及系统 |
CN115115064B (zh) * | 2022-07-11 | 2023-09-05 | 山东大学 | 一种半异步联邦学习方法及系统 |
WO2024087573A1 (zh) * | 2022-10-29 | 2024-05-02 | 华为技术有限公司 | 一种联邦学习方法及装置 |
CN118101501A (zh) * | 2024-04-23 | 2024-05-28 | 山东大学 | 一种工业物联网异构联邦学习的通信方法和系统 |
CN118101501B (zh) * | 2024-04-23 | 2024-07-05 | 山东大学 | 一种工业物联网异构联邦学习的通信方法和系统 |
Also Published As
Publication number | Publication date |
---|---|
WO2023279975A1 (zh) | 2023-01-12 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113469340A (zh) | 一种模型处理方法、联邦学习方法及相关设备 | |
CN110009052B (zh) | 一种图像识别的方法、图像识别模型训练的方法及装置 | |
US11763599B2 (en) | Model training method and apparatus, face recognition method and apparatus, device, and storage medium | |
WO2022083536A1 (zh) | 一种神经网络构建方法以及装置 | |
CN111816159B (zh) | 一种语种识别方法以及相关装置 | |
US11403510B2 (en) | Processing sensor data | |
CN109918684A (zh) | 模型训练方法、翻译方法、相关装置、设备及存储介质 | |
CN110516113B (zh) | 一种视频分类的方法、视频分类模型训练的方法及装置 | |
CN113065636A (zh) | 一种卷积神经网络的剪枝处理方法、数据处理方法及设备 | |
CN113065635A (zh) | 一种模型的训练方法、图像增强方法及设备 | |
CN113469367A (zh) | 一种联邦学习方法、装置及系统 | |
CN113505883A (zh) | 一种神经网络训练方法以及装置 | |
WO2022012668A1 (zh) | 一种训练集处理方法和装置 | |
CN113536970A (zh) | 一种视频分类模型的训练方法及相关装置 | |
CN113191479A (zh) | 联合学习的方法、系统、节点及存储介质 | |
CN113869496A (zh) | 一种神经网络的获取方法、数据处理方法以及相关设备 | |
CN115879508A (zh) | 一种数据处理方法及相关装置 | |
WO2023051678A1 (zh) | 一种推荐方法及相关装置 | |
CN115866291A (zh) | 一种数据处理方法及其装置 | |
CN115983362A (zh) | 一种量化方法、推荐方法以及装置 | |
CN117764190A (zh) | 一种数据处理方法及其装置 | |
CN114254724A (zh) | 一种数据处理方法、神经网络的训练方法以及相关设备 | |
CN114254176B (zh) | 分层因子分解机模型训练方法、信息推荐方法及相关设备 | |
US20240104915A1 (en) | Long duration structured video action segmentation | |
CN115265881B (zh) | 压力检测方法和装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |