CN114565810A - 一种基于数据保护场景下的模型压缩方法及系统 - Google Patents

一种基于数据保护场景下的模型压缩方法及系统 Download PDF

Info

Publication number
CN114565810A
CN114565810A CN202210220060.9A CN202210220060A CN114565810A CN 114565810 A CN114565810 A CN 114565810A CN 202210220060 A CN202210220060 A CN 202210220060A CN 114565810 A CN114565810 A CN 114565810A
Authority
CN
China
Prior art keywords
model
teacher
loss function
generator
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210220060.9A
Other languages
English (en)
Inventor
林绍辉
林振元
何高奇
王长波
马利庄
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
East China Normal University
Original Assignee
East China Normal University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by East China Normal University filed Critical East China Normal University
Priority to CN202210220060.9A priority Critical patent/CN114565810A/zh
Publication of CN114565810A publication Critical patent/CN114565810A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/217Validation; Performance evaluation; Active pattern learning techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/50Information retrieval; Database structures therefor; File system structures therefor of still image data
    • G06F16/53Querying
    • G06F16/535Filtering based on additional data, e.g. user or group profiles
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Biomedical Technology (AREA)
  • Evolutionary Biology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及一种基于数据保护场景下的模型压缩方法及系统,属于模型压缩领域,将传统的模型反演扩展到基于多教师集成的模型反演,充分反演来自教师的更丰富的信息以生成可泛化的数据。另外教师内部对比用于逐步合成具有与历史样本不同模式的新样本,师生对比旨在推动学生与同构教师之间的关系远离表示空间中的非同构关系,以提高合成数据的多样性。并且以对抗的方式训练图像生成和知识转移的过程,以同时学习学生模型和生成合成数据。本发明不依赖于模型原始的训练数据,通过多教师模型的知识蒸馏以及引入基于对比学习的损失函数,对无数据的模型压缩方法进行有效压缩和压缩后的模型具有更高的准确率。

Description

一种基于数据保护场景下的模型压缩方法及系统
技术领域
本发明涉及模型压缩领域,特别是涉及一种基于数据保护场景下的模型压缩方法及系统。
背景技术
近年来随着深度学习算力的不断发展,深度学习模型愈发庞大,而当我们需要将模型部署到终端设备上时,就不得不进行模型的压缩。知识蒸馏(KnowledgeDistillation,KD)是一种流行的压缩方法,通过从冗余的教师模型中转移知识来学习轻量级学生模型来模仿表征能力。在大多数现有的KD方法中,使用logits或来自教师的特征信息成功地将知识转移到学生模型,但是在其中需要访问整个训练数据。
不幸的是,由于隐私、保密或传输限制,预训练模型的原始训练样本通常不可用。例如,患者的医疗数据是保密的,不会公开共享以泄露患者的隐私。如果没有数据的帮助,这些方法可能无法适用。
现有技术的方法是通过人为合成的训练数据来代替原始数据。但现有的方法所生成的数据都和原始数据有一定差距,缺乏数据的多样性和泛化性。压缩后模型的准确率不够令人满意。
发明内容
本发明的目的是提供一种基于数据保护场景下的模型压缩方法及系统,不依赖于模型原始的训练数据,对无数据的模型进行有效压缩同时提高模型压缩的准确率。
为实现上述目的,本发明提供了如下方案:
一种基于数据保护场景下的模型压缩方法,所述方法包括:
预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
分别建立教师内部对比损失函数和师生对比损失函数;
利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
可选的,所述one-hot预测损失函数为
Figure BDA0003536854210000021
式中,
Figure BDA0003536854210000022
为one-hot预测损失,CE为交叉熵损失,
Figure BDA0003536854210000023
为生成器合成的图像
Figure BDA0003536854210000024
输入到训练好的教师模型后的输出,c为预定义的类;
所述特征正则化损失函数为
Figure BDA0003536854210000025
式中,
Figure BDA0003536854210000026
为特征正则化损失,
Figure BDA0003536854210000027
为生成器合成的图像
Figure BDA0003536854210000028
输入到训练好的教师模型后第l个BN层得到的均值,
Figure BDA0003536854210000029
为生成器合成的图像
Figure BDA00035368542100000210
输入到训练好的教师模型后第l个BN层得到的方差,F(μl(x)|X)为输入图像x输入到训练好的教师模型后第l个BN层得到的均值,
Figure BDA00035368542100000211
为输入图像x输入到训练好的教师模型后第l个BN层得到的方差;
所述对抗蒸馏损失函数为
Figure BDA00035368542100000212
式中,
Figure BDA00035368542100000213
为对抗蒸馏损失,KL为库尔贝克-莱布尔散度,
Figure BDA00035368542100000214
为生成器合成的图像
Figure BDA00035368542100000215
输入到学生模型的输出,τ为温度。
可选的,所述组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数,具体包括:
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得单教师条件下生成器的无数据蒸馏的模型反转损失函数为
Figure BDA0003536854210000035
式中,
Figure BDA0003536854210000036
为单教师条件下生成器的无数据蒸馏的模型反转损失,λ1、λ2和λ3分别为第一平衡参数、第二平衡参数和第三平衡参数;
根据单教师条件下生成器的无数据蒸馏的模型反转损失函数,构建多教师条件下生成器的无数据蒸馏的模型反转损失函数为
Figure BDA0003536854210000037
式中,
Figure BDA0003536854210000038
为多教师条件下生成器的无数据蒸馏的模型反转损失,
Figure BDA0003536854210000039
为具有多教师信息的one-hot预测损失,
Figure BDA00035368542100000310
为具有多教师信息的特征正则化损失,
Figure BDA00035368542100000311
为具有多教师信息的对抗蒸馏损失,z为噪声输入,θg为生成器的参数;
Figure BDA0003536854210000031
Figure BDA0003536854210000032
Figure BDA0003536854210000033
其中,
Figure BDA00035368542100000312
为M数量的集成教师模型的输出,
Figure BDA00035368542100000313
Figure BDA00035368542100000314
为第m个训练好的教师模型的输出;
Figure BDA00035368542100000315
为生成器合成的图像
Figure BDA00035368542100000317
输入到第m个训练好的教师模型后第l个BN层得到的均值,
Figure BDA00035368542100000316
为生成器合成的图像
Figure BDA00035368542100000318
输入到第m个训练好的教师模型后第l个BN层得到的方差,
Figure BDA00035368542100000319
为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的均值,
Figure BDA00035368542100000320
为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的方差。
可选的,所述多教师集成蒸馏损失函数为
Figure BDA0003536854210000034
式中,
Figure BDA00035368542100000321
为多教师集成蒸馏损失。
可选的,所述教师内部对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像进行数据增强,并将每张所述图像和数据增强后的图像分别输入至每个训练好的教师模型中,获得每个训练好的教师模型输出的每张所述图像的表示和数据增强后的图像的表示;
任意选取生成器合成的同一批量图像中的一张图像为待测图像;
将待测图像的表示和数据增强后的待测图像的表示作为正样本对,生成器合成的同一批量图像中待测图像以外的图像的表示作为负样本,并将生成器合成的历史图像的表示作为负样本;
根据正样本对和负样本,确定教师内部对比损失函数为
Figure BDA0003536854210000041
式中,
Figure BDA0003536854210000043
为教师内部对比损失,
Figure BDA0003536854210000044
为正样本对,
Figure BDA0003536854210000045
为第m个训练好的教师模型输出的待测图像的表示,
Figure BDA0003536854210000046
Figure BDA0003536854210000047
为第m个头部映射网络的参数,h为头部投影网络,
Figure BDA0003536854210000048
为第m个训练好的教师模型输出的数据增强后的待测图像的表示,
Figure BDA0003536854210000049
为第m个训练好的教师模型对应的第i个负样本,K为负样本的数量,τ1为第一温度超参数,sim()为余弦相似度。
可选的,所述师生对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像分别输入至每个训练好的教师模型和学生模型,获得每个训练好的教师模型对每张图像的表示和学生模型对每张图像的表示;
将学生模型和与学生模型同构的训练好的教师模型对同一张图像的表示作为负样本对;所述同构为学生模型和教师模型属于同一网络结构系列;
将学生模型和与学生模型异构的训练好的教师模型对同一张图像的表示定义为正样本对;所述异构为学生模型和教师模型属于不同网络结构系列;
根据所述负样本对和定义的正样本对,确定师生对比损失函数为
Figure BDA0003536854210000042
式中,
Figure BDA00035368542100000410
为师生对比损失,
Figure BDA00035368542100000411
为从生成器合成的当前批量图像中第r张图像构造出的学生模型的查询,
Figure BDA00035368542100000413
Figure BDA00035368542100000414
为生成器合成的图像
Figure BDA00035368542100000412
输入到学生模型的输出,θh为头部映射网络的参数,h为头部投影网络,
Figure BDA00035368542100000415
为从生成器合成的当前批量图像中第r张图像构造出的与学生模型异构的第m个训练好的教师模型的查询,D(s)为与学生模型异构的教师索引集,N为当前批量图像中图像的数量,τ2为第一温度超参数,Neg为负对集合,
Figure BDA0003536854210000051
I(s)为与学生模型同构的教师索引集,
Figure BDA0003536854210000054
为学生模型输出的历史图像的表示中第j个负样本,J为学生模型输出的历史图像的表示中负样本的数量,
Figure BDA0003536854210000052
为从生成器合成的历史图像中第j张图像构造出的学生网络的查询,
Figure BDA0003536854210000053
为与学生网络模型同构的教师模型的查询。
可选的,所述利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作,具体包括:
根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为
Figure BDA0003536854210000055
式中,
Figure BDA0003536854210000056
为生成器的优化损失,
Figure BDA0003536854210000057
为教师内部对比损失函数和师生对比损失函数的总函数,
Figure BDA0003536854210000058
λ为
Figure BDA0003536854210000059
Figure BDA00035368542100000510
之间的平衡参数;
初始化生成器的参数θg、学生模型的参数θs和图像库;
根据噪声输入z,利用生成器合成当前批量的图像;
根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
将生成器的参数θg更新为
Figure BDA00035368542100000511
其中,η为系数,
Figure BDA00035368542100000512
为梯度算子;
将生成器合成的当前批量的图像存储至图像库;
从图像库中抽取一批量图像;
根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
将学生模型的参数θs更新为
Figure BDA00035368542100000513
可选的,所述学生模型收敛是指学生模型的当前损失与前一次迭代计算的学生模型的损失相等。
一种基于数据保护场景下的模型压缩系统,所述系统包括:
预设模块,用于预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
三种损失函数构建模块,用于分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合模块,用于组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
多教师集成蒸馏损失函数构建模块,用于构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
对比损失函数建立模块,用于分别建立教师内部对比损失函数和师生对比损失函数;
优化模块,用于利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
循环模块,用于重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
可选的,所述优化模块,具体包括:
优化损失函数确定子模块,用于根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为
Figure BDA0003536854210000061
式中,
Figure BDA0003536854210000062
为生成器的优化损失,
Figure BDA0003536854210000063
为教师内部对比损失函数和师生对比损失函数的总函数,
Figure BDA0003536854210000064
λ为
Figure BDA0003536854210000065
Figure BDA0003536854210000066
之间的平衡参数;
初始化子模块,用于初始化生成器的参数θg、学生模型的参数θs和图像库;
合成子模块,用于根据噪声输入z,利用生成器合成当前批量的图像;
优化损失计算子模块,用于根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
生成器参数更新子模块,用于将生成器的参数θg更新为
Figure BDA0003536854210000071
其中,η为系数,
Figure BDA0003536854210000072
为梯度算子;
存储子模块,用于将生成器合成的当前批量的图像存储至图像库;
抽取子模块,用于从图像库中抽取一批量图像;
学生模型当前损失计算子模块,用于根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
学生模型参数更新子模块,用于将学生模型的参数θs更新为
Figure BDA0003536854210000073
根据本发明提供的具体实施例,本发明公开了以下技术效果:
本发明公开一种基于数据保护场景下的模型压缩方法及系统,首先将传统的模型反演扩展到基于多教师集成的模型反演,充分反演来自教师的更丰富的信息以生成可泛化的数据。另外提出了多教师和学生之间的对比交互正则化,其中包含教师内对比和师生对比,教师内部对比用于逐步合成具有与历史样本不同模式的新样本,而师生对比旨在推动学生与同构教师之间的关系远离表示空间中的非同构关系,以提高合成数据的多样性。并且以对抗的方式训练图像生成和知识转移的过程,以同时学习学生模型和生成合成数据。本发明不依赖于模型原始的训练数据,通过多教师模型的知识蒸馏以及引入基于对比学习的损失函数,对无数据的模型压缩方法进行有效压缩和压缩后的模型具有更高的准确率。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1为本发明提供的基于数据保护场景下的模型压缩方法框架图;
图2为本发明提供的优化操作的流程图;
图3为本发明实施例提供的图片生成效果比较图;
图4为本发明实施例提供的数据分布比较图;图4(a)为MTCKI的数据分布图,图4(b)为CMI的数据分布图,图4(c)为CIFAR-10的数据分布图;
图5为不同方法训练损失曲线图;
图6为不同epoch合成的图像对比图;图6(a)为第10个epoch合成的图像,图6(b)为第100个epoch合成的图像。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明的目的是提供一种基于数据保护场景下的模型压缩方法及系统,不依赖于模型原始的训练数据,对无数据的模型进行有效压缩同时提高模型压缩的准确率。
为使本发明的上述目的、特征和优点能够更加明显易懂,下面结合附图和具体实施方式对本发明作进一步详细的说明。
本发明提供了一种基于数据保护场景下的模型压缩方法,参照图1,包括以下步骤:
步骤1,预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器。
随机挑选一些在同一数据集下训练好的教师模型,以及随机初始化的学生模型和生成器。将随机向量输入到生成器得到合成的图片。
步骤2,分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数。
2-1,构建one-hot预测损失函数
生成器所产生的每一个图片都应当属于一个类别,所以我们将图片输入到教师网络后得到的logits与logits值最大的那个类别计算交叉熵损失CE。
one-hot预测损失函数为
Figure BDA0003536854210000091
式中,
Figure BDA0003536854210000092
为one-hot预测损失,CE为交叉熵损失,
Figure BDA0003536854210000093
为生成器合成的图像
Figure BDA0003536854210000094
输入到训练好的教师模型后的输出,c为预定义的类;
2-2,构建特征正则化损失函数
BN层已广泛用于CNN,通过运行平均统计量(例如:运行均值μ(x)和运行方差σ2(x))在训练期间。训练后,这些统计数据存储了有关X的丰富信息。
对于生成器产生的每一个批量的图片输入到所有的教师网络中,BatchNormalization(BN)层得到的均值
Figure BDA0003536854210000095
和方差
Figure BDA0003536854210000096
和预先训练好的教师网络中的均值E(μl(x)|X)和方差
Figure BDA0003536854210000097
计算二范数,再对每一个Batch Normalization层得到的二范数损失累计求和。
特征正则化损失函数为
Figure BDA0003536854210000098
式中,
Figure BDA0003536854210000099
为特征正则化损失,
Figure BDA00035368542100000910
为生成器合成的图像
Figure BDA00035368542100000911
输入到训练好的教师模型后第l个BN层得到的均值,
Figure BDA00035368542100000912
为生成器合成的图像
Figure BDA00035368542100000913
输入到训练好的教师模型后第l个BN层得到的方差,E(μl(x)|X)为输入图像x输入到训练好的教师模型后第l个BN层得到的均值,
Figure BDA00035368542100000914
为输入图像x输入到训练好的教师模型后第l个BN层得到的方差;
2-3,构建对抗蒸馏损失函数
提出对抗性蒸馏损失以鼓励合成图像使学生-教师产生较大的分歧,生成器合成的图片通过所有的教师网络和学生网络后得到的分布拉远来保证生成的图片具有多样性。
对抗蒸馏损失函数为
Figure BDA00035368542100000915
式中,
Figure BDA00035368542100000916
为对抗蒸馏损失,KL为库尔贝克-莱布尔散度,
Figure BDA00035368542100000917
为生成器合成的图像
Figure BDA00035368542100000918
输入到学生模型的输出,τ为温度。
生成器G可以通过最小化方程来生成广义图像。因为它反转了来自多个预训练教师的知识。然而,合成图像仍然缺乏多样性,这可能导致在重新训练期间过度拟合。为此,提出了多名教师和一名学生之间的对比互动,以提高数据多样性并产生高保真图像。
步骤3,组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数。
多视图结构非常普遍存在于许多现实世界的数据集中。这些数据中存在多个特征,可用于正确分类图像。通过观察翅膀、身体大小或嘴巴的形状,可以将鸟类图像分类为鸟类。即使学生可以提取老师学习的所有特征,他们仍然无法“看到”老师没有发现的特征,从而限制了学生的表现。即使某些模型缺少单个学生可以学习多视图知识的视图,集成也可以收集几乎所有这些视图。我们首先考虑多个集成教师来构建一个可靠的多分支模型。我们选择所有教师的平均最终输出作为模型预测,此外,我们使用不同的教师来获取各种统计知识,以提高合成图像的可生成性和多样性。
在一个示例中,多教师条件下生成器的无数据蒸馏的模型反转损失函数的获得步骤为:
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得单教师条件下生成器的无数据蒸馏的模型反转损失函数为
Figure BDA0003536854210000104
式中,
Figure BDA0003536854210000105
为单教师条件下生成器的无数据蒸馏的模型反转损失,λ1、λ2和λ3分别为第一平衡参数、第二平衡参数和第三平衡参数;
根据单教师条件下生成器的无数据蒸馏的模型反转损失函数,构建多教师条件下生成器的无数据蒸馏的模型反转损失函数为
Figure BDA0003536854210000106
式中,
Figure BDA0003536854210000107
为多教师条件下生成器的无数据蒸馏的模型反转损失,
Figure BDA0003536854210000108
为具有多教师信息的one-hot预测损失,
Figure BDA0003536854210000109
为具有多教师信息的特征正则化损失,
Figure BDA00035368542100001010
为具有多教师信息的对抗蒸馏损失,z为噪声输入,θg为生成器的参数;
Figure BDA0003536854210000101
Figure BDA0003536854210000102
Figure BDA0003536854210000103
其中,
Figure BDA0003536854210000112
为M数量的集成教师模型的输出,
Figure BDA0003536854210000113
Figure BDA0003536854210000114
为第m个训练好的教师模型的输出;
Figure BDA0003536854210000115
为生成器合成的图像
Figure BDA0003536854210000116
输入到第m个训练好的教师模型后第l个BN层得到的均值,
Figure BDA0003536854210000117
为生成器合成的图像
Figure BDA0003536854210000118
输入到第m个训练好的教师模型后第l个BN层得到的方差,
Figure BDA0003536854210000119
为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的均值,
Figure BDA00035368542100001110
为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的方差。
通过多教师集成蒸馏损失优化学生网络,使得学生网络模仿教师网络的输出。
步骤4,构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数。
对于每个预先训练好的教师模型以及学生模型输入图片得到教师网络和学生网络输出的logits,对教师网络输出的logits取平均得到集成后的logits,再对集成后的logits与学生网络输出到logits计算KL散度作为多教师集成蒸馏损失函数。
示例性的,多教师集成蒸馏损失函数为
Figure BDA0003536854210000111
式中,
Figure BDA00035368542100001111
为多教师集成蒸馏损失。
步骤5,分别建立教师内部对比损失函数和师生对比损失函数。
5-1,教师内部对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像进行数据增强,并将每张所述图像和数据增强后的图像分别输入至每个训练好的教师模型中,获得每个训练好的教师模型输出的每张所述图像的表示和数据增强后的图像的表示;
任意选取生成器合成的同一批量图像中的一张图像为待测图像;
将待测图像的表示和数据增强后的待测图像的表示作为正样本对,生成器合成的同一批量图像中待测图像以外的图像的表示作为负样本,并将生成器合成的历史图像的表示作为负样本;
根据正样本对和负样本,确定教师内部对比损失函数为
Figure BDA0003536854210000123
式中,
Figure BDA0003536854210000124
为教师内部对比损失,
Figure BDA0003536854210000125
为正样本对,
Figure BDA0003536854210000126
为第m个训练好的教师模型输出的待测图像的表示,
Figure BDA0003536854210000127
Figure BDA0003536854210000128
为第m个头部映射网络的参数,h为头部投影网络,
Figure BDA0003536854210000129
为第m个训练好的教师模型输出的数据增强后的待测图像的表示,
Figure BDA00035368542100001210
为第m个训练好的教师模型对应的第i个负样本,K为负样本的数量,τ1为第一温度超参数,sim()为余弦相似度。
5-2,师生对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像分别输入至每个训练好的教师模型和学生模型,获得每个训练好的教师模型对每张图像的表示和学生模型对每张图像的表示;
将学生模型和与学生模型同构的训练好的教师模型对同一张图像的表示作为负样本对;所述同构为学生模型和教师模型属于同一网络结构系列;学生模型和教师模型属于哪一个网络结构系列在选取的时候是已知的;
将学生模型和与学生模型异构的训练好的教师模型对同一张图像的表示定义为正样本对;所述异构为学生模型和教师模型属于不同网络结构系列;
根据所述负样本对和定义的正样本对,确定师生对比损失函数为
Figure BDA0003536854210000121
式中,
Figure BDA00035368542100001211
为师生对比损失,
Figure BDA00035368542100001212
为从生成器合成的当前批量图像中第r张图像构造出的学生模型的查询,
Figure BDA00035368542100001215
Figure BDA00035368542100001216
为生成器合成的图像
Figure BDA00035368542100001213
输入到学生模型的输出,θh为头部映射网络的参数,h为头部投影网络,
Figure BDA00035368542100001214
为从生成器合成的当前批量图像中第r张图像构造出的与学生模型异构的第m个训练好的教师模型的查询,D(s)为与学生模型异构的教师索引集,N为当前批量图像中图像的数量,τ2为第一温度超参数,Neg为负对集合,
Figure BDA0003536854210000122
I(s)为与学生模型同构的教师索引集,
Figure BDA00035368542100001217
为学生模型输出的历史图像的表示中第j个负样本,J为学生模型输出的历史图像的表示中负样本的数量,
Figure BDA00035368542100001218
为从生成器合成的历史图像中第j张图像构造出的学生网络的查询,
Figure BDA0003536854210000131
为与学生网络模型同构的教师模型的查询。
步骤6,利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作。
示例性的,步骤6具体包括:
根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为
Figure BDA0003536854210000132
式中,
Figure BDA0003536854210000133
为生成器的优化损失,
Figure BDA0003536854210000134
为教师内部对比损失函数和师生对比损失函数的总函数,
Figure BDA0003536854210000135
λ为
Figure BDA0003536854210000136
Figure BDA0003536854210000137
之间的平衡参数;优选地,λ的值为0.2;
初始化生成器的参数θg、学生模型的参数θs和图像库;
根据噪声输入z,利用生成器合成当前批量的图像;
根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
将生成器的参数θg更新为
Figure BDA0003536854210000138
其中,η为系数,
Figure BDA0003536854210000139
为梯度算子;
将生成器合成的当前批量的图像存储至图像库;
从图像库中抽取一批量图像;
根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
将学生模型的参数θs更新为
Figure BDA00035368542100001310
步骤7,重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
学生模型收敛是指学生模型的当前损失与前一次迭代计算的学生模型的损失相等。
参照图2,步骤6和步骤7的无数据蒸馏方法(Multi-teacher contrastiveKnowledgeinversion,MTCKI)的整体流程为:
输入:预训练教师模型,子集大小M
随机选择教师模型
Figure BDA0003536854210000141
初始化:G(·;θg),z←(0,1),学生模型fs(·;θs),图像库。
对于e=1:最大纪元
对于i=1:最大迭代次数
对于t=1:最大步数
从噪声z生成一批样本G(z)。
用等式
Figure BDA0003536854210000142
计算损耗LG。
更新G的参数
Figure BDA0003536854210000143
将样本G(z)存储到图像库中。
结束
对于k=1:最大步数
从图像库中抽取一批图像G(z)。
用等式
Figure BDA0003536854210000144
计算损耗Ls。
更新S的参数
Figure BDA0003536854210000145
结束
结束
输出:学生模型fS(·;θs)和图像库。
使用测试集数据对压缩后得到的学生模型以及合成的数据进行测试,所述测试指标为分类的Accuracy为和FID。
本发明的目的是针对现有技术的不足而提出的新的无数据蒸馏方法,基于多教师对比学习的模型压缩方法,在实施的过程中,一个学生可以访问多个教师,多个教师网络提供全面的指导,有利于训练一个对模型偏差具有鲁棒性的学生模型。首先将传统的模型反演扩展到基于多教师集成的模型反演,充分反演来自教师的更丰富的信息以生成可泛化的数据。另外,提出了多教师和学生之间的对比交互正则化,其中包含教师内对比和师生对比,以提高合成数据的多样性。具体来说,教师内部对比用于逐步合成具有与历史样本不同模式的新样本,而师生对比旨在推动学生与同构教师之间的关系远离表示空间中的非同构关系。以对抗的方式训练图像生成和知识转移的过程,以同时学习学生模型和生成合成数据。
本发明不依赖于模型原始的训练数据,通过多教师模型的知识蒸馏以及引入基于对比学习的损失函数,对无数据的模型压缩方法进行有效压缩和压缩后的模型具有更高的准确率,反演生成的图片具有多样性和泛化性。
本发明与现有技术相比压缩后的模型准确率更高,且适用于不同网络结构的学生网络,可以一次生成适用于多种不同的学生网络,多次生成不同数据集的计算开销以及时间,这种优点来自于生成器所合成的图片具有多样性和泛化性,由于生成器合成的图片与历史的图片进行对比,使得图片和之前生成的图片差异尽可能大,产生更丰富样式的图片,由于多个教师的正则化,以及同源的学生与其教师网络之间图片分布的拉远,消除了生成数据只适用于某一特定网络的缺点,使得图片更具有泛化性,和真实的数据集分布更接近,使得网络性能得到提升。
下面的实施例在3个公开数据集上进行测试,分别为CIFAR10、CIFAR100、Caltech-101,同时对模型进行压缩得到不同的学生网络,其测试结果如下表1:
表1数据集上性能提升对比
Figure BDA0003536854210000161
图像合成的评价指标通常使用FID来衡量。数值越低,表示合成的图片越接近真实的图片,生成质量越好。采用本发明通过无数据的知识蒸馏对模型进行压缩,使用该方法压缩后的模型的准确率是最高的,而且合成的图片的FID值更低,这也就意味着合成的图片更接近原始数据集。
总结以上的在CIFAR-10、CIFAR-100和Caltech-101上的比较结果。可以观察到:(1)本发明的方法在所有三个数据集上都优于现有方法。例如,当在CIFAR-10上提取相同的WRN-16-1时,本发明的方法达到了91.59%。(2)在本发明的多教师集成和CIFAR-10上的ResNet-34几乎相同的准确度(~95.7%)的情况下,本发明的方法在压缩到同一个小网络的准确率比其他基线有了显着的提升。这也就说明性能的提升来自于多教师结构和提出的对比交互损失,而不是教师的预测的提升。(3)教师和学生之间的同构结构有助于进一步提高学生在所有基线中的表现。例如,与ResNet-34相比,本发明使用相同的WRN-16-1作为学生,WRN-40-2作为教师显着提高了WRN-16-1在CIFAR-10上的准确度。(4)本发明的预训练教师没有使用MobileNet-V2;本发明的合成图像仍然可以有效地训练模型。本发明的方法比使用原始数据训练的MobileNet-V2实现了3.22%的准确度提升。这意味着使用提出的MTCKI的合成图像对于各种模型的训练具有很高的泛化性。(5)与其他方法相比,本发明用不同的学生模型生成的数据集的FID值都是最低的,并且方差要小得多。这意味着本发明的合成图像与原始数据集更接近。
参阅附图3,本发明可以更好合成图像的细节,在视觉效果上具有优越性。将该方法与SOTA方法的合成质量进行了比较显然,本发明的MTCKI反转的图像质量最高。例如,DAFL使用CIFAR-10上的预训练教师生成类噪声图像。Deepinv能够生成满意的视觉图像,但物体颜色与背景颜色接近,风格单一。因此,它与原始的CIFAR-10数据集相距甚远。DFQ和MTCKI的合成图像之间的比较表明,MTCKI可以生成更多样化的图像,而DFQ则遭受更严重的模式崩溃。尽管CMI采样的图像在颜色和风格上似乎有了一些改进,但它们仍然过于模糊而无法区分。本发明的方法在对象轮廓的清晰度、颜色匹配的合理性以及丰富、详细的信息方面提高了图像质量。对于CIFAR-10,MTCKI生成更多样化的语义图像,例如不同姿势的马的特写和各种类型的卡车。即使是像船后面的天际线这样的微小细节也能够被合成的。对于CIFAR-100,合成图像提供了丰富的语义信息,本发明可以轻松识别图中显示的对象,如熊猫、自行车、鲜花。
参阅附图4,本发明通过VGG16对合成数据集的特征画了t-sne图,可以看到的是本发明合成的数据的同一个类别聚类特点明显,与原数据的分布相似。
参阅附图5,本发明进一步分析了该方法的收敛性和变化的合成图像。与其他基线相比,本发明的方法需要更少的训练epoch来收敛,并且还实现了最低的损失。值得注意的是,在训练过程中,由于丰富的多教师信息和对比交互的有效性,如图6所示,第10个epoch合成的图像已经具有多样化的语义信息和组织良好的对象轮廓。
本发明的方法从可用的教师模型中提取特定于模型的知识并将其融合到学生模型中,以消除模型偏差。此外,使用多教师和学生之间的对比交互来提高合成图像的多样性,这鼓励合成图像与以前存储的图像区分开来。本发明与现有技术相比具有在图像生成过程中能产生更具有多样性和泛化性的图片,以及只需要生成一次就可以为各种网络而不是特定网络提供全面的指导的优势。大量实验表明,该方法不仅生成视觉上令人满意的图像,而且优于现有的最先进的方法。
本发明的方法作为一个用于无数据蒸馏的新框架,从多个可用的教师模型中提取“多视角”知识并将其融合到表现良好的学生模型中。上述方法设计了一种对比交互,充分利用来自多位师生的知识,生成具有高泛化性和多样性的合成数据。大量实验表明上述方法(MTCKI)优于现有的最先进方法。不仅合成了更接近原始数据集分布的高保真数据,而且还产生了与在原始数据集上训练的预训练模型相当的结果。本发明与现有技术相比压缩后的模型准确率更高,且适用于不同网络结构的学生网络,可以一次生成适用于多种不同的学生网络,多次生成不同数据集的计算开销以及时间,这种优点来自于生成器所合成的图片具有多样性和泛化性,由于生成器合成的图片与历史的图片进行对比,使得图片和之前生成的图片差异尽可能大,产生更丰富样式的图片,由于多个教师的正则化,以及同源的学生与其教师网络之间图片分布的拉远,消除了生成数据只适用于某一特定网络的缺点,使得图片更具有泛化性,和真实的数据集分布更接近,使得网络性能得到提升。
本发明还提供了一种基于数据保护场景下的模型压缩系统,系统包括:
预设模块,用于预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
三种损失函数构建模块,用于分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合模块,用于组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
多教师集成蒸馏损失函数构建模块,用于构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
对比损失函数建立模块,用于分别建立教师内部对比损失函数和师生对比损失函数;
优化模块,用于利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
循环模块,用于重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
优化模块,具体包括:
优化损失函数确定子模块,用于根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为
Figure BDA0003536854210000191
式中,
Figure BDA0003536854210000192
为生成器的优化损失,
Figure BDA0003536854210000193
为教师内部对比损失函数和师生对比损失函数的总函数,
Figure BDA0003536854210000194
λ为
Figure BDA0003536854210000195
Figure BDA0003536854210000196
之间的平衡参数;
初始化子模块,用于初始化生成器的参数θg、学生模型的参数θs和图像库;
合成子模块,用于根据噪声输入z,利用生成器合成当前批量的图像;
优化损失计算子模块,用于根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
生成器参数更新子模块,用于将生成器的参数θg更新为
Figure BDA0003536854210000197
其中,η为系数,
Figure BDA0003536854210000198
为梯度算子;
存储子模块,用于将生成器合成的当前批量的图像存储至图像库;
抽取子模块,用于从图像库中抽取一批量图像;
学生模型当前损失计算子模块,用于根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
学生模型参数更新子模块,用于将学生模型的参数θs更新为
Figure BDA0003536854210000199
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似部分互相参见即可。对于实施例公开的系统而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。
本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的方法及其核心思想;同时,对于本领域的一般技术人员,依据本发明的思想,在具体实施方式及应用范围上均会有改变之处。综上所述,本说明书内容不应理解为对本发明的限制。

Claims (10)

1.一种基于数据保护场景下的模型压缩方法,其特征在于,所述方法包括:
预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
分别建立教师内部对比损失函数和师生对比损失函数;
利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
2.根据权利要求1所述的基于数据保护场景下的模型压缩方法,其特征在于,
所述one-hot预测损失函数为
Figure FDA0003536854200000011
式中,
Figure FDA0003536854200000012
为one-hot预测损失,CE为交叉熵损失,
Figure FDA0003536854200000013
为生成器合成的图像
Figure FDA0003536854200000014
输入到训练好的教师模型后的输出,c为预定义的类;
所述特征正则化损失函数为
Figure FDA0003536854200000015
式中,
Figure FDA0003536854200000016
为特征正则化损失,
Figure FDA0003536854200000017
为生成器合成的图像
Figure FDA0003536854200000018
输入到训练好的教师模型后第l个BN层得到的均值,
Figure FDA0003536854200000019
为生成器合成的图像
Figure FDA00035368542000000110
输入到训练好的教师模型后第l个BN层得到的方差,F(μl(x)|X)为输入图像x输入到训练好的教师模型后第l个BN层得到的均值,
Figure FDA00035368542000000111
为输入图像x输入到训练好的教师模型后第l个BN层得到的方差;
所述对抗蒸馏损失函数为
Figure FDA0003536854200000021
式中,
Figure FDA0003536854200000022
为对抗蒸馏损失,KL为库尔贝克-莱布尔散度,
Figure FDA0003536854200000023
为生成器合成的图像
Figure FDA0003536854200000024
输入到学生模型的输出,τ为温度。
3.根据权利要求2所述的基于数据保护场景下的模型压缩方法,其特征在于,所述组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数,具体包括:
组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得单教师条件下生成器的无数据蒸馏的模型反转损失函数为
Figure FDA0003536854200000025
式中,
Figure FDA0003536854200000026
为单教师条件下生成器的无数据蒸馏的模型反转损失,λ1、λ2和λ3分别为第一平衡参数、第二平衡参数和第三平衡参数;
根据单教师条件下生成器的无数据蒸馏的模型反转损失函数,构建多教师条件下生成器的无数据蒸馏的模型反转损失函数为
Figure FDA0003536854200000027
式中,
Figure FDA0003536854200000028
为多教师条件下生成器的无数据蒸馏的模型反转损失,
Figure FDA0003536854200000029
为具有多教师信息的one-hot预测损失,
Figure FDA00035368542000000210
为具有多教师信息的特征正则化损失,
Figure FDA00035368542000000211
为具有多教师信息的对抗蒸馏损失,z为噪声输入,θg为生成器的参数;
Figure FDA00035368542000000212
Figure FDA00035368542000000213
Figure FDA00035368542000000214
其中,
Figure FDA00035368542000000215
为M数量的集成教师模型的输出,
Figure FDA00035368542000000216
Figure FDA00035368542000000217
为第m个训练好的教师模型的输出;
Figure FDA00035368542000000218
为生成器合成的图像
Figure FDA00035368542000000219
输入到第m个训练好的教师模型后第l个BN层得到的均值,
Figure FDA00035368542000000220
为生成器合成的图像
Figure FDA00035368542000000221
输入到第m个训练好的教师模型后第l个BN层得到的方差,
Figure FDA00035368542000000222
为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的均值,
Figure FDA00035368542000000223
为输入图像x输入到第m个训练好的教师模型后第l个BN层得到的方差。
4.根据权利要求3所述的基于数据保护场景下的模型压缩方法,其特征在于,所述多教师集成蒸馏损失函数为
Figure FDA0003536854200000031
式中,
Figure FDA0003536854200000032
为多教师集成蒸馏损失。
5.根据权利要求4所述的基于数据保护场景下的模型压缩方法,其特征在于,所述教师内部对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像进行数据增强,并将每张所述图像和数据增强后的图像分别输入至每个训练好的教师模型中,获得每个训练好的教师模型输出的每张所述图像的表示和数据增强后的图像的表示;
任意选取生成器合成的同一批量图像中的一张图像为待测图像;
将待测图像的表示和数据增强后的待测图像的表示作为正样本对,生成器合成的同一批量图像中待测图像以外的图像的表示作为负样本,并将生成器合成的历史图像的表示作为负样本;
根据正样本对和负样本,确定教师内部对比损失函数为
Figure FDA0003536854200000033
式中,
Figure FDA0003536854200000034
为教师内部对比损失,
Figure FDA0003536854200000035
为正样本对,
Figure FDA0003536854200000036
为第m个训练好的教师模型输出的待测图像的表示,
Figure FDA0003536854200000037
Figure FDA0003536854200000038
为第m个头部映射网络的参数,h为头部投影网络,
Figure FDA0003536854200000039
为第m个训练好的教师模型输出的数据增强后的待测图像的表示,
Figure FDA00035368542000000310
为第m个训练好的教师模型对应的第i个负样本,K为负样本的数量,τ1为第一温度超参数,sim()为余弦相似度。
6.根据权利要求5所述的基于数据保护场景下的模型压缩方法,其特征在于,所述师生对比损失函数的建立过程为:
对生成器合成的同一批量图像中的每张图像分别输入至每个训练好的教师模型和学生模型,获得每个训练好的教师模型对每张图像的表示和学生模型对每张图像的表示;
将学生模型和与学生模型同构的训练好的教师模型对同一张图像的表示作为负样本对;所述同构为学生模型和教师模型属于同一网络结构系列;
将学生模型和与学生模型异构的训练好的教师模型对同一张图像的表示定义为正样本对;所述异构为学生模型和教师模型属于不同网络结构系列;
根据所述负样本对和定义的正样本对,确定师生对比损失函数为
Figure FDA0003536854200000041
式中,
Figure FDA0003536854200000042
为师生对比损失,
Figure FDA0003536854200000043
为从生成器合成的当前批量图像中第r张图像构造出的学生模型的查询,
Figure FDA0003536854200000044
为生成器合成的图像
Figure FDA0003536854200000045
输入到学生模型的输出,θh为头部映射网络的参数,h为头部投影网络,
Figure FDA0003536854200000046
为从生成器合成的当前批量图像中第r张图像构造出的与学生模型异构的第m个训练好的教师模型的查询,D(s)为与学生模型异构的教师索引集,N为当前批量图像中图像的数量,τ2为第一温度超参数,Neg为负对集合,
Figure FDA0003536854200000047
I(s)为与学生模型同构的教师索引集,
Figure FDA0003536854200000048
为学生模型输出的历史图像的表示中第j个负样本,J为学生模型输出的历史图像的表示中负样本的数量,
Figure FDA0003536854200000049
为从生成器合成的历史图像中第j张图像构造出的学生网络的查询,
Figure FDA00035368542000000410
为与学生网络模型同构的教师模型的查询。
7.根据权利要求6所述的基于数据保护场景下的模型压缩方法,其特征在于,所述利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作,具体包括:
根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为
Figure FDA00035368542000000411
式中,
Figure FDA00035368542000000412
为生成器的优化损失,
Figure FDA00035368542000000413
为教师内部对比损失函数和师生对比损失函数的总函数,
Figure FDA00035368542000000414
λ为
Figure FDA00035368542000000415
Figure FDA00035368542000000416
之间的平衡参数;
初始化生成器的参数θg、学生模型的参数θs和图像库;
根据噪声输入z,利用生成器合成当前批量的图像;
根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
将生成器的参数θg更新为
Figure FDA0003536854200000051
其中,η为系数,
Figure FDA0003536854200000052
为梯度算子;
将生成器合成的当前批量的图像存储至图像库;
从图像库中抽取一批量图像;
根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
将学生模型的参数θs更新为
Figure FDA0003536854200000053
8.根据权利要求7所述的基于数据保护场景下的模型压缩方法,其特征在于,所述学生模型收敛是指学生模型的当前损失与前一次迭代计算的学生模型的损失相等。
9.一种基于数据保护场景下的模型压缩系统,其特征在于,所述系统包括:
预设模块,用于预设在同一数据集下多个训练好的教师模型以及随机初始化的学生模型和生成器;
三种损失函数构建模块,用于分别构建one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数;
组合模块,用于组合one-hot预测损失函数、特征正则化损失函数和对抗蒸馏损失函数,获得多教师条件下生成器的无数据蒸馏的模型反转损失函数;
多教师集成蒸馏损失函数构建模块,用于构建学生模型模仿教师模型输出的多教师集成蒸馏损失函数;
对比损失函数建立模块,用于分别建立教师内部对比损失函数和师生对比损失函数;
优化模块,用于利用无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数对生成器进行优化操作,输出本次优化后生成器合成的图像,并将本次优化后生成器合成的图像分别输入至学生模型和每个训练好的教师模型,通过多教师集成蒸馏损失函数对学生模型进行优化操作;
循环模块,用于重复进行优化操作,直至学生模型收敛,获得压缩后的学生模型。
10.根据权利要求9所述的基于数据保护场景下的模型压缩系统,其特征在于,所述优化模块,具体包括:
优化损失函数确定子模块,用于根据无数据蒸馏的模型反转损失函数、教师内部对比损失函数和师生对比损失函数,确定生成器的优化损失函数为
Figure FDA0003536854200000061
式中,
Figure FDA0003536854200000062
为生成器的优化损失,
Figure FDA0003536854200000063
为教师内部对比损失函数和师生对比损失函数的总函数,
Figure FDA0003536854200000064
λ为
Figure FDA0003536854200000065
Figure FDA0003536854200000066
之间的平衡参数;
初始化子模块,用于初始化生成器的参数θg、学生模型的参数θs和图像库;
合成子模块,用于根据噪声输入z,利用生成器合成当前批量的图像;
优化损失计算子模块,用于根据所述当前批量的图像,利用所述优化损失函数计算生成器的当前优化损失;
生成器参数更新子模块,用于将生成器的参数θg更新为
Figure FDA0003536854200000067
其中,η为系数,
Figure FDA0003536854200000068
为梯度算子;
存储子模块,用于将生成器合成的当前批量的图像存储至图像库;
抽取子模块,用于从图像库中抽取一批量图像;
学生模型当前损失计算子模块,用于根据抽取的一批量图像,利用多教师集成蒸馏损失函数计算学生模型的当前损失;
学生模型参数更新子模块,用于将学生模型的参数θs更新为
Figure FDA0003536854200000069
CN202210220060.9A 2022-03-08 2022-03-08 一种基于数据保护场景下的模型压缩方法及系统 Pending CN114565810A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210220060.9A CN114565810A (zh) 2022-03-08 2022-03-08 一种基于数据保护场景下的模型压缩方法及系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210220060.9A CN114565810A (zh) 2022-03-08 2022-03-08 一种基于数据保护场景下的模型压缩方法及系统

Publications (1)

Publication Number Publication Date
CN114565810A true CN114565810A (zh) 2022-05-31

Family

ID=81718168

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210220060.9A Pending CN114565810A (zh) 2022-03-08 2022-03-08 一种基于数据保护场景下的模型压缩方法及系统

Country Status (1)

Country Link
CN (1) CN114565810A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117573908A (zh) * 2024-01-16 2024-02-20 卓世智星(天津)科技有限公司 基于对比学习的大语言模型蒸馏方法

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117573908A (zh) * 2024-01-16 2024-02-20 卓世智星(天津)科技有限公司 基于对比学习的大语言模型蒸馏方法
CN117573908B (zh) * 2024-01-16 2024-03-19 卓世智星(天津)科技有限公司 基于对比学习的大语言模型蒸馏方法

Similar Documents

Publication Publication Date Title
CN108875807B (zh) 一种基于多注意力多尺度的图像描述方法
CN111798369B (zh) 一种基于循环条件生成对抗网络的人脸衰老图像合成方法
CN114398961B (zh) 一种基于多模态深度特征融合的视觉问答方法及其模型
CN110097178A (zh) 一种基于熵注意的神经网络模型压缩与加速方法
CN109166144A (zh) 一种基于生成对抗网络的图像深度估计方法
CN112784929B (zh) 一种基于双元组扩充的小样本图像分类方法及装置
CN111210002B (zh) 一种基于生成对抗网络模型的多层学术网络社区发现方法、系统
CN112527993B (zh) 一种跨媒体层次化深度视频问答推理框架
CN114386534A (zh) 一种基于变分自编码器和对抗生成网络的图像增广模型训练方法及图像分类方法
CN109978021A (zh) 一种基于文本不同特征空间的双流式视频生成方法
CN109871504A (zh) 一种基于异构信息网络与深度学习的课程推荐系统
CN114511737B (zh) 图像识别域泛化模型的训练方法
CN111694977A (zh) 一种基于数据增强的车辆图像检索方法
CN113822953A (zh) 图像生成器的处理方法、图像生成方法及装置
CN109214442A (zh) 一种基于列表和身份一致性约束的行人重识别算法
CN110210540A (zh) 基于注意力机制的跨社交媒体用户身份识别方法及系统
CN113849725B (zh) 一种基于图注意力对抗网络的社会化推荐方法及系统
WO2022166840A1 (zh) 人脸属性编辑模型的训练方法、人脸属性编辑方法及设备
CN114565810A (zh) 一种基于数据保护场景下的模型压缩方法及系统
CN110197226B (zh) 一种无监督图像翻译方法及系统
CN117036901A (zh) 一种基于视觉自注意力模型的小样本微调方法
CN113822790B (zh) 一种图像处理方法、装置、设备及计算机可读存储介质
CN113283584B (zh) 一种基于孪生网络的知识追踪方法及系统
CN115909201A (zh) 一种基于多分支联合学习的遮挡行人重识别方法及系统
CN114742292A (zh) 面向知识追踪过程的双态协同演化预测学生未来表现方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination