CN116502621B - 一种基于自适应对比知识蒸馏的网络压缩方法和装置 - Google Patents
一种基于自适应对比知识蒸馏的网络压缩方法和装置 Download PDFInfo
- Publication number
- CN116502621B CN116502621B CN202310758129.8A CN202310758129A CN116502621B CN 116502621 B CN116502621 B CN 116502621B CN 202310758129 A CN202310758129 A CN 202310758129A CN 116502621 B CN116502621 B CN 116502621B
- Authority
- CN
- China
- Prior art keywords
- network
- adaptive
- sample
- distillation
- samples
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 74
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 72
- 238000007906 compression Methods 0.000 title claims abstract description 61
- 230000006835 compression Effects 0.000 title claims abstract description 60
- 230000003044 adaptive effect Effects 0.000 claims abstract description 95
- 238000004821 distillation Methods 0.000 claims abstract description 68
- 230000000052 comparative effect Effects 0.000 claims abstract description 42
- 238000013528 artificial neural network Methods 0.000 claims abstract description 17
- 238000003058 natural language processing Methods 0.000 claims abstract description 17
- 230000008569 process Effects 0.000 claims abstract description 12
- 230000006870 function Effects 0.000 claims description 17
- 238000012549 training Methods 0.000 claims description 13
- 239000000126 substance Substances 0.000 claims description 8
- 238000004590 computer program Methods 0.000 claims description 5
- 238000004364 calculation method Methods 0.000 claims description 4
- 238000002474 experimental method Methods 0.000 description 10
- 230000000694 effects Effects 0.000 description 6
- 238000010586 diagram Methods 0.000 description 5
- 238000011156 evaluation Methods 0.000 description 5
- 238000002679 ablation Methods 0.000 description 3
- 239000003292 glue Substances 0.000 description 3
- 239000000203 mixture Substances 0.000 description 3
- 230000008451 emotion Effects 0.000 description 2
- 238000003780 insertion Methods 0.000 description 2
- 230000037431 insertion Effects 0.000 description 2
- 238000005457 optimization Methods 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 241000761389 Copa Species 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000012733 comparative method Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 230000013011 mating Effects 0.000 description 1
- 230000003278 mimic effect Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000012546 transfer Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F40/00—Handling natural language data
- G06F40/20—Natural language analysis
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/213—Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Other Investigation Or Analysis Of Materials By Electrical Means (AREA)
- Vaporization, Distillation, Condensation, Sublimation, And Cold Traps (AREA)
Abstract
本发明公开了一种基于自适应对比知识蒸馏的网络压缩方法及装置。该网络压缩方法包括如下步骤:引入对比性蒸馏损失作为显式监督,以最大化特征负样本对的距离;利用神经网络作为预测器,根据每个样本的学习特征来预测其辨别能力;然后,根据预测的辨别能力对不同样本的损失进行重新加权,以实现样本适应性重加权策略;将样本适应性重加权策略融入到对比性蒸馏损失中,构建自适应对比性蒸馏损失;基于自适应对比性蒸馏损失构建自适应对比知识蒸馏框架,用于实现自然语言处理过程中的神经网络压缩。
Description
技术领域
本发明涉及一种基于自适应对比知识蒸馏的网络压缩方法,同时也涉及相应的网络压缩装置,属于计算机系统技术领域。
背景技术
知识蒸馏(Knowledge Distillation,简写为KD)是一种经典的神经网络压缩方法,其通过引导轻量化的学生网络“模仿”性能更好、结构更复杂的教师网络,在不改变学生网络的情况下提高其性能。在现有技术中,自然语言处理领域常用的BERT模型,其采用的知识蒸馏方法隐含地学习了学生网络的鉴别性特征,也就是说,需要把来自不同类别的样本(负对)的特征推得很远,而把来自相同类别的样本(正对)的特征保持得很近。假设教师网络是经过良好学习的(即在教师网络中,负对的特征是相互远离的),通过最小化教师网络和学生网络之间每个样本的特征距离,使学生网络的特征具有鉴别性,如图1中左侧所示,学生网络中的负对的特征就可以被拉得很远。但是,当常用词出现在具有不同含义的句子中时,会导致教师网络中的负对的特征相互接近,如图1中右侧所示,在这种情况下,使用现有的知识蒸馏范式训练学生网络,将导致学生网络中的负对的特征也是相互接近的。
在自然语言处理(简写为NLP)任务中,现有的知识蒸馏方法在蒸馏过程中没有充分注意到困难样本,类似的句子可能有完全不同的含义。例如,对于语言可接受性任务,虽然句子“我们把自己喊哑了”和“我们把哈里喊哑了”是相似的,因为二者只有一个不同的词,但是,第一个句子在语言上是可接受的,而后一个句子则不是,使二者属于不同类别。由于这类句子的特征是相似的,使得这类句子难以区分,因此辨别能力较差。因此,对困难样本给予更多关注以加强其辨别能力也是亟待解决的一个问题。
在申请号为202210540113.5的中国发明专利申请中,公开了一种基于不确定性估计知识蒸馏的语言模型压缩方法。该方法包括如下步骤:1)对原始语言模型进行对半压缩得到压缩后的神经网络;2)利用原始语言模型合理初始化压缩后神经网络的参数;3)添加前馈网络结构的参数蒸馏损失函数,设计不确定性估计损失函数及自然语言处理任务的交叉熵损失函数;4)利用所设计的损失函数训练压缩后的神经网络模型。该技术方案降低了网络压缩训练过程的计算量,提高了网络压缩率。
发明内容
本发明所要解决的首要技术问题在于提供一种基于自适应对比知识蒸馏的网络压缩方法。
本发明所要解决的另一技术问题在于提供一种基于自适应对比知识蒸馏的网络压缩装置。
为了实现上述目的,本发明采用以下的技术方案:
根据本发明实施例的第一方面,提供一种基于自适应对比知识蒸馏的网络压缩方法,用在自然语言处理任务中,包括如下步骤:
(1)引入对比性蒸馏损失作为显式监督,以最大化特征负样本对的距离;
(2)利用一个神经网络作为预测器,根据每个样本的学习特征来预测其辨别能力;然后,根据预测的辨别能力对不同样本的损失进行重新加权,以实现样本适应性重加权策略;
(3)将所述样本适应性重加权策略融入到所述对比性蒸馏损失中,构建自适应对比性蒸馏损失;
(4)基于所述自适应对比性蒸馏损失构建自适应对比知识蒸馏框架,用于实现自然语言处理过程中的神经网络压缩。
其中较优地,所述步骤(1)中,对于每个样本,利用所述对比性蒸馏损失最大化学生网络中样本的特征与教师网络中样本的特征之间的相似性,并最小化学生网络中样本的特征与教师网络中样本的负样本对特征之间的相似性。
其中较优地,所述步骤(3)中,在自适应对比性蒸馏损失的训练过程中,增加分子项,以使教师网络和学生网络中来自同一样本的特征相互接近,同时减少分母项,以使学生网络中来自不同类别的第个样本的特征远离教师网络中第/>个样本的特征,其中/>为正整数。
其中较优地,所述步骤(3)中,给具有较少鉴别特征的样本分配更高的权重,以形成自适应对比性蒸馏损失。
其中较优地,所述步骤(4)中,在每个批次的计算结束后,将所述批次的特征保存到动态特征存储器中,同时将样本的标签也保存到所述动态特征存储器中,用于识别样本。
其中较优地,所述动态特征存储器中只存储插入自适应对比性蒸馏损失的所在层的特征;在存储空间满了之后,根据先入先出的策略更新存储空间。
根据本发明实施例的第二方面,提供一种基于自适应对比知识蒸馏的网络压缩装置,用在自然语言处理任务中,包括处理器和动态特征存储器,所述处理器和所述动态特征存储器耦接;其中,
所述动态特征存储器用于存储计算机程序;
所述处理器用于运行存储在所述动态特征存储器中的计算机程序,执行上述基于自适应对比知识蒸馏的网络压缩方法。
与现有技术相比较,首先,本发明实施例将对比学习的概念引入到知识蒸馏中,提出一个对比性蒸馏损失作为显式监督,以最大化特征负样本对的距离。特别地,对于每个样本,对比性蒸馏损失旨在最大化学生网络中样本的特征与教师网络中样本的特征之间的相似性,并最小化学生网络中样本的特征与教师网络中样本的负样本对特征之间的相似性。
其次,本发明实施例针对较少鉴别性特征的困难样本,提出样本适应性重加权策略,以适应性地对这些困难样本给予更多的关注,并加强其识别能力。具体地说,利用一个神经网络作为预测器,根据每个样本的学习特征来预测其辨别能力。然后,根据预测的鉴别能力对不同样本的损失进行重新加权。由于这个过程中的所有操作都是可微分的,预测器的参数可以与学生网络共同学习。
最后,本发明实施例将样本适应性重加权策略无缝地融入前述对比性蒸馏损失中,形成自适应对比性蒸馏损失。将自适应对比性蒸馏损失与现有技术中的患者知识蒸馏方法相结合,构建了自适应对比知识蒸馏框架,可以用于BERT压缩。通过在多个自然语言处理任务上的广泛实验,证明了该自适应对比知识蒸馏框架对BERT压缩的有效性。
附图说明
图1为现有技术中,教师网络和学生网络之间负对样本的特征距离示意图;
图2为本发明实施例中,基于自适应对比知识蒸馏的网络压缩方法的逻辑框架图;
图3为本发明实施例中,多项消融实验的测试结果的示意图;
图4为本发明实施例中,教师网络和学生网络之间负对样本的特征距离示意图;
图5为本发明实施例中,基于自适应对比知识蒸馏的网络压缩装置的结构示意图。
具体实施方式
下面结合附图和具体实施例对本发明的技术内容进行详细具体的说明。
在将知识蒸馏应用于BERT模型的过程中,有一些方法被提出来压缩BERT模型,但这些方法缺乏对学习判别学生网络特征的显式监督。同时,虽然有些工作也使用了BERT模型的对比损失,但这些方法没有充分关注自然语言处理任务中的困难样本。为了解决现有技术中存在的上述问题,本发明实施例首先提供一种基于自适应对比知识蒸馏的网络压缩方法,以便使用经过预训练的大容量教师网络更好地来帮助训练轻量级学生网络。
如图2所示,在本发明实施例提供的基于自适应对比知识蒸馏的网络压缩方法中,首先在自然语言处理任务中引入对比性蒸馏损失(简写为CDL)作为显式监督,以最大化特征负样本对的距离,学习更多的鉴别性的学生网络特征;接着利用一个神经网络作为预测器,根据每个样本的学习特征来预测其辨别能力;然后,根据预测的辨别能力对不同样本的损失进行重新加权,以实现样本适应性重加权策略(简写为SAR),以适应性地对鉴别能力较差的困难样本给予更多的关注;将该样本适应性重加权策略无缝地融入到该对比性蒸馏损失中,构建自适应对比性蒸馏损失(简写为A-CDL);基于该自适应对比性蒸馏损失构建自适应对比知识蒸馏(简写为ACKD)框架,用于实现自然语言处理过程中的神经网络压缩。通过在多个自然语言处理任务上的广泛实验,证明了本发明提供的网络压缩方法对BERT压缩的有效性。
在本发明的一个实施例中,自适应对比知识蒸馏框架在训练学生网络时的损失来自四个部分,分别为交叉熵损失(简写为CEL)、知识蒸馏损失(简写为KDL)、患者损失(简写为PTL)和自适应对比性蒸馏损失。下面,以现有技术中的患者知识蒸馏方法(详见SiqiSun, Yu Cheng, Zhe Gan, and Jingjing Liu. 2019.《Patient knowledgedistillation for bert model compression》.arXiv:1908.09355.)为基础,具体说明本发明实施例提供的基于自适应对比知识蒸馏的网络压缩方法的具体实施过程。
在现有技术中的患者知识蒸馏方法中,给定具有个样本的训练数据集,学生网络可以使用如下的损失函数进行训练:
其中,i和N为正整数,为任务特有的损失,/>为相应的损失函数,对于分类任务通常采用交叉熵损失函数;/>为知识蒸馏损失,/>为相应的损失函数,一般采用师生之间输出概率分布的Kullback-Leibler散度;/>为引入的患者损失;是均方误差函数;/>和/>分别为教师网络和学生网络,其参数分别表示为/>和/>;/>和/>分别表示计算患者损失时,教师网络和学生网络在第/>的配对层的第个样本的隐藏状态特征;/>是插入患者损失的层数;/>和/>是控制不同损失函数项的权衡的超参数;损失/>、/>和/>分别对应于图2中的交叉熵损失、知识蒸馏损失和患者损失。
虽然公式(1)中的损失可以将知识从教师网络转移到学生网络身上,但它缺乏显式的监督来学习学生网络的辨别特征。也就是说,它只提供了将教师网络和学生网络的同一样本的特征拉近的监督,而缺乏将不同类的特征推远以获得更多的鉴别性特征学习的监督。为此,在本发明实施例中,首先引入一个对比性蒸馏损失作为显式监督,以学习更多的鉴别性的学生网络特征。
由于对比性蒸馏损失可以在不同的层中引入,下面只关注第个成对的层,为了更好地表达,省略层的索引。例如,使用/>和/>分别表示教师网络和学生网络在同一层的第/>个样本的隐藏状态特征。对比性蒸馏损失/>的表达式如下:
其中, $表示余弦相似度。/>$表示包含来自不同类别的样本与第/>个样本(即负样本对)的隐藏状态特征的集合,i、m、N均为正整数。
由于类似的句子可能具有完全不同的含义,这使得这些样本难以区分。对较少鉴别性特征的困难样本,本发明实施例进一步提出一种样本适应性重加权策略,以适应性地对这些困难样本给予更多的关注,并加强其识别能力。具体来说,使用一个由神经网络构成的预测器来预测每个样本的辨别能力,并将这种预测的辨别能力纳入对比性蒸馏损失,形成自适应对比性蒸馏损失。自适应对比性蒸馏损失的表达式如下:
其中,是第/>个样本的预测分辨能力;/>是预测器的函数,由神经网络实现;/>是预测器的可学习参数;/>是sigmoid函数,用于确保预测的辨别能力为正值。
由于在这个过程中,所有的操作都是可微的,可以在蒸馏中与学生网络共同训练这个预测器。因此,可以适应性地给具有较少鉴别特征的样本分配更高的权重,最后形成自适应对比性蒸馏损失,它对应于图2中的自适应对比性蒸馏损失(A-CDL)。需要说明的是,本发明实施例中的预测器是由一个简单的神经网络实现的。因此,与梯度计算所需的计算相比,该预测器引起的额外计算可以忽略不计。
由于自适应对比性蒸馏损失可以被引入到教师网络和学生网络的不同配对层中,为了更好地表达,下面的公式中另外使用上标来表示插入自适应对比性蒸馏损失的第配对层的相应符号。因此,在本发明实施例提供的自适应对比知识蒸馏框架中,训练学生网络时的损失函数表达式为:
其中,和/>是控制不同成分重要性的超参数;/>、/>和/>分别是交叉熵损失、知识蒸馏损失和患者损失;/>是适应性对比蒸馏损失。
通过使用公式(4)中引入的损失,可以使用显式监督将学生网络中的负样本对的特征推得很远,同时考虑到样本的分辨能力。通过这种方式,进一步构建了用于BERT压缩的自适应对比知识蒸馏框架。
在构建自适应对比知识蒸馏框架时,存在的另一个问题是自适应对比性蒸馏损失需要大量的样本多样性,而现有技术中的患者知识蒸馏方法并不需要。因此,构建自适应对比知识蒸馏框架需要解决这个问题。具体来说,这个组成成分是根据不同样本的特征来计算的。由于目前深度学习网络的特性,特征将在每个小批的计算后被释放。因此,只能根据一个小批次的样本来计算/>。因此,第/>个样本的特征只能从一小部分负样本对的特征中推远,这就造成了优化方向的不准确。为此,本发明实施例中构建了动态特征存储器,用于获得更多的样本多样性。
具体来说,在每个批次的计算结束后,将这一批次的特征保存到动态特征存储器中从而进行的计算。同时,这些样本的标签也保存到动态特征存储器中,用于识别中的样本。由于BERT模型是并行处理一连串的标识,特征维度相对较大,这就给GPU(图形处理器)带来了更大的内存负担。因此,为了进一步节省内存用量,在动态特征存储器中只存储插入自适应对比性蒸馏损失的所在层的特征。在存储空间满了之后,根据先入先出的策略更新存储空间。在本发明的一个实施例中,将存储空间的大小设定为1000。通过这种方式,在计算/>时增加了样本多样性。
在本发明的一个实施例中,自适应对比性蒸馏损失的设计理念是:在蒸馏过程中,损失将被最小化。为了实现这一目标,需要最大化/>函数内部的值。因此,在训练过程中,分子项/>将被增加,这将使教师网络和学生网络中来自同一样本的特征相互接近。同时,分母项/>将被减少,这将使学生网络中来自不同类别的第/>个样本的特征远离教师网络中第/>个样本的特征,其中/>为正整数。此外,通过使用鉴别能力/>,给具有较少鉴别特征的样本分配较高的权重。通过这种方式,在考虑到样本辨别能力的情况下引入了显式监督,以学习更多的学生网络特征的辨别能力。
从另一个角度来看,本发明实施例提供的自适应对比性蒸馏损失也可以被看作是在学习学生网络时“消除”教师网络不正确预测的影响的损失。具体来说,如图1中右侧所示,如果类别2的样本与类别1的样本接近而被老师错误分类,现有技术中的患者知识蒸馏方法将不会意识到这种错误。因此,学生网络的类别2的样本会被老师的类别2的样本 "吸引",导致学生网络也出现错误分类。相反,从公式(3)来看,在计算时,负样本集/>是基于真实标签得到的。因此,如图2所示,尽管类别2的样本被老师错误分类,但学生网络的类别2的样本会被老师的类别1的样本 "排斥"。虽然学生网络的交叉熵损失也是基于真实标签的,但优化方向会受到教师网络错误预测的影响。所以,本发明实施例提供的自适应对比性蒸馏损失可以在一定程度上“消除”教师网络不正确预测的影响。
为了验证本发明实施例提供的基于自适应对比知识蒸馏的网络压缩方法的实际性能,发明人进行了如下的实验测试和评估。
第一,GLUE基准测试。
在本发明的一个实施例中,采用GLUE基准评估自适应对比知识蒸馏框架。具体来说,使用GLUE基准的开发集并使用四个任务进行评估:释义相似性匹配、情绪分类、自然语言推理和语言可接受性。对于释义相似性匹配任务,使用MRPC、QQP和STS-B进行评价;对于情绪分类任务,使用SST-2进行评估;对于自然语言推理任务,使用MNLI、QNLI和RTE进行评估;对于语言可接受性任务,使用CoLA进行评估。
本发明实施例评估了MNLI-m和MNLI-mm在MNLI上的结果。对于MRPC和QQP,同时评估了F1和准确性。对于STS-B,评估了Pearson和Spearman的相关关系。对于CoLA,评估了Matthew的相关性。对于其他数据集,使用准确性作为衡量标准。具体实施过程如下:
首先,在PyTorch框架的基础上实现了自适应对比知识蒸馏框架。按照上述工作任务来评估在特定任务设置下的自适应对比知识蒸馏框架。其中,教师网络首先在下游任务上进行微调,学生网络也在蒸馏过程中基于下游任务进行训练。采用BERT-Base模型作为教师网络,并分别使用3层和6层的BERT模型作为学生网络,分别表示为BERT3和BERT6。在教师网络和学生网络中,隐藏状态的数量都被设定为768,并假设教师网络的低层也包含重要的信息,应该传递给学生网络。因此,选择了“跳过”策略来插入自适应对比性蒸馏损失,这可以带来更强的监督效果。
在下游任务上对预训练的BERT-Base模型进行微调,作为相应的教师网络。最大序列长度被设置为128,并采用AdamW优化器。初始学习率和批大小分别设置为和8。对于不同的下游任务,训练轮次从2到4不等。然后,通过使用自适应对比知识蒸馏框架训练学生网络。在公式(3)中,生成/>的预测器是由一个两层的神经网络实现。动态特征存储器的容量大小被设定为1000,对超参数进行搜索。对于搜索范围,学生网络学习率为,批大小为/>,/>为/>,/>为/>},为/>。其他超参数与训练教师网络时的参数相同。
实验的对比方法采用多种最先进的方法,包括PKD、RCODistilBERT、TinyBERT、CRD、 SFTN、MetaDistill、Annealing KD、ALP-KD和CoDIR,实验结果如表1所示。
表1
从表1中可以看出:(1)当使用BERT3和BERT6作为学生网络时,本发明实施例提供的自适应对比知识蒸馏框架在大多数情况下都优于其他基线方法,这表明了该自适应对比知识蒸馏框架的有效性。特别的,当使用BERT3作为学生网络时,该自适应对比知识蒸馏框架在CoLA上超过了其他基线方法2.9%以上。(2)当使用BERT3作为学生网络时,该自适应对比知识蒸馏框架可以实现更高的性能增益。一个可能的解释是,经过蒸馏的BERT6的性能接近于教师网络BERT-Base,导致其难以更进一步提高性能。另外,BERT3的知识比BERT6少。因此,自适应对比性蒸馏损失作为新知识可以为BERT3带来更多的信息增益,从而带来更多的性能改进。
第二,消融实验测试。
下面通过几个消融实验,对本发明实施例提供的网络压缩方法的有效性进行进一步的验证。在此,使用BERT-Base作为教师网络,使用BERT3作为学生网络,在QNLI上进行实验。
首先是验证公式(4)中的效果。为了验证自适应对比性蒸馏损失的有效性,在本部分实验中去掉了公式(4)中的/>并进行了蒸馏。结果如图3所示,其中,去掉/>的方法表示为“w/o />”。从图3可以看到,本发明实施例提供的网络压缩方法比 “w/o” 方法要好很多,证明了自适应对比性蒸馏损失对于显式监督的有效性,可以将负样本对的学生网络特征推得很远。
其次是验证样本自适应重加权策略的有效性。为了研究样本适应性重加权策略的有效性,在本部分实验中去掉公式(3)中的,然后进行蒸馏。即,在蒸馏中使用对比性蒸馏损失而不是自适应对比性蒸馏损失。结果如图3所示,其中,去掉/>的方法表示为“w/oSAR”。从图3可以看到,本发明实施例提供的网络压缩方法比 “w/o SAR” 方法表现得更好,这证明了样本自适应重加权策略的有效性,即更多地关注鉴别力较差的样本。
再次是验证动态特征存储器的有效性。为了验证自适应对比知识蒸馏框架中使用动态特征存储器的有效性,在本部分实验中进行了去除动态特征存储器的实验。实验结果如图3所示,其中,去掉动态特征存储器的方法表示为“w/o DFS”。从图3中可以看出,本发明实施例提供的网络压缩方法比“w/o DFS”方法表现得更好,证明了使用动态特征存储器的有效性。
最后是验证公式(3)中和/>的有效性。在本部分实验中展示了去除公式(4)中/>和/>时的结果,实验结果如图3所示,其中,去除/>和/>的方法分别表示为 “w/o/>”和“w/o />”。从图3中可以看出,第一,自适应对比知识蒸馏框架的性能优于“w/o”和“w/o />”的方法。这表明使用/>和/>是有益的。第二,“w/o />”的准确性高于“w/o/>”,这表明在自适应对比知识蒸馏框架中压缩BERT时,损失/>比/>更有效果。
第三,算法分析。
在本部分中,同样以BERT-Base为教师网络,以BERT3为学生网络,在QNLI上进行算法分析的实验。其中,使用不同教师网络对学生网络的训练效果如表2所示,从表2中可以看出,在使用不同的教师网络结构时,仍然可以有效地训练学生网络。
表2
教师网络 | BERT12 | BERT10 | BERT8 | BERT6 |
学生网络(BERT3) | 86.2 | 86.1 | 85.8 | 85.5 |
与现有技术相比较,首先,本发明实施例将对比学习的概念引入到知识蒸馏中,提出一个对比性蒸馏损失作为显式监督,以最大化特征负样本对的距离。特别地,对于每个样本,对比性蒸馏损失旨在最大化学生网络中样本的特征与教师网络中样本的特征之间的相似性,并最小化学生网络中样本的特征与教师网络中样本的负样本对特征之间的相似性。如图4所示,对比性蒸馏损失可以有效地将负样本对的特征推得很远。
其次,本发明实施例针对较少鉴别性特征的困难样本,提出样本适应性重加权策略,以适应性地对这些困难样本给予更多的关注,并加强其识别能力。具体地说,利用一个神经网络作为预测器,根据每个样本的学习特征来预测其辨别能力。然后,根据预测的鉴别能力对不同样本的损失进行重新加权。由于这个过程中的所有操作都是可微分的,预测器的参数可以与学生网络共同学习。
最后,本发明实施例将样本适应性重加权策略无缝地融入前述对比性蒸馏损失中,形成自适应对比性蒸馏损失。将自适应对比性蒸馏损失与现有技术中的患者知识蒸馏方法相结合,构建了自适应对比知识蒸馏框架,可以用于BERT压缩。通过在多个自然语言处理任务上的广泛实验,证明了该自适应对比知识蒸馏框架对BERT压缩的有效性。
在上述基于自适应对比知识蒸馏的网络压缩方法的基础上,本发明实施例进一步提供一种基于自适应对比知识蒸馏的网络压缩装置。如图5所示,该网络压缩装置包括一个或多个处理器和动态特征存储器(简写为存储器)。其中,动态特征存储器与处理器耦接,用于存储一个或多个计算机程序,当一个或多个计算机程序被一个或多个处理器执行,使得一个或多个处理器实现如上述实施例中基于自适应对比知识蒸馏的网络压缩方法。
其中,处理器用于控制该基于自适应对比知识蒸馏的网络压缩装置的整体操作,以完成上述基于自适应对比知识蒸馏的网络压缩方法的全部或部分步骤。该处理器可以是中央处理器(CPU)、图形处理器(GPU)、现场可编程逻辑门阵列(FPGA)、专用集成电路(ASIC)、数字信号处理(DSP)芯片等。动态特征存储器用于存储各种类型的数据以支持在该网络压缩装置的操作,这些数据例如可以包括用于网络压缩装置操作的任何应用程序或方法的指令,以及应用程序相关的数据。该动态特征存储器可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,例如静态随机存取动态特征存储器(SRAM)、电可擦除可编程只读动态特征存储器(EEPROM)、可擦除可编程只读动态特征存储器(EPROM)、可编程只读动态特征存储器(PROM)、只读动态特征存储器(ROM)、磁动态特征存储器、快闪动态特征存储器等。
在一个示例性实施例中,基于自适应对比知识蒸馏的网络压缩装置具体可以由计算机芯片或实体实现,或者由具有某种功能的产品来实现,用于执行上述基于自适应对比知识蒸馏的网络压缩方法,并达到如上述方法一致的技术效果。一种典型的实施例为计算机。具体地说,计算机例如可以为个人计算机、膝上型计算机、车载人机交互设备、蜂窝电话、相机电话、智能电话、个人数字助理、媒体播放器、导航设备、电子邮件设备、游戏控制台、平板计算机、可穿戴设备或者这些设备中的任何设备的组合。
在另一个示例性实施例中,本发明还提供一种包括程序指令的计算机可读存储介质,该程序指令被处理器执行时实现上述任意一个实施例中的基于自适应对比知识蒸馏的网络压缩方法的步骤。例如,该计算机可读存储介质可以为包括程序指令的存储器,上述程序指令可以由基于自适应对比知识蒸馏的网络压缩装置的处理器执行,以完成上述基于自适应对比知识蒸馏的网络压缩方法,并达到如上述方法一致的技术效果。
上面对本发明提供的基于自适应对比知识蒸馏的网络压缩方法和装置进行了详细的说明。对本领域的一般技术人员而言,在不背离本发明实质内容的前提下对它所做的任何显而易见的改动,都将构成对本发明专利权的侵犯,将承担相应的法律责任。
Claims (8)
1.一种基于自适应对比知识蒸馏的网络压缩方法,用在自然语言处理任务中,其特征在于包括如下步骤:
(1)引入对比性蒸馏损失作为显式监督,以最大化特征负样本对的距离;该对比性蒸馏损失的表达式如下:
(2)利用一个神经网络作为预测器,根据每个样本的学习特征预测样本的辨别能力;然后,根据预测的辨别能力对不同样本的损失进行重新加权,以实现样本适应性重加权策略;
(3)将所述样本适应性重加权策略融入到所述对比性蒸馏损失中,构建自适应对比性蒸馏损失;该自适应对比性蒸馏损失的表达式如下:
其中,是第/>个样本预测的辨别能力;/>是预测器的函数;/>是预测器的学习参数;/>是Sigmoid函数,用于确保预测的辨别能力为正值;/>表示包含来自不同类别的标签的第/>个样本的隐藏状态特征的集合;/>表示学生网络在配对层的第/>个样本的隐藏状态特征;/>表示教师网络在配对层的第/>个样本的隐藏状态特征;/>表示样本数量;为正整数;
(4)基于所述自适应对比性蒸馏损失构建自适应对比知识蒸馏框架,用于实现自然语言处理过程中的神经网络压缩。
2.如权利要求1所述的网络压缩方法,其特征在于:
所述步骤(1)中,对于每个样本,利用所述对比性蒸馏损失最大化学生网络中样本的特征与教师网络中样本的特征之间的相似性,并最小化学生网络中样本的特征与教师网络中样本的负样本对特征之间的相似性。
3.如权利要求1所述的网络压缩方法,其特征在于:
所述步骤(3)中,在自适应对比性蒸馏损失的训练过程中,增加分子项,以使教师网络和学生网络中来自同一样本的特征相互接近,同时减少分母项,以使学生网络中来自不同类别的样本的特征远离教师网络中的样本的特征。
4.如权利要求3所述的网络压缩方法,其特征在于:
所述步骤(3)中,给具有较少鉴别特征的样本分配更高的权重,以形成自适应对比性蒸馏损失。
5.如权利要求2或3所述的网络压缩方法,其特征在于:
所述教师网络为BERT-Base模型,所述学生网络为BERT模型。
6.如权利要求1所述的网络压缩方法,其特征在于:
所述步骤(4)中,在每个批次的计算结束后,将所述批次的特征保存到动态特征存储器中,同时将样本的标签也保存到所述动态特征存储器中,用于识别样本。
7.如权利要求6所述的网络压缩方法,其特征在于:
所述动态特征存储器中只存储插入自适应对比性蒸馏损失的所在层的特征;在存储空间满了之后,根据先入先出的策略更新存储空间。
8.一种基于自适应对比知识蒸馏的网络压缩装置,用在自然语言处理任务中,其特征在于包括处理器和动态特征存储器,所述处理器和所述动态特征存储器耦接;其中,
所述动态特征存储器用于存储计算机程序;
所述处理器用于运行存储在所述动态特征存储器中的计算机程序,执行如权利要求1~7中任意一项所述的基于自适应对比知识蒸馏的网络压缩方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310758129.8A CN116502621B (zh) | 2023-06-26 | 2023-06-26 | 一种基于自适应对比知识蒸馏的网络压缩方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310758129.8A CN116502621B (zh) | 2023-06-26 | 2023-06-26 | 一种基于自适应对比知识蒸馏的网络压缩方法和装置 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116502621A CN116502621A (zh) | 2023-07-28 |
CN116502621B true CN116502621B (zh) | 2023-10-17 |
Family
ID=87316936
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310758129.8A Active CN116502621B (zh) | 2023-06-26 | 2023-06-26 | 一种基于自适应对比知识蒸馏的网络压缩方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116502621B (zh) |
Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021023202A1 (zh) * | 2019-08-07 | 2021-02-11 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法 |
CN113939827A (zh) * | 2020-12-25 | 2022-01-14 | 阿里巴巴集团控股有限公司 | 用于图像到视频重识别的系统和方法 |
CN114299591A (zh) * | 2021-12-30 | 2022-04-08 | 厦门理工学院 | 基于自适应对比知识蒸馏的人脸属性识别方法及系统 |
CN114972839A (zh) * | 2022-03-30 | 2022-08-30 | 天津大学 | 一种基于在线对比蒸馏网络的广义持续分类方法 |
CN115294407A (zh) * | 2022-09-30 | 2022-11-04 | 山东大学 | 基于预习机制知识蒸馏的模型压缩方法及系统 |
CN115526332A (zh) * | 2022-08-17 | 2022-12-27 | 阿里巴巴(中国)有限公司 | 基于预训练语言模型的学生模型训练方法和文本分类系统 |
CN115995018A (zh) * | 2022-12-09 | 2023-04-21 | 厦门大学 | 基于样本感知蒸馏的长尾分布视觉分类方法 |
-
2023
- 2023-06-26 CN CN202310758129.8A patent/CN116502621B/zh active Active
Patent Citations (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021023202A1 (zh) * | 2019-08-07 | 2021-02-11 | 交叉信息核心技术研究院(西安)有限公司 | 一种卷积神经网络的自蒸馏训练方法、设备和可伸缩动态预测方法 |
CN113939827A (zh) * | 2020-12-25 | 2022-01-14 | 阿里巴巴集团控股有限公司 | 用于图像到视频重识别的系统和方法 |
CN114299591A (zh) * | 2021-12-30 | 2022-04-08 | 厦门理工学院 | 基于自适应对比知识蒸馏的人脸属性识别方法及系统 |
CN114972839A (zh) * | 2022-03-30 | 2022-08-30 | 天津大学 | 一种基于在线对比蒸馏网络的广义持续分类方法 |
CN115526332A (zh) * | 2022-08-17 | 2022-12-27 | 阿里巴巴(中国)有限公司 | 基于预训练语言模型的学生模型训练方法和文本分类系统 |
CN115294407A (zh) * | 2022-09-30 | 2022-11-04 | 山东大学 | 基于预习机制知识蒸馏的模型压缩方法及系统 |
CN115995018A (zh) * | 2022-12-09 | 2023-04-21 | 厦门大学 | 基于样本感知蒸馏的长尾分布视觉分类方法 |
Non-Patent Citations (1)
Title |
---|
基于预训练模型与知识蒸馏的法律判决预测算法;潘瑞东;《控制与决策》;第37卷(第1期);67-76 * |
Also Published As
Publication number | Publication date |
---|---|
CN116502621A (zh) | 2023-07-28 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Mehta et al. | An empirical investigation of the role of pre-training in lifelong learning | |
Yang et al. | Breaking the softmax bottleneck: A high-rank RNN language model | |
Duchi et al. | Distributionally robust losses for latent covariate mixtures | |
CN111356997B (zh) | 具有颗粒化注意力的层次神经网络 | |
Duchi et al. | Distributionally robust losses against mixture covariate shifts | |
Lee et al. | Mutual information-based multi-label feature selection using interaction information | |
US9858534B2 (en) | Weight generation in machine learning | |
US11537930B2 (en) | Information processing device, information processing method, and program | |
CN116909532B (zh) | 一种代码生成与缺陷修复方法和装置 | |
CN116134454A (zh) | 用于使用知识蒸馏训练神经网络模型的方法和系统 | |
CN113837370A (zh) | 用于训练基于对比学习的模型的方法和装置 | |
Sharma et al. | The truth is in there: Improving reasoning in language models with layer-selective rank reduction | |
Tsai et al. | Formalizing generalization and adversarial robustness of neural networks to weight perturbations | |
Bennet et al. | A Hybrid Approach for Gene Selection and Classification Using Support Vector Machine. | |
CN111209402A (zh) | 一种融合迁移学习与主题模型的文本分类方法及系统 | |
Pan et al. | Automatic noisy label correction for fine-grained entity typing | |
Tsai et al. | Formalizing generalization and robustness of neural networks to weight perturbations | |
CN111611796A (zh) | 下位词的上位词确定方法、装置、电子设备及存储介质 | |
CN111259147A (zh) | 基于自适应注意力机制的句子级情感预测方法及系统 | |
Lin et al. | Robust educational dialogue act classifiers with low-resource and imbalanced datasets | |
CN116502621B (zh) | 一种基于自适应对比知识蒸馏的网络压缩方法和装置 | |
Ding et al. | Degradation analysis with nonlinear exponential‐dispersion process: Bayesian offline and online perspectives | |
Galitsky | Customers’ retention requires an explainability feature in machine learning systems they use | |
Yoshikawa et al. | Non-linear regression for bag-of-words data via Gaussian process latent variable set model | |
CN110851600A (zh) | 基于深度学习的文本数据处理方法及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |