CN116244484B - 一种面向不平衡数据的联邦跨模态检索方法及系统 - Google Patents
一种面向不平衡数据的联邦跨模态检索方法及系统 Download PDFInfo
- Publication number
- CN116244484B CN116244484B CN202310523580.1A CN202310523580A CN116244484B CN 116244484 B CN116244484 B CN 116244484B CN 202310523580 A CN202310523580 A CN 202310523580A CN 116244484 B CN116244484 B CN 116244484B
- Authority
- CN
- China
- Prior art keywords
- global
- sample
- cross
- client
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000000034 method Methods 0.000 title claims abstract description 54
- 238000012549 training Methods 0.000 claims abstract description 98
- 238000004364 calculation method Methods 0.000 claims abstract description 16
- 230000006870 function Effects 0.000 claims description 26
- 238000012935 Averaging Methods 0.000 claims description 10
- 238000004220 aggregation Methods 0.000 claims description 10
- 238000000605 extraction Methods 0.000 claims description 10
- 230000002776 aggregation Effects 0.000 claims description 9
- 239000011159 matrix material Substances 0.000 claims description 8
- 230000003044 adaptive effect Effects 0.000 claims description 6
- 238000013527 convolutional neural network Methods 0.000 claims description 3
- 239000000284 extract Substances 0.000 claims description 3
- 238000013139 quantization Methods 0.000 claims description 3
- 230000000007 visual effect Effects 0.000 claims description 3
- 230000004931 aggregating effect Effects 0.000 claims 1
- 238000009827 uniform distribution Methods 0.000 abstract description 5
- 238000009826 distribution Methods 0.000 description 12
- 230000008569 process Effects 0.000 description 5
- 238000010801 machine learning Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 3
- 230000004927 fusion Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000013480 data collection Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 230000002708 enhancing effect Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000012423 maintenance Methods 0.000 description 1
- 238000005259 measurement Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000010606 normalization Methods 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000005070 sampling Methods 0.000 description 1
- 238000003860 storage Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/907—Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually
- G06F16/908—Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually using metadata automatically derived from the content
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Library & Information Science (AREA)
- Databases & Information Systems (AREA)
- Data Mining & Analysis (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
- Image Analysis (AREA)
Abstract
本发明提出了一种面向不平衡数据的联邦跨模态检索方法及系统,涉及联邦学习领域、跨模态检索领域,解决跨模态检索任务中数据非独立同分布带来的影响,基于训练后的全局跨模态检索模型,对待查询目标的查询样本进行编码,获得查询哈希码;对查询哈希码与检索数据集中的数据哈希码进行相似度计算,基于相似度,获得检索结果;全局跨模态检索模型是基于联邦学习训练得到的;本发明面向非独立同分布数据,通过将全局特征类别原型嵌入样本特征中丰富增强样本的特征表示;且充分利用监督学习标签的语义信息,使生成的哈希码更具有判别力和准确性;还提出一种新的服务器端加权平均本地模型参数的方法,有效提升跨设备情况下联邦跨模态检索模型的性能。
Description
技术领域
本发明属于联邦学习领域、跨模态检索领域,尤其涉及一种面向不平衡数据的联邦跨模态检索方法及系统。
背景技术
本部分的陈述仅仅是提供了与本发明相关的背景技术信息,不必然构成在先技术。
为满足日益严格的隐私保护要求,并避免传统的集中式机器学习中的隐私泄露问题,联邦学习应运而生;在联邦学习中,一系列本地设备在中央服务器的协调下协作训练机器学习模型,但同时联邦学习也存在挑战,由于各客户端间数据生成和采样方式存在差异,因此其本地数据是非独立同分布的,非独立同分布数据可能会在联邦学习训练过程中产生客户端模型偏移问题,导致全局模型的性能下降,甚至难以收敛。
在很大程度上,深度神经网络的成功依赖于大量的训练样本,但数据样本通常存储在不同的设备或机构上,将大量分布式的数据样本进行收集并进行集中式存储,不仅耗时和昂贵,并且违反了法律限制或隐私安全保护的要求;因此联邦学习作为一个分布式的机器学习框架,在保证各客户端的数据样本不离开本地的情况下,通过联合一系列客户端来协作训练一个全局模型,它的出现解决了在隐私安全要求下数据样本无法共享的问题。
在联邦学习中,最基础的挑战之一就是数据异质性,数据的异质性和不平衡性是联邦学习中常见的情况;不同于传统机器学习中集中收集和处理数据的工作模式,在联邦学习中各客户端的数据产生于客户端本地,因此不同客户端间的数据分布在很大程度上不同,由此会导致各客户端间本地数据的非独立同分布性。此外,非独立同分布的数据也存在不同的情况,例如特征分布偏差、标签分布偏差或数量偏差等,这些异质性会在不同程度上影响联邦学习算法的稳定性、收敛性和效果;
现有的深度跨模态检索方法通常都需要大量的训练数据,但直接将大量数据进行聚合不仅会带来巨大的隐私风险,还需要高昂的维护成本;利用联邦学习来完成深度跨模态检索模型的训练是一个可行的方案,不仅解决了隐私保护问题,还能继承深度跨模态检索中检索效率高和存储成本低的优点,能将其应用于大规模跨模态检索任务。
但基于联邦学习的跨模态检索任务,同样面临着各客户端间本地数据的非独立同分布性,从而严重影响了跨设备情况下联邦跨模态检索模型的性能。
发明内容
为克服上述现有技术的不足,解决基于联邦学习的跨模态检索任务中数据非独立同分布带来的影响,本发明提供了一种面向不平衡数据的联邦跨模态检索方法及系统,面向非独立同分布数据,通过将全局特征类别原型嵌入样本特征中来丰富增强样本的特征表示;且充分利用监督学习标签的语义信息,使生成的哈希码更具有判别力和准确性;还提出了一种新的服务器端加权平均本地模型参数的方法,有效提升跨设备情况下联邦跨模态检索模型的性能。
为实现上述目的,本发明的一个或多个实施例提供了如下技术方案:
本发明第一方面提供了一种面向不平衡数据的联邦跨模态检索方法;
一种面向不平衡数据的联邦跨模态检索方法,包括:
基于训练后的全局跨模态检索模型,对待查询目标的查询样本进行编码,获得查询哈希码;
对所述查询哈希码与检索数据集中的数据哈希码进行相似度计算,基于所述相似度,获得检索结果;
其中,所述全局跨模态检索模型是基于联邦学习训练得到的,在每一轮迭代训练中,基于上一轮输出的全局模型参数和全局特征类别原型,将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练;当参与训练的客户端都完成本轮迭代训练后,采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,并更新全局特征类别原型。
进一步的,所述全局跨模态检索模型,包括特征提取层、特征增强层、哈希层、分类层和原型计算层。
进一步的,所述特征提取层,用于各客户端基于上一轮训练后的全局模型参数,本地提取样本特征;
对于图像模态,利用卷积神经网络来提取原始视觉特征,对于文本模态,利用两个全连接层提取原始文本特征。
进一步的,所述特征增强层,用于基于提取的样本特征和上一轮的全局特征类别原型,计算样本的增强特征,具体步骤为:
将样本标签与全局特征类别原型相融合,得到富含全局记忆信息的记忆特征;
引入一个自适应选择器,将样本特征与记忆特征相融合,得到样本的增强特征。
进一步的,所述分类层,以样本哈希码为输入,计算样本哈希码的分类标签,利用分类标签与样本原始标签之间的偏差,构造交叉熵损失函数,进行监督学习。
进一步的,所述在服务器端加权平均本地模型参数的方法,是通过相似性权重和类别数量权重方法得出各客户端模型权重,并使用该权重进行加权聚合得到下一轮的全局模型参数。
进一步的,所述更新全局特征类别原型,具体为:
(1)客户端的原型计算层,基于本轮训练后的本地模型参数,提取样本特征,计算本地特征类别原型表示为:
其中,表示第i个样本的样本特征,li表示第i个样本的标签,c表示类别,nk表示第k个客户端的样本数量;
(2)服务器端,基于各个客户端的本地特征类别原型,计算全局特征类别原型Pglobal,表示为:
其中,K表示客户端的个数,|D|表示K个客户端的样本总量,|Dk|表示第k个客户端的样本数量。
本发明第二方面提供了一种面向不平衡数据的联邦跨模态检索系统。
一种面向不平衡数据的联邦跨模态检索系统,包括编码单元和检索单元;
编码单元,被配置为:基于训练后的全局跨模态检索模型,对待查询目标的查询样本进行编码,获得查询哈希码;
检索单元,被配置为:对所述查询哈希码与检索数据集中的数据哈希码进行相似度计算,基于所述相似度,获得检索结果;
其中,所述全局跨模态检索模型是基于联邦学习训练得到的,在每一轮迭代训练中,基于上一轮输出的全局模型参数和全局特征类别原型,将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练;当参与训练的客户端都完成本轮迭代训练后,采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,并更新全局特征类别原型。
以上一个或多个技术方案存在以下有益效果:
本发明提出了一种动态元嵌入模块,使全局语义知识能在各个参与训练的客户端间传输,通过全局特征类别原型的嵌入不仅丰富了样本的特征表示,还缓解了联邦学习中各客户端存在的数据分布不平衡问题。
本发明充分利用监督学习中标签的语义信息,使各个模态生成的哈希码更具有判别力和准确性。
本发明提出了一种新的在服务器端加权平均本地模型参数的方法,使得到的全局模型更具有泛化能力,并且能有效解决联邦学习中数据非独立同分布问题。
本发明附加方面的优点将在下面的描述中部分给出,部分将从下面的描述中变得明显,或通过本发明的实践了解到。
附图说明
构成本发明的一部分的说明书附图用来提供对本发明的进一步理解,本发明的示意性实施例及其说明用于解释本发明,并不构成对本发明的不当限定。
图1为第一个实施例的方法流程图。
图2为第二个实施例的系统结构图。
具体实施方式
下面结合附图与实施例对本发明做进一步说明。
应该指出,以下详细说明都是例示性的,旨在对本发明提供进一步的说明。除非另有指明,本文使用的所有技术和科学术语具有与本发明所属技术领域的普通技术人员通常理解的相同含义。
需要注意的是,这里所使用的术语仅是为了描述具体实施方式,而非意图限制根据本发明的示例性实施方式。如在这里所使用的,除非上下文另外明确指出,否则单数形式也意图包括复数形式,此外,还应当理解的是,当在本说明书中使用术语“包含”和/或“包括”时,其指明存在特征、步骤、操作、器件、组件和/或它们的组合。
在不冲突的情况下,本发明中的实施例及实施例中的特征可以相互组合。
实施例一
本实施例公开了一种面向不平衡数据的联邦跨模态检索方法;
如图1所示,一种面向不平衡数据的联邦跨模态检索方法,包括:
步骤S1:基于训练后的全局跨模态检索模型,对待查询目标的查询样本进行编码,获得查询哈希码。
将查询样本输入到训练后的全局跨模态检索模型中,输出样本的实值哈希码h,并通过sign()函数将其转为二值哈希码,即查询哈希码Bq,用公式表示为:
Bq=sign(h)
其中,sign()为符号函数,Bq为查询样本的哈希码。
也就是说,训练后的全局跨模态检索模型,用于检索时,以样本为输入、样本的哈希码为最终输出。
步骤S2:对所述查询哈希码与检索数据集中的数据哈希码进行相似度计算,基于所述相似度,获得检索结果。
基于海明距离,度量样本的查询哈希码Bq与检索集中样本的哈希码间的相似性,海明距离越小则样本间的相似性越高,从而实现快速的跨模态检索任务。
其中,所述全局跨模态检索模型是基于联邦学习训练得到的,在每一轮迭代训练中,基于上一轮输出的全局模型参数和全局特征类别原型,将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练;当参与训练的客户端都完成本轮迭代训练后,采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,并更新全局特征类别原型。
具体的,为了便于理解,以下结合附图对本实施例所述方案进行详细说明。
联邦学习训练过程为:中央服务器首先初始化全局模型参数,并将初始化的模型参数分发到本轮各个参与训练的客户端中,各客户端首先加载由中央服务器发来的参数,然后利用本地样本对模型进行迭代训练,优化方法可选用随机梯度下降等方法,本地训练完成后,将更新后的本地模型参数传回中央服务器端,中央服务器根据收到的各客户端本地模型参数对服务器的全局模型进行更新,至此完成联邦学习中的一轮训练,然后迭代更新,直到到达目标效果或者规定轮数时终止训练。
本实施例是在联邦学习训练的基础上,构建全局跨模态检索模型,用于将样本编码成哈希码,通过哈希码进行比对,找到目标数据,所以全局跨模态检索模型是本实施例的关键。
本实施例中的全局跨模态检索模型,关键点在于:首先将代表全局特征类别信息的全局特征类别原型嵌入各客户端本地的样本特征中,丰富不同模态的特征表示,得到样本的增强特征,然后利用增强特征学习样本的哈希码,同时在模型中添加分类层,充分利用标签信息进行监督学习,生成更有判别力的哈希码,为缓解数据不平衡对全局模型性能的影响;还设计了一种加权平均方法来融合各个客户端的模型参数、求取全局模型参数;所以,全局跨模态检索模型,主要包含两个模块,分别为不同模态样本的特征增强部分和全局模型聚合部分。
样本特征增强部分,对应下面模型结构中的“特征增强层”,通过将服务器上的全局特征类别原型嵌入到客户端样本的特征表示,让全局语义知识能在各客户端间统一;全局模型聚合部分,对应下面模型迭代训练中的“新的在服务器端加权平均本地模型参数的方法”,将各客户端更新的模型参数进行加权平均得到新一轮的全局模型参数。
全局跨模态检索模型,尽管检索时以样本为输入、样本的哈希码为最终输出,但在迭代训练的过程中,每次迭代的输入为上一轮训练后的全局模型参数和上一轮训练后更新的全局特征类别原型,输出为本轮训练后的全局模型参数和本轮训练后更新的全局特征类别原型。
为了便于理解,分别从模型结构和模型迭代训练的角度说明全局跨模态检索模型。
不失一般性,以两种模态为例介绍面向不平衡数据的联邦跨模态检索方法;在联邦学习的场景下,假设有k个客户端,每个客户端的本地数据样本表示为其中,nk表示第k个客户端的样本数量,xi代表图像模态的样本实例,yi代表文本模态的样本实例,样本的标签用/>表示,其中,C是类别数量。
一、全局跨模态检索模型的结构
全局跨模态检索模型,包括依次连接的特征提取层、特征增强层、哈希层、分类层和原型计算层:
(1)特征提取层,用于各客户端基于上一轮训练后的全局模型参数,本地提取样本特征。
对于图像模态,首先利用卷积神经网络来提取原始视觉特征对于文本模态利用两个全连接层提取原始文本特征/>其中d为样本特征的维度。
(2)特征增强层,用于基于提取的样本特征和上一轮的全局特征类别原型,计算样本的增强特征。
在联邦学习中,由于各客户端拥有各自的数据样本,一般情况下各客户端的数据分布不是独立同分布的,为了缓解这个问题,本实施例将动态元嵌入思想运用于基于联邦学习的跨模态检索任务中,将上一轮更新后的全局特征类别原型和样本的标签语义信息充分嵌入进样本的特征中,增强了样本的特征表示,使各客户端拥有统一的全局语义知识,从而缓解客户端本地模型漂移问题,具体为:
1)将样本标签与全局特征类别原型相融合,得到富含全局记忆信息的记忆特征。
由于标签中存在充分的语义信息,因此将样本标签与记忆矩阵Pglobal相融合能得到富含全局记忆信息的记忆特征,记忆特征被设计为:
VM=L⊙Pglobal
其中,L是实例的标签矩阵,Pglobal为全局的原始特征类别原型,在两者间使用Hadamard乘积⊙可得到记忆特征VM。
2)引入一个自适应选择器,将样本特征与记忆特征相融合,得到样本的增强特征。
在此基础上再引入一个自适应选择器,将原始特征与记忆特征相融合,从而增强原始特征,得到最终输出的增强特征VE被表示为:
VE=VF+A⊙VM
其中A=tanh(FC(VF)),使用(FC+Tanh)不仅能从原始特征VF中直接得到自适应选择器的权值,还能避免对参数的复杂调整。
(3)哈希层
在增强特征VE后进一步附加一个哈希层,用于生成样本的实值哈希码其中b为哈希码的位数。
(4)分类层
由于更好的哈希码能促进更准确的分类,因此,在模型的最后加入了一个分类层用于监督学习。
分类层以样本哈希码为输入,计算样本哈希码的分类标签,也就是根据样本哈希码预测类别,得到预测的分类标签lp,利用分类标签lp与样本原始标签li之间的偏差,构造交叉熵损失函数,进行监督学习,更多细节在下面迭代训练的损失函数中说明。
(5)原型计算层,用于在每轮客户端本地训练结束后,计算本地特征类别原型,上传到服务器端,用于服务器端计算全局特征类别原型,用于下一轮训练,其中,计算本地特征类别原型的方法,具体为:
对于每一个参与联邦训练的客户端,在本地训练结束后,计算不同模态的本地特征类别原型,以图像模态为例,为第k个客户端图像模态的本地特征类别原型,表示如下:
其中,表示第i个样本的原始特征,li表示第i个样本的标签,c表示类别,nk表示第k个客户端的样本数量。
需要特别说明的是,是本轮客户端训练后的模型提取的样本特征,而不是本轮开始训练前提取的样本特征。
同理,根据上述方法能求出第k个客户端文本模态的本地特征类别原型为
二、全局跨模态检索模型的迭代训练
遵从联邦学习的方法,在每一轮的迭代训练中,先在每个客户端进行本地训练,得到每个客户端本地模型参数,然后服务器将各个客户端上传的模型参数进行加权融合,得到本轮的全局模型参数,用于下一轮的迭代训练,直至满足迭代要求,获得最优的全局跨模态检索模型。
不失一般性,以第t轮联邦学习为例,对本实施例中所设计的方法进行说明,具体步骤为:
(1)在联邦学习第t轮训练过程中,中央服务器首先向本轮参与训练的各个客户端发送第t-1轮训练后的全局模型参数以及全局特征类别原型Pglpbal。
对于第一轮训练,全局模型参数和全局特征类别原型Pglobal,是随机生成的。
(2)第k个客户端在接到中央服务器的数据后,用全局模型参数更新客户端本地的跨模态检索模型,然后开始进行本地的跨模态检索模型的训练:
将客户端本地的训练样本输入到更新后的跨模态检索模型中,通过特征提取层提取出训练样本的原始特征,通过特征增强层,将全局特征类别原型Pglobal和训练样本的标签语义信息相融合得到富含全局记忆信息的记忆特征VM,然后采用一个自适应选择器将原始特征与记忆特征相融合,得到样本的增强特征VE(对应模型结构中的特征增强层)。
将样本的增强特征VE送入哈希层,生成样本的实值哈希码。
将样本的实值哈希码送入分类层,得到预测的分类标签lp。
在训练过程中,损失函数会不断下降,当损失函数满足停止条件时,客户端停止本次迭代训练,所以,为了提升训练的效果和效率,需要设计合适的损失函数适当地惩罚模型。
本实施例本地模型训练使用的损失函数Lall,包括分类损失LCE和哈希损失Lhash:
分类损失LCE,利用分类标签lp与样本原始标签li之间的偏差构建交叉熵损失函数,用来生成更具有准确性的哈希码,分类损失LCE的具体公式如下:
其中,lp表示分类标签,li表示样本原始标签,|Dk|表示第k个客户端的样本数量。
哈希损失Lhash,可用现阶段跨模态检索中常用的相似性损失、量化损失、平衡损失等损失函数,损失函数Lall的具体公式为:
Lall=ηLhash+LCE
其中,η为超参数,用来控制哈希损失的权重。
当第t轮中各个参与训练的客户端都完成设定的本地迭代训练后,各客户端通过原型计算层,生成不同模态的本地特征类别原型,并将其与本地更新模型参数上传至中央服务器,客户端的本次迭代训练结束。
(3)在中央服务器上,使用新的在服务器端加权平均本地模型参数的方法,计算第(t+1)轮的全局模型参数,并基于客户端上传的本地特征类别原型,更新全局特征类别原型Pglobal。
新的在服务器端加权平均本地模型参数的方法,具体为:
在中央服务器对各客户端更新的本地模型进行聚合得到新一轮的全局模型时,传统的加权聚合方式可表示为下式:
其中,为第(t+1)轮的全局模型参数,/>为第k个客户端第t轮的本地模型参数。
为缓解各客户端数据分布不平衡带来的模型漂移问题,本实施例设计了一种新的加权聚合方式,通过利用各客户端本地特征类别原型和全局特征类别原型间的负对数似然函数来度量两者间的相似性,公式如下所示:
其中,表示第k个客户端的原型相似性,/> 为第k个客户端的本地特征类别原型矩阵,S为一个对角线元素为1的相似性矩阵。
当负对数似然函数的值越小时,代表第K个客户端的本地特征类别原型与全局特征类别原型间的相似性越大,则该本地样本的数据分布更符合全局数据分布,此时给该本地跨模态检索模型赋予更高的权重,有助于生成更符合全局数据分布的跨模态检索模型,提高了全局模型的泛化能力,公式如下所示:
其中,Simk表示第k个客户端初始的相似性权重,表示第k个客户端的原型相似性,K表示客户端的个数。
当计算出K个参与训练的客户端的相似性权重后,对其进行归一化能得到第k个客户端最终的相似性权重
此外,不仅考虑到了相似性权重,还考虑了各客户端样本中类别数量的权重,公式如下所示:
其中,l_Numk为第k个客户端中所含有的样本的类别数量总数,K表示客户端的个数,本发明设计的样本类别数量权重在单标签数据集和多标签数据集上都能实现,在多标签数据集情况下,样本类别数量越多的客户端会拥有更多的类别或样本,能训练出更具有泛化能力的跨模态检索模型;在单标签数据集情况下,样本类别数量即客户端样本数量;最终的客户端模型权重表示为如下形式:
其中,Sk表示第k个客户端的相似性权重,Lk表示第k个客户端的类别数量的权重,K表示客户端的个数。
在客户端模型权重的基础上,通过全局模型加权聚合,得到第(t+1)轮的全局模型参数具体公式如下:
其中,为第(t+1)轮的全局模型参数,ak表示第k个客户端的模型权重,为第k个客户端第t轮的本地模型参数,K表示客户端的个数。
这种加权聚合机制能帮助在每一轮中学习一个更具有全局泛化能力的跨模态检索模型,缓解在联邦学习中广泛存在的数据非独立同分布带来的模型偏移问题。
基于客户端上传的本地特征类别原型,更新全局特征类别原型Pglobal的具体方法为:
在每一轮各客户端本地训练结束后,客户端会上传各自的本地特征类别原型到服务器端,在中央服务器端对其进行聚合,以图像模态为例,图像模态的全局特征类别原型Pglobal-img表示为:
其中,K表示客户端的个数,|D|表示K个客户端的样本总量,|Dk|表示第k个客户端图像模态的样本数量。
通过上述公式能计算出图像模态的全局特征类别原型Pglobal-img。同理,根据上述方法能求出文本模态的全局特征类别原型为Pglobal-txt。
将图像和文本模态的全局特征类别原型中的信息充分结合,得到最终的全局特征类别原型,用于下一轮的迭代训练,具体公式为:
在本实施例设计的方法中,每一轮迭代训练都会更新得到一个全局特征类别原型并用于下一轮的训练。
实施例二
本实施例公开了一种面向不平衡数据的联邦跨模态检索系统;
如图2所示,一种面向不平衡数据的联邦跨模态检索系统,包括编码单元和检索单元;
编码单元,被配置为:基于训练后的全局跨模态检索模型,对待查询目标的查询样本进行编码,获得查询哈希码;
检索单元,被配置为:对所述查询哈希码与检索数据集中的数据哈希码进行相似度计算,基于所述相似度,获得检索结果;
其中,所述全局跨模态检索模型是基于联邦学习训练得到的,在每一轮迭代训练中,基于上一轮输出的全局模型参数和全局特征类别原型,将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练;当参与训练的客户端都完成本轮迭代训练后,采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,并更新全局特征类别原型。
以上所述仅为本发明的优选实施例而已,并不用于限制本发明,对于本领域的技术人员来说,本发明可以有各种更改和变化。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。
Claims (6)
1.一种面向不平衡数据的联邦跨模态检索方法,其特征在于,包括:
基于训练后的全局跨模态检索模型,对待查询目标的查询样本进行编码,获得查询哈希码;
对所述查询哈希码与检索数据集中的数据哈希码进行相似度计算,基于所述相似度,获得检索结果;
其中,所述全局跨模态检索模型是基于联邦学习训练得到的,在每一轮迭代训练中,基于上一轮输出的全局模型参数和全局特征类别原型,将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练;当参与训练的客户端都完成本轮迭代训练后,采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,并更新全局特征类别原型;
所述全局跨模态检索模型,包括特征提取层、特征增强层、哈希层、分类层和原型计算层;
所述将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练,具体为用全局模型参数更新客户端本地的跨模态检索模型,然后开始进行本地的跨模态检索模型的训练,步骤为:
将客户端本地的训练样本输入到更新后的跨模态检索模型中,通过特征提取层提取出训练样本的原始特征,通过特征增强层,将全局特征类别原型和训练样本的标签语义信息相融合得到富含全局记忆信息的记忆特征,然后采用一个自适应选择器将原始特征与记忆特征相融合,得到样本的增强特征;
将样本的增强特征送入哈希层,生成样本的实值哈希码;
将样本的实值哈希码送入分类层,得到预测的分类标签lp;
本地模型训练使用的损失函数Lall,包括分类损失LCE和哈希损失Lhash:
分类损失LCE的具体公式如下:
其中,lp表示分类标签,li表示样本原始标签,|Dk|表示第k个客户端的样本数量;
哈希损失Lhash,可用现阶段跨模态检索中常用的相似性损失、量化损失、平衡损失等损失函数,损失函数Lall的具体公式为:
Lall=ηLhash+LCE
其中,η为超参数,用来控制哈希损失的权重;
所述采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,具体为:
在客户端模型权重的基础上,通过全局模型加权聚合,得到下一轮的全局模型参数,具体公式如下:
其中,为下一轮的全局模型参数,ak表示第k个客户端的模型权重,/>为第k个客户端当前轮的本地模型参数,K表示客户端的个数;
其中,Sk表示第k个客户端的相似性权重,Lk表示第k个客户端的类别数量的权重,K表示客户端的个数;
其中,l_Numk为第k个客户端中所含有的样本的类别数量总数,K表示客户端的个数;
其中,Simk表示第k个客户端初始的相似性权重,LSimk表示第k个客户端的原型相似性,K表示客户端的个数;
其中,为第k个客户端的本地特征类别原型矩阵,S为一个对角线元素为1的相似性矩阵;
所述更新全局特征类别原型,具体为:
(1)客户端的原型计算层,基于本轮训练后的本地模型参数,提取样本特征,计算本地特征类别原型表示为:
其中,表示第i个样本的样本特征,li表示第i个样本的标签,c表示类别,nk表示第k个客户端的样本数量;
(2)服务器端,基于各个客户端的本地特征类别原型,计算全局特征类别原型Pglobal,表示为:
其中,K表示客户端的个数,|D|表示K个客户端的样本总量,|Dk|表示第k个客户端的样本数量。
2.如权利要求1所述的一种面向不平衡数据的联邦跨模态检索方法,其特征在于,所述特征提取层,用于各客户端基于上一轮训练后的全局模型参数,本地提取样本特征;
对于图像模态,利用卷积神经网络来提取原始视觉特征,对于文本模态,利用两个全连接层提取原始文本特征。
3.如权利要求1所述的一种面向不平衡数据的联邦跨模态检索方法,其特征在于,所述特征增强层,用于基于提取的样本特征和上一轮的全局特征类别原型,计算样本的增强特征,具体步骤为:
将样本标签与全局特征类别原型相融合,得到富含全局记忆信息的记忆特征;
引入一个自适应选择器,将样本特征与记忆特征相融合,得到样本的增强特征。
4.如权利要求1所述的一种面向不平衡数据的联邦跨模态检索方法,其特征在于,所述分类层,以样本哈希码为输入,计算样本哈希码的分类标签,利用分类标签与样本原始标签之间的偏差,构造交叉熵损失函数,进行监督学习。
5.如权利要求1所述的一种面向不平衡数据的联邦跨模态检索方法,其特征在于,所述在服务器端加权平均本地模型参数的方法,是通过相似性权重和类别数量权重方法得出各客户端模型权重,并使用该权重进行加权聚合得到下一轮的全局模型参数。
6.一种面向不平衡数据的联邦跨模态检索系统,其特征在于,包括编码单元和检索单元;
编码单元,被配置为:基于训练后的全局跨模态检索模型,对待查询目标的查询样本进行编码,获得查询哈希码;
检索单元,被配置为:对所述查询哈希码与检索数据集中的数据哈希码进行相似度计算,基于所述相似度,获得检索结果;
其中,所述全局跨模态检索模型是基于联邦学习训练得到的,在每一轮迭代训练中,基于上一轮输出的全局模型参数和全局特征类别原型,将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练;当参与训练的客户端都完成本轮迭代训练后,采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,并更新全局特征类别原型;
所述全局跨模态检索模型,包括特征提取层、特征增强层、哈希层、分类层和原型计算层;
所述将全局特征类别原型嵌入到各客户端本地的样本特征,得到样本的增强特征,利用增强特征生成样本的哈希码,利用哈希码构建损失函数进行本轮训练,具体为用全局模型参数更新客户端本地的跨模态检索模型,然后开始进行本地的跨模态检索模型的训练,步骤为:
将客户端本地的训练样本输入到更新后的跨模态检索模型中,通过特征提取层提取出训练样本的原始特征,通过特征增强层,将全局特征类别原型和训练样本的标签语义信息相融合得到富含全局记忆信息的记忆特征,然后采用一个自适应选择器将原始特征与记忆特征相融合,得到样本的增强特征;
将样本的增强特征送入哈希层,生成样本的实值哈希码;
将样本的实值哈希码送入分类层,得到预测的分类标签lp;
本地模型训练使用的损失函数Lall,包括分类损失LCE和哈希损失Lhash:
分类损失LCE的具体公式如下:
其中,lp表示分类标签,li表示样本原始标签,|Dk|表示第k个客户端的样本数量;
哈希损失Lhash,可用现阶段跨模态检索中常用的相似性损失、量化损失、平衡损失等损失函数,损失函数Lall的具体公式为:
Lall=ηLhash+LCE
其中,η为超参数,用来控制哈希损失的权重;
所述采用在服务器端加权平均本地模型参数的方法,得到下一轮的全局模型参数,具体为:
在客户端模型权重的基础上,通过全局模型加权聚合,得到下一轮的全局模型参数,具体公式如下:
其中,为下一轮的全局模型参数,ak表示第k个客户端的模型权重,/>为第k个客户端当前轮的本地模型参数,K表示客户端的个数;
其中,Sk表示第k个客户端的相似性权重,Lk表示第k个客户端的类别数量的权重,K表示客户端的个数;
其中,l_Numk为第k个客户端中所含有的样本的类别数量总数,K表示客户端的个数;
其中,Simk表示第k个客户端初始的相似性权重,LSimk表示第k个客户端的原型相似性,K表示客户端的个数;
其中,为第k个客户端的本地特征类别原型矩阵,S为一个对角线元素为1的相似性矩阵;
所述更新全局特征类别原型,具体为:
(1)客户端的原型计算层,基于本轮训练后的本地模型参数,提取样本特征,计算本地特征类别原型表示为:
其中,表示第i个样本的样本特征,li表示第i个样本的标签,c表示类别,nk表示第k个客户端的样本数量;
(2)服务器端,基于各个客户端的本地特征类别原型,计算全局特征类别原型Pglobal,表示为:
其中,K表示客户端的个数,|D|表示K个客户端的样本总量,|Dk|表示第k个客户端的样本数量。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310523580.1A CN116244484B (zh) | 2023-05-11 | 2023-05-11 | 一种面向不平衡数据的联邦跨模态检索方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310523580.1A CN116244484B (zh) | 2023-05-11 | 2023-05-11 | 一种面向不平衡数据的联邦跨模态检索方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN116244484A CN116244484A (zh) | 2023-06-09 |
CN116244484B true CN116244484B (zh) | 2023-08-08 |
Family
ID=86629883
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310523580.1A Active CN116244484B (zh) | 2023-05-11 | 2023-05-11 | 一种面向不平衡数据的联邦跨模态检索方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116244484B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117708681B (zh) * | 2024-02-06 | 2024-04-26 | 南京邮电大学 | 基于结构图指导的个性化联邦脑电信号分类方法及系统 |
Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109299216A (zh) * | 2018-10-29 | 2019-02-01 | 山东师范大学 | 一种融合监督信息的跨模态哈希检索方法和系统 |
CN110059198A (zh) * | 2019-04-08 | 2019-07-26 | 浙江大学 | 一种基于相似性保持的跨模态数据的离散哈希检索方法 |
WO2019231624A2 (en) * | 2018-05-30 | 2019-12-05 | Quantum-Si Incorporated | Methods and apparatus for multi-modal prediction using a trained statistical model |
WO2022104540A1 (zh) * | 2020-11-17 | 2022-05-27 | 深圳大学 | 一种跨模态哈希检索方法、终端设备及存储介质 |
WO2022155994A1 (zh) * | 2021-01-21 | 2022-07-28 | 深圳大学 | 基于注意力的深度跨模态哈希检索方法、装置及相关设备 |
CN114925238A (zh) * | 2022-07-20 | 2022-08-19 | 山东大学 | 一种基于联邦学习的视频片段检索方法及系统 |
CN114943017A (zh) * | 2022-06-20 | 2022-08-26 | 昆明理工大学 | 一种基于相似性零样本哈希的跨模态检索方法 |
CN115080801A (zh) * | 2022-07-22 | 2022-09-20 | 山东大学 | 基于联邦学习和数据二进制表示的跨模态检索方法及系统 |
CN115686868A (zh) * | 2022-12-28 | 2023-02-03 | 中南大学 | 一种基于联邦哈希学习的面向跨节点多模态检索方法 |
CN115795065A (zh) * | 2022-11-04 | 2023-03-14 | 山东建筑大学 | 基于带权哈希码的多媒体数据跨模态检索方法及系统 |
-
2023
- 2023-05-11 CN CN202310523580.1A patent/CN116244484B/zh active Active
Patent Citations (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2019231624A2 (en) * | 2018-05-30 | 2019-12-05 | Quantum-Si Incorporated | Methods and apparatus for multi-modal prediction using a trained statistical model |
CN109299216A (zh) * | 2018-10-29 | 2019-02-01 | 山东师范大学 | 一种融合监督信息的跨模态哈希检索方法和系统 |
CN110059198A (zh) * | 2019-04-08 | 2019-07-26 | 浙江大学 | 一种基于相似性保持的跨模态数据的离散哈希检索方法 |
WO2022104540A1 (zh) * | 2020-11-17 | 2022-05-27 | 深圳大学 | 一种跨模态哈希检索方法、终端设备及存储介质 |
WO2022155994A1 (zh) * | 2021-01-21 | 2022-07-28 | 深圳大学 | 基于注意力的深度跨模态哈希检索方法、装置及相关设备 |
CN114943017A (zh) * | 2022-06-20 | 2022-08-26 | 昆明理工大学 | 一种基于相似性零样本哈希的跨模态检索方法 |
CN114925238A (zh) * | 2022-07-20 | 2022-08-19 | 山东大学 | 一种基于联邦学习的视频片段检索方法及系统 |
CN115080801A (zh) * | 2022-07-22 | 2022-09-20 | 山东大学 | 基于联邦学习和数据二进制表示的跨模态检索方法及系统 |
CN115795065A (zh) * | 2022-11-04 | 2023-03-14 | 山东建筑大学 | 基于带权哈希码的多媒体数据跨模态检索方法及系统 |
CN115686868A (zh) * | 2022-12-28 | 2023-02-03 | 中南大学 | 一种基于联邦哈希学习的面向跨节点多模态检索方法 |
Non-Patent Citations (1)
Title |
---|
基于哈希学习的大规模媒体检索研究;罗昕;《中国优秀博士学位论文全文数据库》;全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN116244484A (zh) | 2023-06-09 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111537945B (zh) | 基于联邦学习的智能电表故障诊断方法及设备 | |
CN112508085B (zh) | 基于感知神经网络的社交网络链路预测方法 | |
CN104318340B (zh) | 基于文本履历信息的信息可视化方法及智能可视分析系统 | |
CN113177132B (zh) | 基于联合语义矩阵的深度跨模态哈希的图像检索方法 | |
CN113326377B (zh) | 一种基于企业关联关系的人名消歧方法及系统 | |
CN111898703B (zh) | 多标签视频分类方法、模型训练方法、装置及介质 | |
CN113821670B (zh) | 图像检索方法、装置、设备及计算机可读存储介质 | |
CN111950622B (zh) | 基于人工智能的行为预测方法、装置、终端及存储介质 | |
CN107194422A (zh) | 一种结合正反向实例的卷积神经网络关系分类方法 | |
CN107947921A (zh) | 基于递归神经网络和概率上下文无关文法的密码生成系统 | |
CN111026887B (zh) | 一种跨媒体检索的方法及系统 | |
CN114580663A (zh) | 面向数据非独立同分布场景的联邦学习方法及系统 | |
CN113822315A (zh) | 属性图的处理方法、装置、电子设备及可读存储介质 | |
CN115114409B (zh) | 一种基于软参数共享的民航不安全事件联合抽取方法 | |
CN116244484B (zh) | 一种面向不平衡数据的联邦跨模态检索方法及系统 | |
CN115080801A (zh) | 基于联邦学习和数据二进制表示的跨模态检索方法及系统 | |
CN112364889A (zh) | 一种基于云平台的制造资源智能匹配系统 | |
CN113887694A (zh) | 一种注意力机制下基于特征表征的点击率预估模型 | |
CN116362325A (zh) | 一种基于模型压缩的电力图像识别模型轻量化应用方法 | |
CN113705242B (zh) | 面向教育咨询服务的智能语义匹配方法和装置 | |
CN117371481A (zh) | 一种基于元学习的神经网络模型检索方法 | |
CN116777646A (zh) | 基于人工智能的风险识别方法、装置、设备及存储介质 | |
CN116340516A (zh) | 实体关系的聚类提取方法、装置、设备及存储介质 | |
CN114997920A (zh) | 广告文案生成方法及其装置、设备、介质、产品 | |
CN114398980A (zh) | 跨模态哈希模型的训练方法、编码方法、装置及电子设备 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |