CN115238888A - 图像分类模型的训练方法、使用方法、装置、设备及介质 - Google Patents
图像分类模型的训练方法、使用方法、装置、设备及介质 Download PDFInfo
- Publication number
- CN115238888A CN115238888A CN202210885176.4A CN202210885176A CN115238888A CN 115238888 A CN115238888 A CN 115238888A CN 202210885176 A CN202210885176 A CN 202210885176A CN 115238888 A CN115238888 A CN 115238888A
- Authority
- CN
- China
- Prior art keywords
- feature
- feature representation
- image
- sample
- low
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Software Systems (AREA)
- Multimedia (AREA)
- Medical Informatics (AREA)
- Databases & Information Systems (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Data Mining & Analysis (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了一种图像分类模型的训练方法、使用方法、装置、设备及介质,属于人工智能领域。图像分类模型包括特征提取网络和多示例学习模型,方法包括:获取样本图像集,样本图像集中的样本图像包括至少两个示例;通过基于对比学习的自监督学习,采用样本图像集中的样本图像对特征提取网络进行训练,得到训练后的特征提取网络;通过基于互注意力机制的多示例学习,采用样本图像集中的样本图像对多示例学习模型进行训练,得到训练后的多示例学习模型。上述方案可以降低图像分类模型的计算复杂度,本申请实施例可应用于云技术、人工智能、智慧交通、辅助驾驶等各种场景。
Description
技术领域
本申请实施例涉及人工智能技术领域,特别涉及一种图像分类模型的训练方法、使用方法、装置、计算机设备、计算机可读存储介质和计算机程序产品。
背景技术
在机器学习中,多示例学习MIL(Multiple Instance Learning)是由监督型算法演变出的一种方法。
标准的多示例学习假设一个包有若干示例。多示例学习中只有包具有标签,示例不具有标签。如果包至少含有一个正示例,则该包被标记为正样本。如果包的所有示例都是负示例,则该包被标记为负样本。以包是图像,示例是图像中的动物,正示例是猫为例,由于一张图像中可以有很多个动物,那么不论一张图像有多少其它动物,只要一张图像中具有一个或多个猫,该图像都被标记为正样本;如果图像里有其它动物但一个猫都没有,该图像被标记为负样本。相关技术中通过挖掘同一图像中的所有示例的特征信息,再基于自注意力机制对同一图像中的不同示例之间的特征信息进行处理,以此提高多示例学习模型的特征提取能力。
然而,相关技术中的自注意力机制的计算复杂度比较高,硬件资源和时间消耗都很大,训练困难,如何降低多示例学习方法的计算复杂度是亟待解决的问题。
发明内容
本申请提供了一种图像分类模型的训练方法、使用方法、装置、计算机设备、计算机可读存储介质和计算机程序产品,可以提高图像分类模型的预测精度。技术方案如下:
一个方面,提供了一种图像分类模型的训练方法,图像分类模型包括特征提取网络和多示例学习模型,该方法包括:
获取样本图像集,样本图像集中的样本图像包括至少两个示例;
通过基于对比学习的自监督学习,采用样本图像集中的样本图像对特征提取网络进行训练,得到训练后的特征提取网络;
通过基于互注意力机制的多示例学习,采用样本图像集中的样本图像对多示例学习模型进行训练,得到训练后的多示例学习模型;
其中,互注意力机制是计算至少两个示例的特征表示与低秩隐变量之间的注意力的机制,低秩隐变量的数量小于至少两个示例的数量,低秩隐变量是基于样本图像的秩设置的。
另一个方面,提供了一种图像分类模型的使用方法,该方法包括:
获取待分类的输入图像,输入图像包括至少两个示例;
将输入图像输入图像分类模型中的特征提取网络进行特征提取,得到第三特征表示序列,第三特征表示序列包括至少两个示例的特征表示;
将第三特征表示序列输入图像分类模型中的多示例学习模型进行基于互注意力机制的分类预测,得到输入图像的分类结果;
其中,互注意力机制是对计算至少两个示例的特征表示与低秩隐变量之间的注意力的机制,低秩隐变量的数量小于至少两个示例的数量,低秩隐变量是基于样本图像的秩设置的。
另一个方面,提供了一种图像分类模型的训练装置,图像分类模型包括特征提取网络和多示例学习模型,该装置包括:
获取模块,用于获取样本图像集,所述样本图像集中的样本图像包括至少两个示例;
自监督学习模块,用于通过基于对比学习的自监督学习,采用所述样本图像集中的样本图像对所述特征提取网络进行训练,得到训练后的特征提取网络;
多示例学习模块,用于通过基于互注意力机制的多示例学习,采用所述样本图像集中的样本图像对所述多示例学习模型进行训练,得到训练后的多示例学习模型;
其中,所述互注意力机制是计算所述至少两个示例的特征表示与低秩隐变量之间的注意力的机制,所述低秩隐变量的数量小于所述至少两个示例的数量,所述低秩隐变量是基于所述样本图像的秩设置的。
另一个方面,提供了一种图像分类模型的使用装置,该装置包括:
获取模块,用于获取待分类的输入图像,输入图像包括至少两个示例;
特征提取模块,用于将输入图像输入图像分类模型中的特征提取网络进行特征提取,得到第三特征表示序列,第三特征表示序列包括至少两个示例的特征表示;
预测模块,用于将第三特征表示序列输入图像分类模型中的多示例学习模型进行基于互注意力机制的分类预测,得到输入图像的分类结果;
其中,互注意力机制是对计算至少两个示例的特征表示与低秩隐变量之间的注意力的机制,低秩隐变量的数量小于至少两个示例的数量,低秩隐变量是基于样本图像的秩设置的。
另一个方面,提供了一种计算机设备,计算机设备包括:处理器和存储器,存储器存储有计算机程序,计算机程序由处理器加载并执行以实现如上图像分类模型的训练方法和/或使用方法。
另一个方面,提供了一种计算机可读存储介质,存储介质存储有计算机程序,计算机程序由处理器加载并执行以实现如上图像分类模型的训练方法和/或使用方法。
另一个方面,提供了一种计算机程序产品,计算机程序产品包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述方面提供的图像分类模型的训练方法和/或使用方法。
本申请实施例提供的技术方案带来的有益效果至少包括:
通过预先按照样本图像的秩设置低秩隐变量,通过同一个样本图像中的至少两个示例与低秩隐变量之间的互注意力机制,代替相关技术中的自注意力机制,避免直接计算自注意力,在减少计算量的同时也保证模型的学习能力和预测精度。比如,同一个样本图像中的多个示例的示例数量为n,则自注意力机制的计算复杂度为O(n2),而本申请中的低秩隐变量的数量为r,则互注意力机制的计算复杂度为O(rn)。在样本图像为低秩样本图像的情况下,r远小于n,因此能够显著降低运算复杂度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了一个示例性实施例提供的计算机系统的结构框图;
图2示出了一个示例性实施例提供的图像分类模型的训练方法的示意图;
图3示出了一个示例性实施例提供的图像分类模型的训练方法的示意图;
图4示出了一个示例性实施例提供的图像分类模型的训练方法的流程图;
图5示出了一个示例性实施例提供的图像分类模型的训练方法的流程图;
图6示出了一个示例性实施例提供的图像分类模型的训练方法的流程图;
图7示出了一个示例性实施例提供的互注意力网络的示意图;
图8示出了一个示例性实施例提供的门控注意力模块的示意图;
图9示出了一个示例性实施例提供的图像分类模型的训练方法的流程图;
图10示出了一个示例性实施例提供的特征向量的低秩示意图;
图11示出了一个示例性实施例提供的图像分类模型的训练方法的流程图;
图12示出了一个示例性实施例提供的ImageNet预训练特征和自监督学习的秩与概率密度的示意图;
图13示出了一个示例性实施例提供的图像分类模型的使用方法的流程图;
图14示出了一个示例性实施例提供的图像分类模型的训练方法的场景图;
图15示出了一个示例性实施例提供的图像分类模型的训练方法的场景图;
图16示出了一个示例性实施例提供的图像分类模型的训练装置的结构框图;
图17示出了一个示例性实施例提供的图像分类模型的使用装置的结构框图;
图18示出了一个示例性实施例提供的计算机设备的结构框图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本申请相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本申请的一些方面相一致的装置和方法的例子。
在本公开使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本公开。在本公开和所附权利要求书中所使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。还应当理解,本文中使用的术语“和/或”是指并包含一个或多个相关联的列出项目的任何或所有可能组合。
需要说明的是,本申请所涉及的用户信息(包括但不限于用户设备信息、用户个人信息等)和数据(包括但不限于用于分析的数据、存储的数据、展示的数据等),均为经用户授权或者经过各方充分授权的信息和数据,且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。例如,本申请中涉及到的输入图像是经过用户授权或经过各方充分授权的情况下获取的。
应当理解,尽管在本公开可能采用术语第一、第二等来描述各种信息,但这些信息不应限于这些术语。这些术语仅用来将同一类型的信息彼此区分开。例如,在不脱离本公开范围的情况下,第一参数也可以被称为第二参数,类似地,第二参数也可以被称为第一参数。取决于语境,如在此所使用的词语“如果”可以被解释成为“在……时”或“当……时”或“响应于确定”。
图1示出了本申请一个示例性实施例提供的计算机系统的结构框图。该计算机系统可以实现成为图像分类模型的训练方法和/或使用方法的系统架构。该计算机系统可以包括:终端100和服务器200。
终端100包括但不限于手机、电脑、智能语音交互设备、智能家电、车载终端、飞行器等。终端100中可以安装运行目标应用程序的客户端,该目标应用程序可以是图像分类模型的训练和/或使用应用程序,也可以是提供有图像分类模型的训练和/或使用功能的其他应用程序,本申请对此不作限定。另外,本申请对该目标应用程序的形式不作限定,包括但不限于安装在终端100中的App(Application,应用程序)、小程序等,还可以是网页形式。
服务器200可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云计算服务的云服务器、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、内容分发网络(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。服务器200可以是上述目标应用程序的后台服务器,用于为目标应用程序的客户端提供后台服务。
本申请实施例提供的图像分类模型的训练方法和/或使用方法,各步骤的执行主体可以是计算机设备,计算机设备是指具备数据计算、处理和存储能力的电子设备。以图1所示的方案实施环境为例,可以由终端100执行图像分类模型的训练方法和/或使用方法,比如终端100中安装运行的目标应用程序的客户端执行该图像分类模型的训练方法和/或使用方法,也可以由服务器200执行该图像分类模型的训练方法和/或使用方法,或者由终端100和服务器200交互配合执行,本申请对此不作限定。
此外,本申请技术方案还可以和区块链技术相结合。例如,本申请所公开的图像分类模型的训练方法和/或使用方法,其中涉及的一些数据可以保存于区块链上。终端100和服务器200之间可以通过网络进行通信,如有线或无线网络。本申请实施例可应用于各种场景,包括但不限于云技术、人工智能、智慧交通、辅助驾驶等。
本申请实施例涉及了人工智能技术领域与计算机视觉技术。
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
计算机视觉技术(Computer Vision,CV)是一门研究如何使机器“看”的科学,更进一步的说,就是指用摄影机和电脑代替人眼对目标进行识别和测量等机器视觉,并进一步做图形处理,使电脑处理成为更适合人眼观察或传送给仪器检测的图像。作为一个科学学科,计算机视觉研究相关的理论和技术,试图建立能够从图像或者多维数据中获取信息的人工智能系统。计算机视觉技术通常包括图像处理、图像识别、图像语义理解、图像检索、光符识别(optical character recognition,OCR)、视频处理、视频语义理解、视频内容/行为识别、三维物体重建、3D技术、虚拟现实、增强现实、同步定位与地图构建等技术,还包括常见的人脸识别、指纹识别等生物特征识别技术。
相关技术中通过挖掘同一图像中的所有示例的特征信息,再基于自注意力机制对同一图像中的不同示例之间的特征信息进行处理,以此提高多示例学习模型的特征提取能力。
然而,相关技术中的自注意力机制的计算复杂度比较高,硬件资源和时间消耗都很大,训练困难,如何降低多示例学习方法的计算复杂度是亟待解决的问题。
本申请中,通过预先按照样本图像的秩设置低秩隐变量,通过同一个样本图像中的至少两个示例与低秩隐变量之间的互注意力机制,代替相关技术中的自注意力机制,避免直接计算自注意力,在减少计算量的同时也保证模型的学习能力和预测精度。比如,同一个样本图像中的多个示例的示例数量为n,则自注意力机制的计算复杂度为O(n2),而本申请中的低秩隐变量的数量为r,则互注意力机制的计算复杂度为O(rn)。在样本图像为低秩样本图像的情况下,由于r远小于n,因此能够显著降低运算复杂度。
接下来,对本申请一个示例性实施例提供的图像分类模型进行介绍。图2示出了本申请一个示例性实施例提供的图像分类模型的训练原理图。图像分类模型包括:特征提取网络220和多示例学习模型240。
多示例学习模型240包括互注意力网络242、非局部池化层244和线性分类器246。互注意力网络242包括级联的两个门控注意力组件242-a和242-b。
特征提取网络220的输出端与门控注意力242-a的输入端连接,门控注意力组件242-b的输出端与非局部池化层244的输入端连接,非局部池化层244的输出端与线性分类器246的输入端连接。
图像分类模型对于样本图像200的处理过程包括:预处理阶段、自监督学习阶段、多示例学习阶段。
预处理阶段:
样本图像是指输入至图像分类模型的图像。训练样本集包括多个样本图像。可选地,该样本图像是低秩样本图像,比如细胞切片图像、森林图像、星空图像等图像中存在大量相似或相同的元素的图像。将每个样本图像进行预处理,得到每个样本图像的至少两个示例。比如,将每个样本图像划分为n个大小相同的图像切片,将每个图像切片视为是一个示例,则同一个样本图像具有n个示例;又比如,将每个样本图像中的脸部区域识别出来,每个脸部区域视为是一个示例,若样本图像中有2个动物的脸部区域,则具有2个示例;若样本图像中有4个动物的脸部区域,则具有4个示例。
自监督学习阶段:
通过特征提取网络220对至少两个示例进行特征提取,得到第一特征表示序列。第一特征表示序列包括多个第一特征表示,每个第一特征表示是一个示例的特征表示。若一个样本图像包括n个示例,则该样本图像的第一特征表示序列中包括n个第一特征表示。
相关技术中的传统自监督学习将与锚点第一特征表示完全相同的第一特征表示作为锚点第一特征表示的正样本特征表示;将与锚点第一特征表示不同的的第一特征表示作为锚点第一特征表示的负样本特征表示。其中,锚点第一特征表示是任意一个示例的第一特征表示。也即,传统自监督学习认为一个示例仅有一个正样本。
本申请实施例改进了正负样本的采样方法,针对传统自监督学习中仅考虑单个正样本的局限性,对第一特征表示序列中的第一特征表示进行聚类,将与锚点第一特征表示属于同一聚类结果的第一特征表示作为锚点第一特征表示的正样本特征表示;将与锚点第一特征表示属于不同聚类结果的第一特征表示作为锚点第一特征表示的负样本特征表示,从而引入了多个正样本,符合实际应用情况。之后基于正样本特征表示和负样本特征表示之间的自监督学习损失,对特征提取网络220进行训练。也即,本申请实施例中的自监督学习认为一个示例有多个正样本。
多示例学习阶段:
传统多示例学习引入自注意力机制对同一个样本图像的不同示例之间的相关信息进行建模。假设同一个样本图像具有n个示例,则自注意力机制具有高计算复杂度O(n2),硬件资源和时间消耗都很大,训练困难。
本申请实施例重点改进传统多示例学习中对输入示例的信息交互关注不足而导致的性能下降,具体来说引入互注意力机制。通过学习预设的r个低秩隐变量和n个示例之间的互注意力,实现对同一个样本图像中不同示例之间的相关信息建模,计算复杂度为O(rn),避免直接进行高复杂度的计算,在减少计算量的同时也保证图像分类模型的学习能力和预测精度。
示例性的,通过门控注意力组件242-a和242-b对第一特征表示序列中的第一特征表示与预设的r个低秩隐变量之间进行互注意力计算,输出第二特征表示序列。通过非局部池化层244对第二特征表示序列进行池化,得到样本图像的全局特征表示。通过线性分类器246对全局特征表示进行预测,得到样本图像的预测分类结果;基于样本图像200的预测分类结果和标签分类结果之间的误差损失,对多示例学习模型240的模型参数进行训练。
其中,自监督学习阶段和多示例学习阶段可以分开进行。在自监督学习阶段学习完毕后,再执行多示例学习阶段。
接下来,将通过不同实施例对上述图像分类模型的训练方法进行拆分介绍。
图3示出了本申请一个示例性实施例提供的图像分类模型的框图。该图像分类模型是用于对低秩多示例图像进行图像分类的模型。秩(rank)可以理解为图像所包含的信息的丰富程度。该图像分类模型包括:特征提取网络220和低秩多示例学习(Multi-InstanceLearning,MIL)模型240。
特征提取网络220用于提取低秩多示例图像中的各个示例的特征表示。特征提取网络220的输入为低秩多示例图像,特征提取网络220的输出为特征表示序列。该特征表示序列包括多个特征表示。每个特征表示序列与一个低秩多示例图像对应,每个特征表示与一个示例对应。
其中,低秩MIL模型240包括:互注意力网络242、池化层244、分类器246。特征提取网络220的输出端与互注意力网络242的输入端连接,互注意力网络的输出端与池化层的输入端连接,池化层的输出端与分类器的输入端连接。
图4示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图。该方法可以由计算机设备执行,该计算机设备上运行有图3所示的图像分类模型。该方法包括:
步骤320,获取样本图像集,样本图像集中的样本图像包括至少两个示例。
样本图像集是图像分类模型进行学习任务用到的图像集。样本图像集包括大量的低秩样本图像,下文将低秩样本图像简称为样本图像。
在多示例学习中,一个样本图像认为是一个包,一个样本图像中包括多个/多组可分类对象,每个/每组可分类对象认为是一个示例。同一个样本图像包括至少两个示例。在样本图像集中,每个样本图像具有图像级别的标签。
对于一个样本图像,若该样本图像中存在至少一个示例是正样本示例,不论该样本图像是否还具有其它负样本示例,则该样本图像的图像级别的标签为正样本;若该样本图像中的示例全部是负样本示例,则该样本图像的图像级别的标签为负样本。
可选地,将样本图像集中的样本图像按照切割分辨率进行分割,得到至少两个图像切片;将属于同一个样本图像的至少两个图像切片确定为同一个样本图像的至少两个示例。
步骤340,通过基于对比学习的自监督学习,采用样本图像集中的样本图像对特征提取网络进行训练,得到训练后的特征提取网络。
本实施例中的训练阶段分为2个阶段:自监督学习阶段和多示例学习阶段。先采用自监督学习阶段对特征提取网络进行训练,使得特征提取网络具有一定的特征提取能力,再对低秩MIL模型进行训练,训练低秩MIL模型的分类能力。
在对特征提取网络进行训练时,使用基于对比学习的自监督学习对特征提取网络进行训练。示意性的,在同一个图像的多个示例中,对于锚点示例(多个示例中的任意一个),在该多个示例中确定该锚点示例的正样本示例和负样本示例,基于特征提取网络对正样本示例和负样本示例分别提取的特征表示之间的误差,对特征提取网络进行训练。
其中,特征提取网络的输入是样本图像的多个示例,特征提取网络的输出是第一特征表示序列,该第一特征表示序列包括多个第一特征表示,多个示例和多个第一特征表示一一对应。
步骤360,通过基于互注意力机制的多示例学习,采用样本图像集中的样本图像对多示例学习模型进行训练,得到训练后的多示例学习模型。
其中,互注意力机制是计算至少两个示例的特征表示与低秩隐变量之间的注意力的机制,低秩隐变量的数量小于至少两个示例的数量,低秩隐变量是基于样本图像的秩设置的。
由于样本图像的秩比较低,预先通过对样本图像的秩进行分析,可以确定样本图像的秩为r。按照样本图像的秩设置r个低秩隐变量,通过同一个样本图像中的至少两个示例与r个低秩隐变量之间的互注意力机制,代替相关技术中的自注意力机制,避免直接计算自注意力,在减少计算量的同时也保证模型的学习能力和预测精度。
比如,同一个样本图像中的多个示例的示例数量为n,则自注意力机制的计算复杂度为O(n2),而本申请中的低秩隐变量的数量为r,则互注意力机制的计算复杂度为O(rn)。在样本图像为低秩样本图像的情况下,能够显著降低运算复杂度。
在基于图3的可选实施例中,如图5所示,步骤360可以包括步骤362至366。
步骤362,通过互注意力网络对第一特征表示序列中的第一特征表示与低秩隐变量之间进行互注意力计算,输出第二特征表示序列;第一特征表示序列是特征提取网络对至少两个示例提取的。
对于一个样本图像来讲,设第一特征表示序列包括该样本图像中的n个示例的第一特征表示。r维的低秩隐变量是预先基于样本图像的秩来设置的,r为样本图像的秩数,该秩数可以是估计值。
假设输入的样本图像为X,对样本图像X进行分割后得到图像切片的集合{x1,x2…xn},其中,每个图像切片xi称为一个示例(instance),n是示例的个数,为大于1的正整数。通过特征提取函数Ff将每个示例进行特征提取后,得到对应空间的特征表示序列z={z1,z2…zn}:
{z1,z2…zn}=Ff({x1,x2…xn})
其中,zi∈R1×d,默认d=1024。
定义互注意力:
其中WQ,WK,WV是三个可学习的矩阵变换;Zl是第一特征表示序列;L是低秩隐变量,是一个可学习的参数;CAtt是互注意力计算函数;softmax是归一化指数函数;Q是一组query集合组成的矩阵,K是一组key集合组成的矩阵,KT是K的转置,V是一组value集合组成的矩阵。由于低秩隐变量L的维度是r×d,r为预设的超参数,d是第一特征表示的特征维度,比如1024维,则第一特征表示序列Zl的维度是n×d,其中r<<n。所以互注意力的计算复杂度为O(rn),远小于自注意力的复杂度O(n2)。
在一些实施例中,上述互注意力网络是将门控注意力机制和互注意力机制结合起来构建的门控注意力模块GAB(GatedAttention Block)。
GAB是将门控注意力机制和互注意力机制结合起来构建的,具体表示为:
其中,GSB是建立在互注意力机制上的门控输出函数;⊙表示对位点乘;L是低秩隐变量;Zl是第一特征表示序列;φU表示Sigmoid Linear Units激活函数,引入非线性因素,让数据更好的被分类;Wo,WU表示线性变换矩阵;CAtt是互注意力计算函数。
在一些实施例中,上述互注意力网络包括第一注意力组件和第二注意力组件。如图6所示,上述步骤362可以包括:
步骤362-a,通过第一注意力组件使用r维的低秩隐变量作为查询,与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H,r是小于n的正整数。
步骤362-b,通过第二注意力组件使用n个特征表示作为查询,与r个低维向量H进行第二互注意力计算,得到第二特征表示序列中的n个第二特征表示。其中,n个特征表示的初始值为第一特征表示序列中的第一特征表示。
以每个注意力组件为GAB为例,图7示出了由2个GAB级联而成的互注意力网络,图8示出了一个GAB的示例性结构。表示为:
公式中,ILRA先用低秩隐变量L作为query,去与关注高维度的第一特征表示序列Zl进行互注意力计算,产生r个低维的向量H;然后第一特征表示序列Zl,再与新生成的r个低维向量H进行互注意力计算,输出中间特征表示序列Zl+1。
经过t轮迭代之后,输出得到第二特征表示序列Zt:
其中,t的数值以实际设计为准,一般默认6-8;Z0为第一个第二特征表示序列;ILRA为输出函数。
步骤364,通过池化层对第二特征表示序列进行池化,得到样本图像的全局特征表示。
步骤366,通过分类器对全局特征表示进行预测,得到样本图像的预测分类结果;基于样本图像的预测分类结果和标签分类结果之间的误差损失,对多示例学习模型的模型参数进行训练。
根据第二特征表示序列{z1,z2…zn},采用多示例学习模型,进行图像分类,表示为:
Logits=σ(pool(φ(z1),φ(z2)…φ(zn)))
其中logits是输出分类结果;σ是分类器;pool是池化函数,用于去除冗余信息,得到一个低分辨率的特征图;φ是每个示例的得分函数,为矩阵的转置,n是示例的个数,为大于1的正整数。通常采用注意力得分函数,表示为:
其中W,U,V都是可学习的参数矩阵;⊙表示对位点乘;tanh为双曲正切函数;sigm为激活函数;n是示例的个数,为大于1的正整数。
在一些实施例中,上述分类器可以采用多层感知机(Multi-Layer Perceptron,MLP)来实现。上述分类预测过程可表示为:
其中,第二特征表示序列Zt={zt,1,zt,2,…zt,n};MLP是线性分类器;zb是维度等于d的全局特征表示,是可学习的参数;logits是分类器输出的预测结果,n是示例的个数,为大于1的正整数。
综上所述,本实施例通过将门控注意力机制和互注意力机制结合起来,可以对特征表示序列中的特征进行筛选,来提高分类器的分类精度。
本实施例通过使用两个GAB,第一个GAB可以将特征表示序列的n个维度降低到低秩隐变量的r个维度,第二个GAB将r个维度再恢复到n个维度,避免特征表示序列维度越来越小的情况。
本实施例通过迭代t轮可以避免一轮学习训练不充分的问题,从而保证提取出来的全局特征表示具有较好的特征表示能力。
在一些实施例中,上述自监督学习可以采用基于多个正样本的对比学习来实现。图9示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图。该方法可以由计算机设备执行。上述步骤340可实现为:
步骤342,通过特征提取网络对至少两个示例进行特征提取,得到第一特征表示序列;第一特征表示序列中的第一特征表示与至少两个示例一一对应;
对于任一样本图像,通过特征提取网络对该样本图像的至少两个示例进行特征提取,得到该样本图像的第一特征表示序列;第一特征表示序列中的第一特征表示与该样本图像的至少两个示例一一对应。
步骤344,对第一特征表示序列中的第一特征表示进行聚类,将与锚点第一特征表示属于同一聚类结果的第一特征表示作为锚点第一特征表示的正样本特征表示;将与锚点第一特征表示属于不同聚类结果的第一特征表示作为锚点第一特征表示的负样本特征表示;
其中,锚点第一特征表示是该样本图像的至少两个示例中的任一示例。对第一特征表示序列中的第一特征表示进行聚类,聚类形式不限。将与锚点第一特征表示属于同一聚类结果的多个第一特征表示作为锚点第一特征表示的正样本特征表示;将与锚点第一特征表示属于不同聚类结果的多个第一特征表示作为锚点第一特征表示的负样本特征表示。
步骤346,基于正样本特征表示和负样本特征表示之间的自监督学习损失,对特征提取网络进行训练。
对比学习是一种常见的自监督学习范式。给定包含n个图像的集合{xi}i∈I,其中I={1,2,…n},n为正整数。每个图像均随机经过两次数据增强(旋转,剪裁,缩放,颜色变换等),得到2n个图像,并分别用两个网络进行特征提取,得到第一特征表示序列,表示为{fi,gi}i∈I。对比学习的优化目标是最小化损失函数:
其中,如果i=j,则ti,j=1,否则为0。sθ(i,j)是fi和gj的余弦相似度函数。对比学习损失函数认为,gi是fi的正样本,仅有一个,而gj,j∈I{i}是负样本,共有n-1个。
对比学习损失函数认为每个图像对应的正样本只有一个,即经过不同图像增强的同一个图像。但实际上,以病理数字化图像WSI(Whole Slide Image)为例的图像存在大量相似的样本,除了图像自身可以作为正样本,应当挖掘此类相似图像补充正样本的数量,使得图像特征的学习过程更加合理。给定图像第一特征表示序列,每个图像特征可以用特征字典D的基向量进行线性表示。示例性的如图10所示,输入包含10个样本示例的第一特征表示序列,可以根据其类别信息,分解成6个基向量的线性组合,对应的系数矩阵会表现出如图10所示的低秩特点。
本申请期望找到这些特征向量的最小的秩:
其中,F=[f1,…,fn];G=[g1,…,gn]。低秩分解表示为:
其中,D是字典;B是块对角的系数矩阵;E是误差矩阵。由于原始数据总有误差和噪声,因此不存在完美的低秩矩阵,矩阵E就是对原始数据的误差进行估计消除。系数矩阵B的秩为r,r<<n。如果能得到图像的类别标签,一般可以认为r等于图像的类别数,由于在对比学习训练的过程中,无法对上述低秩分解表示的公式直接进行计算求解,因此本申请采用如图11所示的方法寻找每个示例的多个正样本。
上述步骤344可以包括如下子步骤,如图11所示:
步骤344-a,将第一特征表示序列中的第一特征表示,按照与锚点第一特征表示的相似度由高到低的顺序进行排序;锚点第一特征表示是第一特征表示序列中的一个。
步骤344-b,将排序后的第一特征表示序列矩阵分解至r个子空间,r为预设的秩数。
步骤344-c,将分解至r个子空间中的第一个子空间中的第一特征表示确定为锚点第一特征表示的正样本特征表示;将分解至r个子空间中的最后一个子空间中的第一特征表示确定为锚点第一特征表示的负样本特征表示。
假设给定锚点第一特征表示fa,锚点第一特征表示为任意一个示例的特征表示,其余第一特征表示是其他示例的特征表示。根据其余第一特征表示和锚点第一特征表示fa的相似度排序,得到集合:
C(a)={i,j∣sθ(a,i)≥sθ(a,j),ifi<j,andi,j∈[1,n]}
即其余第一特征表示和锚点第一特征表示fa的相似度越高,排序越靠前。根据矩阵的低秩分解结果,将上述集合分解为C1(a),C2(a)…Cr(a)共r个子空间,分别与B1,B2…Br对应。由于集合C(a)已经进行了相似度排序,那么C1(a)和Cr(a)就是相距最远的正样本和负样本,应该满足:
其中,ξ=1是相似度函数间隔,p为正样本,n为负样本。加入函数间隔是为了将正样本和负样本分隔得更开。
将上述公式加入最小化损失函数,得到低秩约束损失函数:
该公式将同一样本图像中的同类示例纳入考量,增加同一个示例的正样本数量。当所有的正样本和负样本被正确分开,并达到一定的函数间隔时,损失函数可以被最小化。将低秩约束函数和传统的对比损失函数结合,可以得到最终的自监督学习损失函数:
示例性的,以WSI图像为例,基于数据低秩的特殊属性:随着数据维度的升高,高维数据之间往往存在较多的相关性和冗余度。由于数据本身信息量的增长比数据维度增长慢得多,数据的维度越高,也就使得数据变得越冗余。WSI图像属于超高维度的图像数据,尺寸通常可以达到10000×10000以上,但WSI图像中存在大量相似的组织,细胞等元素,符合低秩数据先验的假设。分析270张来自乳腺癌的WSI图像,并对其进行矩阵低秩分解,得到对应矩阵的秩,绘制出直方图。从图12的(1)所示,采用ImageNet预训练的特征维度是1024,但是矩阵分解后,得到270张WSI图像的秩平均为349,低于1024的满秩情况。图12的(2)是本申请提出的自监督学习模型,WSI图像的秩平均为181,远小于ImageNet特征的秩。从该结果可以看出,低秩是WSI图像存在的典型现象。并且从后续的实验分析得到,自监督学习的秩更低,图像分类效果也更好。
图13示出了本申请一个示例性实施例提供的图像分类模型的使用方法的流程图。本实施例以该方法由计算机设备执行来举例说明。该图像分类模型是根据上述方法训练得到的,所述方法包括:
步骤402,获取待分类的输入图像,输入图像包括至少两个示例。
将输入图像按照切割分辨率进行分割,得到至少两个图像切片;将属于该输入图像的至少两个图像切片确定为该输入图像的至少两个示例。
结合参考图14,假设上述图像分类模型用于病理图像的癌症分类预测。将WSI图像X按照切割分辨率(比如,220像素*220像素)进行分割,得到多个图像切片的集合{x1,x2…xn},n为正整数,每个图像切片是一个示例。
步骤404,将输入图像输入进图像分类模型中的特征提取网络进行特征提取,得到第三特征表示序列,第三特征表示序列包括至少两个示例的特征表示。
将每个示例进行特征提取后得到输入图像的第三特征表示序列z={z1,z2…zn}:
{z1,z2…zn}=Ff({x1,x2…xn})
其中,zu∈R1×d,默认d=1024。
步骤406,将第三特征表示序列输入图像分类模型中的多示例学习模型进行基于互注意力机制的分类预测,得到输入图像的分类结果。
其中,互注意力机制是对计算至少两个示例的特征表示与低秩隐变量之间的注意力的机制,低秩隐变量的数量小于至少两个示例的数量,低秩隐变量是基于样本图像的秩设置的。
在一些实施例中,通过互注意力网络对第三特征表示序列中的第三特征表示与低秩隐变量之间进行互注意力计算,输出第四特征表示序列;第三特征表示序列是特征提取网络对至少两个示例提取得到的;
通过池化层对第四特征表示序列进行池化,得到样本图像的全局特征表示;
通过分类器对全局特征表示进行预测,得到样本图像的预测分类结果;基于样本图像的预测分类结果和标签分类结果之间的误差损失,对多示例学习模型的模型参数进行训练。
在一些实施例中,互注意力网络包括第一注意力组件和第二注意力组件;
通过第一注意力组件使用r维的低秩隐变量作为查询,与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H,r是小于n的正整数;
通过第二注意力组件使用n个特征表示作为查询,与r个低维向量H进行第二互注意力计算,得到第四特征表示序列中的n个第四特征表示;
其中,n个特征表示的初始值为第三特征表示序列中的第三特征表示。
在一些实施例中,通过第二注意力组件使用n个特征表示作为查询,与r个低维向量H进行第二互注意力计算,得到n个中间特征表示;
在不满足迭代结束条件时,将n个中间特征表示作为输入的n个特征表示,再次从通过第一注意力组件使用r维的低秩隐变量作为查询与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H的步骤开始执行;
在满足迭代结束条件时,将n个中间特征表示输出为第四特征表示序列。
在一些实施例中,迭代结束条件为:迭代次数达到预设迭代次数t。
图14示出了本申请一个示例性实施例提供的图像分类模型的训练方法的场景图。在该应用场景中,假设该应用场景是以病理数字化切片进行癌症分析为例。本申请通过多示例学习模型分析病理数字化切片,判断病理图片中是否存在癌变的组织,作为后续治疗的重要依据。本实施例辅助医生诊断病理图像切割得到的病理数字化切片,提供患者是否有癌症的重要参考信息,有助于减小病理医生阅片压力,提高医生判读的一致性,能有效提高分类的精度,提高数字病理诊断的智能化和自动化水平,减少误判、漏判等风险。
图15示出了本申请一个示例性实施例提供的图像分类模型的训练方法的场景图,在该应用场景中,假设该应用场景是以图像示例进行生态环境分析为例。本申请通过多示例学习模型分析图像示例,判断输入图片中是否存在目标动物如鹿、野猪、兔子等,作为生态环境分析的重要依据。本实施例辅助生物学家、护林人等工作人员统计输入图像切割得到的图像示例,提供当地目标动物存在数量的重要参考信息,有助于降低工作人员统计数据难度,能有效提高分类的精度,提高生态环境分析的智能化和自动化水平,减少统计数据人力和物力的消耗。
图16示出了本申请一个示例性实施例提供的图像分类模型的训练装置500的结构框图,图像分类模型包括特征提取网络和多示例学习模型,该装置包括:
获取模块510,用于获取样本图像集,样本图像集中的样本图像包括至少两个示例;
自监督学习模块520,用于通过基于对比学习的自监督学习,采用样本图像集中的样本图像对特征提取网络进行训练,得到训练后的特征提取网络;
多示例学习模块530,用于通过基于互注意力机制的多示例学习,采用样本图像集中的样本图像对多示例学习模型进行训练,得到训练后的多示例学习模型。
在本实施例的一个可选设计中,获取模块510包括:
分割子模块,用于将样本图像集中的样本图像按照切割分辨率进行分割,得到至少两个图像切片;
确定子模块,用于将属于同一个样本图像的至少两个图像切片确定为同一个样本图像的至少两个示例。
在本实施例的一个可选设计中,自监督学习模块520包括:
特征提取子模块,用于通过特征提取网络对至少两个示例进行特征提取,得到第一特征表示序列;第一特征表示序列中的第一特征表示与至少两个示例一一对应;
聚类子模块,用于对第一特征表示序列中的第一特征表示进行聚类,将与锚点第一特征表示属于同一聚类结果的第一特征表示作为锚点第一特征表示的正样本特征表示;将与锚点第一特征表示属于不同聚类结果的第一特征表示作为锚点第一特征表示的负样本特征表示;
训练子模块,用于基于正样本特征表示和负样本特征表示之间的自监督学习损失,对特征提取网络进行训练。
聚类子模块包括:
排序单元,用于将第一特征表示序列中的第一特征表示,按照与锚点第一特征表示的相似度由高到低的顺序进行排序;锚点第一特征表示是第一特征表示序列中的一个;
分解单元,用于将排序后的第一特征表示序列矩阵分解至r个子空间,r为预设的秩数;
确定单元,用于将分解至r个子空间中的第一个子空间中的第一特征表示确定为锚点第一特征表示的正样本特征表示;将分解至r个子空间中的最后一个子空间中的第一特征表示确定为锚点第一特征表示的负样本特征表示。
在本实施例的一个可选设计中,多示例学习模块530包括:
计算子模块,用于通过互注意力网络对第一特征表示序列中的第一特征表示与低秩隐变量之间进行互注意力计算,输出第二特征表示序列;第一特征表示序列是特征提取网络对至少两个示例提取得到的;
池化子模块,用于通过池化层对第二特征表示序列进行池化,得到样本图像的全局特征表示;
预测子模块,用于通过分类器对全局特征表示进行预测,得到样本图像的预测分类结果;基于样本图像的预测分类结果和标签分类结果之间的误差损失,对多示例学习模型的模型参数进行训练。
计算子模块包括:
第一计算单元,用于通过第一注意力组件使用r维的低秩隐变量作为查询,与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H,r是小于n的正整数;
第二计算单元,用于通过第二注意力组件使用n个特征表示作为查询,与r个低维向量H进行第二互注意力计算,得到第二特征表示序列中的n个第二特征表示;
第二计算单元包括:
第一计算子单元,用于通过第二注意力组件使用n个特征表示作为查询,与r个低维向量H进行第二互注意力计算,得到n个中间特征表示;
第二计算子单元,用于在不满足迭代结束条件时,将n个中间特征表示作为输入的n个特征表示,再次从通过第一注意力组件使用r维的低秩隐变量作为查询与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H的步骤开始执行;
输出子单元,用于在满足迭代结束条件时,将n个中间特征表示输出为第二特征表示序列。
图17示出了本申请一个示例性实施例提供的图像分类模型的使用装置600的框图。该装置包括:
获取模块610,用于获取待分类的输入图像,输入图像包括至少两个示例;
特征提取模块620,用于将输入图像输入图像分类模型中的特征提取网络进行特征提取,得到第三特征表示序列,第三特征表示序列包括至少两个示例的特征表示;
预测模块630,用于将第三特征表示序列输入图像分类模型中的多示例学习模型进行基于互注意力机制的分类预测,得到输入图像的分类结果。
在一些实施例中,预测模块630,用于通过互注意力网络对第三特征表示序列中的第三特征表示与低秩隐变量之间进行互注意力计算,输出第四特征表示序列;第三特征表示序列是特征提取网络对至少两个示例提取得到的;通过池化层对第四特征表示序列进行池化,得到样本图像的全局特征表示;通过分类器对全局特征表示进行预测,得到样本图像的预测分类结果;基于样本图像的预测分类结果和标签分类结果之间的误差损失,对多示例学习模型的模型参数进行训练。
在一些实施例中,互注意力网络包括第一注意力组件和第二注意力组件;
预测模块630,用于通过第一注意力组件使用r维的低秩隐变量作为查询,与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H,r是小于n的正整数;
预测模块630,用于通过第二注意力组件使用n个特征表示作为查询,与r个低维向量H进行第二互注意力计算,得到第四特征表示序列中的n个第四特征表示;
其中,n个特征表示的初始值为第三特征表示序列中的第三特征表示。
在一些实施例中,预测模块630,用于通过第二注意力组件使用n个特征表示作为查询,与r个低维向量H进行第二互注意力计算,得到n个中间特征表示;在不满足迭代结束条件时,将n个中间特征表示作为输入的n个特征表示,再次从通过第一注意力组件使用r维的低秩隐变量作为查询与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H的步骤开始执行;在满足迭代结束条件时,将n个中间特征表示输出为第四特征表示序列。
在一些实施例中,迭代结束条件为:迭代次数达到预设迭代次数t。
图18示出了本申请一个示例性实施例提供的计算机设备的结构框图。通常,计算机设备700包括有:处理器701和存储器702。
处理器701可以包括一个或多个处理核心,比如4核心处理器、8核心处理器等。处理器701可以采用数字信号处理(Digital Signal Processing,DSP)、现场可编程门阵列(Field-Programmable Gate Array,FPGA)、可编程逻辑阵列(Programmable Logic Array,PLA)中的至少一种硬件形式来实现。处理器701也可以包括主处理器和协处理器,主处理器是用于对在唤醒状态下的数据进行处理的处理器,也称中央处理器(Central ProcessingUnit,CPU);协处理器是用于对在待机状态下的数据进行处理的低功耗处理器。在一些实施例中,处理器701可以在集成有图像处理器(Graphics Processing Unit,GPU),GPU用于负责显示屏所需要显示的内容的渲染和绘制。一些实施例中,处理器701还可以包括人工智能(Artificial Intelligence,AI)处理器,该AI处理器用于处理有关机器学习的计算操作。
存储器702可以包括一个或多个计算机可读存储介质,该计算机可读存储介质可以是非暂态的。存储器702还可包括高速随机存取存储器,以及非易失性存储器,比如一个或多个磁盘存储设备、闪存存储设备。在一些实施例中,存储器702中的非暂态的计算机可读存储介质用于存储至少一个指令,该至少一个指令用于被处理器701所执行以实现本申请中方法实施例提供的图像分类模型的方法和/或使用方法。
在一些实施例中,服务器700还可选包括有:输入接口703和输出接口704。处理器701、存储器702和输入接口703、输出接口704之间可以通过总线或信号线相连。各个外围设备可以通过总线、信号线或电路板与输入接口703、输出接口704相连。输入接口703、输出接口704可被用于将输入/输出(Input/Output,I/O)相关的至少一个外围设备连接到处理器701和存储器702。在一些实施例中,处理器701、存储器702和输入接口703、输出接口704被集成在同一芯片或电路板上;在一些其他实施例中,处理器701、存储器702和输入接口703、输出接口704中的任意一个或两个可以在单独的芯片或电路板上实现,本申请实施例对此不加以限定。
本领域技术人员可以理解,上述示出的结构并不构成对服务器700的限定,服务器700可以包括比图示更多或更少的组件,或者组合某些组件,或者采用不同的组件布置。
在示例性实施例中,还提供了一种芯片,芯片包括可编程逻辑电路和/或程序指令,当芯片在计算机设备上运行时,用于实现上述方面图像分类模型的训练方法和/或使用方法。
在示例性实施例中,还提供了一种计算机程序产品,该计算机程序产品包括计算机指令,该计算机指令存储在计算机可读存储介质中。计算机设备的处理器从计算机可读存储介质读取该计算机指令,处理器从计算机可读存储介质读取并执行该计算机指令,以实现上述各方法实施例提供的图像分类模型的训练方法和/或使用方法。
在示例性实施例中,还提供了一种计算机可读存储介质,该计算机可读存储介质中存储有计算机程序,计算机程序由处理器加载并执行以实现上述各方法实施例提供的图像分类模型的训练方法和/或使用方法。
本领域普通技术人员可以理解实现上述实施例的全部或部分步骤可以通过硬件来完成,也可以通过程序来指令相关的硬件完成,程序可以存储于一种计算机可读存储介质中,上述提到的存储介质可以是只读存储器,磁盘或光盘等。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本申请实施例所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。计算机可读介质包括计算机存储介质和通信介质,其中通信介质包括便于从一个地方向另一个地方传送计算机程序的任何介质。存储介质可以是通用或专用计算机能够存取的任何可用介质。
以上仅为本申请的可选实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。
Claims (14)
1.一种图像分类模型的训练方法,其特征在于,所述图像分类模型包括特征提取网络和多示例学习模型,所述方法包括:
获取样本图像集,所述样本图像集中的样本图像包括至少两个示例;
通过基于对比学习的自监督学习,采用所述样本图像集中的样本图像对所述特征提取网络进行训练,得到训练后的特征提取网络;
通过基于互注意力机制的多示例学习,采用所述样本图像集中的样本图像对所述多示例学习模型进行训练,得到训练后的多示例学习模型;
其中,所述互注意力机制是计算所述至少两个示例的特征表示与低秩隐变量之间的注意力的机制,所述低秩隐变量的数量小于所述至少两个示例的数量,所述低秩隐变量是基于所述样本图像的秩设置的。
2.根据权利要求1所述的方法,其特征在于,所述多示例学习模型包括:互注意力网络、池化层和分类器;
所述通过基于互注意力机制的多示例学习,采用所述样本图像集中的样本图像对所述多示例学习模型进行训练,得到训练后的多示例学习模型,包括:
通过所述互注意力网络对第一特征表示序列中的第一特征表示与所述低秩隐变量之间进行互注意力计算,输出第二特征表示序列;所述第一特征表示序列是所述特征提取网络对所述至少两个示例提取得到的;
通过所述池化层对所述第二特征表示序列进行池化,得到所述样本图像的全局特征表示;
通过所述分类器对所述全局特征表示进行预测,得到所述样本图像的预测分类结果;基于所述样本图像的预测分类结果和标签分类结果之间的误差损失,对所述多示例学习模型的模型参数进行训练。
3.根据权利要求2所述的方法,其特征在于,所述互注意力网络包括第一注意力组件和第二注意力组件;
所述通过所述互注意力网络对所述第一特征表示序列中的特征表示与所述低秩隐变量之间进行互注意力计算,输出第二特征表示序列,包括:
通过所述第一注意力组件使用r维的所述低秩隐变量作为查询,与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H,r是小于n的正整数;
通过所述第二注意力组件使用所述n个特征表示作为查询,与所述r个低维向量H进行第二互注意力计算,得到第二特征表示序列中的n个第二特征表示;
其中,所述n个特征表示的初始值为所述第一特征表示序列中的第一特征表示。
4.根据权利要求3所述的方法,其特征在于,所述通过所述第二注意力组件使用所述n个特征表示作为查询,与所述r个低维向量H进行第二互注意力计算,得到第二特征表示序列中的n个第二特征表示,包括:
通过所述第二注意力组件使用所述n个特征表示作为查询,与所述r个低维向量H进行第二互注意力计算,得到n个中间特征表示;
在不满足迭代结束条件时,将所述n个中间特征表示作为输入的所述n个特征表示,再次从所述通过所述第一注意力组件使用r维的所述低秩隐变量作为查询与输入的n个特征表示进行第一互注意力计算,得到r个低维向量H的步骤开始执行;
在满足迭代结束条件时,将所述n个中间特征表示输出为第二特征表示序列。
5.根据权利要求3所述的方法,其特征在于,所述第一注意力组件和所述第二注意力组件是基于门控注意力机制和互注意力机制结合构建的。
6.根据权利要求1至5任一所述的方法,其特征在于,
所述通过基于对比学习的自监督学习,采用所述样本图像集中的样本图像对所述特征提取网络进行训练,得到训练后的特征提取网络,包括:
通过所述特征提取网络对所述至少两个示例进行特征提取,得到第一特征表示序列;所述第一特征表示序列中的第一特征表示与所述至少两个示例一一对应;
对所述第一特征表示序列中的第一特征表示进行聚类,将与锚点第一特征表示属于同一聚类结果的第一特征表示作为所述锚点第一特征表示的正样本特征表示;将与锚点第一特征表示属于不同聚类结果的第一特征表示作为所述锚点第一特征表示的负样本特征表示;
基于所述正样本特征表示和所述负样本特征表示之间的自监督学习损失,对所述特征提取网络进行训练。
7.根据权利要求6所述的方法,其特征在于,所述对所述第一特征表示序列中的第一特征表示进行聚类,将与锚点第一特征表示属于同一聚类结果的第一特征表示作为所述锚点第一特征表示的正样本特征表示;将与锚点第一特征表示属于不同聚类结果的第一特征表示作为所述锚点第一特征表示的负样本特征表示,包括:
将所述第一特征表示序列中的第一特征表示,按照与锚点第一特征表示的相似度由高到低的顺序进行排序;所述锚点第一特征表示是所述第一特征表示序列中的一个;
将排序后的所述第一特征表示序列矩阵分解至r个子空间,r为预设的秩数;
将分解至所述r个子空间中的第一个子空间中的第一特征表示确定为所述锚点第一特征表示的正样本特征表示;将分解至所述r个子空间中的最后一个子空间中的第一特征表示确定为所述锚点第一特征表示的负样本特征表示。
8.根据权利要求1至7任一所述的方法,其特征在于,所述方法还包括:
将所述样本图像集中的样本图像按照切割分辨率进行分割,得到至少两个图像切片;
将属于同一个样本图像的所述至少两个图像切片确定为所述同一个样本图像的所述至少两个示例。
9.一种图像分类模型的使用方法,其特征在于,所述图像分类模型是根据权利要求1至8任一所述的方法训练得到的,所述方法包括:
获取待分类的输入图像,所述输入图像包括至少两个示例;
将所述输入图像输入所述图像分类模型中的特征提取网络进行特征提取,得到第三特征表示序列,所述第三特征表示序列包括所述至少两个示例的特征表示;
将所述第三特征表示序列输入所述图像分类模型中的多示例学习模型进行基于互注意力机制的分类预测,得到所述输入图像的分类结果;
其中,所述互注意力机制是对所述计算所述至少两个示例的特征表示与低秩隐变量之间的注意力的机制,所述低秩隐变量的数量小于所述至少两个示例的数量,所述低秩隐变量是基于所述样本图像的秩设置的。
10.一种图像分类模型的训练装置,其特征在于,所述图像分类模型包括特征提取网络和多示例学习模型,所述装置包括:
获取模块,用于获取样本图像集,所述样本图像集中的样本图像包括至少两个示例;
自监督学习模块,用于通过基于对比学习的自监督学习,采用所述样本图像集中的样本图像对所述特征提取网络进行训练,得到训练后的特征提取网络;
多示例学习模块,用于通过基于互注意力机制的多示例学习,采用所述样本图像集中的样本图像对所述多示例学习模型进行训练,得到训练后的多示例学习模型;
其中,所述互注意力机制是计算所述至少两个示例的特征表示与低秩隐变量之间的注意力的机制,所述低秩隐变量的数量小于所述至少两个示例的数量,所述低秩隐变量是基于所述样本图像的秩设置的。
11.一种图像分类模型的使用装置,其特征在于,所述图像分类模型是根据权利要求10所述的装置训练得到的,所述装置包括:
获取模块,用于获取待分类的输入图像,所述输入图像包括至少两个示例;
特征提取模块,用于将所述输入图像输入所述图像分类模型中的特征提取网络进行特征提取,得到第三特征表示序列,所述第三特征表示序列包括所述至少两个示例的特征表示;
预测模块,用于将所述第三特征表示序列输入所述图像分类模型中的多示例学习模型进行基于互注意力机制的分类预测,得到所述输入图像的分类结果;
其中,所述互注意力机制是对所述计算所述至少两个示例的特征表示与低秩隐变量之间的注意力的机制,所述低秩隐变量的数量小于所述至少两个示例的数量,所述低秩隐变量是基于所述样本图像的秩设置的。
12.一种计算机设备,其特征在于,所述计算机设备包括:处理器和存储器,所述存储器存储有计算机程序,所述计算机程序由所述处理器加载并执行以实现如权利要求1至8任一所述图像分类模型的训练方法,和/或如权利要求9所述的图像分类模型的使用方法。
13.一种计算机可读存储介质,其特征在于,所述计算机可读存储介质存储有计算机程序,所述计算机程序由处理器加载并执行以实现如权利要求1至8任一所述图像分类模型的训练方法,和/或如权利要求9所述的图像分类模型的使用方法。
14.一种计算机程序产品,其特征在于,所述计算机程序产品包括计算机指令,所述计算机指令存储在计算机可读存储介质中,处理器从所述计算机可读存储介质中获取所述计算机指令,使得所述处理器加载并执行以实现如权利要求1至8任一所述图像分类模型的训练方法,和/或如权利要求9所述的图像分类模型的使用方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210885176.4A CN115238888A (zh) | 2022-07-26 | 2022-07-26 | 图像分类模型的训练方法、使用方法、装置、设备及介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210885176.4A CN115238888A (zh) | 2022-07-26 | 2022-07-26 | 图像分类模型的训练方法、使用方法、装置、设备及介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN115238888A true CN115238888A (zh) | 2022-10-25 |
Family
ID=83674830
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210885176.4A Pending CN115238888A (zh) | 2022-07-26 | 2022-07-26 | 图像分类模型的训练方法、使用方法、装置、设备及介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115238888A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116524302A (zh) * | 2023-05-05 | 2023-08-01 | 广州市智慧城市投资运营有限公司 | 一种场景识别模型的训练方法、装置及存储介质 |
-
2022
- 2022-07-26 CN CN202210885176.4A patent/CN115238888A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116524302A (zh) * | 2023-05-05 | 2023-08-01 | 广州市智慧城市投资运营有限公司 | 一种场景识别模型的训练方法、装置及存储介质 |
CN116524302B (zh) * | 2023-05-05 | 2024-01-26 | 广州市智慧城市投资运营有限公司 | 一种场景识别模型的训练方法、装置及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2020238293A1 (zh) | 图像分类方法、神经网络的训练方法及装置 | |
CN111898696A (zh) | 伪标签及标签预测模型的生成方法、装置、介质及设备 | |
CN111898703B (zh) | 多标签视频分类方法、模型训练方法、装置及介质 | |
CN113469088A (zh) | 一种无源干扰场景下的sar图像舰船目标检测方法及系统 | |
WO2024060684A1 (zh) | 模型训练方法、图像处理方法、设备及存储介质 | |
CN114693624A (zh) | 一种图像检测方法、装置、设备及可读存储介质 | |
CN111507403A (zh) | 图像分类方法、装置、计算机设备和存储介质 | |
CN112786160A (zh) | 基于图神经网络的多图片输入的多标签胃镜图片分类方法 | |
US20230055263A1 (en) | Stratification in non-classified heterogeneous object labels | |
CN116994021A (zh) | 图像检测方法、装置、计算机可读介质及电子设备 | |
CN115238888A (zh) | 图像分类模型的训练方法、使用方法、装置、设备及介质 | |
CN111445545B (zh) | 一种文本转贴图方法、装置、存储介质及电子设备 | |
CN113704534A (zh) | 图像处理方法、装置及计算机设备 | |
CN115424275B (zh) | 一种基于深度学习技术的渔船船牌号识别方法及系统 | |
CN114973107B (zh) | 基于多鉴别器协同和强弱共享机制的无监督跨域视频动作识别方法 | |
CN111768214A (zh) | 产品属性的预测方法、系统、设备和存储介质 | |
CN112801153B (zh) | 一种嵌入lbp特征的图的半监督图像分类方法及系统 | |
CN115115910A (zh) | 图像处理模型的训练方法、使用方法、装置、设备及介质 | |
CN115080745A (zh) | 基于人工智能的多场景文本分类方法、装置、设备及介质 | |
CN114168780A (zh) | 多模态数据处理方法、电子设备及存储介质 | |
Ju et al. | A novel neutrosophic logic svm (n-svm) and its application to image categorization | |
CN114692715A (zh) | 一种样本标注方法及装置 | |
CN111814865A (zh) | 一种图像识别方法、装置、设备及存储介质 | |
CN111625672B (zh) | 图像处理方法、装置、计算机设备及存储介质 | |
CN109408706B (zh) | 一种图像过滤方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |