CN113609927B - 基于分支学习和分层伪标签的行人重识别网络训练方法 - Google Patents
基于分支学习和分层伪标签的行人重识别网络训练方法 Download PDFInfo
- Publication number
- CN113609927B CN113609927B CN202110812690.0A CN202110812690A CN113609927B CN 113609927 B CN113609927 B CN 113609927B CN 202110812690 A CN202110812690 A CN 202110812690A CN 113609927 B CN113609927 B CN 113609927B
- Authority
- CN
- China
- Prior art keywords
- data
- tag data
- pseudo
- training
- branch
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
- 238000012549 training Methods 0.000 title claims abstract description 101
- 238000000034 method Methods 0.000 title claims abstract description 51
- 230000006870 function Effects 0.000 claims abstract description 20
- 238000012360 testing method Methods 0.000 claims description 6
- 238000012935 Averaging Methods 0.000 claims description 3
- 238000013528 artificial neural network Methods 0.000 claims description 3
- 238000004364 calculation method Methods 0.000 claims description 3
- 238000000605 extraction Methods 0.000 claims description 3
- 230000004927 fusion Effects 0.000 claims description 2
- 238000011160 research Methods 0.000 description 4
- 230000000694 effects Effects 0.000 description 3
- 230000006978 adaptation Effects 0.000 description 2
- 230000000750 progressive effect Effects 0.000 description 2
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000011840 criminal investigation Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000010586 diagram Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000003064 k means clustering Methods 0.000 description 1
- 238000005065 mining Methods 0.000 description 1
- 238000003909 pattern recognition Methods 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
- G06F18/232—Non-hierarchical techniques
- G06F18/2321—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions
- G06F18/23213—Non-hierarchical techniques using statistics or function optimisation, e.g. modelling of probability density functions with fixed number of clusters, e.g. K-means clustering
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Artificial Intelligence (AREA)
- Life Sciences & Earth Sciences (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Probability & Statistics with Applications (AREA)
- Image Analysis (AREA)
Abstract
本发明涉及一种基于分支学习和分层伪标签的行人重识别网络训练方法,行人重识别网络为相互平均教学网络,训练方法包括:获取标签数据集和无标签数据集,将标签数据集作为一层,将无标签数据集分为N层,并对各层的无标签数据分别赋值伪标签,形成N层伪标签数据,N为常数;构建分支学习框架,包括N+1个共享权重的相互平均教学网络分支,其中一个分支用于输入标签数据进行训练,其余N个分支分别对应输入N层伪标签数据进行训练;构建各分支的损失函数,确定分支学习框架的总损失函数,基于总损失函数进行多轮训练,每一轮训练过程中对无标签数据集重新进行分层。与现有技术相比,本发明训练的网络更加精确,训练时网络的收敛速度快。
Description
技术领域
本发明涉及一种行人重识别网络训练方法,尤其是涉及一种基于分支学习和分层伪标签的行人重识别网络训练方法。
背景技术
行人重识别是一项跨域识别同一行人的任务,在目标自动识别中有着重要的地位。最近几年,有许多研究重点关注需要大量标注数据的全监督行人重识别,然而在生活中,大量标注数据往往会消耗大量的人力和时间成本,且在一些情境下,如刑侦调查时,往往缺少大量的标注数据,而每个行人仅有一张标注图像以供网络训练。由此引出单样本行人重识别这一具有意义的研究课题。
目前针对单样本的行人重识别,已经有了一些有价值的研究。有一些研究通过丰富行人的特征来致力于增加识别的精度,一些研究通过扩大训练数据集的规模来提升网络的效果,进而达到提高识别率的效果。通常,扩大训练数据集又有两个思路,一是生成新的可训练数据,二是为无标签数据赋值伪标签将其转换成标签数据参与训练。生成新数据的方法虽然能有效增大训练数据的规模,但其无法充分挖掘已有的标签数据的信息。于是伪标签法成为了应用更为广泛的半监督学习方法。伪标签法分为半监督学习的伪标签法和无监督学习的伪标签法,其中,半监督的伪标签法包括标签传播法和k近邻聚类等,无监督伪标签法包括K-means聚类和DBSCAN聚类等。目前已有的伪标签法大部分仅单用一种方法,然而不同的伪标签法有着不同的适用范围,能从不同视角对无标签数据赋值伪标签,仅用一种方法会限制其使用效果。更重要的是,对于大部分的伪标签法而言,伪标签数据往往被视作和标签数据具有同等的地位,并将它们混合在一起进行训练。实际上,伪标签数据的噪声导致它不能提供和标签数据一样准确的信息,而且不同伪标签法获得的伪标签数据也具有不同的噪声,因而需要将它们分组分别进行训练。不同类型的数据又具有不同的特点,因而对不同组的数据使用相同的损失函数是不合理的,需要针对不同组的特点设计个性化的损失函数。
发明内容
本发明的目的就是为了克服上述现有技术存在的缺陷而提供一种基于分支学习和分层伪标签的行人重识别网络训练方法。
本发明的目的可以通过以下技术方案来实现:
一种基于分支学习和分层伪标签的行人重识别网络训练方法,所述的行人重识别网络为相互平均教学网络,所述的相互平均教学网络包括两个结构相同的网络Net1和Net2以及对应的平均网络Mean Net1和Mean Net2,所述的训练方法包括:
获取标签数据集和无标签数据集,将标签数据集作为一层,将无标签数据集分为N层,并对各层的无标签数据分别赋值伪标签,形成N层伪标签数据,N为常数;
构建分支学习框架,包括N+1个共享权重的相互平均教学网络分支,其中一个分支用于输入标签数据进行训练,其余N个分支分别对应输入N层伪标签数据进行训练;
构建各分支的损失函数,确定分支学习框架的总损失函数,基于总损失函数对分支学习框架进行多轮训练,每一轮训练过程中对无标签数据集重新进行分层。
优选地,所述的无标签数据集分为2层,具体为:将与标签数据集中的标签数据距离较近的若干无标签数据分作一层,剩余无标签数据分作一层。
优选地,所述的无标签数据集分层的具体方式为:
对标签数据集和无标签数据集中的标签数据和无标签数据/>分别采用特征提取器进行特征提取,标签数据特征记作/>无标签数据特征记作/>θo为特征提取器;
计算无标签数据集中任意一个无标签数据和标签数据集中任意一个标签数据间的欧式距离并取最小值,计算公式为:
其中,||·||表示欧氏距离,L表示标签数据集;
将无标签数据对应的/>由小到大排序,选取前p个无标签数据作为第一层伪标签数据,称作最近邻伪标签数据,其余无标签数据剔除掉其中的聚类离群点后作为第二层伪标签数据,称作聚类伪标签数据。
优选地,在每一轮训练过程中更新p的大小,更新方式表示为:
其中,U表示无标签数据集中,|U|表示无标签数据集的样本个数,0<γ<1,epoch为训练轮数。
优选地,对各层无标签数据赋值伪标签的方式为:
对于最近邻伪标签数据,将与其欧式距离最小的有标签数据的标签作为此最近邻伪标签数据的伪标签;
对于聚类伪标签数据,基于提取的特征对全部有标签数据和无标签数据进行聚类,将属于同一聚类类型中的有标签数据的标签作为该聚类类型中聚类伪标签数据的伪标签。
优选地,采用DBSCAN聚类法对全部有标签数据和无标签数据进行聚类。
优选地,在多轮训练过程中,所述的特征提取器θo不断更新,更新方式为:
首轮训练时,采用预设的Resnet50神经网络作为该轮训练的特征提取器;
第k轮训练时,提取k-1轮训练过的相互平均教学网络,选择Net1和Net2中测试指标mAP更高的一个网络去掉分类器作为第k轮训练的特征提取器,k≥2。
优选地,所述的分支学习框架的总损失函数记作L,表示为:
其中,分别表示输入标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,/>分别表示输入最近邻伪标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,/> 分别表示输入聚类伪标签数据分支的难样本三元组损失和软三元组损失,LBD表示输入标签数据分支的类间距离损失,LGC表示输入最近邻伪标签数据分支的全局中心损失,λ1,λ2,α1表示权重。
优选地,输入标签数据分支的类间距离损失LBD表示为:
LBD=LBD-1+LBD-2
其中,LBD-1表示用于训练Net1的类间距离损失,LBD-2表示用于训练Net2的类间距离损失,其中,LB表示当前训练批次的训练样本集,NB表示训练样本集LB中的样本数,和/>表示LB中的标签数据样本,/>和/>分别为输入标签数据分支中相互平均教学网络的Net1和Net2提取出标签数据/>的特征,/>分别为输入标签数据分支中相互平均教学网络的Net1和Net2提取出标签数据/>的特征,θ1、θ2表示Net1和Net2的特征提取器,||·||表示欧氏距离。
优选地,输入最近邻伪标签数据分支的全局中心损失LGC通过如下方式获得:
对于标签数据其对应标签为j,输入标签数据分支中的相互平均教学网络的平均网络Mean Net1和Mean Net2提取出标签数据/>的特征为/>和/>ET[θ1]、ET[θ2]分别为所述的平均网络Mean Net1和Mean Net2的特征提取器,将两个特征进行融合,并将融合结果记为标签j的全局类中心Cj,其表达式为:
采用一个记忆模块来存储这些全局类中心,每完成一轮训练更新一次全局类中心的大小;
第一轮训练时输入最近邻伪标签数据分支的全局中心损失LGC取作0;
从第二轮训练开始,输入最近邻伪标签数据分支的全局中心损失LGC通过下式获得:
其中,表示第i个最近邻伪标签数据,NB表示最近邻伪标签数据的总个数,分别为输入最近邻伪标签数据分支中的相互平均教学网络中Net1和Net2提取出最近邻伪标签数据/>的特征,θ1、θ2表示Net1和Net2的特征提取器,yi表示/>的伪标签。
与现有技术相比,本发明具有如下优点:
(1)本发明方法能够充分挖掘无标签数据的信息,为网络提供内容更丰富的训练数据,使得训练的网络更加精确;
(2)本发明方法能有效缩短训练时网络的收敛速度。
附图说明
图1为本发明一种基于分支学习和分层伪标签的行人重识别网络训练方法的流程示意图。
具体实施方式
下面结合附图和具体实施例对本发明进行详细说明。注意,以下的实施方式的说明只是实质上的例示,本发明并不意在对其适用物或其用途进行限定,且本发明并不限定于以下的实施方式。
实施例
本实施例提供一种基于分支学习和分层伪标签的行人重识别网络训练方法,行人重识别网络为相互平均教学网络(MMT网络),相互平均教学网络为现有的网络结构为2020年发表于International Conference on Learning Representations(ICLR)中的文章“Mutual mean-teaching:Pseudo label refinery for unsupervised domainadaptation on person re-identification”提出的一种新型网络结构,其包含两个结构相同的网络Net1和Net2,以及它们对应的的平均网络Mean Net1和Mean Net2。
如图1所示,本实施例提供的训练方法包括:
获取标签数据集和无标签数据集,将标签数据集作为一层,将无标签数据集分为N层,并对各层的无标签数据分别赋值伪标签,形成N层伪标签数据,N为常数;
构建分支学习框架,包括N+1个共享权重的相互平均教学网络分支,其中一个分支用于输入标签数据进行训练,其余N个分支分别对应输入N层伪标签数据进行训练;
构建各分支的损失函数,确定分支学习框架的总损失函数,基于总损失函数对分支学习框架进行多轮训练,每一轮训练过程中对无标签数据集重新进行分层,重复训练直至网络收敛到最好结果。
无标签数据集分为2层,具体为:将与标签数据集中的标签数据距离较近的若干无标签数据分作一层,剩余无标签数据分作一层。
无标签数据集分层的具体方式为:
对标签数据集和无标签数据集中的标签数据和无标签数据/>分别采用特征提取器进行特征提取,标签数据特征记作/>无标签数据特征记作/>θo表示特征提取器;
计算无标签数据集中任意一个无标签数据和标签数据集中任意一个标签数据间的欧式距离并取最小值,计算公式为:
其中,||·||表示欧氏距离,L表示标签数据集;
将无标签数据对应的/>由小到大排序,选取前p个无标签数据作为第一层伪标签数据,称作最近邻伪标签数据,其余无标签数据剔除掉其中的聚类离群点后作为第二层伪标签数据,称作聚类伪标签数据。
在每一轮训练过程中更新p的大小,更新方式表示为:
其中,U表示无标签数据集中,|U|表示无标签数据集的样本个数,0<γ<1,epoch为训练轮数。
对各层无标签数据赋值伪标签的方式为:
对于最近邻伪标签数据,将与其欧式距离最小的有标签数据的标签作为此最近邻伪标签数据的伪标签;
对于聚类伪标签数据,基于提取的特征对全部有标签数据和无标签数据进行聚类,将属于同一聚类类型中的有标签数据的标签作为该聚类类型中聚类伪标签数据的伪标签,采用DBSCAN聚类法对全部有标签数据和无标签数据进行聚类。
在多轮训练过程中,特征提取器θo不断更新,更新方式为:
首轮训练时,采用预设的Resnet50神经网络作为该轮训练的特征提取器;
第k轮训练时,提取k-1轮训练过的相互平均教学网络,选择Net1和Net2中测试指标mAP更高的一个网络去掉分类器作为第k轮训练的特征提取器,k≥2。
由此,本实施例将标签数据集和无标签数据集数据分为3层,分别为标签数据层、最近邻为标签数据层和聚类伪标签数据层,从而分支学习框架包括3个共享权重的相互平均教学网络分支。在训练过程中,不断更新最近邻为标签数据层和聚类伪标签数据层中的数据,从而使得网络识别精度越来越好。
将标签数据层、最近邻为标签数据层和聚类伪标签数据层这三个层上的数据分别输入到不同的共享权重的MMT分支,并对每个分支用不同的损失函数进行训练。对于标签数据分支,采用分类损失、软分类损失、难样本三元组损失、软三元组损失和设计的类间距离损失进行训练;对于最近邻伪标签数据分支,同样采用了分类损失、软分类损失、难样本三元组损失、软三元组损失进行训练,并额外为其设计了全局中心损失以使训练向着缩小类间距的方向进行;对于聚类伪标签数据分支,因为这些数据的伪标签来源于聚类算法而非标签数据,其伪标签不能代表行人身份信息,因而其不能用分类损失和软分类损失而仅能用难样本三元组损失和软三元组损失进行训练。上述分类损失、软分类损失、难样本三元组损失、软三元组损失均为文章“Mutual mean-teaching:Pseudo label refinery forunsupervised domain adaptation on person re-identification”中提出的几种损失函数,在本实施例中不详细说明。
对于最近邻为标签数据层,设计了类间距离损失,目的是让网络对不同的类之间有更好的区分度,且难样本三元组损失仅学习一层中距离最近的负样本对,而忽略了其他负样本对的学习,可能会导致学习信息的丢失。类间距离损失的主要思想是由于所有的标签数据均不属于同一个类别,因而我们将标签数据在特征空间中彼此推开,输入标签数据分支的类间距离损失LBD表示为:
LBD=LBD-1+LBD-2
其中,LBD-1表示用于训练Net1的类间距离损失,LBD-2表示用于训练Net2的类间距离损失,其中,LB表示当前训练批次的训练样本集,NB表示训练样本集LB中的样本数,和/>表示LB中的标签数据样本,/>和/>分别为输入标签数据分支中相互平均教学网络的Net1和Net2提取出标签数据/>的特征,/>分别为输入标签数据分支中相互平均教学网络的Net1和Net2提取出标签数据/>的特征,θ1、θ2表示Net1和Net2的特征提取器,||·||表示欧氏距离。
传统的中心损失仅针对一层中的数据而不是全体训练数据,这会导致其在行人重识别上的应用受到限制。而且多分支学习框架中,上文提到的损失函数仅能学习同一层上的数据而不能学习不同层上的数据。针对这两点,设计了全局中心损失,其核心思想是使第二层的伪标签数据能紧密围绕在对应的标签数据周围。因此,输入最近邻伪标签数据分支的全局中心损失LGC通过如下方式获得:
对于标签数据其对应标签为j,输入标签数据分支中的相互平均教学网络的平均网络Mean Net1和Mean Net2提取出标签数据/>的特征为/>和/>ET[θ1]、ET[θ2]分别为平均网络Mean Net1和Mean Net2的特征提取器,将两个特征进行融合,并将融合结果记为标签j的全局类中心Cj,其表达式为:
采用一个记忆模块来存储这些全局类中心,每完成一轮训练更新一次全局类中心的大小;
第一轮训练时输入最近邻伪标签数据分支的全局中心损失LGC取作0;
从第二轮训练开始,输入最近邻伪标签数据分支的全局中心损失LGC通过下式获得:
其中,表示第i个最近邻伪标签数据,NB表示最近邻伪标签数据的总个数,分别为输入最近邻伪标签数据分支中的相互平均教学网络中Net1和Net2提取出最近邻伪标签数据/>的特征,θ1、θ2表示Net1和Net2的特征提取器,yi表示/>的伪标签。
由此,分支学习框架的总损失函数记作L,表示为:
其中,分别表示输入标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,/>分别表示输入最近邻伪标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,/> 分别表示输入聚类伪标签数据分支的难样本三元组损失和软三元组损失,LBD表示输入标签数据分支的类间距离损失,LGC表示输入最近邻伪标签数据分支的全局中心损失,λ1,λ2,α1表示权重。
本实施例在Market-1501和DukeMTMC-reID数据集上进行了实验,和其它最新的单样本行人重识别结果对比如下表所示:
表1不同方法性能对比表
注,上表中【1】~【5】为参考文献,具体如下:
【1】Y.Wu,Y.Lin,X.Dong,Y.Yan,W.Bian,Y.Yang,Progressive learning forperson reidentification with one example,IEEE Transactions on ImageProcessing PP(6)(2019)1–1.
【2】D.Xia,H.Liu,L.Xu,J.Li,L.Wang,Self-training with one-shot stepwiselearning method for person re-identifification,in:CONCURRENCY ANDCOMPUTATION-PRACTICE&EXPERIENCE,2021.doi:10.1002/cpe.6296.
【3】Y.Zhang,B.Ma,L.Liu,X.Yi,Self-Paced Uncertainty Estimation for One-shot Person Re-Identifification,arXiv e-prints(2021)arXiv:2104.09152arXiv:2104.09152.
【4】T.Xu,J.Li,H.Wu,H.Yang,Y.Chen,Feature space regularization forperson re-identifification with one sample,in:2019IEEE 31st InternationalConference on Tools with Artifificial Intelligence(ICTAI),2019.
【5】H.Li,J.Xiao,M.Sun,E.G.Lim,Y.Zhao,Progressive sample mining andrepresentation learning for one-shot person re-identifification,PATTERNRECOGNITION 110.doi:10.1016/j.patcog.2020.107614.
由上表可知,本发明方法能在有限的标签训练样本的条件下,充分利用了全部无标签数据信息,并专业化地对不同类别数据进行分组训练,从而训练出一个性能较好的网络以完成行人重识别任务,我们的方法比目前存在的单样本行人重识别方法更有效和先进。
上述实施方式仅为例举,不表示对本发明范围的限定。这些实施方式还能以其它各种方式来实施,且能在不脱离本发明技术思想的范围内作各种省略、置换、变更。
Claims (5)
1.一种基于分支学习和分层伪标签的行人重识别网络训练方法,所述的行人重识别网络为相互平均教学网络,所述的相互平均教学网络包括两个结构相同的网络Net1和Net2以及对应的平均网络Mean Net1和Mean Net2,其特征在于,所述的训练方法包括:
获取标签数据集和无标签数据集,将标签数据集作为一层,将无标签数据集分为N层,并对各层的无标签数据分别赋值伪标签,形成N层伪标签数据,N为常数;
构建分支学习框架,包括N+1个共享权重的相互平均教学网络分支,其中一个分支用于输入标签数据进行训练,其余N个分支分别对应输入N层伪标签数据进行训练;
构建各分支的损失函数,确定分支学习框架的总损失函数,基于总损失函数对分支学习框架进行多轮训练,每一轮训练过程中对无标签数据集重新进行分层;
所述的无标签数据集分为2层,具体为:将与标签数据集中的标签数据距离较近的若干无标签数据分作一层,剩余无标签数据分作一层;
所述的无标签数据集分层的具体方式为:
对标签数据集和无标签数据集中的标签数据和无标签数据/>分别采用特征提取器进行特征提取,标签数据特征记作/>无标签数据特征记作/>θo为特征提取器;
计算无标签数据集中任意一个无标签数据和标签数据集中任意一个标签数据/>间的欧式距离并取最小值,计算公式为:
其中,||·||表示欧氏距离,L表示标签数据集;
将无标签数据对应的/>由小到大排序,选取前p个无标签数据作为第一层伪标签数据,称作最近邻伪标签数据,其余无标签数据剔除掉其中的聚类离群点后作为第二层伪标签数据,称作聚类伪标签数据;
所述的分支学习框架的总损失函数记作L,表示为:
其中,分别表示输入标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,/>分别表示输入最近邻伪标签数据分支的分类损失、软分类损失、难样本三元组损失和软三元组损失,/> 分别表示输入聚类伪标签数据分支的难样本三元组损失和软三元组损失,LBD表示输入标签数据分支的类间距离损失,LGC表示输入最近邻伪标签数据分支的全局中心损失,λ1,λ2,α1表示权重;
输入标签数据分支的类间距离损失LBD表示为:
LBD=LBD-1+LBD-2
其中,LBD-1表示用于训练Net1的类间距离损失,LBD-2表示用于训练Net2的类间距离损失,其中,LB表示当前训练批次的训练样本集,NB表示训练样本集LB中的样本数,和/>表示LB中的标签数据样本,/>和/>分别为输入标签数据分支中相互平均教学网络的Net1和Net2提取出标签数据/>的特征,/>分别为输入标签数据分支中相互平均教学网络的Net1和Net2提取出标签数据/>的特征,θ1、θ2表示Net1和Net2的特征提取器,||·||表示欧氏距离;
输入最近邻伪标签数据分支的全局中心损失LGC通过如下方式获得:
对于标签数据其对应标签为j,输入标签数据分支中的相互平均教学网络的平均网络Mean Net1和Mean Net2提取出标签数据/>的特征为/>和/>ET[θ1]、ET[θ2]分别为所述的平均网络Mean Net1和Mean Net2的特征提取器,将两个特征进行融合,并将融合结果记为标签j的全局类中心Cj,其表达式为:
采用一个记忆模块来存储这些全局类中心,每完成一轮训练更新一次全局类中心的大小;
第一轮训练时输入最近邻伪标签数据分支的全局中心损失LGC取作0;
从第二轮训练开始,输入最近邻伪标签数据分支的全局中心损失LGC通过下式获得:
其中,表示第i个最近邻伪标签数据,NB表示最近邻伪标签数据的总个数,分别为输入最近邻伪标签数据分支中的相互平均教学网络中Net1和Net2提取出最近邻伪标签数据/>的特征,θ1、θ2表示Net1和Net2的特征提取器,yi表示/>的伪标签。
2.根据权利要求1所述的一种基于分支学习和分层伪标签的行人重识别网络训练方法,其特征在于,在每一轮训练过程中更新p的大小,更新方式表示为:
其中,U表示无标签数据集,|U|表示无标签数据集的样本个数,0<γ<1,epoch为训练轮数。
3.根据权利要求1所述的一种基于分支学习和分层伪标签的行人重识别网络训练方法,其特征在于,对各层无标签数据赋值伪标签的方式为:
对于最近邻伪标签数据,将与其欧式距离最小的有标签数据的标签作为此最近邻伪标签数据的伪标签;
对于聚类伪标签数据,基于提取的特征对全部有标签数据和无标签数据进行聚类,将属于同一聚类类型中的有标签数据的标签作为该聚类类型中聚类伪标签数据的伪标签。
4.根据权利要求3所述的一种基于分支学习和分层伪标签的行人重识别网络训练方法,其特征在于,采用DBSCAN聚类法对全部有标签数据和无标签数据进行聚类。
5.根据权利要求1所述的一种基于分支学习和分层伪标签的行人重识别网络训练方法,其特征在于,在多轮训练过程中,所述的特征提取器θo不断更新,更新方式为:
首轮训练时,采用预设的Resnet50神经网络作为该轮训练的特征提取器;
第k轮训练时,提取k-1轮训练过的相互平均教学网络,选择Net1和Net2中测试指标mAP更高的一个网络去掉分类器作为第k轮训练的特征提取器,k≥2。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110812690.0A CN113609927B (zh) | 2021-07-19 | 2021-07-19 | 基于分支学习和分层伪标签的行人重识别网络训练方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110812690.0A CN113609927B (zh) | 2021-07-19 | 2021-07-19 | 基于分支学习和分层伪标签的行人重识别网络训练方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113609927A CN113609927A (zh) | 2021-11-05 |
CN113609927B true CN113609927B (zh) | 2023-09-29 |
Family
ID=78337875
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110812690.0A Active CN113609927B (zh) | 2021-07-19 | 2021-07-19 | 基于分支学习和分层伪标签的行人重识别网络训练方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113609927B (zh) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113989596B (zh) * | 2021-12-23 | 2022-03-22 | 深圳佑驾创新科技有限公司 | 图像分类模型的训练方法及计算机可读存储介质 |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111967294A (zh) * | 2020-06-23 | 2020-11-20 | 南昌大学 | 一种无监督域自适应的行人重识别方法 |
CN112016687A (zh) * | 2020-08-20 | 2020-12-01 | 浙江大学 | 一种基于互补伪标签的跨域行人重识别方法 |
CN112131961A (zh) * | 2020-08-28 | 2020-12-25 | 中国海洋大学 | 一种基于单样本的半监督行人重识别方法 |
CN112418331A (zh) * | 2020-11-26 | 2021-02-26 | 国网甘肃省电力公司电力科学研究院 | 一种基于聚类融合的半监督学习伪标签赋值方法 |
Family Cites Families (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110008842A (zh) * | 2019-03-09 | 2019-07-12 | 同济大学 | 一种基于深度多损失融合模型的行人重识别方法 |
-
2021
- 2021-07-19 CN CN202110812690.0A patent/CN113609927B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111967294A (zh) * | 2020-06-23 | 2020-11-20 | 南昌大学 | 一种无监督域自适应的行人重识别方法 |
CN112016687A (zh) * | 2020-08-20 | 2020-12-01 | 浙江大学 | 一种基于互补伪标签的跨域行人重识别方法 |
CN112131961A (zh) * | 2020-08-28 | 2020-12-25 | 中国海洋大学 | 一种基于单样本的半监督行人重识别方法 |
CN112418331A (zh) * | 2020-11-26 | 2021-02-26 | 国网甘肃省电力公司电力科学研究院 | 一种基于聚类融合的半监督学习伪标签赋值方法 |
Non-Patent Citations (1)
Title |
---|
半监督单样本深度行人重识别方法;单纯;王敏;计算机系统应用(第001期);256-260 * |
Also Published As
Publication number | Publication date |
---|---|
CN113609927A (zh) | 2021-11-05 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN111967294B (zh) | 一种无监督域自适应的行人重识别方法 | |
Roy et al. | Curriculum graph co-teaching for multi-target domain adaptation | |
US10719780B2 (en) | Efficient machine learning method | |
CN108090472B (zh) | 基于多通道一致性特征的行人重识别方法及其系统 | |
CN102314614B (zh) | 一种基于类共享多核学习的图像语义分类方法 | |
CN110532379B (zh) | 一种基于lstm的用户评论情感分析的电子资讯推荐方法 | |
CN109933804A (zh) | 融合主题信息与双向lstm的关键词抽取方法 | |
CN108647595B (zh) | 基于多属性深度特征的车辆重识别方法 | |
CN113326731A (zh) | 一种基于动量网络指导的跨域行人重识别算法 | |
CN110347791B (zh) | 一种基于多标签分类卷积神经网络的题目推荐方法 | |
CN111832511A (zh) | 一种增强样本数据的无监督行人重识别方法 | |
CN113627463A (zh) | 基于多视图对比学习的引文网络图表示学习系统及方法 | |
CN112819065A (zh) | 基于多重聚类信息的无监督行人难样本挖掘方法和系统 | |
Yuan et al. | One-shot learning for fine-grained relation extraction via convolutional siamese neural network | |
CN111080551B (zh) | 基于深度卷积特征和语义近邻的多标签图像补全方法 | |
CN111950528A (zh) | 图表识别模型训练方法以及装置 | |
US20180349766A1 (en) | Prediction guided sequential data learning method | |
CN112766378A (zh) | 一种专注细粒度识别的跨域小样本图像分类模型方法 | |
CN111898665A (zh) | 基于邻居样本信息引导的跨域行人再识别方法 | |
CN113609927B (zh) | 基于分支学习和分层伪标签的行人重识别网络训练方法 | |
CN111125396B (zh) | 一种单模型多分支结构的图像检索方法 | |
CN113779283B (zh) | 一种深度监督与特征融合的细粒度跨媒体检索方法 | |
Wang et al. | Learning multiple semantic knowledge for cross-domain unsupervised vehicle re-identification | |
CN114463812B (zh) | 基于双通道多分支融合特征蒸馏的低分辨率人脸识别方法 | |
CN109543038B (zh) | 一种应用于文本数据的情感分析方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |