CN113435208A - 学生模型的训练方法、装置及电子设备 - Google Patents

学生模型的训练方法、装置及电子设备 Download PDF

Info

Publication number
CN113435208A
CN113435208A CN202110662767.0A CN202110662767A CN113435208A CN 113435208 A CN113435208 A CN 113435208A CN 202110662767 A CN202110662767 A CN 202110662767A CN 113435208 A CN113435208 A CN 113435208A
Authority
CN
China
Prior art keywords
model
student
layer
loss function
error
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202110662767.0A
Other languages
English (en)
Other versions
CN113435208B (zh
Inventor
念天磊
刘丽
阳锋
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202110662767.0A priority Critical patent/CN113435208B/zh
Publication of CN113435208A publication Critical patent/CN113435208A/zh
Application granted granted Critical
Publication of CN113435208B publication Critical patent/CN113435208B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/30Semantic analysis
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/205Parsing
    • G06F40/216Parsing using statistical methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N7/00Computing arrangements based on specific mathematical models
    • G06N7/01Probabilistic graphical models, e.g. probabilistic networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Probability & Statistics with Applications (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Health & Medical Sciences (AREA)
  • Mathematical Analysis (AREA)
  • Computational Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Algebra (AREA)
  • Mathematical Optimization (AREA)
  • Pure & Applied Mathematics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)
  • Machine Translation (AREA)

Abstract

本申请提出了一种学生模型的训练方法及装置,涉及人工智能领域,尤其涉及自然语言处理和深度学习技术等领域,可应用于文本生成、机器翻译、模型压缩等场景下,包括将训练样本分别输入学生模型和教师模型中进行训练;获取学生模型和教师模型在嵌入层上的第一误差、在中间层上的第二误差以及在输出层上的损失函数;根据第一误差、第二误差、损失函数,确定学生模型的总损失函数,并基于总损失函数对学生模型的模型参数进行调整,并继续使用下一个训练样本对调整后的学生模型训练,直至训练结束,生成目标学生模型。本申请中,学生模型可以学习到教师模型的中间层的信息,使得学生模型的训练速度加快,优化了模型的训练效果,提高了模型的性能。

Description

学生模型的训练方法、装置及电子设备
技术领域
本申请涉及人工智能领域,尤其涉及自然语言处理和深度学习技术等领域,可应用于文本生成、机器翻译以及模型压缩场景下。
背景技术
目前,语义通顺度模型可以应用于多个领域,针对机器或其他方式生成的文案,进行通顺度的调整,进而过滤出高品质的符合人们阅读习惯的文本文案。但是,现有的语义通顺度模型的结构较为复杂,参数规模、运算量以及硬件的资源配置和消耗均较大,虽然可以实现高品质的文本文案的过滤输出,但是运算耗时较长,无法实现大规模的应用。
相关技术中,提出了一种知识蒸馏的方法,将复杂的语义通顺度模型作为教师模型,构建学生模型对其进行学习。但是,目前构建的学生模型仅可以实现对于教师模型最后一层的对齐学习,忽略了中间层,使得学生模型的学习成本较高,同时无法保证最终学生模型的输出结果的精度。
因此,如何实现对于学生模型对于教师模型的中间层的学习,进而提高模型的精度,是目前需要解决的问题。
发明内容
本申请提出了一种学生模型的训练方法、装置、电子设备、存储介质及计算机程序产品。
根据本申请的第一方面,提出了一种学生模型的训练方法,包括:将训练样本分别输入学生模型和教师模型中进行训练;获取所述学生模型和所述教师模型在嵌入层上的第一误差;获取所述学生模型和所述教师模型在中间层上的第二误差;获取所述学生模型与所述教师模型在输出层上的损失函数;根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。
根据本申请的第二方面,提出了一种学生模型的训练装置,包括:输入模块,用于将训练样本分别输入学生模型和教师模型中进行训练;获取模块,用于获取所述学生模型和所述教师模型在嵌入层上的第一误差,以及获取所述学生模型和所述教师模型在中间层上的第二误差,以及获取所述学生模型与所述教师模型在输出层上的损失函数;训练模块,用于根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。
根据本申请的第三方面,提出了一种电子设备,包括:包括处理器和存储器;其中,所述处理器通过读取所述存储器中存储的可执行程序代码来运行与所述可执行程序代码对应的程序,以用于实现如上述第一方面中任一项所述的学生模型的训练方法。
根据本申请的第四方面,提出了一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现如上述第一方面中任一项所述的学生模型的训练方法。
根据本申请的第五方面,提出了一种计算机程序产品,当所述计算机程序产品中的指令处理器执行时实现如上述第一方面中任一项所述的学生模型的训练方法。
应当理解,本部分所描述的内容并非旨在标识本申请的实施例的关键或重要特征,也不用于限制本申请的范围。本申请的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本申请的限定。其中:
图1为本申请一实施例的学生模型的结构示意图;
图2为本申请一实施例的学生模型的训练方法的流程示意图;
图3为本申请另一实施例的学生模型的训练方法的流程示意图;
图4为本申请另一实施例的学生模型的训练方法的流程示意图;
图5为本申请另一实施例的学生模型的训练方法的流程示意图;
图6为本申请另一实施例的学生模型的训练方法的流程示意图;
图7为本申请一实施例的学生模型的训练装置的结构示意图;
图8为本申请另一实施例的学生模型的训练装置的结构示意图;
图9为本申请一实施例的电子设备的示意性框图。
具体实施方式
以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
深度学习(Deep Learning,简称DL),是机器学习(Machine Learning,简称ML)领域中一个新的研究方向,它被引入机器学习使其更接近于最初的目标——人工智能。深度学习是学习样本数据的内在律和表示层次,这些学习过程中获得的信息对诸如文字,图像和声音等数据的解释有很大的帮助。它的最终目标是让机器能够像人一样具有分析学习能力,能够识别文字、图像和声音等数据。深度学习是一个复杂的机器学习算法,在语音和图像识别方面取得的效果,远远超过先前相关技术。
自然语言处理(Natural Language Processing,NLP)是计算机科学领域与人工智能领域中的一个重要方向。它研究能实现人与计算机之间用自然语言进行有效通信的各种理论和方法。自然语言处理是一门融语言学、计算机科学、数学于一体的科学。自然语言处理主要应用于机器翻译、舆情监测、自动摘要、观点提取、文本分类、问题回答、文本语义对比、语音识别等方面。
人工智能(Artificial Intelligence,简称AI),是研究使计算机来模拟人生的某些思维过程和智能行为(如学习、推理、思考、规划等)的学科,既有硬件层面的技术,也有软件层面的技术。人工智能硬件技术一般包括计算机视觉技术、语音识别技术、自然语言处理技术以及及其学习/深度学习、大数据处理技术、知识图谱技术等几大方面。
机器翻译(machine translation),又称为自动翻译,是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程。它是计算语言学的一个分支,是人工智能的目标之一。
图1为本申请一实施例的学生模型的结构示意图,如图1所示。
教师模型100包括嵌入层11、多个自注意力层(transformer层)12和输出层13,学生模型20包括嵌入层21、多个transformer层22和输出层23。可选地,从教师模型100的多个transformer层12中,基于设定的层间隔,每间隔N层提取一层的教师模型100的transformer层12,并将其设置为学生模型200的transformer层22。
如图1所示,设定教师模型100中共计12个transformer层12,设定学生模型200每间隔4层的层间隔提取教师模型100的一层transformer层的模型参数,构建学生模型的transformer层。将教师模型100的12个transformer层进行顺序编码,获取编码为{0,1,2,3,4,5,6,7,8,9,10,11}的12个transformer层,每2层的层间隔提取一层transformer层,则学生模型200可以将编码为{2,5,8,11}的四层transformer层提取出来,作为其自身构建模型时的transformer层。
需要说明的是,设定的层间隔也可以间隔3层或者其他可以实现使学生模型对教师模型的中间层进行间隔抽取的间隔层数,设定的教师模型的中间层中的transformer层的层数可以为上述的12层,也可以为其他层数,上述的层间隔数以及教师模型的中间层的transformer层的层数,均在此不做限定。
基于上述方法构建的学生模型200,不仅可以学习到教师模型100最后的概率输出,还可以学习到教师模型100的中间的transformer层12的信息,其拟合能力更强,泛化性更加优秀。
需要说明的是,初始状态的学生模型的中间层的参数,与其抽取的教师模型的对应的中间层的模型参数是相同的。在后续的学生模型训练过程中,会基于学生模型的中间层的输出与教师模型中对应的中间层的输出之间的差异程度,对学生模型的中间层的模型参数进行优化调整。
图2为本申请一实施例的学生模型的训练方法的流程示意图,如图2所示,该方法包括:
S201,将训练样本分别输入学生模型和教师模型中进行训练。
本申请实施例中,训练样本可以为文本训练样本,学生模型和教师模型分别为一个语义通顺度模型。实现中,为了实现语义通顺度模型对于文本的高质量的过滤,往往会为语义通顺度模型设置为复杂结构的模型,相应的,结构复杂的模型往往需要庞大的运算量,为了可以满足模型的运算量,对于结构复杂的模型会配置较多的硬件资源,使得部分场景的平台无法承载。相关技术中提出了知识蒸馏的方法,将结构复杂的模型作为教师模型构建学生模型,通过对学生模型的训练,使得学生模型的性能可以接近教师模型的性能。
本申请实施例中,可以对教师模型的中间层进行间隔抽取,并将其设置为学生模型的中间层。在学生模型的训练过程中,可以将教师模型中的某一层的输出与学生模型中与其对应的层的输出进行对比,并以教师模型的层的输出为标准,基于对比的结果调整学生模型的参数,进而实现学生模型的性能优化。
为了达到上述效果,可以将训练样本分别输入教师模型与学生模型,教师模型和学生模型基于相同的训练样本进行运算和输出,通过对比教师模型和学生模型的对应层的输出结果,获取学生模型和教师模型之间的差异程度。
S202,获取学生模型和教师模型在嵌入层上的第一误差。
为了实现语义通顺度的教师模型和学生模型对于训练样本中的文本文案实现高品质文案的过滤提取,可以将输入模型中的文本进行进一步地向量化的处理,使得输入模型的文本可以被教师模型和学生模型运算处理。
可选地,可以对输入模型中的文本进行向量的编码,通过将输入模型中的文本中的各个词语转换成数字化的向量,获取到文本对应的向量表示。
进一步地,可以在教师模型和学生模型配置嵌入层(embedding),教师模型的嵌入层和学生模型的嵌入层会分别将输入的训练样本的文本中的词语转换成向量,进而获取训练样本的文本对应的向量表示。
可选地,嵌入层可以将训练样本的文本中的词语转化生成对应的字向量、文本向量以及位置向量,用于体现训练样本的文本中的每个词语,在正确的语义表达场景中,相应的意思、上下文的语义以及正确的位置等等相关特征。
进一步地,将教师模型和学生模型的嵌入层输出的,训练样本的文本对应的向量表示进行对比,以获取学生模型的嵌入层的输出与教师模型的嵌入层的输出之间的差异程度,基于该差异程度获取学生模型和教师模型之间在嵌入层上的第一误差。
S203,获取学生模型和教师模型在中间层上的第二误差。
为了使得学生模型可以更加充分的学习到教师模型中间层的信息,在构建初始状态的学生模型时,会从教师模型的中间层中按照设定的层间隔抽取其中的部分中间层,设置为学生模型的中间层。其中,中间层可以对训练样本的文本的特征,以及特征与特征之间的关联关系进行进一步的提取和增强。
可选地,中间层可以设置为自注意力层(transformer)。
其中,由于学生模型的中间层是基于间隔抽取的教师模型的中间层进行设置的,因此,初始状态的学生模型的中间层的模型参数与教师模型中与其相对应的中间层的模型参数相同。
进一步地,可以确定教师模型的中间层中,与学生模型的中间层相对应的部分中间层,并获取该部分中间层中每个层的输出,将其与学生模型的每个对应的中间层的输出进行对比,以获取学生模型的中间层与教师模型中与其相对应的中间层之间的差异程度,基于该差异程度可以获取学生模型和教师模型在中间层上的第二误差。
S204,获取学生模型与教师模型在输出层上的损失函数。
本申请实施例中,学生模型和教师模型配置有输出层,输出层可以基于中间层提取出的训练样本的文本的特征以及特征间的关联关系,对训练样本的文本的语义进行进一步的梳理,使得输出层可以实现符合人们阅读习惯的文本的输出。
在正确的语义表达场景中,基于人们的日常阅读习惯,文本中的词语均存在其设定的所处位置,词语与词语之间存在较为稳定的相对位置关系。本申请实施例中,训练样本的文本中的词语之间的相对位置关系可能存在误差,因此,通过输出层可以将文本中的词语之间的相对位置关系进行梳理。
其中,针对其中一个可以进行语义表达的位置,输出层会基于梳理后的正确的语义表达,确定训练样本的文本中属于该位置上的词语。其中,出现在该位置上的词语可以是一个,也可以是多个,因此,输出层会基于正确的语义表达确定每个词语属于该位置的概率。
进一步地,将每个可以进行语义表达的位置上可以出现的词语对应的概率进行整合,可以生成教师模型和学生模型的输出层输出的预测概率分布。
进一步地,可以获取学生模型的输出层的输出结果与教师模型的输出层的输出结果,并获取二者之间的差异程度,基于该差异程度可以获取学生模型和教师模型在输出层上的损失函数。
可选地,由于学生模型和教师模型的输出层的输出结果为概率分布,因此,可以采用交叉熵函数获取二者之间损失函数。
S205,根据第一误差、第二误差和损失函数,确定学生模型的总损失函数,并基于总损失函数对学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的学生模型训练,直至训练结束,生成目标学生模型。
基于学生模型和教师模型之间的嵌入层上的第一误差、中间层上的第二误差以及输出层上的损失函数后,可以获取学生模型和教师模型之间的总损失函数。
可选地,可以通过设置不同的权重实现对于第一误差、第二误差以及损失函数的整合,进而生成学生模型和教师模型之间的总损失函数。
进一步地,基于获取到的总损失函数,可以调整学生模型的模型参数。比如,通过调整嵌入层的相关参数,可以缩小第一误差。再比如,通过调整中间层的相关参数,可以缩小第二误差。再比如,通过调整输出层的相关参数,可以缩小输出层上的损失函数。进一步地,可以实现学生模型的整体参数的调整。
对于学生模型的参数进行调整,可以使得调整后的学生模型的性能可以更加接近于教师模型的性能。
进一步地,可以将下一个训练样本输入至调整后的学生模型中,并继续对学生模型进行训练,通过多轮次的训练不断的优化学生模型的性能,直至满足训练结束的条件。
可选地,可以基于学生模型的训练轮次设定训练结束的条件,在学生模型进行训练的过程中,统计其训练轮次,当统计到的训练轮次与设定的训练结束条件中的训练轮次的数值相同时,即可结束学生模型训练。
其中,当学生模型的训练轮次达到设置的训练结束的标准轮次时,可以理解为,最后一次的训练结束后的学生模型可以满足实际场景的应用所需。
可选地,可以基于学生模型的总损失函数设定训练结束的条件,将学生模型训练结束需要满足的条件设定为总损失函数对应的阈值,在每一轮次的学生模型训练结束后,将获取到的总损失函数与设定的阈值进行比较,若当前轮次的训练的总损失函数小于或者等于设定的阈值时,即可结束学生模型的训练。
其中,当学生模型对应的总损失函数小于或者等于设定的阈值时,可以理解为,该学生模型的性能可以满足实际场景的应用所需。
进一步地,将满足模型训练的结束条件的学生模型确定为目标学生模型输出。
本申请提出的学生模型的训练方法,获取教师模型和学生模型在嵌入层的第一误差,获取教师模型和学生模型在中间层上的第二误差,获取教师模型和学生模型在输出层上的损失函数,进一步地,生成学生模型对应的总损失函数。基于总损失函数进行学生模型的参数调整,并对调整后的学生模型继续进行训练,直至满足训练结束的条件,并输出目标学生模型。本申请中,通过间隔抽取教师模型的中间层并设置为学生模型的中间层,使得学生模型可以更好的学习到教师模型的中间层的信息,基于教师模型和学生模型中的多个对应层的输出分别获取误差,可以实现对学生模型的模型参数的精确调整,加快了学生模型的训练速度,优化了学生模型的训练效果,有效提高了学生模型的性能。
上述实施例中,关于第一误差的获取,可结合图3进一步理解,图3为本申请另一实施例的学生模型的训练方法的流程示意图,如图3所示,该方法包括:
S301,根据学生模型中嵌入层输出的第一特征表示,以及教师模型中嵌入层输出的第二特征表示,获取第一误差。
实现中,语义通顺度模型通过对输入其中的文本中的词语进行向量化的转换,然后再经过特征提取,以及特征之间的关系的确定,进而对输入其中的文本进行过滤,从而获取到符合人们日常阅读习惯的高品质的文本。
可选地,可以通过嵌入层实现对于输入模型中的文本的向量转化。
可选地,可以在嵌入层中使用一位有效编码(one-hot)的编码方式,对文本进行编码。其中,文本中的每个词语均可以通过One-hot编码方式进行向量转化,进而获取文本对应的向量表示。
本申请实施例中,将训练样本分别输入教师模型和学生模型,教师模型的嵌入层以及学生模型的嵌入层会分别对训练样本的文本进行向量转化,并分别输出训练样本的文本对应的向量表示。
其中,向量表示中可以包括文本对应的字向量、文本向量以及位置向量等相关向量,通过这些向量可以体现训练样本的文本对应的特征,以及特征间的关联关系。
进一步地,将学生模型的嵌入层的输出标记为第一特征表示,将教师模型的嵌入层的输出标记为第二特征表示。将第二特征表示作为标准,获取第一特征表示与第二特征表示之间的差异程度,基于该差异程度可以获取学生模型与教师模型在嵌入层上的第一误差。
嵌入层输出的特征表示可以包括训练样本中每个元素的特征值,其中,根据学生模型中嵌入层输出的特征表示,以及教师模型中嵌入层输出的特征表示,获取第一误差。
进一步地,获取学生模型中嵌入层输出的元素的第一特征值,以及教师模型中嵌入层输出的元素的第二特征值。
可以理解为,学生模型的嵌入层输出的第一特征表示中,每个元素对应的特征值可以标记为第一特征值,相应地,教师模型的嵌入层输出的第二特征表示中,每个元素对应的特征值可以标记为第二特征值。
再进一步,针对同一个元素,根据第一特征值和第二特征值,获取元素的第一均方误差,并将每个元素对应的第一均方误差求和,获取第一误差。
可以理解为,针对文本中的相同的某一个元素,可以获取学生模型提取的其对应的第一特征值,以及教师模型提取的其对应的第二特征值,通过对比该元素的第一特征值与第二特征值之间的差异程度,可以获取学生模型和教师模型对于该元素的特征值提取结果之间的误差。进一步地,通过获取训练样本的文本中每一个元素对应的误差,可以获取学生模型和教师模型在嵌入层上的第一误差。
本申请实施例中,为了使得学生模型可以具有更强的实用性和适用性,可以将学生模型设置为轻量模型。与结构复杂的教师模型相比,学生模型的性能可以实现无限的接近于教师模型的性能,因此,对于教师模型和学生模型之间的误差的估计可以为有偏估计。可选地,可以通过均方误差实现学生模型与教师模型之间的差异程度的度量。
获取训练样本的文本中每个元素对应的第一特征值与第二特征值的均方误差,并将其确定为训练样本的文本中的元素的第一均方误差。通过对每个元素对应的第一均方误差的整合,可以获取学生模型和教师模型在嵌入层上的第一误差。
设定,第一误差为lossemb,则第一误差lossemb的计算公式如下:
Figure BDA0003116023610000091
其中,es(j)为学生模型的嵌入层输出的训练样本的第j个元素的特征值,et(j)为教师模型的嵌入层输出的训练样本的第j个元素的特征值,N为训练样本的文本对应的向量表示的向量维数。
本申请提出的学生模型的训练方法,获取教师模型和学生模型的嵌入层的输出的训练样本的每个元素的特征值之间的第一均方误差,并将获取到的全部的第一均方误差进行整合,获取到学生模型和教师模型之间的嵌入层上的第一误差。通过对第一误差的准确计算,使得学生模型的模型参数可以基于第一误差实现准确的调整,提高了学生模型的训练效率,优化了学生模型的训练效果。
上述实施例中,关于第二误差的获取,可结合图4进一步理解,图4为本申请另一实施例的学生模型的训练方法的流程示意图,如图4所示,该方法包括:
S401,根据学生模型中每个第一中间层输出的第三特征表示,以及教师模型中与第一中间层匹配的第二中间层输出的第四特征表示,获取第二误差。
为了使得学生模型可以学习到教师模型的中间层的信息,在构建初始状态的学生模型时,会基于设定的层间隔,从教师模型的多个中间层中间隔抽取部分中间层设置为学生模型的中间层,因此,初始状态的学生模型的中间层与教师模型中对应的中间层的参数相同。
进一步地,通过计算学生模型的中间层的输出与教师模型的对应的中间层的输出之间的差异程度,可以实现学生模型的性能的优化。
可选地,教师模型和学生模型的中间层可以设置为transformer层。在每个transformer层,通过transformer层对输入的训练样本的文本对应的向量表示中包含的特征进行提取,进而获取在正确的语义表达场景中,训练样本的文本对应的特征表示。
可选地,当教师模型以一种预训练模型(bert)作为模型构建的主结构时,transformer层中的编码器(Encoder)会被一层一层的堆叠,可以在transformer层中添加分类(classification,CLS)字符,进而实现对于训练样本的文本的特征,以及特征间关联关系的提取。
本申请实施例中,可以将学生模型的每个中间层标记为第一中间层,获取到每个第一中间层输出的训练样本的文本的特征表示后,将获取到的特征表示标记为每个第一中间层的第三特征表示。
相应地,可以将教师模型中与学生模型的每个第一中间层对应的中间层标记为第二中间层,获取每个第二中间层输出的训练样本的文本的特征表示后,将获取到的特征表示标记为每个第二中间层的第四特征表示。
需要说明的是,学生模型的第i个第一中间层与教师模型的第j个第二中间层匹配,其中,i与j之间间隔设定层数,其中,i、j均大于或者等于零。
进一步地,根据学生模型中每个第一中间层输出的第三特征表示,以及教师模型中与第一中间层匹配的第二中间层输出的第四特征表示,获取第二误差。
可以理解为,将每个第一中间层输出的第三特征表示,以及教师模型中与第一中间层匹配的第二中间层输出的第四特征表示进行对比,基于对比的结果,可以获取到学生模型的每个第一中间层和教师模型中对应的每个第二中间层之间的差异程度,基于该差异程度可以获取学生模型和教师模型在中间层上的第二误差。
其中,针对每个第一中间层,根据第一中间层的第三特征表示以及匹配的第二中间层的第四特征表示,获取第一中间层的第二均方误差,并将每个第一中间层对应的第二均方误差求和,获取第二误差。
可选地,可以获取学生模型的每个第一中间层输出的第三特征表示,和教师模型中对应的每个第二中间层输出的第四特征表示的均方误差,并将其标记为该第一中间层与对应的第二中间层之间的第二均方误差。
进一步地,获取每一层的第一中间层与对应的第二中间层之间的第二均方误差,可以通过加和的处理方法,获取学生模型的中间层与教师模型的中间层之间的第二误差。
设定,第二误差为lossmiddle,则第二误差lossmiddle的计算公式如下:
Figure BDA0003116023610000101
其中,hs(i)为学生模型的第i个中间层的输出,ht(It(i))为教师模型的第It(i)个中间层的输出,M为学生模型的中间层的层数。
本申请提出的学生模型的训练方法,获取教师模型的中间层中与学生模型对应的每个中间层的输出,以及学生模型的每个中间层的输出,进而获取学会模型的每个中间层的输出与教师模型中与其对应的每个中间层的输出之间的第二均方误差,将获取到的全部的第二均方误差进行整合,获取学生模型和教师模型之间的中间层上的第二误差。通过对第二误差的准确计算,获取到了学生模型的中间层的输出与教师模型的中间层的输出之间的差异程度,使得学生模型的模型参数可以实现准确的调整,提高的学生模型的训练效率,优化了学生模型的训练效果。
上述实施例中,关于输出层上的损失函数的获取,可结合图5进一步理解,图5为本申请另一实施例的学生模型的训练方法的流程示意图,如图5所示,该方法包括:
S501,获取学生模型中输出层输出的第一预测概率分布,以及教师模型中输出层输出的第二预测概率分布。
实现中,语义通顺度模型可以通过输入其中的文本中的词语的提取、词语之间的关联关系提取,以及词语之间的位置关系等相关信息的确定,实现符合人们阅读习惯的高品质文本的生成。
进一步地,基于正确的语义表达,语义通顺度模型会对输入模型中的文本的各个词语之间的位置关系进行重新的梳理确认。可以理解为,针对某一个输入模型中的文本的词语,会对其在正确的语义表达场景中应处的位置进行确认。针对其中一个可以进行语义表达的位置,输出层会基于梳理后的正确的语义表达,确定训练样本的文本中属于该位置上的词语。其中,出现在该位置上的词语可以是一个,也可以是多个,因此,输出层会基于正确的语义表达确定每个词语属于该位置的概率,进而获取文本对应的预测概率分布。
其中,学生模型的输出层可以基于中间层提取到的特征表示,对训练样本的文本进行语义上的重新梳理,并基于梳理后的正确的语义表达,生成每个位置上对应的词语的概率分布,进一步地,将获取到的概率分布确定为学生模型的输出层输出的第一预测概率分布。
相应地,教师模型的输出层可以基于中间层提取出的特征表示,对训练样本的文本进行语义上的梳理,并基于梳理后的正确的语义表达,生成每个位置上对应的词语的概率分布,进一步地,将获取到的概率分布确定为教师模型的输出层输出的第二预测概率分布。
其中,第一预测概率分布与第二预测概率分布中,可以包括训练样本的文本在正确的语义表达场景下,每个位置对应的字符的实际类别的概率以及预测类别的概率。
S502,根据第一预测概率分布和第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数。
确定学生模型输出的第一预测概率分布以及教师模型输出的第二预测概率分布后,根据其中包含的每个位置上对应的字符的实际类别的概率,可以获取学生模型和教师模型的输出中的实际类别对应的损失函数,将其标记为第一损失函数。
相应的,根据其中包含的每个位置上对应的字符的预测类别的概率,可以获取学生模型和教师模型的输出中的预测类别对应的损失函数,将其标记为第二损失函数。
本申请实施例中,第一损失函数与第二损失函数体现的是学生模型输出的概率分布与教师模型输出的概率分布之间的差异程度,因此,可以使用交叉熵计算学生模型与教师模型之间的损失函数。
进一步地,从第一预测概率分布中,获取训练样本所属的实际类别对应的目标预测概率,并根据目标预测概率,确定第一损失函数。
可以理解为,从第一预测概率分布中,可以获取每个位置上对应的词语的所属的实际类别是正确的实际类别的目标预测概率,基于该目标预测概率可以获取学生模型的输出中,实际类别对应的第一损失函数。
设定,学生模型的第一预测概率为ps(t),第一损失函数为losshard,则计算公式如下:
losshard=-logps(t)
再进一步,针对所有类别中的任一类别,从第一预测概率分布中,获取任一类别对应的第一预测概率,并从第二预测概率分布中,获取任一类别对应的第二预测概率。
可以理解为,第一预测概率分布与第二预测概率分布中,包含对于每个位置对应的文字的可能所属的类别进行预测的预测概率。
其中,第一预测概率分布中的任一类别对应的预测概率可以标记为第一预测概率,第二预测概率分布中的任一类别对应的预测概率可以标记为第二预测概率。
再进一步地,根据任一类别的第一预测概率和第二预测概率,确定任一类别的损失值,并将每个类别的损失值求和,获取第二损失函数。
设定,第一预测概率为ps,第二预测概率为pt,则基于第一预测概率与第二预测概率获取到的第二损失函数为losssoft,公式如下:
Figure BDA0003116023610000121
其中,C为所有类别的集合,pt(i)为教师模型对于第i个类别的预测概率,ps(i)为学生模型对第i个类别的预测概率。
S503,对第一损失函数和第二损失函数进行加权,获取损失函数。
确定实际类别对应的第一损失函数与预测类别对应的第二损失函数后,将二者基于权重进行整合,可以获取到学生模型与教师模型之间在输出层上的损失函数。
设定,第一损失函数为losshard,第二损失函数为losssoft,学生模型与教师模型之间在输出层上的损失函数为,则公式如下:
losslabel=(1-μ)losshard+μlosssoft
其中,μ为设定的第二损失函数的权重参数,通过权重参数的设置可以实现第一损失函数与第二损失函数的平衡。
本申请提出的学生模型的训练方法,基于第一预测概率分布与第二预测概率分布,获取学生模型与教师模型之间的第一损失函数与第二损失函数,进一步地,将二者基于权重占比加和,获取到学生模型与教师模型在输出层上的损失函数。通过对损失函数的准确计算,获取到了学生模型的输出层与教师模型的输出层之间的差异程度,使得学生模型的模型参数可以实现准确的调整,提高的学生模型的训练效率,优化了学生模型的训练效果。
进一步地,在上述实施例的基础之上,获取到嵌入层上的第一误差、中间层上的第二误差以及输出层上的损失函数后,可以确定学生模型的总损失函数,作为一种可能的实现方式,如图6所示,包括:
S601,对第一误差、第二误差和损失函数进行加权,获取总损失函数。
获取学生模型和教师模型在嵌入层上的第一误差、中间层上的第二误差以及输出层上的损失函数后,可以进行进一步地整合,以获取学生模型和教师模型之间的总损失函数。
可选地,可以分别设置第一误差、第二误差与损失函数的权重参数,并基于设置的权重参数对第一误差、第二误差以及损失函数进行加权,进而获取到学生模型和教师模型之间的总损失函数。
设定,第一误差为lossemb,第二误差为lossmiddle,损失函数为losslabel,则学生模型和教师模型之间的总损失函数为,公式如下:
loss=αlossemb+βlossmiddle+γlosslabel
其中,α为第一误差的权重参数,β为第二误差的权重参数,γ为损失函数的权重参数。
基于获取到的总损失函数,可以对学生模型进行参数调整,进一步地,可以通过下一个训练样本对调整后的学生模型继续进行模型训练,直至满足训练结束的条件而停止训练,并将最后一个轮次训练获取到的学生模型作为目标学生模型输出。
本申请提出的学生模型的训练方法,通过对学生模型和教师模型在嵌入层上的第一误差、中间层上的第二误差以及输出层上的损失函数进行加权,获取到学生模型与教师模型之间的总损失函数。通过精准的对总损失函数的计算,可以实现对于学生模型参数的精准调整,有效提高了学生模型的训练效率,优化了学生模型的训练效果。
上述实施例提出的学生模型的训练方法除了应用于文本场景下,还可以应用于语音识别和图像识别等场景下,比如,在语音识别场景下,训练样本可以为语音训练样本,学生模型和教师模型可以为语音识别模型,在图像处理场景下,训练样本可以为图像训练样本,学生模型和教师模型可以为图像分类模型。
与上述几种实施例提出的学生模型的训练方法相对应,本申请的一个实施例还提出了一种学生模型的训练装置,由于本申请实施例提出的学生模型的训练装置与上述几种实施例提出的学生模型的训练方法相对应,因此上述学生模型的训练方法的实施方式也适用于本申请实施例提出的学生模型的训练装置,在下述实施例中不再详细描述。
图7为本申请一实施例的学生模型的训练装置的结构示意图,如图7所示,学生模型的训练装置700,包括输入模块71、获取模块72、训练模块73,其中:
输入模块71,用于将训练样本分别输入学生模型和教师模型中进行训练;
获取模块72,用于获取学生模型和教师模型在嵌入层上的第一误差,以及获取学生模型和教师模型在中间层上的第二误差,以及获取学生模型与教师模型在输出层上的损失函数;
训练模块73,用于根据第一误差、第二误差和损失函数,确定学生模型的总损失函数,并基于总损失函数对学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的学生模型训练,直至训练结束,生成目标学生模型。
图8为本申请一实施例的学生模型的训练装置的结构示意图,如图8所示,学生模型的训练装置800,包括输入模块81、获取模块82、训练模块83,其中:
需要说明的是,输入模块71、获取模块72、训练模块73与输入模块81、获取模块82、训练模块83,具有相同的结构和功能。
本申请实施例中,获取模块82,还用于:根据学生模型中嵌入层输出的第一特征表示,以及教师模型中嵌入层输出的第二特征表示,获取第一误差。
本申请实施例中,获取模块82,还用于:获取学生模型中嵌入层输出的元素的第一特征值,以及教师模型中嵌入层输出的元素的第二特征值;针对同一个元素,根据第一特征值和第二特征值,获取元素的第一均方误差,并将每个元素对应的第一均方误差求和,获取第一误差。
本申请实施例中,获取模块82,还用于:根据学生模型中每个第一中间层输出的第三特征表示,以及教师模型中与第一中间层匹配的第二中间层输出的第四特征表示,获取第二误差。
本申请实施例中,获取模块82,还用于:针对每个第一中间层,根据第一中间层的第三特征表示以及匹配的第二中间层的第四特征表示,获取第一中间层的第二均方误差,并将每个第一中间层对应的第二均方误差求和,获取第二误差。
本申请实施例中,获取模块82,还用于:获取学生模型中输出层输出的第一预测概率分布,以及教师模型中输出层输出的第二预测概率分布;根据第一预测概率分布和第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数;对第一损失函数和第二损失函数进行加权,获取损失函数。
本申请实施例中,获取模块82,还用于:从第一预测概率分布中,获取训练样本所属的实际类别对应的目标预测概率,并根据目标预测概率,确定第一损失函数;针对所有类别中的任一类别,从第一预测概率分布中,获取任一类别对应的第一预测概率,并从第二预测概率分布中,获取任一类别对应的第二预测概率;根据任一类别的第一预测概率和第二预测概率,确定任一类别的损失值,并将每个类别的损失值求和,获取第二损失函数。
本申请实施例中,训练模块83,还用于:对第一误差、第二误差和损失函数进行加权,获取总损失函数。
本申请实施例中,学生模型的第i个第一中间层与教师模型的第j个第二中间层匹配,其中,i与j之间间隔设定层数,其中,i、j均大于或者等于零。
本申请实施例中,初始的学生模型中每层的模型参数与教师模型中对应层的模型参数相同。
本申请提出的学生模型的训练装置,获取教师模型和学生模型在嵌入层的第一误差,获取教师模型和学生模型在中间层上的第二误差,获取教师模型和学生模型在输出层上的损失函数,进一步地,生成学生模型对应的总损失函数。基于总损失函数进行学生模型的参数调整,并对调整后的学生模型继续进行训练,直至满足训练结束的条件,并输出目标学生模型。本申请中,通过间隔抽取教师模型的中间层并设置为学生模型的中间层,使得学生模型可以更好的学习到教师模型的中间层的信息,基于教师模型和学生模型中的多个对应层的输出分别获取误差,可以实现对学生模型的模型参数的精确调整,加快了学生模型的训练速度,优化了学生模型的训练效果,有效提高了学生模型的性能。
图9示出了可以用来实施本申请的实施例的示例电子设备900的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本申请的实现。
如图9所示,设备900包括计算单元901,其可以根据存储在只读存储器(ROM)902中的计算机程序或者从存储单元908加载到随机访问存储器(RAM)903中的计算机程序,来执行各种适当的动作和处理。在RAM 903中,还可存储设备900操作所需的各种程序和数据。计算单元901、ROM902以及RAM903通过总线904彼此相连。输入/输出(I/O)接口905也连接至总线904。
设备900中的多个部件连接至I/O接口905,包括:输入单元906,比如键盘、鼠标等;输出单元907,比如各种类型的显示器、扬声器等;存储单元908,比如磁盘、光盘等;以及通信单元909,比如网卡、调制解调器、无线通信收发机等。通信单元909允许设备900通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元901可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元901的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元901执行上文所描述的各个方法和处理,比如学生模型的训练方法。比如,在一些实施例中,学生模型的训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,比如存储单元908。在一些实施例中,计算机程序的部分或者全部可以经由ROM902和/或通信单元909而被载入和/或安装到设备900上。当计算机程序加载到RAM903并由计算单元901执行时,可以执行上文描述的学生模型的训练方法一个或多个步骤。备选地,在其他实施例中,计算单元901可以通过其他任何适当的方式(比如,借助于固件)而被配置为执行学生模型的训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本申请的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本申请的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与使用者的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向使用者显示信息的显示装置(比如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(比如,鼠标或者轨迹球),使用者可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与使用者的交互;比如,提供给使用者的反馈可以是任何形式的传感反馈(比如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自使用者的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(比如,作为数据服务器)、或者包括中间件部件的计算系统(比如,应用服务器)、或者包括前端部件的计算系统(比如,具有图形使用者界面或者网络浏览器的使用者计算机,使用者可以通过该图形使用者界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(比如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)、互联网和区块链网络。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务端可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与VPS服务(“Virtual Private Server”,或简称“VPS”)中,存在的管理难度大,业务扩展性弱的缺陷。服务器也可以为分布式系统的服务器,或者是结合区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。比如,本申请中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本申请公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本申请保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本申请的精神和原则之内所作的修改、等同替换和改进等,均应包含在本申请保护范围之内。

Claims (23)

1.一种学生模型的训练方法,包括:
将训练样本分别输入学生模型和教师模型中进行训练;
获取所述学生模型和所述教师模型在嵌入层上的第一误差;
获取所述学生模型和所述教师模型在中间层上的第二误差;
获取所述学生模型与所述教师模型在输出层上的损失函数;
根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。
2.根据权利要求1所述的方法,其中,所述获取所述学生模型和所述教师模型在嵌入层上的第一误差,包括:
根据所述学生模型中嵌入层输出的第一特征表示,以及所述教师模型中嵌入层输出的第二特征表示,获取所述第一误差。
3.根据权利要求2所述的方法,其中,所述嵌入层输出的特征表示包括所述训练样本中每个元素的特征值,其中,所述根据所述学生模型中嵌入层输出的特征表示,以及所述教师模型中嵌入层输出的特征表示,获取所述第一误差,包括:
获取所述学生模型中嵌入层输出的所述元素的第一特征值,以及所述教师模型中嵌入层输出的所述元素的第二特征值;
针对同一个所述元素,根据所述第一特征值和所述第二特征值,获取所述元素的第一均方误差,并将每个所述元素对应的第一均方误差求和,获取所述第一误差。
4.根据权利要求1所述的方法,其中,所述获取所述学生模型和所述教师模型在中间层上的第二误差,包括:
根据所述学生模型中每个第一中间层输出的第三特征表示,以及所述教师模型中与所述第一中间层匹配的第二中间层输出的第四特征表示,获取所述第二误差。
5.根据权利要求4所述的方法,其中,所述根据所述学生模型中每个第一中间层输出的第三特征表示,以及所述教师模型中与所述第一中间层匹配的第二中间层输出的第四特征表示,获取所述第二误差,包括:
针对每个所述第一中间层,根据所述第一中间层的所述第三特征表示以及所述匹配的第二中间层的所述第四特征表示,获取所述第一中间层的第二均方误差,并将每个所述第一中间层对应的第二均方误差求和,获取所述第二误差。
6.根据权利要求1-5任一项所述的方法,其中,所述获取所述学生模型与所述教师模型在输出层上的损失函数,包括:
获取所述学生模型中输出层输出的第一预测概率分布,以及所述教师模型中输出层输出的第二预测概率分布;
根据所述第一预测概率分布和所述第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数;
对所述第一损失函数和所述第二损失函数进行加权,获取所述损失函数。
7.根据权利要求6所述的方法,其中,所述根据所述第一预测概率分布和所述第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数,包括:
从所述第一预测概率分布中,获取所述训练样本所属的实际类别对应的目标预测概率,并根据所述目标预测概率,确定所述第一损失函数;
针对所有类别中的任一类别,从所述第一预测概率分布中,获取所述任一类别对应的第一预测概率,并从所述第二预测概率分布中,获取所述任一类别对应的第二预测概率;
根据所述任一类别的所述第一预测概率和所述第二预测概率,确定所述任一类别的损失值,并将每个所述类别的所述损失值求和,获取所述第二损失函数。
8.根据权利要求1-5任一项所述的方法,其中,所述根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,包括:
对所述第一误差、所述第二误差和所述损失函数进行加权,获取所述总损失函数。
9.根据权利要求1-5任一项所述的方法,其中,所述学生模型的第i个第一中间层与所述教师模型的第j个第二中间层匹配,其中,所述i与所述j之间间隔设定层数,其中,所述i、j均大于或者等于零。
10.根据权利要求9所述的方法,其中,初始的所述学生模型中每层的模型参数与所述教师模型中对应层的模型参数相同。
11.一种学生模型的训练装置,包括:
输入模块,用于将训练样本分别输入学生模型和教师模型中进行训练;
获取模块,用于获取所述学生模型和所述教师模型在嵌入层上的第一误差,以及获取所述学生模型和所述教师模型在中间层上的第二误差,以及获取所述学生模型与所述教师模型在输出层上的损失函数;
训练模块,用于根据所述第一误差、所述第二误差和所述损失函数,确定所述学生模型的总损失函数,并基于所述总损失函数对所述学生模型的模型参数进行调整,并继续使用下一个训练样本对调整模型参数的所述学生模型训练,直至训练结束,生成目标学生模型。
12.根据权利要求11所述的装置,其中,所述获取模块,还用于:
根据所述学生模型中嵌入层输出的第一特征表示,以及所述教师模型中嵌入层输出的第二特征表示,获取所述第一误差。
13.根据权利要求12所述的装置,其中,所述获取模块,还用于:
获取所述学生模型中嵌入层输出的所述元素的第一特征值,以及所述教师模型中嵌入层输出的所述元素的第二特征值;
针对同一个所述元素,根据所述第一特征值和所述第二特征值,获取所述元素的第一均方误差,并将每个所述元素对应的第一均方误差求和,获取所述第一误差。
14.根据权利要求11所述的装置,其中,所述获取模块,还用于:
根据所述学生模型中每个第一中间层输出的第三特征表示,以及所述教师模型中与所述第一中间层匹配的第二中间层输出的第四特征表示,获取所述第二误差。
15.根据权利要求14所述的装置,其中,所述获取模块,还用于:
针对每个所述第一中间层,根据所述第一中间层的所述第三特征表示以及所述匹配的第二中间层的所述第四特征表示,获取所述第一中间层的第二均方误差,并将每个所述第一中间层对应的第二均方误差求和,获取所述第二误差。
16.根据权利要求11-15任一项所述的装置,其中,所述获取模块,还用于:
获取所述学生模型中输出层输出的第一预测概率分布,以及所述教师模型中输出层输出的第二预测概率分布;
根据所述第一预测概率分布和所述第二预测概率分布,确定实际类别对应的第一损失函数,以及预测类别对应的第二损失函数;
对所述第一损失函数和所述第二损失函数进行加权,获取所述损失函数。
17.根据权利要求16所述的装置,其中,所述获取模块,还用于:
从所述第一预测概率分布中,获取所述训练样本所属的实际类别对应的目标预测概率,并根据所述目标预测概率,确定所述第一损失函数;
针对所有类别中的任一类别,从所述第一预测概率分布中,获取所述任一类别对应的第一预测概率,并从所述第二预测概率分布中,获取所述任一类别对应的第二预测概率;
根据所述任一类别的所述第一预测概率和所述第二预测概率,确定所述任一类别的损失值,并将每个所述类别的所述损失值求和,获取所述第二损失函数。
18.根据权利要求11-15任一项所述的装置,其中,所述训练模块,还用于:
对所述第一误差、所述第二误差和所述损失函数进行加权,获取所述总损失函数。
19.根据权利要求11-15任一项所述的装置,其中,所述学生模型的第i个第一中间层与所述教师模型的第j个第二中间层匹配,其中,所述i与所述j之间间隔设定层数,其中,所述i、j均大于或者等于零。
20.根据权利要求19所述的装置,其中,初始的所述学生模型中每层的模型参数与所述教师模型中对应层的模型参数相同。
21.一种电子设备,其特征在于,包括处理器和存储器;
其中,所述处理器通过读取所述存储器中存储的可执行程序代码来运行与所述可执行程序代码对应的程序,以用于实现如权利要求1-10中任一所述的方法。
22.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现如权利要求1-10中任一所述的方法。
23.一种计算机程序产品,其特征在于,当所述计算机程序产品中的指令处理器执行时实现如权利要求1-10中任一所述的方法。
CN202110662767.0A 2021-06-15 2021-06-15 学生模型的训练方法、装置及电子设备 Active CN113435208B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110662767.0A CN113435208B (zh) 2021-06-15 2021-06-15 学生模型的训练方法、装置及电子设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110662767.0A CN113435208B (zh) 2021-06-15 2021-06-15 学生模型的训练方法、装置及电子设备

Publications (2)

Publication Number Publication Date
CN113435208A true CN113435208A (zh) 2021-09-24
CN113435208B CN113435208B (zh) 2023-08-25

Family

ID=77756027

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110662767.0A Active CN113435208B (zh) 2021-06-15 2021-06-15 学生模型的训练方法、装置及电子设备

Country Status (1)

Country Link
CN (1) CN113435208B (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2022217853A1 (en) * 2021-04-16 2022-10-20 Huawei Technologies Co., Ltd. Methods, devices and media for improving knowledge distillation using intermediate representations
CN116861302A (zh) * 2023-09-05 2023-10-10 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111611377A (zh) * 2020-04-22 2020-09-01 淮阴工学院 基于知识蒸馏的多层神经网络语言模型训练方法与装置
CN111950302A (zh) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 基于知识蒸馏的机器翻译模型训练方法、装置、设备及介质
US20200364542A1 (en) * 2019-05-16 2020-11-19 Salesforce.Com, Inc. Private deep learning
CN112508120A (zh) * 2020-12-18 2021-03-16 北京百度网讯科技有限公司 学生模型训练方法、装置、设备、介质和程序产品

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20200364542A1 (en) * 2019-05-16 2020-11-19 Salesforce.Com, Inc. Private deep learning
CN111611377A (zh) * 2020-04-22 2020-09-01 淮阴工学院 基于知识蒸馏的多层神经网络语言模型训练方法与装置
CN111950302A (zh) * 2020-08-20 2020-11-17 上海携旅信息技术有限公司 基于知识蒸馏的机器翻译模型训练方法、装置、设备及介质
CN112508120A (zh) * 2020-12-18 2021-03-16 北京百度网讯科技有限公司 学生模型训练方法、装置、设备、介质和程序产品

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
PENG SHEN等: "Interactive Learning of Teacher-student Model for Short Utterance Spoken Language Identification", 《ICASSP 2019 - 2019 IEEE INTERNATIONAL CONFERENCE ON ACOUSTICS, SPEECH AND SIGNAL PROCESSING (ICASSP)》 *
宁尚明;滕飞;李天瑞;: "基于多通道自注意力机制的电子病历实体关系抽取", 计算机学报, no. 05 *

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2022217853A1 (en) * 2021-04-16 2022-10-20 Huawei Technologies Co., Ltd. Methods, devices and media for improving knowledge distillation using intermediate representations
CN116861302A (zh) * 2023-09-05 2023-10-10 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法
CN116861302B (zh) * 2023-09-05 2024-01-23 吉奥时空信息技术股份有限公司 一种案件自动分类分拨方法

Also Published As

Publication number Publication date
CN113435208B (zh) 2023-08-25

Similar Documents

Publication Publication Date Title
CN108536679B (zh) 命名实体识别方法、装置、设备及计算机可读存储介质
US20210312139A1 (en) Method and apparatus of generating semantic feature, method and apparatus of training model, electronic device, and storage medium
CN110427625B (zh) 语句补全方法、装置、介质及对话处理系统
CN111783474B (zh) 一种评论文本观点信息处理方法、装置及存储介质
CN113033622B (zh) 跨模态检索模型的训练方法、装置、设备和存储介质
CN110598713A (zh) 基于深度神经网络的智能图像自动描述方法
CN113792854A (zh) 一种模型训练及字库建立方法、装置、设备及存储介质
CN113268609A (zh) 基于知识图谱的对话内容推荐方法、装置、设备及介质
CN112101010B (zh) 一种基于bert的电信行业oa办公自动化文稿审核的方法
CN115455171B (zh) 文本视频的互检索以及模型训练方法、装置、设备及介质
CN113553412B (zh) 问答处理方法、装置、电子设备和存储介质
CN113435208B (zh) 学生模型的训练方法、装置及电子设备
CN110781413A (zh) 兴趣点确定方法及装置、存储介质、电子设备
CN113204611A (zh) 建立阅读理解模型的方法、阅读理解方法及对应装置
CN111695338A (zh) 基于人工智能的面试内容精炼方法、装置、设备及介质
CN116152833B (zh) 基于图像的表格还原模型的训练方法及表格还原方法
CN112560985A (zh) 神经网络的搜索方法、装置及电子设备
CN111767697A (zh) 文本处理方法、装置、计算机设备以及存储介质
CN113705242B (zh) 面向教育咨询服务的智能语义匹配方法和装置
US20230075339A1 (en) Method of training information generation model, method of generating information, and device
CN115357710B (zh) 表格描述文本生成模型的训练方法、装置及电子设备
CN116431827A (zh) 信息处理方法、装置、存储介质及计算机设备
CN115687934A (zh) 意图识别方法、装置、计算机设备及存储介质
CN115730590A (zh) 意图识别方法以及相关设备
CN114882388A (zh) 多任务模型的训练及预测方法、装置、设备和介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant