CN114565810A - 一种基于数据保护场景下的模型压缩方法及系统 - Google Patents
一种基于数据保护场景下的模型压缩方法及系统 Download PDFInfo
- Publication number
- CN114565810A CN114565810A CN202210220060.9A CN202210220060A CN114565810A CN 114565810 A CN114565810 A CN 114565810A CN 202210220060 A CN202210220060 A CN 202210220060A CN 114565810 A CN114565810 A CN 114565810A
- Authority
- CN
- China
- Prior art keywords
- model
- teacher
- loss function
- generator
- student
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/50—Information retrieval; Database structures therefor; File system structures therefor of still image data
- G06F16/53—Querying
- G06F16/535—Filtering based on additional data, e.g. user or group profiles
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Biomedical Technology (AREA)
- Evolutionary Biology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Health & Medical Sciences (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种基于数据保护场景下的模型压缩方法及系统,属于模型压缩领域,将传统的模型反演扩展到基于多教师集成的模型反演,充分反演来自教师的更丰富的信息以生成可泛化的数据。另外教师内部对比用于逐步合成具有与历史样本不同模式的新样本,师生对比旨在推动学生与同构教师之间的关系远离表示空间中的非同构关系,以提高合成数据的多样性。并且以对抗的方式训练图像生成和知识转移的过程,以同时学习学生模型和生成合成数据。本发明不依赖于模型原始的训练数据,通过多教师模型的知识蒸馏以及引入基于对比学习的损失函数,对无数据的模型压缩方法进行有效压缩和压缩后的模型具有更高的准确率。
Description
技术领域
本发明涉及模型压缩领域,特别是涉及一种基于数据保护场景下的模型压缩方法及系统。
背景技术
近年来随着深度学习算力的不断发展,深度学习模型愈发庞大,而当我们需要将模型部署到终端设备上时,就不得不进行模型的压缩。知识蒸馏(KnowledgeDistillation,KD)是一种流行的压缩方法,通过从冗余的教师模型中转移知识来学习轻量级学生模型来模仿表征能力。在大多数现有的KD方法中,使用logits或来自教师的特征信息成功地将知识转移到学生模型,但是在其中需要访问整个训练数据。
不幸的是,由于隐私、保密或传输限制,预训练模型的原始训练样本通常不可用。例如,患者的医疗数据是保密的,不会公开共享以泄露患者的隐私。如果没有数据的帮助,这些方法可能无法适用。
现有技术的方法是通过人为合成的训练数据来代替原始数据。但现有的方法所生成的数据都和原始数据有一定差距,缺乏数据的多样性和泛化性。压缩后模型的准确率不够令人满意。
发明内容
本发明的目的是提供一种基于数据保护场景下的模型压缩方法及系统,不依赖于模型原始的训练数据,对无数据的模型进行有效压缩同时提高模型压缩的准确率。
为实现上述目的,本发明提供了如下方案:
一种基于数据保护场景下的模型压缩方法,所述方法包括:
预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
分别建立教师内部对比损失函数和师生对比损失函数;
利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
所述特征正则化损失函数为式中,为特征正则化损失,为生成器合成的图像输入到训练好的教师模型后第l个BN层得到的均值,为生成器合成的图像输入到训练好的教师模型后第l个BN层得到的方差,F(μl(x)|X)为输入图像x输入到训练好的教师模型后第l个BN层得到的均值,为输入图像x输入到训练好的教师模型后第l个BN层得到的方差;
可选的,所述组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数,具体包括:
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得单教师条件下生成器的无数据蒸馏的模型反转损失函数为式中,为单教师条件下生成器的无数据蒸馏的模型反转损失,λ1、λ2和λ3分别为第一平衡参数、第二平衡参数和第三平衡参数;
根据单教师条件下生成器的无数据蒸馏的模型反转损失函数,构建多教师条件下生成器的无数据蒸馏的模型反转损失函数为式中,为多教师条件下生成器的无数据蒸馏的模型反转损失,为具有多教师信息的one-hot预测损失,为具有多教师信息的特征正则化损失,为具有多教师信息的对抗蒸馏损失,z为噪声输入,θg为生成器的参数;
其中,为M数量的集成教师模型的输出, 为第m个训练好的教师模型的输出;为生成器合成的图像输入到第m个训练好的教师模型后第l个BN层得到的均值,为生成器合成的图像输入到第m个训练好的教师模型后第l个BN层得到的方差,为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的均值,为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的方差。
可选的,所述多教师集成蒸馏损失函数为
可选的,所述教师内部对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像进行数据增强,并将每张所述图像和数据增强后的图像分别输入至每个训练好的教师模型中,获得每个训练好的教师模型输出的每张所述图像的表示和数据增强后的图像的表示;
任意选取生成器合成的同一批量图像中的一张图像为待测图像;
将待测图像的表示和数据增强后的待测图像的表示作为正样本对,生成器合成的同一批量图像中待测图像以外的图像的表示作为负样本,并将生成器合成的历史图像的表示作为负样本;
根据正样本对和负样本,确定教师内部对比损失函数为式中,为教师内部对比损失,为正样本对,为第m个训练好的教师模型输出的待测图像的表示, 为第m个头部映射网络的参数,h为头部投影网络,为第m个训练好的教师模型输出的数据增强后的待测图像的表示,为第m个训练好的教师模型对应的第i个负样本,K为负样本的数量,τ1为第一温度超参数,sim()为余弦相似度。
可选的,所述师生对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像分别输入至每个训练好的教师模型和学生模型,获得每个训练好的教师模型对每张图像的表示和学生模型对每张图像的表示;
将学生模型和与学生模型同构的训练好的教师模型对同一张图像的表示作为负样本对;所述同构为学生模型和教师模型属于同一网络结构系列;
将学生模型和与学生模型异构的训练好的教师模型对同一张图像的表示定义为正样本对;所述异构为学生模型和教师模型属于不同网络结构系列;
根据所述负样本对和定义的正样本对,确定师生对比损失函数为式中,为师生对比损失,为从生成器合成的当前批量图像中第r张图像构造出的学生模型的查询, 为生成器合成的图像输入到学生模型的输出,θh为头部映射网络的参数,h为头部投影网络,为从生成器合成的当前批量图像中第r张图像构造出的与学生模型异构的第m个训练好的教师模型的查询,D(s)为与学生模型异构的教师索引集,N为当前批量图像中图像的数量,τ2为第一温度超参数,Neg为负对集合,I(s)为与学生模型同构的教师索引集,为学生模型输出的历史图像的表示中第j个负样本,J为学生模型输出的历史图像的表示中负样本的数量,为从生成器合成的历史图像中第j张图像构造出的学生网络的查询,为与学生网络模型同构的教师模型的查询。
可选的,所述利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作,具体包括:
初始化生成器的参数θg、学生模型的参数θs和图像库;
根据噪声输入z,利用生成器合成当前批量的图像;
根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
将生成器合成的当前批量的图像存储至图像库;
从图像库中抽取一批量图像;
根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
可选的,所述学生模型收敛是指学生模型的当前损失与前一次迭代计算的学生模型的损失相等。
一种基于数据保护场景下的模型压缩系统,所述系统包括:
预设模块,用于预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
三种损失函数构建模块,用于分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合模块,用于组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
多教师集成蒸馏损失函数构建模块,用于构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
对比损失函数建立模块,用于分别建立教师内部对比损失函数和师生对比损失函数;
优化模块,用于利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
循环模块,用于重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
可选的,所述优化模块,具体包括:
优化损失函数确定子模块,用于根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为式中,为生成器的优化损失,为教师内部对比损失函数和师生对比损失函数的总函数,λ为和之间的平衡参数;
初始化子模块,用于初始化生成器的参数θg、学生模型的参数θs和图像库;
合成子模块,用于根据噪声输入z,利用生成器合成当前批量的图像;
优化损失计算子模块,用于根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
存储子模块,用于将生成器合成的当前批量的图像存储至图像库;
抽取子模块,用于从图像库中抽取一批量图像;
学生模型当前损失计算子模块,用于根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
根据本发明提供的具体实施例,本发明公开了以下技术效果:
本发明公开一种基于数据保护场景下的模型压缩方法及系统,首先将传统的模型反演扩展到基于多教师集成的模型反演,充分反演来自教师的更丰富的信息以生成可泛化的数据。另外提出了多教师和学生之间的对比交互正则化,其中包含教师内对比和师生对比,教师内部对比用于逐步合成具有与历史样本不同模式的新样本,而师生对比旨在推动学生与同构教师之间的关系远离表示空间中的非同构关系,以提高合成数据的多样性。并且以对抗的方式训练图像生成和知识转移的过程,以同时学习学生模型和生成合成数据。本发明不依赖于模型原始的训练数据,通过多教师模型的知识蒸馏以及引入基于对比学习的损失函数,对无数据的模型压缩方法进行有效压缩和压缩后的模型具有更高的准确率。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本发明提供的基于数据保护场景下的模型压缩方法框架图;
图2为本发明提供的优化操作的流程图;
图3为本发明实施例提供的图片生成效果比较图;
图4为本发明实施例提供的数据分布比较图;图4(a)为MTCKI的数据分布图,图4(b)为CMI的数据分布图,图4(c)为CIFAR-10的数据分布图;
图5为不同方法训练损失曲线图;
图6为不同epoch合成的图像对比图;图6(a)为第10个epoch合成的图像,图6(b)为第100个epoch合成的图像。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明的目的是提供一种基于数据保护场景下的模型压缩方法及系统,不依赖于模型原始的训练数据,对无数据的模型进行有效压缩同时提高模型压缩的准确率。
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本发明作进一步详细的说明。
本发明提供了一种基于数据保护场景下的模型压缩方法,参照图1,包括以下步骤:
步骤1,预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器。
随机挑选一些在同一数据集下训练好的教师模型,以及随机初始化的学生模型和生成器。将随机向量输入到生成器得到合成的图片。
步骤2,分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数。
2-1,构建one-hot预测损失函数
生成器所产生的每一个图片都应当属于一个类别,所以我们将图片输入到教师网络后得到的logits与logits值最大的那个类别计算交叉熵损失CE。
2-2,构建特征正则化损失函数
BN层已广泛用于CNN,通过运行平均统计量(例如:运行均值μ(x)和运行方差σ2(x))在训练期间。训练后,这些统计数据存储了有关X的丰富信息。
对于生成器产生的每一个批量的图片输入到所有的教师网络中,BatchNormalization(BN)层得到的均值和方差和预先训练好的教师网络中的均值E(μl(x)|X)和方差计算二范数,再对每一个Batch Normalization层得到的二范数损失累计求和。
特征正则化损失函数为式中,为特征正则化损失,为生成器合成的图像输入到训练好的教师模型后第l个BN层得到的均值,为生成器合成的图像输入到训练好的教师模型后第l个BN层得到的方差,E(μl(x)|X)为输入图像x输入到训练好的教师模型后第l个BN层得到的均值,为输入图像x输入到训练好的教师模型后第l个BN层得到的方差;
2-3,构建对抗蒸馏损失函数
提出对抗性蒸馏损失以鼓励合成图像使学生-教师产生较大的分歧,生成器合成的图片通过所有的教师网络和学生网络后得到的分布拉远来保证生成的图片具有多样性。
生成器G可以通过最小化方程来生成广义图像。因为它反转了来自多个预训练教师的知识。然而,合成图像仍然缺乏多样性,这可能导致在重新训练期间过度拟合。为此,提出了多名教师和一名学生之间的对比互动,以提高数据多样性并产生高保真图像。
步骤3,组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数。
多视图结构非常普遍存在于许多现实世界的数据集中。这些数据中存在多个特征,可用于正确分类图像。通过观察翅膀、身体大小或嘴巴的形状,可以将鸟类图像分类为鸟类。即使学生可以提取老师学习的所有特征,他们仍然无法“看到”老师没有发现的特征,从而限制了学生的表现。即使某些模型缺少单个学生可以学习多视图知识的视图,集成也可以收集几乎所有这些视图。我们首先考虑多个集成教师来构建一个可靠的多分支模型。我们选择所有教师的平均最终输出作为模型预测,此外,我们使用不同的教师来获取各种统计知识,以提高合成图像的可生成性和多样性。
在一个示例中,多教师条件下生成器的无数据蒸馏的模型反转损失函数的获得步骤为:
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得单教师条件下生成器的无数据蒸馏的模型反转损失函数为式中,为单教师条件下生成器的无数据蒸馏的模型反转损失,λ1、λ2和λ3分别为第一平衡参数、第二平衡参数和第三平衡参数;
根据单教师条件下生成器的无数据蒸馏的模型反转损失函数,构建多教师条件下生成器的无数据蒸馏的模型反转损失函数为式中,为多教师条件下生成器的无数据蒸馏的模型反转损失,为具有多教师信息的one-hot预测损失,为具有多教师信息的特征正则化损失,为具有多教师信息的对抗蒸馏损失,z为噪声输入,θg为生成器的参数;
其中,为M数量的集成教师模型的输出, 为第m个训练好的教师模型的输出;为生成器合成的图像输入到第m个训练好的教师模型后第l个BN层得到的均值,为生成器合成的图像输入到第m个训练好的教师模型后第l个BN层得到的方差,为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的均值,为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的方差。
通过多教师集成蒸馏损失优化学生网络,使得学生网络模仿教师网络的输出。
步骤4,构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数。
对于每个预先训练好的教师模型以及学生模型输入图片得到教师网络和学生网络输出的logits,对教师网络输出的logits取平均得到集成后的logits,再对集成后的logits与学生网络输出到logits计算KL散度作为多教师集成蒸馏损失函数。
示例性的,多教师集成蒸馏损失函数为
步骤5,分别建立教师内部对比损失函数和师生对比损失函数。
5-1,教师内部对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像进行数据增强,并将每张所述图像和数据增强后的图像分别输入至每个训练好的教师模型中,获得每个训练好的教师模型输出的每张所述图像的表示和数据增强后的图像的表示;
任意选取生成器合成的同一批量图像中的一张图像为待测图像;
将待测图像的表示和数据增强后的待测图像的表示作为正样本对,生成器合成的同一批量图像中待测图像以外的图像的表示作为负样本,并将生成器合成的历史图像的表示作为负样本;
根据正样本对和负样本,确定教师内部对比损失函数为式中,为教师内部对比损失,为正样本对,为第m个训练好的教师模型输出的待测图像的表示, 为第m个头部映射网络的参数,h为头部投影网络,为第m个训练好的教师模型输出的数据增强后的待测图像的表示,为第m个训练好的教师模型对应的第i个负样本,K为负样本的数量,τ1为第一温度超参数,sim()为余弦相似度。
5-2,师生对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像分别输入至每个训练好的教师模型和学生模型,获得每个训练好的教师模型对每张图像的表示和学生模型对每张图像的表示;
将学生模型和与学生模型同构的训练好的教师模型对同一张图像的表示作为负样本对;所述同构为学生模型和教师模型属于同一网络结构系列;学生模型和教师模型属于哪一个网络结构系列在选取的时候是已知的;
将学生模型和与学生模型异构的训练好的教师模型对同一张图像的表示定义为正样本对;所述异构为学生模型和教师模型属于不同网络结构系列;
根据所述负样本对和定义的正样本对,确定师生对比损失函数为式中,为师生对比损失,为从生成器合成的当前批量图像中第r张图像构造出的学生模型的查询, 为生成器合成的图像输入到学生模型的输出,θh为头部映射网络的参数,h为头部投影网络,为从生成器合成的当前批量图像中第r张图像构造出的与学生模型异构的第m个训练好的教师模型的查询,D(s)为与学生模型异构的教师索引集,N为当前批量图像中图像的数量,τ2为第一温度超参数,Neg为负对集合,I(s)为与学生模型同构的教师索引集,为学生模型输出的历史图像的表示中第j个负样本,J为学生模型输出的历史图像的表示中负样本的数量,为从生成器合成的历史图像中第j张图像构造出的学生网络的查询,为与学生网络模型同构的教师模型的查询。
步骤6,利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作。
示例性的,步骤6具体包括:
根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为式中,为生成器的优化损失,为教师内部对比损失函数和师生对比损失函数的总函数,λ为和之间的平衡参数;优选地,λ的值为0.2;
初始化生成器的参数θg、学生模型的参数θs和图像库;
根据噪声输入z,利用生成器合成当前批量的图像;
根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
将生成器合成的当前批量的图像存储至图像库;
从图像库中抽取一批量图像;
根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
步骤7,重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
学生模型收敛是指学生模型的当前损失与前一次迭代计算的学生模型的损失相等。
参照图2,步骤6和步骤7的无数据蒸馏方法(Multi-teacher contrastiveKnowledgeinversion,MTCKI)的整体流程为:
输入:预训练教师模型,子集大小M
初始化:G(·;θg),z←(0,1),学生模型fs(·;θs),图像库。
对于e=1:最大纪元
对于i=1:最大迭代次数
对于t=1:最大步数
从噪声z生成一批样本G(z)。
将样本G(z)存储到图像库中。
结束
对于k=1:最大步数
从图像库中抽取一批图像G(z)。
结束
结束
输出:学生模型fS(·;θs)和图像库。
使用测试集数据对压缩后得到的学生模型以及合成的数据进行测试,所述测试指标为分类的Accuracy为和FID。
本发明的目的是针对现有技术的不足而提出的新的无数据蒸馏方法,基于多教师对比学习的模型压缩方法,在实施的过程中,一个学生可以访问多个教师,多个教师网络提供全面的指导,有利于训练一个对模型偏差具有鲁棒性的学生模型。首先将传统的模型反演扩展到基于多教师集成的模型反演,充分反演来自教师的更丰富的信息以生成可泛化的数据。另外,提出了多教师和学生之间的对比交互正则化,其中包含教师内对比和师生对比,以提高合成数据的多样性。具体来说,教师内部对比用于逐步合成具有与历史样本不同模式的新样本,而师生对比旨在推动学生与同构教师之间的关系远离表示空间中的非同构关系。以对抗的方式训练图像生成和知识转移的过程,以同时学习学生模型和生成合成数据。
本发明不依赖于模型原始的训练数据,通过多教师模型的知识蒸馏以及引入基于对比学习的损失函数,对无数据的模型压缩方法进行有效压缩和压缩后的模型具有更高的准确率,反演生成的图片具有多样性和泛化性。
本发明与现有技术相比压缩后的模型准确率更高,且适用于不同网络结构的学生网络,可以一次生成适用于多种不同的学生网络,多次生成不同数据集的计算开销以及时间,这种优点来自于生成器所合成的图片具有多样性和泛化性,由于生成器合成的图片与历史的图片进行对比,使得图片和之前生成的图片差异尽可能大,产生更丰富样式的图片,由于多个教师的正则化,以及同源的学生与其教师网络之间图片分布的拉远,消除了生成数据只适用于某一特定网络的缺点,使得图片更具有泛化性,和真实的数据集分布更接近,使得网络性能得到提升。
下面的实施例在3个公开数据集上进行测试,分别为CIFAR10、CIFAR100、Caltech-101,同时对模型进行压缩得到不同的学生网络,其测试结果如下表1:
表1数据集上性能提升对比
图像合成的评价指标通常使用FID来衡量。数值越低,表示合成的图片越接近真实的图片,生成质量越好。采用本发明通过无数据的知识蒸馏对模型进行压缩,使用该方法压缩后的模型的准确率是最高的,而且合成的图片的FID值更低,这也就意味着合成的图片更接近原始数据集。
总结以上的在CIFAR-10、CIFAR-100和Caltech-101上的比较结果。可以观察到:(1)本发明的方法在所有三个数据集上都优于现有方法。例如,当在CIFAR-10上提取相同的WRN-16-1时,本发明的方法达到了91.59%。(2)在本发明的多教师集成和CIFAR-10上的ResNet-34几乎相同的准确度(~95.7%)的情况下,本发明的方法在压缩到同一个小网络的准确率比其他基线有了显着的提升。这也就说明性能的提升来自于多教师结构和提出的对比交互损失,而不是教师的预测的提升。(3)教师和学生之间的同构结构有助于进一步提高学生在所有基线中的表现。例如,与ResNet-34相比,本发明使用相同的WRN-16-1作为学生,WRN-40-2作为教师显着提高了WRN-16-1在CIFAR-10上的准确度。(4)本发明的预训练教师没有使用MobileNet-V2;本发明的合成图像仍然可以有效地训练模型。本发明的方法比使用原始数据训练的MobileNet-V2实现了3.22%的准确度提升。这意味着使用提出的MTCKI的合成图像对于各种模型的训练具有很高的泛化性。(5)与其他方法相比,本发明用不同的学生模型生成的数据集的FID值都是最低的,并且方差要小得多。这意味着本发明的合成图像与原始数据集更接近。
参阅附图3,本发明可以更好合成图像的细节,在视觉效果上具有优越性。将该方法与SOTA方法的合成质量进行了比较显然,本发明的MTCKI反转的图像质量最高。例如,DAFL使用CIFAR-10上的预训练教师生成类噪声图像。Deepinv能够生成满意的视觉图像,但物体颜色与背景颜色接近,风格单一。因此,它与原始的CIFAR-10数据集相距甚远。DFQ和MTCKI的合成图像之间的比较表明,MTCKI可以生成更多样化的图像,而DFQ则遭受更严重的模式崩溃。尽管CMI采样的图像在颜色和风格上似乎有了一些改进,但它们仍然过于模糊而无法区分。本发明的方法在对象轮廓的清晰度、颜色匹配的合理性以及丰富、详细的信息方面提高了图像质量。对于CIFAR-10,MTCKI生成更多样化的语义图像,例如不同姿势的马的特写和各种类型的卡车。即使是像船后面的天际线这样的微小细节也能够被合成的。对于CIFAR-100,合成图像提供了丰富的语义信息,本发明可以轻松识别图中显示的对象,如熊猫、自行车、鲜花。
参阅附图4,本发明通过VGG16对合成数据集的特征画了t-sne图,可以看到的是本发明合成的数据的同一个类别聚类特点明显,与原数据的分布相似。
参阅附图5,本发明进一步分析了该方法的收敛性和变化的合成图像。与其他基线相比,本发明的方法需要更少的训练epoch来收敛,并且还实现了最低的损失。值得注意的是,在训练过程中,由于丰富的多教师信息和对比交互的有效性,如图6所示,第10个epoch合成的图像已经具有多样化的语义信息和组织良好的对象轮廓。
本发明的方法从可用的教师模型中提取特定于模型的知识并将其融合到学生模型中,以消除模型偏差。此外,使用多教师和学生之间的对比交互来提高合成图像的多样性,这鼓励合成图像与以前存储的图像区分开来。本发明与现有技术相比具有在图像生成过程中能产生更具有多样性和泛化性的图片,以及只需要生成一次就可以为各种网络而不是特定网络提供全面的指导的优势。大量实验表明,该方法不仅生成视觉上令人满意的图像,而且优于现有的最先进的方法。
本发明的方法作为一个用于无数据蒸馏的新框架,从多个可用的教师模型中提取“多视角”知识并将其融合到表现良好的学生模型中。上述方法设计了一种对比交互,充分利用来自多位师生的知识,生成具有高泛化性和多样性的合成数据。大量实验表明上述方法(MTCKI)优于现有的最先进方法。不仅合成了更接近原始数据集分布的高保真数据,而且还产生了与在原始数据集上训练的预训练模型相当的结果。本发明与现有技术相比压缩后的模型准确率更高,且适用于不同网络结构的学生网络,可以一次生成适用于多种不同的学生网络,多次生成不同数据集的计算开销以及时间,这种优点来自于生成器所合成的图片具有多样性和泛化性,由于生成器合成的图片与历史的图片进行对比,使得图片和之前生成的图片差异尽可能大,产生更丰富样式的图片,由于多个教师的正则化,以及同源的学生与其教师网络之间图片分布的拉远,消除了生成数据只适用于某一特定网络的缺点,使得图片更具有泛化性,和真实的数据集分布更接近,使得网络性能得到提升。
本发明还提供了一种基于数据保护场景下的模型压缩系统,系统包括:
预设模块,用于预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
三种损失函数构建模块,用于分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合模块,用于组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
多教师集成蒸馏损失函数构建模块,用于构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
对比损失函数建立模块,用于分别建立教师内部对比损失函数和师生对比损失函数;
优化模块,用于利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
循环模块,用于重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
优化模块,具体包括:
优化损失函数确定子模块,用于根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为式中,为生成器的优化损失,为教师内部对比损失函数和师生对比损失函数的总函数,λ为和之间的平衡参数;
初始化子模块,用于初始化生成器的参数θg、学生模型的参数θs和图像库;
合成子模块,用于根据噪声输入z,利用生成器合成当前批量的图像;
优化损失计算子模块,用于根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
存储子模块,用于将生成器合成的当前批量的图像存储至图像库;
抽取子模块,用于从图像库中抽取一批量图像;
学生模型当前损失计算子模块,用于根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的系统而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的方法及其核心思想;同时,对于本领域的一般技术人员,依据本发明的思想,在具体实施方式及应用范围上均会有改变之处。综上所述,本说明书内容不应理解为对本发明的限制。
Claims (10)
1.一种基于数据保护场景下的模型压缩方法,其特征在于,所述方法包括:
预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
分别建立教师内部对比损失函数和师生对比损失函数;
利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
2.根据权利要求1所述的基于数据保护场景下的模型压缩方法,其特征在于,
所述特征正则化损失函数为式中,为特征正则化损失,为生成器合成的图像输入到训练好的教师模型后第l个BN层得到的均值,为生成器合成的图像输入到训练好的教师模型后第l个BN层得到的方差,F(μl(x)|X)为输入图像x输入到训练好的教师模型后第l个BN层得到的均值,为输入图像x输入到训练好的教师模型后第l个BN层得到的方差;
3.根据权利要求2所述的基于数据保护场景下的模型压缩方法,其特征在于,所述组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数,具体包括:
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得单教师条件下生成器的无数据蒸馏的模型反转损失函数为式中,为单教师条件下生成器的无数据蒸馏的模型反转损失,λ1、λ2和λ3分别为第一平衡参数、第二平衡参数和第三平衡参数;
根据单教师条件下生成器的无数据蒸馏的模型反转损失函数,构建多教师条件下生成器的无数据蒸馏的模型反转损失函数为式中,为多教师条件下生成器的无数据蒸馏的模型反转损失,为具有多教师信息的one-hot预测损失,为具有多教师信息的特征正则化损失,为具有多教师信息的对抗蒸馏损失,z为噪声输入,θg为生成器的参数;
5.根据权利要求4所述的基于数据保护场景下的模型压缩方法,其特征在于,所述教师内部对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像进行数据增强,并将每张所述图像和数据增强后的图像分别输入至每个训练好的教师模型中,获得每个训练好的教师模型输出的每张所述图像的表示和数据增强后的图像的表示;
任意选取生成器合成的同一批量图像中的一张图像为待测图像;
将待测图像的表示和数据增强后的待测图像的表示作为正样本对,生成器合成的同一批量图像中待测图像以外的图像的表示作为负样本,并将生成器合成的历史图像的表示作为负样本;
6.根据权利要求5所述的基于数据保护场景下的模型压缩方法,其特征在于,所述师生对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像分别输入至每个训练好的教师模型和学生模型,获得每个训练好的教师模型对每张图像的表示和学生模型对每张图像的表示;
将学生模型和与学生模型同构的训练好的教师模型对同一张图像的表示作为负样本对;所述同构为学生模型和教师模型属于同一网络结构系列;
将学生模型和与学生模型异构的训练好的教师模型对同一张图像的表示定义为正样本对;所述异构为学生模型和教师模型属于不同网络结构系列;
根据所述负样本对和定义的正样本对,确定师生对比损失函数为式中,为师生对比损失,为从生成器合成的当前批量图像中第r张图像构造出的学生模型的查询,为生成器合成的图像输入到学生模型的输出,θh为头部映射网络的参数,h为头部投影网络,为从生成器合成的当前批量图像中第r张图像构造出的与学生模型异构的第m个训练好的教师模型的查询,D(s)为与学生模型异构的教师索引集,N为当前批量图像中图像的数量,τ2为第一温度超参数,Neg为负对集合,I(s)为与学生模型同构的教师索引集,为学生模型输出的历史图像的表示中第j个负样本,J为学生模型输出的历史图像的表示中负样本的数量,为从生成器合成的历史图像中第j张图像构造出的学生网络的查询,为与学生网络模型同构的教师模型的查询。
7.根据权利要求6所述的基于数据保护场景下的模型压缩方法,其特征在于,所述利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作,具体包括:
初始化生成器的参数θg、学生模型的参数θs和图像库;
根据噪声输入z,利用生成器合成当前批量的图像;
根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
将生成器合成的当前批量的图像存储至图像库;
从图像库中抽取一批量图像;
根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
8.根据权利要求7所述的基于数据保护场景下的模型压缩方法,其特征在于,所述学生模型收敛是指学生模型的当前损失与前一次迭代计算的学生模型的损失相等。
9.一种基于数据保护场景下的模型压缩系统,其特征在于,所述系统包括:
预设模块,用于预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
三种损失函数构建模块,用于分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合模块,用于组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
多教师集成蒸馏损失函数构建模块,用于构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
对比损失函数建立模块,用于分别建立教师内部对比损失函数和师生对比损失函数;
优化模块,用于利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
循环模块,用于重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
10.根据权利要求9所述的基于数据保护场景下的模型压缩系统,其特征在于,所述优化模块,具体包括:
优化损失函数确定子模块,用于根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为式中,为生成器的优化损失,为教师内部对比损失函数和师生对比损失函数的总函数,λ为和之间的平衡参数;
初始化子模块,用于初始化生成器的参数θg、学生模型的参数θs和图像库;
合成子模块,用于根据噪声输入z,利用生成器合成当前批量的图像;
优化损失计算子模块,用于根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
存储子模块,用于将生成器合成的当前批量的图像存储至图像库;
抽取子模块,用于从图像库中抽取一批量图像;
学生模型当前损失计算子模块,用于根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210220060.9A CN114565810A (zh) | 2022-03-08 | 2022-03-08 | 一种基于数据保护场景下的模型压缩方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210220060.9A CN114565810A (zh) | 2022-03-08 | 2022-03-08 | 一种基于数据保护场景下的模型压缩方法及系统 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114565810A true CN114565810A (zh) | 2022-05-31 |
Family
ID=81718168
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210220060.9A Pending CN114565810A (zh) | 2022-03-08 | 2022-03-08 | 一种基于数据保护场景下的模型压缩方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114565810A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117573908A (zh) * | 2024-01-16 | 2024-02-20 | 卓世智星(天津)科技有限公司 | 基于对比学习的大语言模型蒸馏方法 |
-
2022
- 2022-03-08 CN CN202210220060.9A patent/CN114565810A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117573908A (zh) * | 2024-01-16 | 2024-02-20 | 卓世智星(天津)科技有限公司 | 基于对比学习的大语言模型蒸馏方法 |
CN117573908B (zh) * | 2024-01-16 | 2024-03-19 | 卓世智星(天津)科技有限公司 | 基于对比学习的大语言模型蒸馏方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108875807B (zh) | 一种基于多注意力多尺度的图像描述方法 | |
CN111798369B (zh) | 一种基于循环条件生成对抗网络的人脸衰老图像合成方法 | |
CN114398961B (zh) | 一种基于多模态深度特征融合的视觉问答方法及其模型 | |
CN110097178A (zh) | 一种基于熵注意的神经网络模型压缩与加速方法 | |
CN109166144A (zh) | 一种基于生成对抗网络的图像深度估计方法 | |
CN112784929B (zh) | 一种基于双元组扩充的小样本图像分类方法及装置 | |
CN111210002B (zh) | 一种基于生成对抗网络模型的多层学术网络社区发现方法、系统 | |
CN112527993B (zh) | 一种跨媒体层次化深度视频问答推理框架 | |
CN114386534A (zh) | 一种基于变分自编码器和对抗生成网络的图像增广模型训练方法及图像分类方法 | |
CN109978021A (zh) | 一种基于文本不同特征空间的双流式视频生成方法 | |
CN109871504A (zh) | 一种基于异构信息网络与深度学习的课程推荐系统 | |
CN114511737B (zh) | 图像识别域泛化模型的训练方法 | |
CN111694977A (zh) | 一种基于数据增强的车辆图像检索方法 | |
CN113822953A (zh) | 图像生成器的处理方法、图像生成方法及装置 | |
CN109214442A (zh) | 一种基于列表和身份一致性约束的行人重识别算法 | |
CN110210540A (zh) | 基于注意力机制的跨社交媒体用户身份识别方法及系统 | |
CN113849725B (zh) | 一种基于图注意力对抗网络的社会化推荐方法及系统 | |
WO2022166840A1 (zh) | 人脸属性编辑模型的训练方法、人脸属性编辑方法及设备 | |
CN114565810A (zh) | 一种基于数据保护场景下的模型压缩方法及系统 | |
CN110197226B (zh) | 一种无监督图像翻译方法及系统 | |
CN117036901A (zh) | 一种基于视觉自注意力模型的小样本微调方法 | |
CN113822790B (zh) | 一种图像处理方法、装置、设备及计算机可读存储介质 | |
CN113283584B (zh) | 一种基于孪生网络的知识追踪方法及系统 | |
CN115909201A (zh) | 一种基于多分支联合学习的遮挡行人重识别方法及系统 | |
CN114742292A (zh) | 面向知识追踪过程的双态协同演化预测学生未来表现方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |