CN116363461A - 多视图儿童肿瘤病理图像分类的深度网络增量学习方法 - Google Patents
多视图儿童肿瘤病理图像分类的深度网络增量学习方法 Download PDFInfo
- Publication number
- CN116363461A CN116363461A CN202310356664.0A CN202310356664A CN116363461A CN 116363461 A CN116363461 A CN 116363461A CN 202310356664 A CN202310356664 A CN 202310356664A CN 116363461 A CN116363461 A CN 116363461A
- Authority
- CN
- China
- Prior art keywords
- model
- training
- data
- network
- classification
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 55
- 230000001575 pathological effect Effects 0.000 title claims abstract description 30
- 206010028980 Neoplasm Diseases 0.000 title claims abstract description 11
- 238000013140 knowledge distillation Methods 0.000 claims abstract description 47
- 238000004422 calculation algorithm Methods 0.000 claims abstract description 30
- 238000012549 training Methods 0.000 claims description 81
- 230000006870 function Effects 0.000 claims description 24
- 238000005457 optimization Methods 0.000 claims description 19
- 230000008569 process Effects 0.000 claims description 16
- 230000007170 pathology Effects 0.000 claims description 11
- 239000013598 vector Substances 0.000 claims description 8
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 claims description 7
- 230000004913 activation Effects 0.000 claims description 7
- 238000003062 neural network model Methods 0.000 claims description 7
- 238000004364 calculation method Methods 0.000 claims description 5
- 238000013528 artificial neural network Methods 0.000 claims description 4
- 238000012795 verification Methods 0.000 claims description 4
- 230000005540 biological transmission Effects 0.000 claims description 3
- 230000007246 mechanism Effects 0.000 claims description 3
- 238000002156 mixing Methods 0.000 claims description 3
- 238000007781 pre-processing Methods 0.000 claims description 3
- 238000011423 initialization method Methods 0.000 claims description 2
- 230000000717 retained effect Effects 0.000 claims description 2
- 238000012545 processing Methods 0.000 abstract description 10
- 238000012360 testing method Methods 0.000 description 7
- 238000013461 design Methods 0.000 description 4
- 238000010801 machine learning Methods 0.000 description 4
- 238000013459 approach Methods 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 3
- 238000010586 diagram Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000010606 normalization Methods 0.000 description 3
- 206010029260 Neuroblastoma Diseases 0.000 description 2
- 238000002790 cross-validation Methods 0.000 description 2
- 238000013135 deep learning Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 238000012935 Averaging Methods 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000012937 correction Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000006735 deficit Effects 0.000 description 1
- 238000004821 distillation Methods 0.000 description 1
- 238000011156 evaluation Methods 0.000 description 1
- 238000002474 experimental method Methods 0.000 description 1
- 238000000605 extraction Methods 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 238000013178 mathematical model Methods 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000003909 pattern recognition Methods 0.000 description 1
- 230000002093 peripheral effect Effects 0.000 description 1
- 238000011176 pooling Methods 0.000 description 1
- 230000002787 reinforcement Effects 0.000 description 1
- 230000011218 segmentation Effects 0.000 description 1
- 238000000638 solvent extraction Methods 0.000 description 1
- 238000013519 translation Methods 0.000 description 1
- 230000000007 visual effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/77—Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
- G06V10/774—Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/0002—Inspection of images, e.g. flaw detection
- G06T7/0012—Biomedical image inspection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/764—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V10/00—Arrangements for image or video recognition or understanding
- G06V10/70—Arrangements for image or video recognition or understanding using pattern recognition or machine learning
- G06V10/82—Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20081—Training; Learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20084—Artificial neural networks [ANN]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/30—Subject of image; Context of image processing
- G06T2207/30004—Biomedical image processing
- G06T2207/30096—Tumor; Lesion
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06V—IMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
- G06V2201/00—Indexing scheme relating to image or video recognition or understanding
- G06V2201/03—Recognition of patterns in medical or anatomical images
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Health & Medical Sciences (AREA)
- Evolutionary Computation (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Data Mining & Analysis (AREA)
- Databases & Information Systems (AREA)
- Mathematical Physics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Multimedia (AREA)
- Molecular Biology (AREA)
- General Engineering & Computer Science (AREA)
- Probability & Statistics with Applications (AREA)
- Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
- Radiology & Medical Imaging (AREA)
- Quality & Reliability (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种多视图儿童肿瘤病理图像分类的深度网络增量学习方法,该方法首先对儿童肿瘤病理图像特征进行建模,然后在不断获得新数据的情况下,训练得到增量学习模型,由此提高其准确性并获得鲁棒的分类能力。本发明的创新点在于针对医学病理图像的特点进行了设计,提出了一种多层级知识蒸馏正则化方法,以缓解增量学习中的旧模型在新数据上适应时的灾难性遗忘问题,增加算法的鲁棒性和可靠性。本发明的方法可以应用于医学图像处理领域以及其他需要增量学习的任务中。
Description
技术领域
本发明涉及计算机技术领域,涉及模式识别与机器学习技术,特别涉及一种深度网络增量学习方法。
背景技术
背景技术涉及四大块:知识蒸馏,增量学习,DetexNet模型,LwM算法。
1)知识蒸馏(Knowledge Distillation)
知识蒸馏是指从一个大模型中抽取出知识,然后将其传递给一个小模型,以提高小模型的性能。这种方法通常用于压缩模型的大小,提高模型的推理速度,并使模型更适合在边缘设备上运行。知识蒸馏已经成为了深度学习中的一个非常重要的技术,并在许多应用场景中被广泛应用。
知识蒸馏的核心思想是利用一个大型的预训练模型来提取出它所学到的知识,然后将这些知识传递给一个小型的目标模型。这些知识可以是各种形式的,例如网络中的中间层输出、特征映射、梯度等。一般来说,这些知识可以被看作是对目标函数的近似或近似答案。知识蒸馏通过将这些知识与目标模型的训练目标结合起来,来训练一个更小的模型。
具体来说,假设有一个大型的模型F和一个小型的模型f,知识蒸馏的目标是训练f,以使其在执行任务时与F相似。在知识蒸馏中,会将F的输出作为f的目标输出,同时使用F的中间层输出作为f的输入,这些中间层输出通常被称为软目标。知识蒸馏还会将原始训练数据作为f的输入,从而使f学习到更多的信息。
具体而言,知识蒸馏可以用以下公式表示:
其中,θ表示f的参数,L是原始训练数据的损失函数,x是原始训练数据的输入,y是输出标签,H是知识蒸馏的辅助损失函数,g(x)是大型模型F的输出。参数λ控制着两个损失函数之间的平衡。
知识蒸馏在增量学习中也可以得到广泛的应用。增量学习是一种机器学习方法,它允许模型动态地从新的数据中学习,并在不断增加的类别集上进行分类。在增量学习中,一个模型需要不断地学习新的类别,同时保持对已有类别的准确性。这个过程通常被称为持续学习或连续学习。
知识蒸馏可以用于增量学习中,以帮助模型保持对旧类别的记忆,并同时学习新类别。具体而言,可以使用知识蒸馏来训练一个新模型ft,该模型可以在训练新类别时,利用已有模型ft-1的知识,以保持对旧类别的准确性。在增量学习中,可以使用知识蒸馏的方法来传递ft-1中已经学习到的知识,从而加快训练过程和提高模型的准确性。
另外,知识蒸馏还可以用于在保持模型准确性的同时,压缩模型的大小,提高模型在边缘设备上的性能。这对于一些资源受限的应用场景非常重要。在增量学习中,知识蒸馏可以通过将已有模型的知识传递给新模型,来避免在每个新类别的训练中使用完整的模型,从而减少了模型大小和计算复杂度。
总之,知识蒸馏是一种非常重要的技术,它可以用于压缩深度学习模型的大小,提高模型的推理速度和在边缘设备上的性能。
2)增量学习(Incremental Learning)
增量学习(Incremental Learning)是指模型不断地在新数据上进行学习,而不是一次性将所有数据放在一起进行训练。它是机器学习领域的一种重要技术,广泛应用于在线学习、智能推荐、自然语言处理、计算机视觉等领域。
增量学习的优点在于可以及时处理新数据,避免了重复训练的过程,节省了计算资源,同时也可以在原有的基础上不断提升模型的性能。在增量学习中,常用的方法包括在线学习、增量式学习、混合式学习等。
在线学习(Online Learning)是指模型在接收到新的数据后,直接在当前模型的基础上进行更新,而不是重新训练整个模型。在线学习的核心是在线性学习算法,通过随机梯度下降等方法来逐步更新模型参数,适用于大规模数据的场景,能够快速处理新的数据,但可能会面临数据稳定性、过拟合等问题。
增量式学习(Incremental Learning)是指将原有的模型与新数据融合,重新训练整个模型。增量式学习需要选择合适的数据样本来进行增量训练,以避免原有模型被新数据“淹没”,同时也需要考虑如何保持原有模型的稳定性和泛化能力。
混合式学习(Hybrid Learning)是指将在线学习和增量式学习相结合的一种方法,通过在训练过程中动态地调整学习率、样本选择等参数,来平衡新旧数据的影响,提高模型的性能和稳定性。
在增量学习中,模型的性能评估通常采用交叉验证、准确率、召回率等指标来衡量。其中,交叉验证(Cross-validation)是指将数据集分为训练集和测试集,通过多次训练和测试来评估模型的性能。准确率(Precision)和召回率(Recall)是指分类模型中正确分类的样本占总体样本的比例和正确分类的正样本占所有正样本的比例。
在实际应用中,增量学习通常需要考虑数据的稳定性、安全性等问题。例如,如何防止模型被恶意攻击、如何保证数据隐私等都需要进行细致的设计和实现。此外,增量学习还可以结合深度学习、强化学习等技术进行应用,进一步提
3)DetexNet模型
DetexNet深度卷积神经网络模型结构设计基于著名的DenseNet,这个模型可以利用上传统的手工纹理特征提取器的专业的先验知识,并将其嵌入到深度卷积神经网络当中,所以理论上该模型可以发挥出出色和稳健的性能。
DetexNet中的TEM核表示DetexNet设计的单元,它们被放置于整个卷积网络的最底层,TEM核卷积中的内核的值通过Texture Energy Measure(TEM)特征算法计算得到。图像的每个通道的每个M×M滑动窗口的TEM特征由下面的公式给出:
其中的Wx,y是中心位于图像坐标(x,y)的滑动窗口,φ(·)是非线性函数,R(a,b)是掩码过滤器。形式上来说R(a,b)=p(a,b)m(a,b),将(a,b)位置的像素值p与滑动窗口中相应位置的掩码值m相乘。
并且通过以下公式进行归一化:
使用γ·表示新滤波器,并且设计了新的特征提取函数:
输入的图像为RGB图像(红色R,绿色G,和蓝色B),RGB图像会被分割成R、G、B三个通道分别送入三个相互独立的TEM核当中,以获取对应通道的底层纹理特征,随后提取得到的RGB纹理特征将被串联成一个特征图并送入更高的层,如图3中的DenseNet骨干网络中所示。DetexNet网络模型的最后层是一个带有全连接层的分类器,该分类器会输出最后分类结果的概率预测,从而从病理图像中提供预测的诊断结果。
4)LwM算法
Learning without Memorizing(LwM)是一种增量学习(IL)的方法,旨在提高训练模型的能力,使其能够识别更多的类别。该方法的关键问题是在教分类器学习新类别的同时,需要存储与现有类别相关的数据(例如图像)。然而,这是不切实际的,因为它会在每个增量步骤增加内存需求,这使得在内存有限的边缘设备上实现IL算法变得不可能。因此,LwM提出了一种保留现有(基础)类别信息的方法,而不存储它们的任何数据,同时使分类器逐步学习新类别。
LwM的核心思想是利用注意力蒸馏损失(L),来惩罚分类器的注意力图的变化,从而保留基础类别的信息,当新类别被添加时。注意力蒸馏损失(L)的定义如下:
其中N是批次大小,C是基础类别数目,A,t和A,s分别是目标网络和源网络在第i个样本和第j个基础类别上的注意力图。注意力图是通过对特征图进行全局平均池化得到的。LwM还使用了知识蒸馏损失(L),来保留源网络对基础类别的预测能力。知识蒸馏损失(L)的定义如下:
其中P,t和P,s分别是目标网络和源网络在第i个样本和第j个基础类别上的预测概率。
LwM的总损失函数为:
L=LCE+λLD+γLAD
其中L是交叉熵损失,用于训练目标网络对新类别的分类能力,λ和γ是超参数,用于平衡不同损失项的权重。
LwM算法应用到医学病理图像中的主要问题是:这种知识蒸馏损失在自然图像上往往工作地很好,但在病理图像上的表现欠佳。病理与自然图像不同,自然图像的特点是清晰明确的物体,而病理图像依靠细胞或组织的重复模式作为最小组成部分。因此,在反映这些图像的信息方面,神经网络模型的中层甚至浅层纹理等基本特征比高层抽象的物体特征更重要。
发明内容
本发明的目的是提供一种多视图儿童肿瘤病理图像分类的深度网络增量学习方法,该方法提出了一种新的病理图像多层级知识蒸馏正则方法,并将该机制应用到多视图儿童肿瘤病理图像分类中并给出了该方法的数学模型,同时给出了此模型的学习算法。本发明针对病理图像的特点设计了一种多层级知识蒸馏损失,可以充分地利用病理图像的底层纹理信息进行知识蒸馏正则化,提高增量学习算法在病理图像中的应用效果。
实现本发明目的的具体技术方案是:
一种多视图儿童肿瘤病理图像分类的深度网络增量学习方法,该方法包括如下步骤:
步骤一:预处理包含医学病理图像和对应的分类标签的数据集
原始数据为H&E染色的儿童肿瘤组织病理切片(WSI)扫描图片,将每个切片裁剪为若干不重叠子图像,每个子图像的大小为224*224;每个子图像都根据人为标注有一个对应的分类标签,构得所述的数据集;
步骤二:确定用于病理图像分类的模型结构
采用的病理图像分类的神经网络模型为DetexNet模型;DetextNet模型基于DensetNet网络结构,在DenseNet网络的最底层设计了一个TEM多视图特征提取器,通过TEM算法从输入的源图像提取R、G和B三个视图的特征,将特征按通道连接后送入DenseNet网络当中,最后通过一个Softmax激活层得到对应每个类别的分类相对概率,概率最大的类别为模型的分类结果;
步骤三:确定初始训练的优化目标
在增量学习中,训练数据是不断更新的,不同时间点会得到不同的新数据,每一次得到新数据都要对模型进行训练;在第一批次的数据到来时,对步骤二确定的模型结构进行实例化,并使用第一批次的数据对模型进行初始训练,训练使用的优化目标仅为多分类交叉熵损失,该损失由如下公式描述:
步骤四:确定增量训练的优化目标
所述模型的优化目标为多分类交叉熵损失和多层级的知识蒸馏正则化;在每当获得一批新数据时,使用优化目标对当前已有模型进行新数据的增量训练;将当前模型定义为教师模型Mt,将教师模型Mt的模型结构和参数复制一份作为可学习的学生模型Ms;优化目标函数包括学习当前新数据的分类的交叉熵损失和保留旧数据知识的多层级知识蒸馏正则损失;优化目标函数如下公式描述:
其中的Ms为学生模型,x为当前批次的新数据,σ表示网络尾部的Softmax激活函数,K表示类别数量,yk表示one-hot标签向量,s和t分别表示从学生模型网络和教师模型网络中不同的中间层抽取出来的中间特征向量,n表示提取的层数,pi表示对应层级的权重;在本轮新数据训练完成后,舍弃教师模型Mt,得到的学生模型Ms作为当前模型;
步骤五:训练深度网络增量学习模型
首先用第一批数据按照步骤三做一次初始训练,初始训练只使用常规多分类的交叉熵损失进行优化,每当获得一批新数据时,使用步骤四所述的增量学习步骤进行对新数据的训练;
在新数据增量训练当中应用回放策略:使用初始数据X0训练模型M0后,使用模型M0对已经训练过的数据X0提取特征并对特征进行聚类,选定N个最靠近聚类中心最具有代表性的样本构成集合P0;
接下来,当有一批新数据X1到来时,将P0混入X1,并使用混合数据对模型M0进行增量训练得到新模型M1;接着使用新模型M1对新数据X1进行特征提取,再使用聚类算法对特征进行聚类,选取N个最靠近聚类中心的样本作为/>最后,将/>和P0进行聚类,得到最具有代表性的N个样本集P1;每当有新的数据到来时,重复所述过程,使用P1代替P0进行后续增量学习;
每一轮新数据的增量训练根据步骤三的目标函数对各参数计算梯度更新模型,采用Adam算法优化模型的目标函数;具体训练包括如下步骤:
步骤a1:初始训练使用基于均匀分布的xaiver初始化方法,初始化神经网络的参数;
步骤a2:设置超参数,批次大小,最大迭代次数;
步骤a3:设置Adam优化器的学习率超参数,使用Adam算法更新模型的参数;
步骤a4:设置Early stop机制,以验证集上的结果为标准,当验证集的精度不再上升时,中止模型的训练;
步骤六:训练完毕后,进行模型推理,即对儿童肿瘤病理图像进行分类。
所述多层级的知识蒸馏正则化,其使用的病理图像分类的模型,具有如下结构特征:骨干网络模型由若干DenseBlock构成,经过每个DenseBlock的中间特征Xl会被保留下来用作后续知识蒸馏的计算;每个中间特征的计算用公式描述为:
Xl=Hl(X0,X1,...,Xl-1)
其中,Xl表示第l层DenseBlock得到的特征,Hl表示第l个DensBlock;
在病理图像分类的模型中,层级低的特征包含局部纹理信息,而层级高的特征则包含全局信息;在训练过程中,让学生模型Ms(Student Model)和教师模型Mt(TeacherModel)得到相同数量的中间特征s和t,并两两计算s和t之间的知识蒸馏损失,最后将得到的所有知识蒸馏损失进行加权求和;此外,通过调整不同层级之间的权重,来控制网络中不同层级之间的信息传递和组合方式,进一步提升模型的性能;多层级知识蒸馏损失的公式描述为:
其中,si和ti分别代表学生模型和教师模型中提取的多层级特征,K表示类别数量,n表示提取的层数,σ表示神经网络尾部的Softmax激活函数,pi表示对应层级的权重,权重之和为1;p0到pn的取值为从0.1到1的插值,并且归一化到:
本发明的有益效果包括:
1)与传统的方法相比,本发明的创新之处在于针对医学病理图像的图像特点设计了一种多层级知识蒸馏正则化方法,从而有效地缓解了病理图像增量学习中的灾难性遗忘问题,从而实现数据增量效果的提升;
2)本发明提高了医学病理图像分析软件在应用过程中对新数据的适应能力,降低了初始训练对数据量的要求,提高了分类的鲁棒性。医学分析软件训练算法的数据是由专业的医学专家所标注,标注成本很高,所以收集的数据量较小,无法覆盖所有情况,因此算法鲁棒性不够,难以适应各种情况。其次,在软件投入使用时,医生常常需要对结果进行修正,这一步修正所提供的标注信息没有进行有效的利用,利用本发明方法可以较好的解决这一问题。
附图说明
图1为本发明的流程图;
图2为本发明的优化目标结构图;
图3为DetexNet网络结构框架图;
图4为回放策略训练算法流程图。
具体实施方式
结合以下具体实施例和附图,对本发明作进一步的详细说明。实施本发明的过程、条件、实验方法等,除以下专门提及的内容之外,均为本领域的普遍知识和公知常识,本发明没有特别限制内容。
本发明的具体实施方式流程图如图1所示。
具体实施包括如下步骤:
1、预处理医学病理图像和对应的分类标签的数据集
处理原始的病理图像WSI切片扫描图数据,可以采用以下步骤:
切分整张图片:由于整张图片的文件体积非常大,需要将其切分为若干个子图像,以便于后续算法处理。可以使用开源的图像处理库,如OpenCV或Pillow,来实现图片的切分。
调整子图像大小:由于卷积神经网络的输入图像大小为224×224,需要将每个子图像的大小调整为224×224。可以使用图像处理库中的resize函数来实现。
为每个子图像分配标签:需要为每个子图像分配一个对应的分类标签。通过人工标注的方式进行。
数据增强:为了避免模型的过拟合,可以对每个子图像进行数据增强,包括旋转、平移、翻转等操作。这可以通过OpenCV或者torchvision图像处理库中的函数来实现。
划分训练集和测试集:为了评估模型的性能,需要将数据集划分为训练集和测试集。一般来说,会将数据集中的80%用于训练,20%用于测试。可以使用开源的Python库,如scikit-learn,来实现数据集的划分。
数据标准化:为了避免特征尺度的差异对模型的影响,需要对数据进行标准化。一般来说,会将每个像素的值减去数据集的均值,然后除以数据集的标准差。可以使用开源的Python库,如NumPy,来实现数据标准化。
数据加载器:最后,需要将处理好的数据集加载到模型中进行训练和测试。可以使用开源的Python库,如PyTorch中的DataLoader,来实现数据加载器,以便于模型能够从数据集中读取数据进行训练和测试。
2、确定用于多视图病理图像分类的神经网络模型
本发明采用所述的多视图病理分类神经网络模型,如图3所示,具有如下结构:病理图像被划分为RGB三个通道,每个通道分别经过一个TEM核(由TEM算法计算得出),得到三个不同视图的数据,再对其进行通道的合并,送入DenseNet;在DenseNet中,经过每个DenseBlock的中间特征Xl会被保留下来用作后续知识蒸馏的计算;DenseNet中的每个DenseBlock用公式描述为:
Xl=Hl(X0,X1,...,Xl-l)
其中的Xl第l层DenseBlock得到的特征,Hl表示第l个DensBlock。
在整个网络中,层级较低的特征主要包含局部纹理信息,而层级较高的特征则更多地包含全局信息。在训练过程中,会让学生模型和教师模型得到相同数量的中间特征,并计算它们之间的知识蒸馏损失,最后将它们进行加权求和。此外,通过调整不同层级之间的权重,来控制网络中不同层级之间的信息传递和组合方式,以进一步提升模型的性能。
3、确定本模型的优化目标
本发明的优化目标结构图如图2所示。本发明模型的主要优化策略为多分类交叉熵损失和多层级的知识蒸馏正则化;在每当获得一批新数据时,使用本优化目标对当前已有模型进行新数据的增量训练;将当前已在旧数据上训练完成的模型定义为教师模型Mt,将教师模型Mt的模型结构和参数复制一份作为可学习的学生模型Ms;优化目标函数由两部分组成,第一部分为学习当前新数据的分类的交叉熵损失,第二部分为保留旧数据知识的多层级知识蒸馏正则损失;优化目标函数如下公式描述:
其中的Ms为学生模型,x为当前批次的新数据,σ表示网络尾部的Softmax激活函数,K表示类别数量,y表示one-hot标签向量,s和t分别表示从学生模型网络和教师模型网络中不同的中间层抽取出来的中间特征向量,n表示提取的层数,pi表示对应层级的权重;在本轮新数据训练完成后,舍弃教师模型Mt,得到的学生模型Ms作为当前的已有模型。
4、训练模型
本发明模型首先要用第一批数据做一次初始训练,初始训练只使用常规多分类的交叉熵损失进行优化,每当获得一批新数据时,使用步骤三所述的增量学习步骤进行对新数据的训练;在新数据增量训练当中应用回放策略,参阅图4:使用初始数据X0训练模型M0后,使用模型M0对已经训练过的数据X0提取特征并对特征进行聚类,选定N个最靠近聚类中心最具有代表性的样本构成集合P0;接下来,当有一批新数据X1到来时,将P0混入X1,并使用混合数据对模型M0进行增量训练得到新模型M1;接着使用新模型M1对新数据X1进行特征提取,再使用聚类算法对特征进行聚类,选取N个最靠近聚类中心的样本作为/>最后,将/>和P0进行聚类,得到最具有代表性的N个样本集P1;每当有新的数据到来时,重复所述过程,使用P1代替P0进行后续增量学习;
本发明使用Pytorch作为代码实现框架,使用一台具有24GB内存的NVIDIA RTX3090GPU来训练的神经网络模型。为了获得最佳的训练结果,采用了每一批次数据都训练20K次迭代的训练策略,并将批次大小设置为32。这些数据在训练时使用了Adam优化器来优化神经网络模型的各个网络模块,以便获得更好的结果。此外,教师和学生网络的学习率设置为0.0004,以获得更高的准确性和更好的模型效果。
5、训练完毕后,进行模型推理,即病理图像分类
对于步骤1中提到的输入的WSI图像,需要进行一系列处理。首先,需要将其切分为若干互不重叠的子图像,子图像的大小被固定为224×224。这一步骤可以保证在之后的处理中每个小图像都能够被准确地处理。接下来,将这些子图像分批送入模型中进行推理。在推理的过程中,并不需要像训练过程一样做数据增强,但仍需要做数据的归一化处理。这样能够保证模型在进行推理时能够得到准确的结果。推理完成后,每个子图像都会得到一个预测类别。这些预测值可以被用来计算出每个子图像的置信度,从而进一步提高整体预测的准确性。最后,将所有子图像根据其位置进行叠加,得到的综合结果即为整体WSI的预测结果。通过这一系列的处理,能够有效地对WSI图像进行分类预测,并得到准确的结果。
实施例
本发明的实验数据为PNT7数据集,它是一个包含七个类别的外周神经母细胞瘤数据集。该数据集由某某儿童医学中心的163张病理切片组成,涵盖73个患者的病例,记录时间从2014年1月1日到2015年12月31日。每个类别包含8到30个患者,每个患者的片数从2到186不等。所有这些标本都是按照标准的组织学协议收集的,并使用ScanScope T2数字化仪进行数字化。为了使组织学切片图像包含足够的视觉信息供病理学家分类并且舒适地适应人类视觉,每个组织学切片图像被截成多个大小为768×768像素的非重叠图像块,并通过命名贴片标题记录标签。根据国际神经母细胞瘤病理分类的组织学标准,所有贴片都由一位资深儿科病理学家进行组织学标记并分为七个类别:GN、GNBi、UD、PD、D、NOS和UN。
表1不同算法在不同增量情况上的精度百分比
表1是本发明提出的方法和不同算法在不同增量情况上的精度百分比。实验的构造方式是:随机打乱了PNT7数据集,将其分成5个等量的数据集,然后在每个阶段都增加了一份(20%)新数据,以此来模拟数据增量的形式。表中LwF通过在训练过程中保留旧任务的知识来避免网络忘记以前学习的内容。具体来说,LwF使用反向传播算法来更新网络权重,同时使用Kullback-Leibler散度来衡量新旧任务之间的差异,并使用这些差异来调整权重更新的大小。LwM旨在实现无记忆的学习。相比传统的机器学习算法,LwM更加注重对数据的实时处理和分析,而不是依赖于先前的经验或记忆。这种算法可以更好地适应快速变化的环境,并具有更好的泛化能力。iCARL采用了重要样本选择和知识蒸馏的方法,能够有效地缓解类别间的遗忘和干扰问题,并且具有较好的泛化性能。它能够在不断接收新类别数据的情况下,不断地提升模型的准确性。微调是指在已经训练好的模型上,仅用新数据进行轻微的权重调整,以提高模型的准确性和泛化能力。这可以让模型更好地适应新数据的特点,从而提高对新数据的预测能力。在微调过程中,使用较小的学习率以避免对已经学习好的权重造成太大的干扰。联合训练同时使用所有数据来进行训练,包括旧数据和新数据,这种方法的效果最好,可以被视为增量学习的上限。本发明的方法在病理图像的数据增量学习任务上表现明显优于其他算法,最终达到了74.8%的精度,这表明了本发明方法的有效性。
本发明的保护内容不局限于以上实施例。在不背离发明构思的精神和范围下,本领域技术人员能够想到的变化和优点都被包括在本发明中,并且以所附的权利要求书为保护范围。
Claims (2)
1.一种多视图儿童肿瘤病理图像分类的深度网络增量学习方法,其特征在于,该方法包括如下步骤:
步骤一:预处理包含医学病理图像和对应的分类标签的数据集
原始数据为H&E染色的儿童肿瘤组织病理切片即WSI扫描图像,将每个切片裁剪为若干不重叠子图像,每个子图像的大小为224*224;每个子图像都根据人为标注有一个对应的分类标签,构得所述的数据集;
步骤二:确定用于病理图像分类的模型结构
采用的病理图像分类的神经网络模型为DetexNet模型;DetextNet模型基于DensetNet网络结构,在DenseNet网络的最底层设计了一个TEM多视图特征提取器,通过TEM算法从输入的源图像提取R、G和B三个视图的特征,将特征按通道连接后送入DenseNet网络当中,最后通过一个Softmax激活层得到对应每个类别的分类相对概率,概率最大的类别为模型的分类结果;
步骤三:确定初始训练的优化目标
在增量学习中,训练数据是不断更新的,不同时间点会得到不同的新数据,每一次得到新数据都要对模型进行训练;在第一批次的数据到来时,对步骤二确定的模型结构进行实例化,并使用第一批次的数据对模型进行初始训练,训练使用的优化目标仅为多分类交叉熵损失,该损失由如下公式描述:
步骤四:确定增量训练的优化目标
所述模型的优化目标为多分类交叉熵损失和多层级的知识蒸馏正则化;在每当获得一批新数据时,使用优化目标对当前已有模型进行新数据的增量训练;将当前模型定义为教师模型M,将教师模型M的模型结构和参数复制一份作为可学习的学生模型M;优化目标函数包括学习当前新数据的分类的交叉熵损失和保留旧数据知识的多层级知识蒸馏正则损失;优化目标函数如下公式描述:
其中的Ms为学生模型,x为当前批次的新数据,σ表示网络尾部的Softmax激活函数,K表示类别数量,yk表示one-hot标签向量,si和ti分别表示从学生模型网络和教师模型网络中不同的中间层抽取出来的中间特征向量,n表示提取的层数,pi表示对应层级的权重;在本轮新数据训练完成后,舍弃教师模型Mt,得到的学生模型Ms作为当前模型;
步骤五:训练深度网络增量学习模型
首先用第一批数据按照步骤三做一次初始训练,初始训练只使用常规多分类的交叉熵损失进行优化,每当获得一批新数据时,使用步骤四所述的增量学习步骤进行对新数据的训练;
在新数据增量训练当中应用回放策略:使用初始数据X0训练模型M0后,使用模型M0对已经训练过的数据X0提取特征并对特征进行聚类,选定N个最靠近聚类中心最具有代表性的样本构成集合P0;
接下来,当有一批新数据X1到来时,将P0混入X1,并使用混合数据对模型M0进行增量训练得到新模型M1;接着使用新模型M1对新数据X1进行特征提取,再使用聚类算法对特征进行聚类,选取N个最靠近聚类中心的样本作为/>最后,将/>和P0进行聚类,得到最具有代表性的N个样本集P1;每当有新的数据到来时,重复所述过程,使用P1代替P0进行后续增量学习;
每一轮新数据的增量训练根据步骤三的目标函数对各参数计算梯度更新模型,采用Adam算法优化模型的目标函数;具体训练包括如下步骤:
步骤a1:初始训练使用基于均匀分布的xaiver初始化方法,初始化神经网络的参数;
步骤a2:设置超参数,批次大小,最大迭代次数;
步骤a3:设置Adam优化器的学习率超参数,使用Adam算法更新模型的参数;
步骤a4:设置Early stop机制,以验证集上的结果为标准,当验证集的精度不再上升时,中止模型的训练;
步骤六:训练完毕后,进行模型推理,即对儿童肿瘤病理图像进行分类。
2.根据权利要求1所述的深度网络增量学习方法,其特征在于,所述多层级的知识蒸馏正则化,其使用的病理图像分类的模型,具有如下结构特征:骨干网络模型由若干DenseBlock构成,经过每个DenseBlock的中间特征Xl会被保留下来用作后续知识蒸馏的计算;每个中间特征的计算用公式描述为:
Xl=Hl(X0,X1,...,Xl-1)
其中,Xl表示第层DenseBlock得到的特征,Hl表示第l个DensBlock;
在病理图像分类的模型中,层级低的特征包含局部纹理信息,而层级高的特征则包含全局信息;在训练过程中,让学生模型M和教师模型M得到相同数量的中间特征s和t,并两两计算s和t之间的知识蒸馏损失,最后将得到的所有知识蒸馏损失进行加权求和;此外,通过调整不同层级之间的权重,来控制网络中不同层级之间的信息传递和组合方式,进一步提升模型的性能;多层级知识蒸馏损失的公式描述为:
其中,si和ti分别代表学生模型和教师模型中提取的多层级特征,K表示类别数量,n表示提取的层数,σ表示神经网络尾部的Softmax激活函数,pi表示对应层级的权重,权重之和为1;p0到pn的取值为从0.1到1的插值,并且归一化到:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310356664.0A CN116363461A (zh) | 2023-04-04 | 2023-04-04 | 多视图儿童肿瘤病理图像分类的深度网络增量学习方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310356664.0A CN116363461A (zh) | 2023-04-04 | 2023-04-04 | 多视图儿童肿瘤病理图像分类的深度网络增量学习方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116363461A true CN116363461A (zh) | 2023-06-30 |
Family
ID=86941707
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310356664.0A Pending CN116363461A (zh) | 2023-04-04 | 2023-04-04 | 多视图儿童肿瘤病理图像分类的深度网络增量学习方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116363461A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117493889A (zh) * | 2023-12-27 | 2024-02-02 | 中国科学院自动化研究所 | 增量式持续学习方法、装置、存储介质和电子设备 |
-
2023
- 2023-04-04 CN CN202310356664.0A patent/CN116363461A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117493889A (zh) * | 2023-12-27 | 2024-02-02 | 中国科学院自动化研究所 | 增量式持续学习方法、装置、存储介质和电子设备 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110674714B (zh) | 基于迁移学习的人脸和人脸关键点联合检测方法 | |
CN109886121B (zh) | 一种遮挡鲁棒的人脸关键点定位方法 | |
CN108647583B (zh) | 一种基于多目标学习的人脸识别算法训练方法 | |
CN105069400B (zh) | 基于栈式稀疏自编码的人脸图像性别识别系统 | |
US20160283842A1 (en) | Neural network and method of neural network training | |
CN108510194A (zh) | 风控模型训练方法、风险识别方法、装置、设备及介质 | |
CN110969086B (zh) | 一种基于多尺度cnn特征及量子菌群优化kelm的手写图像识别方法 | |
CN115017418B (zh) | 基于强化学习的遥感影像推荐系统及方法 | |
CN112150493A (zh) | 一种基于语义指导的自然场景下屏幕区域检测方法 | |
CN108764358A (zh) | 一种太赫兹图像识别方法、装置、设备及可读存储介质 | |
CN111368935B (zh) | 一种基于生成对抗网络的sar时敏目标样本增广方法 | |
CN116110022B (zh) | 基于响应知识蒸馏的轻量化交通标志检测方法及系统 | |
CN116363461A (zh) | 多视图儿童肿瘤病理图像分类的深度网络增量学习方法 | |
CN111833322A (zh) | 一种基于改进YOLOv3的垃圾多目标检测方法 | |
CN115409804A (zh) | 一种乳腺磁共振影像的病灶区域识别标注及疗效预测方法 | |
CN115035341A (zh) | 一种自动选择学生模型结构的图像识别知识蒸馏方法 | |
CN115100165A (zh) | 一种基于肿瘤区域ct图像的结直肠癌t分期方法及系统 | |
Rachmad et al. | Classification of mycobacterium tuberculosis based on color feature extraction using adaptive boosting method | |
CN115115828A (zh) | 数据处理方法、装置、程序产品、计算机设备和介质 | |
CN117636183A (zh) | 一种基于自监督预训练的小样本遥感图像分类方法 | |
CN110866866B (zh) | 图像仿色处理方法、装置、电子设备及存储介质 | |
CN111160161B (zh) | 一种基于噪声剔除的自步学习人脸年龄估计方法 | |
CN115018729B (zh) | 一种面向内容的白盒图像增强方法 | |
Chatzistamatis et al. | Image recoloring of art paintings for the color blind guided by semantic segmentation | |
CN116129189A (zh) | 一种植物病害识别方法、设备、存储介质及装置 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |