CN117371511A - 图像分类模型的训练方法、装置、设备及存储介质 - Google Patents
图像分类模型的训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- CN117371511A CN117371511A CN202311443497.XA CN202311443497A CN117371511A CN 117371511 A CN117371511 A CN 117371511A CN 202311443497 A CN202311443497 A CN 202311443497A CN 117371511 A CN117371511 A CN 117371511A
- Authority
- CN
- China
- Prior art keywords
- image
- domain image
- model
- target domain
- features
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 164
- 238000000034 method Methods 0.000 title claims abstract description 143
- 238000013145 classification model Methods 0.000 title claims abstract description 49
- 238000013508 migration Methods 0.000 claims abstract description 45
- 230000005012 migration Effects 0.000 claims abstract description 45
- 230000008569 process Effects 0.000 claims abstract description 34
- 230000002708 enhancing effect Effects 0.000 claims abstract description 24
- 230000001575 pathological effect Effects 0.000 claims abstract description 13
- 239000013598 vector Substances 0.000 claims description 128
- 230000006978 adaptation Effects 0.000 claims description 34
- 230000009466 transformation Effects 0.000 claims description 24
- 239000011159 matrix material Substances 0.000 claims description 22
- 238000012216 screening Methods 0.000 claims description 17
- 230000003321 amplification Effects 0.000 claims description 15
- 238000003199 nucleic acid amplification method Methods 0.000 claims description 15
- 230000003044 adaptive effect Effects 0.000 claims description 13
- 238000004590 computer program Methods 0.000 claims description 10
- 238000013507 mapping Methods 0.000 claims description 7
- 238000013473 artificial intelligence Methods 0.000 abstract description 12
- 238000012545 processing Methods 0.000 description 64
- 238000005516 engineering process Methods 0.000 description 18
- 206010028980 Neoplasm Diseases 0.000 description 14
- 210000004027 cell Anatomy 0.000 description 14
- 238000010586 diagram Methods 0.000 description 13
- 238000013136 deep learning model Methods 0.000 description 12
- 201000011510 cancer Diseases 0.000 description 10
- 239000000284 extract Substances 0.000 description 10
- 208000002154 non-small cell lung carcinoma Diseases 0.000 description 10
- 238000000605 extraction Methods 0.000 description 9
- 208000006265 Renal cell carcinoma Diseases 0.000 description 8
- 230000007170 pathology Effects 0.000 description 8
- 230000006870 function Effects 0.000 description 6
- 238000010801 machine learning Methods 0.000 description 5
- 238000013528 artificial neural network Methods 0.000 description 4
- 238000013135 deep learning Methods 0.000 description 4
- 230000004069 differentiation Effects 0.000 description 4
- 238000011156 evaluation Methods 0.000 description 4
- 206010006187 Breast cancer Diseases 0.000 description 3
- 208000026310 Breast neoplasm Diseases 0.000 description 3
- 238000013459 approach Methods 0.000 description 3
- 230000009286 beneficial effect Effects 0.000 description 3
- 230000004927 fusion Effects 0.000 description 3
- 238000010606 normalization Methods 0.000 description 3
- 238000013526 transfer learning Methods 0.000 description 3
- 208000010507 Adenocarcinoma of Lung Diseases 0.000 description 2
- 208000008839 Kidney Neoplasms Diseases 0.000 description 2
- 206010058467 Lung neoplasm malignant Diseases 0.000 description 2
- 206010038389 Renal cancer Diseases 0.000 description 2
- 238000004458 analytical method Methods 0.000 description 2
- 230000033228 biological regulation Effects 0.000 description 2
- 210000003855 cell nucleus Anatomy 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000012790 confirmation Methods 0.000 description 2
- 239000003623 enhancer Substances 0.000 description 2
- 238000002474 experimental method Methods 0.000 description 2
- 210000003734 kidney Anatomy 0.000 description 2
- 201000010982 kidney cancer Diseases 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 201000005249 lung adenocarcinoma Diseases 0.000 description 2
- 201000005202 lung cancer Diseases 0.000 description 2
- 208000020816 lung neoplasm Diseases 0.000 description 2
- 201000005243 lung squamous cell carcinoma Diseases 0.000 description 2
- 210000004940 nucleus Anatomy 0.000 description 2
- 238000006467 substitution reaction Methods 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 238000000844 transformation Methods 0.000 description 2
- 210000004881 tumor cell Anatomy 0.000 description 2
- 208000029729 tumor suppressor gene on chromosome 11 Diseases 0.000 description 2
- 206010055113 Breast cancer metastatic Diseases 0.000 description 1
- 201000009030 Carcinoma Diseases 0.000 description 1
- 208000030808 Clear cell renal carcinoma Diseases 0.000 description 1
- 239000013255 MILs Substances 0.000 description 1
- 241001465754 Metazoa Species 0.000 description 1
- 238000013475 authorization Methods 0.000 description 1
- 230000006399 behavior Effects 0.000 description 1
- 201000010240 chromophobe renal cell carcinoma Diseases 0.000 description 1
- 206010073251 clear cell renal cell carcinoma Diseases 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013527 convolutional neural network Methods 0.000 description 1
- 125000004122 cyclic group Chemical group 0.000 description 1
- 230000001419 dependent effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 238000001514 detection method Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000004043 dyeing Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000002546 full scan Methods 0.000 description 1
- 238000003384 imaging method Methods 0.000 description 1
- 230000006698 induction Effects 0.000 description 1
- 230000003993 interaction Effects 0.000 description 1
- 238000013140 knowledge distillation Methods 0.000 description 1
- 230000003902 lesion Effects 0.000 description 1
- 238000011068 loading method Methods 0.000 description 1
- 210000004072 lung Anatomy 0.000 description 1
- 238000002156 mixing Methods 0.000 description 1
- 239000000203 mixture Substances 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000003287 optical effect Effects 0.000 description 1
- 230000008447 perception Effects 0.000 description 1
- 208000019465 refractory cytopenia of childhood Diseases 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 230000004044 response Effects 0.000 description 1
- 239000007787 solid Substances 0.000 description 1
- 238000010186 staining Methods 0.000 description 1
- 210000001519 tissue Anatomy 0.000 description 1
- 230000001131 transforming effect Effects 0.000 description 1
- 230000007704 transition Effects 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/09—Supervised learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
- G06V10/7753—Incorporation of unlabelled data, e.g. multiple instance learning [MIL]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
Abstract
本申请公开了一种图像分类模型的训练方法、装置、设备及存储介质,属于人工智能领域。该方法包括:获取源域图像和目标域图像;使用所述目标域图像增强所述源域图像,得到增强图像;使用所述源域图像和所述增强图像训练所述教师模型,得到训练后的教师模型;使用所述训练后的教师模型监督所述学生模型的训练过程,所述学生模型是使用所述目标域图像训练的病理图像分类模型。本申请中,通过利用源域图像和增强图像对教师模型进行训练,再对学生模型进行知识迁移,这种方式可以提高学生模型在图像分类中的性能。
Description
技术领域
本申请涉及人工智能技术领域,特别涉及一种图像分类模型的训练方法、装置、设备及存储介质。
背景技术
知识蒸馏技术能够借助预训练的教师模型,来增强学生模型在相关领域的训练,从而有效地将特定知识传递给学生模型。
在图像分类领域,可以使用源域图像集训练一个教师模型,使用目标域图像集训练一个学生模型,使用教师模型来监督学生模型的训练过程。
在源域和目标域的特征差异较大的情况下,上述方法训练得到的学生模型的精度较差,教师模型的迁移性不高。
发明内容:
本申请提供了一种图像分类模型的训练方法、装置、设备及存储介质,所述技术方案如下:
根据本申请的一方面,提供了一种图像分类模型的训练方法,所述方法由部署有教师模型和学生模型的计算机设备执行,所述方法包括:
获取源域图像和目标域图像;
使用所述目标域图像增强所述源域图像,得到增强图像;
使用所述源域图像和所述增强图像训练所述教师模型,得到训练后的教师模型;
使用所述训练后的教师模型监督所述学生模型的训练过程,所述学生模型是使用所述目标域图像训练的病理图像分类模型。
根据本申请的另一方面,提供了一种图像分类模型的训练装置,所述装置包括:
获取模块,用于获取源域图像和目标域图像;
增强模块,用于使用所述目标域图像增强所述源域图像,得到增强图像;
教师训练模块,用于使用所述源域图像和所述增强图像训练所述教师模型,得到训练后的教师模型;
学生训练模块,用于使用所述训练后的教师模型监督所述学生模型的训练过程,所述学生模型是使用所述目标域图像训练的病理图像分类模型。
根据本申请的另一方面,提供了一种计算机设备,所述计算机设备包括:处理器和存储器,所述存储器中存储有至少一段程序;所述处理器,用于执行如上所述的图像分类模型的训练方法。
根据本申请的另一方面,提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有可执行指令,所述可执行指令由处理器加载并执行以实现如上所述的图像分类模型的训练方法。
根据本申请的另一方面,提供了一种计算机程序产品,所述计算机程序产品包括计算机指令,所述计算机指令存储在计算机可读存储介质中,处理器从所述计算机可读存储介质读取并执行所述计算机指令,以实现上述如上所述的图像分类模型的训练方法。
本申请提供的技术方案带来的有益效果至少包括:
计算机设备中部署有教师模型和学生模型,教师模型对应着源域图像,学生模型对应目标域图像,分别获取源域图像和目标域图像,使用目标域图像增强源域图像,得到增强图像,对教师模型中的源域图像和增强图像进行训练,可以得到训练后的教师模型。通过利用源域图像和增强图像对教师模型进行训练,其中增强图像是利用目标域图像对源域图像进行增强得到的,因此教师模型在学习过程中能够学习到目标域的特征,在源域对目标域进行迁移时,迁移性会变得更好;再使用训练后的教师模型监督学生模型的训练,从教师模型中提取和学生模型关联性高的特征向量,对学生模型进行训练,可以提高学生模型预测结果的准确性。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了本申请一个示例性实施例提供的教师模型训练过程的示意图;
图2示出了本申请一个示例性实施例提供的学生模型训练过程的示意图;
图3示出了本申请一个示例性实施例提供的计算机系统的示意图;
图4示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的流程图;
图5示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的流程图;
图6示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的流程图;
图7示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的示意图;
图8示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的流程图;
图9示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的流程图;
图10示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的流程图;
图11示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的示意图;
图12示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的示意图;
图13示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的示意图;
图14示出了本申请一个示例性实施例提供的一种图像分类模型的训练方法的示意图;
图15示出了本申请一个示例性实施例提供的一种图像分类模型的训练装置的结构框图;
图16示出了本申请一个示例性实施例提供的一种计算机设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本申请一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本申请的一些方面相一致的装置和方法的例子。
在本申请使用的术语是仅仅出于描述特定实施例的目的,而非旨在限制本申请。在本申请和所附权利要求书中使用的单数形式的“一种”、“所述”和“该”也旨在包括多数形式,除非上下文清楚地表示其他含义。还应当理解,本文中使用的术语“和/或”是指并包含一个或多个相关联的列出项目的任何或所有可能组合。
需要进行说明的是,本申请在收集用户的相关数据之前以及在收集用户的相关数据的过程中,都可以显示提示界面、弹窗或输出语音提示信息,该提示界面、弹窗或语音提示信息用于提示用户当前正在搜集其相关数据,使得本申请仅仅在获取到用户对该提示界面或弹窗发出的确认操作后,才开始执行获取用户相关数据的相关步骤,否则(即未获取到用户对该提示界面或弹窗发出的确认操作时),结束获取用户相关数据的相关步骤,即不获取用户的相关数据。换句话说,本申请所采集的所有用户数据都是在用户同意并授权的情况下进行采集的,且相关用户数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。
首先,对本申请涉及的相关名词做出介绍:
人工智能(Artificial Intelligence,AI):是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、预训练模型技术、操作/交互系统、机电一体化等技术。其中,预训练模型又称大模型、基础模型,经过微调后可以广泛应用于人工智能各大方向下游任务。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
机器学习(Machine Learning,ML):是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。预训练模型是深度学习的最新发展成果,融合了以上技术。
域:对一个训练集中样本子集产生整体性分布偏差的因素,比如在病理图像分类中,源域图像可以是肿瘤细胞的图像,目标域图像可以是癌症细胞的图像。
迁移学习:当数据中有域的差别时,构建学习系统来处理域的差异。
源域:即源数据集,源域是基于多个源域图像构建的数据集。
目标域:即目标数据集,目标域图像是基于多个目标域图像构建的数据集。
教师模型:在源数据集上进行预训练并具有较高性能的模型。在学生模型的训练中,它可以作为指导和参考,帮助学生模型学习泛化性更好的特征。
学生模型:在目标数据集上进行训练的模型,它通过从教师模型中获取知识和指导来提升自己的性能。学生模型的目标是在目标任务或领域上表现出与教师模型相似甚至更好的性能,从而实现知识迁移的目的。
多实例学习(Multiple Instance Learning,MIL):是一种监督学习的方法,其中训练样本被组织成包(bag)和实例(instance)的集合。在多实例学习中,每个训练样本被视为一个包,而包中的实例则是样本的子集。例如在病理图像分类中,一张病理图像代表一个组织的多个切片,每个切片作为一个实例。
本申请提供了一种图像分类模型的训练方法的方案。首先对教师模型进行训练,获取目标域图像和源域图像,使用目标域图像增强源域图像,得到增强图像,将源域图像和增强图像输入教师模型进行训练,可以得到训练后的教师模型;然后再使用训练后的教师模型监督学生模型的训练,教师模型在训练的过程中能够学习到目标域的特征,因此在源域对目标域进行迁移时,迁移性较好。
以该方法应用于终端设备进行举例说明。以下实施例对该方法的步骤进行简述。
相关技术中,采用源域图像对教师模型进行训练,本申请实施例中中采用源域和增强图像对教师模型进行训练。
教师模型的训练过程10:
结合参考图1,为了提升教师模型11的迁移性,使用源域图像12和增强图像14共同对教师模型11进行训练。其中,增强图像14包含了目标域图像13的特征,教师模型11在训练的过程中学习了目标域图像13的特征,在后续教师模型11进行迁移学习时,教师模型11会提升对目标域图像13的预测准确性。
首先获取源域图像12和目标域图像13,然后在目标域图像13中查找和源域图像12相似度符合筛选条件的相似目标域图像,其中相似度可以根据特征向量进行计算。从源域图像12中提取源域图像特征,再从目标域图像13中提取和源域图像特征距离最近的目标域图像特征,将该目标域图像特征作为相似目标域图像特征。
使用相似目标域图像对源域图像12进行增强,即使用相似目标域图像特征对源域图像特征进行增强。增强的方式有多种,示例性的,可以将相似目标域图像特征加入到源域图像特征中,或将相似目标域图像特征替换源域图像特征,得到增强图像14的增强图像特征。
将源域图像12和增强图像14一起到输入教师模型11中,对教师模型11进行训练,教师模型11可以同时学习到源域图像12和目标域图像13的相关信息,训练后的教师模型11后续进行迁移学习时,对目标域图像13的预测的准确性会提升。
训练后的教师模型11是一个优化后的模型,它可以对输入的图像进行特征提取和分类预测。训练后的教师模型可以提取教师图像特征22,教师图像特征22是指通过训练后的教师模型11对输入的图像进行处理,从输入的图像中提取出的具有代表性的特征向量;分类预测是指教师模型根据提取的特征向量,对目标域图像13进行分类或预测,可以得到教师分类结果15。
将训练后的教师模型11得出的教师分类结果15和真实的标签16进行比较,可以得到一个误差,这个误差反映了教师模型11预测的准确性;可以根据误差反向传播算法调整教师模型11的模型参数,以提高教师模型预测的准确性。
学生模型的训练过程20:
结合参考图2,在一些实施例中,使用学生模型21中的目标域图像13,增强教师模型11中的源域图像12,对教师模型11进行训练,可以得到训练后的教师模型11,再使用训练后的教师模型11监督学生模型21的训练,这是一种迁移学习的方法,可以将教师模型11的知识和经验传递给学生模型21,从而提高学生模型21在目标域图像13上的性能。
在一些实施例中,教师模型11的预测任务和学生模型21的预测任务可能存在差异,因此,教师模型11提取出的教师图像特征22不适合直接输入到学生模型21中,需要对该教师图像特征22进行多头自适应处理,得到和学生模型21的预测任务相关的图像特征。
可选地,通过多头特征适应网络对教师图像特征22进行多头自适应处理,可以将教师图像特征22中与学生模型21预测任务相关的图像特征进行线性变换或者放大处理。其中,通过多头特征适应网络中的多头放大层24进行放大处理,通过多头特征适应网络的线性变换层25进行线性变换处理,将进行多大处理和线性变换后得到的特征向量再通过加权求和层26进行求和处理,可以得到和学生模型21的预测任务相关的图像特征,即可得到用于迁移到学生模型21的迁移特征组合,将该迁移特征组合输入学生模型21,可对学生模型21进行训练。
在一些实施例中,训练后的教师模型可以得到一个教师模型的预测结果,即教师分类结果15,提取教师模型中的教师图像特征22进行多头自适应处理得到迁移特征组合,使用迁移特征组合对学生模型21进行训练。学生模型21对输入的图像进行预测,可得到学生模型21的预测结果,即学生分类结果23。
将教师模型11的教师分类结果15作为真实的标签,将学生模型21的学生分类结果23和教师模型11的教师分类结果15进行比较,可以得到一个误差,这个误差反映了学生模型21预测的准确性,根据误差反向传播算法调整学生模型21的模型参数,以提高学生模型预测的准确性。
图3示出了本申请一个示例性实施例提供的计算机系统的示意图。
该计算机系统包括模型训练系统10和终端设备20,终端设备20和模型训练系统10之间通过网络40进行通信连接,终端设备20为用户侧设备,需要进行模型训练的用户可以通过终端设备20登录模型训练系统10。模型训练系统10中包括多个数据处理单元101,例如图1所示的数据处理单元101-1、数据处理单元101-2等。模型训练系统10中还需要设置至少一个控制单元102,例如图1所示的控制单元102-1、控制单元102-2等,以及至少一个用于存储数据的存储单元103,例如存储单元103-1等。模型训练系统10通过网络接口单元105和终端设备20通信,网络接口单元105例如图1所示的网络接口单元105-1、网络接口单元105-2。模型训练系统10中还包括至少一个处理单元106,各单元之间可以通过总线104通信连接。
图3所示的应用场景中,各单元距离较近,可以通过总线方式连接,如果各单元分布在不同地域,各单元之间也可以通过网络实现异地通信连接。图1所示的应用场景中,处理单元106可以通过总线104对模型训练系统10中的其它单元(包括数据处理单元101、控制单元102、网络接口单元105和存储单元103等)进行控制。一个控制单元102可执行处理单元106下达的一个训练任务,一个控制单元102可同时调用多个数据处理单元101,以协调多个数据处理单元101共同完成一个训练任务,处理单元106和多个控制单元102可以合并设置。存储单元103用于存储样本数据和深度学习模型,还可以用来存储模型训练系统10中需要存储的其它数据,如训练过程中产生的一些中间数据、用户信息、用户对应的训练任务以及系统参数等等。
模型训练系统10通过终端设备20向用户提供上传深度学习模型和样本数据的接口,以使用户通过终端设备20上显示的接口将需要训练的深度学习模型和训练用的样本数据上传到模型训练系统10,处理单元106将上传的深度学习模型和样本数据存储入存储单元103中。用户可以选择使用模型训练系统10提供的样本数据和深度学习模型。处理单元106还需要为训练任务配置用户上传或选择的样本数据和深度学习模型的存储地址。可选地,样本数据以及配置参数也可以由训练系统10提供,不需要用户输入。
处理单元106配置好训练任务的各项参数后,将配置参数导入对应的控制单元102中,由控制单元102调用相应的数据处理单元101,并控制数据处理单元101利用样本数据对深度学习模型进行训练。具体地,数据处理单元101根据配置的存储地址,从对应的存储单元103中读取深度学习模型和样本数据,利用样本数据对深度学习模型进行训练。完成训练后,将训练好的深度学习模型存储到对应的存储单元103中,同时,处理单元106结束该训练任务,释放对应的数据处理单元101和控制单元102。完成训练后,处理单元106可通过网络向终端设备20发送完成训练的提示信息,用户可通过终端设备20再次登录模型训练系统10,通过网络接口单元105访问用户账号对应的存储单元103,获取训练好的深度学习模型。
需要说明的是,实际应用中,一个数据处理单元可以对应一块芯片,或者,多个数据处理单元可集成在同一芯片中,集成在同一芯片中的各个数据处理单元间可相互独立地执行不同的计算任务。同一芯片中可集成同种数据处理精度的数据处理单元,同一芯片也可以集成多种数据处理精度的数据处理单元。以V100芯片为例,其集成有多个半精度的数据处理单元和多个单精度的数据处理单元。本申请实施例中,集成有数据处理单元101的芯片不限于V100芯片,还可以是任意一种具备训练深度学习模型的计算能力的芯片,如神经网络处理器(Neural network Processing Unit,NPU)、张量处理器(Tensor ProcessingUnit,TPU)或现场可编程逻辑门阵列(Field-Programmable Gate Array,FPGA)等。
实际应用中,控制单元102和处理单元106可以是CPU(中央处埋器)、专用集成电路(Application Specific Integrated Circuit,ASIC)、FPGA或复杂可编程逻辑器件(Complex Programmable Logic Device,CPLD)等。控制单元102也可以是处理单元106中设置的多个独立的进程,处理单元106为每个训练任务分配一个进程,以控制多个数据处理单元101执行训练任务。
终端设备20和模型训练系统10之间通过网络进行通信连接,该网络可以为局域网、广域网等。用户均可以通过在该终端设备20中安装的登录模型训练系统10的客户端进行登录,或者可通过终端设备20中的浏览器访问模型训练系统10的主页,并通过模型训练系统10的主页登录模型训练系统10。模型训练系统10可以部署在一台服务器30、若干台服务器组成的服务器集群或云计算中心。
下面,对本申请实施例中提供的图像分类模型的训练方法进行说明。
图4示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图,以该方法用于终端设备进行举例说明,该方法包括如下步骤中的至少部分步骤:
步骤210:获取源域图像和目标域图像;
源域图像和目标域图像指的是两个不同域的图像数据集。
其中,源域图像是用于训练教师模型的图像数据集。源域图像上通常带有标签,教师模型会在源域图像上进行训练并学习出对应的特征。
目标域图像是用于训练学生模型的图像数据集。目标域图像是指学生模型能应用并进行预测的图像。
在一些实施例中,对于图像分类任务,会对全部或部分图像进行分配标签操作,标签指的是用于描述图像所属类别或类别标识的信息。部分图像会被标注一个或多个标签,用于标识其所属的类别或属性。
在一些实施例中,可以通过公开数据集的方法获取源域图像和目标域图像,即从公开的图像数据集中获取源域图像和目标域图像;源域图像和目标域图像用于训练和评估图像分类模型,公开数据集中通常提供了图像和对应的标签,可以方便地进行数据加载和处理。
步骤220:使用目标域图像增强源域图像,得到增强图像;
在图像分类中,增强图像是指通过对源域图像进行一系列变换和处理,生成具有多样性和泛化能力的新图像。可选地,可以通过使用目标域图像增强源域图像。
在一些实施例中,预训练的教师模型在源域图像上表现良好,但在目标域图像上可能不够适应。通过使用目标域图像增强源域图像,教师模型可以更好地了解目标任务的特点和要求,从而提升其在目标域图像上的性能。
示例性的,假设有一个预训练的教师模型,该教师模型在乳腺癌病理图像分类任务上表现良好。然而,当本申请将该教师模型应用于其他类型的癌症病理图像时,它的性能可能下降。可以使用其他类型癌症的图像数据进行数据增强,让教师模型更好地适应目标任务。
在一些实施例中,可以通过在源域图像训练教师模型时将目标域图像的目标数据纳入训练,将无标注的目标域图像中的目标数据加入教师模型的训练中,以此提供给教师模型一个明确的迁移目标。
可选地,数据标注可以通过人工标注的方式进行。人工标注是指由人工根据预定义的标注规则为图像分配标签。无标注指的是没有经过标注或者没有标签的数据。这些数据通常是从目标域图像中收集而来,但没有与之对应的标签信息。
步骤230:使用源域图像和增强图像训练教师模型,得到训练后的教师模型;
训练是指通过将模型与数据进行交互,使其逐渐调整和优化模型参数的过程。在深度学习中,训练是指使用已知输入和期望输出(标签)的数据集来调整教师模型的权重,以使教师模型能够更好地拟合数据。
在一些实施例中,使用源域图像和增强图像共同训练教师模型,训练教师模型如下:首先准备源域图像和增强图像的数据集;然后可选择适合任务的教师模型,如卷积神经网络、循环神经网络等模型架构;对教师模型的参数进行初始化;然后将源域图像输入到教师模型中,通过教师模型的前向传播过程,得到教师模型对源域图像的预测输出结果;然后将增强图像输入到教师模型中,通过教师模型的前向传播过程,得到教师模型对增强图像的预测输出结果;将教师模型的预测输出结果和真实标签进行比较,可以得到一个误差值,这个误差值反映了教师模型预测的准确性;可以根据误差反向传播算法调整教师模型的模型参数,以提高教师模型预测的准确性。通过训练教师模型,教师模型会逐渐学习到源域图像和增强图像之间的关系,训练后的教师模型可以用于后续的预测、分类或其他相关任务。
步骤240:使用训练后的教师模型监督学生模型的训练过程。
其中,学生模型是使用目标域图像训练的图像分类模型。
使用训练后的教师模型来监督学生模型的训练过程,这是一种迁移学习的方法,可以将教师模型的知识和经验传递给学生模型,从而提高学生模型在目标域图像上的性能。
综上所述,本实施例提供的方法,计算机设备中部署有教师模型和学生模型,教师模型对应着源域图像,学生模型对应目标域图像,分别获取源域图像和目标域图像,使用目标域图像增强源域图像,得到增强图像,对教师模型中的源域图像和增强图像进行训练,可以得到训练后的教师模型。通过利用源域图像和增强图像对教师模型进行训练,其中增强图像是利用目标域图像对源域图像进行增强得到的,因此教师模型在学习过程中能够学习到目标域的特征,在源域对目标域进行迁移时,迁移性会变得更好;再使用训练后的教师模型监督学生模型的训练,从教师模型中提取和学生模型关联性高的特征向量,对学生模型进行训练,可以提高学生模型预测结果的准确性。
本申请实施例包括两个阶段:
阶段一:教师模型的训练过程;
阶段二:学生模型的训练过程。
教师模型的训练过程:图5至图7是教师模型的训练过程。
在基于图4实施例的基础上,图5示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图。在本实施例中,步骤220替换实现为步骤221和步骤222:
步骤221:在目标域图像中查找与源域图像的相似度符合筛选条件的相似目标域图像;
在目标域图像中,根据相似度筛选条件,找到与源域图像相似度符合筛选条件的目标域图像。可选地,相似度可以根据目标域图像的特征向量进行计算和比较。
其中,筛选条件有多种,可选地,筛选相似目标域图像时,可以通过相似度阈值进行筛选:设置一个相似度阈值,只选择与源域图像相似度高于该阈值的目标域图像;筛选相似目标域图像时,可以通过相似度进行筛选:对于源域图像和目标域图像,将目标域图像中和源域图像相似度最大的筛选出来。上述对筛选相似目标域图像进行示例性说明,本申请对筛选条件不作限定。
示例性的,在病理图像分类中,在使用目标域图像增强源域图像得到增强图像时,可以主动检索目标域图像中与源域相似的图像部分,将目标域图像中和源域图像相似的图像部分增强到源域图像中。以源域图像是癌细胞,目标域图像是肿瘤为例进行举例说明,首先,主动检索可以在大多数情况下防止肿瘤和正常图像的混合,因为癌细胞之间的相似性高于正常细胞和肿瘤细胞之间的相似性。其次,将目标域图像的特征整合至源域图像,这种方式增强了特征在特定方向(从源域到目标域)上的可迁移性。通过目标感知的数据增强,教师模型从源域中提取到了更多源域和目标域间共有的知识。
在一个可选的示例中,对于至少两个源域图像特征中的任一源域图像特征,在至少两个目标域图像特征中查找与源域图像特征距离最近的目标域特征,作为相似目标域图像特征。
其中,源域图像特征是指从源域图像中提取的表示图像内容的特征向量。这些特征向量可以是通过图像特征提取算法得到的,用于描述源域图像的局部或全局特征。目标域图像特征则是指从目标域图像中提取的类似的特征向量。目标域图像可以是与源域图像相关的、具有相似内容或样式的图像。
示例性的,假设有两个源域图像特征,分别为A和B,以及两个目标域图像特征,分别为X和Y。源域图像特征A与目标域图像特征X的距离为2,与目标域图像特征Y的距离为5。源域图像特征B与目标域图像特征X的距离为3,与目标域图像特征Y的距离为1。选择源域图像特征A作为给定的源域图像特征。然后,在目标域图像特征X和Y中找到与A距离最近的目标域图像特征。在这种情况下,目标域图像特征X与源域图像特征A的距离最近,因此本申请实施例选择目标域图像特征X作为与源域图像特征A距离最近的目标域特征。选择源域图像特征A作为源域图像特征。然后,在目标域图像特征X和Y中找到与A距离最近的目标域图像特征。在这种情况下,目标域图像特征X与源域图像特征A的距离最近,因此本申请实施例选择目标域图像特征X作为与给定源域图像特征A距离最近的目标域特征。
步骤222:使用相似目标域图像增强源域图像,得到增强图像。
在一些实施例中,使用相似目标域图像增强源域图像,增强的方式有多种,可选地,目标域图像可以通过风格迁移的方式增强源域图像:
·使用目标域图像的样式特征,将其应用到源域图像上,使得目标域图像和源域图像具有相似的样式特征;
·目标域图像可以通过内容替换的方式增强源域图像:将目标域图像中的某些内容或特征与源域图像中的对应内容或特征进行替换,从而使得源域图像具有与目标域图像相似的特征;
·目标域图像可以通过特征融合的方式增强源域图像:将源域图像特征和目标域图像特征进行融合,以增强源域图像特征,可以使用图像融合算法,将源域图像特征和目标域图像特征进行加权融合,从而得到增强图像。
在一个可选的示例中,使用相似目标域图像特征增强源域图像特征,得到增强图像的增强图像特征。
在一些实施例中,利用相似目标域图像特征来增强源域图像特征,从而得到一个更好的增强图像特征。这样的增强图像特征可以用于各种图像处理任务,如图像分类、目标检测、图像生成等。
在一个可选的示例中,使用相似目标域图像特征对源域图像特征采用如下增强方式中的至少一种,得到增强图像的增强图像特征:
·将相似目标域图像特征加入到至少两个源域图像特征中;
·使用相似目标域图像特征替换源域图像特征;
·计算源域图像特征和相似目标域图像特征的插值图像特征,使用插值图
像特征替换源域图像特征;
·使用相似目标域图像特征的协方差矩阵更新源域图像特征。方式如下:
方式一(追加):将相似目标域图像特征加入到至少两个源域图像特征中:
在一些实施例中,将相似目标域图像特征加入到至少两个源域图像特征中,意味着将目标域图像的特征与源域图像的特征进行融合或组合,以增强源域图像特征。
示例性的,假设有一组源域图像,每个源域图像都有对应的特征向量,同时还有一组相似的目标域图像,目标域图像也有对应的特征向量。现在想要将相似目标域图像的特征加入到源域图像的特征中,可以将相似目标域图像的特征向量与源域图像的特征向量进行拼接或连接。例如源域图像特征向量的维度为n,相似目标域图像特征向量的维度也为n,可以将相似目标域图像的特征向量与源域图像的特征向量进行拼接,形成一个新的特征向量,维度为2n。通过这种方式可以将相似目标域图像的特征加入到了源域图像的特征中。
方式二(替换):使用相似目标域图像特征替换源域图像特征:
在一些实施例中,使用相似目标域图像特征替换源域图像特征,可以直接使用相似目标域图像特征替换源域图像特征。这种替换方式可以将目标域图像的特征直接应用到源域图像上,以增强源域图像特征。
示例性的,假设有两个源域图像A和B,每个源域图像都有一个特征向量。同时,有一个相似的目标域图像C,目标域图像C也有一个特征向量。可以将源域图像A的特征向量替换为相似目标域图像C的特征向量,或将源域图像B的特征向量替换为相似目标域图像C的特征向量。
方式三(插值):计算源域图像特征和相似目标域图像特征的插值图像特征,使用插值图像特征替换源域图像特征:
在一些实施例中,计算源域图像特征和相似目标域图像特征的插值图像特征,使用插值图像特征替换源域图像特征,可以通过计算源域图像特征和相似目标域图像特征之间的插值,得到一个新的特征向量,然后,可以使用这个插值图像特征替换源域图像特征,以获得增强源域图像特征。
示例性的,假设有一个源域图像A的特征表示为Fa,以及一个相似目标域图像B的特征表示为Fb。将A的特征过渡到B的特征,生成一个中间状态的特征,以通过插值来实现。例如,常见的插值方法之一是线性插值。可以采用如下公式进行特征的线性插值:Fi=(1-t)*Fa+t*Fb,其中,Fi是插值图像特征,t是一个范围在[0,1]的插值参数。当t=0时,插值结果为源域图像A的特征Fa;当t=1时,插值结果为相似目标域图像B的特征Fb。因此,在t变化的过程中,插值图像特征Fi将从源域图像特征逐渐过渡到相似目标域图像特征。
方式四(协变):使用相似目标域图像特征的协方差矩阵更新源域图像特征:
在一些实施例中,使用相似目标域图像特征的协方差矩阵更新源域图像特征,可以使用相似目标域图像特征的协方差矩阵来更新源域图像特征。协方差矩阵可以提供特征之间的相关性信息,通过使用它来更新源域图像特征,可以增强源域图像特征。具体来说,本申请实施例可以通过以下步骤进行更新:首先计算相似目标域图像特征的协方差矩阵。协方差矩阵反映了特征之间的相关性和方差大小;其次使用相似目标域图像特征的协方差矩阵来分析源域图像特征的结构,可以通过计算源域图像特征和相似目标域图像特征的协方差、相关系数等来评估它们之间的相似性;然后基于分析结果,对源域图像特征进行调整或更新。一种常见的方法是使用线性变换,将源域图像特征映射到相似目标域图像特征的结构上。新的特征向量fC如下:
fC=f+λδ,δ~N(0,ΣC)
其中ΣC是聚类中心向量c的协方差矩阵,N(0,ΣC)是多元正态分布。生成的特征fC被追加到包中。
示例性的,结合参考图7,源域图像中的每个切片对应一个特征向量,图7中的切片(1)对应的特征向量为F1,图7中的切片(2)对应的特征向量为F2,图7中的切片(3)对应的特征向量为Fn;当目标域图像中的切片数量较多时,可对目标域图像中的切片进行聚类,图7中的聚类中心(4)、聚类中心(5)和聚类中心(6)也对应着聚类中心对应的特征向量C1、C2和Cn。可以将目标域图像的一部分和源域图像相似的特征加入源域图像,对源域图像的特征向量进行增强。可选地,使用相似目标域图像特征对源域图像特征进行增强时,可采用替换、追加、插值和协变的方式,对源域图像的特征向量进行增强。
综上所述,本实施例提供的方法,在目标域图像中查找与源域图像的相似度符合筛选条件的相似目标域图像;使用相似目标域图像增强源域图像,得到增强图像。通过查找与源域图像相似度符合筛选条件的相似目标域图像,可以找到与源域图像相关的目标域图像,使用相似目标域图像增强源域图像得到增强图像,使用相似目标域图像对教师模型进行训练可以提升教师模型迁移的准确性。
以多实例学习为全扫描数字病理切片(Whole Slide Imaging,WSI)为例进行说明,在多实例学习中,WSI通常是千兆像素的图像,将WSI分成多个切片,每个切片作为一个实例。
例如在病理图像分类中,有一组病理图像数据集,其中包含10个WSI,每个WSI都是千兆像素级别的图像,将每个WSI分成多个切片,可以将多个切片作为实例,每个实例代表一个切片,它是病理图像中的一个局部区域。将这些实例组织成多个包,每个包是多个实例的集合。通过多实例学习模型从切片实例中提取特征,并根据这些特征进行分类判断,从而对病理图像进行分类。
在基于图5实施例的基础上,图6示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图。在本实施例中,还包括如下步骤:
步骤310:将源域图像的多个第一切片特征,确定为源域图像的至少两个源域图像特征;
其中,源域图像包括多个第一切片,多个第一切片具有一一对应的第一切片特征;目标域图像包括多个第二切片,多个第二切片具有一一对应的第二切片特征。
第一切片指的是源域图像中的图像块或者图像区域。第二切片指的是目标域图像中的图像块或者图像区域。
第一切片特征指源域图像中每个图像的第一切片的特征向量,这些特征向量可以从图像中提取得到。第一切片特征指目标域图像中每个图像的第二切片的特征向量,这些特征向量可以从目标域图像中提取得到。
源域图像由多个第一切片组成,每个第一切片都有对应的第一切片特征。这意味着源域图像被分割成多个图像块,并且每个图像块都有与之对应的特征向量。目标域图像由多个第二切片组成,每个第二切片都有对应的第二切片特征。这表示目标域图像也被分割成多个图像块,并且每个图像块都有与之对应的特征向量。
在一些实施例中,一个图像用多个切片对应的切片特征(即特征向量)进行表示,将目标域图像的一部分特征加入源域图像,可对源域图像的特征向量进行增强。示例性的,在病理图像分类中,会将病理图像切成许多个切片,每个细胞所在的区域为一个切片,可从一个切片中提取一个特征向量,一个切片对应一个特征向量,同一个图像包含的多个切片分别对应的特征向量的集合称为一个包。可选地,把目标域图像的特征加入到源域图像中的特征中,可对源域图像的特征进行增强。
步骤320:将目标域图像的多个第二切片特征进行聚类,得到至少两个聚类中心特征;
第二切片特征是指目标域图像每个图像的第二切片的特征向量。
其中,聚类是一种将相似的特征分组的方法,可将一组数据样本分成若干个类别或群组,使得同一类别内的样本具有较高的相似度,而不同类别的样本之间具有较大的差异性。通过聚类可以将目标域图像的多个第二切片特征划分为不同的类别,并计算每个类别的聚类中心特征。聚类中心特征是由聚类算法得到的每个聚类中心的代表性特征。一个聚类中心特征通常用于表示该聚类中心的特征,它可以是该聚类中心内特征的平均值或中心位置。
在一些实施例中,目标域图像由切片的特征向量组成,这些特征向量可以用来表示目标域图像包中的一些特征,目标域图像中的细胞含量较多,切片的数量巨大。为了减少计算量,可将目标域图像中的多个切片特征进行聚类操作,目标域图像由多个聚类中心表示,这些聚类中心代表了目标域图像包中的一些特征,并且聚类中心数目远小于切片数目,因此可以减少特征向量的计算量。
示例性的,假设有一个包含100个肺部病理切片图像的目标域图像数据集。对于每个图像,本申请实施例提取了第二切片的特征向量,得到了100个特征向量。然后,本申请实施例使用K均值算法对这100个特征向量进行聚类。假设设置聚类数目为3,则K均值算法将这100个特征向量划分为3个聚类中心,它们分别代表了3个不同的聚类中心特征。本申请实施例可以将这3个聚类中心特征确定为目标域图像的3个目标域图像特征,即可确定为目标域图像的3个特征向量,用于表示整个目标域图像数据集的关键特征。
步骤330:将至少两个聚类中心特征确定为目标域图像的至少两个目标域图像特征。
综上所述,本实施例提供的方法,源域图像和目标域图像中包括多个切片,每个切片对应着一个特征向量,通过提取源域图像的多个第一切片特征,并将其确定为源域图像的特征可以捕捉到源域图像的不同局部特征,丰富了源域图像的特征表示;对目标域图像的多个第二切片特征进行聚类,得到至少两个聚类中心特征,这样可以将目标域图像的相似特征进行聚合,减少特征维度,提高特征的表达能力。
学生模型的训练过程:图8至图12是教师模型的训练过程。
图8示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图。在本实施例中,该方法包括如下步骤中的至少部分步骤:
步骤410:将目标域图像输入训练后的教师模型,得到训练后的教师模型提取的教师图像特征和教师分类结果;
训练后的教师模型是一个已经通过训练数据进行训练和优化的模型,具有较高的准确性和泛化能力,它可以对输入的图像进行特征提取和分类预测。
特征提取是指教师模型从输入图像中提取出具有代表性的特征表示。这些特征可以是高层次的语义特征,也可以是低层次的视觉特征。教师模型通过学习从训练数据中提取有用信息的方式,将目标域图像转换为特征向量。
分类预测是指教师模型根据提取的特征向量,对目标域图像进行分类或预测。教师模型可以根据其训练过程中学到的知识和模式,将目标域图像分为不同的类别或给出相应的预测结果。
教师图像特征是指通过训练后的教师模型对输入的图像进行处理,从图像中提取出的具有代表性的特征向量,这些特征向量可以用来表示源域图像的内容、结构或其他相关信息。教师分类结果是指通过训练后的教师模型对输入的图像进行分类预测,判断源域图像属于哪个类别或给出相应的预测结果。
步骤420:基于教师图像特征对学生模型训练,得到学生模型的学生分类结果;
学生分类结果指的是经过学生模型处理后,对输入的目标域图像进行分类预测得到的结果。学生模型通过学习和训练,根据输入的目标域图像的特征进行分类预测,并给出输入的目标域图像所属的类别或相应的预测结果。
在一些实施例中,教师图像特征被用作学生模型的输入数据,以指导学生模型的训练。学生模型会根据教师图像特征进行学习,并尝试预测与教师模型相似的分类结果。
步骤430:基于教师分类结果和学生分类结果之间的误差,更新学生模型的模型参数。
在一些实施例中,使用教师模型的分类结果与学生模型的分类结果之间的误差,来更新学生模型的参数。目的是最小化预测结果与真实结果之间的误差。可以通过计算学生模型预测的分类结果与教师模型分类结果之间的差异,来评估学生模型在当前数据上的表现。然后,本申请实施例可以根据这种差异来更新学生模型的模型参数,以提高其性能。
可选地,更新学生模型的模型参数,可以通过误差反向传播算法调整学生模型的参数,使得学生模型能够更好地逼近教师模型的分类结果。
在一个可选的示例中,基于教师分类结果和学生分类结果之间的误差,更新学生模型和多头特征适应网络的模型参数。
示例性的,教师模型包括特征提取网络和预测网络。教师模型输入的图像用特征向量所组成的矩阵表示,将此作为教师图像特征。将教师图像特征输入到教师模型的预测模块对教师图像进行预测。首先向教师模型中输入一个教师图像,得到一个预测结果,该预测结果是基于特征向量得到的。在教师模型对学生模型进行知识迁移时,将矩阵输入学生模型中,让学生模型进行预测,得到一个预测结果。假设教师模型的预测结果为预测结果A,学生模型的预测结果为预测结果B,将预测结果A和预测结果B进行比较,根据比较得出的误差调整学生模型的模型参数,基于模型参数对学生模型进行训练。
综上所述,本实施例提供的方法,将目标域图像输入训练后的教师模型,可以提取出目标域图像的特征表示,并给出对应的分类结果,这样可以利用教师模型提取的特征和分类结果来指导学生模型的训练;基于教师图像特征对学生模型进行训练,可以利用教师模型提取的特征来指导学生模型的学习过程;根据教师分类结果和学生分类结果之间的误差,可以计算出学生模型的误差,并利用该误差来更新学生模型的模型参数。这样可以通过不断的更新模型参数来提升学生模型的性能。
在基于图8实施例的基础上,图9示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图。在本实施例中,步骤420替换实现为步骤421、步骤422和步骤423:
步骤421:将教师图像特征从源域特征空间映射到目标域特征空间,得到映射后的图像特征;
其中,特征空间是指用来表示数据的一个数学空间。特征空间是所有的特征向量所构成的空间。通过选择合适的特征空间,可以提取有用的特征,并且能够更好地对学生模型进行预测。
源域特征空间是指源域图像在特征空间中的表示,目标域特征空间是指目标域图像在特征空间中的表示。将源域特征空间映射到目标域特征空间,意味着将源域图像在特征空间中的表示转换为适用于目标域图像的表示。
在一些实施例中,假设有一个教师模型在源域中训练,用于图像分类任务。该教师模型可以提取出一组源域特征,这些特征在源域中对图像分类非常有效。若本申请实施例希望将这个教师模型应用于目标域的图像分类任务,本申请实施例可以使用一个映射函数或者神经网络模型来将教师图像特征从源域特征空间映射到目标域特征空间。
步骤422:通过多头特征适应网络提取映射后的图像特征中,用于迁移至学生模型的迁移特征组合;
多头特征适应网络是一个用于特征适应的模型,它可以从映射后的图像特征中提取出用于迁移的特征组合。多头特征适应网络通过在网络中引入多个头来学习不同的特征适应映射,每个头都负责学习一种相适应的映射,从而负责提取不同的特征表示。
在一些实施例中,在通过多头特征适应网络将图像特征从源域映射到目标域后,可以从映射后的图像特征中选择一组特征作为迁移特征,用于迁移至学生模型。
使用多头特征适应网络将源域图像特征映射到目标域特征空间,在映射后的图像特征中,可以选择一组特征作为迁移特征,用于迁移至学生模型。可选地,这组迁移特征可以是在源域和目标域上都具有较高区分度的特征,例如动物尾巴的形状。通过选择合适的迁移特征,可以将多头特征适应网络学到的源域和目标域之间的特征转换迁移到学生模型中,从而提高学生模型在目标域上的性能。
步骤423:将迁移特征组合输入学生模型,得到学生模型的学生分类结果。
其中,迁移特征组合是指从映射后的图像特征中选择的特征,用于迁移至学生模型,这些特征被组合在一起作为输入,传递给学生模型进行分类。
在一些实施例中,将选择的迁移特征组合作为输入,输入到学生模型中,然后通过学生模型进行分类,得到学生模型的分类结果。可选地,迁移特征组合可以是在源域图像和目标域图像中具有高区分度的特征。
示例性的,假设有一个源域是乳腺癌细胞图像分类任务,目标域是肺癌细胞图像分类任务。使用多头特征适应网络将源域图像特征映射到目标域特征空间。在映射后的特征中,本申请实施例选择一组迁移特征组合作为输入,输入到学生模型中。迁移特征组合可以是在源域和目标域中具有高区分度的癌细胞特征的组合。例如,在源域的乳腺癌细胞图像和目标域的肺癌细胞图像中,细胞核的形状和大小、细胞核的染色程度可以是具有高区分度的特征。可选地,可将细胞核的形状和大小、细胞核的染色程度作为迁移特征组合输入学生模型中,得到学生模型的分类结果。
综上所述,本实施例提供的方法,将教师图像特征从源域特征空间映射到目标域特征空间,可以将源域图像的特征表示转化为目标域图像的特征表示;通过多头特征适应网络提取映射后的图像特征中用于迁移至学生模型的迁移特征组合,多头特征适应网络可以从映射后的图像特征中提取适合迁移至学生模型的迁移特征组合,可以选择对学生模型训练有益的特征,提高学生模型在目标域上的分类性能。
在基于图9实施例的基础上,图10示出了本申请一个示例性实施例提供的图像分类模型的训练方法的流程图。在本实施例中,步骤422替换实现为步骤422-1、步骤422-2、步骤422-3和步骤422-4:
步骤422-1:确定映射后的图像特征中的多头特征向量中的每个头内特征向量对应的放大因子;
在一些实施例中,多头特征适应网络包括线性变化层和多头放大层。
多头放大层是多头特征适应网络中的一层,用于对目标特征组合中的每个头内特征向量进行自适应放大。它通常由多个放大器组成,每个放大器对应一个头内特征向量。多头放大层的作用是增强每个头内特征向量的和权重,以便在迁移至学生模型时更好地适应目标域的任务。头内特征向量指的是每个头所产生的特征向量。
放大因子指的是对多头特征向量中的每个头内特征向量进行缩放或放大的系数。放大因子可以用来调整特征向量的权重或幅度,以增强或减弱每个头内特征的重要性。
在将教师模型中与目标域图像有关的目标特征组合进行放大时:
在一些实施例中,通过多头特征适应网络中的多头放大层对教师模型中与目标域图像有关的目标特征组合进行放大。示例性的,与目标特征组合相关的特征在教师模型中可能具有较低的响应,因此为多头特征适应网络中与目标特征组合相关的特征的头部添加一个放大因子,计算方式如下:
其中,和/>是可学习的参数矩阵,它们的值可通过误差反向传播算法进行调整和优化。h′t是归一化后的头内特征向量,ds代表了头内特征向量的维度,T用于控制放大因子的范围。较小的T值会使放大因子的范围更集中,较大的T值会使放大因子的范围更分散。/>
在一个可选的示例中,在进行映射之前对教师图像特征进行归一化处理。
在一些实施例中,归一化处理有助于消除源域图像特征中的知识偏差,不同的图像特征具有差异,这些差异会对模型的学习产生负面影响,通过归一化,可以将图像特征的分布范围调整到一个统一的尺度,减少不同图像特征之间的差异,从而消除知识偏差。归一化后的头内特征向量h′t为:
其中,ht是教师特征向量,即初始的头内特征向量。T=0.1,t=0.3,T用于控制特征的分布形状。较小的T值会使特征的分布更加集中,而较大的T值会使图像特征的分布更加平均。t是幂指数,用于调整图像特征的分布形状。较小的t值会使图像特征的分布更加尖锐,而较大的t值会使图像特征的分布更加平滑。sign(ht)是符号函数,它将头内特征向量中的每个元素的正数变为1,负数变为-1,零保持不变。
在一些实施例中,通过多头放大层对头内特征组合中的每个头内特征向量进行放大,可以增强教师图像特征中重要特征的影响力,减弱教师图像特征中不重要特征的影响力。其中,迁移至学生模型的迁移特征组合是经过不同头的特征适应后的头内特征向量进行加权求和得到的。
步骤422-2:对映射后的图像特征中的多头特征向量中的每头特征向量进行线性变换,得到提取后的头内特征组合;
多头特征适应网络中通过线性变换层对对映射后的图像特征进行线性变换。
线性变换层是多头特征适应网络中的另一层,用于对映射后的图像特征进行线性变换。线性变换层的作用是将映射后的图像特征转换为头内特征组合,以供后续处理和使用。头内特征组合是由经过线性变换后的头内特征向量组成的。
在一些实施例中,线性变换是指对每个头内特征向量应用线性操作,通常是矩阵乘法,这个线性操作可以通过一个权重矩阵来实现,将头内特征向量与权重矩阵相乘,得到提取后的头内特征组合。
在一些实施例中,在使用源域图像和增强图像对教师模型进行训练时,教师模型中的源域图像占的权重大于增强图像在教师模型中的权重。向教师模型输入目标域图像时,提取的特征向量是源域图像特征和增强图像特征混合的特征向量所构成的矩阵。在教师模型对学生模型进行知识迁移时,为了提高学生模型的预测结果,在将矩阵输入学生模型之前,可以将教师模型中与目标域图像有关的目标特征组合进行提取,得到提取后的头内特征组合。
步骤422-3:对于每个头内特征向量,基于放大因子和头内特征组合,得到特征适应后的头内特征向量;
在一些实施例中,每个头内特征向量都对应着不同的特征空间,头内特征组合是由多个头内特征向量组成的。通过将放大后的头内特征向量和头内特征组合进行组合,得到特征适应后的头内特征向量。
在一些实施例中,对于多头特征适应网络中的第i个头,基于放大后的头内特征向量和头内特征组合进行组合,可以得到特征适应后的头内特征向量如下所示:/>
其中,是特征适应后的头内特征向量,ai是放大因子,h′t是归一化后的头内特征向量,/>是一个可学习的参数矩阵,用于对多头特征向量中的每头特征向量进行线性变换。
示例性的,结合参考图11,教师特征向量ht经过归一化操作得到归一化后的头内特征向量h′t,对归一化后的头内特征向量进行放大和线性变换,放大后得到可学习的参数矩阵线性变换得到可学习的参数矩阵/>其中,对参数矩阵/>进行运算可得到表示每个头内特征向量对应的放大因子ai,参数矩阵/>和头内特征向量运算后得到头内特征组合,通过将放大后的头内特征向量和头内特征组合进行组合,可得到特征适应后的头内特征向量/>
步骤422-4:将不同头的特征适应后的头内特征向量进行求和,得到用于迁移至学生模型的迁移特征组合。
迁移特征组合可表示为将不同头的特征适应后的头内特征向量进行求和得到的结果。在多头特征适应网络中,多头特征适应网络的输出hMHFA可表示为:
其中,ai是每个头对应的放大因子,是特征适应后的头内特征向量,对每个头对应的头内特征向量进行加权求和处理,可得到多头特征适应网络的输出。hMHFA是多头特征适应网络中不同头的特征适应后的头内特征向量进行加权求和得到的结果。
结合参考图12,多头特征适应网络中包括多个头,引入多个头来学习不同的特征适应映射,每个头都负责学习一种相适应的映射,多头特征适应网络从教师模型中提取具有代表性的头内特征向量FA,对提取出来的FA进行放大处理和线性变换处理后,得到特征适应后的头内特征向量将不同头的特征适应后的头内特征向量进行加权求和,得到多头特征适应网络的输出hMHFA。
综上所述,本实施例提供的方法,通过多头特征适应网络提取映射后的图像特征用于迁移至学生模型的迁移特征组合,多头特征适应网络包括线性变换层和多头放大层,通过线性变换层提取头内特征组合,并通过多头放大层对头内特征向量进行自适应放大,可以提高学生模型在目标域上的特征表示能力。这种方法使得学生模型更好地利用映射后的图像特征,并提高在目标域上的分类性能。
本技术方案在病理图像分类上产生的有益效果:
本申请实施例在方法的验证中使用了三个数据集。第一个是Camelyon16数据集(https://camelyon16.grand-challenge.org),它由399张病理图像组成,Camelyon16数据集的任务是检测转移性乳腺癌病变。第二个是肾细胞癌(Renal Cell Carcinoma,RCC)数据集,包含940张WSI。它由来自109例癌症基因组图谱(The Cancer Genome Atlas,TCGA)中的肾脏发色细胞肾癌(Kidney Chromophobe,KICH)病例的121张WSI,来自513例TCGA中的肾脏透明细胞肾癌(Kidney Renal Clear Cell Carcinoma,KIRC)病例的519张WSI,以及来自276例TCGA中的肾脏乳头状细胞肾癌(Kidney Renal Papillary Cell Carcinoma,KIRP)病例的300张WSI组成。RCC数据集的目的是亚型分类,即确定癌症属于哪个类别。第三个是非小细胞肺癌(Non-Small Cell Lung Cancer,NSCLC)数据集,NSCLC数据集由478例TCGA中的肺鳞状细胞癌(Lung Squamous Cell Carcinoma,LUSC)病理的512张WSI和478例TCGA中的肺腺癌(Lung Adenocarcinoma,LUAD)病例的541张WSI组成。NSCLC数据集同样用于进行癌症亚型分类任务。对比时采用了三个评价指标:AUC(Area Under the Curve,曲线面积)、F1值(F1-score)和准确率(Accuracy)。
AUC是一种常用的评价指标,用于衡量二分类模型的性能。它表示ROC曲线下的面积,范围在0到1之间,数值越接近1表示模型性能越好。F1是综合考虑了模型的精确率(Precision)和召回率(Recall)的评价指标。精确率表示模型预测为正例的样本中真正为正例的比例,召回率表示模型能够正确预测为正例的样本的比例。F1值是精确率和召回率的调和平均值,数值范围在0到1之间,数值越接近1表示模型性能越好。准确率是指模型在所有样本中正确分类的比例,数值范围在0到1之间,数值越接近1表示模型性能越好。准确率是最常用的评价指标之一,但在不平衡数据集中可能会受到样本分布的影响。
本申请实施例提出的方法与其他知识迁移方法在四个设置中进行比较:NSCLC到Camelyon16,RCC到Camelyon16,RCC到NSCLC和NSCLC到RCC。结果如表一和表二所示。本申请实施例的方法在每个指标上都取得了最佳性能,特别是在Camelyon16上,本申请实施例的方法明显优于其他方法。此外,由于肿瘤大小较小且样本数量有限,Camelyon16比TCGA数据集更具挑战性,这在指标的数值上有所体现。该结果证明本申请实施例的方法能够有效地将共有知识从一个较简单的数据集传递到一个更困难的数据集。
表一
表二
/>
除此之外,本申请实施例进行了消融研究,以评估本申请实施例的框架和多头特征适应(Multi-Head Feature Adaptation,MHFA)网络对性能的影响。本申请实施例使用上述的数据增强方法进行了相同的实验,以评估它们的性能。实验结果通过箱线图在图13中呈现。总体而言,与没有数据增强相比,引入数据增强方法表现出了改善的性能。值得注意的是,本申请实施例观察到在从NSCLC到Camelyon16的迁移中,联合数据增强方法取得了最高的性能,超过其他方法很大程度上(最佳AUC得分为0.958)。然而,在其他设置中,它的性能较差。另一方面,替换数据增强方法展示了整体上更优越的性能。具体而言,它在从RCC到NSCLC和NSCLC到RCC的知识迁移中取得了最佳性能,并在迁移到Camelyon16数据集时有不错的性能。
图13箱线图显示了不同数据增强方法的性能。X轴上的“O”,“R”,“A”,“I”,“C”,“J”分别代表不同的数据增强方法,即“原始”、“替换”、“追加”、“插值”、“协变”和“联合”。
表三
表三表示不同投影模块和本申请实施例的MHFA网络中不同头数量的结果。为了进一步探究MHFA网络的效果,本申请实施例的方法与不同的投影头进行了性能比较,包括线性投影头(Linear Projection Head,LP)、多层投影头(Multi-Layer PerceptronProjection Head,MLP)和1×1卷积投影头(Convolution Projection Head,Conv)。本申请实施例还对模块中头的数量进行了消融研究。这些实验是在没有数据增强方法的情况下进行的,并且实验结果报告在表三中。如表中所示,与其他选择相比,1×1Conv头的性能较差。另一方面,尽管LP和MLP投影头的结构相对简单,但它们都取得了不错的结果。本申请实施例提出的MHFA网络相比其他方法展现出了卓越的性能,其中采用PTS的MHFA配置实现了最佳性能。本申请实施例注意到增加MHFA网络中头的数量并不一定会导致性能的提升,例如当头的数量从两个增加到四个时,性能反而些许下降。总体而言,本申请实施例在不同头数量下的性能仍然相互竞争,突显了本申请实施例捕捉到的目标特征的鲁棒性。图14中展示了采用MHFA网络前后的特征分布,可以看到MHFA网络显著增大了目标域图像中的特征向量在特征空间中分布的方差。第二行中展示了每个头输出特征的分布,可以看到每个头分别捕获不同的特征,加入放大因子后每个头内的特征分别也更加分散。通过多头的方式以及加入放大因子,原先紧凑的分布的目标域数据被扩散开。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
请参考图15,其示出了本申请一个实施例提供的图像分类模型的训练装置的结构框图。该装置具有实现上述图像分类模型的训练方法示例的功能,功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置可以是上文介绍的服务器,也可以设置在服务器中。如图15所示,该装置1100可以包括:获取模块1110、增强模块1120、教师训练模块1130和学生训练模块1140;
获取模块1110,用于获取源域图像和目标域图像;
增强模块1120,用于使用所述目标域图像增强所述源域图像,得到增强图像;
教师训练模块1130,用于使用所述源域图像和所述增强图像训练所述教师模型,得到训练后的教师模型;
学生训练模块1140,用于使用所述训练后的教师模型监督所述学生模型的训练过程,所述学生模型是使用所述目标域图像训练的图像分类模型。
在一些可选的实施例中,增强模块1120还包括查找子模块和增强子模块。
在一个可选的实施例中,查找子模块,用于在所述目标域图像中查找与所述源域图像的相似度符合筛选条件的相似目标域图像;增强子模块,用于使用所述相似目标域图像增强所述源域图像,得到所述增强图像。
在一些可选的实施例中,查找子模块还包括查找单元,增强子模块还包括增强单元。
在一个可选的实施例中,查找单元,用于对于所述至少两个源域图像特征中的任一源域图像特征,在所述至少两个目标域图像特征中查找与所述源域图像特征距离最近的目标域特征,作为相似目标域图像特征;增强单元,用于使用所述相似目标域图像特征增强所述源域图像特征,得到所述增强图像的增强图像特征。
在一些可选的实施例中,装置1100还包括第一确定模块、第二确定模块和聚类模块。
在一个可选的实施例中,第一确定模块,用于将所述源域图像的多个第一切片特征,确定为所述源域图像的所述至少两个源域图像特征;聚类模块,用于将所述目标域图像的多个第二切片特征进行聚类,得到至少两个聚类中心特征;第二确定模块,用于将所述至少两个聚类中心特征确定为所述目标域图像的所述至少两个目标域图像特征。
在一个可选的实施例中,增强模块1120,用于使用所述相似目标域图像特征对所述源域图像特征采用如下增强方式中的至少一种,得到所述增强图像的增强图像特征:将所述相似目标域图像特征加入到所述至少两个源域图像特征中;使用所述相似目标域图像特征替换所述源域图像特征;计算所述源域图像特征和所述相似目标域图像特征的插值图像特征,使用所述插值图像特征替换所述源域图像特征;使用所述相似目标域图像特征的协方差矩阵更新所述源域图像特征。
在一些可选的实施例中,教师训练模块1130还包括教师训练子模块、学生训练子模块和更新子模块。
在一个可选的实施例中,教师训练子模块,用于将所述目标域图像输入所述训练后的教师模型,得到所述训练后的教师模型提取的教师图像特征和教师分类结果;学生训练子模块,用于基于所述教师图像特征对所述学生模型训练,得到所述学生模型的学生分类结果;更新子模块,用于基于所述教师分类结果和所述学生分类结果之间的误差,更新所述学生模型的模型参数。
在一些可选的实施例中,学生训练模块1140还包括映射子模块、提取子模块和输入子模块。
在一个可选的实施例中,映射子模块,用于将所述教师图像特征从源域特征空间映射到目标域特征空间,得到映射后的图像特征;提取子模块,用于通过所述多头特征适应网络提取所述映射后的图像特征中,用于迁移至所述学生模型的迁移特征组合;输入子模块,用于将所述迁移特征组合输入所述学生模型,得到所述学生模型的学生分类结果。
在一些可选的实施例中,提取子模块还包括确定单元、变换单元、组合单元和求和单元。
在一个可选的实施例中,确定单元,确定所述映射后的图像特征中的多头特征向量中的每个头内特征向量对应的放大因子;变换单元,对所述映射后的图像特征中的多头特征向量中的每头特征向量进行线性变换,得到提取后的头内特征组合;组合单元,对于所述每个头内特征向量,基于所述放大因子和所述头内特征组合,得到特征适应后的头内特征向量;求和单元,用于将不同头的所述特征适应后的头内特征向量进行加权求和,得到用于迁移至所述学生模型的迁移特征组合。
在一些可选的实施例中,更新子模块还包括更新单元。
在一个可选的实施例中,更新单元,用于基于所述教师分类结果和所述学生分类结果之间的误差,更新所述学生模型和所述多头特征适应网络的模型参数。
在一个可选的实施例中,将所述教师图像特征进行归一化。
综上所述,本实施例提供的装置,计算机设备中部署有教师模型和学生模型,教师模型对应着源域图像,学生模型对应目标域图像,分别获取源域图像和目标域图像,使用目标域图像增强源域图像,得到增强图像,对教师模型中的源域图像和增强图像进行训练,可以得到训练后的教师模型。通过利用源域图像和增强图像对教师模型进行训练,其中增强图像是利用目标域图像对源域图像进行增强得到的,因此教师模型在学习过程中能够学习到目标域的特征,在源域对目标域进行迁移时,迁移性会变得更好;再使用训练后的教师模型监督学生模型的训练,从教师模型中提取和学生模型关联性高的特征向量,对学生模型进行训练,可以提高学生模型预测结果的准确性。
图16示出了本申请一个示例性实施例示出的计算机设备1500的结构框图。该计算机设备可用于实施上述实施例中提供的图像分类模型的训练方法。所述计算机设备1500包括中央处理单元(Central Processing Unit,CPU)1501、包括随机存取存储器(RandomAccess Memory,RAM)1502和只读存储器(Read-Only Memory,ROM)1503的系统存储器1504,以及连接系统存储器1504和中央处理单元1501的系统总线1505。所述计算机设备1500还包括帮助计算机设备内的各个器件之间传输信息的基本输入/输出系统(Input/Output系统,I/O系统)1506,和用于存储操作系统1513、应用程序1514和其他程序模块1515的大容量存储设备1507。
所述基本输入/输出系统1506包括有用于显示信息的显示器1508和用于用户输入信息的诸如鼠标、键盘之类的输入设备1509。其中所述显示器1508和输入设备1509都通过连接到系统总线1505的输入输出控制器1510连接到中央处理单元1501。所述基本输入/输出系统1506还可以包括输入输出控制器1510以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器1510还提供输出到显示屏、打印机或其他类型的输出设备。
所述大容量存储设备1507通过连接到系统总线1505的大容量存储控制器(未示出)连接到中央处理单元1501。所述大容量存储设备1507及其相关联的计算机可读存储介质为终端设备1500提供非易失性存储。也就是说,所述大容量存储设备1507可以包括诸如硬盘或者只读光盘(Compact Disc Read-Only Memory,CD-ROM)驱动器之类的计算机可读存储介质(未示出)。
不失一般性,所述计算机可读存储介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读存储指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、可擦除可编程只读寄存器(Erasable Programmable Read OnlyMemory,EPROM)、电子抹除式可复写只读存储器(Electrically-Erasable ProgrammableRead-Only Memory,EEPROM)、闪存或其他固态存储其技术,CD-ROM、数字多功能光盘(Digital Versatile Disc,DVD)或其他光学存储、磁带盒、磁带、磁盘存储或其他磁性存储设备。当然,本领域技术人员可知所述计算机存储介质不局限于上述几种。上述的系统存储器1504和大容量存储设备1507可以统称为存储器。
存储器存储有一个或多个程序,一个或多个程序被配置成由一个或多个中央处理单元1501执行,一个或多个程序包含用于实现上述方法实施例的指令,中央处理单元1501执行该一个或多个程序实现上述各个方法实施例提供的方法。
根据本申请的各种实施例,所述计算机设备1500还可以通过诸如因特网等网络连接到网络上的远程终端设备运行。也即计算机设备1500可以通过连接在所述系统总线1505上的网络接口单元1511连接到网络1512,或者说,也可以使用网络接口单元1511来连接到其他类型的网络或远程终端设备系统(未示出)。
所述存储器还包括一个或者一个以上的程序,所述一个或者一个以上程序存储于存储器中,所述一个或者一个以上程序包含用于进行本申请实施例提供的方法中由终端设备所执行的步骤。
本申请实施例还提供一种计算机可读存储介质,该计算机可读存储介质中存储有至少一条计算机程序,该至少一条计算机程序由处理器加载并执行以实现上述各方法实施例提供的图像分类模型的训练方法。
本申请实施例还提供一种计算机程序产品,所述计算机程序产品包括计算机程序,所述计算机程序存储在计算机可读存储介质中;所述计算机程序由计算机设备的处理器从所述计算机可读存储介质读取并执行,使得所述计算机设备执行以实现上述各方法实施例提供的图像分类模型的训练方法。
可以理解的是,在本申请的具体实施方式中,涉及到的数据,历史数据,以及画像等与用户身份或特性相关的用户数据处理等相关的数据,当本申请以上实施例运用到具体产品或技术中时,需要获得用户许可或者同意,且相关数据的收集、使用和处理需要遵守相关国家和地区的相关法律法规和标准。
需要说明的是,除非本文中另外明确定义,否则用于权利要求中的所有术语根据它们在技术领域中的普通含义来解释。除非另外明确叙述,否则对“一个元件、装置、部件、设备、步骤等”的所有参考将被开放地解释为指代元件、装置、部件、设备、步骤等的至少一个实例。除非明确叙述,否则本文所公开的任意方法的步骤不是必须以所公开的确切顺序来执行。
应当理解的是,在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。
Claims (14)
1.一种图像分类模型的训练方法,其特征在于,所述方法由部署有教师模型和学生模型的计算机设备执行,所述方法包括:
获取源域图像和目标域图像;
使用所述目标域图像增强所述源域图像,得到增强图像;
使用所述源域图像和所述增强图像训练所述教师模型,得到训练后的教师模型;
使用所述训练后的教师模型监督所述学生模型的训练过程,所述学生模型是使用所述目标域图像训练的图像分类模型。
2.根据权利要求1所述的方法,其特征在于,所述使用所述目标域图像增强所述源域图像,得到增强图像,包括:
在所述目标域图像中查找与所述源域图像的相似度符合筛选条件的相似目标域图像;
使用所述相似目标域图像增强所述源域图像,得到所述增强图像。
3.根据权利要求2所述的方法,其特征在于,所述源域图像包括至少两个源域图像特征,所述目标域图像包括至少两个目标域图像特征;
所述在所述目标域图像中查找与所述源域图像的相似度符合筛选条件的相似目标域图像,包括:
对于所述至少两个源域图像特征中的任一源域图像特征,在所述至少两个目标域图像特征中查找与所述源域图像特征距离最近的目标域特征,作为相似目标域图像特征;
所述使用所述相似目标域图像增强所述源域图像,得到所述增强图像,包括:
使用所述相似目标域图像特征增强所述源域图像特征,得到所述增强图像的增强图像特征。
4.根据权利要求3所述的方法,其特征在于,所述源域图像包括多个第一切片,所述多个第一切片具有一一对应的第一切片特征;所述目标域图像包括多个第二切片,所述多个第二切片具有一一对应的第二切片特征;
所述方法还包括:
将所述源域图像的多个第一切片特征,确定为所述源域图像的所述至少两个源域图像特征;
将所述目标域图像的多个第二切片特征进行聚类,得到至少两个聚类中心特征;
将所述至少两个聚类中心特征确定为所述目标域图像的所述至少两个目标域图像特征。
5.根据权利要求3所述的方法,其特征在于,所述使用所述相似目标域图像特征增强所述源域图像特征,得到所述增强图像的增强图像特征,包括:
使用所述相似目标域图像特征对所述源域图像特征采用如下增强方式中的至少一种,得到所述增强图像的增强图像特征:
将所述相似目标域图像特征加入到所述至少两个源域图像特征中;
使用所述相似目标域图像特征替换所述源域图像特征;
计算所述源域图像特征和所述相似目标域图像特征的插值图像特征,使用所述插值图像特征替换所述源域图像特征;
使用所述相似目标域图像特征的协方差矩阵更新所述源域图像特征。
6.根据权利要求1至5任一所述的方法,其特征在于,所述使用所述训练后的教师模型监督学生模型的训练过程,包括:
将所述目标域图像输入所述训练后的教师模型,得到所述训练后的教师模型提取的教师图像特征和教师分类结果;
基于所述教师图像特征对所述学生模型训练,得到所述学生模型的学生分类结果;
基于所述教师分类结果和所述学生分类结果之间的误差,更新所述学生模型的模型参数。
7.根据权利要求6所述的方法,其特征在于,所述计算机设备还部署有多头特征适应网络;
所述基于所述教师图像特征对所述学生模型训练,得到所述学生模型的学生分类结果,包括:
将所述教师图像特征从源域特征空间映射到目标域特征空间,得到映射后的图像特征;
通过所述多头特征适应网络提取所述映射后的图像特征中,用于迁移至所述学生模型的迁移特征组合;
将所述迁移特征组合输入所述学生模型,得到所述学生模型的学生分类结果。
8.根据权利要求7所述的方法,其特征在于,所述多头特征适应网络包括多头放大层和线性变换层;
所述通过所述多头特征适应网络提取所述映射后的图像特征中,用于迁移至所述学生模型的迁移特征组合,包括:
确定所述映射后的图像特征中的多头特征向量中的每个头内特征向量对应的放大因子;
对所述映射后的图像特征中的多头特征向量中的每头特征向量进行线性变换,得到提取后的头内特征组合;
对于所述每个头内特征向量,基于所述放大因子和所述头内特征组合,得到特征适应后的头内特征向量;
将不同头的所述特征适应后的头内特征向量进行求和,得到用于迁移至所述学生模型的迁移特征组合。
9.根据权利要求6所述的方法,其特征在于,所述基于所述教师分类结果和所述学生分类结果之间的误差,更新所述学生模型的模型参数,包括:
基于所述教师分类结果和所述学生分类结果之间的误差,更新所述学生模型和所述多头特征适应网络的模型参数。
10.根据权利要求7所述的方法,其特征在于,所述方法还包括:
将所述教师图像特征进行归一化。
11.一种图像分类模型的训练装置,其特征在于,所述装置包括:
获取模块,用于获取源域图像和目标域图像;
增强模块,用于使用所述目标域图像增强所述源域图像,得到增强图像;
教师训练模块,用于使用所述源域图像和所述增强图像训练所述教师模型,得到训练后的教师模型;
学生训练模块,用于使用所述训练后的教师模型监督所述学生模型的训练过程,所述学生模型是使用所述目标域图像训练的病理图像分类模型。
12.一种计算机设备,其特征在于,所述计算机设备包括处理器和存储器,所述存储器中存储有计算机程序,所述计算机程序由所述处理器加载并执行以实现如权利要求1至10任一项所述的图像分类模型的训练方法。
13.一种计算机可读存储介质,所述计算机可读存储介质中存储有计算机程序,其特征在于,所述计算机程序被处理器执行时实现如权利要求1至10任一项所述的图像分类模型的训练方法。
14.一种计算机程序产品,其特征在于,包括计算机程序,所述计算机程序被处理器执行时实现如权利要求1至10任一项所述的图像分类模型的训练方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311443497.XA CN117371511A (zh) | 2023-11-01 | 2023-11-01 | 图像分类模型的训练方法、装置、设备及存储介质 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202311443497.XA CN117371511A (zh) | 2023-11-01 | 2023-11-01 | 图像分类模型的训练方法、装置、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN117371511A true CN117371511A (zh) | 2024-01-09 |
Family
ID=89402178
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202311443497.XA Pending CN117371511A (zh) | 2023-11-01 | 2023-11-01 | 图像分类模型的训练方法、装置、设备及存储介质 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN117371511A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117609887A (zh) * | 2024-01-19 | 2024-02-27 | 腾讯科技(深圳)有限公司 | 数据增强模型训练及数据处理方法、装置、设备、介质 |
-
2023
- 2023-11-01 CN CN202311443497.XA patent/CN117371511A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117609887A (zh) * | 2024-01-19 | 2024-02-27 | 腾讯科技(深圳)有限公司 | 数据增强模型训练及数据处理方法、装置、设备、介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN107945204B (zh) | 一种基于生成对抗网络的像素级人像抠图方法 | |
US9519868B2 (en) | Semi-supervised random decision forests for machine learning using mahalanobis distance to identify geodesic paths | |
CN110738247B (zh) | 一种基于选择性稀疏采样的细粒度图像分类方法 | |
CN109993102B (zh) | 相似人脸检索方法、装置及存储介质 | |
CN110796199B (zh) | 一种图像处理方法、装置以及电子医疗设备 | |
CN114841257B (zh) | 一种基于自监督对比约束下的小样本目标检测方法 | |
CN111127385A (zh) | 基于生成式对抗网络的医学信息跨模态哈希编码学习方法 | |
CN111242948B (zh) | 图像处理、模型训练方法、装置、设备和存储介质 | |
Liang et al. | Comparison detector for cervical cell/clumps detection in the limited data scenario | |
CN110738132B (zh) | 一种具备判别性感知能力的目标检测质量盲评价方法 | |
CN108021930A (zh) | 一种自适应的多视角图像分类方法及系统 | |
CN110210625A (zh) | 基于迁移学习的建模方法、装置、计算机设备和存储介质 | |
CN114330499A (zh) | 分类模型的训练方法、装置、设备、存储介质及程序产品 | |
CN110765882A (zh) | 一种视频标签确定方法、装置、服务器及存储介质 | |
CN117371511A (zh) | 图像分类模型的训练方法、装置、设备及存储介质 | |
CN112819024B (zh) | 模型处理方法、用户数据处理方法及装置、计算机设备 | |
CN113222149A (zh) | 模型训练方法、装置、设备和存储介质 | |
CN110781970A (zh) | 分类器的生成方法、装置、设备及存储介质 | |
CN111222847A (zh) | 基于深度学习与非监督聚类的开源社区开发者推荐方法 | |
CN115063664A (zh) | 用于工业视觉检测的模型学习方法、训练方法及系统 | |
CN113128564A (zh) | 一种基于深度学习的复杂背景下典型目标检测方法及系统 | |
Pereira et al. | Assessing active learning strategies to improve the quality control of the soybean seed vigor | |
CN116451081A (zh) | 数据漂移的检测方法、装置、终端及存储介质 | |
CN116129189A (zh) | 一种植物病害识别方法、设备、存储介质及装置 | |
CN116188428A (zh) | 一种桥接多源域自适应的跨域组织病理学图像识别方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication |