CN112016699B - 一种深度学习模型训练方法、工作节点和参数服务器 - Google Patents
一种深度学习模型训练方法、工作节点和参数服务器 Download PDFInfo
- Publication number
- CN112016699B CN112016699B CN202010896348.9A CN202010896348A CN112016699B CN 112016699 B CN112016699 B CN 112016699B CN 202010896348 A CN202010896348 A CN 202010896348A CN 112016699 B CN112016699 B CN 112016699B
- Authority
- CN
- China
- Prior art keywords
- target
- statistical
- parameter
- training
- statistical parameter
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 262
- 238000000034 method Methods 0.000 title claims abstract description 84
- 238000013136 deep learning model Methods 0.000 title claims abstract description 44
- 238000004364 calculation method Methods 0.000 claims description 15
- 238000005070 sampling Methods 0.000 claims description 5
- 238000010606 normalization Methods 0.000 description 11
- 238000010586 diagram Methods 0.000 description 6
- 230000003993 interaction Effects 0.000 description 5
- 238000013135 deep learning Methods 0.000 description 4
- 238000011478 gradient descent method Methods 0.000 description 4
- 230000001360 synchronised effect Effects 0.000 description 4
- 230000009286 beneficial effect Effects 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 238000012937 correction Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 2
- 230000004913 activation Effects 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 230000000717 retained effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Management, Administration, Business Operations System, And Electronic Commerce (AREA)
Abstract
本发明实施例提供一种深度学习模型训练方法、工作节点和参数服务器,其中,应用于工作节点的深度学习模型训练方法,包括:接收参数服务器发送的第一统计参数,其中,所述第一统计参数是所述参数服务器根据目标模型的目标层的历史训练数据确定的;在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数,其中,所述目标统计参数为所述目标批量训练样本的统计参数;基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,并基于所述实际统计参数对所述目标批量训练样本进行批标准化,以及将所述目标统计参数发送至所述参数服务器。本申请实施例能够提升深度学习模型的训练效率。
Description
技术领域
本发明涉及深度学习技术领域,尤其涉及一种深度学习模型训练方法、工作节点和参数服务器。
背景技术
随着信息科技的发展,采用深度学习模型进行训练,以使用训练出的模型对目标数据进行预测,已经得到越来越广泛的使用,为了进一步提升训练出的模型的准确性等,训练样本的数量也越来越大,这就造成了训练的复杂程度和训练时间较长。
在相关技术中,通常可以采用多个工作节点对同一模型进行训练,例如:不同的工作节点负责训练同一模型中的不同训练层,此时,下一训练层需要等待上一训练层训练完成才能够执行训练过程,其等待时间大大增加了模型训练的总时间,从而降低了模型训练的效率。
由此可知,相关技术中采用多个工作节点对同一模型进行训练的过程中,存在模型训练效率低的缺陷。
发明内容
本发明实施例提供一种深度学习模型训练方法、工作节点和参数服务器,以解决相关技术中采用多个工作节点对同一深度学习模型进行训练的过程中,存在的模型训练效率低的问题。
为了解决上述技术问题,本发明是这样实现的:
第一方面,本发明实施例提供了一种深度学习模型训练方法,应用于工作节点,所述方法包括:
接收参数服务器发送的第一统计参数,其中,所述第一统计参数是所述参数服务器根据目标模型的目标层的历史训练数据确定的;
在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数,其中,所述目标统计参数为所述目标批量训练样本的统计参数;
基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,并基于所述实际统计参数对所述目标批量训练样本进行批标准化,以及将所述目标统计参数发送至所述参数服务器。
第二方面,本发明实施例还提供了一种深度学习模型训练方法,应用于参数服务器,所述方法包括:
向第一工作节点分别发送第一统计参数,其中,所述第一统计参数是所述参数服务器预先根据目标模型的目标层的历史训练数据确定的,所述第一工作节点为与所述参数服务器共同训练所述目标模型的工作节点;
接收所述第一工作节点分别发送的目标统计参数,其中,所述目标统计参数为所述第一工作节点分别基于不同的批量训练样本训练至所述目标层时,对所述不同的批量训练样本分别进行统计后得出的统计参数;
基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数。
第三方面,本发明实施例还提供了一种工作节点,包括:
第一接收模块,用于接收参数服务器发送的第一统计参数,其中,所述第一统计参数是所述参数服务器根据目标模型的目标层的历史训练数据确定的;
第一获取模块,用于在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数,其中,所述目标统计参数为所述目标批量训练样本的统计参数;
确定模块,用于基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,并基于所述实际统计参数对所述目标批量训练样本进行批标准化,以及将所述目标统计参数发送至所述参数服务器。
第四方面,本发明实施例还提供了一种参数服务器,包括:
发送模块,用于向第一工作节点分别发送第一统计参数,其中,所述第一统计参数是所述参数服务器预先根据目标模型的目标层的历史训练数据确定的,所述第一工作节点为与所述参数服务器共同训练所述目标模型的工作节点;
第二接收模块,用于接收所述第一工作节点分别发送的目标统计参数,其中,所述目标统计参数为所述第一工作节点分别基于不同的批量训练样本训练至所述目标层时,对所述不同的批量训练样本分别进行统计后得出的统计参数;
更新模块,用于基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数。
第五方面,本发明实施例还提供了一种电子设备,包括处理器,存储器及存储在所述存储器上并可在所述处理器上运行的程序或指令,所述程序或指令被所述处理器执行时实现第一方面所述的深度学习模型训练方法的步骤,或者所述程序或指令被所述处理器执行时实现第二方面所述的深度学习模型训练方法的步骤。
第六方面,本发明实施例还提供了一种可读存储介质,所述可读存储介质上存储程序或指令,所述程序或指令被处理器执行时实现第一方面所述的深度学习模型训练方法的步骤,或者所述程序或指令被处理器执行时实现第二方面所述的深度学习模型训练方法的步骤。
在本发明实施例中,工作节点接收参数服务器根据历史训练数据确定的第一统计参数,在该工作节点基于目标批量训练样本对目标模型的目标层进行训练时,基于已经接收到的第一统计参数和目标批量训练样本的目标统计参数,确定所述目标层的实际统计参数,并基于该实际统计参数执行所述目标模型的前向和反向传播训练,这样,工作节点无需等待参数服务器获取到目标模型的全部工作节点对目标层进行训练时的训练样本统计参数时,依据这些训练样本统计参数更新统计参数并下发至各个工作节点时,才能够基于该参数服务器下发的统计参数执行前向和反向传播训练,从而大大减小了工作节点等待参数服务器下发统计参数的时间,能够提升深度学习模型的训练效率。
附图说明
为了更清楚地说明本发明实施例的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本发明实施例提供的一种深度学习模型训练方法的流程图;
图2是本发明实施例提供的另一种深度学习模型训练方法的流程图;
图3是本发明实施例提供的深度学习模型训练方法中工作节点与参数服务器的连接架构示意图;
图4是本发明实施例提供的深度学习模型训练方法中工作节点与参数服务器的数据交互示意图;
图5是本发明实施例提供的一种工作节点的结构图;
图6是本发明实施例提供的一种参数服务器的结构图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
在大型深度学习模型的训练过程中,为了加快模型收敛速度,提升训练效率,同时考虑到样本总量可能很大(无法利用所有样本数据进行模型迭代),通常使用小批量梯度下降方法进行模型训练。其中,使用小批量梯度下降方法进行模型训练时,每次迭代使用batch size(批尺寸:一次训练所选取的样本数)个样本来对参数进行更新。然而对于参数很多或中间激活数据很多的大型深度学习模型来说,该深度学习模型无法在单个工作节点进行小批量计算,通常需要将训练过程放在多台工作节点上进行。
举例来说,将训练过程放在多台工作节点上进行的过程中,可以采用针对大型深度学习模型的数据并行训练方法,例如,可以是如图1所示训练网络,该训练网络包括多个工作节点10,以将同一个模型放在各个工作节点10上分别进行训练,然后对训练数据集进行划分(生成小批量训练样本),并将划分后的小批量训练样本20分别分配到不同的工作节点10上,以使该工作点10基于分配到的小批量训练样本20进行模型训练,且各个工作节点10在训练完成后与参数服务器30进行数据交互,以上报训练结果,或者各个工作节点10在训练的过程中与参数服务器30进行数据交互,以使批量训练样本标准化(以下简称:批标准化),该批标准化可以对训练过程的前向传播进行各批标准化层的处理,同时上述实际统计参数在反向传播时也会用到。
在一种实施方式中,单个工作节点能够基于小批量训练样本独立的对深度学习进行训练,并在将更新数据同步到参数服务器,该实施方式仅适用于小型深度学习训练模型。
在另一种可选的实施方式中,在大型深度学习网络模型的训练过程中,使用小批量梯度下降方法进行模型训练时,各工作节点使用相同的模型,分别从数据库中取出数据进行训练,共同完成样本数量为batch size的小批量训练(即batch size为一次模型迭代过程中,所有工作节点训练的训练样本数量的总和),各工作节点运行完后,将模型或参数更新数据同步到参数服务器,参数服务器得到所有工作节点数据后,进行模型更新,并将更新后的模型同步到各工作节点。其中,在各个工作节点执行模型训练的过程中,如果网络层需要对batch size个样本进行全局batch size统计值(以下以批标准化层(batch norm)为例进行举例说明)的时候,需要各工作节点在运行到这一层时,将该层的数据同步到参数服务器,由参数服务器完成统计值的计算后,再同步到各工作节点。
例如:对模型训练网络中含有的批标准化层(当前的模型训练网络通常会包含多层批标准化层)进行训练时,需要参数服务器对全部工作节点的训练数据进行统计,以使各个工作节点基于该统计结果分别对本层的当前数据进行校正计算,利用校正后的统计值进行批标准化。
网络层需要对batch size个样本进行全局统计值的时候,工作节点需要进行以下等待过程:
等待其他工作节点训练至同一批标准化层,以使参数服务器能够获取各个工作节点分别上报的批标准化层的数据;然后等待参数服务器对各个工作节点分别上报的批标准化层的数据进行统计计算,并下发统计结果。
这样,极大的增加了工作节点与参数服务器之间的通信量,且增加了各个工作节点在训练过程中的等待时间,对于以后日趋增多的大型模型的训练来说,是个严重的瓶颈问题,本申请能够解决采用小批量梯度下降方法进行模型训练的过程中的存在的模型训练效率低的问题。
请参见图2,图2是本发明实施例提供的一种深度学习模型训练方法的流程图,该方法应用于工作节点。如图2所示,所述深度学习模型训练方法可以包括以下步骤:
步骤201、接收参数服务器发送的第一统计参数,其中,所述第一统计参数是所述参数服务器根据目标模型的目标层的历史训练数据确定的。
举例来说,上述目标模型的目标层的历史训练数据可以是对历史训练样本的统计参数。
例如,训练网络需要对目标模型进行多次模型迭代训练,则每一次迭代训练中:该训练网络中的各个工作节点分别在前向训练至目标标准化层时,对各自当前使用的训练样本进行统计计算,以得到目标统计参数,并将该目标统计参数上报至参数服务器,该参数服务器可以根据接收到的各个工作节点的目标统计参数更新所述第一统计参数。这样,参数服务器更新后的第一统计参数能够反映全部历史训练样本的统计特征,则在工作节点下一次对目标标准化层进行迭代训练的过程中,工作节点从参数服务器获取的第一统计参数即为已经对目标标准化层训练过的样本数据的历史统计参数,该工作节点将目标标准化层的目标批量训练样本的目标统计参数与该第一统计参数进行结合,以得到本标准化层的实际统计参数。
需要说明的是,目标模型中可以包括多个标准化层,参数服务器可以分别对每一个标准化层的第一统计参数进行更新,且各个工作节点在前向训练至m1标准化层时,可以从参数服务器获取m1标准化层对应的第一统计参数;各个工作节点在前向训练至m2标准化层时,可以从参数服务器获取m2标准化层对应的第一统计参数,其中,目标模型中的标准化层包括m1标准化层和m2标准化层。
上述第一统计参数可以是对历史数据进行求方差、求和求积分等计算中的一个或者多个后得出的统计值或者统计值序列,或者还可以包括已完成训练的训练样本数量,其中,统计值序列可以包括多个统计值,例如:包括方差值和求和值等。
需要说明的是,上述第一统计参数可以为工作节点在对目标层进行训练之前接收到的历史统计参数,其未对目标层当前使用的训练样本进行统计。
步骤202、在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数,其中,所述目标统计参数为所述目标批量训练样本的统计参数。
举例来说,上述目标层可以是目标批标准化层,所述目标批量训练样本可以是包括至少一个训练样本的样本集合,且应用本申请提供的深度学习模型训练方法的工作节点,基于该样本集合中的训练样本对目标模型的目标层进行训练,即所述目标批量训练样本又可以称之为所述目标层进行训练的当前训练样本。上述基于目标批量训练样本对目标模型的目标层进行训练,也可以理解为:目标批量训练样本前向训练至所述目标模型的目标层。
另外,上述目标统计参数可以是对该目标批量训练样本进行统计计算后得出的统计值或者统计值序列,该统计计算的方式可以与参数服务器对历史训练数据执行的统计计算的方式相同,在此不再赘述。
需要说明的是,在执行步骤202之前,工作节点还需要获取目标批量数据,举例来说,可以通过以下方式提升各个工作节点获取到的目标批量数据的全局性,即训练过程中能够使用到全局数据信息。
在一种可选的实施方式中,可以通过随机采样方式从数据库中获取所述目标批量训练样本。
例如,可以采取有放回的方式采样,本实施方式中,工作节点在数据库中获取所述目标批量训练样本之后,可以不删除所述目标批量训练样本,其他工作节点也可以随机选取到目标批量训练样本中的训练样本。
在另一种可选的实施方式中,还可以从数据库获取排列于预设位置处的所述目标批量训练样本,其中,所述数据库存储的训练样本是乱序排列的。
其中,上述列于预设位置处的所述目标批量训练样本可以是从排列于第一位的训练样本开始获取,并获取排列于该训练样本之后的N-1个训练样本,其中,N表示批量训练样本中包含的训练样本数量。举例来说,数据库可以对其内的训练样本进行分配,以向不同的工作节点分配不同的训练样本。
在一些可选的实施例中,所述第一统计参数的参数种类与所述目标统计参数的参数种类相同,举例来说,所述第一统计参数和所述目标统计参数分别包括:训练样本统计参数值和训练样本数量。即第一统计参数包括对历史训练样本的统计参数值和历史训练样本含有的样本数量,目标统计参数包括对目标批量训练样本的统计参数值和目标批量训练样本含有的样本数量。其中,上述样本数量可以是指训练样本使用过的次数,即如果同一个样本训练过n次,则样本数量为n。
在一些可选的实施例中,上述目标层的实际统计参数,可以包括统计参数值。
步骤203、基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,并基于所述实际统计参数对所述目标批量训练样本进行批标准化,以及将所述目标统计参数发送至所述参数服务器。
在一些可选的实施例中,上述基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,可以是采用第一统计参数对目标统计参数进行校正计算,以得到所述目标层的实际统计参数。
其中,上述基于所述实际统计参数对所述目标层的数据进行批标准化,也可以称之为:基于所述实际统计参数对所述目标层的训练样本进行批标准化。
在一种可能的实现方式中,上述批标准化的校正公式可以根据目标层的统计值方差,以及与第一统计值之间的偏差程度等信息进行调整,该批标准化的具体过程与现有技术中批标准化层的批标准化的过程具有相同含义,在此不再赘述。
在一些可选的实施例中,工作节点在获取到目标统计参数之后,还将该目标统计参数发送至参数服务器,以使参数服务器根据该目标统计参数对第一统计参数进行更新,以便于参数服务器将更新后的第一统计参数发送至各个工作节点,实现各个工作节点之间的数据同步,便于在下一次迭代中使用更新后的第一统计参数对该迭代次序中的目标统计参数进行校准计算,直至模型训练完成。如前所述,第一统计参数以及目标统计参数可以包括训练样本统计参数值和训练样本数量,这样,工作节点和参数服务器之间的通信量较小,提高通信效率。
本实施方式中,并不具体限定工作节点执行:基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数、基于所述实际统计参数对所述目标层的数据进行批标准化以及将所述目标统计参数发送至所述参数服务器这三个步骤的先后顺序。
其中,每次迭代训练完成之后,各个工作节点可以将训练结果上报至参数服务器,以由参数服务器根据各个工作节点上报的训练结果对模型进行更新,并将更新后的模型下发至各个工作节点,以使各个工作节点分别基于不同的训练样本继续对更新后的模型进行训练。
在一种可能的实现方式中,各个工作节点在得到所述实际统计参数之后,可以将该实际统计参数存储在本工作节点,以利用该实际统计参数进行反向传播。
在本发明实施例中,工作节点接收参数服务器根据历史训练数据确定的第一统计参数,在该工作节点基于目标批量训练样本对目标模型的目标层进行训练时,基于已经接收到的第一统计参数和目标批量训练样本的目标统计参数,确定所述目标层的实际统计参数,并基于该实际统计参数执行所述目标模型的前向和反向传播训练,这样,工作节点无需等待参数服务器获取到目标模型的全部工作节点对目标层进行训练时的训练样本统计参数时,依据这些训练样本统计参数更新统计参数并下发至各个工作节点时,才能够基于该参数服务器下发的统计参数执行前向和反向传播训练,从而大大减小了工作节点等待参数服务器下发统计参数的时间,能够提升深度学习模型的训练效率。
需要说明的是,模型训练网络中的每一个工作节点可以分别执行上述深度学习模型训练方法中的各个步骤。
请参阅图3,是本申请提供的另一种深度学习模型训练方法的流程图,该深度学习模型训练方法应用于参数服务器,如图3所示,该方法可以包括以下步骤:
步骤301、向第一工作节点分别发送第一统计参数,其中,所述第一统计参数是所述参数服务器预先根据目标模型的目标层的历史训练数据确定的,所述第一工作节点为与所述参数服务器共同训练所述目标模型的工作节点。
本实施方式中,上述第一工作节点可以是与所述参数服务器一起对目标模型进行训练的工作节点,其可以是执行如图2所示方法的工作节点。
步骤301中,参数服务器可以向对目标模型进行训练的全部第一工作节点发送第一统计参数。其中,所述第一统计参数与如图2所示方法实施例中的第一统计参数具有相同含义,在此不在赘述。
步骤302、接收所述第一工作节点分别发送的目标统计参数,其中,所述目标统计参数为所述第一工作节点分别基于不同的批量训练样本训练至所述目标层时,对所述不同的批量训练样本分别进行统计后得出的统计参数。
其中,各个第一工作节点发送的目标统计参数可以为其对目标层进行训练时使用的批量训练样本的统计参数,其可以与如图2所示方法实施例中的目标统计参数具有相同含义,在此不再赘述。
在一种可选的实施方式中,上述接收所述第一工作节点分别基于不同的批量训练样本对目标模型的目标层进行训练时发送的目标统计参数,可以是在接收到每一个第一工作节点对目标模型的目标层进行训练时发送的目标统计参数后,才执行步骤303。
在另一种可选的实施方式中,上述接收所述第一工作节点分别基于不同的批量训练样本对目标模型的目标层进行训练时发送的目标统计参数,可以是在接收到预设数量个第一工作节点发送的目标统计参数的情况下,基于所述第一统计参数和所述预设数量个目标统计参数对所述第一统计参数进行更新,其中,所述预设数量小于或者等于所述第一工作节点的总数量。
本实施方式中,预留一些工作节点,当预设数量的第一工作节点提交数据到参数服务器后,即可根据提交的数据更新模型,而对于超出预设数量的第一工作节点提交数据不再等待和接收,这样,能够减少因为某个工作节点运行慢或死机等导致的整体模型训练过程停止等待,能够提升模型训练效率。
步骤303、基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数。
上述基于所述目标统计参数对所述第一统计参数进行更新,可以是根据所述目标统计参数对所述第一统计参数进行偏差校正。
应理解,参数服务器可以重复执行步骤301至步骤303,以实现在进行模型数据更新时,将每一个批标准化层的参数同步到各个工作节点。其中,参数服务器可以在得到更新后的第一统计参数时,执行步骤301。本公开对参数服务器执行步骤301、步骤302以及步骤303的顺序不做限制。
在具体实施中,参数服务器在得到更新后的第一统计参数之后,可以将该更新后的第一统计参数分别发送至第一工作节点,具体可以是在各工作节点前向运行至目标标准化层时,将更新后的第一统计参数分别发送至该工作节点,其中,所述目标标准化层为与所述更新的第一统计参数相关联的标准化层。
本申请实施例提供的深度学习模型训练方法,在对批标准化层进行训练的过程中,无需各个工作节点将训练样本数据全部发送至参数服务器,并等待参数服务器进行全局统计后返回统计参数才能够执行批标准化,而是由各个工作节点对各自使用的批标准化层的批标准化数据进行统计,并上报统计数据,从而减少了参数服务器与各个工作节点之间的数据交互量,且参数服务器能够向各个工作节点发送第一统计参数,以使工作节点能够根据接收到的第一统计参数对目标统计参数进行校正,以得到各自使用的实际统计参数,并采用该实际统计参数进行批标准化,能够大大减少工作节点的等待时间。
下面结合工作节点与参数服务器的数据交互过程,对本申请实施例提供的深度学习模型训练方法进行举例说明,如图4所示,该方法包括以下过程:
步骤401、工作节点A从数据库中获取当前需要训练的目标训练数据。
其中,目标训练数据即如图2和图3所示方法实施例中的目标批量训练样本。
步骤402、工作节点A基于目标训练数据对目标模型进行训练。
其中,目标模型可以包括多个需要全局batch size统计参数的层(如batchnorm批标准化层,以下仅以batchnorm批标准化层为例)。
其中,全局batch size统计参数包括:方差,求和,平均值等统计值,其还包括:目标训练数据的数量。
步骤403、在工作节点A前向运行至目标批标准化层时,获取目标统计参数,并基于历史统计参数对目标统计参数进行校正,得到实际统计参数。
本步骤中,工作节点A获取目标批标准化层的目标训练数据,并对目标训练数据进行统计计算,以得到目标统计参数,然后使用从参数服务器同步来的目标批标准化层的第一统计参数,对本工作节点的目标统计参数进行校正,求得当前实际使用的统计参数,即实际统计参数,此统计参数在本机保留,反向传播时需要用到。
步骤404、在工作节点A将目标统计参数发送至参数服务器。
步骤405、参数服务器根据目标统计参数对存储的第一统计参数进行更新,并将更新后的第一统计参数发送至各个工作节点(包括工作节点A)。
本申请实施例中,工作节点和参数服务器相互配合以执行如图2和图3所示深度学习模型训练方法的各个过程,且能够取得相同的有益效果,为避免重复,在此不再赘述。
请参阅图5,是本申请实施例提供的一种工作节点的结构图,如图5所示,该工作节点500包括:
第一接收模块501,用于接收参数服务器发送的第一统计参数,其中,所述第一统计参数是所述参数服务器根据目标模型的目标层的历史训练数据确定的;
第一获取模块502,用于在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数,其中,所述目标统计参数为所述目标批量训练样本的统计参数;
确定模块503,用于基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,并基于所述实际统计参数对所述目标批量训练样本进行批标准化,以及将所述目标统计参数发送至所述参数服务器。
可选的,工作节点500还包括:
第二获取模块,用于通过随机采样方式从数据库中获取所述目标批量训练样本;
或者,
第三获取模块,用于从数据库获取排列于预设位置处的所述目标批量训练样本,其中,所述数据库存储的训练样本是乱序排列的。
可选的,所述第一统计参数和所述目标统计参数分别包括:训练样本统计参数值和训练样本数量。
本申请实施例提供的工作节点500能够执行如图2所示方法实施例中的各个过程,且能够取得相同的有益效果,为避免重复,在此不再赘述。
请参阅图6,是本申请实施例提供的一种参数服务器的结构图,如图6所示,该参数服务器600包括:
发送模块601,用于向第一工作节点分别发送第一统计参数,其中,所述第一统计参数是所述参数服务器预先根据目标模型的目标层的历史训练数据确定的,所述第一工作节点为与所述参数服务器共同训练所述目标模型的工作节点;
第二接收模块602,用于接收所述第一工作节点分别发送的目标统计参数,其中,所述目标统计参数为所述第一工作节点分别基于不同的批量训练样本训练至所述目标层时,对所述不同的批量训练样本分别进行统计后得出的统计参数;
更新模块603,用于基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数。
可选的,更新模块603,包括:
更新单元,用于在接收到预设数量个第一工作节点发送的目标统计参数的情况下,基于所述第一统计参数和所述预设数量个目标统计参数对所述第一统计参数进行更新,其中,所述预设数量小于或者等于所述第一工作节点的总数量。
本申请实施例提供的参数服务器600能够执行如图3所示方法实施例中的各个过程,且能够取得相同的有益效果,为避免重复,在此不再赘述。
发明实施例还提供一种电子设备,包括处理器、存储器,存储在存储器上并可在所述处理器上运行的程序或指令,该程序或指令被处理器执行时实现如图1或图2所示方法实施例的各个过程,且能达到相同的技术效果,为避免重复,这里不再赘述。
需要说明的是,在本文中,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者装置不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者装置所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括该要素的过程、方法、物品或者装置中还存在另外的相同要素。
通过以上的实施方式的描述,本领域的技术人员可以清楚地了解到上述实施例方法可借助软件加必需的通用硬件平台的方式来实现,当然也可以通过硬件,但很多情况下前者是更佳的实施方式。基于这样的理解,本发明的技术方案本质上或者说对现有技术做出贡献的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质(如ROM/RAM、磁碟、光盘)中,包括若干指令用以使得一台移动终端(可以是手机,计算机,服务器,空调器,或者网络设备等)执行本发明各个实施例所述的方法。
上面结合附图对本发明的实施例进行了描述,但是本发明并不局限于上述的具体实施方式,上述的具体实施方式仅仅是示意性的,而不是限制性的,本领域的普通技术人员在本发明的启示下,在不脱离本发明宗旨和权利要求所保护的范围情况下,还可做出很多形式,均属于本发明的保护之内。
Claims (12)
1.一种深度学习模型训练方法,应用于工作节点,其特征在于,所述方法包括:
接收参数服务器发送的第一统计参数,其中,所述第一统计参数是所述参数服务器根据目标模型的目标层的历史训练数据确定的;
在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数,其中,所述目标统计参数为所述目标批量训练样本的统计参数;
基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,并基于所述实际统计参数对所述目标批量训练样本进行批标准化,以及将所述目标统计参数发送至所述参数服务器;
所述第一统计参数包括对历史训练数据进行计算后得出的统计值或者统计值序列,或者,所述第一统计参数包括已完成训练的训练样本数量;
所述目标层包括目标批标准化层;
所述目标统计参数包括对所述目标批量训练样本进行统计计算后得出的统计值或者统计值序列;
所述基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,包括:
基于所述第一统计参数对所述目标统计参数进行校正计算,得到所述目标层的实际统计参数。
2.根据权利要求1所述的深度学习模型训练方法,其特征在于,所述在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数之前,所述方法还包括:
通过随机采样方式从数据库中获取所述目标批量训练样本;
或者,
从数据库获取排列于预设位置处的所述目标批量训练样本,其中,所述数据库存储的训练样本是乱序排列的。
3.根据权利要求1所述的深度学习模型训练方法,其特征在于,所述第一统计参数和所述目标统计参数分别包括:训练样本统计参数值和训练样本数量。
4.一种深度学习模型训练方法,应用于参数服务器,其特征在于,所述方法包括:
向第一工作节点分别发送第一统计参数,其中,所述第一统计参数是所述参数服务器预先根据目标模型的目标层的历史训练数据确定的,所述第一工作节点为与所述参数服务器共同训练所述目标模型的工作节点;
接收所述第一工作节点分别发送的目标统计参数,其中,所述目标统计参数为所述第一工作节点分别基于不同的批量训练样本训练至所述目标层时,对所述不同的批量训练样本分别进行统计后得出的统计参数;
基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数;
所述第一统计参数包括对历史训练数据进行计算后得出的统计值或者统计值序列,或者,所述第一统计参数包括已完成训练的训练样本数量;
所述目标层包括目标批标准化层;
所述目标统计参数包括对所述批量训练样本进行统计计算后得出的统计值或者统计值序列;
基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数,包括:
根据所述目标统计参数对所述第一统计参数进行偏差校正,得到更新后的第一统计参数。
5.根据权利要求4所述的深度学习模型训练方法,其特征在于,所述基于所述目标统计参数对所述第一统计参数进行更新,包括:
在接收到预设数量个第一工作节点发送的目标统计参数的情况下,基于所述第一统计参数和所述预设数量个目标统计参数对所述第一统计参数进行更新,其中,所述预设数量小于或者等于所述第一工作节点的总数量。
6.一种工作节点,其特征在于,包括:
第一接收模块,用于接收参数服务器发送的第一统计参数,其中,所述第一统计参数是所述参数服务器根据目标模型的目标层的历史训练数据确定的;
第一获取模块,用于在基于目标批量训练样本对所述目标层进行训练时,获取所述目标层的目标统计参数,其中,所述目标统计参数为所述目标批量训练样本的统计参数;
确定模块,用于基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,并基于所述实际统计参数对所述目标批量训练样本进行批标准化,以及将所述目标统计参数发送至所述参数服务器;
所述第一统计参数包括对历史训练数据进行计算后得出的统计值或者统计值序列,或者,所述第一统计参数包括已完成训练的训练样本数量;
所述目标层包括目标批标准化层;
所述目标统计参数包括对所述目标批量训练样本进行统计计算后得出的统计值或者统计值序列;
所述基于所述第一统计参数和所述目标统计参数确定所述目标层的实际统计参数,包括:
基于所述第一统计参数对所述目标统计参数进行校正计算,得到所述目标层的实际统计参数。
7.根据权利要求6所述的工作节点,其特征在于,还包括:
第二获取模块,用于通过随机采样方式从数据库中获取所述目标批量训练样本;
或者,
第三获取模块,用于从数据库获取排列于预设位置处的所述目标批量训练样本,其中,所述数据库存储的训练样本是乱序排列的。
8.根据权利要求6所述的工作节点,其特征在于,所述第一统计参数和所述目标统计参数分别包括:训练样本统计参数值和训练样本数量。
9.一种参数服务器,其特征在于,包括:
发送模块,用于向第一工作节点分别发送第一统计参数,其中,所述第一统计参数是所述参数服务器预先根据目标模型的目标层的历史训练数据确定的,所述第一工作节点为与所述参数服务器共同训练所述目标模型的工作节点;
第二接收模块,用于接收所述第一工作节点分别发送的目标统计参数,其中,所述目标统计参数为所述第一工作节点分别基于不同的批量训练样本训练至所述目标层时,对所述不同的批量训练样本分别进行统计后得出的统计参数;
更新模块,用于基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数;
所述第一统计参数包括对历史训练数据进行计算后得出的统计值或者统计值序列,或者,所述第一统计参数包括已完成训练的训练样本数量;
所述目标层包括目标批标准化层;
所述目标统计参数包括对所述批量训练样本进行统计计算后得出的统计值或者统计值序列;
基于所述目标统计参数对所述第一统计参数进行更新,得到更新后的第一统计参数,包括:
根据所述目标统计参数对所述第一统计参数进行偏差校正,得到更新后的第一统计参数。
10.根据权利要求9所述的参数服务器,其特征在于,所述更新模块,包括:
更新单元,用于在接收到预设数量个第一工作节点发送的目标统计参数的情况下,基于所述第一统计参数和所述预设数量个目标统计参数对所述第一统计参数进行更新,其中,所述预设数量小于或者等于所述第一工作节点的总数量。
11.一种电子设备,其特征在于,包括处理器,存储器及存储在所述存储器上并可在所述处理器上运行的程序或指令,所述程序或指令被所述处理器执行时实现如权利要求1-3中任一项所述的深度学习模型训练方法的步骤,或者所述程序或指令被所述处理器执行时实现如权利要求4或5所述的深度学习模型训练方法的步骤。
12.一种可读存储介质,其特征在于,所述可读存储介质上存储程序或指令,所述程序或指令被处理器执行时实现如权利要求1-3中任一项所述的深度学习模型训练方法的步骤,或者所述程序或指令被处理器执行时实现如权利要求4或5所述的深度学习模型训练方法的步骤。
Priority Applications (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010896348.9A CN112016699B (zh) | 2020-08-31 | 2020-08-31 | 一种深度学习模型训练方法、工作节点和参数服务器 |
PCT/CN2021/115544 WO2022042741A1 (zh) | 2020-08-31 | 2021-08-31 | 学习模型训练方法、工作节点、服务器、设备、介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010896348.9A CN112016699B (zh) | 2020-08-31 | 2020-08-31 | 一种深度学习模型训练方法、工作节点和参数服务器 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN112016699A CN112016699A (zh) | 2020-12-01 |
CN112016699B true CN112016699B (zh) | 2024-02-02 |
Family
ID=73503128
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202010896348.9A Active CN112016699B (zh) | 2020-08-31 | 2020-08-31 | 一种深度学习模型训练方法、工作节点和参数服务器 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN112016699B (zh) |
WO (1) | WO2022042741A1 (zh) |
Families Citing this family (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112016699B (zh) * | 2020-08-31 | 2024-02-02 | 北京灵汐科技有限公司 | 一种深度学习模型训练方法、工作节点和参数服务器 |
CN114004358B (zh) * | 2021-12-29 | 2022-06-14 | 粤港澳大湾区数字经济研究院(福田) | 一种深度学习模型训练方法 |
CN116663639B (zh) * | 2023-07-31 | 2023-11-03 | 浪潮电子信息产业股份有限公司 | 一种梯度数据同步方法、系统、装置及介质 |
CN117370471B (zh) * | 2023-12-07 | 2024-02-27 | 苏州元脑智能科技有限公司 | 基于修剪平均的全局预测方法、装置、设备及存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107688493A (zh) * | 2016-08-05 | 2018-02-13 | 阿里巴巴集团控股有限公司 | 训练深度神经网络的方法、装置及系统 |
CN108122032A (zh) * | 2016-11-29 | 2018-06-05 | 华为技术有限公司 | 一种神经网络模型训练方法、装置、芯片和系统 |
US20190026657A1 (en) * | 2016-03-26 | 2019-01-24 | Alibaba Group Holding Limited | Distributed Cluster Training Method and Apparatus |
CN109754060A (zh) * | 2017-11-06 | 2019-05-14 | 阿里巴巴集团控股有限公司 | 一种神经网络机器学习模型的训练方法及装置 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107578094A (zh) * | 2017-10-25 | 2018-01-12 | 济南浪潮高新科技投资发展有限公司 | 基于参数服务器和fpga实现神经网络分布式训练的方法 |
CN108491928B (zh) * | 2018-03-29 | 2019-10-25 | 腾讯科技(深圳)有限公司 | 模型参数发送方法、装置、服务器及存储介质 |
CN112016699B (zh) * | 2020-08-31 | 2024-02-02 | 北京灵汐科技有限公司 | 一种深度学习模型训练方法、工作节点和参数服务器 |
-
2020
- 2020-08-31 CN CN202010896348.9A patent/CN112016699B/zh active Active
-
2021
- 2021-08-31 WO PCT/CN2021/115544 patent/WO2022042741A1/zh active Application Filing
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20190026657A1 (en) * | 2016-03-26 | 2019-01-24 | Alibaba Group Holding Limited | Distributed Cluster Training Method and Apparatus |
CN107688493A (zh) * | 2016-08-05 | 2018-02-13 | 阿里巴巴集团控股有限公司 | 训练深度神经网络的方法、装置及系统 |
CN108122032A (zh) * | 2016-11-29 | 2018-06-05 | 华为技术有限公司 | 一种神经网络模型训练方法、装置、芯片和系统 |
CN109754060A (zh) * | 2017-11-06 | 2019-05-14 | 阿里巴巴集团控股有限公司 | 一种神经网络机器学习模型的训练方法及装置 |
Non-Patent Citations (4)
Title |
---|
"Tianjic: A Unified and Scalable Chip Bridging Spike-Based and Continuous Neural Computation";Deng Lei 等;《IEEE Journal of Solid-State Circuits》;全文 * |
"基于深度学习的车辆驾驶状态识别算法研究";郭耀华;《中国优秀硕士学位论文全文数据库 工程科技Ⅱ辑》;全文 * |
"多标准高性能前向纠错码处理器";吴臻志;《中国博士学位论文全文数据库 信息科技辑》;全文 * |
"面向云计算的分布式机器学习任务调度算法研究";孟彬彬 等;《西安文理学院学报( 自然科学版)》;第23卷(第1期);全文 * |
Also Published As
Publication number | Publication date |
---|---|
WO2022042741A1 (zh) | 2022-03-03 |
CN112016699A (zh) | 2020-12-01 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112016699B (zh) | 一种深度学习模型训练方法、工作节点和参数服务器 | |
CN113139662A (zh) | 联邦学习的全局及局部梯度处理方法、装置、设备和介质 | |
CN111914936B (zh) | 语料数据的数据特征增强方法、装置及计算机设备 | |
CN106549810A (zh) | 云服务平台新版本发布前测试方法、装置以及系统 | |
CN111030861A (zh) | 一种边缘计算分布式模型训练方法、终端和网络侧设备 | |
US20100217734A1 (en) | Method and system for calculating value of website visitor | |
CN109709985B (zh) | 一种无人机任务优化方法、装置及系统 | |
CN107911251B (zh) | 一种网络设备配置方法、装置和介质 | |
CN101808167B (zh) | 一种流程跟踪方法以及装置和系统 | |
CN110046091A (zh) | 一种自动测试方法和装置 | |
CN114945817A (zh) | 基于缺陷检测的任务处理方法、装置及设备及存储介质 | |
Waeber et al. | A Bayesian approach to stochastic root finding | |
CN115660115A (zh) | 一种联邦学习模型训练方法、装置、设备及存储介质 | |
CN109001694A (zh) | 一种动态自适应天线扫描特性模拟方法及系统 | |
CN115577797B (zh) | 一种基于本地噪声感知的联邦学习优化方法及系统 | |
CN110753366A (zh) | 行业短信网关容量的预测处理方法及装置 | |
CN113128696A (zh) | 分布式机器学习通信优化方法、装置、服务器及终端设备 | |
CN116050554A (zh) | 一种景区客流量预测方法、装置、计算设备和存储介质 | |
CN110780859B (zh) | 一种基于自定义表单的服务架构的实现方法 | |
Ridder | Asymptotic optimality of the cross-entropy method for Markov chain problems | |
CN114528893A (zh) | 机器学习模型训练方法、电子设备及存储介质 | |
CN110852418A (zh) | 神经网络模型的数据处理方法及装置、存储介质、终端 | |
CN113115231B (zh) | 基于lbs的数据处理系统 | |
CN117493583B (zh) | 结合事件日志和知识图谱的流程操作序列生成方法及系统 | |
CN108259393B (zh) | 一种流数据处理中乱序纠正方法及系统 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |