CN110807529A - 一种机器学习模型的训练方法、装置、设备及存储介质 - Google Patents

一种机器学习模型的训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
CN110807529A
CN110807529A CN201911046952.6A CN201911046952A CN110807529A CN 110807529 A CN110807529 A CN 110807529A CN 201911046952 A CN201911046952 A CN 201911046952A CN 110807529 A CN110807529 A CN 110807529A
Authority
CN
China
Prior art keywords
machine learning
learning model
model
current
data set
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN201911046952.6A
Other languages
English (en)
Inventor
牛帅程
吴家祥
谭明奎
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tencent Technology Shenzhen Co Ltd
Original Assignee
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tencent Technology Shenzhen Co Ltd filed Critical Tencent Technology Shenzhen Co Ltd
Priority to CN201911046952.6A priority Critical patent/CN110807529A/zh
Publication of CN110807529A publication Critical patent/CN110807529A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Software Systems (AREA)
  • Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Medical Informatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

本申请实施例公开了一种机器学习模型的训练方法、装置、设备及存储介质,在利用当前机器学习模型对训练数据集进行处理得到当前处理结果之后,需要确定是否满足停止条件。如果不满足停止条件,则基于施蒂费尔流形和目标损失函数更新该当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件。其中,因模型参数是基于施蒂费尔流形确定的,使得该模型参数能够严格满足正交约束,从而能够有效地降低机器学习模型中模型参数的冗余程度,从而能够有效地提高机器学习模型的表达能力。

Description

一种机器学习模型的训练方法、装置、设备及存储介质
技术领域
本申请涉及数据处理技术领域,尤其涉及一种机器学习模型的训练方法、装置、设备及存储介质。
背景技术
随着机器学习技术的发展,机器学习技术已广泛地应用于图像分类、语音识别、机器翻译等应用场景中;而且,在这些应用场景中,利用较好的机器学习模型能够得到准确的图像分类结果、语音识别结果以及机器翻译结果等,但是利用较差的机器学习模型却只能给出错误的图像分类结果、语音识别结果以及机器翻译结果等。
基于此可知,在机器学习技术的应用过程中,只有使用较好的机器学习模型才能得到较好的应用效果,如此使得获取到较好的机器学习模型是十分重要的。然而,如何获取到较好的机器学习模型仍是一个亟待解决的技术问题。
发明内容
本申请实施例提供了一种机器学习模型的训练方法、装置、设备及存储介质,能够准确地获取到较好的机器学习模型。
有鉴于此,本申请第一方面提供了一种机器学习模型的训练方法,包括:
利用当前机器学习模型对训练数据集进行处理得到当前处理结果;
若确定不满足停止条件,则基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件;其中,所述停止条件包括:基于所述目标损失函数和所述当前处理结果确定的损失值低于第一阈值,或基于所述目标损失函数和所述当前处理结果确定的损失值的变化率低于第二阈值,或所述模型参数的更新次数达到第三阈值。
本申请第二方面提供了一种机器学习模型的训练装置,包括:
处理单元,用于利用当前机器学习模型对训练数据集进行处理得到当前处理结果;
更新单元,用于若确定未满足停止条件,则基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,并继续由所述处理单元执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果直至满足停止条件;其中,所述停止条件包括:基于所述目标损失函数和所述当前处理结果确定的损失值低于第一阈值,或基于所述目标损失函数和所述当前处理结果确定的损失值的变化率低于第二阈值,或所述模型参数的更新次数达到第三阈值。
本申请第三方面提供了一种设备,其特征在于,所述设备包括处理器以及存储器:
所述存储器用于存储计算机程序;
所述处理器用于根据所述计算机程序执行上述第一方面所述的机器学习模型的训练方法。
本申请第四方面提供了一种计算机可读存储介质,所述计算机可读存储介质用于存储计算机程序,所述计算机程序用于执行上述第一方面所述的机器学习模型的训练方法。
本申请第五方面提供了一种包括指令的计算机程序产品,当其在计算机上运行时,使得所述计算机执行上述第一方面所述的机器学习模型的训练方法。
从以上技术方案可以看出,本申请实施例具有以下优点:
本申请实施例提供的机器学习模型的训练方法中,在利用当前机器学习模型对训练数据集进行处理得到当前处理结果之后,需要确定是否满足停止条件。如果不满足停止条件,则基于施蒂费尔流形和目标损失函数更新该当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件。其中,因停止条件包括:基于目标损失函数和当前处理结果确定的损失值低于第一阈值,或基于目标损失函数和当前处理结果确定的损失值的变化率低于第二阈值,或模型参数的更新次数达到第三阈值,使得基于该停止条件最后获得的机器学习模型能够达到较好的处理效果,如此能够准确地获取到较好的机器学习模型。另外,因模型参数是基于施蒂费尔流形确定的,使得该模型参数能够严格满足正交约束,从而能够有效地降低机器学习模型中模型参数的冗余程度,从而能够有效地提高机器学习模型的表达能力。
附图说明
图1为本申请实施例提供的机器学习模型的训练方法的一种应用场景示意图;
图2为本申请实施例提供的机器学习模型的训练方法的另一种应用场景示意图;
图3为本申请实施例提供的一种机器学习模型的训练方法的流程图;
图4为本申请实施例提供的在Stiefel流形上的参数更新过程特点示意图;
图5为本申请实施例提供的模型参数更新的一种实施方式;
图6为本申请实施例提供的模型参数更新的另一种实施方式;
图7为本申请实施例提供的应用于图像分类的机器学习模型的训练方法的流程图;
图8为本申请实施例提供的机器学习模型验证结果示意图;
图9为本申请实施例提供的一种机器学习模型的训练装置的结构示意图;
图10为本申请实施例提供的一种终端设备的结构示意图;
图11为本申请实施例提供的一种服务器的结构示意图。
具体实施方式
为了使本技术领域的人员更好地理解本申请方案,下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请的说明书和权利要求书及上述附图中的术语“第一”、“第二”、“第三”、“第四”等(如果存在)是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或设备不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
相关技术中,通常采用正则化技术或循环丢弃再训练的方式进行机器学习模型的训练,但是上述两种模型训练方式具有的技术问题为:因采用正则化技术或循环丢弃再训练的方式进行模型训练时只能使得模型参数尽可能地接近正交约束条件,却无法严格地满足正交约束条件,如此使得模型参数的冗余程度只能尽可能地降低,却无法最大程度地降低模型参数的冗余程度,从而导致采用正则化技术或循环丢弃再训练的方式确定的模型参数仍然具有较高的冗余程度,从而降低了机器学习模型的表达能力。
针对背景技术部分的技术问题以及上述两种技术方案所存在的技术问题,本申请实施例提供了一种机器学习模型的训练方法,既能够获取到较好的机器学习模型,又能够最大程度地降低模型参数的冗余程度,从而能够有效地提高机器学习模型的表达能力。
具体的,在本申请实施例提供的机器学习模型的训练方法中,在利用当前机器学习模型对训练数据集进行处理得到当前处理结果之后,需要确定是否满足停止条件。如果不满足停止条件,则基于施蒂费尔流形和目标损失函数更新该当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件。
在本申请实施例提供的机器学习模型的训练方法中,因停止条件包括:基于目标损失函数和当前处理结果确定的损失值低于第一阈值,或基于目标损失函数和当前处理结果确定的损失值的变化率低于第二阈值,或模型参数的更新次数达到第三阈值,使得基于该停止条件最后获得的机器学习模型能够达到较好的处理效果,如此能够准确地获取到较好的机器学习模型。另外,因模型参数是基于施蒂费尔流形确定的,使得该模型参数能够严格满足正交约束条件,从而能够有效地最大程度地降低机器学习模型中模型参数的冗余程度,从而能够有效地提高机器学习模型的表达能力。
应理解,本申请实施例提供的机器学习模型的训练方法可以应用于数据处理设备,如终端设备、服务器等;其中,终端设备具体可以为智能手机、计算机、个人数字助理(Personal Digital Assitant,PDA)、平板电脑等;服务器具体可以为应用服务器,也可以为Web服务器,在实际部署时,该服务器可以为独立服务器,也可以为集群服务器。
若本申请实施例提供的机器学习模型的训练方法由终端设备执行时,则终端设备可以直接基于施蒂费尔流形、目标损失函数以及停止条件更新当前机器学习模型的模型参数。如此,终端设备执行该机器学习模型的训练方法能够周期性地基于施蒂费尔流形进行机器学习模型的模型参数的更新,使得模型训练过程中模型参数能够时刻严格满足正交约束条件,从而能够在获取到较好的机器学习模型的前提下最大程度地降低机器学习模型中模型参数的冗余程度,从而能够有效地提高机器学习模型的表达能力。若本申请实施例提供的机器学习模型的训练方法由服务器执行时,则服务器先基于施蒂费尔流形、目标损失函数以及停止条件更新当前机器学习模型的模型参数,并将最终确定的模型参数发送给终端设备,以便终端设备根据接收的处理后的模型参数更新终端设备的本地机器学习模型。
为了便于理解本申请实施例提供的技术方案,下面结合图1以本申请实施例提供的机器学习模型的训练方法应用于终端设备为例,对本申请实施例提供的机器学习模型的训练方法适用的应用场景进行示例性介绍。其中,图1为本申请实施例提供的机器学习模型的训练方法的一种应用场景示意图。
如图1所示,该应用场景包括:终端设备101和用户102;终端设备101用于执行本申请实施例提供的机器学习模型的训练方法;而且用户102能够在终端设备101上利用训练获得的机器学习模型执行图像分类、语音识别、机器翻译等任务。
终端设备101在利用当前机器学习模型对训练数据集进行处理得到当前处理结果之后,需要确定是否满足停止条件。如果不满足停止条件,则基于施蒂费尔流形和目标损失函数更新该当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件。如此用户102能够利用终端设备101训练获得的机器学习模型执行图像分类、语音识别、机器翻译等任务。需要说明的是,用户102可以是人和/或机器,本申请实施例对此不做具体限定。
应理解,在实际应用中,也可以将本申请实施例提供的机器学习模型的训练方法应用于服务器,参见图2,该图为本申请实施例提供的机器学习模型的训练方法的另一种应用场景示意图。如图2所示,服务器201能够基于施蒂费尔流形、目标损失函数以及停止条件重复地更新当前机器学习模型的模型参数,并将最终确定的模型参数发送给终端设备202,使得终端设备202能够利用接收地模型参数更新本地机器学习模型的模型参数,以便用户203能够利用该参数更新后的本地机器学习模型执行图像分类、语音识别、机器翻译等任务。其中,服务器201基于施蒂费尔流形、目标损失函数以及停止条件重复地更新当前机器学习模型的模型参数,具体可以为:服务器201在利用当前机器学习模型对训练数据集进行处理得到当前处理结果之后,需要确定是否满足停止条件。如果不满足停止条件,则基于施蒂费尔流形和目标损失函数更新该当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件。
应理解,图1和图2所示的应用场景仅为示例,在实际应用中,本申请实施例提供的机器学习模型的训练方法还可以应用于其他需要训练机器学习模型的应用场景,在此不对本申请实施例提供的机器学习模型的训练方法做任何限定。
下面通过实施例对本申请提供的机器学习模型的训练方法进行介绍。
方法实施例一
参见图3,该图为本申请实施例提供的一种机器学习模型的训练方法的流程图。
本申请实施例提供的机器学习模型的训练方法,包括步骤S301-S304:
S301:利用当前机器学习模型对训练数据集进行处理得到当前处理结果。
当前机器学习模型是指在本轮训练过程中使用的机器学习模型。需要说明的是,机器学习模型的训练过程通常包括多轮训练,而且每轮训练中使用的机器学习模型的模型参数是不同的。另外,后一轮训练中机器学习模型的模型参数均是基于前一轮机器学习模型的模型参数确定的。此外,在第一轮训练中,当前机器学习模型的模型参数可以预先设定。
另外,因不同的应用场景需要使用不同的机器学习模型,使得在对机器学习模型的训练过程中也需要参考机器学习模型的应用场景,从而使得不同应用场景下的机器学习模型所具有的处理手段以及所使用的训练数据集均不同。基于此,本申请实施例将以四个应用场景为例进行解释和说明。
作为第一示例,当机器学习模型需要应用于图像分类时,则步骤S301具体可以为:利用当前机器学习模型对训练数据集进行图像分类处理得到当前处理结果。其中,训练数据集包括至少一张图像。需要说明的是,训练数据集不仅包括至少一张图像,还包括该训练数据集对应的实际图像分类结果,以便后续能够基于该训练数据集对应的实际图像分类结果以及当前机器学习模型的当前处理结果进行模型参数的更新或模型优劣程度的判断。
作为第二示例,当机器学习模型需要应用于语音识别时,则步骤S301具体可以为:利用当前机器学习模型对训练数据集进行语音识别处理得到当前处理结果。其中,训练数据集包括至少一段语音。需要说明的是,训练数据集不仅包括至少一段语音,还包括该训练数据集对应的实际语音识别结果,以便后续能够基于该训练数据集对应的实际语音识别结果以及当前机器学习模型的当前处理结果进行模型参数的更新或模型优劣程度的判断。
作为第三示例,当机器学习模型需要应用于机器翻译时,则步骤S301具体可以为:利用当前机器学习模型对训练数据集进行机器翻译处理得到当前处理结果。其中,训练数据集包括至少一份文本。需要说明的是,训练数据集不仅包括至少一份文本,还包括该训练数据集对应的实际文本翻译结果,以便后续能够基于该训练数据集对应的实际文本翻译结果以及当前机器学习模型的当前处理结果进行模型参数的更新或模型优劣程度的判断。
作为第四示例,当机器学习模型需要应用于人脸识别时,则步骤S301具体可以为:利用当前机器学习模型对训练数据集进行人脸识别处理得到当前处理结果。其中,训练数据集包括至少一张图像。需要说明的是,训练数据集不仅包括至少一张图像,还包括该训练数据集对应的实际人脸识别结果,以便后续能够基于该训练数据集对应的实际人脸识别结果以及当前机器学习模型的当前处理结果进行模型参数的更新或模型优劣程度的判断。
以上为本申请实施例提供的四个应用场景的相关内容。需要说明的是,本申请实施例提供的机器学习模型的训练方法不仅能够应用于图像分类、语音识别、机器翻译以及人脸识别等应用场景中,还可以应用于其他应用场景中,本申请实施例对此不做具体限定。
还需要说明的是,本申请实施例中的机器学习模型可以是任一种神经网络模型(例如,残差神经网络(Residual Neural Network,ResNet)、应用于嵌入式设备的轻量级的深层神经网络(例如,MobileNet)、网络架构搜索网络(Neural Architecture Search,NASNet)、长短期记忆网络(Long Short-Term Memory,LSTM)等),也可以是任一种深度学习模型,还可以是其他类型的机器学习模型,本申请实施例对此不做具体限定。
S302:确定是否满足停止条件,若是,则执行步骤S304;若否,则执行步骤S303。
停止条件是指用于衡量是否停止对机器模型进行训练的条件;而且该停止条件可以包括:基于目标损失函数和当前处理结果确定的损失值低于第一阈值,或基于目标损失函数和当前处理结果确定的损失值的变化率低于第二阈值,或模型参数的更新次数达到第三阈值。其中,损失值是指根据当前处理结果和目标损失函数确定的用于衡量当前机器学习模型优劣程度的值。
基于此可知,在本申请实施例中只要确定当前轮训练满足以下三种情况中的至少一种情况即可确定当前轮训练已满足停止条件,该三种情况具体为:
(1)当确定基于目标损失函数和当前处理结果确定的损失值低于第一阈值时,则表征当前机器学习模型的发生错误的概率很低,从而确定当前机器学习模型的处理效果较好,从而可以确定当前机器学习模型已处于较好的状态,从而可以确定满足停止条件。此时无需再对机器学习模型进行更新优化,可以结束对机器学习模型的训练过程,以便提高机器学习模型的训练效率。
(2)当确定基于目标损失函数和当前处理结果确定的损失值的变化率低于第二阈值时,则表征当前机器学习模型的模型参数基本不发生什么变化,从而表征当前机器学习模型的模型参数已收敛,从而可以确定当前机器学习模型的处理效果较好,从而可以确定当前机器学习模型已处于较好的状态,从而可以确定满足停止条件。此时无需再对机器学习模型进行更新优化,可以结束对机器学习模型的训练过程,以便提高机器学习模型的训练效率。
(3)当模型参数的更新次数达到第三阈值时,则表征机器学习模型的模型参数已经历了足够多的的更新次数,从而确定当前机器学习模型的处理效果较好,从而可以确定当前机器学习模型已处于较好的状态,从而可以确定满足停止条件。此时无需再对机器学习模型进行更新优化,可以结束对机器学习模型的训练过程,以便提高机器学习模型的训练效率。
然而,只有在确定当前轮训练不满足上述每一种情况时,才确定当前轮训练不满足停止条件,其具体为:当确定目标损失函数和当前处理结果确定的损失值不低于第一阈值、目标损失函数和当前处理结果确定的损失值的变化率不低于第二阈值、且模型参数的更新次数未达到第三阈值时,则确定不满足停止条件。此时需要继续对当前机器学习模型进行更新优化,以便提高当前机器学习模型的处理效果。
需要说明的是,在本申请实施例中,目标损失函数可以预先设定,尤其可以根据应用场景设定。例如,图像分类、语音识别、机器翻译、以及人脸识别等应用场景均对应于不同的目标损失函数。另外,为了防止模型过拟合,可以在目标损失函数中加入正则化项,如此使得目标损失函数可以包括正则化项。此外,第一阈值、第二阈值和第三阈值可以预先设定,尤其可以根据应用场景设定。
S303:基于施蒂费尔流形和目标损失函数更新当前机器学习模型的模型参数,并返回执行步骤S301。
施蒂费尔(Stiefel)流形是指由满足条件WTW=In的所有矩阵W所组成的集合(如图4所示)。其中,W为m×n的矩阵,而且,在本申请实施例中W为根据模型参数确定的参数矩阵。另外,In为n×n的单位矩阵。需要说明的是,根据图4所示可知,基于施蒂费尔流形更新的模型参数能够始终满足正交约束条件。
在本申请实施例中,在确定当前轮训练不满足停止条件时,则可以基于施蒂费尔流形和目标损失函数更新当前机器学习模型的模型参数,以便后续能够利用更新后的当前机器学习模型进行新一轮训练。其中,因当前机器学习模型的模型参数是基于施蒂费尔流形进行更新的,使得更新后的当前机器学习模型的模型参数也严格地满足正交约束条件,从而保证在每轮更新后的模型参数均严格地满足正交约束条件,从而有效地降低了模型参数的冗余程度。
需要说明的是,在本申请实施例中,对机器学习模型的模型参数的更新过程可以按照全体参数进行更新,也可以按照层级参数进行更新,还可以采用其他方式进行更新。其中,所谓按照层级参数进行更新是指当机器学习模型包括多层网络结构时,可以将对不同层网络结构中的所有参数分别进行更新。例如,当机器学习模型包括第一卷积层、第二卷积层和第一全连接层时,则步骤S303具体可以为:基于施蒂费尔流形和目标损失函数分别更新第一卷积层中的模型参数、第二卷积层中的模型参数和第一全连接层中的模型参数,而且该三个更新过程可以串行执行也可以并行执行。
还需要说明的是,在本申请实施例中,为了便于进行模型参数的更新可以采用矩阵的形式来表示模型参数。其中,当模型参数的更新过程按照全体参数进行更新时,则由机器学习模型中的所有模型参数构成一个参数矩阵。然而,当模型参数的更新过程按照层级参数进行更新时,则由机器学习模型中每层网络结构的模型参数构成一个参数矩阵,例如,可以将全连接层中的所有模型参数构成一个参数矩阵(其中,Cin和Cout分别为输入和输出神经元的个数);也可以将卷积层中的所有模型参数构成一个参数矩阵
Figure BDA0002254363330000102
(其中,Cin和Cout分别为输入和输出层特征图的数目,且h为卷积核的高度,且w为卷积核的宽度)。基于此可知,参数矩阵可以由机器学习模型中所有模型参数构成,也可以由机器学习模型中部分模型参数(例如,一个网络层的模型参数)构成,本申请实施例对此不做具体限定;而且,此处关于“参数矩阵”的内容适用于全文中提及到的“参数矩阵”。
另外,为了提高模型参数的更新效率,本申请实施例还提供了模型参数更新的一种实施方式,如图5所示,在该实施方式中,“基于施蒂费尔流形和目标损失函数更新当前机器学习模型的模型参数”具体可以包括步骤S3031-S3033:
S3031:获取目标损失函数对模型参数在施蒂费尔流形上的黎曼梯度。
在本申请实施例中,步骤S3031可以采用现有的或未来出现的任一种能够获取目标损失函数对模型参数在施蒂费尔流形上的黎曼梯度的方法进行实施。
另外,为了提高黎曼梯度的获取效率以及准确率,可以先获取欧几里得空间中的梯度,再获取黎曼梯度。基于此,本申请实施例还提供了步骤S3031的一种实施方式,其具体可以包括以下两个步骤:
第一步:获取目标损失函数对模型参数在欧几里得空间中的梯度,作为初始梯度。
在本申请实施例中,可以采用公式(1)来获取目标损失函数对模型参数在欧几里得空间中的梯度,且公式(1)具体为:
Figure BDA0002254363330000111
其中,GW表示目标损失函数对模型参数在欧几里得空间中的梯度;
Figure BDA0002254363330000112
表示方向导数;L(·)表示目标损失函数;W表示参数矩阵。
需要说明的是,本申请实施例不限定欧几里得空间中梯度的计算过程,可以采用现有的或未来出现的任一方法进行计算。
第二步:根据初始梯度,确定目标损失函数对模型参数在施蒂费尔流形上的黎曼梯度。
目标损失函数对模型参数的黎曼梯度表示目标损失函数对模型参数在欧几里得空间中的梯度在模型参数切空间中的投影;而且,在典范(Canonical)内积下,黎曼梯度可以利用公式(2)进行计算。
式中,
Figure BDA0002254363330000114
表示目标损失函数对模型参数在施蒂费尔流形上的黎曼梯度;GW表示目标损失函数对模型参数在欧几里得空间中的梯度;W表示参数矩阵。
还需要说明的是,在本申请实施例中,在求解黎曼梯度时,公式(2)不仅可以在用Canonical内积下进行计算,还可以在其他內积下进行计算,本申请实施例对此不做具体限定。
S3032:基于凯莱变换和黎曼梯度确定正交矩阵。
在本申请实施例中,在获取到黎曼梯度之后,可以根据该黎曼梯度计算凯莱变换所需的正交矩阵,以便后续能够基于该正交矩阵更新模型参数。其中,基于凯莱变换的新模型参数的计算公式如公式(3)所示。
Figure BDA0002254363330000121
式中,Wnew表示更新后的模型参数;Q表示正交矩阵;τ表示参数更新步长;W表示参数矩阵;I表示单位矩阵;
Figure BDA0002254363330000122
基于公式(2)和(3)可知,在获取到黎曼梯度之后,可以根据黎曼梯度的公式项
Figure BDA0002254363330000123
确定A,再基于
Figure BDA0002254363330000124
确定正交矩阵Q,以便后续能够基于正交矩阵Q与当前机器学习模型的模型参数W的乘积QW,确定更新后的模型参数Wnew
需要说明的是,由于YA(τ)满足YA(τ)TYA(τ)=I,使得YA(τ)满足正交约束条件,从而使得更新后的模型参数Wnew也满足正交约束条件,从而保证了在机器模型的训练过程中,机器模型的模型参数时刻严格满足正交约束条件,从而有效地降低了模型参数的冗余程度。
另外,模型参数W的维度决定了矩阵A的维度(例如,若W为m×n的矩阵,则A为m×m的矩阵),而且矩阵A的维度能够直接影响求逆的复杂度。如此,当模型参数W的维度m值较大时,会导致
Figure BDA0002254363330000126
求逆的复杂度急剧增加,从而导致正交矩阵Q的计算复杂度。
基于此,为了降低正交矩阵Q的计算复杂度,可以对求逆矩阵进行降维,以便降低求逆矩阵的复杂度,从而降低正交矩阵Q的计算复杂度。如此,本申请实施例还提供了获取正交矩阵的一种实施方式,在该实施方式中,当模型参数使用参数矩阵进行表示时,步骤S3032具体可以为:若参数矩阵的行数超过参数矩阵的列数的二倍,则基于应用谢尔曼莫里森公式后的凯莱变换和所述黎曼梯度,确定正交矩阵;若参数矩阵的行数不超过参数矩阵的列数的二倍,则基于凯莱变换和所述黎曼梯度确定正交矩阵。
在一些情况下,参数矩阵的行数可以根据模型参数中输入层的参数个数、卷积核的深度以及宽度确定,且参数矩阵的列数可以由模型参数中输出层的参数个数确定。例如,当模型参数为卷积层对应的参数
Figure BDA0002254363330000127
时,则Cout表示参数矩阵的列数,且Cin*h*w表示参数矩阵的行数。
谢尔曼莫里森公式(Sherman Morrison Formula,SMF)用于对求逆矩阵进行降维处理。
基于应用谢尔曼莫里森公式后的凯莱变换的新模型参数的计算公式如公式(4)所示。
Figure BDA0002254363330000131
式中,Wnew表示更新后的模型参数;Q表示正交矩阵;τ表示参数更新步长;W表示参数矩阵;I表示单位矩阵;A=LRT;L=[GW,W];R=[W,-GW]。
基于上述内容可知,在该实施方式中,假设参数矩阵W的维度为m×n,则当参数矩阵的行数m不超过参数矩阵的列数n的二倍(也就是,m≤2n)时,正交矩阵中的需要求逆的矩阵为而且需要求逆的矩阵
Figure BDA0002254363330000134
的维度为m×m;然而,当参数矩阵的行数超过参数矩阵的列数的二倍(也就是,m>2n)时,则正交矩阵
Figure BDA0002254363330000135
Figure BDA0002254363330000136
中需要求逆的矩阵为而且需要求逆的矩阵
Figure BDA0002254363330000138
的维度为2n×2n。
也就是,当m<2n时,则利用包括m×m维度需要求逆的矩阵
Figure BDA0002254363330000139
的正交矩阵计算公式
Figure BDA00022543633300001310
求解正交矩阵Q;当m>2n时,则利用包括2n×2n维需要求逆的矩阵
Figure BDA00022543633300001311
的正交矩阵计算公式
Figure BDA00022543633300001312
Figure BDA00022543633300001313
求解正交矩阵Q。如此能够有效地降低求解正交矩阵时逆矩阵的计算复杂度,从而能够有效地降低新模型参数的计算复杂度,从而提高了当前机器学习模型的更新效率。
S3033:利用正交矩阵和当前机器学习模型的模型参数的乘积,更新当前机器学习模型的模型参数。
在本申请实施例中,在获取到正交矩阵之后,可以根据正交矩阵与当前机器学习模型的模型参数的乘积(详情请参见公式(3)和(4)),确定更新后的模型参数,以便能够利用更新后的模型参数更新当前机器学习模型的模型参数。
另外,为了提高机器学习模型的训练效率,逆矩阵的求解过程可以采用迭代方法实现。基于此,本申请实施例还提供了模型参数更新的另一种实施方式,如图6所示,在该实施方式中,“基于施蒂费尔流形和目标损失函数更新当前机器学习模型的模型参数”除了包括步骤S3031-S3033以外,还包括步骤S3034和S3035:
S3034:利用目标迭代公式,获取正交矩阵中逆矩阵的值。
目标迭代公式是指利用迭代法求解逆矩阵的过程中所使用的公式;而且,目标迭代公式具体如公式(5)所示。
Xk+1=Xk(3I-ZXk(3I-ZXk)) (5)
式中,Xk+1表示第K次迭代所得的结果;Xk表示第K-1次迭代所得的结果;I表示单位矩阵;Z为求解正交矩阵过程中需要求逆的矩阵,且
Figure BDA0002254363330000141
Figure BDA0002254363330000142
K为迭代的次数,且K=0,1,2,3,……;需要说明的是,随着K的增大,Xk将收敛于逆矩阵Z-1
需要说明的是,由于在利用迭代法求解逆矩阵时,可以在计算矩阵与矩阵之间的乘积时采用多个乘积并行计算的手段,从而能够有效地提高逆矩阵的求解效率,从而能够有效地提高模型参数的更新效率。
还需要说明的是,本申请实施例不限定目标迭代公式,目标迭代公式可以采用公式(5)进行计算,也可以采用其他公式进行计算。
S3035:根据逆矩阵的值,确定正交矩阵和当前机器学习模型的模型参数的乘积。
在本申请实施例中,在获取到逆矩阵
Figure BDA0002254363330000145
之后,可以根据公式
Figure BDA0002254363330000146
Figure BDA0002254363330000147
计算出正交矩阵Q,以便根据QW确定正交矩阵Q和当前机器学习模型的模型参数W的乘积。
以上为本申请实施例提供的模型参数更新的另一种实施方式,在该实施方式中,通过利用迭代法求解逆矩阵并利用处理器(例如,图形处理器(Graphics ProcessingUnit,GPU))并行处理的优势,能够有效地提高逆矩阵的求解效率,从而能够有效地提高模型参数的更新效率。
S304:结束对机器学习模型的训练过程。
以上为本申请实施例提供的机器学习模型的训练方法的具体实施方式,在该实施方式中,在利用当前机器学习模型对训练数据集进行处理得到当前处理结果之后,需要确定是否满足停止条件。如果不满足停止条件,则基于施蒂费尔流形和目标损失函数更新该当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件。其中,因停止条件包括:基于目标损失函数和当前处理结果确定的损失值低于第一阈值,或基于目标损失函数和当前处理结果确定的损失值的变化率低于第二阈值,或模型参数的更新次数达到第三阈值,使得基于该停止条件最后获得的机器学习模型能够达到较好的处理效果,如此能够准确地获取到较好的机器学习模型。另外,因模型参数是基于施蒂费尔流形确定的,使得该模型参数能够严格满足正交约束,从而能够有效地降低机器学习模型中模型参数的冗余程度,从而能够有效地提高机器学习模型的表达能力。
为了便于进一步理解本申请实施例提供的机器学习模型的训练方法,方法实施例二将结合图7对本申请实施例提供的机器学习模型的训练方法进行整体介绍。
方法实施例二
方法实施例二是在方法实施例一的基础上进行的改进,为了简要起见,方法实施例二中与方法实施例一中部分内容相同,在此不再赘述,该内容相同的部分的技术详情请参照方法实施例一中的相关内容。
参见图7,该图为本申请实施例提供的应用于图像分类的机器学习模型的训练方法的流程图。
本申请实施例提供的机器学习模型的训练方法,应用于图像分类,且该机器学习模型为神经网络模型,该方法包括步骤S701-S710:
S701:利用当前机器学习模型对训练数据集进行图像分类处理得到当前处理结果。
S702:确定是否满足停止条件,若是,则执行步骤S710;若否,则执行步骤S703。
S703:获取当前机器学习模型中每层网络结构的模型参数,并分别将每层网络结构的模型参数生成参数矩阵。
需要说明的是,参数矩阵就是上文中的W;而且每层网络均对应一个参数矩阵W。
S704:获取目标损失函数对每个参数矩阵中参数在欧几里得空间中的梯度,作为每个参数矩阵对应的初始梯度。
S705:根据每个参数矩阵对应的初始梯度,确定目标损失函数对每个参数矩阵中参数在施蒂费尔流形上的黎曼梯度。
S706:基于黎曼梯度确定每个参数矩阵对应的正交矩阵,其具体为:若参数矩阵的行数超过参数矩阵的列数的二倍,则基于应用谢尔曼莫里森公式后的凯莱变换和所述黎曼梯度,确定正交矩阵(也就是基于公式(4)确定正交矩阵Q);若参数矩阵的行数不超过参数矩阵的列数的二倍,则基于凯莱变换和黎曼梯度确定正交矩阵(也就是基于公式(3)确定正交矩阵Q)。
S707:利用目标迭代公式,获取每个参数矩阵对应的正交矩阵中逆矩阵的值。
S708:根据每个参数矩阵对应的正交矩阵中逆矩阵的值,确定每个参数矩阵对应的正交矩阵和该参数矩阵的乘积。
S709:利用每个参数矩阵对应的正交矩阵和该参数矩阵的乘积,更新当前机器学习模型中每层网络结构的参数,并返回步骤S701。
S710:结束对机器学习模型的训练过程。
以上为本申请实施例提供的应用于图像分类的神经网络模型的训练方法,在该方法中,通过基于施蒂费尔流形更新每层网络结构的参数,如此能够保证神经网络模型的训练过程中参数能够时刻严格地满足正交约束条件,从而能够降低神经网络模型的模型参数的冗余程度,从而能够有效地提高神经网络模型的表达能力。另外,还通过凯莱变换、逆矩阵的降维处理以及逆矩阵的迭代求解,使得该训练过程能够有效地使用处理器的并行计算资源,从而能够有效地提高模型训练效率,从而有效地提高了神经网络模型的构建效率。
需要说明的是,上述方法实施例二是以机器学习模型为神经网路模型且应用场景为图像分类为例进行说明的。然而,本申请实施例不限定机器学习模型的类型以及应用场景,而且,本申请实施例提供的机器学习模型的训练方法适用于各类机器学习模型的训练,也适用于各种应用场景下的机器学习模型的训练。
另外,为了能够清楚的表明本申请实施例提供的机器学习模型的训练方法的有效性,如图8所示,本申请实施例还提供了利用本申请实施例提供的机器学习模型的训练方法在CIFAR-10数据集上训练获得的MobileNet模型的验证结果、在WebFace数据集上训练获得的MobileFaceNet模型在LFW、CFP和Agedb数据集上做人脸验证任务的准确率。基于图8可知,MobileNet模型的性能提升了1.59个百分点,且MobileFaceNet模型在LFW、CFP和Agedb数据集上均有2.5个百分点以上的性能提升。
基于上述方法实施例提供的机器学习模型的训练方法,本申请还提供了对应的机器学习模型的训练装置,以使得上述方法实施例提供的机器学习模型的训练方法在实际中得以应用和实现。
装置实施例
需要说明的是,本实施例提供的机器学习模型的训练装置的技术详情可以参照上述方法实施例提供的机器学习模型的训练方法。
参见图9,该图为本申请实施例提供的一种机器学习模型的训练装置的结构示意图。
本申请实施例提供的机器学习模型的训练装置900,包括:
处理单元901,用于利用当前机器学习模型对训练数据集进行处理得到当前处理结果;
更新单元902,用于若确定未满足停止条件,则基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,并继续由所述处理单元901执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果直至满足停止条件;其中,所述停止条件包括:基于所述目标损失函数和所述当前处理结果确定的损失值低于第一阈值,或基于所述目标损失函数和所述当前处理结果确定的损失值的变化率低于第二阈值,或所述模型参数的更新次数达到第三阈值。
可选的,在图9所示的机器学习模型的训练装置900的基础上,所述更新单元902,具体用于:
获取目标损失函数对所述模型参数在施蒂费尔流形上的黎曼梯度;
基于凯莱变换和所述黎曼梯度确定正交矩阵;
利用所述正交矩阵和所述当前机器学习模型的模型参数的乘积,更新所述当前机器学习模型的模型参数。
可选的,在图9所示的机器学习模型的训练装置900的基础上,所述更新单元902,具体用于:
获取目标损失函数对所述模型参数在欧几里得空间中的梯度,作为初始梯度;
根据所述初始梯度,确定目标损失函数对所述模型参数在施蒂费尔流形上的黎曼梯度。
可选的,在图9所示的机器学习模型的训练装置900的基础上,所述更新单元902,具体用于:
当模型参数使用参数矩阵进行表示时,若所述参数矩阵的行数超过所述参数矩阵的列数的二倍,则基于应用谢尔曼莫里森公式后的凯莱变换和所述黎曼梯度,确定正交矩阵。
可选的,在图9所示的机器学习模型的训练装置900的基础上,所述更新单元902,还用于:
利用目标迭代公式,获取所述正交矩阵中逆矩阵的值;
根据所述逆矩阵的值,确定所述正交矩阵和所述当前机器学习模型的模型参数的乘积。
可选的,在图9所示的机器学习模型的训练装置900的基础上,所述目标损失函数包括正则化项。
可选的,在图9所示的机器学习模型的训练装置900的基础上,所述处理单元901,具体用于:
利用当前机器学习模型对训练数据集进行图像分类处理得到当前处理结果;其中,所述训练数据集包括至少一张图像;
或,
利用当前机器学习模型对训练数据集进行语音识别处理得到当前处理结果;其中,所述训练数据集包括至少一段语音;
或,
利用当前机器学习模型对训练数据集进行机器翻译处理得到当前处理结果;其中,所述训练数据集包括至少一份文本;
或,
利用当前机器学习模型对训练数据集进行人脸识别处理得到当前处理结果;其中,所述训练数据集包括至少一张图像。
以上为本申请实施例提供的机器学习模型的训练装置900的具体实施方式,在该实施方式中,在利用当前机器学习模型对训练数据集进行处理得到当前处理结果之后,需要确定是否满足停止条件。如果不满足停止条件,则基于施蒂费尔流形和目标损失函数更新该当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果直至满足停止条件。其中,因停止条件包括:基于目标损失函数和当前处理结果确定的损失值低于第一阈值,或基于目标损失函数和当前处理结果确定的损失值的变化率低于第二阈值,或模型参数的更新次数达到第三阈值,使得基于该停止条件最后获得的机器学习模型能够达到较好的处理效果,如此能够准确地获取到较好的机器学习模型。另外,因模型参数是基于施蒂费尔流形确定的,使得该模型参数能够严格满足正交约束,从而能够有效地降低机器学习模型中模型参数的冗余程度,从而能够有效地提高机器学习模型的表达能力。
本申请实施例还提供了一种用于训练机器学习模型的终端设备和服务器,下面将从硬件实体化的角度对本申请实施例提供的用于训练机器学习模型的终端设备和服务器进行介绍。
参见图10,为本申请实施例提供的一种终端设备的结构示意图。为了便于说明,仅示出了与本申请实施例相关的部分,具体技术细节未揭示的,请参照本申请实施例方法部分。该终端可以为包括手机、平板电脑、个人数字助理(英文全称:Personal DigitalAssistant,英文缩写:PDA)、销售终端(英文全称:Point of Sales,英文缩写:POS)、车载电脑等任意终端设备,以终端为平板电脑为例:
图10示出的是与本申请实施例提供的终端相关的平板电脑的部分结构的框图。参考图10,平板电脑包括:射频(英文全称:Radio Frequency,英文缩写:RF)电路1010、存储器1020、输入单元1030、显示单元1040、传感器1050、音频电路1060、无线保真(英文全称:wireless fidelity,英文缩写:WiFi)模块1070、处理器1080、以及电源1090等部件。本领域技术人员可以理解,图10中示出的平板电脑结构并不构成对平板电脑的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。
存储器1020可用于存储软件程序以及模块,处理器1080通过运行存储在存储器1020的软件程序以及模块,从而执行平板电脑的各种功能应用以及数据处理。存储器1020可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能等)等;存储数据区可存储根据平板电脑的使用所创建的数据(比如音频数据、电话本等)等。此外,存储器1020可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。
处理器1080是平板电脑的控制中心,利用各种接口和线路连接整个平板电脑的各个部分,通过运行或执行存储在存储器1020内的软件程序和/或模块,以及调用存储在存储器1020内的数据,执行平板电脑的各种功能和处理数据,从而对平板电脑进行整体监控。可选的,处理器1080可包括一个或多个处理单元;优选的,处理器1080可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、用户界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理器1080中。
在本申请实施例中,该终端所包括的处理器1080还具有以下功能:
利用当前机器学习模型对训练数据集进行处理得到当前处理结果;
若确定不满足停止条件,则基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果直至满足停止条件;其中,所述停止条件包括:基于所述目标损失函数和所述当前处理结果确定的损失值低于第一阈值,或基于所述目标损失函数和所述当前处理结果确定的损失值的变化率低于第二阈值,或所述模型参数的更新次数达到第三阈值。
可选的,所述处理器1080还用于执行本申请实施例提供的机器学习模型的训练方法的任意一种实现方式的步骤。
本申请实施例还提供了一种服务器,图11是本申请实施例提供的一种服务器的结构示意图,该服务器1100可因配置或性能不同而产生比较大的差异,可以包括一个或一个以上中央处理器(central processing units,CPU)1122(例如,一个或一个以上处理器)和存储器1132,一个或一个以上存储应用程序1142或数据1144的存储介质1130(例如一个或一个以上海量存储设备)。其中,存储器1132和存储介质1130可以是短暂存储或持久存储。存储在存储介质1130的程序可以包括一个或一个以上模块(图示没标出),每个模块可以包括对服务器中的一系列指令操作。更进一步地,中央处理器1122可以设置为与存储介质1130通信,在服务器1100上执行存储介质1130中的一系列指令操作。
服务器1100还可以包括一个或一个以上电源1126,一个或一个以上有线或无线网络接口1150,一个或一个以上输入输出接口1158,和/或,一个或一个以上操作系统1141,例如Windows ServerTM,Mac OS XTM,UnixTM,LinuxTM,FreeBSDTM等等。
上述实施例中由服务器所执行的步骤可以基于该图11所示的服务器结构。
其中,CPU 1122用于执行如下步骤:
利用当前机器学习模型对训练数据集进行处理得到当前处理结果;
若确定不满足停止条件,则基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果直至满足停止条件;其中,所述停止条件包括:基于所述目标损失函数和所述当前处理结果确定的损失值低于第一阈值,或基于所述目标损失函数和所述当前处理结果确定的损失值的变化率低于第二阈值,或所述模型参数的更新次数达到第三阈值。
可选的,CPU 1122还可以用于执行本申请实施例中机器学习模型的训练方法的任意一种实现方式的步骤。
本申请实施例还提供一种计算机可读存储介质,用于存储计算机程序,该计算机程序用于执行前述各个实施例所述的一种机器学习模型的训练方法中的任意一种实施方式。
本申请实施例还提供一种包括指令的计算机程序产品,当其在计算机上运行时,使得计算机执行前述各个实施例所述的一种机器学习模型的训练方法中的任意一种实施方式。
所属领域的技术人员可以清楚地了解到,为描述的方便和简洁,上述描述的系统,装置和单元的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在本申请所提供的几个实施例中,应该理解到,所揭露的系统,装置和方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅仅是示意性的,例如,所述单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。
所述作为分离部件说明的单元可以是或者也可以不是物理上分开的,作为单元显示的部件可以是或者也可以不是物理单元,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部单元来实现本实施例方案的目的。
另外,在本申请各个实施例中的各功能单元可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中。上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。
所述集成的单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的全部或部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,服务器,或者网络设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(英文全称:Read-OnlyMemory,英文缩写:ROM)、随机存取存储器(英文全称:Random AccessMemory,英文缩写:RAM)、磁碟或者光盘等各种可以存储计算机程序的介质。
以上所述,以上实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围。

Claims (10)

1.一种机器学习模型的训练方法,其特征在于,包括:
利用当前机器学习模型对训练数据集进行处理得到当前处理结果;
若确定不满足停止条件,则基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,并继续执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果及后续步骤直至满足停止条件;其中,所述停止条件包括:基于所述目标损失函数和所述当前处理结果确定的损失值低于第一阈值,或基于所述目标损失函数和所述当前处理结果确定的损失值的变化率低于第二阈值,或所述模型参数的更新次数达到第三阈值。
2.根据权利要求1所述的方法,其特征在于,所述基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,具体包括:
获取目标损失函数对所述模型参数在施蒂费尔流形上的黎曼梯度;
基于凯莱变换和所述黎曼梯度确定正交矩阵;
利用所述正交矩阵和所述当前机器学习模型的模型参数的乘积,更新所述当前机器学习模型的模型参数。
3.根据权利要求2所述的方法,其特征在于,所述获取目标损失函数对所述模型参数在施蒂费尔流形上的黎曼梯度,包括:
获取目标损失函数对所述模型参数在欧几里得空间中的梯度,作为初始梯度;
根据所述初始梯度,确定目标损失函数对所述模型参数在施蒂费尔流形上的黎曼梯度。
4.根据权利要求2所述的方法,其特征在于,当模型参数使用参数矩阵进行表示时,若所述参数矩阵的行数超过所述参数矩阵的列数的二倍,所述基于凯莱变换和所述黎曼梯度确定正交矩阵,包括:
基于应用谢尔曼莫里森公式后的凯莱变换和所述黎曼梯度,确定正交矩阵。
5.根据权利要求2所述的方法,其特征在于,所述方法还包括:
利用目标迭代公式,获取所述正交矩阵中逆矩阵的值;
根据所述逆矩阵的值,确定所述正交矩阵和所述当前机器学习模型的模型参数的乘积。
6.根据权利要求1所述的方法,其特征在于,所述目标损失函数包括正则化项。
7.根据权利要求1所述的方法,其特征在于,所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果,包括:
利用当前机器学习模型对训练数据集进行图像分类处理得到当前处理结果;其中,所述训练数据集包括至少一张图像;
或,
利用当前机器学习模型对训练数据集进行语音识别处理得到当前处理结果;其中,所述训练数据集包括至少一段语音;
或,
利用当前机器学习模型对训练数据集进行机器翻译处理得到当前处理结果;其中,所述训练数据集包括至少一份文本;
或,
利用当前机器学习模型对训练数据集进行人脸识别处理得到当前处理结果;其中,所述训练数据集包括至少一张图像。
8.一种机器学习模型的训练装置,其特征在于,包括:
处理单元,用于利用当前机器学习模型对训练数据集进行处理得到当前处理结果;
更新单元,用于若确定未满足停止条件,则基于施蒂费尔流形和目标损失函数更新所述当前机器学习模型的模型参数,并继续由所述处理单元执行所述利用当前机器学习模型对训练数据集进行处理得到当前处理结果直至满足停止条件;其中,所述停止条件包括:基于所述目标损失函数和所述当前处理结果确定的损失值低于第一阈值,或基于所述目标损失函数和所述当前处理结果确定的损失值的变化率低于第二阈值,或所述模型参数的更新次数达到第三阈值。
9.一种设备,其特征在于,所述设备包括处理器以及存储器:
所述存储器用于存储计算机程序;
所述处理器用于根据所述计算机程序执行权利要求1-7中任一项所述的方法。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质用于存储计算机程序,所述计算机程序用于执行权利要求1-7中任一项所述的方法。
CN201911046952.6A 2019-10-30 2019-10-30 一种机器学习模型的训练方法、装置、设备及存储介质 Pending CN110807529A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN201911046952.6A CN110807529A (zh) 2019-10-30 2019-10-30 一种机器学习模型的训练方法、装置、设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN201911046952.6A CN110807529A (zh) 2019-10-30 2019-10-30 一种机器学习模型的训练方法、装置、设备及存储介质

Publications (1)

Publication Number Publication Date
CN110807529A true CN110807529A (zh) 2020-02-18

Family

ID=69489692

Family Applications (1)

Application Number Title Priority Date Filing Date
CN201911046952.6A Pending CN110807529A (zh) 2019-10-30 2019-10-30 一种机器学习模型的训练方法、装置、设备及存储介质

Country Status (1)

Country Link
CN (1) CN110807529A (zh)

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111325354A (zh) * 2020-03-13 2020-06-23 腾讯科技(深圳)有限公司 机器学习模型压缩方法、装置、计算机设备和存储介质
CN111970335A (zh) * 2020-07-30 2020-11-20 腾讯科技(深圳)有限公司 一种信息推荐的方法、装置及存储介质
CN112329072A (zh) * 2020-12-31 2021-02-05 支付宝(杭州)信息技术有限公司 一种基于安全多方计算的模型联合训练方法
CN113537492A (zh) * 2021-07-19 2021-10-22 第六镜科技(成都)有限公司 模型训练及数据处理方法、装置、设备、介质、产品

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111325354A (zh) * 2020-03-13 2020-06-23 腾讯科技(深圳)有限公司 机器学习模型压缩方法、装置、计算机设备和存储介质
CN111325354B (zh) * 2020-03-13 2022-10-25 腾讯科技(深圳)有限公司 机器学习模型压缩方法、装置、计算机设备和存储介质
CN111970335A (zh) * 2020-07-30 2020-11-20 腾讯科技(深圳)有限公司 一种信息推荐的方法、装置及存储介质
CN111970335B (zh) * 2020-07-30 2021-09-07 腾讯科技(深圳)有限公司 一种信息推荐的方法、装置及存储介质
CN112329072A (zh) * 2020-12-31 2021-02-05 支付宝(杭州)信息技术有限公司 一种基于安全多方计算的模型联合训练方法
CN112329072B (zh) * 2020-12-31 2021-03-30 支付宝(杭州)信息技术有限公司 一种基于安全多方计算的模型联合训练方法
CN113537492A (zh) * 2021-07-19 2021-10-22 第六镜科技(成都)有限公司 模型训练及数据处理方法、装置、设备、介质、产品
CN113537492B (zh) * 2021-07-19 2024-04-26 第六镜科技(成都)有限公司 模型训练及数据处理方法、装置、设备、介质、产品

Similar Documents

Publication Publication Date Title
CN110807529A (zh) 一种机器学习模型的训练方法、装置、设备及存储介质
JP6811894B2 (ja) ニューラルネットワーク構造の生成方法および装置、電子機器、ならびに記憶媒体
US10970617B2 (en) Deep convolutional neural network acceleration and compression method based on parameter quantification
WO2017219991A1 (zh) 适用于模式识别的模型的优化方法、装置及终端设备
JP2021006980A (ja) スパース性制約及び知識の蒸留に基づくスパースかつ圧縮されたニューラルネットワーク
CN111553215B (zh) 人员关联方法及其装置、图卷积网络训练方法及其装置
US20210342696A1 (en) Deep Learning Model Training Method and System
CN111105017A (zh) 神经网络量化方法、装置及电子设备
CN114282666A (zh) 基于局部稀疏约束的结构化剪枝方法和装置
WO2022251317A1 (en) Systems of neural networks compression and methods thereof
CN114332500A (zh) 图像处理模型训练方法、装置、计算机设备和存储介质
CN115564017A (zh) 模型数据处理方法、电子设备及计算机存储介质
CN116090536A (zh) 神经网络的优化方法、装置、计算机设备及存储介质
Wang et al. Towards efficient convolutional neural networks through low-error filter saliency estimation
CN111401569B (zh) 超参数优化方法、装置和电子设备
CN107977628B (zh) 神经网络训练方法、人脸检测方法及人脸检测装置
US20220044125A1 (en) Training in neural networks
US11544563B2 (en) Data processing method and data processing device
CN113642592A (zh) 一种训练模型的训练方法、场景识别方法、计算机设备
CN115601550B (zh) 模型确定方法、装置、计算机设备及计算机可读存储介质
US20240070521A1 (en) Layer freezing &amp; data sieving for sparse training
CN113761934B (zh) 一种基于自注意力机制的词向量表示方法及自注意力模型
CN117222005B (zh) 指纹定位方法、装置、电子设备及存储介质
Huang et al. FedMef: Towards Memory-efficient Federated Dynamic Pruning
US20240135180A1 (en) Systems of neural networks compression and methods thereof

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
REG Reference to a national code

Ref country code: HK

Ref legal event code: DE

Ref document number: 40022085

Country of ref document: HK