CN117807555A - 模型预测方法、装置、电子设备和存储介质 - Google Patents
模型预测方法、装置、电子设备和存储介质 Download PDFInfo
- Publication number
- CN117807555A CN117807555A CN202311862724.2A CN202311862724A CN117807555A CN 117807555 A CN117807555 A CN 117807555A CN 202311862724 A CN202311862724 A CN 202311862724A CN 117807555 A CN117807555 A CN 117807555A
- Authority
- CN
- China
- Prior art keywords
- model
- machine learning
- models
- learning models
- consensus
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 59
- 238000003860 storage Methods 0.000 title claims abstract description 24
- 238000010801 machine learning Methods 0.000 claims abstract description 140
- 238000012549 training Methods 0.000 claims description 43
- 230000006870 function Effects 0.000 claims description 23
- 238000009826 distribution Methods 0.000 claims description 15
- 230000008569 process Effects 0.000 claims description 14
- 238000004422 calculation algorithm Methods 0.000 claims description 10
- 238000004364 calculation method Methods 0.000 claims description 8
- 238000005516 engineering process Methods 0.000 abstract description 7
- 238000013473 artificial intelligence Methods 0.000 abstract description 2
- 238000004590 computer program Methods 0.000 description 17
- 238000004891 communication Methods 0.000 description 13
- 238000010586 diagram Methods 0.000 description 12
- 230000015654 memory Effects 0.000 description 9
- 238000005457 optimization Methods 0.000 description 7
- 238000012545 processing Methods 0.000 description 6
- 238000013140 knowledge distillation Methods 0.000 description 5
- 241000282326 Felis catus Species 0.000 description 4
- 230000002776 aggregation Effects 0.000 description 4
- 238000004220 aggregation Methods 0.000 description 4
- 238000001514 detection method Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 4
- 238000005070 sampling Methods 0.000 description 4
- 230000005540 biological transmission Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 241001465754 Metazoa Species 0.000 description 2
- 230000002411 adverse Effects 0.000 description 2
- 238000001914 filtration Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 230000004931 aggregating effect Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000010354 integration Effects 0.000 description 1
- 230000002452 interceptive effect Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 230000000116 mitigating effect Effects 0.000 description 1
- 238000004806 packaging method and process Methods 0.000 description 1
- 238000006116 polymerization reaction Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 239000000758 substrate Substances 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Landscapes
- Image Analysis (AREA)
Abstract
本公开提供了一种模型预测方法,涉及人工智能领域、金融科技领域或其他技术领域,该方法包括:将待处理的图像或文本输入N个机器学习模型;获得N个所述机器学习模型基于待处理的图像或文本输出的N个初始预测结果;基于N个所述机器学习模型各自的模型权重与初始预测结果,得到待处理的图像或文本的最终预测结果;其中,每个所述机器学习模型的模型权重通过该模型的共识质量得到,所述共识质量通过该模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定。本公开还提供了一种模型预测装置、电子设备和存储介质。
Description
技术领域
本公开涉及人工智能领域、金融科技领域或其他技术领域,更具体地,涉及一种模型预测方法、装置、电子设备和存储介质。
背景技术
联邦学习是一种在多台机器上实现的分布式机器学习技术,目标是在保护数据安全和隐私安全的基础上,实现多机构联合建模,提升机器学习模型的推理效果。联邦学习依赖于分布式设备之间的大量通信训练多个子模型得到集成模型,且每个分布式设备本地所训练得到的子模型依赖于本地样本集的质量。
相关技术中,面对联邦学习中遇到的通信瓶颈、推理较慢和数据需要满足强IID条件等问题,目前常用的解决方案是使用知识蒸馏,知识蒸馏通过迁移学习方式,将教师模型的推理结果作为监督信号,指导学生模型的训练,从而完成知识迁移过程。除了知识蒸馏以外,其他缓解方法还有模型集成、个性化学习等。
在实现本公开发明构思的过程中,发明人发现,相关技术针对联邦学习缺点的解决方案依然受到训练子模型的样本集质量的影响,导致集成模型的共识预测结果受到个别子模型的影响。因此,在保护数据安全和隐私安全的基础上,降低个别子模型对共识预测结果的负面影响,提高共识预测结果的准确性是当前亟待解决的问题。
发明内容
鉴于上述问题,本公开提供了模型预测方法、装置、电子设备和存储介质。
本公开实施例的一个方面,提供了一种模型预测方法,包括:将待处理的图像或文本输入N个机器学习模型,其中,N个所述机器学习模型基于服务器节点及K个客户端节点执行联邦学习算法预先训练得到,N、K皆为大于或等于2的整数;获得N个所述机器学习模型基于待处理的图像或文本输出的N个初始预测结果;基于N个所述机器学习模型各自的模型权重与初始预测结果,得到待处理的图像或文本的最终预测结果;其中,每个所述机器学习模型的模型权重通过该模型的共识质量得到,所述共识质量通过该模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定。
根据本公开的实施例,在得到待处理的图像或文本的最终预测结果之前,确定每个所述机器学习模型的模型权重包括:获得每个所述机器学习模型的共识质量;对于任一个所述机器学习模型,将该机器学习模型的的共识质量输入到归一化的指数函数,以基于计算结果得到该模型的模型权重。
根据本公开的实施例,获得每个所述机器学习模型的共识质量包括:确定N个所述机器学习模型对第一样本集的共识逻辑,所述共识逻辑基于每个所述机器学习模型对所述第一样本集的预测结果进行知识投票得到,所述第一样本集包括至少一个图像样本或至少一个文本样本;基于所述共识逻辑和所述模型数量得到所述共识质量。
根据本公开的实施例,基于计算结果得到模型权重包括:
按照如下归一化的指数函数计算第k个机器学习模型的模型权重
其中,Qk(S)代表共识质量,K≥k≥1,S表征第一样本集,Nk代表第k个客户端节点的数据量,第k个机器学习模型来自第k个客户端节点,LF代表逻辑焦点标识,用于指示为模型权重;
其中,描述了第k个客户端模型对集成模型全局共识的边际贡献。
根据本公开的实施例,在将待处理的图像或文本输入N个机器学习模型之前,还包括:使用在K个所述客户端节点本地训练得到的K个所述机器学习模型,对所述服务器节点存储的无标签第二样本集进行预测,得到K个预测结果,所述第二样本集包括至少一个图像样本或至少一个文本样本;使K个所述客户端节点将K个所述预测结果传输至所述服务器节点,其中,所述服务器节点被配置为基于K个所述预测结果训练第一机器学习模型,N个所述机器学习模型包括该第一机器学习模型。
根据本公开的实施例,所述服务器节点基于K个所述预测结果训练第一机器学习模型包括:基于K个所述预测结果中对所述无标签第二样本集中每个图像样本或每个文本样本的共识结果确定样本标签,得到有标签第二样本集,所述共识结果包括K个所述预测结果中占多数的结果;基于有标签的第二样本集训练所述第一机器学习模型。
根据本公开的实施例,在基于有标签的第二样本集训练所述第一机器学习模型之后,所述方法还包括:从K个所述机器学习模型和所述第一机器学习模型中确定M个待定模型,M为大于或等于1的整数,且M小于或等于K+1;对M个所述待定模型各自的模型参数进行拟合处理,得到M个第二机器学习模型,其中,每个所述第二机器学习模型的模型参数符合拟合处理所依据的数学分布,N个所述机器学习模型包括M个所述第二机器学习模型。
根据本公开的实施例,在将待处理的图像或文本输入N个机器学习模型之前,还包括:获得N个所述机器学习模型对第三样本集的共识预测结果,以及每个所述机器学习模型对所述第三样本集的预测结果,其中,所述第三样本集包括至少一个图像样本或至少一个文本样本;确定每个所述机器学习模型对所述第三样本集的预测结果与所述共识预测结果之间的偏离度;对于任一个所述机器学习模型,当其偏离度大于或等于预设阈值时,对该模型重新训练,直至其偏离度小于所述预设阈值。
本公开实施例的另一方面提供了一种模型预测装置,包括:数据输入模块,用于将待处理的图像或文本输入N个机器学习模型,其中,N个所述机器学习模型基于服务器节点及K个客户端节点执行联邦学习算法预先训练得到,K、N皆为大于或等于2的整数;预测输出模块,用于获得N个所述机器学习模型基于待处理的图像或文本输出的N个初始预测结果;最终预测模块,用于基于N个所述机器学习模型各自的模型权重与初始预测结果,得到待处理的图像或文本的最终预测结果;其中,每个所述机器学习模型的模型权重通过该模型的共识质量得到,所述共识质量通过该模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定。
本公开实施例的另一方面提供了一种电子设备,包括:一个或多个处理器;存储装置,用于存储一个或多个程序,其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得一个或多个处理器执行如上所述的方法。
本公开实施例的另一方面还提供了一种计算机可读存储介质,其上存储有可执行指令,该指令被处理器执行时使处理器执行如上所述的方法。
本公开实施例的另一方面还提供了一种计算机程序产品,包括计算机程序,该计算机程序被处理器执行时实现如上所述的方法。
上述一个或多个实施例具有如下有益效果:通过任一模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定共识质量,进一步得到模型权重,可以不需要访问本地数据,保护数据安全和隐私安全,能够调整各个模型的预测结果对最终预测结果的贡献,有效降低某个模型因接受不相关数据训练而导致对集成模型整体的负面影响,降低了个别子模型的预测结果对最终预测结果的不利影响,提升N个机器学习模型整体的泛化性。
附图说明
通过以下参照附图对本公开实施例的描述,本公开的上述内容以及其他目的、特征和优点将更为清楚,在附图中:
图1示意性示出了根据本公开实施例的模型预测方法的应用场景图;
图2示意性示出了根据本公开实施例的基于联邦学习的训练及模型预测的全局流程图;
图3示意性示出了根据本公开实施例的知识投票的流程图;
图4A示意性示出了根据本公开实施例的训练第一机器学习模型的流程图;
图4B示意性示出了根据本公开实施例的知识投票的图解示例;
图5A示意性示出了根据本公开实施例的模型参数拟合的流程图;
图5B示意性示出了根据本公开实施例的模型参数拟合的图解示例;
图6示意性示出了根据本公开实施例的确定模型权重的流程图;
图7示意性示出了根据本公开实施例的得到共识质量的流程图;
图8示意性示出了根据本公开实施例的模型差异度优化的流程图;
图9示意性示出了根据本公开实施例的模型预测方法的流程图;
图10示意性示出了根据本公开实施例的模型预测装置的结构框图;以及
图11示意性示出了根据本公开实施例的适于实现模型预测方法的电子设备的方框图。
具体实施方式
以下,将参照附图来描述本公开的实施例。但是应该理解,这些描述只是示例性的,而并非要限制本公开的范围。在下面的详细描述中,为便于解释,阐述了许多具体的细节以提供对本公开实施例的全面理解。然而,明显地,一个或多个实施例在没有这些具体细节的情况下也可以被实施。此外,在以下说明中,省略了对公知结构和技术的描述,以避免不必要地混淆本公开的概念。
随着科技的发展,数据安全越发重要,建立在大数据基础之上的机器学习和深度学习技术在不断地革新着当下的生产力,对数据量的需求也越来越大。然而出于数据安全和隐私保护的需要,并非所有数据均能被参与机构共享。但模型的训练需要使用大量数据,对于大部分组织机构来说,很难获取到足够的数据,即使收集到一部分数据,也无法保证数据的质量符合要求,即存在数据孤岛现象。对于上述问题可以应用联邦学习加以解决,但联邦学习依赖于分布式设备之间的大量通信训练子模型,也带来了一些包括训练速度慢,通信代价高昂,数据需要符合独立同分布条件(1ID)等挑战。
本公开一些实施例提供了一种模型预测方法,通过任一模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定共识质量,进一步得到模型权重,可以在不需要访问本地数据的情况下,保护了数据安全和隐私安全,调整各个模型的预测结果对最终预测结果的贡献,有效降低某个模型因接受不相关数据训练而导致对集成模型整体的负面影响,降低了个别子模型对最终预测结果的不利影响,提升N个机器学习模型整体的泛化性。
在本发明的技术方案中,所涉及的用户信息(包括但不限于用户个人信息、用户图像信息、用户设备信息,例如位置信息等)和数据(包括但不限于用于分析的数据、存储的数据、展示的数据等),均为经用户授权或者经过各方充分授权的信息和数据,并且相关数据的收集、存储、使用、加工、传输、提供、公开和应用等处理,均遵守相关国家和地区的相关法律法规和标准,采取了必要保密措施,不违背公序良俗,并提供有相应的操作入口,供用户选择授权或者拒绝。
图1示意性示出了根据本公开实施例的模型预测方法的应用场景图。需要注意的是,图1所示仅为可以应用本公开实施例的示例,以帮助本领域技术人员理解本公开的技术内容,但并不意味着本公开实施例不可以用于其他设备、系统、环境或场景。
如图1所示,根据该实施例的应用场景100可以包括服务器节点和K个客户端节点,可以执行联邦学习算法预先训练得到N个机器学习模型,K、N皆为大于或等于2的整数。
参照图1,在训练阶段,使用至少一台服务器节点(Server)以及若干台客户端节点(Guest和Host),各客户端节点具有相同的网络结构,训练任务首先由Host(如客户端1)发起,除此之外,Guest和Host在模型训练上完全一致。
任意客户端节点可以是具有显示屏并且支持网页浏览的各种电子设备,包括但不限于智能手机、平板电脑、膝上型便携计算机和台式计算机等等。
服务器节点可以是提供各种服务的服务器,例如对用户利用任意客户端节点所浏览的网站提供支持的后台管理服务器(仅为示例)。后台管理服务器可以对接收到的用户请求等数据进行分析等处理,并将处理结果(例如根据用户请求获取或生成的网页、信息、或数据等)反馈给终端设备。
服务器节点可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或分布式系统,还可以是提供云服务、云计算、网络服务、中间件服务等基础云计算服务的云服务器。
应该理解,图1中的服务器和客户端的数目仅仅是示意性的。根据实现需要,可以具有任意数目的服务器和客户端。
图2示意性示出了根据本公开实施例的基于联邦学习的训练及模型预测的全局流程图。
参照图2,开始首先由各客户端基于本地数据(如)训练本地模型,即本地的机器学习模型(如/>)。完毕后开始使用各本地模型对统一的无标签样本集预测并执行知识投票,待所有客户端把对样本集的预测结果上传服务器后,由服务器基于样本集的预测结果(如Dτ)训练得到扩展模型,即服务器训练的机器学习模型,如/>(步骤1)。接着对各本地模型和扩展模型进行模型参数拟合,得到若干个泛化模型(步骤2)。接着,计算逻辑焦点并按逻辑焦点分配模型权重,以集成全局模型(步骤3)。进一步执行模型差异度优化步骤,减小子模型和集成模型的差异,得到优化后的集成模型(步骤4)。以上训练过程迭代若干次直到模型收敛或达到最大次数为止。最后部署集成模型进行模型预测。本公开提及的集成模型包括多个子模型,例如N个机器学习模型。
示例性地,本公开训练阶段及预测阶段涉及的待处理的图像或图像样本包括人脸图像、人体姿态图像、遥感图像、动物图像、物体图像或界面截图等,预测结果包括人脸识别、姿态识别、物体分类或风险预测等。本公开所提及待处理的文本或文本样本包括用户数据、应用属性、服务器日志或新闻等,预测结果包括用户分类、应用故障检测或性能检测等。
需要说明的是,本公开实施例中基于联邦学习训练至模型预测的过程,并不限定于图2所示的所有流程。
以下将基于图1和图2的描述,通过图3~图10对本公开实施例的模型预测方法进行详细描述。
图3示意性示出了根据本公开实施例的知识投票的流程图。
如图3所示,该实施例包括:
在操作S310,使用在K个客户端节点本地训练得到的K个机器学习模型,对服务器节点存储的无标签第二样本集进行预测,得到K个预测结果,第二样本集包括至少一个图像样本或至少一个文本样本;
各个客户端节点训练的模型具有相同的架构,例如基于卷积神经网络算法构建得到。该模型可以包括多个卷积层,在同一卷积层中,可以使用多个卷积核来提取不同的图像信息,一般地,卷积核数量越多,卷积操作反映的图像信息越丰富。
在训练过程中,可以采用误差反向传播(back propagation,BP)算法在训练过程中修正初始的模型中参数的大小,使得模型的重建误差损失越来越小。具体地,前向传递输入图像或文本,直至输出预测结果会产生误差损失,通过反向传播误差损失信息来更新初始的超分辨率模型中参数,从而使误差损失收敛。
服务器节点可以将存储的无标签第二样本集发送至各个客户端节点。
在操作S320,使K个客户端节点将K个预测结果传输至服务器节点,其中,服务器节点被配置为基于K个预测结果训练第一机器学习模型,N个机器学习模型包括该第一机器学习模型。
相关技术中,客户端节点与服务器节点之间往往在联邦学习阶段进行不断的通信来交互模型更新,数量较多的客户端节点以及庞大的模型参数会对通信网络造成巨大的带宽负担。导致通信量大、通信代价高昂,模型训练速度较慢。
根据本公开的实施例,先由客户端节点训练完成本地完成,然后客户端节点与服务器节点之间传输对无标签第一样本集的预测结果,以便在服务器节点构造带高质量伪标签的第一样本集实现知识蒸馏,降低通信量,有效缓解了联邦学习中的高通信代价问题。
图4A示意性示出了根据本公开实施例的训练第一机器学习模型的流程图。图4B示意性示出了根据本公开实施例的知识投票的图解示例。
如图4A所示,该实施例包括:
在操作S410,基于K个预测结果中对无标签第二样本集中每个图像样本或每个文本样本的共识结果确定样本标签,得到有标签第二样本集,共识结果包括K个预测结果中占多数的结果;
首先在K个客户端使用本地数据训练得到K个本地模型然后使用这K个模型预测服务器端无标签数据,得到预测logits为/>代表每个模型的预测结果,如某个类别的概率值。接着集成所有客户端模型的预测值获得共识逻辑Ci:
其中|D|=∑k|Dk|,基于共识逻辑Ci可以将服务器无标签数据集扩展为有标签的数据集下述结合图4B介绍共识逻辑Ci。
在操作S420,基于有标签的第二样本集训练第一机器学习模型。
例如在训练第一机器学习模型过程中通过比较该模型的预测与1样本标签之间的差异来计算误差损失,并执行反向传播算法更新模型参数。以KL距离和CE交叉熵损失作为Loss函数训练扩展模型如下公式2所示:
其中,DKL()用于计算KL距离,CE()用于计算交叉熵损失,表示第i个样本,T表示样本空间,q()表示模型的预测值logits,kv表示本地数据。
下面结合图2~图4B说明利用共识投票机制得到共识结果进行训练的过程。
举例说明,知识投票的主要思想是少数服从多数原则,即如果某个预测logit能够被其他更多推理结果佐证,那么就更有可能是真正的标签。如图4B所示,为了提高获取的共识逻辑质量,知识投票一共分为三个步骤:首先,通过一个置信门设置置信阈值,过滤掉模糊的结果。如图4B图中Client 4给出了cat预测概率0.4,dog预测概率0.6,这是一个很模糊的推理值,由于模型不敢给出更高置信度的推理结果,因此直接过滤掉。然后,在剩下的客户端模型中进一步过滤与主流结果相悖的推理值,如图4B中Client 1对cat的预测概率为0.07,明显与主流不符,因此更有可能是有偏答案,直接过滤。最后,对两轮过滤后剩下的预测logits(如cat的概率值0.94和0.92)进行平均聚合,得到共识逻辑Ci。同时记录支持Ci的客户端模型数量nc_i,并作为训练服务端扩展模型的式2Loss函数权重。对于经前两轮过滤的模型预测概率,并非直接舍弃而是给一个很小的权重nc=0.001,降低其影响。
知识投票之后即可得到基于服务器端公共数据的扩展数据集和扩展模型/>再使用前面的权重/>对知识蒸馏损失函数重新加权,可得到最终损失函数如下:
其中,
α′,β′为0~1之间的权重超参数。
可以理解,使用客户端模型预测服务端无标签数据,基于预测结果训练得到扩展模型这个过程相当于将客户端模型知识通过伪标签转移到服务端扩展模型中,完成知识转移,实现知识蒸馏。
图5A示意性示出了根据本公开实施例的模型参数拟合的流程图。图5B示意性示出了根据本公开实施例的模型参数拟合的图解示例。
在基于有标签的第二样本集训练第一机器学习模型之后,如图5A所示,该实施例包括:
在操作S510,从K个机器学习模型和第一机器学习模型中确定M个待定模型,M为大于或等于1的整数,且M小于或等于K+1;例如可以随机确定,也可以根据每个模型的预测准确率来确定M个待定模型。
在操作S520,对M个待定模型各自的模型参数进行拟合处理,得到M个第二机器学习模型,其中,每个第二机器学习模型的模型参数符合拟合处理所依据的数学分布,N个机器学习模型包括M个第二机器学习模型。
示例性地,当数据符合独立同分布条件时,服务器参数聚合阶段能得到正确的梯度使模型正常收敛,但是当数据不符合独立同分布条件时,则常规聚合后的参数梯度不一定能使模型收敛,这种情况下聚合的参数会偏离理想参数w*。
常规的联邦学习训练流程中,有一个阶段是聚合各客户端模型的参数到服务端,加权求平均作为服务端模型参数。而根据本公开的实施例,聚合阶段聚合的是预测logits,再以logits作为伪标签训练服务端模型,不需要聚合模型参数。因此,集成模型是客户端模型和服务端模型加权集成得到的,权重值由逻辑焦点确定。
在一些实施例中,参照图5B,使用贝叶斯模型抽样解决此问题。以高斯分布为例,首先基于对角高斯分布拟合客户端模型参数,如下公式所示:
其中,wi指第i个模型参数,表示第i个客户端对应的数据量,/>表示各个客户端节点和服务器的总数据量。μ根据各个客户端的数据量作为权重,对所有客户端模型的参数求加权均值。
根据上面公式可知客户端模型wi和Fedavg平均聚合得到的模型均有可能为高斯分布的参数样本。然后再依据该分布采样模型/>实现模型参数的拟合。除高斯分布之外,也可使用狄利克雷分布采样模型。
在一些实施例中,可以在对应客户端执行模型参数的拟合,则对应客户端存储本地模型以及拟合后的泛化模型,可以降低数据传输量,降低通信成本。在另一些实施例中,可以使得对应客户端将本地模型的参数上传到服务器,然后由服务器执行参数的拟合,则服务器存储第一机器学习模型和M个泛化模型。下述模型差异度优化阶段对泛化模型的优化在存储其的节点执行。
根据本公开的实施例,按照特定数据分布对模型参数拟合,得到的抽样模型(又称泛化模型)能够克服数据不符合独立同分布条件时模型不收敛的问题,提高集成模型的泛化性。
图6示意性示出了根据本公开实施例的确定模型权重的流程图。
如图6所示,该实施例包括:
在操作S610,获得每个机器学习模型的共识质量,共识质量通过N个机器学习模型中与该模型对图像样本或文本样本的预测结果相同的模型数量确定;
所谓共识质量即如果一个共识类能被更多其他模型推理支持,那么这个共识类别就更有可能是正确结果,即共识质量高,故通过与该模型具有相同预测结果的模型数量确定。
图7示意性示出了根据本公开实施例的得到共识质量的流程图。
如图7所示,该实施例是操作S610的其中一个实施例,包括:
在操作S710,确定N个机器学习模型对第一样本集的共识逻辑,共识逻辑基于每个机器学习模型对第一样本集的预测结果进行知识投票(参照图4B)得到,第一样本集包括至少一个图像样本或至少一个文本样本;
在操作S720,基于共识逻辑和模型数量得到共识质量。
令S=Dk+1表示服务端数据集,则第i个客户端模型在S上预测的logits为Ci(S),共识逻辑为(max(Ci(S)),),其中/>表示支持该共识逻辑的客户端模型数量,进一步地有共识质量为/>则客户端模型k对应的共识质量为:
其中,nc_i为支持Ci的客户端模型数量,Ci(S′)为共识逻辑,max指的是取最大值。
通过使用nc_i可以提高模型预测正确标签的概率,因为权重越大,说明客户端模型支持该label的数量越多,结果越可信。Ci表示对于样本i,客户端模型预测的logit值(logit是未经softmax归一化的推理结果)。Max(Ci)选取其中logit值最大的那个作为伪标签。nc_i表示推理结果等于Max(Ci)的模型数量。用即可增强该伪标签的推理可信度,故定义为共识质量。
在操作S620,对于任一个所述机器学习模型,将该机器学习模型的的共识质量输入到归一化的指数函数,以基于计算结果得到该模型的模型权重。
根据操作S620的其中一个实施例,包括:按照如下归一化的指数函数计算第k个机器学习模型的模型权重
其中,Qk(S)代表第k个机器学习模型的共识质量,K≥k≥1,S表征第一样本集,Nk代表第k个客户端节点的数据量,第k个机器学习模型来自第k个客户端节点,LF代表逻辑焦点标识,用于指示为模型权重。根据本公开的实施例,/>描述了第k个客户端模型对集成模型全局共识的边际贡献,/>值越高对集成模型的影响也越大。通过上述方法计算得到的模型权重不需要访问本地数据,适用于联邦学习情形,同时逻辑焦点的计算用到了数据和标签信息,可以有效降低不相关数据对集成模型的影响,提升集成模型的泛化性。
模型权重决定了每个模型对集成模型的影响程度,一般情况下联邦学习中当所有模型同等重要时,模型权重应取决于各客户端的数据质量和数量,但当数据不符合独立同分布条件时,这个权重就无法衡量了。该实施例提出了共识质量和逻辑焦点的概念,解决此问题。所谓共识质量即如果一个共识类能被更多其他模型推理支持,那么这个共识类别就更有可能是正确结果。
图8示意性示出了根据本公开实施例的模型差异度优化的流程图。
如图8所示,该实施例包括:
在操作S810,获得N个机器学习模型对第三样本集的共识预测结果,以及每个机器学习模型对第三样本集的预测结果,其中,第三样本集包括至少一个图像样本或至少一个文本样本;
本公开涉及的第一样本集、第二样本集和第三样本集可以是相同或不同的样本集,任两个样本集之间可以包括相同的样本。
在操作S820,确定每个机器学习模型对第三样本集的预测结果与共识预测结果之间的偏离度;
示例性地,参照图4B,若某个机器学习模型对cat的预测概率为0.5,共识预测结果为0.95,则单个样本的偏离度为两者差值0.45,操作S820的偏离度为第三样本集中所有样本的偏离度之和。
在操作S830,对于任一个机器学习模型,当其偏离度大于或等于预设阈值时,对该模型重新训练,直至其偏离度小于预设阈值。预设阈值可以灵活确定,本公开不进行限定。
参照图1和图8,为了进一步降低由于客户端数据Non-iid而导致的子模型偏离过大问题,接下来还需要最小化子模型(包括本地模型、泛化模型和扩展模型中一种或多种)和集成模型的参数差距。为此结合模型参数正则项提出了一种新的目标函数,如下公式所示:
其中f(x)是常规损失函数,g(w)用来衡量子模型相对于集成模型的偏离度,偏离度越大,Loss值越大,梯度反向传播就越显著。代表集成模型的参数,α′、β′是f(x)和g(w)的0~1之间的超参数。根据公式,当β>0时,模型参数偏离度才会被纳入目标函数中,并随着β的增大对目标函数的影响越来越大。通过以上Loss2函数的设置,可以有效降低因个别客户端数据质量不高而导致的模型发散问题,增强集成模型性能。
根据本公开的实施例,使用逻辑焦点和模型差异度优化的手段,为子模型精准加权,有效缓解了客户端数据Non-IID问题,同时也提高了模型识别恶意节点的能力。
图9示意性示出了根据本公开实施例的模型预测方法的流程图。
如图9所示,该实施例包括:
在操作S910,将待处理的图像或文本输入N个机器学习模型(即集成模型),其中,N个机器学习模型基于服务器节点及K个客户端节点执行联邦学习算法预先训练得到,K、N皆为大于或等于2的整数;
在操作S920,获得N个机器学习模型基于待处理的图像或文本输出的N个初始预测结果;
在操作S930,基于N个机器学习模型各自的模型权重与初始预测结果,得到待处理的图像或文本的最终预测结果;
其中,每个机器学习模型的模型权重通过将该模型的逻辑焦点输入归一化的指数函数得到,逻辑焦点通过该模型对图像样本或文本样本的预测结果与N个机器学习模型对图像样本或文本样本的共识预测结果之间的差异程度得到。
其中,hF:表征集成模型输出的共识预测结果,N为K、M和1的和,即为图1中的K′,αk表征第k个模型的模型权重,表征第k个机器学习模型的预测结果,c′表征生产环境部署的标识。
例如可以将人脸图像、人体姿态图像、遥感图像、动物图像、物体图像或界面截图等一种或多种图像输入集成模型,预测结果包括人脸识别、姿态识别、物体分类或风险预测等。例如可以将用户数据、应用属性、服务器日志或新闻等一种或多种文本输入集成模型,预测结果包括用户分类、应用故障检测或性能检测等。
根据本公开的实施例,能够以此通过模型权重来决定对应模型初始预测结果对共识预测结果的影响程度。从而不需要访问本地数据,保护了数据安全和隐私安全,有效降低不相关数据对模型的影响,提升N个机器学习模型整体的泛化性。
基于上述模型预测方法,本公开还提供了一种模型预测装置。以下将结合图10对该装置进行详细描述。
图10示意性示出了根据本公开实施例的模型预测装置的结构框图。
如图10所示,该实施例的模型预测装置1000包括数据输入模块1010、预测输出模块1020和最终预测模块1030。
数据输入模块1010可以执行操作S910,用于将待处理的图像或文本输入N个机器学习模型,其中,N个机器学习模型基于服务器节点及K个客户端节点执行联邦学习算法预先训练得到,K、N皆为大于或等于2的整数;
预测输出模块1020可以执行操作S920,用于获得N个机器学习模型基于待处理的图像或文本输出的N个初始预测结果;
最终预测模块1030可以执行操作S930,用于基于N个机器学习模型各自的模型权重与初始预测结果,得到待处理的图像或文本的最终预测结果;
其中,每个所述机器学习模型的模型权重通过该模型的共识质量得到,所述共识质量通过该模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定。
在一些实施例中,模型预测装置1000还可以包括传输模块,该模块可以执行操作S310~操作S320,在此不再赘述。
在一些实施例中,模型预测装置1000还可以包括扩展训练模块,该模块可以执行操作S410~操作S420,在此不再赘述。
在一些实施例中,模型预测装置1000还可以包括拟合模块,该模块可以执行操作S510~操作S520,在此不再赘述。
在一些实施例中,模型预测装置1000还可以包括权重计算模块,该模块可以执行操作S610~操作S630,操作S710~操作S720,在此不再赘述。
在一些实施例中,模型预测装置1000还可以包括差异优化模块,该模块可以执行操作S810~操作S830,在此不再赘述。
需要说明的是,模型预测装置1000包括分别用于执行如上图2~图9描述的任意一个实施例的各个步骤的模块。装置部分实施例中各模块/单元/子单元等的实施方式、解决的技术问题、实现的功能、以及达到的技术效果分别与方法部分实施例中各对应的步骤的实施方式、解决的技术问题、实现的功能、以及达到的技术效果相同或类似,在此不再赘述。
根据本公开的实施例,数据输入模块1010、预测输出模块1020和最终预测模块1030中的任意多个模块可以合并在一个模块中实现,或者其中的任意一个模块可以被拆分成多个模块。或者,这些模块中的一个或多个模块的至少部分功能可以与其他模块的至少部分功能相结合,并在一个模块中实现。
根据本公开的实施例,数据输入模块1010、预测输出模块1020和最终预测模块1030中的至少一个可以至少被部分地实现为硬件电路,例如现场可编程门阵列(FPGA)、可编程逻辑阵列(PLA)、片上系统、基板上的系统、封装上的系统、专用集成电路(ASIC),或可以通过对电路进行集成或封装的任何其他的合理方式等硬件或固件来实现,或以软件、硬件以及固件三种实现方式中任意一种或以其中任意几种的适当组合来实现。或者,数据输入模块1010、预测输出模块1020和最终预测模块1030中的至少一个可以至少被部分地实现为计算机程序模块,当该计算机程序模块被运行时,可以执行相应的功能。
图11示意性示出了根据本公开实施例的适于实现模型预测方法的电子设备的方框图。
如图11所示,根据本公开实施例的电子设备1100包括处理器1101,其可以根据存储在只读存储器(ROM)1102中的程序或者从存储部分1108加载到随机访问存储器(RAM)1103中的程序而执行各种适当的动作和处理。处理器1101例如可以包括通用微处理器(例如CPU)、指令集处理器和/或相关芯片组和/或专用微处理器(例如,专用集成电路(ASIC))等等。处理器1101还可以包括用于缓存用途的板载存储器。处理器1101可以包括用于执行根据本公开实施例的方法流程的不同动作的单一处理单元或者是多个处理单元。
在RAM 1103中,存储有电子设备1100操作所需的各种程序和数据。处理器1101、ROM 1102以及RAM 1103通过总线1104彼此相连。处理器1101通过执行ROM 1102和/或RAM1103中的程序来执行根据本公开实施例的方法流程的各种操作。需要注意,程序也可以存储在除ROM 1102和RAM 1103以外的一个或多个存储器中。处理器1101也可以通过执行存储在一个或多个存储器中的程序来执行根据本公开实施例的方法流程的各种操作。
根据本公开的实施例,电子设备1100还可以包括输入/输出(I/O)接口1105,输入/输出(I/O)接口1105也连接至总线1104。电子设备1100还可以包括连接至I/O接口1105的以下部件中的一项或多项:包括键盘、鼠标等的输入部分1106;包括诸如阴极射线管(CRT)、液晶显示器(LCD)等以及扬声器等的输出部分1107;包括硬盘等的存储部分1108;以及包括诸如LAN卡、调制解调器等的网络接口卡的通信部分1109。通信部分1109经由诸如因特网的网络执行通信处理。驱动器1110也根据需要连接至I/O接口1105。可拆卸介质1111,诸如磁盘、光盘、磁光盘、半导体存储器等等,根据需要安装在驱动器1110上,以便于从其上读出的计算机程序根据需要被安装入存储部分1108。
本公开还提供了一种计算机可读存储介质,该计算机可读存储介质可以是上述实施例中描述的设备/装置/系统中所包含的;也可以是单独存在,而未装配入该设备/装置/系统中。上述计算机可读存储介质承载有一个或者多个程序,当上述一个或者多个程序被执行时,实现根据本公开实施例的方法。
根据本公开的实施例,计算机可读存储介质可以是非易失性的计算机可读存储介质,例如可以包括但不限于:便携式计算机磁盘、硬盘、随机访问存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本公开中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。例如,根据本公开的实施例,计算机可读存储介质可以包括上文描述的ROM 1102和/或RAM 1103和/或ROM 1102和RAM 1103以外的一个或多个存储器。
本公开的实施例还包括一种计算机程序产品,其包括计算机程序,该计算机程序包含用于执行流程图所示的方法的程序代码。当计算机程序产品在计算机系统中运行时,该程序代码用于使计算机系统实现本公开实施例所提供的方法。
在该计算机程序被处理器1101执行时执行本公开实施例的系统/装置中限定的上述功能。根据本公开的实施例,上文描述的系统、装置、模块、单元等可以通过计算机程序模块来实现。
在一种实施例中,该计算机程序可以依托于光存储器件、磁存储器件等有形存储介质。在另一种实施例中,该计算机程序也可以在网络介质上以信号的形式进行传输、分发,并通过通信部分1109被下载和安装,和/或从可拆卸介质1111被安装。该计算机程序包含的程序代码可以用任何适当的网络介质传输,包括但不限于:无线、有线等等,或者上述的任意合适的组合。
在这样的实施例中,该计算机程序可以通过通信部分1109从网络上被下载和安装,和/或从可拆卸介质1111被安装。在该计算机程序被处理器1101执行时,执行本公开实施例的系统中限定的上述功能。根据本公开的实施例,上文描述的系统、设备、装置、模块、单元等可以通过计算机程序模块来实现。
根据本公开的实施例,可以以一种或多种程序设计语言的任意组合来编写用于执行本公开实施例提供的计算机程序的程序代码,具体地,可以利用高级过程和/或面向对象的编程语言、和/或汇编/机器语言来实施这些计算程序。程序设计语言包括但不限于诸如Java,C++,python,“C”语言或类似的程序设计语言。程序代码可以完全地在用户计算设备上执行、部分地在用户设备上执行、部分在远程计算设备上执行、或者完全在远程计算设备或服务器上执行。在涉及远程计算设备的情形中,远程计算设备可以通过任意种类的网络,包括局域网(LAN)或广域网(WAN),连接到用户计算设备,或者,可以连接到外部计算设备(例如利用因特网服务提供商来通过因特网连接)。
附图中的流程图和框图,图示了按照本公开各种实施例的系统、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段、或代码的一部分,上述模块、程序段、或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个接连地表示的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图或流程图中的每个方框、以及框图或流程图中的方框的组合,可以用执行规定的功能或操作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
本领域技术人员可以理解,本公开的各个实施例和/或权利要求中记载的特征可以进行多种组合或/或结合,即使这样的组合或结合没有明确记载于本公开中。特别地,在不脱离本公开精神和教导的情况下,本公开的各个实施例和/或权利要求中记载的特征可以进行多种组合和/或结合。所有这些组合和/或结合均落入本公开的范围。
以上对本公开的实施例进行了描述。但是,这些实施例仅仅是为了说明的目的,而并非为了限制本公开的范围。尽管在以上分别描述了各实施例,但是这并不意味着各个实施例中的措施不能有利地结合使用。本公开的范围由所附权利要求及其等同物限定。不脱离本公开的范围,本领域技术人员可以做出多种替代和修改,这些替代和修改都应落在本公开的范围之内。
Claims (11)
1.一种模型预测方法,包括:
将待处理的图像或文本输入N个机器学习模型,其中,N个所述机器学习模型基于服务器节点及K个客户端节点执行联邦学习算法预先训练得到,N、K皆为大于或等于2的整数;
获得N个所述机器学习模型基于待处理的图像或文本输出的N个初始预测结果;
基于N个所述机器学习模型各自的模型权重与初始预测结果,得到待处理的图像或文本的最终预测结果;
其中,每个所述机器学习模型的模型权重通过该模型的共识质量得到,所述共识质量通过该模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定。
2.根据权利要求1所述的方法,其中,在得到待处理的图像或文本的最终预测结果之前,确定每个所述机器学习模型的模型权重包括:
获得每个所述机器学习模型的共识质量;
对于任一个所述机器学习模型,将该机器学习模型的的共识质量输入到归一化的指数函数,以基于计算结果得到该模型的模型权重。
3.根据权利要求2所述的方法,其中,获得每个所述机器学习模型的共识质量包括:
确定N个所述机器学习模型对第一样本集的共识逻辑,所述共识逻辑基于每个所述机器学习模型对所述第一样本集的预测结果进行知识投票得到,所述第一样本集包括至少一个图像样本或至少一个文本样本;
基于所述共识逻辑和所述模型数量得到所述共识质量。
4.根据权利要求2所述的方法,其中,基于计算结果得到模型权重包括:
按照如下归一化的指数函数计算第k个机器学习模型的模型权重
其中,Qk(S)代表共识质量,K≥k≥1,S表征第一样本集,Nk代表第k个客户端节点的数据量,第k个机器学习模型来自第k个客户端节点,LF代表逻辑焦点标识,用于指示为模型权重;
其中,描述了第k个客户端模型对集成模型全局共识的边际贡献。
5.根据权利要求1或2所述的方法,其中,在将待处理的图像或文本输入N个机器学习模型之前,还包括:
使用在K个所述客户端节点本地训练得到的K个所述机器学习模型,对所述服务器节点存储的无标签第二样本集进行预测,得到K个预测结果,所述第二样本集包括至少一个图像样本或至少一个文本样本;
使K个所述客户端节点将K个所述预测结果传输至所述服务器节点,其中,所述服务器节点被配置为基于K个所述预测结果训练第一机器学习模型,N个所述机器学习模型包括该第一机器学习模型。
6.根据权利要求5所述的方法,其中,所述服务器节点基于K个所述预测结果训练第一机器学习模型包括:
基于K个所述预测结果中对所述无标签第二样本集中每个图像样本或每个文本样本的共识结果确定样本标签,得到有标签第二样本集,所述共识结果包括K个所述预测结果中占多数的结果;
基于有标签的第二样本集训练所述第一机器学习模型。
7.根据权利要求6所述的方法,其中,在基于有标签的第二样本集训练所述第一机器学习模型之后,所述方法还包括:
从K个所述机器学习模型和所述第一机器学习模型中确定M个待定模型,M为大于或等于1的整数,且M小于或等于K+1;
对M个所述待定模型各自的模型参数进行拟合处理,得到M个第二机器学习模型,其中,每个所述第二机器学习模型的模型参数符合拟合处理所依据的数学分布,N个所述机器学习模型包括M个所述第二机器学习模型。
8.根据权利要求1或7所述的方法,其中,在将待处理的图像或文本输入N个机器学习模型之前,还包括:
获得N个所述机器学习模型对第三样本集的共识预测结果,以及每个所述机器学习模型对所述第三样本集的预测结果,其中,所述第三样本集包括至少一个图像样本或至少一个文本样本;
确定每个所述机器学习模型对所述第三样本集的预测结果与所述共识预测结果之间的偏离度;
对于任一个所述机器学习模型,当其偏离度大于或等于预设阈值时,对该模型重新训练,直至其偏离度小于所述预设阈值。
9.一种模型预测装置,包括:
数据输入模块,用于将待处理的图像或文本输入N个机器学习模型,其中,N个所述机器学习模型基于服务器节点及K个客户端节点执行联邦学习算法预先训练得到,K、N皆为大于或等于2的整数;
预测输出模块,用于获得N个所述机器学习模型基于待处理的图像或文本输出的N个初始预测结果;
最终预测模块,用于基于N个所述机器学习模型各自的模型权重与初始预测结果,得到待处理的图像或文本的最终预测结果;
其中,每个所述机器学习模型的模型权重通过该模型的共识质量得到,所述共识质量通过该模型对图像样本或文本样本的预测结果,及N个所述机器学习模型中输出相同预测结果的模型数量确定。
10.一种电子设备,包括:
一个或多个处理器;
存储装置,用于存储一个或多个程序,
其中,当所述一个或多个程序被所述一个或多个处理器执行时,使得所述一个或多个处理器执行根据权利要求1~8中任一项所述的方法。
11.一种计算机可读存储介质,其上存储有可执行指令,该指令被处理器执行时使处理器执行根据权利要求1~8中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311862724.2A CN117807555A (zh) | 2023-12-29 | 2023-12-29 | 模型预测方法、装置、电子设备和存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311862724.2A CN117807555A (zh) | 2023-12-29 | 2023-12-29 | 模型预测方法、装置、电子设备和存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117807555A true CN117807555A (zh) | 2024-04-02 |
Family
ID=90421479
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311862724.2A Pending CN117807555A (zh) | 2023-12-29 | 2023-12-29 | 模型预测方法、装置、电子设备和存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117807555A (zh) |
-
2023
- 2023-12-29 CN CN202311862724.2A patent/CN117807555A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US11868891B2 (en) | Machine-learning techniques for monotonic neural networks | |
US20210150355A1 (en) | Training machine learning models using task selection policies to increase learning progress | |
US10558913B1 (en) | Machine-learning techniques for monotonic neural networks | |
US11823013B2 (en) | Text data representation learning using random document embedding | |
EP3574454B1 (en) | Learning neural network structure | |
US10922609B2 (en) | Semi-supervised learning via deep label propagation | |
CN111279362B (zh) | 胶囊神经网络 | |
US11366990B2 (en) | Time-series representation learning via random time warping | |
US20180189950A1 (en) | Generating structured output predictions using neural networks | |
JP2018190396A (ja) | ネットワークレーティング予測エンジン | |
US10990889B2 (en) | Generating a predictive behavior model for predicting user behavior using unsupervised feature learning and a recurrent neural network | |
CN110852447A (zh) | 元学习方法和装置、初始化方法、计算设备和存储介质 | |
US20230049747A1 (en) | Training machine learning models using teacher annealing | |
US20220215209A1 (en) | Training machine learning models using unsupervised data augmentation | |
US20220094649A1 (en) | Systems and methods for generating dynamic conversational responses using trained machine learning models | |
US20220019888A1 (en) | Unified framework for dynamic clustering and discrete time event prediction | |
US20220230065A1 (en) | Semi-supervised training of machine learning models using label guessing | |
EP3945472A2 (en) | Method of and system for online machine learning with dynamic model evaluation and selection | |
CN110968887B (zh) | 在数据隐私保护下执行机器学习的方法和系统 | |
GB2600817A (en) | Systems and methods for generating dynamic interface options using machine learning models | |
US20240046128A1 (en) | Dynamic causal discovery in imitation learning | |
US20230351119A1 (en) | Systems and methods for generating dynamic conversational responses through aggregated outputs of machine learning models | |
CN117807555A (zh) | 模型预测方法、装置、电子设备和存储介质 | |
US20230017505A1 (en) | Accounting for long-tail training data through logit adjustment | |
US20220358366A1 (en) | Generation and implementation of dedicated feature-based techniques to optimize inference performance in neural networks |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |