CN115965817A - 图像分类模型的训练方法、装置及电子设备 - Google Patents
图像分类模型的训练方法、装置及电子设备 Download PDFInfo
- Publication number
- CN115965817A CN115965817A CN202310014934.XA CN202310014934A CN115965817A CN 115965817 A CN115965817 A CN 115965817A CN 202310014934 A CN202310014934 A CN 202310014934A CN 115965817 A CN115965817 A CN 115965817A
- Authority
- CN
- China
- Prior art keywords
- image
- sample
- support
- sample image
- query
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 208
- 238000013145 classification model Methods 0.000 title claims abstract description 159
- 238000000034 method Methods 0.000 title claims abstract description 83
- 238000004364 calculation method Methods 0.000 claims description 32
- 238000000605 extraction Methods 0.000 claims description 30
- 230000007246 mechanism Effects 0.000 claims description 29
- 238000012545 processing Methods 0.000 claims description 26
- 230000006870 function Effects 0.000 claims description 20
- 238000004590 computer program Methods 0.000 claims description 11
- 230000000007 visual effect Effects 0.000 claims description 4
- 239000000126 substance Substances 0.000 claims 1
- 238000013473 artificial intelligence Methods 0.000 abstract description 3
- 238000013135 deep learning Methods 0.000 abstract description 2
- 238000003058 natural language processing Methods 0.000 abstract description 2
- 241000894007 species Species 0.000 description 45
- 238000010586 diagram Methods 0.000 description 13
- 230000008569 process Effects 0.000 description 10
- 238000004891 communication Methods 0.000 description 8
- 238000002372 labelling Methods 0.000 description 4
- 230000004048 modification Effects 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 238000013136 deep learning model Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 241000271566 Aves Species 0.000 description 1
- 241000282472 Canis lupus familiaris Species 0.000 description 1
- 241000282693 Cercopithecidae Species 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000005540 biological transmission Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 230000009194 climbing Effects 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Abstract
本公开提供了图像分类模型的训练方法、装置及电子设备,涉及人工智能技术领域,尤其涉及自然语言处理、计算机视觉、深度学习技术领域。具体实现方案为:获取多个训练数据集,训练数据集包括支持集和查询集;获取初始的图像分类模型;针对每个训练数据集,根据训练数据集中的支持集和查询集,确定训练数据集中的多个样本图像对以及对应的样本相似度;依次针对每个训练数据集,采用其中的多个样本图像对以及对应的样本相似度,对图像分类模型进行训练,得到训练好的图像分类模型,从而能够根据较少的样本图像以及对应的类别,训练得到准确度较高的图像分类模型,能够适用于图像标注数据比较缺乏的图像分类任务,提高图像分类任务下的准确度。
Description
技术领域
本公开涉及人工智能技术领域,尤其涉及自然语言处理、计算机视觉、深度学习技术领域,尤其涉及一种图像分类模型的训练方法、装置及电子设备。
背景技术
目前,针对图像识别任务,需要对大量图像进行标记,得到图像标注数据;采用图像标注数据对深度学习模型进行训练,得到识别准确度较高的图像识别模型,用于图像识别任务。
其中,针对物种细粒度识别任务,由于很多物种的体型、外貌相似,
特征差异较小,只有相应领域的专家才能区分不同的物种,导致该任务下的图像标注数据缺乏,难以训练得到识别准确度较高的物种细粒度识别模型。
发明内容
本公开提供了一种图像分类模型的训练方法、装置及电子设备。
根据本公开的一方面,提供了一种图像分类模型的训练方法,所述方法包括:获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。
根据本公开的另一方面,提供了一种图像分类方法,所述方法包括:获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;根据所述待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于如上所述的图像分类模型的训练方法训练得到;将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别。
根据本公开的另一方面,提供了一种图像分类模型的训练装置,所述装置包括:第一获取模块,用于获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;第二获取模块,用于获取初始的图像分类模型;确定模块,用于针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;训练模块,用于依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。
根据本公开的另一方面,提供了一种图像分类装置,所述装置包括:
获取模块,用于获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;生成模块,用于根据所述5待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;第一输入模块,用于将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于如上所述0的图像分类模型的训练方法训练得到;第二输入模块,用于将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;确定模块,用于根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样5本图像的类别,确定所述待处理图像的类别。
根据本公开的另一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开上述提0出的图像分类模型的训练方法,或者,执行本公开上述提出的图像分类方法。
根据本公开的另一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,所述计算机指令用于使计算机执行本公开上述提出的图像分类模型的训练方法,或者,执行本公开上述提出的图像分5类方法。
根据本公开的另一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开上述提出的图像分类模型的训练方法,或者,实现本公开上述提出的图像分类方法。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是根据本公开第一实施例的示意图;
图2是根据本公开第二实施例的示意图;
图3是根据本公开第三实施例的示意图;
图4是根据本公开第四实施例的示意图;
图5是根据本公开第五实施例的示意图;
图6是根据本公开第六实施例的示意图;
图7是用来实现本公开实施例的图像分类模型的训练方法或者图像分类方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
目前,针对图像识别任务,需要对大量图像进行标记,得到图像标注数据;采用图像标注数据对深度学习模型进行训练,得到识别准确度较高的图像识别模型,用于图像识别任务。
其中,针对物种细粒度识别任务,由于很多物种的体型、外貌相似,
特征差异较小,只有相应领域的专家才能区分不同的物种,导致该任务下的图像标注数据缺乏,难以训练得到识别准确度较高的物种细粒度识别模型。
针对上述问题,本公开提出一种图像分类模型的训练方法、装置及电子设备。
图1是根据本公开第一实施例的示意图,需要说明的是,本公开实施例的图像分类模型的训练方法可应用于图像分类模型的训练装置,该装置可被配置于电子设备中,以使该电子设备可以执行图像分类模型的训练功能。以下实施例中以执行主体为电子设备为例进行说明。
其中,电子设备可以为任一具有计算能力的设备,例如可以为个人电脑(PersonalComputer,简称PC)、移动终端、服务器等,移动终端例如可以为车载设备、手机、平板电脑、个人数字助理、穿戴式设备等具有各种操作系统、触摸屏和/或显示屏的硬件设备。
如图1所示,该图像分类模型的训练方法可以包括如下步骤:
步骤101,获取多个训练数据集,训练数据集包括支持集和查询集;支持集包括支持样本图像以及支持样本图像的类别;查询集中包括查询样本图像以及查询样本图像的类别。
在本公开实施例中,支持集中,支持样本图像的数量可以为多个,类别的数量可以为多个。例如,支持集中可以包括N个类别,每个类别下可以有K个支持样本图像。
在本公开实施例中,查询集中,查询样本图像的数量可以为多个,类别的数量可以为一个或者多个。例如,查询集中可以包括一个类别,该类别下可以有K个查询样本图像。
其中,需要说明的是,查询集中的类别,可以为支持集中的其中一个类别;或者,查询集中的类别,可以与支持集中的类别不同。其中,若查询集中的类别,与支持集中的类别不同,后续基于支持集和查询集生成的图像对中,包括不同类别的两个样本图像的图像对较多,包括相同类别的两个样本图像的图像对较少或者没有,导致两种图像对的数量不均衡,可能训练得到的图像分类模型的准确度会受影响。
其中,若查询集中的类别,为支持集中的其中一个类别,后续基于支持集和查询集生成的图像对中,两种图像对的数量可能比较均衡,确保训练得到的图像分类模型的准确度。
在本公开实施例中,查询集中某个类别下的查询样本图像,与支持集中相同类别下的支持样本图像,可以相同或者不同。
在本公开实施例中,支持样本图像的类别,可以为支持样本图像中目标对象所属的物种;查询样本图像的类别,可以为查询样本图像中目标对象所属的物种。其中,类别为物种的情况下,图像分类模型,可以用于细粒度的物种识别,从而避免在物种识别任务下由专家进行物种识别,降低物种识别任务下的人工成本,提高物种识别任务下的物种识别准确度。
步骤102,获取初始的图像分类模型。
在本公开实施例中,图像分类模型中包括依次连接的特征提取网络、注意力机制网络和相似度计算网络;特征提取网络与注意力机制网络,用于提取样本图像对中支持样本图像的支持图像特征,以及提取样本图像对中查询样本图像的查询图像特征;相似度计算网络,用于对支持图像特征以及查询图像特征进行拼接处理以及相似度计算处理,获取样本图像对中支持样本图像与查询样本图像之间的预测相似度。
其中,特征提取网络与注意力机制网络,可以分别为视觉模型(即,VisionTransformer模型)中的特征提取网络以及注意力机制网络。其中,注意力机制网络,能够对特征提取网络提取到的图像特征进行重要性筛选以及特征提取处理,获取其中比较重要的图像特征。在物种识别任务下,特征提取网络结合注意力机制网络,可以提取到图像中物种间的较小差异,提高后续计算得到的相似度的准确度。其中,Vision Transformer模型的架构在提取物种间较小差异的效果更好,可以进一步提高相似度计算的准确度,进而提高物种识别的准确度。
其中,相似度计算网络,例如可以为4层卷积网络。相似度计算网络具体用于,对支持图像特征以及查询图像特征进行拼接处理,得到拼接图像特征;根据拼接图像特征进行相似度计算处理,得到样本图像对中支持样本图像与查询样本图像之间的预测相似度。
步骤103,针对每个训练数据集,根据训练数据集中的支持样本图像、支持样本图像的类别、查询样本图像以及查询样本图像的类别,确定训练数据集中的多个样本图像对,以及样本图像对中支持样本图像以及查询样本图像之间的样本相似度。
在本公开实施例中,电子设备执行步骤103的过程例如可以为,针对每个训练数据集,根据训练数据集中的支持样本图像以及查询样本图像,生成多个样本图像对;针对每个样本图像对,根据样本图像对中支持样本图像的类别,以及样本图像对中查询样本图像的类别,确定样本图像对中支持样本图像与查询样本图像之间的样本相似度。
其中,针对每个训练数据集,电子设备生成多个样本图像对的过程例如可以为,从训练数据集中的支持集中随机选择一个支持样本图像,从训练数据集中的查询集中随机选择一个查询样本图像;对选择的支持样本图像以及查询样本图像进行配对处理,得到一个样本图像对;重复执行该段中的上述所有步骤,得到所述训练数据集中的多个样本图像对。
其中,电子设备确定样本图像对中支持样本图像与查询样本图像之间的样本相似度的过程例如可以为,确定样本图像对中支持样本图像的类别与查询样本图像的类别是否相同;若该两个类别相同,则确定样本相似度为1;若该两个类别不同,则确定样本相似度为0。
其中,对训练数据集中的支持样本图像以及查询样本图像进行配对处理,得到样本图像对;并结合样本图像对中两个样本图像的类别,确定样本相似度,能够自动且准确确定两个样本图像的样本相似度,方便后续对图像分类模型进行训练处理,降低训练成本。
步骤104,依次针对每个训练数据集,以训练数据集中样本图像对中的支持样本图像和查询样本图像为图像分类模型的输入,以样本图像对中支持样本图像以及查询样本图像之间的样本相似度为图像分类模型的输出,对图像分类模型进行训练,得到训练好的图像分类模型。
本公开实施例的图像分类模型的训练方法,通过获取多个训练数据集,训练数据集包括支持集和查询集;支持集包括支持样本图像以及支持样本图像的类别;查询集中包括查询样本图像以及查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据训练数据集中的支持样本图像、支持样本图像的类别、查询样本图像以及查询样本图像的类别,确定训练数据集中的多个样本图像对,以及样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以训练数据集中样本图像对中的支持样本图像和查询样本图像为图像分类模型的输入,以样本图像对中支持样本图像以及查询样本图像之间的样本相似度为图像分类模型的输出,对图像分类模型进行训练,得到训练好的图像分类模型,从而能够根据较少的样本图像以及对应的类别,训练得到准确度较高的图像分类模型,能够适用于图像标注数据比较缺乏的任务,例如物种细粒度识别任务等,提高物种细粒度识别任务下的识别准确度。
其中,为了准确获取多个训练数据集,方便后续对图像分类模型进行训练,提高训练得到的图像分类模型的准确度,可以控制训练数据集中类别的数量,以及每个类别下样本图像的数量,进而控制两种样本图像对的比例。如图2所示,图2是根据本公开第二实施例的示意图,图2所示实施例可以包括以下步骤:
步骤201,获取原始数据集,其中,原始数据集中包括大于预设数量的样本图像,以及样本图像的类别。
在本公开实施例中,以物种识别任务或者物种细粒度识别任务为例,原始数据集中样本图像的类别,即样本图像中目标对象的物种,例如,鸟类、狗类、猴类等。以鸟类为例,鸟类又可以细分为,游禽类、涉禽类、攀禽类、陆禽类、猛禽类、鸣禽类等。
步骤202,从原始数据集的多个类别中抽取第一类别,并从原始数据集中具有第一类别的样本图像中抽取支持样本图像,得到支持集。
在本公开实施例中,第一类别的数量可以为N个,第一类别下的支持样本图像的数量例如可以为K个。其中,N的数量,可以根据类别的实际数量、以及训练数据集的数量等确定。K的数量,可以根据各个类别下的样本图像的数量来确定。
步骤203,从第一类别中抽取一个类别作为第二类别,并从原始数据集中具有第二类别的样本图像中抽取查询样本图像,得到支持集对应的查询集。
在本公开实施例中,从第一类别中抽取一个类别作为第二类别,从并原始数据集中具有第二类别的样本图像中抽取查询样本图像,能够确保一个训练数据集的支持集和查询集中存在相同类别的样本图像,从而基于支持集和查询集生成样本图像对时,确保两种样本图像对的数量均衡;基于数量均衡的两种样本图像对,对图像分类模型进行训练,能够确保训练得到的图像分类模型针对两种样本图像对进行相似度计算时的准确度。
另外,作为步骤203的替换方案,电子设备生成支持集对应的查询集的过程例如还可以为,从原始数据集的多个类别中抽取一个类别作为第二类别,并从原始数据集中具有第二类别的样本图像中抽取查询样本图像,得到支持集对应的查询集。
另外,作为步骤203的替换方案,电子设备生成支持集对应的查询集的过程例如还可以为,从第一类别中抽取一个类别作为第二类别,并将支持集中具有第二类别的样本图像,作为查询样本图像,得到支持集对应的查询集。
步骤204,根据支持集以及支持集对应的查询集,生成训练数据集。
步骤205,获取初始的图像分类模型。
步骤206,针对每个训练数据集,根据训练数据集中的支持样本图像、支持样本图像的类别、查询样本图像以及查询样本图像的类别,确定训练数据集中的多个样本图像对,以及样本图像对中支持样本图像以及查询样本图像之间的样本相似度。
步骤207,依次针对每个训练数据集,以训练数据集中样本图像对中的支持样本图像和查询样本图像为图像分类模型的输入,以样本图像对中支持样本图像以及查询样本图像之间的样本相似度为图像分类模型的输出,对图像分类模型进行训练,得到训练好的图像分类模型。
其中,需要说明的是,步骤205至步骤207的详细内容,可以参考图1所示实施例中的步骤102至步骤104,此处不再进行详细说明。
本公开实施例的图像分类模型的训练方法,通过获取原始数据集,其中,原始数据集中包括大于预设数量的样本图像,以及样本图像的类别;从原始数据集的多个类别中抽取第一类别,并从原始数据集中具有第一类别的样本图像中抽取支持样本图像,得到支持集;从第一类别中抽取一个类别作为第二类别,并从原始数据集中具有第二类别的样本图像中抽取查询样本图像,得到支持集对应的查询集;根据支持集以及支持集对应的查询集,生成训练数据集;获取初始的图像分类模型;针对每个训练数据集,根据训练数据集中的支持样本图像、支持样本图像的类别、查询样本图像以及查询样本图像的类别,确定训练数据集中的多个样本图像对,以及样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以训练数据集中样本图像对中的支持样本图像和查询样本图像为图像分类模型的输入,以样本图像对中支持样本图像以及查询样本图像之间的样本相似度为图像分类模型的输出,对图像分类模型进行训练,得到训练好的图像分类模型,从而能够根据较少的样本图像以及对应的类别,训练得到准确度较高的图像分类模型,能够适用于图像标注数据比较缺乏的任务,例如物种细粒度识别任务等,提高物种细粒度识别任务下的识别准确度。
其中,为了准确地根据多个训练数据集对图像分类模型进行训练,进一步提高训练得到的图像分类模型的准确度,可以依次针对每个训练数据集,构建损失函数对图像分类模型进行训练。如图3所示,图3是根据本公开第三实施例的示意图,图3所示实施例可以包括以下步骤:
步骤301,获取多个训练数据集,训练数据集包括支持集和查询集;支持集包括支持样本图像以及支持样本图像的类别;查询集中包括查询样本图像以及查询样本图像的类别。
步骤302,获取初始的图像分类模型。
步骤303,针对每个训练数据集,根据训练数据集中的支持样本图像、支持样本图像的类别、查询样本图像以及查询样本图像的类别,确定训练数据集中的多个样本图像对,以及样本图像对中支持样本图像以及查询样本图像之间的样本相似度。
步骤304,依次针对每个训练数据集,将训练数据集中样本图像对中的支持样本图像和查询样本图像输入图像分类模型,获取样本图像对中支持样本图像与查询样本图像之间的预测相似度。
在本公开实施例中,电子设备执行步骤304的过程例如可以为,将训练数据集中样本图像对中的支持样本图像和查询样本图像输入图像分类模型的特征提取网络以及注意力机制网络,获取支持样本图像的支持图像特征,以及查询样本图像的查询图像特征;将支持图像特征和查询图像特征输入图像分类模型的相似度计算网络,获取相似度计算网络输出的预测相似度。
步骤305,根据预测相似度,以及样本图像对中支持样本图像与查询样本图像之间的样本相似度,构建损失函数。
在本公开实施例中,损失函数可以根据一个训练数据集中样本图像对的预测相似度以及样本相似度进行构建。例如,损失函数可以为一个训练数据集中样本图像对的预测相似度以及样本相似度的差方和。
步骤306,根据损失函数的数值,对图像分类模型进行参数调整,实现训练。
在本公开实施例中,以训练数据集的数量为5个为例,针对第一个训练数据集,根据该训练数据集中样本图像对的预测相似度以及样本相似度,构建损失函数;根据损失函数的数值,对图像分类模型进行参数调整;针对第二个训练数据集至第五个训练数据集,分别参考第一个训练数据集的执行步骤;第五个训练数据集执行上述步骤完成后,得到训练好的图像分类模型。
其中,根据多个训练数据集中样本图像对的预测相似度以及样本相似度,对图像分类模型进行参数调整,可以使得图像分类模型能够学习到不同类别间的较小差异,从而能够完成新的没有接触过的类别的分类任务。
本公开实施例的图像分类模型的训练方法,通过获取多个训练数据集,训练数据集包括支持集和查询集;支持集包括支持样本图像以及支持样本图像的类别;查询集中包括查询样本图像以及查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据训练数据集中的支持样本图像、支持样本图像的类别、查询样本图像以及查询样本图像的类别,确定训练数据集中的多个样本图像对,以及样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,将训练数据集中样本图像对中的支持样本图像和查询样本图像输入图像分类模型,获取样本图像对中支持样本图像与查询样本图像之间的预测相似度;根据预测相似度,以及样本图像对中支持样本图像与查询样本图像之间的样本相似度,构建损失函数;根据损失函数的数值,对图像分类模型进行参数调整,实现训练,从而能够根据较少的样本图像以及对应的类别,训练得到准确度较高的图像分类模型,能够适用于图像标注数据比较缺乏的任务,例如物种细粒度识别任务等,提高物种细粒度识别任务下的识别准确度。
图4是根据本公开第四实施例的示意图,需要说明的是,本公开实施例的图像分类方法可应用于图像分类装置,该装置可被配置于电子设备中,以使该电子设备可以执行图像分类功能。以下实施例中以执行主体为电子设备为例进行说明。
其中,电子设备可以为任一具有计算能力的设备,例如可以为个人电脑(PersonalComputer,简称PC)、移动终端、服务器等,移动终端例如可以为车载设备、手机、平板电脑、个人数字助理、穿戴式设备等具有各种操作系统、触摸屏和/或显示屏的硬件设备。
如图4所示,该图像分类方法可以包括如下步骤:
步骤401,获取待处理图像以及支持集,支持集包括多个支持样本图像以及支持样本图像的类别。
在本公开实施例中,支持集中,支持样本图像的数量可以为多个,类别的数量可以为多个。例如,支持集中可以包括N个类别,每个类别下可以有K个支持样本图像。
在本公开实施例中,支持集的数量可以为多个。电子设备获取支持集的过程例如可以为,从原始数据集的多个类别中抽取第一类别,并从原始数据集中具有第一类别的样本图像中抽取支持样本图像,得到一个支持集。其中,多个支持集中的类别可以相同或者不同。
在本公开实施例中,支持样本图像的类别,为支持样本图像中目标对象所属的物种;查询样本图像的类别,为查询样本图像中目标对象所属的物种。其中,类别为物种的情况下,图像分类模型,可以用于细粒度的物种识别,从而避免在物种识别任务下由专家进行物种识别,降低物种识别任务下的人工成本,提高物种识别任务下的物种识别准确度。
步骤402,根据待处理图像以及多个支持样本图像,生成多个图像对;图像对中包括待处理图像以及支持样本图像。
在本公开实施例中,电子设备执行步骤402的过程例如可以为,针对多个支持样本图像中的每个支持样本图像,对该支持样本图像以及待处理图像进行配对处理,得到一个图像对。
步骤403,将待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取图像对中待处理图像的图像特征,以及图像对中支持样本图像的支持图像特征;图像分类模型基于图像分类模型的训练方法训练得到。
在本公开实施例中,图像分类模型中包括依次连接的特征提取网络、注意力机制网络和相似度计算网络;特征提取网络与注意力机制网络,用于提取图像对中支持样本图像的支持图像特征,以及提取图像对中待处理图像的图像特征;相似度计算网络,用于对支持图像特征以及图像特征进行拼接处理以及相似度计算处理,获取图像对中支持样本图像与待处理图像之间的相似度。
其中,图像分类模型,基于图1至图3中任一实施例的图像分类模型的训练方法训练得到。
其中,需要说明的是,本公开实施例中支持集中支持样本图像的数量,与图1至图3实施例中多个训练数据集中样本图像的数量,之间的比值可以为固定数值或者靠近固定数值。其中,固定数值例如可以为3:7,以同时确保图像分类模型的准确度,以及对待处理图像进行分类时的准确度。
步骤404,将图像对中待处理图像的图像特征,以及图像对中支持样本图像的支持图像特征,输入图像分类模型中的相似度计算网络,获取待处理图像与支持样本图像之间的相似度。
步骤405,根据待处理图像与支持样本图像之间的相似度,以及支持样本图像的类别,确定待处理图像的类别。
在本公开实施例中,一种示例中,电子设备执行步骤405的过程例如可以为,根据待处理图像与支持样本图像之间的相似度,从多个支持样本图像中选择目标样本图像;将目标样本图像的类别,确定为待处理图像的类别。
其中,电子设备从多个支持样本图像中选择目标样本图像的过程例如可以为,根据相似度对多个支持样本图像进行降序排序,得到排序结果;将排序结果中排序在最前的支持样本图像,作为目标样本图像。
其中,与待处理图像相似度最大的支持样本图像,该支持样本图像的类别与待处理图像的类别相同的可能性最大,因此,电子设备将对应的相似度最大的支持样本图像,作为目标样本图像,能够进一步提高待处理图像类别确定的准确度,进一步提高图像分类的准确度。
另一种示例中,电子设备执行步骤405的过程例如可以为,针对支持样本图像的每个类别,对该类别下各个支持样本图像与待处理图像的相似度进行加和求平均值,得到该类别的相似度,进而得到各个类别的相似度;将对应的相似度最大的类别,确定为待处理图像的类别。
本公开实施例的图像分类方法,通过获取待处理图像以及支持集,支持集包括多个支持样本图像以及支持样本图像的类别;根据待处理图像以及多个支持样本图像,生成多个图像对;图像对中包括待处理图像以及支持样本图像;将待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取图像对中待处理图像的图像特征,以及图像对中支持样本图像的支持图像特征;图像分类模型基于图像分类模型的训练方法训练得到;将图像对中待处理图像的图像特征,以及图像对中支持样本图像的支持图像特征,输入图像分类模型中的相似度计算网络,获取待处理图像与支持样本图像之间的相似度;根据待处理图像与支持样本图像之间的相似度,以及支持样本图像的类别,确定待处理图像的类别,从而能够采用基于较少样本图像训练得到的图像分类模型,准确确定待处理图像的类别,在降低人工成本的情况下,确保图像分类的准确度。
为了实现上述实施例,本公开还提供一种图像分类模型的训练装置。如图5所示,图5是根据本公开第五实施例的示意图。该图像分类模型的训练装置50可以包括:第一获取模块501、第二获取模块502、确定模块503和训练模块504。
其中,第一获取模块501,用于获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;第二获取模块502,用于获取初始的图像分类模型;确定模块503,用于针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;训练模块504,用于依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。
作为本公开实施例的一种可能实现方式,所述第一获取模块501具体用于,获取原始数据集,其中,所述原始数据集中包括大于预设数量的样本图像,以及所述样本图像的类别;从所述原始数据集的多个类别中抽取第一类别,并从所述原始数据集中具有所述第一类别的样本图像中抽取支持样本图像,得到支持集;从所述第一类别中抽取一个类别作为第二类别,并从所述原始数据集中具有所述第二类别的样本图像中抽取查询样本图像,得到所述支持集对应的查询集;根据所述支持集以及所述支持集对应的查询集,生成训练数据集。
作为本公开实施例的一种可能实现方式,所述确定模块503具体用于,针对每个训练数据集,根据所述训练数据集中的所述支持样本图像以及所述查询样本图像,生成多个所述样本图像对;针对每个样本图像对,根据所述样本图像对中支持样本图像的类别,以及所述样本图像对中查询样本图像的类别,确定所述样本图像对中支持样本图像与查询样本图像之间的样本相似度。
作为本公开实施例的一种可能实现方式,所述图像分类模型中包括依次连接的特征提取网络、注意力机制网络和相似度计算网络;所述特征提取网络与所述注意力机制网络,用于提取样本图像对中支持样本图像的支持图像特征,以及提取所述样本图像对中查询样本图像的查询图像特征;所述相似度计算网络,用于对所述支持图像特征以及所述查询图像特征进行拼接处理以及相似度计算处理,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度。
作为本公开实施例的一种可能实现方式,所述特征提取网络与所述注意力机制网络,分别为视觉Vision Transformer模型中的特征提取网络以及注意力机制网络。
作为本公开实施例的一种可能实现方式,所述训练模块504具体用于,依次针对每个训练数据集,将所述训练数据集中样本图像对中的支持样本图像和查询样本图像输入所述图像分类模型,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度;根据所述预测相似度,以及所述样本图像对中支持样本图像与查询样本图像之间的样本相似度,构建损失函数;根据所述损失函数的数值,对所述图像分类模型进行参数调整,实现训练。
作为本公开实施例的一种可能实现方式,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;所述查询样本图像的类别,为所述查询样本图像中目标对象所属的物种。
本公开实施例的图像分类模型的训练装置,通过获取多个训练数据集,训练数据集包括支持集和查询集;支持集包括支持样本图像以及支持样本图像的类别;查询集中包括查询样本图像以及查询样本图像的类别;获取初始的图像分类模型;针对每个训练数据集,根据训练数据集中的支持样本图像、支持样本图像的类别、查询样本图像以及查询样本图像的类别,确定训练数据集中的多个样本图像对,以及样本图像对中支持样本图像以及查询样本图像之间的样本相似度;依次针对每个训练数据集,以训练数据集中样本图像对中的支持样本图像和查询样本图像为图像分类模型的输入,以样本图像对中支持样本图像以及查询样本图像之间的样本相似度为图像分类模型的输出,对图像分类模型进行训练,得到训练好的图像分类模型,从而能够根据较少的样本图像以及对应的类别,训练得到准确度较高的图像分类模型,能够适用于图像标注数据比较缺乏的任务,例如物种细粒度识别任务等,提高物种细粒度识别任务下的识别准确度。
为了实现上述实施例,本公开还提供一种图像分类装置,如图6所示,图6是根据本公开第六实施例的示意图。该图像分类装置60可以包括:获取模块601、生成模块602、第一输入模块603、第二输入模块604和确定模块605。
其中,获取模块601,用于获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;生成模块602,用于根据所述待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;第一输入模块603,用于将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于图像分类模型的训练方法训练得到;第二输入模块604,用于将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;确定模块605,用于根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别。
作为本公开实施例的一种可能实现方式,所述确定模块605具体用于,根据所述待处理图像与所述支持样本图像之间的相似度,从多个所述支持样本图像中选择目标样本图像;将所述目标样本图像的类别,确定为所述待处理图像的类别。
作为本公开实施例的一种可能实现方式,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;所述待处理图像的类别,为所述待处理图像中目标对象所属的物种。
本公开实施例的图像分类装置,通过获取待处理图像以及支持集,支持集包括多个支持样本图像以及支持样本图像的类别;根据待处理图像以及多个支持样本图像,生成多个图像对;图像对中包括待处理图像以及支持样本图像;将待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取图像对中待处理图像的图像特征,以及图像对中支持样本图像的支持图像特征;图像分类模型基于图像分类模型的训练方法训练得到;将图像对中待处理图像的图像特征,以及图像对中支持样本图像的支持图像特征,输入图像分类模型中的相似度计算网络,获取待处理图像与支持样本图像之间的相似度;根据待处理图像与支持样本图像之间的相似度,以及支持样本图像的类别,确定待处理图像的类别,从而能够采用基于较少样本图像训练得到的图像分类模型,准确确定待处理图像的类别,在降低人工成本的情况下,确保图像分类的准确度。
本公开的技术方案中,所涉及的用户个人信息的收集、存储、使用、加工、传输、提供和公开等处理,均在征得用户同意的前提下进行,并且均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图7示出了可以用来实施本公开的实施例的示例电子设备700的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图7所示,设备700包括计算单元701,其可以根据存储在只读存储器(ROM)702中的计算机程序或者从存储单元708加载到随机访问存储器(RAM)703中的计算机程序,来执行各种适当的动作和处理。在RAM 703中,还可存储设备700操作所需的各种程序和数据。计算单元701、ROM 702以及RAM 703通过总线704彼此相连。输入/输出(I/O)接口705也连接至总线704。
设备700中的多个部件连接至I/O接口705,包括:输入单元706,例如键盘、鼠标等;输出单元707,例如各种类型的显示器、扬声器等;存储单元708,例如磁盘、光盘等;以及通信单元709,例如网卡、调制解调器、无线通信收发机等。通信单元709允许设备700通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元701可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元701的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元701执行上文所描述的各个方法和处理,例如图像分类模型的训练方法或者图像分类方法。例如,在一些实施例中,图像分类模型的训练方法或者图像分类方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元708。在一些实施例中,计算机程序的部分或者全部可以经由ROM 702和/或通信单元709而被载入和/或安装到设备700上。当计算机程序加载到RAM 703并由计算单元701执行时,可以执行上文描述的图像分类模型的训练方法或者图像分类方法的一个或多个步骤。备选地,在其他实施例中,计算单元701可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行图像分类模型的训练方法或者图像分类方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (23)
1.一种图像分类模型的训练方法,包括:
获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;
获取初始的图像分类模型;
针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;
依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。
2.根据权利要求1所述的方法,其中,所述获取多个训练数据集,包括:
获取原始数据集,其中,所述原始数据集中包括大于预设数量的样本图像,以及所述样本图像的类别;
从所述原始数据集的多个类别中抽取第一类别,并从所述原始数据集中具有所述第一类别的样本图像中抽取支持样本图像,得到支持集;
从所述第一类别中抽取一个类别作为第二类别,并从所述原始数据集中具有所述第二类别的样本图像中抽取查询样本图像,得到所述支持集对应的查询集;
根据所述支持集以及所述支持集对应的查询集,生成训练数据集。
3.根据权利要求1所述的方法,其中,所述针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度,包括:
针对每个训练数据集,根据所述训练数据集中的所述支持样本图像以及所述查询样本图像,生成多个所述样本图像对;
针对每个样本图像对,根据所述样本图像对中支持样本图像的类别,以及所述样本图像对中查询样本图像的类别,确定所述样本图像对中支持样本图像与查询样本图像之间的样本相似度。
4.根据权利要求1所述的方法,其中,所述图像分类模型中包括依次连接的特征提取网络、注意力机制网络和相似度计算网络;
所述特征提取网络与所述注意力机制网络,用于提取样本图像对中支持样本图像的支持图像特征,以及提取所述样本图像对中查询样本图像的查询图像特征;
所述相似度计算网络,用于对所述支持图像特征以及所述查询图像特征进行拼接处理以及相似度计算处理,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度。
5.根据权利要求4所述的方法,其中,所述特征提取网络与所述注意力机制网络,分别为视觉Vision Transformer模型中的特征提取网络以及注意力机制网络。
6.根据权利要求1所述的方法,其中,所述依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型,包括:
依次针对每个训练数据集,将所述训练数据集中样本图像对中的支持样本图像和查询样本图像输入所述图像分类模型,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度;
根据所述预测相似度,以及所述样本图像对中支持样本图像与查询样本图像之间的样本相似度,构建损失函数;
根据所述损失函数的数值,对所述图像分类模型进行参数调整,实现训练。
7.根据权利要求1-6中任一项所述的方法,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;
所述查询样本图像的类别,为所述查询样本图像中目标对象所属的物种。
8.一种图像分类方法,包括:
获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;
根据所述待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;
将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于权利要求1-7中任一项所述的方法训练得到;
将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;
根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别。
9.根据权利要求8所述的方法,其中,所述根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别,包括:
根据所述待处理图像与所述支持样本图像之间的相似度,从多个所述支持样本图像中选择目标样本图像;
将所述目标样本图像的类别,确定为所述待处理图像的类别。
10.根据权利要求8或9所述的方法,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;
所述待处理图像的类别,为所述待处理图像中目标对象所属的物种。
11.一种图像分类模型的训练装置,包括:
第一获取模块,用于获取多个训练数据集,所述训练数据集包括支持集和查询集;所述支持集包括支持样本图像以及所述支持样本图像的类别;所述查询集中包括查询样本图像以及所述查询样本图像的类别;
第二获取模块,用于获取初始的图像分类模型;
确定模块,用于针对每个训练数据集,根据所述训练数据集中的所述支持样本图像、所述支持样本图像的类别、所述查询样本图像以及所述查询样本图像的类别,确定所述训练数据集中的多个样本图像对,以及所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度;
训练模块,用于依次针对每个训练数据集,以所述训练数据集中样本图像对中的支持样本图像和查询样本图像为所述图像分类模型的输入,以所述样本图像对中支持样本图像以及查询样本图像之间的样本相似度为所述图像分类模型的输出,对所述图像分类模型进行训练,得到训练好的图像分类模型。
12.根据权利要求11所述的装置,其中,所述第一获取模块具体用于,
获取原始数据集,其中,所述原始数据集中包括大于预设数量的样本图像,以及所述样本图像的类别;
从所述原始数据集的多个类别中抽取第一类别,并从所述原始数据集中具有所述第一类别的样本图像中抽取支持样本图像,得到支持集;
从所述第一类别中抽取一个类别作为第二类别,并从所述原始数据集中具有所述第二类别的样本图像中抽取查询样本图像,得到所述支持集对应的查询集;
根据所述支持集以及所述支持集对应的查询集,生成训练数据集。
13.根据权利要求11所述的装置,其中,所述确定模块具体用于,
针对每个训练数据集,根据所述训练数据集中的所述支持样本图像以及所述查询样本图像,生成多个所述样本图像对;
针对每个样本图像对,根据所述样本图像对中支持样本图像的类别,以及所述样本图像对中查询样本图像的类别,确定所述样本图像对中支持样本图像与查询样本图像之间的样本相似度。
14.根据权利要求11所述的装置,其中,所述图像分类模型中包括依次连接的特征提取网络、注意力机制网络和相似度计算网络;
所述特征提取网络与所述注意力机制网络,用于提取样本图像对中支持样本图像的支持图像特征,以及提取所述样本图像对中查询样本图像的查询图像特征;
所述相似度计算网络,用于对所述支持图像特征以及所述查询图像特征进行拼接处理以及相似度计算处理,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度。
15.根据权利要求14所述的装置,其中,所述特征提取网络与所述注意力机制网络,分别为视觉Vision Transformer模型中的特征提取网络以及注意力机制网络。
16.根据权利要求11所述的装置,其中,所述训练模块具体用于,
依次针对每个训练数据集,将所述训练数据集中样本图像对中的支持样本图像和查询样本图像输入所述图像分类模型,获取所述样本图像对中支持样本图像与查询样本图像之间的预测相似度;
根据所述预测相似度,以及所述样本图像对中支持样本图像与查询样本图像之间的样本相似度,构建损失函数;
根据所述损失函数的数值,对所述图像分类模型进行参数调整,实现训练。
17.根据权利要求11-16中任一项所述的装置,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;
所述查询样本图像的类别,为所述查询样本图像中目标对象所属的物种。
18.一种图像分类装置,包括:
获取模块,用于获取待处理图像以及支持集,所述支持集包括多个支持样本图像以及所述支持样本图像的类别;
生成模块,用于根据所述待处理图像以及多个所述支持样本图像,生成多个图像对;所述图像对中包括所述待处理图像以及所述支持样本图像;
第一输入模块,用于将所述待处理图像对输入图像分类模型的特征提取网络以及注意力机制网络,获取所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征;所述图像分类模型基于权利要求1-7中任一项所述的方法训练得到;
第二输入模块,用于将所述图像对中待处理图像的图像特征,以及所述图像对中支持样本图像的支持图像特征,输入所述图像分类模型中的相似度计算网络,获取所述待处理图像与所述支持样本图像之间的相似度;
确定模块,用于根据所述待处理图像与所述支持样本图像之间的相似度,以及所述支持样本图像的类别,确定所述待处理图像的类别。
19.根据权利要求18所述的装置,其中,所述确定模块具体用于,
根据所述待处理图像与所述支持样本图像之间的相似度,从多个所述支持样本图像中选择目标样本图像;
将所述目标样本图像的类别,确定为所述待处理图像的类别。
20.根据权利要求18或19所述的装置,其中,所述支持样本图像的类别,为所述支持样本图像中目标对象所属的物种;
所述待处理图像的类别,为所述待处理图像中目标对象所属的物种。
21.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-7中任一项所述的方法;或者,执行权利要求8-10中任一项所述的方法。
22.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-7中任一项所述的方法;或者,执行根据权利要求8-10中任一项所述的方法。
23.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-7中任一项所述方法的步骤;
或者,实现根据权利要求8-10中任一项所述方法的步骤。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310014934.XA CN115965817A (zh) | 2023-01-05 | 2023-01-05 | 图像分类模型的训练方法、装置及电子设备 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310014934.XA CN115965817A (zh) | 2023-01-05 | 2023-01-05 | 图像分类模型的训练方法、装置及电子设备 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115965817A true CN115965817A (zh) | 2023-04-14 |
Family
ID=87357838
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310014934.XA Pending CN115965817A (zh) | 2023-01-05 | 2023-01-05 | 图像分类模型的训练方法、装置及电子设备 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115965817A (zh) |
Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107145827A (zh) * | 2017-04-01 | 2017-09-08 | 浙江大学 | 基于自适应距离度量学习的跨摄像机行人再识别方法 |
CN108388888A (zh) * | 2018-03-23 | 2018-08-10 | 腾讯科技(深圳)有限公司 | 一种车辆识别方法、装置和存储介质 |
CN111062424A (zh) * | 2019-12-05 | 2020-04-24 | 中国科学院计算技术研究所 | 小样本食品图像识别模型训练方法及食品图像识别方法 |
CN113627522A (zh) * | 2021-08-09 | 2021-11-09 | 华南师范大学 | 基于关系网络的图像分类方法、装置、设备及存储介质 |
CN113780345A (zh) * | 2021-08-06 | 2021-12-10 | 华中科技大学 | 面向中小企业的基于张量注意力的小样本分类方法和系统 |
CN113902256A (zh) * | 2021-09-10 | 2022-01-07 | 支付宝(杭州)信息技术有限公司 | 训练标签预测模型的方法、标签预测方法和装置 |
CN114187905A (zh) * | 2020-08-27 | 2022-03-15 | 海信视像科技股份有限公司 | 用户意图识别模型的训练方法、服务器及显示设备 |
CN114299363A (zh) * | 2021-12-29 | 2022-04-08 | 京东方科技集团股份有限公司 | 图像处理模型的训练方法、图像分类方法及装置 |
CN115424053A (zh) * | 2022-07-25 | 2022-12-02 | 北京邮电大学 | 小样本图像识别方法、装置、设备及存储介质 |
-
2023
- 2023-01-05 CN CN202310014934.XA patent/CN115965817A/zh active Pending
Patent Citations (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN107145827A (zh) * | 2017-04-01 | 2017-09-08 | 浙江大学 | 基于自适应距离度量学习的跨摄像机行人再识别方法 |
CN108388888A (zh) * | 2018-03-23 | 2018-08-10 | 腾讯科技(深圳)有限公司 | 一种车辆识别方法、装置和存储介质 |
CN111062424A (zh) * | 2019-12-05 | 2020-04-24 | 中国科学院计算技术研究所 | 小样本食品图像识别模型训练方法及食品图像识别方法 |
CN114187905A (zh) * | 2020-08-27 | 2022-03-15 | 海信视像科技股份有限公司 | 用户意图识别模型的训练方法、服务器及显示设备 |
CN113780345A (zh) * | 2021-08-06 | 2021-12-10 | 华中科技大学 | 面向中小企业的基于张量注意力的小样本分类方法和系统 |
CN113627522A (zh) * | 2021-08-09 | 2021-11-09 | 华南师范大学 | 基于关系网络的图像分类方法、装置、设备及存储介质 |
CN113902256A (zh) * | 2021-09-10 | 2022-01-07 | 支付宝(杭州)信息技术有限公司 | 训练标签预测模型的方法、标签预测方法和装置 |
CN114299363A (zh) * | 2021-12-29 | 2022-04-08 | 京东方科技集团股份有限公司 | 图像处理模型的训练方法、图像分类方法及装置 |
CN115424053A (zh) * | 2022-07-25 | 2022-12-02 | 北京邮电大学 | 小样本图像识别方法、装置、设备及存储介质 |
Non-Patent Citations (1)
Title |
---|
陆妍等: "基于Transformer 的小样本细粒度图像分类方法", 《计算机工程与应用》, pages 1 - 11 * |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114549874B (zh) | 多目标图文匹配模型的训练方法、图文检索方法及装置 | |
CN111104514A (zh) | 文档标签模型的训练方法及装置 | |
CN114494784A (zh) | 深度学习模型的训练方法、图像处理方法和对象识别方法 | |
CN114118287A (zh) | 样本生成方法、装置、电子设备以及存储介质 | |
CN113836314B (zh) | 知识图谱构建方法、装置、设备以及存储介质 | |
CN115063875A (zh) | 模型训练方法、图像处理方法、装置和电子设备 | |
CN112528641A (zh) | 建立信息抽取模型的方法、装置、电子设备和可读存储介质 | |
CN113360700A (zh) | 图文检索模型的训练和图文检索方法、装置、设备和介质 | |
KR20230006601A (ko) | 정렬 방법, 정렬 모델의 트레이닝 방법, 장치, 전자 기기 및 매체 | |
CN114861059A (zh) | 资源推荐方法、装置、电子设备及存储介质 | |
CN112528146B (zh) | 内容资源推荐方法、装置、电子设备及存储介质 | |
EP4246365A1 (en) | Webpage identification method and apparatus, electronic device, and medium | |
CN115470900A (zh) | 一种神经网络模型的剪枝方法、装置及设备 | |
CN112100362B (zh) | 文档格式推荐模型训练方法、装置以及电子设备 | |
CN114817476A (zh) | 语言模型的训练方法、装置、电子设备和存储介质 | |
CN113408632A (zh) | 提高图像分类准确性的方法、装置、电子设备及存储介质 | |
CN114612725A (zh) | 图像处理方法、装置、设备及存储介质 | |
CN114707638A (zh) | 模型训练、对象识别方法及装置、设备、介质和产品 | |
CN115965817A (zh) | 图像分类模型的训练方法、装置及电子设备 | |
CN114120416A (zh) | 模型训练方法、装置、电子设备及介质 | |
CN112784600A (zh) | 信息排序方法、装置、电子设备和存储介质 | |
CN112329427B (zh) | 短信样本的获取方法和装置 | |
US20220383626A1 (en) | Image processing method, model training method, relevant devices and electronic device | |
US20220222941A1 (en) | Method for recognizing action, electronic device and storage medium | |
US20230004774A1 (en) | Method and apparatus for generating node representation, electronic device and readable storage medium |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |