CN114492788A - 训练深度学习模型的方法和装置、电子设备及存储介质 - Google Patents
训练深度学习模型的方法和装置、电子设备及存储介质 Download PDFInfo
- Publication number
- CN114492788A CN114492788A CN202111683696.9A CN202111683696A CN114492788A CN 114492788 A CN114492788 A CN 114492788A CN 202111683696 A CN202111683696 A CN 202111683696A CN 114492788 A CN114492788 A CN 114492788A
- Authority
- CN
- China
- Prior art keywords
- training
- data
- sample data
- inference
- correct
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
Abstract
本公开提供用于训练深度学习模型的方法和装置、电子设备及存储介质,该方法包括:使用第一训练数据集开始对待训练模型进行迭代训练;筛选推理错误的样本数据和推理正确的样本数据;对所筛选的推理错误的样本数据进行数据增强;基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集;以及在使用所述第二训练数据集进行进一步训练后,使用所述第一训练数据集继续进行迭代训练。本公开能够在深度学习模型的训练过程中基于推理错误的样本数据进行动态调整生成新的训练数据集,从而优化深度学习模型的训练过程,进而改善深度学习模型的训练效率和训练后的预测性能。
Description
技术领域
本公开涉及计算机技术领域,具体涉及人工智能和深度学习技术,尤其涉及训练深度学习模型的方法和装置、电子设备及存储介质。
背景技术
随着人工智能技术的发展和人工智能应用领域的不断拓展,对深度学习模型提出了更高的要求。深度学习模型需要通过基于现有数据的训练过程收敛到所需的预测性能(例如分类、检测),因此用于训练过程的数据集(即训练数据集)直接影响深度学习模型的训练效率和任务预测性能。通常,可获得的用于训练数据集的样本数据分布不均,并且在训练过程中训练数据集固定不变,导致深度学习模型在诸如训练效率和预测性能等方面不够理想。
发明内容
本公开旨在提供一种用于训练深度学习模型的方法和装置、电子设备及存储介质,以至少解决上述技术问题。
根据本公开的一方面,提供了一种训练深度学习模型的方法,包括:
使用第一训练数据集开始对待训练模型进行迭代训练;筛选推理错误的样本数据和推理正确的样本数据;对所筛选的推理错误的样本数据进行数据增强;基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集;以及在使用所述第二训练数据集进行进一步训练后,使用所述第一训练数据集继续进行迭代训练。
根据本公开的另一方面,提供了一种用于训练深度学习模型的装置,包括训练单元、数据筛选单元、数据增强单元以及生成单元,其中:
所述训练单元用于使用第一训练数据集开始对待训练模型进行迭代训练;所述数据筛选单元用于筛选推理错误的样本数据和推理正确的样本数据;所述数据增强单元用于对所筛选的推理错误的样本数据进行数据增强;所述生成单元用于基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集;所述训练单元还用于使用所述第二训练数据集进行进一步训练,并且还用于在所述进一步训练后使用所述第一训练数据集继续进行迭代训练。
根据本公开的又一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器,其中所述存储器存储有可被至少一个处理器执行的指令,所述指令在被所述至少一个处理器执行时使得所述电子设备执行前述的训练深度学习模型的方法。
根据本公开的又一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中所述计算机指令用于使计算机执行前述的训练深度学习模型的方法。
根据本公开的技术方案能够在深度学习模型的训练过程中基于推理错误的样本数据进行动态调整生成新的训练数据集,从而优化深度学习模型的训练过程,进而改善深度学习模型的训练效率和训练后的预测性能。
应当理解,本部分所描述的内容并非旨在表示本公开的实施方式的关键或重要的目的、特征以及技术效果,也不用于限制本公开的范围。本公开的其它目的、特征以及技术效果将通过以下的说明书而变得容易理解。
附图说明
附图仅用于更好地理解本公开的实施方式和实施例,并不旨在构成对本公开的限定。其中:
图1是根据本公开的一些实施方式的训练深度学习模型的方法的示意性流程图;
图2是根据本公开的一些实施例的训练深度学习模型的方法的示意图;
图3是根据本公开的一些实施方式的用于训练深度学习模型的装置的示意性结构图;
图4是根据本公开的一些实施方式的电子设备的示意性结构图。
具体实施方式
以下结合附图对本公开的示意性实施方式和实施例做出说明,其中包括本公开实施方式和实施例的多种细节以帮助理解,应当将它们认为仅仅是示例性的。因此,本领域普通技术人员应当认识到,可以对这里说明的实施方式和实施例做出多种改变和修改,而不会背离本公开的范围和构思。同样,为了清楚和简洁,以下的说明中省略了对公知功能和结构的描述。
通常,可获得的作为用于训练深度学习模型的训练数据集的样本数据分布不均,例如在分类任务中,所有需要预测的类别中有一个或者多个类别的样本量非常少或不同类别样本的比例相差悬殊,会导致样本量较少的一类的高错分率,即样本量较少的一类的样本会有较大的比例会被错误地预测成样本量较多的一类。目前,针对样本数据分布不均的这种问题主要采用数据增强,以图像为例,在人工观察到某一类样本稀少的时候,可以使用图片翻转、缩放、裁剪、位移、添加高斯噪声、对抗网络生成等方式,但即使在进行数据增强后也会存在分布依然失衡的情况,因此进一步适当地复制稀少的样本数据来进行弥补,然而,这种方式一方面需要大量的人力和时间成本来筛选和增强稀少的样本数据,另一方面样本数据分布的均衡化效果仍不理想且均衡化难度较大。以图像为例,通过人工方式可容易观察到图片在特征维度A(实物类别)、B(实物形状)、C(亮度情况)等上的分布并刻意地去对这些特征维度上稀少的样本数据进行数据增强以使训练数据集尽可能在这些可见的特征维度接近均衡,但实际上该训练数据集的样本数据可能还具有更多不易观察到的特征维度α、β、γ等,因此完全依赖人工观察的数据增强方式无法真正接近理想化的均衡。
另外,在深度学习模型的训练过程中训练数据集通常固定不变,但实际上到了深度学习的训练后期,权重的更新仅仅来自于其中一小部分样本数据的贡献,因此诸如准确率的性能提升也会逐步放缓,这意味着硬件上的算力也在很大程度上浪费掉,导致训练效率不高。
鉴于此,本公开旨在提供改进的深度学习模型训练方案,包括用于训练深度学习模型的方法和装置、电子设备及存储介质。
本公开提供的训练深度学习模型的方法可适用于各种基于深度学习的模型,例如包括当前已知的模型和未来基于这些已知的模型改进或演进的模型。
上述方法可应用于服务器端。在这里,服务器可以是硬件或软件,可以实现为单个硬件模块或单个软件模块,也可以实现为提供分布式服务的多个硬件或软件模块。服务器端可以获取模型的初始信息,包括模型的结构信息和参数信息,还可以获取用于训练模型的训练数据,然后构造监督函数,利用训练数据来对模型进行迭代训练。
上述方法也可以应用于具有数据处理能力的终端设备,终端设备也可以利用处理器(如图形处理器(GPU)、中央处理单元(CPU)等)对模型进行迭代训练。
图1是根据本公开的一些实施方式的训练深度学习模型的方法的示意性流程图。如图1所示,该方法100包括以下步骤:
S110:使用第一训练数据集开始对待训练模型进行迭代训练。
在一些实施方式中,首先可由训练深度学习模型的方法的执行主体(例如服务器)获取第一训练数据集,例如,第一训练数据集可以预先存储在执行主体本地,也可以通过网络爬虫获取。第一训练数据集可以包括诸如图像、文本、语音等多媒体数据的样本数据,样本数据可以包含该样本数据所属类别的标签信息,例如图像中的目标对象的类别信息、图像的质量类别信息、文本语句的语言类别信息、语音信号对应的个人身份信息,等等。第一训练数据集可以是未经数据增强的,也可以是经数据增强的。
在一些实施方式中,待训练模型可以是未经训练的模型或经过初步训练的模型,在实际场景中,可以是线上运行的模型或者未上线的模型。待训练模型可以是用于执行分类或检测任务的模型,例如,图像分类模型、语音识别模型等等。
可以理解,本文的迭代训练是指多轮(epoch)的迭代训练。
S120:筛选推理错误的样本数据和推理正确的样本数据。
在一些实施方式中,步骤S120可以包括如下子步骤:
S1201:基于在至少一轮的迭代训练中推理错误的样本数据构成推理错误信息集,其中所述推理错误信息集中包括所述推理错误的样本数据的索引和在所述至少一轮的迭代训练中的推理错误置信度。
在一些实施方式中,可以在确定第N轮迭代训练的结果满足预设条件时,从第N+1轮迭代训练开始确定每轮迭代训练中推理错误的样本数据以用于构成推理错误信息集,直到基于推理错误信息集确定截止到第M轮迭代训练的总推理错误次数超过预设次数,其中N是不小于1的正整数,M是大于N的正整数。可以理解,在第N+1轮迭代训练中确定的推理错误的样本数据用来创建推理错误信息集,而从第N+2轮迭代训练开始,在每轮迭代训练中确定的推理错误的样本数据则用来更新推理错误信息集。在一些实施方式中,可以从第L轮迭代训练开始判断每轮迭代训练的结果是否满足预设条件,直到确定第N轮迭代训练的结果满足预设条件,其中L是不大于N的正整数。
应当理解,推理错误的样本数据表示在使用它对模型进行训练后由模型输出的推理结果和其实际标签信息不一致的样本数据,其中推理结果可以表现为多个置信度中最高的一个置信度对应的类别。
在一些实施方式中,迭代训练的结果可以是训练精度(accuracy)、训练损失(loss)和已训练轮数中的至少一者,且预设条件可以是不小于预设的训练精度阈值、不大于预设的训练损失阈值以及不小于预设的训练轮数阈值中相应的至少一者。预设的阈值可以根据待训练模型和第一训练数据集合理地设置,由此能够使得所筛选的推理错误的样本数据能够更接近在第一训练数据集中稀少的样本数据。
在一些实施方式中,在每轮迭代训练中确定推理错误的样本数据时,可提取其推理错误信息并记录在推理错误信息集中,该推理错误信息包括该样本数据的索引(例如存储路径)和该轮迭代训练中的推理错误置信度(例如,假设标签类别是A,但实际推理结果为类别B,此时实际推理成类别B的置信度就是“推理错误置信度”)。可以理解,即使相同的样本数据在多轮迭代训练中均被推理错误,在推理错误信息集中仍会记录为多条推理错误信息(即分别包括各自的置信度)。
在一些实施方式中,总推理错误次数可以是推理错误信息集中累积的推理错误信息的总条数(即不管相同的样本数据是否重复推理错误)。预设次数可以视待训练模型的训练效果和第一训练数据集的分布情况而合理地设置,例如设置为第一训练数据集中的样本数据数量的80%,本公开对此不作特别限定。
S1202:在所述至少一轮的迭代训练中选择推理正确的样本数据以构成推理正确信息集,使得所选的推理正确的样本数据在标签信息和推理正确置信度方面分布基本均衡,其中所述推理正确信息集包括所选择的推理正确的样本数据的索引。
在一些实施方式中,可以在基于推理错误信息集确定截止到第M轮迭代训练的总推理错误次数超过预设次数时,在第M轮迭代训练中确定推理正确的样本数据并在从中进行选择。
在一些实施方式中,可以在第M轮迭代训练中确定推理正确的样本数据时,提取它们的推理正确信息,其可以分别包括相应样本数据的索引(例如存储路径)、标签信息和在该轮迭代训练中的推理正确置信度(例如,假设标签类别是A,实际推理结果也是类别A,此时实际推理成类别A对应的置信度就是“推理正确置信度”),然后可以为每个标签信息(例如,类别)按推理正确置信度(例如,置信度区间)均匀地选择推理正确的样本数据,并记录它们的索引以生成推理正确信息集。作为示例,对每个标签类别,可以按≥0.9、[0.8,0.9)以及[0.7,0.8)的推理正确置信度区间分别随机选取10个。
S130:对所筛选的推理错误的样本数据进行数据增强。
在一些实施方式中,步骤S130可以包括子步骤S1301:基于所述推理错误信息集中的索引获取所筛选的推理错误的样本数据。
在一些实施方式中,步骤S130还可以包括子步骤S1302:对于每个所获取的推理错误的样本数据,进行与基于在所述推理错误信息集中的推理错误重复次数和各自的推理错误置信度正相关的数据增强。作为示例,假设在所述推理错误信息集中,样本数据A记录为推理错误3次,推理错误置信度分别为0.6、0.7、0.8,样本数据B记录为推理错误1次,置信度为0.9,那么对A和B进行数据增强的次数比例保持(0.6+0.7+0.8):0.9(即7:3)。数据增强技术可以根据样本数据集的特点在已知的方式中进行选取,例如随机裁剪、随机翻转、位移、添加高斯噪声等。
在一些实施方式中,相对于子步骤S1301附加地或替代地,步骤S130还可以包括子步骤S1303:对于每个所获取的推理错误的样本数据,基于与所述推理错误置信度正相关的加权值进行数据增强。在一些实施方式中,可以按推理错误置信度区间预先设置加权值,使得越高的推理错误置信度区间被赋予越大的加权值,且对每个样本数据进行的数据增强次数设为与推理错误置信度和相应加权值的乘积成正比。作为示例,推理错误置信度>0.9时的加权值为1.5,在(0.8,0.9]区间时的加权值为1.2,在(0.7,0.8]区间时的加权值为1.1,>0.5时的加权值为1,此时假设在所述推理错误信息集中,样本数据A的推理错误置信度为0.95,样本数据B的推理错误置信度为0.85,样本数据C的推理错误置信度为0.75,样本数据D的推理错误置信度为0.65,则对它们进行数据增强的次数之比为()0.95*1.5):(0.85*1.2):(0.75*1.1):(0.65*1)。数据增强技术可以根据样本数据集的特点在已知的方式中进行选取,例如随机裁剪、随机翻转、位移、添加高斯噪声等。
S140:基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集。
在一些实施方式中,在S140之前,本公开的方法还可以包括:基于步骤S1202中的推理正确信息集中的索引获取所筛选的推理正确的样本数据。
在一些实施方式中,可以按预设比例组合经数据增强的推理错误的样本数据和所筛选的推理正确的样本数据以生成第二训练数据集。该预设比例可以视待训练模型和其任务进行合理设置,例如,可以设为1:1,本公开对此不作特别限定。
S150:在使用所述第二训练数据集进行进一步训练后,使用所述第一训练数据集继续进行迭代训练。
在一些实施方式中,基于第二训练数据集的进一步训练可以是至少一轮的训练。例如,可以使用第二训练数据集进行仅一轮的训练。又例如,也可以使用第二训练数据集进行多轮的迭代训练,但应当理解,基于第二训练数据集的迭代训练轮数不宜过大,以免反而降低训练优化效果。
在一些实施方式中,可以在第M+1轮迭代训练中使用第二训练数据集进行训练后,从第M+2轮迭代训练开始重新使用第一训练数据集继续进行迭代训练。
在一些实施方式中,可以在使用第二训练数据集进行进一步训练后,使用第一训练数据集继续进行迭代训练,直到训练完成。可以理解,训练完成意味着达到收敛,在此前提下可以指完成预设轮次的迭代训练,也可以指性能不再提升,例如精度不再提升或损失不再减少。
在一些实施方式中,训练深度学习模型的方法还可以包括:以迭代方式重复执行步骤S120至S150,直到训练完成。作为示例,在重新使用第一训练数据集继续进行迭代训练后,经过如步骤S120的样本数据筛选、如步骤S130的数据增强和如步骤S140的新训练数据集(例如,第三训练数据集)生成,如步骤S150使用新训练数据集进行进一步训练后重新使用第一训练数据集继续进行迭代训练,以此类推,直到训练完成。
图2是根据本公开的一些实施例的训练深度学习模型的方法的示意图。如图2所示,在使用第一训练数据集210开始对待训练模型进行迭代训练后,从第69轮迭代训练开始每轮确定推理错误的样本数据,由此生成和更新推理错误信息集211,直到第73轮迭代训练;然后,基于推理错误信息集211和另外筛选的推理正确的样本数据集(未图示)生成第二训练数据集220,并在第74轮迭代训练中使用第二训练数据集220进行进一步训练后重新使用第一训练数据集210继续进行迭代训练。
本公开提供的训练深度学习模型的方法由于通过筛选推理错误的样本数据和推理正确的样本数据且对所筛选的推理错误的样本数据的数据增强来组合生成新的训练数据集以用于进行补充训练,因此尤其在训练数据分布不均的情况下,不仅能够避免或减少人工筛选稀少的样本数据的成本,而且针对稀少的样本数据的数据增强效果优异,从而能够显著改善训练后模型的泛化能力。
另外,本公开提供的训练深度学习模型的方法由于在迭代训练过程中动态优化训练数据集,因此能够加快权重迭代更新以加快收敛,从而节省算力资源以提升训练效率。
作为训练深度学习模型的方法的执行装置,本公开提供的用于训练深度学习模型的装置的实施方式与前述训练深度学习模型的方法的实施方式相对应,且具体可以应用于各种电子设备。
图3是根据本公开的一些实施方式的用于训练深度学习模型的装置的示意性结构图。如图3所示,该装置300包括训练单元310、数据筛选单元320、数据增强单元330以及生成单元340,其中训练单元310用于使用第一训练数据集开始对待训练模型进行迭代训练;数据筛选单元320用于筛选推理错误的样本数据和推理正确的样本数据;数据增强单元330用于对所筛选的推理错误的样本数据进行数据增强;生成单元340用于基于经数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集;训练单元310还用于使用第二训练数据集进行进一步训练,并且还用于在所述进一步训练后使用第一训练数据集继续进行迭代训练。
在一些实施方式中,数据筛选单元320可以具体用于:基于在至少一轮的迭代训练中推理错误的样本数据构成推理错误信息集,其中推理错误信息集中包括推理错误的样本数据的索引和在所述至少一轮的迭代训练中的推理错误置信度;以及在所述至少一轮的迭代训练中选择推理正确的样本数据以构成推理正确信息集,使得所选的推理正确的样本数据在标签信息和推理正确置信度方面分布基本均衡,其中推理正确信息集包括所选择的推理正确的样本数据的索引。
在一些实施方式中,数据筛选单元320可以具体用于基于推理错误信息集中的索引获取所筛选的推理错误的样本数据,且数据增强单元330可以进一步用于:对于每个所获取的推理错误的样本数据,进行与基于在推理错误信息集中的推理错误重复次数和各自的推理错误置信度正相关的数据增强;和/或对于每个所获取的推理错误的样本数据,基于与推理错误置信度正相关的加权值进行数据增强。
在一些实施方式中,装置300还可以包括获取单元,其用于基于推理正确信息集中的索引获取所筛选的推理正确的样本数据。
在一些实施方式中,训练单元310可以具体用于使用第二训练数据集进行至少一轮的训练。
在一些实施方式中,训练单元310可以具体用于在所述进一步训练后使用第一训练数据集继续进行迭代训练,直到训练完成。
在一些实施方式中,训练单元310、数据筛选单元320、数据增强单元330以及生成单元340可以具体用于以迭代方式重复执行训练优化操作,直到训练完成,其中训练优化操作包括:数据筛选单元320筛选推理错误的样本数据和推理正确的样本数据;数据增强单元330对所筛选的推理错误的样本数据进行数据增强;生成单元340基于经数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成优化训练数据集;以及训练单元310在使用优化训练数据集进行进一步训练后使用第一训练数据集继续进行迭代训练。
在一些实施方式中,数据筛选单元320可以具体用于:在确定第N轮迭代训练的结果满足预设条件时,从第N+1轮迭代训练开始确定每轮迭代训练中推理错误的样本数据以用于构成推理错误信息集,直到基于推理错误信息集确定截止到第M轮迭代训练的总推理错误次数超过预设次数,其中N是不小于1的正整数,M是大于N的正整数。
在一些实施方式中,数据筛选单元320可以具体用于:在基于推理错误信息集确定截止到第M轮迭代训练的总推理错误次数超过预设次数时,在第M轮迭代训练中确定推理正确的样本数据并在从中进行选择。
上述装置300的实施方式与前述方法的实施方式相对应。由此,上文针对方法描述的操作、特征及所能达到的技术效果同样适用于装置300及其中包含的单元,在此不再赘述。
本公开还提供一种电子设备和一种可读存储介质。
图4是根据本公开的一些实施方式的电子设备的示意性结构图。电子设备400旨在表示各种形式的数字计算机,诸如膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备400还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的实施方式。
如图4所示,该电子设备400包括:一个或多个处理器410、存储器420,以及用于连接各部件的接口,包括高速接口和低速接口。各个部件利用不同的总线互相连接,并且可以被安装在公共主板上或者根据需要以其它方式安装。处理器可以对在电子设备内执行的指令进行处理,包括存储在存储器中或者存储器上以在外部输入/输出装置(诸如,耦合至接口的显示设备)上显示GUI的图形信息的指令。在其它实施方式中,若需要,可以将多个处理器和/或多条总线与多个存储器和多个存储器一起使用。同样,可以连接多个电子设备,各个设备提供部分必要的操作(例如,作为服务器阵列、一组刀片式服务器、或者多处理器系统)。图4中以一个处理器410为例。
存储器420即为本公开所提供的非瞬时计算机可读存储介质。其中,存储器存储有可由至少一个处理器执行的指令,以使至少一个处理器执行本公开所提供的模型训练方法。本公开的非瞬时计算机可读存储介质存储计算机指令,该计算机指令用于使计算机执行本公开所提供的模型训练方法。
存储器420作为一种非瞬时计算机可读存储介质,可用于存储非瞬时软件程序、非瞬时计算机可执行程序以及模块,如本公开实施方式中的模型训练方法对应的程序指令/模块(例如,附图3所示的训练单元310、数据筛选单元320、数据增强单元330、生成单元340)。处理器410通过运行存储在存储器420中的非瞬时软件程序、指令以及模块,从而执行服务器的各种功能应用以及数据处理,即实现上述方法实施方式中的模型训练方法。
存储器420可以包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需要的应用程序;存储数据区可存储根据模型训练的电子设备的使用所创建的数据等。此外,存储器420可以包括高速随机存取存储器,还可以包括非瞬时存储器,例如至少一个磁盘存储器件、闪存器件、或其他非瞬时固态存储器件。在一些实施方式中,存储器420可选包括相对于处理器410远程设置的存储器,这些远程存储器可以通过网络连接至模型训练的电子设备。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
模型训练方法的电子设备还可以包括:输入装置430和输出装置440。处理器410、存储器420、输入装置430和输出装置440可以通过总线或者其他方式连接,图4中以通过总线450连接为例。
输入装置430可接收输入的数字或字符信息,以及产生与模型优化的电子设备的用户设置以及功能控制有关的键信号输入,例如触摸屏、小键盘、鼠标、轨迹板、触摸板、指示杆、一个或者多个鼠标按钮、轨迹球、操纵杆等输入装置。输出装置440可以包括显示设备、辅助照明装置(例如,LED)和触觉反馈装置(例如,振动电机)等。该显示设备可以包括但不限于,液晶显示器(LCD)、发光二极管(LED)显示器和等离子体显示器。在一些实施方式中,显示设备可以是触摸屏。
此处描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、专用ASIC(专用集成电路)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
这些计算程序(也称作程序、软件、软件应用、或者代码)包括可编程处理器的机器指令,并且可以利用高级过程和/或面向对象的编程语言、和/或汇编/机器语言来实施这些计算程序。如本文使用的,术语“机器可读介质”和“计算机可读介质”指的是用于将机器指令和/或数据提供给可编程处理器的任何计算机程序产品、设备、和/或装置(例如,磁盘、光盘、存储器、可编程逻辑装置(PLD)),包括,接收作为机器可读信号的机器指令的机器可读介质。术语“机器可读信号”指的是用于将机器指令和/或数据提供给可编程处理器的任何信号。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端可以是智能手机、平板电脑、笔记本电脑、台式计算机、智能音箱、智能手表等,但并不局限于此。服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云计算、云服务、云数据库、云存储等基础云计算服务的云服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。
根据本公开的技术方案能够在深度学习模型的训练过程中基于推理错误的样本数据进行动态调整生成新的训练数据集,从而优化深度学习模型的训练过程,进而改善深度学习模型的训练效率和训练后的预测性能。。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域普通技术人员应当明白,根据设计要求和其他因素,可以进行多种修改、组合、子组合和替代。任何在本公开的构思和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (12)
1.一种训练深度学习模型的方法,包括:
使用第一训练数据集开始对待训练模型进行迭代训练;
筛选推理错误的样本数据和推理正确的样本数据;
对所筛选的推理错误的样本数据进行数据增强;
基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集;以及
在使用所述第二训练数据集进行进一步训练后,使用所述第一训练数据集继续进行迭代训练。
2.根据权利要求1所述的方法,其中所述筛选推理错误的样本数据和推理正确的样本数据包括:
基于在至少一轮的迭代训练中推理错误的样本数据构成推理错误信息集,其中所述推理错误信息集中包括所述推理错误的样本数据的索引和在所述至少一轮的迭代训练中的推理错误置信度;以及
在所述至少一轮的迭代训练中选择推理正确的样本数据以构成推理正确信息集,使得所选的推理正确的样本数据在标签信息和推理正确置信度方面分布基本均衡,其中所述推理正确信息集包括所选择的推理正确的样本数据的索引。
3.根据权利要求2所述的方法,其中所述对所筛选的推理错误的样本数据进行数据增强包括:基于所述推理错误信息集中的索引获取所筛选的推理错误的样本数据,
其中所述对所筛选的推理错误的样本数据进行数据增强还包括:
对于每个所获取的推理错误的样本数据,进行与基于在所述推理错误信息集中的推理错误重复次数和各自的推理错误置信度正相关的数据增强;和/或
对于每个所获取的推理错误的样本数据,基于与所述推理错误置信度正相关的加权值进行数据增强。
4.根据权利要求2所述的方法,其中在基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集之前,所述方法还包括:
基于所述推理正确信息集中的索引获取所筛选的推理正确的样本数据。
5.根据权利要求1所述的方法,其中使用所述第二训练数据集进行至少一轮的训练。
6.根据权利要求1所述的方法,其中在使用所述第二训练数据集进行进一步训练后,使用所述第一训练数据集继续进行迭代训练,直到训练完成。
7.根据权利要求1所述的方法,还包括:以迭代方式重复执行训练优化操作,直到训练完成,其中所述训练优化操作包括:
筛选推理错误的样本数据和推理正确的样本数据;
对所筛选的推理错误的样本数据进行数据增强;
基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成优化训练数据集;以及
在使用所述优化训练数据集进行进一步训练后,使用所述第一训练数据集继续进行迭代训练。
8.根据权利要求2所述的方法,其中所述基于在至少一轮的迭代训练中推理错误的样本数据构成推理错误信息集包括:
在确定第N轮迭代训练的结果满足预设条件时,从第N+1轮迭代训练开始确定每轮迭代训练中推理错误的样本数据以用于构成所述推理错误信息集,直到基于所述推理错误信息集确定截止到第M轮迭代训练的总推理错误次数超过预设次数,其中N是不小于1的正整数,M是大于N的正整数。
9.根据权利要求8所述的方法,其中所述在所述至少一轮的迭代训练中选择推理正确的样本数据以构成推理正确信息集包括:
在基于所述推理错误信息集确定截止到第M轮迭代训练的总推理错误次数超过预设次数时,在第M轮迭代训练中确定推理正确的样本数据并在从中进行选择。
10.一种用于训练深度学习模型的装置,包括训练单元、数据筛选单元、数据增强单元以及生成单元,其中:
所述训练单元用于使用第一训练数据集开始对待训练模型进行迭代训练;
所述数据筛选单元用于筛选推理错误的样本数据和推理正确的样本数据;
所述数据增强单元用于对所筛选的推理错误的样本数据进行数据增强;
所述生成单元用于基于经所述数据增强的推理错误的样本数据和所筛选的推理正确的样本数据生成第二训练数据集;
所述训练单元还用于使用所述第二训练数据集进行进一步训练,并且还用于在所述进一步训练后使用所述第一训练数据集继续进行迭代训练。
11.一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器,其中所述存储器存储有可被所述至少一个处理器执行的指令,所述指令在被所述至少一个处理器执行时使得所述电子设备执行如权利要求1至9中任一项所述的训练深度学习模型的方法。
12.一种存储有计算机指令的非瞬时计算机可读存储介质,其中所述计算机指令用于使计算机执行如权利要求1至9中任一项所述的训练深度学习模型的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111683696.9A CN114492788A (zh) | 2021-12-31 | 2021-12-31 | 训练深度学习模型的方法和装置、电子设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111683696.9A CN114492788A (zh) | 2021-12-31 | 2021-12-31 | 训练深度学习模型的方法和装置、电子设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114492788A true CN114492788A (zh) | 2022-05-13 |
Family
ID=81509575
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111683696.9A Pending CN114492788A (zh) | 2021-12-31 | 2021-12-31 | 训练深度学习模型的方法和装置、电子设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114492788A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115031363A (zh) * | 2022-05-27 | 2022-09-09 | 约克广州空调冷冻设备有限公司 | 预测空调性能的方法和装置 |
CN115391450A (zh) * | 2022-08-26 | 2022-11-25 | 百度在线网络技术(北京)有限公司 | 推理信息生成方法、装置、设备、可读存储介质及产品 |
CN116010669A (zh) * | 2023-01-18 | 2023-04-25 | 深存科技(无锡)有限公司 | 向量库重训练的触发方法、装置、检索服务器及存储介质 |
-
2021
- 2021-12-31 CN CN202111683696.9A patent/CN114492788A/zh active Pending
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115031363A (zh) * | 2022-05-27 | 2022-09-09 | 约克广州空调冷冻设备有限公司 | 预测空调性能的方法和装置 |
CN115031363B (zh) * | 2022-05-27 | 2023-11-28 | 约克广州空调冷冻设备有限公司 | 预测空调性能的方法和装置 |
CN115391450A (zh) * | 2022-08-26 | 2022-11-25 | 百度在线网络技术(北京)有限公司 | 推理信息生成方法、装置、设备、可读存储介质及产品 |
CN115391450B (zh) * | 2022-08-26 | 2024-01-09 | 百度在线网络技术(北京)有限公司 | 推理信息生成方法、装置、设备、可读存储介质及产品 |
CN116010669A (zh) * | 2023-01-18 | 2023-04-25 | 深存科技(无锡)有限公司 | 向量库重训练的触发方法、装置、检索服务器及存储介质 |
CN116010669B (zh) * | 2023-01-18 | 2023-12-08 | 深存科技(无锡)有限公司 | 向量库重训练的触发方法、装置、检索服务器及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111539223B (zh) | 语言模型的训练方法、装置、电子设备及可读存储介质 | |
CN111737994B (zh) | 基于语言模型获取词向量的方法、装置、设备及存储介质 | |
US20210201198A1 (en) | Method, electronic device, and storage medium for generating node representations in heterogeneous graph | |
US11928432B2 (en) | Multi-modal pre-training model acquisition method, electronic device and storage medium | |
CN111737995B (zh) | 基于多种词向量训练语言模型的方法、装置、设备及介质 | |
US11526668B2 (en) | Method and apparatus for obtaining word vectors based on language model, device and storage medium | |
US20210166136A1 (en) | Method, apparatus, electronic device and storage medium for obtaining question-answer reading comprehension model | |
CN111967256B (zh) | 事件关系的生成方法、装置、电子设备和存储介质 | |
CN111708922A (zh) | 用于表示异构图节点的模型生成方法及装置 | |
CN114492788A (zh) | 训练深度学习模型的方法和装置、电子设备及存储介质 | |
CN111667056B (zh) | 用于搜索模型结构的方法和装置 | |
CN111709252B (zh) | 基于预训练的语义模型的模型改进方法及装置 | |
CN111860769A (zh) | 预训练图神经网络的方法以及装置 | |
CN111738419B (zh) | 神经网络模型的量化方法和装置 | |
CN113723278B (zh) | 表格信息提取模型的训练方法及装置 | |
CN111931520B (zh) | 自然语言处理模型的训练方法和装置 | |
CN111127191B (zh) | 风险评估方法及装置 | |
US20220300763A1 (en) | Method, apparatus, electronic device and storage medium for training semantic similarity model | |
CN112560499B (zh) | 语义表示模型的预训练方法、装置、电子设备及存储介质 | |
CN111783949A (zh) | 基于迁移学习的深度神经网络的训练方法和装置 | |
KR20220003444A (ko) | 옵티마이저 학습 방법, 장치, 전자 기기 및 판독 가능 기록 매체 | |
CN111667428A (zh) | 基于自动搜索的噪声生成方法和装置 | |
CN111783872B (zh) | 训练模型的方法、装置、电子设备及计算机可读存储介质 | |
CN111160552B (zh) | 新闻信息的推荐处理方法、装置、设备和计算机存储介质 | |
CN111914882A (zh) | 支持向量机的生成方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |