CN114287009A - 协同训练数据属性的推断方法、装置、设备及存储介质 - Google Patents
协同训练数据属性的推断方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN114287009A CN114287009A CN202180004174.3A CN202180004174A CN114287009A CN 114287009 A CN114287009 A CN 114287009A CN 202180004174 A CN202180004174 A CN 202180004174A CN 114287009 A CN114287009 A CN 114287009A
- Authority
- CN
- China
- Prior art keywords
- model
- gradient
- training
- data
- attribute
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000012549 training Methods 0.000 title claims abstract description 186
- 238000000034 method Methods 0.000 title claims abstract description 76
- 238000003860 storage Methods 0.000 title claims abstract description 11
- 230000006870 function Effects 0.000 claims description 26
- 238000004891 communication Methods 0.000 claims description 13
- 238000009826 distribution Methods 0.000 claims description 12
- 238000013527 convolutional neural network Methods 0.000 claims description 7
- 238000004364 calculation method Methods 0.000 claims description 4
- 238000010801 machine learning Methods 0.000 abstract description 5
- 230000008569 process Effects 0.000 description 20
- 238000013135 deep learning Methods 0.000 description 9
- 238000010586 diagram Methods 0.000 description 8
- 238000005457 optimization Methods 0.000 description 6
- 230000000694 effects Effects 0.000 description 5
- 238000013145 classification model Methods 0.000 description 4
- 230000002776 aggregation Effects 0.000 description 3
- 238000004220 aggregation Methods 0.000 description 3
- 238000012935 Averaging Methods 0.000 description 2
- 238000013473 artificial intelligence Methods 0.000 description 2
- 238000007405 data analysis Methods 0.000 description 2
- 238000013136 deep learning model Methods 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000012545 processing Methods 0.000 description 2
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 230000004069 differentiation Effects 0.000 description 1
- 238000005516 engineering process Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 101150050759 outI gene Proteins 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
- G06N5/046—Forward inferencing; Production systems
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N5/00—Computing arrangements using knowledge-based models
- G06N5/04—Inference or reasoning models
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/94—Hardware or software architectures specially adapted for image or video understanding
-
- G—PHYSICS
- G10—MUSICAL INSTRUMENTS; ACOUSTICS
- G10L—SPEECH ANALYSIS TECHNIQUES OR SPEECH SYNTHESIS; SPEECH RECOGNITION; SPEECH OR VOICE PROCESSING TECHNIQUES; SPEECH OR AUDIO CODING OR DECODING
- G10L25/00—Speech or voice analysis techniques not restricted to a single one of groups G10L15/00 - G10L21/00
- G10L25/27—Speech or voice analysis techniques not restricted to a single one of groups G10L15/00 - G10L21/00 characterised by the analysis technique
- G10L25/30—Speech or voice analysis techniques not restricted to a single one of groups G10L15/00 - G10L21/00 characterised by the analysis technique using neural networks
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Software Systems (AREA)
- General Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Mathematical Physics (AREA)
- General Engineering & Computer Science (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Biophysics (AREA)
- Molecular Biology (AREA)
- Biomedical Technology (AREA)
- Life Sciences & Earth Sciences (AREA)
- Multimedia (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Databases & Information Systems (AREA)
- Medical Informatics (AREA)
- Signal Processing (AREA)
- Audiology, Speech & Language Pathology (AREA)
- Human Computer Interaction (AREA)
- Acoustics & Sound (AREA)
- Image Analysis (AREA)
Abstract
本申请涉及机器学习技术领域,公开了一种协同训练数据属性的推断方法、装置、计算设备及存储介质。方法包括:将服务器预训练的共享模型分发给分布式协同训练的参与设备;获取参与设备上传的第一梯度;基于更新后的共享模型,根据第一梯度重建样本数据的深度特征;采用共享模型提取带有属性标签的辅助数据的深度特征,训练属性推断模型;根据训练完成的属性推断模型对重建的深度特征进行属性推断。本申请无需重构输入样本,就能推断参与设备本地样本数据的相关属性,且不受参与设备每次训练更新的样本数据批量大小的影响,尤其在大批量样本数据下表现突出,性能稳定,可以对单个训练样本作属性推断。
Description
技术领域
本申请涉及机器学习技术领域,特别涉及一种协同训练数据属性的推断方法、装置、计算设备及存储介质。
背景技术
随着硬件设备的飞速发展以及大数据的广泛应用,人工智能领域受到人们的广泛关注。其中,深度学习作为一种重要的数据分析工具,广泛应用于生物特征识别、汽车自动驾驶、机器视觉等多个应用领域。在深度学习训练的过程中,包括中心式训练以及分布式训练两种方式。中心式训练由一个中心服务器收集训练所要求的数据然后进行集中训练;分布式训练(也可以称为协同训练)不需要收集数据,而是利用分布式训练参与者的本地数据在其本地设备(以下称为参与设备)上训练模型,然后将训练的梯度或者模型的参数信息发送给中心服务器进行聚合,以此来达到分布式训练同一个模型的目的。
在协同训练的过程中,训练参与者的数据分布往往不平衡,导致本地训练的模型有一定的偏差,从而导致协同训练的模型性能下降。此外,在深度学习中,模型的应用场景需要与模型的数据分布相似才能最大化模型的性能。统计训练数据的属性也可以将模型部署到更加适用的场景中。现有技术中,一般需要对协同训练的更新样本数据进行重建,才可以对参与设备本地的每个单独样本进行属性推断,且其技术方法仅适用于参与设备使用单个或极小批量样本数据进行迭代更新的情况,不符合协同训练的一般情况;其他基于梯度更新的属性推断技术方法则无法获得整个批量中单个训练样本的数据属性,推断有效性较差。
发明内容
本申请实施方式的目的在于提供一种协同训练数据属性的推断方法、装置、计算设备及存储介质,以解决现有技术中需要对协同训练的样本数据进行重建,才可以对训练数据中每个单独样本进行属性推断的技术问题,并且克服了目前只能在单个或极小批量的参与设备更新样本上进行推断的限制。
为解决上述技术问题,本申请实施例提供了一种协同训练数据属性的推断方法,应用于模型分布式协同训练的中心服务器,所述方法包括:将预训练的共享模型分发给分布式协同训练的参与设备,以使所述参与设备采用样本数据对所述共享模型进行训练;获取所述参与设备上传的第一梯度,所述第一梯度为所述参与设备进行模型训练时计算的模型损失相对于模型参数的梯度;基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征;采用共享模型提取带有属性标签的辅助数据的深度特征,训练属性推断模型,其中,所述共享模型是经过协同训练若干次迭代更新得到的;根据训练完成的属性推断模型对重建的所述深度特征进行属性推断。
在一些实施例中,所述基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征,包括:随机初始化待优化的第一深度特征;将所述第一深度特征输入所述共享模型,获取第二梯度;
最小化所述第一梯度和所述第二梯度之间的差距,对所述第一深度特征进行优化。
在一些实施例中,所述共享模型为卷积神经网络模型,所述共享模型包括特征提取器和分类器fc,所述特征提取器包括(n+1)个卷积块;所述将所述第一深度特征输入所述共享模型,获取第二梯度,包括:将所述第一深度特征输入所述特征提取器的所述多个卷积块中的最后一个卷积块fn+1,再将所述卷积块fn+1输出的特征E(X)输入所述分类器fc;分别计算损失函数对应于fn+1的参数的梯度与所述损失函数对应于fc的参数的梯度其中,所述第二梯度包括所述梯度和所述梯度
在一些实施例中,所述最小化所述第一梯度和所述第二梯度之间的差距,包括:最小化目标函数以最小化所述第一梯度和所述第二梯度之间的差距,所述目标函数为:其中,λ为超参数,gn+1和gc为所述参与设备上传的第一梯度,和均为衡量两个梯度之间差别的距离函数,衡量两个梯度g与之间差别的距离函数d为:其中,σ2=Var(g),Var(g)为梯度g的方差。
在一些实施例中,将所述超参数λ和所述学习率α设置为相同的数值。
在一些实施例中,所述第一梯度为所述参与设备随机采样的第一样本集进行模型训练,并计算所述第一样本集对应的反向传播的模型损失相对于模型参数的梯度;所述基于所述共享模型,根据所述第一梯度重建样本数据的深度特征,包括:基于所述共享模型,根据所述第一梯度重建所述第一样本集的深度特征。
在一些实施例中,所述样本数据和所述辅助数据为图片或语音。
本申请实施例还提供了一种协同训练数据属性的推断装置,应用于模型分布式协同训练的中心服务器,所述装置包括:分发模块,用于将预训练的共享模型分发给分布式协同训练的参与设备,以使所述参与设备采用本地样本的批量数据对所述共享模型进行训练迭代更新;获取模块,用于获取所述参与设备上传的第一梯度,所述第一梯度为所述参与设备进行模型训练时计算的模型损失相对于模型参数的梯度;重建模块,用于基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征;训练模块,用于采用当前共享模型提取带有属性标签的辅助数据的深度特征,训练属性推断模型,其中,所述共享模型是经过协同训练若干次迭代更新得到的;推断模块,用于根据训练完成的属性推断模型以及根据重建的所述深度特征对参与设备本地的单个训练样本进行数据属性推断。
本申请实施例还提供了一种计算设备,包括处理器、存储器、通信接口和通信总线,所述处理器、所述存储器和所述通信接口通过所述通信总线完成相互间的通信;所述存储器用于存放至少一条可执行指令,所述可执行指令使所述处理器执行如上所述的协同训练数据属性的推断方法对应的操作。
本申请实施例还提供了一种计算机可读存储介质,所述存储介质中存储有至少一条可执行指令,所述可执行指令使处理器执行如上所述的协同训练数据属性的推断方法对应的操作。
本申请实施例通过获取设备反馈的进行模型训练时计算的模型损失相对于模型参数的梯度,根据该梯度和共享模型重建样本数据的深度特征,并通过带有属性标签的辅助数据的深度特征训练属性推断模型,最后根据训练完成的属性推断模型对重建的深度特征进行属性推断,可以利用重建的深度特征中包含的冗余特征进行额外的属性推断,无需重构输入样本,就能够推断每一个样本数据的相关属性,且不受参与设备每次训练更新的样本数据批量大小(batch size)的影响,尤其在大批量样本数据下表现突出,性能稳定,且可以对单个训练样本作属性推断。
附图说明
一个或多个实施方式通过与之对应的附图中的图片进行示例性说明,这些示例性说明并不构成对实施方式的限定,附图中具有相同参考数字标号的元件表示为类似的元件,除非有特别申明,附图中的图不构成比例限制。
图1是人脸分类模型的样本数据的性别特征图;
图2是本申请实施例的应用场景示意图;
图3是本申请实施例提供的协同训练数据属性的推断方法的流程图;
图4是共享模型的结构示意图;
图5是本申请和相关技术1在不同样本数据批量大小下重建深度特征的成功率的统计图;
图6是本申请实施例提供的协同训练数据属性的推断装置的结构图;
图7是本申请实施例提供的一种计算设备的结构示意图。
具体实施方式
为使本申请实施方式的目的、技术方案和优点更加清楚,下面将结合附图对本申请的各实施方式进行详细的阐述。然而,本领域的普通技术人员可以理解,在本申请各实施方式中,为了使读者更好地理解本申请而提出了许多技术细节。但是,即使没有这些技术细节和基于以下各实施方式的种种变化和修改,也可以实现本申请所要求保护的技术方案。
随着硬件设备的飞速发展以及大数据的广泛应用,人工智能领域受到人们的广泛关注。其中,深度学习作为一种重要的数据分析工具,广泛应用于生物特征识别、汽车自动驾驶、机器视觉等多个应用领域。
在深度学习训练的过程中,包括中心式训练以及分布式训练两种方式。中心式训练由一个中心服务器收集训练所要求的数据然后进行集中训练;分布式训练(也可以称为协同训练或协作学习)是指多个参与者利用自己的本地数据共同训练同一个机器学习模型,在此过程中,参与者不需要收集数据,也不需要交换自己的本地数据,而是利用参与者的本地数据在其本地设备上训练模型,然后将训练的梯度或者模型的参数信息发送给中心服务器进行聚合,相当于参与者之间交换用于模型参数更新的梯度信息,以此来达到分布式训练同一个模型的目的。协同训练由于参与者无需将本地数据上传,保证了数据的私密性,数据安全性较高。
在协同训练的过程中,训练参与者的数据分布往往不平衡,例如当协同训练一个人脸识别模型时,不同训练参与者的数据中的男女性别比例可能不同,导致本地训练的模型有一定的偏差,从而导致协同训练的模型性能下降。统计各个训练参与者本地数据中的男女比例可以根据其数据分布为本地模型添加约束,从而提高模型性能。
此外,在深度学习中,模型的应用场景需要与模型的数据分布相似才能最大化模型的性能。统计训练数据的属性也可以将模型部署到更加适用的场景中。例如,在协同训练的人脸识别模型中,如果参与者的数据大多数是年轻人的数据,那么将其部署到应用场景大多数是老年人的应用中是不太合适的。经过统计训练数据的属性,可以将模型部署到更加适合的场景中或者在微调模型后再进行相应的部署。
在深度学习中,不同的学习任务训练的模型提取出的特征具有一定的泛化性,也就是说,任务一所提取的特征可以应用于任务二的学习中。综上,为了提高模型性能,以及将模型部署到更加适用的场景中,在协同训练中需要推断出训练数据的分布以及相关属性。由于深度特征不仅编码含有协同训练主任务相关的信息,也含有其他额外信息,可以利用深度特征对数据进行相关的推断。
相关技术1中,利用模型前向传播后得到的中间层特征或者最后输出的概率进行数据的属性推断。该方法通过带有属性标签的数据,经过模型前向传播得到的特征或者模型输出的概率,然后利用这些信息训练属性推断分类器,以此推断数据的相关属性。这种数据属性推断方式的应用场景更多是机器学习即服务(Machine Learning As a Service)。在这种场景下,更多的是利用数据对一个训练完成的模型进行查询,并不涉及到利用参与者的数据更新模型参数或者重新部署模型的问题。另外,通常来说,此类方法通常需要修改模型的训练过程,使得模型的中间层输出或者最终输出编码含有数据属性相关的信息。
相关技术2中,直接利用模型反向传播时的梯度进行数据属性的推断。该方法通过将带有数据标签的数据输入模型,然后计算该数据所对应的损失梯度,直接利用梯度信息训练属性推断分类器,以此推断数据的相关属性。在深度学习训练的过程中,大多采用的训练方式是小批量(mini-batch)训练,即一个训练流程中,输入多个数据,然后计算多个数据所对应的平均梯度。协同训练中分发的梯度是由多个数据的梯度加权平均而来。因此,此类方法仅仅能够判断一个整个批次中数据的平均属性,而无法获取单一某一个数据点的属性。
相关技术3中,利用协作训练时子模型上传的梯度重建原始训练数据。该方法通过将随机初始化的训练数据输入模型,计算该数据所对应的损失梯度,然后最小化该损失梯度与上传梯度之间的差距,以此来优化随机初始化的训练数据,从而重建原始训练数据,并将训练数据用于属性推断。在这种方法中,会受到模型结构以及训练数据批次大小(batchsize)的非常大影响,而影响重建数据的效果,最终导致将重建的数据用于属性推断并不准确。
因此,本申请实施例提出了一种方案,利用协同训练过程中分发的梯度信息重建训练数据的深度特征,并且利用重建的深度特征对数据进行额外的信息推断,以此推测出训练数据的分布以及相关属性,来调整训练过程中的参数设置,以及将训练好的模型更好地部署到实际场景中。
上述相关技术1进行数据属性推断时通常需要修改模型的训练过程,使得模型的中间层输出或者最终输出编码含有数据属性相关的信息。而修改模型训练过程的方法在协同训练中往往是不可行的,因为所有的参与者需要有一个共同的学习目标,如果单一一个参与者修改了训练过程,会影响整体模型的训练效果。通过本申请实施例提供的利用梯度重建深度特征,且根据重建的深度特征对每个训练数据进行属性推断的方法,不需要修改模型的训练过程,即可达到推断数据属性的目的。
上述相关技术2利用经过加权平均后的梯度信息进行数据属性推断,只能推断出批次数据中的平均属性,而不能精确到具体某一特定数据的属性。通过本申请实施例利用梯度重建出每一个数据点对应的深度特征,然后利用重建的深度特征对数据进行属性推断,可以精确推断特定数据点的属性。
上述相关技术3会受到模型结构以及训练数据批次大小(batch size)的非常大影响,而影响重建数据的效果,最终导致将重建的数据用于属性推断并不准确。通过本申请实施例所利用模型结构仅仅为模型的部分子块,涉及到的模型结构更为简单,且重建深度特征的任务比重建原始数据的任务也更为简单,通过本申请实施例提供的重建方法重构深度特征,可以避免模型结构的影响,且在大批量样本数据下表现突出,性能稳定。
图1是人脸分类模型的样本数据的性别特征图。t分布随机近邻嵌入(t-distributed stochastic neighbor embedding,t-sne)算法是将高维度的特征映射到二维,然后将坐标归一化到(0,1)之间,图1中横坐标和纵坐标为归一化后的竖直,无具体含义。训练普通的人脸分类模型的主要任务是分辨人的身份信息,而性别信息在模型训练过程中并没有提供。但是,如图中所示,可以看出,即使没有提供性别信息,模型提取的特征经过t-sne降维可视化后,男女样本提取的特征有着一定的差异,并且可以很容易地进行区分。因此,利用深度特征可以进行一定的数据属性推断,这验证了利用模型的特征进行数据属性推断的可能性。
本申请实施例的应用场景是深度学习中的协同训练过程中。其中,协同训练的目标是利用各个协同训练参与者的本地数据共同训练一个模型,并且训练数据不需要离开参与者本地。深度学习模型可以是各种神经网络模型,例如卷积神经网络(ConvolutionalNeural Network,CNN)模型。深度学习模型可以用于数据处理,例如图像处理中的特征提取和分类。进一步的,可以用于人脸识别、物体识别等。其中,物体识别可以是动物、植物、物品等。
图2是本申请实施例的应用场景示意图。如图中所示,中心服务器将需要协同训练的共享模型分发给协同训练参与者的参与设备(也可称为训练设备),由参与设备采用本地存储的训练数据进行模型训练。参与设备将训练的梯度或者模型的参数信息发送给中心服务器进行聚合,最终完成模型的训练。为提高协同训练效率和避免较大的偏差,所述共享模型一般在服务器端的公共数据集上进行预训练。所述公共数据集一般认为与协同训练的所有参与设备拥有不一样的样本但相似的数据分布。
图3是本申请实施例提供的协同训练数据属性的推断方法的流程图。该方法应用于模型分布式协同训练的中心服务器。如图中所示,该方法包括如下步骤:
S11:将预训练的共享模型分发给分布式协同训练的参与设备,以使参与设备采用本地样本的批量数据对共享模型进行训练迭代更新;
S12:获取参与设备上传的第一梯度,第一梯度为参与设备进行模型迭代更新训练时计算的模型损失相对于模型参数的梯度;
S13:基于更新后的共享模型,根据第一梯度重建样本数据的深度特征;
S14:采用当前共享模型提取带有属性标签的辅助数据的深度特征,训练属性推断模型,其中,共享模型是经过协同训练若干次迭代更新得到的;
S15:根据训练完成的属性推断模型以及根据重建的深度特征对参与设备本地的单个训练样本进行数据属性推断。
本申请实施例通过获取设备反馈的进行模型训练时计算的模型损失相对于模型参数的梯度,根据该梯度和共享模型重建样本数据的深度特征,并通过带有属性标签的辅助数据的深度特征训练属性推断模型,最后根据训练完成的属性推断模型对重建的深度特征进行属性推断,可以利用重建的深度特征中包含的冗余特征进行额外的属性推断,无需重构输入样本,就能够推断每一个样本数据的相关属性,且不受参与设备每次训练更新的样本数据批量大小的影响,尤其在大批量样本数据下表现突出,性能稳定,且可以对参与设备的单个本地训练样本作数据属性推断。
首先,对本申请实施例中的协同训练过程进行简要介绍。中心服务器将第一共享模型(也即初始化模型)分发给所有的参与设备。每个参与设备分别从本地存储的样本数据中随机选择一批样本数据进行模型训练。此次训练完成后,参与设备将训练更新的模型参数发送给中心服务器。中心服务器将获得的所有参与设备更新的模型参数进行参数平均后,得到优化的第二共享模型。中心服务器继续将第二共享模型分发给所有的参与设备。由参与设备继续进行模型训练的过程。在后续的训练中,参与设备每次都将随机选取一批新的本地样本数据进行训练。经过多次迭代训练后,最终得到收敛完成的训练好的模型。
S11中,首先,中心服务器对需要训练的模型进行初始化,并将初始化的共享模型分发给各分布式协同训练的参与设备。每个参与设备本地存储有用于训练该模型的样本数据,各参与设备上存储的样本数据通常不同,且不平衡。需要训练的模型可以为图片识别模型或语音识别模型,则用于训练的数据为图片或语音。经过每次迭代后,中心服务器再将更新的共享模型分发给参与设备。
图4是共享模型的结构示意图。如图中所示,共享模型为用于图片识别的卷积神经网络模型,共享模型包括特征提取器E和分类器C。特征提取器E包括(n+1)个卷积块,分别表示为f1,f2,…,fn,fn+1,特征提取器E用于提取输入样本数据X的特征E(X)。分类器C包括卷积块fc,其可以是二分类模型,可以根据中心服务器的数据提取的深度特征进行训练,然后对重建的深度特征的属性进行预测,也即用于将提取的特征E(X)按照模型目的构建的分类器进行识别。输入样本数据X(图片)至特征提取器E,经过(n+1)个卷积块的卷积操作,得到样本数据X对应的深度特征E(X)。深度特征E(X)输入至分类器C,得到最终的图片识别结果。
参与设备接收中心服务器发送的共享模型,且在本地存储的数据中随机采样一批数据,在本地进行训练,并且计算随机采样的数据对应的反向传播的损失梯度g,并且将g分享给中心服务器用于协同训练模型。请参考图4,可以根据损失函数计算损失梯度g,损失梯度g包括gc和gn+1,其中,gc为参与设备计算的损失函数对应于fc的参数的梯度,gn+1为参与设备计算的损失函数对应于fn+1的参数的梯度,二者均为真实的梯度。为后续的属性推断提供了有用的信息,特别对于大批量样本数据的训练和更新尤为有用。
在一些实施例中,第一梯度为参与设备随机采样的第一样本集进行模型训练,并计算第一样本集对应的反向传播的模型损失相对于模型参数的梯度。第一样本集中的数据可以是小批量的数据,这样可以提高计算速度和效率。当参与设备随机采样的第一样本集进行模型训练时,在S13中,则基于更新后的共享模型,根据第一梯度重建第一样本集的深度特征。
参与设备的训练数据始终保持在本地,且不与其他参与设备或者服务器共享,以此达到保护训练数据的隐私安全的目的。一个小批次(mini-batch)的数据中包含多个样本,样本数量取决于批次的大小。每一个样本中包含的内容与协同训练的目标模型相关,例如,协同训练的目的是共同训练人脸识别模型,则每一个训练样本包含的内容就是一张人脸图片,并且对应于一个标签,标签的内容取决于协同训练模型的目的。
在一些实施例中,S13进一步可以包括:
S131:随机初始化待优化的第一深度特征;
S132:将第一深度特征输入共享模型,获取第二梯度;
S133:最小化第一梯度和第二梯度之间的差距,对第一深度特征进行优化。
请参考图4,共享模型可以为卷积神经网络模型,共享模型包括特征提取器和分类器fc,特征提取器包括(n+1)个卷积块。第一深度特征为数据对其代表的是一对可优化的数据对,是欲重建的深度特征,是伪标签。由于不知道样本数据(原始数据)的真实标签,因此此处需提供一个可优化的伪标签用于计算交叉熵损失。最开始是随机初始化得到的,经过优化之后,最终获得的重建的与原来的相似,即为深度特征。
本申请实施例需要利用的信息是最后一个卷积块fn+1以及最后的分类器fc的信息。利用fn+1和fc的前向传播信息以及这两层网络相对应的反向传播梯度信息gn+1和gc,将数据对输入由fn+1和fc组成的子模型,并且分别计算损失函数对应于fn+1和fc的参数的梯度。具体的,S132包括:
S1321:将第一深度特征输入特征提取器的多个卷积块中的最后一个卷积块fn+1,再将卷积块fn+1输出的特征E(X)输入分类器fc;
其中,σ2=Var(g),Var(g)为梯度g的方差。
S133中对第一深度特征进行优化,进一步可以包括:
在一些实施例中,可以将超参数λ和学习率α设置为相同的数值。例如,将超参数λ设为0.1且将学习率α也设为0.1。可以理解的是,超参数和学习率的值一般可根据经验调整,也可以设置为其他数值。
其中,优化的次数可以根据经验设置,例如设置为5000次。当然,也可以设置为更高的次数,优化结果将更加逼近真实值,但会导致时间成本的增加。若设置为较少的次数,则优化结果可能没那么逼近真实值,但是会降低时间成本。
可以理解的是,S11~S13的执行无需改变协同训练过程,协同训练正常进行即可。且S11~S13可以对每个参与设备的每次训练的样本数据的深度特征进行重建,从而对于所有用于训练的样本数据,都可以进行后续的属性推断。
S14中,中心服务器存储有带有属性标签的辅助数据,训练属性推断模型(也可以称为属性分类模型,其功能是对数据的属性进行识别或分类)需要先利用特征提取器提取带有属性标签的辅助数据的特征。然后利用提取的带有属性标签的辅助数据的深度特征训练属性推断模型,用于推断参与设备中的样本数据的属性。
S15中,中心服务器将重建的样本数据的深度特征输入属性推断模型,以此实现对协同训练数据属性的推断。
本步骤可以对参与设备本地的所有参与了模型训练的样本数据的属性进行推断。
综上,本申请实施例通过协同训练的参与设备上传的梯度进行深度特征重建,并且通过属性推断模型对分布式协同训练的参与设备的训练数据的属性进行推断,从而实现对于协同训练数据属性的推断。
统计了本申请实施例的方法在不同样本数据批量大小下重建深度特征的成功率。具体为,统计重建深度特征与原始真实特征的余弦相似度>0.95的比例。与相关技术1的对比结果请参考图5所示。图5是本申请和相关技术1(图中标记为方法1)在不同样本数据批量大小下重建深度特征的成功率的统计图。可知,相比相关技术1,本申请实施例的方法在不同批量大小下对深度特征都有良好的重建效果。尤其在大批量大小(例如批量大小=512)下表现突出,性能稳定。
使用本申请实施例的方法以及相关技术1、相关技术2分别对第一数据集、第二数据集和第三数据集分别进行协同训练数据属性推断,对应的属性推断准确率如表1所示:
表1
由此可见,本申请提高了属性推断的准确率。
综上,相比现有技术,本申请实施例具有如下有益效果:
(1)利用模型训练过程中的前向传播以及反向传播信息重建训练数据对应的深度特征,可以不受小批次的大小影响而精准地重建出每个数据对应的深度特征;相比重建输入样本的方式,重建的数据量小,效率更高,例如重建输入样本的方式在样本数据的批量大小达到8时,其重建结果几乎不能用于属性推断。
(2)利用了更少的模型结构重建深度特征,因此重建效果受到模型具体结构影响较小,使其可以应用于多个不同的卷积神经网络模型,提高了应用的广泛性。
(3)与其他基于反向传播信息的推断方法相比,本申请实施例提出利用梯度重建深度特征的方法,可以重建出每一个训练样本对应的深度特征然后利用其进行数据相关属性的推断,可以推断小批次训练中的每一个数据相关的属性,并且提高了推断准确率。而现有的一些方式仅能推断批量样本数据中是否存在某个属性,无法获知该属性属于哪个具体的样本,或者一次仅能针对一个数量的样本数据进行属性推断。
图6是本申请实施例提供的协同训练数据属性的推断装置的结构图。如图中所示,该数据属性的推断装置应用于模型分布式协同训练的中心服务器,装置500包括分发模块501、获取模块502、重建模块503、训练模块504和推断模块505。其中:
分发模块501用于将预训练的共享模型分发给分布式协同训练的参与设备,以使所述参与设备采用本地样本的批量数据对所述共享模型进行训练迭代更新;
获取模块502,用于获取所述参与设备上传的第一梯度,所述第一梯度为所述参与设备进行模型训练时计算的模型损失相对于模型参数的梯度;
重建模块503,用于基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征;
训练模块504,用于采用当前共享模型提取带有属性标签的辅助数据的深度特征,训练属性推断模型,其中,所述共享模型是经过协同训练若干次迭代更新得到的;
推断模块505,用于根据训练完成的属性推断模型以及根据重建的所述深度特征对参与设备本地的单个训练样本进行数据属性推断。
本装置的具体实现方式和工作原理可参考前述的方法实施例,此处不再赘述。
图7是本申请实施例提供的一种计算设备的结构示意图。如图中所示,计算设备600包括处理器601、存储器602、通信接口603和通信总线604,处理器601、存储器602和通信接口603通过通信总线604完成相互间的通信。存储器602是非易失性计算机可读存储介质,可用于存储非易失性软件程序、非易失性计算机可执行程序以及模块。本申请实施例中,存储器602用于存放至少一条可执行指令,可执行指令使处理器601执行如上的协同训练数据属性的推断方法对应的操作。
本申请实施例还提供了一种计算机可读存储介质,存储介质中存储有至少一条可执行指令,可执行指令使处理器执行如上的协同训练数据属性的推断方法对应的操作。
最后应说明的是:以上实施方式仅用以说明本申请的技术方案,而非对其限制;在本申请的思路下,以上实施方式或者不同实施方式中的技术特征之间也可以进行组合,步骤可以以任意顺序实现,并存在如上所述的本申请的不同方面的许多其它变化,为了简明,它们没有在细节中提供;尽管参照前述实施方式对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施方式所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施方式技术方案的范围。
Claims (13)
1.一种协同训练数据属性的推断方法,应用于模型分布式协同训练的中心服务器,其特征在于,所述方法包括:
将预训练的共享模型分发给分布式协同训练的参与设备,以使所述参与设备采用本地样本的批量数据对所述共享模型进行训练迭代更新;
获取所述参与设备上传的第一梯度,所述第一梯度为所述参与设备进行模型训练时计算的模型损失相对于模型参数的梯度;
基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征;
采用所述共享模型提取带有属性标签的辅助数据的深度特征,训练属性推断模型,其中,所述共享模型是经过协同训练若干次迭代更新得到的;
根据训练完成的属性推断模型以及根据重建的所述深度特征对参与设备本地的单个训练样本进行数据属性推断。
2.根据权利要求1所述的方法,其特征在于,所述基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征,包括:
随机初始化待优化的第一深度特征;
将所述第一深度特征输入所述共享模型,获取第二梯度;
最小化所述第一梯度和所述第二梯度之间的差距,对所述第一深度特征进行优化。
8.根据权利要求7所述的方法,其特征在于,将所述超参数λ和所述学习率α设置为相同的数值。
9.根据权利要求1~9任一项所述的方法,其特征在于,所述第一梯度为所述参与设备随机采样的第一样本集进行模型训练,并计算所述第一样本集对应的反向传播的模型损失相对于模型参数的梯度;
所述基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征,包括:
基于所述共享模型,根据所述第一梯度重建所述第一样本集的深度特征。
10.根据权利要求1~9任一项所述的方法,其特征在于,所述样本数据和所述辅助数据为图片或语音。
11.一种协同训练数据属性的推断装置,应用于模型分布式协同训练的中心服务器,其特征在于,所述装置包括:
分发模块,用于将预训练的共享模型分发给分布式协同训练的参与设备,以使所述参与设备采用本地样本的批量数据对所述共享模型进行训练迭代更新;
获取模块,用于获取所述参与设备上传的第一梯度,所述第一梯度为所述参与设备进行模型训练时计算的模型损失相对于模型参数的梯度;
重建模块,用于基于所述共享模型,根据所述第一梯度重建所述样本数据的深度特征;
训练模块,用于采用当前共享模型提取带有属性标签的辅助数据的深度特征,训练属性推断模型,其中,所述共享模型是经过协同训练若干次迭代更新得到的;
推断模块,用于根据训练完成的属性推断模型以及根据重建的所述深度特征对参与设备本地的单个训练样本进行数据属性推断。
12.一种计算设备,其特征在于,包括处理器、存储器、通信接口和通信总线,所述处理器、所述存储器和所述通信接口通过所述通信总线完成相互间的通信;
所述存储器用于存放至少一条可执行指令,所述可执行指令使所述处理器执行如权利要求1~10中任一项所述的协同训练数据属性的推断方法对应的操作。
13.一种计算机可读存储介质,其特征在于,所述存储介质中存储有至少一条可执行指令,所述可执行指令使处理器执行如权利要求1~10中任一项所述的协同训练数据属性的推断方法对应的操作。
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
PCT/CN2021/135055 WO2023097602A1 (zh) | 2021-12-02 | 2021-12-02 | 协同训练数据属性的推断方法、装置、设备及存储介质 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN114287009A true CN114287009A (zh) | 2022-04-05 |
CN114287009B CN114287009B (zh) | 2024-08-02 |
Family
ID=80880015
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202180004174.3A Active CN114287009B (zh) | 2021-12-02 | 2021-12-02 | 协同训练数据属性的推断方法、装置、设备及存储介质 |
Country Status (3)
Country | Link |
---|---|
US (1) | US20240232665A1 (zh) |
CN (1) | CN114287009B (zh) |
WO (1) | WO2023097602A1 (zh) |
Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110008696A (zh) * | 2019-03-29 | 2019-07-12 | 武汉大学 | 一种面向深度联邦学习的用户数据重建攻击方法 |
CN110580496A (zh) * | 2019-07-11 | 2019-12-17 | 南京邮电大学 | 一种基于熵最小化的深度迁移学习系统及方法 |
CN112101489A (zh) * | 2020-11-18 | 2020-12-18 | 天津开发区精诺瀚海数据科技有限公司 | 一种联邦学习与深度学习融合驱动的设备故障诊断方法 |
CN112600794A (zh) * | 2020-11-23 | 2021-04-02 | 南京理工大学 | 一种联合深度学习中检测gan攻击的方法 |
CN112634341A (zh) * | 2020-12-24 | 2021-04-09 | 湖北工业大学 | 多视觉任务协同的深度估计模型的构建方法 |
CN113065581A (zh) * | 2021-03-18 | 2021-07-02 | 重庆大学 | 基于参数共享对抗域自适应网络的振动故障迁移诊断方法 |
Family Cites Families (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2021228404A1 (en) * | 2020-05-15 | 2021-11-18 | Huawei Technologies Co., Ltd. | Generating high-dimensional, high utility synthetic data |
CN112016632B (zh) * | 2020-09-25 | 2024-04-26 | 北京百度网讯科技有限公司 | 模型联合训练方法、装置、设备和存储介质 |
-
2021
- 2021-12-02 WO PCT/CN2021/135055 patent/WO2023097602A1/zh unknown
- 2021-12-02 CN CN202180004174.3A patent/CN114287009B/zh active Active
-
2024
- 2024-03-22 US US18/613,118 patent/US20240232665A1/en active Pending
Patent Citations (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110008696A (zh) * | 2019-03-29 | 2019-07-12 | 武汉大学 | 一种面向深度联邦学习的用户数据重建攻击方法 |
CN110580496A (zh) * | 2019-07-11 | 2019-12-17 | 南京邮电大学 | 一种基于熵最小化的深度迁移学习系统及方法 |
CN112101489A (zh) * | 2020-11-18 | 2020-12-18 | 天津开发区精诺瀚海数据科技有限公司 | 一种联邦学习与深度学习融合驱动的设备故障诊断方法 |
CN112600794A (zh) * | 2020-11-23 | 2021-04-02 | 南京理工大学 | 一种联合深度学习中检测gan攻击的方法 |
CN112634341A (zh) * | 2020-12-24 | 2021-04-09 | 湖北工业大学 | 多视觉任务协同的深度估计模型的构建方法 |
CN113065581A (zh) * | 2021-03-18 | 2021-07-02 | 重庆大学 | 基于参数共享对抗域自适应网络的振动故障迁移诊断方法 |
Non-Patent Citations (1)
Title |
---|
MINGXUE XU 等: "Subject Property Inference Attack in Collaborative Learning", IEEE, 23 August 2020 (2020-08-23), pages 227 - 231, XP033830291, DOI: 10.1109/IHMSC49165.2020.00057 * |
Also Published As
Publication number | Publication date |
---|---|
CN114287009B (zh) | 2024-08-02 |
WO2023097602A1 (zh) | 2023-06-08 |
US20240232665A1 (en) | 2024-07-11 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Garnelo et al. | Conditional neural processes | |
CN108133330B (zh) | 一种面向社交众包任务分配方法及其系统 | |
Baytas et al. | Asynchronous multi-task learning | |
Kolesnikov et al. | PixelCNN models with auxiliary variables for natural image modeling | |
Hu et al. | Mixnorm: Test-time adaptation through online normalization estimation | |
CN106548159A (zh) | 基于全卷积神经网络的网纹人脸图像识别方法与装置 | |
Kazemi et al. | Unsupervised facial geometry learning for sketch to photo synthesis | |
CN114358250A (zh) | 数据处理方法、装置、计算机设备、介质及程序产品 | |
CN114677535A (zh) | 域适应图像分类网络的训练方法、图像分类方法及装置 | |
CN116452333A (zh) | 异常交易检测模型的构建方法、异常交易检测方法及装置 | |
Muhammad et al. | Early Stopping Effectiveness for YOLOv4. | |
Baghirli et al. | Satdm: Synthesizing realistic satellite image with semantic layout conditioning using diffusion models | |
CN116523002A (zh) | 多源异构数据的动态图生成对抗网络轨迹预测方法和系统 | |
Wang et al. | Comment: Variational autoencoders as empirical bayes | |
CN114287009A (zh) | 协同训练数据属性的推断方法、装置、设备及存储介质 | |
CN104200222B (zh) | 一种基于因子图模型的图片中对象识别方法 | |
CN114140848B (zh) | 基于knn和dsn的微表情识别方法、系统、设备及存储介质 | |
CN113706290A (zh) | 在区块链上采用神经架构搜索的信用评估模型构建方法、系统、设备及存储介质 | |
CN115908600A (zh) | 基于先验正则化的大批量图像重建方法 | |
CN117033997A (zh) | 数据切分方法、装置、电子设备和介质 | |
Faye et al. | Regularization by denoising: Bayesian model and Langevin-within-split Gibbs sampling | |
Kong et al. | Learning Deep Contrastive Network for Facial Age Estimation | |
Yang | Feature sharing attention 3d face reconstruction with unsupervised learning from in-the-wild photo collection | |
Wu et al. | Multi-rater prism: Learning self-calibrated medical image segmentation from multiple raters | |
CN113569887B (zh) | 图片识别模型训练和图片识别方法、装置和存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |