一种机器学习工具中间件及机器学习训练方法
技术领域
本发明属于机器学习技术领域,尤其涉及一种机器学习工具中间件及机器学习训练方法。
背景技术
机器学习是人工智能的一个分支,而在很多时候,几乎成为人工智能的代名词。简单来说,机器学习就是通过机器学习算法模型,使得机器能从大量历史数据中学习规律,从而对新的样本做智能识别或对未来做预测。机器学习的一般过程是从输入数据(输入数据)中计算出机器学习算法模型参数,根据计算得到的模型参数形成机器算法模型,并对新的样本做智能识别或对未来做预测。在很多现实应用中,输入数据非常大,必须由多台计算装置同时处理才能在合理的时间内完成计算,因此必须互相交换模型参数,而交换模型参数由参数服务器来收集参数进行汇总和分发。
现有的大规模机器学习平台是一个封闭的训练框架,首先基于一个可共享的存储空间。另外例如数据文件支持的格式是有限的,模型文件支持的格式是有限的,进行机器学习训练时采用的训练目标和算法是在预先实现的有限方法中进行选择,训练过程中的参数调整方法和停止条件也是预先实现的。
而实际中不同的产品或者业务往往需要不同的数据、模型或者训练方法,基于不同的训练工具实现,这些相关文件以及训练方法往往会有很大的不同。如果基于现有的大规模机器学习平台实现,则需要完全用该平台已有的功能替换,或对该平台进行扩展以便兼容实际的机器学习任务。但是这样做,就需要进行大量的实验对比验证,而且需要对已有产品进行修改以兼容该平台的数据、模型格式。另外,也不能够保证该平台的已有实现能够达到业务上的需求。同时还需要对该平台的实现有深入的了解,而且需要花费大量的时间进行数据格式、模型格式以及训练方法的实现,对用户有很高的要求。
发明内容
本发明的目的是提供一种机器学习工具中间件及机器学习训练方法,使得各种机器学习工具不依赖于大规模机器学习平台,不需要改变具体的模型、数据文件解析、以及核心的训练方法、训练目标,就能够完成训练。
为了实现上述目的,本发明技术方案如下:
一种机器学习工具中间件,用于机器学习工具的模型训练,所述机器学习工具包括至少一个训练单元,每个训练单元都设置有与机器学习工具结合的中间件,所述中间件包括底层通信模块,以及数据分发模块、模型参数更新模块、训练参数调整模块和训练停止判断模块中的至少一块,其中:
所述底层通信模块,用于实现训练单元之间对应模块之间的通信,以及训练单元之间的通信;
所述数据分发模块,用于从数据存储设备中分发需要的数据到训练单元能够访问的存储单元,以便训练单元从所述存储单元中读取数据进行训练;
所述模型参数更新模块,用于收集其他训练单元的训练信息,更新本训练单元的模型参数;
所述训练参数调整模块,用于收集其他训练单元的训练信息,对本训练单元的训练参数进行调整;
所述训练停止判断模块,用于收集其他训练单元的训练信息,来进行是否停止训练的判断。
进一步地,所述数据存储设备用于存储机器学习工具所有训练数据,所述数据存储设备位于机器学习工具的主训练单元上。
进一步地,所述主训练单元的数据分发模块用于接收其他训练单元的数据分发模块的请求,向其他训练单元的数据分发模块分发数据,所述其他训练单元数据的数据分发模块接收分发的数据存储在本训练单元的本地存储单元。
通过设置数据分发模块实现数据的分发,训练数据从主训练单元的存储设备分发到各训练单元的本地存储单元,分发在中间件中实现,不影响训练单元的训练过程。各训练单元不需要在每次训练时到共享的存储设备去提取数据,因此降低了存储设备的工作压力,不需要共享一个大规模存储平台。
进一步地,所述模型参数更新模块收集其他训练单元的训练信息,并且将本训练单元的训练信息传送给其他训练单元,对各训练单元的模型参数进行平均更新模型参数。
或者,所述机器学习工具还包括参数服务器,所述模型参数更新模块将本训练单元的训练信息传送到参数服务器,由参数服务器更新模型参数后发回。
进一步地,所述底层通信模块还用于在实现训练单元之间对应模块之间的通信,以及训练单元之间的通信时,为各种通信之间加上互锁机制。使不同的模块不能够同时的进行通信,当一个模块正在进行通信时,其他的模块需要等待其完成才能进行通信。
本发明还提出了一种机器学习训练方法,用于机器学习工具的模型训练,所述机器学习工具包括至少一个训练单元,每个训练单元都设置有与机器学习工具结合的中间件,所述训练单元通过所述中间件进行通信,训练单元之间通过所述中间件执行如下训练操作中的至少一项完成模型训练,所述训练操作包括:
从数据存储设备中分发需要的数据到各个训练单元能够访问的存储单元,以便各个训练单元从所述存储单元中读取数据进行训练;
收集其他训练单元的训练信息,更新本训练单元的模型参数;
收集其他训练单元的训练信息,对本训练单元的训练参数进行调整;
收集其他训练单元的训练信息,来进行是否停止训练的判断。
本发明提出了一种机器学习工具中间件及机器学习训练方法,通过中间件的数据分发模块分发数据到各训练单元本地存储单元,不再依赖于大规模存储平台。中间件负责进行大规模并行训练所需要的处理:数据分发、模型参数更新、训练参数调整、训练停止同步以及训练单元之间的通信,而不改变具体的模型、数据文件解析,以及核心的训练方法、训练目标,从而不再依赖于大规模机器学习平台。本发明对各种机器学习工具方便扩展,而几乎不影响单个训练单元的训练行为,同时支持对各种数据文件格式的扩展。
附图说明
图1为本发明机器学习工具中间件结构示意图;
图2为本发明机器学习训练与中间件对应关系示意图;
图3为本发明实施例机器学习训练方法流程。
具体实施方式
下面结合附图和实施例对本发明技术方案做进一步详细说明,以下实施例不构成对本发明的限定。
机器学习工具在人工智能领域应用非常广泛,常用的机器学习工具包括Caffe、Kaldi等,机器学习工具根据已知的训练数据训练得到机器学习模型,并采用机器学习模型对未知的数据进行分析以便学习到新的知识。本发明的总体思想是提供一种机器学习工具中间件,使得机器学习工具能够适应不同的训练数据文件格式,并且该中间件能够适用于任何机器学习工具,从而满足基于不同的机器学习工具、不同的训练数据、不同的模型或者训练方法,进行机器学习模型的训练。
如图1所示,本实施例一种机器学习工具中间件,包括:数据分发模块、模型参数更新模块、训练参数调整模块、训练停止判断模块和底层通信模块。
在实际的应用中,本实施例机器学习工具通过调用中间件实现两者的结合,然后将中间件与机器学习工具部署在一个或多个服务器上同时进行训练。在进行模型训练时,机器学习工具包括至少一个基本的机器学习工具进程,用于实现对不同训练数据的并行处理,或者对不同的模型分区进行并行处理,本实施例同时支持这两种分布式并行处理方式。每一个基本的机器学习工具进程称为一个训练单元,例如部署在不同服务器上的机器学习工具及其结合的中间件构成一个训练单元,用以处理一个机器学习工具进程。
在图1中,示例性地列举了两个训练单元1和训练单元2,本发明不限于训练单元数量的多少。每个训练单元包括机器学习工具和对应的中间件,训练单元之间通过底层通信模块连接,在一个训练单元中,数据分发模块、模型参数更新模块、训练参数调整模块、训练停止判断模块均分别与机器学习工具连接,并与底层通信模块连接,底层通信模块还与机器学习工具进行连接。本实施例所述的连接,属于软件程序方面的接口调用,这里不再赘述。
其中,数据分发模块,用于从数据存储设备中分发需要的数据到各个训练单元能够访问的存储单元。
对于具有多个训练单元的机器学习工具来说,训练所用到的所有训练数据通常存储在一个主训练单元的数据存储设备中,各训练单元的数据分发模块向主训练单元对应的数据分发模块请求数据,然后通过网络传输数据文件到本地存储单元,提供给本地的训练单元使用。通常每个训练单元具有自己的数据存储单元,训练数据存储在主训练单元的存储设备中,通过数据分发模块将数据分发到各个训练单元本地的存储单元供各个训练单元使用,各训练单元从本地的存储单元读取训练数据进行训练。本实施例的存储设备和存储单元分别设置,优选地存储单元在训练单元服务器本地,也可以位于各训练单元能够访问其他存储设备。这里数据的分发是后台在中间件上进行的,不会影响训练单元实际的训练过程。这样,在训练单元处理完当前数据文件的时候,就可以直接进行下一数据文件的处理,即中间件数据分发模块已经准备好的数据文件。
模型参数更新模块,用于实现各训练单元之间模型参数的更新。当训练单元处理完若干批次数据需要进行多训练单元更新时,可以通过中间件的模型参数更新模块进行参数更新,即收集其他训练单元的训练信息,并且将本训练单元的训练信息告诉其他训练单元。这里的训练信息可以是模型参数本身,也可以是模型参数更新时的相关参数,比如梯度。而参数更新可以是各个训练单元同步进行,也可以各个训练单元异步进行,还可以通过一个虚拟的参数服务器进行。具体来说,更新方法可以是各训练单元上的模型参数进行平均(同步的),也可以是各个训练单元将梯度发送给参数服务器,由参数服务器将最新的模型参数发回,然后进行下一步的训练(异步的)。
训练参数调整模块,用于对各训练单元的训练参数进行调整。训练参数调整模块与模型参数更新模块类似,主要是将本训练单元的训练目标、学习速率等信息与其他训练单元进行交换,然后进行训练参数的调整。这样每次调整是基于所有训练单元的训练信息统一的进行调整,而不是单个训练单元的训练信息,可以提供更好的调整机制。
训练停止判断模块,用于基于所有训练单元的训练信息来进行是否停止训练的判断。与训练参数调整模块类似,训练停止判断模块是基于所有训练单元的训练信息来进行是否停止训练的判断,而不是单个训练单元的训练信息,这样可以提供更好的停止机制。
底层通信模块,用于实现训练单元之间对应模块之间的通信,以及训练单元之间的通信。
该模块主要是用来处理训练单元之间对应模块的通信,例如训练单元1与训练单元2数据分发模块之间的通信,是通过调用底层通信模块来实现数据的分发;又如两个训练单元对应的模型参数更新模块之间、两个训练单元对应的训练参数调整模块之间、两个训练单元对应的训练停止判断模块之间的通信。
同时可以提供训练单元之间进行一些必要的通信。例如:训练单元可以在具体的训练过程中通过调用底层通信模块来不断的同步综合所有训练单元的训练表现,比如训练的客观指标。又例如各个训练单元可以在具体的训练过程中通过调用底层通信模块来进行训练单元之间的统一行为控制,比如何时一致的进行实际的训练,何时一致的进行指定的测试。
同时,为了进行无风险的通信,需要在各种通信之间加上互锁机制,以保证通信安全。在某些底层的系统通信实现上,比如MPI通信协议,并不能够充分的支持多线程自由的调用进行通信。也就是说,存在一些系统底层通信协议使得不允许多个模块同时进行通信。为了保护通信安全,本实施例在底层通信模块上加入了互锁机制,使不同的模块不能够同时的进行通信,当一个模块正在进行通信时,其他的模块需要等待其完成才能进行通信。
如图2所示,采用本实施例中间件,进行一个典型的机器学习训练过程如下:
各个训练单元同时启动,主要的训练单元(能够访问模型文件、数据文件)将模型文件通过中间件底层通信模块传输给其他所有训练单元,各个训练单元读入模型文件。然后各训练单元通过中间件数据分发模块向存储有训练数据的主训练单元数据分发模块请求训练数据,主训练单元中间件数据分发模块响应请求,分发训练数据到各训练单元的本地存储单元。各个训练单元读入中间件数据分发模块准备好的数据文件,进行训练处理;同时,中间件数据分发模块继续在后台进行数据分发,准备下一批次的数据文件。
通过中间件模型参数更新模块进行参数更新,即收集其他训练单元的训练信息,并且将本训练单元的训练信息告诉其他训练单元。训练单元按照自身的训练目标以及训练方法处理完每一批次数据处理之后,通过中间件模型参数更新模块更新模型参数。或各个训练单元模型参数更新模块将梯度发送给参数服务器,由参数服务器将最新的模型参数发回,然后进行下一步的训练。
训练参数调整模块将本训练单元的训练目标、学习速率等信息与其他训练单元进行交换,然后通过中间件训练参数调整模块调整训练参数。
类似地,训练停止判断模块收集其他训练单元的训练信息,并且将本训练单元的训练信息告诉其他训练单元,基于所有训练单元的训练信息来进行是否停止训练的判断。训练单元进行每一批次数据处理的时候,通过中间件训练停止判断模块判断是否停止训练。如果判断停止,则结束训练,输出学习到的模型,否则返回继续读取训练数据,进行下一批训练数据的训练,直到完成训练过程。
上述各模块间相互传送信息数据都通过底层通信模块来进行传输。
通过上述过程,多个训练单元进行机器模型任务处理时,就可以根据自身的训练方法、算法不断的进行模型参数、训练参数的更新,对自身的模型、数据格式文件进行处理,达到大规模并行化处理的目的。
需要说明的是,本实施例的中间件中只有底层通信模块是必须的,其他模块可以根据具体的机器学习工具选择需要的模块组合。
例如:有些机器学习工具有自身的一些训练参数调整方法,这样用户就可以选择不使用本发明中的训练参数调整模块,而采用机器学习工具本身的方法,同时用本发明中的底层通信模块来同步各个机器学习程序上的训练参数,保证整体一致。又如有些机器学习工具不允许在运行时动态的读取新的数据文件,因此用户可以选择不使用本发明中的数据分发模块,而只是在训练开始前把数据先分发到各个机器上,训练时各个训练单元直接读取本机已经分发好的训练数据开始训练即可。
如图3所示,本发明实施例一种机器学习训练方法,用于机器学习工具的模型训练,该机器学习工具包括至少一个训练单元,每个训练单元都设置有与机器学习工具结合的中间件,训练单元通过中间件进行通信,训练单元之间通过所述中间件执行如下训练操作中的至少一项完成模型训练,训练操作包括:
从数据存储设备中分发需要的数据到各个训练单元能够访问的存储单元,以便各个训练单元从所述存储单元中读取数据进行训练;
收集其他训练单元的训练信息,更新本训练单元的模型参数;
收集其他训练单元的训练信息,对本训练单元的训练参数进行调整;
收集其他训练单元的训练信息,来进行是否停止训练的判断。
上述训练操作通过中间件进行,包括数据的分发、进行参数更新、调整训练参数和停止训练的判断。各训练单元通过中间件向存储有训练数据的主训练单元请求训练数据,主训练单元中间件响应请求,分发训练数据到各训练单元的本地存储单元。各个训练单元读入中间件准备好的数据文件,进行训练处理,同时,中间件在后台进行数据分发,准备下一批次的数据文件。在训练过程中,训练单元按照自身的训练目标以及训练方法处理完每一批次数据处理之后,通过中间件更新模型参数。即收集其他训练单元的训练信息,并且将本训练单元的训练信息告诉其他训练单元;或各个训练单元通过中间件将梯度发送给参数服务器,由参数服务器将最新的模型参数发回,然后进行下一步的训练。训练单元通过中间件将本训练单元的训练目标、学习速率等信息与其他训练单元进行交换,然后通过中间件调整训练参数。类似地,训练单元通过中间件收集其他训练单元的训练信息,并且将本训练单元的训练信息告诉其他训练单元,基于所有训练单元的训练信息来进行是否停止训练的判断。训练单元进行每一批次数据处理的时候,通过中间件训练判断是否停止训练,如果判断停止,则结束训练,输出学习到的模型,否则返回继续读取训练数据,进行下一批训练数据的训练,直到完成训练过程。
以上实施例仅用以说明本发明的技术方案而非对其进行限制,在不背离本发明精神及其实质的情况下,熟悉本领域的技术人员当可根据本发明作出各种相应的改变和变形,但这些相应的改变和变形都应属于本发明所附的权利要求的保护范围。