CN115797732B - 用于开放类别场景下的图像检索模型训练方法及系统 - Google Patents
用于开放类别场景下的图像检索模型训练方法及系统 Download PDFInfo
- Publication number
- CN115797732B CN115797732B CN202310113191.1A CN202310113191A CN115797732B CN 115797732 B CN115797732 B CN 115797732B CN 202310113191 A CN202310113191 A CN 202310113191A CN 115797732 B CN115797732 B CN 115797732B
- Authority
- CN
- China
- Prior art keywords
- training
- data
- model
- classification
- training set
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02D—CLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
- Y02D10/00—Energy efficient computing, e.g. low power processors, power management or thermal management
Landscapes
- Image Analysis (AREA)
Abstract
本发明属于图像检索技术领域,具体涉及用于开放类别场景下的图像检索模型训练方法及系统。方法包括:S1,采用独立进行新增数据集的数据类别标注,并获得N个类别定义没有依赖的独立的训练集;S2,采用动态扩展分类头的方式构建模型,所述模型包含1个特征提取网络和N个分类头;S3,根据N个训练集数据构建出N个dataloader工具,分别对应着N个分类头;在一个epoch的训练过程中,计算出各个训练集的batch批处理数;根据其中一次训练的轮次,确定从对应的dataloader中采样数据。本发明具有能够提升开放类别场景下图像检索模型训练数据的扩充效率,能够满足实际应用场景下频繁的训练数据扩充和频繁的模型优化迭代需求的特点。
Description
技术领域
本发明属于图像检索技术领域,具体涉及用于开放类别场景下的图像检索模型训练方法及系统。
背景技术
图像检索(CBIR)技术是深度学习领域里的一个重大研究课题。而深度学习技术是一种依赖于大量数据样本的技术,在实际应用中,为了训练出性能优异的图像检索模型,往往需要大量的标注数据作为训练集;为了持续优化模型,往往需要不断收集来自各个场景的新样本数据加入现有训练集中用于模型训练,以期训练出更加鲁棒和准确的图像检索模型,实现模型的不断迭代优化。
对于图像检索技术,扩充训练集主要有以下两种方案:第一种是直接将新的训练数据加入到原训练集中,和原训练集数据一起用于模型训练;第二种是通过迁移学习的方式,使用已经训练好的模型在新的数据集上微调,来更新模型。
在图像检索技术领域,从数据集类别上可分为封闭类别场景和开放类别场景。所谓封闭类别,即训练集的样本类别数是固定的,扩充的训练集数据的样本类别是包含在原训练集样本类别内的,在数据扩充时不会出现也无需考虑新的类别样本。所谓开放类别场景,即训练集的样本类别数是不固定的,扩充的训练集数据样本可能包含新的类别数据。而以上方案虽然都能达到训练集数据扩充的目的,但是对于第一种方案来说,当在大规模开放类别场景下扩充训练数据的时候,一般会面临两个问题,第一:标注人员在标注新数据时,需要将新数据逐个与原训练集中的样本对比,以确定新数据的样本是属于一个新的类别,还是属于原训练集中的某个类别,当实际生产应用中,为了保证模型更新优化速度,需要频繁扩充训练数据,这就使得数据扩充的效率非常低下,导致模型的迭代效率低下。第二:实际生产任务中,为增强深度学习模型在不同场景下的识别能力,往往会为训练集扩充不同场景的数据,而不同场景之间的相同或相似样本的类别定义一般难以把握,就很有可能出现类别的错误定义,从而导致模型训练效果不佳。
对于第二种方案来说,虽然在新数据标注的时候无需关心与原训练集数据的类别是否重复,从而能够提升开放类别场景数据扩充标注效率。但是由于每次更新模型,都是使用前一次训练的模型仅在新增数据上微调迭代,当模型迭代达到一定次数之后,由于深度学习模型的遗忘机制,最早几批的训练数据对模型的贡献会逐渐减弱,从而使得新模型在最早几批训练数据对应的测试数据上效果变差,达不到持续扩充训练数据从而持续优化模型的目的。
针对开放类别场景下的图像检索模型训练任务,有以下问题:
第一,现有的传统数据扩充方案受标注效率、标注准确率以及深度学习模型遗忘机制的限制,无法满足实际使用中频繁新增数据、持续优化模型的需求;
第二,现有的模型构建是静态的,分类头输出类别数由训练集全部样本决定,由于传统数据扩充方案的局限性,使得当前的模型构建方案无法满足实际应用中频繁新增数据、持续优化模型的需求。
第三,现有的模型训练方案中的训练数据采样及模型参数更新策略,在实际应用中,仅适用于传统的数据扩充策略和模型构建方案。由于传统数据扩充及模型构建方案的局限性,使得现有的模型训练策略无法满足实际应用中频繁新增数据、持续优化模型的需求。具体来说:
1.数据扩充策略存在局限性
对于图像检索技术,扩充训练集主要有以下两种方案:第一种是直接将新的训练数据加入到原训练集中,和原训练集数据一起用于模型训练;第二种是通过迁移学习的方式,使用已经训练好的模型在新的数据集上微调,来更新模型。
以上方案都能达到训练集数据扩充的目的,但是对于第一种方案来说,在大规模开放类别场景下扩充训练数据的时候,一般会面临两个问题,第一:标注人员在标注新数据时,需要将新数据逐个与原训练集中的样本对比,以确定新数据的样本是属于一个新的类别,还是属于原训练集中的某个类别,当实际生产应用中,为了保证模型更新优化速度,需要频繁扩充训练数据,这就使得数据扩充的效率非常低下,导致模型的迭代效率低下。第二:实际生产任务中,为增强深度学习模型在不同场景下的识别能力,往往会为训练集扩充不同场景的数据,而不同场景之间的相同或相似样本的类别定义一般难以把握,就很有可能出现类别的错误定义,从而导致模型训练效果不佳。
对于第二种方案来说,虽然在新数据标注的时候无需关心与原训练集数据的类别是否重复的问题,但是由于每次更新模型,都是使用前一次训练的模型仅在新增数据上微调,当模型更新达到一定次数之后,由于深度学习模型的遗忘机制,最早几批的训练数据对模型的贡献会逐渐减弱,从而使得新模型在最早几批训练数据对应的测试数据上效果变差,达不到持续扩充训练数据从而持续优化模型的目的。
2.模型构建方案存在局限性
目前的图像检索模型的训练结构一般由特征提取网络和分类头组成,分类头的输出类别数由训练集全部样本决定,传统单一分类头的模型构建形式,当使用上述第一种方案的数据扩充方式时,由于开放类别场景下的新增元素的类别定义往往难以把握,容易出现类别定义冲突(如相同类别样本被标注为不同类别,或不同类别样本被标注为相同类别),因此会影响分类损失和对比损失计算的正确性。当使用上述第二种方案的数据扩充方式时,在新增训练集的类别数和原训练集类别数不一致的情况下,原训练集所训练的模型的分类头权重也无法用于新增训练集的迁移学习,从而会加重模型的遗忘现象。这样一来,当前模型构建方式在开放类别场景的实际使用中就有较大的局限性。
3.模型训练方案存在局限性
在当前的模型训练方案中,一方面,训练数据的采样策略是基于单一训练集的,这就要求数据扩充需按照传统扩充策略将新增数据扩充至原训练集数据中,或者仅使用新增数据集进行迁移学习微调迭代模型,而如上所述,这两种数据扩充策略都具有一定的局限性。另一方面,当前对于训练数据的采样及训练方案是基于单一分类头的模型结构,如上所述,单一分类头的模型结构在训练时存在一定的局限性。综上,当前的模型训练方案存在一定的局限性。
因此,设计一种能够提升开放类别场景下图像检索模型训练数据的扩充效率,能够满足实际应用场景下频繁的训练数据扩充和频繁的模型优化迭代需求的用于开放类别场景下的图像检索模型训练方法及系统,就显得十分重要。
发明内容
本发明是为了克服现有技术中,现有传统数据扩充方案存在标注效率低、标注准确度难以把控,传统模型构建方案只能适应传统数据扩充方式,传统采样方式只能基于单一训练集的问题,提供了一种能够提升开放类别场景下图像检索模型训练数据的扩充效率,能够满足实际应用场景下频繁的训练数据扩充和频繁的模型优化迭代需求的用于开放类别场景下的图像检索模型训练方法及系统。
为了达到上述发明目的,本发明采用以下技术方案:
用于开放类别场景下的图像检索模型训练方法,包括如下步骤;
S1,数据扩充:
其中,所述数据扩充采用独立进行新增数据集的数据类别标注,并获得N个类别定义没有依赖的独立的训练集;
S2,模型构建:
其中,所述模型构建采用动态扩展分类头的方式构建模型,所述模型包含1个特征提取网络和N个分类头;所述分类头的输入维度与特征提取网络输出的特征维度一致;各个分类头的输出维度与对应的训练集样本类别数一致;
S3,模型训练:
其中,所述模型训练过程如下:
根据N个训练集数据构建出N个dataloader工具,分别对应着N个分类头;在一个epoch的训练过程中,计算出各个训练集的batch批处理数;一个epoch的总训练轮次为N个训练集的batch数目之和;
根据其中一次训练的轮次,确定从对应的dataloader中采样数据;所述采样数据经过特征提取网络提取特征。
作为优选,步骤S1中,每个新增数据集的类别标注索引均可从头开始,数据集之间的类别定义互不干涉;每个训练集分别独立贡献最终的损失函数计算。
作为优选,步骤S2中,模型的构建是动态的,且随着训练集数量的扩充,分类头的数量始终与训练集的数量保持一致。
作为优选,步骤S2中,各个分类头均由一个或多个全连接层组成;各个分类头均只参与训练过程。
作为优选,步骤S3包括如下步骤:
训练数据采样策略:
S31,设定有原训练集D_0, 扩充训练集D_1, 扩充训练集D_2,共三个训练集;三个训练集分别含有N_0,N_1, N_2个样本,分别含有C_1, C_2, C_3个类别;
S32,计算各个训练集在一个epoch内被采样的batch批处理数;设定批处理数为b,即batch_size=b,则三个训练集分别需要被采样B0=N_0/b,B1=N_1/b,B2=N_2/b次,一个epoch内总共需要采样次数B=(B0+B1+B2)次;
S33,在每个epoch训练期间,记录当前采样次数为b_n;当b_n<=B0时,从训练集D_0中采样数据;当B0<b_n<=B0+B1时,从训练集D_1中采样数据;当B0+B1<b_n<=B0+B1+B2时,从训练集D_2中采样数据;
S34,在下一个epoch训练时,重复步骤S32至步骤S33采样过程;
其中,一个epoch等于使用训练集中的全部样本训练一次。
作为优选,步骤S3还包括如下步骤:
在训练数据采样策略的基础上,一个轮次的模型训练策略具体过程如下:
S35,设定当前批处理数batch_size为b的采样数据,来自于训练集D_1,经过特征提取网络,得到b个特征向量V_b;
S36,打开分类头c1, 输入为V_b,输出为b个采样数据各自的所属类别概率分布,与b个采样数据的标签计算得到分类损失;根据b个采样数据的类别标签,在V_b内部给各个特征向量构建正负样本对关系,进而计算得到对比损失;
S37,分类损失回传更新分类头c1参数和特征提取网络参数,对比损失反向传播更新特征提取网络参数,完成一个轮次模型的训练。
本发明还提供了用于开放类别场景下的图像检索模型训练系统包括;
数据扩充模块,用于独立进行新增数据集的数据类别标注,获得N个类别定义没有依赖的独立的训练集;
模型构建模块,用于采用动态扩展分类头的方式构建模型,所述模型包含1个特征提取网络和N个分类头;所述分类头的输入维度与特征提取网络输出的特征维度一致;各个分类头的输出维度与对应的训练集样本类别数一致;
模型训练模块,用于根据N个训练集数据构建出N个dataloader工具,分别对应着N个分类头;在一个epoch的训练过程中,计算出各个训练集的batch批处理数;一个epoch的总训练轮次为N个训练集的batch数目之和;根据其中一次训练的轮次,确定从对应的dataloader中采样数据;所述采样数据经过特征提取网络提取特征。
本发明与现有技术相比,有益效果是:(1)在数据扩充阶段,本发明所提出的数据扩充方案大大提升了新增数据的标注效率,并且不会存在由不同场景相似图像带来的类别定义歧义问题,从而保证了新增数据标注的准确率;(2)在模型构建阶段,本发明所提出的动态扩展分类头的模型训练结构设计能够灵活有效地随时应对动态的训练集扩充;(3)在模型训练阶段,本发明所提出的“循环读取、定向采样”的训练集数据采样策略,一方面确保了每次采样的一个batch的训练数据都来自于N个训练集中的一个,无论是涉及到分类损失的类别定义,还是涉及到对比损失的正负样本对的定义,都不会存在新旧数据的类别定义歧义而引发的模型训练问题;另一方面保证了一个epoch内,所有的训练集样本都参与了训练,在整个训练过程中,所有的训练集保持一个循环参与训练的状态,而不是像迁移学习那样先学习完旧训练集,再学习新训练集的轮流学习状态,避免了由于深度学习模型的遗忘机制引发的遗忘现象;同时,本发明所提出的“定向更新”的模型参数更新策略,实现了所有训练集各自训练各自的分类头,共同训练特征提取模型的目的。
附图说明
图1为本发明中用于开放类别场景下的图像检索模型训练方法的一种流程图;
图2为本发明方法中数据扩充的一种示意图;
图3为本发明方法中模型构建的一种示意图;
图4为本发明方法中一个轮次的模型训练过程的一种示意图;
图5为本发明实施例所提供的用于开放类别场景下的图像检索模型训练方法的一种实际业务流程图。
实施方式
为了更清楚地说明本发明实施例,下面将对照附图说明本发明的具体实施方式。显而易见地,下面描述中的附图仅仅是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图,并获得其他的实施方式。
如图1所示,本发明提供了用于开放类别场景下的图像检索模型训练方法,包括如下步骤;
S1,数据扩充:
其中,所述数据扩充采用独立进行新增数据集的数据类别标注,并获得N个类别定义没有依赖的独立的训练集;
S2,模型构建:
其中,所述模型构建采用动态扩展分类头的方式构建模型,所述模型包含1个特征提取网络和N个分类头;所述分类头的输入维度与特征提取网络输出的特征维度一致;各个分类头的输出维度与对应的训练集样本类别数一致;
S3,模型训练:
其中,所述模型训练过程如下:
根据N个训练集数据构建出N个dataloader工具,分别对应着N个分类头;在一个epoch的训练过程中,计算出各个训练集的batch批处理数;一个epoch的总训练轮次为N个训练集的batch数目之和;
根据其中一次训练的轮次,确定从对应的dataloader中采样数据;所述采样数据经过特征提取网络提取特征。
DataLoader是深度学习框架Pytorch中用来处理模型输入数据的一个工具类,一般在模型训练之前,会利用训练集构建一个Dataloader可迭代对象,在训练时,从Dataloader按批次取数据进行模型训练。
如图2所示,扩充数据阶段所要完成的功能是,标注和扩充训练集数据。这里与传统训练集标注和扩充方案不一样的是,由于本发明所设计的多训练集独立采样、联合训练的训练方案能够允许新增训练集数据与原训练集数据有类别定义重叠,甚至类别定义歧义,因此可以对新增数据进行独立的数据类别标注,无需考虑原训练集。每个新增数据集的类别标注索引均可从头开始,数据集之间的类别定义互不干涉,这样在经过多次数据新增之后,就得到了多个独立标注的训练集数据。这样的新数据类别定义方式大大提升了新增数据的标注效率,并且不会存在由不同场景相似图像带来的类别定义歧义问题,从而保证了新增数据标注的准确率。
如图3所示,为了适应上述扩充数据方式,模型采用动态扩展分类头的方式构建,所构建的模型包含一个特征提取网络和N个分类头(N是训练集个数),各个分类头由一个或多个全连接层组成,分类头的输入维度为特征提取网络输出的特征向量的维度,分类头的输出维度与对应训练集样本类别数一致,N个分类头均只参与训练过程,不参与推理过程,扩展的分类头不增加推理耗时。这种网络结构的设计能够灵活有效地随时应对动态的训练集扩充。
模型训练阶段所要完成的功能主要包括两个方面,一方面是训练数据采样策略设计,所设计的采样策略有以下三个优点:1.保证了所有的训练数据都能够参与训练;2.不会在数据的采样上发生类别冲突(类别采样冲突是指一个batch内既含有训练集n1的类别i的数据,又含有训练集n2的类别i的数据);3.保证一个epoch内所有的训练集参与训练,使得在整个训练期间所有的训练集是循环参与训练,避免了某个或者某些训练集主导模型训练过程,从而避免了模型遗忘现象的出现。另一方面是训练策略设计,在训练时,所采样的数据来自于哪个训练集,就只由与那个训练集对应的分类头产生分类损失,并与对比损失一起进行回传,完成一次模型参数的更新。
具体的,训练数据采样策略过程如下:
S31,设定有原训练集D_0, 扩充训练集D_1, 扩充训练集D_2,共三个训练集;三个训练集分别含有N_0,N_1, N_2个样本,分别含有C_1, C_2, C_3个类别;
S32,计算各个训练集在一个epoch内被采样的batch批处理数;设定批处理数为b,即batch_size=b,则三个训练集分别需要被采样B0=N_0/b,B1=N_1/b,B2=N_2/b次,一个epoch内总共需要采样次数B=(B0+B1+B2)次;
S33,在每个epoch训练期间,记录当前采样次数为b_n;当b_n<=B0时,从训练集D_0中采样数据;当B0<b_n<=B0+B1时,从训练集D_1中采样数据;当B0+B1<b_n<=B0+B1+B2时,从训练集D_2中采样数据;
S34,在下一个epoch训练时,重复步骤S32至步骤S33采样过程;
其中,一个epoch等于使用训练集中的全部样本训练一次。
在上述采样策略的基础上,一个轮次的模型训练策略可以描述为如下过程,流程如图4所示:
S35,设定当前批处理数batch_size为b的采样数据,来自于训练集D_1,经过特征提取网络,得到b个特征向量V_b;
S36,打开分类头c1, 输入为V_b,输出为b个采样数据各自的所属类别概率分布,与b个采样数据的标签计算得到分类损失;根据b个采样数据的类别标签,在V_b内部给各个特征向量构建正负样本对关系,进而计算得到对比损失;
S37,分类损失回传更新分类头c1参数和特征提取网络参数,对比损失反向传播更新特征提取网络参数,完成一个轮次模型的训练。
通过这种方式,首先,确保了每次采样的一个batch的训练数据都来自于N个训练集中的一个,无论是涉及到分类损失的类别定义,还是涉及到对比损失的正负样本对的定义,都不会存在由不同场景相似图像导致的类别定义歧义而引发的模型训练问题。其次,这种采样方式保证了一个epoch内,所有的训练集样本都参与了训练,在整个训练过程中,所有的训练集保持一个循环参与训练的状态,而不是像迁移学习那样先学习完旧训练集,再学习新训练集的轮流学习状态,避免了由于深度学习模型的遗忘机制引发的遗忘现象。
本发明还提供了用于开放类别场景下的图像检索模型训练系统包括;
数据扩充模块,用于独立进行新增数据集的数据类别标注,获得N个类别定义没有依赖的独立的训练集;
模型构建模块,用于采用动态扩展分类头的方式构建模型,所述模型包含1个特征提取网络和N个分类头;所述分类头的输入维度与特征提取网络输出的特征维度一致;各个分类头的输出维度与对应的训练集样本类别数一致;
模型训练模块,用于根据N个训练集数据构建出N个dataloader工具,分别对应着N个分类头;在一个epoch的训练过程中,计算出各个训练集的batch批处理数;一个epoch的总训练轮次为N个训练集的batch数目之和;根据其中一次训练的轮次,确定从对应的dataloader中采样数据;所述采样数据经过特征提取网络提取特征。
基于本发明的技术方案,如图5所示,通过一个实际使用案例展示了本发明的实施方案,及其部分有益效果:
1.数据扩充阶段:在RPA(Robotic Process Automation)场景下,在已经有了软件“企业微信”图标训练集D1的基础上,新增软件“钉钉”图标的标注训练集D2,D2的类别定义从头开始,不考虑和D1中的样本类别是否有重合,得到两个完全独立的训练集D1和D2。新增软件“Excel”图标的标注训练集D3,D3的类别定义从头开始,不考虑和D1以及D2中的样本类别是否有重合,至此,完成对训练集D1的扩充,得到三个训练集D1、D2和D3,包含样本数分别为N1、N2和N3,样本类别数分别为C1、C2和C3。如图5中数据扩充部分所示。
2. 模型构建阶段:以resnet18作为特征提取网络,resnet18是一种具有残差结构的卷积神经网络模型,残差结构的存在使其能够缓解模型训练过程中的梯度消失的问题。输入图像大小为64*64,输出特征向量长度为512,在resnet18后,为D1、D2和D3设计分类头CLS_1, CLS_2, CLS_3,其中CLS_1由一个输入维度为512,输出维度为C1的全连接层构成,CLS_2由一个输入维度为512,输出维度为C2的全连接层构成,CLS_3由一个输入维度为512,输出维度为C3的全连接层构成。如图5中模型构建部分所示。
3.模型训练阶段:设置batch_size=64, 在一个epoch的训练过程中,首先每个训练集需要采样训练数据次数分别为N1//64=b1次、N2//64=b2次、N3//64=b3次,一个epoch的总训练轮次为b1+b2+b3次。在训练轮次iter<b1时,训练数据采样自训练集D1;在训练轮次b1<=iter<b1+b2时,训练数据采样自训练集D2;在训练轮次b1+b2<=iter<b1+b2+b3时,训练数据采样自训练集D3。图5以训练轮次34(b1<34<b1+b2)为例,采样数据输入特征提取网络resnet18,得到64个训练数据的512维特征,计算Triplet Loss(对比损失)并回传损失更新resnet18模型参数,由于采样自训练集D2,因此打开分类头CLS_2, 关闭另外两个分类头,计算Cross Entropy Loss(分类损失),然后通过反向传播完成对resnet18和分类头CLS_2参数的更新。如图5中一个轮次的训练数据采样部分及双线箭头流程所示。
4.一个epoch内其他轮次的训练过程与上述第34轮次训练过程一致,这样在一个epoch内就能够使所有训练集参与了特征提取模型的训练。在整个训练过程中,循环执行一个epoch的训练过程,使所有训练集循环参与模型训练,直至训练结束。
本发明创新性的提出一种开放类别场景下的图像检索模型训练方法,针对开放类别场景下的图像检索模型的训练集扩充难题,在方法中设计了一整套的数据扩充、模型构建以及模型训练策略,大大简化了开放类别场景下的图像检索训练集扩充流程,提升了训练集扩充效率,提升了训练集标注的准确性,从而大大提升了开放类别场景下的图像检索模型迭代优化效率,大大提升了模型对不同场景数据的适应能力,增强了模型在实际应用场景中落地的能力。
本发明的创新点具体如下:
1.本发明创新性地提出开放类别场景下图像检索模型训练数据扩充的独立标注方案,以及多训练集的数据扩充方式。简化了训练集数据扩充流程,提升了扩充数据的标注效率,避免了扩充数据的标注歧义,提升了标注准确率。
2.本发明创新性地提出动态分类头的图像检索训练模型结构设计,提升了特征提取模型的训练对训练集形式的适应能力。
3.本发明创新性地提出针对完全独立的多训练集的训练数据采样及模型训练方案,保证了多个训练集在模型训练过程中合理的参与机制。
以上所述仅是对本发明的优选实施例及原理进行了详细说明,对本领域的普通技术人员而言,依据本发明提供的思想,在具体实施方式上会有改变之处,而这些改变也应视为本发明的保护范围。
Claims (3)
1.用于开放类别场景下的图像检索模型训练方法,其特征在于,包括如下步骤;
S1,数据扩充:
其中,所述数据扩充采用独立进行新增数据集的数据类别标注,并获得N个类别定义没有依赖的独立的训练集;
S2,模型构建:
其中,所述模型构建采用动态扩展分类头的方式构建模型,所述模型包含1个特征提取网络和N个分类头;所述分类头的输入维度与特征提取网络输出的特征维度一致;各个分类头的输出维度与对应的训练集样本类别数一致;
S3,模型训练:
其中,所述模型训练过程如下:
根据N个训练集数据构建出N个dataloader工具,分别对应着N个分类头;在一个epoch的训练过程中,计算出各个训练集的batch批处理数;一个epoch的总训练轮次为N个训练集的batch数目之和;
根据其中一次训练的轮次,确定从对应的dataloader中采样数据;所述采样数据经过特征提取网络提取特征;
步骤S2中,模型的构建是动态的,且随着训练集数量的扩充,分类头的数量始终与训练集的数量保持一致;
步骤S2中,各个分类头均由一个或多个全连接层组成;各个分类头均只参与训练过程;
步骤S3包括如下步骤:
训练数据采样策略:
S31,设定有原训练集D_0,扩充训练集D_1,扩充训练集D_2,共三个训练集;三个训练集分别含有N_0,N_1,N_2个样本,分别含有C_1,C_2,C_3个类别;
S32,计算各个训练集在一个epoch内被采样的batch批处理数;设定批处理数为b,即batch_size=b,则三个训练集分别需要被采样B0=N_0/b,B1=N_1/b,B2=N_2/b次,一个epoch内总共需要采样次数B=(B0+B1+B2)次;
S33,在每个epoch训练期间,记录当前采样次数为b_n;当b_n<=B0时,从训练集D_0中采样数据;当B0<b_n<=B0+B1时,从训练集D_1中采样数据;当B0+B1<b_n<=B0+B1+B2时,从训练集D_2中采样数据;
S34,在下一个epoch训练时,重复步骤S32至步骤S33采样过程;
其中,一个epoch等于使用训练集中的全部样本训练一次;
步骤S3还包括如下步骤:
在训练数据采样策略的基础上,一个轮次的模型训练策略具体过程如下:
S35,设定当前批处理数batch_size为b的采样数据,来自于训练集D_1,经过特征提取网络,得到b个特征向量V_b;
S36,打开分类头c1,输入为V_b,输出为b个采样数据各自的所属类别概率分布,与b个采样数据的标签计算得到分类损失;根据b个采样数据的类别标签,在V_b内部给各个特征向量构建正负样本对关系,进而计算得到对比损失;
S37,分类损失回传更新分类头c1参数和特征提取网络参数,对比损失反向传播更新特征提取网络参数,完成一个轮次模型的训练。
2.根据权利要求1所述的用于开放类别场景下的图像检索模型训练方法,其特征在于,步骤S1中,每个新增数据集的类别标注索引均可从头开始,数据集之间的类别定义互不干涉;每个训练集分别独立贡献最终的损失函数计算。
3.用于开放类别场景下的图像检索模型训练系统,用于实现权利要求1-2任一项所述的用于开放类别场景下的图像检索模型训练方法,其特征在于,所述用于开放类别场景下的图像检索模型训练系统包括;
数据扩充模块,用于独立进行新增数据集的数据类别标注,获得N个类别定义没有依赖的独立的训练集;
模型构建模块,用于采用动态扩展分类头的方式构建模型,所述模型包含1个特征提取网络和N个分类头;所述分类头的输入维度与特征提取网络输出的特征维度一致;各个分类头的输出维度与对应的训练集样本类别数一致;
模型训练模块,用于根据N个训练集数据构建出N个dataloader工具,分别对应着N个分类头;在一个epoch的训练过程中,计算出各个训练集的batch批处理数;一个epoch的总训练轮次为N个训练集的batch数目之和;根据其中一次训练的轮次,确定从对应的dataloader中采样数据;所述采样数据经过特征提取网络提取特征。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310113191.1A CN115797732B (zh) | 2023-02-15 | 2023-02-15 | 用于开放类别场景下的图像检索模型训练方法及系统 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310113191.1A CN115797732B (zh) | 2023-02-15 | 2023-02-15 | 用于开放类别场景下的图像检索模型训练方法及系统 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115797732A CN115797732A (zh) | 2023-03-14 |
CN115797732B true CN115797732B (zh) | 2023-06-09 |
Family
ID=85430984
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310113191.1A Active CN115797732B (zh) | 2023-02-15 | 2023-02-15 | 用于开放类别场景下的图像检索模型训练方法及系统 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115797732B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116935107A (zh) * | 2023-07-12 | 2023-10-24 | 中国科学院自动化研究所 | 基于互联网搜索的检测类别自扩展目标检测方法及装置 |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115661539A (zh) * | 2022-11-03 | 2023-01-31 | 南京邮电大学 | 一种嵌入不确定性信息的少样本图像识别方法 |
Family Cites Families (7)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN109685121B (zh) * | 2018-12-11 | 2023-07-18 | 中国科学院苏州纳米技术与纳米仿生研究所 | 图像检索模型的训练方法、图像检索方法、计算机设备 |
US11295171B2 (en) * | 2019-10-18 | 2022-04-05 | Google Llc | Framework for training machine-learned models on extremely large datasets |
CN111914928B (zh) * | 2020-07-30 | 2024-04-09 | 南京大学 | 一种为图像分类器进行对抗样本防御的方法 |
CN111898547B (zh) * | 2020-07-31 | 2024-04-16 | 平安科技(深圳)有限公司 | 人脸识别模型的训练方法、装置、设备及存储介质 |
CN111814913A (zh) * | 2020-08-20 | 2020-10-23 | 深圳市欢太科技有限公司 | 图像分类模型的训练方法、装置、电子设备及存储介质 |
CN114491036A (zh) * | 2022-01-25 | 2022-05-13 | 四川启睿克科技有限公司 | 一种基于自监督和有监督联合训练的半监督文本分类方法及系统 |
CN115471700A (zh) * | 2022-09-16 | 2022-12-13 | 中国科学院计算技术研究所 | 一种基于知识传输的图像分类模型训练方法及分类方法 |
-
2023
- 2023-02-15 CN CN202310113191.1A patent/CN115797732B/zh active Active
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115661539A (zh) * | 2022-11-03 | 2023-01-31 | 南京邮电大学 | 一种嵌入不确定性信息的少样本图像识别方法 |
Non-Patent Citations (3)
Title |
---|
A Simple and Efficient Ensemble Classifier Combining Multiple Neural Network Models on Social Media Datasets in Vietnamese;Huynh H D 等;arxiv;全文 * |
动态置信度的序列选择增量学习方法;李念;廖闻剑;彭艳兵;;计算机系统应用(02);全文 * |
结合场景分类数据的高分遥感图像语义分割方法;秦亿青 等;计算机应用与软件(06);全文 * |
Also Published As
Publication number | Publication date |
---|---|
CN115797732A (zh) | 2023-03-14 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN108647577B (zh) | 一种自适应难例挖掘的行人重识别方法与系统 | |
CN115797732B (zh) | 用于开放类别场景下的图像检索模型训练方法及系统 | |
CN112685504B (zh) | 一种面向生产过程的分布式迁移图学习方法 | |
CN111581454A (zh) | 基于深度图压缩算法的并行查询表现预测系统及方法 | |
CN113761221B (zh) | 基于图神经网络的知识图谱实体对齐方法 | |
CN110674326A (zh) | 一种基于多项式分布学习的神经网络结构检索方法 | |
CN115391553B (zh) | 一种自动搜索时序知识图谱补全模型的方法 | |
CN112131402A (zh) | 一种基于蛋白质家族聚类的ppi知识图谱表示学习方法 | |
CN112381208A (zh) | 一种深度渐进且逐步寻优的神经网络架构搜索方法与系统 | |
CN111191785A (zh) | 一种基于拓展搜索空间的结构搜索方法 | |
CN113297429A (zh) | 一种基于神经网络架构搜索的社交网络链路预测方法 | |
CN116151324A (zh) | 基于图神经网络的rc互连延时预测方法 | |
CN114969367A (zh) | 基于多方面子任务交互的跨语言实体对齐方法 | |
Lu et al. | Research on optimization method of computer network service quality based on feature matching algorithm | |
CN113987203A (zh) | 一种基于仿射变换与偏置建模的知识图谱推理方法与系统 | |
CN110543478A (zh) | 公共层宽表建设方法、装置及服务器 | |
Sood et al. | Neunets: An automated synthesis engine for neural network design | |
CN112163069B (zh) | 一种基于图神经网络节点特征传播优化的文本分类方法 | |
CN114065770B (zh) | 一种基于图神经网络构建语义知识库的方法及系统 | |
WO2023273171A1 (zh) | 图像处理方法、装置、设备和存储介质 | |
CN115759470A (zh) | 一种基于机器学习的航班飞行全过程燃油消耗预测方法 | |
CN115860119A (zh) | 基于动态元学习的少样本知识图谱补全方法和系统 | |
CN112307914B (zh) | 一种基于文本信息指导的开放域图像内容识别方法 | |
CN114611668A (zh) | 一种基于异质信息网络随机游走的向量表示学习方法及系统 | |
CN114943328A (zh) | 基于bp神经网络非线性组合的sarima-gru时序预测模型 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |