CN113849641A - 一种跨领域层次关系的知识蒸馏方法和系统 - Google Patents

一种跨领域层次关系的知识蒸馏方法和系统 Download PDF

Info

Publication number
CN113849641A
CN113849641A CN202111131585.7A CN202111131585A CN113849641A CN 113849641 A CN113849641 A CN 113849641A CN 202111131585 A CN202111131585 A CN 202111131585A CN 113849641 A CN113849641 A CN 113849641A
Authority
CN
China
Prior art keywords
layer
domain
student
field
prototype
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202111131585.7A
Other languages
English (en)
Other versions
CN113849641B (zh
Inventor
董晨鹤
沈颖
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Sun Yat Sen University
Original Assignee
Sun Yat Sen University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Sun Yat Sen University filed Critical Sun Yat Sen University
Priority to CN202111131585.7A priority Critical patent/CN113849641B/zh
Publication of CN113849641A publication Critical patent/CN113849641A/zh
Application granted granted Critical
Publication of CN113849641B publication Critical patent/CN113849641B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Databases & Information Systems (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种跨领域层次关系的知识蒸馏方法和系统,为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。

Description

一种跨领域层次关系的知识蒸馏方法和系统
技术领域
本发明涉及知识蒸馏技术领域,尤其涉及一种跨领域层次关系的知识蒸馏方法和系统。
背景技术
语言是人类重要的沟通和表达方式。随着互联网的发展,当今时代的信息量不断增加,信息的增长速度已经远远超越了人类的理解速度。如果计算机能够处理、理解语言,不仅为海量信息的处理提供了可能,也有助于深化对语言能力和人类智能的认识。
自然语言处理,旨在将人的语言形式转化为机器可理解的、结构化的、完整的语义表示,目的是让计算机能够理解和生成人类语言。大型预训练语言模型在大量的自然语言处理任务上取得了显著的效果,例如机器翻译、文本摘要、对话生成等任务。然而,过大的模型尺寸和过低的推理时间阻碍了其实际应用的脚步,很难在资源受限的设备上进行部署。因而,涌现了许多针对预训练语言模型的压缩技术,例如量化、权重裁剪、知识蒸馏等技术。由于知识蒸馏即插即用的特性,其在实际中得到了广泛的应用。
知识蒸馏的目的在于将知识从更大尺寸的教师模型迁移到更小尺寸的学生模型中。传统的知识蒸馏局限于单领域知识蒸馏,然而,对于人类而言,常常迁移不同领域的相关知识,例如学过钢琴的人学小提琴比别人学得快,会骑自行车的人更容易学会骑摩托车。来自不同领域的文本数据在文本、句式术语上有显著的差异,但是自然语言又具备跨领域的共性知识,如词汇、句法等,这为跨领域的知识迁移提供了可能。因而,现有的知识蒸馏技术已经从传统的单领域知识蒸馏扩展到了跨领域知识蒸馏。然而,现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力。
发明内容
本发明实施例提供了一种跨领域层次关系的知识蒸馏方法和系统,用于解决现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。
有鉴于此,本发明第一方面提供了一种跨领域层次关系的知识蒸馏方法方法,所述方法包括:
获取不同领域的训练样本;
对各领域的训练样本分别计算学生层的原型特征;
对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;
将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;
将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;
根据蒸馏损失函数对学生模型进行迭代训练。
可选地,在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。
可选地,学生层的原型特征计算公式为:
Figure BDA0003280647000000021
其中,hm,d为d领域的第m个学生层的原型特征,
Figure BDA0003280647000000026
为第d个领域的训练集,L为句子长度,
Figure BDA0003280647000000022
Figure BDA0003280647000000023
中第i个采样学生嵌入的第l个单词,
Figure BDA0003280647000000024
Figure BDA0003280647000000025
中第i个采样的第m个学生层的前馈网络层输出,M为学生层总数。
可选地,还包括:
基于自注意力机制建立每个领域的参考原型特征;
将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;
将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。
可选地,参考原型特征为:
Figure BDA0003280647000000031
Figure BDA0003280647000000032
其中,
Figure BDA0003280647000000033
为第m层的注意力矩阵,
Figure BDA0003280647000000034
为第m层所有领域的原型特征,
Figure BDA0003280647000000035
为第m层的一个可学习参考矩阵。
可选地,聚合原型特征为:
Figure BDA0003280647000000036
Figure BDA0003280647000000037
其中,
Figure BDA0003280647000000038
为第m层、第d个领域的相似系数,
Figure BDA0003280647000000039
为第m层和之前层在第d个领域的原型特征,
Figure BDA00032806470000000310
为第m层第d个领域的一个可学习参考矩阵,
Figure BDA00032806470000000311
为第m层、第d个领域的参考原型特征。
可选地,蒸馏损失函数为:
Figure BDA00032806470000000312
其中,rm,d为第m层、第d个领域的领域关系系数,
Figure BDA00032806470000000313
为d个领域的嵌入层损失,
Figure BDA00032806470000000314
为预测层损失,
Figure BDA00032806470000000315
为第d个领域中的第m个学生层的注意力层损失,
Figure BDA00032806470000000316
为第d个领域中的第m个学生层的前馈网络层损失,D为总领域数,γ为用来控制预测损失
Figure BDA00032806470000000317
的权重。
本发明第二方面提供一种跨领域层次关系的知识蒸馏系统,所述系统包括:
训练样本获取模块,用于获取不同领域的训练样本;
原型特征生成模块,用于对各领域的训练样本分别计算学生层的原型特征;
领域关系网络构建模块,用于对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;
领域关系系数获取模块,用于将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;
蒸馏损失函数生成模块,用于将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;
模型训练模块,用于根据蒸馏损失函数对学生模型进行迭代训练。
可选地,领域关系网络构建模块建立的两层领域关系图网络的网络结构包括:
在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。
可选地,还包括:
参考原型特征生成模块,用于基于自注意力机制建立每个领域的参考原型特征;
对比聚合模块,用于将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;
领域关系系数更新模块,用于将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。
从以上技术方案可以看出,本发明实施例具有以下优点:
由于不同领域的层原型特征会有不同的偏好,因此本发明实施例中提供的跨领域层次关系的知识蒸馏方法中为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。
同时,本发明中还引入了一个层次化对比-聚合机制挖掘出各个领域更具有代表性的层原型特征,进一步提升压缩语言模型的表达能力。
附图说明
图1为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的一个流程示意图;
图2为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的一个模型结构原理图;
图3为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的令一个流程示意图;
图4为本发明实施例中提供的跨领域层次关系的知识蒸馏方法的另一个模型结构原理图;
图5为本发明实施例中提供的跨领域层次关系的知识蒸馏系统的结构示意图。
具体实施方式
为了使本技术领域的人员更好地理解本发明方案,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
基础的跨领域知识蒸馏方法中,对教师模型和学生模型的嵌入层、注意力矩阵、前馈网络层、预测概率分布进行联合蒸馏,并且运用多任务学习的训练策略进行跨领域知识蒸馏。具体地,将所有领域的嵌入层和前馈网络层的权重进行共享,但为不同的预测层赋予不同的权重。第d个领域的嵌入层损失
Figure BDA0003280647000000051
和预测层损失
Figure BDA0003280647000000052
可以表示为:
Figure BDA0003280647000000053
Figure BDA0003280647000000054
其中,MSE和CE分别代表均方差损失和交叉熵损失,ES
Figure BDA0003280647000000055
分别代表第d个领域中学生模型和教师模型的嵌入层,
Figure BDA0003280647000000056
Figure BDA0003280647000000057
分别代表第d个领域中学生模型和教师模型的预测概率分布,Wembd是一个可学习转换矩阵,用于将学生嵌入与维度不匹配的教师嵌入进行对齐,t为温度系数。
第d个领域中的第m个学生层的注意力层损失
Figure BDA0003280647000000058
和前馈网络层输出损失
Figure BDA0003280647000000059
可表示为:
Figure BDA00032806470000000510
Figure BDA00032806470000000511
其中,h为注意力头数,
Figure BDA0003280647000000061
Figure BDA0003280647000000062
分别是在第d个领域内的第m个学生层和其匹配的第n个教师层的第i个注意力头矩阵,
Figure BDA0003280647000000063
Figure BDA0003280647000000064
分别是第d个领域内的第m个学生层和第n个老师层的前馈网络层输出,
Figure BDA0003280647000000065
是一个转换矩阵,用于将第m个学生层输出与维度不匹配的第n个老师层输出进行对齐。
采用均匀策略对学生模型和教师模型间的模型层进行匹配。最后,总体的蒸馏损失函数可以表示为:
Figure BDA0003280647000000066
其中,D为总领域数,M是学生模型中的层数,γ为用来控制预测损失
Figure BDA0003280647000000067
的权重。
上述基础跨领域知识蒸馏方法虽然可以将不同领域的学生模型进行蒸馏,但不同领域间的关系信息被忽略了,也就降低了模型的泛化能力。为了提高模型捕捉不同领域之间的关系信息方面能力,增强模型的泛化能力,本发明提供了一种新的跨领域层次关系的知识蒸馏方法。
为了便于理解,请参阅图1和图2,图1为本发明实施例中跨领域层次关系的知识蒸馏方法的一个流程示意图,如图1所示,本发明实施例中跨领域层次关系的知识蒸馏方法包括:
步骤101、获取不同领域的训练样本。
步骤102、对各领域的训练样本分别计算学生层的原型特征。
本发明中,对所有学生模型层都执行知识蒸馏。使用原型特征来反映各个领域数据的特点,对不同的学生层计算不同的原型特征。在实际中,为不同批量的训练样本计算不同的原型特征。在d领域的第m个学生层的原型特征hm,d可以由下式计算得到:
Figure BDA0003280647000000068
其中,hm,d为d领域的第m个学生层的原型特征,
Figure BDA0003280647000000069
为第d个领域的训练集,L为句子长度,
Figure BDA00032806470000000610
Figure BDA00032806470000000611
中第i个采样学生嵌入的第l个单词,
Figure BDA00032806470000000612
Figure BDA00032806470000000613
中第i个采样的第m个学生层的前馈网络层输出,M为学生层总数。
步骤103、对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络。
计算的这些领域原型特征被用来挖掘不同领域间的关系,为一次性地同时找到跨领域关系,本发明中用基于图注意力网络的领域关系网络同时处理所有领域的原型特征。
步骤104、将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数。
在图注意力网络中,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度。这样,不同领域的关系便可以同时被捕捉。如图2所示,除了预测层外,为每个学生层建立一个两层领域关系图网络,第m层图网络的输入hm是一系列包含所有第m个学生层的领域原型特征的节点特征,即
Figure BDA0003280647000000071
Figure BDA0003280647000000072
表示大小为D×F的参数矩阵,D是总领域数,F是每个原型特征的通道数。通过领域关系图网络来生成在d领域的第m个学生层的领域关系系数rm,d
在第m个学生层的第一层领域关系网络中,一个共享参数矩阵
Figure BDA0003280647000000073
以及一个自注意力机制首先被应用到每个节点上,其中,F′是中间层通道数。之后用一个包含K个头的多头拼接机制使得训练过程更加稳定。具体地,每个输入原型特征hm,d首先经过参数矩阵Wm的转换,接着两个节点i,j之间的注意力系数αi,j,m通过为二者拼接后的转换特征应用一个参数向量
Figure BDA0003280647000000074
得到,并且在之后送入LeakyReLU非线性函数和softmax函数,αi,j,m的表达式如下所示:
Figure BDA0003280647000000075
其中,
Figure BDA0003280647000000076
表示拼接操作,
Figure BDA0003280647000000077
是所有节点i的一阶邻居(包括节点i本身)。
之后,节点i最终的输出
Figure BDA0003280647000000078
可以通过节点i和其邻居的转换特征加权和得到,并在之后送入ELU非线性函数和一个多头拼接机制,表示为:
Figure BDA0003280647000000079
其中,k表示头序号。
在第m个学生层的第二层领域关系图网络中,为了得到领域关系系数,将第一层图网络中用到的参数Wm和am分别变换为
Figure BDA0003280647000000081
并且不采用多头拼接机制。用softmax操作对输出归一化并最终得到如下的领域关系系数
Figure BDA0003280647000000085
Figure BDA0003280647000000083
Figure BDA0003280647000000084
步骤105、将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数。
根据每个学生层生成的领域关系系数,对每个领域进行权重重分配,因而可以得到各层的损失,从而确定总体损失,得到蒸馏损失函数。
步骤106、根据蒸馏损失函数对学生模型进行迭代训练。
根据蒸馏损失函数损失函数对学生模型进行训练,即,根据蒸馏损失函数更新学生模型的参数。
由于不同领域的层原型特征会有不同的偏好,因此本发明实施例中提供的跨领域层次关系的知识蒸馏方法中为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。
同时,为进一步提升压缩语言模型的表达能力,本发明中还引入了一个层次化对比-聚合机制挖掘出各个领域更具有代表性的层原型特征,为每个领域建立了一系列参考原型特征,并根据与对应参考原型特征的相似度层次化地聚合当前层和其之前层的原型特征,从而得到各个领域更具有代表性的聚合原型特征。
请参阅图3-图4,在一个实施例中,还包括:
步骤107、基于自注意力机制建立每个领域的参考原型特征;
步骤108、将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;
步骤109、将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。
对于每个学生层,当前层和其之前层原型特征的参考原型特征可以简单地设定为当前层的原始领域原型特征。然而,这种方法没有考虑到其他领域的信息,而该信息对于提高模型在不同领域的泛化性能起着很重要的作用。为此,本发明实施例中为同一层的所有领域原型特征引入了一个自注意力机制以加入不同领域的信息。具体地,第m个学生层发参考原型特征
Figure BDA0003280647000000091
可表示为:
Figure BDA0003280647000000092
Figure BDA0003280647000000093
其中,
Figure BDA0003280647000000094
为第m层的注意力矩阵,
Figure BDA0003280647000000095
为第m层所有领域的原型特征,
Figure BDA0003280647000000096
为第m层的一个可学习参考矩阵,softmax操作在最后一个向量维度执行。
在得到参考原型特征之后,使用一个对比-聚合机制来动态地聚合层原型特征,该过程通过将其与对应的参考原型特征进行对比完成,可以使得模型注意到每个领域中更具代表性的层原型特征。具体地,第m层第d个领域的聚合原型特征
Figure BDA0003280647000000097
可以表示为;
Figure BDA0003280647000000098
Figure BDA0003280647000000099
其中,
Figure BDA00032806470000000910
为第m层、第d个领域的相似系数,
Figure BDA00032806470000000911
为第m层和之前层在第d个领域的原型特征,
Figure BDA00032806470000000912
为第m层第d个领域的一个可学习参考矩阵,
Figure BDA00032806470000000913
为第m层、第d个领域的参考原型特征。
聚合后的原型特征
Figure BDA00032806470000000914
被送入领域关系图网络中以得到领域关系系数
Figure BDA00032806470000000915
最后,得到的总体损失(即最终确定的蒸馏损失函数)可以表示为:
Figure BDA0003280647000000101
其中,rm,d为第m层、第d个领域的领域关系系数,
Figure BDA0003280647000000102
为d个领域的嵌入层损失,
Figure BDA0003280647000000103
为预测层损失,
Figure BDA0003280647000000104
为第d个领域中的第m个学生层的注意力层损失,
Figure BDA0003280647000000105
为第d个领域中的第m个学生层的前馈网络层损失,D为总领域数,γ为用来控制预测损失
Figure BDA0003280647000000106
的权重。
为了便于理解,请参阅图5,本发明中提供了一种跨领域层次关系的知识蒸馏系统的实施例,包括:
训练样本获取模块501,用于获取不同领域的训练样本;
原型特征生成模块502,用于对各领域的训练样本分别计算学生层的原型特征;
领域关系网络构建模块503,用于对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;
领域关系系数获取模块504,用于将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;
蒸馏损失函数生成模块505,用于将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;
模型训练模块506,用于根据蒸馏损失函数对学生模型进行迭代训练。
领域关系网络构建模块503建立的两层领域关系图网络的网络结构包括:
在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。
还包括:
参考原型特征生成模块507,用于基于自注意力机制建立每个领域的参考原型特征;
对比聚合模块508,用于将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;
领域关系系数更新模块509,用于将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。
由于不同领域的层原型特征会有不同的偏好,因此本发明实施例中提供的跨领域层次关系的知识蒸馏系统中为各个领域构建了一系列参考原型特征,建立多个领域关系图网络来充分地学习不同领域间的关系,每个图节点代表一个领域的原型特征,每个边的权重代表相连的两个原型特征的相似度,这样,不同领域的关系便可以同时被捕捉生成一系列领域关系系数对各个领域在知识蒸馏过程中的权重进行重分配,引导模型动态地关注更重要的领域信息,可以更方便高效地处理多领域环境下模型压缩,大幅度地提升模型的性能,解决了现有的跨领域知识蒸馏方法在捕捉不同领域之间的关系信息方面能力较差,泛化性能较低,难以提高压缩语言模型的表达能力的技术问题。
同时,本发明中还引入了一个层次化对比-聚合机制挖掘出各个领域更具有代表性的层原型特征,进一步提升压缩语言模型的表达能力。
本发明实施例中提供的跨领域层次关系的知识蒸馏系统用于执行前述跨领域层次关系的知识蒸馏方法实施例中的跨领域层次关系的知识蒸馏方法,可取得与前述跨领域层次关系的知识蒸馏方法实施例相同的技术效果,在此不再进行赘述。
以上所述,以上实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本发明各实施例技术方案的精神和范围。

Claims (10)

1.一种跨领域层次关系的知识蒸馏方法,其特征在于,包括:
获取不同领域的训练样本;
对各领域的训练样本分别计算学生层的原型特征;
对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;
将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;
将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;
根据蒸馏损失函数对学生模型进行迭代训练。
2.根据权利要求1所述的跨领域层次关系的知识蒸馏方法,其特征在于,在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。
3.根据权利要求1所述的跨领域层次关系的知识蒸馏方法,其特征在于,学生层的原型特征计算公式为:
Figure FDA0003280646990000011
其中,hm,d为d领域的第m个学生层的原型特征,
Figure FDA0003280646990000012
为第d个领域的训练集,L为句子长度,
Figure FDA0003280646990000013
Figure FDA0003280646990000016
中第i个采样学生嵌入的第l个单词,
Figure FDA0003280646990000014
Figure FDA0003280646990000015
中第i个采样的第m个学生层的前馈网络层输出,M为学生层总数。
4.根据权利要求1所述的跨领域层次关系的知识蒸馏方法,其特征在于,还包括:
基于自注意力机制建立每个领域的参考原型特征;
将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;
将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。
5.根据权利要求4所述的跨领域层次关系的知识蒸馏方法,其特征在于,参考原型特征为:
Figure FDA0003280646990000021
Figure FDA0003280646990000022
其中,
Figure FDA0003280646990000023
为第m层的注意力矩阵,
Figure FDA0003280646990000024
为第m层所有领域的原型特征,
Figure FDA0003280646990000025
为第m层的一个可学习参考矩阵。
6.根据权利要求5所述的跨领域层次关系的知识蒸馏方法,其特征在于,聚合原型特征为:
Figure FDA0003280646990000026
Figure FDA0003280646990000027
其中,
Figure FDA0003280646990000028
为第m层、第d个领域的相似系数,
Figure FDA0003280646990000029
为第m层和之前层在第d个领域的原型特征,
Figure FDA00032806469900000210
为第m层第d个领域的一个可学习参考矩阵,
Figure FDA00032806469900000211
为第m层、第d个领域的参考原型特征。
7.根据权利要求6所述的跨领域层次关系的知识蒸馏方法,其特征在于,蒸馏损失函数为:
Figure FDA00032806469900000212
其中,rm,d为第m层、第d个领域的领域关系系数,
Figure FDA00032806469900000213
为d个领域的嵌入层损失,
Figure FDA00032806469900000214
为预测层损失,
Figure FDA00032806469900000215
为第d个领域中的第m个学生层的注意力层损失,
Figure FDA00032806469900000216
为第d个领域中的第m个学生层的前馈网络层损失,D为总领域数,γ为用来控制预测损失
Figure FDA00032806469900000217
的权重。
8.一种跨领域层次关系的知识蒸馏系统,其特征在于,包括:
训练样本获取模块,用于获取不同领域的训练样本;
原型特征生成模块,用于对各领域的训练样本分别计算学生层的原型特征;
领域关系网络构建模块,用于对学生模型中的除了预测层外的每个学生层建立一个基于图注意力网络的两层领域关系图网络;
领域关系系数获取模块,用于将每个领域的训练样本的原型特征输入领域关系图网络,得到每个学生层的领域关系系数;
蒸馏损失函数生成模块,用于将每个学生层的领域关系系数作为教师模型和学生模型的对应层的权重系数,确定蒸馏损失函数;
模型训练模块,用于根据蒸馏损失函数对学生模型进行迭代训练。
9.根据权利要求8所述的跨领域层次关系的知识蒸馏系统,其特征在于,领域关系网络构建模块建立的两层领域关系图网络的网络结构包括:
在第一层领域关系图网络中,每个节点上应用一个共享参数矩阵和注意力机制,并将节点的输出送入ELU非线性函数和多头拼接机制,在第二层领域关系图网络中,去除多头拼接机制,使用softmax对输出归一化得到领域关系系数。
10.根据权利要求8所述的跨领域层次关系的知识蒸馏系统,其特征在于,还包括:
参考原型特征生成模块,用于基于自注意力机制建立每个领域的参考原型特征;
对比聚合模块,用于将每个学生层的原型特征和参考原型特征进行对比聚合处理,得到每个学生层每个领域的聚合原型特征;
领域关系系数更新模块,用于将每个学生层每个领域的聚合原型特征输入领域关系图网络,得到每个学生层的领域关系系数并更新。
CN202111131585.7A 2021-09-26 2021-09-26 一种跨领域层次关系的知识蒸馏方法和系统 Active CN113849641B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111131585.7A CN113849641B (zh) 2021-09-26 2021-09-26 一种跨领域层次关系的知识蒸馏方法和系统

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111131585.7A CN113849641B (zh) 2021-09-26 2021-09-26 一种跨领域层次关系的知识蒸馏方法和系统

Publications (2)

Publication Number Publication Date
CN113849641A true CN113849641A (zh) 2021-12-28
CN113849641B CN113849641B (zh) 2023-10-24

Family

ID=78980247

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111131585.7A Active CN113849641B (zh) 2021-09-26 2021-09-26 一种跨领域层次关系的知识蒸馏方法和系统

Country Status (1)

Country Link
CN (1) CN113849641B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114610500A (zh) * 2022-03-22 2022-06-10 重庆邮电大学 一种基于模型蒸馏的边缘缓存方法

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200302295A1 (en) * 2019-03-22 2020-09-24 Royal Bank Of Canada System and method for knowledge distillation between neural networks
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
CN112241455A (zh) * 2020-12-17 2021-01-19 之江实验室 基于多层级知识蒸馏预训练语言模型自动压缩方法及平台
CN112712099A (zh) * 2020-10-10 2021-04-27 江苏清微智能科技有限公司 一种基于双层知识蒸馏说话人模型压缩系统和方法
CN113281048A (zh) * 2021-06-25 2021-08-20 华中科技大学 一种基于关系型知识蒸馏的滚动轴承故障诊断方法和系统

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200302295A1 (en) * 2019-03-22 2020-09-24 Royal Bank Of Canada System and method for knowledge distillation between neural networks
CN111767711A (zh) * 2020-09-02 2020-10-13 之江实验室 基于知识蒸馏的预训练语言模型的压缩方法及平台
CN112712099A (zh) * 2020-10-10 2021-04-27 江苏清微智能科技有限公司 一种基于双层知识蒸馏说话人模型压缩系统和方法
CN112241455A (zh) * 2020-12-17 2021-01-19 之江实验室 基于多层级知识蒸馏预训练语言模型自动压缩方法及平台
CN113281048A (zh) * 2021-06-25 2021-08-20 华中科技大学 一种基于关系型知识蒸馏的滚动轴承故障诊断方法和系统

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
孙红等: "基于知识蒸馏的短文本分类方法", 软件导刊, vol. 20, no. 6, pages 23 - 27 *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114610500A (zh) * 2022-03-22 2022-06-10 重庆邮电大学 一种基于模型蒸馏的边缘缓存方法
CN114610500B (zh) * 2022-03-22 2024-04-30 重庆邮电大学 一种基于模型蒸馏的边缘缓存方法

Also Published As

Publication number Publication date
CN113849641B (zh) 2023-10-24

Similar Documents

Publication Publication Date Title
CN111554268B (zh) 基于语言模型的语言识别方法、文本分类方法和装置
CN109840287B (zh) 一种基于神经网络的跨模态信息检索方法和装置
CN109284506B (zh) 一种基于注意力卷积神经网络的用户评论情感分析系统及方法
CN110534087B (zh) 一种文本韵律层级结构预测方法、装置、设备及存储介质
CN111553479B (zh) 一种模型蒸馏方法、文本检索方法及装置
CN116415654A (zh) 一种数据处理方法及相关设备
CN112115687A (zh) 一种结合知识库中的三元组和实体类型的生成问题方法
CN113204633B (zh) 一种语义匹配蒸馏方法及装置
CN112560456B (zh) 一种基于改进神经网络的生成式摘要生成方法和系统
CN110659411A (zh) 一种基于神经注意力自编码器的个性化推荐方法
CN115408603A (zh) 一种基于多头自注意力机制的在线问答社区专家推荐方法
CN115809464A (zh) 基于知识蒸馏的轻量级源代码漏洞检测方法
CN116821294A (zh) 一种基于隐式知识反刍的问答推理方法和装置
CN117033602A (zh) 一种多模态的用户心智感知问答模型的构建方法
CN117009545A (zh) 一种持续多模态知识图谱的构建方法
CN113849641B (zh) 一种跨领域层次关系的知识蒸馏方法和系统
CN114169408A (zh) 一种基于多模态注意力机制的情感分类方法
CN112132075B (zh) 图文内容处理方法及介质
CN116863920B (zh) 基于双流自监督网络的语音识别方法、装置、设备及介质
CN117271745A (zh) 一种信息处理方法、装置及计算设备、存储介质
CN116543289A (zh) 一种基于编码器-解码器及Bi-LSTM注意力模型的图像描述方法
CN115455162A (zh) 层次胶囊与多视图信息融合的答案句子选择方法与装置
CN114239575B (zh) 语句分析模型的构建方法、语句分析方法、装置、介质和计算设备
US11941508B2 (en) Dialog system with adaptive recurrent hopping and dual context encoding
CN113051353A (zh) 一种基于注意力机制的知识图谱路径可达性预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant