CN114580578B - 具有约束的分布式随机优化模型训练方法、装置及终端 - Google Patents

具有约束的分布式随机优化模型训练方法、装置及终端 Download PDF

Info

Publication number
CN114580578B
CN114580578B CN202210486474.6A CN202210486474A CN114580578B CN 114580578 B CN114580578 B CN 114580578B CN 202210486474 A CN202210486474 A CN 202210486474A CN 114580578 B CN114580578 B CN 114580578B
Authority
CN
China
Prior art keywords
gradient
data
agent
training
agents
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202210486474.6A
Other languages
English (en)
Other versions
CN114580578A (zh
Inventor
侯洁
孙涛
崔金强
丁玉隆
尉越
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peng Cheng Laboratory
Original Assignee
Peng Cheng Laboratory
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peng Cheng Laboratory filed Critical Peng Cheng Laboratory
Priority to CN202210486474.6A priority Critical patent/CN114580578B/zh
Publication of CN114580578A publication Critical patent/CN114580578A/zh
Application granted granted Critical
Publication of CN114580578B publication Critical patent/CN114580578B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D30/00Reducing energy consumption in communication networks
    • Y02D30/70Reducing energy consumption in communication networks in wireless communication networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Engineering & Computer Science (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本发明公开了一种具有约束的分布式随机优化模型训练方法、装置及终端,上述方法包括:循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;其中,对智能体进行迭代训练时获取随机训练样本,在每次迭代中通过随机选择的无偏随机局部梯度来计算局部目标函数的全局梯度。与现有技术相比,大大降低了梯度评估的成本和计算复杂性,可用于高维、大规模优化问题。

Description

具有约束的分布式随机优化模型训练方法、装置及终端
技术领域
本发明涉及机器学习技术领域,尤其涉及的是一种具有约束的分布式随机优化模型训练方法、装置及终端。
背景技术
近年来,随着高科技的蓬勃发展,特别是云计算和大数据等新兴领域的出现。具有约束和随机因素的分布式优化理论和应用得到了越来越多的重视,并逐渐渗透到科学研究、工程应用和社会生活的各个方面。分布式优化是通过多智能体之间的合作协调有效地实现优化的任务,可用来解决许多集中式算法难以胜任的大规模复杂的优化问题。可以创建具有约束的随机优化模型用来解决带有约束和随机因素的优化问题,目前对该优化模型训练时常采用基于投影的随机梯度下降法和条件梯度法。
采用基于投影的随机梯度下降法训练时,在向负随机梯度方向迈出一步后,迭代将被投射回约束集上。当执行投影的计算成本较低时(如投影单纯形上),这种方法是有效的。但在许多实际情况下,如处理迹范数球、基多面体等可行域情况时,投影到约束集上的成本可能很高,计算效能低。
条件梯度法,通过求解约束集上的线性最小化子问题来避免投影的计算以获得一个条件梯度,然后通过当前迭代和条件梯度的凸组合来更新下一步迭代。虽然条件梯度法及其变体可以解决具有约束和随机因素的随机优化问题,但是只能应用在集中式环境,无法应用在分布式环境中,并且收敛速率不快。
发明内容
本发明的主要目的在于提供一种具有约束的分布式随机优化模型训练方法、装置、智能终端及存储介质,能够对随机优化模型进行训练,以在分布式环境中解决具有复杂约束和随机因素的优化问题。
为了实现上述目的,本发明第一方面提供一种具有约束的分布式随机优化模型训练方法,所述模型中包括由至少两个智能体组成的智能体集合,所述方法包括:
循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
所述对智能体进行迭代训练包括如下步骤:
获取训练样本,所述训练样本为随机的样本数据;
基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
可选的,所述基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度,包括:
基于所述平均状态数据,获得第一随机梯度;
基于所述迭代数据中的平均状态数据,获得第二随机梯度;
基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度。
可选的,所述基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度的表达式为:
Figure DEST_PATH_IMAGE001
其中
Figure 100002_DEST_PATH_IMAGE002
为衰减步长,
Figure DEST_PATH_IMAGE003
Figure 100002_DEST_PATH_IMAGE004
为迭代数据中的局部梯度,
Figure DEST_PATH_IMAGE005
为第一随机梯度,
Figure 100002_DEST_PATH_IMAGE006
为第二随机梯度,
Figure DEST_PATH_IMAGE007
为随机变量。
可选的,所述获取训练样本,包括:
在已获取的训练样本集中随机选取设定数量的训练样本或通过在线采样获得训练样本。
可选的,所述基于所述邻居节点数据,根据平均一致性算法获得平均状态数据,包括:
获取所述邻居节点数据中各智能体的待优化参数并根据平均一致性算法获得所述平均状态数据。
可选的,所述计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度,包括:
基于所述局部梯度和所述当前迭代的智能体的迭代数据中的局部梯度,获得所述梯度下降值;
基于所述智能体集合中各智能体之间的关联关系,获得与所述当前迭代的智能体关联的邻居节点;
基于所述梯度下降值和所述邻居节点的迭代数据中的聚合梯度,获得所述当前迭代的智能体的聚合梯度;
基于所述邻居节点数据,根据平均一致性方法获得所述全局梯度。
可选的,所述基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数,包括:
获取目标场景的约束集合;
基于所述全局梯度与所述约束集合的相关性,获得可行方向;
基于所述可行方向与所述平均状态数据的凸组合,更新所述当前迭代的智能体的待优化参数。
本发明第二方面提供一种具有约束的分布式随机优化模型训练装置,其中,上述装置包括:
迭代模块,用于循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
样本数据获取模块,用于获取训练样本,所述训练样本为随机的样本数据;
邻居节点数据获取模块,用于基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
平均状态数据计算模块,用于基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
局部梯度计算模块,用于基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
全局梯度计算模块,用于计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
更新模块,用于基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
本发明第三方面提供一种智能终端,上述智能终端包括存储器、处理器以及存储在上述存储器上并可在上述处理器上运行的具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被上述处理器执行时实现任意一项上述具有约束的分布式随机优化模型训练方法的步骤。
本发明第四方面提供一种计算机可读存储介质,上述计算机可读存储介质上存储有具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被处理器执行时实现任意一项上述具有约束的分布式随机优化模型训练方法的步骤。
由上可见,与现有技术相比,本发明在训练每个智能体时使用随机的样本数据,采用随机梯度来计算梯度下降值并根据模型中各个智能体之间的关联关系来更新智能体的待优化参数,每次迭代时只需要计算一次样本梯度,也不需要存储样本的梯度信息或状态信息。因此,本发明的模型训练方法不仅可以在分布式环境中解决具有复杂约束和随机因素的优化问题,而且收敛速度快、计算效能高、存储开销小。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其它的附图。
图1是本发明实施例提供的具有约束的分布式随机优化模型训练方法中的智能体训练流程示意图;
图2是图1实施例的步骤S500具体流程示意图;
图3是图1实施例的步骤S600具体流程示意图;
图4是图1实施例的步骤S400具体流程示意图;
图5是图1实施例的流程框图;
图6是图1实施例的测试集准确率曲线图;
图7是本发明实施例提供的具有约束的分布式随机优化模型训练装置的结构示意图;
图8是本发明实施例提供的一种智能终端的内部结构原理框图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本发明实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本发明。在其它情况下,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本发明的描述。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当…时”或“一旦”或“响应于确定”或“响应于检测到”。类似的,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述的条件或事件]”或“响应于检测到[所描述条件或事件]”。
下面结合本发明实施例的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明的一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
在下面的描述中阐述了很多具体细节以便于充分理解本发明,但是本发明还可以采用其它不同于在此描述的其它方式来实施,本领域技术人员可以在不违背本发明内涵的情况下做类似推广,因此本发明不受下面公开的具体实施例的限制。
具有约束和随机因素的分布式随机优化问题广泛存在于各种工程应用中,如无人系统、无线通信、分布式机器学习、多智能体强化学习等等。由于在分布式网络或分布式控制等目标场景(比如资源分配, 传感器网络中的定位)中的可行域可以是复杂的迹范数球、基本多面体等,现有的基于投影的随机梯度下降方法很难处理或无法处理这些优化问题。
本发明利用梯度跟踪技术将条件梯度方法扩展到分布式环境中,不仅避免了投影计算,还提高了模型的计算效能。在解决具有随机因素的凸优化问题方面,达到与基于投影的随机梯度下降方法一致的收敛速率
Figure 100002_DEST_PATH_IMAGE008
示例性方法
如图1所示,本发明实施例提供一种具有约束的分布式随机优化模型训练方法,使用该方法创建用于在分布式环境中解决具有复杂约束集的随机优化问题的网络模型。在分布式环境中,每个个体往往有一个代价函数,且整个网络的代价由这些个体的代价函数和来表示。网络模型的目的是通过个体间的局部信息交流而完成整个网络模型代价函数的优化。其中每个个体只知道自己的代价函数,在给定的分布式优化算法下得出保证其收敛的条件。上述个体也称为智能体。
用于分布式环境中的网络模型包括至少两个智能体,所有的智能体组成智能体集合。通过循环获取智能体集合中的每一个智能体,对获取到的智能体进行迭代训练并在智能体中保存生成的迭代数据,直至网络模型的迭代次数达到设定次数,完成网络模型的创建,应用在目标场景中。本领域技术人员熟知,网络模型创建好后,还可以通过预先设定的测试数据对网络模型进行测试。
具体的,上述迭代训练包括如下步骤:
步骤S100:获取训练样本,所述训练样本为随机的样本数据;
具体的,本发明解决的是带有随机因素的优化问题,在实际问题中并不知道样本的真实概率分布,因此传统随机优化方法(SAG、SAGA、SVRG或其变体)无法解决该问题。而本发明方案无需获取到所有样本的信息或者随机样本的分布情况,训练样本可以为随机的样本数据,每次迭代只需随机选取一个样本或是在线获取一个样本即可。具体实施时, 可以在已获取到的训练样本集中随机选取设定数量的训练样本;或者在线采样获取训练样本。因此本发明的训练方法既可用于带有随机因素的优化问题(stochastic),也可用于有限和问题(finite-sum)中。
在本实施例中,网络模型中设定了10个智能体,以公开数据集a9a作为训练数据集,每次迭代训练时,从该训练数据集中随机获取10%的样本数据作为训练数据。
步骤S200:基于智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
其中,邻居节点数据包括各个智能体之间的连接权重,各个智能体保存的上一次迭代时获得的迭代数据。迭代数据包括下述迭代步骤中生成的平均状态数据、局部梯度和聚合梯度等。
具体的,本发明中将多智能体表示为一个由
Figure DEST_PATH_IMAGE009
个智能体组成的集合
Figure 100002_DEST_PATH_IMAGE010
。集合
Figure DEST_PATH_IMAGE011
中的智能体通过通信网络
Figure 100002_DEST_PATH_IMAGE012
互相交换信息,其中
Figure DEST_PATH_IMAGE013
表示网络
Figure 100002_DEST_PATH_IMAGE014
中边的集合,
Figure DEST_PATH_IMAGE015
。根据边的集合,可以获得整个模型的权重连接矩阵
Figure DEST_PATH_IMAGE016
,权重连接矩阵
Figure DEST_PATH_IMAGE017
是一个双随机矩阵,且其中的行列和均为1。
Figure DEST_PATH_IMAGE018
表示第
Figure DEST_PATH_IMAGE019
个智能体和第
Figure DEST_PATH_IMAGE020
个智能体之间的连接权重。因此,可以通过该集合获得与每个智能体关联的所有邻居节点、连接权重等数据。
步骤S300:基于邻居节点数据,根据平均一致性算法获得平均状态数据;
其中,平均一致性算法用于根据一个智能体的所有邻居节点数据预估当前迭代的智能体的平均状态。具体的,通过获取当前迭代的智能体的邻居节点数据中的待优化参数并根据平均一致性算法获得当前迭代的智能体的平均状态数据。
例如:第
Figure 579417DEST_PATH_IMAGE019
个智能体的平均状态数据具体可以表示为:
Figure DEST_PATH_IMAGE021
,其中,
Figure DEST_PATH_IMAGE022
表示第
Figure 292158DEST_PATH_IMAGE019
个智能体的邻居节点集合,
Figure DEST_PATH_IMAGE023
表示第
Figure 278568DEST_PATH_IMAGE019
个智能体和第
Figure 178391DEST_PATH_IMAGE020
个智能体之间的连接权重,
Figure DEST_PATH_IMAGE024
为第
Figure 275660DEST_PATH_IMAGE020
个智能体第
Figure DEST_PATH_IMAGE025
次迭代时的待优化参数,
Figure DEST_PATH_IMAGE026
为第
Figure 108487DEST_PATH_IMAGE019
个智能体第
Figure 219DEST_PATH_IMAGE025
次迭代时的平均状态数据。
步骤S400:基于平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
具体的,由于随机变量
Figure DEST_PATH_IMAGE027
的分布未知,只能得到目标函数
Figure DEST_PATH_IMAGE028
的随机梯度,即对于给定的
Figure DEST_PATH_IMAGE029
和随机变量
Figure 195795DEST_PATH_IMAGE027
,可以得到实际梯度
Figure DEST_PATH_IMAGE030
的无偏估计
Figure DEST_PATH_IMAGE031
。众所周知,条件梯度算法的朴素随机实现是可以将得到的随机梯度
Figure 96755DEST_PATH_IMAGE031
替换实际梯度
Figure 518509DEST_PATH_IMAGE030
,但是由于无法消失的方差的存在导致算法很可能会发散。为了解决这个问题,本发明利用递归动量思想设计局部梯度迭代公式,局部梯度不仅与当前迭代样本的梯度有关,且与上一次局部梯度有关。
基于解决有约束的随机优化问题,计算局部梯度时本发明创新性地采用了随机变量并根据递归动量策略纳入上一次迭代的局部梯度,不仅可以消除随机梯度方差带来的影响,而且可以起到动量加速的效果,保证在目标函数是凸函数的情况下得到与基于投影的随机梯度下降法一致的收敛速率
Figure 581143DEST_PATH_IMAGE008
步骤S500:计算局部梯度的梯度下降值并根据邻居节点数据获得全局梯度;
具体的,本发明基于传统梯度跟踪技术的原理,利用上一次迭代的迭代数据来计算全局梯度。首先将当前迭代的智能体的局部梯度与该智能体保存的迭代数据中的局部梯度比较,计算出局部梯度的梯度下降值,根据梯度下降值结合邻居节点数据更新当前迭代的智能体的变量和全局梯度,使得当前迭代的智能体的变量更新体现对全局梯度的追踪。
在本实施例中,如图2所示,获得全局梯度具体包括如下步骤:
步骤S510:基于局部梯度和迭代智能体的迭代数据中的局部梯度,获得梯度下降值;
具体的,将局部梯度
Figure DEST_PATH_IMAGE032
与迭代智能体的迭代数据中的局部梯度
Figure 455558DEST_PATH_IMAGE004
相减,获得的差值即为梯度下降值。
步骤S520:基于智能体之间的关联信息,获得迭代智能体关联的邻居节点;
步骤S530:基于梯度下降值和邻居节点的迭代数据中的聚合梯度,获得迭代智能体的聚合梯度;
具体的,根据梯度跟踪方法,计算聚合梯度。具体的计算公式为:
Figure DEST_PATH_IMAGE033
其中,
Figure DEST_PATH_IMAGE034
为第
Figure 160209DEST_PATH_IMAGE019
个智能体第
Figure 170890DEST_PATH_IMAGE025
次迭代的聚合梯度,
Figure DEST_PATH_IMAGE035
为第
Figure 670004DEST_PATH_IMAGE020
个智能体第
Figure DEST_PATH_IMAGE036
次迭代时的聚合梯度,
Figure 297295DEST_PATH_IMAGE022
表示第
Figure 8899DEST_PATH_IMAGE019
个智能体的邻居节点集合,
Figure 139666DEST_PATH_IMAGE023
表示第
Figure 544103DEST_PATH_IMAGE019
个智能体和第
Figure 393110DEST_PATH_IMAGE020
个智能体之间的连接权重,
Figure 173984DEST_PATH_IMAGE032
为第
Figure 159258DEST_PATH_IMAGE019
个智能体第
Figure 734595DEST_PATH_IMAGE025
次迭代时的局部梯度,
Figure 70899DEST_PATH_IMAGE004
为第
Figure 124305DEST_PATH_IMAGE019
个智能体第
Figure 229665DEST_PATH_IMAGE036
次迭代时的局部梯度。
步骤S540:基于邻居节点数据,根据平均一致性方法获得全局梯度。
具体的,根据聚合梯度,采用平均一致性算法计算全局梯度。具体计算公式为:
Figure DEST_PATH_IMAGE037
,其中
Figure DEST_PATH_IMAGE038
为第
Figure 507062DEST_PATH_IMAGE019
个智能体第
Figure 65083DEST_PATH_IMAGE025
次迭代时的全局梯度,
Figure DEST_PATH_IMAGE039
为第
Figure 187759DEST_PATH_IMAGE020
个智能体第
Figure 147625DEST_PATH_IMAGE025
次迭代的聚合梯度,
Figure 330345DEST_PATH_IMAGE022
表示第
Figure 375661DEST_PATH_IMAGE019
个智能体的邻居节点集合,
Figure 770870DEST_PATH_IMAGE023
表示第
Figure 850822DEST_PATH_IMAGE019
个智能体和第
Figure 938864DEST_PATH_IMAGE020
个智能体之间的连接权重。
步骤S600:基于全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数;
具体的,首先通过最小化全局梯度
Figure 471476DEST_PATH_IMAGE038
与可行集(即约束集合)
Figure DEST_PATH_IMAGE040
的相关性得到可行方向
Figure DEST_PATH_IMAGE041
,然后构建可行方向
Figure 201535DEST_PATH_IMAGE041
与平均状态数据
Figure 135993DEST_PATH_IMAGE026
的凸组合,更新待优化参数。具体公式为:
Figure DEST_PATH_IMAGE042
Figure DEST_PATH_IMAGE043
其中
Figure DEST_PATH_IMAGE044
是迭代步长。
在本实施例中,如图3所示,更新当前迭代的智能体的待优化参数具体包括如下步骤:
步骤S610:获取目标场景的约束集合;
步骤S620:基于全局梯度与约束集合的相关性,获得可行方向;
具体的,通过最小化全局梯度
Figure 926094DEST_PATH_IMAGE038
与可行集(即约束集合)
Figure 946003DEST_PATH_IMAGE040
的相关性得到可行方向
Figure 948594DEST_PATH_IMAGE041
,具体表达式为:
Figure DEST_PATH_IMAGE045
步骤S630:基于可行方向与平均状态数据的凸组合,更新迭代智能体的待优化参数。
具体的,更新迭代智能体的待优化参数的表达式为:
Figure 3138DEST_PATH_IMAGE043
,其中
Figure DEST_PATH_IMAGE046
为用于
Figure DEST_PATH_IMAGE047
次迭代的待优化参数,
Figure 229720DEST_PATH_IMAGE026
为所述平均状态数据,
Figure 736924DEST_PATH_IMAGE041
为所述可行方向,
Figure DEST_PATH_IMAGE048
为预设步长。
对模型中的所有智能体完成迭代一次后,令上述变量中的
Figure DEST_PATH_IMAGE049
,进行下一次迭代。直至迭代次数k大于设定的迭代次数K。
由上所述,本发明可应用于具有复杂约束集的随机环境中,通过利用条件梯度技术避免了代价高的投影计算,大大提高计算效能。同时,本发明在每次迭代中通过随机选择的无偏随机局部梯度来计算局部目标函数的全局梯度,大大降低了梯度评估的成本和计算复杂性,可用于高维、大规模优化问题。并且每次迭代随机选取部分样本(样本个数大于等于1)来计算样本梯度,无需对所有样本或者批量样本进行梯度计算,更适用于高维、大规模优化问题。该方法也可以用于有限和问题中,与算法SAG和SAGA不同,该方法无需为每个样本维护一个旧的梯度,具有更小的存储开销;该方法通过利用动量更新思想,对于随机凸优化(stochastic)问题可以达到与基于投影的随机梯度法一致的收敛速率
Figure 808785DEST_PATH_IMAGE008
在一些实施例中,如图4所示,上述步骤S400中计算局部梯度,具体包括步骤:
步骤S410:基于平均状态数据,获得第一随机梯度;
步骤S420:基于迭代数据中的平均状态数据,获得第二随机梯度;
步骤S430:基于预设的衰减步长、迭代数据中的局部梯度、第一随机梯度、第二随机梯度,获得局部梯度;
具体的,首先根据随机变量
Figure 717836DEST_PATH_IMAGE007
和平均状态数据
Figure 318581DEST_PATH_IMAGE026
,计算第一随机梯度,即:
Figure 313082DEST_PATH_IMAGE005
;根据迭代数据中的平均状态数据
Figure DEST_PATH_IMAGE050
(上一次迭代时获得的平均状态数据)和随机变量
Figure 923055DEST_PATH_IMAGE007
,获得第二随机梯度
Figure 686612DEST_PATH_IMAGE006
。然后基于预设的衰减步长、迭代数据中的局部梯度、第一随机梯度、第二随机梯度,获得局部梯度,具体表达式为:
Figure 723838DEST_PATH_IMAGE001
其中
Figure 940055DEST_PATH_IMAGE002
为衰减步长,
Figure 88140DEST_PATH_IMAGE003
Figure 706203DEST_PATH_IMAGE004
为所述迭代数据中的局部梯度,
Figure 914331DEST_PATH_IMAGE005
为第一随机梯度,
Figure 617844DEST_PATH_IMAGE006
为第二随机梯度。
本发明利用动量更新方法不仅可以消除随机梯度方差带来的影响,也可以起到加速的效果,理论推导得出该方法在处理随机凸优化(stochastic)问题时具有与基于投影的随机梯度下降法一致的收敛速率
Figure 304041DEST_PATH_IMAGE008
。同时,该方法无需为每个样本存储一个样本梯度或状态信息,大大提高了算法收敛性能并降低了存储开销。
下表提供了本发明的具有约束的分布式随机优化方法(DMFW)与随机梯度下降法(RSA、RSG、SPPDM)以及随机无投影方法(OFW、STORC、SFW、NSFW)在解决随机优化问题时的收敛速率对比。从表可以看出,本发明方法DMFW与传统无投影算法相比具有更快的收敛速率,且与随机梯度下降法具有一致的收敛速率。
优化方法 适用环境 有无投影 目标函数 收敛速率
RSA 集中式 无约束 光滑凸
Figure DEST_PATH_IMAGE051
RSG 集中式 有投影 光滑非凸
Figure 42190DEST_PATH_IMAGE051
SPPDM 分布式 无约束 非光滑非凸
Figure 421218DEST_PATH_IMAGE051
OFW 集中式 无投影 光滑凸
Figure DEST_PATH_IMAGE052
STORC 集中式 无投影 光滑凸
Figure DEST_PATH_IMAGE053
SFW 集中式 无投影 光滑凸
Figure 143187DEST_PATH_IMAGE053
NSFW 集中式 无投影 光滑非凸
Figure 633074DEST_PATH_IMAGE052
DMFW(本发明方法) 分布式 无投影 光滑凸
Figure 960150DEST_PATH_IMAGE051
也就是说,本发明利用梯度跟踪技术将条件梯度无投影方法扩展到分布式上,从而避免了投影计算,提高了算法的计算效能。不仅可以在梯度近似下衰减噪声,而且可以在凸情况下获得与基于投影的梯度下降法相当的收敛保证,收敛速度更快、计算复杂度更低和存储开销更小。
参考图5,以下以对公开数据集a9a进行二分类在线学习为例,对本发明的具体实施过程做详细描述。
二分类操作需要解决的是一个分布式凸优化问题,其可以用下述公式来表述:
Figure DEST_PATH_IMAGE054
Figure DEST_PATH_IMAGE055
其中
Figure DEST_PATH_IMAGE056
表示智能体的数量,
Figure DEST_PATH_IMAGE057
表示每个智能体的训练样本数,
Figure DEST_PATH_IMAGE058
是数据样本
Figure 572397DEST_PATH_IMAGE019
的(特征,标签)对,
Figure DEST_PATH_IMAGE059
Figure DEST_PATH_IMAGE060
设定
Figure 516082DEST_PATH_IMAGE056
=10,约束集合满足
Figure DEST_PATH_IMAGE061
,待优化参数
Figure DEST_PATH_IMAGE062
。a9a样本集的总数
Figure DEST_PATH_IMAGE063
,数据预处理后的训练集数据为:
Figure DEST_PATH_IMAGE064
以及训练集标签
Figure DEST_PATH_IMAGE065
其中,数据预处理的具体方法:将二分类数据标签
Figure DEST_PATH_IMAGE066
数值改为1和-1,1表示正样本,-1表示负样本。由于a9a数据集的正负样本比例为1:3左右,采用smote方法将其调整为1:1左右,具体步骤包括:在7800个正样本里面,随机选择一个正样本点,每次循环找到距离该点最近的m个点,随机选其中一个连线,再在连线上随机找1个点作为插值点,重复32561-7800=24761次,得到smote后正负样本均匀的数据
Figure DEST_PATH_IMAGE067
Figure DEST_PATH_IMAGE068
由于智能体个数为10,将训练集均匀的分成10份,即:
Figure DEST_PATH_IMAGE069
以及
Figure DEST_PATH_IMAGE070
。每个智能体每次随机选取训练集中的10%的数据进行训练。即每次迭代时只能随机获取部分样本信息,无法知道所有的样本信息,包括样本总数、除本次迭代抽取的样本之外的其他样本特征和标签等信息。
设置待优化参数
Figure DEST_PATH_IMAGE071
,中间变量
Figure DEST_PATH_IMAGE072
,其中
Figure DEST_PATH_IMAGE073
,设置迭代总次数
Figure DEST_PATH_IMAGE074
,迭代步长
Figure DEST_PATH_IMAGE075
Figure DEST_PATH_IMAGE076
。迭代步长为衰减步长,即随着迭代次数的增加,步长逐渐减小。需要说明的是,在测试结果不发散的情况下,为了加快运算可以适当增大迭代步长。
对所有
Figure DEST_PATH_IMAGE077
的智能体进行第2次迭代,第3次迭代,……,第
Figure DEST_PATH_IMAGE078
次迭代。其中,每个智能体每次随机选取训练集中的10%的数据进行每次迭代训练。对所有智能体完成迭代一次后,保存当前迭代的每个智能体的数据
Figure 793349DEST_PATH_IMAGE046
Figure 240510DEST_PATH_IMAGE034
Figure 430183DEST_PATH_IMAGE032
,并令
Figure 595585DEST_PATH_IMAGE049
,然后进行下一次迭代。
经过500次迭代,得到所有智能体最终的优化参数值,随机选取某个智能体
Figure 427275DEST_PATH_IMAGE019
的最终参数
Figure DEST_PATH_IMAGE079
,准备测试集数据
Figure DEST_PATH_IMAGE080
,计算
Figure DEST_PATH_IMAGE081
Figure 791260DEST_PATH_IMAGE081
大于0的判定为正样本,小于0的判定为负样本,获得如图6所示的测试集准确率结果图。
以下对本方法的收敛性的验证过程进行详细叙述:
定义以下辅助变量:
Figure DEST_PATH_IMAGE082
Figure DEST_PATH_IMAGE083
Figure DEST_PATH_IMAGE084
Figure DEST_PATH_IMAGE085
表示网络中所有智能体的状态平均,
Figure DEST_PATH_IMAGE086
表示网络中所有智能体的局部梯度预估平均值,
Figure DEST_PATH_IMAGE087
表示所有智能体的真实梯度平均值。在给出收敛速率证明之前,我们引出如下几个引理。
假设条件:
1. 目标函数
Figure DEST_PATH_IMAGE088
是光滑函数;
2. 权重连接矩阵
Figure DEST_PATH_IMAGE089
第二大特征值
Figure DEST_PATH_IMAGE090
的幅度大小严格小于1;
3. 可行集
Figure DEST_PATH_IMAGE091
是个凸紧集,满足
Figure DEST_PATH_IMAGE092
对于
Figure DEST_PATH_IMAGE093
均成立;
4. 存在常数
Figure DEST_PATH_IMAGE094
使得
Figure DEST_PATH_IMAGE095
Figure DEST_PATH_IMAGE096
成立。
引理1:设假设1-3成立。令
Figure DEST_PATH_IMAGE097
,那么,对于任意的
Figure DEST_PATH_IMAGE098
Figure DEST_PATH_IMAGE099
有,
Figure DEST_PATH_IMAGE100
其中
Figure DEST_PATH_IMAGE101
引理1给出了
Figure DEST_PATH_IMAGE102
,也就是说当
Figure DEST_PATH_IMAGE103
时,
Figure DEST_PATH_IMAGE104
趋于0,即随着迭代次数的增加,智能体
Figure DEST_PATH_IMAGE105
的平均状态估计将趋于真实的状态平均值。下面,我们将通过选择合适的步长
Figure DEST_PATH_IMAGE106
Figure DEST_PATH_IMAGE107
,建立对于所有
Figure 260157DEST_PATH_IMAGE098
Figure DEST_PATH_IMAGE108
的有界性。
引理2:设假设1-4成立。选择步长
Figure DEST_PATH_IMAGE109
Figure 178434DEST_PATH_IMAGE097
,那么对于任意
Figure 813815DEST_PATH_IMAGE098
Figure 969990DEST_PATH_IMAGE099
Figure DEST_PATH_IMAGE110
其中
Figure DEST_PATH_IMAGE111
,
Figure DEST_PATH_IMAGE112
,
Figure DEST_PATH_IMAGE113
.
引理3:设假设1-4成立。那么有
a) 对于任意
Figure DEST_PATH_IMAGE114
Figure DEST_PATH_IMAGE115
的条件期望满足
Figure DEST_PATH_IMAGE116
b) 选择步长
Figure 637818DEST_PATH_IMAGE109
Figure 512233DEST_PATH_IMAGE097
,对于任意的
Figure 685725DEST_PATH_IMAGE099
Figure 961986DEST_PATH_IMAGE115
的期望满足
Figure DEST_PATH_IMAGE117
其中
Figure DEST_PATH_IMAGE118
引理3说明当
Figure DEST_PATH_IMAGE119
时有
Figure DEST_PATH_IMAGE120
的期望收敛到0,也就是说本发明的算法DMFW随着迭代次数的增加,存在的方差在不断的减小,最终可以消除方差带来的影响。结合引理2和引理3,下面的引理给出了
Figure DEST_PATH_IMAGE121
的有界性。
引理4:设假设1-4成立。令
Figure DEST_PATH_IMAGE122
Figure DEST_PATH_IMAGE123
,那么对于任意
Figure DEST_PATH_IMAGE124
Figure DEST_PATH_IMAGE125
Figure DEST_PATH_IMAGE126
根据引理4,可以得到如下收敛性定理
定理1: 设假设1-4成立。目标函数
Figure DEST_PATH_IMAGE127
是凸函数,选择步长
Figure 585734DEST_PATH_IMAGE122
Figure 947445DEST_PATH_IMAGE123
。那么,对于任意的
Figure 924628DEST_PATH_IMAGE125
Figure DEST_PATH_IMAGE128
其中
Figure DEST_PATH_IMAGE129
定理1说明当
Figure 320975DEST_PATH_IMAGE119
时,序列
Figure DEST_PATH_IMAGE130
将收敛至最优解
Figure DEST_PATH_IMAGE131
。除此之外,当引理1中的误差
Figure DEST_PATH_IMAGE132
收敛至0时,序列
Figure DEST_PATH_IMAGE133
与序列
Figure 787728DEST_PATH_IMAGE130
具有相同的收敛保证。根据定理1可以得到当目标函数是凸函数的情况下,本发明方法DMFW可以达到
Figure DEST_PATH_IMAGE134
的收敛速率。
下面对定理1进行证明。
证明:由假设1可以得到目标函数
Figure DEST_PATH_IMAGE135
是光滑函数,根据光滑函数的性质可以得到
Figure DEST_PATH_IMAGE136
其中最后一个不等式是由假设3可得到的。由
Figure DEST_PATH_IMAGE137
的定义可知,上不等式的右边第二项可以重写为
Figure DEST_PATH_IMAGE138
其中
Figure DEST_PATH_IMAGE139
由算法中有关
Figure DEST_PATH_IMAGE140
的最优性可得,
Figure DEST_PATH_IMAGE141
是由于凸函数
Figure DEST_PATH_IMAGE142
的性质(
Figure DEST_PATH_IMAGE143
)。将上不等式代入第一个不等式中,可以得到
Figure DEST_PATH_IMAGE144
上不等式两边同时减去
Figure DEST_PATH_IMAGE145
,可以得到
Figure DEST_PATH_IMAGE146
对上不等式中的项
Figure DEST_PATH_IMAGE147
取期望且利用Jensen不等式可以得到
Figure DEST_PATH_IMAGE148
,则有
Figure DEST_PATH_IMAGE149
Figure 89265DEST_PATH_IMAGE122
Figure 604560DEST_PATH_IMAGE123
代入上不等式,有
Figure DEST_PATH_IMAGE150
从上不等式可以得到
Figure DEST_PATH_IMAGE151
,其中
Figure DEST_PATH_IMAGE152
示例性设备
如图7所示,对应于上述具有约束的分布式随机优化模型训练方法,本发明实施例还提供一种具有约束的分布式随机优化模型训练装置,上述具有约束的分布式随机优化模型训练装置包括:
迭代模块600,用于循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
样本数据获取模块610,用于获取训练样本,所述训练样本为随机的样本数据;
邻居节点数据获取模块620,用于基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
平均状态数据计算模块630,用于基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
局部梯度计算模块640,用于基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
全局梯度计算模块650,用于计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
更新模块660,用于基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
具体的,本实施例中,上述具有约束的分布式随机优化模型训练装置的各模块的具体功能可以参照上述具有约束的分布式随机优化模型训练方法中的对应描述,在此不再赘述。
基于上述实施例,本发明还提供了一种智能终端,其原理框图可以如图8所示。上述智能终端包括通过系统总线连接的处理器、存储器、网络接口以及显示屏。其中,该智能终端的处理器用于提供计算和控制能力。该智能终端的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统和具有约束的分布式随机优化模型训练程序。该内存储器为非易失性存储介质中的操作系统和具有约束的分布式随机优化模型训练程序的运行提供环境。该智能终端的网络接口用于与外部的终端通过网络连接通信。该具有约束的分布式随机优化模型训练程序被处理器执行时实现上述任意一种具有约束的分布式随机优化模型训练方法的步骤。该智能终端的显示屏可以是液晶显示屏或者电子墨水显示屏。
本领域技术人员可以理解,图8中示出的原理框图,仅仅是与本发明方案相关的部分结构的框图,并不构成对本发明方案所应用于其上的智能终端的限定,具体的智能终端可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,提供了一种智能终端,上述智能终端包括存储器、处理器以及存储在上述存储器上并可在上述处理器上运行的具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被上述处理器执行时进行以下操作指令:
循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
所述对智能体进行迭代训练包括如下步骤:
获取训练样本,所述训练样本为随机的样本数据;
基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
可选的,所述基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度,包括:
基于所述平均状态数据,获得第一随机梯度;
基于所述迭代数据中的平均状态数据,获得第二随机梯度;
基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度。
可选的,所述基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度的表达式为:
Figure 386572DEST_PATH_IMAGE001
其中
Figure 961909DEST_PATH_IMAGE002
为衰减步长,
Figure 298213DEST_PATH_IMAGE003
Figure 617199DEST_PATH_IMAGE004
为迭代数据中的局部梯度,
Figure 722558DEST_PATH_IMAGE005
为第一随机梯度,
Figure 203218DEST_PATH_IMAGE006
为第二随机梯度,
Figure 761238DEST_PATH_IMAGE007
为随机变量。
可选的,所述获取训练样本,包括:
在已获取的训练样本集中随机选取设定数量的训练样本或通过在线采样获得设定数量的训练样本。
可选的,所述基于所述邻居节点数据,根据平均一致性算法获得平均状态数据,包括:
获取所述邻居节点数据中各智能体的待优化参数并根据平均一致性算法获得所述平均状态数据。
可选的,所述计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度,包括:
基于所述局部梯度和所述当前迭代的智能体的迭代数据中的局部梯度,获得所述梯度下降值;
基于所述智能体集合中各智能体之间的关联关系,获得与所述当前迭代的智能体关联的邻居节点;
基于所述梯度下降值和所述邻居节点的迭代数据中的聚合梯度,获得所述当前迭代的智能体的聚合梯度;
基于所述邻居节点数据,根据平均一致性方法获得所述全局梯度。
可选的,所述基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数,包括:
获取目标场景的约束集合;
基于所述全局梯度与所述约束集合的相关性,获得可行方向;
基于所述可行方向与所述平均状态数据的凸组合,更新所述当前迭代的智能体的待优化参数。
本发明实施例还提供一种计算机可读存储介质,上述计算机可读存储介质上存储有具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被处理器执行时实现本发明实施例提供的任意一种具有约束的分布式随机优化模型训练方法的步骤。
应理解,上述实施例中各步骤的序号大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将上述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本发明的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各实例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟是以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
在本发明所提供的实施例中,应该理解到,所揭露的装置/终端设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/终端设备实施例仅仅是示意性的,例如,上述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以由另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。
上述集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,上述计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,上述计算机程序包括计算机程序代码,上述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。上述计算机可读介质可以包括:能够携带上述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,RandomAccess Memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,上述计算机可读存储介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减。
以上所述实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解;其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不是相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。

Claims (9)

1.具有约束的分布式随机优化模型训练方法,所述方法应用于传感器网络中的定位,模型中包括由至少两个智能体组成的智能体集合,所述智能体为传感器网络中的智能体,其特征在于,所述方法包括:
循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
所述对智能体进行迭代训练包括如下步骤:
获取训练样本,所述训练样本为随机的样本数据,所述随机的样本数据为在训练集中随机选取的样本数据或在线采样获取的样本数据;
基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据,所述邻居节点数据包括各个智能体之间的连接权重、各个智能体保存的上一次迭代时获得的迭代数据;
基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
2.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度,包括:
基于所述平均状态数据,获得第一随机梯度;
基于所述迭代数据中的平均状态数据,获得第二随机梯度;
基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度。
3.如权利要求2所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度的表达式为:
Figure DEST_PATH_IMAGE002
其中
Figure DEST_PATH_IMAGE004
为衰减步长,
Figure DEST_PATH_IMAGE006
Figure DEST_PATH_IMAGE008
为迭代数据中的局部梯度,
Figure DEST_PATH_IMAGE010
为第一随机梯度,
Figure DEST_PATH_IMAGE012
为第二随机梯度,
Figure DEST_PATH_IMAGE014
为随机变量。
4.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述基于所述邻居节点数据,根据平均一致性算法获得平均状态数据,包括:
获取所述邻居节点数据中各智能体的待优化参数并根据平均一致性算法获得所述平均状态数据。
5.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度,包括:
基于所述局部梯度和所述当前迭代的智能体的迭代数据中的局部梯度,获得所述梯度下降值;
基于所述智能体集合中各智能体之间的关联关系,获得与所述当前迭代的智能体关联的邻居节点;
基于所述梯度下降值和所述邻居节点的迭代数据中的聚合梯度,获得所述当前迭代的智能体的聚合梯度;
基于所述邻居节点数据,根据平均一致性方法获得所述全局梯度。
6.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数,包括:
获取目标场景的约束集合;
基于所述全局梯度与所述约束集合的相关性,获得可行方向;
基于所述可行方向与所述平均状态数据的凸组合,更新所述当前迭代的智能体的待优化参数。
7.具有约束的分布式随机优化模型训练装置,应用于传感器网络中的定位,其特征在于,所述装置包括:
迭代模块,用于循环获取智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至模型的迭代次数达到设定次数,所述智能体为传感器网络中的智能体;
样本数据获取模块,用于获取训练样本,所述训练样本为随机的样本数据,所述随机的样本数据为在训练集中随机选取的样本数据或在线采样获取的样本数据;
邻居节点数据获取模块,用于基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据,所述邻居节点数据包括各个智能体之间的连接权重、各个智能体保存的上一次迭代时获得的迭代数据;
平均状态数据计算模块,用于基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
局部梯度计算模块,用于基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
全局梯度计算模块,用于计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
更新模块,用于基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
8.智能终端,其特征在于,所述智能终端包括存储器、处理器以及存储在所述存储器上并可在所述处理器上运行的具有约束的分布式随机优化模型训练程序,所述具有约束的分布式随机优化模型训练程序被所述处理器执行时实现如权利要求1-6任意一项所述具有约束的分布式随机优化模型训练方法的步骤。
9.计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有具有约束的分布式随机优化模型训练程序,所述具有约束的分布式随机优化模型训练程序被处理器执行时实现如权利要求1-6任意一项所述具有约束的分布式随机优化模型训练方法的步骤。
CN202210486474.6A 2022-05-06 2022-05-06 具有约束的分布式随机优化模型训练方法、装置及终端 Active CN114580578B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210486474.6A CN114580578B (zh) 2022-05-06 2022-05-06 具有约束的分布式随机优化模型训练方法、装置及终端

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210486474.6A CN114580578B (zh) 2022-05-06 2022-05-06 具有约束的分布式随机优化模型训练方法、装置及终端

Publications (2)

Publication Number Publication Date
CN114580578A CN114580578A (zh) 2022-06-03
CN114580578B true CN114580578B (zh) 2022-08-23

Family

ID=81769205

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210486474.6A Active CN114580578B (zh) 2022-05-06 2022-05-06 具有约束的分布式随机优化模型训练方法、装置及终端

Country Status (1)

Country Link
CN (1) CN114580578B (zh)

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109952582A (zh) * 2018-09-29 2019-06-28 区链通网络有限公司 一种强化学习模型的训练方法、节点、系统及存储介质
WO2019144046A1 (en) * 2018-01-19 2019-07-25 Hyperdyne, Inc. Distributed high performance computing using distributed average consensus
CN111950611A (zh) * 2020-07-30 2020-11-17 西南大学 基于随机梯度追踪技术的大数据二分类分布式优化方法
CN112381218A (zh) * 2020-11-20 2021-02-19 中国人民解放军国防科技大学 一种用于分布式深度学习训练的本地更新方法
WO2022037337A1 (zh) * 2020-08-19 2022-02-24 腾讯科技(深圳)有限公司 机器学习模型的分布式训练方法、装置以及计算机设备

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019144046A1 (en) * 2018-01-19 2019-07-25 Hyperdyne, Inc. Distributed high performance computing using distributed average consensus
CN109952582A (zh) * 2018-09-29 2019-06-28 区链通网络有限公司 一种强化学习模型的训练方法、节点、系统及存储介质
CN111950611A (zh) * 2020-07-30 2020-11-17 西南大学 基于随机梯度追踪技术的大数据二分类分布式优化方法
WO2022037337A1 (zh) * 2020-08-19 2022-02-24 腾讯科技(深圳)有限公司 机器学习模型的分布式训练方法、装置以及计算机设备
CN112381218A (zh) * 2020-11-20 2021-02-19 中国人民解放军国防科技大学 一种用于分布式深度学习训练的本地更新方法

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
Distributed stochastic optimization with gradient tracking over strongly-connected networks;Ran Xin等;《2019 IEEE 58th Conference on Decision and Control (CDC)》;20200312;第8353-8358页 *
基于多智能体网络的分布式优化研究;卢开红;《中国优秀硕士学位论文全文数据库 信息科技辑》;20200229;I140-16 *
基于强化学习的多智能体协同关键技术及应用研究;李盛祥;《中国优秀硕士学位论文全文数据库 信息科技辑》;20220430;I140-35 *

Also Published As

Publication number Publication date
CN114580578A (zh) 2022-06-03

Similar Documents

Publication Publication Date Title
CN109460793B (zh) 一种节点分类的方法、模型训练的方法及装置
CN110276442B (zh) 一种神经网络架构的搜索方法及装置
CN106411896B (zh) 基于apde-rbf神经网络的网络安全态势预测方法
WO2019018375A1 (en) NEURONAL ARCHITECTURE RESEARCH FOR CONVOLUTION NEURAL NETWORKS
CN110009486B (zh) 一种欺诈检测的方法、系统、设备及计算机可读存储介质
CN111259738A (zh) 人脸识别模型构建方法、人脸识别方法及相关装置
CN112633511A (zh) 用于计算量子配分函数的方法、相关装置及程序产品
CN112639841B (zh) 用于在多方策略互动中进行策略搜索的采样方案
CN116112563A (zh) 一种基于流行度预测的双策略自适应缓存替换方法
Bhatnagar et al. Stochastic algorithms for discrete parameter simulation optimization
CN117151208B (zh) 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质
Tembine Mean field stochastic games: Convergence, Q/H-learning and optimality
CN114580578B (zh) 具有约束的分布式随机优化模型训练方法、装置及终端
CN113220466A (zh) 一种基于长短期记忆模型的云服务负载通用预测方法
CN116453585A (zh) mRNA和药物关联的预测方法、装置、终端设备及介质
CN116125279A (zh) 一种电池健康状态的确定方法、装置、设备及存储介质
Liu et al. Online quantification of input model uncertainty by two-layer importance sampling
Ho et al. Adaptive communication for distributed deep learning on commodity GPU cluster
CN113112092A (zh) 一种短期概率密度负荷预测方法、装置、设备和存储介质
JP7331938B2 (ja) 学習装置、推定装置、学習方法及び学習プログラム
CN113836359B (zh) 动态图嵌入方法、装置、电子设备及存储介质
CN109993313A (zh) 样本标签处理方法及装置、社群划分方法及装置
CN110323743B (zh) 一种暂态功角稳定评估历史数据的聚类方法及装置
CN118133936A (zh) 基于光滑化和动量技术的分布式随机非光滑优化方法
CN112163170B (zh) 一种基于虚节点和元学习改进社交网络对齐方法及系统

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant