CN112381209B - 一种模型压缩方法、系统、终端及存储介质 - Google Patents

一种模型压缩方法、系统、终端及存储介质 Download PDF

Info

Publication number
CN112381209B
CN112381209B CN202011269682.8A CN202011269682A CN112381209B CN 112381209 B CN112381209 B CN 112381209B CN 202011269682 A CN202011269682 A CN 202011269682A CN 112381209 B CN112381209 B CN 112381209B
Authority
CN
China
Prior art keywords
network
compression
sample
training
student network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202011269682.8A
Other languages
English (en)
Other versions
CN112381209A (zh
Inventor
郑强
王晓锐
高鹏
谢国彤
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Ping An Technology Shenzhen Co Ltd
Original Assignee
Ping An Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Ping An Technology Shenzhen Co Ltd filed Critical Ping An Technology Shenzhen Co Ltd
Priority to CN202011269682.8A priority Critical patent/CN112381209B/zh
Publication of CN112381209A publication Critical patent/CN112381209A/zh
Priority to PCT/CN2021/083230 priority patent/WO2021197223A1/zh
Application granted granted Critical
Publication of CN112381209B publication Critical patent/CN112381209B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Other Investigation Or Analysis Of Materials By Electrical Means (AREA)

Abstract

本发明公开了一种模型压缩方法、系统、终端及存储介质。所述方法包括:通过样本生成器生成训练样本;基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。本发明实施例可完全不依赖于原始训练数据集实现对模型的压缩,解决了因为原始训练数据集敏感性和数据海量问题导致模型压缩工作无法展开的问题,有效降低了模型压缩的精度损失。本发明还涉及区块链技术领域。

Description

一种模型压缩方法、系统、终端及存储介质
技术领域
本发明涉及人工智能技术领域,特别是涉及一种模型压缩方法、系统、终端及存储介质。
背景技术
在人工智能领域,模型生命周期通常可分为模型训练和模型推理两个环节。在模型训练环节,为追求模型具有更高的预测精准度,模型往往不可避免的存在冗余。而在模型推理环节,由于受到推理应用环境的苛刻要求,除了关注模型预测的精准度外,还希望模型具有推理速度快、资源占用省、文件尺寸小等高性能特点。模型压缩恰恰是将模型从模型训练环节向模型推理环节转变的常用优化手段。
目前,业界主流的模型压缩技术包括剪枝、量化和知识蒸馏等,这些主流技术都需要通过原始训练数据集参与才能完成对模型的优化过程模型;其中,剪枝技术需要通过原始训练数据集完成剪枝决策和剪枝后重建(Fine-Tune);模型量化需要通过原始训练数据集来完成Quantization-aware training(训练中引入量化)过程或者通过原始训练数据集Post-training quantization(训练后的量化)的Calibration(校准)过程;知识蒸馏需要通过原始训练数据集分别送入Teacher网络和Student网络完成Knowledge-Transfer((知识转移))的过程。
从行业的发展状况来看,模型训练和模型压缩往往由不同的职能团队承担,且分工比较明确。而由于训练数据涉及私密性或者数据海量(难于传输和存储)等原因,获得原始训练数据集的难度较大,影响模型压缩工作的进展。
近期,虽然逐渐也有不依赖于原始训练数据集的模型压缩技术出现,例如对抗知识蒸馏;但由于该技术成熟度不高,还存在如下不足:
1.对抗知识蒸馏过程波动性和随机性大,难于稳定复现;
2.对抗知识蒸馏的精度损失较大,难于满足实际应用要求。
发明内容
本发明提供了一种模型压缩方法、系统、终端及存储介质,能够在一定程度上解决现有技术中存在的不足。
为解决上述技术问题,本发明采用的技术方案为:
一种模型压缩方法,包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
本发明实施例采取的技术方案还包括:所述基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练包括:
基于超参组合,将随机数生成器生成的第一随机数输入样本生成器,由所述样本生成器生成第一训练样本;
将所述第一训练样本分别输入教师网络和学生网络,分别由所述教师网络和学生网络输出第一预测结果y和y_hat,根据所述第一预测结果y和y_hat计算出第一损失值;
根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新;
将所述随机数生成器生成的第二随机数输入样本生成器,由所述样本生成器生成第二训练样本;
将所述第二训练样本分别输入至所述参数更新后的学生网络和教师网络,分别由所述教师网络和学生网络输出第二预测结果y和y_hat,根据所述第二预测结果y和y_hat计算出第二损失值;
根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器进行参数更新;
迭代执行所述学生网络和样本生成器的更新,直到满足迭代结束条件时结束迭代,得到基于超参组合对所述学生网络进行的对抗知识蒸馏的训练结果,并保存所述学生网络的参数。
本发明实施例采取的技术方案还包括:所述根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络的参数进行更新还包括:
根据预设迭代次数迭代执行所述第一损失值的计算以及所述学生网络的反向传播,对所述学生网络的参数进行预设迭代次数的更新。
本发明实施例采取的技术方案还包括:所述根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器进行参数更新还包括:
根据预设迭代次数迭代执行所述第二损失值的计算以及所述样本生成器的反向传播,对所述样本生成器的参数进行预设迭代次数的更新。
本发明实施例采取的技术方案还包括:所述由所述教师网络生成合成样本集包括:
通过Label生成器生成期望标签;
通过随机样本生成器生成样本;
将所述生成样本输入教师网络,由所述教师网络输出第一预测标签;
基于所述期望标签和第一预测标签计算出第三损失函数;
基于所述第三损失函数对所述生成样本进行梯度更新;
根据预设迭代次数迭代执行所述生成样本的梯度更新,直到所述生成样本满足预设条件。
本发明实施例采取的技术方案还包括:所述将所述合成样本集输入所述粗压缩学生网络,对所述粗压缩学生网络进行有监督学习训练包括:
将所述合成样本输入粗压缩学生网络,由所述粗压缩学生网络输出第二预测标签;
基于所述第二预测标签计算出所述学生网络最终的损失函数。
本发明实施例采取的技术方案还包括:将所述合成样本集输入所述粗压缩学生网络,对所述粗压缩学生网络进行有监督学习训练后还包括:
根据模型评价指标对所述学生网络压缩结果进行评价;所述模型评价指标包括准确率。
本发明实施例采取的另一技术方案为:一种模型压缩系统,包括:
蒸馏压缩模块:用于基于至少一组超参组合,将样本生成器生成的训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
样本生成模块:用于通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
监督学习模块:用于通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
本发明实施例采取的又一技术方案为:一种终端,所述终端包括处理器、与所述处理器耦接的存储器,其中,
所述存储器存储有用于实现上述的模型压缩方法的程序指令;
所述处理器用于执行所述存储器存储的所述程序指令以执行所述模型压缩操作。
本发明实施例采取的又一技术方案为:一种存储介质,存储有处理器可运行的程序指令,所述程序指令用于执行上述的模型压缩方法。
本发明的有益效果是:本发明实施例的模型压缩方法、系统、终端及存储介质通过粗压缩和精压缩两个阶段进行模型压缩,在粗压缩阶段,采用对抗知识蒸馏方法对学生网络进行蒸馏压缩,生成粗压缩学生网络;在精压缩阶段,通过教师网络生成高质量的合成样本集,并通过合成样本集对粗压缩学生网络进行有监督学习,实现无需原始训练数据集的模型压缩。相对于现有技术,本发明实施例至少具有以下优点:
1、可完全不依赖于原始训练数据集实现对模型的压缩,解决了因为原始训练数据集敏感性和数据海量问题导致模型压缩工作无法展开的问题;
2、弥补了对抗知识蒸馏方法进行模型压缩时随机性大、波动性大以及难以控制和调试的问题;
3、有效降低了模型压缩的精度损失,做到几乎无损。
附图说明
图1是本发明第一实施例的模型压缩方法的流程示意图;
图2是本发明第二实施例的模型压缩方法的流程示意图;
图3为本发明实施例对学生模型进行蒸馏压缩的实现过程示意图;
图4为本发明实施例的合成样本集生成过程示意图;
图5是本发明第三实施例的模型压缩方法的流程示意图;
图6是本发明实施例的学生网络训练过程示意图;
图7是本发明一个实施例中的学生模型性能示意图;
图8是本发明一个实施例中的合成样本示意图;
图9是本发明一个实施例中的模型压缩结果示意图;
图10是本发明实施例模型压缩系统的结构示意图;
图11是本发明实施例的本发明实施例的终端结构示意图;
图12是本发明实施例的存储介质结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本发明的一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明中的术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”、“第三”的特征可以明示或者隐含地包括至少一个该特征。本发明的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。本发明实施例中所有方向性指示(诸如上、下、左、右、前、后……)仅用于解释在某一特定姿态(如附图所示)下各部件之间的相对位置关系、运动情况等,如果该特定姿态发生改变时,则该方向性指示也相应地随之改变。此外,术语“包括”和“具有”以及它们任何变形,意图在于覆盖不排他的包含。例如包含了一系列步骤或单元的过程、方法、系统、产品或设备没有限定于已列出的步骤或单元,而是可选地还包括没有列出的步骤或单元,或可选地还包括对于这些过程、方法、产品或设备固有的其它步骤或单元。
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本发明的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
针对现有技术存在的不足,本发明实施例的模型压缩方法通过将整个压缩阶段分成粗压缩和精压缩两个阶段,首先,在粗压缩阶段,采用对抗知识蒸馏方法实现对被压缩模型的粗略压缩,得到粗压缩结果;在精压缩阶段,采用监督学习的方法对粗压缩结果进行微调,得到更高精度的压缩结果,从而在不依赖原始训练数据集的情况下完成对模型的高精度压缩。
具体的,请参阅图1,是本发明第一实施例的模型压缩方法的流程示意图。本发明第一实施例的模型压缩方法包括以下步骤:
S10:通过样本生成器生成训练样本;
S11:基于至少一组超参组合,将训练样本分别输入学生网络和教师网络,对学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
本步骤中,对学生网络和教师网络进行对抗知识蒸馏训练具体包括以下步骤:
S11a:基于超参组合H1,将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
S11b:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
S11c:根据第一损失值loss_s对S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
S11d:迭代执行S11a至S11c K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
S11e:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
S11f:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
S11g:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
S11h:迭代执行S11e至S11g M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
S11i:迭代执行S11a至S11h,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1},并保存S网络参数。
S12:通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;
本步骤中,合成样本集的生成方式具体包括:
S12a:通过Label生成器生成期望标签label;
S12b:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
S12c:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
S12d:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
S12e:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
S12f:迭代执行S12c至S12e M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求。
S13:通过合成样本集对粗压缩学生网络进行有监督学习训练,得到学生网络的压缩结果。
请参阅图2,是本发明第二实施例的模型压缩方法的流程示意图。本发明第二实施例的模型压缩方法包括以下步骤:
S20:构造学生网络结构;
本步骤中,本发明实施例基于预训练模型结构和模型压缩目标构造出更轻量级的学生网络模型结构。
S21:在不使用原始训练数据集的情况下,通过对抗知识蒸馏方式在多种超参组合方式下对学生网络进行蒸馏压缩,得到粗压缩S网络;
本步骤中,请一并参阅图3,为本发明实施例对学生模型进行蒸馏压缩的实现过程,其具体包括:
S21a:从超参组合簇(H1,H2,H3…HN)中取一个超参组合H1,用于对抗知识蒸馏的训练;
S21b:将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
S21c:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
S21d:根据第一损失值loss_s对S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
S21e:迭代执行步骤S21b至S21d K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
S21f:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
S21g:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
S21h:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
S21i:迭代执行步骤S21f至S21h M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
S21j:迭代执行步骤S21b至S21i,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1},并保存S网络参数;
S21k:基于超参组合簇(H1,H2,H3…HN)中的超参组合,迭代执行步骤S21b至步骤S21i n(n∈N)次,得到基于超参组合簇(H1,H2,H3…HN)进行对抗知识蒸馏的训练结果{Hn,Sn,ACCn}。
S22:通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;
本步骤中,请一并参阅图4,为本发明实施例的合成样本集生成过程示意图。合成样本集生成方式具体包括:
S22a:通过Label生成器生成期望标签label;
S22b:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
S22c:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
S22d:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
S22e:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
S22f:迭代执行S22c至S22e M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求;
S22g:迭代执行S22a至S22f,生成合成样本集{Sample(B,H,W,C),label}。
S23:通过合成样本集对粗压缩后的S网络进行有监督学习训练,得到S网络的压缩结果。
请参阅图5,是本发明第三实施例的模型压缩方法的流程示意图。本发明第三实施例的模型压缩方法包括以下步骤:
S30:构造学生网络结构;
本步骤中,本发明实施例基于预训练模型结构和模型压缩目标构造出更轻量级的学生网络模型结构。
S31:在不使用原始训练数据集的情况下,通过对抗知识蒸馏方式在多种超参组合方式下对学生网络进行蒸馏压缩,得到粗压缩S网络;
本步骤中,对学生模型进行蒸馏压缩的实现过程,其具体包括:
S31a:从超参组合簇(H1,H2,H3…HN)中取一个超参组合H1,用于对抗知识蒸馏的训练;
S31b:将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
S31c:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
S31d:根据第一损失值loss_s对S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
S31e:迭代执行步骤S31b至S31d K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
S31f:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
S31g:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
S31h:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
S31i:迭代执行步骤S31f至S31h M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
S31j:迭代执行步骤S31b至S31i,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1},并保存S网络参数;
S31k:基于超参组合簇(H1,H2,H3…HN)中的超参组合,迭代执行步骤S31b至步骤S31i n(n∈N)次,得到基于超参组合簇(H1,H2,H3…HN)进行对抗知识蒸馏的训练结果{Hn,Sn,ACCn}。
S32:通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;
本步骤中,合成样本集生成过程具体包括:
S32a:通过Label生成器生成期望标签label;
S32b:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
S32c:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
S32d:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
S32e:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
S32f:迭代执行S32c至S32e M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求;
S32g:迭代执行S32a至S32f,生成合成样本集{Sample(B,H,W,C),label}。
S33:通过合成样本集对粗压缩后的S网络进行有监督学习训练,得到S网络的压缩结果;
本步骤中,请一并参阅图6,是本发明实施例的学生网络有监督学习训练过程示意图。首先,将合成样本Sample(B,H,W,C)输入粗压缩后的S网络,由S网络输出第二预测标签label_hat2;损失函数计算器L基于第二预测标签label_hat2计算出S网络的损失函数loss_s。
S34:根据模型评价指标对学生网络压缩结果进行评价;
其中,模型评价指标包括但不限于Accuracy(准确率)等。
可以理解,上述实施例中的K次、M次等迭代次数可根据实际应用场景进行设定。
为了验证本发明实施例的可行性和有效性,以将本发明实施例的模型压缩方法应用于基于Transformer架构进行的OCR任务模型压缩为例进行实验。在经过8组超参组合配置下,通过对抗知识蒸馏后,在系统框图中A节点可得到如图7所示性能的学生模型。通过样本合成迭代训练后,在系统框图中B节点处可得到如图8所示的高质量的合成样本。通过第二阶段的监督学习后,在系统框图中C节点处可得到如图9所示的模型压缩结果。实验结果证明,本发明实施例可以在不依赖于原始训练数据集的情况下得到高精度的模型压缩结果。
基于上述,本发明实施例的模型压缩方法通过粗压缩和精压缩两个阶段进行模型压缩,在粗压缩阶段,采用对抗知识蒸馏方法对学生网络进行蒸馏压缩,生成粗压缩后的学生网络;在精压缩阶段,通过教师网络生成高质量的合成样本集,并通过合成样本集对粗压缩后的学生网络进行有监督学习,实现无需原始训练数据集的模型压缩。相对于现有技术,本发明实施例至少具有以下优点:
1、可完全不依赖于原始训练数据集实现对模型的压缩,解决了因为原始训练数据集敏感性和数据海量问题导致模型压缩工作无法展开的问题;
2、弥补了对抗知识蒸馏方法进行模型压缩时随机性大、波动性大以及难以控制和调试的问题;
3、有效降低了模型压缩的精度损失,做到几乎无损。
在一个可选的实施方式中,还可以:将所述的模型压缩方法的结果上传至区块链中。
具体地,基于所述的模型压缩方法的结果得到对应的摘要信息,具体来说,摘要信息由所述的模型压缩方法的结果进行散列处理得到,比如利用sha256s算法处理得到。将摘要信息上传至区块链可保证其安全性和对用户的公正透明性。用户可以从区块链中下载得该摘要信息,以便查证所述的模型压缩方法的结果是否被篡改。本示例所指区块链是分布式数据存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。区块链(Blockchain),本质上是一个去中心化的数据库,是一串使用密码学方法相关联产生的数据块,每一个数据块中包含了一批次网络交易的信息,用于验证其信息的有效性(防伪)和生成下一个区块。区块链可以包括区块链底层平台、平台产品服务层以及应用服务层等。
请参阅图10,是本发明实施例模型压缩系统的结构示意图。本发明实施例模型压缩系统40包括:
蒸馏压缩模块41:用于基于至少一组超参组合,采用对抗知识蒸馏方法对学生网络进行蒸馏压缩,生成粗压缩学生网络;具体的,蒸馏压缩模块42对学生模型进行蒸馏压缩的实现过程具体为:
第一步:从超参组合簇(H1,H2,H3…HN)中取一个超参组合H1,用于对抗知识蒸馏的训练;
第二步:将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
第三步:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
第四步:根据第一损失值loss_s对S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
第五步:迭代执行第二步至第四步K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
第六步:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
第七步:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
第八步:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
第九步:迭代执行第六步至第八步M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
第十步:迭代执行第二步至第九步,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1},并保存S网络参数;
第十一步:基于超参组合簇(H1,H2,H3…HN)中的超参组合,迭代执行步骤第二步至第九步n(n∈N)次,得到基于超参组合簇(H1,H2,H3…HN)进行对抗知识蒸馏的训练结果{Hn,Sn,ACCn}。
样本生成模块42:用于通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;其中,合成样本集的生成过程具体包括:
第一步:通过Label生成器生成期望标签label;
第二步:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
第三步:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
第四步:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
第五步:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
第六步:迭代执行第三步至第五步M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求;
第七步:迭代执行第一步至第六步,生成合成样本集{Sample(B,H,W,C),label}。
监督学习模块43:用于将合成样本集输入粗压缩学生网络,对粗压缩学生网络进行有监督学习训练,得到学生网络压缩结果;其中,学生网络的有监督学习训练过程具体为:首先,将合成样本Sample(B,H,W,C)输入粗压缩后的S网络,由S网络输出第二预测标签label_hat2;损失函数计算器L基于第二预测标签label_hat2计算出S网络的损失函数loss_s。
请参阅图11,为本发明实施例的终端结构示意图。该终端50包括处理器51、与处理器51耦接的存储器52。
存储器52存储有用于实现上述模型压缩方法的程序指令。
处理器51用于执行存储器52存储的程序指令以执行模型压缩操作。
其中,处理器51还可以称为CPU(Central Processing Unit,中央处理单元)。处理器51可能是一种集成电路芯片,具有信号的处理能力。处理器51还可以是通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现成可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
请参阅图12,图12为本发明实施例的存储介质的结构示意图。本发明实施例的存储介质存储有能够实现上述所有方法的程序文件61,其中,该程序文件61可以以软件产品的形式存储在上述存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本发明各个实施方式方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-OnlyMemory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质,或者是计算机、服务器、手机、平板等终端设备。
在本发明所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的系统实施例仅仅是示意性的,例如,单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
另外,在本发明各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。以上仅为本发明的实施方式,并非因此限制本发明的专利范围,凡是利用本发明说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本发明的专利保护范围内。

Claims (8)

1.一种模型压缩方法,其特征在于,包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果;
所述基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练包括:
基于超参组合,将随机数生成器生成的第一随机数输入样本生成器,由所述样本生成器生成第一训练样本;
将所述第一训练样本分别输入教师网络和学生网络,分别由所述教师网络和学生网络输出第一预测结果y和y_hat,根据所述第一预测结果y和y_hat计算出第一损失值;
根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新;
将所述随机数生成器生成的第二随机数输入样本生成器,由所述样本生成器生成第二训练样本;
将所述第二训练样本分别输入至所述参数更新后的学生网络和教师网络,分别由所述教师网络和学生网络输出第二预测结果y和y_hat,根据所述第二预测结果y和y_hat计算出第二损失值;
根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器进行参数更新;
迭代执行所述学生网络和样本生成器的更新,直到满足迭代结束条件时结束迭代,得到基于超参组合对所述学生网络进行的对抗知识蒸馏的训练结果,并保存所述学生网络的参数;
所述由所述教师网络生成合成样本集包括:
通过Label生成器生成期望标签;
通过随机样本生成器生成样本;
将所述生成样本输入教师网络,由所述教师网络输出第一预测标签;
基于所述期望标签和第一预测标签计算出第三损失函数;
基于所述第三损失函数对所述生成样本进行梯度更新;
根据预设迭代次数迭代执行所述生成样本的梯度更新,直到所述生成样本满足预设条件。
2.根据权利要求1所述的模型压缩方法,其特征在于,所述根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新还包括:
根据预设迭代次数迭代执行所述第一损失值的计算以及所述学生网络的反向传播,对所述学生网络的参数进行预设迭代次数的更新。
3.根据权利要求2所述的模型压缩方法,其特征在于,所述根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器的参数进行更新还包括:
根据预设迭代次数迭代执行所述第二损失值的计算以及所述样本生成器的反向传播,对所述样本生成器的参数进行预设迭代次数的更新。
4.根据权利要求1所述的模型压缩方法,其特征在于,所述通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练后包括:
将所述合成样本集输入粗压缩学生网络,由所述粗压缩学生网络输出第二预测标签;
基于所述第二预测标签计算出所述学生网络最终的损失函数。
5.根据权利要求1至4任一项所述的模型压缩方法,其特征在于,所述通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练后还包括:
根据模型评价指标对所述学生网络压缩结果进行评价;所述模型评价指标包括准确率。
6.一种模型压缩系统,其用于实现如权利要求1至5任一项所述的模型压缩方法,其特征在于,包括:
蒸馏压缩模块:用于基于至少一组超参组合,将样本生成器生成的训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
样本生成模块:用于通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
监督学习模块:用于通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
7.一种终端,其特征在于,所述终端包括处理器、与所述处理器耦接的存储器,其中,
所述存储器存储有用于实现权利要求1~5任一项所述的模型压缩方法的程序指令;
所述处理器用于执行所述存储器存储的所述程序指令以执行所述模型压缩方法。
8.一种存储介质,其特征在于,存储有处理器可运行的程序指令,所述程序指令用于执行权利要求1~5任一项所述的模型压缩方法。
CN202011269682.8A 2020-11-13 2020-11-13 一种模型压缩方法、系统、终端及存储介质 Active CN112381209B (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
CN202011269682.8A CN112381209B (zh) 2020-11-13 2020-11-13 一种模型压缩方法、系统、终端及存储介质
PCT/CN2021/083230 WO2021197223A1 (zh) 2020-11-13 2021-03-26 一种模型压缩方法、系统、终端及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011269682.8A CN112381209B (zh) 2020-11-13 2020-11-13 一种模型压缩方法、系统、终端及存储介质

Publications (2)

Publication Number Publication Date
CN112381209A CN112381209A (zh) 2021-02-19
CN112381209B true CN112381209B (zh) 2023-12-22

Family

ID=74583913

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011269682.8A Active CN112381209B (zh) 2020-11-13 2020-11-13 一种模型压缩方法、系统、终端及存储介质

Country Status (2)

Country Link
CN (1) CN112381209B (zh)
WO (1) WO2021197223A1 (zh)

Families Citing this family (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112381209B (zh) * 2020-11-13 2023-12-22 平安科技(深圳)有限公司 一种模型压缩方法、系统、终端及存储介质
CN113255763B (zh) * 2021-05-21 2023-06-09 平安科技(深圳)有限公司 基于知识蒸馏的模型训练方法、装置、终端及存储介质
US11599794B1 (en) * 2021-10-20 2023-03-07 Moffett International Co., Limited System and method for training sample generator with few-shot learning
CN114240892B (zh) * 2021-12-17 2024-07-02 华中科技大学 一种基于知识蒸馏的无监督工业图像异常检测方法及系统
CN114495245B (zh) * 2022-04-08 2022-07-29 北京中科闻歌科技股份有限公司 人脸伪造图像鉴别方法、装置、设备以及介质
CN115908955B (zh) * 2023-03-06 2023-06-20 之江实验室 基于梯度蒸馏的少样本学习的鸟类分类系统、方法与装置

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109710544A (zh) * 2017-10-26 2019-05-03 杭州华为数字技术有限公司 内存访问方法、计算机系统以及处理装置
CN110084281A (zh) * 2019-03-31 2019-08-02 华为技术有限公司 图像生成方法、神经网络的压缩方法及相关装置、设备
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN111027060A (zh) * 2019-12-17 2020-04-17 电子科技大学 基于知识蒸馏的神经网络黑盒攻击型防御方法
CN111126573A (zh) * 2019-12-27 2020-05-08 深圳力维智联技术有限公司 基于个体学习的模型蒸馏改进方法、设备及存储介质
CN111160474A (zh) * 2019-12-30 2020-05-15 合肥工业大学 一种基于深度课程学习的图像识别方法
CN111461226A (zh) * 2020-04-01 2020-07-28 深圳前海微众银行股份有限公司 对抗样本生成方法、装置、终端及可读存储介质

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11410029B2 (en) * 2018-01-02 2022-08-09 International Business Machines Corporation Soft label generation for knowledge distillation
CN109711544A (zh) * 2018-12-04 2019-05-03 北京市商汤科技开发有限公司 模型压缩的方法、装置、电子设备及计算机存储介质
US11604984B2 (en) * 2019-11-18 2023-03-14 Shanghai United Imaging Intelligence Co., Ltd. Systems and methods for machine learning based modeling
CN111598216B (zh) * 2020-04-16 2021-07-06 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN112381209B (zh) * 2020-11-13 2023-12-22 平安科技(深圳)有限公司 一种模型压缩方法、系统、终端及存储介质

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109710544A (zh) * 2017-10-26 2019-05-03 杭州华为数字技术有限公司 内存访问方法、计算机系统以及处理装置
CN110084281A (zh) * 2019-03-31 2019-08-02 华为技术有限公司 图像生成方法、神经网络的压缩方法及相关装置、设备
CN110674880A (zh) * 2019-09-27 2020-01-10 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN111027060A (zh) * 2019-12-17 2020-04-17 电子科技大学 基于知识蒸馏的神经网络黑盒攻击型防御方法
CN111126573A (zh) * 2019-12-27 2020-05-08 深圳力维智联技术有限公司 基于个体学习的模型蒸馏改进方法、设备及存储介质
CN111160474A (zh) * 2019-12-30 2020-05-15 合肥工业大学 一种基于深度课程学习的图像识别方法
CN111461226A (zh) * 2020-04-01 2020-07-28 深圳前海微众银行股份有限公司 对抗样本生成方法、装置、终端及可读存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
Data-Free Learning of Student Networks;Hanting Chen 等;ICCV 2019;全文 *
Knowledge Distillation with Adversarial Samples Supporting Decision Boundary;Byeongho Heo 等;AAAI 2019;全文 *

Also Published As

Publication number Publication date
CN112381209A (zh) 2021-02-19
WO2021197223A1 (zh) 2021-10-07

Similar Documents

Publication Publication Date Title
CN112381209B (zh) 一种模型压缩方法、系统、终端及存储介质
CN109902706B (zh) 推荐方法及装置
US20220222531A1 (en) Asynchronous neural network training
CN111461168B (zh) 训练样本扩充方法、装置、电子设备及存储介质
CN111062489A (zh) 一种基于知识蒸馏的多语言模型压缩方法、装置
US11704570B2 (en) Learning device, learning system, and learning method
US20230196202A1 (en) System and method for automatic building of learning machines using learning machines
CN110188910A (zh) 利用机器学习模型提供在线预测服务的方法及系统
US20190102658A1 (en) Hierarchical image classification method and system
JP2022529178A (ja) 人工知能推奨モデルの特徴処理方法、装置、電子機器、及びコンピュータプログラム
WO2022127474A1 (en) Providing explainable machine learning model results using distributed ledgers
CN111401940A (zh) 特征预测方法、装置、电子设备及存储介质
JPWO2018062265A1 (ja) 音響モデル学習装置、その方法、及びプログラム
CN111783873A (zh) 基于增量朴素贝叶斯模型的用户画像方法及装置
CN111967941B (zh) 一种构建序列推荐模型的方法和序列推荐方法
CN110598869A (zh) 基于序列模型的分类方法、装置、电子设备
CN103782290A (zh) 建议值的生成
Huai et al. Latency-constrained DNN architecture learning for edge systems using zerorized batch normalization
Zhou et al. LightAdam: Towards a fast and accurate adaptive momentum online algorithm
JP7215966B2 (ja) ハイパーパラメータ管理装置、ハイパーパラメータ管理方法及びハイパーパラメータ管理プログラム製品
KR20220040295A (ko) 메트릭 학습을 위한 가상의 학습 데이터 생성 방법 및 시스템
CN110990256A (zh) 开源代码检测方法、装置及计算机可读存储介质
CN113010687A (zh) 一种习题标签预测方法、装置、存储介质以及计算机设备
JP2021135683A (ja) 学習装置、推論装置、学習方法及び推論方法
CN111768220A (zh) 生成车辆定价模型的方法和装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant