CN116050508A - 神经网络训练方法以及装置 - Google Patents
神经网络训练方法以及装置 Download PDFInfo
- Publication number
- CN116050508A CN116050508A CN202111261376.4A CN202111261376A CN116050508A CN 116050508 A CN116050508 A CN 116050508A CN 202111261376 A CN202111261376 A CN 202111261376A CN 116050508 A CN116050508 A CN 116050508A
- Authority
- CN
- China
- Prior art keywords
- image
- loss
- network
- hash
- neural network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 123
- 238000000034 method Methods 0.000 title claims abstract description 102
- 238000012549 training Methods 0.000 title claims abstract description 43
- 238000013139 quantization Methods 0.000 claims abstract description 102
- 238000000605 extraction Methods 0.000 claims abstract description 18
- 238000005457 optimization Methods 0.000 claims description 10
- 230000004044 response Effects 0.000 claims description 8
- 239000013598 vector Substances 0.000 description 23
- 238000012545 processing Methods 0.000 description 19
- 230000008569 process Effects 0.000 description 15
- 238000012512 characterization method Methods 0.000 description 14
- 238000005516 engineering process Methods 0.000 description 14
- 230000000694 effects Effects 0.000 description 11
- 238000013473 artificial intelligence Methods 0.000 description 10
- 238000003860 storage Methods 0.000 description 10
- 238000013135 deep learning Methods 0.000 description 8
- 238000004422 calculation algorithm Methods 0.000 description 7
- 230000006870 function Effects 0.000 description 7
- 238000007726 management method Methods 0.000 description 7
- 230000005540 biological transmission Effects 0.000 description 6
- 230000009467 reduction Effects 0.000 description 6
- 230000006978 adaptation Effects 0.000 description 5
- 238000010586 diagram Methods 0.000 description 5
- 230000014509 gene expression Effects 0.000 description 5
- 238000010801 machine learning Methods 0.000 description 5
- 238000005259 measurement Methods 0.000 description 5
- 230000008901 benefit Effects 0.000 description 4
- 238000004590 computer program Methods 0.000 description 4
- 238000012544 monitoring process Methods 0.000 description 4
- 238000013145 classification model Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 241000282472 Canis lupus familiaris Species 0.000 description 2
- 230000002776 aggregation Effects 0.000 description 2
- 238000004220 aggregation Methods 0.000 description 2
- 238000013459 approach Methods 0.000 description 2
- 238000012550 audit Methods 0.000 description 2
- 238000013500 data storage Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000011478 gradient descent method Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 238000012423 maintenance Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 230000011218 segmentation Effects 0.000 description 2
- 241000271566 Aves Species 0.000 description 1
- 241000282326 Felis catus Species 0.000 description 1
- 230000009471 action Effects 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000001174 ascending effect Effects 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 230000015556 catabolic process Effects 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 238000000354 decomposition reaction Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 238000006731 degradation reaction Methods 0.000 description 1
- 238000009826 distribution Methods 0.000 description 1
- 238000009499 grossing Methods 0.000 description 1
- 230000003862 health status Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000006698 induction Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000010295 mobile communication Methods 0.000 description 1
- 230000008450 motivation Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 238000007781 pre-processing Methods 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000012954 risk control Methods 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 238000013526 transfer learning Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Image Analysis (AREA)
Abstract
本申请实施例公开了神经网络训练方法以及装置,上述方法包括获取第一图像组,上述第一图像组为至少一个第一样本图像组中任一图像组,上述第一样本图像组包括第一图像、第二图像和第三图像,上述第二图像与上述第一图像的类别相同,上述第三图像与上述第一图像的类别不相同;基于第一网络对第一图像组进行哈希特征提取得到哈希特征组;基于第二网络对哈希特征组进行量化得到量化结果组,上述第二网络的参数表征类别对应的码本;确定神经网路损失,基于上述神经网络损失,优化上述神经网络参数;上述神经网络损失包括基于上述哈希特征组确定的第一损失、基于上述量化结果组确定的第二损失。本申请实施例提升基于码本的图像检索准确度。
Description
技术领域
本申请实施例涉及人工智能技术领域,尤其涉及神经网络训练方法以及装置。
背景技术
乘积量化算法是在矢量量化的基础上发展而来的一种检索方法,可以用于加快图像的检索速度,但是相关技术中乘积量化算法存在特征割裂的问题,也就是说,相似的图像产生的特征可能被量化为不同的编码或者编码的差异大,这影响了基于乘积量化进行图像检索的准确度,尤其是在端到端量化场景中,性能下降尤为明显。
发明内容
为了提升基于乘积量化进行图像检索的准确度,提升乘积量化对于图像的表征能力,本申请实施例提供神经网络训练方法以及装置。
一方面,本申请实施例提供了一种神经网络训练方法,所述方法包括:
获取第一图像组,所述第一图像组为至少一个第一样本图像组中任一图像组,所述第一样本图像组包括第一图像、第二图像和第三图像,所述第二图像与所述第一图像的类别相同,所述第三图像与所述第一图像的类别不相同;
基于所述第一网络对所述第一图像组进行哈希特征提取,得到哈希特征组;
基于所述第二网络对所述哈希特征组进行量化,得到量化结果组,所述第二网络的参数表征类别对应的码本;
确定神经网络损失,以及基于所述神经网络损失,优化所述神经网络参数;
其中,所述神经网络损失包括基于所述哈希特征组确定的第一损失,以及基于所述量化结果组确定的第二损失;所述第一损失和所述第二损失均为三元组损失。
另一方面,本申请实施例提供一种神经网络训练装置,所述装置包括:
图像组获取模块,用于获取第一图像组,所述第一图像组为至少一个第一样本图像组中任一图像组,所述第一样本图像组包括第一图像、第二图像和第三图像,所述第二图像与所述第一图像的类别相同,所述第三图像与所述第一图像的类别不相同;
哈希提取模块,用于基于所述第一网络对所述第一图像组进行哈希特征提取,得到哈希特征组;
量化模块,用于基于所述第二网络对所述哈希特征组进行量化,得到量化结果组,所述第二网络的参数表征类别对应的码本;
优化模块,用于确定神经网络损失,以及基于所述神经网络损失,优化所述神经网络参数;
其中,所述神经网络损失包括基于所述哈希特征组确定的第一损失,以及基于所述量化结果组确定的第二损失;所述第一损失和所述第二损失均为三元组损失。
另一方面,本申请实施例提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有至少一条指令或至少一段程序,所述至少一条指令或至少一段程序由处理器加载并执行以实现上述的一种神经网络训练方法。
另一方面,本申请实施例提供了一种电子设备,包括至少一个处理器,以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述至少一个处理器通过执行所述存储器存储的指令实现上述的一种神经网络训练方法。
另一方面,本申请实施例提供了一种计算机程序产品,包括计算机程序或指令,该计算机程序或指令被处理器执行时实现上述的一种神经网络训练方法。
本申请实施例提供了神经网络训练方法、装置、存储介质及设备。本申请实施例提供的神经网络训练方法通过在第一网络上基于多标签的多原型产生乘积量化码本的语义哈希特征的深度量化手段,参考多标签数据原型高效得到码本,避免了码本表征不明确进而产生特征割裂的问题。进一步地,还可以借助多标签语义原型产生码本,使得哈希特征分布在这些多标签原型中心对应的码本附近,可提升码本的聚合效果,最终提升基于码本的图像检索准确度。
附图说明
为了更清楚地说明本申请实施例或相关技术中的技术方案和优点,下面将对实施例或相关技术描述中所需要使用的附图作简单的介绍,显而易见地,下面描述中的附图仅仅是本申请实施例的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其它附图。
图1是本申请实施例提供的神经网络训练方法流程示意图;
图2是本申请实施例提供的神经网络的结构示意图;
图3是本申请实施例提供的第二网络参数的更新方法流程示意图;
图4是本申请实施例提供的基于上述第二样本图像组更新第二网络的参数方法流程示意图;
图5是本申请实施例提供的根据聚类处理结果更新第二网络的参数方法流程示意图;
图6是本申请实施例提供的基于各标签对应的原型平均距离更新第二网络的参数方法流程示意图;
图7是本申请实施例提供的神经网络训练装置的框图;
图8是本申请实施例提供的一种用于实现本申请实施例所提供的方法的设备的硬件结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请实施例一部分实施例,而不是全部的实施例。基于本申请实施例中的实施例,本领域普通技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本申请实施例保护的范围。
需要说明的是,本申请实施例的说明书和权利要求书及上述附图中的术语“第一”、“第二”等是用于区别类似的对象,而不必用于描述特定的顺序或先后次序。应该理解这样使用的数据在适当情况下可以互换,以便这里描述的本申请实施例的实施例能够以除了在这里图示或描述的那些以外的顺序实施。此外,术语“包括”和“具有”以及他们的任何变形,意图在于覆盖不排他的包含,例如,包含了一系列步骤或单元的过程、方法、系统、产品或服务器不必限于清楚地列出的那些步骤或单元,而是可包括没有清楚地列出的或对于这些过程、方法、产品或设备固有的其它步骤或单元。
为了使本申请实施例公开的目的、技术方案及优点更加清楚明白,以下结合附图及实施例,对本申请实施例进行进一步详细说明。应当理解,此处描述的具体实施例仅仅用以解释本申请实施例,并不用于限定本申请实施例。
以下,术语“第一”、“第二”仅用于描述目的,而不能理解为指示或暗示相对重要性或者隐含指明所指示的技术特征的数量。由此,限定有“第一”、“第二”的特征可以明示或者隐含地包括一个或者更多个该特征。在本实施例的描述中,除非另有说明,“多个”的含义是两个或两个以上。为了便于理解本申请实施例上述的技术方案及其产生的技术效果,本申请实施例首先对于相关专业名词进行解释:
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习、自动驾驶、智慧交通等几大方向。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、式教学习等技术。
端到端学习:传统的图像识别问题往往通过分治法将其分解为预处理,特征提取和选择,分类器设计等若干步骤。分治法的动机是将图像识别的母问题分解为简单、可控且清晰的若干小的子问题。不过分步解决子问题时,尽管可以在子问题上得到最优解,但子问题上的最优解并不意味着就能得到全局问题的最后解。深度学习提供了一种端到端的学习范式,整个学习的流程并不进行人为的子问题划分,而是完全交给深度学习模型直接学习从原始数据到期望输出的映射。
图像识别:类别级别的识别,不考虑对象的特定实例,仅考虑对象的类别(如人、狗、猫、鸟等)进行的识别并给出对象所属类别。一个典型的例子是大型通用物体识别开源数据集ImageNet中的识别任务,识别出某个物体是200个类别中的哪一个。
图像多标签识别:通过计算机识别出图像是否具有指定属性标签的组合。一张图像可能具有多个属性,多标签识别任务是判断某张图具有哪些预设的属性标签。
ImageNet:大型通用物体识别开源数据集。
ImageNet预训练模型:基于ImageNet训练一个深度学习网络模型,得到该模型的参数权重即为ImageNet预训练模型。
哈希模型:是一种学习二值嵌入(Embedding)特征的模型及方法,用于替代传统浮点型的Embedding特征进行检索,这种哈希的过程也称为特征量化,即把浮点特征量化为二值特征。
基于向量量化的检索技术:通过对Embedding特征向量划分N(N为大于1的正整数)个不重叠的区域、每个区域用一个向量代表(常用聚类中心),从而把特征量化到N个量化向量。检索时先通过召回对应的量化向量,再对比Embedding特征向量对应的图像与量化向量下的图像的相似性,根据该相似性确定检索结果。
乘积量化:Product Quantization,PQ量化,意思是指把原来的向量空间分解为若干个低维向量空间的笛卡尔积,并利用聚类算法对分解得到的低维向量空间分别做量化。这样每个向量就能由多个低维空间的量化特征组合表示。把原始D维向量分成N组,那么每组就是D/N维的子向量,各自用聚类算法学习到一个码本,然后这些码本的笛卡尔积就是原始D维向量对应的码本,码本可以用于进行基于PQ量化的图像检索。
基于乘积量化的检索技术:首先把一个D维的向量分成N个子空间,每个空间特征维度是D/N,对每个子空间分别做聚类,聚类为K类,得到N个空间中聚类中心,形成码本。D、N、K都是正整数。检索时对待检索图像对应的特征分成N个维度,然后每个维度分别在该维度下的K个中心找到最近的中心,召回该中心下的所有样本;N维度下的所有样本分别与待检索图像对应的特征在各维度下计算距离,得到N个距离,对上述N个距离求和得到待检索图像与召回样本的距离,按照距离升序排序,根据排序结果选择排序在前的若干个样本序作为检索结果。该方法在分N个子空间后,每个子空间都分别进行量化,一种简单的量化方法是对子空间中D/N维的特征进行符号量化,即当某一维度特征大于0则量化为1,小于0则量化为0,如对[-1,1,0.5,-0.2]特征向量,量化后得到[0,1,1,0]编码,该方法不再需要计算距离,可以直接根据量化出的编码差异判断两个被量化的特征之间的距离,提升检索速度。
深度学习乘积量化:深度学习进行PQ量化的技术,优点是可以同时学习哈希特征以及PQ量化特征。
Triplet Loss,三元组损失,是深度学习中的一种损失函数,用于训练差异性较小的样本,通过优化锚样本对应的特征与正样本对应的特征的距离小于锚样本对应的特征与负样本对应的特征的距离,实现样本的相似性计算。
相关技术中可以基于深度神经网络训练图像哈希特征的过程中常借助度量学习方法,基于Triplet Loss优化神经网络以使得不相似样本产生的特征的距离相较于相似样本产生的特征的距离大于某个边界值。应用中常利用乘积量化PQ对哈希特征进行计算多个子空间聚类中心,利用这些聚类中心作为索引进行相似度检索。这种方法下哈希特征学习与PQ量化过程是二阶段的、PQ量化过程是无监督的,从而造成PQ量化在相似度度量中不准确,如把一个相似样本对分别量化到不同PQ码本中,这就造成了样本对的割裂。具体来说,PQ量化从特征出发直接做子空间的分割以及每个空间的切分,容易产生相似的样本由于特征相似度不足(如两个相似样本的特征向量分别是[-1,1,0.5,-0.03],[-1,1,0.5,0.01],由特征向量直接做被量化则得到[0,1,1,0]和[0,1,1,1]两个编码,而不是量化到相同的编码,这就导致了样本对的割裂。样本对的割裂降低了码本对于类别的表征能力,降低了基于码本进行检索的准确度,并且量化性能的下降在端到端训练中尤为明显。
有鉴于此,本申请提供一种神经网络训练方法,该方法在神经网络的训练过程中将PQ量化能力和哈希特征提取能力一并进行训练,从而避免了传统二阶段量化没有直接相似样本度量能力,通过基于码本的重建特征进行码本相似度度量学习实现了有监督的码本学习,并且通过多标签获取丰富的图像原型,根据图像原型优化码本,最终使得PQ量化能力与哈希特征提取能力均显著提升,强化码本的语义表征能力,提升对于图像的语义信息的表达能力,进而提升针对图像的语义相似度检索的准确度。
本申请实施例并不限定上述图像的来源,比如上述图像可以来自于用户终端,上述用户终端包括但不限于手机、电脑、智能语音交互设备、智能家电、车载终端等。
本申请实施例所提供的方法还可以涉及区块链,即本申请实施例提供的方法可以基于区块链实现,或者本申请实施例提供的方法中涉及到的数据可以基于区块链存储,或本申请实施例中提供的方法的执行主体可以位于区块链中。区块链是分布式数据存储、点对点传输、共识机制、加密算法等计算机技术的新型应用模式。区块链(Blockchain),本质上是一个去中心化的数据库,是一串使用密码学方法相关联产生的数据块,每一个数据块中包含了一批次网络交易的信息,用于验证其信息的有效性(防伪)和生成下一个区块。区块链可以包括区块链底层平台、平台产品服务层以及应用服务层。
区块链底层平台可以包括用户管理、基础服务、智能合约以及运营监控等处理模块。其中,用户管理模块负责所有区块链参与者的身份信息管理,包括维护公私钥生成(账户管理)、密钥管理以及用户真实身份和区块链地址对应关系维护(权限管理)等,并且在授权的情况下,监管和审计某些真实身份的交易情况,提供风险控制的规则配置(风控审计);基础服务模块部署在所有区块链节点设备上,用来验证业务请求的有效性,并对有效请求完成共识后记录到存储上,对于一个新的业务请求,基础服务先对接口适配解析和鉴权处理(接口适配),然后通过共识算法将业务信息加密(共识管理),在加密之后完整一致的传输至共享账本上(网络通信),并进行记录存储;智能合约模块负责合约的注册发行以及合约触发和合约执行,开发人员可以通过某种编程语言定义合约逻辑,发布到区块链上(合约注册),根据合约条款的逻辑,调用密钥或者其它的事件触发执行,完成合约逻辑,同时还提供对合约升级注销的功能;运营监控模块主要负责产品发布过程中的部署、配置的修改、合约设置、云适配以及产品运行中的实时状态的可视化输出,例如:告警、监控网络情况、监控节点设备健康状态等。
平台产品服务层提供典型应用的基本能力和实现框架,开发人员可以基于这些基本能力,叠加业务的特性,完成业务逻辑的区块链实现。应用服务层提供基于区块链方案的应用服务给业务参与方进行使用。
以下介绍本申请实施例的一种神经网络训练方法,上述神经网络包括第一网络和第二网络,图1示出了本申请实施例提供的一种神经网络训练方法的流程示意图,本申请实施例提供了如实施例或流程图上述的方法操作步骤,但基于常规或者无创造性的劳动可以包括更多或者更少的操作步骤。实施例中列举的步骤顺序仅仅为众多步骤执行顺序中的一种方式,不代表唯一的执行顺序。在实际中的系统、终端设备或服务器产品执行时,可以按照实施例或者附图所示的方法顺序执行或者并行执行(例如并行处理器或者多线程处理的环境),上述方法可以包括:
S101.获取第一图像组,上述第一图像组为至少一个第一样本图像组中任一图像组,上述第一样本图像组包括第一图像、第二图像和第三图像,上述第二图像与上述第一图像的类别相同,上述第三图像与上述第一图像的类别不相同。
图像度量学习的训练数据集为三元组样本:anchor、positive、negative,分别表示锚点样本、正样本、负样本,其中anchor和positive组成正样本对,anchor和negative组成负样本对。本申请实施例中每个三元组样本可以生成一个第一样本图像组。为了得到第一样本图像组,可以对于已有的样本图像进行标注,首先标注出正样本对。本申请实施例并不限定正样本对的确定方法,相似图像可以形成正样本对,比如拍摄同一个狗得到的不同图像可以形成正样本对,或者对同一个图像的锐度、清晰度、灰度等参数进行微调得到的不同图像可以形成正样本度。以第一图像组为例,第一图像可以为锚点样本,第二图像为正样本、第三图像为负样本。
本申请实施例中对于某个正样本对的锚点样本,可以计算其它正样本对中的样本与上述锚点样本的相似度,按照相似度由小到大排序选择排序结果中前预设数量个图像作为负样本,与上述正样本对一起形成第一样本图像组。以上述预设数量为20为例,一个正样本对可以形成20个第一样本图像组。如果存在bs个正样本对,则一共可以形成20*bs个第一样本图像组,bs为正整数,本申请不限定bs的值,其可以取值相对大一些,比如,可以取值256。
在一个实施例中,为了在神经网络的训练过程充分利用类别信息,优化神经网络的学习效果,对于每一第一样本图像组,可以直接使用业务上应用的多标签分类模型,将该多标签分类模型对图像进行处理得到的多标签结果作为样本的标签。本申请实施例并不限定上述多标签分类模型的来源,比如,其可以基于ImageNet预训练模型得到。由于业务上的多标签对描述图像具有指导意义,同时也尽可能覆盖了对应业务的海量图像,故通过在神经网络的训练过程中充分使用多标签信息可以提升模型训练效果。
S102.基于上述第一网络对上述第一图像组进行哈希特征提取,得到哈希特征组。
本申请实施例中可以基于第一网络对第一图像组中的第一图像、第二图像和第三图像分别进行哈希特征提取,得到哈希特征组。在一个实施例中,第一图像组中每一图像的哈希特征都可以通过128维度的向量表征。
在一个实施例中可以基于卷积神经网络构建第一网络,比如可以用ResNet-18或者ResNet-101。在一个实施例中,第一网络的结构可以通过表1表征:
表1
其中,Layer Name可以自行定义,表征网络层的名称,Output Size表示网络层输出的数据的维度,ResNet-101表示基于ResNet-101得到的网络结构,以7x7,64,Stride 2为例,7x7表示卷积核大小为7*7,64表示通道数64,Stride 2表示步长为2,Max Pool表示最大池化操作,Block表示多网络层形成的网络层组,x3 Blocks表示三个这样的网络层组的堆叠。
S103.基于上述第二网络对上述哈希特征组进行量化,得到量化结果组,上述第二网络的参数表征类别对应的码本。
本申请实施例中可以基于第二网络对哈希特征组中的各哈希特征分别进行量化,得到量化结果组。在一个实施例中,第二网络的结构可以通过表2表征:
表2
Layer Name | Output Size | Layer |
Code-book | Kx(MxNclass)x(128/K) | Full Connetction |
其中,表1中可以使用Full Connetction(全连接层)的参数表征K个码本,由于128维特征分成了K段,故每个码本对应的特征维度维128/K。对于Nclass各标签对应的大类,可以存在M个小类,每个小类对应一个码本,故共有Nclass*M*K个码本,需要K*(M*Nclass)*(128/K)的全连层参数组成码本。如果每段维护一个64个聚类中心的码本,则有64*K个聚类中心需要学习,该码本中包含64个PQ码。
S104.确定神经网络损失,以及基于上述神经网络损失,优化上述神经网络参数;其中,上述神经网络损失包括基于上述哈希特征组确定的第一损失,以及基于上述量化结果组确定的第二损失;上述第一损失和上述第二损失均为三元组损失。
在一个实施例中,在上述神经网络还包括第三网络,上述方法还包括:基于上述第三网络对上述哈希特征组进行分类预测,得到分类结果组。本申请实施例中可以基于第三网络对哈希特征组中的各哈希特征分别进行分类预测,得到分类结果组。本申请实施例对于第三网络的结构不做限定,比如,其也可以基于卷积网络、残差网络等实现。在一个实施例中,第三网络的结构可以通过表3表征:
表3
Layer Name | Output Size | Layer |
Pool | 1x2048 | Max Pool |
Embedding | 1x128 | Full Connetction |
Classification | 1x200 | Full Connetction |
请参考图2,其示出神经网络的结构示意图。三元组形式的样本可以通过第一网络进行哈希特征提取后得到哈希特征组,哈希特征组经过第二网络进行量化后可以得到量化结果组,哈希特征组经过第三网络进行分类预测后可以得到分类结果组,分类结果组中的分类结果表征样本图像属于任一类别的概率。
本申请不限定对图2所示的神经网络进行训练的具体过程,比如,可以设定第一网络中Conv1-Conv5采用在ImageNet数据集上预训练的ResNet101的参数,第二网络采用方差为0.01,均值为0的高斯分布进行初始化。再比如设置神经网络的学习参数,学习参数请参考表1、2、3。再比如设置学习率为0.00005。
本申请实施例中神经网络中第一损失基于哈希特征组确定。具体来说,第一损失可以通过公式Ltri=max(||xa-xp||-||xa-xn||-∝,0)表征,其中,||Xa-Xp||表示锚点样本与正样本的哈希特征的距离,||xa-xn||表示锚点样本与负样本的哈希特征的距离,∝表征距离余量,本申请不限定∝的值,比如,可以取值为20,这种情况下,第一损失的目的在于使得锚点样本与负样本的哈希特征的距离相较于锚点样本与正样本的哈希特征的距离大至少20。
为了提升神经网络训练速度,本申请实施例中距离可以使用汉明距离度量,汉明距离是使用在数据传输差错控制编码里面的概念,它表示两个相同长度字对应的不同位的数量。对于二值特征为(0,0,0,1)和(1,1,0,1)的两个样本,其汉明距离为不相同位置的个数,即2。
本申请实施例神经网络中第二损失基于量化结果组确定。第二损失与第一损失都是三元组损失,不同之处在于第一损失度量的是哈希特征的距离相似度,而第二损失度量的是量化结果的距离相似度,上述量化结果的相似度体现在基于量化结果得到的量化重建结果的相似度。举例来说,量化结果组包括第一图像、第二图像和第三图像分别对应的量化结果L1、量化结果L2和量化结果L3。基于量化结果L1和第二网络中的码本可以得到对应的量化重建结果基于量化结果L2和第二网络中的码本可以得到对应的量化重建结果基于量化结果L3和第二网络中的码本可以得到对应的量化重建结果量化重建结果和可以被认为是对于第一图像、第二图像和第三图像的基于PQ的量化表达。第二损失就是用于限定与之间的距离相较于与之间的距离大于预设值。参考第一损失,如果将距离余量设置为16,第二损失的含义就是使得锚点样本与负样本的量化重建结果的距离相较于锚点样本与正样本的量化重建结果的距离大至少16。本申请并不限定第二损失的距离余量,其相较于第一损失是允许精确度降低的,也就是说,第二损失的距离余量可以小于第一损失的距离余量。
在一个实施例中,上述神经网络损失还可以包括基于第一哈希特征和第二哈希特征之间的差异确定的第三损失,其中,上述第一哈希特征为至少一个哈希特征组中任一哈希特征,上述第二哈希特征为对上述第一哈希特征进行二值化处理后得到的特征。第三损失的目的是使得哈希特征中的元素二值化差异度增大,比如靠近1或者靠近-1。在一个实施例中,由于第一哈希特征的目标为输出[-1,1]的量化值,由此可对第一哈希特征进行符号量化,将小于0的输出量化为-1,将大于0的输出量化为1,具体来说,可以采用符号函数 对第一哈希特征进行二值化处理,得到第二哈希特征。其中,ui表示第一哈希特征中第i个元素,bi表示第二哈希特征中第i个元素,i为正整数。第三损失可以使用回归损失度量,比如常见的平方损失、绝对值损失都属于回归损失。
在一个实施例中,上述神经网络损失还可以包括基于第一哈希特征和第一量化重建结果之间的差异确定的第四损失;上述第一量化重建结果为根据上述第一哈希特征对应的第一量化结果和上述第二网络的参数重建的特征。本申请实施例中量化重建结果的含义在前文已有说明,在此不做赘述。第一哈希特征和第一量化重建结果是对于相同样本图像的特征表达,因此,第四损失的目的是使得基于码本表达的特征和基于哈希提取得到的特征的距离尽可能小,从而提升码本的表达能力,减小特征的割裂,使得相似的图像的码本表达也是相似的。
由于本申请中对于哈希特征进行K分段,最终重建时由K个分段对应的码本联合重建,以下对i样本的重建结果由1,2,…K个PQ码Cj,j表示某个分段,乘以样本在该PQ码上的权重值之和产生,上述权重值被记录在第一量化结果中。第一量化结果的获取方式为:对哈希特征先分K段,取第1段哈希码,计算该哈希码与对应的码本中的PQ码的汉明距离,取最近的PQ码,作为样本i在第1个分段上激活的量化编码,并使其权重为1,在其他PQ码的权重为0,从而得到Zi1。其他段也同样获得对应Zij,从而得到第一量化结果,。基于第一量化结果中的Zij和上述Cj可以拼接得到与哈希特征相同维度的向量,即为第一量化重建结果,其中第一量化重建结果可以通过表征,R表示第一量化重建结果。第四损失也可以使用回归损失度量,比如均方差损失误差,该损失使得重建结果更接近原始哈希特征,当然由于哈希特征是多变的,而PQ码有限,不可能保证每个重建结果都与哈希码完全一致,只需要保证重建结果都与哈希码尽可能一致即可。
在一个实施例中,上述神经网络损失还包括第五损失或第六损失,上述第六损失为小于分类损失阈值的上述第五损失;上述第五损失为基于第一分类结果与第一标签的差异确定的损失,上述第一分类结果为至少一个分类结果组中的任一分类结果,上述第一标签为上述第一分类结果对应的图像携带的标签。
本申请实施例中第一分类结果可以表征样本图像属于每一类别的概率,第一分类结果的维度可以根据类别数量确定,比如,如果存在5000个类别,则第一分类结果可以被表征为维度为5000的向量,每个元素表征上述样本图像属于对应的类别的概率。本申请实施例中第一标签就是属于上述5000个类别的标签。上述第五损失和第六损失都是描述预测分类结果与标注的分类结果的差异,本申请实施例中可以使用二分类的交叉熵损失函数描述上述第五损失或第六损失。具体来说,可以通过公式表征上述交叉熵损失函数,N为样本数量,i表示样本下标,yi表示样本图像i对应的第一标签,pi表示样本图像i对应的第一分类结果。
本申请实施例认为在标注的标签可能携带错误的信息的情况下,可能会导致神经网络学习到错误的知识,因此在本申请实施例中还可以进行这类样本的平滑处理,避免模型记住这些错误的样本,而产生错误信息回传到神经网络的不良结果。也就是说,本申请实施例可以通过在神经网络中包括第六损失而不包括大于等于分类损失阈值的第五损失来避免神经网络学习到错误的知识。
在上述神经网络损失包括上述第六损失的情况下,上述确定神经网路损失,包括:获取各上述第五损失的平均损失Lmean-j和最大损失Lmax-j。根据上述最大损失Lmax-j和上述平均损失Lmean-j确定上述分类损失阈值。将小于上述分类损失阈值的第五损失确定为上述第六损失。本申请实施例认为,当一个样本存在过大的分类损失,有可能该样本为预测错误的样本,故不再学习该样本。基于上述方法可以对疑似标注错误的样本进行清洗,避免模型学习到错误的知识。本申请实施例并不限定根据上述最大损失Lmax-j和上述平均损失Lmean-j确定上述分类损失阈值的具体方法,比如,可以根据公式(Lmax-j-Lmean-j)*3/4+Lmean-j来确定分类损失阈值。
在一个优选的实施例中,上述神经网络包括第一损失、第二损失、第三损失、第四损失和第六损失,其中,第一损失和第三损失均属于哈希损失Lhash,第二损失和第四损失均属于PQ损失Lpq,第六损失属于分类损失Lmulti-class。则神经网络损失可以通过Ltotal=a*Lhash+b*Lpq+c*Lmulti-class表征,其中a,b,c均为权重值,a的值应当大于b和c,也就是说,优先保证哈希特征学习、避免多标签分类在梯度回传中过多影响哈希表征效果。本申请并不限定a,b,c的具体取值,比如,a为1,b为0.2,c为0.1。
具体来说,Lhash=w1L1+w3L3,L1和L3分别表示第一损失和第三损失,w1和w3分别为对应的权重,为了确保哈希特征的度量学习效果w1大于w3,本申请并不限定w1和w3的具体取值,比如,w1和w3分别为1和0.1。具体来说,Lpq=w2L2+w4L4,L2和L4分别表示第二损失和第四损失,w2和w4分别为对应的权重,为了确保哈希特征的度量学习效果w2应当小于w1,并且基于码本重建的损失重要度较大,因此w2大于w4,本申请并不限定w2和w4的具体取值,比如,w2和w4分别为0.5和0.01。
由于多任务学习中,容易因为多个损失相互影响造成收敛困难问题,故上述方案调整了不同损失的权重,优先保证哈希特征学习的效果(权重大),避免多标签分类任务、量化任务影响神经网络的收敛。在其他可行的实施例中,也可以先训练第一网络及第三网络,第三网络产生的损失相较于第一网络产生的损失依然权重较小,待第一网络产生的损失较小后,再加入Lpq一并学习直至神经网络收敛。
相关技术中,一般基于深度学习的PQ量化在结合语义信息时会优先考虑多分类的聚类中心(原型)作为PQ量化参考,但实际上由于无法穷尽所有图像的分类,故这种相关技术可以对指定图像场景下有限分类图像的检索较为有效,而对海量包罗万象的图像检索而言,由于存在分类覆盖不全导致PQ的表征不准确的问题,从而影响了这一类相关技术的使用。本申请实施例借助多标签模型提供的信息作为参考,同时还考虑到图像可能不携带任何标签的情况,也就是提供了分类覆盖不全的情况下的码本适应能力。
为了提供对分类覆盖不全的情况下的码本适应能力,请参考图3,其示出本申请实施例中第二网络参数的更新方法流程示意图,上述方法包括:
S201.获取第二样本图像组,上述第二样本图像组包括上述至少一个第一样本图像组中的图像以及至少一个第四图像,上述第四图像为是未被识别出所属类别的图像。
本申请实施例中第一样本组的图像中的标签都指向明确的类别,也就是说第一样本组的图像都属于被覆盖的类别,这些图像的类别可以通过预设模型被识别出来,而该预设模型无法识别出对应的类别的图像即为上述第四图像。比如,多标签预测模型可以预测5000个分类,第一样本图像组中的图像的标签可以基于上述多标签预测模型预测得到,也就是说,第一样本图像组中的标签指向上述5000个分类中的类别,而不属于上述5000个分类的情况,上述多标签预测模型无法预测得到,这种情况下的图像就是本申请实施例中的第四图像。为了将第四图像的情况也进行考虑,从而为本申请得到的码本提供分类覆盖不全情况下的适应能力,本申请实施例中可以对于第四图像也加入一个特殊标签,该特殊标签表征第四图像属于未被覆盖的分类,比如others标签。需要注意的是,由于第四图像不存在类别信息,因此,第四图像并不产生第五损失和第六损失。
S202.基于上述第二样本图像组,更新上述第二网络的参数。
本申请实施例中神经网络经过一轮迭代学习之后,即可更新码本,也就是说,神经网络经过一轮迭代学习之后,可以通过执行步骤S201至S202更新上述第二网络的参数。
请参考图4,其示出本申请实施例中基于上述第二样本图像组更新第二网络的参数方法流程示意图,上述方法包括:
S301.确定第一标签集,上述第一标签集包括每一上述类别对应的标签,以及上述第四图像携带的标签。
根据前文可以确定第一标签集中的标签数量,如果多标签预测模型可以预测图像的5000个分类,则第一标签集包括5001个标签,因为还多一个others标签。
S302.针对上述第一标签集中的每个标签,提取上述标签对应的第二哈希特征,上述标签对应的第二哈希特征包括上述第二样本图像组中上述标签对应的各图像的第三哈希特征,上述第三哈希特征为对上述图像经由上述第一网络进行哈希特征提取得到的特征。
沿用前文示例,上述第一标签集中任一标签mi,其中i表示标签下标,可以查找第二样本图像组中携带标签mi的图像,获取第一网络对于这些图像的哈希特征提取结果(第三哈希特征),这些哈希特征提取结果就是标签mi对应的第二哈希特征。
S303.对各上述第二哈希特征进行聚类处理,根据聚类处理结果,更新上述第二网络的参数。
请参考图5,其示出本申请实施例中根据聚类处理结果更新第二网络的参数方法流程示意图,上述方法包括:
S401.针对上述第一标签集中的每个标签,对上述标签对应的第二哈希特征进行聚类,得到第一预设数量个聚类中心。
本申请实施例并不对第一预设数量进行限定,比如,可以取值M=10,其中M表示第一预设数量。对第二哈希特征进行聚类处理,可以得到10个聚类中心。
S402.针对每个聚类中心,将与上述聚类中心距离最近的上述第三哈希特征确定为上述聚类中心对应的原型。
本申请实施例中每个聚类中心对应一个原型,从而针对每个标签,可以得到10个原型。
S403.计算各上述原型的相互距离,得到上述标签对应的原型平均距离。
对标签mi,将对应于标签mi的10个原型两两计算相互距离,取平均值作为原型平均距离。
S404.基于各上述标签对应的原型平均距离,更新上述第二网络的参数。
请参考图6,其示出本申请实施例中基于各标签对应的原型平均距离,更新第二网络的参数方法流程示意图,上述方法包括:
S501.基于各上述标签对应的原型平均距离,确定距离阈值。
沿用前文示例,本申请实施例可以计算5001(包含others)个原型平均距离的均值作为上述距离阈值。
S502.确定上述第一标签集中各标签对应的原型中心,上述原型中心为上述标签下各原型的中心。
本申请实施例中,可以将上述标签下各原型进行聚类操作,得到上述原型中心。
S503.针对上述第一标签集中的每个标签,响应于存在与上述标签的原型中心距离小于上述距离阈值的其它标签下的原型的情况,将上述其它标签下的原型删除。
对标签mi,计算该标签mi的原型中心(即标签mi下10个原型的均值)与其他标签中的原型(5000标签,共50000原型)的距离,当出现距离小于该距离阈值时,删除对应的其他标签中的原型。
S504.响应于各标签下各原型的总数量小于等于原型阈值的情况,基于上述各标签下原型更新上述第二网络的参数。
本申请实施例并不限定原型阈值的大小,比如,可以设定原型阈值为原各标签各原型总数的80%,也就是说,将原型数量缩减80%,以前文为例,原型阈值大小为5001*10*0.8。当缩减到只剩下80%原型时,相当于每个标签只有M=8个原型的数量级,停止缩减。
当S503执行完毕后仍多于80%原型时,对存在大于8个原型的标签,从原型多的标签开始,依次进行下述处理,直至缩减到只剩下80%原型:
在标签内相互计算原型的距离,并找到距离最近的两个原型,删除其中任意一个,执行上述删除方法直至该标签下剩余8个原型。
S504中可以把各标签下原型的数量小于等于原型阈值的情况下的每个原型的哈希特征进行K分段拆分,把不同分段分别存储到全局K个分段对应的码本。
本申请实施例认为神经网络在第一轮学习时,码本的表征能力尚较差,因此也可以不学习码本及相关损失,也就是不学习Lpq,在第一轮学习结束后,根据上述方法提取到码本,更新第二网络的参数后可以加入Lpq训练神经网络。上述神经网络在训练前可以将所有参数都设为需要学习状态,在产生损失后基于梯度下降法(比如随机梯度下降法)更新神经网络参数。由于码本需要对全局样本具备描述能力,故随着哈希特征的学习优化,旧码本未必对当前哈希最优故需要使用步骤S201至S202定期更新。
上述原型缩减的主要目的是去除重复的原型,当出现噪声时,容易使得两个相似标签(标签1、2)下的样本相似,故这两个标签各自的原型可能会出现相似,相似的原型会使得两个相似的图像在量化时一个基于标签1对应的码本内容量化,一个基于标签2对应的码本内容量化,从而造成量化结果的割裂,通过上述原型缩减使得码本的表征能力大幅提升,降低了割裂概率,显著提升基于码本进行图像检索的准确度。
本申请实施例提供的神经网络训练方法通过在第一网络(哈希模型)上基于多标签的多原型产生乘积量化码本的语义哈希特征的深度量化手段,参考多标签数据原型高效找到PQ中心,避免了二阶段PQ表征不明确、空间割裂的问题的同时,使得PQ中心具有实质表征从而可以通过语义损失约束PQ学习,最终实现PQ与哈希联合学习端到端产生特征与量化的任务。具体来说,通过在学习二值哈希特征的同时维护一个量化码本,使得两者在学习中持续优化,实现码本对相似样本的更好支持,避免割裂样本对。借助多标签语义原型产生PQ量化码本,并使得哈希特征分布在这些多标签原型中心对应的码本附近,可提升PQ码本的聚合效果。具体来说,本申请实施例中对二值哈希特征进行学习的同时,维持一个量化的码本,码本由多标签的多原型中心产生并定期根据多标签原型更新。在神经网络学习中把哈希特征通过分段码本重建,使得重建后的特征与原始的哈希特征接近(第三损失)从而维持码本对哈希码表达的准确性,同时使得重建结果同时具备度量效果(第二损失)避免码本对空间割裂。
本申请实施例通过多标签丰富的语义原型表征的量化码本与哈希特征联合度量学习,使得分类即量化、对语义哈希表征、量化表征效果均有提升,同时使得相似样本尽可能量化在相近的量化向量中。检索应用时不需要额外训练一个量化码本,直接采用神经网络中的码本作为索引,对库存样本关联即可。
当然,本申请并不对于神经网络中第一网络、第二网络、第三网络的结构进行限定,以第一网络为例,其可以使用Resnet101,也可以使用Resnet50、Inceptionv4等,对于数据量较大的检索,可采用小网络如Resnet18。神经网络损失中的各项损失的权值在保证着重学习哈希特征的前提下,都可以根据需要调整。在原型缩减环节也可以视图像携带的标签的噪声情况调整缩减比例。
本申请实施例可以基于得到的神经网络进行图像的检索,上述图像检索方法包括:
S601.获取第五图像。
本申请实施例中第五图像是待检索图像,上述图像检索方法的目的在于在图像库中查找与上述待检索图像相似的图像。
S602.将上述第五图像输入上述神经网络,得到上述第五图像对应的第二量化结果。
S603.将图像库中的各第六图像输入上述神经网络,得到每一上述第六图像对应的第三量化结果。
S604.根据上述第二量化结果和上述每一上述第六图像对应的第三量化结果,输出至少一个目标第六图像,上述目标第六图像对应的第三量化结果与上述第二量化结果的距离符合预设要求。
本申请实施例并不限定上述预设要求的具体内容,比如,预设要求的内容可以为目标第六图像的第三量化结果与上述第二量化结果的汉明距离小于预设阈值,对上述预设阈值的大小不做限定,比如,可以为3。再比如,预设要求的内容还可以为基于各第六图像的第三量化结果与上述第二量化结果的汉明距离由小到大的顺序进行排序,将排序在前的第二预设数量个第六图像作为目标第六图像,本申请并不限定上述第二预设数量的具体数值,比如,可以为30。
请参考图7,其示出本实施例中一种神经网络训练装置的框图,上述装置包括:
图像组获取模块101,用于获取第一图像组,上述第一图像组为至少一个第一样本图像组中任一图像组,上述第一样本图像组包括第一图像、第二图像和第三图像,上述第二图像与上述第一图像的类别相同,上述第三图像与上述第一图像的类别不相同;
哈希提取模块102,用于基于上述第一网络对上述第一图像组进行哈希特征提取,得到哈希特征组;
量化模块103,用于基于上述第二网络对上述哈希特征组进行量化,得到量化结果组,上述第二网络的参数表征类别对应的码本;
优化模块104,用于确定神经网络损失,以及基于上述神经网络损失,优化上述神经网络参数;
其中,上述神经网络损失包括基于上述哈希特征组确定的第一损失,以及基于上述量化结果组确定的第二损失;上述第一损失和上述第二损失均为三元组损失。
在一个实施例中,上述神经网络损失还包括下述至少之一:
基于第一哈希特征和第二哈希特征之间的差异确定的第三损失;
基于第一哈希特征和第一量化重建结果之间的差异确定的第四损失;
其中,上述第一哈希特征为至少一个哈希特征组中任一哈希特征,上述第二哈希特征为对上述第一哈希特征进行二值化处理后得到的特征,上述第一量化重建结果为根据上述第一哈希特征对应的第一量化结果和上述第二网络的参数重建的特征。
在一个实施例中,上述神经网络还包括第三网络,上述优化模块104还用于执行下述操作:
基于上述第三网络对上述哈希特征组进行分类预测,得到分类结果组;
上述神经网络损失还包括第五损失或第六损失,上述第六损失为小于分类损失阈值的上述第五损失;
上述第五损失为基于第一分类结果与第一标签的差异确定的损失,上述第一分类结果为至少一个分类结果组中的任一分类结果,上述第一标签为上述第一分类结果对应的图像携带的标签。
在一个实施例中,在上述神经网络损失包括上述第六损失的情况下,上述优化模块104还用于执行下述操作:
获取各上述第五损失的平均损失和最大损失;
根据上述最大损失和上述平均损失确定上述分类损失阈值;
将小于上述分类损失阈值的第五损失确定为上述第六损失。
在一个实施例中,上述优化模块104还用于执行下述操作;
获取第二样本图像组,上述第二样本图像组包括上述至少一个第一样本图像组中的图像以及至少一个第四图像,上述第四图像是未被识别出所属类别的图像;
基于上述第二样本图像组,更新上述第二网络的参数。
在一个实施例中,上述优化模块104还用于执行下述操作:
确定第一标签集,上述第一标签集包括每一上述类别对应的标签,以及上述第四图像携带的标签;
针对上述第一标签集中的每个标签,提取上述标签对应的第二哈希特征,上述标签对应的第二哈希特征包括上述第二样本图像组中上述标签对应的各图像的第三哈希特征,上述第三哈希特征为对上述图像经由上述第一网络进行哈希特征提取得到的特征;
对各上述第二哈希特征进行聚类处理,根据聚类处理结果,更新上述第二网络的参数。
在一个实施例中,上述优化模块104还用于执行下述操作:
针对上述第一标签集中的每个标签,对上述标签对应的第二哈希特征进行聚类,得到第一预设数量个聚类中心;
针对每个聚类中心,将与上述聚类中心距离最近的上述第三哈希特征确定为上述聚类中心对应的原型;
计算各上述原型的相互距离,得到上述标签对应的原型平均距离;
基于各上述标签对应的原型平均距离,更新上述第二网络的参数。
在一个实施例中,上述优化模块104还用于执行下述操作:
基于各上述标签对应的原型平均距离,确定距离阈值;
确定上述第一标签集中各标签对应的原型中心,上述原型中心为上述标签下各原型的中心;
针对上述第一标签集中的每个标签,响应于存在与上述标签的原型中心距离小于上述距离阈值的其它标签下的原型的情况,将上述其它标签下的原型删除;
响应于各标签下各原型的总数量小于等于原型阈值的情况,基于上述各标签下原型更新上述第二网络的参数。
在一个实施例中,上述装置还包括检索模块,上述检索模块用于执行下述操作:
获取第五图像;
将上述第五图像输入上述神经网络,得到上述第五图像对应的第二量化结果;
将图像库中的各第六图像输入上述神经网络,得到每一上述第六图像对应的第三量化结果;
根据上述第二量化结果和上述每一上述第六图像对应的第三量化结果,输出至少一个目标第六图像,上述目标第六图像对应的第三量化结果与上述第二量化结果的距离符合预设要求。
本申请实施例还提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述一种神经网络训练方法。
本申请实施例还提供了一种计算机可读存储介质,上述计算机可读存储介质可以存储有多条指令。上述指令可以适于由处理器加载并执行本申请实施例上述的一种神经网络训练方法。
在一个实施例中,上述一种神经网络训练方法,上述神经网络包括第一网络和第二网络,上述方法包括:
获取第一图像组,上述第一图像组为至少一个第一样本图像组中任一图像组,上述第一样本图像组包括第一图像、第二图像和第三图像,上述第二图像与上述第一图像的类别相同,上述第三图像与上述第一图像的类别不相同;
基于上述第一网络对上述第一图像组进行哈希特征提取,得到哈希特征组;
基于上述第二网络对上述哈希特征组进行量化,得到量化结果组,上述第二网络的参数表征类别对应的码本;
确定神经网络损失,以及基于上述神经网络损失,优化上述神经网络参数;
其中,上述神经网络损失包括基于上述哈希特征组确定的第一损失,以及基于上述量化结果组确定的第二损失;上述第一损失和上述第二损失均为三元组损失。
在另一个实施例中,上述神经网络损失还包括下述至少之一:
基于第一哈希特征和第二哈希特征之间的差异确定的第三损失;
基于第一哈希特征和第一量化重建结果之间的差异确定的第四损失;
其中,上述第一哈希特征为至少一个哈希特征组中任一哈希特征,上述第二哈希特征为对上述第一哈希特征进行二值化处理后得到的特征,上述第一量化重建结果为根据上述第一哈希特征对应的第一量化结果和上述第二网络的参数重建的特征。
在另一个实施例中,上述神经网络还包括第三网络,上述方法还包括:
基于上述第三网络对上述哈希特征组进行分类预测,得到分类结果组;
上述神经网络损失还包括第五损失或第六损失,上述第六损失为小于分类损失阈值的上述第五损失;
上述第五损失为基于第一分类结果与第一标签的差异确定的损失,上述第一分类结果为至少一个分类结果组中的任一分类结果,上述第一标签为上述第一分类结果对应的图像携带的标签。
在另一个实施例中,在上述神经网络损失包括上述第六损失的情况下,上述确定神经网络损失,包括:
获取各上述第五损失的平均损失和最大损失;
根据上述最大损失和上述平均损失确定上述分类损失阈值;
将小于上述分类损失阈值的第五损失确定为上述第六损失。
在另一个实施例中,上述方法还包括;
获取第二样本图像组,上述第二样本图像组包括上述至少一个第一样本图像组中的图像以及至少一个第四图像,上述第四图像是未被识别出所属类别的图像;
基于上述第二样本图像组,更新上述第二网络的参数。
在另一个实施例中,上述基于上述第二样本图像组,更新上述第二网络的参数,包括:
确定第一标签集,上述第一标签集包括每一上述类别对应的标签,以及上述第四图像携带的标签;
针对上述第一标签集中的每个标签,提取上述标签对应的第二哈希特征,上述标签对应的第二哈希特征包括上述第二样本图像组中上述标签对应的各图像的第三哈希特征,上述第三哈希特征为对上述图像经由上述第一网络进行哈希特征提取得到的特征;
对各上述第二哈希特征进行聚类处理,根据聚类处理结果,更新上述第二网络的参数。
在另一个实施例中,,上述对各上述第二哈希特征进行聚类处理,根据聚类处理结果,更新上述第二网络的参数,包括:
针对上述第一标签集中的每个标签,对上述标签对应的第二哈希特征进行聚类,得到第一预设数量个聚类中心;
针对每个聚类中心,将与上述聚类中心距离最近的上述第三哈希特征确定为上述聚类中心对应的原型;
计算各上述原型的相互距离,得到上述标签对应的原型平均距离;
基于各上述标签对应的原型平均距离,更新上述第二网络的参数。
在另一个实施例中,上述基于各上述标签对应的原型平均距离,更新上述第二网络的参数,包括:
基于各上述标签对应的原型平均距离,确定距离阈值;
确定上述第一标签集中各标签对应的原型中心,上述原型中心为上述标签下各原型的中心;
针对上述第一标签集中的每个标签,响应于存在与上述标签的原型中心距离小于上述距离阈值的其它标签下的原型的情况,将上述其它标签下的原型删除;
响应于各标签下各原型的总数量小于等于原型阈值的情况,基于上述各标签下原型更新上述第二网络的参数。
在另一个实施例中,上述方法还包括:
获取第五图像;
将上述第五图像输入上述神经网络,得到上述第五图像对应的第二量化结果;
将图像库中的各第六图像输入上述神经网络,得到每一上述第六图像对应的第三量化结果;
根据上述第二量化结果和上述每一上述第六图像对应的第三量化结果,输出至少一个目标第六图像,上述目标第六图像对应的第三量化结果与上述第二量化结果的距离符合预设要求。
进一步地,图8示出了一种用于实现本申请实施例所提供的方法的设备的硬件结构示意图,上述设备可以参与构成或包含本申请实施例所提供的装置或系统。如图8所示,设备10可以包括一个或多个(图中采用102a、102b,……,102n来示出)处理器102(处理器102可以包括但不限于微处理器MCU或可编程逻辑器件FPGA等的处理装置)、用于存储数据的存储器104、以及用于通信功能的传输装置106。除此以外,还可以包括:显示器、输入/输出接口(I/O接口)、通用串行总线(USB)端口(可以作为I/O接口的端口中的一个端口被包括)、网络接口、电源和/或相机。本领域普通技术人员可以理解,图8所示的结构仅为示意,其并不对上述电子装置的结构造成限定。例如,设备10还可包括比图8中所示更多或者更少的组件,或者具有与图8所示不同的配置。
应当注意到的是上述一个或多个处理器102和/或其他数据处理电路在本文中通常可以被称为“数据处理电路”。该数据处理电路可以全部或部分的体现为软件、硬件、固件或其他任意组合。此外,数据处理电路可为单个独立的处理模块,或全部或部分的结合到设备10(或移动设备)中的其他元件中的任意一个内。如本申请实施例中所涉及到的,该数据处理电路作为一种处理器控制(例如与接口连接的可变电阻终端路径的选择)。
存储器104可用于存储应用软件的软件程序以及模块,如本申请实施例中上述的方法对应的程序指令/数据存储装置,处理器102通过运行存储在存储器104内的软件程序以及模块,从而执行各种功能应用以及数据处理,即实现上述的一种神经网络训练方法。存储器104可包括高速随机存储器,还可包括非易失性存储器,如一个或者多个磁性存储装置、闪存、或者其他非易失性固态存储器。在一些实例中,存储器104可进一步包括相对于处理器102远程设置的存储器,这些远程存储器可以通过网络连接至设备10。上述网络的实例包括但不限于互联网、企业内部网、局域网、移动通信网及其组合。
传输装置106用于经由一个网络接收或者发送数据。上述的网络具体实例可包括设备10的通信供应商提供的无线网络。在一个实例中,传输装置106包括一个网络适配器(NetworkInterfaceController,NIC),其可通过基站与其他网络设备相连从而可与互联网进行通讯。在一个实例中,传输装置106可以为射频(RadioFrequency,RF)模块,其用于通过无线方式与互联网进行通讯。
显示器可以例如触摸屏式的液晶显示器(LCD),该液晶显示器可使得用户能够与设备10(或移动设备)的用户界面进行交互。
需要说明的是:上述本申请实施例先后顺序仅仅为了描述,不代表实施例的优劣。且上述对本申请实施例特定实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
本申请实施例中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置和服务器实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,上述的程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
以上上述仅为本申请实施例的较佳实施例,并不用以限制本申请实施例,凡在本申请实施例的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请实施例的保护范围之内。
Claims (10)
1.一种神经网络训练方法,其特征在于,所述神经网络包括第一网络和第二网络,所述方法包括:
获取第一图像组,所述第一图像组为至少一个第一样本图像组中任一图像组,所述第一样本图像组包括第一图像、第二图像和第三图像,所述第二图像与所述第一图像的类别相同,所述第三图像与所述第一图像的类别不相同;
基于所述第一网络对所述第一图像组进行哈希特征提取,得到哈希特征组;
基于所述第二网络对所述哈希特征组进行量化,得到量化结果组,所述第二网络的参数表征类别对应的码本;
确定神经网络损失,以及基于所述神经网络损失,优化所述神经网络参数;
其中,所述神经网络损失包括基于所述哈希特征组确定的第一损失,以及基于所述量化结果组确定的第二损失;所述第一损失和所述第二损失均为三元组损失。
2.根据权利要求1所述的方法,其特征在于,所述神经网络损失还包括下述至少之一:
基于第一哈希特征和第二哈希特征之间的差异确定的第三损失;
基于第一哈希特征和第一量化重建结果之间的差异确定的第四损失;
其中,所述第一哈希特征为至少一个哈希特征组中任一哈希特征,所述第二哈希特征为对所述第一哈希特征进行二值化处理后得到的特征,所述第一量化重建结果为根据所述第一哈希特征对应的第一量化结果和所述第二网络的参数重建的特征。
3.根据权利要求1或2所述的方法,其特征在于,所述神经网络还包括第三网络,所述方法还包括:
基于所述第三网络对所述哈希特征组进行分类预测,得到分类结果组;
所述神经网络损失还包括第五损失或第六损失,所述第六损失为小于分类损失阈值的所述第五损失;
所述第五损失为基于第一分类结果与第一标签的差异确定的损失,所述第一分类结果为至少一个分类结果组中的任一分类结果,所述第一标签为所述第一分类结果对应的图像携带的标签。
4.根据权利要求3所述的方法,其特征在于,在所述神经网络损失包括所述第六损失的情况下,所述确定神经网络损失,包括:
获取各所述第五损失的平均损失和最大损失;
根据所述最大损失和所述平均损失确定所述分类损失阈值;
将小于所述分类损失阈值的第五损失确定为所述第六损失。
5.根据权利要求1所述的方法,其特征在于,所述方法还包括;
获取第二样本图像组,所述第二样本图像组包括所述至少一个第一样本图像组中的图像以及至少一个第四图像,所述第四图像为是未被识别出所属类别的图像;
基于所述第二样本图像组,更新所述第二网络的参数。
6.根据权利要求5所述的方法,其特征在于,所述基于所述第二样本图像组,更新所述第二网络的参数,包括:
确定第一标签集,所述第一标签集包括每一所述类别对应的标签,以及所述第四图像携带的标签;
针对所述第一标签集中的每个标签,提取所述标签对应的第二哈希特征,所述标签对应的第二哈希特征包括所述第二样本图像组中所述标签对应的各图像的第三哈希特征,所述第三哈希特征为对所述图像经由所述第一网络进行哈希特征提取得到的特征;
对各所述第二哈希特征进行聚类处理,根据聚类处理结果,更新所述第二网络的参数。
7.根据权利要求6所述的方法,其特征在于,所述对各所述第二哈希特征进行聚类处理,根据聚类处理结果,更新所述第二网络的参数,包括:
针对所述第一标签集中的每个标签,对所述标签对应的第二哈希特征进行聚类,得到第一预设数量个聚类中心;
针对每个聚类中心,将与所述聚类中心距离最近的所述第三哈希特征确定为所述聚类中心对应的原型;
计算各所述原型的相互距离,得到所述标签对应的原型平均距离;
基于各所述标签对应的原型平均距离,更新所述第二网络的参数。
8.根据权利要求7所述的方法,其特征在于,所述基于各所述标签对应的原型平均距离,更新所述第二网络的参数,包括:
基于各所述标签对应的原型平均距离,确定距离阈值;
确定所述第一标签集中各标签对应的原型中心,所述原型中心为所述标签下各原型的中心;
针对所述第一标签集中的每个标签,响应于存在与所述标签的原型中心距离小于所述距离阈值的其它标签下的原型的情况,将所述其它标签下的原型删除;
响应于各标签下各原型的总数量小于等于原型阈值的情况,基于所述各标签下原型更新所述第二网络的参数。
9.根据权利要求1所述的方法,其特征在于,所述方法还包括:
获取第五图像;
将所述第五图像输入所述神经网络,得到所述第五图像对应的第二量化结果;
将图像库中的各第六图像输入所述神经网络,得到每一所述第六图像对应的第三量化结果;
根据所述第二量化结果和所述每一所述第六图像对应的第三量化结果,输出至少一个目标第六图像,所述目标第六图像对应的第三量化结果与所述第二量化结果的距离符合预设要求。
10.一种神经网络训练装置,其特征在于,所述装置包括:
图像组获取模块,用于获取第一图像组,所述第一图像组为至少一个第一样本图像组中任一图像组,所述第一样本图像组包括第一图像、第二图像和第三图像,所述第二图像与所述第一图像的类别相同,所述第三图像与所述第一图像的类别不相同;
哈希提取模块,用于基于所述第一网络对所述第一图像组进行哈希特征提取,得到哈希特征组;
量化模块,用于基于所述第二网络对所述哈希特征组进行量化,得到量化结果组,所述第二网络的参数表征类别对应的码本;
优化模块,用于确定神经网络损失,以及基于所述神经网络损失,优化所述神经网络参数;
其中,所述神经网络损失包括基于所述哈希特征组确定的第一损失,以及基于所述量化结果组确定的第二损失;所述第一损失和所述第二损失均为三元组损失。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111261376.4A CN116050508A (zh) | 2021-10-28 | 2021-10-28 | 神经网络训练方法以及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111261376.4A CN116050508A (zh) | 2021-10-28 | 2021-10-28 | 神经网络训练方法以及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116050508A true CN116050508A (zh) | 2023-05-02 |
Family
ID=86124151
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111261376.4A Pending CN116050508A (zh) | 2021-10-28 | 2021-10-28 | 神经网络训练方法以及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116050508A (zh) |
Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018137358A1 (zh) * | 2017-01-24 | 2018-08-02 | 北京大学 | 基于深度度量学习的目标精确检索方法 |
CN109165306A (zh) * | 2018-08-09 | 2019-01-08 | 长沙理工大学 | 基于多任务哈希学习的图像检索方法 |
CN109241317A (zh) * | 2018-09-13 | 2019-01-18 | 北京工商大学 | 基于深度学习网络中度量损失的行人哈希检索方法 |
CN110321957A (zh) * | 2019-07-05 | 2019-10-11 | 重庆大学 | 融合三元组损失和生成对抗网络的多标签图像检索方法 |
CN110688502A (zh) * | 2019-09-09 | 2020-01-14 | 重庆邮电大学 | 一种基于深度哈希和量化的图像检索方法及存储介质 |
CA3141042A1 (en) * | 2019-06-13 | 2020-12-17 | Luis Eduardo Gutierrez-Sheris | System and method using a fitness-gradient blockchain consensus and providing advanced distributed ledger capabilities via specialized data records |
CN112766458A (zh) * | 2021-01-06 | 2021-05-07 | 南京瑞易智能科技有限公司 | 一种联合分类损失的双流有监督深度哈希图像检索方法 |
CN113190699A (zh) * | 2021-05-14 | 2021-07-30 | 华中科技大学 | 一种基于类别级语义哈希的遥感图像检索方法及装置 |
-
2021
- 2021-10-28 CN CN202111261376.4A patent/CN116050508A/zh active Pending
Patent Citations (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018137358A1 (zh) * | 2017-01-24 | 2018-08-02 | 北京大学 | 基于深度度量学习的目标精确检索方法 |
CN109165306A (zh) * | 2018-08-09 | 2019-01-08 | 长沙理工大学 | 基于多任务哈希学习的图像检索方法 |
CN109241317A (zh) * | 2018-09-13 | 2019-01-18 | 北京工商大学 | 基于深度学习网络中度量损失的行人哈希检索方法 |
CA3141042A1 (en) * | 2019-06-13 | 2020-12-17 | Luis Eduardo Gutierrez-Sheris | System and method using a fitness-gradient blockchain consensus and providing advanced distributed ledger capabilities via specialized data records |
CN110321957A (zh) * | 2019-07-05 | 2019-10-11 | 重庆大学 | 融合三元组损失和生成对抗网络的多标签图像检索方法 |
CN110688502A (zh) * | 2019-09-09 | 2020-01-14 | 重庆邮电大学 | 一种基于深度哈希和量化的图像检索方法及存储介质 |
CN112766458A (zh) * | 2021-01-06 | 2021-05-07 | 南京瑞易智能科技有限公司 | 一种联合分类损失的双流有监督深度哈希图像检索方法 |
CN113190699A (zh) * | 2021-05-14 | 2021-07-30 | 华中科技大学 | 一种基于类别级语义哈希的遥感图像检索方法及装置 |
Non-Patent Citations (2)
Title |
---|
冯兴杰;程毅玮;: "基于深度卷积神经网络与哈希的图像检索", 计算机工程与设计, no. 03, 16 March 2020 (2020-03-16) * |
李泗兰;郭雅;: "基于深度学习哈希算法的快速图像检索研究", 计算机与数字工程, no. 12, 20 December 2019 (2019-12-20) * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2023065545A1 (zh) | 风险预测方法、装置、设备及存储介质 | |
CN111382283B (zh) | 资源类别标签标注方法、装置、计算机设备和存储介质 | |
CN111259647A (zh) | 基于人工智能的问答文本匹配方法、装置、介质及电子设备 | |
CN113821670B (zh) | 图像检索方法、装置、设备及计算机可读存储介质 | |
CN113806582B (zh) | 图像检索方法、装置、电子设备和存储介质 | |
CN114491039B (zh) | 基于梯度改进的元学习少样本文本分类方法 | |
CN113822315A (zh) | 属性图的处理方法、装置、电子设备及可读存储介质 | |
CN114492601A (zh) | 资源分类模型的训练方法、装置、电子设备及存储介质 | |
CN112819024B (zh) | 模型处理方法、用户数据处理方法及装置、计算机设备 | |
CN114298122A (zh) | 数据分类方法、装置、设备、存储介质及计算机程序产品 | |
CN114611672A (zh) | 模型训练方法、人脸识别方法及装置 | |
CN116244484B (zh) | 一种面向不平衡数据的联邦跨模态检索方法及系统 | |
CN116307078A (zh) | 账户标签预测方法、装置、存储介质及电子设备 | |
CN111161238A (zh) | 图像质量评价方法及装置、电子设备、存储介质 | |
CN116362894A (zh) | 多目标学习方法、装置、电子设备及计算机可读存储介质 | |
CN114898184A (zh) | 模型训练方法、数据处理方法、装置及电子设备 | |
CN115631008B (zh) | 商品推荐方法、装置、设备及介质 | |
CN116050508A (zh) | 神经网络训练方法以及装置 | |
WO2021115269A1 (zh) | 用户集群的预测方法、装置、计算机设备和存储介质 | |
CN111611981A (zh) | 信息识别方法和装置及信息识别神经网络训练方法和装置 | |
WO2022262603A1 (zh) | 多媒体资源的推荐方法、装置、设备、存储介质及计算机程序产品 | |
CN111476037B (zh) | 文本处理方法、装置、计算机设备和存储介质 | |
CN116702016A (zh) | 对象属性识别方法、装置、设备及存储介质 | |
Feng et al. | Construction of Legal Reporting Information Platform Based on Natural Optimization Algorithm | |
Gama et al. | Advances in Knowledge Discovery and Data Mining: 26th Pacific-Asia Conference, PAKDD 2022, Chengdu, China, May 16–19, 2022, Proceedings, Part II |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
REG | Reference to a national code |
Ref country code: HK Ref legal event code: DE Ref document number: 40087996 Country of ref document: HK |