CN116994018A - 模型训练方法、分类预测方法以及装置 - Google Patents
模型训练方法、分类预测方法以及装置 Download PDFInfo
- Publication number
- CN116994018A CN116994018A CN202211167683.0A CN202211167683A CN116994018A CN 116994018 A CN116994018 A CN 116994018A CN 202211167683 A CN202211167683 A CN 202211167683A CN 116994018 A CN116994018 A CN 116994018A
- Authority
- CN
- China
- Prior art keywords
- server
- model parameters
- servers
- data set
- image
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 214
- 238000000034 method Methods 0.000 title claims abstract description 105
- 238000013145 classification model Methods 0.000 claims abstract description 98
- 238000004364 calculation method Methods 0.000 claims abstract description 69
- 238000012545 processing Methods 0.000 claims abstract description 46
- 238000012360 testing method Methods 0.000 claims description 103
- 238000005070 sampling Methods 0.000 claims description 22
- 230000006870 function Effects 0.000 claims description 18
- 238000004590 computer program Methods 0.000 claims description 16
- 238000012937 correction Methods 0.000 claims description 6
- 230000003247 decreasing effect Effects 0.000 claims 1
- 230000000694 effects Effects 0.000 abstract description 22
- 238000013473 artificial intelligence Methods 0.000 abstract description 17
- 238000010801 machine learning Methods 0.000 abstract description 8
- 238000004422 calculation algorithm Methods 0.000 description 18
- 238000005516 engineering process Methods 0.000 description 16
- 230000008569 process Effects 0.000 description 13
- 238000010586 diagram Methods 0.000 description 8
- 238000004891 communication Methods 0.000 description 5
- 238000010276 construction Methods 0.000 description 4
- 238000011160 research Methods 0.000 description 4
- 238000006467 substitution reaction Methods 0.000 description 4
- 230000001360 synchronised effect Effects 0.000 description 3
- 238000013528 artificial neural network Methods 0.000 description 2
- 230000006399 behavior Effects 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 230000007246 mechanism Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000003190 augmentative effect Effects 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 239000000835 fiber Substances 0.000 description 1
- 230000006698 induction Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000012804 iterative process Methods 0.000 description 1
- 239000011159 matrix material Substances 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000003672 processing method Methods 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/40—Extraction of image or video features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Software Systems (AREA)
- Evolutionary Computation (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Medical Informatics (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Molecular Biology (AREA)
- Image Analysis (AREA)
Abstract
本申请提供了一种模型训练方法、分类预测方法以及装置,其涉及人工智能中的机器学习,该模型训练方法包括:获取n张图像,并对该n张图像进行划分得到k个图像集;利用k个服务器分别对该k个图像集进行图像处理,以得到该k个服务器中的各个服务器的样本数据集;基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型。该模型训练方法能够在降低训练成本、每次迭代计算对计算资源和存储资源的要求以及每次迭代计算的计算量的基础上保证图像分类模型的训练效果。
Description
技术领域
本申请实施例涉及人工智能中的机器学习领域,并且更具体地,涉及模型训练方法、分类预测方法以及装置。
背景技术
在自动驾驶领域中,针对用于目标分类模型的训练中,需要对大规模数据进行迭代训练计算需要大量的计算资源。例如,通常训练一个目标分类模型需要进行10000次以上的迭代运算才能得到较稳定的模型求解,而每次迭代训练的图像样本数量往往达到10000张以上。因此,对于如此大规模的图像数据集,如果采用传统的梯度下降算法进行模型的迭代计算,需要训练较长时间。
发明内容
本申请实施例提供了一种模型训练方法、分类预测方法以及装置,能够在降低图像分类模型的训练成本、每次迭代计算对计算资源和存储资源的要求以及每次迭代计算的计算量的基础上保证图像分类模型的训练效果。
第一方面,本申请实施例提供了一种模型训练方法,包括:
获取n张图像,并对该n张图像进行划分得到k个图像集;n>k>1;
利用k个服务器分别对该k个图像集进行图像处理,以得到该k个服务器中的各个服务器的样本数据集;
基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型。
第二方面,本申请实施例提供了一种分类预测方法,包括:
获取目标图像;
利用k个服务器中的目标服务器对该目标图像进行图像处理,以得到该目标图像的特征数据;其中,该k个服务器配置有k个图像分类模型,该k个图像分类模型为按照第一方面提供的方法训练的图像分类模型;
基于该目标图像的特征数据,利用该目标服务器中的图像分类模型,对该目标图像进行分类预测,得到该目标图像的图像分类结果。
第三方面,本申请实施例提供了一种模型训练装置,包括:
获取单元,用于获取n张图像,并对该n张图像进行划分得到k个图像集;n>k>1;
处理单元,用于利用k个服务器分别对该k个图像集进行图像处理,以得到该k个服务器中的各个服务器的样本数据集;
训练单元,用于基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型。
第四方面,本申请实施例提供了一种分类预测装置,包括:
获取单元,用于获取目标图像;
处理单元,用于利用k个服务器中的目标服务器对该目标图像进行图像处理,以得到该目标图像的特征数据;其中,该k个服务器配置有k个图像分类模型,该k个图像分类模型为按照第一方面描述的方法训练的图像分类模型;
预测单元,用于基于该目标图像的特征数据,利用该目标服务器中的图像分类模型,对该目标图像进行分类预测,得到该目标图像的图像分类结果。
第五方面,本申请实施例提供了一种电子设备,包括:
处理器,适于实现计算机指令;以及,
计算机可读存储介质,计算机可读存储介质存储有计算机指令,计算机指令适于由处理器加载并执行上文涉及的第一方面或第二方面所提供的方法。
第六方面,本申请实施例提供了一种计算机可读存储介质,该计算机可读存储介质存储有计算机指令,该计算机指令被计算机设备的处理器读取并执行时,使得计算机设备执行上文涉及的第一方面或第二方面所提供的方法。
第七方面,本申请实施例提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上文涉及的第一方面或第二方面所提供的方法。
本申请实施例中,利用k个服务器分别对该k个图像集进行图像处理以及利用各个服务器对各个服务器的样本数据集进行迭代计算的方式训练图像分类模型,相当于,通过各个服务器分摊模型训练所需的计算量,不仅能够合理利用各个服务器中的空闲资源,进而降低了该k个图像分类模型的训练成本,还能够降低每次迭代计算对计算资源和存储资源的要求以及降低每次迭代计算的计算量。此外,基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型,相当于,在第t+1次迭代计算中,可以通过共享该k个服务器中经过第t次迭代计算后得到的k个模型参数,来实现对该k个图像分类模型的联合训练,能够保证该k个图像分类模型的训练效果。
简言之,本申请实施例提供的模型训练方法能够在降低图像分类模型的训练成本、每次迭代计算对计算资源和存储资源的要求以及每次迭代计算的计算量的基础上保证图像分类模型的训练效果。
附图说明
图1是本申请实施例提供的系统框架的示例。
图2是本申请实施例提供的模型训练方法的示意性流程图。
图3是本申请实施例提供的模型训练方法的另一示意性流程图。
图4是本申请实施例提供的模型训练方法涉及的各个阶段的示意图。
图5是本申请实施例提供的分类预测方法的示意性流程图。
图6是本申请实施例提供的模型训练装置的示意性框图。
图7是本申请实施例提供的分类预测装置的示意性框图。
图8是本申请实施例提供的电子设备的示意性框图。
具体实施方式
本申请提供的方案可涉及人工智能(Artificial Intelligence,AI)技术领域。
其中,AI是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
应理解,人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
随着人工智能技术研究和进步,人工智能技术在多个领域展开研究和应用,例如常见的智能家居、智能穿戴设备、虚拟助理、智能音箱、智能营销、无人驾驶、自动驾驶、无人机、机器人、智能医疗、智能客服等,相信随着技术的发展,人工智能技术将在更多的领域得到应用,并发挥越来越重要的价值。
本申请实施例可涉及人工智能技术中的计算机视觉(Computer Vision,CV)技术,计算机视觉是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、OCR、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、3D技术、虚拟现实、增强现实、同步定位与地图构建等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。
本申请实施例也可以涉及人工智能技术中的机器学习(Machine Learning,ML),ML是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习等技术。
为便于理解本申请的方案,下面对本申请涉及的相关术语进行说明。
1、梯度下降算法:
梯度下降算法的核心思想是:先随便初始化一个w0,然后给定一个步长η,通过不断地修改wt+1←wt,从而最后靠近到达取得最大值的点,即不断进行下面的迭代过程,直到达到指定次数,或者梯度等于0为止,常用的梯度下降公式可以表示为:
2、随机梯度下降算法:
随机梯度下降算法不直接使用梯度而是采用另一个输出为随机变量的替代函数:g(w(t))。
值得注意的是,该替代函数g(w(t))需要满足它的期望值等于相当于这个函数围绕着/>的输出值随机波动。常用的随机梯度下降公式可以表示为:w(t+1)=w(t)-ηg(w(t))。
3、牛顿法:
牛顿法主要是为了解决非线性优化问题,其收敛速度比梯度下降算法速度更快。牛顿法的主要思想是:在极小值的当前估计值的附近对损失函数:做二阶的泰勒展开,进而可以找到极小点的下一个估计值。假设为当前的极小值估计值/>则:对齐进行求导得到:进一步的,J′(w)=0时,可得到:/>由此,时得到:/>其中,/>表示海森矩阵。
图1是本申请实施例提供的系统框架100的示例。
该系统框架100可以是一个应用程序系统,本申请实施例对该应用程序的具体类型不加以限定。该系统框架100包括:终端131、终端132和服务器集群110。终端131和终端132均可通过无线或有线网络120与服务器集群110相连。
终端131和终端132可以是智能手机、游戏主机、台式计算机、平板电脑、电子书阅读器、MP4播放器、MP4播放器和膝上型便携计算机中的至少一种。终端131和终端132可以是客户端,其安装和运行有应用程序。该应用程序可以是在线视频程序、短视频程序、图片分享程序、声音社交程序、漫画程序、壁纸程序、新闻推送程序、供求信息推送程序、学术交流程序、技术交流程序、政策交流程序、包含评论机制的程序、包含观点发布机制的程序、知识分享程序中的任意一种。终端131和终端132可以分别是用户141、用户142使用的终端,终端131和终端132中运行的应用程序内登录有注册账户。
服务器集群110包括k个服务器,其中各个服务器可以是物理服务器、云平台或虚拟化中心。服务器集群110用于为应用程序(例如终端131和终端132上的应用程序)提供后台服务。可选地,服务器集群110承担主要计算工作,终端131和终端132承担次要计算工作;或者,服务器集群110承担次要计算工作,终端131和终端132承担主要计算工作;或者,终端131和终端132和服务器集群110之间采用分布式计算架构进行协同计算。例如集合本申请来说,服务器集群110中的中的各个服务器均可配置有图像分类模型,各个服务器可对各个服务器的图像集进行图像处理,并利用图像处理得到的样本数据集进行图像分类模型的训练。
图2示出了根据本申请实施例的模型训练方法200的示意性流程图,该方法200可以由模型训练装置执行,该模型训练装置可与k个服务器相连,该k个服务器可以用于为该模型训练装置提供计算服务,该k个服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是云服务器,其中,该k个服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。
如图2所示,该方法200可包括:
S210,模型训练装置获取n张图像,并对该n张图像进行划分得到k个图像集;n>k>1。
示例性地,该n张图像可以是图像采集车或用户上传的图像。
示例性地,模型训练装置可以将该n张图像等分划分为该k个图像集,即该k个图像集中的各个图像集中包括n/k张图像。
示例性地,该模型训练装置可以将该n张图像随机划分为该k个图像集。
示例性地,该模型训练装置可以将该n张图像按照预设的比例划分为该k个图像集。
示例性地,该模型训练装置可以基于k个服务器的可用资源、可用计算资源和可用存储资源将该n张图像划分为k个图像集。例如,该模型训练装置可以按照该k个服务器的可用资源的比例、该k个服务器的可用计算资源的比例、或该k个服务器的可用存储资源的比例,将该n张图像划分为该k个图像集。其中,该k个服务器中的各个服务器的可用资源、可用计算资源或可用存储资源越多时,为该各个服务器划分的图像集中的图像越多。例如,该k个服务器的可用资源的比例、该k个服务器的可用计算资源的比例、或该k个服务器的可用存储资源的比例,可以等于该k个图像集中图像数量的比例。
S220,模型训练装置利用k个服务器分别对该k个图像集进行图像处理,以得到该k个服务器中的各个服务器的样本数据集。
示例性地,该模型训练装置可以将该k个数据集分别上传至该k个服务器,并利用该k个服务器中各个服务器对该各个服务器获取的图像集中的图像进行图像处理,以得到各个服务器的样本数据集。
示例性地,该各个服务器采用的图像处理方法包括但不限于利用各种神经网络或编码器进行图像处理的方法。
示例性地,该各个服务器的样本数据集可包括特征数据和标签数据。该特征数据可以是用于描述图像的特征的数据,该标签数据可以是用于描述图像的分类结果的数据。
示例性地,假设该k个图像集中的各个图像集中包括n/k张图像,则该各个服务器的样本数据集可表示为:{Yi,Xi}[n/k]×m。其中,Yi表示第i个服务器的样本数据集中的标签数据,Xi表示第i个服务器的样本数据集中的特征数据,n/k表示第i个服务器的样本集中包括的图像的数量,m表示特征数据和标签数据的维度。即该k个服务器的样本数据集可表示为:{Y1,X1}[n/k]×m,{Y2,X2}[n/k]×m,…,{Yk,Xk}[n/k]×m
S230,模型训练装置基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型。
示例性地,模型训练装置基于该k个服务器中经过第t次迭代计算后得到的k个模型参数确定待迭代的模型参数,并利用待迭代的模型参数更新该各个服务器中的模型参数,然后基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,以得到该k个服务器中经过第t次迭代计算后得到的k个模型参数,以此类推,直至得到训练后的k个图像分类模型。
示例性地,该各个服务器中的模型参数的初始值相同。换言之,该各个服务器进行第1次迭代计算时,该各个服务器中待迭代的模型参数相同。
示例性地,该k个服务器中的模型参数的初始值分别表示为W1,W2,...,WK。其中,W1=W2=...=WK。
示例性地,在第t次迭代计算中,该各个服务器可以采用小批量梯度下降算法(MBGD)对各个服务器的模型参数进行并行独立的进行迭代计算,并得到该k个服务器中经过第t次迭代计算后得到的k个模型参数。
示例性地,该k个服务器中经过第t次迭代计算后得到的k个模型参数为:{Wi t|i=1,...,k}。
具体而言,模型训练装置可以先将含有n张图像的图像集,按照图像数据划分成k等分并得到k个图像集,每个图像集含有n/k张图像,并将k个图像集上传至k个服务器;然后,由各个服务器对各个服务器的图像集中的图像进行图像处理并得到各个服务器的样本数据集,该各个服务器的样本数据集可成标签与特征数据,k个数据集分别表示为:{Yi,Xi}[n/k]×m。其中,Yi表示第i个服务器的样本数据集中的标签数据,Xi表示第i个服务器的样本数据集中的特征数据,n/k表示第i个服务器的样本集中包括的图像的数量,m表示特征数据和标签数据的维度;接着,模型训练装置基于该k个服务器中经过第t次迭代计算后得到的k个模型参数确定待迭代的模型参数,并利用待迭代的模型参数更新该各个服务器中的模型参数,然后基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,以得到该k个服务器中经过第t次迭代计算后得到的k个模型参数,以此类推,直至得到训练后的k个图像分类模型。
本申请实施例中,利用k个服务器分别对该k个图像集进行图像处理以及利用各个服务器对各个服务器的样本数据集进行迭代计算的方式训练图像分类模型,相当于,通过各个服务器分摊模型训练所需的计算量,不仅能够合理利用各个服务器中的空闲资源,进而降低了该k个图像分类模型的训练成本,还能够降低每次迭代计算对计算资源和存储资源的要求以及降低每次迭代计算的计算量。此外,基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型,相当于,在第t+1次迭代计算中,可以通过共享该k个服务器中经过第t次迭代计算后得到的k个模型参数,来实现对该k个图像分类模型的联合训练,能够保证该k个图像分类模型的训练效果。
简言之,本申请实施例提供的模型训练方法能够在降低图像分类模型的训练成本、每次迭代计算对计算资源和存储资源的要求以及每次迭代计算的计算量的基础上保证图像分类模型的训练效果。
在一些实施例中,该S230可包括:
模型训练装置可先利用该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,更新该各个服务器中的模型参数;然后基于该各个服务器的样本数据集,对该各个服务器中的模型参数进行第t+1次迭代计算,直至该各个服务器中的模型参数的迭代次数等于预设迭代次数或该各个服务器中的模型参数的下降梯度为零时,得到该k个图像分类模型。
示例性地,模型训练装置可先利用更新该各个服务器中的模型参数;然后基于该各个服务器的样本数据集,对该各个服务器中的模型参数进行第t+1次迭代计算,直至该各个服务器中的模型参数的迭代次数等于预设迭代次数或该各个服务器中的模型参数的下降梯度为零时,得到该k个图像分类模型。
本实施例中,先利用该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,更新该各个服务器中的模型参数;然后基于该各个服务器的样本数据集,对该各个服务器中的模型参数进行第t+1次迭代计算,不仅能够实现对k个图像分类模型的联合训练,还能够提升该k个图像分类模型的收敛速度以及训练速度。
在一些实施例中,将该各个服务器中的模型参数与该各个服务器的样本数据集中的特征数据相乘后减去该各个服务器的样本数据集中的标签数据,并将得到的结果与该各个服务器的样本数据集中的特征数据进行相乘,得到该各个服务器中的模型参数的梯度;将该各个服务器中的模型参数的学习率和该各个服务器中的模型参数的梯度进行相乘,得到该各个服务器中的模型参数的修正值;利用该各个服务器中的模型参数减去该各个服务器中的模型参数的修正值,得到该k个服务器中经过第t+1次迭代计算后得到的k个模型参数;若t+1等于该预设迭代次数或该各个服务器中的模型参数的梯度为零时,则得到该k个图像分类模型;若t+1小于该预设值或该各个服务器中的模型参数的梯度为零时,则基于该各个服务器的样本数据集和该k个服务器中经过第t+1次迭代计算后得到的k个模型参数,对该各个服务器中的模型参数进行第t+2次迭代计算,直至该各个服务器中的模型参数的迭代次数等于预设迭代次数或该各个服务器中的模型参数的下降梯度为零时,得到该k个图像分类模型。
示例性地,在第t+1次迭代计算中,可以利用该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数构建第t+1次迭代计算的迭代模型:
其中,X表示第i个服务器中的特征数据,Yi表示横向联邦学习中的第i个服务器(master worker)中的标签数据,η表示学习率,Wi t表示第i个服务器中经过第t次迭代后的模型参数,表示该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数。
换言之,模型训练装置可先利用更新该各个服务器中的模型参数;然后利用该各个服务器按照/>对该各个服务器中的模型参数进行迭代计算,直至该各个服务器中的模型参数的迭代次数等于预设迭代次数或该各个服务器中的模型参数的下降梯度为零时,得到该k个图像分类模型。值得注意的是,该各个服务器进行第1次迭代计算时,可利用该各个服务器中的模型参数的初始值的最小值作为待迭代参数进行迭代计算,即各个服务器可通过/>对该各个服务器中的模型参数进行第1次迭代计算,并得到该k个服务器中经过第1次迭代计算后得到的k个模型参数。
在一些实施例中,该S230可包括:
对该各个服务器的样本数据集进行抽样,得到该各个服务器的训练数据集和该各个服务器的测试数据集;利用该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,更新该各个服务器中的模型参数;基于该各个服务器的训练数据集,对该各个服务器中的模型参数进行第t+1次迭代计算,得到该k个服务器中经过第t+1次迭代计算后得到的k个模型参数;基于该各个服务器的测试数据集,对该各个服务器中的模型参数的损失进行测试,得到用于表征该k个服务器中的模型参数的损失的目标损失值;确定该目标损失值是否满足测试要求;若该目标损失值满足该测试要求,则得到该k个图像分类模型;若该目标损失值不满足测试要求,则利用该k个服务器中经过第t+1次迭代计算后得到的k个模型参数中的最小参数更新该各个服务器中的模型参数,并基于该各个服务器的训练数据集,对该各个服务器中的模型参数进行第t+2次迭代计算,直至该目标损失值满足该测试要求时,得到该k个图像分类模型。
示例性地,模型训练装置可先利用更新该各个服务器中的模型参数;然后基于该各个服务器的训练数据集,对该各个服务器中的模型参数进行第t+1次迭代计算,以得到{Wi t+1|i=1,...,k};接着基于该各个服务器的测试数据集,对{Wi t+1|i=1,...,k}的损失进行测试,得到目标损失值,并确定该目标损失值是否满足测试要求;若该目标损失值满足该测试要求,则得到该k个图像分类模型;若该目标损失值不满足测试要求,则进行下一次迭代计算,直至该目标损失值满足该测试要求时,得到该k个图像分类模型。
本实施例中,通过对各个服务器的样本数据集进行划分得到各个服务器的训练数据集和各个服务器的测试数据集,一方面,通过各个服务器的训练数据集能够实现对各个服务器中的模型参数进行联合训练,另一方面,通过各个服务器的测试数据集能够实现对训练效果的联合测试,即能够使得测试过程和训练过程保持一致,进而,能够提升该目标损失值的准确度以及能够保证图像分类模型的训练效果。
值得注意的是,模型训练装置在基于该各个服务器的测试数据集,对该各个服务器中的模型参数的损失进行测试时,可以先利用该k个服务器中经过第t+1次迭代计算后得到的k个模型参数中的最小参数更新该各个服务器的模型参数。当然,在其他可替代实施例中,也可以直接对该k个服务器中经过第t+1次迭代计算后得到的k个模型参数进行损失测试,本申请对此不作具体限定。
在一些实施例中,模型训练装置基于该各个服务器的样本数据集的数据类型,按照预设比例通过对该各个服务器的样本数据集进行抽样,得到该各个服务器的训练数据集和该各个服务器的测试数据集。
示例性地,该各个服务器的样本数据集的数据类型包括适用于横向联邦学习模型的数据类型和适用于纵向联邦学习模型的数据类型。
本实施例中,模型训练装置基于该各个服务器的样本数据集的数据类型,对该各个服务器的样本数据集进行抽样,能够使得该各个服务器的训练数据集和该各个服务器的测试数据集的数据结构符合联邦学习模型要求的数据结构,进而能够提升对图像分类模型的训练效果。
在一些实施例中,若该各个服务器的样本数据集的数据类型为适用于横向联邦学习模型的数据类型,则按照该预设比例通过对该各个服务器的样本数据集中的特征数据和标签数据进行抽样,得到该各个服务器的训练数据集和该各个服务器的测试数据集;若该各个服务器的样本数据集的数据类型为适用于纵向联邦学习模型的数据类型,则按照该预设比例通过对该k个服务器中的第一服务器的样本数据集中的特征数据和标签数据进行抽样,并按照该预设比例通过对该k个服务器中除该第一服务器之外的服务器器的样本数据集中的特征数据进行抽样,以得到该各个服务器的训练数据集和该各个服务器的测试数据集。
示例性地,若该k个服务器的样本数据集为:{Y1,X1}[n/k]×m,{Y2,X2}[n/k]×m,…,{Yk,Xk}[n/k]×m,则可以利用各个服务器按照一定比例对各个服务器的样本数据集进行随机抽样,划分为训练数据集和测试数据集。
若该各个服务器的样本数据集的数据类型为适用于横向联邦学习模型的数据类型,则通过对第1个服务器的样本数据集{Y1,X1}[n/k]×m进行抽样得到训练数据集和测试数据集/>其中d1表示训练数据集的比例,d2表示测试数据集的比例。可选的,可通过设置d1>d2来保证第1个服务器的训练效果;类似的,可同样按照一定比例对第i个服务器(i=2,…,k)中的样本数据集进行抽样,得到该第i个服务器的训练数据集/>和第i个服务器的测试数据集若该各个服务器的样本数据集的数据类型为适用于纵向联邦学习模型的数据类型,则可以通过对第1个服务器的样本数据集{Y1,X1}[n/k]×m进行抽样得到训练数据集/>和测试数据集/>其中d1表示训练数据集的比例,d2表示测试数据集的比例。可选的,可通过设置d1>d2来保证第1个服务器的训练效果。类似的,可同样按照一定比例对第i个服务器(i=2,…,k)中的样本数据集进行抽样,得到该第i个服务器的训练数据集/>和第i个服务器的测试数据集/>
在一些实施例中,利用该各个服务器中的模型参数乘以该各个服务器的测试数据集中的特征数据,得到该各个服务器的中间预测结果;利用二分类函数对该各个服务器的中间预测结果进行计算,得到该各个服务器输出的图像分类结果;基于该各个服务器输出的图像分类结果和该各个服务器的测试数据集中的标签数据,确定该目标损失值。
示例性地,假设第1个服务器的训练数据集为第1个服务器的测试数据集为/>该第i个服务器(i=2,…,k)中的训练数据集为该第i个服务器(i=2,…,k)中的测试数据集为
在模型训练过程中,模型训练装置可利用第1个服务器,按照基于第1个服务器的训练数据集/>对第1个服务器的模型参数进行第t+1次迭代计算,并利用第i个服务器,按照基于第i个服务器的训练数据集/>对第i个服务器的模型参数进行第t+1次迭代计算,以得到{Wi t+1|i=1,...,k}。
在模型测试过程中,模型训练装置可利用第1个服务器,按照基于第1个服务器的训练数据集中的特征数据/>对第1个服务器的模型参数进行测试并得到该第1个服务器输出的图像分类结果;此外,该模型训练装置可利用第i个服务器,按照/>基于第i个服务器的训练数据集中的特征数据/>对第i个服务器的模型参数进行测试并得到该第i个服务器输出的图像分类结果。其中,sigmod()表示该二分类函数,/>表示该k个服务器中经过第t+1次迭代计算后得到的k个模型参数中的最小参数,/>表示第i个服务器的测试数据集。
在一些实施例中,利用该各个服务器的测试数据集中的标签数据减去该各个服务器输出的图像分类结果,得到该各个服务器中的模型参数的损失特征数据;对该各个服务器中的模型参数的损失特征数据进行转置,并利用转置后的数据与该各个服务器中的模型参数的损失特征数据进行相乘,得到该各个服务器中的模型参数的损失值;将该各个服务器中的模型参数的损失值的平均值,确定为该目标损失值。
示例性地,模型训练装置获取各个服务器输出的图像分类结果后,可按照计算该各个服务器中的模型参数的损失值。其中,/>表示第i个服务器输出的图像分类结果,Yi test表示第i个服务器的训练数据集中的标签数据。
当然,在其他可替代实施例中,模型训练装置也可以将该各个服务器中的模型参数的损失值的最小值或最大值,确定为该目标损失值,也可以对该各个服务器中的模型参数的损失值进行加权计算,以得到该目标损失值,本申请对此不作具体限定。
在一些实施例中,若该目标损失值小于或等于预设阈值,则确定该图像分类结果满足该测试要求;否则,确定该图像分类结果不满足该测试要求。
当然,模型训练装置也可以按照其他方式确定该图像分类结果是否满足该测试要求。例如,在其他可替代实施例中,若该目标损失值在预设范围内,则确定该图像分类结果满足该测试要求;否则,确定该图像分类结果不满足该测试要求。
在一些实施例中,该各个服务器中的模型参数的初始值相同,且该k个图像集中的各个图像集中包括n/k张图像。
本实施例中,将该各个服务器中的模型参数的初始值设计为相同,能够保证每一次迭代计算可以使用的迭代模型,进而能够降低模型训练的复杂度。此外,将该k个图像集中的各个图像集设计为包括n/k张图像,能够保证该各个服务器在每次迭代计算中各个服务器采用的样本数据集的数量相等,与该各个服务器的样本数据集不相等的情况相比,避免了该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,为基于较少的样本数据集进行迭代计算得到的模型参数,相当于,避免了利用基于较少的样本数据集进行迭代计算得到的模型参数,更新该各个服务器中的模型参数,进而能够保证该k个图像分类模型进行联合训练时的训练效果。
图3是本申请实施例提供的模型训练方法300的示意性流程图。该方法300可以由模型训练装置执行,该模型训练装置可与k个服务器相连,该k个服务器可以用于为该模型训练装置提供计算服务,该k个服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是云服务器,其中,该k个服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。
如图3所示,该模型训练方法300可包括:
S301,模型训练装置获取各个服务器的图像集。
示例性地,模型训练装置图像采集车或用户上传的n张图像,并将该n张图像等分划分为该k个图像集,即该k个图像集中的各个图像集中包括n/k张图像。
S302,该模型训练装置利用各个服务器获取各个服务器的样本数据集。
示例性地,利用k个服务器分别对该k个图像集进行图像处理,以得到该k个服务器中的各个服务器的样本数据集。例如,该模型训练装置可以将该k个数据集分别上传至该k个服务器,并利用该k个服务器中各个服务器对该各个服务器获取的图像集中的图像进行图像处理,以得到各个服务器的样本数据集。
S3031,服务器1将样本数据集1存储至服务器1。
示例性地,服务器1可将样本数据集1存储至服务器1的本地存储空间内。
S3041,服务器1初始化模型的参数为W1。
示例性地,W1可以是随机值,也可以是预定定义的值。可选的,预先定义的值也可以是由用户预先设置的值,或模型训练装置配置的默认值。本申请对此不作具体限定。
示例性地, 表示各个服务器中的模型参数的初始值的最小值。
S3051,服务器1利用梯度下降算法进行第1次迭代计算。
示例性地,服务器1进行第1次迭代计算时,可利用各个服务器中的模型参数的初始值的最小值作为待迭代参数进行迭代计算,即服务器1可基于样本数据集1,按照对该服务器1中的模型参数进行第1次迭代计算,并得到该服务器1中经过第1次迭代计算后得到的模型参数。
S3032,服务器2将样本数据集2存储至服务器1。
示例性地,服务器2可将样本数据集2存储至服务器2的本地存储空间内。
S3042,服务器2初始化模型的参数为W2。
示例性地,W2可以是随机值,也可以是预定定义的值。可选的,预先定义的值也可以是由用户预先设置的值,或模型训练装置配置的默认值。本申请对此不作具体限定。
示例性地, 表示各个服务器中的模型参数的初始值的最小值。
S3052,服务器2利用梯度下降算法进行第1次迭代计算。
示例性地,服务器2进行第1次迭代计算时,可利用各个服务器中的模型参数的初始值的最小值作为待迭代参数进行迭代计算,即服务器2可基于样本数据集2,按照对该服务器2中的模型参数进行第1次迭代计算,并得到该服务器2中经过第1次迭代计算后得到的模型参数。
S303k,服务器k将样本数据集k存储至服务器k。
示例性地,服务器k可将样本数据集k存储至服务器k的本地存储空间内。
S304k,服务器k初始化模型的参数为W1。
示例性地,Wk可以是随机值,也可以是预定定义的值。可选的,预先定义的值也可以是由用户预先设置的值,或模型训练装置配置的默认值。本申请对此不作具体限定。
示例性地, 表示各个服务器中的模型参数的初始值的最小值。
S305k,服务器k利用梯度下降算法进行第1次迭代计算。
示例性地,服务器k进行第1次迭代计算时,可利用各个服务器中的模型参数的初始值的最小值作为待迭代参数进行迭代计算,即服务器k可基于样本数据集k,按照对该服务器k中的模型参数进行第1次迭代计算,并得到该服务器k中经过第1次迭代计算后得到的模型参数。
S306,该模型训练装置选择各个服务器中经过第1次迭代计算后得到的模型参数中的最小值。
示例性地,假设各个服务器中经过第1次迭代计算后得到的模型参数可表示为{Wi 1|i=1,...,k},则该模型训练装置获取{Wi 1|i=1,...,k}后,可通过min{Wi 1|i=1,...,k}计算{Wi 1|i=1,...,k}的最小值
S3071,服务器1利用梯度下降算法进行第2次迭代计算。
示例性地,服务器1进行第2次迭代计算时,可利用各个服务器中经过第1次迭代计算后的模型参数的最小值(即)作为待迭代参数进行迭代计算,即服务器1可基于样本数据集1,按照/>对该服务器1中的模型参数进行第2次迭代计算,并得到该服务器2中经过第1次迭代计算后得到的模型参数。
S3072,服务器2利用梯度下降算法进行第2次迭代计算。
示例性地,服务器2进行第2次迭代计算时,可利用各个服务器中经过第1次迭代计算后的模型参数的最小值(即)作为待迭代参数进行迭代计算,即服务器2可基于样本数据集2,按照/>对该服务器2中的模型参数进行第2次迭代计算,并得到该服务器2中经过第2次迭代计算后得到的模型参数。
S307k,服务器k利用梯度下降算法进行第2次迭代计算。
示例性地,服务器k进行第2次迭代计算时,可利用各个服务器中经过第1次迭代计算后的模型参数的最小值(即)作为待迭代参数进行迭代计算,即服务器k可基于样本数据集k,按照/>对该服务器k中的模型参数进行第2次迭代计算,并得到该服务器k中经过第2次迭代计算后得到的模型参数。
S308,该模型训练装置选择各个服务器中经过第H-1次迭代计算后得到的模型参数中的最小值。
示例性地,假设各个服务器中经过第H-1次迭代计算后得到的模型参数可表示为{Wi H-1|i=1,...,k},则该模型训练装置获取{Wi H-1|i=1,...,k}后,可通过min{Wi H-1|i=1,...,k}计算{Wi H-1|i=1,...,k}的最小值
S3091,服务器1利用梯度下降算法进行第H次迭代计算。
示例性地,服务器1进行第H次迭代计算时,可利用各个服务器中经过第H-1次迭代计算后的模型参数的最小值(即)作为待迭代参数进行迭代计算,即可基于样本数据集1,按照/>对该服务器1中的模型参数进行第H次迭代计算,并得到该服务器1中经过第H次迭代计算后得到的模型参数。
S3092,服务器2利用梯度下降算法进行第H次迭代计算。
示例性地,服务器2进行第H次迭代计算时,可利用各个服务器中经过第H-1次迭代计算后的模型参数的最小值(即)作为待迭代参数进行迭代计算,即可基于样本数据集2,按照/>对该服务器2中的模型参数进行第H次迭代计算,并得到该服务器2中经过第H次迭代计算后得到的模型参数。
S309k,服务器k利用梯度下降算法进行第H次迭代计算。
示例性地,服务器k进行第H次迭代计算时,可利用各个服务器中经过第H-1次迭代计算后的模型参数的最小值(即)作为待迭代参数进行迭代计算,即可基于样本数据集k,按照/>对该服务器k中的模型参数进行第H次迭代计算,并得到该服务器k中经过第H次迭代计算后得到的模型参数。
S310,该模型训练装置选择各个服务器中经过第H次迭代计算后得到的模型参数中的最小值。
示例性地,假设H等于预设的迭代次数,各个服务器中经过第H次迭代计算后得到的模型参数可表示为{Wi H|i=1,...,k},则该模型训练装置获取{Wi H|i=1,...,k}后,可通过min{Wi H|i=1,...,k}计算{Wi H|i=1,...,k}的最小值进一步的,该模型训练装置可利用/>更新各个服务器中的模型参数,以得到各个服务器中训练后的图像分类模型。
本实施例中,利用k个服务器分别对该k个图像集进行图像处理以及利用各个服务器对各个服务器的样本数据集进行迭代计算的方式训练图像分类模型,相当于,通过各个服务器分摊模型训练所需的计算量,不仅能够合理利用各个服务器中的空闲资源,进而降低了该k个图像分类模型的训练成本,还能够降低每次迭代计算对计算资源和存储资源的要求以及降低每次迭代计算的计算量。此外,基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型,相当于,在第t+1次迭代计算中,可以通过共享该k个服务器中经过第t次迭代计算后得到的k个模型参数,来实现对该k个图像分类模型的联合训练,能够保证该k个图像分类模型的训练效果。简言之,本申请实施例提供的模型训练方法能够在降低图像分类模型的训练成本、每次迭代计算对计算资源和存储资源的要求以及每次迭代计算的计算量的基础上保证图像分类模型的训练效果。
图4是本申请实施例提供的模型训练方法400涉及的各个阶段的示意图。该方法400可以由模型训练装置执行,该模型训练装置可与k个服务器相连,该k个服务器可以用于为该模型训练装置提供计算服务,该k个服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是云服务器,其中,该k个服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。
如图4所示,本申请实施例提供的模型训练方法400涉及的以下七个阶段:
410、图像获取及图像集划分阶段。
在图像获取及图像集划分阶段,模型训练装置图像采集车或用户上传的n张图像,并将该n张图像等分划分为该k个图像集,即该k个图像集中的各个图像集中包括n/k张图像。
420、图像处理阶段。
在图像处理阶段,模型训练装置利用k个服务器分别对该k个图像集进行图像处理,以得到该k个服务器中的各个服务器的样本数据集。示例性地,该模型训练装置可以将该k个数据集分别上传至该k个服务器,并利用该k个服务器中各个服务器对该各个服务器获取的图像集中的图像进行图像处理,以得到各个服务器的样本数据集。
430、训练数据集和测试数据集的构建阶段。
在训练数据集和测试数据集的构建阶段,模型训练装置基于该各个服务器的样本数据集的数据类型,按照预设比例通过对该各个服务器的样本数据集进行抽样,得到该各个服务器的训练数据集和该各个服务器的测试数据集。其中,该各个服务器的样本数据集的数据类型包括适用于横向联邦学习模型的数据类型和适用于纵向联邦学习模型的数据类型。
示例性地,假设该k个服务器的样本数据集为:{Y1,X1}[n/k]×m,{Y2,X2}[n/k]×m,…,{Yk,Xk}[n/k]×m,则该模型训练装置可以利用各个服务器按照一定比例对各个服务器的样本数据集进行随机抽样,划分为训练数据集和测试数据集。
若该各个服务器的样本数据集的数据类型为适用于横向联邦学习模型的数据类型,则通过对第1个服务器的样本数据集{Y1,X1}[n/k]×m进行抽样得到训练数据集和测试数据集/>其中d1表示训练数据集的比例,d2表示测试数据集的比例。可选的,可通过设置d1>d2来保证第1个服务器的训练效果;类似的,可同样按照一定比例对第i个服务器(i=2,…,k)中的样本数据集进行抽样,得到该第i个服务器的训练数据集/>和第i个服务器的测试数据集若该各个服务器的样本数据集的数据类型为适用于纵向联邦学习模型的数据类型,则可以通过对第1个服务器的样本数据集{Y1,X1}[n/k]×m进行抽样得到训练数据集/>和测试数据集/>其中d1表示训练数据集的比例,d2表示测试数据集的比例。可选的,可通过设置d1>d2来保证第1个服务器的训练效果。类似的,可同样按照一定比例对第i个服务器(i=2,…,k)中的样本数据集进行抽样,得到该第i个服务器的训练数据集/>和第i个服务器的测试数据集/>
440、梯度下降算法的构建阶段。
示例性地,在第t+1次迭代计算中,可以利用该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数构建第t+1次迭代计算的迭代模型:
其中,X表示第i个服务器中的特征数据,Yi表示横向联邦学习中的第i个服务器(master worker)中的标签数据,η表示学习率,Wi t表示第i个服务器中经过第t次迭代后的模型参数,表示该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数。
换言之,模型训练装置可先利用更新该各个服务器中的模型参数;然后利用该各个服务器按照/>对该各个服务器中的模型参数进行迭代计算,直至该各个服务器中的模型参数的迭代次数等于预设迭代次数或该各个服务器中的模型参数的下降梯度为零时,得到该k个图像分类模型。值得注意的是,该各个服务器进行第1次迭代计算时,可利用该各个服务器中的模型参数的初始值的最小值作为待迭代参数进行迭代计算,即各个服务器可通过/>对该各个服务器中的模型参数进行第1次迭代计算,并得到该k个服务器中经过第1次迭代计算后得到的k个模型参数。
450、模型训练阶段。
示例性地,假设第1个服务器的训练数据集为第1个服务器的测试数据集为/>该第i个服务器(i=2,…,k)中的训练数据集为该第i个服务器(i=2,…,k)中的测试数据集为
在模型训练过程中,模型训练装置可利用第1个服务器,按照基于第1个服务器的训练数据集/>对第1个服务器的模型参数进行第t+1次迭代计算,并利用第i个服务器,按照基于第i个服务器的训练数据集/>对第i个服务器的模型参数进行第t+1次迭代计算,以得到{Wi t+1|i=1,...,k}。
460、模型测试阶段。
示例性地,假设第1个服务器的训练数据集为第1个服务器的测试数据集为/>该第i个服务器(i=2,…,k)中的训练数据集为该第i个服务器(i=2,…,k)中的测试数据集为
在模型测试过程中,模型训练装置可利用第1个服务器,按照基于第1个服务器的训练数据集中的特征数据/>对第1个服务器的模型参数进行测试并得到该第1个服务器输出的图像分类结果;此外,该模型训练装置可利用第i个服务器,按照/>基于第i个服务器的训练数据集中的特征数据/>对第i个服务器的模型参数进行测试并得到该第i个服务器输出的图像分类结果。其中,sigmod()表示该二分类函数,/>表示该k个服务器中经过第t+1次迭代计算后得到的k个模型参数中的最小参数,/>表示第i个服务器的测试数据集。进一步的,模型训练装置获取各个服务器输出的图像分类结果后,可按照计算该各个服务器中的模型参数的损失值。其中,/>表示第i个服务器输出的图像分类结果,Yi test表示第i个服务器的训练数据集中的标签数据。
470、预测阶段。
示例性地,该模型训练装置获取目标图像后,利用k个服务器中的目标服务器对该目标图像进行图像处理,以得到该目标图像的特征数据;并基于该目标图像的特征数据,利用该目标服务器中的图像分类模型,对该目标图像进行分类预测,得到该目标图像的图像分类结果。
本实施例中,利用k个服务器分别对该k个图像集进行图像处理以及利用各个服务器对各个服务器的样本数据集进行迭代计算的方式训练图像分类模型,相当于,通过各个服务器分摊模型训练所需的计算量,不仅能够合理利用各个服务器中的空闲资源,进而降低了该k个图像分类模型的训练成本,还能够降低每次迭代计算对计算资源和存储资源的要求以及降低每次迭代计算的计算量。此外,基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型,相当于,在第t+1次迭代计算中,可以通过共享该k个服务器中经过第t次迭代计算后得到的k个模型参数,来实现对该k个图像分类模型的联合训练,能够保证该k个图像分类模型的训练效果。简言之,本申请实施例提供的模型训练方法能够在降低图像分类模型的训练成本、每次迭代计算对计算资源和存储资源的要求以及每次迭代计算的计算量的基础上保证图像分类模型的训练效果。
另外,通过对各个服务器的样本数据集进行划分得到各个服务器的训练数据集和各个服务器的测试数据集,一方面,通过各个服务器的训练数据集能够实现对各个服务器中的模型参数进行联合训练,另一方面,通过各个服务器的测试数据集能够实现对训练效果的联合测试,即能够使得测试过程和训练过程保持一致,进而,能够提升该目标损失值的准确度以及能够保证图像分类模型的训练效果。此外,模型训练装置基于该各个服务器的样本数据集的数据类型,对该各个服务器的样本数据集进行抽样,能够使得该各个服务器的训练数据集和该各个服务器的测试数据集的数据结构符合联邦学习模型要求的数据结构,进而能够提升对图像分类模型的训练效果。
图5是本申请实施例提供的分类预测方法500的示意性流程图。该方法300可以由分类预测装置执行,该分类预测装置可与k个服务器相连,该k个服务器可以用于为该模型训练装置提供计算服务,该k个服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是云服务器,其中,该k个服务器可以通过有线或无线通信方式进行直接或间接地连接,本申请在此不做限制。当然,在一些实施例中,上文涉及的模型训练装置在预测阶段,也可作为分量预测装置,本申请对此不作具体限定。
如图5所示,该分类预测方法500可包括:
S510,分类预测装置获取目标图像。
S520,分类预测装置利用k个服务器中的目标服务器对该目标图像进行图像处理,以得到该目标图像的特征数据;其中,该k个服务器配置有k个图像分类模型;应当理解,该k个图像分类模型可以按照上文涉及的模型训练方法得到的图像分类模型,为避免重复,此处不再赘述。
S530,分类预测装置基于该目标图像的特征数据,利用该目标服务器中的图像分类模型,对该目标图像进行分类预测,得到该目标图像的图像分类结果。
示例性地,该分类预测装置可基于目标图像的数量,确定目标服务器的数量。例如,若目标图像包括大于k的j张图像时,可以确定目标服务器为该k个服务器。此时,该分类预测装置可先将j张图像划分为k个图像集,并利用该k个服务器对该k个图像集中的图像进行预测,以提升该j张图像的预测效率。
以上结合附图详细描述了本申请的优选实施方式,但是,本申请并不限于上文涉及的实施方式中的具体细节,在本申请的技术构思范围内,可以对本申请的技术方案进行多种简单变型,这些简单变型均属于本申请的保护范围。例如,在上文涉及的具体实施方式中所描述的各个具体技术特征,在不矛盾的情况下,可以通过任何合适的方式进行组合,为了避免不必要的重复,本申请对各种可能的组合方式不再另行说明。又例如,本申请的各种不同的实施方式之间也可以进行任意组合,只要其不违背本申请的思想,其同样应当视为本申请所公开的内容。
还应理解,在本申请的各种方法实施例中,上文涉及的各过程的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。
上文对本申请实施例提供的方法进行了说明,下面对本申请实施例提供的装置进行说明。
图6是本申请实施例提供的模型训练装置600的示意性框图。
如图6所示,该模型训练装置600可包括:
获取单元610,用于获取n张图像,并对该n张图像进行划分得到k个图像集;n>k>1;
处理单元620,用于利用k个服务器分别对该k个图像集进行图像处理,以得到该k个服务器中的各个服务器的样本数据集;
训练单元630,用于基于该k个服务器中经过第t次迭代计算后得到的k个模型参数更新该各个服务器中的模型参数,并基于该各个服务器的样本数据集对该各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型。
在一些实施例中,该训练单元630具体用于:
利用该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,更新该各个服务器中的模型参数;
基于该各个服务器的样本数据集,对该各个服务器中的模型参数进行第t+1次迭代计算,直至该各个服务器中的模型参数的迭代次数等于预设迭代次数或该各个服务器中的模型参数的下降梯度为零时,得到该k个图像分类模型。
在一些实施例中,该训练单元630具体用于:
将该各个服务器中的模型参数与该各个服务器的样本数据集中的特征数据相乘后减去该各个服务器的样本数据集中的标签数据,并将得到的结果与该各个服务器的样本数据集中的特征数据进行相乘,得到该各个服务器中的模型参数的梯度;
将该各个服务器中的模型参数的学习率和该各个服务器中的模型参数的梯度进行相乘,得到该各个服务器中的模型参数的修正值;
利用该各个服务器中的模型参数减去该各个服务器中的模型参数的修正值,得到该k个服务器中经过第t+1次迭代计算后得到的k个模型参数;
若t+1等于该预设迭代次数或该各个服务器中的模型参数的梯度为零时,则得到该k个图像分类模型;
若t+1小于该预设值或该各个服务器中的模型参数的梯度为零时,则基于该各个服务器的样本数据集和该k个服务器中经过第t+1次迭代计算后得到的k个模型参数,对该各个服务器中的模型参数进行第t+2次迭代计算,直至该各个服务器中的模型参数的迭代次数等于预设迭代次数或该各个服务器中的模型参数的下降梯度为零时,得到该k个图像分类模型。
在一些实施例中,该训练单元630具体用于:
对该各个服务器的样本数据集进行抽样,得到该各个服务器的训练数据集和该各个服务器的测试数据集;
利用该k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,更新该各个服务器中的模型参数;
基于该各个服务器的训练数据集,对该各个服务器中的模型参数进行第t+1次迭代计算,得到该k个服务器中经过第t+1次迭代计算后得到的k个模型参数;
基于该各个服务器的测试数据集,对该各个服务器中的模型参数的损失进行测试,得到用于表征该k个服务器中的模型参数的损失的目标损失值;
确定该目标损失值是否满足测试要求;
若该目标损失值满足该测试要求,则得到该k个图像分类模型;
若该目标损失值不满足测试要求,则利用该k个服务器中经过第t+1次迭代计算后得到的k个模型参数中的最小参数更新该各个服务器中的模型参数,并基于该各个服务器的训练数据集,对该各个服务器中的模型参数进行第t+2次迭代计算,直至该目标损失值满足该测试要求时,得到该k个图像分类模型。
在一些实施例中,该训练单元630具体用于:
基于该各个服务器的样本数据集的数据类型,按照预设比例通过对该各个服务器的样本数据集进行抽样,得到该各个服务器的训练数据集和该各个服务器的测试数据集。
在一些实施例中,该训练单元630具体用于:
若该各个服务器的样本数据集的数据类型为适用于横向联邦学习模型的数据类型,则按照该预设比例通过对该各个服务器的样本数据集中的特征数据和标签数据进行抽样,得到该各个服务器的训练数据集和该各个服务器的测试数据集;
若该各个服务器的样本数据集的数据类型为适用于纵向联邦学习模型的数据类型,则按照该预设比例通过对该k个服务器中的第一服务器的样本数据集中的特征数据和标签数据进行抽样,并按照该预设比例通过对该k个服务器中除该第一服务器之外的服务器器的样本数据集中的特征数据进行抽样,以得到该各个服务器的训练数据集和该各个服务器的测试数据集。
在一些实施例中,该训练单元630具体用于:
利用该各个服务器中的模型参数乘以该各个服务器的测试数据集中的特征数据,得到该各个服务器的中间预测结果;
利用二分类函数对该各个服务器的中间预测结果进行计算,得到该各个服务器输出的图像分类结果;
基于该各个服务器输出的图像分类结果和该各个服务器的测试数据集中的标签数据,确定该目标损失值。
在一些实施例中,该训练单元630具体用于:
利用该各个服务器的测试数据集中的标签数据减去该各个服务器输出的图像分类结果,得到该各个服务器中的模型参数的损失特征数据;
对该各个服务器中的模型参数的损失特征数据进行转置,并利用转置后的数据与该各个服务器中的模型参数的损失特征数据进行相乘,得到该各个服务器中的模型参数的损失值;
将该各个服务器中的模型参数的损失值的平均值,确定为该目标损失值。
在一些实施例中,该训练单元630具体用于:
若该目标损失值小于或等于预设阈值,则确定该图像分类结果满足该测试要求;
否则,确定该图像分类结果不满足该测试要求。
在一些实施例中,该各个服务器中的模型参数的初始值相同,且该k个图像集中的各个图像集中包括n/k张图像。
应理解,装置实施例与方法实施例可以相互对应,类似的描述可以参照方法实施例。为避免重复,此处不再赘述。具体地,模型训练装置600可以对应于执行本申请实施例的方法200~400中的相应主体,并且模型训练装置600中的各个单元分别为了实现方法200~400中的相应流程,为了简洁,在此不再赘述。
图7是本申请实施例提供的分类预测装置700的示意性框图。
如图7所示,该分类预测装置700可包括:
获取单元710,用于获取目标图像;
处理单元720,用于利用k个服务器中的目标服务器对该目标图像进行图像处理,以得到该目标图像的特征数据;其中,该k个服务器配置有k个图像分类模型;
预测单元720,用于基于该目标图像的特征数据,利用该目标服务器中的图像分类模型,对该目标图像进行分类预测,得到该目标图像的图像分类结果。
应理解,装置实施例与方法实施例可以相互对应,类似的描述可以参照方法实施例。为避免重复,此处不再赘述。具体地,该分类预测装置700可以对应于执行本申请实施例的方法500中的相应主体,并且该分类预测装置700中的各个单元分别为了实现方法500中的相应流程,为了简洁,在此不再赘述。
还应当理解,本申请实施例涉及的模型训练装置600或分类预测装置700中的各个单元可以分别或全部合并为一个或若干个另外的单元来构成,或者其中的某个(些)单元还可以再拆分为功能上更小的多个单元来构成,这可以实现同样的操作,而不影响本申请的实施例的技术效果的实现。上文涉及的单元是基于逻辑功能划分的,在实际应用中,一个单元的功能也可以由多个单元来实现,或者多个单元的功能由一个单元实现。在本申请的其它实施例中,该模型训练装置600或分类预测装置700也可以包括其它单元,在实际应用中,这些功能也可以由其它单元协助实现,并且可以由多个单元协作实现。根据本申请的另一个实施例,可以通过在包括例如中央处理单元(CPU)、随机存取存储介质(RAM)、只读存储介质(ROM)等处理元件和存储元件的通用计算机的通用计算设备上运行能够执行相应方法所涉及的各步骤的计算机程序(包括程序代码),来构造本申请实施例涉及的模型训练装置600或分类预测装置700,以及来实现本申请实施例提供的方法。计算机程序可以记载于例如计算机可读存储介质上,并通过计算机可读存储介质装载于电子设备中,并在其中运行,来实现本申请实施例的相应方法。
换言之,上文涉及的单元可以通过硬件形式实现,也可以通过软件形式的指令实现,还可以通过软硬件结合的形式实现。具体地,本申请实施例中的方法实施例的各步骤可以通过处理器中的硬件的集成逻辑电路和/或软件形式的指令完成,结合本申请实施例公开的方法的步骤可以直接体现为硬件译码处理器执行完成,或者用译码处理器中的硬件及软件组合执行完成。可选地,软件可以位于随机存储器,闪存、只读存储器、可编程只读存储器、电可擦写可编程存储器、寄存器等本领域的成熟的存储介质中。该存储介质位于存储器,处理器读取存储器中的信息,结合其硬件完成上文涉及的方法实施例中的步骤。
图8是本申请实施例提供的电子设备800的示意结构图。
如图8所示,该电子设备800至少包括处理器810以及计算机可读存储介质820。其中,处理器810以及计算机可读存储介质820可通过总线或者其它方式连接。计算机可读存储介质820用于存储计算机程序821,计算机程序821包括计算机指令,处理器810用于执行计算机可读存储介质820存储的计算机指令。处理器810是电子设备800的计算核心以及控制核心,其适于实现一条或多条计算机指令,具体适于加载并执行一条或多条计算机指令从而实现相应方法流程或相应功能。
作为示例,处理器810也可称为中央处理器(Central Processing Unit,CPU)。处理器810可以包括但不限于:通用处理器、数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(Field Programmable Gate Array,FPGA)或者其他可编程逻辑器件、分立元件门或者晶体管逻辑器件、分立硬件组件等等。
作为示例,计算机可读存储介质820可以是高速RAM存储器,也可以是非不稳定的存储器(Non-VolatileMemory),例如至少一个磁盘存储器;可选的,还可以是至少一个位于远离前述处理器810的计算机可读存储介质。具体而言,计算机可读存储介质820包括但不限于:易失性存储器和/或非易失性存储器。其中,非易失性存储器可以是只读存储器(Read-Only Memory,ROM)、可编程只读存储器(Programmable ROM,PROM)、可擦除可编程只读存储器(Erasable PROM,EPROM)、电可擦除可编程只读存储器(Electrically EPROM,EEPROM)或闪存。易失性存储器可以是随机存取存储器(Random Access Memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM可用,例如静态随机存取存储器(Static RAM,SRAM)、动态随机存取存储器(Dynamic RAM,DRAM)、同步动态随机存取存储器(Synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(Double DataRate SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(Enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(synch link DRAM,SLDRAM)和直接内存总线随机存取存储器(Direct Rambus RAM,DR RAM)。
如图8所示,该电子设备800还可以包括收发器830。
其中,处理器810可以控制该收发器830与其他设备进行通信,具体地,可以向其他设备发送信息或数据,或接收其他设备发送的信息或数据。收发器830可以包括发射机和接收机。收发器830还可以进一步包括天线,天线的数量可以为一个或多个。
应当理解,该通信设备800中的各个组件通过总线系统相连,其中,总线系统除包括数据总线之外,还包括电源总线、控制总线和状态信号总线。值得注意的是,该电子设备800可以是任一具有数据处理能力的电子设备;该计算机可读存储介质820中存储有第一计算机指令;由处理器810加载并执行计算机可读存储介质820中存放的第一计算机指令,以实现图1所示方法实施例中的相应步骤;具体实现中,计算机可读存储介质820中的第一计算机指令由处理器810加载并执行相应步骤,为避免重复,此处不再赘述。
根据本申请的另一方面,本申请实施例还提供了一种计算机可读存储介质(Memory),计算机可读存储介质是电子设备800中的记忆设备,用于存放程序和数据。例如,计算机可读存储介质820。可以理解的是,此处的计算机可读存储介质820既可以包括电子设备800中的内置存储介质,当然也可以包括电子设备800所支持的扩展存储介质。计算机可读存储介质提供存储空间,该存储空间存储了电子设备800的操作系统。并且,在该存储空间中还存放了适于被处理器810加载并执行的一条或多条的计算机指令,这些计算机指令可以是一个或多个的计算机程序821(包括程序代码)。
根据本申请的另一方面,本申请实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。例如,计算机程序821。此时,数据处理设备800可以是计算机,处理器810从计算机可读存储介质820读取该计算机指令,处理器810执行该计算机指令,使得该计算机执行上文涉及的各种可选方式中提供的各种方法。换言之,当使用软件实现时,可以全部或部分地以计算机程序产品的形式实现。该计算机程序产品包括一个或多个计算机指令。在计算机上加载和执行该计算机程序指令时,全部或部分地运行本申请实施例的流程或实现本申请实施例的功能。该计算机可以是通用计算机、专用计算机、计算机网络、或者其他可编程装置。该计算机指令可以存储在计算机可读存储介质中,或者从一个计算机可读存储介质向另一个计算机可读存储介质进行传输,例如,该计算机指令可以从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(digital subscriberline,DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。
本领域普通技术人员可以意识到,结合本文中所公开的实施例描述的各示例的单元以及流程步骤,能够以电子硬件、或者计算机软件和电子硬件的结合来实现。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。
最后需要说明的是,以上内容,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应以权利要求的保护范围为准。
Claims (15)
1.一种模型训练方法,其特征在于,包括:
获取n张图像,并对所述n张图像进行划分得到k个图像集;n>k>1;
利用k个服务器分别对所述k个图像集进行图像处理,以得到所述k个服务器中的各个服务器的样本数据集;
基于所述k个服务器中经过第t次迭代计算后得到的k个模型参数更新所述各个服务器中的模型参数,并基于所述各个服务器的样本数据集对所述各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型。
2.根据权利要求1所述的方法,其特征在于,所述基于所述k个服务器中经过第t次迭代计算后得到的k个模型参数更新所述各个服务器中的模型参数,并基于所述各个服务器的样本数据集对所述各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型,包括:
利用所述k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,更新所述各个服务器中的模型参数;
基于所述各个服务器的样本数据集,对所述各个服务器中的模型参数进行第t+1次迭代计算,直至所述各个服务器中的模型参数的迭代次数等于预设迭代次数或所述各个服务器中的模型参数的下降梯度为零时,得到所述k个图像分类模型。
3.根据权利要求2所述的方法,其特征在于,所述基于所述各个服务器的样本数据集,对所述各个服务器中的模型参数进行第t+1次迭代计算,直至所述各个服务器中的模型参数的迭代次数等于预设迭代次数或所述各个服务器中的模型参数的下降梯度为零时,得到所述k个图像分类模型,包括:
将所述各个服务器中的模型参数与所述各个服务器的样本数据集中的特征数据相乘后减去所述各个服务器的样本数据集中的标签数据,并将得到的结果与所述各个服务器的样本数据集中的特征数据进行相乘,得到所述各个服务器中的模型参数的梯度;
将所述各个服务器中的模型参数的学习率和所述各个服务器中的模型参数的梯度进行相乘,得到所述各个服务器中的模型参数的修正值;
利用所述各个服务器中的模型参数减去所述各个服务器中的模型参数的修正值,得到所述k个服务器中经过第t+1次迭代计算后得到的k个模型参数;
若t+1等于所述预设迭代次数或所述各个服务器中的模型参数的梯度为零时,则得到所述k个图像分类模型;
若t+1小于所述预设值或所述各个服务器中的模型参数的梯度为零时,则基于所述各个服务器的样本数据集和所述k个服务器中经过第t+1次迭代计算后得到的k个模型参数,对所述各个服务器中的模型参数进行第t+2次迭代计算,直至所述各个服务器中的模型参数的迭代次数等于预设迭代次数或所述各个服务器中的模型参数的下降梯度为零时,得到所述k个图像分类模型。
4.根据权利要求1所述的方法,其特征在于,所述基于所述k个服务器中经过第t次迭代计算后得到的k个模型参数更新所述各个服务器中的模型参数,并基于所述各个服务器的样本数据集对所述各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型,包括:
对所述各个服务器的样本数据集进行抽样,得到所述各个服务器的训练数据集和所述各个服务器的测试数据集;
利用所述k个服务器中经过第t次迭代计算后得到的k个模型参数中的最小参数,更新所述各个服务器中的模型参数;
基于所述各个服务器的训练数据集,对所述各个服务器中的模型参数进行第t+1次迭代计算,得到所述k个服务器中经过第t+1次迭代计算后得到的k个模型参数;
基于所述各个服务器的测试数据集,对所述各个服务器中的模型参数的损失进行测试,得到用于表征所述k个服务器中的模型参数的损失的目标损失值;
确定所述目标损失值是否满足测试要求;
若所述目标损失值满足所述测试要求,则得到所述k个图像分类模型;
若所述目标损失值不满足测试要求,则利用所述k个服务器中经过第t+1次迭代计算后得到的k个模型参数中的最小参数更新所述各个服务器中的模型参数,并基于所述各个服务器的训练数据集,对所述各个服务器中的模型参数进行第t+2次迭代计算,直至所述目标损失值满足所述测试要求时,得到所述k个图像分类模型。
5.根据权利要求4所述的方法,其特征在于,所述对所述各个服务器的样本数据集进行抽样,得到所述各个服务器的训练数据集和所述各个服务器的测试数据集,包括:
基于所述各个服务器的样本数据集的数据类型,按照预设比例通过对所述各个服务器的样本数据集进行抽样,得到所述各个服务器的训练数据集和所述各个服务器的测试数据集。
6.根据权利要求5所述的方法,其特征在于,所述基于所述各个服务器的样本数据集的数据类型,按照预设比例通过对所述各个服务器的样本数据集进行抽样,得到所述各个服务器的训练数据集和所述各个服务器的测试数据集,包括:
若所述各个服务器的样本数据集的数据类型为适用于横向联邦学习模型的数据类型,则按照所述预设比例通过对所述各个服务器的样本数据集中的特征数据和标签数据进行抽样,得到所述各个服务器的训练数据集和所述各个服务器的测试数据集;
若所述各个服务器的样本数据集的数据类型为适用于纵向联邦学习模型的数据类型,则按照所述预设比例通过对所述k个服务器中的第一服务器的样本数据集中的特征数据和标签数据进行抽样,并所述预设按照比例通过对所述k个服务器中除所述第一服务器之外的服务器器的样本数据集中的特征数据进行抽样,以得到所述各个服务器的训练数据集和所述各个服务器的测试数据集。
7.根据权利要求4所述的方法,其特征在于,所述基于所述各个服务器的测试数据集,对所述各个服务器中的模型参数的损失进行测试,得到用于表征所述k个服务器中的模型参数的损失的目标损失值,包括:
利用所述各个服务器中的模型参数乘以所述各个服务器的测试数据集中的特征数据,得到所述各个服务器的中间预测结果;
利用二分类函数对所述各个服务器的中间预测结果进行计算,得到所述各个服务器输出的图像分类结果;
基于所述各个服务器输出的图像分类结果和所述各个服务器的测试数据集中的标签数据,确定所述目标损失值。
8.根据权利要求7所述的方法,其特征在于,所述基于所述各个服务器输出的图像分类结果和所述各个服务器的测试数据集中的标签数据,确定所述目标损失值,包括:
利用所述各个服务器的测试数据集中的标签数据减去所述各个服务器输出的图像分类结果,得到所述各个服务器中的模型参数的损失特征数据;
对所述各个服务器中的模型参数的损失特征数据进行转置,并利用转置后的数据与所述各个服务器中的模型参数的损失特征数据进行相乘,得到所述各个服务器中的模型参数的损失值;
将所述各个服务器中的模型参数的损失值的平均值,确定为所述目标损失值。
9.根据权利要求1至8中任一项所述的方法,其特征在于,所述各个服务器中的模型参数的初始值相同,且所述k个图像集中的各个图像集中包括n/k张图像。
10.一种分类预测方法,其特征在于,包括:
获取目标图像;
利用k个服务器中的目标服务器对所述目标图像进行图像处理,以得到所述目标图像的特征数据;其中,所述k个服务器配置有k个图像分类模型,所述k个图像分类模型为按照权利要求1至9中任一项所述的方法训练的图像分类模型;
基于所述目标图像的特征数据,利用所述目标服务器中的图像分类模型,对所述目标图像进行分类预测,得到所述目标图像的图像分类结果。
11.一种模型训练装置,其特征在于,包括:
获取单元,用于获取n张图像,并对所述n张图像进行划分得到k个图像集;n>k>1;
处理单元,用于利用k个服务器分别对所述k个图像集进行图像处理,以得到所述k个服务器中的各个服务器的样本数据集;
训练单元,用于基于所述k个服务器中经过第t次迭代计算后得到的k个模型参数更新所述各个服务器中的模型参数,并基于所述各个服务器的样本数据集对所述各个服务器中更新后的模型参数进行第t+1次迭代计算,得到训练后的k个图像分类模型。
12.一种分类预测装置,其特征在于,包括:
获取单元,用于获取目标图像;
处理单元,用于利用k个服务器中的目标服务器对所述目标图像进行图像处理,以得到所述目标图像的特征数据;其中,所述k个服务器配置有k个图像分类模型,所述k个图像分类模型为按照权利要求1至10中任一项所述的方法训练的图像分类模型;
预测单元,用于基于所述目标图像的特征数据,利用所述目标服务器中的图像分类模型,对所述目标图像进行分类预测,得到所述目标图像的图像分类结果。
13.一种电子设备,其特征在于,包括:
处理器,适于执行计算机程序;
计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,所述计算机程序被所述处理器执行时,实现权利要求1至9中任一项所述的方法或权利要求10所述的方法。
14.一种计算机可读存储介质,其特征在于,用于存储计算机程序,所述计算机程序使得计算机执行权利要求1至9中任一项所述的方法或权利要求10所述的方法。
15.一种计算机程序产品,包括计算机指令,其特征在于,所述计算机指令被处理器执行时实现权利要求1至9中任一项所述的方法或权利要求10所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211167683.0A CN116994018A (zh) | 2022-09-23 | 2022-09-23 | 模型训练方法、分类预测方法以及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211167683.0A CN116994018A (zh) | 2022-09-23 | 2022-09-23 | 模型训练方法、分类预测方法以及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116994018A true CN116994018A (zh) | 2023-11-03 |
Family
ID=88527137
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211167683.0A Pending CN116994018A (zh) | 2022-09-23 | 2022-09-23 | 模型训练方法、分类预测方法以及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116994018A (zh) |
-
2022
- 2022-09-23 CN CN202211167683.0A patent/CN116994018A/zh active Pending
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2022022274A1 (zh) | 一种模型训练方法及装置 | |
WO2019232772A1 (en) | Systems and methods for content identification | |
CN108920717B (zh) | 用于显示信息的方法及装置 | |
WO2022016556A1 (zh) | 一种神经网络蒸馏方法以及装置 | |
EP4303767A1 (en) | Model training method and apparatus | |
US20240135174A1 (en) | Data processing method, and neural network model training method and apparatus | |
WO2024001806A1 (zh) | 一种基于联邦学习的数据价值评估方法及其相关设备 | |
CN114266897A (zh) | 痘痘类别的预测方法、装置、电子设备及存储介质 | |
CN114547267A (zh) | 智能问答模型的生成方法、装置、计算设备和存储介质 | |
CN112785585A (zh) | 基于主动学习的图像视频质量评价模型的训练方法以及装置 | |
CN116684330A (zh) | 基于人工智能的流量预测方法、装置、设备及存储介质 | |
CN108475346A (zh) | 神经随机访问机器 | |
WO2023246735A1 (zh) | 一种项目推荐方法及其相关设备 | |
WO2023185541A1 (zh) | 一种模型训练方法及其相关设备 | |
CN116739154A (zh) | 一种故障预测方法及其相关设备 | |
WO2023050143A1 (zh) | 一种推荐模型训练方法及装置 | |
CN113947185B (zh) | 任务处理网络生成、任务处理方法、装置、电子设备及存储介质 | |
CN114169906B (zh) | 电子券推送方法、装置 | |
CN116994018A (zh) | 模型训练方法、分类预测方法以及装置 | |
CN114707070A (zh) | 一种用户行为预测方法及其相关设备 | |
CN112417260B (zh) | 本地化推荐方法、装置及存储介质 | |
CN114298961A (zh) | 图像处理方法、装置、设备及存储介质 | |
CN114299517A (zh) | 图像处理方法、装置、设备、存储介质及计算机程序产品 | |
CN113822291A (zh) | 一种图像处理方法、装置、设备及存储介质 | |
CN111310794A (zh) | 目标对象的分类方法、装置和电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |