CN113344074B - 模型训练方法、装置、设备及存储介质 - Google Patents
模型训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN113344074B CN113344074B CN202110615506.3A CN202110615506A CN113344074B CN 113344074 B CN113344074 B CN 113344074B CN 202110615506 A CN202110615506 A CN 202110615506A CN 113344074 B CN113344074 B CN 113344074B
- Authority
- CN
- China
- Prior art keywords
- model
- subtree
- trained
- sample set
- sub
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 110
- 238000000034 method Methods 0.000 title claims abstract description 58
- 230000008569 process Effects 0.000 claims abstract description 17
- 230000005540 biological transmission Effects 0.000 claims description 3
- 238000013473 artificial intelligence Methods 0.000 abstract description 6
- 238000005516 engineering process Methods 0.000 abstract description 6
- 238000013135 deep learning Methods 0.000 abstract description 3
- 239000000523 sample Substances 0.000 description 117
- 238000003062 neural network model Methods 0.000 description 11
- 238000004590 computer program Methods 0.000 description 10
- 238000013528 artificial neural network Methods 0.000 description 9
- 238000004891 communication Methods 0.000 description 8
- 238000010586 diagram Methods 0.000 description 7
- 238000012545 processing Methods 0.000 description 7
- 230000006870 function Effects 0.000 description 4
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 241001632427 Radiola Species 0.000 description 2
- 230000002776 aggregation Effects 0.000 description 2
- 238000004220 aggregation Methods 0.000 description 2
- 230000002146 bilateral effect Effects 0.000 description 2
- 238000007796 conventional method Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 239000013074 reference sample Substances 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/243—Classification techniques relating to the number of classes
- G06F18/24323—Tree-organised classifiers
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computational Linguistics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Evolutionary Biology (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本公开提供了一种模型训练方法、装置、设备及存储介质,涉及计算机技术领域,进一步涉及计算机视觉和深度学习等人工智能技术。具体实现方案为:确定基于树结构的待训练模型在训练过程中需要使用的期望容量;在本地设备中图像处理器的可用容量小于所述期望容量的情况下,对所述待训练模型的树结构进行拆分,得到至少两个子树;通过本地设备中图像处理器,分别采用所述至少两个子树关联的样本集,对待训练模型进行训练。通过本公开的技术,为通过图像处理器对规模较大的TDM模型进行训练提供了一种新思路。
Description
技术领域
本公开涉及计算机技术领域,尤其涉及计算机视觉和深度学习等人工智能领域,具体涉及一种模型训练方法、装置、设备以及存储介质。
背景技术
随着深度学习等人工智能技术的快速发展,人工智能技术已经广泛应用于计算机视觉领域,即基于人工智能技术训练计算机视觉任务模型。比如为了能够从海量数据集(比如海量广告或者商品等)中高效且精准的召回用户感兴趣的内容,基于树的深度模型(Tree-based Deep Model,TDM)被广泛使用。
目前,对于数据量比较大的TDM模型,通常采用分布式CPU集群进行训练。但受限于CPU的硬件能力等,无法支持复杂的TDM模型,因此亟需提供一种新的模型训练方法,用于训练TDM模型。
发明内容
本公开提供了一种模型训练方法、装置、设备及存储介质。
根据本公开的一方面,提供了一种模型训练方法,该方法包括:
确定基于树结构的待训练模型在训练过程中需要使用的期望容量;
在本地设备中图像处理器的可用容量小于所述期望容量的情况下,对所述待训练模型的树结构进行拆分,得到至少两个子树;
通过本地设备中图像处理器,分别采用所述至少两个子树关联的样本集,对待训练模型进行训练。
根据本公开的另一方面,提供了一种模型训练装置,该装置包括:
期望容量确定模块,用于确定基于树结构的待训练模型在训练过程中需要使用的期望容量;
树拆分模块,用于在本地设备中图像处理器的可用容量小于所述期望容量的情况下,对所述待训练模型的树结构进行拆分,得到至少两个子树;
训练模块,用于通过本地设备中图像处理器,分别采用所述至少两个子树关联的样本集,对待训练模型进行训练。
根据本公开的另一方面,提供了一种电子设备,该电子设备包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开任一实施例所述的模型训练方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使计算机执行本公开任一实施例所述的模型训练方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开任一实施例所述的模型训练方法。
根据本公开的技术,能够支持复杂的TDM模型的训练。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1A是根据本公开实施例提供的一种模型训练方法的流程图;
图1B是根据本公开实施例提供的一种树结构的结构示意图;
图1C是根据本公开实施例提供的一种对树结构拆分后子树的结构示意图;
图2是根据本公开实施例提供的另一种模型训练方法的流程图;
图3是根据本公开实施例提供的又一种模型训练方法的流程图;
图4是根据本公开实施例提供的一种模型训练装置的结构示意图;
图5是用来实现本公开实施例的模型训练方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
图1A是根据本公开实施例提供的一种模型训练方法的流程图。本公开实施例适用于如何对TDM模型进行训练的情况,尤其适用于在TDM模型的数据量(或者可以说规模)比较大的场景下,如何对TDM模型进行训练的情况。该实施例可以由配置在电子设备中的模型训练装置来执行,该装置可以采用软件和/或硬件来实现。如图1A所示,该模型训练方法包括:
S101,确定基于树结构的待训练模型在训练过程中需要使用的期望容量。
本实施例中,所谓基于树结构的待训练模型即为待训练的TDM模型。其中,待训练模型中的树结构可以是对大量数据进行聚合得到;可选的,需求或者应用场景不同,进行聚合的数据不同,进而待训练模型的用途不同。例如,可以对大量的广告(具体为广告的特征表示)进行聚合得到待训练模型的树结构,进而对待训练模型进行训练可以得到用于召回用户感兴趣广告的一个模型。进一步的,树结构中叶子节点表示真实的场景数据比如广告,树结构中叶子节点之外的其他节点表示广告之间的共性信息(也可称为虚拟广告)。比如图1B所示的树结构,节点8和节点9等代表不同的广告;节点4表示节点8和节点9的共性信息,即节点8所代表的广告与节点9所代表的广告之间的共性信息。
可选的,可根据待训练模型的树结构中所包括的节点数量和节点参数维度等,确定期望容量;其中,所谓节点参数为节点的特征表示(即embedding),比如广告场景下,叶子节点参数可以为广告的特征表示,叶子节点之外的其他节点参数为共性信息的特征表示;待训练模型的树结构中各个节点的参数维度相同。具体的,可以将节点数量、节点参数维度以及每一维度参数的字节数相乘,并将乘积作为期望容量。
作为一种可选实施方式,还可以结合待训练模型的所需的其他数据,比如训练样本等来确定期望容量。例如,可以根据待训练模型的树结构中所包括的节点数量、节点参数维度和所需训练样本数量等,确定期望容量。
S102,在本地设备中图像处理器的可用容量小于期望容量的情况下,对待训练模型的树结构进行拆分,得到至少两个子树。
本实施例中,图像处理器(Graphics Processing Unit,GPU)也可以称为图形处理器,或者微型处理器等。可选的,本地设备即为执行模型训练的任一电子设备;进一步的,本实施例中本地设备中配置有至少两个图像处理器,作为计算资源,以大幅度提升本地设备的硬件能力,进而提升模型训练的效率。
其中,本地设备中图像处理器的可用容量为本地设备中当前可供使用的所有图像处理器的容量之和,也可称为全局容量。
可选的,在确定基于树结构的待训练模型在训练过程中需要使用的期望容量之后,可以将所确定的期望容量与本地设备中图像处理器的可用容量进行比较,若本地设备中图像处理器的可用容量等于或大于期望容量,则说明本地设备的图像处理器中可以容纳下所有节点参数等数据,此时可以控制本地设备中至少两个图像处理器,基于树结构中节点参数,采用大量样本,对待训练模型(具体可以为对待训练模型中的神经网络模型)进行训练,以更新待训练模型的神经网络参数和树结构中节点参数,直至神经网络参数和节点参数等收敛,得到训练好的模型。
进一步的,若本地设备中图像处理器的可用容量小于期望容量,则说明本地设备的图像处理器中无法容纳下所有节点参数等数据,此时对待训练模型的树结构进行拆分,以得到至少两个子树,之后基于拆分的子树关联的样本集,对待训练模型进行训练。
本实施例中,对待训练模型的树结构进行拆分的方式有很多种,作为本公开的一种可选实施方式,可以按照左右对称方式对待训练模型的树结构进行拆分。比如,待训练模型的树结构如图1B所示,按照左右对称方式,可以将待训练模型的树结构拆分成如图1C所示的两个子树。
作为本公开的又一种可选实施方式,还可以按层对待训练模型的树结构进行拆分。比如,继续参见图1B,由于叶子节点的数量比较多,因此可以将前三层(其中根节点作为第一层)划分到一起,最后一层(即叶子节点)划分到一起。
S103,通过本地设备中图像处理器,分别采用至少两个子树关联的样本集,对待训练模型进行训练。
可选的,本实施例中可以基于子树关联的样本集分阶段对待训练模型进行训练。例如,对待训练模型的树结构进行拆分,得到两个子树,分别为子树A和子树B;进而可以控制本地设备中图像处理器(优选为至少两个图像处理器),基于子树A关联的样本集,对待训练模型进行训练;在完成该阶段的训练之后,可以控制本地设备中至少两个图像处理器,在基于子树A关联的样本集对待训练模型进行训练完成后的基础上,基于子树B关联的样本集,再次对待训练模型进行训练,直至神经网络参数和节点参数等收敛,得到训练好的模型。
本公开实施例的技术方案,通过引入图像处理器对待训练模型进行训练,相比于现有采用CPU对TDM模型进行训练而言,降低了模型训练成本,提升了训练模型的效率;同时,本公开实施例根据待训练模型在训练过程中需要使用的期望容量和图像处理器的可用容量之间的比较结果,在待训练模型的数据量比较大即期望容量大于需求容量的情况下,动态对待训练模型的树结构进行拆分,进而控制图像处理器,根据拆分的各个子树关联的样本集,对待训练模型进行训练,为通过图像处理器对规模较大的TDM模型进行训练提供了一种新思路。
可选的,作为本公开实施例的一种可选实施方式,还可以根据本地设备中图像处理器的可用容量,以及待训练模型的树结构中叶子节点分布情况,对待训练模型的树结构进行拆分,得到至少两个子树。本实施例中叶子节点分布情况可以包括叶子节点在树结构中的分布位置,以及叶子节点分布的稀疏程度等。
例如可以是,根据本地设备中图像处理器的可用容量和节点参数维度等,确定本地设备中图像处理器所能容纳的节点数量;根据节点数量和叶子节点分布情况,对待训练模型的树结构进行拆分,得到至少两个子树。可选的,每一子树中均包括叶子节点;进一步的,每一子树中至少包括两个叶子节点。
需要说明的是,本实施例中在对待训练模型的树结构进行拆分时,考虑图像处理器的可用容量,能够使得图像处理器的资源被充分利用,为后续高效训练模型奠定了基础;同时,在对待训练模型的树结构进行拆分时,还考虑待训练模型的树结构中叶子节点分布情况,保证每一子树中都包括叶子节点,以在模型训练过程中增加树的层次信息,为提升模型的准确度奠定了基础。
图2是根据本公开实施例提供的另一种模型训练方法的流程图。本实施例在上述实施例的基础上,对如何通过本地设备中图像处理器,分别采用至少两个子树关联的样本集,对待训练模型进行训练进行解释说明。如图2所示,该模型训练方法包括:
S201,确定基于树结构的待训练模型在训练过程中需要使用的期望容量。
S202,在本地设备中图像处理器的可用容量小于期望容量的情况下,对待训练模型的树结构进行拆分,得到至少两个子树。
S203,针对每一子树,根据该子树所关联的种子样本和该子树的子树结构,生成该子树的样本集;控制本地设备中图像处理器,基于该子树中节点参数,采用该子树的样本集,对待训练模型进行训练,以更新待训练模型的神经网络参数和该子树中节点参数。
本实施例中,所谓种子样本也可称为基准样本,从实际使用场景中获取。比如广告场景下,基于用户输入的请求消息(即query)召回相应的广告,以此将请求消息和广告作为种子样本;可选的,每一种子样本可携带有一个标签(即label),进而对于任一种子样本而言,可由请求消息、广告和标签等组成。对于种子样本而言,标签为1(表示从大量数据中这一数据被召回的概率比较大);相应的标签为0,表示数据被召回的概率比较小。进一步的,任一种子样本,对应树结构中一个叶子节点。
可选的,对于任一子树的样本集可以包括正样本集和负样本集。进一步的,对于任一子树,可根据该子树所关联的种子样本和该子树的子树结构,生成该子树的样本集。其中,对于任一子树所关联的种子样本优选为多个。作为本公开实施例的一种可选实施方式,对于任一子树的正样本集和负样本集可通过如下方式确定:将种子样本在子树结构中所属的支路,作为目标支路;根据种子样本中的请求消息,以及目标支路所包括的目标节点,构成该子树的正样本集;根据种子样本中的请求消息,以及该子树中除目标节点之外的其他节点,构建该子树的负样本集。
具体的,针对该子树关联的每一种子样本,将该种子样本所对应的叶子节点在该子树的子树结构中所属的支路,作为目标支路。例如图1C中节点8所在的这一子树,种子样本所对应的叶子节点为节点9,进而可以将节点9在该子树的子树结构中所属的支路,即节点1-节点2-节点4-节点9,作为目标支路。进一步的,可以将位于目标支路上的节点作为目标节点,即将节点1、节点2、节点4和节点9作为目标节点。
可选的,可以从位于目标支路上的目标节点中选取一部分或全部构建正样本。比如可以基于每一目标节点构建一个正样本。继续参见图1C,比如由户的请求消息、节点9所表示的广告和标签数据(具体为1)构成一个正样本,由户的请求消息、节点4所表示的共性信息和标签数据(具体为1)构成一个正样本,由户的请求消息、节点2所表示的共性信息和标签数据(具体为1)构成一个正样本,以及由户的请求消息、节点1所表示的共性信息和标签数据(具体为1)构成一个正样本。
进一步的,可以按层从该子树中除目标节点之外的其他节点中随机选取一个或多个,进而基于种子样本中的请求消息,以及随机选取的节点,构建负样本。继续参见图1C,例如随机从每层中抽取一个其他节点,比如抽取节点8和节点5,进而可以基于节点8和节点5构建两个负样本。比如,由用户的请求消息、节点8所表示的广告和标签数据(具体为0)构成一个负样本,由用户的请求消息、节点5所表示的共性信息和标签数据(具体为0)构成一个负样本。
可选的,针对该子树关联的每一种子样本,通过采用上述方式可构建一些正样本和负样本;进而基于该子树关联的所有种子样本所构建的正样本,可得到该子树的正样本集;同理,基于该子树关联的所有种子样本所构建的负样本,可得到该子树的负样本集。
进一步的,基于神经网络模型对输入数据格式的限定,本实施例中任一子树的样本集中的任一样本均向量化即embedding。比如广告场景下,样本集中任一样本可以由用户的请求消息的embedding、广告或广告的共性信息的embedding以及label等按照设定的格式构成。为了降低传输数据所带来的网络开销,任一子树的样本集中的任一样本可以由用户的请求消息的embedding、节点标识以及label等按照设定的格式构成。其中,节点标识可以是节点的编号。示例性的,本地设备中可以存储节点标识与节点参数之间的对应关系表,进而在基于任一子树关联的样本集对待训练模型进行训练之前,可以将该子树中节点参数提前传输至本地设备的图像处理器中。进一步的,本实施例中对应关系表中各个节点参数是基于聚合得到,也可以称为初始参数,在模型训练过程中会动态更新。
需要说明的是,本实施例通过引入构建样本集的方式,丰富了本地设备的功能,为后续进行模型训练提供了数据支撑。
具体的,本实施例中,在得到至少两个子树之后,可以从至少两个子树中随机选取一个子树,比如子树A,进而可以根据子树A所关联的种子样本和子树A的子树结构,生成子树A的样本集;之后向本地设备中的至少两个图像处理器传输子树A的样本集,并控制本地设备中的至少两个图像处理器配合,基于子树A中节点参数,采用子树A的样本集,对待训练模型中的神经网络模型进行训练,以更新神经网络参数和子树A中节点参数。进一步的,根据子树A所关联的种子样本和子树A的子树结构,生成子树A的样本集之前、同时或之后,可以向本地设备中的至少两个图像处理器传输子树A中节点参数。此外,每个图像处理器中均存储神经网络模型,且每个图像处理器所存储的神经网络模型的初始参数相同。
进一步的,在基于子树A关联的样本集对待训练模型进行训练完成后,可以从至少两个图像处理器中导出更新后的子树A中的节点参数并存储。
示例性的,基于子树A关联的样本集对待训练模型进行训练完成后,可以随机从剩余子树中选取一个子树,比如子树B,可以根据子树B所关联的种子样本和子树B的子树结构,生成子树B的样本集;之后向本地设备中的至少两个图像处理器传输子树B的样本集,并控制本地设备中的至少两个图像处理器配合,在基于子树A关联的样本集所更新后的神经网络参数(为便于区分,此处称为当前神经网络参数)的基础上,基于子树B中节点参数,采用子树B的样本集,对待训练模型中的神经网络模型再次进行训练,以更新当前神经网络参数和子树B中节点参数。进一步的,在基于子树B关联的样本集对待训练模型进行训练完成后,可以从至少两个图像处理器中导出更新后的子树B中的节点参数并存储。
重复上述操作,直至待训练模型的树结构中所有节点参数均被更新,停止模型训练。此时可以将训练好的神经网络模型从图像处理器中导出,并与从图像处理器导出的树结构中所有节点参数对应存储。
本公开实施例的技术方案,通过在生成子树的样本集的过程中,结合子树关联的种子样本和子树的子树结构,使得所生成的样本集能够体现树的层次信息,即通过样本集可以呈现出树的结构,为提升模型的准确度奠定了基础。同时通过在待训练模型的数据量比较大即期望容量大于需求容量的情况下,动态对待训练模型的树结构进行拆分,进而控制图像处理器,根据拆分的各个子树关联的样本集,对待训练模型进行训练,为通过图像处理器对规模较大的TDM模型进行训练提供了一种新思路。
图3是根据本公开实施例提供的又一种模型训练方法的流程图。本实施例在上述实施例的基础上,进一步对如何通过本地设备中图像处理器,分别采用至少两个子树关联的样本集,对待训练模型进行训练进行解释说明。如图3所示,该模型训练方法包括:
S301,确定基于树结构的待训练模型在训练过程中需要使用的期望容量。
S302,在本地设备中图像处理器的可用容量小于期望容量的情况下,对待训练模型的树结构进行拆分,得到至少两个子树;其中,本地设置中配置有至少两个图像处理器。
S303,针对每一子树,根据该子树所关联的种子样本和该子树的子树结构,生成该子树的样本集;将该子树的样本集分配给至少两个图像处理器,以得到图像处理器关联的子样本集;通过图像处理器采用关联的子样本集,对待训练模型进行训练。
可选的,本实施例中,在得到至少两个子树之后,可以从至少两个子树中随机选取一个子树,比如子树A;采用子树A中各节点标识(比如节点编号),分别对本地设备中图像处理器的总数量(比如本地设备中包括8个图像处理器,即总数量为8)进行取余运算。可选的,本实施例中为本地设备中每个图像处理器分配一个唯一标识,例如为每个图像处理器分配一个编号(比如0、1…)。进而可以对于子树A中每个节点,可以将该节点参数传输至编号与该节点的取余结果相等的图像处理器中。例如,节点1对8取余的结果为1,进而可以将节点1的参数传输至图像处理器1中。
之后可以根据子树A所关联的种子样本和子树A的子树结构,生成子树A的样本集;对于子树A中的每个节点,可以从样本集中获取该节点所对应的样本,并传输至存储该节点参数的图像处理器中;通过此操作,可以得到每个图像处理器关联的子样本集;采用异步方式,控制各个图像处理器采用关联的子样本集,对待训练模型中的神经网络模型进行训练,以更新子树A中节点参数。
作为本公开实施例的一种可选方式,可以将子树A中每一种子样本关联的正样本和负样本作为一个数据单元,进而子树A的样本集可以看成由多个数据单元组成。将子树A的样本集中的数据单元随机分配给至少两个图像处理器,以此可得到每个图像处理器关联的子样本集。之后,可以控制各个图像处理器采用关联的子样本集,对待训练模型中的神经网络模型进行训练,以更新子树A中节点参数。可选的,任意两个处理器所分配到的子样本集不同。需要说明的是,本实施例中通过将每一种子样本关联的正样本和负样本作为一个整体即数据单元,且以数据单元为分配最小单位,使得在模型训练过程中具有树的层次信息,为提升模型的准确度奠定了基础。
比如,本地设备中包括8个图像处理器,即GPU0至GPU7。假设子树A为图1C中节点8所在的这一子树,某一种子样本对应的叶子节点为图1C中的节点9,该种子样本关联的正样本可以包括基于节点9、节点4、节点2以及节点1分别构建的样本,该种子样本关联的负样本可以包括基于节点8和节点5分别构建的样本。进一步的,节点1和节点9的参数存储于GPU1中,节点2的参数存储于GPU2中,节点4的参数存储于GPU4中,节点5的参数存储于GPU5中,节点8的参数存储于GPU0中。
如果该种子样本关联的这一数据单元被随机分配至GPU1中,GPU1可以与GPU0、GPU2、GPU4和GPU5通信,以获取节点8、节点2、节点4和节点5的参数;GPU1将这些节点参数以及相关样本数据作为待训练模型中神经网络模型的输入,开始进行前向和后向训练;进一步的,后向训练会得到节点参数对应的梯度数据,此时GPU1会用节点1和9的梯度更新本地的节点参数,同时会将节点2、节点4、节点5、节点8的梯度数据分别传递给GPU2、GPU4、GPU5和GPU0,以更新对应节点参数。
在GPU1基于其所关联的子样本集进行训练的过程中,本地设备中的其他GPU也在基于关联的子样本进行训练。进一步的,每个GPU在基于所获取的子样本集进行训练时,若GPU本地存储的节点参数已更新,则采用更新后的节点参数进行后续训练。可选的,由于样本集中样本数据量较大,可以将样本集划分为多个批次;每一批次,均会向各个GPU分配子样本集。直至子树A的样本集中的样本均用于训练模型,且子树A中的节点参数均被更新且收敛,说明基于子树A关联的样本集对待训练模型进行训练完成。此时可以从至少两个图像处理器中导出更新后的子树A中的节点参数并存储。
进一步的,基于子树A关联的样本集对待训练模型进行训练完成后,可以随机从剩余子树中选取一个子树,比如子树B,重复上述操作,以子树B中节点参数等。
重复上述操作,直至待训练模型的树结构中所有节点参数均被更新,停止模型训练。此时可以将训练好的神经网络模型从图像处理器中导出,并与从图像处理器导出的树结构中所有节点参数对应存储。
本公开实施例的技术方案,通过至少两个图像处理器作为计算资源,配合对待训练模型进行训练,大幅度提升了本地设备的硬件能力,进而提升了模型训练的效率。
图4是根据本公开实施例提供的一种模型训练装置的结构示意图。公开实施例适用于如何对TDM模型进行训练的情况,尤其适用于在TDM模型的数据量(或者可以说规模)比较大的场景下,如何对TDM模型进行训练的情况。该装置可以采用软件和/或硬件来实现,该装置可实现本公开任意实施例所述的模型训练方法。如图4所示,该模型训练装置包括:
期望容量确定模块401,用于确定基于树结构的待训练模型在训练过程中需要使用的期望容量;
树拆分模块402,用于在本地设备中图像处理器的可用容量小于期望容量的情况下,对待训练模型的树结构进行拆分,得到至少两个子树;
训练模块403,用于通过本地设备中图像处理器,分别采用所述至少两个子树关联的样本集,对待训练模型进行训练。
本公开实施例的技术方案,通过引入图像处理器对待训练模型进行训练,相比于现有采用CPU对TDM模型进行训练而言,降低了模型训练成本,提升了训练模型的效率;同时,本公开实施例根据待训练模型在训练过程中需要使用的期望容量和图像处理器的可用容量之间的比较结果,在待训练模型的数据量比较大即期望容量大于需求容量的情况下,动态对待训练模型的树结构进行拆分,进而控制图像处理器,根据拆分的各个子树关联的样本集,对待训练模型进行训练,为通过图像处理器对规模较大的TDM模型进行训练提供了一种新思路。
示例性的,树拆分模块402具体用于:
根据本地设备中图像处理器的可用容量,以及待训练模型的树结构中叶子节点分布情况,对待训练模型的树结构进行拆分,得到至少两个子树。
示例性的,训练模块403包括:
样本集生成单元,用于针对每一子树,根据该子树所关联的种子样本和该子树的子树结构,生成该子树的样本集;
训练单元,用于控制本地设备中图像处理器,基于该子树中节点参数,采用该子树的样本集,对待训练模型进行训练,以更新待训练模型的神经网络参数和该子树中节点参数。
示例性的,样本集生成单元具体用于:
将种子样本在子树结构中所属的支路,作为目标支路;
根据种子样本中的请求消息,以及目标支路所包括的目标节点,构成该子树的正样本集;
根据种子样本中的请求消息,以及该子树中除所述目标节点之外的其他节点,构建该子树的负样本集。
示例性的,训练单元具体用于:
将该子树的样本集分配给至少两个图像处理器,以得到图像处理器关联的子样本集;
通过图像处理器采用关联的子样本集,对待训练模型进行训练。
示例性的,上述装置还包括:
特征传输单元,用于向本地设备中图像处理器传输该子树中节点参数。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图5示出了可以用来实施本公开的实施例的示例电子设备500的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图5所示,电子设备500包括计算单元501,其可以根据存储在只读存储器(ROM)502中的计算机程序或者从存储单元508加载到随机访问存储器(RAM)503中的计算机程序,来执行各种适当的动作和处理。在RAM 503中,还可存储电子设备500操作所需的各种程序和数据。计算单元501、ROM 502以及RAM 503通过总线504彼此相连。输入/输出(I/O)接口505也连接至总线504。
电子设备500中的多个部件连接至I/O接口505,包括:输入单元506,例如键盘、鼠标等;输出单元507,例如各种类型的显示器、扬声器等;存储单元508,例如磁盘、光盘等;以及通信单元509,例如网卡、调制解调器、无线通信收发机等。通信单元509允许电子设备500通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元501可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元501的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元501执行上文所描述的各个方法和处理,例如模型训练方法。例如,在一些实施例中,模型训练方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元508。在一些实施例中,计算机程序的部分或者全部可以经由ROM 502和/或通信单元509而被载入和/或安装到电子设备500上。当计算机程序加载到RAM 503并由计算单元501执行时,可以执行上文描述的模型训练方法的一个或多个步骤。备选地,在其他实施例中,计算单元501可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行模型训练方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)、区块链网络和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,又称为云计算服务器或云主机,是云计算服务体系中的一项主机产品,以解决了传统物理主机与VPS服务中,存在的管理难度大,业务扩展性弱的缺陷。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (10)
1.一种模型训练方法,包括:
确定基于树结构的待训练模型在训练过程中需要使用的期望容量;
在本地设备中图像处理器的可用容量小于所述期望容量的情况下,对所述待训练模型的树结构进行拆分,得到至少两个子树;所述本地设备中配置有至少两个图像处理器;
针对每一子树,根据该子树所关联的种子样本和该子树的子树结构,生成该子树的样本集;将该子树的样本集分配给至少两个图像处理器,以得到图像处理器关联的子样本集;通过图像处理器采用关联的子样本集,对待训练模型进行训练。
2.根据权利要求1所述的方法,其中,所述对所述待训练模型的树结构进行拆分,得到至少两个子树,包括:
根据本地设备中图像处理器的可用容量,以及所述待训练模型的树结构中叶子节点分布情况,对所述待训练模型的树结构进行拆分,得到至少两个子树。
3.根据权利要求1所述的方法,其中,所述根据该子树所关联的种子样本和该子树的子树结构,生成该子树的样本集,包括:
将所述种子样本在子树结构中所属的支路,作为目标支路;
根据所述种子样本中的请求消息,以及所述目标支路所包括的目标节点,构成该子树的正样本集;
根据所述种子样本中的请求消息,以及该子树中除所述目标节点之外的其他节点,构建该子树的负样本集。
4.根据权利要求1所述的方法,所述将该子树的样本集分配给至少两个图像处理器,以得到图像处理器关联的子样本集之前,还包括:
向所述本地设备中图像处理器传输该子树中节点参数。
5.一种模型训练装置,包括:
期望容量确定模块,用于确定基于树结构的待训练模型在训练过程中需要使用的期望容量;
树拆分模块,用于在本地设备中图像处理器的可用容量小于所述期望容量的情况下,对所述待训练模型的树结构进行拆分,得到至少两个子树;所述本地设备中配置有至少两个图像处理器;
训练模块,用于针对每一子树,根据该子树所关联的种子样本和该子树的子树结构,生成该子树的样本集;将该子树的样本集分配给至少两个图像处理器,以得到图像处理器关联的子样本集;通过图像处理器采用关联的子样本集,对待训练模型进行训练。
6.根据权利要求5所述的装置,其中,所述树拆分模块具体用于:
根据本地设备中图像处理器的可用容量,以及所述待训练模型的树结构中叶子节点分布情况,对所述待训练模型的树结构进行拆分,得到至少两个子树。
7.根据权利要求5所述的装置,其中,所述样本集生成单元具体用于:
将所述种子样本在子树结构中所属的支路,作为目标支路;
根据所述种子样本中的请求消息,以及所述目标支路所包括的目标节点,构成该子树的正样本集;
根据所述种子样本中的请求消息,以及该子树中除所述目标节点之外的其他节点,构建该子树的负样本集。
8.根据权利要求5所述的装置,还包括:
特征传输单元,用于向所述本地设备中图像处理器传输该子树中节点参数。
9.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-4中任一项所述的模型训练方法。
10.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使计算机执行根据权利要求1-4中任一项所述的模型训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110615506.3A CN113344074B (zh) | 2021-06-02 | 2021-06-02 | 模型训练方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110615506.3A CN113344074B (zh) | 2021-06-02 | 2021-06-02 | 模型训练方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113344074A CN113344074A (zh) | 2021-09-03 |
CN113344074B true CN113344074B (zh) | 2023-09-05 |
Family
ID=77473052
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110615506.3A Active CN113344074B (zh) | 2021-06-02 | 2021-06-02 | 模型训练方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113344074B (zh) |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114356540B (zh) * | 2021-10-30 | 2024-07-02 | 腾讯科技(深圳)有限公司 | 一种参数更新方法、装置、电子设备和存储介质 |
CN114676795B (zh) * | 2022-05-26 | 2022-08-23 | 鹏城实验室 | 一种深度学习模型的训练方法、装置、设备及存储介质 |
Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112561077A (zh) * | 2020-12-14 | 2021-03-26 | 北京百度网讯科技有限公司 | 多任务模型的训练方法、装置及电子设备 |
CN112560936A (zh) * | 2020-12-11 | 2021-03-26 | 北京百度网讯科技有限公司 | 模型并行训练方法、装置、设备、存储介质和程序产品 |
CN112749325A (zh) * | 2019-10-31 | 2021-05-04 | 北京京东尚科信息技术有限公司 | 搜索排序模型的训练方法、装置、电子设备及计算机介质 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US9324040B2 (en) * | 2013-01-30 | 2016-04-26 | Technion Research & Development Foundation Limited | Training ensembles of randomized decision trees |
CN110728317A (zh) * | 2019-09-30 | 2020-01-24 | 腾讯科技(深圳)有限公司 | 决策树模型的训练方法、系统、存储介质及预测方法 |
-
2021
- 2021-06-02 CN CN202110615506.3A patent/CN113344074B/zh active Active
Patent Citations (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112749325A (zh) * | 2019-10-31 | 2021-05-04 | 北京京东尚科信息技术有限公司 | 搜索排序模型的训练方法、装置、电子设备及计算机介质 |
CN112560936A (zh) * | 2020-12-11 | 2021-03-26 | 北京百度网讯科技有限公司 | 模型并行训练方法、装置、设备、存储介质和程序产品 |
CN112561077A (zh) * | 2020-12-14 | 2021-03-26 | 北京百度网讯科技有限公司 | 多任务模型的训练方法、装置及电子设备 |
Also Published As
Publication number | Publication date |
---|---|
CN113344074A (zh) | 2021-09-03 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
JP7194252B2 (ja) | マルチタスクモデルのパラメータ更新方法、装置及び電子機器 | |
EP4160440A1 (en) | Federated computing processing method and apparatus, electronic device, and storage medium | |
CN113344074B (zh) | 模型训练方法、装置、设备及存储介质 | |
CN114202027B (zh) | 执行配置信息的生成方法、模型训练方法和装置 | |
WO2021023149A1 (zh) | 一种动态返回报文的方法和装置 | |
EP4060496A2 (en) | Method, apparatus, device and storage medium for running inference service platform | |
KR20210105315A (ko) | 데이터 주석 방법, 장치, 기기, 저장매체 및 컴퓨터 프로그램 | |
CN112527506A (zh) | 设备资源的处理方法、装置、电子设备及存储介质 | |
CN112560936B (zh) | 模型并行训练方法、装置、设备、存储介质和程序产品 | |
CN114153986A (zh) | 一种知识图谱构建方法、装置、电子设备及存储介质 | |
CN116524165B (zh) | 三维表情模型的迁移方法、装置、设备和存储介质 | |
CN112559632A (zh) | 分布式图数据库的状态同步方法、装置、电子设备及介质 | |
CN114579311B (zh) | 执行分布式计算任务的方法、装置、设备以及存储介质 | |
US20240275848A1 (en) | Content initialization method, electronic device and storage medium | |
CN115905322A (zh) | 业务处理方法、装置、电子设备及存储介质 | |
CN115454971A (zh) | 数据迁移方法、装置、电子设备及存储介质 | |
CN115730681B (zh) | 模型训练方法、装置、设备以及存储介质 | |
CN116560817B (zh) | 任务执行方法、装置、电子设备和存储介质 | |
CN115860114B (zh) | 深度学习模型的训练方法、装置、电子设备及存储介质 | |
CN116894917B (zh) | 虚拟形象的三维发丝模型的生成方法、装置、设备和介质 | |
CN114398130B (zh) | 页面展示方法、装置、设备和存储介质 | |
CN112311833B (zh) | 数据更新方法和装置 | |
CN114650222A (zh) | 参数配置方法、装置、电子设备和存储介质 | |
CN105989185A (zh) | 例行任务及工具生成的系统配置方法及其系统 | |
CN118433039A (zh) | 模型训练时的节点通信方法、装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |