CN114724007A - 训练分类模型、数据分类方法、装置、设备、介质及产品 - Google Patents
训练分类模型、数据分类方法、装置、设备、介质及产品 Download PDFInfo
- Publication number
- CN114724007A CN114724007A CN202210336174.XA CN202210336174A CN114724007A CN 114724007 A CN114724007 A CN 114724007A CN 202210336174 A CN202210336174 A CN 202210336174A CN 114724007 A CN114724007 A CN 114724007A
- Authority
- CN
- China
- Prior art keywords
- classification
- training image
- image
- loss
- training
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 297
- 238000013145 classification model Methods 0.000 title claims abstract description 207
- 238000000034 method Methods 0.000 title claims abstract description 81
- 238000002372 labelling Methods 0.000 claims description 29
- 230000006870 function Effects 0.000 claims description 24
- 238000004590 computer program Methods 0.000 claims description 13
- 238000003860 storage Methods 0.000 claims description 13
- 238000000605 extraction Methods 0.000 claims description 6
- 230000000694 effects Effects 0.000 abstract description 11
- 238000012545 processing Methods 0.000 abstract description 10
- 238000013473 artificial intelligence Methods 0.000 abstract description 6
- 238000013135 deep learning Methods 0.000 abstract description 4
- 238000010586 diagram Methods 0.000 description 17
- 206010012689 Diabetic retinopathy Diseases 0.000 description 9
- 238000004891 communication Methods 0.000 description 8
- 238000012360 testing method Methods 0.000 description 7
- 238000011161 development Methods 0.000 description 5
- 238000003745 diagnosis Methods 0.000 description 5
- 201000010099 disease Diseases 0.000 description 5
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 5
- 230000008569 process Effects 0.000 description 5
- 238000010200 validation analysis Methods 0.000 description 4
- 238000009826 distribution Methods 0.000 description 3
- 230000003902 lesion Effects 0.000 description 3
- 238000012986 modification Methods 0.000 description 3
- 230000004048 modification Effects 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 229910015234 MoCo Inorganic materials 0.000 description 2
- 230000033228 biological regulation Effects 0.000 description 2
- 238000010276 construction Methods 0.000 description 2
- 238000013461 design Methods 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 230000002708 enhancing effect Effects 0.000 description 2
- 235000008434 ginseng Nutrition 0.000 description 2
- 230000003993 interaction Effects 0.000 description 2
- 230000009466 transformation Effects 0.000 description 2
- 230000000007 visual effect Effects 0.000 description 2
- 241000208340 Araliaceae Species 0.000 description 1
- 240000004371 Panax ginseng Species 0.000 description 1
- 235000002789 Panax ginseng Nutrition 0.000 description 1
- 235000005035 Panax pseudoginseng ssp. pseudoginseng Nutrition 0.000 description 1
- 235000003140 Panax quinquefolius Nutrition 0.000 description 1
- 230000002776 aggregation Effects 0.000 description 1
- 238000004220 aggregation Methods 0.000 description 1
- 238000003491 array Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008901 benefit Effects 0.000 description 1
- 230000001413 cellular effect Effects 0.000 description 1
- 238000012512 characterization method Methods 0.000 description 1
- 239000003086 colorant Substances 0.000 description 1
- 238000010924 continuous production Methods 0.000 description 1
- 238000005520 cutting process Methods 0.000 description 1
- 230000007423 decrease Effects 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 230000009977 dual effect Effects 0.000 description 1
- 239000000284 extract Substances 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 230000007786 learning performance Effects 0.000 description 1
- 239000004973 liquid crystal related substance Substances 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 210000002569 neuron Anatomy 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 238000012216 screening Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 230000001953 sensory effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 238000000844 transformation Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
- 238000012800 visualization Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
- G06F18/2155—Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本公开提供了训练分类模型方法、数据分类方法、装置、设备、介质及产品,涉及人工智能技术领域,具体为深度学习、计算机视觉技技术领域,可应用于医学影像处理场景。具体实现方案为:利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定第一训练图像对应的分类损失;利用分类模型的第二网络分支提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征;基于第一图像特征与第二图像特征,确定对比损失;基于分类损失以及对比损失,更新分类模型的参数,得到训练完成的分类模型。本公开提升了分类模型的图像分类效果。
Description
技术领域
本公开涉及人工智能技术领域,尤其涉及深度学习、计算机视觉技术领域,可应用于医学影像处理场景。
背景技术
自监督对比学习是无监督学习的一种,能够从无标注的数据中学习知识,随着自监督对比学习的发展,从特征层面取得了很好的效果。例如,在人工智能的数据分类中,使用自监督对比学习方式能够对影像等图像数据进行分类,比如对医学图像数据进行分级。
对于有监督模型,需要大量高质量标注的样本提高学习效果。对于数据标注成本高,有标注样本相对较少、标注质量差的情况,有监督模型往往泛化能力不够强,且标注本身的噪声限制了有监督分类模型的上限。
发明内容
本公开提供了一种用于训练分类模型方法、数据分类方法、装置、设备、介质及产品。
根据本公开的一方面,提供了一种训练分类模型的方法,包括:利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失;利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;基于所述第一图像特征与所述第二图像特征,确定对比损失;基于所述分类损失以及所述对比损失,更新所述分类模型的参数,得到训练完成的分类模型。
根据本公开的另一方面,提供了一种数据分类方法,包括:
确定待分类数据;将所述待分类数据输入至分类模型,得到所述分类模型的输出结果;基于所述分类模型的输出结果,确定所述待分类数据的分类结果;其中,所述分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到;所述第一网络分支用于对第一训练图像进行分类预测,得到所述第一训练图像的分类预测结果;所述第二网络分支用于提取所述第一训练图像的第一图像特征以及第二训练图像的第二图像特征;其中,所述分类损失通过所述第一训练图像的分类预测结果确定,所述对比损失基于所述第一图像特征与所述第二图像特征确定。
根据本公开的又一方面,提供了一种训练分类模型的装置,包括:确定模块,用于利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失,以及,基于第一图像特征与第二图像特征,确定对比损失;提取模块,利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;更新模块,用于基于所述分类损失以及所述对比损失,更新所述分类模型的参数,得到训练完成的分类模型。
根据本公开的又一方面,提供了一种数据分类装置,包括:
确定模块,用于确定待分类数据;分类模块,用于将所述待分类数据输入至分类模型,得到所述分类模型的输出结果,并基于所述分类模型的输出结果,确定所述待分类数据的分类结果;其中,所述分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到;所述第一网络分支用于对第一训练图像进行分类预测,得到所述第一训练图像的分类预测结果;所述第二网络分支用于提取所述第一训练图像的第一图像特征以及第二训练图像的第二图像特征;其中,所述分类损失通过所述第一训练图像的分类预测结果确定,所述对比损失基于所述第一图像特征与所述第二图像特征确定。
根据本公开的又一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开中的训练分类模型的方法。
根据本公开的又一方面,提供了一种电子设备,包括:至少一个处理器;以及与所述至少一个处理器通信连接的存储器;其中,所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行本公开中的数据分类方法。
根据本公开的又一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行本公开中训练分类模型的方法。
根据本公开的又一方面,提供了一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行本公开中数据分类方法。
根据本公开的又一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开中的训练分类模型的方法。
根据本公开的又一方面,提供了一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现本公开中的数据分类方法。
应当理解,本部分所描述的内容并非旨在标识本公开的实施例的关键或重要特征,也不用于限制本公开的范围。本公开的其它特征将通过以下的说明书而变得容易理解。
附图说明
附图用于更好地理解本方案,不构成对本公开的限定。其中:
图1是根据本公开一示例性实施例示出的训练分类模型的方法流程示意图;
图2是根据本公开一示例性实施例示出的分类模型结构示意图;
图3示出了一种有监督分类模型结构的示意图;
图4示出了一种无监督对比学习的分类模型结构的示意图;
图5是根据本公开提供的一种基于分类损失以及对比损失,更新分类模型参数的方法流程示意图;
图6是根据本公开的利用分类模型的第一网络分支对第一训练图像进行分类预测的分类预测结果,确定第一训练图像对应的分类损失的方法流程示意图;
图7是根据本公开的利用所述分类模型的第二网络分支分别提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征的方法流程示意图;
图8是在糖尿病视网膜病变分级数据集上,有监督分类器的骨干网络的输出特征可视化的结果示意图;
图9示出了糖尿病视网膜病变分级数据集中包括的数据集详细信息;
图10为有监督分类器分类模型结构和本公开中双流结构的分类模型结构训练集损失值(loss)随训练过程的移动平均线曲线对比示意图;
图11为有监督分类器分类模型结构和本公开中双流结构的分类模型结构验证集中Kappa值的移动平均曲线。
图12展现了有监督分类器分类模型结构和本公开中双流结构的分类模型结构在测试集上的Kappa值对比示意图;
图13是根据本公开一示例性实施例示出的一种数据分类方法流程图;
图14是根据本公开一示例性实施例示出的一种训练分类模型的装置框图;
图15是根据本公开一示例性实施例示出的一种数据分类装置框图;
图16是用来实现本公开实施例的训练分类模型的方法或数据分类方法的电子设备的框图。
具体实施方式
以下结合附图对本公开的示范性实施例做出说明,其中包括本公开实施例的各种细节以助于理解,应当将它们认为仅仅是示范性的。因此,本领域普通技术人员应当认识到,可以对这里描述的实施例做出各种改变和修改,而不会背离本公开的范围和精神。同样,为了清楚和简明,以下的描述中省略了对公知功能和结构的描述。
目前数据分类应用于多领域。例如应用在医学影像的分类。各类医学影像在临床上应用广泛,为临床中疾病的诊断提供有效的依据,医学影像在医疗发展中占据重要地位。当前技术中,使用医学影像进行诊断依赖于医生进行的人工判读,医学影像的大数据量给医生带来巨大的负担,且对于医学影像的判断依赖于医生的经验,其准确程度无法保证。
近年来深度学习技术在医学影像处理领域内的应用日益广泛,通过对医学影像进行初步筛查或者辅助医生诊断的方式有效的减轻了医生的负担并提高了疾病诊断的准确率。通过深度学习实现医学影像诊断在实现中,深度学习方法依赖于大量的高质量的有标注数据,对医学影像数据进行标注非常耗时且代价高昂。由于影像中显示病灶的发病过程连续,而标注往往是离散的,且病灶的形态多样,不同区域之间相互关联,很难精确定义病症发展程度的分界线,相邻等级之间界限难以区分,对病症的程度进行分类困难。
随着人工智能技术的发展,各应用场景下需要进行人工智能模型的训练。对于一般的有监督分类模型,需要大量高质量标注的样本,对于医学图像而言,有监督模型往往泛化能力不够强,标注本身的噪声限制了有监督分类模型的上限。随着自监督对比学习的发展,实现从特征层面对比学习。
鉴于此,本公开实施例提供了一种训练分类模型的方法,通过结合非监督对比学习和有监督分类的网络结构,有效提升分类模型的性能。
图1是根据本公开一示例性实施例示出的训练分类模型的方法流程示意图,参照图1,该方法包括以下步骤。
在步骤S101中,利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定第一训练图像对应的分类损失。
在步骤S102中,利用分类模型的第二网络分支提取第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征。
在步骤S103中,基于第一图像特征与第二图像特征,确定对比损失。
在步骤S104中,基于分类损失以及对比损失,更新分类模型的参数,得到训练完成的分类模型。
在本公开实施例中,基于分类模型的第一网络分支确定训练图像的分类损失,并基于分类模型的第二网络分支确定训练图像对应的对比损失。并基于分类损失与对比损失更新初始分类模型的参数,得到训练完成的分类模型。
在本公开实施例中,对于作为训练图像的图像集,可以进行数据增强,第一训练图像和第二训练图像。其中,数据增强可以为模糊处理、旋转、平移、剪切等操作。可以理解地,第一训练图像与第二训练图像为对相同的训练图像经过不同的数据增强得到的,第一训练图像与第二训练图像来自于同一训练图像。其中,训练图像可采用已有图像数据集,也可以根据需要进行采集并新构建等方式获取,本实施例中对此不做具体限定。
其中,为描述方便,将得到第一训练图像所使用的数据增强称为第一数据增强。可以理解的是,第一训练图像为对训练图像进行第一数据增强得到的。将得到第二训练图像所使用的数据增强称为第二数据增强。可以理解的是,第二训练图像为对训练图像进行第二数据增强得到的。
其中,第一数据增强和第二数据增强可以理解为是训练图像经过的两次不同的数据增强。
本公开实施例中,对图像集中的图像分别进行不同的数据增强,得到第一训练图像和第二训练图像,能够增强训练图像的多样性。
图2是根据本公开一示例性实施例示出的分类模型结构示意图,参照图2,本公开实施例中的分类模型可以为双流结构,即分类模型包括第一网络分支以及第二网络分支。本公开实施例中的第一网络分支以及第二网络分支分别包括骨干网络模块,可以用fθ(x)表示。第一网络分支以及第二网络分支中的骨干网络模块可以是相同的网络结构,不同的网络参数。其中,第一网络分支可以是有监督分类模型结构。第二网络分支可以是无监督对比学习的分类模型结构。
其中,图3示出了一种有监督分类模型结构的示意图。其中,有监督分类模型需要大量的数据和标签信息。如图3所示输入的图像数据(有监督分类模型结构获取的训练图像),经过数据增强,增加了图像的多样性,经过骨干网络fθ(x)和分类器后,与标注信息计算交叉熵损失(Cross-entropy loss,CE),进行梯度反向传播来更新网络权重。这种结构依赖标注信息,如果标注信息质量不高,那么分类器性能也难以提升,标注信息限制了监督学习的性能。
其中,图4示出了一种无监督对比学习的分类模型结构的示意图。无监督对比学习的分类模型结构中,针对获取到的训练图像经过两次不同的数据增强(图4中以第一数据增强和第二数据增强进行示意性说明)。经过两次不同的数据增强得到的训练图像,分别经过各自的骨干网络fθ(x),各自接入投影器(其中,投影器有时也称为投影层)。参阅图4所示,无监督对比学习的分类模型结构中,上分支的骨干网络模块通过投影器后接入预测器(其中,预测器有时也称为预测层)。下分支的骨干网络模块通过投影器与上分支预测器的输出特征计算对比损失。梯度反向传播只在上分支进行,下分支的网络权重通过上分支网络参数进行指数移动平均(exponential moving average,EMA)来更新。其中,无监督对比学习的分类模型结构不需要标注信息,可以利用大量无标注数据来进行训练,因此常被用来做预训练模型。这种方式预训练的模型往往比有监督预训练的模型要好,能够拉近相似样本的特征和推远不相似样本特征,这一点和有监督训练有相同的方向,优点是不需要标注信息。
无监督对比学习的分类模型结构例如可以是基于动量比对的非监督式视觉表征学习(Momentum Contrast,MoCo)v3结构。通过MoCo v3结构,可以对每一张原始图片进行不同的随机变换得到两个样例,来源于同一张图片的两个样例之间构成一个正样本对,来源于不同图片的两个样例之间构成一个负样本对,之后使用相同的网络作为编码器对每一个样例进行特征抽取和编码得到样本的特征表示,之后优化模型增大正样本对之间的相似度并降低负样本对之间的相似度,通过以上方法来实现无监督的样本中有价值、有区分度的特征的抽取。
参阅图2所示,本公开实施例提供的具有双流结构的分类模型结构中,结合了有监督分类模型结构和无监督对比学习的分类模型结构。其中,第一网络分支中,对获取到的训练图像进行第一数据增强后得到第一训练图像。第一训练图像通过第一网络分支的骨干网络模块后接分类器。其中,结合第一网络分支中分类器输出结果与标注信息,可以实现对第一训练图像进行分类预测,得到第一训练图像的分类预测结果。通过第一训练图像的分类预测结果,确定第一训练图像对应的分类损失,例如计算交叉熵损失。其中,第二网络分支是无监督对比学习的分类模型结构。第二网络分支可以用于提取图像本身的特征信息,后续称为图像特征。例如,本公开实施例中第二网络分支包括上分支与下分支两个分支,以下可以称为第一分支和第二分支。其中,第一分支中包括投影器和预测器。第二分支中包括投影器。通过第二网络分支的不同分支可以提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征。
在本公开实施例中,通过第二网络分支的不同分支提取的第一训练图像的第一图像特征,以及第二训练图像的第二图像特征,确定训练图像对应的对比损失。第二网络分支的对比损失使得第二网络分支中骨干网络提取的特征中相似特征聚集,不相似特征远离,更好地学习相似特征,远离不相似特征。
本公开实施例图2所示具有双流结构的分类模型结构同时学习标注信息和图像本身的特征信息,不再完全依赖标注信息来训练模型。例如,可以在无监督对比学习的分类模型结构MoCov3的基础之上,增加分类功能输出头,实现分类功能。并且,对比损失头使得骨干网络提取的特征具备相似特征聚集,不相似特征远离。在对比学习最近的研究表明,这种结构下的骨干网络作为预训练模型,在下游任务中微调后效果最好。
综上,根据本公开实施例提供的训练分类模型的方法,获取训练图像利用分类模型的第一网络分支对第一训练图像进行分类预测,得到分类预测结果,通过分类预测结果,确定第一训练图像对应的分类损失。利用分类模型的第二网络分支提取第一训练图像的第一图像特征、第二训练图像的第二图像特征。基于第一图像特征与第二图像特征,确定训练图像对应的对比损失,可以直接比较图像特征之间的关系,弱化有监督分类模型中标签的作用。基于分类损失与对比损失更新分类模型的参数,得到训练完成的分类模型。
本实施例中的训练分类模型的方法,通过在分类模型训练中,基于分类损失与对比损失更新分类模型的参数,实现利用对图像的标注信息以及对图像特征的学习,有效提升训练得到的分类模型的分类性能,提升通过分类模型进行图像分类的效果。
在本公开示例性的实施方式中,训练图像对应的对比损失通过多分类版本的噪声对比估计损失函数确定。利用对比学习中的多分类版本的噪声对比估计(Info Noise-contrastive estimation,Info NCE)损失函数计算训练图像对应的对比损失。Info NCE损失函数采用将正样本与负样本对比的方式,更好地学习图像特征。即,计算第一训练图像进行分类预测的分类预测结果,与对训练图像进行标注的标签对应向量的差异,即损失值,基于获得的损失值训练分类模型,能够更为有效地优化分类模型的训练。
本公开实施例以下将对基于本公开实施例提供的具有双流结构的分类模型结构的训练过程进行分别说明。
图5是根据本公开提供的一种基于分类损失以及对比损失,更新分类模型参数的方法流程示意图,参照图5,该方法包括以下步骤。
在步骤S501中,基于分类损失,利用反向传播更新第一网络分支的参数。
在步骤S502中,基于对比损失,利用反向传播更新第一网络分支的参数。
在步骤S503中,基于更新后的所述第一网络分支的参数,通过动量更新第二网络分支的参数。
在本公开实施例中,分类模型包括第一网络分支与第二网络分支,第一训练图像通过第一网络分支,得到第一训练图像的分类预测结果,基于分类预测结果计算分类损失。分类模型通过第二网络分支的不同分支提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征,确定训练图像对应的对比损失。基于分类损失,利用反向传播更新第一网络分支的参数,并基于对比损失,利用反向传播更新第一网络分支的参数。即得到的对比损失利用反向传播算法,更新第一网络分支的参数。反向传播(Back propagationalgorithm,BP)算法,是适合于多层神经元网络的学习算法。BP算法建立在梯度下降法的基础,通过最小化网络的损失,并通过后反馈调整更新第一网络分支的参数。
在本公开实施例中,通过第一网络分支的参数,动量更新第二网络分支的参数,从而得到训练完成的分类模型。例如,第二网络分支的网络权重通过第一网络分支网络参数进行指数移动平均(Exponential Moving Average,EMA)更新。利用分类损失更新第一网络分支的参数,可以实现对图像的标注信息的利用,利用对比损失更新第一网络分支的参数,可以拉近相似样本的特征、推远不相似样本特征,以充分学习图像特征,从而提升训练得到的分类模型的分类性能。
图6是根据本公开的利用分类模型的第一网络分支对第一训练图像进行分类预测的分类预测结果,确定第一训练图像对应的分类损失的方法流程示意图,参照图6,该方法包括以下步骤。
在步骤S601中,将第一训练图像输入至分类模型的第一网络分支的分类器,得到对第一训练图像进行分类预测的分类预测结果。
在步骤S602中,基于分类预测结果与所述训练图像的标注信息,确定所述第一训练图像对应的交叉熵分类损失。
在本公开实施例中,分类模型包括第一网络分支,第一训练图像通过第一网络分支的骨干网络模块和分类器,对第一训练图像进行分类预测,得到第一训练图像的分类预测结果。将分类预测结果与对训练图像的标注信息计算分类损失,其中,分类损失可以是交叉熵损失。交叉熵描述了两个概率分布之间的距离,交叉熵越小说明两者之间越接近,交叉熵损失用于确定每次训练的准确率。
在本公开实施例中,分类损失函数,即交叉熵函数可以通过如下式确定:
其中,pq表示模型输出softmax变换后与标注信息对应的概率。
图7是根据本公开的利用所述分类模型的第二网络分支分别提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征的方法流程示意图,参照图7,该方法包括以下步骤。
在步骤S701中,利用第二网络分支中第一分支的投影器与预测器,提取第一训练图像的第一图像特征。
在步骤S702中,利用第二网络分支中第二分支的投影器,提取第二训练图像的第二图像特征。
在本公开实施例中,第二网络分支是无监督对比学习模型,第二网络分支包括第一分支与第二分支两个分支,通过第二网络分支的第一分支与第二分支分别提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征。第一训练图像、第二训练图像为对相同的训练图像进行不同的数据增强分别得到的,第二网络分支的第一分支包括骨干网络、投影器以及预测器,第二分支包括骨干网络以及投影器。即,在第二网络分支中,第一训练图像在第一分支经过骨干网络,接投影器后再接预测器,输出第一图像特征;在第二网络分支的第二分支中,第二训练图像经过骨干网络接投影器,输出第二图像特征,计算第一图像特征与第二图像特征的对比损失。基于对比损失,利用反向传播更新第一分支的参数,并基于第一分支的参数,通过动量更新第二分支的参数。
在本公开示例性的实施方式中,对比损失函数可以通过如下式确定:
其中,对于一个训练批次的K个样本,样本q对应的正样本为k+,τ为温度系数,用于调节输出的概率分布,根据训练批次大小调节,批次越大τ值越小,q·k+表示输入的训练图像与其正样本的相似度。
分类损失函数与对比损失函数,可以联合表示为:
其中,α为权重,用于加权对比损失,可以根据实际训练需要设定。
本实施例中的训练分类模型的方法,通过第二网络分支的结构,利用相同的图像经过不同分支输出特征的对比损失,对网络参数进行更新,实现直接比较图像特征之间的关系,提升了分类模型的学习性能,进而提升模型精度。
本公开实施例以下将结合实际应用对上述涉及的分类模型训练过程所能达到的效果进行说明。
本公开实施例以下以上述进行分类模型训练的训练图像为医学影像为例进行说明。例如,本公开实施例中以糖尿病视网膜病变的医学影像数据为例进行说明。当然,在实际应用中并不局限于医学影像数据,也并不局限于糖尿病视网膜病变的医学影像数据,还可以是其他需要进行分类/分级的数据。
图8是在糖尿病视网膜病变分级数据集上,有监督分类器的骨干网络的输出特征可视化的结果示意图。参阅图8所示,糖尿病视网膜病变分级数据集中包括0、1、2、3、4和5等五种等级数据。其中。不同颜色亮度代表不同等级,可以看出相邻等级之间混杂在一起,难以区分,这就是数据相邻级别差异不大,标注质量不高导致的,相邻等级之间相似度高,病变本身是一个连续的过程,连续等级边界难以区分。
本公开实施例中,基于具有双流结构的分类模型结构能够在一定程度上缓解这种问题。图2中损失函数包括两部分,一个是交叉熵损失函数,使得分类模型结构具备分级功能,另一个是对比损失函数,不需要标注信息,能够直接对图像特征进行对比,学习相似特征区分不相似特征,以下结合实验数据进行说明。
其中,图9示出了糖尿病视网膜病变分级数据集中包括的数据集详细信息。参阅图9所示,数据集由不同类型的相机拍摄,分别为相机类型1、2、3、4和5。并且主要包括5种级别类型。5种级别类型分别为0、1、2、3、和4。其中,第0级表示阴性,1到4级表示病变的加重程度。
一示例中,相机类型1的相机所拍摄图像的数据子集上,具有交叉熵损失的有监督分类器对比本专利网络的结果,骨干网络均为Resnet50,数据增强和超参数设置均相同,学习率均为0.001,训练次数为500,采用随机梯度下降优化器,迭代训练40000次。其中,针对有监督分类器分类模型结构,在测试集上最终计算得到的交叉熵损失对应的Kappa值为78.24。应用本公开中双流结构的分类模型结构Kappa值为80.01,高于有监督分类器分类模型结构1.77%。其中,Kappa值越大说明分类性能越好,即说明应用本公开中双流结构的分类模型结构具有更高的分类性能。
进一步的试验结果如图10和图11所示。图10为有监督分类器分类模型结构和本公开中双流结构的分类模型结构训练集损失值(loss)随训练过程的移动平均线曲线对比示意图。图11为有监督分类器分类模型结构和本公开中双流结构的分类模型结构验证集中Kappa值的移动平均曲线。结合图10,可以得到在分类模型结构训练开始时,本公开中双流结构的分类模型结构具有较高的损失值,随着训练的进行,本公开中双流结构的分类模型结构loss值快速下降,最终低于有监督分类器分类模型结构。一般而言同等条件下loss越小说明训练集收敛效果好。结合图11,本公开中双流结构的分类模型结构的损失在训练后期有更小的损失值,但是验证集指标是更好的,说明本公开中双流结构的分类模型结构相对于有监督分类器分类模型结构具有更好的泛化性能。
基于上述示例,本公开中双流结构的分类模型结构交叉熵损失和对比损失的权重是1:1,在这种情况下,最终损失收敛小于有监督分类器分类模型结构。本公开中双流结构的分类模型结构同时学到了两种损失下的特征,这两种损失并不冲突。本公开中双流结构的分类模型结构同时兼具标注信息和特征本身的对比信息。最终应用本公开提供的分类模型训练方法的效果优于有监督分类器分类模型结构,也说明对比损失通过学习不同样本之间的相似性,有利于矫正原始标注信息。
在另一示例中,合并5种相机类型的相机所拍摄图像的训练集和验证集,训练一个模型。在5种相机的测试集上分别进行测试。本公开中双流结构的分类模型结构与有监督分类器分类模型结构的配置参数一致,详细参数配置与上述示例中相同,试验结果如图12所示。图12展现了5种相机测试集上的Kappa值,该值越大表示分类效果越好。数据集合并后数据集变大,有监督分类器分类模型结构和本公开中双流结构的分类模型结构的网络分类性能均有明显提升,这符合数据集越大泛化性能越好的常识。结合图10,可以确定应用本公开双流结构的分类模型结构,在每种类型的相机上都取得了更好的结果。Kappa值平均上升1.2%,这进一步证实了应用本公开双流结构的分类模型结构的有效性。
基于本公开实施例提供的分类模型,本公开实施例还提供一种数据分类方法。
图13是根据本公开一示例性实施例示出的一种数据分类方法流程图。参阅图13所示,包括如下步骤:
在步骤S1301中,确定待分类数据。
在步骤S1302中,将待分类数据输入至分类模型,得到分类模型的输出结果。
本公开实施例中,该分类模型为具有双流结构的分类模型结构。
在步骤S1303中,基于分类模型的输出结果,确定待分类数据的分类结果。
本公开实施例提供的数据分类方法,应用双流结构的分类模型结构进行数据分类,由于对图像的标注信息以及对图像特征进行学习,故能够提升通过分类模型进行图像分类的效果。
其中,本公开实施例中涉及的具有双流结构的分类模型结构的训练方法可以采用本公开上述各实施例涉及的训练方法进行预先训练得到,在此不再详述,具体可参阅上述实施例的相关描述。
例如,本公开实施例中涉及的具有双流结构的分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到。其中,第一网络分支用于对第一训练图像进行分类预测,得到第一训练图像的分类预测结果;第二网络分支用于提取第一训练图像的第一图像特征以及第二训练图像的第二图像特征;其中,分类损失通过第一训练图像的分类预测结果确定,对比损失基于第一图像特征与第二图像特征确定。
在本公开实施例中,基于分类模型的第一网络分支确定训练图像的分类损失,并基于分类模型的第二网络分支确定训练图像对应的对比损失。并基于分类损失与对比损失更新初始分类模型的参数,得到训练完成的分类模型。
其中,一种实施方式中,第一训练图像的分类预测结果基于第一网络分支的分类器得到;分类损失为第一训练图像的交叉熵分类损失,第一训练图像的交叉熵分类损失基于第一训练图像的分类预测结果以及训练图像的标注信息确定;第一训练图像通过对训练图像进行数据增强得到。
在本公开实施例中,分类模型包括第一网络分支,第一训练图像通过第一网络分支的骨干网络模块和分类器,对第一训练图像进行分类预测,得到第一训练图像的分类预测结果。将分类预测结果与对训练图像的标注信息计算分类损失,其中,分类损失可以是交叉熵损失。交叉熵描述了两个概率分布之间的距离,交叉熵越小说明两者之间越接近,交叉熵损失用于确定每次训练的准确率。
其中,一种实施方式中,第一训练图像的第一图像特征利用第二网络分支中第一分支的投影器与预测器提取;第二训练图像的第二图像特征利用第二网络分支中第二分支的投影器提取。
在本公开实施例中,通过第二网络分支的不同分支提取的第一训练图像的第一图像特征,以及第二训练图像的第二图像特征,确定训练图像对应的对比损失。第二网络分支的对比损失使得第二网络分支中骨干网络提取的特征中相似特征聚集,不相似特征远离,更好地学习相似特征,远离不相似特征。
其中,本公开一示例性实施例中,上述待分类数据可以是医学影像数据。对于医学影像数据,通过应用本公开实施例提供的具有双流结构的分类模型结构进行医学影像数据的分类,例如进行病灶严重程度分级,能有效提升分类器性能,缓解医学图像标注质量差的问题。一典型应用场景中,可以对糖尿病视网膜病变分级。
结合本公开上述示例可知,应用本公开实施例提供的具有双流结构的分类模型结构进行糖尿病视网膜病变分级,分类性能指标Kappa值提升明显。
基于相同的构思,本公开实施例还提供一种训练分类模型的装置。
可以理解的是,本公开实施例提供的装置为了实现上述功能,其包含了执行各个功能相应的硬件结构和/或软件模块。结合本公开实施例中所公开的各示例的单元及算法步骤,本公开实施例能够以硬件或硬件和计算机软件的结合形式来实现。某个功能究竟以硬件还是计算机软件驱动硬件的方式来执行,取决于技术方案的特定应用和设计约束条件。本领域技术人员可以对每个特定的应用来使用不同的方法来实现所描述的功能,但是这种实现不应认为超出本公开实施例的技术方案的范围。
图14是根据本公开一示例性实施例示出的一种训练分类模型的装置框图。如图14所示,本公开实施例的训练分类模型的装置1400,包括:确定模块1401、提取模块1402以及更新模块1403。
确定模块1401,用于利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定第一训练图像对应的分类损失,以及,基于第一图像特征与第二图像特征,确定对比损失;提取模块1402,利用分类模型的第二网络分支提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征;更新模块1403,用于基于分类损失以及对比损失,更新分类模型的参数,得到训练完成的分类模型。
其中,训练图像对应的对比损失通过多分类版本的噪声对比估计损失函数确定。
其中,更新模块1403采用如下方式基于分类损失以及对比损失,更新分类模型的参数:
基于分类损失,利用反向传播更新第一网络分支的参数,并基于对比损失,利用反向传播更新第一网络分支的参数;
基于更新后的第一网络分支的参数,通过动量更新第二网络分支的参数。
其中,确定模块1401用于采用如下方式利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定第一训练图像对应的分类损失:
将第一训练图像输入至分类模型的第一网络分支的分类器,得到对第一训练图像进行分类预测的分类预测结果;
基于分类预测结果与训练图像的标注信息,确定第一训练图像对应的交叉熵分类损失。
其中,提取模块1402用于采用如下方式利用分类模型的第二网络分支提取第一训练图像的第一图像特征,以及第二训练图像的第二图像特征:利用第二网络分支中第一分支的投影器与预测器,提取第一训练图像的第一图像特征;利用第二网络分支中第二分支的投影器,提取第二训练图像的第二图像特征。
图15是根据本公开一示例性实施例示出的一种数据分类装置框图。如图15所示,本公开实施例的数据分类装置1500,包括:确定模块1501和分类模块1502。
确定模块1501,用于确定待分类数据;分类模块1502,用于将待分类数据输入至分类模型,得到分类模型的输出结果,并基于分类模型的输出结果,确定待分类数据的分类结果。
其中,分类模型具有双流结构的分类模型结构。
本公开实施例中涉及的具有双流结构的分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到。其中,第一网络分支用于对第一训练图像进行分类预测,得到第一训练图像的分类预测结果;第二网络分支用于提取第一训练图像的第一图像特征以及第二训练图像的第二图像特征;其中,分类损失通过第一训练图像的分类预测结果确定,对比损失基于第一图像特征与第二图像特征确定。
其中,一种实施方式中,第一训练图像的分类预测结果基于第一网络分支的分类器得到;分类损失为第一训练图像的交叉熵分类损失,第一训练图像的交叉熵分类损失基于第一训练图像的分类预测结果以及训练图像的标注信息确定;第一训练图像通过对训练图像进行数据增强得到。
其中,一种实施方式中,第一训练图像的第一图像特征利用第二网络分支中第一分支的投影器与预测器提取;第二训练图像的第二图像特征利用第二网络分支中第二分支的投影器提取。
本公开实施例中涉及的具有双流结构的分类模型结构的训练方法可以采用本公开上述各实施例涉及的训练方法进行预先训练得到,在此不再详述,具体可参阅上述实施例的相关描述。
其中,待分类数据为医学影像数据。
关于本公开上述涉及的装置,其中各个模块执行操作的具体方式已经在有关该方法的实施例中进行了详细描述,此处将不做详细阐述说明。
综上,根据本公开实施例提供的训练分类模型的装置,获取训练图像利用分类模型的第一网络分支对第一训练图像进行分类预测,得到分类预测结果,通过分类预测结果,确定第一训练图像对应的分类损失。利用分类模型的第二网络分支分别提取第一训练图像的第一图像特征、第二训练图像的第二图像特征,第一训练图像为对训练图像进行第一数据增强得到的,第二训练图像为对训练图像进行第二数据增强得到的,并基于第一图像特征与第二图像特征,确定训练图像对应的对比损失,可以直接比较图像特征之间的关系,弱化有监督分类模型中标签的作用,最后基于分类损失与对比损失更新分类模型的参数,得到训练完成的分类模型。本实施例中的训练分类模型的方法,通过在分类模型训练中,基于分类损失与对比损失更新分类模型的参数,实现利用对图像的标注信息以及对图像特征的学习,有效提升训练得到的分类模型的分类性能,提升通过分类模型进行图像分类的效果。
本公开的技术方案中,所涉及的用户个人信息的获取,存储和应用等,均符合相关法律法规的规定,且不违背公序良俗。
根据本公开的实施例,本公开还提供了一种电子设备、一种可读存储介质和一种计算机程序产品。
图16示出了可以用来实施本公开的实施例的示例电子设备1600的示意性框图。电子设备旨在表示各种形式的数字计算机,诸如,膝上型计算机、台式计算机、工作台、个人数字助理、服务器、刀片式服务器、大型计算机、和其它适合的计算机。电子设备还可以表示各种形式的移动装置,诸如,个人数字处理、蜂窝电话、智能电话、可穿戴设备和其它类似的计算装置。本文所示的部件、它们的连接和关系、以及它们的功能仅仅作为示例,并且不意在限制本文中描述的和/或者要求的本公开的实现。
如图16所示,设备1600包括计算单元1601,其可以根据存储在只读存储器(ROM)1602中的计算机程序或者从存储单元1608加载到随机访问存储器(RAM)1603中的计算机程序,来执行各种适当的动作和处理。在RAM 1603中,还可存储设备1600操作所需的各种程序和数据。计算单元1601、ROM 1602以及RAM 1603通过总线1604彼此相连。输入/输出(I/O)接口1605也连接至总线1604。
设备1600中的多个部件连接至I/O接口1605,包括:输入单元1606,例如键盘、鼠标等;输出单元1607,例如各种类型的显示器、扬声器等;存储单元1608,例如磁盘、光盘等;以及通信单元1609,例如网卡、调制解调器、无线通信收发机等。通信单元1609允许设备1600通过诸如因特网的计算机网络和/或各种电信网络与其他设备交换信息/数据。
计算单元1601可以是各种具有处理和计算能力的通用和/或专用处理组件。计算单元1601的一些示例包括但不限于中央处理单元(CPU)、图形处理单元(GPU)、各种专用的人工智能(AI)计算芯片、各种运行机器学习模型算法的计算单元、数字信号处理器(DSP)、以及任何适当的处理器、控制器、微控制器等。计算单元1601执行上文所描述的各个方法和处理,例如的训练分类模型的方法或数据分类方法。例如,在一些实施例中,训练分类模型的方法或数据分类方法可被实现为计算机软件程序,其被有形地包含于机器可读介质,例如存储单元1608。在一些实施例中,计算机程序的部分或者全部可以经由ROM 1602和/或通信单元1609而被载入和/或安装到设备1600上。当计算机程序加载到RAM 1603并由计算单元1601执行时,可以执行上文描述的训练分类模型的方法或数据分类方法的一个或多个步骤。备选地,在其他实施例中,计算单元1601可以通过其他任何适当的方式(例如,借助于固件)而被配置为执行训练分类模型的方法或数据分类方法。
本文中以上描述的系统和技术的各种实施方式可以在数字电子电路系统、集成电路系统、场可编程门阵列(FPGA)、专用集成电路(ASIC)、专用标准产品(ASSP)、芯片上系统的系统(SOC)、负载可编程逻辑设备(CPLD)、计算机硬件、固件、软件、和/或它们的组合中实现。这些各种实施方式可以包括:实施在一个或者多个计算机程序中,该一个或者多个计算机程序可在包括至少一个可编程处理器的可编程系统上执行和/或解释,该可编程处理器可以是专用或者通用可编程处理器,可以从存储系统、至少一个输入装置、和至少一个输出装置接收数据和指令,并且将数据和指令传输至该存储系统、该至少一个输入装置、和该至少一个输出装置。
用于实施本公开的方法的程序代码可以采用一个或多个编程语言的任何组合来编写。这些程序代码可以提供给通用计算机、专用计算机或其他可编程数据处理装置的处理器或控制器,使得程序代码当由处理器或控制器执行时使流程图和/或框图中所规定的功能/操作被实施。程序代码可以完全在机器上执行、部分地在机器上执行,作为独立软件包部分地在机器上执行且部分地在远程机器上执行或完全在远程机器或服务器上执行。
在本公开的上下文中,机器可读介质可以是有形的介质,其可以包含或存储以供指令执行系统、装置或设备使用或与指令执行系统、装置或设备结合地使用的程序。机器可读介质可以是机器可读信号介质或机器可读储存介质。机器可读介质可以包括但不限于电子的、磁性的、光学的、电磁的、红外的、或半导体系统、装置或设备,或者上述内容的任何合适组合。机器可读存储介质的更具体示例会包括基于一个或多个线的电气连接、便携式计算机盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦除可编程只读存储器(EPROM或快闪存储器)、光纤、便捷式紧凑盘只读存储器(CD-ROM)、光学储存设备、磁储存设备、或上述内容的任何合适组合。
为了提供与用户的交互,可以在计算机上实施此处描述的系统和技术,该计算机具有:用于向用户显示信息的显示装置(例如,CRT(阴极射线管)或者LCD(液晶显示器)监视器);以及键盘和指向装置(例如,鼠标或者轨迹球),用户可以通过该键盘和该指向装置来将输入提供给计算机。其它种类的装置还可以用于提供与用户的交互;例如,提供给用户的反馈可以是任何形式的传感反馈(例如,视觉反馈、听觉反馈、或者触觉反馈);并且可以用任何形式(包括声输入、语音输入或者、触觉输入)来接收来自用户的输入。
可以将此处描述的系统和技术实施在包括后台部件的计算系统(例如,作为数据服务器)、或者包括中间件部件的计算系统(例如,应用服务器)、或者包括前端部件的计算系统(例如,具有图形用户界面或者网络浏览器的用户计算机,用户可以通过该图形用户界面或者该网络浏览器来与此处描述的系统和技术的实施方式交互)、或者包括这种后台部件、中间件部件、或者前端部件的任何组合的计算系统中。可以通过任何形式或者介质的数字数据通信(例如,通信网络)来将系统的部件相互连接。通信网络的示例包括:局域网(LAN)、广域网(WAN)和互联网。
计算机系统可以包括客户端和服务器。客户端和服务器一般远离彼此并且通常通过通信网络进行交互。通过在相应的计算机上运行并且彼此具有客户端-服务器关系的计算机程序来产生客户端和服务器的关系。服务器可以是云服务器,也可以为分布式系统的服务器,或者是结合了区块链的服务器。
应该理解,可以使用上面所示的各种形式的流程,重新排序、增加或删除步骤。例如,本发公开中记载的各步骤可以并行地执行也可以顺序地执行也可以不同的次序执行,只要能够实现本公开公开的技术方案所期望的结果,本文在此不进行限制。
上述具体实施方式,并不构成对本公开保护范围的限制。本领域技术人员应该明白的是,根据设计要求和其他因素,可以进行各种修改、组合、子组合和替代。任何在本公开的精神和原则之内所作的修改、等同替换和改进等,均应包含在本公开保护范围之内。
Claims (22)
1.一种训练分类模型的方法,包括:
利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失;
利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;
基于所述第一图像特征与所述第二图像特征,确定对比损失;
基于所述分类损失以及所述对比损失,更新所述分类模型的参数,得到训练完成的分类模型。
2.根据权利要求1所述的方法,其中,所述对比损失通过多分类版本的噪声对比估计损失函数确定。
3.根据权利要求1或2所述的方法,其中,所述基于所述分类损失以及所述对比损失,更新所述分类模型的参数,包括:
基于所述分类损失,利用反向传播更新所述第一网络分支的参数,并基于所述对比损失,利用反向传播更新所述第一网络分支的参数;
基于更新后的所述第一网络分支的参数,通过动量更新所述第二网络分支的参数。
4.根据权利要求1所述的方法,其中,所述利用分类模型的第一网络分支对所述第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失,包括:
将所述第一训练图像输入至分类模型的第一网络分支的分类器,得到对所述第一训练图像进行分类预测的分类预测结果;
基于所述分类预测结果与所述训练图像的标注信息,确定所述第一训练图像对应的交叉熵分类损失。
5.根据权利要求1所述的方法,其中,所述利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,以及所述第二训练图像的第二图像特征,包括:
利用所述第二网络分支中第一分支的投影器与预测器,提取所述第一训练图像的第一图像特征;
利用所述第二网络分支中第二分支的投影器,提取所述第二训练图像的第二图像特征。
6.根据权利要求1-5中任意一项所述的方法,其中,所述第一训练图像和所述第二训练图像通过对训练图像进行不同次的数据增强得到。
7.一种数据分类方法,包括:
确定待分类数据;
将所述待分类数据输入至分类模型,得到所述分类模型的输出结果;
基于所述分类模型的输出结果,确定所述待分类数据的分类结果;
其中,所述分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到;
所述第一网络分支用于对第一训练图像进行分类预测,得到所述第一训练图像的分类预测结果;
所述第二网络分支用于提取所述第一训练图像的第一图像特征以及第二训练图像的第二图像特征;
其中,所述分类损失通过所述第一训练图像的分类预测结果确定,所述对比损失基于所述第一图像特征与所述第二图像特征确定。
8.根据权利要求7所述的方法,其中,
所述第一训练图像的分类预测结果基于所述第一网络分支的分类器得到;
所述分类损失为所述第一训练图像的交叉熵分类损失,所述第一训练图像的交叉熵分类损失基于所述第一训练图像的分类预测结果以及训练图像的标注信息确定;
所述第一训练图像通过对所述训练图像进行数据增强得到。
9.根据权利要求7所述的方法,其中,
所述第一训练图像的第一图像特征利用所述第二网络分支中第一分支的投影器与预测器提取;
所述第二训练图像的第二图像特征利用所述第二网络分支中第二分支的投影器提取。
10.根据权利要求7所述的方法,其中,所述待分类数据为医学影像数据。
11.一种训练分类模型的装置,包括:
确定模块,用于利用分类模型的第一网络分支对第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失,以及,基于第一图像特征与第二图像特征,确定对比损失;
提取模块,利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,并提取第二训练图像的第二图像特征;
更新模块,用于基于所述分类损失以及所述对比损失,更新所述分类模型的参数,得到训练完成的分类模型。
12.根据权利要求11所述的装置,其中,所述对比损失通过多分类版本的噪声对比估计损失函数确定。
13.根据权利要求11或12所述的装置,其中,所述更新模块采用如下方式基于所述分类损失以及所述对比损失,更新所述分类模型的参数:
基于所述分类损失,利用反向传播更新所述第一网络分支的参数,并基于所述对比损失,利用反向传播更新所述第一网络分支的参数;
基于更新后的所述第一网络分支的参数,通过动量更新所述第二网络分支的参数。
14.根据权利要求11所述的装置,其中,所述确定模块用于采用如下方式利用分类模型的第一网络分支对所述第一训练图像进行分类预测,并基于得到的分类预测结果,确定所述第一训练图像对应的分类损失:
将所述第一训练图像输入至分类模型的第一网络分支的分类器,得到对所述第一训练图像进行分类预测的分类预测结果;
基于所述分类预测结果与所述训练图像的标注信息,确定所述第一训练图像对应的交叉熵分类损失。
15.根据权利要求11所述的装置,其中,所述提取模块用于采用如下方式利用所述分类模型的第二网络分支提取所述第一训练图像的第一图像特征,以及所述第二训练图像的第二图像特征:
利用所述第二网络分支中第一分支的投影器与预测器,提取所述第一训练图像的第一图像特征;
利用所述第二网络分支中第二分支的投影器,提取所述第二训练图像的第二图像特征。
16.一种数据分类装置,包括:
确定模块,用于确定待分类数据;
分类模块,用于将所述待分类数据输入至分类模型,得到所述分类模型的输出结果,并基于所述分类模型的输出结果,确定所述待分类数据的分类结果;
其中,所述分类模型包括第一网络分支和第二网络分支,并基于分类损失以及对比损失进行参数更新后预先训练得到;
所述第一网络分支用于对第一训练图像进行分类预测,得到所述第一训练图像的分类预测结果;
所述第二网络分支用于提取所述第一训练图像的第一图像特征以及第二训练图像的第二图像特征;
其中,所述分类损失通过所述第一训练图像的分类预测结果确定,所述对比损失基于所述第一图像特征与所述第二图像特征确定。
17.根据权利要求16所述的装置,其中,
所述第一训练图像的分类预测结果基于所述第一网络分支的分类器得到;
所述分类损失为所述第一训练图像的交叉熵分类损失,所述第一训练图像的交叉熵分类损失基于所述第一训练图像的分类预测结果以及训练图像的标注信息确定;
所述第一训练图像通过对所述训练图像进行数据增强得到。
18.根据权利要求16所述的装置,其中,
所述第一训练图像的第一图像特征利用所述第二网络分支中第一分支的投影器与预测器提取;
所述第二训练图像的第二图像特征利用所述第二网络分支中第二分支的投影器提取。
19.根据权利要求16所述的装置,其中,所述待分类数据为医学影像数据。
20.一种电子设备,包括:
至少一个处理器;以及
与所述至少一个处理器通信连接的存储器;其中,
所述存储器存储有可被所述至少一个处理器执行的指令,所述指令被所述至少一个处理器执行,以使所述至少一个处理器能够执行权利要求1-6中任一项所述的方法,或者执行权利要求7-10中任一项所述的方法。
21.一种存储有计算机指令的非瞬时计算机可读存储介质,其中,所述计算机指令用于使所述计算机执行根据权利要求1-6中任一项所述的方法,或者执行权利要求7-10中任一项所述的方法。
22.一种计算机程序产品,包括计算机程序,所述计算机程序在被处理器执行时实现根据权利要求1-6中任一项所述的方法,或者执行权利要求7-10中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210336174.XA CN114724007A (zh) | 2022-03-31 | 2022-03-31 | 训练分类模型、数据分类方法、装置、设备、介质及产品 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202210336174.XA CN114724007A (zh) | 2022-03-31 | 2022-03-31 | 训练分类模型、数据分类方法、装置、设备、介质及产品 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN114724007A true CN114724007A (zh) | 2022-07-08 |
Family
ID=82242042
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202210336174.XA Pending CN114724007A (zh) | 2022-03-31 | 2022-03-31 | 训练分类模型、数据分类方法、装置、设备、介质及产品 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN114724007A (zh) |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115082740A (zh) * | 2022-07-18 | 2022-09-20 | 北京百度网讯科技有限公司 | 目标检测模型训练方法、目标检测方法、装置、电子设备 |
CN115240036A (zh) * | 2022-09-22 | 2022-10-25 | 武汉珈鹰智能科技有限公司 | 一种裂缝图像识别网络的训练方法、应用方法及存储介质 |
CN115294396A (zh) * | 2022-08-12 | 2022-11-04 | 北京百度网讯科技有限公司 | 骨干网络的训练方法以及图像分类方法 |
CN115457329A (zh) * | 2022-09-23 | 2022-12-09 | 北京百度网讯科技有限公司 | 图像分类模型的训练方法、图像分类方法和装置 |
CN115496954A (zh) * | 2022-11-03 | 2022-12-20 | 中国医学科学院阜外医院 | 眼底图像分类模型构建方法、设备及介质 |
WO2024099032A1 (zh) * | 2022-11-09 | 2024-05-16 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置和计算机设备 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112241764A (zh) * | 2020-10-23 | 2021-01-19 | 北京百度网讯科技有限公司 | 图像识别方法、装置、电子设备及存储介质 |
US20210319266A1 (en) * | 2020-04-13 | 2021-10-14 | Google Llc | Systems and methods for contrastive learning of visual representations |
CN113516181A (zh) * | 2021-07-01 | 2021-10-19 | 北京航空航天大学 | 一种数字病理图像的表征学习方法 |
CN113627483A (zh) * | 2021-07-09 | 2021-11-09 | 武汉大学 | 基于自监督纹理对比学习的宫颈oct图像分类方法及设备 |
CN114020950A (zh) * | 2021-11-03 | 2022-02-08 | 北京百度网讯科技有限公司 | 图像检索模型的训练方法、装置、设备以及存储介质 |
-
2022
- 2022-03-31 CN CN202210336174.XA patent/CN114724007A/zh active Pending
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20210319266A1 (en) * | 2020-04-13 | 2021-10-14 | Google Llc | Systems and methods for contrastive learning of visual representations |
CN112241764A (zh) * | 2020-10-23 | 2021-01-19 | 北京百度网讯科技有限公司 | 图像识别方法、装置、电子设备及存储介质 |
CN113516181A (zh) * | 2021-07-01 | 2021-10-19 | 北京航空航天大学 | 一种数字病理图像的表征学习方法 |
CN113627483A (zh) * | 2021-07-09 | 2021-11-09 | 武汉大学 | 基于自监督纹理对比学习的宫颈oct图像分类方法及设备 |
CN114020950A (zh) * | 2021-11-03 | 2022-02-08 | 北京百度网讯科技有限公司 | 图像检索模型的训练方法、装置、设备以及存储介质 |
Non-Patent Citations (2)
Title |
---|
ASHRAFUL ISLAM 等: "A Broad Study on the Transferability of Visual Representations with Contrastive Learning", 《ARXIV:2103.13517V3 [CS.CV]》, 16 August 2021 (2021-08-16), pages 1 - 18 * |
JEAN-BASTIEN GRILL 等: "Bootstrap Your Own Latent A New Approach to Self-Supervised Learning", 《ARXIV:2006.07733V3 [CS.LG]》, 10 September 2020 (2020-09-10), pages 1 - 35 * |
Cited By (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115082740A (zh) * | 2022-07-18 | 2022-09-20 | 北京百度网讯科技有限公司 | 目标检测模型训练方法、目标检测方法、装置、电子设备 |
CN115082740B (zh) * | 2022-07-18 | 2023-09-01 | 北京百度网讯科技有限公司 | 目标检测模型训练方法、目标检测方法、装置、电子设备 |
CN115294396A (zh) * | 2022-08-12 | 2022-11-04 | 北京百度网讯科技有限公司 | 骨干网络的训练方法以及图像分类方法 |
CN115294396B (zh) * | 2022-08-12 | 2024-04-23 | 北京百度网讯科技有限公司 | 骨干网络的训练方法以及图像分类方法 |
CN115240036A (zh) * | 2022-09-22 | 2022-10-25 | 武汉珈鹰智能科技有限公司 | 一种裂缝图像识别网络的训练方法、应用方法及存储介质 |
CN115457329A (zh) * | 2022-09-23 | 2022-12-09 | 北京百度网讯科技有限公司 | 图像分类模型的训练方法、图像分类方法和装置 |
CN115457329B (zh) * | 2022-09-23 | 2023-11-10 | 北京百度网讯科技有限公司 | 图像分类模型的训练方法、图像分类方法和装置 |
CN115496954A (zh) * | 2022-11-03 | 2022-12-20 | 中国医学科学院阜外医院 | 眼底图像分类模型构建方法、设备及介质 |
WO2024099032A1 (zh) * | 2022-11-09 | 2024-05-16 | 腾讯科技(深圳)有限公司 | 图像分类方法、装置和计算机设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN114724007A (zh) | 训练分类模型、数据分类方法、装置、设备、介质及产品 | |
WO2020253127A1 (zh) | 脸部特征提取模型训练方法、脸部特征提取方法、装置、设备及存储介质 | |
CN113326764A (zh) | 训练图像识别模型和图像识别的方法和装置 | |
CN110490239B (zh) | 图像质控网络的训练方法、质量分类方法、装置及设备 | |
EP3664019A1 (en) | Information processing device, information processing program, and information processing method | |
CN111160407B (zh) | 一种深度学习目标检测方法及系统 | |
CN112784778B (zh) | 生成模型并识别年龄和性别的方法、装置、设备和介质 | |
CN113222149A (zh) | 模型训练方法、装置、设备和存储介质 | |
JP2023547010A (ja) | 知識の蒸留に基づくモデルトレーニング方法、装置、電子機器 | |
CN117611932B (zh) | 基于双重伪标签细化和样本重加权的图像分类方法及系统 | |
CN112508126A (zh) | 深度学习模型训练方法、装置、电子设备及可读存储介质 | |
CN114549985A (zh) | 一种基于自监督对比学习的目标检测方法及系统 | |
CN112650885A (zh) | 视频分类方法、装置、设备和介质 | |
CN115457329B (zh) | 图像分类模型的训练方法、图像分类方法和装置 | |
EP4343616A1 (en) | Image classification method, model training method, device, storage medium, and computer program | |
CN115797637A (zh) | 基于模型间和模型内不确定性的半监督分割模型 | |
CN114462598A (zh) | 深度学习模型的训练方法、确定数据类别的方法和装置 | |
CN113837965B (zh) | 图像清晰度识别方法、装置、电子设备及存储介质 | |
JP6600288B2 (ja) | 統合装置及びプログラム | |
CN113657248A (zh) | 人脸识别模型的训练方法、装置及计算机程序产品 | |
WO2024060839A1 (zh) | 对象操作方法、装置、计算机设备以及计算机存储介质 | |
CN117456272A (zh) | 一种基于对比学习的自监督异常检测方法 | |
CN115294405B (zh) | 农作物病害分类模型的构建方法、装置、设备及介质 | |
Fan et al. | [Retracted] Accurate Recognition and Simulation of 3D Visual Image of Aerobics Movement | |
CN115100731B (zh) | 一种质量评价模型训练方法、装置、电子设备及存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |