CN114330510A - 模型训练方法、装置、电子设备和存储介质 - Google Patents

模型训练方法、装置、电子设备和存储介质 Download PDF

Info

Publication number
CN114330510A
CN114330510A CN202111511703.7A CN202111511703A CN114330510A CN 114330510 A CN114330510 A CN 114330510A CN 202111511703 A CN202111511703 A CN 202111511703A CN 114330510 A CN114330510 A CN 114330510A
Authority
CN
China
Prior art keywords
prediction result
model
prediction
credibility
teacher
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202111511703.7A
Other languages
English (en)
Other versions
CN114330510B (zh
Inventor
李磊
林衍凯
任宣丞
赵光香
李鹏
周杰
孙栩
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peking University
Tencent Technology Shenzhen Co Ltd
Original Assignee
Peking University
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peking University, Tencent Technology Shenzhen Co Ltd filed Critical Peking University
Priority to CN202111511703.7A priority Critical patent/CN114330510B/zh
Publication of CN114330510A publication Critical patent/CN114330510A/zh
Application granted granted Critical
Publication of CN114330510B publication Critical patent/CN114330510B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Landscapes

  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请实施例公开了一种模型训练方法、装置、电子设备和存储介质,该方法涉及人工智能领域中的深度学习方向,包括:获取教师模型和学生模型;获取第一预测结果,第一预测结果由教师模型对样本数据集中的样本数据进行预测得到;获取教师模型对第一预测结果的可信度,可信度用于表征第一预测结果的可信程度;根据可信度更新第一预测结果,将更新后的第一预测结果作为第二预测结果;获取第二预测结果与学生模型预测样本数据的结果之间的差异;基于差异,更新学生模型的参数,以训练学生模型。本申请实施例通过教师模型对样本数据预测的可信度,能够准确地确定出样本数据对应的标签,以提升对学生模型的训练效果和效率。

Description

模型训练方法、装置、电子设备和存储介质
技术领域
本申请涉及计算机技术领域,具体涉及一种模型训练方法、装置、电子设备和存储介质。
背景技术
随着人工智能技术的发展,深度学习成为国内外多个领域的研究热点,知识融合(Knowledge Amalgamation,KA)是深度学习的一个重要研究方向。知识融合是指使用多个教师模型对学生模型进行训练,将教师模型的计算能力融合到学生模型中,使学生模型具备与教师模型同等水平的计算能力。
然而,在目前的相关技术中,会因教师模型对训练样本的过自信预测、各教师模型之间的结构存在差异等原因,导致知识融合的效率较低、效果较差。
发明内容
本申请实施例提供一种模型训练方法、装置、电子设备和存储介质,可以提升对学生模型的训练效果和效率。
本申请实施例提供一种模型训练方法,包括:
获取教师模型和学生模型;
获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;
获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;
根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;
获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;
基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
相应的,本申请实施例提供一种模型训练装置,包括:
模型获取单元,用于获取教师模型和学生模型;
第一预测结果获取单元,用于获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;
可信度获取单元,用于获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;
第二预测结果获取单元,用于根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;
差异获取单元,用于获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;
模型训练单元,用于基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
可选的,在一些实施例中,所述第二预测结果获取单元可以包括第一计算子单元和加权子单元,如下:
第一计算子单元,用于根据每个第一预测结果的可信度,计算所述每个第一预测结果对应的加权值;
第二计算子单元,用于对所述每个第一预测结果结合所述加权值进行加权求和,得到更新后的所述第一预测结果作为第二预测结果。
可选的,在一些实施例中,所述第一计算子单元具体可以用于获取所述每个第一预测结果的可信度参数,以及所述第一预测结果的可信度参数总和;计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比,将所述占比确定为所述每个第一预测结果对应的加权值。
可选的,在一些实施例中,所述第一计算子单元具体可以用于获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;计算所述每个第一预测结果的可信度与所述类别数量的对数的比值;将预设值与所述比值的差值确定为所述每个第一预测结果的可信度参数。
可选的,在一些实施例中,所述第一计算子单元具体可以用于对所述每个第一预测结果的可信度参数做e次幂运算,得到第一参数;对所述可信度参数总和做e次幂运算,得到第二参数;对所述第一参数与第二参数进行相除,得到所述占比。
可选的,在一些实施例中,所述第二预测结果获取单元还可以包括对比子单元和结果确定子单元,如下:
对比子单元,用于将不同的所述可信度中的每个可信度进行逐一对比;
结果确定子单元,用于将最小可信度对应的第一预测结果,确定为所述第二预测结果。
可选的,在一些实施例中,所述可信度获取单元可以包括获取子单元和第三计算子单元,如下:
获取子单元,用于获取第一概率值,所述第一概率值表征所述教师模型在目标状态下对所述样本数据进行预测的准确率;其中,在所述目标状态下的所述教师模型中包括至少一个被随机掩盖的神经节点;
第三计算子单元,用于计算所述第一概率值的熵,将所述熵确定为所述教师模型对所述第一预测结果的可信度,返回并执行获取所述第一概率值。
可选的,在一些实施例中,所述获取子单元具体可以用于获取每个教师模型的第二概率值和预测次数,所述第二概率值表征所述每个教师模型的在不同目标状态下对所述样本数据进行预测的准确率;将所述预测次数对应数量的所述第二概率值进行求和运算,得到第二概率值总和;将所述得到第二概率值总和与所述预测次数进行相除,得到所述第一概率值。
可选的,在一些实施例中,所述第三计算子单元具体可以用于获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;计算所述第一概率值与其对数的乘积值;将所述预测次数对应数量的所述乘积值进行求和运算,得到所述乘积值总和,将所述乘积值总和的负数作为所述概率值的熵。
本申请实施例提供一种电子设备,包括处理器和存储器,所述存储器存储有多条指令;所述处理器从所述存储器中加载指令,以执行本申请实施例提供的模型训练方法中的步骤。
本申请实施例还提供一种计算机可读存储介质,所述计算机可读存储介质存储有多条指令,所述指令适于处理器进行加载,以执行本申请实施例提供的模型训练方法中的步骤。
此外,本申请实施例还提供一种计算机程序产品,包括计算机程序或指令,所述计算机程序或指令被处理器执行时实现本申请实施例提供的模型训练方法中的步骤。
本申请实施例提供了一种模型训练方法、装置、电子设备和存储介质,可以获取教师模型和学生模型;获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
在本申请中,可以通过教师模型对样本数据预测的可信度,准确地确定出样本数据对应的标签。由此,能够在克服教师模型过自信预测和结构差异的基础上,提升对学生模型的训练效果和效率。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1a是本申请实施例提供的模型训练方法的场景示意图;
图1b是本申请实施例提供的模型训练方法的流程示意图;
图1c是本申请实施例提供的训练学生模型的过程示意图;
图1d是本申请实施例提供的训练学生模型的过程示意图;
图1e是本申请实施例提供的获取可信度的过程示意图;
图2是本申请实施例提供的模型训练方法的另一流程示意图;
图3是本申请实施例提供的模型训练装置的结构示意图;
图4是本申请实施例提供的电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
下面简单介绍一下本申请实施例可能用到的技术。
人工智能(Artificial Intelligence,AI)是一种利用数字计算机来模拟人类感知环境、获取知识并使用知识的技术,该技术可以使机器具有类似于人类的感知、推理与决策的功能。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习、自动驾驶、智慧交通等几大方向。
自然语言处理(Nature Language processing,NLP)是计算机科学领域与人工智能领域中的一个重要方向。它研究能实现人与计算机之间用自然语言进行有效通信的各种理论和方法。自然语言处理是一门融语言学、计算机科学、数学于一体的科学。因此,这一领域的研究将涉及自然语言,即人们日常使用的语言,所以它与语言学的研究有着密切的联系。自然语言处理技术通常包括文本处理、语义理解、机器翻译、机器人问答、知识图谱等技术。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
本申请实施例提供一种模型训练方法、装置、电子设备和存储介质。
其中,该模型训练装置具体可以集成在电子设备中,该电子设备可以为终端、服务器等设备。其中,终端可以为手机、平板电脑、智能蓝牙设备、笔记本电脑、或者个人电脑(Personal Computer,PC)等设备;服务器可以是单一服务器,也可以是由多个服务器组成的服务器集群。
在一些实施例中,该模型训练装置还可以集成在多个电子设备中,比如,模型训练装置可以集成在多个服务器中,由多个服务器来实现本申请的模型训练方法。
可以理解的是,本实施例的模型训练方法可以是在终端上执行的,也可以是在服务器上执行,还可以由终端和服务器共同执行的。以上举例不应理解为对本申请的限制。
请参考图1a,以服务器11执行该模型训练方法为例。
如图1a所示,服务器11和终端12可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。服务器11可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、CDN、以及大数据和人工智能平台等基础云计算服务的云服务器。服务器11用于为终端12运行的应用程序提供后台服务。
终端12可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。终端12可以安装和运行有支持数据上传的应用程序。该应用程序可以是相册类应用程序、社交类应用程序、购物类应用程序以及检索类应用程序等。示意性的,终端12是用户使用的终端,终端12中运行的应用程序内登录有用户账户。
可选地,服务器11承担主要模型训练工作,终端12承担次要模型训练工作;或者,服务器11承担次要模型训练工作,终端12承担主要模型训练工作;或者,服务器11或终端12分别可以单独承担模型训练工作。
终端12可以泛指多个终端中的一个,本实施例仅以终端12来举例说明。本领域技术人员可以知晓,上述终端的数量可以更多或更少。比如上述终端可以仅为一个,或者上述终端为几十个或几百个,或者更多数量,此时上述图像分类方法的实施环境还包括其他终端。本申请实施例对终端的数量和设备类型不加以限定。
可选的,上述的无线网络或有线网络使用标准通信技术和/或协议。网络通常为因特网、但也可以是任何网络,包括但不限于局域网(Local Area Network,LAN)、城域网(Metropolitan Area Network,MAN)、广域网(Wide Area Network,WAN)、移动、有线或者无线网络、专用网络或者虚拟专用网络的任何组合)。在一些实施例中,使用包括超文本标记语言(Hyper Text Mark-up Language,HTML)、可扩展标记语言(ExtensibleMarkupLanguage,XML)等的技术和/或格式来代表通过网络交换的数据。此外还可以使用诸如安全套接字层(Secure Socket Layer,SSL)、传输层安全(Transport Layer Security,TLS)、虚拟专用网络(Virtual Private Network,VPN)、网际协议安全(InternetProtocolSecurity,IPsec)等常规加密技术来加密所有或者一些链路。在另一些实施例中,还可以使用定制和/或专用数据通信技术取代或者补充上述数据通信技术。
其中,服务器11可以用于:获取教师模型和学生模型;获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
在一种可选的实现方式中,本申请实施例提供的模型训练方法,能够用于文本分类场景,下面以训练文本分类模型的场景进行介绍。
首先,通过有标注的训练集预训练得到多个教师模型,教师模型的结构较为复杂、性能较强。其中,有标注的训练集中包括多个文本的训练样本数据。然后通过教师模型分别对另一无标注的训练集中的文本进行预测,获取预测得到的第一预测结果,即各文本被教师模型预测的输出结果。例如,第一预测结果可以是该文本的内容分类,例如该文本的内容是财经类、体育类、医疗类等。
然后再确定各教师模型对各第一预测结果的可信度,从而确定各第一预测结果的可信程度。基于各第一预测结果的可信度,确定出多个教师模型对该文本进行预测的唯一的第二预测结果,来训练学生模型。通过第一预测结果的可信度进一步得到更准确的第二预测结果,作为训练学生模型的标签,从而能够在克服教师模型预测过自信,以及教师模型与学生模型异构的基础上,提高学生模型的训练效率和效果。
在一种可选的实现方式中,本申请实施例提供的模型训练方法,还能够应用于情感分类场景,下面以训练情感分类模型的场景进行介绍。
首先,通过有标注的训练集预训练得到多个教师模型,教师模型的结构较为复杂、性能较强。然后通过多个教师模型分别对无标注的训练集进行预测,获取预测得到的第一预测结果,第一预测结果可以为文本对应的情感分类,例如积极情感、中立情感、消极情感等。
然后再确定出各第一预测结果的可信度,从而确定各教师模型对其第一预测结果的可信程度。基于各第一预测结果的可信程度,确定出多个教师模型对该文本进行预测的唯一的第二预测结果,来训练学生模型。
其中,在完成对学生模型的训练之后,可以将该学生模型部署到服务器11中,为终端12提供相应的服务,比如分类服务或分析服务等等。也可以部署到需要该学生模型的其它服务器中,比如部署到人工智能客服中。
其中,以通过该学生模型进行文本分类为例,终端12可以用于获取待分类对象,将该待分类对象发送给服务器11;服务器11接收待分类对象,将该待分类对象输入到训练好的学生模型,通过训练好的学生模型对该待分类对象的分类进行预测,获取学生模型的输出,将其作为分类结果,将该分类结果发送给终端12,通过终端12的显示器显示该分类结果。
上述服务器11执行的步骤,也可以由终端12执行。
以下分别进行详细说明。需说明的是,以下实施例的序号不作为对实施例优选顺序的限定。
本申请实施例提供的模型训练方法可以应用于各种类型的神经网络知识融合应用场景中,通过多个教师模型对学生模型进行训练。比如,可以训练用于文本分类的学生模型,可以将训练好的学生模型部署在线上阅读的相关应用中,以实现对不同类型的线上阅读内容进行分类;又例如,还可以将该模型训练方法用于智能客服中,通过部署训练好的学生模对用户输入的文本内容进行语义分析、分类,为用户提供准确的回复内容。
本实施例将从模型训练装置的角度进行描述,该模型训练装置具体可以集成在电子设备中,该电子设备可以是服务器或终端等设备。
如图1b所示,该模型训练方法的具体流程可以如下:
110、获取教师模型和学生模型。
其中,教师模型(Teacher Model)是指在知识融合中具有指导作用和参考作用的模型,学生模型(Student Model)是指在知识融合中待融合的目标模型。具体地,可以通过多个教师模型对学生模型进行知识融合,也即,以教师模型作为参考对学生模型进行知识融合的训练,使学生模型能够得到更好的性能和通用性。
其中,教师模型和学生模型可以是各种神经网络模型,例如循环神经网络(Recurrent Neural Network,RNN)、卷积神经网络(Convolutional Neural Networks,CNN)、深度神经网络(Deep Neural Networks,DNN)等,本实施例对此不作限制。
其中,教师模型可以是由其他设备进行训练后,发送给该模型训练装置。例如,可以获取由多对训练样本及其标签组成的训练集,使用该训练集对初始模型进行预训练,待初始模型完成预训练后,得到性能较好的教师模型,并将其发送给该模型训练装置。
其中,教师网络在预训练后,可以存储在数据库中,例如可以存储在区块链的共享账本中。在需要对学生模型进行训练时,可以从区块链中选取多个与该学生网络作用相同或相近且高性能的模型,将其作为教师模型,并从区块链的共享账本中调用上述多个教师模型,提供给该模型训练装置。
其中,学生模型可以通过系统开发者在其他设备上构建后,发送给该模型训练装置。例如,系统开发者可以根据模型的功能需求和基础数据,构建出模型的初始参数,经过对初始参数中的部分参数进行适当调整后,使得该模型能够执行目标任务或类似的基础任务后,获取该模型作为学生模型。由于学生模型中的模型参数为初始值或简单调整后的参数值,使得学生模型的性能较低,需要对学生模型进行训练。
其中,教师模型和学生模型还可以通过深度学习得到,深度学习为通过建立具有阶层结构的神经网络,在计算系统中实现人工智能的机器学习。由于具有阶层结构的神经网络能够对输入信息进行逐层提取和筛选,因此深度学习具有表征学习能力,可以实现端到端的监督学习和非监督学习。
需要说明的是,使用教师模型对学生模型进行训练的过程中,教师模型的参数是固定不变的,学生模型的参数可以随着训练的进行而迭代更新。此外,可以获取任意数量的教师模型对学生模型进行训练,本实施例对此不作限制。
120、获取第一预测结果。
其中,第一预测结果是指教师模型对样本数据集中的样本数据进行预测得到的输出结果。
其中,样本数据是指训练学生模型所用到的训练样本,样本数据集是指多个样本数据所组成的集合。
其中,样本数据可以是随机抽取的无标签文本数据,例如一个句子、一段文本、一篇文章等。其中,第一预测结果可以是教师模型对样本数据进行预测得到的分类结果。例如,分类结果可以是该文本数据对应的领域分类、情感分类、意图分类等。
在一具体场景中,以第一预测结果是样本数据的领域分类为例,样本数据可以为一段描述医疗器材的文本,例如该文本的内容为“随着科技的进步,现在的主流助听器都是带有芯片的数字助听器,可以根据佩戴者的听力检测结果通过电脑编程来调整参数”。则教师模型可以预测出该样本数据的第一预测结果为:医疗领域,或者预测出的第一预测结果为:80%为医疗领域、20%为财经领域。
在一具体场景中,以第一预测结果是样本数据的情感分类为例,样本数据可以为一段商品订单评价的文本,例如该文本的内容为:“该商品的质量很好、性价比高,非常值得推荐”。则教师模型可以预测出该样本数据的第一预测结果为:积极情感,或者预测出的第一预测结果为:85%为积极情感、10%为消息情感、5%为中立情感。
其中,第一预测结果中包含的单个分类可以称为独立类别,包含的多个分类可以称为联合类别。可以理解的是,为保证对学生模型训练的全面性和通用性,获取的教师模型通常是用于预测出联合类别的模型,因此第一预测结果相当于教师模型对样本数据进行预测,得到样本数据在联合类别集合中的概率分布。
需要说明的是,对于联合类别集合的概率分布,其包含的独立类别可以继续向下扩展转换为新的联合类别。例如,独立类别为医疗领域,可以将其继续向下扩展为“临床医疗领域”和“科研医疗领域”,又例如独立类别为积极情感,可以将其继续向下扩展为“持续积极情感”和“短暂积极情感”。对于向下扩展的维度和数量,本实施例不作限制。
在一具体场景中,获取的多个教师模型可以分别用于预测不同的联合类别,例如,教师模型A可以用于预测医疗领域和财经领域的联合类别,教师模型B可以用于预测体育领域和娱乐领域的联合类别。
由此,在本申请实施例中,通过使用不同预测能力的教师模型对学生模型进行知识融合的训练,能够有效地提升对学生模型训练的通用性和全面性,从而能够提升对学生模型的训练效果;并结合大量随机抽取的无标签样本数据对学生模型进行训练,无需对样本数据进行标注、样本数据获取方便,从而能够提升对学生模型的训练效率。
130、获取所述教师模型对所述第一预测结果的可信度。
其中,可信度可以用于表征第一预测结果的可信程度。具体地,可信度为“uncertainty”,用于表征第一预测结果的不确定性。其中,可信度可以具有对应的值,例如0.1、
Figure BDA0003395256630000111
30%等。可以理解的是,可信度的值越大,则表示可信程度越低。
可以理解的是,虽然教师模型的性能较好,能够对属于自身预测类别的样本数据进行高准确率预测。但对于不是自身预测类别的样本数据,教师模型可能仍会给出该样本数据在自身预测类别的较高概率分布,导致预测结果不准确,该情况称为教师模型的过自信预测。
需要说明的是,可以通过比较样本数据在联合类别上的真实概率分布,与教师模型输出的第一预测结果之间的关系,来确定教师模型对第一预测结果的可信度。
在一具体实施例中,可以获取N个教师模型T1至Tn,以及样本数据集D中的任意一个样本数据x。对于其中任意一个教师模型Ti,其能够在类别集合Yi对样本数据x进行预测;对于所有教师模型,能够在联合类别集合
Figure BDA0003395256630000121
对样本数据x进行预测。
其中,教师模型Ti的预测结果与样本数据x在联合类别集合Y中的真实概率分布y之间的关系可以用公式(1)表示:
Figure BDA0003395256630000122
其中,
Figure BDA0003395256630000123
表示对样本数据x进行预测,预测出样本数据x为y的概率,
Figure BDA0003395256630000124
表示y属于类别集合Yi的概率。
根据贝叶斯定理,可以进一步得到公式(2):
Figure BDA0003395256630000125
其中,
Figure BDA0003395256630000126
表示样本数据x属于类别集合Yi的真实概率,
Figure BDA0003395256630000127
表示预测出样本数据x属于类别集合Yi的概率。
可以理解的是,只需比较样本数据x属于类别集合Yi的预测概率与真实概率之间的差异,即可得到教师模型对第一预测结果的可信度,进而判断出教师模型是否进行过自信预测。但是在本申请实施例中,样本数据是随机抽取的无标签数据,因此无法直接通过上述方式获取可信度。
为避免出现教师模型的过自信预测,可以获取教师模型对第一预测结果的可信度。具体地,步骤“获取所述教师模型对所述第一预测结果的可信度”,可以包括:
获取第一概率值,所述第一概率值表征所述教师模型在目标状态下对所述样本数据进行预测的准确率;其中,在所述目标状态下的所述教师模型中包括至少一个被随机掩盖的神经节点;
计算所述第一概率值的熵,将所述熵确定为所述教师模型对所述第一预测结果的可信度,返回并执行获取所述第一概率值。
其中,步骤“获取第一概率值”,可以包括:
获取每个教师模型的第二概率值和预测次数,所述第二概率值表征所述每个教师模型的在不同目标状态下对所述样本数据进行预测的准确率;
将所述预测次数对应数量的所述第二概率值进行求和运算,得到第二概率值总和;
将所述得到第二概率值总和与所述预测次数进行相除,得到所述第一概率值。
其中,步骤“计算所述第一概率值的熵”,可以包括:
获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;
计算所述第一概率值与其对数的乘积值;
将所述预测次数对应数量的所述乘积值进行求和运算,得到所述乘积值总和,将所述乘积值总和的负数作为所述概率值的熵。
其中,可以通过M-C Dropout(Monte-Carlo Dropout)算法来使教师模型处于目标状态,通过获取教师模型在目标状态下对样本数据预测的准确率,得到第一概率值和第二概率值。
如图1e所示,教师模型的结构中可以包括多个神经节点13,每次对教师模型执行M-C Dropout算法时,可以将至少一个神经节点13进行随机遮盖,得到被遮盖的失活神经节点14,失活神经节点14不参与教师模型对样本数据x的预测。其中,随机遮盖可以包括对随机数量的神经节点13进行遮盖;也可以包括在不变的预设数量下,对随机位置的神经节点13进行遮盖。
其中,第二概率值可以表征每个教师模型的在不同目标状态下对样本数据进行预测的准确率。其中,不同目标状态的个数可以取决于M-C Dropout算法的执行次数,每个目标状态均对应一个第二概率值,也即,对教师模型每执行一次M-C Dropout算法,即可得到一个第二概率值。
其中,第一概率值可以根据第二概率值得到的概率。例如,可以对第二概率值进行统计、处理后,得到第一概率值。
具体地,可以计算教师模型在目标状态下对样本数据进行预测的准确率,将该准确率作为第一概率值。需要说明的是,由于本申请实施例中的样本数据为无标签数据,因此该准确率需要人工记录。
具体地,第一概率值可以用公式(3)表示:
Figure BDA0003395256630000141
其中,
Figure BDA0003395256630000142
表示教师模型Ti的第一概率值,
Figure BDA0003395256630000143
表示教师模型Ti的第二概率值,
Figure BDA0003395256630000144
表示对教师模型Ti执行M-C Dropout算法进行第K次采样的参数,
Figure BDA0003395256630000145
Figure BDA0003395256630000146
表示对教师模型Ti在目标状态下进行K次预测的准确率总和进行求平均。
相应的,可以将第一概率值的熵确定为教师模型对第一预测结果的可信度,可以用公式(4)表示:
Figure BDA0003395256630000147
其中,ui表示教师模型Ti对第一预测结果的可信度,H(pi)表示第一概率值的熵,|Yi|表示类别集合Yi所包含独立类别的个数,
Figure BDA0003395256630000148
表示教师模型Ti在目标状态下预测类别v的第一概率值。
其中,公式(4)中的可信度u以具体数值呈现,也即,计算第一概率值的熵相当于对可信度继续赋值。
可选的,对于N个教师模型T1至Tn,可以重复执行上述步骤“获取第一概率值”和步骤“计算第一概率值的熵”,并结合公式(3)和公式(4)获取教师模型T1至Tn对应的可信度u1至un
由上可知,本申请实施例可以通过随机抽取的无标签样本数据,使用M-C Dropout算法获取教师模型对第一预测结果的可信度,无需对样本数据进行标注、获取样本数据方便,并且可以通过可信度对教师模型的预测能力和模型性能进行判断,以避免因教师模型的过自信预测导致学生模型的训练效果不佳。
140、根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果。
第二预测结果可以是对第一预测结果进行处理后的结果。例如,可以对第一预测结果进行加权、求平均、筛选、修改数值等处理,得到第二预测结果。对于第一预测结果的处理方式,本实施例对此不作限制。
其中,第二预测结果可以作为对学生模型进行训练的标签数据。可以理解的是,由于获取的教师模型本身具有较好的性能,在对其第一预测结果进行优化处理后,得到的第二预测结果对学生模型的训练具有较高的参考性、指导性,因此将其作为训练学生模型的标签数据,能够更好地提升学生模型的性能。
具体地,步骤“根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果”,可以包括:
根据每个第一预测结果的可信度,计算所述每个第一预测结果对应的加权值;
对所述每个第一预测结果结合所述加权值进行加权求和,得到更新后的所述第一预测结果作为第二预测结果。
具体地,步骤“根据每个第一预测结果的可信度,计算所述每个第一预测结果对应的加权值”,可以包括:
获取所述每个第一预测结果的可信度参数,以及所述第一预测结果的可信度参数总和;
计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比,将所述占比确定为所述每个第一预测结果对应的加权值。
具体地,步骤“获取所述每个第一预测结果的可信度参数”,可以包括:
获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;
计算所述每个第一预测结果的可信度与所述类别数量的对数的比值;
将预设值与所述比值的差值确定为所述每个第一预测结果的可信度参数。
其中,可信度参数是指在计算加权值的过程中,用到的与可信度相关的中间参数。具体地,可信度参数可以用公式(5)表示:
ci=1-ui/log|Yi| (5);
其中,ci表示教师模型Ti的可信度参数,ui表示教师模型Ti对第一预测结果的可信度,|Yi|表示类别集合Yi所包含独立类别的个数。
其中,可信度参数与可信度的表征相似,也可以用于表征教师模型对样本数据预测的可信程度。需要说明的是,可信度参数的值越大,表示对象的可信程度越高,相当于机器学习领域中的置信度。
可选的,步骤“计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比”,可以包括:
对所述每个第一预测结果的可信度参数做e次幂运算,得到第一参数;
对所述可信度参数总和做e次幂运算,得到第二参数;
对所述第一参数与第二参数进行相除,得到所述占比。
具体地,该占比可以用公式(6)表示:
Figure BDA0003395256630000161
其中,wi表示教师模型Ti的第一预测结果对应的加权值,cj表示第j个教师模型的可信度参数,τ为计算加权值用到的超参数(通常预设τ=1),exp表示以e为底的指数运算符号,exp(ci/τ)表示计算加权值用到的第一参数,
Figure BDA0003395256630000162
表示计算加权值用到的第二参数。
具体地,在计算出各教师模型对各自第一预测结果的加权值后,第二预测结果可以用公式(7)表示:
Figure BDA0003395256630000163
其中,
Figure BDA0003395256630000164
表示综合多个教师模型输出后获取的第二预测结果,Ti(x)为第i个教师模型输出的第一预测结果,N表示获取教师模型的总数。
在一具体场景中,请参阅图1c,图1c是本申请实施例提供的训练学生模型的过程示意图。
如图1c所示,可以获取两个教师模型对学生模型进行训练,第一个教师模型对样本数据的第一预测结果为y1和y2,第二个教师模型对样本数据的第一预测结果用y3、y4、y5。例如,上述y1至y5可以表示文本属于5个不同的领域类别、情感类别、意图类别等,本实施例对此不作限制。
假设第一个教师模型的第一预测结果为该样本数据属于y1的概率为30%、属于y2的概率为70%,假设第二个教师模型的第一预测结果为该样本数据属于y3的概率为10%、属于y4的概率为35%、属于y5的概率为55%。
具体地,可以计算出上述两个教师模型对第一预测结果的可信度分别为u1为0.2、u2为0.8。可以理解的是,由于本申请实施例中的可信度表征第一预测结果的不确定性,因此相较于第二个教师模型,第一个教师模型的第一预测结果的可信程度较高。由此,可以认为该样本数据的类别接近或属于第一个教师模型的可预测类别。还可以理解的是,若不同教师模型之间的可信度相差越大,越能够推测出该样本数据本身与某个教师模型的可预测类别相同或越接近,使得该样本数据对于学生模型训练的指导性、倾向性更强。
具体地,可以根据上述公式(5)、(6)、(7)计算出对该样本数据的第二预测结果,即,对两个教师模型的第一预测结果进行加权求和后的预测结果。例如,最终求得的第二预测结果可以为:该样本数属于y1的概率为10%、属于y2的概率为50%、y3的概率为5%、属于y4的概率为15%、属于y5的概率为20%。可以看出,相较于第二个教师模型,由于第一个教师模型的可信度较小,使得第一个教师模型可以获得更大的加权值,在该第二预测结果中,即可呈现出对第一个教师模型的第一预测结果更可信。
可选的,步骤“根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果”,可以包括:
将不同的所述可信度中的每个可信度进行逐一对比;
将最小可信度对应的第一预测结果,确定为所述第二预测结果。
其中,对于最小可信度以外对应的第一预测结果,其第二预测结果中概率分布可以用0表示。
具体地,本实施例中的第二预测结果可以用公式(8)表示:
Figure BDA0003395256630000171
其中,
Figure BDA0003395256630000173
表示第i*个教师模型对样本数据x的第一预测结果,i*表示目标教师模型的代号,
Figure BDA0003395256630000172
表示最小可信度对应的教师模型,/log|Yi|表示目标输出的激活函数,|Yi|表示类别集合Yi所包含独立类别的个数。
在一具体场景中,请参阅图1d,图1d是本申请实施例提供的训练学生模型的过程示意图。
如图1d所示,继续采用图1c中的示例,可以获取两个教师模型对学生模型进行训练,第一个教师模型对样本数据的第一预测结果为y1和y2,第二个教师模型对样本数据的第一预测结果用y3、y4、y5。仍假设第一个教师模型的第一预测结果为该样本数据属于y1的概率为30%、属于y2的概率为70%,假设第二个教师模型的第一预测结果为该样本数据属于y3的概率为10%、属于y4的概率为35%、属于y5的概率为55%。上述两个教师模型对第一预测结果的可信度分别为u1为0.2、u2为0.8。
具体地,相较于第一个教师模型,第一个教师模型对第一预测结果的可信度较低,可以直接采用第一个教师模型的第一预测结果作为第二预测结果,丢弃第二个教师模型的第一预测结果,也即,将第二个教师模型的第一预测结果的概率设置为0。例如,最终得到的第二预测结果可以为该样本数据属于y1的概率为30%、属于y2的概率为70%,假设第二个教师模型的第一预测结果为该样本数据属于y3的概率为0、属于y4的概率为0、属于y5的概率为0。
可以理解的是,在该示例中,由于第一个教师模型的第一预测结果的可信度较小,可以认为该样本数据的类别接近或属于第一个教师模型的可预测类别,第一个教师模型的输出相较于第二个教师模型的输出更准确、指导性和参考性更强,因此可以直接将第一个教师模型的第一预测结果确定为训练学生模型的标签数据,从而后续能够提升对学生模型的训练效果和效率。
150、获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异。
其中,差异可以是指第二预测结果表征的概率分布,与学生模型预测样本数据的概率分布之间的差值。
如图1c和图1d所示,第二预测结果与学生模型的预测结果是不同的,可以计算两者之间的差值得到该差异。例如,以独立类别y1为例,图1c的第二预测结果中y1的概率为10%,学生模型的预测结果中y1的概率可以为12%,则该差异可以为2%或-2%。相应的,可以计算出学生模型的预测结果与第二预测结果中y2、y3、y4和y5之间的差异。
160、基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
其中,学生模型的参数可以包括神经元的权值、偏置量等。
其中,可以通过反向传播(Backpropagation)算法对学生模型的参数进行更新,得到更新后的学生模型的参数。具体地,在反向传播过程中,可以通过对学生模型参数的修改,改善学生模型的损失函数。其中,当学生模型对样本数据的预测结果与第二预测结果的差异足够小时,或当损失函数小于预设阈值,或迭代次数大于预设次数时,停止训练该学生模型。
其中,损失函数可以为KL散度损失函数,KL散度损失函数可以由学生模型的预测结果与第二预测结果之间的相对熵来构建。
具体的,训练学生模型的KL散度损失函数可以用公式(9)表示:
Figure BDA0003395256630000191
其中,
Figure BDA0003395256630000192
表示学生模型的预测结果与第二预测结果之间的相对熵,S(x)表示学生模型对样本数据x的预测结果,T(x)表示教师模型对样本数据x的第二预测结果。
可以理解的是,该KL散度损失函数可以实现后续不断调整、更新学生模型的参数,不断降低相对熵的大小,直至相对熵收敛。
其中,损失函数还可以是其他损失函数,例如均方误差损失函数、JS散度损失函数等,本实施例对此不作限制。
需要说明的是,对于训练好的学生模型,可以将其理解为是集成了多个教师模型能力后的“全能模型”,该全能模型的参数量小于参与训练教师模型参数量的总和,从而实现降低部署开销的目的。
由上可知,本申请实施例中可以通过获取多个教师模型对第一预测结果的可信度,判断出各教师模型对第一预测结果的可信程度,并根据可信度对第一预测结果进行优化处理得到第二预测结果。由于可信度能够表征教师模型对第一预测结果的可信程度,使得第二预测结果相较于第一预测结果更准确,使用第二预测结果作为训练学生模型的标签数据或监督信号,能够避免教师模型的过自信预测。
并且,本申请实施例的方法只需借助可信度对教师模型的预测结果进行调整,无需考虑各教师模型之间、教师模型与学生模型之间的结构差异,例如在现有对模型的知识融合中,需要将学生模型的隐层输出向量与多个教师模型的输出向量进行对齐,导致对学生模型进行知识融合的效率低、泛化性较差。从而本申请实施例能够在克服教师模型的过自信预测,以及模型之间结构差异的基础上,提升对学生模型的训练效果和效率。
根据上述实施例所描述的方法,以下将作进一步详细说明。
在本实施例中,将以该模型训练装置具体集成在服务器为例,对本申请实施例的方法进行详细说明。
如图2所示,一种模型训练方法具体流程如下:
210、获取教师模型和学生模型。
其中,教师模型可以是知识融合中具有指导作用的参考模型,学生模型可以是知识融合中待融合的目标模型。
其中,教师模型可以是由其他设备进行训练后,发送服务器。学生模型可以通过系统开发者在其他设备上构建后,发送给服务器。
220、获取第一预测结果。
其中,第一预测结果可以由教师模型对样本数据集中的样本数据进行预测得到。
其中,样本数据是指训练学生模型所用到的训练样本,样本数据集是指多个样本数据所组成的集合。
其中,样本数据可以是随机抽取的无标签文本数据,例如一个句子、一段文本、一篇文章等。其中,第一预测结果可以是教师模型对样本数据进行预测得到的分类结果。例如,分类结果可以是该文本数据对应的领域分类、情感分类、意图分类等。
230、根据所述教师模型输出的第一概率值,获取所述教师模型对所述第一预测结果的可信度。
可选地,本实施例中,步骤“获取所述教师模型对所述第一预测结果的可信度”,可以包括:
获取第一概率值,所述第一概率值表征所述教师模型在目标状态下对所述样本数据进行预测的准确率;其中,在所述目标状态下的所述教师模型中包括至少一个被随机掩盖的神经节点;
计算所述第一概率值的熵,将所述熵确定为所述教师模型对所述第一预测结果的可信度,返回并执行获取所述第一概率值。
可选地,本实施例中,步骤“获取第一概率值”,可以包括:
获取每个教师模型的第二概率值和预测次数,所述第二概率值表征所述每个教师模型的在不同目标状态下对所述样本数据进行预测的准确率;
将所述预测次数对应数量的所述第二概率值进行求和运算,得到第二概率值总和;
将所述得到第二概率值总和与所述预测次数进行相除,得到所述第一概率值。
可选地,本实施例中,步骤“计算所述第一概率值的熵”,可以包括:
获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;
计算所述第一概率值与其对数的乘积值;
将所述预测次数对应数量的所述乘积值进行求和运算,得到所述乘积值总和,将所述乘积值总和的负数作为所述概率值的熵。
其中,具体地,可以通过M-C Dropout(Monte-Carlo Dropout)算法来使教师模型处于目标状态,通过获取教师模型在目标状态下对样本数据预测的准确率,得到第一概率值和第二概率值,并将第一概率值的熵确定为教师模型对第一预测结果的可信度。
241、根据多个所述可信度之间的比较结果,确定第二预测结果。
其中,比较结果可以是多个可信度之间的大小。具体地,可以将不同的可信度中的每个可信度进行逐一对比,得到比较结果。可以理解的是,通过逐一对比的方式,可以获取所有可信度中的最小可信度,进而能够获取可信程度最高的第一预测结果。
其中,具体地,可以将最小可信度对应的第一预测结果,确定为所述第二预测结果。相应地,对于最小可信度以外对应的第一预测结果,其第二预测结果中概率分布可以用0表示。
可选的,在步骤230后,还可以将步骤241替换为步骤242。
242、根据所述可信度对每个第一预测结果进行加权,确定第二预测结果。
其中,加权可以是对每个第一预测结果结合加权值进行加权求和。
可选地,本实施例中,加权值可以通过以下步骤得到:
获取所述每个第一预测结果的可信度参数,以及所述第一预测结果的可信度参数总和;
计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比,将所述占比确定为所述每个第一预测结果对应的加权值。
其中,步骤“计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比”,可以包括:
对所述每个第一预测结果的可信度参数做e次幂运算,得到第一参数;
对所述可信度参数总和做e次幂运算,得到第二参数;
对所述第一参数与第二参数进行相除,得到所述占比。
250、获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异。
其中,差异可以是指第二预测结果表征的概率分布,与学生模型预测样本数据的概率分布之间的差值。
260、基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
其中,学生模型的参数可以包括神经元的权值、偏置量等。
其中,可以通过反向传播算法对学生模型的参数进行更新,得到更新后的学生模型的参数。具体地,在反向传播过程中,可以通过对学生模型参数的修改,改善学生模型的损失函数。
其中,损失函数可以为KL散度损失函数,KL散度损失函数可以由学生模型的预测结果与第二预测结果之间的相对熵来构建。
由上可知,本申请实施例可以本实施例可以通过服务器获取教师模型和学生模型;获取第一预测结果;根据所述教师模型输出的第一概率值,获取所述教师模型对所述第一预测结果的可信度;根据多个所述可信度之间的比较结果,确定第二预测结果,或者根据所述可信度对每个第一预测结果进行加权,确定第二预测结果;获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;基于所述差异,更新所述学生模型的参数,以训练所述学生模型。由此,本申请实施例可以通过教师模型对样本数据预测的可信度,准确地确定出样本数据对应的标签,从而能够在克服教师模型过自信预测和结构差异的基础上,提升对学生模型的训练效果和效率。
为了更好地实施以上方法,本申请实施例还提供一种模型训练装置,该模型训练装置具体可以集成在电子设备中,该电子设备可以为终端、服务器等设备。其中,终端可以为手机、平板电脑、智能蓝牙设备、笔记本电脑、个人电脑等设备;服务器可以是单一服务器,也可以是由多个服务器组成的服务器集群。
比如,在本实施例中,将以模型训练装置具体集成在服务器为例,对本申请实施例的方法进行详细说明。
例如,如图3所示,该模型训练装置可以包括模型获取单元301、第一预测结果获取单元302、可信度获取单元303、第二预测结果获取单元304、差异获取单元305以及模型训练单元306,如下:
(一)模型获取单元301;
模型获取单元301,用于获取教师模型和学生模型。
(二)第一预测结果获取单元302;
第一预测结果获取单元302,用于根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果。
(三)可信度获取单元303;
可信度获取单元303,用于获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度。
可选的,在一些实施例中,所述可信度获取单元303可以包括获取子单元和第三计算子单元,如下:
获取子单元,用于获取第一概率值,所述第一概率值表征所述教师模型在目标状态下对所述样本数据进行预测的准确率;其中,在所述目标状态下的所述教师模型中包括至少一个被随机掩盖的神经节点;
第三计算子单元,用于计算所述第一概率值的熵,将所述熵确定为所述教师模型对所述第一预测结果的可信度,返回并执行获取所述第一概率值。
可选的,在一些实施例中,所述获取子单元具体可以用于获取每个教师模型的第二概率值和预测次数,所述第二概率值表征所述每个教师模型的在不同目标状态下对所述样本数据进行预测的准确率;将所述预测次数对应数量的所述第二概率值进行求和运算,得到第二概率值总和;将所述得到第二概率值总和与所述预测次数进行相除,得到所述第一概率值。
可选的,在一些实施例中,所述第三计算子单元具体可以用于获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;计算所述第一概率值与其对数的乘积值;将所述预测次数对应数量的所述乘积值进行求和运算,得到所述乘积值总和,将所述乘积值总和的负数作为所述概率值的熵。
(四)第二预测结果获取单元304;
第二预测结果获取单元304,用于根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果。
可选的,在一些实施例中,所述第二预测结果获取单元304可以包括第一计算子单元和加权子单元,如下:
第一计算子单元,用于根据每个第一预测结果的可信度,计算所述每个第一预测结果对应的加权值;
第二计算子单元,用于对所述每个第一预测结果结合所述加权值进行加权求和,得到更新后的所述第一预测结果作为第二预测结果。
可选的,在一些实施例中,所述第一计算子单元具体可以用于获取所述每个第一预测结果的可信度参数,以及所述第一预测结果的可信度参数总和;计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比,将所述占比确定为所述每个第一预测结果对应的加权值。
可选的,在一些实施例中,所述第一计算子单元具体可以用于获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;计算所述每个第一预测结果的可信度与所述类别数量的对数的比值;将预设值与所述比值的差值确定为所述每个第一预测结果的可信度参数。
可选的,在一些实施例中,所述第一计算子单元具体可以用于对所述每个第一预测结果的可信度参数做e次幂运算,得到第一参数;对所述可信度参数总和做e次幂运算,得到第二参数;对所述第一参数与第二参数进行相除,得到所述占比。
可选的,在一些实施例中,所述第二预测结果获取单元304还可以包括对比子单元和结果确定子单元,如下:
对比子单元,用于将不同的所述可信度中的每个可信度进行逐一对比;
结果确定子单元,用于将最小可信度对应的第一预测结果,确定为所述第二预测结果。
(五)差异获取单元305;
差异获取单元305,用于获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异。
(六)模型训练单元306;
模型训练单元306,用于基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
具体实施时,以上各个单元可以作为独立的实体来实现,也可以进行任意组合,作为同一或若干个实体来实现,以上各个单元的具体实施可参见前面的方法实施例,在此不再赘述。
由上可知,本实施例的模型训练装置可以由模型获取单元301获取教师模型和学生模型;由第一预测结果获取单元302获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;由可信度获取单元303获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;由第二预测结果获取单元304根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;由差异获取单元305获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;由模型训练单元306基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
由此,本申请实施例可以通过教师模型对样本数据预测的可信度,准确地确定出样本数据对应的标签,从而能够在克服教师模型过自信预测和结构差异的基础上,提升对学生模型的训练效果和效率。
本申请实施例还提供一种电子设备,该电子设备可以为终端、服务器等设备。其中,终端可以为手机、平板电脑、智能蓝牙设备、笔记本电脑、个人电脑,等等;服务器可以是单一服务器,也可以是由多个服务器组成的服务器集群,等等。
在一些实施例中,该模型训练装置还可以集成在多个电子设备中,比如,模型训练装置可以集成在多个服务器中,由多个服务器来实现本申请的模型训练方法。
在本实施例中,将以本实施例的电子设备为例进行详细描述,比如,如图4所示,其示出了本申请实施例所涉及的电子设备的结构示意图,具体来讲:
该电子设备可以包括一个或者一个以上处理核心的处理器401、一个或一个以上计算机可读存储介质的存储器402、电源403、输入模块404以及通信模块405等部件。本领域技术人员可以理解,图4中示出的电子设备结构并不构成对电子设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。其中:
处理器401是该电子设备的控制中心,利用各种接口和线路连接整个电子设备的各个部分,通过运行或执行存储在存储器402内的软件程序和/或模块,以及调用存储在存储器402内的数据,执行电子设备的各种功能和处理数据,从而对电子设备进行整体监控。在一些实施例中,处理器401可包括一个或多个处理核心;在一些实施例中,处理器401可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器401中。
存储器402可用于存储软件程序以及模块,处理器401通过运行存储在存储器402的软件程序以及模块,从而执行各种功能应用以及数据处理。存储器402可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据电子设备的使用所创建的数据等。此外,存储器402可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。相应地,存储器402还可以包括存储器控制器,以提供处理器401对存储器402的访问。
电子设备还包括给各个部件供电的电源403,在一些实施例中,电源403可以通过电源管理系统与处理器401逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。电源403还可以包括一个或一个以上的直流或交流电源、再充电系统、电源故障检测电路、电源转换器或者逆变器、电源状态指示器等任意组件。
该电子设备还可包括输入模块404,该输入模块404可用于接收输入的数字或字符信息,以及产生与用户设置以及功能控制有关的键盘、鼠标、操作杆、光学或者轨迹球信号输入。
该电子设备还可包括通信模块405,在一些实施例中通信模块405可以包括无线模块,电子设备可以通过该通信模块405的无线模块进行短距离无线传输,从而为用户提供了无线的宽带互联网访问。比如,该通信模块405可以用于帮助用户收发电子邮件、浏览网页和访问流式媒体等。
尽管未示出,电子设备还可以包括显示单元等,在此不再赘述。具体在本实施例中,电子设备中的处理器401会按照如下的指令,将一个或一个以上的应用程序的进程对应的可执行文件加载到存储器402中,并由处理器401来运行存储在存储器402中的应用程序,从而实现各种功能,如下:
获取教师模型和学生模型;获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
在一些实施例中,还提出一种计算机程序产品,包括计算机程序或指令,该计算机程序或指令被处理器执行时实现上述任一种模型训练方法中的步骤。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
由上可知,本申请实施例可以通过教师模型对样本数据预测的可信度,准确地确定出样本数据对应的标签,从而能够在克服教师模型过自信预测和结构差异的基础上,提升对学生模型的训练效果和效率。
本领域普通技术人员可以理解,上述实施例的各种方法中的全部或部分步骤可以通过指令来完成,或通过指令控制相关的硬件来完成,该指令可以存储于一计算机可读存储介质中,并由处理器进行加载和执行。
为此,本申请实施例提供一种计算机可读存储介质,其中存储有多条指令,该指令能够被处理器进行加载,以执行本申请实施例所提供的任一种模型训练方法中的步骤。例如,该指令可以执行如下步骤:
获取教师模型和学生模型;获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
其中,该存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、磁盘或光盘等。
根据本申请的一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述实施例中提供的模型训练方面的各种可选实现方式中提供的方法。
由于该存储介质中所存储的指令,可以执行本申请实施例所提供的任一种模型训练方法中的步骤,因此,可以实现本申请实施例所提供的任一种模型训练方法所能实现的有益效果,详见前面的实施例,在此不再赘述。
以上对本申请实施例所提供的一种模型训练方法、装置、电子设备和计算机可读存储介质进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。

Claims (13)

1.一种模型训练方法,其特征在于,包括:
获取教师模型和学生模型;
获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;
获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;
根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;
获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;
基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
2.根据权利要求1所述的方法,其特征在于,所述根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果,包括:
根据每个第一预测结果的可信度,计算所述每个第一预测结果对应的加权值;
对所述每个第一预测结果结合所述加权值进行加权求和,得到更新后的所述第一预测结果作为第二预测结果。
3.根据权利要求2所述的方法,其特征在于,所述根据每个第一预测结果的可信度,计算所述每个第一预测结果对应的加权值,包括:
获取所述每个第一预测结果的可信度参数,以及所述第一预测结果的可信度参数总和;
计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比,将所述占比确定为所述每个第一预测结果对应的加权值。
4.根据权利要求3所述的方法,其特征在于,所述获取所述每个第一预测结果的可信度参数,包括:
获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;
计算所述每个第一预测结果的可信度与所述类别数量的对数的比值;
将预设值与所述比值的差值确定为所述每个第一预测结果的可信度参数。
5.根据权利要求3所述的方法,其特征在于,所述计算所述每个第一预测结果的可信度参数在所述可信度参数总和中的占比,包括:
对所述每个第一预测结果的可信度参数做e次幂运算,得到第一参数;
对所述可信度参数总和做e次幂运算,得到第二参数;
对所述第一参数与第二参数进行相除,得到所述占比。
6.根据权利要求1所述的方法,其特征在于,所述根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果,包括:
将不同的所述可信度中的每个可信度进行逐一对比;
将最小可信度对应的第一预测结果,确定为所述第二预测结果。
7.根据权利要求1所述的方法,其特征在于,所述获取所述教师模型对所述第一预测结果的可信度,包括:
获取第一概率值,所述第一概率值表征所述教师模型在目标状态下对所述样本数据进行预测的准确率;其中,在所述目标状态下的所述教师模型中包括至少一个被随机掩盖的神经节点;
计算所述第一概率值的熵,将所述熵确定为所述教师模型对所述第一预测结果的可信度,返回并执行获取所述第一概率值。
8.根据权利要求7所述的方法,其特征在于,所述获取第一概率值包括:
获取每个教师模型的第二概率值和预测次数,所述第二概率值表征所述每个教师模型的在不同目标状态下对所述样本数据进行预测的准确率;
将所述预测次数对应数量的所述第二概率值进行求和运算,得到第二概率值总和;
将所述得到第二概率值总和与所述预测次数进行相除,得到所述第一概率值。
9.根据权利要求8所述的方法,其特征在于,所述计算所述第一概率值的熵,包括:
获取所述每个第一预测结果的类别数量,所述类别数量表征所述每个第一预测结果包含预测类别的数目;
计算所述第一概率值与其对数的乘积值;
将所述预测次数对应数量的所述乘积值进行求和运算,得到所述乘积值的总和,将所述乘积值总和的负数作为所述概率值的熵。
10.一种模型训练装置,其特征在于,包括:
模型获取单元,用于获取教师模型和学生模型;
第一预测结果获取单元,用于获取第一预测结果,所述第一预测结果由所述教师模型对样本数据集中的样本数据进行预测得到;
可信度获取单元,用于获取所述教师模型对所述第一预测结果的可信度,所述可信度用于表征所述第一预测结果的可信程度;
第二预测结果获取单元,用于根据所述可信度更新所述第一预测结果,将更新后的所述第一预测结果作为第二预测结果;
差异获取单元,用于获取所述第二预测结果与学生模型预测所述样本数据的结果之间的差异;
模型训练单元,用于基于所述差异,更新所述学生模型的参数,以训练所述学生模型。
11.一种电子设备,其特征在于,包括处理器和存储器,所述存储器存储有多条指令;所述处理器从所述存储器中加载指令,以执行如权利要求1~9任一项所述的模型训练方法中的步骤。
12.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有多条指令,所述指令适于处理器进行加载,以执行权利要求1~9任一项所述的模型训练方法中的步骤。
13.一种计算机程序产品,包括计算机程序或指令,其特征在于,所述计算机程序或指令被处理器执行时实现权利要求1~9任一项所述的模型训练方法中的步骤。
CN202111511703.7A 2021-12-06 2021-12-06 模型训练方法、装置、电子设备和存储介质 Active CN114330510B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111511703.7A CN114330510B (zh) 2021-12-06 2021-12-06 模型训练方法、装置、电子设备和存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111511703.7A CN114330510B (zh) 2021-12-06 2021-12-06 模型训练方法、装置、电子设备和存储介质

Publications (2)

Publication Number Publication Date
CN114330510A true CN114330510A (zh) 2022-04-12
CN114330510B CN114330510B (zh) 2024-06-25

Family

ID=81051536

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111511703.7A Active CN114330510B (zh) 2021-12-06 2021-12-06 模型训练方法、装置、电子设备和存储介质

Country Status (1)

Country Link
CN (1) CN114330510B (zh)

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115099988A (zh) * 2022-06-28 2022-09-23 腾讯科技(深圳)有限公司 模型训练方法、数据处理方法、设备及计算机介质
CN115879446A (zh) * 2022-12-30 2023-03-31 北京百度网讯科技有限公司 文本处理方法、深度学习模型训练方法、装置以及设备
CN116206182A (zh) * 2023-01-03 2023-06-02 北京航空航天大学 一种面向单通道图像的高性能深度学习模型及训练方法
CN116594349A (zh) * 2023-07-18 2023-08-15 中科航迈数控软件(深圳)有限公司 机床预测方法、装置、终端设备以及计算机可读存储介质
CN117009534A (zh) * 2023-10-07 2023-11-07 之江实验室 文本分类方法、装置、计算机设备以及存储介质
CN117274615A (zh) * 2023-09-21 2023-12-22 书行科技(北京)有限公司 人体动作预测方法及相关产品
TWI835638B (zh) * 2022-05-04 2024-03-11 國立清華大學 於非對稱策略架構下以階層式強化學習訓練主策略的方法

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111126592A (zh) * 2018-10-30 2020-05-08 三星电子株式会社 输出预测结果、生成神经网络的方法及装置和存储介质
CN111639744A (zh) * 2020-04-15 2020-09-08 北京迈格威科技有限公司 学生模型的训练方法、装置及电子设备
CN112132170A (zh) * 2019-06-25 2020-12-25 国际商业机器公司 使用教师-学生学习模式进行模型训练
CN112733945A (zh) * 2021-01-13 2021-04-30 中国科学技术大学 一种基于高斯dropout的深度学习模型不确定度计算方法
CN112749728A (zh) * 2020-08-13 2021-05-04 腾讯科技(深圳)有限公司 学生模型训练方法、装置、计算机设备及存储介质
CN113744798A (zh) * 2021-09-01 2021-12-03 腾讯医疗健康(深圳)有限公司 组织样本的分类方法、装置、设备和存储介质

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111126592A (zh) * 2018-10-30 2020-05-08 三星电子株式会社 输出预测结果、生成神经网络的方法及装置和存储介质
CN112132170A (zh) * 2019-06-25 2020-12-25 国际商业机器公司 使用教师-学生学习模式进行模型训练
US20200410388A1 (en) * 2019-06-25 2020-12-31 International Business Machines Corporation Model training using a teacher-student learning paradigm
CN111639744A (zh) * 2020-04-15 2020-09-08 北京迈格威科技有限公司 学生模型的训练方法、装置及电子设备
CN112749728A (zh) * 2020-08-13 2021-05-04 腾讯科技(深圳)有限公司 学生模型训练方法、装置、计算机设备及存储介质
CN112733945A (zh) * 2021-01-13 2021-04-30 中国科学技术大学 一种基于高斯dropout的深度学习模型不确定度计算方法
CN113744798A (zh) * 2021-09-01 2021-12-03 腾讯医疗健康(深圳)有限公司 组织样本的分类方法、装置、设备和存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
LEI LI等: "Dynamic Knowledge Distillation for Pre-trained Language Models", 《ARXIV》, 23 September 2021 (2021-09-23) *
LEI LI等: "From Mimicking to Integrating: Knowledge Integration for Pre-Trained Language Models", 《ARXIV》, 11 October 2022 (2022-10-11) *

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
TWI835638B (zh) * 2022-05-04 2024-03-11 國立清華大學 於非對稱策略架構下以階層式強化學習訓練主策略的方法
CN115099988A (zh) * 2022-06-28 2022-09-23 腾讯科技(深圳)有限公司 模型训练方法、数据处理方法、设备及计算机介质
CN115879446A (zh) * 2022-12-30 2023-03-31 北京百度网讯科技有限公司 文本处理方法、深度学习模型训练方法、装置以及设备
CN115879446B (zh) * 2022-12-30 2024-01-12 北京百度网讯科技有限公司 文本处理方法、深度学习模型训练方法、装置以及设备
CN116206182A (zh) * 2023-01-03 2023-06-02 北京航空航天大学 一种面向单通道图像的高性能深度学习模型及训练方法
CN116594349A (zh) * 2023-07-18 2023-08-15 中科航迈数控软件(深圳)有限公司 机床预测方法、装置、终端设备以及计算机可读存储介质
CN116594349B (zh) * 2023-07-18 2023-10-03 中科航迈数控软件(深圳)有限公司 机床预测方法、装置、终端设备以及计算机可读存储介质
CN117274615A (zh) * 2023-09-21 2023-12-22 书行科技(北京)有限公司 人体动作预测方法及相关产品
CN117274615B (zh) * 2023-09-21 2024-03-22 书行科技(北京)有限公司 人体动作预测方法及相关产品
CN117009534A (zh) * 2023-10-07 2023-11-07 之江实验室 文本分类方法、装置、计算机设备以及存储介质
CN117009534B (zh) * 2023-10-07 2024-02-13 之江实验室 文本分类方法、装置、计算机设备以及存储介质

Also Published As

Publication number Publication date
CN114330510B (zh) 2024-06-25

Similar Documents

Publication Publication Date Title
CN114330510B (zh) 模型训练方法、装置、电子设备和存储介质
CN111897941B (zh) 对话生成方法、网络训练方法、装置、存储介质及设备
Pasupa et al. Thai sentiment analysis with deep learning techniques: A comparative study based on word embedding, POS-tag, and sentic features
Setyawan et al. Comparison of multinomial naive bayes algorithm and logistic regression for intent classification in chatbot
Al-Hmouz et al. Modeling and simulation of an adaptive neuro-fuzzy inference system (ANFIS) for mobile learning
CN110705255B (zh) 检测语句之间的关联关系的方法和装置
CN111382190B (zh) 一种基于智能的对象推荐方法、装置和存储介质
CN113761250A (zh) 模型训练方法、商户分类方法及装置
Napoli et al. An agent-driven semantical identifier using radial basis neural networks and reinforcement learning
CN111340112A (zh) 分类方法、装置、服务器
Guo et al. Who is answering whom? Finding “Reply-To” relations in group chats with deep bidirectional LSTM networks
Vinod et al. Natural disaster prediction by using image based deep learning and machine learning
Ramkissoon et al. Legitimacy: an ensemble learning model for credibility based fake news detection
Ding et al. [Retracted] College English Online Teaching Model Based on Deep Learning
Lin et al. A novel personality detection method based on high-dimensional psycholinguistic features and improved distributed Gray Wolf Optimizer for feature selection
CN112989024B (zh) 文本内容的关系提取方法、装置、设备及存储介质
CN111046655A (zh) 一种数据处理方法、装置及计算机可读存储介质
Shen et al. Student public opinion management in campus commentary based on deep learning
CN113821634A (zh) 内容分类方法、装置、电子设备和存储介质
Cuomo et al. A biologically inspired model for describing the user behaviors in a Cultural Heritage environment.
Chiramel et al. Detection of social media platform insults using Natural language processing and comparative study of machine learning algorithms
Karimi et al. Relevant question answering in community based networks using deep lstm neural networks
Wang et al. [Retracted] Evaluation and Analysis of College Students’ Mental Health from the Perspective of Deep Learning
CN114548382B (zh) 迁移训练方法、装置、设备、存储介质及程序产品
Tang [Retracted] Big Data Analysis and Modeling of Higher Education Reform Based on Cloud Computing Technology

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant