CN114580578B - 具有约束的分布式随机优化模型训练方法、装置及终端 - Google Patents
具有约束的分布式随机优化模型训练方法、装置及终端 Download PDFInfo
- Publication number
- CN114580578B CN114580578B CN202210486474.6A CN202210486474A CN114580578B CN 114580578 B CN114580578 B CN 114580578B CN 202210486474 A CN202210486474 A CN 202210486474A CN 114580578 B CN114580578 B CN 114580578B
- Authority
- CN
- China
- Prior art keywords
- gradient
- data
- agent
- training
- agents
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 89
- 238000000034 method Methods 0.000 title claims abstract description 79
- 238000005457 optimization Methods 0.000 title claims abstract description 70
- 238000004422 calculation algorithm Methods 0.000 claims description 37
- 238000004364 calculation method Methods 0.000 claims description 21
- 238000003860 storage Methods 0.000 claims description 16
- 230000002776 aggregation Effects 0.000 claims description 9
- 238000004220 aggregation Methods 0.000 claims description 9
- 238000005070 sampling Methods 0.000 claims description 5
- 238000011156 evaluation Methods 0.000 abstract description 2
- 239000003795 chemical substances by application Substances 0.000 description 138
- 230000006870 function Effects 0.000 description 19
- 238000011478 gradient descent method Methods 0.000 description 8
- 238000010586 diagram Methods 0.000 description 7
- 230000008569 process Effects 0.000 description 7
- 238000004590 computer program Methods 0.000 description 6
- 238000005516 engineering process Methods 0.000 description 5
- 238000012360 testing method Methods 0.000 description 5
- 238000009826 distribution Methods 0.000 description 4
- 230000004044 response Effects 0.000 description 4
- 230000003902 lesion Effects 0.000 description 3
- 239000011159 matrix material Substances 0.000 description 3
- 238000006116 polymerization reaction Methods 0.000 description 3
- 102100029469 WD repeat and HMG-box DNA-binding protein 1 Human genes 0.000 description 2
- 101710097421 WD repeat and HMG-box DNA-binding protein 1 Proteins 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 230000007423 decrease Effects 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 238000010801 machine learning Methods 0.000 description 2
- 238000007781 pre-processing Methods 0.000 description 2
- 230000001133 acceleration Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 230000002238 attenuated effect Effects 0.000 description 1
- 230000003247 decreasing effect Effects 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 239000012466 permeate Substances 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 239000000047 product Substances 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000013468 resource allocation Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
- XLYOFNOQVPJJNP-UHFFFAOYSA-N water Substances O XLYOFNOQVPJJNP-UHFFFAOYSA-N 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D30/00—Reducing energy consumption in communication networks
- Y02D30/70—Reducing energy consumption in communication networks in wireless communication networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Software Systems (AREA)
- Medical Informatics (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明公开了一种具有约束的分布式随机优化模型训练方法、装置及终端,上述方法包括:循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;其中,对智能体进行迭代训练时获取随机训练样本,在每次迭代中通过随机选择的无偏随机局部梯度来计算局部目标函数的全局梯度。与现有技术相比,大大降低了梯度评估的成本和计算复杂性,可用于高维、大规模优化问题。
Description
技术领域
本发明涉及机器学习技术领域,尤其涉及的是一种具有约束的分布式随机优化模型训练方法、装置及终端。
背景技术
近年来,随着高科技的蓬勃发展,特别是云计算和大数据等新兴领域的出现。具有约束和随机因素的分布式优化理论和应用得到了越来越多的重视,并逐渐渗透到科学研究、工程应用和社会生活的各个方面。分布式优化是通过多智能体之间的合作协调有效地实现优化的任务,可用来解决许多集中式算法难以胜任的大规模复杂的优化问题。可以创建具有约束的随机优化模型用来解决带有约束和随机因素的优化问题,目前对该优化模型训练时常采用基于投影的随机梯度下降法和条件梯度法。
采用基于投影的随机梯度下降法训练时,在向负随机梯度方向迈出一步后,迭代将被投射回约束集上。当执行投影的计算成本较低时(如投影单纯形上),这种方法是有效的。但在许多实际情况下,如处理迹范数球、基多面体等可行域情况时,投影到约束集上的成本可能很高,计算效能低。
条件梯度法,通过求解约束集上的线性最小化子问题来避免投影的计算以获得一个条件梯度,然后通过当前迭代和条件梯度的凸组合来更新下一步迭代。虽然条件梯度法及其变体可以解决具有约束和随机因素的随机优化问题,但是只能应用在集中式环境,无法应用在分布式环境中,并且收敛速率不快。
发明内容
本发明的主要目的在于提供一种具有约束的分布式随机优化模型训练方法、装置、智能终端及存储介质,能够对随机优化模型进行训练,以在分布式环境中解决具有复杂约束和随机因素的优化问题。
为了实现上述目的,本发明第一方面提供一种具有约束的分布式随机优化模型训练方法,所述模型中包括由至少两个智能体组成的智能体集合,所述方法包括:
循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
所述对智能体进行迭代训练包括如下步骤:
获取训练样本,所述训练样本为随机的样本数据;
基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
可选的,所述基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度,包括:
基于所述平均状态数据,获得第一随机梯度;
基于所述迭代数据中的平均状态数据,获得第二随机梯度;
基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度。
可选的,所述基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度的表达式为:
可选的,所述获取训练样本,包括:
在已获取的训练样本集中随机选取设定数量的训练样本或通过在线采样获得训练样本。
可选的,所述基于所述邻居节点数据,根据平均一致性算法获得平均状态数据,包括:
获取所述邻居节点数据中各智能体的待优化参数并根据平均一致性算法获得所述平均状态数据。
可选的,所述计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度,包括:
基于所述局部梯度和所述当前迭代的智能体的迭代数据中的局部梯度,获得所述梯度下降值;
基于所述智能体集合中各智能体之间的关联关系,获得与所述当前迭代的智能体关联的邻居节点;
基于所述梯度下降值和所述邻居节点的迭代数据中的聚合梯度,获得所述当前迭代的智能体的聚合梯度;
基于所述邻居节点数据,根据平均一致性方法获得所述全局梯度。
可选的,所述基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数,包括:
获取目标场景的约束集合;
基于所述全局梯度与所述约束集合的相关性,获得可行方向;
基于所述可行方向与所述平均状态数据的凸组合,更新所述当前迭代的智能体的待优化参数。
本发明第二方面提供一种具有约束的分布式随机优化模型训练装置,其中,上述装置包括:
迭代模块,用于循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
样本数据获取模块,用于获取训练样本,所述训练样本为随机的样本数据;
邻居节点数据获取模块,用于基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
平均状态数据计算模块,用于基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
局部梯度计算模块,用于基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
全局梯度计算模块,用于计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
更新模块,用于基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
本发明第三方面提供一种智能终端,上述智能终端包括存储器、处理器以及存储在上述存储器上并可在上述处理器上运行的具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被上述处理器执行时实现任意一项上述具有约束的分布式随机优化模型训练方法的步骤。
本发明第四方面提供一种计算机可读存储介质,上述计算机可读存储介质上存储有具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被处理器执行时实现任意一项上述具有约束的分布式随机优化模型训练方法的步骤。
由上可见,与现有技术相比,本发明在训练每个智能体时使用随机的样本数据,采用随机梯度来计算梯度下降值并根据模型中各个智能体之间的关联关系来更新智能体的待优化参数,每次迭代时只需要计算一次样本梯度,也不需要存储样本的梯度信息或状态信息。因此,本发明的模型训练方法不仅可以在分布式环境中解决具有复杂约束和随机因素的优化问题,而且收敛速度快、计算效能高、存储开销小。
附图说明
为了更清楚地说明本发明实施例中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其它的附图。
图1是本发明实施例提供的具有约束的分布式随机优化模型训练方法中的智能体训练流程示意图;
图2是图1实施例的步骤S500具体流程示意图;
图3是图1实施例的步骤S600具体流程示意图;
图4是图1实施例的步骤S400具体流程示意图;
图5是图1实施例的流程框图;
图6是图1实施例的测试集准确率曲线图;
图7是本发明实施例提供的具有约束的分布式随机优化模型训练装置的结构示意图;
图8是本发明实施例提供的一种智能终端的内部结构原理框图。
具体实施方式
以下描述中,为了说明而不是为了限定,提出了诸如特定系统结构、技术之类的具体细节,以便透彻理解本发明实施例。然而,本领域的技术人员应当清楚,在没有这些具体细节的其它实施例中也可以实现本发明。在其它情况下,省略对众所周知的系统、装置、电路以及方法的详细说明,以免不必要的细节妨碍本发明的描述。
应当理解,当在本说明书和所附权利要求书中使用时,术语“包括”指示所描述特征、整体、步骤、操作、元素和/或组件的存在,但并不排除一个或多个其它特征、整体、步骤、操作、元素、组件和/或其集合的存在或添加。
还应当理解,在本发明说明书中所使用的术语仅仅是出于描述特定实施例的目的而并不意在限制本发明。如在本发明说明书和所附权利要求书中所使用的那样,除非上下文清楚地指明其它情况,否则单数形式的“一”、“一个”及“该”意在包括复数形式。
还应当进一步理解,在本发明说明书和所附权利要求书中使用的术语“和/或”是指相关联列出的项中的一个或多个的任何组合以及所有可能组合,并且包括这些组合。
如在本说明书和所附权利要求书中所使用的那样,术语“如果”可以依据上下文被解释为“当…时”或“一旦”或“响应于确定”或“响应于检测到”。类似的,短语“如果确定”或“如果检测到[所描述条件或事件]”可以依据上下文被解释为意指“一旦确定”或“响应于确定”或“一旦检测到[所描述的条件或事件]”或“响应于检测到[所描述条件或事件]”。
下面结合本发明实施例的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本发明的一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本发明保护的范围。
在下面的描述中阐述了很多具体细节以便于充分理解本发明,但是本发明还可以采用其它不同于在此描述的其它方式来实施,本领域技术人员可以在不违背本发明内涵的情况下做类似推广,因此本发明不受下面公开的具体实施例的限制。
具有约束和随机因素的分布式随机优化问题广泛存在于各种工程应用中,如无人系统、无线通信、分布式机器学习、多智能体强化学习等等。由于在分布式网络或分布式控制等目标场景(比如资源分配, 传感器网络中的定位)中的可行域可以是复杂的迹范数球、基本多面体等,现有的基于投影的随机梯度下降方法很难处理或无法处理这些优化问题。
示例性方法
如图1所示,本发明实施例提供一种具有约束的分布式随机优化模型训练方法,使用该方法创建用于在分布式环境中解决具有复杂约束集的随机优化问题的网络模型。在分布式环境中,每个个体往往有一个代价函数,且整个网络的代价由这些个体的代价函数和来表示。网络模型的目的是通过个体间的局部信息交流而完成整个网络模型代价函数的优化。其中每个个体只知道自己的代价函数,在给定的分布式优化算法下得出保证其收敛的条件。上述个体也称为智能体。
用于分布式环境中的网络模型包括至少两个智能体,所有的智能体组成智能体集合。通过循环获取智能体集合中的每一个智能体,对获取到的智能体进行迭代训练并在智能体中保存生成的迭代数据,直至网络模型的迭代次数达到设定次数,完成网络模型的创建,应用在目标场景中。本领域技术人员熟知,网络模型创建好后,还可以通过预先设定的测试数据对网络模型进行测试。
具体的,上述迭代训练包括如下步骤:
步骤S100:获取训练样本,所述训练样本为随机的样本数据;
具体的,本发明解决的是带有随机因素的优化问题,在实际问题中并不知道样本的真实概率分布,因此传统随机优化方法(SAG、SAGA、SVRG或其变体)无法解决该问题。而本发明方案无需获取到所有样本的信息或者随机样本的分布情况,训练样本可以为随机的样本数据,每次迭代只需随机选取一个样本或是在线获取一个样本即可。具体实施时, 可以在已获取到的训练样本集中随机选取设定数量的训练样本;或者在线采样获取训练样本。因此本发明的训练方法既可用于带有随机因素的优化问题(stochastic),也可用于有限和问题(finite-sum)中。
在本实施例中,网络模型中设定了10个智能体,以公开数据集a9a作为训练数据集,每次迭代训练时,从该训练数据集中随机获取10%的样本数据作为训练数据。
步骤S200:基于智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
其中,邻居节点数据包括各个智能体之间的连接权重,各个智能体保存的上一次迭代时获得的迭代数据。迭代数据包括下述迭代步骤中生成的平均状态数据、局部梯度和聚合梯度等。
具体的,本发明中将多智能体表示为一个由个智能体组成的集合。集合中的智能体通过通信网络互相交换信息,其中表示网络中边的集合,。根据边的集合,可以获得整个模型的权重连接矩阵,权重连接矩阵是一个双随机矩阵,且其中的行列和均为1。表示第个智能体和第个智能体之间的连接权重。因此,可以通过该集合获得与每个智能体关联的所有邻居节点、连接权重等数据。
步骤S300:基于邻居节点数据,根据平均一致性算法获得平均状态数据;
其中,平均一致性算法用于根据一个智能体的所有邻居节点数据预估当前迭代的智能体的平均状态。具体的,通过获取当前迭代的智能体的邻居节点数据中的待优化参数并根据平均一致性算法获得当前迭代的智能体的平均状态数据。
例如:第个智能体的平均状态数据具体可以表示为:,其中,表示第个智能体的邻居节点集合,表示第个智能体和第个智能体之间的连接权重,为第个智能体第次迭代时的待优化参数,为第个智能体第次迭代时的平均状态数据。
步骤S400:基于平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
具体的,由于随机变量的分布未知,只能得到目标函数的随机梯度,即对于给定的和随机变量,可以得到实际梯度的无偏估计。众所周知,条件梯度算法的朴素随机实现是可以将得到的随机梯度替换实际梯度,但是由于无法消失的方差的存在导致算法很可能会发散。为了解决这个问题,本发明利用递归动量思想设计局部梯度迭代公式,局部梯度不仅与当前迭代样本的梯度有关,且与上一次局部梯度有关。
基于解决有约束的随机优化问题,计算局部梯度时本发明创新性地采用了随机变量并根据递归动量策略纳入上一次迭代的局部梯度,不仅可以消除随机梯度方差带来的影响,而且可以起到动量加速的效果,保证在目标函数是凸函数的情况下得到与基于投影的随机梯度下降法一致的收敛速率。
步骤S500:计算局部梯度的梯度下降值并根据邻居节点数据获得全局梯度;
具体的,本发明基于传统梯度跟踪技术的原理,利用上一次迭代的迭代数据来计算全局梯度。首先将当前迭代的智能体的局部梯度与该智能体保存的迭代数据中的局部梯度比较,计算出局部梯度的梯度下降值,根据梯度下降值结合邻居节点数据更新当前迭代的智能体的变量和全局梯度,使得当前迭代的智能体的变量更新体现对全局梯度的追踪。
在本实施例中,如图2所示,获得全局梯度具体包括如下步骤:
步骤S510:基于局部梯度和迭代智能体的迭代数据中的局部梯度,获得梯度下降值;
步骤S520:基于智能体之间的关联信息,获得迭代智能体关联的邻居节点;
步骤S530:基于梯度下降值和邻居节点的迭代数据中的聚合梯度,获得迭代智能体的聚合梯度;
具体的,根据梯度跟踪方法,计算聚合梯度。具体的计算公式为:
其中,为第个智能体第次迭代的聚合梯度,为第个智能体第次迭代时的聚合梯度,表示第个智能体的邻居节点集合,表示第个智能体和第个智能体之间的连接权重,为第个智能体第次迭代时的局部梯度,为第个智能体第次迭代时的局部梯度。
步骤S540:基于邻居节点数据,根据平均一致性方法获得全局梯度。
具体的,根据聚合梯度,采用平均一致性算法计算全局梯度。具体计算公式为:,其中为第个智能体第次迭代时的全局梯度,为第个智能体第次迭代的聚合梯度,表示第个智能体的邻居节点集合,表示第个智能体和第个智能体之间的连接权重。
步骤S600:基于全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数;
在本实施例中,如图3所示,更新当前迭代的智能体的待优化参数具体包括如下步骤:
步骤S610:获取目标场景的约束集合;
步骤S620:基于全局梯度与约束集合的相关性,获得可行方向;
步骤S630:基于可行方向与平均状态数据的凸组合,更新迭代智能体的待优化参数。
由上所述,本发明可应用于具有复杂约束集的随机环境中,通过利用条件梯度技术避免了代价高的投影计算,大大提高计算效能。同时,本发明在每次迭代中通过随机选择的无偏随机局部梯度来计算局部目标函数的全局梯度,大大降低了梯度评估的成本和计算复杂性,可用于高维、大规模优化问题。并且每次迭代随机选取部分样本(样本个数大于等于1)来计算样本梯度,无需对所有样本或者批量样本进行梯度计算,更适用于高维、大规模优化问题。该方法也可以用于有限和问题中,与算法SAG和SAGA不同,该方法无需为每个样本维护一个旧的梯度,具有更小的存储开销;该方法通过利用动量更新思想,对于随机凸优化(stochastic)问题可以达到与基于投影的随机梯度法一致的收敛速率。
在一些实施例中,如图4所示,上述步骤S400中计算局部梯度,具体包括步骤:
步骤S410:基于平均状态数据,获得第一随机梯度;
步骤S420:基于迭代数据中的平均状态数据,获得第二随机梯度;
步骤S430:基于预设的衰减步长、迭代数据中的局部梯度、第一随机梯度、第二随机梯度,获得局部梯度;
具体的,首先根据随机变量和平均状态数据,计算第一随机梯度,即:;根据迭代数据中的平均状态数据(上一次迭代时获得的平均状态数据)和随机变量,获得第二随机梯度。然后基于预设的衰减步长、迭代数据中的局部梯度、第一随机梯度、第二随机梯度,获得局部梯度,具体表达式为:
本发明利用动量更新方法不仅可以消除随机梯度方差带来的影响,也可以起到加速的效果,理论推导得出该方法在处理随机凸优化(stochastic)问题时具有与基于投影的随机梯度下降法一致的收敛速率。同时,该方法无需为每个样本存储一个样本梯度或状态信息,大大提高了算法收敛性能并降低了存储开销。
下表提供了本发明的具有约束的分布式随机优化方法(DMFW)与随机梯度下降法(RSA、RSG、SPPDM)以及随机无投影方法(OFW、STORC、SFW、NSFW)在解决随机优化问题时的收敛速率对比。从表可以看出,本发明方法DMFW与传统无投影算法相比具有更快的收敛速率,且与随机梯度下降法具有一致的收敛速率。
优化方法 | 适用环境 | 有无投影 | 目标函数 | 收敛速率 |
RSA | 集中式 | 无约束 | 光滑凸 | |
RSG | 集中式 | 有投影 | 光滑非凸 | |
SPPDM | 分布式 | 无约束 | 非光滑非凸 | |
OFW | 集中式 | 无投影 | 光滑凸 | |
STORC | 集中式 | 无投影 | 光滑凸 | |
SFW | 集中式 | 无投影 | 光滑凸 | |
NSFW | 集中式 | 无投影 | 光滑非凸 | |
DMFW(本发明方法) | 分布式 | 无投影 | 光滑凸 |
也就是说,本发明利用梯度跟踪技术将条件梯度无投影方法扩展到分布式上,从而避免了投影计算,提高了算法的计算效能。不仅可以在梯度近似下衰减噪声,而且可以在凸情况下获得与基于投影的梯度下降法相当的收敛保证,收敛速度更快、计算复杂度更低和存储开销更小。
参考图5,以下以对公开数据集a9a进行二分类在线学习为例,对本发明的具体实施过程做详细描述。
二分类操作需要解决的是一个分布式凸优化问题,其可以用下述公式来表述:
其中,数据预处理的具体方法:将二分类数据标签数值改为1和-1,1表示正样本,-1表示负样本。由于a9a数据集的正负样本比例为1:3左右,采用smote方法将其调整为1:1左右,具体步骤包括:在7800个正样本里面,随机选择一个正样本点,每次循环找到距离该点最近的m个点,随机选其中一个连线,再在连线上随机找1个点作为插值点,重复32561-7800=24761次,得到smote后正负样本均匀的数据,。
由于智能体个数为10,将训练集均匀的分成10份,即:以及。每个智能体每次随机选取训练集中的10%的数据进行训练。即每次迭代时只能随机获取部分样本信息,无法知道所有的样本信息,包括样本总数、除本次迭代抽取的样本之外的其他样本特征和标签等信息。
对所有的智能体进行第2次迭代,第3次迭代,……,第次迭代。其中,每个智能体每次随机选取训练集中的10%的数据进行每次迭代训练。对所有智能体完成迭代一次后,保存当前迭代的每个智能体的数据,,,并令,然后进行下一次迭代。
以下对本方法的收敛性的验证过程进行详细叙述:
假设条件:
引理3:设假设1-4成立。那么有
根据引理4,可以得到如下收敛性定理
下面对定理1进行证明。
示例性设备
如图7所示,对应于上述具有约束的分布式随机优化模型训练方法,本发明实施例还提供一种具有约束的分布式随机优化模型训练装置,上述具有约束的分布式随机优化模型训练装置包括:
迭代模块600,用于循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
样本数据获取模块610,用于获取训练样本,所述训练样本为随机的样本数据;
邻居节点数据获取模块620,用于基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
平均状态数据计算模块630,用于基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
局部梯度计算模块640,用于基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
全局梯度计算模块650,用于计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
更新模块660,用于基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
具体的,本实施例中,上述具有约束的分布式随机优化模型训练装置的各模块的具体功能可以参照上述具有约束的分布式随机优化模型训练方法中的对应描述,在此不再赘述。
基于上述实施例,本发明还提供了一种智能终端,其原理框图可以如图8所示。上述智能终端包括通过系统总线连接的处理器、存储器、网络接口以及显示屏。其中,该智能终端的处理器用于提供计算和控制能力。该智能终端的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统和具有约束的分布式随机优化模型训练程序。该内存储器为非易失性存储介质中的操作系统和具有约束的分布式随机优化模型训练程序的运行提供环境。该智能终端的网络接口用于与外部的终端通过网络连接通信。该具有约束的分布式随机优化模型训练程序被处理器执行时实现上述任意一种具有约束的分布式随机优化模型训练方法的步骤。该智能终端的显示屏可以是液晶显示屏或者电子墨水显示屏。
本领域技术人员可以理解,图8中示出的原理框图,仅仅是与本发明方案相关的部分结构的框图,并不构成对本发明方案所应用于其上的智能终端的限定,具体的智能终端可以包括比图中所示更多或更少的部件,或者组合某些部件,或者具有不同的部件布置。
在一个实施例中,提供了一种智能终端,上述智能终端包括存储器、处理器以及存储在上述存储器上并可在上述处理器上运行的具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被上述处理器执行时进行以下操作指令:
循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
所述对智能体进行迭代训练包括如下步骤:
获取训练样本,所述训练样本为随机的样本数据;
基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据;
基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
可选的,所述基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度,包括:
基于所述平均状态数据,获得第一随机梯度;
基于所述迭代数据中的平均状态数据,获得第二随机梯度;
基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度。
可选的,所述基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度的表达式为:
可选的,所述获取训练样本,包括:
在已获取的训练样本集中随机选取设定数量的训练样本或通过在线采样获得设定数量的训练样本。
可选的,所述基于所述邻居节点数据,根据平均一致性算法获得平均状态数据,包括:
获取所述邻居节点数据中各智能体的待优化参数并根据平均一致性算法获得所述平均状态数据。
可选的,所述计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度,包括:
基于所述局部梯度和所述当前迭代的智能体的迭代数据中的局部梯度,获得所述梯度下降值;
基于所述智能体集合中各智能体之间的关联关系,获得与所述当前迭代的智能体关联的邻居节点;
基于所述梯度下降值和所述邻居节点的迭代数据中的聚合梯度,获得所述当前迭代的智能体的聚合梯度;
基于所述邻居节点数据,根据平均一致性方法获得所述全局梯度。
可选的,所述基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数,包括:
获取目标场景的约束集合;
基于所述全局梯度与所述约束集合的相关性,获得可行方向;
基于所述可行方向与所述平均状态数据的凸组合,更新所述当前迭代的智能体的待优化参数。
本发明实施例还提供一种计算机可读存储介质,上述计算机可读存储介质上存储有具有约束的分布式随机优化模型训练程序,上述具有约束的分布式随机优化模型训练程序被处理器执行时实现本发明实施例提供的任意一种具有约束的分布式随机优化模型训练方法的步骤。
应理解,上述实施例中各步骤的序号大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本发明实施例的实施过程构成任何限定。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将上述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。实施例中的各功能单元、模块可以集成在一个处理单元中,也可以是各个单元单独物理存在,也可以两个或两个以上单元集成在一个单元中,上述集成的单元既可以采用硬件的形式实现,也可以采用软件功能单元的形式实现。另外,各功能单元、模块的具体名称也只是为了便于相互区分,并不用于限制本发明的保护范围。上述系统中单元、模块的具体工作过程,可以参考前述方法实施例中的对应过程,在此不再赘述。
在上述实施例中,对各个实施例的描述都各有侧重,某个实施例中没有详述或记载的部分,可以参见其它实施例的相关描述。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各实例的单元及算法步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟是以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同法来实现所描述的功能,但是这种实现不应认为超出本发明的范围。
在本发明所提供的实施例中,应该理解到,所揭露的装置/终端设备和方法,可以通过其它的方式实现。例如,以上所描述的装置/终端设备实施例仅仅是示意性的,例如,上述模块或单元的划分,仅仅为一种逻辑功能划分,实际实现时可以由另外的划分方式,例如多个单元或组件可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。
上述集成的模块/单元如果以软件功能单元的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读存储介质中。基于这样的理解,本发明实现上述实施例方法中的全部或部分流程,也可以通过计算机程序来指令相关的硬件来完成,上述计算机程序可存储于一计算机可读存储介质中,该计算机程序在被处理器执行时,可实现上述各个方法实施例的步骤。其中,上述计算机程序包括计算机程序代码,上述计算机程序代码可以为源代码形式、对象代码形式、可执行文件或某些中间形式等。上述计算机可读介质可以包括:能够携带上述计算机程序代码的任何实体或装置、记录介质、U盘、移动硬盘、磁碟、光盘、计算机存储器、只读存储器(ROM,Read-Only Memory)、随机存取存储器(RAM,RandomAccess Memory)、电载波信号、电信信号以及软件分发介质等。需要说明的是,上述计算机可读存储介质包含的内容可以根据司法管辖区内立法和专利实践的要求进行适当的增减。
以上所述实施例仅用以说明本发明的技术方案,而非对其限制;尽管参照前述实施例对本发明进行了详细的说明,本领域的普通技术人员应当理解;其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不是相应技术方案的本质脱离本发明各实施例技术方案的精神和范围,均应包含在本发明的保护范围之内。
Claims (9)
1.具有约束的分布式随机优化模型训练方法,所述方法应用于传感器网络中的定位,模型中包括由至少两个智能体组成的智能体集合,所述智能体为传感器网络中的智能体,其特征在于,所述方法包括:
循环获取所述智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至所述模型的迭代次数达到设定次数;
所述对智能体进行迭代训练包括如下步骤:
获取训练样本,所述训练样本为随机的样本数据,所述随机的样本数据为在训练集中随机选取的样本数据或在线采样获取的样本数据;
基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据,所述邻居节点数据包括各个智能体之间的连接权重、各个智能体保存的上一次迭代时获得的迭代数据;
基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
2.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度,包括:
基于所述平均状态数据,获得第一随机梯度;
基于所述迭代数据中的平均状态数据,获得第二随机梯度;
基于预设的衰减步长、所述迭代数据中的局部梯度、所述第一随机梯度、所述第二随机梯度,获得所述局部梯度。
4.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述基于所述邻居节点数据,根据平均一致性算法获得平均状态数据,包括:
获取所述邻居节点数据中各智能体的待优化参数并根据平均一致性算法获得所述平均状态数据。
5.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度,包括:
基于所述局部梯度和所述当前迭代的智能体的迭代数据中的局部梯度,获得所述梯度下降值;
基于所述智能体集合中各智能体之间的关联关系,获得与所述当前迭代的智能体关联的邻居节点;
基于所述梯度下降值和所述邻居节点的迭代数据中的聚合梯度,获得所述当前迭代的智能体的聚合梯度;
基于所述邻居节点数据,根据平均一致性方法获得所述全局梯度。
6.如权利要求1所述的具有约束的分布式随机优化模型训练方法,其特征在于,所述基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数,包括:
获取目标场景的约束集合;
基于所述全局梯度与所述约束集合的相关性,获得可行方向;
基于所述可行方向与所述平均状态数据的凸组合,更新所述当前迭代的智能体的待优化参数。
7.具有约束的分布式随机优化模型训练装置,应用于传感器网络中的定位,其特征在于,所述装置包括:
迭代模块,用于循环获取智能体集合中的每一个智能体,对所述智能体进行迭代训练并在所述智能体中保存生成的迭代数据,直至模型的迭代次数达到设定次数,所述智能体为传感器网络中的智能体;
样本数据获取模块,用于获取训练样本,所述训练样本为随机的样本数据,所述随机的样本数据为在训练集中随机选取的样本数据或在线采样获取的样本数据;
邻居节点数据获取模块,用于基于所述智能体集合中各智能体之间的关联关系,获得与当前迭代的智能体对应的邻居节点数据,所述邻居节点数据包括各个智能体之间的连接权重、各个智能体保存的上一次迭代时获得的迭代数据;
平均状态数据计算模块,用于基于所述邻居节点数据,根据平均一致性算法获得平均状态数据;
局部梯度计算模块,用于基于所述平均状态数据和当前迭代的智能体的迭代数据,计算随机梯度并根据所述随机梯度计算局部梯度;
全局梯度计算模块,用于计算局部梯度的梯度下降值并根据所述邻居节点数据获得全局梯度;
更新模块,用于基于所述全局梯度,根据条件梯度算法更新当前迭代的智能体的待优化参数。
8.智能终端,其特征在于,所述智能终端包括存储器、处理器以及存储在所述存储器上并可在所述处理器上运行的具有约束的分布式随机优化模型训练程序,所述具有约束的分布式随机优化模型训练程序被所述处理器执行时实现如权利要求1-6任意一项所述具有约束的分布式随机优化模型训练方法的步骤。
9.计算机可读存储介质,其特征在于,所述计算机可读存储介质上存储有具有约束的分布式随机优化模型训练程序,所述具有约束的分布式随机优化模型训练程序被处理器执行时实现如权利要求1-6任意一项所述具有约束的分布式随机优化模型训练方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210486474.6A CN114580578B (zh) | 2022-05-06 | 2022-05-06 | 具有约束的分布式随机优化模型训练方法、装置及终端 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210486474.6A CN114580578B (zh) | 2022-05-06 | 2022-05-06 | 具有约束的分布式随机优化模型训练方法、装置及终端 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114580578A CN114580578A (zh) | 2022-06-03 |
CN114580578B true CN114580578B (zh) | 2022-08-23 |
Family
ID=81769205
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210486474.6A Active CN114580578B (zh) | 2022-05-06 | 2022-05-06 | 具有约束的分布式随机优化模型训练方法、装置及终端 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114580578B (zh) |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109952582A (zh) * | 2018-09-29 | 2019-06-28 | 区链通网络有限公司 | 一种强化学习模型的训练方法、节点、系统及存储介质 |
WO2019144046A1 (en) * | 2018-01-19 | 2019-07-25 | Hyperdyne, Inc. | Distributed high performance computing using distributed average consensus |
CN111950611A (zh) * | 2020-07-30 | 2020-11-17 | 西南大学 | 基于随机梯度追踪技术的大数据二分类分布式优化方法 |
CN112381218A (zh) * | 2020-11-20 | 2021-02-19 | 中国人民解放军国防科技大学 | 一种用于分布式深度学习训练的本地更新方法 |
WO2022037337A1 (zh) * | 2020-08-19 | 2022-02-24 | 腾讯科技(深圳)有限公司 | 机器学习模型的分布式训练方法、装置以及计算机设备 |
-
2022
- 2022-05-06 CN CN202210486474.6A patent/CN114580578B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019144046A1 (en) * | 2018-01-19 | 2019-07-25 | Hyperdyne, Inc. | Distributed high performance computing using distributed average consensus |
CN109952582A (zh) * | 2018-09-29 | 2019-06-28 | 区链通网络有限公司 | 一种强化学习模型的训练方法、节点、系统及存储介质 |
CN111950611A (zh) * | 2020-07-30 | 2020-11-17 | 西南大学 | 基于随机梯度追踪技术的大数据二分类分布式优化方法 |
WO2022037337A1 (zh) * | 2020-08-19 | 2022-02-24 | 腾讯科技(深圳)有限公司 | 机器学习模型的分布式训练方法、装置以及计算机设备 |
CN112381218A (zh) * | 2020-11-20 | 2021-02-19 | 中国人民解放军国防科技大学 | 一种用于分布式深度学习训练的本地更新方法 |
Non-Patent Citations (3)
Title |
---|
Distributed stochastic optimization with gradient tracking over strongly-connected networks;Ran Xin等;《2019 IEEE 58th Conference on Decision and Control (CDC)》;20200312;第8353-8358页 * |
基于多智能体网络的分布式优化研究;卢开红;《中国优秀硕士学位论文全文数据库 信息科技辑》;20200229;I140-16 * |
基于强化学习的多智能体协同关键技术及应用研究;李盛祥;《中国优秀硕士学位论文全文数据库 信息科技辑》;20220430;I140-35 * |
Also Published As
Publication number | Publication date |
---|---|
CN114580578A (zh) | 2022-06-03 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110276442B (zh) | 一种神经网络架构的搜索方法及装置 | |
CN111406264B (zh) | 神经架构搜索 | |
CN112101530B (zh) | 神经网络训练方法、装置、设备及存储介质 | |
CN112633511B (zh) | 用于计算量子配分函数的方法、相关装置及程序产品 | |
CN111259738A (zh) | 人脸识别模型构建方法、人脸识别方法及相关装置 | |
CN110009486B (zh) | 一种欺诈检测的方法、系统、设备及计算机可读存储介质 | |
CN106934722A (zh) | 基于k节点更新与相似度矩阵的多目标社区检测方法 | |
CN106227043A (zh) | 自适应最优控制方法 | |
Hajek et al. | Community recovery in a preferential attachment graph | |
CN116112563A (zh) | 一种基于流行度预测的双策略自适应缓存替换方法 | |
CN116125279A (zh) | 一种电池健康状态的确定方法、装置、设备及存储介质 | |
Bhatnagar et al. | Stochastic algorithms for discrete parameter simulation optimization | |
CN117151208B (zh) | 基于自适应学习率的异步联邦学习参数更新方法、电子设备及存储介质 | |
Tembine | Mean field stochastic games: Convergence, Q/H-learning and optimality | |
CN114580578B (zh) | 具有约束的分布式随机优化模型训练方法、装置及终端 | |
CN113220466A (zh) | 一种基于长短期记忆模型的云服务负载通用预测方法 | |
CN116702925A (zh) | 一种基于事件触发机制的分布式随机梯度优化方法及系统 | |
CN116224126A (zh) | 一种电动汽车锂离子电池健康状态估计方法及系统 | |
Liu et al. | Online quantification of input model uncertainty by two-layer importance sampling | |
Ho et al. | Adaptive communication for distributed deep learning on commodity GPU cluster | |
CN115545168A (zh) | 基于注意力机制和循环神经网络的动态QoS预测方法及系统 | |
CN113112092A (zh) | 一种短期概率密度负荷预测方法、装置、设备和存储介质 | |
Amrane et al. | On the use of ensembles of metamodels for estimation of the failure probability | |
CN118133936A (zh) | 基于光滑化和动量技术的分布式随机非光滑优化方法 | |
CN113836359B (zh) | 动态图嵌入方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |