WO2021197223A1 - 一种模型压缩方法、系统、终端及存储介质 - Google Patents

一种模型压缩方法、系统、终端及存储介质 Download PDF

Info

Publication number
WO2021197223A1
WO2021197223A1 PCT/CN2021/083230 CN2021083230W WO2021197223A1 WO 2021197223 A1 WO2021197223 A1 WO 2021197223A1 CN 2021083230 W CN2021083230 W CN 2021083230W WO 2021197223 A1 WO2021197223 A1 WO 2021197223A1
Authority
WO
WIPO (PCT)
Prior art keywords
network
sample
student network
training
student
Prior art date
Application number
PCT/CN2021/083230
Other languages
English (en)
French (fr)
Inventor
郑强
王晓锐
高鹏
谢国彤
Original Assignee
平安科技(深圳)有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 平安科技(深圳)有限公司 filed Critical 平安科技(深圳)有限公司
Publication of WO2021197223A1 publication Critical patent/WO2021197223A1/zh

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Definitions

  • This application relates to the field of artificial intelligence technology, in particular to a model compression method, system, terminal and storage medium.
  • model life cycle can usually be divided into two links: model training and model inference.
  • model training process in order to pursue the model with higher prediction accuracy, the model is often inevitably redundant.
  • model reasoning link the inventor realized that due to the harsh requirements of the reasoning application environment, in addition to focusing on the accuracy of model prediction, the model also hopes that the model has high performance characteristics such as fast reasoning, low resource consumption, and small file size.
  • Model compression is precisely the common optimization method that transforms the model from model training to model inference.
  • the mainstream model compression technologies in the industry include pruning, quantification, and knowledge distillation. These mainstream technologies need to participate in the original training data set to complete the model optimization process model; among them, the pruning technology needs to be completed through the original training data set. Pruning decision and reconstruction after pruning (Fine-Tune); model quantization needs to complete the Quantization-aware training (quantization introduced in training) process through the original training data set or through the original training data set Post-training quantization (quantization after training) ) Calibration process; Knowledge distillation needs to send the original training data set to the Teacher network and the Student network to complete the Knowledge-Transfer ((knowledge transfer)) process.
  • model training and model compression are often undertaken by different functional teams, and the division of labor is relatively clear.
  • the training data involves privacy or massive data (difficult to transmit and store), it is more difficult to obtain the original training data set, which affects the progress of the model compression work.
  • This application provides a model compression method, system, terminal, and storage medium, which can solve the deficiencies in the prior art to a certain extent.
  • a model compression method including:
  • a model compression system including:
  • Distillation compression module Based on at least one set of hyperparameter combinations, the training samples generated by the sample generator are input into the student network and the teacher network respectively, and the student network and the teacher network are subjected to anti-knowledge distillation training to generate a roughly compressed student network;
  • Sample generation module used to generate samples through a random sample generator, input the generated samples into the teacher network, and generate a synthetic sample set by the teacher network;
  • Supervised learning module used to perform supervised learning training on the coarsely compressed student network through the synthetic sample set to obtain the compression result of the student network.
  • a terminal includes a processor and a memory coupled to the processor, wherein:
  • the memory stores program instructions for implementing the following steps, and the steps include:
  • the processor is configured to execute the program instructions stored in the memory to perform the model compression operation.
  • a storage medium storing program instructions executable by a processor, and the program instructions are used to execute the following steps, and the steps include:
  • the model compression method, system, terminal and storage medium of the embodiments of this application perform model compression through two stages of coarse compression and fine compression.
  • the anti-knowledge distillation method is used to distill the student network Compression to generate a coarsely compressed student network;
  • the fine compression stage a high-quality synthetic sample set is generated through the teacher network, and supervised learning is performed on the coarsely compressed student network through the synthetic sample set to achieve model compression without the original training data set.
  • the embodiments of the present application have at least the following advantages:
  • the compression of the model can be achieved without relying on the original training data set, which solves the problem that the model compression cannot be carried out due to the sensitivity of the original training data set and the massive amount of data;
  • FIG. 1 is a schematic flowchart of a model compression method according to a first embodiment of the present application
  • FIG. 2 is a schematic flowchart of a model compression method according to a second embodiment of the present application
  • Figure 3 is a schematic diagram of the implementation process of distilling and compressing the student model in an embodiment of the application
  • FIG. 4 is a schematic diagram of a synthetic sample set generation process according to an embodiment of the application.
  • FIG. 5 is a schematic flowchart of a model compression method according to a third embodiment of the present application.
  • Fig. 6 is a schematic diagram of a student network training process in an embodiment of the present application.
  • FIG. 7 is a schematic diagram of the performance of the student model in an embodiment of the present application.
  • Fig. 8 is a schematic diagram of a synthesized sample in an embodiment of the present application.
  • FIG. 9 is a schematic diagram of model compression results in an embodiment of the present application.
  • FIG. 10 is a schematic structural diagram of a model compression system according to an embodiment of the present application.
  • FIG. 11 is a schematic diagram of a terminal structure of an embodiment of the present application according to an embodiment of the present application.
  • FIG. 12 is a schematic diagram of the structure of a storage medium according to an embodiment of the present application.
  • first”, “second”, and “third” in this application are only used for descriptive purposes, and cannot be understood as indicating or implying relative importance or implicitly indicating the number of indicated technical features. Thus, the features defined with “first”, “second”, and “third” may explicitly or implicitly include at least one of the features.
  • "a plurality of” means at least two, such as two, three, etc., unless otherwise specifically defined. All directional indications (such as up, down, left, right, front, back%) in the embodiments of this application are only used to explain the relative positional relationship between the components in a specific posture (as shown in the figure) , Movement status, etc., if the specific posture changes, the directional indication will also change accordingly.
  • the model compression method in the embodiments of the present application divides the entire compression stage into two stages: coarse compression and fine compression.
  • coarse compression stage an anti-knowledge distillation method is used to achieve a rough estimate of the compressed model. Compression to obtain coarse compression results; in the fine compression stage, the method of supervised learning is used to fine-tune the coarse compression results to obtain higher-precision compression results, thereby completing high-precision compression of the model without relying on the original training data set .
  • This application can also be applied to smart contract scenarios to promote the purpose of blockchain construction.
  • FIG. 1 is a schematic flowchart of the model compression method according to the first embodiment of the present application.
  • the model compression method of the first embodiment of the present application includes the following steps:
  • the training samples are input into the student network and the teacher network respectively, and the student network and the teacher network are subjected to anti-knowledge distillation training to generate a coarsely compressed student network;
  • the adversarial knowledge distillation training on the student network and the teacher network specifically includes the following steps:
  • S11c Backpropagate the S network according to the first loss value loss_s, and update the parameters of the S network; among them, the parameter update goal of the S network is to make the loss_s smaller and smaller, that is, the S network and the T network are getting closer and closer. ;
  • S11d Iteratively execute S11a to S11c K times, and update the parameters of the S network K times through back propagation; at this time, the G network coefficients are not updated;
  • S11g Backpropagate the G network according to the second loss value loss_g1, and update the parameters of the G network; wherein the parameter update goal of the G network is to make the first loss value loss_s larger and larger;
  • S11h Iteratively execute S11e to S11g M times, and update the parameters of the G network M times through backpropagation; at this time, the G network coefficients are not updated;
  • S11i Iteratively execute S11a to S11h until the ACC1 (evaluation index, such as Accuracy) of the S network no longer increases significantly, and the iteration ends, and the training result of the anti-knowledge distillation based on the hyperparameter combination H1 is obtained ⁇ H1, S1, ACC1 ⁇ , and Save S network parameters.
  • ACC1 evaluation index, such as Accuracy
  • S12 Generate samples through a random sample generator, input the generated samples into the teacher network, and the teacher network generates a synthetic sample set;
  • the method for generating the synthetic sample set specifically includes:
  • S12a Generate the desired label label through the Label generator
  • the loss function calculator L calculates the third loss function loss_g2 based on the label generated by the label generator and the label_hat1 output by the T network;
  • S12f Iteratively execute S12c to S12e M times, and perform M gradient updates on the generated sample until the generated sample meets the preset requirement.
  • FIG. 2 is a schematic flowchart of a model compression method according to a second embodiment of the present application.
  • the model compression method of the second embodiment of the present application includes the following steps:
  • the embodiment of the application constructs a more lightweight student network model structure based on the pre-training model structure and the model compression target.
  • Figure 3 is the implementation process of distilling and compressing the student model in this embodiment of the application, which specifically includes:
  • S21a Take a super parameter combination H1 from the super parameter combination clusters (H1, H2, H3...HN) for training against knowledge distillation;
  • S21b Input the first random number r1 generated by the random number generator into the G network (Generator, sample generator), and the G network generates the first training sample x1;
  • S21c Input the first training sample x1 to the T network (teacher network) and S network (student network), respectively, the T network and S network output the first prediction results y and y_hat, and the loss function calculator L is based on the first prediction As a result, y and y_hat calculate the first loss value loss_s;
  • S21d Backpropagate the S network according to the first loss value loss_s, and update the parameters of the S network; among them, the parameter update goal of the S network is to make the loss_s smaller and smaller, that is, the S network and the T network are getting closer and closer. ;
  • S21e Iteratively execute steps S21b to S21d K times, and update the parameters of the S network K times through back propagation; at this time, the G network coefficients are not updated;
  • S21h Backpropagate the G network according to the second loss value loss_g1, and update the parameters of the G network; wherein the parameter update goal of the G network is to make the first loss value loss_s larger and larger;
  • Steps S21f to S21h are iteratively performed M times, and the parameters of the G network are updated M times through back propagation; the G network coefficients are not updated at this time;
  • Steps S21b to S21i are executed iteratively until the ACC1 (evaluation index, such as Accuracy) of the S network no longer increases significantly, and the iteration ends, and the training result of the anti-knowledge distillation based on the hyperparameter combination H1 is obtained ⁇ H1, S1, ACC1 ⁇ , And save S network parameters;
  • ACC1 evaluation index, such as Accuracy
  • step S21k Based on the hyperparameter combination in the hyperparameter combination cluster (H1, H2, H3...HN), iteratively execute step S21b to step S21in (n ⁇ N) times to obtain the hyperparameter combination cluster (H1, H2, H3...HN) )
  • S22 Generate samples through a random sample generator, input the generated samples into the teacher network, and the teacher network generates a synthetic sample set;
  • FIG. 4 is a schematic diagram of the synthetic sample set generation process according to the embodiment of the application.
  • the synthetic sample set generation method specifically includes:
  • the loss function calculator L calculates the third loss function loss_g2 based on the label generated by the label generator and the label_hat1 output by the T network;
  • S22f Iteratively execute S22c to S22e M times, and perform gradient updates on the generated sample Sample M times until the generated sample Sample meets the preset requirements;
  • S22g Iteratively execute S22a to S22f to generate a synthetic sample set ⁇ Sample(B,H,W,C), label ⁇ .
  • FIG. 5 is a schematic flowchart of a model compression method according to a third embodiment of the present application.
  • the model compression method of the third embodiment of the present application includes the following steps:
  • the embodiment of the application constructs a more lightweight student network model structure based on the pre-training model structure and the model compression target.
  • the implementation process of distilling and compressing the student model includes:
  • S31a Take a super parameter combination H1 from the super parameter combination clusters (H1, H2, H3...HN) for training against knowledge distillation;
  • S31b Input the first random number r1 generated by the random number generator into the G network (Generator, sample generator), and the G network generates the first training sample x1;
  • S31d Backpropagate the S network according to the first loss value loss_s, and update the parameters of the S network; among them, the parameter update goal of the S network is to make the loss_s smaller and smaller, that is, the S network and the T network are getting closer and closer. ;
  • S31e Iteratively execute steps S31b to S31d K times, and update the parameters of the S network K times through back propagation; at this time, the G network coefficients are not updated;
  • S31h Perform backpropagation on the G network according to the second loss value loss_g1, and update the parameters of the G network; wherein, the parameter update goal of the G network is to make the first loss value loss_s larger and larger;
  • S31i Iteratively execute steps S31f to S31h M times, and update the parameters of the G network M times through back propagation; at this time, the G network coefficients are not updated;
  • S31j Iteratively execute steps S31b to S31i until the ACC1 (evaluation index, such as Accuracy) of the S network no longer increases significantly, and the iteration ends, and the training result of the anti-knowledge distillation based on the hyperparameter combination H1 is obtained ⁇ H1, S1, ACC1 ⁇ , And save S network parameters;
  • ACC1 evaluation index, such as Accuracy
  • step S31k Based on the hyperparameter combination in the hyperparameter combination cluster (H1, H2, H3...HN), iteratively execute step S31b to step S31in (n ⁇ N) times to obtain the hyperparameter combination cluster (H1, H2, H3... HN)
  • S32 Generate samples through a random sample generator, input the generated samples into the teacher network, and the teacher network generates a synthetic sample set;
  • the synthetic sample set generation process specifically includes:
  • the loss function calculator L calculates the third loss function loss_g2 based on the label generated by the label generator and the label_hat1 output by the T network;
  • S32f Iteratively execute S32c to S32e M times, and perform M gradient updates on the generated sample until the generated sample meets the preset requirements;
  • S32g Iteratively execute S32a to S32f to generate a synthetic sample set ⁇ Sample(B,H,W,C), label ⁇ .
  • FIG. 6 is a schematic diagram of the supervised learning and training process of the student network in an embodiment of the present application.
  • the synthetic sample Sample (B, H, W, C) is input to the coarsely compressed S network, and the S network outputs the second predicted label label_hat2; the loss function calculator L calculates the loss function loss_s of the S network based on the second predicted label label_hat2.
  • model evaluation indicators include but are not limited to Accuracy (accuracy) and so on.
  • experiments are carried out by taking the application of the model compression method of the embodiments of the present application to OCR task model compression based on the Transformer architecture as an example.
  • the A node in the system block diagram can obtain the student model with the performance shown in Figure 7.
  • high-quality synthesized samples as shown in Figure 8 can be obtained at node B in the system block diagram.
  • the model compression result shown in Figure 9 can be obtained at node C in the system block diagram.
  • the model compression method of the embodiment of the present application performs model compression through two stages of coarse compression and fine compression.
  • the coarse compression stage the anti-knowledge distillation method is used to distill and compress the student network to generate a coarsely compressed student network;
  • the fine compression stage a high-quality synthetic sample set is generated through the teacher network, and supervised learning is performed on the coarsely compressed student network through the synthetic sample set, so as to achieve model compression without the original training data set.
  • the embodiments of the present application have at least the following advantages:
  • the compression of the model can be achieved without relying on the original training data set, which solves the problem that the model compression cannot be carried out due to the sensitivity of the original training data set and the massive amount of data;
  • the corresponding summary information is obtained based on the result of the model compression method.
  • the summary information is obtained by hashing the result of the model compression method, for example, obtained by processing the sha256s algorithm.
  • Uploading summary information to the blockchain can ensure its security and fairness and transparency to users.
  • the user can download the summary information from the blockchain to verify whether the result of the model compression method has been tampered with.
  • the blockchain referred to in this example is a new application mode of computer technology such as distributed data storage, point-to-point transmission, consensus mechanism, and encryption algorithm.
  • Blockchain essentially a decentralized database, is a series of data blocks associated with cryptographic methods. Each data block contains a batch of network transaction information for verification. The validity of the information (anti-counterfeiting) and the generation of the next block.
  • the blockchain can include the underlying platform of the blockchain, the platform product service layer, and the application service layer.
  • FIG. 10 is a schematic structural diagram of a model compression system according to an embodiment of the present application.
  • the model compression system 40 of the embodiment of the present application includes:
  • Distillation compression module 41 used to distill and compress the student network based on at least one set of hyperparameter combinations using the anti-knowledge distillation method to generate a coarse compression student network; specifically, the distillation compression module 42 performs distillation compression on the student model. for:
  • Step 1 Take a hyperparameter combination H1 from the hyperparameter combination cluster (H1, H2, H3...HN) for training against knowledge distillation;
  • Step 2 Input the first random number r1 generated by the random number generator into the G network (Generator, sample generator), and the G network generates the first training sample x1;
  • Step 3 Input the first training sample x1 into the T network (teacher network) and S network (student network), respectively, the T network and S network output the first prediction results y and y_hat, and the loss function calculator L is based on the first A prediction result y and y_hat calculate the first loss value loss_s;
  • Step 4 Backpropagate the S network according to the first loss value loss_s, and update the parameters of the S network; among them, the parameter update goal of the S network is to make the loss_s smaller and smaller, that is, the S network and the T network are more and more Closer
  • the fifth step iteratively execute the second to fourth steps K times, and update the parameters of the S network K times through backpropagation; at this time, the G network coefficients are not updated;
  • Step 6 Input the second random number r2 generated by the random number generator into the G network, and the G network generates a second training sample x2;
  • Step 7 Input the second training sample x2 into the updated T network and S network, respectively, the T network and S network output the second prediction results y and y_hat, the loss function calculator L according to the second prediction result y Calculate the second loss value loss_g1 with y_hat;
  • Step 8 Backpropagate the G network according to the second loss value loss_g1, and update the parameters of the G network; wherein the parameter update goal of the G network is to make the first loss value loss_s larger and larger;
  • Step 9 Iteratively execute steps 6 to 8 M times, and update the parameters of the G network M times through backpropagation; at this time, the G network coefficients are not updated;
  • the tenth step iteratively execute the second to the ninth steps until the ACC1 (evaluation index, such as Accuracy) of the S network no longer increases significantly, and the iteration ends, and the training result of the anti-knowledge distillation based on the hyperparameter combination H1 is obtained ⁇ H1, S1, ACC1 ⁇ , and save the S network parameters;
  • ACC1 evaluation index, such as Accuracy
  • Step 11 Based on the hyperparameter combination in the hyperparameter combination cluster (H1, H2, H3...HN), iteratively execute steps from the second step to the ninth step n (n ⁇ N) times to obtain the hyperparameter combination cluster ( H1, H2, H3...HN) is the training result ⁇ Hn, Sn, ACCn ⁇ of anti-knowledge distillation.
  • Sample generation module 42 used to generate samples through a random sample generator, input the generated samples into the teacher network, and the teacher network generates a synthetic sample set; wherein the synthetic sample set generation process specifically includes:
  • Step 1 Generate the desired label label through the Label generator
  • Step 3 Input the Sample into the T network, and the T network will output the first predicted label label_hat1;
  • Step 4 The loss function calculator L calculates the third loss function loss_g2 based on the label generated by the label generator and the label_hat1 output by the T network;
  • Step 5 Perform gradient update on the generated sample Sample based on the third loss function loss_g2;
  • Step 6 Iteratively execute steps 3 to 5 M times, and update the gradient of the generated sample M times until the generated sample meets the preset requirements;
  • Step 7 Iteratively execute steps 1 to 6 to generate a synthetic sample set ⁇ Sample(B,H,W,C), label ⁇ .
  • Supervised learning module 43 used to input the synthetic sample set into the coarsely compressed student network, perform supervised learning training on the coarsely compressed student network, and obtain the result of the student network compression; among them, the supervised learning and training process of the student network is specifically as follows: Synthetic sample Sample (B, H, W, C) is input to the coarsely compressed S network, and the S network outputs the second predicted label label_hat2; the loss function calculator L calculates the loss function loss_s of the S network based on the second predicted label label_hat2.
  • FIG. 11 is a schematic diagram of a terminal structure according to an embodiment of the application.
  • the terminal 50 includes a processor 51 and a memory 52 coupled to the processor 51.
  • the memory 52 stores program instructions for realizing the above-mentioned model compression method.
  • the processor 51 is configured to execute program instructions stored in the memory 52 to perform a model compression operation.
  • the processor 51 may also be referred to as a CPU (Central Processing Unit, central processing unit).
  • the processor 51 may be an integrated circuit chip with signal processing capability.
  • the processor 51 may also be a general-purpose processor, a digital signal processor (DSP), an application specific integrated circuit (ASIC), an off-the-shelf programmable gate array (FPGA) or other programmable logic device, a discrete gate or transistor logic device, or a discrete hardware component.
  • DSP digital signal processor
  • ASIC application specific integrated circuit
  • FPGA off-the-shelf programmable gate array
  • the general-purpose processor may be a microprocessor or the processor may also be any conventional processor or the like.
  • FIG. 12 is a schematic structural diagram of a storage medium according to an embodiment of the application.
  • the storage medium of the embodiment of the present application stores a program file 61 that can implement all the above methods.
  • the program file 61 can be stored in the above storage medium in the form of a software product, and includes a number of instructions to enable a computer device (which can It is a personal computer, a server, or a network device, etc.) or a processor (processor) that executes all or part of the steps of the methods in the various embodiments of the present application.
  • the aforementioned storage media include: U disk, mobile hard disk, read-only memory (ROM, Read-Only Memory), random access memory (RAM, Random Access Memory), magnetic disks or optical disks and other media that can store program codes, or terminal devices such as computers, servers, mobile phones, and tablets.
  • the storage medium may be non-volatile or volatile.
  • the disclosed system, device, and method can be implemented in other ways.
  • the system embodiment described above is only illustrative.
  • the division of units is only a logical function division, and there may be other divisions in actual implementation, for example, multiple units or components can be combined or integrated. To another system, or some features can be ignored, or not implemented.
  • the displayed or discussed mutual coupling or direct coupling or communication connection may be indirect coupling or communication connection through some interfaces, devices or units, and may be in electrical, mechanical or other forms.
  • the functional units in the various embodiments of the present application may be integrated into one processing unit, or each unit may exist alone physically, or two or more units may be integrated into one unit.
  • the above-mentioned integrated unit can be implemented in the form of hardware or software functional unit. The above are only implementations of this application, and do not limit the scope of this application. Any equivalent structure or equivalent process transformation made using the content of the description and drawings of this application, or directly or indirectly applied to other related technical fields, The same reasoning is included in the scope of patent protection of this application.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Other Investigation Or Analysis Of Materials By Electrical Means (AREA)

Abstract

本申请公开了一种模型压缩方法、系统、终端及存储介质。所述方法包括:通过样本生成器生成训练样本;基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。本申请实施例可完全不依赖于原始训练数据集实现对模型的压缩,解决了因为原始训练数据集敏感性和数据海量问题导致模型压缩工作无法展开的问题,有效降低了模型压缩的精度损失。本申请还涉及区块链技术领域。

Description

一种模型压缩方法、系统、终端及存储介质
本申请要求于2020年11月13日提交中国专利局、申请号为202011269682.8、申请名称为“一种模型压缩方法、系统、终端及存储介质”的中国专利申请的优先权,其全部内容通过引用结合在本申请中。
技术领域
本申请涉及人工智能技术领域,特别是涉及一种模型压缩方法、系统、终端及存储介质。
背景技术
在人工智能领域,模型生命周期通常可分为模型训练和模型推理两个环节。在模型训练环节,为追求模型具有更高的预测精准度,模型往往不可避免的存在冗余。而在模型推理环节,发明人意识到,由于受到推理应用环境的苛刻要求,除了关注模型预测的精准度外,还希望模型具有推理速度快、资源占用省、文件尺寸小等高性能特点。模型压缩恰恰是将模型从模型训练环节向模型推理环节转变的常用优化手段。
目前,业界主流的模型压缩技术包括剪枝、量化和知识蒸馏等,这些主流技术都需要通过原始训练数据集参与才能完成对模型的优化过程模型;其中,剪枝技术需要通过原始训练数据集完成剪枝决策和剪枝后重建(Fine-Tune);模型量化需要通过原始训练数据集来完成Quantization-aware training(训练中引入量化)过程或者通过原始训练数据集Post-training quantization(训练后的量化)的Calibration(校准)过程;知识蒸馏需要通过原始训练数据集分别送入Teacher网络和Student网络完成Knowledge-Transfer((知识转移))的过程。
从行业的发展状况来看,模型训练和模型压缩往往由不同的职能团队承担,且分工比较明确。而由于训练数据涉及私密性或者数据海量(难于传输和存储)等原因,获得原始训练数据集的难度较大,影响模型压缩工作的进展。
近期,虽然逐渐也有不依赖于原始训练数据集的模型压缩技术出现,例如对抗知识蒸馏;但由于该技术成熟度不高,还存在如下不足:
1. 对抗知识蒸馏过程波动性和随机性大,难于稳定复现;
2. 对抗知识蒸馏的精度损失较大,难于满足实际应用要求。
技术问题
本申请提供了一种模型压缩方法、系统、终端及存储介质,能够在一定程度上解决现有技术中存在的不足。
技术解决方案
为解决上述技术问题,本申请采用的技术方案为:
一种模型压缩方法,包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
本申请实施例采取的另一技术方案为:一种模型压缩系统,包括:
蒸馏压缩模块:用于基于至少一组超参组合,将样本生成器生成的训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
样本生成模块:用于通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
监督学习模块:用于通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
本申请实施例采取的又一技术方案为:一种终端,所述终端包括处理器、与所述处理器耦接的存储器,其中,
所述存储器存储有用于实现如下步骤的程序指令,所述步骤包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果;
所述处理器用于执行所述存储器存储的所述程序指令以执行所述模型压缩操作。
本申请实施例采取的又一技术方案为:一种存储介质,存储有处理器可运行的程序指令,所述程序指令用于执行如下步骤,所述步骤包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
有益效果
本申请的有益效果是:本申请实施例的模型压缩方法、系统、终端及存储介质通过粗压缩和精压缩两个阶段进行模型压缩,在粗压缩阶段,采用对抗知识蒸馏方法对学生网络进行蒸馏压缩,生成粗压缩学生网络;在精压缩阶段,通过教师网络生成高质量的合成样本集,并通过合成样本集对粗压缩学生网络进行有监督学习,实现无需原始训练数据集的模型压缩。相对于现有技术,本申请实施例至少具有以下优点:
1、可完全不依赖于原始训练数据集实现对模型的压缩,解决了因为原始训练数据集敏感性和数据海量问题导致模型压缩工作无法展开的问题;
2、弥补了对抗知识蒸馏方法进行模型压缩时随机性大、波动性大以及难以控制和调试的问题;
3、有效降低了模型压缩的精度损失,做到几乎无损。
附图说明
图1是本申请第一实施例的模型压缩方法的流程示意图;
图2是本申请第二实施例的模型压缩方法的流程示意图;
图3为本申请实施例对学生模型进行蒸馏压缩的实现过程示意图;
图4为本申请实施例的合成样本集生成过程示意图;
图5是本申请第三实施例的模型压缩方法的流程示意图;
图6是本申请实施例的学生网络训练过程示意图;
图7是本申请一个实施例中的学生模型性能示意图;
图8是本申请一个实施例中的合成样本示意图;
图9是本申请一个实施例中的模型压缩结果示意图;
图10是本申请实施例模型压缩系统的结构示意图;
图11是本申请实施例的本申请实施例的终端结构示意图;
图12是本申请实施例的存储介质结构示意图。
本发明的实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请的一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请中的术语“第一”、“第二”、“第三”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”、“第三”的特征可以明示或者隐含地包括至少一个该特征。本申请的描述中,“多个”的含义是至少两个,例如两个,三个等,除非另有明确具体的限定。本申请实施例中所有方向性指示(诸如上、下、左、右、前、后……)仅用于解释在某一特定姿态(如附图所示)下各部件之间的相对位置关系、运动情况等,如果该特定姿态发生改变时,则该方向性指示也相应地随之改变。此外,术语“包括”和“具有”以及它们任何变形,意图在于覆盖不排他的包含。例如包含了一系列步骤或单元的过程、方法、系统、产品或设备没有限定于已列出的步骤或单元,而是可选地还包括没有列出的步骤或单元,或可选地还包括对于这些过程、方法、产品或设备固有的其它步骤或单元。
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
针对现有技术存在的不足,本申请实施例的模型压缩方法通过将整个压缩阶段分成粗压缩和精压缩两个阶段,首先,在粗压缩阶段,采用对抗知识蒸馏方法实现对被压缩模型的粗略压缩,得到粗压缩结果;在精压缩阶段,采用监督学习的方法对粗压缩结果进行微调,得到更高精度的压缩结果,从而在不依赖原始训练数据集的情况下完成对模型的高精度压缩。本申请还可以应用于智能合约场景中,从而推动区块链的建设的目的。
具体的,请参阅图1,是本申请第一实施例的模型压缩方法的流程示意图。本申请第一实施例的模型压缩方法包括以下步骤:
S10:通过样本生成器生成训练样本;
S11:基于至少一组超参组合,将训练样本分别输入学生网络和教师网络,对学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
本步骤中,对学生网络和教师网络进行对抗知识蒸馏训练具体包括以下步骤:
S11a:基于超参组合H1,将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
S11b:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
S11c:根据第一损失值loss_s对 S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
S11d:迭代执行S11a至S11c K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
S11e:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
S11f:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
S11g:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
S11h:迭代执行S11e至S11g M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
S11i:迭代执行S11a至S11h,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1} ,并保存S网络参数。
S12:通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;
本步骤中,合成样本集的生成方式具体包括:
S12a:通过Label生成器生成期望标签label;
S12b:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
S12c:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
S12d:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
S12e:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
S12f:迭代执行S12c至S12e M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求。
S13:通过合成样本集对粗压缩学生网络进行有监督学习训练,得到学生网络的压缩结果。
请参阅图2,是本申请第二实施例的模型压缩方法的流程示意图。本申请第二实施例的模型压缩方法包括以下步骤:
S20:构造学生网络结构;
本步骤中,本申请实施例基于预训练模型结构和模型压缩目标构造出更轻量级的学生网络模型结构。
S21:在不使用原始训练数据集的情况下,通过对抗知识蒸馏方式在多种超参组合方式下对学生网络进行蒸馏压缩,得到粗压缩S网络;
本步骤中,请一并参阅图3,为本申请实施例对学生模型进行蒸馏压缩的实现过程,其具体包括:
S21a:从超参组合簇(H1,H2,H3…HN)中取一个超参组合H1,用于对抗知识蒸馏的训练;
S21b:将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
S21c:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
S21d:根据第一损失值loss_s对 S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
S21e:迭代执行步骤S21b至S21d K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
S21f:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
S21g:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
S21h:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
S21i:迭代执行步骤S21f至S21h M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
S21j:迭代执行步骤S21b至S21i,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1} ,并保存S网络参数;
S21k:基于超参组合簇(H1,H2,H3…HN)中的超参组合,迭代执行步骤S21b至步骤S21in(n∈N)次,得到基于超参组合簇(H1,H2,H3…HN)进行对抗知识蒸馏的训练结果{Hn,Sn,ACCn}。
S22:通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;
本步骤中,请一并参阅图4,为本申请实施例的合成样本集生成过程示意图。合成样本集生成方式具体包括:
S22a:通过Label生成器生成期望标签label;
S22b:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
S22c:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
S22d:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
S22e:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
S22f:迭代执行S22c至S22e M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求;
S22g:迭代执行S22a至S22f,生成合成样本集{Sample(B,H,W,C), label}。
S23:通过合成样本集对粗压缩后的S网络进行有监督学习训练,得到S网络的压缩结果。
请参阅图5,是本申请第三实施例的模型压缩方法的流程示意图。本申请第三实施例的模型压缩方法包括以下步骤:
S30:构造学生网络结构;
本步骤中,本申请实施例基于预训练模型结构和模型压缩目标构造出更轻量级的学生网络模型结构。
S31:在不使用原始训练数据集的情况下,通过对抗知识蒸馏方式在多种超参组合方式下对学生网络进行蒸馏压缩,得到粗压缩S网络;
本步骤中,对学生模型进行蒸馏压缩的实现过程,其具体包括:
S31a:从超参组合簇(H1,H2,H3…HN)中取一个超参组合H1,用于对抗知识蒸馏的训练;
S31b:将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
S31c:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
S31d:根据第一损失值loss_s对 S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
S31e:迭代执行步骤S31b至S31d K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
S31f:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
S31g:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
S31h:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
S31i:迭代执行步骤S31f至S31h M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
S31j:迭代执行步骤S31b至S31i,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1} ,并保存S网络参数;
S31k:基于超参组合簇(H1,H2,H3…HN)中的超参组合,迭代执行步骤S31b至步骤S31i n(n∈N)次,得到基于超参组合簇(H1,H2,H3…HN)进行对抗知识蒸馏的训练结果{Hn,Sn,ACCn}。
S32:通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;
本步骤中,合成样本集生成过程具体包括:
S32a:通过Label生成器生成期望标签label;
S32b:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
S32c:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
S32d:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
S32e:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
S32f:迭代执行S32c至S32e M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求;
S32g:迭代执行S32a至S32f,生成合成样本集{Sample(B,H,W,C), label}。
S33:通过合成样本集对粗压缩后的S网络进行有监督学习训练,得到S网络的压缩结果;
本步骤中,请一并参阅图6,是本申请实施例的学生网络有监督学习训练过程示意图。首先,将合成样本 Sample(B,H,W,C)输入粗压缩后的S网络,由S网络输出第二预测标签label_hat2;损失函数计算器L基于第二预测标签label_hat2计算出S网络的损失函数loss_s。
S34:根据模型评价指标对学生网络压缩结果进行评价;
其中,模型评价指标包括但不限于Accuracy(准确率)等。
可以理解,上述实施例中的K次、M次等迭代次数可根据实际应用场景进行设定。
为了验证本申请实施例的可行性和有效性,以将本申请实施例的模型压缩方法应用于基于Transformer架构进行的OCR任务模型压缩为例进行实验。在经过8组超参组合配置下,通过对抗知识蒸馏后,在系统框图中A节点可得到如图7所示性能的学生模型。通过样本合成迭代训练后,在系统框图中B节点处可得到如图8所示的高质量的合成样本。通过第二阶段的监督学习后,在系统框图中C节点处可得到如图9所示的模型压缩结果。实验结果证明,本申请实施例可以在不依赖于原始训练数据集的情况下得到高精度的模型压缩结果。
基于上述,本申请实施例的模型压缩方法通过粗压缩和精压缩两个阶段进行模型压缩,在粗压缩阶段,采用对抗知识蒸馏方法对学生网络进行蒸馏压缩,生成粗压缩后的学生网络;在精压缩阶段,通过教师网络生成高质量的合成样本集,并通过合成样本集对粗压缩后的学生网络进行有监督学习,实现无需原始训练数据集的模型压缩。相对于现有技术,本申请实施例至少具有以下优点:
1、可完全不依赖于原始训练数据集实现对模型的压缩,解决了因为原始训练数据集敏感性和数据海量问题导致模型压缩工作无法展开的问题;
2、弥补了对抗知识蒸馏方法进行模型压缩时随机性大、波动性大以及难以控制和调试的问题;
3、有效降低了模型压缩的精度损失,做到几乎无损。
在一个可选的实施方式中,还可以:将所述的模型压缩方法的结果上传至区块链中。
具体地,基于所述的模型压缩方法的结果得到对应的摘要信息,具体来说,摘要信息由所述的模型压缩方法的结果进行散列处理得到,比如利用sha256s算法处理得到。将摘要信息上传至区块链可保证其安全性和对用户的公正透明性。用户可以从区块链中下载得该摘要信息,以便查证所述的模型压缩方法的结果是否被篡改。本示例所指区块链是分布式数据存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。区块链(Blockchain),本质上是一个去中心化的数据库,是一串使用密码学方法相关联产生的数据块,每一个数据块中包含了一批次网络交易的信息,用于验证其信息的有效性(防伪)和生成下一个区块。区块链可以包括区块链底层平台、平台产品服务层以及应用服务层等。
请参阅图10,是本申请实施例模型压缩系统的结构示意图。本申请实施例模型压缩系统40包括:
蒸馏压缩模块41:用于基于至少一组超参组合,采用对抗知识蒸馏方法对学生网络进行蒸馏压缩,生成粗压缩学生网络;具体的,蒸馏压缩模块42对学生模型进行蒸馏压缩的实现过程具体为:
第一步:从超参组合簇(H1,H2,H3…HN)中取一个超参组合H1,用于对抗知识蒸馏的训练;
第二步:将随机数生成器生成的第一随机数r1输入G网络(Generator,样本生成器),由G网络生成第一训练样本x1;
第三步:将第一训练样本x1分别输入至T网络(教师网络)和S网络(学生网络),分别由T网络和S网络输出第一预测结果y和y_hat,损失函数计算器L根据第一预测结果y和y_hat计算出第一损失值loss_s;
第四步:根据第一损失值loss_s对 S网络进行反向传播,对S网络进行参数更新;其中,S网络的参数更新目标是使得loss_s越来越小,即与S网络与T网络越来越接近;
第五步:迭代执行第二步至第四步 K次,通过反向传播对S网络的参数进行K次更新;此时G网络系数不更新;
第六步:将随机数生成器生成的第二随机数r2输入G网络,由G网络生成第二训练样本x2;
第七步:将第二训练样本x2分别输入至参数更新后的T网络和S网络,分别由T网络和S网络输出第二预测结果y和y_hat,损失函数计算器L根据第二预测结果y和y_hat计算出第二损失值loss_g1;
第八步:根据第二损失值loss_g1对G网络进行反向传播,对G网络的参数进行更新;其中,G网络的参数更新目标是使得第一损失值loss_s越来越大;
第九步:迭代执行第六步至第八步 M次,通过反向传播对G网络的参数进行M次更新;此时G网络系数不更新;
第十步:迭代执行第二步至第九步,直到S网络的ACC1(评价指标,如Accuracy)不再明显增加时结束迭代,得到基于超参组合H1进行对抗知识蒸馏的训练结果{H1,S1,ACC1} ,并保存S网络参数;
第十一步:基于超参组合簇(H1,H2,H3…HN)中的超参组合,迭代执行步骤第二步至第九步n(n∈N)次,得到基于超参组合簇(H1,H2,H3…HN)进行对抗知识蒸馏的训练结果{Hn,Sn,ACCn}。
样本生成模块42:用于通过随机样本生成器生成样本,将生成样本输入教师网络,由教师网络生成合成样本集;其中,合成样本集的生成过程具体包括:
第一步:通过Label生成器生成期望标签label;
第二步:通过随机样本生成器生成样本Sample(B,H,W,C),其中B=Batch(图片数)、H=Height(图片长)、W=Width(图片宽)、C=Channel(通道数);
第三步:将Sample输入T网络,由T网络输出第一预测标签label_hat1;
第四步:损失函数计算器L基于label生成器生成的label和T网络输出的label_hat1计算出第三损失函数loss_g2;
第五步:基于第三损失函数loss_g2对生成样本Sample进行梯度更新;
第六步:迭代执行第三步至第五步 M次,对生成样本Sample进行M次梯度更新,直到生成样本Sample满足预设要求;
第七步:迭代执行第一步至第六步,生成合成样本集{Sample(B,H,W,C), label}。
监督学习模块43:用于将合成样本集输入粗压缩学生网络,对粗压缩学生网络进行有监督学习训练,得到学生网络压缩结果;其中,学生网络的有监督学习训练过程具体为:首先,将合成样本 Sample(B,H,W,C)输入粗压缩后的S网络,由S网络输出第二预测标签label_hat2;损失函数计算器L基于第二预测标签label_hat2计算出S网络的损失函数loss_s。
请参阅图11,为本申请实施例的终端结构示意图。该终端50包括处理器51、与处理器51耦接的存储器52。
存储器52存储有用于实现上述模型压缩方法的程序指令。
处理器51用于执行存储器52存储的程序指令以执行模型压缩操作。
其中,处理器51还可以称为CPU(Central Processing Unit,中央处理单元)。处理器51可能是一种集成电路芯片,具有信号的处理能力。处理器51还可以是通用处理器、数字信号处理器(DSP)、专用集成电路(ASIC)、现成可编程门阵列(FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。通用处理器可以是微处理器或者该处理器也可以是任何常规的处理器等。
请参阅图12,图12为本申请实施例的存储介质的结构示意图。本申请实施例的存储介质存储有能够实现上述所有方法的程序文件61,其中,该程序文件61可以以软件产品的形式存储在上述存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)或处理器(processor)执行本申请各个实施方式方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,Random Access Memory)、磁碟或者光盘等各种可以存储程序代码的介质,或者是计算机、服务器、手机、平板等终端设备。所述存储介质可以是非易失性,也可以是易失性。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的系统实施例仅仅是示意性的,例如,单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。以上仅为本申请的实施方式,并非因此限制本申请的专利范围,凡是利用本申请说明书及附图内容所作的等效结构或等效流程变换,或直接或间接运用在其他相关的技术领域,均同理包括在本申请的专利保护范围内。

Claims (20)

1. 一种模型压缩方法,其中,包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
2. 根据权利要求1所述的模型压缩方法,其中,所述基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练包括:
基于超参组合,将随机数生成器生成的第一随机数输入样本生成器,由所述样本生成器生成第一训练样本;
将所述第一训练样本分别输入教师网络和学生网络,分别由所述教师网络和学生网络输出第一预测结果y和y_hat,根据所述第一预测结果y和y_hat计算出第一损失值;
根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新;
将所述随机数生成器生成的第二随机数输入样本生成器,由所述样本生成器生成第二训练样本;
将所述第二训练样本分别输入至所述参数更新后的学生网络和教师网络,分别由所述教师网络和学生网络输出第二预测结果y和y_hat,根据所述第二预测结果y和y_hat计算出第二损失值;
根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器进行参数更新;
迭代执行所述学生网络和样本生成器的更新,直到满足迭代结束条件时结束迭代,得到基于超参组合对所述学生网络进行的对抗知识蒸馏的训练结果,并保存所述学生网络的参数。
3. 根据权利要求2所述的模型压缩方法,其中,所述根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新还包括:
根据预设迭代次数迭代执行所述第一损失值的计算以及所述学生网络的反向传播,对所述学生网络的参数进行预设迭代次数的更新。
4. 根据权利要求3所述的模型压缩方法,其中,所述根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器的参数进行更新还包括:
根据预设迭代次数迭代执行所述第二损失值的计算以及所述样本生成器的反向传播,对所述样本生成器的参数进行预设迭代次数的更新。
5. 根据权利要求1所述的模型压缩方法,其中,所述由所述教师网络生成合成样本集包括:
通过Label生成器生成期望标签;
通过随机样本生成器生成样本;
将所述生成样本输入教师网络,由所述教师网络输出第一预测标签;
基于所述期望标签和第一预测标签计算出第三损失函数;
基于所述第三损失函数对所述生成样本进行梯度更新;
根据预设迭代次数迭代执行所述生成样本的梯度更新,直到所述生成样本满足预设条件。
6. 根据权利要求5所述的模型压缩方法,其中,所述将所述合成样本集输入所述粗压缩学生网络,对所述粗压缩学生网络进行有监督学习训练包括:
将所述合成样本输入粗压缩学生网络,由所述粗压缩学生网络输出第二预测标签;
基于所述第二预测标签计算出所述学生网络最终的损失函数。
7. 根据权利要求1至6任一项所述的模型压缩方法,其中,所述将所述合成样本集输入所述粗压缩学生网络,对所述粗压缩学生网络进行有监督学习训练后还包括:
根据模型评价指标对所述学生网络压缩结果进行评价;所述模型评价指标包括准确率。
8. 一种模型压缩系统,其中,包括:
蒸馏压缩模块:用于基于至少一组超参组合,将样本生成器生成的训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
样本生成模块:用于通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
监督学习模块:用于通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
9. 一种终端,其中,所述终端包括处理器、与所述处理器耦接的存储器,其中,
所述存储器存储有用于实现如下步骤的程序指令,所述步骤包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果;
所述处理器用于执行所述存储器存储的所述程序指令。
10. 根据权利要求9所述的终端,其中,所述基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练包括:
基于超参组合,将随机数生成器生成的第一随机数输入样本生成器,由所述样本生成器生成第一训练样本;
将所述第一训练样本分别输入教师网络和学生网络,分别由所述教师网络和学生网络输出第一预测结果y和y_hat,根据所述第一预测结果y和y_hat计算出第一损失值;
根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新;
将所述随机数生成器生成的第二随机数输入样本生成器,由所述样本生成器生成第二训练样本;
将所述第二训练样本分别输入至所述参数更新后的学生网络和教师网络,分别由所述教师网络和学生网络输出第二预测结果y和y_hat,根据所述第二预测结果y和y_hat计算出第二损失值;
根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器进行参数更新;
迭代执行所述学生网络和样本生成器的更新,直到满足迭代结束条件时结束迭代,得到基于超参组合对所述学生网络进行的对抗知识蒸馏的训练结果,并保存所述学生网络的参数。
11. 根据权利要求10所述的终端,其中,所述根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新还包括:
根据预设迭代次数迭代执行所述第一损失值的计算以及所述学生网络的反向传播,对所述学生网络的参数进行预设迭代次数的更新。
12. 根据权利要求11所述的终端,其中,所述根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器的参数进行更新还包括:
根据预设迭代次数迭代执行所述第二损失值的计算以及所述样本生成器的反向传播,对所述样本生成器的参数进行预设迭代次数的更新。
13. 根据权利要求9所述的终端,其中,所述由所述教师网络生成合成样本集包括:
通过Label生成器生成期望标签;
通过随机样本生成器生成样本;
将所述生成样本输入教师网络,由所述教师网络输出第一预测标签;
基于所述期望标签和第一预测标签计算出第三损失函数;
基于所述第三损失函数对所述生成样本进行梯度更新;
根据预设迭代次数迭代执行所述生成样本的梯度更新,直到所述生成样本满足预设条件。
14. 根据权利要求13所述的终端,其中,所述将所述合成样本集输入所述粗压缩学生网络,对所述粗压缩学生网络进行有监督学习训练包括:
将所述合成样本输入粗压缩学生网络,由所述粗压缩学生网络输出第二预测标签;
基于所述第二预测标签计算出所述学生网络最终的损失函数。
15. 根据权利要求9至14任一项所述的终端,其中,所述将所述合成样本集输入所述粗压缩学生网络,对所述粗压缩学生网络进行有监督学习训练后还包括:
根据模型评价指标对所述学生网络压缩结果进行评价;所述模型评价指标包括准确率。
16. 一种存储介质,其中,存储有处理器可运行的程序指令,所述程序指令用于执行如下步骤,所述步骤包括:
通过样本生成器生成训练样本;
基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练,生成粗压缩学生网络;
通过随机样本生成器生成样本,将所述生成样本输入所述教师网络,由所述教师网络生成合成样本集;
通过所述合成样本集对所述粗压缩学生网络进行有监督学习训练,得到所述学生网络的压缩结果。
17. 根据权利要求16所述的存储介质,其中,所述基于至少一组超参组合,将所述训练样本分别输入学生网络和教师网络,对所述学生网络和教师网络进行对抗知识蒸馏训练包括:
基于超参组合,将随机数生成器生成的第一随机数输入样本生成器,由所述样本生成器生成第一训练样本;
将所述第一训练样本分别输入教师网络和学生网络,分别由所述教师网络和学生网络输出第一预测结果y和y_hat,根据所述第一预测结果y和y_hat计算出第一损失值;
根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新;
将所述随机数生成器生成的第二随机数输入样本生成器,由所述样本生成器生成第二训练样本;
将所述第二训练样本分别输入至所述参数更新后的学生网络和教师网络,分别由所述教师网络和学生网络输出第二预测结果y和y_hat,根据所述第二预测结果y和y_hat计算出第二损失值;
根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器进行参数更新;
迭代执行所述学生网络和样本生成器的更新,直到满足迭代结束条件时结束迭代,得到基于超参组合对所述学生网络进行的对抗知识蒸馏的训练结果,并保存所述学生网络的参数。
18. 根据权利要求17所述的存储介质,其中,所述根据所述第一损失值对所述学生网络进行反向传播,对所述学生网络进行参数更新还包括:
根据预设迭代次数迭代执行所述第一损失值的计算以及所述学生网络的反向传播,对所述学生网络的参数进行预设迭代次数的更新。
19. 根据权利要求18所述的存储介质,其中,所述根据所述第二损失值对所述样本生成器进行反向传播,对所述样本生成器的参数进行更新还包括:
根据预设迭代次数迭代执行所述第二损失值的计算以及所述样本生成器的反向传播,对所述样本生成器的参数进行预设迭代次数的更新。
20. 根据权利要求16所述的存储介质,其中,所述由所述教师网络生成合成样本集包括:
通过Label生成器生成期望标签;
通过随机样本生成器生成样本;
将所述生成样本输入教师网络,由所述教师网络输出第一预测标签;
基于所述期望标签和第一预测标签计算出第三损失函数;
基于所述第三损失函数对所述生成样本进行梯度更新;
根据预设迭代次数迭代执行所述生成样本的梯度更新,直到所述生成样本满足预设条件。
PCT/CN2021/083230 2020-11-13 2021-03-26 一种模型压缩方法、系统、终端及存储介质 WO2021197223A1 (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202011269682.8A CN112381209B (zh) 2020-11-13 2020-11-13 一种模型压缩方法、系统、终端及存储介质
CN202011269682.8 2020-11-13

Publications (1)

Publication Number Publication Date
WO2021197223A1 true WO2021197223A1 (zh) 2021-10-07

Family

ID=74583913

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/083230 WO2021197223A1 (zh) 2020-11-13 2021-03-26 一种模型压缩方法、系统、终端及存储介质

Country Status (2)

Country Link
CN (1) CN112381209B (zh)
WO (1) WO2021197223A1 (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114495245A (zh) * 2022-04-08 2022-05-13 北京中科闻歌科技股份有限公司 人脸伪造图像鉴别方法、装置、设备以及介质
CN115908955A (zh) * 2023-03-06 2023-04-04 之江实验室 基于梯度蒸馏的少样本学习的鸟类分类系统、方法与装置

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112381209B (zh) * 2020-11-13 2023-12-22 平安科技(深圳)有限公司 一种模型压缩方法、系统、终端及存储介质
CN113255763B (zh) * 2021-05-21 2023-06-09 平安科技(深圳)有限公司 基于知识蒸馏的模型训练方法、装置、终端及存储介质
US11599794B1 (en) 2021-10-20 2023-03-07 Moffett International Co., Limited System and method for training sample generator with few-shot learning

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109711544A (zh) * 2018-12-04 2019-05-03 北京市商汤科技开发有限公司 模型压缩的方法、装置、电子设备及计算机存储介质
US20190205748A1 (en) * 2018-01-02 2019-07-04 International Business Machines Corporation Soft label generation for knowledge distillation
CN110084281A (zh) * 2019-03-31 2019-08-02 华为技术有限公司 图像生成方法、神经网络的压缩方法及相关装置、设备
CN111178542A (zh) * 2019-11-18 2020-05-19 上海联影智能医疗科技有限公司 基于机器学习建模的系统和方法
CN111461226A (zh) * 2020-04-01 2020-07-28 深圳前海微众银行股份有限公司 对抗样本生成方法、装置、终端及可读存储介质
CN111598216A (zh) * 2020-04-16 2020-08-28 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN112381209A (zh) * 2020-11-13 2021-02-19 平安科技(深圳)有限公司 一种模型压缩方法、系统、终端及存储介质

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109710544B (zh) * 2017-10-26 2021-02-09 华为技术有限公司 内存访问方法、计算机系统以及处理装置
CN110674880B (zh) * 2019-09-27 2022-11-11 北京迈格威科技有限公司 用于知识蒸馏的网络训练方法、装置、介质与电子设备
CN111027060B (zh) * 2019-12-17 2022-04-29 电子科技大学 基于知识蒸馏的神经网络黑盒攻击型防御方法
CN111126573B (zh) * 2019-12-27 2023-06-09 深圳力维智联技术有限公司 基于个体学习的模型蒸馏改进方法、设备及存储介质
CN111160474B (zh) * 2019-12-30 2023-08-29 合肥工业大学 一种基于深度课程学习的图像识别方法

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190205748A1 (en) * 2018-01-02 2019-07-04 International Business Machines Corporation Soft label generation for knowledge distillation
CN109711544A (zh) * 2018-12-04 2019-05-03 北京市商汤科技开发有限公司 模型压缩的方法、装置、电子设备及计算机存储介质
CN110084281A (zh) * 2019-03-31 2019-08-02 华为技术有限公司 图像生成方法、神经网络的压缩方法及相关装置、设备
CN111178542A (zh) * 2019-11-18 2020-05-19 上海联影智能医疗科技有限公司 基于机器学习建模的系统和方法
CN111461226A (zh) * 2020-04-01 2020-07-28 深圳前海微众银行股份有限公司 对抗样本生成方法、装置、终端及可读存储介质
CN111598216A (zh) * 2020-04-16 2020-08-28 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN112381209A (zh) * 2020-11-13 2021-02-19 平安科技(深圳)有限公司 一种模型压缩方法、系统、终端及存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
BYEONGHO HEO; MINSIK LEE; SANGDOO YUN; JIN YOUNG CHOI: "Knowledge Distillation with Adversarial Samples Supporting Decision Boundary", ARXIV.ORG, CORNELL UNIVERSITY LIBRARY, 201 OLIN LIBRARY CORNELL UNIVERSITY ITHACA, NY 14853, 15 May 2018 (2018-05-15), 201 Olin Library Cornell University Ithaca, NY 14853, XP080878249 *
CHEN HANTING; WANG YUNHE; XU CHANG; YANG ZHAOHUI; LIU CHUANJIAN; SHI BOXIN; XU CHUNJING; XU CHAO; TIAN QI: "Data-Free Learning of Student Networks", 2019 IEEE/CVF INTERNATIONAL CONFERENCE ON COMPUTER VISION (ICCV), IEEE, 27 October 2019 (2019-10-27), pages 3513 - 3521, XP033723721, DOI: 10.1109/ICCV.2019.00361 *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114495245A (zh) * 2022-04-08 2022-05-13 北京中科闻歌科技股份有限公司 人脸伪造图像鉴别方法、装置、设备以及介质
CN115908955A (zh) * 2023-03-06 2023-04-04 之江实验室 基于梯度蒸馏的少样本学习的鸟类分类系统、方法与装置

Also Published As

Publication number Publication date
CN112381209B (zh) 2023-12-22
CN112381209A (zh) 2021-02-19

Similar Documents

Publication Publication Date Title
WO2021197223A1 (zh) 一种模型压缩方法、系统、终端及存储介质
Liu et al. Fedcoin: A peer-to-peer payment system for federated learning
KR102342604B1 (ko) 뉴럴 네트워크 생성 방법 및 장치
CN109902706B (zh) 推荐方法及装置
WO2022089256A1 (zh) 联邦神经网络模型的训练方法、装置、设备、计算机程序产品及计算机可读存储介质
CN105446896B (zh) 映射化简应用的缓存管理方法和装置
Wu et al. Using fractional order accumulation to reduce errors from inverse accumulated generating operator of grey model
JP2020523619A (ja) プライバシー保護のための分散型マルチパーティセキュリティモデル訓練フレームワーク
JP2020525814A (ja) 秘密分散を使用したロジスティック回帰モデリング方式
WO2023124296A1 (zh) 基于知识蒸馏的联合学习训练方法、装置、设备及介质
CN113408746A (zh) 一种基于区块链的分布式联邦学习方法、装置及终端设备
CN108431832A (zh) 利用外部存储器扩增神经网络
CN108694201A (zh) 一种实体对齐方法和装置
US20190087723A1 (en) Variable isa vector-based compaction in distributed training of neural networks
US20210150351A1 (en) Isa-based compression in distributed training of neural networks
CN110175469A (zh) 一种社交媒体用户隐私泄漏检测方法、系统、设备及介质
Hommes et al. Genetic algorithm learning in a New Keynesian macroeconomic setup
CN103782290A (zh) 建议值的生成
TWI758223B (zh) 具有動態最小批次尺寸之運算方法,以及用於執行該方法之運算系統及電腦可讀儲存媒體
CN116108697B (zh) 基于多元性能退化的加速试验数据处理方法、装置和设备
US20220230092A1 (en) Fast converging gradient compressor for federated learning
JP6321216B2 (ja) 行列・キー生成装置、行列・キー生成システム、行列結合装置、行列・キー生成方法、プログラム
Soleymani Efficient semi-discretization techniques for pricing European and American basket options
Dong et al. Weighted least squares model averaging for accelerated failure time models
US20150234686A1 (en) Exploiting parallelism in exponential smoothing of large-scale discrete datasets

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 21780406

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

122 Ep: pct application non-entry in european phase

Ref document number: 21780406

Country of ref document: EP

Kind code of ref document: A1