CN114021473A - 机器学习模型的训练方法、装置、电子设备及存储介质 - Google Patents

机器学习模型的训练方法、装置、电子设备及存储介质 Download PDF

Info

Publication number
CN114021473A
CN114021473A CN202111360276.7A CN202111360276A CN114021473A CN 114021473 A CN114021473 A CN 114021473A CN 202111360276 A CN202111360276 A CN 202111360276A CN 114021473 A CN114021473 A CN 114021473A
Authority
CN
China
Prior art keywords
training
machine learning
learning model
sample set
shared
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202111360276.7A
Other languages
English (en)
Inventor
黄安埠
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
WeBank Co Ltd
Original Assignee
WeBank Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by WeBank Co Ltd filed Critical WeBank Co Ltd
Priority to CN202111360276.7A priority Critical patent/CN114021473A/zh
Publication of CN114021473A publication Critical patent/CN114021473A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F30/00Computer-aided design [CAD]
    • G06F30/20Design optimisation, verification or simulation
    • G06F30/27Design optimisation, verification or simulation using machine learning, e.g. artificial intelligence, neural networks, support vector machines [SVM] or training a model
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/245Classification techniques relating to the decision surface
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Engineering & Computer Science (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Computer Hardware Design (AREA)
  • Geometry (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本申请提供了一种机器学习模型的训练方法、装置;方法包括:基于共享样本集以及训练方设备的私有样本集,对训练方设备的机器学习模型进行训练,得到训练后的机器学习模型;基于共享样本集,调用训练后的机器学习模型进行预测处理,得到预测值集合;向服务方设备发送预测值集合,预测值集合用于供服务方设备结合其他训练方设备发送的预测值集合进行融合处理,得到融合值集合;接收服务方设备发送的融合值集合,并根据融合值集合更新共享样本集;更新后的共享样本集与私有样本集,用于供训练方设备对机器学习模型进行下一轮的训练。通过本申请,能够充分利用多方样本数据减少模型训练耗时,提升模型训练效率。

Description

机器学习模型的训练方法、装置、电子设备及存储介质
技术领域
本申请涉及人工智能技术,尤其涉及一种机器学习模型的训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品。
背景技术
人工智能(Artificial Intelligence,AI)涉及领域广泛,并发挥越来越重要的价值。机器学习作为人工智能的技术子集,在多个应用领域取得了突破性成果,尤其是在金融领域,能够基于有限的用户数据来预测用户信用,从而为相关业务的开展提供重要依据,以规避金融风险。
以联邦学习为例,是一种基于数据隐私保护技术实现的分布式训练范式,它能保证训练数据在不出本地的前提下,联合多个参与方共同训练一个全局的共享模型。然而,由于共享全局模型,每个参与方只能使用统一的模型结构。但在很多应用场景中,每个参与方要解决的问题可能不相同,需要的模型也不同,相关技术提供的联邦学习训练方案,无法满足各参与方训练个性化机器学习模型的需求。
发明内容
本申请实施例提供一种机器学习模型的训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品,能够满足训练方设备训练个性化机器学习模型的需求,并且能够充分利用多方样本数据减少模型训练耗时,提升模型训练效率。
本申请实施例的技术方案是这样实现的:
本申请实施例提供一种机器学习模型的训练方法,应用于训练方设备,包括:
基于共享样本集以及所述训练方设备的私有样本集,对所述训练方设备的机器学习模型进行训练,得到训练后的所述机器学习模型;
基于所述共享样本集,调用训练后的所述机器学习模型进行预测处理,得到预测值集合;
向服务方设备发送所述预测值集合,所述预测值集合用于供所述服务方设备结合其他训练方设备发送的预测值集合进行融合处理,得到融合值集合;
接收所述服务方设备发送的所述融合值集合,并根据所述融合值集合更新所述共享样本集;其中,更新后的所述共享样本集与所述私有样本集,用于供所述训练方设备对所述机器学习模型进行下一轮的训练。
本申请实施例提供一种机器学习模型的训练装置,包括:
训练模块,用于基于共享样本集以及所述训练方设备的私有样本集,对所述训练方设备的机器学习模型进行训练,得到训练后的所述机器学习模型;
预测模块,用于接收各所述训练方设备发送的对应所述共享样本集的共享预测值;并对所述共享预测值进行融合处理,得到与各所述共享样本对应的融合值;
第一发送模块,用于向服务方设备发送所述预测值集合;其中,所述预测值集合用于供所述服务方设备结合其他训练方设备发送的预测值集合进行融合处理,得到融合值集合;
接收模块,用于接收所述服务方设备发送的所述融合值集合,并根据所述融合值集合更新所述共享样本集;其中,更新后的所述共享样本集与所述私有样本集,用于供所述训练方设备对所述机器学习模型进行下一轮的训练。
上述方案中,所述接收模块,还用于针对所述共享样本集中每个共享样本执行以下处理:
获取所述融合值集合与所述共享样本对应的融合值,基于所述融合值替换所述共享样本的标签。
上述方案中,所述训练模块,在训练所述训练方设备的机器学习模型之前,还用于向所述服务方设备发送样本获取请求;
接收所述服务方设备响应于所述样本获取请求而发送的所述共享样本集;
其中,所述共享样本集包括每个所述训练方设备的私有样本集中的部分样本。
上述方案中,所述训练模块,在训练所述训练方设备的机器学习模型之前,还用于接收所述服务方设备发送的样本上传请求;
从所述训练方设备的私有样本集中选取部分样本,并向所述服务方设备发送所述部分样本。
上述方案中,所述训练模块,还用于基于样本上传比例,从所述训练方设备的私有样本集中选取与所述样本上传比例对应的样本;
或者,
从所述训练方设备的所述私有样本集中选取目标数量的样本。
上述方案中,所述训练模块,还用于在对所述机器学习模型进行第一轮训练之前,向所述服务方设备发送模型获取请求;其中,所述模型获取请求包括模型标识;
接收所述服务方设备发送的与所述模型标识对应的初始机器学习模型,将所述初始机器学习模型作为待训练的机器学习模型。
上述方案中,所述训练模块,还用于在对所述机器学习模型进行第一轮训练之前,向所述服务方设备发送模型参数获取请求;其中,所述模型参数获取请求包括模型标识;
接收所述服务方设备发送的与所述模型标识对应的模型参数;其中,所述模型参数是所述服务方设备进行初始化得到的。
本申请实施例提供一种机器学习模型的训练方法,应用于服务方设备,包括:
向每个训练方设备发送共享样本集;其中,所述共享样本集和所述训练方设备的私有样本集,用于供所述训练方设备训练所述训练方设备的机器学习模型;
接收每个所述训练方设备发送的预测值集合并进行融合处理,得到融合值集合;其中,所述预测值集合是所述训练方设备基于所述共享样本集调用训练后机器学习学习模型进行预测处理得到的,所述机器学习模型是基于所述共享样本集和所述训练方设备的私有样本集训练的;
向所述训练方设备发送所述融合值集合;其中,所述融合值集合用于供所述训练方设备更新所述共享样本集,并结合对所述训练方设备的私有样本集对所述机器学习模型进行下一轮的训练。
本申请实施例提供一种机器学习模型的训练装置,包括:
第二发送模块,用于向每个所述训练方设备发送共享样本集;其中,所述共享样本集和所述训练方设备的私有样本集,用于供所述训练方设备训练所述训练方设备的机器学习模型;
融合模块,用于接收每个所述训练方设备发送的预测值集合并进行融合处理,得到融合值集合;其中,所述预测值集合是所述训练方设备基于所述共享样本集调用训练后机器学习学习模型进行预测处理得到的,所述机器学习模型是基于所述共享样本集和所述训练方设备的私有样本集训练的;
第三发送模块,用于向所述训练方设备发送所述融合值集合;其中,所述融合值集合用于供所述训练方设备更新所述共享样本集,并结合对所述训练方设备的私有样本集对所述机器学习模型进行下一轮的训练。
上述方案中,所述融合模块,还用于针对所述共享样本集中每个共享样本执行以下处理:
从每个所述训练方设备发送的预测值集合中,获取所述共享样本对应的多个预测值;
对所述多个预测值求平均值,将所述平均值作为所述共享样本对应的融合值,所述融合值与所述共享样本存在一一对应的关系。
本申请实施例提供一种电子设备,包括:
存储器,用于存储可执行指令;
处理器,用于执行所述存储器中存储的可执行指令时,实现本申请实施例提供的机器学习模型的训练方法。
本申请实施例提供一种计算机可读存储介质,存储有可执行指令,用于引起处理器执行时,实现本申请实施例提供的机器学习模型的训练方法。
本申请实施例提供一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现本申请实施例提供的机器学习模型的训练方法。
本申请实施例具有以下有益效果:
与相关技术中多个训练方训练全局机器学习模型的技术相比,本申请实施例中训练方设备结合本地私有样本集,以及从服务方设备下载得到的共享样本集训练自身对应的机器学习模型,使得训练方设备能够在本地训练符合自身样本特性的个性化机器学习模型;训练方设备通过从服务方设备获取的融合值集合更新已下载的共享样本集,并结合更新后的共享样本集对机器学习模型进行下一轮的训练,能够在不泄露数据隐私的前提下,充分利用多方样本数据进行模型训练,能够有效防止模型过拟合,并能减少模型训练耗时,提升模型训练效率。
附图说明
图1是本申请实施例提供的机器学习模型的训练系统的架构示意图;
图2A是本申请实施例提供的服务方设备的结构示意图;
图2B是本申请实施例提供的训练方设备的结构示意图;
图3是本申请实施例提供的机器学习模型的训练方法的流程示意图;
图4是本申请实施例提供的机器学习模型的训练方法的流程示意图;
图5是本申请实施例提供的确定融合值的流程示意图;
图6是本申请实施例提供的分布式学习系统结构示意图;
图7是本申请实施例提供的机器学习模型的训练方法的流程示意图;
图8是本申请实施例提供的联邦学习模型的结构示意图。
具体实施方式
为了使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请作进一步地详细描述,所描述的实施例不应视为对本申请的限制,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其它实施例,都属于本申请保护的范围。
在以下的描述中,涉及到“一些实施例”,其描述了所有可能实施例的子集,但是可以理解,“一些实施例”可以是所有可能实施例的相同子集或不同子集,并且可以在不冲突的情况下相互结合。
如果申请文件中出现“第一/第二”的类似描述则增加以下的说明,在以下的描述中,所涉及的术语“第一\第二\第三”仅仅是是区别类似的对象,不代表针对对象的特定排序,可以理解地,“第一\第二\第三”在允许的情况下可以互换特定的顺序或先后次序,以使这里描述的本申请实施例能够以除了在这里图示或描述的以外的顺序实施。
除非另有定义,本文所使用的所有的技术和科学术语与属于本申请的技术领域的技术人员通常理解的含义相同。本文中所使用的术语只是为了描述本申请实施例的目的,不是旨在限制本申请。
对本申请实施例进行进一步详细说明之前,对本申请实施例中涉及的名词和术语进行说明,本申请实施例中涉及的名词和术语适用于如下的解释。
1)服务方,是分布式学习中的一种特殊的参与方,是一种管理者角色,负责与其他参与方之间的机器学习模型的同步,服务方用于训练机器学习模型的设备称为服务方设备,例如服务器。
2)训练方,是分布式学习中为基于自身存储的样本来训练模型,为服务方整合得到全局机器学习模型做出贡献的一方,训练方用于训练机器学习模型的设备称为训练方设备,例如服务器。
3)联邦学习,是一种分布式学习的机器学习框架,在保障数据交换时的信息安全、保护终端数据和个人数据隐私、保证合法合规的前提下,在多参与方的计算设备之间开展高效率的机器学习。
4)横向联邦学习也称为特征对齐的联邦学习(Feature-Aligned FederatedLearning),即横向联邦学习的参与者的数据特征是对齐的,适用于参与者的数据特征重叠较多,而样本标识(ID,Identity document)重叠较少的情况。纵向联邦学习也称为样本对齐的联邦学习(Sample-Aligned Federated Learning),即纵向联邦学习的参与者的样本是对齐的,适用于参与者样本ID重叠较多,而数据特征重叠较少的情况。
联邦学习能够保证训练数据在不出本地的前提下,联合多个参与方共同训练一个全局的共享模型。然而申请人在实际实施时发现,由于在很多实际应用场景中,每个训练方要解决的问题可能不相同,因此需要利用其它训练方的数据信息训练一个符合自身样本数据特点的个性化模型。
基于此,本申请实施例提供了一种机器学习模型的训练方法、装置、电子设备、计算机可读存储介质及计算机程序产品,能够满足各训练方设备训练个性化机器学习模型的需求,并且能够充分利用多方样本数据减少模型训练耗时,提升模型训练效率。
首先对本申请实施例提供的机器学习模型的训练系统进行说明,参见图1,图1是本申请实施例提供的机器学习模型的训练系统的架构示意图,在机器学习模型的训练系统100中,训练方设备400通过网络300连接服务方设备200,网络300可以是广域网或者局域网,又或者是二者的组合,使用无线链路实现数据传输。在一些实施例中,训练方设备400可以是终端设备(例如台式机电脑、笔记本电脑)、服务器、服务器集群或者分布式系统。服务方设备200可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(CDN,Content Delivery Network)、以及大数据和人工智能平台等基础云计算服务的云服务器。网络300可以是广域网或者局域网,又或者是二者的组合。训练方设备400以及服务方设备200可以通过有线或无线通信方式进行直接或间接地连接,本申请实施例中不做限制。
示例性地,训练方设备400-1和400-2从服务方设备200获取共享样本集,并根据共享样本集以及各自的私有样本集,训练从服务方设备200下载的机器学习模型;然后调用训练后的机器学习模型对共享样本集进行预测处理,得到预测值集合;随后向服务方设备200发送预测值集合;服务方设备200对训练方设备400-1和400-2发送的预测值集合进行融合处理,得到融合值集合,并向训练方设备400-1和400-2发送融合值集合,训练方设备400-1和400-2接收到服务方设备200发送的融合值集合后,根据融合值集合更新共享样本集,并根据更新后的共享样本集与私有样本集对机器学习模型进行下一轮的训练。
接下来说明服务方设备200的结构。参见图2A,图2A是本申请实施例提供的服务方设备的结构示意图,图2A所示的服务方设备200包括:至少一个处理器210、存储器230和至少一个网络接口220。服务方设备200中的各个组件通过总线系统240耦合在一起。可理解,总线系统240用于实现这些组件之间的连接通信。总线系统240除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。但是为了清楚说明起见,在图2A中将各种总线都标为总线系统240。
处理器210可以是一种集成电路芯片,具有信号的处理能力,例如通用处理器、数字信号处理器(DSP,Digital Signal Processor),或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件等,其中,通用处理器可以是微处理器或者任何常规的处理器等。
存储器230包括易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。非易失性存储器可以是只读存储器(ROM,Read Only Memory),易失性存储器可以是随机存取存储器(RAM,Random Access Memory)。本申请实施例描述的存储器230旨在包括任意适合类型的存储器。
存储器230可以是可移除的,不可移除的或其组合。示例性的硬件设备包括固态存储器,硬盘驱动器,光盘驱动器等。存储器230可选地包括在物理位置上远离处理器210的一个或多个存储设备。
存储器230包括易失性存储器或非易失性存储器,也可包括易失性和非易失性存储器两者。非易失性存储器可以是只读存储器(ROM,Read Only Memory),易失性存储器可以是随机存取存储器(RAM,Random Access Memory)。本申请实施例描述的存储器230旨在包括任意适合类型的存储器。
在一些实施例中,存储器230能够存储数据以支持各种操作,这些数据的示例包括程序、模块和数据结构或者其子集或超集,下面示例性说明。
操作系统231,包括用于处理各种基本系统服务和执行硬件相关任务的系统程序,例如框架层、核心库层、驱动层等,用于实现各种基础业务以及处理基于硬件的任务;
网络通信模块232,用于经由一个或多个(有线或无线)网络接口220到达其他计算设备,示例性的网络接口220包括:蓝牙、无线相容性认证(WiFi)、和通用串行总线(USB,Universal Serial Bus)等;
在一些实施例中,本申请实施例提供的机器学习模型的训练装置可以在服务方设备200中采用软件方式实现,图2A示出了存储在存储器230中的机器学习模型的训练装置233,其可以是计算机程序和插件等形式的软件。机器学习模型训练装置233包括以下软件模块:第二发送模块2331、融合模块2332和第三发送模块2333。这些模块是可以是逻辑功能模块,因此根据所实现的功能可以进行任意的组合或进一步拆分。将在下文中说明各个模块的功能。
参见图2B,图2B是本申请实施例提供的训练方设备的结构示意图,训练方设备在图1中示例性示出了400-1和400-2,图2B所示的训练方设备400包括:至少一个处理器410、存储器430和至少一个网络接口420。训练方设备400中的各个组件通过总线系统440耦合在一起。可理解,总线系统440用于实现这些组件之间的连接通信。
训练方设备400中的处理器410、网络接口420以及总线系统440和服务方设备200中的处理器210、网络接口220以及总线系统440的实现方式相似,在此处将不再进行赘述。
在一些实施例中,本申请实施例提供的机器学习模型的训练装置可以在训练方设备400中采用软件方式实现,图2B示出了存储在存储器430中的机器学习模型的训练装置433,其可以是计算机程序和插件等形式的软件。机器学习模型的训练装置433包括以下软件模块:训练模块4331、预测模块4332、第一发送模块4333和接收模块4334,这些模块是逻辑上的,因此根据所实现的功能可以进行任意的组合或进一步拆分,将在下文中说明各个模块的功能。
在另一些实施例中,本申请实施例提供的机器学习模型的训练装置可以采用硬件方式实现,作为示例,本申请实施例提供的机器学习模型的训练装置可以是采用硬件译码处理器形式的处理器,其被编程以执行本申请实施例提供的机器学习模型的训练方法,例如,硬件译码处理器形式的处理器可以采用一个或多个应用专用集成电路(ASIC,Application Specific Integrated Circuit)、DSP、可编程逻辑器件(PLD,ProgrammableLogic Device)、复杂可编程逻辑器件(CPLD,Complex Programmable Logic Device)、现场可编程门阵列(FPGA,Field-Programmable Gate Array)或其他电子元件。
将结合本申请实施例提供的训练方设备的示例性应用和实施,说明本申请实施例提供的机器学习模型的训练方法。参见图3,图3是本申请实施例提供的机器学习模型的训练方法的流程示意图,将结合图3示出的步骤对模型训练的任意轮迭代过程进行说明。
在步骤101中,训练方设备基于共享样本集以及训练方设备的私有样本集,训练训练方设备的机器学习模型。
在一些实施例中,训练方设备对符合自身样本数据的机器学习模型进行训练之前,可以先从服务方设备中下载初始机器学习模型。需要说明的是,当训练方设备的数量为多个时,各训练方设备可以共同训练一个全局机器学习模型,即各训练方设备均从服务方设备中下载一个全局机器学习模型到本地进行模型训练。另外,各训练方设备也可以各自训练符合自身样本特性的个性化的机器学习模型,即各训练方设备训练的机器学习模型不同,如金融机构中预测用户信用评分的风控模型,联合建模的各金融机构可以采用不同形式(如神经网络模型、决策树模型、随机森模型)的风控模型来预测用户的信用评分;在如,医学领域预测疾病的死亡率的诊断模型,联合建模的各诊疗机构同样可以使用不同形式的诊断模型进行死亡率的预测。需要说明的是,个性化的机器学习模型输入的样本信息包含的特征属性是一样的,输出的模型预测值的含义也是一样的。
示例性地,以金融机构中确定用户的贷款风险等级的风控模型为例,通过使用多个训练方的样本数据来以联邦学习的方式训练风控模型,风控模型是机器学习模型的一种,可以基于用户数据(用户个人信息、银行账号流水数据以及征信记录等)预测用户的信用评分、贷款风险等级。以风控模型为多分类模型为例,风控模型输入信息是训练方存储的样本数据,样本数据可以是用于风控模型的用户数据,记载了脱敏后的用户信息。例如,各训练方设备存储的样本数据可以包括用户的ID、年龄、性别、职业、收入信息、交易记录、银行账号的流水数据,以及征信信息等。风控模型的输出信息是用户的信用评分、贷款风险等级等。例如,贷款风险等级可以分为{A、B、C、D、E}五种等级。对于用于预测用户对应的贷款风险等级的风控模型的实现方式可以有多种,例如风控模型为神经网络模型、树模型或者随机森林模型等。在实际应用中,各训练方设备可以利用本地私有样本集训练一个全局风控模型,即各训练方设备训练的风控模型的类型是一样的。同时,由于各训练方设备可能需要解决的实际问题不同,因此需要的机器学习模型也不一样,举例来说,训练方设备ClientA预测用户的贷款风险等级时,根据自身样本数据的特点以及实际的计算能力等,使用的风控模型是线性分类模型;而与ClientA进行联合训练的训练方设备ClientB根据自身情况,预测用户的贷款风险等级时,可以通过树模型实现。
在一些实施例中,各训练方设备在对机器学习模型进行第一轮训练之前,可以通过以下方式从服务方设备获取对应的机器学习模型:训练方设备在对机器学习模型进行第一轮训练之前,向服务方设备发送模型获取请求;其中,模型获取请求包括模型标识;接收服务方设备发送的与模型标识对应的初始机器学习模型,将初始机器学习模型作为待训练的机器学习模型。
在实际实施时,各个训练方设备的模型,可以全部相同、部分相同或者完全不同。当各训练方设备共同训练同一个全局机器学习模型时,可以直接向服务方设备发送模型获取请求,此时的模型获取请求可以不携带模型标识。当各训练方设备各自训练符合自身条件的机器学习模型时,可以向服务方设备发送携带有模型标识的模型获取请求,同时,为了确定相应的训练方设备,模型获取请求中还可以携带设备标识,以供服务方设备解析模型获取请求,得到模型标识以及设备标识,并向设备标识对应的训练方设备发送与模型标识匹配的机器学习模型。训练方设备将从服务器获取的初始机器学习模型作为自身待训练的机器学习模型。如此,训练方设备能够训练符合自身个性化的机器学习模型。
在一些实施例中,各训练方设备在获取得到机器学习模型后,对机器学习模型进行第一轮训练之前,可以通过随机设置的方式对机器学习模型的模型参数进行初始化;另外,还可以通过以下方式从服务方设备获取机器学习模型对应的模型参数对机器学习模型进行初始化:训练方设备向服务方设备发送模型参数获取请求,模型参数获取请求包括模型标识;接收服务方设备发送的与模型标识对应的模型参数,模型参数是服务方设备进行初始化得到的。如此,训练方设备在对本地机器学习模型进行初始化时,不使用随机设置的模型参数,而是使用服务方设备经过共享样本集进行预设轮次的训练后的模型参数,能够有效减少训练方设备的计算量以及与服务方设备之间的通信量。
在一些实施例中,各训练方设备在获取得到机器学习模型后,在结合共享样本集以及本地私有样本集对机器学习模型进行第一轮训练之前,还需要从服务方设备下载共享样本集。可以通过以下方式从服务方设备获取共享样本集:训练方设备向服务方设备发送样本获取请求;接收服务方设备响应于样本获取请求而发送的共享样本集,共享样本集包括每个训练方设备的私有样本集中的部分样本。
在实际实施时,训练方设备与服务方设备之间可以建立互信机制,当训练方设备通过服务方设备的授权且被服务方设备验证成功时,训练方设备可以向服务方设备发送样本获取请求,并能接收到服务方设备响应样本获取请求所发送的共享样本集。如此,各训练方设备通过从服务方设备中下载共享样本集,能够提升训练样本集中样本的数量,并能够增加样本标签的均匀分布。
示例性地,以确定用户贷款风险等级的风控模型为例,贷款风险等级被划分为{A、B、C、D、E}五个等级,针对训练方设备ClientA而言,本地私有样本集中样本标签主要集中在{A、B、C}等级,针对训练方设备ClientB而言,本地私有样本集中样本标签主要集中在{C、D、E}等级,针对训练方设备ClientC而言,本地私有样本集中样本标签主要集中在{A、C、D}等级。因此,在训练方设备ClientA进行模型训练时,为了使得本地训练样本的样本标签能够覆盖{A、B、C、D、E}五个等级,进而提升机器学习模型的模型效果,可以从服务方设备获取共享样本集,用于补充本地私有样本集缺少的或者数据量少的样本标签为{D、E}的样本数据。相应地,训练方设备ClientB进行模型训练时,需要从服务方设备获取共享样本集,用于补充本地私有样本集缺少的或者数据量少的样本标签为{A、B}的样本数据;训练方设备ClientC进行模型训练时,需要从服务方设备获取共享样本集,用于补充本地私有样本集缺少的或者数据量少的样本标签为{B、E}的样本数据。
继续对共享样本集进行说明,在一些实施例中,各训练方设备可以根据共同约定提取一部分本地样本上传至服务方设备,训练方设备可以通过以下方式实现样本上传:训练方设备接收服务方设备发送的样本上传请求;从训练方设备的私有样本集中选取部分样本,并向服务方设备发送部分样本。
在实际实施时,各训练方设备在接收到样本上传请求后,向服务方设备上传私有样本集中的部分样本,需要说明的是,训练方设备是在服务方设备授权后,才能向服务方设备上传样本,如此,服务方设备可以根据训练方设备的授权信息,向训练方设备发送共享样本集,通过这种授权的方式能够有效保证共享样本集的数据隐私性和安全性。
在一些实施例中,各训练方设备可以通过以下方式提取部分本地样本:基于样本上传比例,训练方设备从本地私有样本集中选取与样本上传比例对应的样本;或者,从训练方设备的私有样本集中选取目标数量的样本。
在实际实施时,训练方设备可以根据预先设定的上传比例从私有样本集中选取与样本比例对应的样本。另外,还可以根据样本上传请求中携带的目标数量上传本地样本。
示例性地,当训练方设备包含的样本数据量大且样本标签信息分布均匀时,可以上传数量较多的样本;当训练方设备包含的样本数据量大且样本标签信息分布不均匀时,可以上传其他训练方设备缺少或数量相对少的样本标签对应的样本。需要说明的是,在一些实施例中,可以由一个训练方设备上传所有的共享样本集,或者由服务器从其它不参与本轮训练的客户端获取。
在一些实施例中,训练方设备在从服务方设备获取得到对应的待训练机器学习模型以及共享样本集后,结合本地私有样本集对待训练机器学习模型进行训练,得到训练后的机器学习模型。如此,能够满足各训练方设备在不泄露用户数据隐私的前提下,利用其它训练方设备的数据提升自身的模型效果。
在步骤102中,训练方设备基于共享样本集,调用训练后的机器学习模型进行预测处理,得到预测值集合。
在一些实施例中,训练方设备经过步骤101得到训练后的机器学习模型。训练方设备调用训练得到的机器学习模型,对共享样本集进行预测,得到共享样本集中各共享样本对应的预测值,作为预测值集合。
在步骤103中,训练方设备向服务方设备发送预测值集合,预测值集合用于供服务方设备结合其他训练方设备发送的预测值集合进行融合处理,得到融合值集合。
在一些实施例中,训练方设备向服务方设备发送共享样本集对应的预测值集合,以供服务方设备结合其他训练方设备发送的共享样本集对应的预测值结合进行融合处理,从而得到融合值集合。
在实际实施时,为了使得共享样本对应的标签数据更接近真实值,可以通过至少两个训练方设备确定每条共享样本对应的至少两个预测值,以使服务方设备根据每条共享样本对应的至少两个预测值,确定该条共享样本对应的融合值。如此,可以进一步的依靠服务方的计算能力来快速处理,降低了训练方设备的计算量,提升了训练效率。并且,由于训练参与方只传输了共享样本集对应的预测值集合,而未传输样本数据,进一步保证了机器学习模型的安全性。
示例性地,存在三个训练方设备ClientA、ClientB以及ClientC,以及包含1000条共享样本的共享样本集,ClientA、ClientB以及ClientC分别调用自身训练后的机器学习模型,对这1000条共享样本进行预测,则针对每条共享样本得到3个预测值(ClientA得到的预测值V1、ClientB得到的预测值V2、ClientC得到的预测值V3),ClientA、ClientB以及ClientC分别将每条共享样本对应的预测值上传到服务方设备,即服务方设备中针对每条共享样本都对应有三个预测值。
在步骤104中,训练方设备接收服务方设备发送的融合值集合,并根据融合值集合更新共享样本集;其中,更新后的共享样本集与各训练方的私有样本集用于供训练方设备对机器学习模型进行下一轮的训练。
在一些实施例中,训练方设备根据获取的融合值集合更新本地共享样本集,可以通过以下方式实现本地共享样本集的更新:训练方设备针对共享样本集中每个共享样本执行以下处理:获取融合值集合与共享样本对应的融合值,基于融合值替换共享样本的标签。
在实际实施时,训练方设备从服务方设备获取共享样本集对应的融合值集合,并使用每条共享样本的融合值替换共享样本的标签信息,得到更新后的共享样本集,训练方设备基于更新后的共享样本集以及本地私有样本集对相应的机器学习模型进行下一轮的训练。
承接上例,训练方设备ClientA从服务方设备得到针对1000条共享样本的融合值集合
Figure BDA0003358959530000131
并使用这个融合值集合替换1000条共享样本的标签信息,即将共享样本集{[X1,Y1]、[X2,Y2]、......、[X1000,Y1000]}更新成共享样本集
Figure BDA0003358959530000132
训练方设备ClientA再次根据更新后的共享样本集以及私有样本集对机器学习模型进行下一轮训练。
需要说明的是,训练方设备使用存储的私有样本集以及共享样本集训练相应的机器学习模型后,如果机器学习模型不满足收敛条件,可以继续重复上述步骤101至步骤104,直至训练方设备对应的机器学习模型满足收敛条件,其中,收敛条件可以是模型收敛,或迭代训练次数到达预设迭代训练次数。
将结合本申请实施例提供的服务方设备的示例性应用和实施,说明本申请实施例提供的机器学习模型的训练方法。参见图4,图4是本申请实施例提供的机器学习模型的训练方法的流程示意图,将结合图4示出的步骤进行说明。
在步骤201中,服务方设备向每个训练方设备发送共享样本集;共享样本集和训练方设备的私有样本集,用于供训练方设备训练训练方设备的机器学习模型。
在一些实施例中,服务方设备在接收到训练方设备的样本获取请求后,会向相应的训练方设备发送共享样本集,以供训练方设备根据共享样本集以及本地的私有样本集,训练从服务器方下载的待训练的机器学习模型。
在实际实施时,服务方存储有共享样本集D0={[X1,Y1]、[X2,Y2]、......、[XN,YN]},D0中包含有N(N≥1且N为整数)条共享样本,其中,这N条共享样本是由训练方设备上传的。需要说明的是,各训练方设备上传的样本包含的特征数据以及标签信息是完全相同的类型或者是具有相同含义的。另外,在实际应用中,这N条共享样本可以由至少两个训练方设备提供;也可以是由一个训练方设备提供的,若只有一个训练方作为共享样本的提供方时,该训练方具有的样本数据通常具有数据量大且样本特征数据以及标签数据分布均匀的特点。样本特征数据以及标签数据分布均匀可以理解成对标签的不同含义能够全覆盖。
示例性地,以风控模型对用户的信用进行等级划分的多分类模型为例,用户的信用等级就是样本对应的标签数据,用户的信用等级被划分为{A、B、C、D、E}五个等级,则共享样本集中需要包括这五个等级对应的样本,若以一个训练方作为共享样本集的提供方,该训练方中应该包含这五个等级对应的所有的样本信息。
在步骤202中,服务方设备结合每个训练方设备发送的预测值集合进行融合处理,得到融合值集合;预测值集合是训练方设备基于共享样本集调用训练后机器学习学习模型进行预测处理得到的,机器学习模型是基于共享样本集和训练方设备的私有样本集训练的。
在一些实施例中,服务方设备将共享样本集发送至训练方设备后,在训练方设备对本地的机器学习模型进行一定轮次的模型训练后,会接收到训练方设备上传的共享样本集对应的预测值集合,然后,服务方设备对预测值集合进行融合处理,得到融合值集合。参见图5,图5是本申请实施例提供的确定融合值的流程示意图,基于图4,步骤202中针对共享样本集中每个共享样本,服务方设备可以执行步骤2021至步骤2022获取每个共享样本对应的融合值,结合图5示出的步骤进行说明。
步骤2021,服务方设备从每个训练方设备发送的预测值集合中,获取共享样本对应的多个预测值。
在实际实施时,针对共享样本对应的预测值的融合方式可以有多种。可以是针对训练相同模型的训练方上传的预测值集合进行融合;还可以是针对全部训练方上传的预测集合进行融合;同样也可以是针对彼此设为信任方的训练方上传的预测集合进行融合。
步骤2022,服务方设备对多个预测值求平均值,将平均值作为共享样本对应的融合值,融合值与共享样本存在一一对应的关系。
在实际实施时,针对N条共享样本对应的包含N个预测值的预测值集合,得到包含N个融合值的融合值集合,可记作
Figure BDA0003358959530000151
需要说明的是,在一些实施例中,服务方设备接收到训练方设备上传的共享样本集的预测值集合后,可以结合自身存储的共享样本对应的标签值,对预测值集合进行融合处理。
示例性地,服务方设备存储的共享样本集D0={[X1,Y1]、[X2,Y2]、......、[XN,YN]},训练方设备ClientA上传的针对共享样本集的预测值集合为
Figure BDA0003358959530000152
训练方设备ClientB上传的针对共享样本集的预测值集合为
Figure BDA0003358959530000153
训练方设备ClientC上传的针对共享样本集的预测值集合为
Figure BDA0003358959530000154
针对共享样本集D0的预测值集合为
Figure BDA0003358959530000155
在另一些实施例中,服务方设备还可以获取共享样本集中各共享样本对应的至少两个预测值,并获取至少两个预测值的中位数,将中位数作为当前共享样本的融合值;还可以使获取至少两个预测值的方差,将方差作为当前共享样本的融合值。需要说明的是,各共享样本对应的预测值的数量与进行模型训练的训练方设备的数量有关。
承接上例,共享样本集D0中包含1000条共享样本,服务方设备将这10000条共享样本发送到三个训练方设备{ClientA、ClientB、ClientC},每个训练方设备调用相应的机器学习模型,得到这1000条共享样本对应的1000个预测值。由于存在三个训练方设备,则服务方设备会接收到3000个预测值,其中,针对每条共享样本对应3个预测值,服务方设备针对每条共享样本对应的3个预测值求均值、中位数或方差等,可以得到每个共享样本对应的融合值。
需要说明的是,各训练方设备对应的机器学习模型类型可以相同,也可以不同,由于每个训练方设备的私有样本集不同,因此,根据私有样本集以及共享样本集训练得到的针对同样模型结构的机器学习模型的模型参数不同。
在步骤203中,服务方设备向训练方设备发送融合值集合;融合值集合用于供训练方设备更新共享样本集,并结合训练方设备的私有样本集对机器学习模型进行下一轮的训练。
在一些实施例中,服务方设备向各训练方设备发送共享样本集对应的融合值集合,以使训练方设备根据融合值集合更新本地对应的共享样本集对应的样本标签,得到新的共享样本集。
承接上例,以融合值是均值为例,对应的融合值集合可记作
Figure BDA0003358959530000161
服务方设备分别向三个训练方设备发送1000个融合值,以供训练方设备根据接收到的融合值集合,更新共享样本集中的样本标签。
训练方设备通过步骤101至步骤104实现机器学习模型的训练过程,服务方设备通过步骤201至步骤203实现针对共享样本集的融合值集合获取操作,并协助训练方设备完成对机器学习模型的训练过程。
以确定用户贷款风险等级的风控模型为例,风控模型是多分类模型,参见图6,图6是本申请实施例提供的分布式学习系统结构示意图,图中整个分布式学习(联邦学习)架构中包含三个训练方设备(训练方1、训练方2、训练方3),且每个训练方设备使用的风控模型的结构不同,分别对应模型M1、模型M2以及模型M3。其中,模型M1、模型M2以及模型M3虽然模型结构不同,但输入的样本具有相同的特征,只是特征对应的实际数据不同,并且模型输出的预测值的含义相同。
设服务方设备存储的共享训练集包含5条样本D0=(A1、A2、A3、A4、A5),其中,每条记录Ai(i≥1)都包含特征数据X(x1,x2,……,xp)和标签数据Y,而不是单独的值,所以D0的完整写法是D0={A1=(X1,Y1)、A2=(X2,Y2)、A3=(X3,Y3)、A4=(X4,Y4)、A5=(X5,Y5)}。训练方1的本地私有样本集D1有三条样本:D1={A1-1=(X1-1,Y1-1)、A1-2=(X1-2,Y1-2)、A1-3=(X1-3,Y1-3)};训练方2的本地私有样本集D2有四条样本:D2={A2-1=(X2-1,Y2-1)、A2-2=(X2-2,Y2-2)、A2-3=(X2-3,Y2-3)、A2-4=(X2-4,,Y2-4)};训练方3的本地私有样本集D3有两条样本:D3={A3-1=(X3-1,Y3-1)、A3-2=(X3-2,Y3-2)}。
训练方1对模型1进行第q(q≥1且q为整数)轮训练过程如下:训练方1将D0和D1都作为训练样本对模型M1进行一定轮次的训练,模型M1经过训练后,模型M1变为M1-1。然后训练方1将共享数据集D0的特征数据X=(X1,X2,X3,X4,X5)输入至模型M1-1中,得到对应共享样本集的5个预测值{Y1-1、Y1-2、Y1-3、Y1-4、Y1-5},并将这5个预测值上传至服务方设备。同时,训练方2执行与训练方1同样的操作,根据D0和D2对模型M2进行训练,得到模型M2-1,通过模型M2-1,得到对应D0的5个预测值{Y2-1、Y2-2、Y2-3、Y2-4、Y2-5},并将这5个预测值上传至服务方设备。另外,训练方3执行与训练方1(或训练方2)同样的操作,得到对应D0的5个预测值{Y3-1、Y3-2、Y3-3、Y3-4、Y3-5},并将这5个预测值上传至服务方设备。服务方设备对{Y1-1、Y1-2、Y1-3、Y1-4、Y1-5}、{Y2-1、Y2-2、Y2-3、Y2-4、Y2-5}、{Y3-1、Y3-2、Y3-3、Y3-4、Y3-5}共15个数据求平均,得到对应共享样本集{A1、A2、A3、A4、A5}每条共享样本的均值,{(Y1-1+Y2-1+Y3-1)/3,(Y1-2+Y2-2+Y3-2)/3,(Y1-3+Y2-3+Y3-3)/3,(Y1-4+Y2-4+Y3-4)/3,(Y1-5+Y2-5+Y3-5)/3},将这5个均值称为融合值,服务方设备将这5个融合值分发至各训练方设备,训练方设备使用融合值替换原来共享数据集D0的标签,也就是训练方1中的共享样本集D0变为:D'0={A1=(X1,(Y1-1+Y2-1+Y3-1)/3)、A2=(X2,(Y1-2+Y2-2+Y3-2)/3)、A3=(X3,(Y1-3+Y2-3+Y3-3)/3)、A4=(X4,(Y1-4+Y2-4+Y3-4)/3)、A5=(X5,(Y1-5+Y2-5+Y3-5)/3)},训练方1对模型M1的第q轮训练结束。训练方1、2、3分别对各自的模型M1、M2、M3重复上面的训练过程,直至模型收敛。需要说明的是,训练方1、2、3对各自的模型进行第q+1轮训练时,使用的共享数据集由D0变为D'0,后续训练过程与第q轮相似。
本申请实施例中训练方设备通过结合私有样本集以及共享样本集对机器学习模型进行训练,如此,能够定制并训练符合自身实际情况的个性化的机器学习模型;另外,训练方设备使用共享样本集对应的融合值集合更新,能够在不泄露用户数据隐私的前提下,充分利用其它训练方的数据信息提升自身的模型效果。
接下来以横向联邦学习的应用场景为例,对本发明实施例提供的机器学习模型的训练方法进行说明。在横向联邦学习的场景下,可以包括服务方设备以及至少两个训练方设备。参见图7,图7为本申请实施例提供的机器学习模型的训练方法的流程示意图,本申请实施例提供的机器学习模型的训练方法由服务方设备和训练方设备协同实施,将结合图7示出的步骤进行说明。
步骤401,服务方设备向各训练方设备发送样本上传请求。
在实际实施时,服务方设备为了存储共享样本集,可以向经过授权且验证成功的训练方设备发送样本上传请求。在横向联邦学习的应用场景下,多个训练方设备作为模型训练数据的拥有方,所拥有的数据集中用户重叠相对少而用户特征重叠相对较多,比如多个训练方设备可以为多家不同的地区银行,它们的用户来自各自所在的地区(即样本不同),但业务相同(即特征相同)。在办理相同业务时,多个训练方设备可以根据私有样本的数据量、实际网络情况以及自身设备运算能力,采用不同结构的机器学习模型。同时,为了保证参与模型训练的样本分布的均衡性,防止模型过拟合,联合建模的多个训练方设备可以在与服务方设备建立互信机制的情况下,在接收到服务方设备发送的样本上传请求时,从本地私有样本中选取部分样本存储到服务方设备,组成共享样本集,供各训练方下载使用。
示例性地,以两家不同地区的银行作为训练方,并根据自身存储的用户特征数据,各自训练用于确定用户的贷款风险等级的风控模型为例。其中,用户特征数据可以包括用户的基础信息、业务数据、行为数据等。基础信息可以包括用户的姓名、性别、年龄、学历、职业、手机号、身份证号、地址、收入数据等。业务数据可以包括贷款业务对应的贷款数据及还款数据等。用户特征数据中贷款风险等级可以包括a、b、c、d四个等级。可以理解的是,两家银行(后续称为训练方A、训练方B)可以训练一个符合自身实际情况的,用于确定用户贷款等级的多分类模型。如,训练方A适合使用采用线性分类模型MA确定用户贷款风险等级,MA模型结构可以为:y=a1x1+a2x2+a3x3+……+amxm,其中,{x1、x2、……、xm}为对应的m个特征,m≥1且m为整数,{a1、a2、……、am}为对应各特征的权重。训练方B使用非线性分类模型MB确定用户贷款风险等级,MB的模型结构可以为:
Figure BDA0003358959530000181
其中,{b1、b2、……、bm}为对应各特征的权重。由于训练方A以及训练B所拥有的用户特征数据,都不能全部覆盖a、b、c、d四个等级,为了防止模型过拟合,训练方A和训练方B在接收到服务方设备发送的样本上传请求后,可以将本地部分用户特征数据存储到服务方设备。
步骤402,各训练方设备接收到样本上传请求后,分别从本地私有样本集中选取目标数量的样本。
在实际实施时,接收到样本上传请求的训练方设备可以按照预先约定的规则,从本地私有样本集中选取目标数量的样本。例如,可以按照约定的上传比例,选取待上传的样本。
承接上例,训练方A和训练方B各自选取本地部分数据存储到服务方设备,选取的规则可以是,能够尽量覆盖a、b、c、d四个贷款风险等级的数据。
步骤403,各训练方设备向服务方设备发送目标数量的样本。
在实际实施时,各训练方设备将从本地选取的目标数量的样本作为共享样本发送至服务方设备。
通过步骤401至步骤403,各训练方设备可以在接收到服务方设备发送的样本上传请求后,根据预先约定的规则,分别上传本地私有训练样本集中的部分样本至服务方设备,进而得到共享样本集。
承接上例,服务方设备存储训练方A和训练方B各自选取本地部分数据作为共享样本集Ds,训练方A和训练方B都可以从下载共享样本集训练各自的模型。
步骤404,各训练方设备向服务方设备发送携带模型标识的模型获取请求。
在实际实施时,训练方设备在进行机器学习模型训练前,会先向服务方设备发送携带有模型标识的模型获取请求,以便从服务方设备下载与模型标识匹配的机器学习模型作为训练方设备的待训练的机器学习模型。
承接上例,训练方A和训练方B对应的模型是从服务方设备中下载的,训练方A向服务方设备发送携带有线性分类模型MA对应的模型标识的模型获取请求,训练方B向服务方设备发送携带有非线性分类模型MB对应的模型标识的模型获取请求。
步骤405,服务方设备接收并解析模型获取请求,得到模型标识。
在实际实施时,模型获取请求中还可以携带有发送该模型获取请求的训练方设备的设备标识,服务方设备解析模型获取请求,并根据得到的模型标识,向设备标识对应的训练方设备发送匹配的机器学习模型。
步骤406,服务方设备向各训练方设备发送与模型标识对应的机器学习模型。
承接上例,服务方设备向训练方A发送对应的线性分类模型MA,服务方设备向训练方B发送对应的非线性分类模型MB
通过步骤404至步骤406,各训练方设备在进行模型训练之前,可以通过向服务方设备发送模型获取请求,从服务方设备获取与各训练方设备对应的机器学习模型作为各训练方设备的待训练模型。
需要说明的是,步骤401至步骤403确定共享样本集的过程,与步骤404至步骤406获取待训练机器学习模型的过程,不存在严格意义的先后关系。
步骤407,训练方设备向服务方设备发送样本获取请求。
在实际实施时,训练方设备为了使用服务方设备存储的共享样本集,可以向服务方设备发送样本获取请求,其中,样本获取请求中可以携带训练方设备的设备标识。
承接上例,训练方A向服务方设备发送携带有训练方A的设备标识的样本获取请求,训练方B向服务方设备发送携带有训练方B的设备标识的样本获取请求。
步骤408,服务方设备响应于样本获取请求,向训练方设备发送的共享样本集。
在实际实施时,服务方设备解析样本获取请求,为了保证共享样本集的数据隐私和安全性,首先会通过样本获取请求中携带的设备标识,验证设备标识对应的训练方设备是否是授权且验证成功的设备,只有设备标识对应的训练方设备是授权且验证成功的设备,才会继续进行后续操作。
承接上例,服务方设备向训练方A和训练方B发送共享样本集Ds
通过步骤407以及步骤408,各训练方设备从服务方设备下载共享样本集至本地,能够有效增加本地用于模型训练的样本集的数量。
步骤409,各训练方设备基于共享样本集以及本地私有样本集,训练各自的机器学习模型。
在实际实施时,训练方设备根据共享样本集以及本地私有样本集,训练自身的机器学习模型,得到训练后的机器学习模型。
承接上例,训练方A使用共享样本集Ds以及本地数据训练线性分类模型MA,训练方B使用共享样本集Ds以及本地数据训练非线性分类模型MB
步骤410,各训练方设备调用训练后的机器学习模型,对共享样本集进行预测处理,得到共享样本集对应的预测值集合。
步骤411,各训练方设备向服务方设备发送各自确定的预测值集合。
步骤412,服务方设备对接收到的各训练方设备发送的预测值集合进行融合处理,得到融合值集合。
在实际实施时,服务方设备对接收到的预测值集合进行融合处理,常见的融合处理包括求均值、求中位数或者求方差等。
步骤413,服务方设备向各训练方设备发送融合值集合。
步骤414,各训练方设备接收服务方设备发送的融合值集合,并根据融合值集合更新共享样本集。
在实际实施时,训练方设备使用融合值更新共享样本的标签信息。
步骤415,各训练方设备根据更新后的共享样本集与本地私有样本集,对机器学习模型进行下一轮的训练。
通过步骤407至步骤414,各训练方设备对从服务方设备获取的机器学习模型进行一轮完整的模型训练。然后通过步骤415,根据更新后的共享样本集以及本地私有样本集重复执行步骤407至步骤414,直至本地机器学习模型满足收敛条件,得到训练完成的机器学习模型,其中,收敛条件可以是模型收敛,或迭代训练次数到达预设迭代训练次数限制。
本申请实施例中服务方设备存储有各训练方设备上传的共享样本集,能够有效减少模型训练时的通信量;同时各训练方设备仅需要上传共享样本集对应的预测值,不需要上传本地样本,有效保证了本地样本数据的隐私性以及安全性;另外,训练方设备通过使用融合值更新本地共享样本集,进而训练对应的机器学习模型,能够提高模型训练速率,加快模型收敛。
下面,将说明本申请实施例在一个实际的应用场景中的示例性应用。
联邦学习是一种基于数据隐私保护技术实现的分布式训练范式,它能保证训练数据在不出本地的前提下,联合多个参与方共同训练一个全局的共享模型。
经典的横向联邦学习训练方法存在的一个问题是,由于共享全局模型,因此每一个客户端都知道模型的结构,并且只能使用统一的模型结构。
但在很多场景中,每一个用户可能要解决的问题不相同,因此需要的模型也不一样。他们使用联邦学习的目的,并不是为了训练一个统一的全局模型,而是希望在不泄露用户数据隐私的前提下,利用其它参与方的数据信息提升自身的模型效果。
在实际实施时,参见图8,图8是本申请实施例提供的联邦学习模型的结构示意图,图中以一个服务端(也可称服务方)以及三个客户端(也可称训练方)为例,假设三个客户端都是分类模型,但它们使用的模型结构不一样。由于三个客户端的模型均不一样,因此,无法使用传统的联邦学习直接进行训练。
基于此,本申请实施例提供一种机器学习模型的训练方法,该方法可应用于联邦学习场景,来实现不同模型网络结构的协同训练。
步骤一:首先服务端准备一份共享数据集,设为D0,客户端i的本地数据集记为Di。其中,共享数据集D0的数据可来自于各客户端的本地数据,例如,所有客户端共同约定,提取一部分本地数据上传服务端生成数据集D0
步骤二:服务端将共享数据集D0发送给每个客户端,如此,每个客户端i既有共享的数据集D0,也有本地数据集Di
步骤三:每个客户端i利用D0和Di进行本地训练,设客户端i的本地模型为fi
步骤四:客户端i对本地模型fi训练后,将
Figure BDA0003358959530000221
代入fi中,得到:
Figure BDA0003358959530000222
步骤五:服务端接收到所有客户端上传的分数值
Figure BDA0003358959530000223
后,执行平均计算:
Figure BDA0003358959530000224
上述公式中的m值表示客户端的数量。这样对于任意的j,
Figure BDA0003358959530000225
值可称为共识(或者称为融合值)。将得到的
Figure BDA0003358959530000226
重新下发给每个客户端。
步骤六:服务端向每个客户端发送均值
Figure BDA0003358959530000227
每个客户端将接收到的
Figure BDA0003358959530000228
替换原来的D0的标签值,这样D0变更为:
Figure BDA0003358959530000229
将D0与Di重新作为客户端i的训练数据,重新训练本地模型fi。重复执行上面的步骤,直到模型收敛为止。
本申请实施例中服务端存储有各客户端上传的共享数据集,能够有效减少模型训练时的通信量;同时各客户端仅需要上传共享样本集对应的预测值,不需要上传本地样本,有效保证了本地样本数据的隐私性以及安全性;另外,客户端通过使用融合值更新本地共享样本集,进而训练对应的机器学习模型,能够提高模型训练速率,加快模型收敛。
下面结合图2A和图2B说明本申请实施例提供的服务方设备200和训练参与方设备(400-1和400-2)实施为软件模块的示例性结构。
在一些实施例中,如图2B所示,图2B示出的是是本申请实施例提供的训练方设备400的结构示意图,存储在存储器430的机器学习模型的训练装置433中的软件模块可以包括:
训练模块4331,用于基于共享样本集以及所述训练方设备的私有样本集,对所述训练方设备的机器学习模型进行训练,得到训练后的所述机器学习模型;
预测模块4332,用于接收各所述训练方设备发送的对应所述共享样本集的共享预测值;并对所述共享预测值进行融合处理,得到与各所述共享样本对应的融合值;
第一发送模块4333,用于向服务方设备发送所述预测值集合;其中,所述预测值集合用于供所述服务方设备结合其他训练方设备发送的预测值集合进行融合处理,得到融合值集合;
接收模块4334,用于接收所述服务方设备发送的所述融合值集合,并根据所述融合值集合更新所述共享样本集;其中,更新后的所述共享样本集与所述私有样本集,用于供所述训练方设备对所述机器学习模型进行下一轮的训练。
在一些实施例中,所述接收模块,还用于针对所述共享样本集中每个共享样本执行以下处理:获取所述融合值集合与所述共享样本对应的融合值,基于所述融合值替换所述共享样本的标签。
在一些实施例中,所述训练模块,在训练所述训练方设备的机器学习模型之前,还用于向所述服务方设备发送样本获取请求;接收所述服务方设备响应于所述样本获取请求而发送的所述共享样本集;其中,所述共享样本集包括每个所述训练方设备的私有样本集中的部分样本。
在一些实施例中,所述训练模块,在训练所述训练方设备的机器学习模型之前,还用于接收所述服务方设备发送的样本上传请求;从所述训练方设备的私有样本集中选取部分样本,并向所述服务方设备发送所述部分样本。
在一些实施例中,所述训练模块,还用于基于样本上传比例,从所述训练方设备的私有样本集中选取与所述样本上传比例对应的样本;或者,从所述训练方设备的所述私有样本集中选取目标数量的样本。
在一些实施例中,所述训练模块,还用于在对所述机器学习模型进行第一轮训练之前,向所述服务方设备发送模型获取请求;其中,所述模型获取请求包括模型标识;接收所述服务方设备发送的与所述模型标识对应的初始机器学习模型,将所述初始机器学习模型作为待训练的机器学习模型。
在一些实施例中,所述训练模块,还用于在对所述机器学习模型进行第一轮训练之前,向所述服务方设备发送模型参数获取请求;其中,所述模型参数获取请求包括模型标识;接收所述服务方设备发送的与所述模型标识对应的模型参数;其中,所述模型参数是所述服务方设备进行初始化得到的。
在一些实施例中,如图2A所示,图2A示出的是是本申请实施例提供的服务方设备200的结构示意图,存储在存储器230的机器学习模型的训练装置233中的软件模块可以包括:
第二发送模块2331,用于向每个所述训练方设备发送共享样本集;其中,所述共享样本集和所述训练方设备的私有样本集,用于供所述训练方设备训练所述训练方设备的机器学习模型;
融合模块2332,用于接收每个所述训练方设备发送的预测值集合并进行融合处理,得到融合值集合;其中,所述预测值集合是所述训练方设备基于所述共享样本集调用训练后机器学习学习模型进行预测处理得到的,所述机器学习模型是基于所述共享样本集和所述训练方设备的私有样本集训练的;
第三发送模块2333,用于向所述训练方设备发送所述融合值集合;其中,所述融合值集合用于供所述训练方设备更新所述共享样本集,并结合对所述训练方设备的私有样本集对所述机器学习模型进行下一轮的训练。
在一些实施例中,所述融合模块,还用于针对所述共享样本集中每个共享样本执行以下处理:从每个所述训练方设备发送的预测值集合中,获取所述共享样本对应的多个预测值;对所述多个预测值求平均值,将所述平均值作为所述共享样本对应的融合值,所述融合值与所述共享样本存在一一对应的关系。
需要说明的是,本申请实施例装置的描述,与上述方法实施例的描述是类似的,具有同方法实施例相似的有益效果,因此不做赘述。
本申请实施例提供了一种计算机程序产品,包括计算机程序,其特征在于,该计算机程序被处理器执行时实现本申请实施例提供的机器学习模型的训练方法。
本申请实施例提供一种存储有可执行指令的计算机可读存储介质,其中存储有可执行指令,当可执行指令被处理器执行时,将引起处理器执行本申请实施例提供的方法,例如,如图3示出的机器学习模型的训练方法。
在一些实施例中,计算机可读存储介质可以是FRAM、ROM、PROM、EPROM、EEPROM、闪存、磁表面存储器、光盘、或CD-ROM等存储器;也可以是包括上述存储器之一或任意组合的各种设备。
在一些实施例中,可执行指令可以采用程序、软件、软件模块、脚本或代码的形式,按任意形式的编程语言(包括编译或解释语言,或者声明性或过程性语言)来编写,并且其可按任意形式部署,包括被部署为独立的程序或者被部署为模块、组件、子例程或者适合在计算环境中使用的其它单元。
作为示例,可执行指令可以但不一定对应于文件系统中的文件,可以可被存储在保存其它程序或数据的文件的一部分,例如,存储在超文本标记语言(HTML,Hyper TextMarkup Language)文档中的一个或多个脚本中,存储在专用于所讨论的程序的单个文件中,或者,存储在多个协同文件(例如,存储一个或多个模块、子程序或代码部分的文件)中。
作为示例,可执行指令可被部署为在一个计算设备上执行,或者在位于一个地点的多个计算设备上执行,又或者,在分布在多个地点且通过通信网络互连的多个计算设备上执行。
综上所述,通过本申请实施例能够有效减少模型训练时的通信量;同时各训练方设备仅需要上传共享样本集对应的预测值,不需要上传本地样本,有效保证了本地样本数据的隐私性以及安全性;另外,训练方设备通过使用融合值更新本地共享样本集,进而训练对应的机器学习模型,能够提高模型训练速率,加快模型收敛。
以上所述,仅为本申请的实施例而已,并非用于限定本申请的保护范围。凡在本申请的精神和范围之内所作的任何修改、等同替换和改进等,均包含在本申请的保护范围之内。

Claims (14)

1.一种机器学习模型的训练方法,其特征在于,应用于训练方设备,所述方法包括:
基于共享样本集以及所述训练方设备的私有样本集,训练所述训练方设备的机器学习模型;
基于所述共享样本集,调用训练后的所述机器学习模型进行预测处理,得到预测值集合;
向服务方设备发送所述预测值集合;其中,所述预测值集合用于供所述服务方设备结合其他训练方设备发送的预测值集合进行融合处理,得到融合值集合;
接收所述服务方设备发送的所述融合值集合,并根据所述融合值集合更新所述共享样本集;其中,更新后的所述共享样本集与所述私有样本集,用于供所述训练方设备对所述机器学习模型进行下一轮的训练。
2.根据权利要求1所述的方法,其特征在于,所述根据所述融合值集合更新所述共享样本集,包括:
针对所述共享样本集中每个共享样本执行以下处理:
获取所述融合值集合与所述共享样本对应的融合值,基于所述融合值替换所述共享样本的标签。
3.根据权利要求1所述的方法,其特征在于,在训练所述训练方设备的机器学习模型之前,所述方法还包括:
向所述服务方设备发送样本获取请求;
接收所述服务方设备响应于所述样本获取请求而发送的所述共享样本集;
其中,所述共享样本集包括每个所述训练方设备的私有样本集中的部分样本。
4.根据权利要求1所述的方法,其特征在于,在训练所述训练方设备的机器学习模型之前,所述方法还包括:
接收所述服务方设备发送的样本上传请求;
从所述训练方设备的私有样本集中选取部分样本,并向所述服务方设备发送所述部分样本。
5.根据权利要求4所述的方法,其特征在于,所述从所述训练方设备的私有样本集中选取部分样本,包括:
基于样本上传比例,从所述训练方设备的私有样本集中选取与所述样本上传比例对应的样本;
或者,
从所述训练方设备的所述私有样本集中选取目标数量的样本。
6.根据权利要求1所述的方法,其特征在于,在训练所述训练方设备的机器学习模型之前,所述方法还包括:
向所述服务方设备发送模型获取请求;其中,所述模型获取请求包括模型标识;
接收所述服务方设备发送的与所述模型标识对应的初始机器学习模型,将所述初始机器学习模型作为待训练的机器学习模型。
7.根据权利要求1所述的方法,其特征在于,在训练所述训练方设备的机器学习模型之前,所述方法还包括:
向所述服务方设备发送模型参数获取请求;其中,所述模型参数获取请求包括模型标识;
接收所述服务方设备发送的与所述模型标识对应的模型参数;其中,所述模型参数是所述服务方设备进行初始化得到的。
8.一种机器学习模型的训练方法,其特征在于,应用于服务方设备,所述方法包括:
向每个训练方设备发送共享样本集;其中,所述共享样本集和所述训练方设备的私有样本集,用于供所述训练方设备训练所述训练方设备的机器学习模型;
接收每个所述训练方设备发送的预测值集合并进行融合处理,得到融合值集合;其中,所述预测值集合是所述训练方设备基于所述共享样本集调用训练后机器学习学习模型进行预测处理得到的,所述机器学习模型是基于所述共享样本集和所述训练方设备的私有样本集训练的;
向所述训练方设备发送所述融合值集合;其中,所述融合值集合用于供所述训练方设备更新所述共享样本集,并结合对所述训练方设备的私有样本集对所述机器学习模型进行下一轮的训练。
9.根据权利要求8所述的方法,其特征在于,所述接收每个所述训练方设备发送的预测值集合并进行融合处理,得到融合值集合,包括:
针对所述共享样本集中每个共享样本执行以下处理:
从每个所述训练方设备发送的预测值集合中,获取所述共享样本对应的多个预测值;
对所述多个预测值求平均值,将所述平均值作为所述共享样本对应的融合值,所述融合值与所述共享样本存在一一对应的关系。
10.一种机器学习模型的训练装置,其特征在于,包括:
训练模块,用于基于共享样本集以及所述训练方设备的私有样本集,对所述训练方设备的机器学习模型进行训练,得到训练后的所述机器学习模型;
预测模块,用于接收各所述训练方设备发送的对应所述共享样本集的共享预测值;并对所述共享预测值进行融合处理,得到与各所述共享样本对应的融合值;
第一发送模块,用于向服务方设备发送所述预测值集合;其中,所述预测值集合用于供所述服务方设备结合其他训练方设备发送的预测值集合进行融合处理,得到融合值集合;
接收模块,用于接收所述服务方设备发送的所述融合值集合,并根据所述融合值集合更新所述共享样本集;其中,更新后的所述共享样本集与所述私有样本集,用于供所述训练方设备对所述机器学习模型进行下一轮的训练。
11.一种机器学习模型的训练装置,其特征在于,包括:
第二发送模块,用于向每个所述训练方设备发送共享样本集;其中,所述共享样本集和所述训练方设备的私有样本集,用于供所述训练方设备训练所述训练方设备的机器学习模型;
融合模块,用于接收每个所述训练方设备发送的预测值集合并进行融合处理,得到融合值集合;其中,所述预测值集合是所述训练方设备基于所述共享样本集调用训练后机器学习学习模型进行预测处理得到的,所述机器学习模型是基于所述共享样本集和所述训练方设备的私有样本集训练的;
第三发送模块,用于向所述训练方设备发送所述融合值集合;其中,所述融合值集合用于供所述训练方设备更新所述共享样本集,并结合对所述训练方设备的私有样本集对所述机器学习模型进行下一轮的训练。
12.一种电子设备,其特征在于,包括:
存储器,用于存储可执行指令;
处理器,用于执行所述存储器中存储的可执行指令时,实现权利要求1至9任一项所述的机器学习模型的训练方法。
13.一种计算机可读存储介质,其特征在于,存储有可执行指令,用于被处理器执行时,实现权利要求1至9任一项所述的机器学习模型的训练方法。
14.一种计算机程序产品,包括计算机程序,其特征在于,该计算机程序被处理器执行时实现权利要求1至9任一项所述的机器学习模型的训练方法。
CN202111360276.7A 2021-11-17 2021-11-17 机器学习模型的训练方法、装置、电子设备及存储介质 Pending CN114021473A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202111360276.7A CN114021473A (zh) 2021-11-17 2021-11-17 机器学习模型的训练方法、装置、电子设备及存储介质

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202111360276.7A CN114021473A (zh) 2021-11-17 2021-11-17 机器学习模型的训练方法、装置、电子设备及存储介质

Publications (1)

Publication Number Publication Date
CN114021473A true CN114021473A (zh) 2022-02-08

Family

ID=80064826

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202111360276.7A Pending CN114021473A (zh) 2021-11-17 2021-11-17 机器学习模型的训练方法、装置、电子设备及存储介质

Country Status (1)

Country Link
CN (1) CN114021473A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117153312A (zh) * 2023-10-30 2023-12-01 神州医疗科技股份有限公司 基于模型平均算法的多中心临床试验方法及系统

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117153312A (zh) * 2023-10-30 2023-12-01 神州医疗科技股份有限公司 基于模型平均算法的多中心临床试验方法及系统

Similar Documents

Publication Publication Date Title
Xu et al. Unleashing the power of edge-cloud generative ai in mobile networks: A survey of aigc services
CN110443375A (zh) 一种联邦学习方法及装置
CN112257873A (zh) 机器学习模型的训练方法、装置、系统、设备及存储介质
Luck et al. Agent technology: enabling next generation computing (a roadmap for agent based computing)
Pérez et al. Group decision making problems in a linguistic and dynamic context
Hayyolalam et al. Single‐objective service composition methods in cloud manufacturing systems: Recent techniques, classification, and future trends
CN112288097A (zh) 联邦学习数据处理方法、装置、计算机设备及存储介质
CN113222175B (zh) 信息处理方法及系统
O’Hara The contradictions of digital modernity
US20210067497A1 (en) System and method for matching dynamically validated network data
CN112906864A (zh) 信息处理方法、装置、设备、存储介质及计算机程序产品
CN114021473A (zh) 机器学习模型的训练方法、装置、电子设备及存储介质
Kim R-learning-based team game model for Internet of things quality-of-service control scheme
Bala et al. A novel game theory based reliable proof‐of‐stake consensus mechanism for blockchain
CN112418929A (zh) 一种数据共享方法及装置
Gershon Intelligent networks and international business communication: A systems theory interpretation
CN114611015B (zh) 交互信息处理方法、装置和云服务器
CN116471092A (zh) 一种基于模块化区块链的自适应元宇宙系统
CN114153491A (zh) 应用程序接口的编排方法、装置、设备以及存储介质
US10984061B2 (en) Systems and methods for providing communications to and from verified entities
Fernandez et al. Collaborative, distributed simulations of agri-food supply chains. Analysis on how linking theory and practice by using multi-agent structures
US10360498B2 (en) Unsupervised training sets for content classification
Yang An Agent-Based Simulation of Heterogeneous Games and Social Systems in Politics, Fertility and Economic Development
Calcaterra et al. Future of Decentralization
KR102637198B1 (ko) 인공지능 모델 제작 플랫폼을 통한 인공지능 모델 공유, 대여 및 판매방법, 장치 및 컴퓨터프로그램

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination