CN116258695A - 基于Transformer与CNN交互的半监督医学影像分割方法 - Google Patents
基于Transformer与CNN交互的半监督医学影像分割方法 Download PDFInfo
- Publication number
- CN116258695A CN116258695A CN202310129552.1A CN202310129552A CN116258695A CN 116258695 A CN116258695 A CN 116258695A CN 202310129552 A CN202310129552 A CN 202310129552A CN 116258695 A CN116258695 A CN 116258695A
- Authority
- CN
- China
- Prior art keywords
- cnn
- network
- branch
- transducer
- feature
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 64
- 238000003709 image segmentation Methods 0.000 title claims abstract description 29
- 230000003993 interaction Effects 0.000 title claims abstract description 25
- 239000011159 matrix material Substances 0.000 claims abstract description 41
- 230000011218 segmentation Effects 0.000 claims abstract description 40
- 238000012549 training Methods 0.000 claims description 47
- 238000010586 diagram Methods 0.000 claims description 30
- 230000006870 function Effects 0.000 claims description 20
- 238000004364 calculation method Methods 0.000 claims description 14
- 230000008569 process Effects 0.000 claims description 12
- 230000004927 fusion Effects 0.000 claims description 9
- 238000000605 extraction Methods 0.000 claims description 7
- 230000007246 mechanism Effects 0.000 claims description 6
- 230000004913 activation Effects 0.000 claims description 5
- 239000000284 extract Substances 0.000 claims description 5
- 230000007704 transition Effects 0.000 claims description 5
- 239000000758 substrate Substances 0.000 claims description 4
- 238000005457 optimization Methods 0.000 claims description 3
- 101150064138 MAP1 gene Proteins 0.000 claims description 2
- 101150077939 mapA gene Proteins 0.000 claims description 2
- 238000012805 post-processing Methods 0.000 claims description 2
- 239000000126 substance Substances 0.000 claims 1
- 238000012935 Averaging Methods 0.000 description 4
- 230000000747 cardiac effect Effects 0.000 description 3
- 238000002059 diagnostic imaging Methods 0.000 description 3
- 230000000694 effects Effects 0.000 description 2
- 238000005516 engineering process Methods 0.000 description 2
- 238000010191 image analysis Methods 0.000 description 2
- 238000002372 labelling Methods 0.000 description 2
- 230000003902 lesion Effects 0.000 description 2
- 210000000056 organ Anatomy 0.000 description 2
- 238000012216 screening Methods 0.000 description 2
- NAWXUBYGYWOOIX-SFHVURJKSA-N (2s)-2-[[4-[2-(2,4-diaminoquinazolin-6-yl)ethyl]benzoyl]amino]-4-methylidenepentanedioic acid Chemical compound C1=CC2=NC(N)=NC(N)=C2C=C1CCC1=CC=C(C(=O)N[C@@H](CC(=C)C(O)=O)C(O)=O)C=C1 NAWXUBYGYWOOIX-SFHVURJKSA-N 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000006243 chemical reaction Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 230000018109 developmental process Effects 0.000 description 1
- 201000010099 disease Diseases 0.000 description 1
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 1
- 238000003384 imaging method Methods 0.000 description 1
- 210000005246 left atrium Anatomy 0.000 description 1
- 239000000203 mixture Substances 0.000 description 1
- 238000012544 monitoring process Methods 0.000 description 1
- 230000035790 physiological processes and functions Effects 0.000 description 1
- 238000011160 research Methods 0.000 description 1
- 210000005245 right atrium Anatomy 0.000 description 1
- 238000012706 support-vector machine Methods 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/0002—Inspection of images, e.g. flaw detection
- G06T7/0012—Biomedical image inspection
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T7/00—Image analysis
- G06T7/10—Segmentation; Edge detection
- G06T7/11—Region-based segmentation
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20081—Training; Learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/20—Special algorithmic details
- G06T2207/20084—Artificial neural networks [ANN]
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06T—IMAGE DATA PROCESSING OR GENERATION, IN GENERAL
- G06T2207/00—Indexing scheme for image analysis or image enhancement
- G06T2207/30—Subject of image; Context of image processing
- G06T2207/30004—Biomedical image processing
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Artificial Intelligence (AREA)
- Data Mining & Analysis (AREA)
- Radiology & Medical Imaging (AREA)
- Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
- Life Sciences & Earth Sciences (AREA)
- Medical Informatics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Quality & Reliability (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Image Analysis (AREA)
Abstract
本发明公开了一种基于Transformer与CNN交互的半监督医学影像分割方法,包括:本发明将Transformer引入到医学影像半监督分割框架中,并设计了Transformer分支与CNN分支之间特征交互的模块C2T模块和T2C模块,实现两分支之间高效的知识共享,提高分割网络同时捕捉细节特征和建立全局依赖关系的能力;同时增加了特征一致性分布约束,利用Teacher模型的Transformer分支的特征协方差矩阵去约束Student的CNN分支特征协方差矩阵,同样Teacher模型的CNN分支的特征协方差矩阵去约束Transformer的特征协方差矩阵,通过这种交叉教学的方式,使得引入Transformer分支之后的半监督框架更加稳定,同时产生更加准确的伪标签。
Description
技术领域
本发明属于医学影像分析领域,具体设计一种基于Transformer与CNN交互的半监督医学影像分割方法。
背景技术
近年来,随着医学影像技术的发展和普及,医学影像为医生进行疾病诊断提供了重要的参考信息。医学影像技术利用不同的成像原理和特制的设备,以非侵入方式获取病人内部组织结构信息,可以查看病人的生理机能状况。医学影像分割是影像辅助治疗中关键的一步,它是从医学影像中识别出病变器官的像素点,获取这些病变的位置、大小等信息,是一项具有一定技术难度的医学影像分析任务。在现实应用场景中,由于医学影像中的不同器官之间以及不同组织之间的边界很难明显区分,导致在标注过程中存在较大的不确定性和主观性,因此很难获得大量完整标注的、精确度高的医学影像数据集。除此之外,由于医学影像数据往往涉及病患隐私,导致获取渠道和数量会比较受限。因此如何充分利用大量的未标注数据是当前医学影像分割丞待解决的问题。
半监督学习则是基于如何通过利用少量的标注数据和大量的未标注数据来训练网络提升模型性能的重要的研究方向。总体来说,现有的半监督医学影像分割框架主要包括:基于CNN的半监督医学影像分割方法和基于Transformer结构的半监督医学影像分割方法两个大类。由于CNN往往只注重局部特征的提取,忽略了上下文信息,而Transformer则具有建立全局依赖关系的能力,因此如何将Transformer结构引入到半监督学习中成为了一个重要的研究课题。论文《Semi-Supervised Medical Image Segmentation via CrossTeaching between CNN and Transformer》中提出搭建了基于交叉教学框架的网络,分别让Transformer产生的预测作为CNN的伪标签以及CNN产生的预测作为Transformer的伪标签,在ACDC等数据集上证明优于现有的半监督框架。如中国专利CN114882047 A,公开日为2022年8月9日,提出搭建CNN和Transformer混合的U型分割网络,通过在Transformer结构中加入残差模块并且在跳跃连接处使用支持向量机的方式对信息进行进一步的筛选和简化,极大的提升了缺少标注数据的医学影像的分割准确性。
发明内容
本发明提供了一种基于Transformer与CNN交互的半监督医学影像分割方法。
基于Transformer与CNN交互的半监督医学影像分割方法的过程为:首先利用有标签的数据进行训练为学生网络提供初始化参数;其次将学生网络参数复制给教师网络,并利用教师网络为无标签数据提供伪标签;此后对无标签数据进行数据增强;有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据优化网络参数,提取无标签数据特征并输出其预测分布。增强后的无标签数据输入至教师网络,提取无标签数据特征并输出其预测分布。此外引入一致性损失约束来自教师网络和学生网络的无标签数据特征,并计算教师网络预测输出的信息熵,通过设定信息熵阈值以提升伪标签置信度,利用伪标签和学生网络的预测分布计算半监督损失。最后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
一种基于Transformer与CNN交互的半监督医学影像分割方法,包括以下步骤:
(1)教师学生网络构建。教师和学生网络都由一个CNN分支和一个Transformer分支组成。并于每个阶段引入C2T(CNN to Transformer)和T2C(Transformer to CNN)模块以实现两个分支提取特征的交互。
以初始化学生网络参数S(θ),并将其参数复制给教师网络T(θ)。
(4)数据增强。利用CutMix数据增强方法,为无标签的数据进行数据增强。
(5)学生网络优化及预测。将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS。
(6)特征一致性约束。将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束。
(8)整体损失优化。训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数。
(9)网络参数传递。网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
(10)目标图像分割。对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,选取最大概率的类别作为该像素点预测类别。
步骤1所述的Transformer和CNN分支交互过程(C2T模块)为:
(i)在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为通过全连接层预测的概率图为/>其中/>的维度为B×D×H×W,/>的维度为B×Nc×H×W,B为训练批次的大小,H,W分别为特征图的高和宽,D为特征图的通道数,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
(ii)对于学生网络中Transformer分支第i个阶段,其输入特征图为分别经过三个全连接层WQ,WK,WV计算Query(Q),Key(K),Value(V),Q、K、V表示为三个不同的特征图,并且计算得到Q与K的相似度矩阵SQ_K;
(iv)通过步骤(iii)得到的相似度矩阵SQ_cls和相似度矩阵SK_cls融合了CNN分支的类别中心的信息,将该信息传递给Transformer分支相应的层来优化矩阵SQ_K,得到融合后的相似度矩阵为将该相似度矩阵与特征图V相乘得到增强后的特征FC2T,最终与CNN分支第i个阶段输出的特征/>相加得到Transformer分支第i个阶段输出的特征图/>
步骤(ii)中,Q与K的相似度矩阵SQ_K计算为:
其中,D为特征图的通道数。
其中为学生网络Transformer分支第i个阶段融合来自CNN分支的信息后最终输出的特征图,/>为学生网络中CNN分支第i个阶段的输出特征图,FC2T为增强后的特征图,其中,/>为融合后的相似度矩阵,V为特征图。/>
步骤1中,Transformer分支通过T2C模块与CNN分支交互,具体包括:
(i)学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为其中/>的维度为B×N×D,N为序列长度,并且有N=H×W,特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽,通过全连接层预测的分割概率图为/>其维度为B×Nc×H×W,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
其中,为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
步骤(ii)中,计算交叉注意力机制得到增强后的特征FT2C,具体包括:
其中WQ,WK,WV为三个全连接层,为学生网络中的CNN分支的第i个阶段输出的特征图,D为特征图的通道数,softmax(·)为激活函数,/>为步骤(i)中计算得到的各个类别的中心,Q,K,V为计算得到的三个不同的特征图;
步骤2所述的初始化学生网络和将学生网络的参数复制给教师网络具体过程为:
(ii)训练结束后将其参数复制给教师网络T(θ)。
表示模型训练输入的批次大小为/>的有标签数据,lce表示交叉熵损失函数,/>表示第i张有标签图像的真实标签,/>为第i张有标签的图像。交叉熵损失函数可以表示为:/>其中H,W为预测输出分割概率图的高和宽,yi表示图像上位置i处的真实类别,pi为位置i处的预测类别概率。
步骤3所述的得到无标签数据的伪标签具体过程为:
步骤4中所述的数据增强方法CutMix具体步骤为:
步骤5将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS。
(ii)学生网络中的两个分支分别提取无标签数据特征FS-T和FS-C并输出其预测分布PS-T和PS-C,最终以二者均值作为学生网络的预测分布PS。
表示模型训练输入的批次大小为/>的有标签数据,lce表示交叉熵损失函数,可表示为:/>其中H,W为预测输出分割概率图的高和宽,ym表示图像上位置m处的真实类别,pm为位置m处的预测概率。/>表示将第i张有标签的图像输入到学生网络后的预测分割结果,/>表示第i张有标签图像的标签。
步骤(ii)中预测分布PS计算过程为:
步骤6将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束,具体步骤为:
(i)教师网络提取无标签数据特征FT并输出其预测分布PT。
(ii)将教师网络Transformer分支和CNN分支提取的特征FT-T,FT-C,和学生网络中Transformer分支和CNN分支提取的特征FS-T,FS-C分别与其自身转置相乘获得协方差特征图
步骤(i)中预测分布PT计算过程为:
步骤(ii)中协方差矩阵的计算过程为:
FΣ=F·(F)T
其中MSE为计算平方差,即MSE(x,y)=x2-y2。
步骤(i)中计算教师模型的预测分布的信息熵:
其中γ为设定的阈值,为计算得到的熵。/>表示为第i张无标签图像的伪标签的第j个位置上的第c个类别的预测值。argmaxc(·)表示计算最大值对应的类别索引。/>为经过筛选后的第i张无标签图像第j个像素点的伪标签。
其中,λ1和λ2为超参数,用于平衡两种损失对总损失的影响。
步骤9中,网络参数更新后,学生网络将其网络参数通过指数移动平均(EMA)传递给教师网络以进行进一步迭代训练,具体过程为:
T(θ)=EMA(S(θ))
与现有技术相比,本发明的有益效果:
(1)相比较于其他方法提到的基于Transformer的半监督医学影像分割方法仅仅在最后预测结果进行交互,本方法可以将CNN提取的局部特征与Transformer捕捉的全局特征进行更加充分的融合,提升了网络性能。
(2)相比较于其他方法提到的基于Transformer的半监督医学影像分割方法通常计算像素级别的损失,本方法改进为计算协方差矩阵的损失,增强了模型的鲁棒性。
(3)本方法同时分别用学生网络中的两个分支和教师网络中的两个分支进行交叉教学,能产生更加稳定准确的伪标签。
附图说明
图1为基于Transformer与CNN交互的半监督医学影像分割方法整体结构图;
图2为Transformer分支和CNN分支在每个阶段交互的C2T模块结构图;
图3为Transformer分支和CNN分支在每个阶段交互的T2C模块结构图;
图4为基于特征分布一致性损失约束图;
图5为本发明基于Transformer与CNN交互的半监督医学影像分割方法的流程示意图;
图6为本发明在心脏分割数据集(ACDC)上的分割结果图;
图7为本发明在皮肤病分割数据集(ISIC)上的分割结果图。
具体实施方式
本发明方法的整体框架如图1和图5所示,包括以下步骤:
1、教师学生网络构建。教师和学生网络都由一个CNN分支和一个Transformer分支组成,并于每个阶段引入C2T(CNN to Transformer)和T2C(Transformer to CNN)模块以实现两个分支提取特征的交互。
4、数据增强。利用CutMix数据增强方法,为有标签和无标签的数据进行数据增强。
5、学生网络优化及预测。将增强后的有标签和无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS。
6、特征一致性约束。将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束。
8、整体损失优化。训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数。
9、网络参数传递。网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
10、目标图像分割。对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,并选择概率最大的对应的类别作为该像素点的预测类别。
步骤1中,教师和学生网络都由一个CNN分支和一个Transformer分支组成。并于每个阶段引入C2T(CNN to Transformer)和T2C(Transformer to CNN)模块以实现两个分支提取特征的交互。
具体而言:
(i)C2T模块
在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为通过全连接层预测的概率图为/>通过两者进行加权平均,可以得到各个类别的中心/>
分别计算公式(2)中的Q,K和通过公式(1)计算得到的类别中心的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls。
其中D为特征图的通道数。
其中为学生网络Transformer分支第i个阶段融合来自CNN分支的信息后最终输出的特征图,/>为学生网络中CNN分支第i个阶段的输出特征图,FC2T为增强后的特征图,其中,/>为融合后的相似度矩阵,V为特征图。
(ii)T2C模块
学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为该特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽。为通过全连接层预测的分割概率图,其维度为B×Nc×H×W,Nc为类别数。通过两者进行加权平均,可以得到各个类别的中心/>
其中,为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
表示模型训练输入的批次大小为/>的有标签数据,lce表示交叉熵损失函数,/>表示第i张有标签图像的真实标签,/>为第i张有标签的图像。交叉熵损失函数可以表示为:/>其中H,W为预测输出分割概率图的高和宽,yi表示图像上位置i处的真实类别,pi为位置i处的预测类别概率。
步骤4中,利用CutMix数据增强方法,为无标签的数据进行数据增强,如公式(11)所示:
步骤5中,将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS。具体而言:
表示模型训练输入的批次大小为/>的有标签数据,lce表示交叉熵损失函数,可表示为:/>其中H,W为预测输出分割概率图的高和宽,ym表示图像上位置m处的真实类别,pm为位置m处的预测概率。/>表示将第i张有标签的图像输入到学生网络后的预测分割结果,/>表示第i张有标签图像的标签。
5.2学生网络中的两个分支分别提取无标签数据特征FS-T和FS-C并输出其预测分布PS-T和PS-C,最终以二者均值作为学生网络的预测分布PS如公式(13)所示:
步骤6中,将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束。具体而言:
6.1教师网络提取无标签数据特征FT并输出其预测分布PT,如公式(14)所示:
6.2将教师网络和学生网络提取的特征FT-T,FT-C,FS-T,FS-C与其自身转置相乘获得协方差特征图如公式(15)所示;教师网络和学生网络对应分支的协方差特征图利用MSE损失函数进行一致性约束,并计算一致性损失/>以提升模型的鲁棒性及泛化能力,如公式(16)所示。
其中MSE为计算平方差,即MSE(x,y)=x2-y2。
7.1计算教师模型的预测分布的信息熵,如公式(17)所示;并设定选择阈值对伪标签进行过滤选择,以获得高置信度的预测输出,如公式(18)所示:
其中γ为设定的阈值,为计算得到的熵。/>表示为第i张无标签图像的伪标签的第j个位置上的第c个类别的预测值。argmaxc(·)表示计算最大值对应的类别索引。/>为经过筛选后的第i张无标签图像第j个像素点的伪标签。
其中,λ1和λ2为超参数,用于平衡两种损失对总损失的影响。
步骤9中,网络参数更新后,学生网络将其网络参数通过指数移动平均(EMA)传递给教师网络以进行进一步迭代训练,如公式(21)所示:
T(θ)=EMA(S(θ)) (21)
步骤10中,所述的目标图像分割,给定目标域图像数据,分割模型为每个像素点预测其类别概率,选择预测概率最大的类别作为预测类别,以获得最终的分割掩码。
进一步地,如图1所示,一种基于Transformer与CNN交互的半监督医学影像分割方法具体过程如下:
1、教师学生网络构建。教师和学生网络都由一个CNN分支和一个Transformer分支组成。并于每个阶段引入C2T(CNN to Transformer)如图2所示,具体步骤为:
(i)在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为通过全连接层预测的概率图为/>通过两者进行加权平均,可以得到各个类别的中心/>如下所示:
(ii)对于学生网络中Transformer分支第i个阶段,其输入特征图为分别经过三个全连接层WQ,WK,WV计算得到三个不同特征图Query(Q),Key(K),Value(V),并且计算可得Q与K的相似度矩阵SQ_K。
(iii)分别计算步骤(ii)中的Q,K和通过步骤(i)计算得到的类别中心的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls。
其中D为特征图的通道数。
其中为学生网络Transformer分支第i个阶段融合来自CNN分支的信息后最终输出的特征图,/>为学生网络中CNN分支第i个阶段的输出特征图,FC2T为增强后的特征图,其中,/>为融合后的相似度矩阵,V为特征图。
同样,在每个阶段引入T2C模块,如图3所示,具体步骤为:
(i)学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为该特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽。/>为通过全连接层预测的分割概率图,其维度为B×Nc×H×W,Nc为类别数。通过两者进行加权平均,可以得到各个类别的中心/>
其中,为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
4、数据增强。利用CutMix数据增强方法,为无标签的数据进行数据增强。
5、学生网络优化及预测。将有标签和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS。
6、特征一致性约束。将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束,如图4所示,具体过程为:
将教师网络和学生网络提取的特征FT-T,FT-C,FS-T,FS-C与其自身转置相乘获得协方差特征图教师网络和学生网络对应分支的协方差特征图利用MSE损失函数进行一致性约束,并计算一致性损失/>以提升模型的鲁棒性及泛化能力,如下公式所示。
FΣ=F·(F)T
其中MSE为计算平方差,即MSE(x,y)=x2-y2。
8、整体损失优化。训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数。
9、网络参数传递。网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
10、目标图像分割。对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,并选取概率最大的对应的类别作为该点的预测类别。
本发明采用两个多中心的公开数据集(皮肤病分割数据集ISIC和心脏分割数据集ACDC)对本发明的性能进行评估。皮肤病分割数据集ISIC包含两个类别,共2594张图像,选择其中的1838张图像用于训练,剩余的756张用于验证模型。心脏分割数据集ACDC包含三个类别(左心房、右心房、心室壁),包含100个病人的图像,选择其中的70个病人的影像用来训练,10个用来验证,20个用来测试。对于半监督任务来说,分别选择训练集中的3%和10%作为有标签数据,剩余部分作为无标签数据。我们和目前性能最好的半监督框架CTCT分别在两个数据集上进行了对比,其中真实标签为领域专家所标注,分割结果如图6和图7所示。可以发现,本发明在两个数据集上,模型的泛化能力和分割性能均优于现有的表现最好的模型结构,验证了本发明的性能。
Claims (10)
1.一种基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,包括以下步骤:
(1)教师学生网络构建:教师网络和学生网络都由CNN分支和Transformer分支结合而成,CNN分支提取医学影像特征和Transformer分支提取医学影像特征并进行交互,所述的CNN分支通过C2T模块与Transformer分支交互,所述的Transformer分支通过T2C模块与CNN分支交互;
(2)学生网络训练:利用有标签的数据训练学生网络,以初始化学生网络参数,并将参数复制给教师网络;
(3)生成伪标签:将医学影像中无标签的数据输入至教师网络,以教师网络的预测输出作为无标签的数据的伪标签;
(4)数据增强:利用数据增强方法,为无标签的数据进行数据增强;
(5)学生网络优化及预测:将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS;
(6)特征一致性约束:将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT,将教师网络提取的特征与学生网络提取的特征进行一致性约束;
(8)整体损失优化:训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数:
(9)网络参数传递:网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练;
(10)目标图像分割:对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,选取最大概率的类别作为该像素点预测类别。
2.根据权利要求1所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤1中,CNN分支通过C2T模块与Transformer分支交互,具体包括:
(i)在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为通过全连接层预测的概率图为/>其中/>的维度为B×D×H×W,/>的维度为B×Nc×H×W,B为训练批次的大小,H,W分别为特征图的高和宽,D为特征图的通道数,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
(ii)对于学生网络中Transformer分支第i个阶段,其输入特征图为分别经过三个全连接层WQ,WK,WV计算Query(Q),Key(K),Value(V),Q、K、V表示为三个不同的特征图,并且计算得到Q与K的相似度矩阵SQ_K;
7.根据权利要求1所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤1中,Transformer分支通过T2C模块与CNN分支交互,具体包括:
(i)学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为其中/>的维度为B×N×D,N为序列长度,并且有N=H×W,特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽,通过全连接层预测的分割概率图为/>其维度为B×Nc×H×W,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
其中,为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
步骤(ii)中,计算交叉注意力机制得到增强后的特征FT2C,具体包括:
9.根据权利要求1所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤6中,将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT,将教师网络提取的特征与学生网络提取的特征进行一致性约束,具体步骤为:
(i)教师网络提取无标签数据特征FT并输出其预测分布PT;
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310129552.1A CN116258695A (zh) | 2023-02-03 | 2023-02-03 | 基于Transformer与CNN交互的半监督医学影像分割方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310129552.1A CN116258695A (zh) | 2023-02-03 | 2023-02-03 | 基于Transformer与CNN交互的半监督医学影像分割方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116258695A true CN116258695A (zh) | 2023-06-13 |
Family
ID=86678835
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310129552.1A Pending CN116258695A (zh) | 2023-02-03 | 2023-02-03 | 基于Transformer与CNN交互的半监督医学影像分割方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116258695A (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116741372A (zh) * | 2023-07-12 | 2023-09-12 | 东北大学 | 一种基于双分支表征一致性损失的辅助诊断系统及装置 |
CN116862931A (zh) * | 2023-09-04 | 2023-10-10 | 北京壹点灵动科技有限公司 | 医学图像分割方法、装置、存储介质及电子设备 |
CN117253044A (zh) * | 2023-10-16 | 2023-12-19 | 安徽农业大学 | 一种基于半监督交互学习的农田遥感图像分割方法 |
-
2023
- 2023-02-03 CN CN202310129552.1A patent/CN116258695A/zh active Pending
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116741372A (zh) * | 2023-07-12 | 2023-09-12 | 东北大学 | 一种基于双分支表征一致性损失的辅助诊断系统及装置 |
CN116741372B (zh) * | 2023-07-12 | 2024-01-23 | 东北大学 | 一种基于双分支表征一致性损失的辅助诊断系统及装置 |
CN116862931A (zh) * | 2023-09-04 | 2023-10-10 | 北京壹点灵动科技有限公司 | 医学图像分割方法、装置、存储介质及电子设备 |
CN116862931B (zh) * | 2023-09-04 | 2024-01-23 | 北京壹点灵动科技有限公司 | 医学图像分割方法、装置、存储介质及电子设备 |
CN117253044A (zh) * | 2023-10-16 | 2023-12-19 | 安徽农业大学 | 一种基于半监督交互学习的农田遥感图像分割方法 |
CN117253044B (zh) * | 2023-10-16 | 2024-05-24 | 安徽农业大学 | 一种基于半监督交互学习的农田遥感图像分割方法 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Wang et al. | Mixed transformer u-net for medical image segmentation | |
Mou et al. | Vehicle instance segmentation from aerial image and video using a multitask learning residual fully convolutional network | |
Huang et al. | Instance-aware image and sentence matching with selective multimodal lstm | |
CN116258695A (zh) | 基于Transformer与CNN交互的半监督医学影像分割方法 | |
CN109766840A (zh) | 人脸表情识别方法、装置、终端及存储介质 | |
Wang et al. | When cnn meet with vit: Towards semi-supervised learning for multi-class medical image semantic segmentation | |
Cai et al. | A robust interclass and intraclass loss function for deep learning based tongue segmentation | |
Wang et al. | An uncertainty-aware transformer for MRI cardiac semantic segmentation via mean teachers | |
Wang et al. | Medical matting: a new perspective on medical segmentation with uncertainty | |
Xu et al. | Vision transformers for computational histopathology | |
CN114898407A (zh) | 一种基于深度学习牙齿目标实例分割及其智能预览的方法 | |
Zhao et al. | Deeply supervised active learning for finger bones segmentation | |
Li et al. | NIA-Network: Towards improving lung CT infection detection for COVID-19 diagnosis | |
CN111126155A (zh) | 一种基于语义约束生成对抗网络的行人再识别方法 | |
Shu et al. | Privileged multi-task learning for attribute-aware aesthetic assessment | |
Yang et al. | GGAC: Multi-relational image gated GCN with attention convolutional binary neural tree for identifying disease with chest X-rays | |
CN111898756B (zh) | 一种多目标信息关联神经网络损失函数计算方法及装置 | |
CN116759076A (zh) | 一种基于医疗影像的无监督疾病诊断方法及系统 | |
CN113590971B (zh) | 一种基于类脑时空感知表征的兴趣点推荐方法及系统 | |
CN114299342B (zh) | 一种基于深度学习的多标记图片分类中未知标记分类方法 | |
Wang et al. | Optimized lightweight CA-transformer: Using transformer for fine-grained visual categorization | |
Cheng et al. | Double attention for pathology image diagnosis network with visual interpretability | |
Zhao et al. | VCMix-Net: A hybrid network for medical image segmentation | |
Yang et al. | Robust feature mining transformer for occluded person re-identification | |
WO2024108522A1 (zh) | 基于自监督学习的多模态脑肿瘤影像分割方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |