CN115115058A - 模型训练方法、装置、设备及介质 - Google Patents

模型训练方法、装置、设备及介质 Download PDF

Info

Publication number
CN115115058A
CN115115058A CN202210374553.8A CN202210374553A CN115115058A CN 115115058 A CN115115058 A CN 115115058A CN 202210374553 A CN202210374553 A CN 202210374553A CN 115115058 A CN115115058 A CN 115115058A
Authority
CN
China
Prior art keywords
noise
round
current
current round
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210374553.8A
Other languages
English (en)
Inventor
黄志超
樊艳波
张卫忠
张勇
王珏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tencent Technology Shenzhen Co Ltd
Original Assignee
Tencent Technology Shenzhen Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tencent Technology Shenzhen Co Ltd filed Critical Tencent Technology Shenzhen Co Ltd
Priority to CN202210374553.8A priority Critical patent/CN115115058A/zh
Publication of CN115115058A publication Critical patent/CN115115058A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Health & Medical Sciences (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Image Analysis (AREA)

Abstract

本申请公开了一种模型训练方法、装置、设备及介质,涉及机器学习领域。该方法包括:从样本数据集中提取出本轮输入样本和本轮输出样本;通过机器学习模型对本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本;通过本轮输出样本和本轮预测样本之间的损失值更新上一轮对抗噪声,得到本轮对抗噪声;根据本轮对抗噪声、本轮输入样本和本轮输出样本更新机器学习模型的上一轮模型参数,得到本轮模型参数;迭代上述四个步骤,直至满足训练完成条件,完成对机器学习模型的训练;其中,在第1轮迭代中,上一轮对抗噪声是初始化的对抗噪声;在第i轮迭代中,上一轮对抗噪声是第i‑1轮迭代中的本轮对抗噪声。本申请可以提高模型的鲁棒性。

Description

模型训练方法、装置、设备及介质
技术领域
本申请涉及机器学习领域,特别涉及一种模型训练方法、装置、设备及介质。
背景技术
对抗训练(adversarial training)是增强神经网络鲁棒性的重要方式。在对抗训练的过程中,输入样本会被混合一些微小的对抗噪声,然后使神经网络适应这种改变,从而增强神经网络的鲁棒性。对抗噪声是添加到输入样本中的,用于干扰机器学习模型功能的噪声。
相关技术采用单步的梯度下降来确定对抗噪声以实现对抗训练,单步的梯度下降指每次梯度下降的幅度相同。相关技术会先获取随机的对抗噪声,在机器学习模型的训练过程中,通过梯度下降的对抗噪声对输入样本进行扰动,得到预测样本,并根据预测样本和输入样本之间的差值对机器学习模型进行训练。
相关技术受单步的梯度下降的步长影响,步长过大,会导致模型参数的过拟合,降低机器学习模型的鲁棒性;而步长过小,会减弱对抗噪声的作用,同样会降低机器学习模型的鲁棒性。
发明内容
本申请实施例提供了一种模型训练方法、装置、设备及介质,该方法可以在机器学习模型训练的过程中,为输入样本添加适合的对抗噪声,在保证训练效率的前提下,增强机器学习模型的鲁棒性,所述技术方案如下:
根据本申请的一个方面,提供了一种模型训练方法,该方法包括:
从样本数据集中提取出本轮输入样本和本轮输出样本;
通过所述机器学习模型对所述本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本;
通过所述本轮输出样本和所述本轮预测样本之间的损失值更新所述上一轮对抗噪声,得到本轮对抗噪声;
根据所述本轮对抗噪声、所述本轮输入样本和所述本轮输出样本更新所述机器学习模型的上一轮模型参数,得到本轮模型参数;
迭代上述四个步骤,直至满足训练完成条件,完成对所述机器学习模型的训练;
其中,在第1轮迭代中,所述上一轮对抗噪声是初始化的对抗噪声;在第i轮迭代中,所述上一轮对抗噪声是第i-1轮迭代中的本轮对抗噪声,i为大于1的整数。
根据本申请的一个方面,提供了一种模型训练装置,该装置包括:
提取模块,用于从样本数据集中提取出本轮输入样本和本轮输出样本;
训练模块,用于通过所述机器学习模型对所述本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本;
更新模块,用于通过所述本轮输出样本和所述本轮预测样本之间的损失值更新所述上一轮对抗噪声,得到本轮对抗噪声;
所述更新模块,还用于根据所述本轮对抗噪声、所述本轮输入样本和所述本轮输出样本更新所述机器学习模型的上一轮模型参数,得到本轮模型参数;
所述训练模块,还用于迭代上述四个步骤,直至满足训练完成条件,完成对所述机器学习模型的训练;
其中,在第1轮迭代中,所述上一轮对抗噪声是初始化的对抗噪声;在第i轮迭代中,所述上一轮对抗噪声是第i-1轮迭代中的本轮对抗噪声,i为大于1的整数。
根据本申请的另一方面,提供了一种计算机设备,该计算机设备包括:处理器和存储器,存储器中存储有至少一条指令、至少一段程序、代码集或指令集,至少一条指令、至少一段程序、代码集或指令集由处理器加载并执行以实现如上方面所述的模型训练方法。
根据本申请的另一方面,提供了一种计算机存储介质,计算机可读存储介质中存储有至少一条程序代码,程序代码由处理器加载并执行以实现如上方面所述的模型训练方法。
根据本申请的另一方面,提供了一种计算机程序产品或计算机程序,上述计算机程序产品或计算机程序包括计算机指令,上述计算机指令存储在计算机可读存储介质中。计算机设备的处理器从上述计算机可读存储介质读取上述计算机指令,上述处理器执行上述计算机指令,使得上述计算机设备执行如上方面所述的模型训练方法。
本申请实施例提供的技术方案带来的有益效果至少包括:
在训练机器学习模型时,会在输入样本中添加对抗噪声,而且该对抗噪声会随着训练的进程发生自适应调整,相较于相关技术,既不会出现过拟合问题,而且还可以在较大步长下完成机器学习模型的训练,训练效率较高,且由于添加了对抗噪声,提高了机器学习模型的鲁棒性。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个示例性实施例提供的计算机系统的示意图;
图2是本申请一个示例性实施例提供的模型训练方法的流程示意图;
图3是本申请一个示例性实施例提供的模型训练方法的流程示意图;
图4是本申请一个示例性实施例提供的模型训练方法的示意图;
图5是本申请一个示例性实施例提供的面部识别模型的训练方法的流程示意图;
图6是本申请一个示例性实施例提供的自动驾驶模型的训练方法的流程示意图;
图7是本申请一个示例性实施例提供的物品推荐模型的训练方法的流程示意图;
图8是本申请一个示例性实施例提供的模型训练装置的结构示意图;
图9是本申请一个示例性实施例提供的计算机设备的结构框图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
首先,对本申请实施例中涉及的名词进行介绍:
人工智能(Artificial Intelligence,AI):是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
计算机视觉技术(Computer Vision,CV):计算机视觉是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、OCR(Optical Character Recognition,光学字符识别)、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、3D技术、虚拟现实、增强现实、同步定位与地图构建等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。
机器学习(Machine Learning,ML):是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
需要说明的是,本申请所涉及的信息(包括但不限于用户设备信息、用户个人信息等)、数据(包括但不限于用于分析的数据、存储的数据、展示的数据等)以及信号,均为经用户授权或者经过各方充分授权的,且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。例如,本申请中涉及到的输入样本、输出样本以及用户信息都是在充分授权的情况下获取的。
图1示出了本申请一个示例性实施例提供的计算机系统的结构示意图。计算机系统100包括:终端120和服务器140。
终端120上安装有与模型训练的应用程序。该应用程序可以是app(application,应用程序)中的小程序,也可以是专门的应用程序,也可以是网页客户端。示例性的,服务器140将训练完成的模型提供给终端120。终端120是智能手机、平板电脑、电子书阅读器、MP3播放器、MP4播放器、膝上型便携计算机和台式计算机中的至少一种。
终端120通过无线网络或有线网络与服务器140相连。
服务器140可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、CDN(Content Delivery Network,内容分发网络)、以及大数据和人工智能平台等基础云计算服务的云服务器。服务器140用于进行机器学习模型的训练,并将训练完成的机器学习模型提供给终端120。可选地,服务器140承担主要计算工作,终端120承担次要计算工作;或者,服务器140承担次要计算工作,终端120承担主要计算工作;或者,服务器140和终端120两者采用分布式计算架构进行协同计算。
图2示出了本申请一个示例性实施例提供的模型训练方法的流程示意图。该方法可由图1所示的计算机系统100执行,该方法包括:
步骤202:从样本数据集中提取出本轮输入样本和本轮输出样本。
样本数据集包括成对的输入样本和输出样本。本申请实施例对样本数据集不做具体限定。示例性的,在机器学习模型用于图像识别的情况下,样本数据集包括成对的样本图像和图像类型标签,或者,样本数据集包括成对的面部图像和面部身份。示例性的,在机器学习模型用于实现自动驾驶的情况下,样本数据集包括成对的路况信息和驾驶策略。示例性的,在机器学习模型用于物品推荐的情况下,样本数据集包括成对的帐号信息和推荐物品。
可选地,本申请实施例的机器学习模型是通过样本数据和历史模型参数的来优化算法的模型。可选地,机器学习模型是模拟生物学习方式的模型。可选地,机器学习模型是在经验学习中改进算法性能的模型。
示例性的,机器学习模型包括但不限于ResNet(残差网络)模型、CNN(Convolutional Neural Networks,卷积神经网络)模型、RNN(Recurrent NeuralNetwork,循环神经网络)模型、Transformer(变换)模型中的至少一种。示例性的,机器学习模型的作用包括但不限于图像识别、自动驾驶、活体检测、自然语言理解、自然语言生成中的至少一种。需要说明的是,本申请对机器学习模型的类型和作用均不作具体限定。
本轮输出样本是本轮输入样本的真实标签。示例性的,在机器学习模型是身份识别模型的情况下,本轮输入样本是脸部图像,则本轮输出样本是该脸部图像的身份。示例性的,在机器学习模型是自动驾驶模型的情况下,本轮输入样本是路况信息,则本轮输出样本是参考的驾驶路线。
示例性的,样本数据集可表示为
Figure BDA0003589768830000061
其中,xi是输入样本,yi是输出样本,n表示样本数据集的样本总数。
可选地,从样本数据集中随机提取出本轮输入样本和本轮输出样本。或者,按照样本排列顺序,从样本数据集中提取出本轮输入样本和本轮输出样本。示例性的,样本数据集包括按序排列的n个成对的输入样本和输出样本,根据迭代次数,从样本数据集中选出本轮输入样本和本轮输出样本。
步骤204:通过机器学习模型对本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本。
对抗噪声用于干扰机器学习模型的功能。示例性的,在机器学习模型用于实现图像识别的情况下,输入样本是“山”的图像,机器学习模型识别该输入样本,输出“山”,而机器学习模型识别添加对抗噪声的输入样本后,输出“海”,故对抗噪声干扰了机器学习模型的功能。
可选地,将上一轮对抗噪声添加到本轮输入样本中,得到带噪的本轮输入样本;通过面部识别模型对带噪的本轮输入样本进行数据处理,得到本轮预测样本。
示例性的,将机器学习模型记为fθ,θ是机器学习模型的模型参数,本轮输入样本记为xk,上一轮对抗噪声是δk-1,则本轮预测样本是fθ(xkk-1)。
可选地,初始化的对抗噪声δ0=Uniform(-∈,∈),其中,Uniform代表均匀分布,∈是预设的常量。
步骤206:通过本轮输出样本和本轮预测样本之间的损失值更新上一轮对抗噪声,得到本轮对抗噪声。
可选地,通过损失函数确定本轮输出样本和本轮预测样本之间的损失值;通过损失值更新上一轮对抗噪声,得到本轮对抗噪声。
示例性的,将机器学习模型的损失函数定义为L(xkk,ykk-1),其中,xkk表示将本轮输入样本xk和本轮对抗噪声δk输入到机器学习模型后得到的输出,yk是本轮输出样本,θk-1是上一轮模型参数,k表示本轮的迭代次数。
步骤208:根据本轮对抗噪声、本轮输入样本和本轮输出样本更新机器学习模型的上一轮模型参数,得到本轮模型参数。
可选地,通过机器学习模型对本轮对抗噪声和本轮输入样本进行数据处理,得到本轮更新样本;根据本轮输出样本和本轮更新样本之间的损失更新值更新上一轮模型参数,得到本轮模型参数。
可选地,通过函数
Figure BDA0003589768830000071
更新模型参数,则本轮模型参数
Figure BDA0003589768830000072
Figure BDA0003589768830000073
θk-1是上一轮模型参数,k表示本轮的迭代次数,xkk表示将本轮输入样本xk和本轮对抗噪声δk输入到机器学习模型后得到的输出,yk是本轮输出样本,函数
Figure BDA0003589768830000074
的形式可由技术人员根据实际需求确定。
步骤210:迭代上述四个步骤,直至满足训练完成条件,完成对机器学习模型的训练。
可选地,迭代上述四个步骤,直至损失值达到收敛,完成对机器学习模型的训练。
可选地,迭代上述四个步骤N次,完成对机器学习模型的训练,N为常数。N的取值可由技术人员根据实际需求进行设置。
综上所述,本实施例在训练机器学习模型时,会在输入样本中添加对抗噪声,而且该对抗噪声会随着训练的进程发生自适应调整,相较于相关技术,既不会出现过拟合问题,而且还可以在较大步长下完成机器学习模型的训练,训练效率较高,且由于添加了对抗噪声,提高了机器学习模型的鲁棒性。
在接下来的实施例中,会对机器学习模型训练方法进行详细介绍,对抗噪声是通过自适应步长得到的,为不同迭代阶段的机器学习模型训练提供自适应的对抗噪声。而且,会对对抗噪声进行下采样,方便对抗噪声的存储,节约计算机设备的内存。
图3示出了本申请一个示例性实施例提供的模型训练方法的流程示意图。该方法可由图1所示的计算机系统100执行,该方法包括:
步骤301:从样本数据集中提取出本轮输入样本和本轮输出样本。
样本数据集包括成对的输入样本和输出样本。本申请实施例对样本数据集不做具体限定。
可选地,从样本数据集中随机提取出本轮输入样本和本轮输出样本。或者,按照样本排列顺序,从样本数据集中提取出本轮输入样本和本轮输出样本。
步骤302:在第1轮迭代时,通过机器学习模型对本轮输入样本和初始对抗噪声进行数据处理,得到本轮预测样本。
可选地,在第1轮迭代时,初始对抗噪声δi=Uniform(-∈,∈),其中,Uniform代表均匀分布,∈为常量,∈可由技术人员根据实际需求进行设置。
步骤303:在第i轮迭代时,通过机器学习模型对本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本。
其中,i为大于1的整数。
可选地,在第i轮迭代时,上一轮对抗噪声是在第i-1轮迭代中的本轮对抗噪声。
示例性的,将机器学习模型记为fθ,θ是机器学习模型的模型参数,本轮输入样本记为xk,上一轮对抗噪声是δk-1,则本轮预测样本是fθ(xkk-1)。
可选地,如图4所示,在使用对抗噪声时,将存储的对抗噪声
Figure BDA0003589768830000081
上采样为对抗噪声δk,由于对抗噪声
Figure BDA0003589768830000082
的信息量较小,对抗噪声
Figure BDA0003589768830000083
所占的存储区域也较小,采用对抗噪声
Figure BDA0003589768830000084
来存储对抗噪声可以节约计算机设备的内存。可选地,通过上采样分辨率对存储的上一轮对抗噪声进行上采样,得到上一轮对抗噪声。示例性的,以样本数据集中的输入样本是图像为例,输入样本的分辨率是w×h,对抗噪声的分辨率与输入样本的分辨率保持一致,若存储的上一轮对抗噪声
Figure BDA0003589768830000091
的分辨率是p×q,将存储的上一轮对抗噪声
Figure BDA0003589768830000092
上采样至w×h得到上一轮对抗噪声δk-1,使上一轮对抗噪声δk-1的分辨率与输入样本的分辨率保持一致。
需要说明的是,步骤302和步骤303相互排斥,执行步骤302时,不执行步骤303;执行步骤303时,不执行步骤302。
步骤304:确定本轮输出样本和本轮预测样本之间的损失值。
可选地,采用损失函数确定本轮输出样本和本轮预测样本之间的损失值。损失函数的形式可由技术人员根据实际需求确定。
示例性的,损失值是L(xkk-1,ykk-1),其中,xkk-1表示将本轮输入样本xk和上一轮对抗噪声δk-1输入到模型后得到的输出,yk是本轮输出样本,θk-1是上一轮模型参数,k表示本轮的迭代次数。
步骤305:确定损失值的梯度在对抗噪声维度上的噪声分量。
示例性的,噪声分量为
Figure BDA0003589768830000093
Figure BDA0003589768830000094
是梯度在对抗噪声维度上的分量。
步骤306:通过噪声分量更新上一轮动量参数,得到本轮动量参数。
本轮动量参数用于表示上一轮对抗噪声的衰减率。可选地,动量参数是momentum(动量)参数。
可选地,根据噪声分量的范数和上一轮动量参数的和,得到本轮动量参数。例如,根据噪声分量的二范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的无穷范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的一范数和上一轮动量参数的和,得到本轮动量参数。需要说明的是,本申请实施例对用于计算本轮动量参数的噪声分量的范数不做具体限定。
示例性的,本轮动量参数为
Figure BDA0003589768830000095
其中,vk-1是上一轮动量参数,β是超参数,β的取值可由技术人员根据实际需求进行调整。
步骤307:根据本轮动量参数更新上一轮对抗噪声,得到本轮对抗噪声。
可选地,本步骤包括以下子步骤:
1、根据本轮动量参数计算自适应步长。
自适应步长与本轮动量参数呈反比。示例性的,本轮动量参数为
Figure BDA0003589768830000096
其中,γ和c是超参数,γ和c的取值可由技术人员根据实际需求进行调整,vk是本轮动量参数。
2、通过噪声分量,计算自适应步长和上一轮对抗噪声的和,得到本轮对抗噪声。
示例性的,本轮对抗噪声为
Figure BDA0003589768830000101
其中,δk-1是上一轮对抗噪声,sgn表示sign函数,其中,
Figure BDA0003589768830000102
可选地,如图4所示,在获取对抗噪声时,将对抗噪声δk下采样为
Figure BDA0003589768830000103
Figure BDA0003589768830000104
存储到对抗噪声的存储区域中,由于对抗噪声
Figure BDA0003589768830000105
的信息量较小,对抗噪声
Figure BDA0003589768830000106
所占的存储区域也较小,采用对抗噪声
Figure BDA0003589768830000107
来存储对抗噪声可以节约计算机设备的内存。可选地,基于下采样分辨率对本轮对抗噪声进行下采样,得到下采样后的本轮对抗噪声;存储下采样后的本轮对抗噪声。下采样分辨率和上采样分辨率相对应,上采样分辨率可以还原经过下采样分辨率采样的对抗噪声。
示例性的,以样本数据集中的输入样本是图像为例,输入样本的分辨率是w×h,对抗噪声的分辨率与输入样本的分辨率保持一致,则本轮对抗噪声δk也为w×h,将本轮对抗噪声δk下采样至p×q得到下采样后的本轮对抗噪声
Figure BDA0003589768830000108
步骤308:将本轮对抗噪声的噪声值剪裁在预设区间内。
其中,预设区间是基于初始化的对抗噪声的噪声分布确定的。
示例性的,本轮对抗噪声
Figure BDA0003589768830000109
其中,Clip[-∈,∈]代表将函数值剪裁到[-∈,∈]的范围内,
Figure BDA00035897688300001010
其中,∈是预设的常量。
步骤309:通过机器学习模型对本轮对抗噪声和本轮输入样本进行数据处理,得到本轮更新样本。
示例性的,将机器学习模型记为fθ,θ是机器学习模型的模型参数,本轮输入样本记为xk,本轮对抗噪声是δk,则本轮预测样本是fθ(xkk)。
步骤310:确定本轮输出样本和本轮更新样本之间的损失更新值。
可选地,计算损失值的损失函数与计算损失更新值的损失函数是同一个损失函数,或者,计算损失值的损失函数与计算损失更新值的损失函数是不同的损失函数。
示例性的,损失更新值为
L(xkk,ykk-1);
其中,xkk表示将本轮输入样本xk和本轮对抗噪声δk输入到机器学习模型后得到的输出,yk是本轮输出样本,θk-1是上一轮模型参数。
步骤311:根据损失更新值更新上一轮模型参数,得到本轮模型参数。
示例性的,本轮模型参数为
Figure BDA0003589768830000111
其中,η是机器学习模型的学习率,也是一个超参数。
步骤312:迭代上述十个步骤,直至满足训练完成条件,完成对机器学习模型的训练。
可选地,迭代上述十个步骤,直至损失值达到收敛,完成对机器学习模型的训练。
可选地,迭代上述十个步骤N次,完成对机器学习模型的训练,N为常数。N的取值可由技术人员根据实际需求进行设置。
示例性的,如表1所示,本申请实施例提供的模型训练方法可以提高机器学习模型的准确率。
表1通过相关技术和本申请实施例提供的方法得到的模型的准确率
Figure BDA0003589768830000112
在表1中,使用的样本数据集是ImageNet数据集(ImageNet数据集是一个大型的自然图像数据集,包含涵盖了动物、物品、交通工具、地点等等各式各样的自然图像),L表示对抗噪声的噪声幅度,ResNet 18是ResNet模型的一种,“18”表示ResNet模型中卷积层和全连接层的数量,ResNet 50也是ResNet模型的一种,“50”表示ResNet模型中卷积层和全连接层的数量。从表1的数据中可以明确地看出在不同模型类型和不同噪声幅度下,通过本申请实施例的方法得到的机器学习模型优于相关技术得到的机器学习模型。
综上所述,本实施例在训练机器学习模型时,会在输入样本中添加对抗噪声,而且该对抗噪声会随着训练的进程发生自适应调整,相较于相关技术,既不会出现过拟合问题,而且还可以在较大步长下完成机器学习模型的训练,训练效率较高,且由于添加了对抗噪声,提高了机器学习模型的鲁棒性。
而且,在存储对抗噪声时,会下采样对抗噪声,在需要使用对抗噪声时,会将下采样后的对抗噪声进行上采样还原。由于下采样后的对抗噪声的信息量较少,对抗噪声采用下采样的方法进行存储时,可以节约对抗噪声的内存,而且,不会影响机器学习模型训练的效果。
需要说明的是,本申请提供的模型训练方法不依赖于模型类型、损失函数类型或者样本数据集。在接下来的实施例中,以训练面部识别模型为例进行说明,面部识别模型应用在面部识别系统中,本申请实施例提供的模型训练方法可以在不增加额外耗时的情况下增加面部识别系统的鲁棒性,提升这一面部识别系统的安全性。
图5示出了本申请一个示例性实施例提供的面部识别模型的训练方法的流程示意图。该方法可由图1所示的计算机系统100执行,该方法包括:
步骤501:从样本面部图像集中提取出本轮输入图像和本轮输出标签。
样本面部图像集包括成对的输入图像和输出标签,输入图像是包括生物面部的图像。示例性的,输入图像是人脸图像,则人脸图像包括一张人脸,或者人脸图像包括多张人脸,或者人脸图像包括部分人脸。示例性的,输出标签用于表示输入图像中人脸的身份,在输入图像包括多张人脸的情况下,输出标签用于表示输入图像中面积最大的人脸的身份,或者,输出标签用于表示输入图像中靠近图像中心的人脸的身份。
输出标签用于表示输入图像中生物面部的身份。示例性的,样本面部图像集包括“输入图像1-标签A”,标签A用于表示输入图像1中的面部是属于用户A的。
可选地,从样本面部图像集中随机提取出本轮输入图像和本轮输出标签。或者,按照样本面部图像排列顺序,从样本面部图像集中提取出本轮输入图像和本轮输出标签。
步骤502:在第1轮迭代时,通过面部识别模型对本轮输入图像和初始化的对抗噪声进行数据处理,得到本轮预测标签。
可选地,在第1轮迭代时,初始化的对抗噪声δi=Uniform(-∈,∈),其中,Uniform代表均匀分布,∈为常量,∈可由技术人员根据实际需求进行设置。
步骤503:在第i轮迭代时,通过面部识别模型对本轮输入图像和上一轮对抗噪声进行数据处理,得到本轮预测标签。
其中,i为大于1的整数。
可选地,在第i轮迭代时,上一轮对抗噪声是在第i-1轮迭代的本轮对抗噪声。
可选地,通过上采样分辨率对存储的上一轮对抗噪声进行上采样,得到上一轮对抗噪声,以节约计算机设备的内存。
需要说明的是,步骤502和步骤503相互排斥,执行步骤502时,不执行步骤503;执行步骤503时,不执行步骤502。
步骤504:确定本轮输出标签和本轮预测标签之间的损失值。
示例性的,损失值是L(xkk-1,ykk-1),其中,xkk-1表示将本轮输入图像xk和上一轮对抗噪声δk-1输入到面部识别模型后得到的输出,yk是本轮输出标签,θk-1是上一轮模型参数,k表示本轮的迭代次数。
步骤505:确定损失值的梯度在对抗噪声维度上的噪声分量。
示例性的,噪声分量为
Figure BDA0003589768830000131
Figure BDA0003589768830000132
是梯度在对抗噪声维度上的分量。
步骤506:通过噪声分量更新上一轮动量参数,得到本轮动量参数。
本轮动量参数用于表示上一轮对抗噪声的衰减率。可选地,动量参数是momentum(动量)参数。
可选地,根据噪声分量的范数和上一轮动量参数的和,得到本轮动量参数。例如,根据噪声分量的二范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的无穷范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的一范数和上一轮动量参数的和,得到本轮动量参数。
步骤507:根据本轮动量参数更新上一轮对抗噪声,得到本轮对抗噪声。
可选地,基于下采样分辨率对本轮对抗噪声进行下采样,得到下采样后的本轮对抗噪声;存储下采样后的本轮对抗噪声,以节约计算机设备的内存。
可选地,本步骤包括以下子步骤:
1、根据本轮动量参数计算自适应步长。
2、通过噪声分量,计算自适应步长和上一轮对抗噪声的和,得到本轮对抗噪声。
步骤508:将本轮对抗噪声的噪声值剪裁在预设区间内。
示例性的,通过Clip[-∈,∈]将本轮对抗噪声的噪声值剪裁在预设区间内,Clip[-∈,∈]代表将函数值剪裁到[-∈,∈]的范围内,∈是预设的常量。
步骤509:通过面部识别模型对本轮对抗噪声和本轮输入图像进行数据处理,得到本轮更新图像。
示例性的,将面部识别模型记为fθ,θ是面部识别模型的模型参数,本轮输入图像记为xk,本轮对抗噪声是δk,则本轮预测标签是fθ(xkk)。
步骤510:确定本轮输出图像和本轮更新图像之间的损失更新值。
可选地,计算损失值的损失函数与计算损失更新值的损失函数是同一个损失函数,或者,计算损失值的损失函数与计算损失更新值的损失函数是不同的损失函数。
示例性的,损失更新值为
L(xkk,ykk-1);
其中,xkk表示将本轮输入图像xk和本轮对抗噪声δk输入到面部识别模型后得到的输出,yk是本轮输出标签,θk-1是上一轮模型参数。
步骤511:根据损失更新值更新上一轮模型参数,得到本轮模型参数。
示例性的,本轮模型参数为
Figure BDA0003589768830000141
其中,η是面部识别模型的学习率,也是一个超参数。
步骤512:迭代上述十个步骤,直至满足训练完成条件,完成对面部识别模型的训练。
可选地,迭代上述十个步骤,直至损失值达到收敛,完成对面部识别模型的训练。
可选地,迭代上述十个步骤N次,完成对面部识别模型的训练,N为常数。N的取值可由技术人员根据实际需求进行设置。
综上所述,本申请实施例在训练面部识别模型时,会在输入图像中添加对抗噪声,而且该对抗噪声会随着训练的进程发生自适应调整,相较于相关技术,可以在不增加额外耗时的情况下增加面部识别模型的鲁棒性,提高面部识别模型的准确度。
在接下来的实施例中,以训练自动驾驶模型为例进行说明,自动驾驶模型应用在自动驾驶系统中,自动驾驶系统对安全性的要求很高,相关技术所需的训练时长过长,导致相关技术无法应用到大规模的训练中,本申请实施例提供的模型训练方法可以大大加速训练过程,减少训练时长,使得大规模训练自动驾驶模型的开销减少。
图6示出了本申请一个示例性实施例提供的自动驾驶模型的训练方法的流程示意图。该方法可由图1所示的计算机系统100执行,该方法包括:
步骤601:从样本数据集中提取出本轮输入路况信息和本轮输出驾驶策略。
样本数据集包括成对的路况信息和驾驶策略。路况信息包括车辆的位置、车辆的速度、天气、信号灯、其它车辆的位置、其它车辆的速度、交通标志、道路规定中的至少一种。驾驶策略包括直行、左转、右转、后退、减速、加速、急停、开灯、关灯、打开雨刮器、关闭雨刮器中的至少一种。
示例性的,路况信息用于表示车辆A前方30米处有人行横道,则对应的驾驶策略是降低车辆A的速度。示例性的,路况信息用于表示车辆B与车辆C之间的距离小于60米,车辆B和车辆C的速度均大于60km/h,则对应的驾驶策略是降低车辆B和车辆C的速度。
可选地,从样本数据集中随机提取出本轮输入路况信息和本轮输出驾驶策略。或者,按照路况信息排列顺序,从样本数据集中提取出本轮输入路况信息和本轮输出驾驶策略。
步骤602:在第1轮迭代时,通过自动驾驶模型对本轮输入路况信息和初始化的对抗噪声进行数据处理,得到本轮预测驾驶策略。
可选地,在第1轮迭代时,初始化的对抗噪声δi=Uniform(-∈,∈),其中,Uniform代表均匀分布,∈为常量,∈可由技术人员根据实际需求进行设置。
步骤603:在第i轮迭代时,通过自动驾驶模型对本轮输入路况信息和上一轮对抗噪声进行数据处理,得到本轮预测驾驶策略,i为大于1的整数。
可选地,在第i轮迭代时,上一轮对抗噪声是第i-1轮迭代中的本轮迭代噪声。
可选地,通过上采样分辨率对存储的上一轮对抗噪声进行上采样,得到上一轮对抗噪声,以节约计算机设备的内存。
需要说明的是,步骤602和步骤603相互排斥,执行步骤602时,不执行步骤603;执行步骤603时,不执行步骤602。
步骤604:确定本轮输出驾驶策略和本轮预测驾驶策略之间的损失值。
示例性的,损失值是L(xkk-1,ykk-1),其中,xkk-1表示将本轮输入驾驶策略xk和上一轮对抗噪声δk-1输入到自动驾驶模型后得到的输出,yk是本轮输出驾驶策略,θk-1是上一轮模型参数,k表示本轮的迭代次数。
步骤605:确定损失值的梯度在对抗噪声维度上的噪声分量。
示例性的,噪声分量为
Figure BDA0003589768830000161
Figure BDA0003589768830000162
是梯度在对抗噪声维度上的分量。
步骤606:通过噪声分量更新上一轮动量参数,得到本轮动量参数。
本轮动量参数用于表示上一轮对抗噪声的衰减率。可选地,动量参数是momentum(动量)参数。
可选地,根据噪声分量的范数和上一轮动量参数的和,得到本轮动量参数。例如,根据噪声分量的二范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的无穷范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的一范数和上一轮动量参数的和,得到本轮动量参数。
步骤607:根据本轮动量参数更新上一轮对抗噪声,得到本轮对抗噪声。
可选地,基于下采样分辨率对本轮对抗噪声进行下采样,得到下采样后的本轮对抗噪声;存储下采样后的本轮对抗噪声,以节约计算机设备的内存。
可选地,本步骤包括以下子步骤:
1、根据本轮动量参数计算自适应步长。
2、通过噪声分量,计算自适应步长和上一轮对抗噪声的和,得到本轮对抗噪声。
步骤608:将本轮对抗噪声的噪声值剪裁在预设区间内。
示例性的,通过Clip[-∈,∈]将本轮对抗噪声的噪声值剪裁在预设区间内,Clip[-∈,∈]代表将函数值剪裁到[-∈,∈]的范围内,∈是预设的常量。
步骤609:通过自动驾驶模型对本轮对抗噪声和本轮输入路况信息进行数据处理,得到本轮更新驾驶策略。
示例性的,将自动驾驶模型记为fθ,θ是自动驾驶模型的模型参数,本轮输入路况信息记为xk,本轮对抗噪声是δk,则本轮预测驾驶策略是fθ(xkk)。
步骤610:确定本轮输出驾驶策略和本轮更新驾驶策略之间的损失更新值。
可选地,计算损失值的损失函数与计算损失更新值的损失函数是同一个损失函数,或者,计算损失值的损失函数与计算损失更新值的损失函数是不同的损失函数。
示例性的,损失更新值为
L(xkk,ykk-1);
其中,xkk表示将本轮输入路况信息xk和本轮对抗噪声δk输入到自动驾驶模型后得到的输出,yk是本轮输出驾驶策略,θk-1是上一轮模型参数。
步骤611:根据损失更新值更新上一轮模型参数,得到本轮模型参数。
示例性的,本轮模型参数为
Figure BDA0003589768830000171
其中,η是自动驾驶模型的学习率,也是一个超参数。
步骤612:迭代上述十个步骤,直至满足训练完成条件,完成对自动驾驶模型的训练。
可选地,迭代上述十个步骤,直至损失值达到收敛,完成对自动驾驶模型的训练。
可选地,迭代上述十个步骤N次,完成对自动驾驶模型的训练,N为常数。N的取值可由技术人员根据实际需求进行设置。
综上所述,本申请实施例在训练自动驾驶模型时,会在输入路况信息中添加对抗噪声,而且该对抗噪声会随着训练的进程发生自适应调整,相较于相关技术,可以在保证自动驾驶模型的准确度的情况下,加快自动驾驶模型的训练速度,减少了训练自动驾驶模型的耗时。
在接下来的实施例中,以训练物品推荐模型为例进行说明,物品推荐模型应用在物品推荐系统中,物品推荐模型需要对大量的样本进行训练才能保证具有较高的准确度,导致相关技术所需的训练时长过长,而本申请实施例提供的模型训练方法可以大大加速训练过程,减少训练时长。
图7示出了本申请一个示例性实施例提供的物品推荐模型的训练方法的流程示意图。该方法可由图1所示的计算机系统100执行,该方法包括:
步骤701:从样本数据集中提取出本轮输入帐号信息和本轮输出推荐物品。
样本数据集包括成对的输入帐号信息和输出推荐物品。
可选地,输入帐号信息包括帐号名称、帐号ID(Identity Document,身份标识号)、帐号识别码、帐号使用时长、帐号注册时间、用户年龄、用户籍贯、用户身份中的至少一种。
输出推荐物品是现实物品、虚拟物品、现实服务、虚拟服务中的至少一种。
示例性的,输入帐号信息包括用户年龄是70岁,与该输入帐号信息对应的输出推荐物品是医疗服务。或者,输入帐号信息包括用户身份是司机,与该输入帐号信息对应的输出推荐物品是车辆保险。
可选地,从样本数据集中随机提取出本轮输入帐号信息和本轮输出推荐物品。或者,按照样本数据集排列顺序,从样本数据集中提取出本轮输入帐号信息和本轮输出推荐物品。
步骤702:在第1轮迭代时,通过物品推荐模型对本轮输入帐号信息和初始化的对抗噪声进行数据处理,得到本轮预测推荐物品。
可选地,在第1轮迭代时,初始化的对抗噪声δi=Uniform(-∈,∈),其中,Uniform代表均匀分布,∈为常量,∈可由技术人员根据实际需求进行设置。
步骤703:在第i轮迭代时,通过物品推荐模型对本轮输入帐号信息和上一轮对抗噪声进行数据处理,得到本轮预测推荐物品。
其中,i为大于1的整数。
可选地,在第i轮迭代时,上一轮对抗噪声是在第i-1轮迭代中的本轮对抗噪声。
可选地,通过上采样分辨率对存储的上一轮对抗噪声进行上采样,得到上一轮对抗噪声,以节约计算机设备的内存。
需要说明的是,步骤702和步骤703相互排斥,执行步骤702时,不执行步骤703;执行步骤703时,不执行步骤702。
步骤704:确定本轮输出推荐物品和本轮预测推荐物品之间的损失值。
示例性的,损失值是L(xkk-1,ykk-1),其中,xkk-1表示将本轮输入帐号信息xk和上一轮对抗噪声δk-1输入到物品推荐模型后得到的输出,yk是本轮输出推荐物品,θk-1是上一轮模型参数,k表示本轮的迭代次数。
步骤705:确定损失值的梯度在对抗噪声维度上的噪声分量。
示例性的,噪声分量为
Figure BDA0003589768830000191
Figure BDA0003589768830000192
是梯度在对抗噪声维度上的分量。
步骤706:通过噪声分量更新上一轮动量参数,得到本轮动量参数。
本轮动量参数用于表示上一轮对抗噪声的衰减率。可选地,动量参数是momentum(动量)参数。
可选地,根据噪声分量的范数和上一轮动量参数的和,得到本轮动量参数。例如,根据噪声分量的二范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的无穷范数和上一轮动量参数的和,得到本轮动量参数。或者,噪声分量的一范数和上一轮动量参数的和,得到本轮动量参数。
步骤707:根据本轮动量参数更新上一轮对抗噪声,得到本轮对抗噪声。
可选地,基于下采样分辨率对本轮对抗噪声进行下采样,得到下采样后的本轮对抗噪声;存储下采样后的本轮对抗噪声,以节约计算机设备的内存。
可选地,本步骤包括以下子步骤:
1、根据本轮动量参数计算自适应步长。
2、通过噪声分量,计算自适应步长和上一轮对抗噪声的和,得到本轮对抗噪声。
步骤708:将本轮对抗噪声的噪声值剪裁在预设区间内。
示例性的,通过Clip[-∈,∈]将本轮对抗噪声的噪声值剪裁在预设区间内,Clip[-∈,∈]代表将函数值剪裁到[-∈,∈]的范围内,∈是预设的常量。
步骤709:通过物品推荐模型对本轮对抗噪声和本轮输入帐号信息进行数据处理,得到本轮更新推荐物品。
示例性的,将物品推荐模型记为fθ,θ是物品推荐模型的模型参数,本轮输入帐号信息记为xk,本轮对抗噪声是δk,则本轮预测推荐物品是fθ(xkk)。
步骤710:确定本轮输出推荐物品和本轮更新推荐物品之间的损失更新值。
可选地,计算损失值的损失函数与计算损失更新值的损失函数是同一个损失函数,或者,计算损失值的损失函数与计算损失更新值的损失函数是不同的损失函数。
示例性的,损失更新值为
L(xkk,ykk-1);
其中,xkk表示将本轮输入帐号信息xk和本轮对抗噪声δk输入到物品推荐模型后得到的输出,yk是本轮输出推荐物品,θk-1是上一轮模型参数。
步骤711:根据损失更新值更新上一轮模型参数,得到本轮模型参数。
示例性的,本轮模型参数为
Figure BDA0003589768830000201
其中,η是面部识别模型的学习率,也是一个超参数。
步骤712:迭代上述十个步骤,直至满足训练完成条件,完成对物品推荐模型的训练。
可选地,迭代上述十个步骤,直至损失值达到收敛,完成对物品推荐模型的训练。
可选地,迭代上述十个步骤N次,完成对物品推荐模型的训练,N为常数。N的取值可由技术人员根据实际需求进行设置。
综上所述,本申请实施例在训练物品推荐模型时,会在输入帐号信息中添加对抗噪声,而且该对抗噪声会随着训练的进程发生自适应调整,相较于相关技术,可以在不增加额外耗时的情况下增加物品推荐模型的鲁棒性,提高物品推荐模型的准确度。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
请参考图8,其示出了本申请一个实施例提供的模型训练装置的框图。上述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置800包括:
提取模块801,用于从样本数据集中提取出本轮输入样本和本轮输出样本;
训练模块802,用于通过所述机器学习模型对所述本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本;
更新模块803,用于通过所述本轮输出样本和所述本轮预测样本之间的损失值更新所述上一轮对抗噪声,得到本轮对抗噪声;
所述更新模块803,还用于根据所述本轮对抗噪声、所述本轮输入样本和所述本轮输出样本更新所述机器学习模型的上一轮模型参数,得到本轮模型参数;
所述训练模块802,还用于迭代上述四个步骤,直至满足训练完成条件,完成对所述机器学习模型的训练;
其中,在第1轮迭代中,所述上一轮对抗噪声是初始化的对抗噪声;在第i轮迭代中,所述上一轮对抗噪声是第i-1轮迭代中的本轮对抗噪声,i为大于1的整数。
在本申请的一个可选设计中,所述更新模块803,还用于确定所述本轮输出样本和所述本轮预测样本之间的所述损失值;确定所述损失值的梯度在对抗噪声维度上的噪声分量;根据所述噪声分量更新所述上一轮对抗噪声,得到所述本轮对抗噪声。
在本申请的一个可选设计中,所述更新模块803,还用于通过所述噪声分量更新上一轮动量参数,得到本轮动量参数,所述本轮动量参数用于表示所述上一轮对抗噪声的衰减率;根据所述本轮动量参数更新所述上一轮对抗噪声,得到所述本轮对抗噪声。
在本申请的一个可选设计中,所述更新模块803,还用于根据所述本轮动量参数计算所述自适应步长,所述自适应步长与所述本轮动量参数呈反比;通过所述噪声分量,计算所述自适应步长和所述上一轮对抗噪声的和,得到所述本轮对抗噪声。
在本申请的一个可选设计中,所述更新模块803,还用于根据所述噪声分量的范数和所述上一轮动量参数的和,得到所述本轮动量参数。
在本申请的一个可选设计中,所述更新模块803,还用于将所述本轮对抗噪声的噪声值剪裁在预设区间内,所述预设区间是基于所述初始化的对抗噪声的噪声分布确定的。
在本申请的一个可选设计中,所述更新模块803,还用于通过所述机器学习模型对所述本轮对抗噪声和所述本轮输入样本进行数据处理,得到本轮更新样本;确定所述本轮输出样本和所述本轮更新样本之间的损失更新值;根据所述损失更新值更新所述上一轮模型参数,得到所述本轮模型参数。
在本申请的一个可选设计中,所述更新模块803,还用于确定所述损失更新值的梯度在模型参数维度上的模型参数分量;根据所述模型参数分量和所述上一轮模型参数的和,得到所述本轮模型参数。
在本申请的一个可选设计中,所述装置还包括采样模块804;
所述采样模块804,用于基于下采样分辨率对所述本轮对抗噪声进行下采样,得到下采样后的本轮对抗噪声;存储所述下采样后的本轮对抗噪声。
在本申请的一个可选设计中,所述采样模块804,还用于通过上采样分辨率对存储的所述上一轮对抗噪声进行上采样,得到所述上一轮对抗噪声。
在本申请的一个可选设计中,所述训练模块802,还用于迭代上述四个步骤,直至所述损失值达到收敛,完成对所述机器学习模型的训练;或,迭代上述四个步骤N次,完成对所述机器学习模型的训练,N为常数。
综上所述,本实施例在训练机器学习模型时,会在输入样本中添加对抗噪声,而且该对抗噪声会随着训练的进程发生自适应调整,相较于相关技术,既不会出现过拟合问题,而且还可以在较大步长下完成机器学习模型的训练,训练效率较高,且由于添加了对抗噪声,提高了机器学习模型的鲁棒性。
图9是根据一示例性实施例示出的一种计算机设备的结构示意图。所述计算机设备900包括中央处理单元(Central Processing Unit,CPU)901、包括随机存取存储器(Random Access Memory,RAM)902和只读存储器(Read-Only Memory,ROM)903的系统存储器904,以及连接系统存储器904和中央处理单元901的系统总线905。所述计算机设备900还包括帮助计算机设备内的各个器件之间传输信息的基本输入/输出系统(Input/Output,I/O系统)906,和用于存储操作系统913、应用程序914和其他程序模块915的大容量存储设备907。
所述基本输入/输出系统906包括有用于显示信息的显示器908和用于用户输入信息的诸如鼠标、键盘之类的输入设备909。其中所述显示器908和输入设备909都通过连接到系统总线905的输入输出控制器910连接到中央处理单元901。所述基本输入/输出系统906还可以包括输入输出控制器910以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器910还提供输出到显示屏、打印机或其他类型的输出设备。
所述大容量存储设备907通过连接到系统总线905的大容量存储控制器(未示出)连接到中央处理单元901。所述大容量存储设备907及其相关联的计算机设备可读介质为计算机设备900提供非易失性存储。也就是说,所述大容量存储设备907可以包括诸如硬盘或者只读光盘(Compact Disc Read-Only Memory,CD-ROM)驱动器之类的计算机设备可读介质(未示出)。
不失一般性,所述计算机设备可读介质可以包括计算机设备存储介质和通信介质。计算机设备存储介质包括以用于存储诸如计算机设备可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机设备存储介质包括RAM、ROM、可擦除可编程只读存储器(Erasable Programmable ReadOnly Memory,EPROM)、带电可擦可编程只读存储器(Electrically ErasableProgrammable Read-Only Memory,EEPROM),CD-ROM、数字视频光盘(Digital Video Disc,DVD)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知所述计算机设备存储介质不局限于上述几种。上述的系统存储器904和大容量存储设备907可以统称为存储器。
根据本申请的各种实施例,所述计算机设备900还可以通过诸如因特网等网络连接到网络上的远程计算机设备运行。也即计算机设备900可以通过连接在所述系统总线905上的网络接口单元912连接到网络911,或者说,也可以使用网络接口单元912来连接到其他类型的网络或远程计算机设备系统(未示出)。
所述存储器还包括一个或者一个以上的程序,所述一个或者一个以上程序存储于存储器中,中央处理器901通过执行该一个或一个以上程序来实现上述模型训练方法的全部或者部分步骤。
在示例性实施例中,还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由处理器加载并执行以实现上述各个方法实施例提供的模型训练方法。
本申请还提供一种计算机可读存储介质,所述存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现上述方法实施例提供的模型训练方法。
本申请还提供一种计算机程序产品或计算机程序,上述计算机程序产品或计算机程序包括计算机指令,上述计算机指令存储在计算机可读存储介质中。计算机设备的处理器从上述计算机可读存储介质读取上述计算机指令,上述处理器执行上述计算机指令,使得上述计算机设备执行如上方面实施例提供的模型训练方法。
上述本申请实施例序号仅仅为了描述,不代表实施例的优劣。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,所述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
以上所述仅为本申请的可选实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (15)

1.一种模型训练方法,其特征在于,所述方法包括:
从样本数据集中提取出本轮输入样本和本轮输出样本;
通过机器学习模型对所述本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本;
通过所述本轮输出样本和所述本轮预测样本之间的损失值更新所述上一轮对抗噪声,得到本轮对抗噪声;
根据所述本轮对抗噪声、所述本轮输入样本和所述本轮输出样本更新所述机器学习模型的上一轮模型参数,得到本轮模型参数;
迭代上述四个步骤,直至满足训练完成条件,完成对所述机器学习模型的训练;
其中,在第1轮迭代中,所述上一轮对抗噪声是初始化的对抗噪声;在第i轮迭代中,所述上一轮对抗噪声是第i-1轮迭代中的本轮对抗噪声,i为大于1的整数。
2.根据权利要求1所述的方法,其特征在于,所述通过所述本轮输出样本和所述本轮预测样本之间的损失值更新所述上一轮对抗噪声,得到本轮对抗噪声,包括:
确定所述本轮输出样本和所述本轮预测样本之间的所述损失值;
确定所述损失值的梯度在对抗噪声维度上的噪声分量;
根据所述噪声分量更新所述上一轮对抗噪声,得到所述本轮对抗噪声。
3.根据权利要求2所述的方法,其特征在于,所述根据所述噪声分量更新所述上一轮对抗噪声,得到所述本轮对抗噪声,包括:
通过所述噪声分量更新上一轮动量参数,得到本轮动量参数,所述本轮动量参数用于表示所述上一轮对抗噪声的衰减率;
根据所述本轮动量参数更新所述上一轮对抗噪声,得到所述本轮对抗噪声。
4.根据权利要求3所述的方法,其特征在于,所述根据所述本轮动量参数更新所述上一轮对抗噪声,得到所述本轮对抗噪声,包括:
根据所述本轮动量参数计算所述自适应步长,所述自适应步长与所述本轮动量参数呈反比;
通过所述噪声分量,计算所述自适应步长和所述上一轮对抗噪声的和,得到所述本轮对抗噪声。
5.根据权利要求3所述的方法,其特征在于,所述通过所述噪声分量更新上一轮动量参数,得到本轮动量参数,包括:
根据所述噪声分量的范数和所述上一轮动量参数的和,得到所述本轮动量参数。
6.根据权利要求2所述的方法,其特征在于,所述方法还包括:
将所述本轮对抗噪声的噪声值剪裁在预设区间内,所述预设区间是基于所述初始化的对抗噪声的噪声分布确定的。
7.根据权利要求1至6任一项所述的方法,其特征在于,所述根据所述本轮对抗噪声、所述本轮输入样本和所述本轮输出样本更新所述机器学习模型的上一轮模型参数,得到本轮模型参数,包括:
通过所述机器学习模型对所述本轮对抗噪声和所述本轮输入样本进行数据处理,得到本轮更新样本;
确定所述本轮输出样本和所述本轮更新样本之间的损失更新值;
根据所述损失更新值更新所述上一轮模型参数,得到所述本轮模型参数。
8.根据权利要求7所述的方法,其特征在于,所述根据所述损失更新值更新所述上一轮模型参数,得到所述本轮模型参数,包括:
确定所述损失更新值的梯度在模型参数维度上的模型参数分量;
根据所述模型参数分量和所述上一轮模型参数的和,得到所述本轮模型参数。
9.根据权利要求1至6任一项所述的方法,其特征在于,所述方法还包括:
基于下采样分辨率对所述本轮对抗噪声进行下采样,得到下采样后的本轮对抗噪声;
存储所述下采样后的本轮对抗噪声。
10.根据权利要求1至6任一项所述的方法,其特征在于,所述方法还包括:
通过上采样分辨率对存储的所述上一轮对抗噪声进行上采样,得到所述上一轮对抗噪声。
11.根据权利要求1至6任一项所述的方法,其特征在于,所述迭代上述四个步骤,直至满足训练完成条件,完成对所述机器学习模型的训练,包括:
迭代上述四个步骤,直至所述损失值达到收敛,完成对所述机器学习模型的训练;
或,迭代上述四个步骤N次,完成对所述机器学习模型的训练,N为常数。
12.一种模型训练装置,其特征在于,所述装置包括:
提取模块,用于从样本数据集中提取出本轮输入样本和本轮输出样本;
训练模块,用于通过机器学习模型对所述本轮输入样本和上一轮对抗噪声进行数据处理,得到本轮预测样本;
更新模块,用于通过所述本轮输出样本和所述本轮预测样本之间的损失值更新所述上一轮对抗噪声,得到本轮对抗噪声;
所述更新模块,还用于根据所述本轮对抗噪声、所述本轮输入样本和所述本轮输出样本更新所述机器学习模型的上一轮模型参数,得到本轮模型参数;
所述训练模块,还用于迭代上述四个步骤,直至满足训练完成条件,完成对所述机器学习模型的训练;
其中,其中,在第1轮迭代中,所述上一轮对抗噪声是初始化的对抗噪声;在第i轮迭代中,所述上一轮对抗噪声是第i-1轮迭代中的本轮对抗噪声,i为大于1的整数。
13.一种计算机设备,其特征在于,所述计算机设备包括:处理器和存储器,所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现如权利要求1至11中任一项所述的模型训练方法。
14.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质中存储有至少一条程序代码,所述程序代码由处理器加载并执行以实现如权利要求1至11中任一项所述的模型训练方法。
15.一种计算机程序产品,包括计算机程序或指令,其特征在于,所述计算机程序或指令被处理器执行时实现权利要求1至11中任一项所述的模型训练方法。
CN202210374553.8A 2022-04-11 2022-04-11 模型训练方法、装置、设备及介质 Pending CN115115058A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210374553.8A CN115115058A (zh) 2022-04-11 2022-04-11 模型训练方法、装置、设备及介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210374553.8A CN115115058A (zh) 2022-04-11 2022-04-11 模型训练方法、装置、设备及介质

Publications (1)

Publication Number Publication Date
CN115115058A true CN115115058A (zh) 2022-09-27

Family

ID=83324773

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210374553.8A Pending CN115115058A (zh) 2022-04-11 2022-04-11 模型训练方法、装置、设备及介质

Country Status (1)

Country Link
CN (1) CN115115058A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115530842A (zh) * 2022-11-30 2022-12-30 合肥心之声健康科技有限公司 一种增强分类心电信号的神经网络模型鲁棒性的方法

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115530842A (zh) * 2022-11-30 2022-12-30 合肥心之声健康科技有限公司 一种增强分类心电信号的神经网络模型鲁棒性的方法

Similar Documents

Publication Publication Date Title
CN110414432B (zh) 对象识别模型的训练方法、对象识别方法及相应的装置
CN111898635A (zh) 神经网络的训练方法、数据获取方法和装置
CN111325664B (zh) 风格迁移方法、装置、存储介质及电子设备
CN112949647B (zh) 三维场景描述方法、装置、电子设备和存储介质
CN113947764B (zh) 一种图像处理方法、装置、设备及存储介质
CN114418030B (zh) 图像分类方法、图像分类模型的训练方法及装置
CN114241459B (zh) 一种驾驶员身份验证方法、装置、计算机设备及存储介质
DK201770681A1 (en) A method for (re-) training a machine learning component
CN110210493A (zh) 基于非经典感受野调制神经网络的轮廓检测方法及系统
JP2023506169A (ja) 視覚入力に対する形式的安全シンボリック強化学習
CN111325766A (zh) 三维边缘检测方法、装置、存储介质和计算机设备
CN114612902A (zh) 图像语义分割方法、装置、设备、存储介质及程序产品
CN116958323A (zh) 图像生成方法、装置、电子设备、存储介质及程序产品
CN114742224A (zh) 行人重识别方法、装置、计算机设备及存储介质
CN116353623A (zh) 一种基于自监督模仿学习的驾驶控制方法
CN115115058A (zh) 模型训练方法、装置、设备及介质
CN117058723B (zh) 掌纹识别方法、装置及存储介质
CN117079276B (zh) 一种基于知识蒸馏的语义分割方法、系统、设备及介质
CN116958712A (zh) 基于先验概率分布的图像生成方法、系统、介质及设备
CN113537267A (zh) 对抗样本的生成方法和装置、存储介质及电子设备
Fatkhulin et al. Analysis of the Basic Image Generation Methods by Neural Networks
CN112990123B (zh) 图像处理方法、装置、计算机设备和介质
CN111461091B (zh) 万能指纹生成方法和装置、存储介质及电子装置
CN114373098A (zh) 一种图像分类方法、装置、计算机设备及存储介质
CN114580715A (zh) 一种基于生成对抗网络和长短期记忆模型的行人轨迹预测方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination