CN116415653A - 一种基于知识蒸馏的类别增量神经网络模型聚合方法 - Google Patents
一种基于知识蒸馏的类别增量神经网络模型聚合方法 Download PDFInfo
- Publication number
- CN116415653A CN116415653A CN202111627025.0A CN202111627025A CN116415653A CN 116415653 A CN116415653 A CN 116415653A CN 202111627025 A CN202111627025 A CN 202111627025A CN 116415653 A CN116415653 A CN 116415653A
- Authority
- CN
- China
- Prior art keywords
- model
- aggregation
- aggregated
- increment
- initial
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Landscapes
- Image Analysis (AREA)
Abstract
本发明提供一种基于知识蒸馏的类别增量神经网络模型聚合方法,具体包括以下步骤:步骤S1,获取聚合所需的增量类别信息以及聚合所得模型的具体结构信息;步骤S2,根据增量类别信息以及结构信息选择多个异构模型作为待聚合模型,并构建初始聚合模型;步骤S3,采用无监督数据对待聚合模型以及初始聚合模型进行类别增量模型聚合,并基于聚合时的类别增量蒸馏损失以及特征过滤损失更新初始聚合模型直至生成增量模型;该方法支持异构待聚合模型间的聚合任务,同时只需单轮聚合就能使聚合模型收敛,大大降低模型聚合过程中的通信成本。
Description
技术领域
本发明属于人工智能领域,具体涉及一种基于知识蒸馏的类别增量神经网络模型聚合方法。
背景技术
随着AI技术的发展,部署有深度神经网络的智能终端(如手机、自动驾驶汽车,物联网终端等)已经逐渐广泛应用到生活中的各个场景中,并结合端云协同计算方法一同构成整体的分布式群智学习框架。这些分布式的智能终端被分散地部署在不同的应用场景中,并且各自拥有相互独立的不同分布的域数据,模型能够通过这些独立域的数据学习到其中的知识。然而,出于隐私保护的考虑,不同域之间的数据无法相互共享用于新模型的训练,为了使聚合模型能够在不接触各个数据域的原始训练数据的前提下学习到汇聚的知识,需要一种方法能仅通过模型参数实现多模型知识的汇聚,这一过程被称为模型聚合。
模型聚合这一概念最早在联邦学习框架中被提出,目前主流的模型聚合方法主要基于联邦参数平均构建,然而基于参数平均的模型聚合方法存在着诸多限制。不仅要求所有参与聚合的待聚合模型拥有相同的模型结构、支持类别以及初始化参数,也需要大量轮次的迭代才能够使得聚合模型最终收敛,极大增加了聚合过程的通信传输成本。与此同时,在实际应用场景中,由于不同数据域的模型应用范围的差异,不可避免地会出现域间模型异构或支持类别不同的现象。而基于参数聚合的模型聚合方法则由于其自身的限制难以应用在真实复杂场景下的模型聚合任务上。
知识蒸馏框架通过构建一个教师-学生的联合学习框架,将训练数据同时输入到教师模型和学生模型中,通过学生模型对教师模型输出的软标签的模仿训练学生模型尽可能接近教师模型。融入模型集成的思想拓展后的知识蒸馏框架能够支持同时将多个教师模型的知识蒸馏到单个学生模型中,这一框架被称为多教师蒸馏。
发明内容
为解决上述问题,提供一种能够在隐私保护的前提下聚合支持不同类别以及异构模型的模型聚合方法,本发明采用了如下技术方案:
本发明提供了一种基于知识蒸馏的类别增量神经网络模型聚合方法,其特征在于,包括以下步骤:步骤S1,获取聚合所需的增量类别信息以及聚合所得模型的具体结构信息;步骤S2,根据增量类别信息以及结构信息选择多个异构模型作为待聚合模型,并构建初始聚合模型;步骤S3,采用无监督数据对待聚合模型以及初始聚合模型进行类别增量模型聚合,并基于聚合时的累加损失更新初始聚合模型直至生成增量模型,累加损失包括类别增量蒸馏损失以及特征过滤损失。
本发明提供的一种基于知识蒸馏的类别增量神经网络模型聚合方法,还可以具有这样的技术特征,其中,待聚合模型与初始聚合模型均由特征提取模块以及分类器组成,待聚合模型还具有特征过滤模块。
本发明提供的一种基于知识蒸馏的类别增量神经网络模型聚合方法,还可以具有这样的技术特征,其中,步骤S3包括以下子步骤:步骤S3-1,将无标注数据输入至待聚合模型和初始聚合模型中分别获取待聚合模型的预测输出z0-zn、初始聚合模型的输出zM以及中间层特征Fj;步骤S3-2,采用特征过滤模块对中间层特征Fj进行过滤后分别输入至待聚合模型中,获取前向传播输出的中间层特征Fj的logits步骤S3-3,分别计算预测输出z0-zn和/>在重合类别上的交叉熵损失并累加,得到特征过滤损失;步骤S3-4,分别计算待聚合模型以及初始聚合模型对无标注数据输出间的类别增量蒸馏损失;步骤S3-5,基于特征过滤损失以及类别增量蒸馏损失计算累加损失,并采用反向传播不断更新初始聚合模型直至生成增量模型。
本发明提供的一种基于知识蒸馏的类别增量神经网络模型聚合系统,还可以具有这样的技术特征,其中,在聚合过程中,仅更新初始聚合模型的模型参数,待聚合模型仅用于推断。
本发明提供的一种基于知识蒸馏的类别增量神经网络模型聚合方法,还可以具有这样的技术特征,其中,特征过滤损失的损失函数为:
式中,M代表聚合后的增量模型,支持标签集合为LM,输出为zM;C1-Cn代表参与聚合过程的待聚合模型,支持标签集合为Li(其中 输出为zi,/>代表初始聚合模型第j个模块的输出的中间层特征输入到待聚合模型i的下一层后得到的输出。
本发明提供的一种基于知识蒸馏的类别增量神经网络模型聚合方法,还可以具有这样的技术特征,其中,类别增量蒸馏损失的损失函数为:
本发明还提供一种基于知识蒸馏的类别增量神经网络模型聚合系统,其特征在于,包括:端侧以及云侧聚合中心,端侧用于将聚合申请以及根据预先存储的端侧私有数据训练得到多个端侧模型传输至云侧聚合中心,并将云侧聚合中心发送的增量模型进行更新存储,聚合申请包括需要增量的类别信息以及聚合所得模型的具体结构信息,云侧聚合中心用于根据聚合申请在接收的多个端侧模型中选择待聚合模型并构建初始聚合模型,采用公开的无标注数据对待聚合模型以及初始聚合模型进行类别增量模型聚合以获取增量模型,将聚合得到的增量模型发至端侧。
发明作用与效果
根据本发明的基于知识蒸馏的类别增量神经网络模型聚合方法,由于引入了基于知识蒸馏的方法来实现模型聚合过程,因此只需单轮聚合就能使聚合模型收敛,大大降低了模型聚合过程中的通信成本;同时,本发明提出的模型聚合方法改进设计了类别增量蒸馏损失以降低待聚合模型不支持类别所带来的误监督影响,提升了增量模型的性能,且能够应用于待聚合模型支持类别各不相同的场景即类别增量场景;还由于提出了基于特征过滤模块的特征过滤损失,使增量模型能够选择性学习多个待聚合模型中的知识,进一步提升了模型性能;另外,还由于特征过滤模块能够自适应地调整特征图尺寸的大小,使得本方法能够进一步支持异构待聚合模型间的聚合任务,更具有真实复杂场景下的实用价值。
附图说明
图1是本发明实施例中基于知识蒸馏的类别增量神经网络模型聚合方法的框架示意图;
图2是本发明实施例中的基于知识蒸馏的类别增量神经网络模型聚合方法的流程图;
图3是本发明实施例中聚合过程的流程图;
图4是本发明实施例中特征过滤模块的结构示意图;
图5是本发明实施例中的类别增量蒸馏损失的示意图;
图6是本发明实施例中无监督类别增量模型聚合过程的示意图。
具体实施方式
为了使本发明实现的技术手段、创作特征、达成目的与功效易于明白了解,以下结合实施例及附图对本发明的基于知识蒸馏的类别增量神经网络模型聚合方法作具体阐述。
<实施例>
图1是本发明实施例中基于知识蒸馏的类别增量神经网络模型聚合方法的框架示意图。
本实施例中,基于知识蒸馏的类别增量神经网络模型聚合方法被部署在分布式云端协同学习框架的云端聚合中心。
如图1所示,基于知识蒸馏的类别增量神经网络模型聚合系统具有端侧和云侧。其中,端侧为分布式的智能端侧设备,每个端侧被视为一个独立的数据域,拥有相互独立的端侧私有数据以及根据私有数据训练得到的多个端侧模型。由于私有数据的类别存在一定差异看,因此,多个端侧模型所支持的类别也各不相同。
端侧用于向云侧发送包含需要增量的类别信息和聚合所得模型的具体结构信息的聚合申请以及多个端侧模型。
云侧聚合中心用于根据聚合申请在接收的多个端侧模型中选择待聚合模型并构建初始聚合模型,采用公开的无标注数据对待聚合模型以及初始聚合模型进行类别增量模型聚合以获取增量模型,并将聚合得到的增量模型发至端侧。
图2是本发明实施例中的基于知识蒸馏的类别增量神经网络模型聚合方法的流程图。
如图2所示,类别增量神经网络模型聚合方法引用知识蒸馏的框架进行多模型的聚合,过程如下:
步骤S1,获取聚合所需的增量类别信息以及聚合所得模型的具体结构信息。
本实施例中,由特定端侧发起聚合申请,并指定本次聚合需要增量的类别以及聚合所得模型具体结构,并将多个异构的端侧模型上传至云侧聚合中心。
步骤S2,根据增量类别信息以及结构信息选择多个异构模型作为待聚合模型,并构建初始聚合模型。
本实施例中,共有多个端侧模型被选中作为待聚合模型参与本轮的模型聚合,云侧聚合中心构建支持增量类别的新模型并初始化,作为初始聚合模型。
步骤S3,采用无监督公开数据对待聚合模型以及初始聚合模型进行类别增量模型聚合,直至初始聚合模型收敛生成增量模型。
图3是本发明实施例中聚合过程的流程图。
如图3所示,本实施例的聚合过程包括以下子步骤:
步骤S3-1,将无标注数据输入至待聚合模型C0-Cn和初始聚合模型C′0中分别获取待聚合模型的预测输出z0-zn、初始聚合模型的输出zM以及中间层特征Fj。
其中,无标注数据可以为任意来源的无标注公开的自然图像,只需满足无标注数据中包含自然图像的纹理即可。同时,越接近参与聚合端侧模型原始训练数据分布的无标注数据,用于本步骤时得到的聚合效果也会更好一些。
图4是本发明实施例中特征过滤模块的结构示意图。
如图4所示,本实施例中的特征过滤模块由一个简单的两层子网络组成,第一层为一个自适应池化层,用于在聚合相互异构的端侧模型时调整聚合模型输出的中间层特征图的尺寸并使其与端侧模型对应下一个模块的输入尺寸相符。自适应池化层的输入尺寸大小为当前模块聚合模型输出的特征图尺寸大小,输出尺寸大小为对应的端侧模型下一个模块输入的特征图的尺寸大小。第二层由一个1x1卷积层构成。同时保证输入的特征通道数和输出的特征通道数与聚合模型和端侧模型对应的通道数保持一致。1x1卷积用于对聚合模型输出的中间层特征进行通道维度上的重排,参数随机初始化并在聚合的过程中不断更新。
本实施例中,上述的单个端侧模型与初始聚合模型间的特征过滤损失函数为:
式中,M代表聚合后的增量模型,支持标签集合为LM,输出为zM;C1-Cn代表参与聚合过程的端侧模型,支持标签集合为Li(其中 输出为zi,/>代表聚合模型第j个模块的输出的中间层特征输入到端侧模型i的下一层后得到的输出。
本实施例中,在对每批无标注数据计算特征过滤损失时,同时将无标注数据输入到每一个端侧模型以及初始聚合模型中,分别计算每个端侧模型和初始聚合模型间的特征过滤损失后,对所有的特征过滤损失进行累加并平均,得到最终的特征过滤损失。
步骤S3-4,分别计算待聚合模型以及初始聚合模型在LM∩Li上的交叉熵损失并累加,得到类别增量蒸馏损失。
本实施例中,上述的类别增量蒸馏损失函数为:
在计算类别增量蒸馏损失时,为了将zi与LM上的概率zM关联起来,本实施例将pi(Y=l)近似为pi(Y=l)在Y∈LM∩Li的条件下的概率,即:从而在计算损失时避免了端侧模型不支持类别间损失的计算,降低了端侧模型不支持类别所带来的误监督影响。
图5是本发明实施例中的类别增量蒸馏损失的示意图。
如图5所示,在本实施例中,对于多个支持不同类别的端侧模型,在代入上述的近似估计的前提下,分别两两计算每个端侧模型与初始聚合模型间重合类别上的交叉熵损失并平均,得到类别增量蒸馏损失。
步骤S3-5,对于每批输入的无标注数据,在分别计算得到特征过滤损失与类别增量蒸馏损失后,计算累加损失并使用反向传播更新初始聚合模型的参数以及所有的特征过滤模块的参数。
在上述聚合的过程中,参与聚合的所有端侧模型只用于推断,而不更新其模型参数。
图6是本发明实施例中无监督类别增量模型聚合过程的示意图。
如图6所示,基于知识蒸馏的类别增量神经网络模型聚合方法支持多个异构端侧模型间的类别增量模型聚合。在端侧数据隐私保护的端云协同学习框架中承担端侧知识汇聚的核心任务,同时,在实际使用过程中,本实施例提出的方法能够支持相互异构的端侧模型间的聚合,能够部署在更加复杂的实际场景下。
本实施例中,使用本实施例所提出的多模型聚合方法与现有的同类聚合方法在公开数据集ImageNet和Cifar-100上进行聚合实验结果,实验结果如表1所示,表明本实施例的多模型聚合方法优于现有的同类聚合方法。表1中使用的评价指标为Top-1分类准确率(%)。
表1
实施例作用与效果
根据本实施例提供的基于知识蒸馏的类别增量神经网络模型聚合方法,由于引入了知识蒸馏的框架用于模型聚合任务,通过让聚合模型模仿多个待聚合模型的中间层特征以及网络输出从而让聚合模型同时学习多个待聚合模型中的知识,同时仅通过单轮的聚合过程便能够让聚合模型达到收敛,因此在能够支持异构待聚合模型间的聚合的同时,也大大降低了现有的参数聚合方法所需求的大量模型传输的通信成本;还由于提出并使用了基于类别增量场景下设计的类别增量蒸馏损失以及特征过滤损失,能够降低因为待聚合模型不支持类别所带来的聚合性能下降问题,同时实现聚合过程中多数据域模型知识的选择性学习,因此能够在类别增量场景下大幅提升聚合模型的性能,更具有现实场景下的使用价值。
实施例中,由于聚合过程中聚合中心仅使用了公开的自然图像作为无监督数据对聚合模型进行训练,因此能够在实现各数据域知识汇聚的同时,保证各数据域原始训练数据的隐私不被泄露。
实施例中,还由于在多轮迭代的聚合过程中,本实施例所提出的聚合方法能够在极少的遗忘基础上实现多轮次的知识汇聚,更适合部署于实际应用场景下的持续学习框架之中。
上述实施例仅用于举例说明本发明的具体实施方式,而本发明不限于上述实施例的描述范围。
Claims (7)
1.一种基于知识蒸馏的类别增量神经网络模型聚合方法,其特征在于,包括以下步骤:
步骤S1,获取聚合所需的增量类别信息以及聚合所得模型的具体结构信息;
步骤S2,根据所述增量类别信息以及所述结构信息选择多个异构模型作为待聚合模型,并构建初始聚合模型;
步骤S3,采用无监督数据对所述待聚合模型以及所述初始聚合模型进行类别增量模型聚合,并基于聚合时的累加损失更新所述初始聚合模型直至生成增量模型,
所述累加损失包括类别增量蒸馏损失以及特征过滤损失。
2.根据权利要求1所述的一种基于知识蒸馏的类别增量神经网络模型聚合方法,其特征在于:
其中,所述待聚合模型与所述初始聚合模型均由特征提取模块以及分类器组成,所述待聚合模型还具有特征过滤模块。
3.根据权利要求2所述的一种基于知识蒸馏的类别增量神经网络模型聚合方法,其特征在于:
其中,所述步骤S3包括以下子步骤:
步骤S3-1,将无标注数据输入至所述待聚合模型和所述初始聚合模型中分别获取所述待聚合模型的预测输出z0-zn、所述初始聚合模型的输出zM以及中间层特征Fj;
步骤S3-4,分别计算所述待聚合模型以及所述初始聚合模型对无标注数据输出间的所述类别增量蒸馏损失;
步骤S3-5,基于所述特征过滤损失以及所述类别增量蒸馏损失计算累加损失,并采用反向传播不断更新所述初始聚合模型直至生成所述增量模型。
4.根据权利要求3所述的一种基于知识蒸馏的类别增量神经网络模型聚合方法,其特征在于:
其中,在聚合过程中,仅更新所述初始聚合模型的模型参数,所述待聚合模型仅用于推断。
7.一种基于知识蒸馏的类别增量神经网络模型聚合系统,其特征在于,包括:
端侧以及云侧聚合中心,
所述端侧用于将聚合申请以及根据预先存储的端侧私有数据训练得到多个端侧模型传输至所述云侧聚合中心,并将所述云侧聚合中心发送的增量模型进行更新存储,
所述聚合申请包括需要增量的类别信息以及聚合所得模型的具体结构信息,
所述云侧聚合中心用于根据所述聚合申请在接收的多个所述端侧模型中选择待聚合模型并构建初始聚合模型,采用公开的无标注数据对所述待聚合模型以及所述初始聚合模型进行类别增量模型聚合以获取所述增量模型,将聚合得到的所述增量模型发至端侧。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111627025.0A CN116415653A (zh) | 2021-12-28 | 2021-12-28 | 一种基于知识蒸馏的类别增量神经网络模型聚合方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111627025.0A CN116415653A (zh) | 2021-12-28 | 2021-12-28 | 一种基于知识蒸馏的类别增量神经网络模型聚合方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116415653A true CN116415653A (zh) | 2023-07-11 |
Family
ID=87049674
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111627025.0A Pending CN116415653A (zh) | 2021-12-28 | 2021-12-28 | 一种基于知识蒸馏的类别增量神经网络模型聚合方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116415653A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116977635A (zh) * | 2023-07-19 | 2023-10-31 | 中国科学院自动化研究所 | 类别增量语义分割学习方法及语义分割方法 |
-
2021
- 2021-12-28 CN CN202111627025.0A patent/CN116415653A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116977635A (zh) * | 2023-07-19 | 2023-10-31 | 中国科学院自动化研究所 | 类别增量语义分割学习方法及语义分割方法 |
CN116977635B (zh) * | 2023-07-19 | 2024-04-16 | 中国科学院自动化研究所 | 类别增量语义分割学习方法及语义分割方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110263280B (zh) | 一种基于多视图的动态链路预测深度模型及应用 | |
CN110674869B (zh) | 分类处理、图卷积神经网络模型的训练方法和装置 | |
CN108376392B (zh) | 一种基于卷积神经网络的图像运动模糊去除方法 | |
CN112181971A (zh) | 一种基于边缘的联邦学习模型清洗和设备聚类方法、系统、设备和可读存储介质 | |
CN113435472A (zh) | 车载算力网络用户需求预测方法、系统、设备、介质 | |
CN110266620A (zh) | 基于卷积神经网络的3d mimo-ofdm系统信道估计方法 | |
CN112491442B (zh) | 一种自干扰消除方法及装置 | |
Shi et al. | Machine learning for large-scale optimization in 6g wireless networks | |
CN111680702B (zh) | 一种使用检测框实现弱监督图像显著性检测的方法 | |
CN104091340A (zh) | 一种模糊图像的快速检测方法 | |
CN109787699B (zh) | 一种基于混合深度模型的无线传感器网络路由链路状态预测方法 | |
CN111224905A (zh) | 一种大规模物联网中基于卷积残差网络的多用户检测方法 | |
CN116415653A (zh) | 一种基于知识蒸馏的类别增量神经网络模型聚合方法 | |
CN114359073A (zh) | 一种低光照图像增强方法、系统、装置及介质 | |
Lei et al. | Oes-fed: a federated learning framework in vehicular network based on noise data filtering | |
CN113194493B (zh) | 基于图神经网络的无线网络数据缺失属性恢复方法及装置 | |
CN113486724A (zh) | 基于cnn-lstm多支流结构和多种信号表示的调制识别模型 | |
Liang et al. | Generative AI-driven semantic communication networks: Architecture, technologies and applications | |
CN111079900A (zh) | 一种基于自适应连接神经网络的图像处理方法及装置 | |
CN116070136A (zh) | 基于深度学习的多模态融合无线信号自动调制识别方法 | |
Yıldırım et al. | Deep receiver design for multi-carrier waveforms using cnns | |
CN115955375A (zh) | 基于cnn-gru和ca-vgg特征融合的调制信号识别方法及系统 | |
CN114595815A (zh) | 一种面向传输友好的云-端协作训练神经网络模型方法 | |
CN114758141A (zh) | 一种协同学习的带噪声标签图像分类方法 | |
Guo et al. | Device-edge digital semantic communication with trained non-linear quantization |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |