CN112488183A - 一种模型优化方法、装置、计算机设备及存储介质 - Google Patents
一种模型优化方法、装置、计算机设备及存储介质 Download PDFInfo
- Publication number
- CN112488183A CN112488183A CN202011359384.8A CN202011359384A CN112488183A CN 112488183 A CN112488183 A CN 112488183A CN 202011359384 A CN202011359384 A CN 202011359384A CN 112488183 A CN112488183 A CN 112488183A
- Authority
- CN
- China
- Prior art keywords
- user
- text
- gradient
- decision parameter
- representing
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000005457 optimization Methods 0.000 title claims abstract description 87
- 238000000034 method Methods 0.000 title claims abstract description 65
- 238000012549 training Methods 0.000 claims abstract description 88
- 239000011159 matrix material Substances 0.000 claims description 64
- 238000005070 sampling Methods 0.000 claims description 47
- 230000006870 function Effects 0.000 claims description 34
- 238000000354 decomposition reaction Methods 0.000 claims description 22
- 238000004364 calculation method Methods 0.000 claims description 8
- 230000002159 abnormal effect Effects 0.000 claims description 5
- 238000010276 construction Methods 0.000 claims description 3
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 238000012790 confirmation Methods 0.000 description 5
- 238000010586 diagram Methods 0.000 description 5
- 230000003287 optical effect Effects 0.000 description 3
- 230000005856 abnormality Effects 0.000 description 2
- 239000003086 colorant Substances 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 238000003672 processing method Methods 0.000 description 2
- 230000009897 systematic effect Effects 0.000 description 2
- 238000012360 testing method Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000004891 communication Methods 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000010365 information processing Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/20—Information retrieval; Database structures therefor; File system structures therefor of structured data, e.g. relational data
- G06F16/23—Updating
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Databases & Information Systems (AREA)
- Machine Translation (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本申请实施例属于人工智能的模型优化技术领域,涉及一种应用于动量梯度下降的模型优化方法、装置、计算机设备及存储介质。本申请提供的应用于动量梯度下降的模型优化方法,由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
Description
技术领域
本申请涉及人工智能的模型优化,尤其涉及一种应用于动量梯度下降的模型优化方法、装置、计算机设备及存储介质。
背景技术
最优化问题是计算数学中最为重要的研究方向之一。在深度学习领域,优化算法同样是关键环节之一。即使完全相同的数据集与模型架构,不同的优化算法也很可能导致不同的训练结果,甚至有的模型出现不收敛现象。
现有一种模型优化方法,在深度学习的模型训练过程中,采用指数加权移动平均的方式,基于积攒了历史梯度的动量对该模型进行训练,以提高该模型的准确率。
然而,申请人发现传统的模型优化方法普遍不智能,在模型优化的过程中Embedding层会出现过拟合的问题。
发明内容
本申请实施例的目的在于提出一种应用于动量梯度下降的模型优化方法、装置、计算机设备及存储介质,以解决传统的模型优化方法在模型优化的过程中Embedding层会出现过拟合的问题。
为了解决上述技术问题,本申请实施例提供一种应用于动量梯度下降的模型优化方法,采用了如下所述的技术方案:
接收用户终端发送的模型优化请求,所述模型优化请求至少携带有原始预测模型以及原始训练数据集;
在所述原始训练数据集中进行采样操作,得到本轮训练数据集;
基于所述本轮训练数据集定义目标函数;
初始化所述原始预测模型的模型优化参数,得到初始速度参数以及初始决策参数;
计算本轮需要更新所述初始决策参数对应的梯度数据;
判断所述梯度数据是否已更新;
若所述梯度数据未更新,则输出采样异常信号;
若所述梯度数据已更新,则基于所述梯度数据更新所述初始速度参数,得到更新速度;
基于所述更新速度更新所述初始决策参数,得到更新决策参数;
当所述初始决策参数以及所述更新决策参数满足收敛条件时,得到目标预测模型。
为了解决上述技术问题,本申请实施例还提供一种应用于动量梯度下降的模型优化装置,采用了如下所述的技术方案:
请求接收模块,用于接收用户终端发送的模型优化请求,所述模型优化请求至少携带有原始预测模型以及原始训练数据集;
采样操作模块,用于在所述原始训练数据集中进行采样操作,得到本轮训练数据集;
函数定义模块,用于基于所述本轮训练数据集定义目标函数;
初始化模块,用于初始化所述原始预测模型的模型优化参数,得到初始速度参数以及初始决策参数;
梯度计算模块,用于计算本轮需要更新所述初始决策参数对应的梯度数据;
梯度判断模块,用于判断所述梯度数据是否已更新;
异常确认模块,用于若所述梯度数据未更新,则输出采样异常信号;
速度参数更新模块,用于若所述梯度数据已更新,则基于所述梯度数据更新所述初始速度参数,得到更新速度;
决策参数更新模块,用于基于所述更新速度更新所述初始决策参数,得到更新决策参数;
目标模型获取模块,用于当所述初始决策参数以及所述更新决策参数满足收敛条件时,得到目标预测模型。
为了解决上述技术问题,本申请实施例还提供一种计算机设备,采用了如下所述的技术方案:
包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所述计算机可读指令时实现如上所述的应用于动量梯度下降的模型优化方法的步骤。
为了解决上述技术问题,本申请实施例还提供一种计算机可读存储介质,采用了如下所述的技术方案:
所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如上所述的应用于动量梯度下降的模型优化方法的步骤。
与现有技术相比,本申请实施例提供的应用于动量梯度下降的模型优化方法、装置、计算机设备及存储介质主要有以下有益效果:
本申请提供了一种应用于动量梯度下降的模型优化方法,接收用户终端发送的模型优化请求,所述模型优化请求至少携带有原始预测模型以及原始训练数据集;在所述原始训练数据集中进行采样操作,得到本轮训练数据集;基于所述本轮训练数据集定义目标函数;初始化模型优化算法参数,得到初始速度参数以及初始决策参数;计算本轮需要更新所述初始决策参数对应的梯度数据;判断所述梯度数据是否已更新;若所述梯度数据未更新,则输出采样异常信号;若所述梯度数据已更新,则基于所述梯度数据更新所述初始速度参数,得到更新速度;基于所述更新速度更新所述初始决策参数,得到更新决策参数;当所述初始决策参数以及所述更新决策参数满足收敛条件时,得到目标预测模型。由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
附图说明
为了更清楚地说明本申请中的方案,下面将对本申请实施例描述中所需要使用的附图作一个简单介绍,显而易见地,下面描述中的附图是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请实施例一提供的应用于动量梯度下降的模型优化方法的实现流程图;
图2是图1中步骤S103的实现流程图;
图3是图1中步骤S110的实现流程图;
图4是本申请实施例二提供的应用于动量梯度下降的模型优化装置的结构示意图;
图5是图4中函数定义模块103的结构示意图;
图6是根据本申请的计算机设备的一个实施例的结构示意图。
具体实施方式
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同;本文中在申请的说明书中所使用的术语只是为了描述具体的实施例的目的,不是旨在于限制本申请;本申请的说明书和权利要求书及上述附图说明中的术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含。本申请的说明书和权利要求书或上述附图中的术语“第一”、“第二”等是用于区别不同对象,而不是用于描述特定顺序。
在本文中提及“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包含在本申请的至少一个实施例中。在说明书中的各个位置出现该短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
为了使本技术领域的人员更好地理解本申请方案,下面将结合附图,对本申请实施例中的技术方案进行清楚、完整地描述。
实施例一
如图1所示,示出了根据本申请实施例一提供的应用于动量梯度下降的模型优化方法的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在步骤S101中,接收用户终端发送的模型优化请求,模型优化请求至少携带有原始预测模型以及原始训练数据集。
在本申请实施例中,用户终端指的是用于执行本申请提供的预防证件滥用的图像处理方法的终端设备,该当前终端可以是诸如移动电话、智能电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、导航装置等等的移动终端以及诸如数字TV、台式计算机等等的固定终端,应当理解,此处对用户终端的举例仅为方便理解,不用于限定本申请。
在本申请实施例中,原始预测模型未进行梯度下降优化的预测模型。
在步骤S102中,在原始训练数据集中进行采样操作,得到本轮训练数据集。
在本申请实施例中,采样操作是指从总体训练数据中抽取个体或样品的过程,也即对总体训练数据进行试验或观测的过程。分随机抽样和非随机抽样两种类型。前者指遵照随机化原则从总体中抽取样本的抽样方法,它不带任何主观性,包括简单随机抽样、系统抽样、整群抽样和分层抽样。后者是一种凭研究者的观点、经验或者有关知识来抽取样本的方法,带有明显主观色彩。
在本申请实施例中,本轮训练数据集指的是经过上述采样操作后筛选出的数据量较小的训练数据集,以减少模型的训练时间。
在步骤S103中,基于本轮训练数据集定义目标函数。
在本申请实施例中,可基于用户文本的数据集生成用户-文本矩阵R,基于奇异值分解法对用户-文本矩阵R进行分解操作,得到用户-隐特征矩阵P以及隐特征-文本矩阵Q,基于用户-文本矩阵R构造目标函数目标函数表示为:
其中,R(Λ)表示用户-文本矩阵R用户对文本的评分数据集合;pm`表示用户-隐特征矩阵P中第m个用户对应的隐特征;qn`表示隐特征-文本矩阵Q中第n个文本对应的隐特征;rm,n表示用户m对文本n的评分数据;表示评分数据集合中用户m对文本n的评分数据;λ2表示隐特征矩阵的正则化因子。
在步骤S104中,初始化原始预测模型的模型优化参数,得到初始速度参数以及初始决策参数。
在本申请实施例中,初始化就是把变量赋为默认值,把控件设为默认状态,具体的,包括初始化学习率∈、动量参数a、初始决策参数θ和初始速度v。
在步骤S105中,计算本轮需要更新初始决策参数对应的梯度数据。
在本申请实施例中,梯度数据表示为:
在步骤S106中,判断梯度数据是否已更新。
在本申请实施例中,当一个训练数据被采样过后,它的Embedding的梯度不为0,基于该采样的特征,通过判断梯度数据是否已更新,即可获知该训练数据是否被采样过。
在步骤S107中,若梯度数据未更新,则输出采样异常信号。
在本申请实施例中,若梯度数据未更新,则说明该训练数据没有被采样过便进行后续的更新操作,没有被反复采样的训练数据,对应的Embedding层基于历史动量也会被被反复训练更新,导致了过拟合情况发生。
在步骤S108中,若梯度数据已更新,则基于梯度数据更新初始速度参数,得到更新速度。
在本申请实施例中,更新速度表示为:
vnew=αvold-∈g
其中,vnew表示更新速度;vold表示初始速度参数;α表示动量参数;∈表示学习率;g表示梯度数据。
在步骤S109中,基于更新速度更新初始决策参数,得到更新决策参数。
在本申请实施例中,更新决策参数表示为:
θnew=θold+vnew
其中,θnew表示更新决策参数;θold表示初始决策参数;vnew表示更新速度。
在步骤S110中,当初始决策参数以及更新决策参数满足收敛条件时,得到目标预测模型。
本申请实施例一提供的应用于动量梯度下降的模型优化方法,接收用户终端发送的模型优化请求,模型优化请求至少携带有原始预测模型以及原始训练数据集;在原始训练数据集中进行采样操作,得到本轮训练数据集;基于本轮训练数据集定义目标函数;初始化模型优化算法参数,得到初始速度参数以及初始决策参数;计算本轮需要更新初始决策参数对应的梯度数据;判断梯度数据是否已更新;若梯度数据未更新,则输出采样异常信号;若梯度数据已更新,则基于梯度数据更新初始速度参数,得到更新速度;基于更新速度更新初始决策参数,得到更新决策参数;当初始决策参数以及更新决策参数满足收敛条件时,得到目标预测模型。由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
继续参阅图2,示出了图1中步骤S103的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在本申请实施例一的一些可选的实现方式中,上述步骤S103具体包括:步骤S201、步骤S202以及步骤S203。
在步骤S201中,基于用户文本的数据集生成用户-文本矩阵R。
在步骤S202中,基于奇异值分解法对用户-文本矩阵R进行分解操作,得到用户-隐特征矩阵P以及隐特征-文本矩阵Q。
在本申请实施例中,奇异值分解(Singular Value Decomposition)是线性代数中一种重要的矩阵分解,奇异值分解则是特征分解在任意矩阵上的推广。
在步骤S203中,基于用户-文本矩阵R构造目标函数。
其中,R(Λ)表示用户-文本矩阵R用户对文本的评分数据集合;Pm`表示用户-隐特征矩阵P中第m个用户对应的隐特征;qn`表示隐特征-文本矩阵Q中第n个文本对应的隐特征;rm,n表示用户m对文本n的评分数据;表示评分数据集合中用户m对文本n的评分数据;λ2表示隐特征矩阵的正则化因子。
继续参阅图3,示出了图1中步骤S110的实现流程图,为了便于说明,仅示出与本申请相关的部分。
在本申请实施例一的一些可选的实现方式中,上述步骤S110具体包括:步骤S301、步骤S302、步骤S303以及步骤S304。
在步骤S301中,计算初始决策参数以及更新决策参数的决策参数差值。
在本申请实施例中,决策参数差值主要用于判断当前模型参数与上轮模型参数的变化量,当该变化量小于一定数值时,则认为决策参数趋向于某个稳定的数值,以使得该预测模型达到稳定。
在步骤S302中,判断决策参数差值是否小于预设收敛阈值。
在本申请实施例中,用户可以根据实际情况调整预设收敛阈值。
在步骤S303中,若决策参数差值小于或等于预设收敛阈值,则确定当前的预测模型收敛,并将当前的预测模型作为目标预测模型。
在本申请实施例中,当决策参数差值小于或等于预设收敛阈值,则说明决策参数趋向于某个稳定的数值,该预测模型达到稳定。
在步骤S304中,若决策参数差值大于预设收敛阈值,则则确定当前的预测模型未收敛,继续执行参数优化操作。
在本申请实施例中,当决策参数差值大于预设收敛阈值,则说明决策参数未达到某个稳定的数值,该预测模型的参数仍然需要进行优化。
在本申请实施例一的一些可选的实现方式中,梯度数据表示为:
在本申请实施例一的一些可选的实现方式中,更新速度表示为:
vnew=αvold-∈g
其中,vnew表示更新速度;vold表示初始速度参数;α表示动量参数;∈表示学习率;g表示梯度数据。
在本申请实施例一的一些可选的实现方式中,更新决策参数表示为:
θnew=θold+vnew
其中,θnew表示更新决策参数;θold表示初始决策参数;vnew表示更新速度。
综上,本申请实施例一提供的应用于动量梯度下降的模型优化方法,接收用户终端发送的模型优化请求,模型优化请求至少携带有原始预测模型以及原始训练数据集;在原始训练数据集中进行采样操作,得到本轮训练数据集;基于本轮训练数据集定义目标函数;初始化模型优化算法参数,得到初始速度参数以及初始决策参数;计算本轮需要更新初始决策参数对应的梯度数据;判断梯度数据是否已更新;若梯度数据未更新,则输出采样异常信号;若梯度数据已更新,则基于梯度数据更新初始速度参数,得到更新速度;基于更新速度更新初始决策参数,得到更新决策参数;当初始决策参数以及更新决策参数满足收敛条件时,得到目标预测模型。由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可以通过计算机可读指令来指令相关的硬件来完成,该计算机可读指令可存储于一计算机可读取存储介质中,该计算机可读指令在执行时,可包括如上述各方法的实施例的流程。其中,前述的存储介质可为磁碟、光盘、只读存储记忆体(Read-Only Memory,ROM)等非易失性存储介质,或随机存储记忆体(Random Access Memory,RAM)等。
应该理解的是,虽然附图的流程图中的各个步骤按照箭头的指示依次显示,但是这些步骤并不是必然按照箭头指示的顺序依次执行。除非本文中有明确的说明,这些步骤的执行并没有严格的顺序限制,其可以以其他的顺序执行。而且,附图的流程图中的至少一部分步骤可以包括多个子步骤或者多个阶段,这些子步骤或者阶段并不必然是在同一时刻执行完成,而是可以在不同的时刻执行,其执行顺序也不必然是依次进行,而是可以与其他步骤或者其他步骤的子步骤或者阶段的至少一部分轮流或者交替地执行。
实施例二
进一步参考图4,作为对上述图1所示方法的实现,本申请提供了一种应用于动量梯度下降的模型优化装置的一个实施例,该装置实施例与图1所示的方法实施例相对应,该装置具体可以应用于各种电子设备中。
如图4所示,本实施例的应用于动量梯度下降的模型优化装置100包括:请求接收模块101、采样操作模块102、函数定义模块103、初始化模块104、梯度计算模块105、梯度判断模块106、异常确认模块107、速度参数更新模块108、决策参数更新模块109以及目标模型获取模块110。其中:
请求接收模块101,用于接收用户终端发送的模型优化请求,模型优化请求至少携带有原始预测模型以及原始训练数据集;
采样操作模块102,用于在原始训练数据集中进行采样操作,得到本轮训练数据集;
函数定义模块103,用于基于本轮训练数据集定义目标函数;
初始化模块104,用于初始化原始预测模型的模型优化参数,得到初始速度参数以及初始决策参数;
梯度计算模块105,用于计算本轮需要更新初始决策参数对应的梯度数据;
梯度判断模块106,用于判断梯度数据是否已更新;
异常确认模块107,用于若梯度数据未更新,则输出采样异常信号;
速度参数更新模块108,用于若梯度数据已更新,则基于梯度数据更新初始速度参数,得到更新速度;
决策参数更新模块109,用于基于更新速度更新初始决策参数,得到更新决策参数;
目标模型获取模块110,用于当初始决策参数以及更新决策参数满足收敛条件时,得到目标预测模型。
在本申请实施例中,用户终端指的是用于执行本申请提供的预防证件滥用的图像处理方法的终端设备,该当前终端可以是诸如移动电话、智能电话、笔记本电脑、数字广播接收器、PDA(个人数字助理)、PAD(平板电脑)、PMP(便携式多媒体播放器)、导航装置等等的移动终端以及诸如数字TV、台式计算机等等的固定终端,应当理解,此处对用户终端的举例仅为方便理解,不用于限定本申请。
在本申请实施例中,原始预测模型未进行梯度下降优化的预测模型。
在本申请实施例中,采样操作是指从总体训练数据中抽取个体或样品的过程,也即对总体训练数据进行试验或观测的过程。分随机抽样和非随机抽样两种类型。前者指遵照随机化原则从总体中抽取样本的抽样方法,它不带任何主观性,包括简单随机抽样、系统抽样、整群抽样和分层抽样。后者是一种凭研究者的观点、经验或者有关知识来抽取样本的方法,带有明显主观色彩。
在本申请实施例中,本轮训练数据集指的是经过上述采样操作后筛选出的数据量较小的训练数据集,以减少模型的训练时间。
在本申请实施例中,可基于用户文本的数据集生成用户-文本矩阵R,基于奇异值分解法对用户-文本矩阵R进行分解操作,得到用户-隐特征矩阵P以及隐特征-文本矩阵Q,基于用户-文本矩阵R构造目标函数目标函数表示为:
其中,R(Λ)表示用户-文本矩阵R用户对文本的评分数据集合;pm`表示用户-隐特征矩阵P中第m个用户对应的隐特征;qn`表示隐特征-文本矩阵Q中第n个文本对应的隐特征;rm,n表示用户m对文本n的评分数据;表示评分数据集合中用户m对文本n的评分数据;λ2表示隐特征矩阵的正则化因子。
在本申请实施例中,初始化就是把变量赋为默认值,把控件设为默认状态,具体的,包括初始化学习率∈、动量参数α、初始决策参数θ和初始速度v。
在本申请实施例中,梯度数据表示为:
在本申请实施例中,当一个训练数据被采样过后,它的Embedding的梯度不为0,基于该采样的特征,通过判断梯度数据是否已更新,即可获知该训练数据是否被采样过。
在本申请实施例中,若梯度数据未更新,则说明该训练数据没有被采样过便进行后续的更新操作,没有被反复采样的训练数据,对应的Embedding层基于历史动量也会被被反复训练更新,导致了过拟合情况发生。
在本申请实施例中,更新速度表示为:
vnew=αvold-∈g
其中,vnew表示更新速度;vold表示初始速度参数;α表示动量参数;∈表示学习率;g表示梯度数据。
在本申请实施例中,更新决策参数表示为:
θnew=θold+vnew
其中,θnew表示更新决策参数;θold表示初始决策参数;vnew表示更新速度。
本申请实施例二提供的应用于动量梯度下降的模型优化装置,由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
继续参阅图5,示出了图4中函数定义模块103的结构示意图,为了便于说明,仅示出与本申请相关的部分。
在本申请实施例一的一些可选的实现方式中,上述函数定义模块103具体包括:矩阵生成子模块1031、矩阵分解子模块1032以及函数构造子模块1033。其中:
矩阵生成子模块1031,用于基于用户文本的数据集生成用户-文本矩阵;
矩阵分解子模块1032,用于基于奇异值分解法对用户-文本矩阵进行分解操作,得到用户-隐特征矩阵以及隐特征-文本矩阵;
函数构造子模块1033,用于基于用户-文本矩阵构造目标函数。
在本申请实施例中,奇异值分解(Singular Value Decomposition)是线性代数中一种重要的矩阵分解,奇异值分解则是特征分解在任意矩阵上的推广。
其中,R(Λ)表示用户-文本矩阵R用户对文本的评分数据集合;pm`表示用户-隐特征矩阵P中第m个用户对应的隐特征;qn`表示隐特征-文本矩阵Q中第n个文本对应的隐特征;rm,n表示用户m对文本n的评分数据;表示评分数据集合中用户m对文本n的评分数据;λ2表示隐特征矩阵的正则化因子。
在本申请实施例二的一些可选的实现方式中,梯度数据表示为:
在本申请实施例二的一些可选的实现方式中,更新速度表示为:
vnew=αvold-∈g
其中,vnew表示更新速度;vold表示初始速度参数;α表示动量参数;∈表示学习率;g表示梯度数据。
在本申请实施例二的一些可选的实现方式中,更新决策参数表示为:
θnew=θold+vnew
其中,θnew表示更新决策参数;θold表示初始决策参数;vnew表示更新速度。
在本申请实施例二的一些可实现方式中,上述目标模型获取模块110具体包括:差值计算子模块、收敛判断子模块、收敛确认子模块以及未收敛确认子模块。其中:
差值计算子模块,用于计算所述初始决策参数以及所述更新决策参数的决策参数差值;
收敛判断子模块,用于判断所述决策参数差值是否小于所述预设收敛阈值;
收敛确认子模块,用于若所述决策参数差值小于或等于所述预设收敛阈值,则确定当前的预测模型收敛,并将所述当前的预测模型作为所述目标预测模型;
未收敛确认子模块,用于若所述决策参数差值大于所述预设收敛阈值,则则确定当前的预测模型未收敛,继续执行参数优化操作。
综上,本申请实施例二提供的应用于动量梯度下降的模型优化装置,包括:请求接收模块,用于接收用户终端发送的模型优化请求,模型优化请求至少携带有原始预测模型以及原始训练数据集;采样操作模块,用于在原始训练数据集中进行采样操作,得到本轮训练数据集;函数定义模块,用于基于本轮训练数据集定义目标函数;初始化模块,用于初始化原始预测模型的模型优化参数,得到初始速度参数以及初始决策参数;梯度计算模块,用于计算本轮需要更新初始决策参数对应的梯度数据;梯度判断模块,用于判断梯度数据是否已更新;异常确认模块,用于若梯度数据未更新,则输出采样异常信号;速度参数更新模块,用于若梯度数据已更新,则基于梯度数据更新初始速度参数,得到更新速度;决策参数更新模块,用于基于更新速度更新初始决策参数,得到更新决策参数;目标模型获取模块,用于当初始决策参数以及更新决策参数满足收敛条件时,得到目标预测模型。由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
为解决上述技术问题,本申请实施例还提供计算机设备。具体请参阅图6,图6为本实施例计算机设备基本结构框图。
所述计算机设备200包括通过系统总线相互通信连接存储器210、处理器220、网络接口230。需要指出的是,图中仅示出了具有组件210-230的计算机设备200,但是应理解的是,并不要求实施所有示出的组件,可以替代的实施更多或者更少的组件。其中,本技术领域技术人员可以理解,这里的计算机设备是一种能够按照事先设定或存储的指令,自动进行数值计算和/或信息处理的设备,其硬件包括但不限于微处理器、专用集成电路(Application Specific Integrated Circuit,ASIC)、可编程门阵列(Field-Programmable Gate Array,FPGA)、数字处理器(Digital Signal Processor,DSP)、嵌入式设备等。
所述计算机设备可以是桌上型计算机、笔记本、掌上电脑及云端服务器等计算设备。所述计算机设备可以与用户通过键盘、鼠标、遥控器、触摸板或声控设备等方式进行人机交互。
所述存储器210至少包括一种类型的可读存储介质,所述可读存储介质包括闪存、硬盘、多媒体卡、卡型存储器(例如,SD或DX存储器等)、随机访问存储器(RAM)、静态随机访问存储器(SRAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、可编程只读存储器(PROM)、磁性存储器、磁盘、光盘等。在一些实施例中,所述存储器210可以是所述计算机设备200的内部存储单元,例如该计算机设备200的硬盘或内存。在另一些实施例中,所述存储器210也可以是所述计算机设备200的外部存储设备,例如该计算机设备200上配备的插接式硬盘,智能存储卡(Smart Media Card,SMC),安全数字(Secure Digital,SD)卡,闪存卡(Flash Card)等。当然,所述存储器210还可以既包括所述计算机设备200的内部存储单元也包括其外部存储设备。本实施例中,所述存储器210通常用于存储安装于所述计算机设备200的操作系统和各类应用软件,例如应用于动量梯度下降的模型优化方法的计算机可读指令等。此外,所述存储器210还可以用于暂时地存储已经输出或者将要输出的各类数据。
所述处理器220在一些实施例中可以是中央处理器(Central Processing Unit,CPU)、控制器、微控制器、微处理器、或其他数据处理芯片。该处理器220通常用于控制所述计算机设备200的总体操作。本实施例中,所述处理器220用于运行所述存储器210中存储的计算机可读指令或者处理数据,例如运行所述应用于动量梯度下降的模型优化方法的计算机可读指令。
所述网络接口230可包括无线网络接口或有线网络接口,该网络接口230通常用于在所述计算机设备200与其他电子设备之间建立通信连接。
本申请提供的应用于动量梯度下降的模型优化方法,由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
本申请还提供了另一种实施方式,即提供一种计算机可读存储介质,所述计算机可读存储介质存储有计算机可读指令,所述计算机可读指令可被至少一个处理器执行,以使所述至少一个处理器执行如上述的应用于动量梯度下降的模型优化方法的步骤。
本申请提供的应用于动量梯度下降的模型优化方法,由于带动量的随机梯度下降在训练过程中,当前轮次的训练数据没有被采样到,而该轮次梯度更新仍然会使用历史动量来更新,这可能导致Embedding层过拟合,本申请在更新梯度之前,通过确认梯度数据是否已经更新,从而确认该轮次的训练数据确定被采样,才进行该梯度更新操作,从而有效避免在训练时当前batch中没被采样到的词,依然会使用历史动量来更新导致Embedding层过拟合的问题。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台终端设备(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本申请各个实施例所述的方法。
显然,以上所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例,附图中给出了本申请的较佳实施例,但并不限制本申请的专利范围。本申请可以以许多不同的形式来实现,相反地,提供这些实施例的目的是使对本申请的公开内容的理解更加透彻全面。尽管参照前述实施例对本申请进行了详细的说明,对于本领域的技术人员来而言,其依然可以对前述各具体实施方式所记载的技术方案进行修改,或者对其中部分技术特征进行等效替换。凡是利用本申请说明书及附图内容所做的等效结构,直接或间接运用在其他相关的技术领域,均同理在本申请专利保护范围之内。
Claims (10)
1.一种应用于动量梯度下降的模型优化方法,其特征在于,包括下述步骤:
接收用户终端发送的模型优化请求,所述模型优化请求至少携带有原始预测模型以及原始训练数据集;
在所述原始训练数据集中进行采样操作,得到本轮训练数据集;
基于所述本轮训练数据集定义目标函数;
初始化所述原始预测模型的模型优化参数,得到初始速度参数以及初始决策参数;
计算本轮需要更新所述初始决策参数对应的梯度数据;
判断所述梯度数据是否已更新;
若所述梯度数据未更新,则输出采样异常信号;
若所述梯度数据已更新,则基于所述梯度数据更新所述初始速度参数,得到更新速度;
基于所述更新速度更新所述初始决策参数,得到更新决策参数;
当所述初始决策参数以及所述更新决策参数满足收敛条件时,得到目标预测模型。
2.根据权利要求1所述的应用于动量梯度下降的模型优化方法,其特征在于,所述本轮训练数据集包括用户文本的数据集,所述基于所述本轮训练数据集定义目标函数的步骤,具体包括:
基于所述用户文本的数据集生成用户-文本矩阵;
基于奇异值分解法对所述用户-文本矩阵进行分解操作,得到用户-隐特征矩阵以及隐特征-文本矩阵;
4.根据权利要求3所述的应用于动量梯度下降的模型优化方法,其特征在于,所述更新速度表示为:
vnew=αvold-∈g
其中,vnew表示所述更新速度;vold表示所述初始速度参数;α表示动量参数;∈表示学习率;g表示所述梯度数据。
5.根据权利要求1所述的应用于动量梯度下降的模型优化方法,其特征在于,所述更新决策参数表示为:
θnew=θold+vnew
其中,θnew表示更新决策参数;θold表示初始决策参数;vnew表示所述更新速度。
6.根据权利要求5所述的应用于动量梯度下降的模型优化方法,其特征在于,所述收敛条件为预设收敛阈值;所述当所述初始决策参数以及所述更新决策参数满足收敛条件时,得到目标预测模型的步骤,具体包括:
计算所述初始决策参数以及所述更新决策参数的决策参数差值;
判断所述决策参数差值是否小于所述预设收敛阈值;
若所述决策参数差值小于或等于所述预设收敛阈值,则确定当前的预测模型收敛,并将所述当前的预测模型作为所述目标预测模型;
若所述决策参数差值大于所述预设收敛阈值,则则确定当前的预测模型未收敛,继续执行参数优化操作。
7.一种应用于动量梯度下降的模型优化装置,其特征在于,包括:
请求接收模块,用于接收用户终端发送的模型优化请求,所述模型优化请求至少携带有原始预测模型以及原始训练数据集;
采样操作模块,用于在所述原始训练数据集中进行采样操作,得到本轮训练数据集;
函数定义模块,用于基于所述本轮训练数据集定义目标函数;
初始化模块,用于初始化所述原始预测模型的模型优化参数,得到初始速度参数以及初始决策参数;
梯度计算模块,用于计算本轮需要更新所述初始决策参数对应的梯度数据;
梯度判断模块,用于判断所述梯度数据是否已更新;
异常确认模块,用于若所述梯度数据未更新,则输出采样异常信号;
速度参数更新模块,用于若所述梯度数据已更新,则基于所述梯度数据更新所述初始速度参数,得到更新速度;
决策参数更新模块,用于基于所述更新速度更新所述初始决策参数,得到更新决策参数;
目标模型获取模块,用于当所述初始决策参数以及所述更新决策参数满足收敛条件时,得到目标预测模型。
8.根据权利要求7所述的应用于动量梯度下降的模型优化装置,其特征在于,所述函数定义模块包括:
矩阵生成子模块,用于基于所述用户文本的数据集生成用户-文本矩阵;
矩阵分解子模块,用于基于奇异值分解法对所述用户-文本矩阵进行分解操作,得到用户-隐特征矩阵以及隐特征-文本矩阵;
9.一种计算机设备,包括存储器和处理器,所述存储器中存储有计算机可读指令,所述处理器执行所述计算机可读指令时实现如权利要求1至6中任一项所述的应用于动量梯度下降的模型优化方法的步骤。
10.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有计算机可读指令,所述计算机可读指令被处理器执行时实现如权利要求1至6中任一项所述的应用于动量梯度下降的模型优化方法的步骤。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011359384.8A CN112488183B (zh) | 2020-11-27 | 2020-11-27 | 一种模型优化方法、装置、计算机设备及存储介质 |
PCT/CN2021/090501 WO2022110640A1 (zh) | 2020-11-27 | 2021-04-28 | 一种模型优化方法、装置、计算机设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011359384.8A CN112488183B (zh) | 2020-11-27 | 2020-11-27 | 一种模型优化方法、装置、计算机设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112488183A true CN112488183A (zh) | 2021-03-12 |
CN112488183B CN112488183B (zh) | 2024-05-10 |
Family
ID=74935992
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011359384.8A Active CN112488183B (zh) | 2020-11-27 | 2020-11-27 | 一种模型优化方法、装置、计算机设备及存储介质 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN112488183B (zh) |
WO (1) | WO2022110640A1 (zh) |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2022110640A1 (zh) * | 2020-11-27 | 2022-06-02 | 平安科技(深圳)有限公司 | 一种模型优化方法、装置、计算机设备及存储介质 |
CN117077598A (zh) * | 2023-10-13 | 2023-11-17 | 青岛展诚科技有限公司 | 一种基于Mini-batch梯度下降法的3D寄生参数的优化方法 |
Families Citing this family (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116068903B (zh) * | 2023-04-06 | 2023-06-20 | 中国人民解放军国防科技大学 | 一种闭环系统鲁棒性能的实时优化方法、装置及设备 |
CN116451872B (zh) * | 2023-06-08 | 2023-09-01 | 北京中电普华信息技术有限公司 | 碳排放预测分布式模型训练方法、相关方法及装置 |
CN117596156B (zh) * | 2023-12-07 | 2024-05-07 | 机械工业仪器仪表综合技术经济研究所 | 一种工业应用5g网络的评估模型的构建方法 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110390561A (zh) * | 2019-07-04 | 2019-10-29 | 四川金赞科技有限公司 | 基于动量加速随机梯度下降的用户-金融产品选用倾向高速预测方法和装置 |
CN110889509A (zh) * | 2019-11-11 | 2020-03-17 | 安徽超清科技股份有限公司 | 一种基于梯度动量加速的联合学习方法及装置 |
CN111639710A (zh) * | 2020-05-29 | 2020-09-08 | 北京百度网讯科技有限公司 | 图像识别模型训练方法、装置、设备以及存储介质 |
Family Cites Families (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US10282513B2 (en) * | 2015-10-13 | 2019-05-07 | The Governing Council Of The University Of Toronto | Methods and systems for 3D structure estimation |
CN110730037B (zh) * | 2019-10-21 | 2021-02-26 | 苏州大学 | 一种基于动量梯度下降法的相干光通信系统光信噪比监测方法 |
CN111507530B (zh) * | 2020-04-17 | 2022-05-31 | 集美大学 | 基于分数阶动量梯度下降的rbf神经网络船舶交通流预测方法 |
CN111695295A (zh) * | 2020-06-01 | 2020-09-22 | 中国人民解放军火箭军工程大学 | 一种光栅耦合器的入射参数反演模型的构建方法 |
CN112488183B (zh) * | 2020-11-27 | 2024-05-10 | 平安科技(深圳)有限公司 | 一种模型优化方法、装置、计算机设备及存储介质 |
-
2020
- 2020-11-27 CN CN202011359384.8A patent/CN112488183B/zh active Active
-
2021
- 2021-04-28 WO PCT/CN2021/090501 patent/WO2022110640A1/zh active Application Filing
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110390561A (zh) * | 2019-07-04 | 2019-10-29 | 四川金赞科技有限公司 | 基于动量加速随机梯度下降的用户-金融产品选用倾向高速预测方法和装置 |
CN110889509A (zh) * | 2019-11-11 | 2020-03-17 | 安徽超清科技股份有限公司 | 一种基于梯度动量加速的联合学习方法及装置 |
CN111639710A (zh) * | 2020-05-29 | 2020-09-08 | 北京百度网讯科技有限公司 | 图像识别模型训练方法、装置、设备以及存储介质 |
Non-Patent Citations (1)
Title |
---|
小刘同学: ""机器学习优化方法:Momentum栋梁梯度下降", pages 5 - 10, Retrieved from the Internet <URL:https://blog.csdn.net/sweetseven_/article/details/103353990> * |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2022110640A1 (zh) * | 2020-11-27 | 2022-06-02 | 平安科技(深圳)有限公司 | 一种模型优化方法、装置、计算机设备及存储介质 |
CN117077598A (zh) * | 2023-10-13 | 2023-11-17 | 青岛展诚科技有限公司 | 一种基于Mini-batch梯度下降法的3D寄生参数的优化方法 |
CN117077598B (zh) * | 2023-10-13 | 2024-01-26 | 青岛展诚科技有限公司 | 一种基于Mini-batch梯度下降法的3D寄生参数的优化方法 |
Also Published As
Publication number | Publication date |
---|---|
WO2022110640A1 (zh) | 2022-06-02 |
CN112488183B (zh) | 2024-05-10 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112488183B (zh) | 一种模型优化方法、装置、计算机设备及存储介质 | |
US10936949B2 (en) | Training machine learning models using task selection policies to increase learning progress | |
CN111758105A (zh) | 学习数据增强策略 | |
CN112101172A (zh) | 基于权重嫁接的模型融合的人脸识别方法及相关设备 | |
CN113408743A (zh) | 联邦模型的生成方法、装置、电子设备和存储介质 | |
CN114780727A (zh) | 基于强化学习的文本分类方法、装置、计算机设备及介质 | |
CN111488985A (zh) | 深度神经网络模型压缩训练方法、装置、设备、介质 | |
CN112861012B (zh) | 基于上下文和用户长短期偏好自适应学习的推荐方法及装置 | |
WO2020191001A1 (en) | Real-world network link analysis and prediction using extended probailistic maxtrix factorization models with labeled nodes | |
US11941867B2 (en) | Neural network training using the soft nearest neighbor loss | |
CN115766104A (zh) | 一种基于改进的Q-learning网络安全决策自适应生成方法 | |
CN116684330A (zh) | 基于人工智能的流量预测方法、装置、设备及存储介质 | |
CN114238656A (zh) | 基于强化学习的事理图谱补全方法及其相关设备 | |
CN111144473A (zh) | 训练集构建方法、装置、电子设备及计算机可读存储介质 | |
CN114443896B (zh) | 数据处理方法和用于训练预测模型的方法 | |
CN113887535B (zh) | 模型训练方法、文本识别方法、装置、设备和介质 | |
CN115099875A (zh) | 基于决策树模型的数据分类方法及相关设备 | |
CN114817476A (zh) | 语言模型的训练方法、装置、电子设备和存储介质 | |
CN114241411A (zh) | 基于目标检测的计数模型处理方法、装置及计算机设备 | |
CN114120367A (zh) | 元学习框架下基于圆损失度量的行人重识别方法及系统 | |
CN114730380A (zh) | 神经网络的深度并行训练 | |
CN111178630A (zh) | 一种负荷预测方法及装置 | |
CN115630687B (zh) | 模型训练方法、交通流量预测方法和装置 | |
CN114971095B (zh) | 在线教育效果预测方法、装置、设备及存储介质 | |
CN113420628B (zh) | 一种群体行为识别方法、装置、计算机设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |