CN114418130B - 一种模型训练方法、数据处理方法及相关设备 - Google Patents

一种模型训练方法、数据处理方法及相关设备 Download PDF

Info

Publication number
CN114418130B
CN114418130B CN202210321507.1A CN202210321507A CN114418130B CN 114418130 B CN114418130 B CN 114418130B CN 202210321507 A CN202210321507 A CN 202210321507A CN 114418130 B CN114418130 B CN 114418130B
Authority
CN
China
Prior art keywords
data
target data
output value
inputting
data processing
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210321507.1A
Other languages
English (en)
Other versions
CN114418130A (zh
Inventor
刘�东
马海川
吴枫
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Science and Technology of China USTC
Original Assignee
University of Science and Technology of China USTC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Science and Technology of China USTC filed Critical University of Science and Technology of China USTC
Priority to CN202210321507.1A priority Critical patent/CN114418130B/zh
Publication of CN114418130A publication Critical patent/CN114418130A/zh
Application granted granted Critical
Publication of CN114418130B publication Critical patent/CN114418130B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Electrically Operated Instructional Devices (AREA)

Abstract

本发明实施例提供了一种模型训练方法、数据处理方法及相关设备。其中,该方法用于训练机器学习模型,机器学习模型包括生成器和判别器,该方法包括:将第一目标数据输入生成器中进行第一数据处理,获得第二目标数据,将第二目标数据输入判别器中,得到第一输出值;第一目标数据是对第三目标数据进行第二数据处理后得到的数据;第一数据处理为第二数据处理的反向处理过程;第一目标数据和第三目标数据为初始训练数据集中的数据;将第三目标数据输入判别器中,得到第二输出值;若第二输出值大于第一输出值,则将第一目标数据和第三目标数据放入当前训练数据集中;使用当前训练数据集对机器学习模型进行训练。本发明可以提高模型训练效果。

Description

一种模型训练方法、数据处理方法及相关设备
技术领域
本发明涉及计算机处理技术领域,特别是涉及一种模型训练方法、数据处理方法及相关设备。
背景技术
在对机器学习模型(如生成对抗网络模型)进行训练时,通常会从训练数据集中随机选取一些数据对该模型进行训练,这样会导致模型训练效果不佳。
发明内容
本发明实施例的目的在于提供一种模型训练方法、数据处理方法及相关设备,能够提高模型训练效果。具体技术方案如下:
本发明提供了一种模型训练方法,所述模型训练方法用于训练机器学习模型,所述机器学习模型包括生成器和判别器,所述模型训练方法包括:
将第一目标数据输入所述生成器中进行第一数据处理,获得第二目标数据,将所述第二目标数据输入所述判别器中,得到第一输出值;所述第一目标数据是对第三目标数据进行第二数据处理后得到的数据;所述第一数据处理为所述第二数据处理的反向处理过程;所述第一目标数据和所述第三目标数据为初始训练数据集中的数据;
将所述第三目标数据输入所述判别器中,得到第二输出值;
若所述第二输出值大于所述第一输出值,则将所述第一目标数据和第三目标数据放入当前训练数据集中;
使用所述当前训练数据集对所述机器学习模型进行训练。
可选地,所述使用所述当前训练数据集对所述机器学习模型进行训练,具体包括:
从所述当前训练数据集中选取第四目标数据和第五目标数据;所述第四目标数据是对所述第五目标数据进行所述第二数据处理后得到的数据;
将所述第四目标数据输入所述生成器中进行第一数据处理,获得第六目标数据,将所述第六目标数据输入所述判别器中,得到第三输出值;
将所述第五目标数据输入所述判别器中,得到第四输出值;
将所述第三输出值和所述第四输出值输入生成器损失函数中,得到第一损失值;
基于所述第一损失值对所述生成器的参数进行更新;
判断生成器参数更新次数是否达到第一预设次数;
若达到所述第一预设次数,则结束训练过程;
若未达到所述第一预设次数,则返回步骤“将第一目标数据输入所述生成器中进行第一数据处理”。
可选地,在将第一目标数据输入所述生成器中进行第一数据处理之前,所述模型训练方法还包括:
将第七目标数据输入所述生成器中进行所述第一数据处理,获得第八目标数据,将所述第八目标数据输入所述判别器中,得到第五输出值;所述第七目标数据是对第九目标数据进行所述第二数据处理后得到的数据;所述第七目标数据和所述第九目标数据为所述初始训练数据集中的数据;
将所述第九目标数据输入所述判别器中,得到第六输出值;
将所述第五输出值和所述第六输出值输入判别器损失函数中,得到第二损失值;
基于所述第二损失值对所述判别器的参数进行更新。
可选地,在所述基于所述第二损失值对所述判别器的参数进行更新之后,所述方法还包括:
判断判别器参数更新次数是否达到第二预设次数;
若达到所述第二预设次数,则返回步骤“将第一目标数据输入所述生成器中进行第一数据处理”;
若未达到所述第二预设次数,则更换所述第七目标数据,然后返回步骤“将第七目标数据输入所述生成器中进行所述第一数据处理”。
可选地,所述将所述第三输出值和所述第四输出值输入生成器损失函数中,得到第一损失值,具体包括:采用方式一或方式二得到所述第一损失值;
方式一:将所述第三输出值和所述第四输出值输入基于JS散度获得的生成器损失函数中,得到第一损失值;
方式二:将所述第三输出值和所述第四输出值输入基于Wasserstein距离获得的生成器损失函数中,得到第一损失值;
其中,
基于JS散度获得的生成器损失函数为:
Figure 100002_DEST_PATH_IMAGE002
式中,
Figure 100002_DEST_PATH_IMAGE004
为基于JS散度获得的生成器损失函数,
Figure 100002_DEST_PATH_IMAGE006
为第四目标数据,
Figure 100002_DEST_PATH_IMAGE008
为第六目标数据,
Figure 100002_DEST_PATH_IMAGE010
为第三输出值,
Figure 100002_DEST_PATH_IMAGE012
为第五目标数据,
Figure 100002_DEST_PATH_IMAGE014
为第四输出值,
Figure 100002_DEST_PATH_IMAGE016
为第四目标数据服从第一分布
Figure 100002_DEST_PATH_IMAGE018
的期望,第一分布
Figure 984936DEST_PATH_IMAGE018
为第四目标数据的概率分布,
Figure 100002_DEST_PATH_IMAGE020
为第五目标数据服从第二分布
Figure 100002_DEST_PATH_IMAGE022
的期望,第二分布
Figure 777443DEST_PATH_IMAGE022
为第五目标数据的概率分布;
基于Wasserstein距离获得的生成器损失函数为:
Figure 100002_DEST_PATH_IMAGE024
式中,
Figure 100002_DEST_PATH_IMAGE026
为基于Wasserstein距离获得的生成器损失函数。
可选地,所述将所述第五输出值和所述第六输出值输入判别器损失函数中,得到第二损失值,具体包括:采用方式三或方式四得到所述第二损失值;
方式三:将所述第五输出值和所述第六输出值输入基于JS散度获得的判别器损失函数中,得到第二损失值;
方式四:将所述第五输出值和所述第六输出值输入基于Wasserstein距离获得的判别器损失函数中,得到第二损失值;
其中,
基于JS散度获得的判别器损失函数为:
Figure 100002_DEST_PATH_IMAGE028
式中,
Figure 100002_DEST_PATH_IMAGE030
为基于JS散度获得的判别器损失函数,
Figure 100002_DEST_PATH_IMAGE032
为第七目标数据,
Figure 100002_DEST_PATH_IMAGE034
为第八目 标数据,
Figure 100002_DEST_PATH_IMAGE036
为第五输出值,
Figure 100002_DEST_PATH_IMAGE038
为第九目标数据,
Figure 100002_DEST_PATH_IMAGE040
为第六输出值,
Figure 100002_DEST_PATH_IMAGE042
为第七目标 数据服从第三分布
Figure 100002_DEST_PATH_IMAGE044
的期望,第三分布
Figure DEST_PATH_IMAGE045
为第七目标数据的概率分布,
Figure DEST_PATH_IMAGE047
为第九目 标数据服从第四分布
Figure DEST_PATH_IMAGE049
的期望,第四分布
Figure 793065DEST_PATH_IMAGE049
为第九目标数据的概率分布;
基于Wasserstein距离获得的判别器损失函数为:
Figure DEST_PATH_IMAGE051
式中,
Figure DEST_PATH_IMAGE053
为基于Wasserstein距离获得的判别器损失函数。
本发明还提供一种数据处理方法,包括:
将待处理数据输入数据处理模型中,得到处理后的数据;其中,所述数据处理模型是通过上述的模型训练方法获得的。
本发明还提供一种模型训练系统,所述系统用于训练机器学习模型,所述机器学习模型包括生成器和判别器,所述系统包括:
第一输入模块,用于将第一目标数据输入所述生成器中进行第一数据处理,获得第二目标数据,将所述第二目标数据输入所述判别器中,得到第一输出值;所述第一目标数据是对第三目标数据进行第二数据处理后得到的数据;所述第一数据处理为所述第二数据处理的反向处理过程;所述第一目标数据和所述第三目标数据为初始训练数据集中的数据;
第二输入模块,用于将所述第三目标数据输入所述判别器中,得到第二输出值;
判断模块,用于在所述第二输出值大于所述第一输出值时,将所述第一目标数据和第三目标数据放入当前训练数据集中;
训练模块,用于使用所述当前训练数据集对所述机器学习模型进行训练。
本发明还提供一种计算机可读存储介质,所述计算机可读存储介质上存储有程序,所述程序被处理器执行时实现上述的模型训练方法和/或上述的数据处理方法。
本发明还提供一种电子设备,包括:
至少一个处理器、以及与所述处理器连接的至少一个存储器、总线;
所述处理器、所述存储器通过所述总线完成相互间的通信;所述处理器用于调用所述存储器中的程序指令,以执行上述的模型训练方法和/或上述的数据处理方法。
本发明实施例提供的一种模型训练方法、数据处理方法及相关设备,由于训练模型的目的是希望生成器输出期望数据,本发明中,第一目标数据通过生成器进行第一数据处理得到第二目标数据,说明第一目标数据并非期望数据;由于第三目标数据可以通过第二数据处理得到第一目标数据,而第一数据处理为第二数据处理的反向处理过程,因此说明第三目标数据为期望数据。当第三目标数据(即期望数据)输入判别器后,得到的第二输出值应当大于第一目标数据(即非期望数据)输入模型中得到的第一输出值。若第二输出值小于第一输出值,则说明输入模型的第一目标数据会导致模型输出结果出错,此时若使用这些数据对模型训练会影响模型参数的生成,进而导致训练效果不佳。因此,本发明以第二输出值大于第一输出值这一条件作为对初始训练数据集中数据的筛选条件,将满足条件的第一目标数据和第三目标数据放入当前训练数据集中,从而使用当前训练数据集对模型进行训练,可以提高模型训练效果。
当然,实施本发明的任一产品或方法必不一定需要同时达到以上所述的所有优点。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本发明实施例提供的模型训练方法流程图;
图2为本发明实施例提供的图像恢复模型训练方法流程图;
图3为本发明实施例提供的模型训练系统结构图;
图4为本发明实施例提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明提供一种模型训练方法,该模型训练方法用于训练机器学习模型,机器学习模型包括生成器和判别器,如图1所示,该模型训练方法包括:
步骤101:将第一目标数据输入生成器中进行第一数据处理,获得第二目标数据,将第二目标数据输入判别器中,得到第一输出值;第一目标数据是对第三目标数据进行第二数据处理后得到的数据;第一数据处理为第二数据处理的反向处理过程;第一目标数据和第三目标数据为初始训练数据集中的数据。
步骤102:将第三目标数据输入判别器中,得到第二输出值。
步骤103:若第二输出值大于第一输出值,则将第一目标数据和第三目标数据放入当前训练数据集中。
步骤104:使用当前训练数据集对机器学习模型进行训练。
机器学习模型包括生成器和判别器,其中,生成器的作用是尽量生成期望数据,而判别器的作用是判断生成器生成的数据是否为期望数据,判别器输出的是生成器生成数据为期望数据的概率。
在本实施例中,将第一目标数据输入生成器后得到第二目标数据,该第一目标数据可以为非期望数据,在生成器中进行的第一数据处理即是生成器为了输出期望数据而进行的数据处理过程,而该第二目标数据则是生成器生成的与期望数据具有一定相似性的数据。
将第二目标数据输入判别器后得到第一输出值,该第一输出值为判别器得到的第二目标数据为期望数据的概率,第一输出值越大说明生成器得到的第二目标数据越接近期望数据。
在生成器中对第一目标数据采用第一数据处理是为了得到期望数据,而第三目标数据采用第二数据处理可以得到第一目标数据,由于第一数据处理为第二数据处理的反向处理过程,因此,第三目标数据可以为期望数据,将第一目标数据输入生成器后得到的第二目标数据与第三目标数据相似度越高,说明生成器训练效果越好。
将第三目标数据输入判别器后得到第二输出值,该第二输出值为判别器得到的第三目标数据为期望数据的概率,由于第三目标数据可以为期望数据,因此,理论上来说,判别器基于第三目标数据得到的第二输出值会大于该判别器基于第二目标数据得到的第一输出值。
若第二输出值小于第一输出值,则说明基于第一目标数据得到的判别器输出结果出现错误,如果将第一目标数据投入当前训练数据集以对机器学习模型进行训练,则会影响模型参数的训练结果,进而导致训练效果不佳。基于此,本发明在第二输出值大于第一输出值的情况下,才会将第一目标数据和第三目标数据放入当前训练数据集中,以便使用当前训练数据集对机器学习模型进行训练。因此,本发明通过对初始训练集中数据进行筛选,有利于提升模型训练效果。
目标数据可以包括图像数据、音频数据、视频数据、文本数据、尺寸数据、用户信息数据中的至少一种。数据处理可以包括:预设图像处理、预设音频处理、预设视频处理、预设文本处理、预设尺寸处理、预设用户信息处理中的至少一种。
可选的,第一图像数据可以包括降质图像,第一预设图像处理可以包括图像恢复处理,第二图像数据可以包括恢复图像,第三图像数据可以包括原始图像,第二预设图像处理可以包括图像降质处理。
可选的,第一文本数据可以包括目标文本,第一预设文本处理可以包括文本筛选处理,第二文本数据可以包括筛选后的文本,第三文本数据可以包括感兴趣文本,第二预设文本处理可以包括文本扩充处理。
在一可选的实施方式中,该模型训练方法可应用于图像恢复模型中。目前在存储图像时,为了减小图像的存储空间,会对原始图像进行压缩,压缩后的图像会导致图像质量降低。为了提高图像质量(如图像分辨率),使降质图像尽可能恢复到原始图像,通常会利用训练好的图像恢复模型对降质图像进行图像恢复。然而,在训练图像恢复模型时,如果从降质图像集合中随机选取降质图像对模型进行训练,这样会导致图像恢复效果不佳。基于此,本发明提供一种可以提高图像恢复效果的图像恢复模型训练方法,具体方案如下。
在进行图像恢复模型训练时,第一目标数据可以为降质图像,第一数据处理可以为图像质量恢复处理,第二目标数据可以为恢复图像,也就是说,将降质图像输入生成器中进行图像质量恢复处理,获得恢复图像。而第三目标数据可以为原始图像,第二数据处理可以为图像降质处理(如图像压缩),该原始图像的图像质量(如图像分辨率)会高于上述的降质图像,而该降质图像可以是该原始图像经过图像降质处理得到的。
基于此,本发明提供一种图像恢复模型训练方法,如图2所示,该方法用于训练图像恢复模型,该图像恢复模型包括生成器和判别器,该方法包括如下步骤:
步骤201:将降质图像输入生成器中进行图像质量恢复处理,获得恢复图像,将恢复图像输入判别器中,得到第一数值;其中,降质图像是对原始图像进行图像降质处理后得到的,降质图像和原始图像为初始训练图像集中的图像。
步骤202:将原始图像输入判别器中,得到第二数值。
步骤203:若第二数值大于第一数值,则将降质图像和原始图像放入当前训练图像集中。
步骤204:使用当前训练图像集对图像恢复模型进行训练。
将恢复图像输入判别器后得到的第一数值可以是恢复图像为原始图像的概率,概率越大说明恢复图像与原始图像在图像质量方面越接近,当然,该概率值也可以用分数值进行表示,分数值越高,两个图像的图像质量越接近。
由于恢复图像的图像质量理论上不会超过原始图像,因此,将原始图像输入判别器后得到的第二数值应大于将恢复图像输入判别器后得到的第一数值。若第二数值小于第一数值,则说明将降质图像输入图像恢复模型后得到的第一数值出现错误,如果将该降质图像用于训练图像恢复模型,则会对图像恢复模型的参数生成产生一定影响,可能导致训练出的图像恢复模型效果不佳,进而会影响图像恢复质量。因此,本发明在第二数值大于第一数值的条件下,将降质图像和原始图像放入当前训练图像集中,以便使用当前训练图像集对图像恢复模型进行训练。
在另一可选的实施方式中,该模型训练方法可用于文本筛选模型中。在进行文本处理时,第一目标数据可以为目标文本(如法律文书),该目标文本中含有感兴趣文本(如案由),第一数据处理可以为感兴趣文本筛选处理,第二目标数据可以为筛选后的文本,第三目标数据可以为感兴趣文本,第二数据处理可以为文本扩充处理(即将仅含有感兴趣文本扩充为除含感兴趣文本以外还有其他文本内容的文本,如将案由扩充为法律文书)。将筛选后的文本输入判别器后得到第三数值,将感兴趣文本输入判别器中得到第四数值,由于筛选后的文本的准确度理论上不会超过感兴趣文本,因此,第四数值理论上应大于第三数值。若第四数值小于第三数值,则说明将目标文本输入文本筛选模型后得到的第三数值出现错误,如果将该目标文本用于训练文本筛选模型,则会对文本筛选模型的参数生成产生一定影响,可能导致训练出的文本筛选模型效果不佳,进而影响文本筛选效果。因此,本发明在第四数值大于第三数值的条件下,将目标文本和感兴趣文本放入当前训练文本集中,以便使用当前训练文本集对文本筛选模型进行训练。
此外,本发明提供的模型训练方法还可以应用于音频处理模型、视频处理模型、尺寸数据处理模型、用户信息处理模型等中。
在本实施例中,步骤104:使用当前训练数据集对机器学习模型进行训练,具体包括:
从当前训练数据集中选取第四目标数据和第五目标数据;第四目标数据是对第五目标数据进行第二数据处理后得到的数据;将第四目标数据输入生成器中进行第一数据处理,获得第六目标数据,将第六目标数据输入判别器中,得到第三输出值;将第五目标数据输入判别器中,得到第四输出值;将第三输出值和第四输出值输入生成器损失函数中,得到第一损失值;基于第一损失值对生成器的参数进行更新;判断生成器参数更新次数是否达到第一预设次数;若达到第一预设次数,则结束训练过程;若未达到第一预设次数,则返回步骤101。
上述步骤是对生成器进行训练,通过生成器参数更新次数判断是否可以结束模型训练过程,如果可以结束训练,则将本次的生成器参数作为模型训练后的结果,可选地,例如该本次生成器参数可以对待处理数据进行第一数据处理。如果次数未达到第一预设次数,则返回步骤101继续模型的训练。
可选地,基于第一损失值对生成器的参数进行更新,具体包括:计算第一损失值对生成器的参数的梯度,得到第一梯度,基于第一梯度对生成器的参数进行更新。
本发明的模型训练方法还包括:将第七目标数据输入生成器中进行第一数据处理,获得第八目标数据,将第八目标数据输入判别器中,得到第五输出值;第七目标数据是对第九目标数据进行第二数据处理后得到的数据;第七目标数据和第九目标数据为初始训练数据集中的数据;将第九目标数据输入判别器中,得到第六输出值;将第五输出值和第六输出值输入判别器损失函数中,得到第二损失值;基于第二损失值对判别器的参数进行更新。
上述步骤是对判别器进行训练,对判别器进行训练的步骤可以在步骤101之前进行,也可以在执行完步骤104之后,返回步骤101之前进行。
可选地,基于第二损失值对判别器的参数进行更新,具体包括:计算第二损失值对判别器的参数的梯度,得到第二梯度,基于第二梯度对判别器的参数进行更新。
作为一可选的实施方式,在基于第二损失值对判别器的参数进行更新之后,该方法还包括:判断判别器参数更新次数是否达到第二预设次数;若达到第二预设次数,则返回步骤101;若未达到第二预设次数,则更换第七目标数据,然后返回步骤“将第七目标数据输入生成器中进行第一数据处理”。
在本实施例中,将第三输出值和第四输出值输入生成器损失函数中,得到第一损失值,具体包括:采用方式一或方式二得到第一损失值。
方式一:将第三输出值和第四输出值输入基于JS散度获得的生成器损失函数中,得到第一损失值。
方式二:将第三输出值和第四输出值输入基于Wasserstein距离获得的生成器损失函数中,得到第一损失值。
可选地,基于JS散度获得的生成器损失函数为:
Figure DEST_PATH_IMAGE055
式中,
Figure 876690DEST_PATH_IMAGE004
为基于JS散度获得的生成器损失函数,
Figure 284538DEST_PATH_IMAGE006
为第四目标数据,
Figure 575842DEST_PATH_IMAGE008
为第六目标数据,
Figure 390345DEST_PATH_IMAGE010
为第三输出值,
Figure 182721DEST_PATH_IMAGE012
为第五目标数据,
Figure 320441DEST_PATH_IMAGE014
为第四输出值,
Figure 923592DEST_PATH_IMAGE016
为第四目标数据服从第一分布
Figure 287708DEST_PATH_IMAGE018
的期望,第一分布
Figure 493562DEST_PATH_IMAGE018
为第四目标数据的概率分布,
Figure 876001DEST_PATH_IMAGE020
为第五目标数据服从第二分布
Figure 774687DEST_PATH_IMAGE022
的期望,第二分布
Figure 360520DEST_PATH_IMAGE022
为第五目标数据的概率分布。
可选地,基于Wasserstein距离获得的生成器损失函数为:
Figure DEST_PATH_IMAGE057
式中,
Figure 245431DEST_PATH_IMAGE026
为基于Wasserstein距离获得的生成器损失函数。
在本实施例中,将第五输出值和第六输出值输入判别器损失函数中,得到第二损失值,具体包括:采用方式三或方式四得到第二损失值。
方式三:将第五输出值和第六输出值输入基于JS散度获得的判别器损失函数中,得到第二损失值。
方式四:将第五输出值和第六输出值输入基于Wasserstein距离获得的判别器损失函数中,得到第二损失值。
可选地,基于JS散度获得的判别器损失函数为:
Figure DEST_PATH_IMAGE059
式中,
Figure 685640DEST_PATH_IMAGE030
为基于JS散度获得的判别器损失函数,
Figure 568276DEST_PATH_IMAGE032
为第七目标数据,
Figure 890673DEST_PATH_IMAGE034
为第八目标数据,
Figure 438329DEST_PATH_IMAGE036
为第五输出值,
Figure 546093DEST_PATH_IMAGE038
为第九目标数据,
Figure 786582DEST_PATH_IMAGE040
为第六输出值,
Figure DEST_PATH_IMAGE061
为第七目标数据服从第三分布
Figure DEST_PATH_IMAGE063
的期望,第三分布
Figure 347007DEST_PATH_IMAGE063
为第七目标数据的概率分布,
Figure DEST_PATH_IMAGE065
为第九目标数据服从第四分布
Figure DEST_PATH_IMAGE067
的期望,第四分布
Figure 889898DEST_PATH_IMAGE067
为第九目标数据的概率分布。
可选地,基于Wasserstein距离获得的判别器损失函数为:
Figure DEST_PATH_IMAGE069
式中,
Figure 852169DEST_PATH_IMAGE053
为基于Wasserstein距离获得的判别器损失函数。
本发明还提供一种数据处理方法,该方法包括:将待处理数据输入数据处理模型中,得到处理后的数据;其中,数据处理模型是通过上述模型训练方法获得的。
作为一可选的实施方式,在模型为图像恢复模型时,该数据处理方法可以为图像恢复方法,该方法可以包括:将待恢复图像输入图像恢复模型中,得到恢复图像。其中,该图像恢复模型是通过图2所示的模型训练方法得到的。
作为另一可选的实施方式,在模型为文本筛选模型时,该数据处理方法可以为文本筛选方法,该方法可以包括:将待处理文本输入文本筛选模型中,得到感兴趣文本。
此外,本发明提供的数据处理方法还可以为音频处理方法、视频处理方法、尺寸数据处理方法、用户信息处理方法等。
本发明还提供一种模型训练系统,该系统用于训练机器学习模型,机器学习模型包括生成器和判别器,如图3所示,该系统包括:
第一输入模块301,用于将第一目标数据输入生成器中进行第一数据处理,获得第二目标数据,将第二目标数据输入判别器中,得到第一输出值;第一目标数据是对第三目标数据进行第二数据处理后得到的数据;第一数据处理为第二数据处理的反向处理过程;第一目标数据和第三目标数据为初始训练数据集中的数据。
第二输入模块302,用于将第三目标数据输入判别器中,得到第二输出值。
判断模块303,用于在第二输出值大于第一输出值时,将第一目标数据和第三目标数据放入当前训练数据集中。
训练模块304,用于使用当前训练数据集对机器学习模型进行训练。
训练模块304,具体包括:
生成器训练单元,用于从当前训练数据集中选取第四目标数据和第五目标数据;第四目标数据是对第五目标数据进行第二数据处理后得到的数据;将第四目标数据输入生成器中进行第一数据处理,获得第六目标数据,将第六目标数据输入判别器中,得到第三输出值;将第五目标数据输入判别器中,得到第四输出值;将第三输出值和第四输出值输入生成器损失函数中,得到第一损失值;基于第一损失值对生成器的参数进行更新;判断生成器参数更新次数是否达到第一预设次数;若达到第一预设次数,则结束训练过程;若未达到第一预设次数,则执行第一输入模块301。
生成器训练单元,具体用于采用方式一或方式二得到第一损失值;
方式一:将第三输出值和第四输出值输入基于JS散度获得的生成器损失函数中,得到第一损失值。
方式二:将第三输出值和第四输出值输入基于Wasserstein距离获得的生成器损失函数中,得到第一损失值。
其中,
基于JS散度获得的生成器损失函数为:
Figure DEST_PATH_IMAGE055A
式中,
Figure 591455DEST_PATH_IMAGE004
为基于JS散度获得的生成器损失函数,
Figure 576859DEST_PATH_IMAGE006
为第四目标数据,
Figure 590952DEST_PATH_IMAGE008
为第六目标数据,
Figure 532363DEST_PATH_IMAGE010
为第三输出值,
Figure 255599DEST_PATH_IMAGE012
为第五目标数据,
Figure 649672DEST_PATH_IMAGE014
为第四输出值,
Figure 670717DEST_PATH_IMAGE016
为第四目标数据服从第一分布
Figure 76422DEST_PATH_IMAGE018
的期望,第一分布
Figure 829614DEST_PATH_IMAGE018
为第四目标数据的概率分布,
Figure 101196DEST_PATH_IMAGE020
为第五目标数据服从第二分布
Figure 598036DEST_PATH_IMAGE022
的期望,第二分布
Figure 123826DEST_PATH_IMAGE022
为第五目标数据的概率分布。
基于Wasserstein距离获得的生成器损失函数为:
Figure DEST_PATH_IMAGE057A
式中,
Figure 923286DEST_PATH_IMAGE026
为基于Wasserstein距离获得的生成器损失函数。
该系统还包括:
判别器训练模块,用于将第七目标数据输入生成器中进行第一数据处理,获得第八目标数据,将第八目标数据输入判别器中,得到第五输出值;第七目标数据是对第九目标数据进行第二数据处理后得到的数据;第七目标数据和第九目标数据为初始训练数据集中的数据;将第九目标数据输入判别器中,得到第六输出值;将第五输出值和第六输出值输入判别器损失函数中,得到第二损失值;基于第二损失值对判别器的参数进行更新。
判别器训练模块,具体用于采用方式三或方式四得到第二损失值;
方式三:将第五输出值和第六输出值输入基于JS散度获得的判别器损失函数中,得到第二损失值。
方式四:将第五输出值和第六输出值输入基于Wasserstein距离获得的判别器损失函数中,得到第二损失值。
其中,
基于JS散度获得的判别器损失函数为:
Figure DEST_PATH_IMAGE059A
式中,
Figure 619847DEST_PATH_IMAGE030
为基于JS散度获得的判别器损失函数,
Figure 733428DEST_PATH_IMAGE032
为第七目标数据,
Figure 628571DEST_PATH_IMAGE034
为第八目标数据,
Figure 457987DEST_PATH_IMAGE036
为第五输出值,
Figure 454893DEST_PATH_IMAGE038
为第九目标数据,
Figure 293536DEST_PATH_IMAGE040
为第六输出值,
Figure 777607DEST_PATH_IMAGE061
为第七目标数据服从第三分布
Figure 122132DEST_PATH_IMAGE063
的期望,第三分布
Figure 199809DEST_PATH_IMAGE063
为第七目标数据的概率分布,
Figure 966777DEST_PATH_IMAGE065
为第九目标数据服从第四分布
Figure 711879DEST_PATH_IMAGE067
的期望,第四分布
Figure 24043DEST_PATH_IMAGE067
为第九目标数据的概率分布。
基于Wasserstein距离获得的判别器损失函数为:
Figure DEST_PATH_IMAGE069A
式中,
Figure DEST_PATH_IMAGE071
为基于Wasserstein距离获得的判别器损失函数。
判别器训练模块,还用于判断判别器参数更新次数是否达到第二预设次数;若达到第二预设次数,则执行第一输入模块301;若未达到第二预设次数,则更换第七目标数据,然后返回步骤“将第七目标数据输入生成器中进行第一数据处理”。
本发明实施例提供了一种计算机可读存储介质,其上存储有程序,该程序被处理器执行时实现上述模型训练方法和/或数据处理方法。
本发明实施例提供了一种电子设备,如图4所示,电子设备40包括至少一个处理器401、以及与处理器401连接的至少一个存储器402、总线403;其中,处理器401、存储器402通过总线403完成相互间的通信;处理器401用于调用存储器402中的程序指令,以执行上述的模型训练方法和/或数据处理方法。本文中的电子设备可以是服务器、PC、PAD、手机等。
本申请还提供了一种计算机程序产品,当在数据处理设备上执行时,适于执行初始化有上述的模型训练方法和/或数据处理方法包括的步骤的程序。
本申请是参照根据本申请实施例的方法、系统和计算机程序产品的流程图和/或方框图来描述的。应理解可由计算机程序指令实现流程图和/或方框图中的每一流程和/或方框、以及流程图和/或方框图中的流程和/或方框的结合。可提供这些计算机程序指令到通用计算机、专用计算机、嵌入式处理机或其他可编程数据处理设备的处理器以产生一个机器,使得通过计算机或其他可编程数据处理设备的处理器执行的指令产生用于实现在流程图一个流程或多个流程和/或方框图一个方框或多个方框中指定的功能的装置。
在一个典型的配置中,设备包括一个或多个处理器(CPU)、存储器和总线。设备还可以包括输入/输出接口、网络接口等。
存储器可能包括计算机可读介质中的非永久性存储器,随机存取存储器(RAM)和/或非易失性内存等形式,如只读存储器(ROM)或闪存(flash RAM),存储器包括至少一个存储芯片。存储器是计算机可读介质的示例。
计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存 (PRAM)、静态随机存取存储器 (SRAM)、动态随机存取存储器 (DRAM)、其他类型的随机存取存储器 (RAM)、只读存储器 (ROM)、电可擦除可编程只读存储器 (EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘 (DVD) 或其他光学存储、磁盒式磁带,磁带磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。按照本文中的界定,计算机可读介质不包括暂存电脑可读媒体 (transitory media),如调制的数据信号和载波。
本领域技术人员应明白,本申请的实施例可提供为方法、系统或计算机程序产品。因此,本申请可采用完全硬件实施例、完全软件实施例或结合软件和硬件方面的实施例的形式。而且,本申请可采用在一个或多个其中包含有计算机可用程序代码的计算机可用存储介质(包括但不限于磁盘存储器、CD-ROM、光学存储器等)上实施的计算机程序产品的形式。
需要说明的是,在本文中,诸如第一和第二等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。还需要说明的是,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、商品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、商品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括要素的过程、方法、商品或者设备中还存在另外的相同要素。
本说明书中的各个实施例均采用相关的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于系统实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
以上仅为本申请的实施例而已,并不用于限制本申请。对于本领域技术人员来说,本申请可以有各种更改和变化。凡在本申请的精神和原理之内所作的任何修改、等同替换、改进等,均应包含在本申请的权利要求范围之内。

Claims (9)

1.一种数据处理方法,其特征在于,包括:
将待处理数据输入数据处理模型中,得到处理后的数据;所述数据处理模型包括生成器和判别器,其中,所述数据处理模型是通过如下模型训练方法获得的:
将第一目标数据输入所述生成器中进行第一数据处理,获得第二目标数据,将所述第二目标数据输入所述判别器中,得到第一输出值;所述第一目标数据是对第三目标数据进行第二数据处理后得到的数据;所述第一数据处理为所述第二数据处理的反向处理过程;所述第一目标数据和所述第三目标数据为初始训练数据集中的数据;所述第一目标数据为降质图像,第一数据处理为图像质量恢复处理,第二目标数据为恢复图像,第三目标数据为原始图像,第二数据处理为图像降质处理;
将所述第三目标数据输入所述判别器中,得到第二输出值;
若所述第二输出值大于所述第一输出值,则将所述第一目标数据和第三目标数据放入当前训练数据集中;
使用所述当前训练数据集对所述数据处理模型进行训练。
2.根据权利要求1所述的数据处理方法,其特征在于,所述使用所述当前训练数据集对所述数据处理模型进行训练,具体包括:
从所述当前训练数据集中选取第四目标数据和第五目标数据;所述第四目标数据是对所述第五目标数据进行所述第二数据处理后得到的数据;
将所述第四目标数据输入所述生成器中进行第一数据处理,获得第六目标数据,将所述第六目标数据输入所述判别器中,得到第三输出值;
将所述第五目标数据输入所述判别器中,得到第四输出值;
将所述第三输出值和所述第四输出值输入生成器损失函数中,得到第一损失值;
基于所述第一损失值对所述生成器的参数进行更新;
判断生成器参数更新次数是否达到第一预设次数;
若达到所述第一预设次数,则结束训练过程;
若未达到所述第一预设次数,则返回步骤“将第一目标数据输入所述生成器中进行第一数据处理”。
3.根据权利要求1或2所述的数据处理方法,其特征在于,在将第一目标数据输入所述生成器中进行第一数据处理之前,还包括:
将第七目标数据输入所述生成器中进行所述第一数据处理,获得第八目标数据,将所述第八目标数据输入所述判别器中,得到第五输出值;所述第七目标数据是对第九目标数据进行所述第二数据处理后得到的数据;所述第七目标数据和所述第九目标数据为所述初始训练数据集中的数据;
将所述第九目标数据输入所述判别器中,得到第六输出值;
将所述第五输出值和所述第六输出值输入判别器损失函数中,得到第二损失值;
基于所述第二损失值对所述判别器的参数进行更新。
4.根据权利要求3所述的数据处理方法,其特征在于,在所述基于所述第二损失值对所述判别器的参数进行更新之后,所述方法还包括:
判断判别器参数更新次数是否达到第二预设次数;
若达到所述第二预设次数,则返回步骤“将第一目标数据输入所述生成器中进行第一数据处理”;
若未达到所述第二预设次数,则更换所述第七目标数据,然后返回步骤“将第七目标数据输入所述生成器中进行所述第一数据处理”。
5.根据权利要求2所述的数据处理方法,其特征在于,所述将所述第三输出值和所述第四输出值输入生成器损失函数中,得到第一损失值,具体包括:采用方式一或方式二得到所述第一损失值;
方式一:将所述第三输出值和所述第四输出值输入基于JS散度获得的生成器损失函数中,得到第一损失值;
方式二:将所述第三输出值和所述第四输出值输入基于Wasserstein距离获得的生成器损失函数中,得到第一损失值;
其中,
基于JS散度获得的生成器损失函数为:
Figure DEST_PATH_IMAGE002
式中,
Figure DEST_PATH_IMAGE004
为基于JS散度获得的生成器损失函数,
Figure DEST_PATH_IMAGE006
为第四目标数据,
Figure DEST_PATH_IMAGE008
为第六目标数据,
Figure DEST_PATH_IMAGE010
为第三输出值,
Figure DEST_PATH_IMAGE012
为第五目标数据,
Figure DEST_PATH_IMAGE014
为第四输出值,
Figure DEST_PATH_IMAGE016
为第四目标数据服从第一分布
Figure DEST_PATH_IMAGE018
的期望,第一分布
Figure 452404DEST_PATH_IMAGE018
为第四目标数据的概率分布,
Figure DEST_PATH_IMAGE020
为第五目标数据服从第二分布
Figure DEST_PATH_IMAGE022
的期望,第二分布
Figure 307228DEST_PATH_IMAGE022
为第五目标数据的概率分布;
基于Wasserstein距离获得的生成器损失函数为:
Figure DEST_PATH_IMAGE024
式中,
Figure DEST_PATH_IMAGE026
为基于Wasserstein距离获得的生成器损失函数。
6.根据权利要求3所述的数据处理方法,其特征在于,所述将所述第五输出值和所述第六输出值输入判别器损失函数中,得到第二损失值,具体包括:采用方式三或方式四得到所述第二损失值;
方式三:将所述第五输出值和所述第六输出值输入基于JS散度获得的判别器损失函数中,得到第二损失值;
方式四:将所述第五输出值和所述第六输出值输入基于Wasserstein距离获得的判别器损失函数中,得到第二损失值;
其中,
基于JS散度获得的判别器损失函数为:
Figure DEST_PATH_IMAGE028
式中,
Figure DEST_PATH_IMAGE030
为基于JS散度获得的判别器损失函数,
Figure DEST_PATH_IMAGE032
为第七目标数据,
Figure DEST_PATH_IMAGE034
为第八目标数据,
Figure DEST_PATH_IMAGE036
为第五输出值,
Figure DEST_PATH_IMAGE038
为第九目标数据,
Figure DEST_PATH_IMAGE040
为第六输出值,
Figure DEST_PATH_IMAGE042
为第七目标数据服从第三分布
Figure DEST_PATH_IMAGE044
的期望,第三分布
Figure 336232DEST_PATH_IMAGE044
为第七目标数据的概率分布,
Figure DEST_PATH_IMAGE046
为第九目标数据服从第四分布
Figure DEST_PATH_IMAGE048
的期望,第四分布
Figure 465863DEST_PATH_IMAGE048
为第九目标数据的概率分布;
基于Wasserstein距离获得的判别器损失函数为:
Figure DEST_PATH_IMAGE050
式中,
Figure DEST_PATH_IMAGE052
为基于Wasserstein距离获得的判别器损失函数。
7.一种模型训练系统,其特征在于,所述系统用于训练数据处理模型,所述数据处理模型包括生成器和判别器,所述系统包括:
第一输入模块,用于将第一目标数据输入所述生成器中进行第一数据处理,获得第二目标数据,将所述第二目标数据输入所述判别器中,得到第一输出值;所述第一目标数据是对第三目标数据进行第二数据处理后得到的数据;所述第一数据处理为所述第二数据处理的反向处理过程;所述第一目标数据和所述第三目标数据为初始训练数据集中的数据;所述第一目标数据为降质图像,第一数据处理为图像质量恢复处理,第二目标数据为恢复图像,第三目标数据为原始图像,第二数据处理为图像降质处理;
第二输入模块,用于将所述第三目标数据输入所述判别器中,得到第二输出值;
判断模块,用于在所述第二输出值大于所述第一输出值时,将所述第一目标数据和第三目标数据放入当前训练数据集中;
训练模块,用于使用所述当前训练数据集对所述数据处理模型进行训练;其中,训练得到数据处理模型后,将待处理数据输入所述数据处理模型中,得到处理后的数据。
8.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有程序,所述程序被处理器执行时实现权利要求1-6任一项所述的数据处理方法。
9.一种电子设备,其特征在于,包括:
至少一个处理器、以及与所述处理器连接的至少一个存储器、总线;
所述处理器、所述存储器通过所述总线完成相互间的通信;所述处理器用于调用所述存储器中的程序指令,以执行权利要求1-6任一项所述的数据处理方法。
CN202210321507.1A 2022-03-30 2022-03-30 一种模型训练方法、数据处理方法及相关设备 Active CN114418130B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210321507.1A CN114418130B (zh) 2022-03-30 2022-03-30 一种模型训练方法、数据处理方法及相关设备

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210321507.1A CN114418130B (zh) 2022-03-30 2022-03-30 一种模型训练方法、数据处理方法及相关设备

Publications (2)

Publication Number Publication Date
CN114418130A CN114418130A (zh) 2022-04-29
CN114418130B true CN114418130B (zh) 2022-07-15

Family

ID=81263517

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210321507.1A Active CN114418130B (zh) 2022-03-30 2022-03-30 一种模型训练方法、数据处理方法及相关设备

Country Status (1)

Country Link
CN (1) CN114418130B (zh)

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108520504B (zh) * 2018-04-16 2020-05-19 湘潭大学 一种基于生成对抗网络端到端的模糊图像盲复原方法
CN110135366B (zh) * 2019-05-20 2021-04-13 厦门大学 基于多尺度生成对抗网络的遮挡行人重识别方法
CN111695605B (zh) * 2020-05-20 2024-05-10 平安科技(深圳)有限公司 基于oct图像的图像识别方法、服务器及存储介质
US11526964B2 (en) * 2020-06-10 2022-12-13 Intel Corporation Deep learning based selection of samples for adaptive supersampling
CN112565777B (zh) * 2020-11-30 2023-04-07 通号智慧城市研究设计院有限公司 基于深度学习模型视频数据传输方法、系统、介质及设备

Also Published As

Publication number Publication date
CN114418130A (zh) 2022-04-29

Similar Documents

Publication Publication Date Title
CN110675399A (zh) 屏幕外观瑕疵检测方法及设备
CN110740356B (zh) 基于区块链的直播数据的监控方法及系统
CN111258905B (zh) 缺陷定位方法、装置和电子设备及计算机可读存储介质
CN109284492B (zh) 一种生成通知文书的方法和装置
CN110807009B (zh) 文件处理方法及装置
CN112651429B (zh) 一种音频信号时序对齐方法和装置
CN114022955A (zh) 一种动作识别方法及装置
CN114418130B (zh) 一种模型训练方法、数据处理方法及相关设备
CN108228869B (zh) 一种文本分类模型的建立方法及装置
CN110969547A (zh) 一种文本生成方法及装置
CN107016028B (zh) 数据处理方法及其设备
CN116955590A (zh) 训练数据筛选方法、模型训练方法、文本生成方法
CN114444725B (zh) 预训练服务系统及基于预训练服务系统的服务提供方法
CN114897723B (zh) 一种基于生成式对抗网络的图像生成与加噪方法
CN105740260A (zh) 提取模板文件数据结构的方法和装置
CN111191007A (zh) 一种基于区块链的文章关键词过滤方法及设备、介质
CN111048065B (zh) 文本纠错数据生成方法及相关装置
CN111078877B (zh) 数据处理、文本分类模型的训练、文本分类方法和装置
CN111260757A (zh) 一种图像处理方法、装置及终端设备
CN109710833B (zh) 用于确定内容节点的方法与设备
CN111667013A (zh) 信息补充方法、装置、电子设备及计算机可读存储介质
CN112579764A (zh) 一种庭审提纲的生成方法、装置、设备及存储介质
CN107948739B (zh) 一种网络电视去重用户数的计算方法及装置
CN109901990B (zh) 一种业务系统的测试方法、装置及设备
CN117523323B (zh) 一种生成图像的检测方法及装置

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant