CN111461345A - 深度学习模型训练方法及装置 - Google Patents

深度学习模型训练方法及装置 Download PDF

Info

Publication number
CN111461345A
CN111461345A CN202010247381.9A CN202010247381A CN111461345A CN 111461345 A CN111461345 A CN 111461345A CN 202010247381 A CN202010247381 A CN 202010247381A CN 111461345 A CN111461345 A CN 111461345A
Authority
CN
China
Prior art keywords
training
deep learning
learning model
data set
round
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202010247381.9A
Other languages
English (en)
Other versions
CN111461345B (zh
Inventor
李兴建
熊昊一
安昊哲
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202010247381.9A priority Critical patent/CN111461345B/zh
Publication of CN111461345A publication Critical patent/CN111461345A/zh
Application granted granted Critical
Publication of CN111461345B publication Critical patent/CN111461345B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques

Abstract

本申请公开了一种深度学习模型训练方法及装置,涉及人工智能领域。具体实现方案为:服务器接收到终端设备发送的训练请求后,响应该训练请求,对数据集中的样本进行m轮训练,训练过程中,不断的更新数据集中各样本的软化标签,从而得到新的数据集,进而利用该新的数据集进行下一轮的训练。采用该种方案,通过同时学习深度学习模型和软化标签,得到泛化能力强的深度学习模型,实现提高深度学习模型准确度的目的。

Description

深度学习模型训练方法及装置
技术领域
本申请实施例涉及人工智能(Artificial Intelligence,AI)技术领域,尤其涉及一种深度学习模型训练方法及装置。
背景技术
目前,越来越多的云服务器厂商提供深度学习模型训练平台,用户在深度学习模型训练平台能够使用不同的深度学习框架进行大规模的训练,以训练得到期望的深度学习模型,如语音识别模型、图片分类模型等。
通常情况下,深度学习模型训练平台在训练深度学习模型时,需要用到的特征包括数据集中每个样本的标签(label)。具体的,将数据集中每个样本的“独热码”作为对应样本的原始标签,对该些原始标签进行标签平滑(label Smoothing)处理,得到各样本的软化标签,进而利用该些软化标签进行深度学习模型训练。其中,软化标签是按照固定的公式直接静态生成。
上述通过静态方式生成软化标签的过程,对数据集中样本之间的异同性利用不足,导致采用该些软化标签训练得到的深度学习模型过拟合,即训练出的深度学习模型只能针对同一规律的样本,无法适应其他规律的新鲜样本,导致深度学习模型准确度低。
发明内容
本申请实施例提供了一种深度学习模型训练方法及装置,通过同时学习深度学习模型和软化标签,得到泛化能力强的深度学习模型,实现提高深度学习模型准确度的目的。
第一方面,本申请实施例提供一种深度学习模型训练方法,服务器接收到终端设备发送的训练请求后,响应该训练请求,对数据集中的样本进行m轮训练,训练过程中,不断的更新数据集中各样本的软化标签,从而得到新的数据集,进而利用该新的数据集进行下一轮的训练。采用该种方案,通过同时学习深度学习模型和软化标签,得到泛化能力强的深度学习模型,实现提高深度学习模型准确度的目的。
第二方面,本申请实施例提供一种深度学习模型训练装置,包括:
输入输出单元,用于接收终端设备发送的训练请求,所述训练请求用于请求训练人工智能深度学习模型;
处理单元,用于根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练,其中,所述m轮训练中任意相邻的两轮训练中,后一轮训练的输入是利用前一轮训练的训练结果对所述数据集中各样本的软化标签进行更新得到的,所述m≥2且为整数。
第三方面、本申请实施例提供一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行第一方面或第一方面任意可能实现的方法。
第四方面,本申请实施例提供一种包含指令的计算机程序产品,当其在电子设备上运行时,使得电子设备计算机执行上述第一方面或第一方面的各种可能的实现方式中的方法。
第五方面,本申请实施例提供一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使所述电子设备执行上述第一方面或第一方面的各种可能的实现方式中的方法。
第六方面,本申请实施例提供一种深度学习模型训练方法,包括:将用于第x-1轮训练的数据集输入至第x-1轮深度学习模型,以得到第x-1个数据集,所述数据集是所述第x-1轮训练的输入,对所述第x-1数据集中的样本进行模型训练,以得到第x轮深度学习模型。
上述申请中的一个实施例具有如下优点或有益效果:通过同时学习深度学习模型和软化标签,得到泛化能力强的深度学习模型,实现提高深度学习模型准确度的目的。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本申请的限定。其中:
图1是本申请实施例提供的深度学习模型训练方法的网络架构示意图;
图2是本申请实施例提供的深度学习模型训练方法的流程图;
图3是本申请实施例提供的深度学习模型训练方法的过程示意图;
图4是本申请实施例提供的深度学习模型训练方法的另一个流程图;
图5为本公开实施例提供的深度学习模型训练装置的结构示意图;
图6是用来实现本公开实施例的深度学习模型训练方法的电子设备的框图。
具体实施方式
以下结合附图对本申请的示范性实施例做出说明,其中包括本申请实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本申请的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
通常情况下,用于深度学习模型训练的数据集包含多个样本和各样本的软化标签(label),软化标签的获取方式有两种:
第一种,静态方式生成软化标签。
该种方式中,将数据集中每个样本的“独热码(one-hot code)”作为对应样本的原始标签,对该些原始标签进行标签平滑(label Smoothing)处理,得到各样本的软化标签。例如,数据集是一个图片集合,其中共有3种样本,分别为棕榈树、松树和男人,则棕榈树的独热码为[1,0,0],松树的独热码为[0,1,0],男人的独热码为[0,0,1]。对该些独热码进行标签平滑处理时,将独热码中为“1”的元素设置为一个特定值,如0.9,将该独热码中其余的元素设置为0.1/(n-1),其中,n为样本的种类数。如此一来,可以得到棕榈树的软化标签为label(棕榈树):[0.9,0.05,0.05],松树的软化标签为label(松树):[0.05,0.9,0.05],男人的软化标签为label(男人):[0.05,0.05,0.9]。
利用该种方式生成的软化标签对数据集中样本之间的异同性利用不足,导致该些软化标签对样本之间的相似性程度不具备代表性,也就是说,该些软化标签不是效果最佳、包含信息最丰富的软化标签。基于该些软化标签训练出的深度学习模型过于拟合,缺乏泛华能力,即无法深度学习模型只能针对同一规律的样本,无法适应其他规律的新鲜样本,导致深度学习模型准确度低。这是因为采用静态方式生成的软化标签无法分辨出具有一定近似关系的样本。继续以上述的包含棕榈树、松树和男人的数据集为例,该数据集中,棕榈树和男人都是树木,具有一定的相似性,但是棕榈树类的图片和男人类的图片、松树类的图片和男人类的图片的差别较大。然而,图片的相似性一般通过软化标签的欧氏距离度量,显然,上述例子中,3个软化标签中俩俩之间的距离(Distance)相等,即Distance(label(棕榈树),label(松树))=Distance(label(棕榈树),label(男人))=Distance(label(松树),label(男人))。
第二种,动态方式生成软化标签。
该种方式中,利用一个已经训练好的、泛华能力强的深度学习模型对数据集中各样本的软化标签进行预测,利用“知识蒸馏”的相似方法,动态产生数据集中各样本的软化标签。
第二种方式虽然实现了动态获得软化标签的优势,但是在得到感谢软化标签之前,必须先训练得到一个泛华能力强的深度学习模型,之后,才能利用该预先训练好的深度学习模型预测软化标签,需要耗费较多的时间和计算资源。
有鉴于此,本申请实施例提供一种深度学习模型训练方法及装置,通过同时学习深度学习模型和软化标签,得到泛化能力强的深度学习模型,实现提高深度学习模型准确度的目的,同时,避免时间和计算资源的浪费。
图1是本申请实施例提供的深度学习模型训练方法的网络架构示意图。该网络架构包括终端设备1和云环境2,云环境2包括云数据中心和云服务平台,所述云数据中心包括云服务提供商拥有的大量基础资源(包括计算资源、存储资源和网络资源),云数据中心包括的计算资源可以是大量的计算设备(例如服务器)。例如,以云数据中心包括的计算资源是运行有虚拟机的服务器为例,则该服务器可以执行本申请实施例所述的深度学习模型训练方法。
深度学习模型训练过程中,由云服务提供商在云服务平台抽象成一种深度学习模型生成服务提供给用户,用户在云服务平台购买该云服务后(例如,可预充值再根据最终资源的使用情况进行结算),云环境利用部署在云数据中心的服务器等向用户提供深度学习模型训练服务。用户在使用深度学习模型训练服务时,可以通过应用程序接口(application program interface,API)或者用户图形界面(Graphical User Interface,GUI)指定需要深度学习模型完成的任务(即任务目标)、并上传数据集至云环境,云环境中的服务器根据训练请求,执行自动训练深度学习模型的操作。服务器在训练深度学习模型的过程中,训练深度学习模型的同时不断的更新数据集中各样本的软化标签,从而得到新的数据集,进而利用该新的数据集进行下一轮的训练。
完成深度学习模型训练后,服务器通过API或者GUI向用户返回训练好的深度学习模型。该训练好的深度学习模型可被用户下载或者在线使用,用于完成特定的任务。
图1中,终端设备1可以为台式终端或移动终端,台式终端可以为电脑等,移动终端可以为手机、平板电脑、笔记本电脑等,服务器可以是独立的服务器、虚拟机或者多个服务器组成的服务器集群等。
图2是本申请实施例提供的深度学习模型训练方法的流程图,本实施例是从服务器的角度进行说明的,本实施例包括:
101、接收终端设备发送的训练请求,所述训练请求用于请求训练人工智能深度学习模型。
示例性,服务器是能够提供深度学习模型训练平台的服务器,用户通过终端设备登录服务器,通过用户界面上传代码等,触发服务器开始深度学习模型训练。其中,深度学习模型可以是用户定制化的模型,可以是图片分类模型、语音分类模型、语义识别模型、商品推荐模型等任意一种人工智能模型。
102、根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练。
其中,所述m轮训练中任意相邻的两轮训练中,后一轮训练的输入是利用前一轮训练的训练结果对所述数据集中各样本的软化标签进行更新得到的,所述m≥2且为整数。
示例性的,服务器在训练深度学习模型的过程中,训练深度学习模型的同时不断的更新数据集中各样本的软化标签,从而得到新的数据集,进而利用该新的数据集进行下一轮的训练。该轮流过程可以拆分为m轮,m的大小和数据集中样本的复杂程度相关,样本越复杂,则m越大,样本越简单,则m越小。每一轮训练中,先利用该轮的数据集训练深度学习模型,训练过程中不断的调整深度学习模型的参数以优化深度学习模型。本轮深度学习模型训练完毕后,利用该深度学习模型预测本轮数据集中各样本的软化标签,进而利用预测出的软化标签更新数据集中各样本的软化标签,得到更新后的。之后,下一轮训练过程中,就可以利用该更新后的数据集进行深度学习模型训练。由此可知,本申请实施例中,任意相邻的两轮训练分别为前一轮训练和后一轮训练,前一轮训练结束后,利用训练结果(即前一轮训练得到的深度学习模型)对前一轮的数据集中各样本的标签进行更新,将更新后的数据集作为后一轮训练的输入。
训练好深度学习模型后,将该深度学习模型部署在服务器上,以将该深度学习模型投入使用。或者,返回该深度学习模型,由用户将该深度学习模型部署在其他服务器上;或者,服务器将训练好的深度学习模型直接发送给需要部署该深度学习模型的服务器。
本申请实施例提供的深度学习模型训练方法,服务器接收到终端设备发送的训练请求后,响应该训练请求,对数据集中的样本进行m轮训练,训练过程中,不断的更新数据集中各样本的软化标签,从而得到新的数据集,进而利用该新的数据集进行下一轮的训练。采用该种方案,通过同时学习深度学习模型和软化标签,得到泛化能力强的深度学习模型,实现提高深度学习模型准确度的目的。
上述实施例中,服务器根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练时,利用第x-1轮训练得到第x-1轮深度学习模型后,利用所述第x-1轮模型对所述数据集中各样本的软化标签进行更新,得到第x-1数据集,将所述第x-1数据集作为所述第x轮训练的输入,以执行第x-1轮训练,
示例性的,第x轮训练的输入是第x-1数据集,第x-1数据集中各样本的软化标签是第x-1轮深度学习模型对数据集中各样本的软化标签进行更新得到的,第x-1轮深度学习模型是第x-1轮训练的结果,数据集用于第x-1轮训练,第x-1数据集中的样本与数据集中的样本一一对应,第x-1数据集中的任意一个样本的软化标签用于表征对应样本和x-1数据集中其他样本的相似度,2≤x≤m且为整数。
当x-1=1、x=2时,上述的第x-1轮训练和第x训练分别为第1轮训练和第2轮训练,第1轮训练过程中,数据集中各样本的软化标签可以是利用热独码等生成的。服务器将数据集中的样本遍历t次,训练得到深度学习模型称之为第1轮深度学习模型。然后,服务器利用该第1轮深度学习模型,训练数据集中各样本软化标签,进而利用训练得到的软化标签更新数据集中各样本的软化标签,得到第1数据集。之后,服务器开始第2轮的训练。其中,t为一个大于或等于1的数值。
第2轮训练过程,输入是第1数据集,服务器将第1数据集中的样本遍历t次,训练得到的深度学习模型称之为第2轮深度学习模型。然后,服务器利用该第2轮深度学习模型,训练第1数据集中各样本的软化标签,进而利用训练得到的软化标签更新第1数据集中各样本的软化标签,得到第2数据集。
之后,服务器采用上述类似过程完成第3至第m轮的训练。
假设深度学习模型用于图片分类,数据集中包含三类图片,分别为棕榈树、松树和男人,每一类图片有很多张,则数据集中,棕榈树类的各图片的软化标签为:[0.9,0.05,0.05],松树类的各图片的软化标签为:[0.05,0.9,0.05],男人类的各图片的软化标签为[0.05,0.05,0.9]。经过几轮训练后,棕榈树类的各图片的软化标签为:[0.59,0.03,0.02],松树类的各图片的软化标签为:[0.02,0.97,0.01],男人类的各图片的软化标签为[0.005,0.005,0.99],此时,该些软化标签仍然很接近最初利用热独码得到的软化标签。
继续采用本申请实施例所述的方法,训练完成一半时,软化标签逐渐软化并能反应类别间的相似关系,棕榈树类、松树类和男人类的软化标签依次更新为[0.7,0.27,0.03]、[0.24,0.75,0.01]、[0.02,0.01,0.97]。根据该些软化标签可明显得到:Distance(label(棕榈树),label(松树))<Distance(label(棕榈树),label(男人)),Distance(label(棕榈树),label(松树))<Distance(label(松树),label(男人))。
继续训练,完成m轮训练后,软化标签越发软化并能反应类别间的相似关系,最后一次更新后,样本数据集中,棕榈树类、松树类和男人类的软化标签依次更新为[0.5,0.47,0.03]、[0.34,0.65,0.01]、[0.015,0.015,0.97]。根据该些软化标签可明显得到:Distance(label(棕榈树),label(松树))<Distance(label(棕榈树),label(男人)),Distance(label(棕榈树),label(松树))<Distance(label(松树),label(男人))。
由此可知:本申请实施例训练出的深度学习模型能够学习出区分棕榈树和松树的必要特征,这样学得的深度学习模型更加可靠,也就是泛化能力强。
下面,对上述实施例中,每一轮训练中,如何结束本轮训练进行详细说明。
一种可行的实现方式中,对于所述m轮中的第x轮训练,将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型,利用所述深度学习模型对应的损失函数判断是否完成所述第x轮深度学习模型的训练。
示例性的,对于任意一轮训练,服务器将上一轮更新后的数据集输入至上一轮训练得到的深度学习模型中,以进行本轮的训练。训练过程中,服务器利用该深度学习模型对应的损失函数不断的计算,根据计算结果确定是否完成本轮训练。其中,损失函数例如为交叉熵损失函数。
例如,服务器判断所述第x-1数据集中的样本是否被所述损失函数遍历预设次数,若所述第x-1数据集中的样本被所述损失函数遍历预设次数,则确定完成所述第x轮深度学习模型的训练。
再如,判断所述损失函数的损失值是否小于预设阈值,若所述损失函数的损失值小于预设阈值,则确定完成所述第x轮深度学习模型的训练。
采用该种方案,通过及时停止任意一轮深度学习模型训练,避免时间和计算资源的浪费。
另一种可行的实现方式中,服务器执行所述m轮中的第x-1轮训练之前,判断所述数据集中每一类样本的软化标签的平均值,以得到各类样本的软化标签,根据各类样本的软化标签,确定所述至少两类样本中第一类样本的软化标签和第二类样本的软化标签,所述第一类样本和所述第二类样本是同一种类的不同子类,根据所述第一类样本的软化标签和所述第二类样本的软化标签,确定是否停止所述深度学习模型训练。
示例性的,每轮训练过程中,服务器都会更新数据集中各样本的软化标签,而数据集中属于同一类的各样本的软化标签的平均值表示该类样本的软化标签。软化标签可以视为一个个的向量,两个软化标签之间的距离,即为两个向量之间的距离。当第一类样本和第二类样本是同一种类的不同子类时,例如,棕榈树和松树同属于树木种类,第一类样本的软化标签和第二类样本的软化标签比较相似。而具有相似关系的两类向量的距离比较小,因此,服务器可以从数据集中确定出两类相似样本,即第一类样本和第二类样本,利用该两类样本的软化标签之间的距离,确定是否停止深度学习模型训练。
本申请实施例中,当更新后的数据集中的软化标签最优时,训练得到的深度学习模型也是最优,因此,可以根据软化标签是否最优,确定深度学习模型是否已经达到最优从而停止本轮深度学习模型训练。
例如,服务器确定所述第一类样本的软化标签和所述第二类样本的软化标签之间的距离是否小于预设阈值,若所述距离小于所述预设阈值,则确定停止所述深度学习模型训练,如距离大于或等于预设阈值,则继续训练深度学习模型。由于深度学习模型训练过程实质上是一个不断优化深度学习模型的过程,使得第x轮训练得到的第x轮深度学习模型往往优于第x-1轮深度学习模型,因此,可以针对每轮训练设置不同的预设阈值,第x轮训练对应的预设阈值小于第x-1轮训练对应的预设阈值。
采用该种方案,通过及时停止任意一轮深度学习模型训练,避免时间和计算资源的浪费。
上述实施例中,对于所述m轮中的第x轮训练,服务器将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型之前,还确定所述第x-1数据集中的同一类别的样本的软化标签的平均值,将所述平均值作为所述第x-1数据集中的同一类别的各样本的软化标签。
示例性的,得到第x-1轮深度学习模型后,服务器将第x-1数据集中每个样本输入第x-1轮深度学习模型,得到第x-1数据集中各样本的软化标签,之后,服务器求取第x-1数据集中同一类别的样本的软化标签的平均值,从而得来一类样本的软化标签。之后,服务器将平均值作为该类样本中每个样本的软化标签。
采用该种方案,实现在每轮训练中确定出软化标签的目的。
图3是本申请实施例提供的深度学习模型训练方法的过程示意图。请参照图3,深度学习模型训练过程中,服务器在得到数据集后,先对数据集中的样本进行数据预处理,如过滤掉重复样本等,得到经过预处理的数据集。之后,利用该经过预处理的数据集进行深度学习模型训练,训练过程中,每一轮都会利用本轮的训练结果,对本轮的输入数据集中各样本的标签进行更新。执行完m轮训练后,对训练好的深度学习模型进行验证,若验证通过,则将该深度学习模型返回给用户;如未通过验证,则重新收集样本、调整算法等继续训练深度学习模型,直到通过验证。后续使用训练好的深度学习模型时,将未见过的数据输入至训练好的深度学习模型,得到预测结果。
图4是本申请实施例提供的深度学习模型训练方法的另一个流程图,包括如下步骤:
201、初始化深度学习模型。
示例性,在执行第1轮的深度学习模型训练之前,服务器先初始化一个深度学习模型。初始化过程中,可以根据需求选择深度学习模型,如语义类模型、图片类模型等。
202、根据数据集中各样本的独热码,确定各样本的软化标签,并进行第1轮的训练。
203、利用本轮得到的深度学习模型更新数据集中个样本的软化标签。
示例性的,每轮训练中,利用本轮所得模型确定本轮输入的数据集中每个样本的软化标签,将同一类样本的软化标签的平均值作为该类样本中每个样本的软化标签。
204、判断是否是最后一轮训练,如是最后一轮训练,则执行步骤205;若不是最后一轮训练,则执行步骤206。
205、训练完成返回深度学习模型。
206、利用更新后的软化标签,开始下一轮的深度学习模型训练,之后,执行步骤203。
上述介绍了本公开实施例提到的深度学习模型训练方法的具体实现,下述为本公开装置实施例,可以用于执行本公开方法实施例。对于本公开装置实施例中未披露的细节,请参照本公开方法实施例。
图5为本公开实施例提供的深度学习模型训练装置的结构示意图。该装置可以集成在服务器中或通过服务器实现。如图5所示,在本实施例中,该深度学习模型训练装置100可以包括:
输入输出单元11,用于接收终端设备发送的训练请求,所述训练请求用于请求训练人工智能深度学习模型;
处理单元12,用于根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练,其中,所述m轮训练中任意相邻的两轮训练中,后一轮训练的输入是利用前一轮训练的训练结果对所述数据集中各样本的软化标签进行更新得到的,所述m≥2且为整数。
一种可行的设计中,所述m轮训练包括第x-1轮训练和第x轮训练,所述处理单元12,用于利用第x-1轮训练得到第x-1轮深度学习模型后,利用所述第x-1轮模型对所述数据集中各样本的软化标签进行更新,得到第x-1数据集,将所述第x-1数据集作为所述第x轮训练的输入,以执行第x-1轮训练。
一种可行的设计中,所述处理单元12,用于对于所述m轮中的第x轮训练,将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型,利用所述深度学习模型对应的损失函数判断是否完成所述第x轮深度学习模型的训练。
一种可行的设计中,所述处理单元12,用于执行所述m轮中的第x-1轮训练之前,判断所述数据集中每一类样本的软化标签的平均值,以得到各类样本的软化标签,根据各类样本的软化标签,确定所述至少两类样本中第一类样本的软化标签和第二类样本的软化标签,所述第一类样本和所述第二类样本是同一种类的不同子类,根据所述第一类样本的软化标签和所述第二类样本的软化标签,确定是否停止所述深度学习模型训练。
一种可行的设计中,所述处理单元12,在根据所述第一类样本的软化标签和所述第二类样本的软化标签,确定是否停止所述深度学习模型训练时,用于确定所述第一类样本的软化标签和所述第二类样本的软化标签之间的距离是否小于预设阈值,若所述距离小于所述预设阈值,则确定停止所述深度学习模型训练。
一种可行的设计中,所述处理单元12,对于所述m轮中的第x轮训练,将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型之前,还用于确定所述第x-1数据集中的同一类别的样本的软化标签的平均值,将所述平均值作为所述第x-1数据集中的同一类别的各样本的软化标签。
一种可行的设计中,当x=2时,所述数据集中各样本的软化标签是利用对应样本的独热码得到的。
一种可行的设计中,所述输入输出单元11,在所述处理单元12根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练之后,还用于输出经过m轮训练的深度学习模型。
本公开实施例提供的深度学习模型训练装置,可用于如上实施例中服务器执行的方法,其实现原理和技术效果类似,在此不再赘述。
图6是用来实现本公开实施例的深度学习模型训练方法的电子设备的框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本申请的实现。
如图6所示,该电子设备包括:一个或多个处理器21、存储器22,以及用于连接各部件的接口,包括高速接口和低速接口。各个部件利用不同的总线互相连接,并且可以被安装在公共主板上或者根据需要以其它方式安装。处理器可以对在电子设备内执行的指令进行处理,包括存储在存储器中或者存储器上以在外部输入/输出装置(诸如,耦合至接口的显示设备)上显示GUI的图形信息的指令。在其它实施方式中,若需要,可以将多个处理器和/或多条总线与多个存储器和多个存储器一起使用。同样,可以连接多个电子设备,各个设备提供部分必要的操作(例如,作为服务器阵列、一组刀片式服务器、或者多处理器系统)。图6中以一个处理器21为例。
存储器22即为本申请所提供的非瞬时计算机可读存储介质。其中,所述存储器存储有可由至少一个处理器执行的指令,以使所述至少一个处理器执行本申请所提供的深度学习模型训练方法。本申请的非瞬时计算机可读存储介质存储计算机指令,该计算机指令用于使计算机执行本申请所提供的深度学习模型训练方法。
存储器22作为一种非瞬时计算机可读存储介质,可用于存储非瞬时软件程序、非瞬时计算机可执行程序以及模块,如本申请实施例中的深度学习模型训练方法对应的程序指令/模块(例如,附图5所示的输入输出单元11、处理单元12)。处理器21通过运行存储在存储器22中的非瞬时软件程序、指令以及模块,从而执行服务器的各种功能应用以及数据处理,即实现上述方法实施例中的深度学习模型训练方法。
存储器22可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序;存储数据区可存储根据深度学习模型训练电子设备的使用所创建的数据等。此外,存储器22可以包括高速随机存取存储器,还可以包括非瞬时存储器,例如至少一个磁盘存储器件、闪存器件、或其他非瞬时固态存储器件。在一些实施例中,存储器22可选包括相对于处理器21远程设置的存储器,这些远程存储器可以通过网络连接至深度学习模型训练电子设备。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
深度学习模型训练方法的电子设备还可以包括:输入装置23和输出装置24。处理器21、存储器22、输入装置23和输出装置24可以通过总线或者其他方式连接,图6中以通过总线连接为例。
输入装置23可接收输入的数字或字符信息,以及产生与深度学习模型训练电子设备的用户设置以及功能控制有关的键信号输入,例如触摸屏、小键盘、鼠标、轨迹板、触摸板、指示杆、一个或者多个鼠标按钮、轨迹球、操纵杆等输入装置。输出装置24可以包括显示设备、辅助照明装置(例如,LED)和触觉反馈装置(例如,振动电机)等。该显示设备可以包括但不限于,液晶显示器(LCD)、发光二极管(LED)显示器和等离子体显示器。在一些实施方式中,显示设备可以是触摸屏。
此处描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、专用ASIC(专用集成电路)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
这些计算程序(也称作程序、软件、软件应用、或者代码)包括可编程处理器的机器指令,并且可以利用高级过程和/或面向对象的编程语言、和/或汇编/机器语言来实施这些计算程序。如本文使用的,术语“机器可读介质”和“计算机可读介质”指的是用于将机器指令和/或数据提供给可编程处理器的任何计算机程序产品、设备、和/或装置(例如,磁盘、光盘、存储器、可编程逻辑装置(PLD)),包括,接收作为机器可读信号的机器指令的机器可读介质。术语“机器可读信号”指的是用于将机器指令和/或数据提供给可编程处理器的任何信号。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
本申请实施例还提供一种深度学习模型训练方法,包括:将用于第x-1轮训练的数据集输入至第x-1轮深度学习模型,以得到第x-1个数据集,所述数据集是所述第x-1轮训练的输入,对所述第x-1数据集中的样本进行模型训练,以得到第x轮深度学习模型。
该实施例的具体实现原理可以参见上述实施例的记载,此处不再赘述。
根据本申请实施例的技术方案,服务器接收到终端设备发送的训练请求后,响应该训练请求,对数据集中的样本进行m轮训练,训练过程中,不断的更新数据集中各样本的软化标签,从而得到新的数据集,进而利用该新的数据集进行下一轮的训练。采用该种方案,通过同时学习深度学习模型和软化标签,得到泛化能力强的深度学习模型,实现提高深度学习模型准确度的目的。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发申请中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本申请公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本申请保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本申请的精神和原则之内所作的修改、等同替换和改进等,均应包含在本申请保护范围之内。

Claims (19)

1.一种深度学习模型训练方法,其特征在于,包括:
接收终端设备发送的训练请求,所述训练请求用于请求训练人工智能深度学习模型;
根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练,其中,所述m轮训练中任意相邻的两轮训练中,后一轮训练的输入是利用前一轮训练的训练结果对所述数据集中各样本的软化标签进行更新得到的,所述m≥2且为整数。
2.根据权利要求1所述的方法,其特征在于,所述m轮训练包括第x-1轮训练和第x轮训练,所述根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练,包括:
利用第x-1轮训练得到第x-1轮深度学习模型后,利用所述第x-1轮模型对所述数据集中各样本的软化标签进行更新,得到第x-1数据集;
将所述第x-1数据集作为所述第x轮训练的输入,以执行第x-1轮训练。
3.根据权利要求2所述的方法,其特征在于,所述根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练,包括:
对于所述m轮中的第x轮训练,将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型;
利用所述深度学习模型对应的损失函数判断是否完成所述第x轮深度学习模型的训练。
4.根据权利要求2所述的方法,其特征在于,所述根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练,包括:
执行所述m轮中的第x-1轮训练之前,判断所述数据集中每一类样本的软化标签的平均值,以得到各类样本的软化标签;
根据各类样本的软化标签,确定至少两类样本中第一类样本的软化标签和第二类样本的软化标签,所述第一类样本和所述第二类样本是同一种类的不同子类;
根据所述第一类样本的软化标签和所述第二类样本的软化标签,确定是否停止所述深度学习模型训练。
5.根据权利要求4所述的方法,其特征在于,所述根据所述第一类样本的软化标签和所述第二类样本的软化标签,确定是否停止所述深度学习模型训练,包括:
确定所述第一类样本的软化标签和所述第二类样本的软化标签之间的距离是否小于预设阈值,若所述距离小于所述预设阈值,则确定停止所述深度学习模型训练。
6.根据权利要求2-5任一项所述的方法,其特征在于,所述对于所述m轮中的第x轮训练,将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型之前,还包括:
确定所述第x-1数据集中的同一类别的样本的软化标签的平均值,将所述平均值作为所述第x-1数据集中的同一类别的各样本的软化标签。
7.根据权利要求2-5任一项所述的方法,其特征在于,
当x=2时,所述数据集中各样本的软化标签是利用对应样本的独热码得到的。
8.根据权利要求1~5任一项所述的方法,其特征在于,所述根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练之后,还包括:
输出经过m轮训练的深度学习模型。
9.一种深度学习模型训练装置,其特征在于,包括:
输入输出单元,用于接收终端设备发送的训练请求,所述训练请求用于请求训练人工智能深度学习模型;
处理单元,用于根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练,其中,所述m轮训练中任意相邻的两轮训练中,后一轮训练的输入是利用前一轮训练的训练结果对所述数据集中各样本的软化标签进行更新得到的,所述m≥2且为整数。
10.根据权利要求9所述的装置,其特征在于,所述m轮训练包括第x-1轮训练和第x轮训练,所述处理单元,用于利用第x-1轮训练得到第x-1轮深度学习模型后,利用所述第x-1轮模型对所述数据集中各样本的软化标签进行更新,得到第x-1数据集,将所述第x-1数据集作为所述第x轮训练的输入,以执行第x-1轮训练。
11.根据权利要求10所述的装置,其特征在于,
所述处理单元,用于对于所述m轮中的第x轮训练,将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型,利用所述深度学习模型对应的损失函数判断是否完成所述第x轮深度学习模型的训练。
12.根据权利要求10所述的装置,其特征在于,
所述处理单元,用于执行所述m轮中的第x-1轮训练之前,判断所述数据集中每一类样本的软化标签的平均值,以得到各类样本的软化标签,根据各类样本的软化标签,确定所述至少两类样本中第一类样本的软化标签和第二类样本的软化标签,所述第一类样本和所述第二类样本是同一种类的不同子类,根据所述第一类样本的软化标签和所述第二类样本的软化标签,确定是否停止所述深度学习模型训练。
13.根据权利要求12所述的装置,其特征在于,
所述处理单元,在根据所述第一类样本的软化标签和所述第二类样本的软化标签,确定是否停止所述深度学习模型训练时,用于确定所述第一类样本的软化标签和所述第二类样本的软化标签之间的距离是否小于预设阈值,若所述距离小于所述预设阈值,则确定停止所述深度学习模型训练。
14.根据权利要求10~13任一项所述的装置,其特征在于,
所述处理单元,对于所述m轮中的第x轮训练,将所述第x-1数据集中的样本输入至所述第x-1轮深度学习模型以训练得到第x轮深度学习模型之前,还用于确定所述第x-1数据集中的同一类别的样本的软化标签的平均值,将所述平均值作为所述第x-1数据集中的同一类别的各样本的软化标签。
15.根据权利要求10~13任一项所述的装置,其特征在于,
当x=2时,所述数据集中各样本的软化标签是利用对应样本的独热码得到的。
16.根据权利要求9~13任一项所述的装置,其特征在于,
所述输入输出单元,在所述处理单元根据所述训练请求,利用数据集中的样本对深度学习模型进行m轮训练之后,还用于输出经过m轮训练的深度学习模型。
17.一种电子设备,其特征在于,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-8中任一项所述的方法。
18.一种存储有计算机指令的非瞬时计算机可读存储介质,其特征在于,所述计算机指令用于使所述计算机执行权利要求1-8中任一项所述的方法。
19.一种深度学习模型训练方法,其特征在于,包括:
将用于第x-1轮训练的数据集输入至第x-1轮深度学习模型,以得到第x-1个数据集,所述数据集是所述第x-1轮训练的输入;
对所述第x-1数据集中的样本进行模型训练,以得到第x轮深度学习模型。
CN202010247381.9A 2020-03-31 2020-03-31 深度学习模型训练方法及装置 Active CN111461345B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010247381.9A CN111461345B (zh) 2020-03-31 2020-03-31 深度学习模型训练方法及装置

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010247381.9A CN111461345B (zh) 2020-03-31 2020-03-31 深度学习模型训练方法及装置

Publications (2)

Publication Number Publication Date
CN111461345A true CN111461345A (zh) 2020-07-28
CN111461345B CN111461345B (zh) 2023-08-11

Family

ID=71681403

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010247381.9A Active CN111461345B (zh) 2020-03-31 2020-03-31 深度学习模型训练方法及装置

Country Status (1)

Country Link
CN (1) CN111461345B (zh)

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111881187A (zh) * 2020-08-03 2020-11-03 深圳诚一信科技有限公司 一种自动建立数据处理模型的方法及相关产品
CN112491820A (zh) * 2020-11-12 2021-03-12 新华三技术有限公司 异常检测方法、装置及设备
CN113343787A (zh) * 2021-05-20 2021-09-03 沈阳铸造研究所有限公司 一种基于深度学习的适用于图谱对比场景中等级评定方法
CN113627610A (zh) * 2021-08-03 2021-11-09 北京百度网讯科技有限公司 用于表箱预测的深度学习模型训练方法及表箱预测方法
CN113656669A (zh) * 2021-10-19 2021-11-16 北京芯盾时代科技有限公司 标签更新方法及装置
CN113792883A (zh) * 2021-03-03 2021-12-14 京东科技控股股份有限公司 基于联邦学习的模型训练方法、装置、设备和介质
CN113986561A (zh) * 2021-12-28 2022-01-28 苏州浪潮智能科技有限公司 人工智能任务处理方法、装置、电子设备及可读存储介质
CN113792883B (zh) * 2021-03-03 2024-04-16 京东科技控股股份有限公司 基于联邦学习的模型训练方法、装置、设备和介质

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN104966105A (zh) * 2015-07-13 2015-10-07 苏州大学 一种鲁棒机器错误检索方法与系统
CN105335756A (zh) * 2015-10-30 2016-02-17 苏州大学 一种鲁棒学习模型与图像分类系统
CN107316083A (zh) * 2017-07-04 2017-11-03 北京百度网讯科技有限公司 用于更新深度学习模型的方法和装置
CN108334943A (zh) * 2018-01-03 2018-07-27 浙江大学 基于主动学习神经网络模型的工业过程半监督软测量建模方法
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
CN110555870A (zh) * 2019-09-09 2019-12-10 北京理工大学 基于神经网络的dcf跟踪置信度评价与分类器更新方法

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN104966105A (zh) * 2015-07-13 2015-10-07 苏州大学 一种鲁棒机器错误检索方法与系统
CN105335756A (zh) * 2015-10-30 2016-02-17 苏州大学 一种鲁棒学习模型与图像分类系统
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
CN107316083A (zh) * 2017-07-04 2017-11-03 北京百度网讯科技有限公司 用于更新深度学习模型的方法和装置
CN108334943A (zh) * 2018-01-03 2018-07-27 浙江大学 基于主动学习神经网络模型的工业过程半监督软测量建模方法
CN110555870A (zh) * 2019-09-09 2019-12-10 北京理工大学 基于神经网络的dcf跟踪置信度评价与分类器更新方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
DONG WANG等: "Label-Denoising Auto-encoder for Classification with Inaccurate Supervision Information", 《2014 22ND INTERNATIONAL CONFERENCE ON PATTERN RECOGNITION》, pages 3648 - 3653 *
LI DING等: "Weakly-Supervised Action Segmentation with Iterative Soft Boundary Assignment", 《PROCEEDINGS OF THE IEEE CONFERENCE ON COMPUTER VISION AND PATTERN RECOGNITION (CVPR)》, pages 6508 - 6516 *
赵旦峰等: "基于后验概率判决的动态迭代停止算法", 《吉林大学学报(工学版)》, vol. 42, no. 3, pages 766 - 770 *

Cited By (11)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111881187A (zh) * 2020-08-03 2020-11-03 深圳诚一信科技有限公司 一种自动建立数据处理模型的方法及相关产品
CN112491820A (zh) * 2020-11-12 2021-03-12 新华三技术有限公司 异常检测方法、装置及设备
CN112491820B (zh) * 2020-11-12 2022-07-29 新华三技术有限公司 异常检测方法、装置及设备
CN113792883A (zh) * 2021-03-03 2021-12-14 京东科技控股股份有限公司 基于联邦学习的模型训练方法、装置、设备和介质
CN113792883B (zh) * 2021-03-03 2024-04-16 京东科技控股股份有限公司 基于联邦学习的模型训练方法、装置、设备和介质
CN113343787A (zh) * 2021-05-20 2021-09-03 沈阳铸造研究所有限公司 一种基于深度学习的适用于图谱对比场景中等级评定方法
CN113343787B (zh) * 2021-05-20 2023-09-01 中国机械总院集团沈阳铸造研究所有限公司 一种基于深度学习的适用于图谱对比场景中等级评定方法
CN113627610A (zh) * 2021-08-03 2021-11-09 北京百度网讯科技有限公司 用于表箱预测的深度学习模型训练方法及表箱预测方法
CN113656669A (zh) * 2021-10-19 2021-11-16 北京芯盾时代科技有限公司 标签更新方法及装置
CN113656669B (zh) * 2021-10-19 2023-12-05 北京芯盾时代科技有限公司 标签更新方法及装置
CN113986561A (zh) * 2021-12-28 2022-01-28 苏州浪潮智能科技有限公司 人工智能任务处理方法、装置、电子设备及可读存储介质

Also Published As

Publication number Publication date
CN111461345B (zh) 2023-08-11

Similar Documents

Publication Publication Date Title
CN111461345B (zh) 深度学习模型训练方法及装置
US20210256403A1 (en) Recommendation method and apparatus
US11386128B2 (en) Automatic feature learning from a relational database for predictive modelling
US10762678B2 (en) Representing an immersive content feed using extended reality based on relevancy
US11763084B2 (en) Automatic formulation of data science problem statements
CN109918662B (zh) 一种电子资源的标签确定方法、装置和可读介质
CN111667056B (zh) 用于搜索模型结构的方法和装置
CN111582479B (zh) 神经网络模型的蒸馏方法和装置
CN111104514A (zh) 文档标签模型的训练方法及装置
CN109471978B (zh) 一种电子资源推荐方法及装置
US11599826B2 (en) Knowledge aided feature engineering
US20190156177A1 (en) Aspect Pre-selection using Machine Learning
WO2022018538A1 (en) Identifying source datasets that fit transfer learning process for target domain
US20240112229A1 (en) Facilitating responding to multiple product or service reviews associated with multiple sources
CN111966361A (zh) 用于确定待部署模型的方法、装置、设备及其存储介质
US20190354849A1 (en) Automatic data preprocessing
CN112288483A (zh) 用于训练模型的方法和装置、用于生成信息的方法和装置
US9843837B2 (en) Cross-platform analysis
CN114360027A (zh) 一种特征提取网络的训练方法、装置及电子设备
US20220129794A1 (en) Generation of counterfactual explanations using artificial intelligence and machine learning techniques
CN113642635A (zh) 模型训练方法及装置、电子设备和介质
CN110348581B (zh) 用户特征群中用户特征寻优方法、装置、介质及电子设备
CN114092608B (zh) 表情的处理方法及装置、计算机可读存储介质、电子设备
CN115048425A (zh) 一种基于强化学习的数据筛选方法及其装置
CN112328710A (zh) 实体信息处理方法、装置、电子设备和存储介质

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant