CN116258695A - 基于Transformer与CNN交互的半监督医学影像分割方法 - Google Patents

基于Transformer与CNN交互的半监督医学影像分割方法 Download PDF

Info

Publication number
CN116258695A
CN116258695A CN202310129552.1A CN202310129552A CN116258695A CN 116258695 A CN116258695 A CN 116258695A CN 202310129552 A CN202310129552 A CN 202310129552A CN 116258695 A CN116258695 A CN 116258695A
Authority
CN
China
Prior art keywords
cnn
network
branch
transducer
feature
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310129552.1A
Other languages
English (en)
Inventor
林兰芬
解仕奥
牛子未
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN202310129552.1A priority Critical patent/CN116258695A/zh
Publication of CN116258695A publication Critical patent/CN116258695A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/0002Inspection of images, e.g. flaw detection
    • G06T7/0012Biomedical image inspection
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/10Segmentation; Edge detection
    • G06T7/11Region-based segmentation
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20081Training; Learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/20Special algorithmic details
    • G06T2207/20084Artificial neural networks [ANN]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/30Subject of image; Context of image processing
    • G06T2207/30004Biomedical image processing
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Data Mining & Analysis (AREA)
  • Radiology & Medical Imaging (AREA)
  • Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Medical Informatics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Quality & Reliability (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)

Abstract

本发明公开了一种基于Transformer与CNN交互的半监督医学影像分割方法,包括:本发明将Transformer引入到医学影像半监督分割框架中,并设计了Transformer分支与CNN分支之间特征交互的模块C2T模块和T2C模块,实现两分支之间高效的知识共享,提高分割网络同时捕捉细节特征和建立全局依赖关系的能力;同时增加了特征一致性分布约束,利用Teacher模型的Transformer分支的特征协方差矩阵去约束Student的CNN分支特征协方差矩阵,同样Teacher模型的CNN分支的特征协方差矩阵去约束Transformer的特征协方差矩阵,通过这种交叉教学的方式,使得引入Transformer分支之后的半监督框架更加稳定,同时产生更加准确的伪标签。

Description

基于Transformer与CNN交互的半监督医学影像分割方法
技术领域
本发明属于医学影像分析领域,具体设计一种基于Transformer与CNN交互的半监督医学影像分割方法。
背景技术
近年来,随着医学影像技术的发展和普及,医学影像为医生进行疾病诊断提供了重要的参考信息。医学影像技术利用不同的成像原理和特制的设备,以非侵入方式获取病人内部组织结构信息,可以查看病人的生理机能状况。医学影像分割是影像辅助治疗中关键的一步,它是从医学影像中识别出病变器官的像素点,获取这些病变的位置、大小等信息,是一项具有一定技术难度的医学影像分析任务。在现实应用场景中,由于医学影像中的不同器官之间以及不同组织之间的边界很难明显区分,导致在标注过程中存在较大的不确定性和主观性,因此很难获得大量完整标注的、精确度高的医学影像数据集。除此之外,由于医学影像数据往往涉及病患隐私,导致获取渠道和数量会比较受限。因此如何充分利用大量的未标注数据是当前医学影像分割丞待解决的问题。
半监督学习则是基于如何通过利用少量的标注数据和大量的未标注数据来训练网络提升模型性能的重要的研究方向。总体来说,现有的半监督医学影像分割框架主要包括:基于CNN的半监督医学影像分割方法和基于Transformer结构的半监督医学影像分割方法两个大类。由于CNN往往只注重局部特征的提取,忽略了上下文信息,而Transformer则具有建立全局依赖关系的能力,因此如何将Transformer结构引入到半监督学习中成为了一个重要的研究课题。论文《Semi-Supervised Medical Image Segmentation via CrossTeaching between CNN and Transformer》中提出搭建了基于交叉教学框架的网络,分别让Transformer产生的预测作为CNN的伪标签以及CNN产生的预测作为Transformer的伪标签,在ACDC等数据集上证明优于现有的半监督框架。如中国专利CN114882047 A,公开日为2022年8月9日,提出搭建CNN和Transformer混合的U型分割网络,通过在Transformer结构中加入残差模块并且在跳跃连接处使用支持向量机的方式对信息进行进一步的筛选和简化,极大的提升了缺少标注数据的医学影像的分割准确性。
发明内容
本发明提供了一种基于Transformer与CNN交互的半监督医学影像分割方法。
基于Transformer与CNN交互的半监督医学影像分割方法的过程为:首先利用有标签的数据进行训练为学生网络提供初始化参数;其次将学生网络参数复制给教师网络,并利用教师网络为无标签数据提供伪标签;此后对无标签数据进行数据增强;有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据优化网络参数,提取无标签数据特征并输出其预测分布。增强后的无标签数据输入至教师网络,提取无标签数据特征并输出其预测分布。此外引入一致性损失约束来自教师网络和学生网络的无标签数据特征,并计算教师网络预测输出的信息熵,通过设定信息熵阈值以提升伪标签置信度,利用伪标签和学生网络的预测分布计算半监督损失。最后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
一种基于Transformer与CNN交互的半监督医学影像分割方法,包括以下步骤:
(1)教师学生网络构建。教师和学生网络都由一个CNN分支和一个Transformer分支组成。并于每个阶段引入C2T(CNN to Transformer)和T2C(Transformer to CNN)模块以实现两个分支提取特征的交互。
(2)学生网络训练。利用有标签的数据
Figure BDA0004083470350000021
训练学生网络,其中N1为无标签数据的数量,/>
Figure BDA0004083470350000022
为第i个训练样本。
以初始化学生网络参数S(θ),并将其参数复制给教师网络T(θ)。
(3)生成伪标签。将无标签的数据
Figure BDA0004083470350000031
输入至教师网络,以教师网络的预测输出作为无标签的数据的伪标签/>
Figure BDA0004083470350000032
其中Nu为无标签数据的数量,/>
Figure BDA0004083470350000033
为第i张无标签图像。
(4)数据增强。利用CutMix数据增强方法,为无标签的数据进行数据增强。
(5)学生网络优化及预测。将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS
(6)特征一致性约束。将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束。
(7)伪标签监督训练。对步骤6中教师模型的预测分布计算信息熵
Figure BDA0004083470350000034
并设定选择阈值对伪标签进行过滤选择,以仅保留高置信度伪标签/>
Figure BDA0004083470350000035
并利用/>
Figure BDA0004083470350000036
监督学生网络的预测分布PS
(8)整体损失优化。训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数。
(9)网络参数传递。网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
(10)目标图像分割。对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,选取最大概率的类别作为该像素点预测类别。
步骤1所述的Transformer和CNN分支交互过程(C2T模块)为:
(i)在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为
Figure BDA0004083470350000037
通过全连接层预测的概率图为/>
Figure BDA0004083470350000038
其中/>
Figure BDA0004083470350000039
的维度为B×D×H×W,/>
Figure BDA00040834703500000310
的维度为B×Nc×H×W,B为训练批次的大小,H,W分别为特征图的高和宽,D为特征图的通道数,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
Figure BDA00040834703500000311
(ii)对于学生网络中Transformer分支第i个阶段,其输入特征图为
Figure BDA0004083470350000041
分别经过三个全连接层WQ,WK,WV计算Query(Q),Key(K),Value(V),Q、K、V表示为三个不同的特征图,并且计算得到Q与K的相似度矩阵SQ_K
(iii)分别计算步骤(ii)中的Q,K与步骤(i)得到的各个类别的中心
Figure BDA0004083470350000042
的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls
(iv)通过步骤(iii)得到的相似度矩阵SQ_cls和相似度矩阵SK_cls融合了CNN分支的类别中心的信息,将该信息传递给Transformer分支相应的层来优化矩阵SQ_K,得到融合后的相似度矩阵为
Figure BDA0004083470350000043
将该相似度矩阵与特征图V相乘得到增强后的特征FC2T,最终与CNN分支第i个阶段输出的特征/>
Figure BDA0004083470350000044
相加得到Transformer分支第i个阶段输出的特征图/>
Figure BDA0004083470350000045
步骤(i)所述的各个类别的中心
Figure BDA0004083470350000046
计算为:
Figure BDA0004083470350000047
其中,
Figure BDA0004083470350000048
为学生网络第i个阶段中CNN分支预测的分割概率图上索引位置j上的类别概率,/>
Figure BDA0004083470350000049
为学生网络CNN分支第i个阶段对应的特征图上索引位置j对应的特征向量,H,W为特征图的高和宽。
步骤(ii)中,Q与K的相似度矩阵SQ_K计算为:
Figure BDA00040834703500000410
Figure BDA00040834703500000411
其中D为特征图的通道维度,
Figure BDA00040834703500000412
为第i-1个阶段的Transformer分支输出的特征图,WQ,WK,WV为三个全连接层。softmax(·)为激活函数。
步骤(iii)中,分别计算步骤(ii)中的Q,K与步骤(i)得到的各个类别的中心
Figure BDA00040834703500000413
的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls,具体包括:
Figure BDA00040834703500000414
其中,D为特征图的通道数。
步骤(iv)中,融合后的相似度矩阵为
Figure BDA00040834703500000415
的计算包括:
Figure BDA00040834703500000416
特征图
Figure BDA00040834703500000417
作为最终融合输出,具体包括:
Figure BDA0004083470350000051
其中
Figure BDA0004083470350000052
为学生网络Transformer分支第i个阶段融合来自CNN分支的信息后最终输出的特征图,/>
Figure BDA0004083470350000053
为学生网络中CNN分支第i个阶段的输出特征图,FC2T为增强后的特征图,其中,/>
Figure BDA0004083470350000054
为融合后的相似度矩阵,V为特征图。/>
步骤1中,Transformer分支通过T2C模块与CNN分支交互,具体包括:
(i)学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为
Figure BDA0004083470350000055
其中/>
Figure BDA0004083470350000056
的维度为B×N×D,N为序列长度,并且有N=H×W,特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽,通过全连接层预测的分割概率图为/>
Figure BDA0004083470350000057
其维度为B×Nc×H×W,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
Figure BDA0004083470350000058
(ii)学生网络的CNN分支第i个阶段的输出特征图为
Figure BDA0004083470350000059
利用Transformer分支通过步骤(i)得到的类别中心和CNN分支的特征图计算交叉注意力机制得到增强后的特征为FT2C
(iii)将步骤(ii)得到的特征FT2C与CNN分支特征
Figure BDA00040834703500000510
进行融合得到最终输出
Figure BDA00040834703500000511
步骤(i)中,所各个类别的中心
Figure BDA00040834703500000512
计算过程为:
Figure BDA00040834703500000513
其中,
Figure BDA00040834703500000514
为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>
Figure BDA00040834703500000515
为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
步骤(ii)中,计算交叉注意力机制得到增强后的特征FT2C,具体包括:
Figure BDA00040834703500000516
Figure BDA00040834703500000517
其中WQ,WK,WV为三个全连接层,
Figure BDA00040834703500000518
为学生网络中的CNN分支的第i个阶段输出的特征图,D为特征图的通道数,softmax(·)为激活函数,/>
Figure BDA0004083470350000061
为步骤(i)中计算得到的各个类别的中心,Q,K,V为计算得到的三个不同的特征图;
步骤(iii)中,输出特征
Figure BDA0004083470350000062
的计算包括:
Figure BDA0004083470350000063
其中FT2C为增强后的特征,
Figure BDA0004083470350000064
为学生网络中的CNN分支的第i个阶段输出的特征图。
步骤2所述的初始化学生网络和将学生网络的参数复制给教师网络具体过程为:
(i)采用有标签数据
Figure BDA0004083470350000065
训练学生网络S(θ),并采用交叉熵损失函数计算监督训练损失/>
Figure BDA0004083470350000066
其中Nl为有标签数据的数量。/>
(ii)训练结束后将其参数复制给教师网络T(θ)。
Figure BDA0004083470350000067
Figure BDA0004083470350000068
表示模型训练输入的批次大小为/>
Figure BDA0004083470350000069
的有标签数据,lce表示交叉熵损失函数,/>
Figure BDA00040834703500000610
表示第i张有标签图像的真实标签,/>
Figure BDA00040834703500000611
为第i张有标签的图像。交叉熵损失函数可以表示为:/>
Figure BDA00040834703500000612
其中H,W为预测输出分割概率图的高和宽,yi表示图像上位置i处的真实类别,pi为位置i处的预测类别概率。
步骤3所述的得到无标签数据的伪标签具体过程为:
Figure BDA00040834703500000613
其中
Figure BDA00040834703500000614
为第i张无标签图像,/>
Figure BDA00040834703500000615
为教师网络T(θ)预测的伪标签。
步骤4中所述的数据增强方法CutMix具体步骤为:
Figure BDA00040834703500000616
其中M∈{0,1}W×H是为了裁剪掉某些区域产生的掩码,·为元素点乘,λ服从Beta分布,即λ~Beta(α,α),
Figure BDA00040834703500000617
分别为第i,j张无标签图像,/>
Figure BDA00040834703500000618
Figure BDA00040834703500000619
分别为第i,j张无标签图像的伪标签。
步骤5将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS
(i)学生网络利用有标签数据更新网络参数,并采用交叉熵损失函数计算监督训练损失
Figure BDA0004083470350000071
(ii)学生网络中的两个分支分别提取无标签数据特征FS-T和FS-C并输出其预测分布PS-T和PS-C,最终以二者均值作为学生网络的预测分布PS
步骤(i)中监督损失
Figure BDA0004083470350000072
为:
Figure BDA0004083470350000073
Figure BDA0004083470350000074
表示模型训练输入的批次大小为/>
Figure BDA0004083470350000075
的有标签数据,lce表示交叉熵损失函数,可表示为:/>
Figure BDA0004083470350000076
其中H,W为预测输出分割概率图的高和宽,ym表示图像上位置m处的真实类别,pm为位置m处的预测概率。/>
Figure BDA0004083470350000077
表示将第i张有标签的图像输入到学生网络后的预测分割结果,/>
Figure BDA0004083470350000078
表示第i张有标签图像的标签。
步骤(ii)中预测分布PS计算过程为:
Figure BDA0004083470350000079
其中
Figure BDA00040834703500000710
和/>
Figure BDA00040834703500000711
分别为学生网络的Transformer和CNN分支的特征提取器,/>
Figure BDA00040834703500000712
和/>
Figure BDA00040834703500000713
分别为学生网络的两个分支的预测分割头。PS-T和PS-C为学生网络两个分支的预测输出,xu为无标签图像。
步骤6将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束,具体步骤为:
(i)教师网络提取无标签数据特征FT并输出其预测分布PT
(ii)将教师网络Transformer分支和CNN分支提取的特征FT-T,FT-C,和学生网络中Transformer分支和CNN分支提取的特征FS-T,FS-C分别与其自身转置相乘获得协方差特征图
Figure BDA00040834703500000714
(iii)教师网络和学生网络对应分支的协方差特征图利用MSE损失函数进行一致性约束,并计算一致性损失
Figure BDA00040834703500000818
步骤(i)中预测分布PT计算过程为:
Figure BDA0004083470350000081
其中
Figure BDA0004083470350000082
和/>
Figure BDA0004083470350000083
分别为教师网络的Transformer和CNN分支的特征提取器,/>
Figure BDA0004083470350000084
和/>
Figure BDA0004083470350000085
分别为教师网络的两个分支的分割预测头。PS-T和PS-C为教师网络两个分支的预测输出,xu为无标签输入图像。
步骤(ii)中协方差矩阵的计算过程为:
FΣ=F·(F)T
步骤(iii)中一致性损失
Figure BDA0004083470350000086
计算过程为:
Figure BDA0004083470350000087
其中MSE为计算平方差,即MSE(x,y)=x2-y2
步骤7所述的针对步骤6中教师模型的预测分布计算信息熵
Figure BDA0004083470350000088
并设定选择阈值对伪标签进行过滤选择,以仅保留高置信度伪标签/>
Figure BDA0004083470350000089
并利用/>
Figure BDA00040834703500000810
监督学生网络的预测分布PS。具体步骤为:
(i)计算教师模型的预测分布的信息熵
Figure BDA00040834703500000811
(ii)设定选择阈值对伪标签进行过滤选择,以获得高置信度的预测输出
Figure BDA00040834703500000812
(iii)利用高置信度的标签
Figure BDA00040834703500000813
监督学生网络的预测分布PS。/>
步骤(i)中计算教师模型的预测分布的信息熵:
Figure BDA00040834703500000814
其中ij为预测输出的索引,i表示第i张无标签图像,j表示第j个像素点。C为输出通道数,c为通道数索引。
Figure BDA00040834703500000815
为第i张无标签图像的伪标签。
步骤ii中并设定选择阈值对伪标签进行过滤选择,以获得高置信度的预测输出
Figure BDA00040834703500000816
为:
Figure BDA00040834703500000817
其中γ为设定的阈值,
Figure BDA0004083470350000091
为计算得到的熵。/>
Figure BDA0004083470350000092
表示为第i张无标签图像的伪标签的第j个位置上的第c个类别的预测值。argmaxc(·)表示计算最大值对应的类别索引。/>
Figure BDA0004083470350000093
为经过筛选后的第i张无标签图像第j个像素点的伪标签。
步骤(iii)中利用高置信度的标签
Figure BDA0004083470350000094
监督学生网络的预测分布PS,具体步骤为:
Figure BDA0004083470350000095
其中
Figure BDA0004083470350000096
为无监督损失函数,/>
Figure BDA0004083470350000097
表示批次大小为/>
Figure BDA0004083470350000098
的无标签训练数据,lce为交叉熵损失函数。
步骤8中整体损失优化中的整体损失
Figure BDA0004083470350000099
为步骤2和步骤5中的监督损失/>
Figure BDA00040834703500000910
步骤6中的一致性损失/>
Figure BDA00040834703500000911
以及步骤7中的无监督损失/>
Figure BDA00040834703500000912
的线性组合:
Figure BDA00040834703500000913
其中,λ1和λ2为超参数,用于平衡两种损失对总损失的影响。
步骤9中,网络参数更新后,学生网络将其网络参数通过指数移动平均(EMA)传递给教师网络以进行进一步迭代训练,具体过程为:
T(θ)=EMA(S(θ))
与现有技术相比,本发明的有益效果:
(1)相比较于其他方法提到的基于Transformer的半监督医学影像分割方法仅仅在最后预测结果进行交互,本方法可以将CNN提取的局部特征与Transformer捕捉的全局特征进行更加充分的融合,提升了网络性能。
(2)相比较于其他方法提到的基于Transformer的半监督医学影像分割方法通常计算像素级别的损失,本方法改进为计算协方差矩阵的损失,增强了模型的鲁棒性。
(3)本方法同时分别用学生网络中的两个分支和教师网络中的两个分支进行交叉教学,能产生更加稳定准确的伪标签。
附图说明
图1为基于Transformer与CNN交互的半监督医学影像分割方法整体结构图;
图2为Transformer分支和CNN分支在每个阶段交互的C2T模块结构图;
图3为Transformer分支和CNN分支在每个阶段交互的T2C模块结构图;
图4为基于特征分布一致性损失约束图;
图5为本发明基于Transformer与CNN交互的半监督医学影像分割方法的流程示意图;
图6为本发明在心脏分割数据集(ACDC)上的分割结果图;
图7为本发明在皮肤病分割数据集(ISIC)上的分割结果图。
具体实施方式
本发明方法的整体框架如图1和图5所示,包括以下步骤:
1、教师学生网络构建。教师和学生网络都由一个CNN分支和一个Transformer分支组成,并于每个阶段引入C2T(CNN to Transformer)和T2C(Transformer to CNN)模块以实现两个分支提取特征的交互。
2、学生网络训练。利用有标签的数据
Figure BDA0004083470350000101
训练学生网络,以初始化学生网络参数S(θ),并将其参数复制给教师网络T(θ)。
3、生成伪标签。将无标签的数据
Figure BDA0004083470350000102
输入至教师网络,以教师网络的预测输出作为无标签的数据的伪标签/>
Figure BDA0004083470350000103
4、数据增强。利用CutMix数据增强方法,为有标签和无标签的数据进行数据增强。
5、学生网络优化及预测。将增强后的有标签和无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS
6、特征一致性约束。将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束。
7、伪标签监督训练。对步骤6中教师模型的预测分布计算信息熵
Figure BDA0004083470350000111
并设定选择阈值对伪标签进行过滤选择,以仅保留高置信度伪标签/>
Figure BDA0004083470350000112
并利用/>
Figure BDA0004083470350000113
监督学生网络的预测分布PS
8、整体损失优化。训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数。
9、网络参数传递。网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
10、目标图像分割。对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,并选择概率最大的对应的类别作为该像素点的预测类别。
步骤1中,教师和学生网络都由一个CNN分支和一个Transformer分支组成。并于每个阶段引入C2T(CNN to Transformer)和T2C(Transformer to CNN)模块以实现两个分支提取特征的交互。
具体而言:
(i)C2T模块
在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为
Figure BDA0004083470350000114
通过全连接层预测的概率图为/>
Figure BDA0004083470350000115
通过两者进行加权平均,可以得到各个类别的中心/>
Figure BDA0004083470350000116
Figure BDA0004083470350000117
其中,
Figure BDA0004083470350000118
为学生网络第i个阶段中CNN分支预测的分割概率图上索引位置j上的类别概率,/>
Figure BDA0004083470350000119
为学生网络CNN分支第i个阶段对应的特征图上索引位置j对应的特征向量,H,W为特征图的高和宽。
对于学生网络中Transformer分支第i个阶段,其输入特征图为
Figure BDA00040834703500001110
分别经过三个全连接层WQ,WK,WV计算Query(Q),Key(K),Value(V),并且计算可得Q与K的相似度矩阵SQ_K
Figure BDA00040834703500001111
Figure BDA00040834703500001112
其中D为特征图的通道维度,
Figure BDA0004083470350000121
为第i-1个阶段学生网络Transformer分支输出的特征图。WQ,WK,WV为三个全连接层。softmax(·)为激活函数。
分别计算公式(2)中的Q,K和通过公式(1)计算得到的类别中心的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls
Figure BDA0004083470350000122
其中D为特征图的通道数。
将该信息传递给Transformer相应的层来优化矩阵SQ_K,得到融合后的相似度矩阵为
Figure BDA0004083470350000123
Figure BDA0004083470350000124
最终将CNN分支第i个阶段的特征
Figure BDA0004083470350000125
和增强后的特征FC2T融合得到Transformer分支第i个阶段的最终的融合输出特征/>
Figure BDA0004083470350000126
Figure BDA0004083470350000127
其中
Figure BDA0004083470350000128
为学生网络Transformer分支第i个阶段融合来自CNN分支的信息后最终输出的特征图,/>
Figure BDA0004083470350000129
为学生网络中CNN分支第i个阶段的输出特征图,FC2T为增强后的特征图,其中,/>
Figure BDA00040834703500001210
为融合后的相似度矩阵,V为特征图。
(ii)T2C模块
学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为
Figure BDA00040834703500001211
该特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽。
Figure BDA00040834703500001212
为通过全连接层预测的分割概率图,其维度为B×Nc×H×W,Nc为类别数。通过两者进行加权平均,可以得到各个类别的中心/>
Figure BDA00040834703500001213
Figure BDA00040834703500001214
其中,
Figure BDA00040834703500001215
为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>
Figure BDA00040834703500001216
为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
学生网络的CNN分支第i个阶段的输出特征图为
Figure BDA00040834703500001217
可以利用Transformer分支通过公式(6)得到的类别中心和CNN分支输出的特征图计算交叉注意力机制得到增强后的特征为FT2C
Figure BDA0004083470350000131
Figure BDA0004083470350000132
其中WQ,WK,WV为三个全连接层,
Figure BDA0004083470350000133
为公式(6)计算得到的类别中心,D为特征图的通道数。
将公式(7)得到的特征FT2C与学生网络CNN分支第i个阶段输出的特征图
Figure BDA0004083470350000134
进行融合得到最终输出特征/>
Figure BDA0004083470350000135
Figure BDA0004083470350000136
步骤2中,采用有标签数据
Figure BDA0004083470350000137
训练学生网络S(θ),并采用交叉熵损失函数计算监督训练损失/>
Figure BDA0004083470350000138
训练结束后将其参数复制给教师网络T(θ),如公式(9)所示:训练结束后将其参数复制给教师网络T(θ)。
Figure BDA0004083470350000139
Figure BDA00040834703500001310
表示模型训练输入的批次大小为/>
Figure BDA00040834703500001311
的有标签数据,lce表示交叉熵损失函数,/>
Figure BDA00040834703500001312
表示第i张有标签图像的真实标签,/>
Figure BDA00040834703500001313
为第i张有标签的图像。交叉熵损失函数可以表示为:/>
Figure BDA00040834703500001314
其中H,W为预测输出分割概率图的高和宽,yi表示图像上位置i处的真实类别,pi为位置i处的预测类别概率。
步骤3中,将无标签的数据
Figure BDA00040834703500001315
输入至教师网络,以教师网络的预测输出作为无标签的数据的伪标签/>
Figure BDA00040834703500001316
如公式(10)所示:
Figure BDA00040834703500001317
其中
Figure BDA0004083470350000141
为第i张无标签图像,/>
Figure BDA0004083470350000142
为教师网络T(θ)预测的伪标签。
步骤4中,利用CutMix数据增强方法,为无标签的数据进行数据增强,如公式(11)所示:
Figure BDA0004083470350000143
其中M∈{0,1}W×H是为了裁剪掉某些区域产生的掩码,·为元素点乘,λ服从Beta分布,即λ~Beta(α,α),
Figure BDA0004083470350000144
分别为第i,j张无标签图像,/>
Figure BDA0004083470350000145
Figure BDA0004083470350000146
分别为第i,j张无标签图像的伪标签。
步骤5中,将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS。具体而言:
5.1学生网络利用有标签数据更新网络参数,并采用交叉熵损失函数计算监督训练损失
Figure BDA00040834703500001413
如公式(12)所示:
Figure BDA0004083470350000147
Figure BDA0004083470350000148
表示模型训练输入的批次大小为/>
Figure BDA0004083470350000149
的有标签数据,lce表示交叉熵损失函数,可表示为:/>
Figure BDA00040834703500001410
其中H,W为预测输出分割概率图的高和宽,ym表示图像上位置m处的真实类别,pm为位置m处的预测概率。/>
Figure BDA00040834703500001411
表示将第i张有标签的图像输入到学生网络后的预测分割结果,/>
Figure BDA00040834703500001412
表示第i张有标签图像的标签。
5.2学生网络中的两个分支分别提取无标签数据特征FS-T和FS-C并输出其预测分布PS-T和PS-C,最终以二者均值作为学生网络的预测分布PS如公式(13)所示:
Figure BDA0004083470350000151
其中
Figure BDA0004083470350000152
和/>
Figure BDA0004083470350000153
分别为学生网络的Transformer和CNN分支的特征提取器,/>
Figure BDA0004083470350000154
和/>
Figure BDA0004083470350000155
分别为学生网络的两个分支的分类器。PS-T和PS-C为学生网络两个分支的预测输出,xu表示输入的无标签图像。
步骤6中,将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束。具体而言:
6.1教师网络提取无标签数据特征FT并输出其预测分布PT,如公式(14)所示:
Figure BDA0004083470350000156
其中
Figure BDA0004083470350000157
和/>
Figure BDA0004083470350000158
分别为教师网络的Transformer和CNN分支的特征提取器,/>
Figure BDA0004083470350000159
和/>
Figure BDA00040834703500001510
分别为教师网络的两个分支的分类器。PT-T和PT-C为教师网络两个分支的预测输出,xu表示输入的无标签图像。
6.2将教师网络和学生网络提取的特征FT-T,FT-C,FS-T,FS-C与其自身转置相乘获得协方差特征图
Figure BDA00040834703500001511
如公式(15)所示;教师网络和学生网络对应分支的协方差特征图利用MSE损失函数进行一致性约束,并计算一致性损失/>
Figure BDA00040834703500001512
以提升模型的鲁棒性及泛化能力,如公式(16)所示。
Figure BDA00040834703500001513
Figure BDA00040834703500001514
其中MSE为计算平方差,即MSE(x,y)=x2-y2
步骤7中,对步骤6中教师模型的预测分布计算信息熵
Figure BDA00040834703500001515
并设定选择阈值对伪标签进行过滤选择,以仅保留高置信度伪标签/>
Figure BDA00040834703500001516
并利用/>
Figure BDA0004083470350000161
监督学生网络的预测分布PS。具体而言:
7.1计算教师模型的预测分布的信息熵,如公式(17)所示;并设定选择阈值对伪标签进行过滤选择,以获得高置信度的预测输出,如公式(18)所示:
Figure BDA0004083470350000162
其中ij为预测输出的索引,i表示第i张无标签图像,j表示第j个像素点。C为输出通道数,c为通道数索引。
Figure BDA0004083470350000163
为第i张无标签图像第j个像素点的伪标签。
Figure BDA0004083470350000164
其中γ为设定的阈值,
Figure BDA0004083470350000165
为计算得到的熵。/>
Figure BDA0004083470350000166
表示为第i张无标签图像的伪标签的第j个位置上的第c个类别的预测值。argmaxc(·)表示计算最大值对应的类别索引。/>
Figure BDA0004083470350000167
为经过筛选后的第i张无标签图像第j个像素点的伪标签。
7.2利用高置信度的标签
Figure BDA0004083470350000168
监督学生网络的预测分布PS,如公式(19)所示:/>
Figure BDA0004083470350000169
其中
Figure BDA00040834703500001610
为无监督损失函数,/>
Figure BDA00040834703500001611
表示批次大小为/>
Figure BDA00040834703500001612
的无标签训练数据,lce为交叉熵损失函数。
步骤8中的整体损失
Figure BDA00040834703500001613
为步骤2和步骤5中的监督损失/>
Figure BDA00040834703500001614
步骤6中的一致性损失/>
Figure BDA00040834703500001615
以及步骤7中的无监督损失/>
Figure BDA00040834703500001616
的线性组合,如公式(20)所示:
Figure BDA00040834703500001617
其中,λ1和λ2为超参数,用于平衡两种损失对总损失的影响。
步骤9中,网络参数更新后,学生网络将其网络参数通过指数移动平均(EMA)传递给教师网络以进行进一步迭代训练,如公式(21)所示:
T(θ)=EMA(S(θ)) (21)
步骤10中,所述的目标图像分割,给定目标域图像数据,分割模型为每个像素点预测其类别概率,选择预测概率最大的类别作为预测类别,以获得最终的分割掩码。
进一步地,如图1所示,一种基于Transformer与CNN交互的半监督医学影像分割方法具体过程如下:
1、教师学生网络构建。教师和学生网络都由一个CNN分支和一个Transformer分支组成。并于每个阶段引入C2T(CNN to Transformer)如图2所示,具体步骤为:
(i)在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为
Figure BDA0004083470350000171
通过全连接层预测的概率图为/>
Figure BDA0004083470350000172
通过两者进行加权平均,可以得到各个类别的中心/>
Figure BDA0004083470350000173
如下所示:
Figure BDA0004083470350000174
其中,
Figure BDA0004083470350000175
为学生网络第i个阶段中CNN分支预测的分割概率图上索引位置j上的类别概率,/>
Figure BDA0004083470350000176
为学生网络CNN分支第i个阶段对应的特征图上索引位置j对应的特征向量,H,W为特征图的高和宽。
(ii)对于学生网络中Transformer分支第i个阶段,其输入特征图为
Figure BDA0004083470350000177
分别经过三个全连接层WQ,WK,WV计算得到三个不同特征图Query(Q),Key(K),Value(V),并且计算可得Q与K的相似度矩阵SQ_K
Figure BDA0004083470350000178
Figure BDA0004083470350000179
其中D为特征图的通道维度,
Figure BDA00040834703500001710
为第i-1个阶段的Transformer分支输出的特征图。WQ,WK,WV为三个全连接层。
(iii)分别计算步骤(ii)中的Q,K和通过步骤(i)计算得到的类别中心的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls
Figure BDA0004083470350000181
其中D为特征图的通道数。
(iv)将该信息传递给Transformer相应的层来优化SQ_K,得到融合后的相似度矩阵为
Figure BDA0004083470350000182
Figure BDA0004083470350000183
(v)最终将CNN分支第i个阶段特征和Transformer分支第i个阶段特征融合得到
Figure BDA0004083470350000184
Figure BDA0004083470350000185
其中
Figure BDA0004083470350000186
为学生网络Transformer分支第i个阶段融合来自CNN分支的信息后最终输出的特征图,/>
Figure BDA0004083470350000187
为学生网络中CNN分支第i个阶段的输出特征图,FC2T为增强后的特征图,其中,/>
Figure BDA0004083470350000188
为融合后的相似度矩阵,V为特征图。
同样,在每个阶段引入T2C模块,如图3所示,具体步骤为:
(i)学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为
Figure BDA0004083470350000189
该特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽。/>
Figure BDA00040834703500001810
为通过全连接层预测的分割概率图,其维度为B×Nc×H×W,Nc为类别数。通过两者进行加权平均,可以得到各个类别的中心/>
Figure BDA00040834703500001811
Figure BDA00040834703500001812
其中,
Figure BDA00040834703500001813
为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>
Figure BDA00040834703500001814
为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
(ii)学生网络的CNN分支第i个阶段的输出特征图为
Figure BDA00040834703500001815
可以利用Transformer分支通过步骤(i)得到的类别中心/>
Figure BDA00040834703500001816
和CNN分支的特征图计算交叉注意力机制得到增强后的特征为FT2C
Figure BDA00040834703500001817
Figure BDA00040834703500001818
其中WQ,WK,WV为三个全连接层,
Figure BDA00040834703500001819
为中步骤(i)计算得到的类别中心,D为特征图的通道数。/>
(iii)将步骤(ii)得到的特征FT2C与CNN分支特征进行融合得到最终输出特征
Figure BDA0004083470350000191
Figure BDA0004083470350000192
2、学生网络训练。利用有标签的数据
Figure BDA0004083470350000193
训练学生网络,以初始化学生网络参数S(θ),并将其参数复制给教师网络T(θ)。
3、生成伪标签。将无标签的数据
Figure BDA0004083470350000194
输入至教师网络,以教师网络的预测输出作为无标签的数据的伪标签/>
Figure BDA0004083470350000195
4、数据增强。利用CutMix数据增强方法,为无标签的数据进行数据增强。
5、学生网络优化及预测。将有标签和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS
6、特征一致性约束。将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT。将教师网络提取的特征与学生网络进行一致性约束,如图4所示,具体过程为:
将教师网络和学生网络提取的特征FT-T,FT-C,FS-T,FS-C与其自身转置相乘获得协方差特征图
Figure BDA0004083470350000196
教师网络和学生网络对应分支的协方差特征图利用MSE损失函数进行一致性约束,并计算一致性损失/>
Figure BDA0004083470350000197
以提升模型的鲁棒性及泛化能力,如下公式所示。
FΣ=F·(F)T
Figure BDA0004083470350000198
其中MSE为计算平方差,即MSE(x,y)=x2-y2
7、伪标签监督训练。对步骤6中教师模型的预测分布计算信息熵
Figure BDA0004083470350000199
并设定选择阈值对伪标签进行过滤选择,以仅保留高置信度伪标签/>
Figure BDA00040834703500001910
并利用/>
Figure BDA00040834703500001911
监督学生网络的预测分布PS
8、整体损失优化。训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数。
9、网络参数传递。网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练。
10、目标图像分割。对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,并选取概率最大的对应的类别作为该点的预测类别。
本发明采用两个多中心的公开数据集(皮肤病分割数据集ISIC和心脏分割数据集ACDC)对本发明的性能进行评估。皮肤病分割数据集ISIC包含两个类别,共2594张图像,选择其中的1838张图像用于训练,剩余的756张用于验证模型。心脏分割数据集ACDC包含三个类别(左心房、右心房、心室壁),包含100个病人的图像,选择其中的70个病人的影像用来训练,10个用来验证,20个用来测试。对于半监督任务来说,分别选择训练集中的3%和10%作为有标签数据,剩余部分作为无标签数据。我们和目前性能最好的半监督框架CTCT分别在两个数据集上进行了对比,其中真实标签为领域专家所标注,分割结果如图6和图7所示。可以发现,本发明在两个数据集上,模型的泛化能力和分割性能均优于现有的表现最好的模型结构,验证了本发明的性能。

Claims (10)

1.一种基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,包括以下步骤:
(1)教师学生网络构建:教师网络和学生网络都由CNN分支和Transformer分支结合而成,CNN分支提取医学影像特征和Transformer分支提取医学影像特征并进行交互,所述的CNN分支通过C2T模块与Transformer分支交互,所述的Transformer分支通过T2C模块与CNN分支交互;
(2)学生网络训练:利用有标签的数据训练学生网络,以初始化学生网络参数,并将参数复制给教师网络;
(3)生成伪标签:将医学影像中无标签的数据输入至教师网络,以教师网络的预测输出作为无标签的数据的伪标签;
(4)数据增强:利用数据增强方法,为无标签的数据进行数据增强;
(5)学生网络优化及预测:将有标签数据和增强后的无标签数据同时输入至学生网络,学生网络利用有标签数据更新网络参数,提取无标签数据特征FS并输出其预测分布PS
(6)特征一致性约束:将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT,将教师网络提取的特征与学生网络提取的特征进行一致性约束;
(7)伪标签监督训练;对步骤6中教师网络的预测分布PT计算信息熵
Figure FDA0004083470340000013
并设定选择阈值对伪标签进行过滤选择,以仅保留高置信度伪标签/>
Figure FDA0004083470340000011
并利用/>
Figure FDA0004083470340000012
监督学生网络的预测分布PS
(8)整体损失优化:训练上述网络时,对总体损失进行优化以及反向梯度传播,以更新网络参数:
(9)网络参数传递:网络参数更新后,学生网络将其网络参数通过指数移动平均传递给教师网络以进行进一步迭代训练;
(10)目标图像分割:对于给定目标域图像,分割模型输出目标图像每个像素点所属类别概率,选取最大概率的类别作为该像素点预测类别。
2.根据权利要求1所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤1中,CNN分支通过C2T模块与Transformer分支交互,具体包括:
(i)在每个阶段将CNN分支学习到的特征通过C2T模块传递给Transformer分支,记学生网络中CNN分支第i个阶段输出的特征图为
Figure FDA0004083470340000021
通过全连接层预测的概率图为/>
Figure FDA0004083470340000022
其中/>
Figure FDA0004083470340000023
的维度为B×D×H×W,/>
Figure FDA0004083470340000024
的维度为B×Nc×H×W,B为训练批次的大小,H,W分别为特征图的高和宽,D为特征图的通道数,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
Figure FDA0004083470340000025
(ii)对于学生网络中Transformer分支第i个阶段,其输入特征图为
Figure FDA0004083470340000026
分别经过三个全连接层WQ,WK,WV计算Query(Q),Key(K),Value(V),Q、K、V表示为三个不同的特征图,并且计算得到Q与K的相似度矩阵SQ_K
(iii)分别计算步骤(ii)中的Q,K与步骤(i)得到的各个类别的中心
Figure FDA0004083470340000027
的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls;/>
(iv)通过步骤(iii)得到的SQ_cls和SK_cls融合了CNN分支的类别中心的信息,将该信息传递给Transformer分支相应的层来优化矩阵SQ_K,得到融合后的相似度矩阵为
Figure FDA0004083470340000028
将该相似度矩阵与特征图V相乘得到增强后的特征FC2T,最终与CNN分支第i个阶段输出的特征
Figure FDA0004083470340000029
相加得到Transformer分支第i个阶段输出的特征图/>
Figure FDA00040834703400000210
3.根据权利要求2所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤(i)中,各个类别的中心
Figure FDA00040834703400000211
计算为:
Figure FDA00040834703400000212
其中,
Figure FDA00040834703400000213
为学生网络第i个阶段中CNN分支预测的分割概率图上索引位置j上的类别概率,/>
Figure FDA00040834703400000214
为学生网络CNN分支第i个阶段对应的特征图上索引位置j对应的特征向量,H,W为特征图的高和宽。
4.根据权利要求2所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤(ii)中,Q与K的相似度矩阵SQ_K计算为:
Figure FDA00040834703400000215
Figure FDA00040834703400000216
其中D为特征图的通道维度,
Figure FDA00040834703400000217
为第i-1个阶段的Transformer分支输出的特征图,WQ,WK,WV为三个全连接层。softmax(·)为激活函数。
5.根据权利要求2所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤(iii)中,分别计算步骤(ii)中的Q,K与步骤(i)得到的各个类别的中心
Figure FDA0004083470340000031
的相似度得到相似度矩阵SQ_cls和相似度矩阵SK_cls,具体包括:
Figure FDA0004083470340000032
其中,D为特征图的通道数。
6.根据权利要求2所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤(iv)中,融合后的相似度矩阵为
Figure FDA0004083470340000033
的计算包括:
Figure FDA0004083470340000034
步骤(iv)中,特征图
Figure FDA0004083470340000035
作为最终融合输出,具体包括:
Figure FDA0004083470340000036
其中
Figure FDA0004083470340000037
为学生网络Transformer分支第i个阶段融合来自CNN分支的信息后最终输出的特征图,/>
Figure FDA0004083470340000038
为学生网络中CNN分支第i个阶段的输出特征图,FC2T为增强后的特征图,其中,/>
Figure FDA0004083470340000039
为融合后的相似度矩阵,V为特征图。/>
7.根据权利要求1所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤1中,Transformer分支通过T2C模块与CNN分支交互,具体包括:
(i)学生网络在每个阶段将Transformer学习到的特征通过T2C模块传递给CNN分支,记学生网络中Transformer分支第i个阶段输出的特征图为
Figure FDA00040834703400000310
其中/>
Figure FDA00040834703400000311
的维度为B×N×D,N为序列长度,并且有N=H×W,特征图经过展平和形状变换得到维度B×D×H×W,B为训练批次大小,D为特征通道数,H,W分别为特征图的高和宽,通过全连接层预测的分割概率图为/>
Figure FDA00040834703400000312
其维度为B×Nc×H×W,Nc为类别数,通过两者进行加权平均,得到各个类别的中心/>
Figure FDA00040834703400000313
(ii)学生网络的CNN分支第i个阶段的输出特征图为
Figure FDA00040834703400000314
利用Transformer分支通过步骤i得到的类别中心和CNN分支的特征图计算交叉注意力机制得到增强后的特征为FT2C
(iii)将步骤(ii)得到的特征FT2C与CNN分支特征
Figure FDA00040834703400000315
进行融合得到最终输出特征图
Figure FDA00040834703400000316
8.根据权利要求7所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤(i)中,各个类别的中心
Figure FDA0004083470340000041
计算过程为:
Figure FDA0004083470340000042
其中,
Figure FDA0004083470340000043
为学生网络中的Transformer分支的第i个阶段通过全连接层预测的分割概率图索引位置j上的类别概率,/>
Figure FDA0004083470340000044
为学生网络中的Transformer分支的第i个阶段输出的特征图索引位置j对应的特征向量,H,W分别为特征图的高和宽;
步骤(ii)中,计算交叉注意力机制得到增强后的特征FT2C,具体包括:
Figure FDA0004083470340000045
Figure FDA0004083470340000046
其中WQ,WK,WV为三个全连接层,
Figure FDA0004083470340000047
为学生网络中的CNN分支的第i个阶段输出的特征图,D为特征图的通道数,softmax(·)为激活函数,/>
Figure FDA0004083470340000048
为类别中心;
步骤(iii)中,融合后输出的特征图
Figure FDA0004083470340000049
的计算包括:
Figure FDA00040834703400000410
其中,FT2C为增强后的特征,
Figure FDA00040834703400000411
为学生网络中的CNN分支的第i个阶段输出的特征图。
9.根据权利要求1所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤6中,将增强后的无标签数据输入至教师网络,提取无标签数据特征FT并输出其预测分布PT,将教师网络提取的特征与学生网络提取的特征进行一致性约束,具体步骤为:
(i)教师网络提取无标签数据特征FT并输出其预测分布PT
(ii)将教师网络中Transformer分支和CNN分支提取的特征FT-T,FT-C、学生网络中Transformer分支和CNN分支提取的特征FS-T,FS-C分别与自身转置相乘获得协方差特征图
Figure FDA00040834703400000412
(iii)教师网络和学生网络对应分支的协方差特征图利用MSE损失函数进行一致性约束,并计算一致性损失
Figure FDA00040834703400000413
通过一致性损失/>
Figure FDA00040834703400000414
将教师网络提取的特征与学生网络提取的特征进行一致性约束。
10.根据权利要求9所述的基于Transformer与CNN交互的半监督医学影像分割方法,其特征在于,步骤(iii)中,一致性损失
Figure FDA0004083470340000051
计算过程为:
Figure FDA0004083470340000052
其中MSE为计算平方差。
CN202310129552.1A 2023-02-03 2023-02-03 基于Transformer与CNN交互的半监督医学影像分割方法 Pending CN116258695A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310129552.1A CN116258695A (zh) 2023-02-03 2023-02-03 基于Transformer与CNN交互的半监督医学影像分割方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310129552.1A CN116258695A (zh) 2023-02-03 2023-02-03 基于Transformer与CNN交互的半监督医学影像分割方法

Publications (1)

Publication Number Publication Date
CN116258695A true CN116258695A (zh) 2023-06-13

Family

ID=86678835

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310129552.1A Pending CN116258695A (zh) 2023-02-03 2023-02-03 基于Transformer与CNN交互的半监督医学影像分割方法

Country Status (1)

Country Link
CN (1) CN116258695A (zh)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116741372A (zh) * 2023-07-12 2023-09-12 东北大学 一种基于双分支表征一致性损失的辅助诊断系统及装置
CN116862931A (zh) * 2023-09-04 2023-10-10 北京壹点灵动科技有限公司 医学图像分割方法、装置、存储介质及电子设备
CN117253044A (zh) * 2023-10-16 2023-12-19 安徽农业大学 一种基于半监督交互学习的农田遥感图像分割方法

Cited By (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116741372A (zh) * 2023-07-12 2023-09-12 东北大学 一种基于双分支表征一致性损失的辅助诊断系统及装置
CN116741372B (zh) * 2023-07-12 2024-01-23 东北大学 一种基于双分支表征一致性损失的辅助诊断系统及装置
CN116862931A (zh) * 2023-09-04 2023-10-10 北京壹点灵动科技有限公司 医学图像分割方法、装置、存储介质及电子设备
CN116862931B (zh) * 2023-09-04 2024-01-23 北京壹点灵动科技有限公司 医学图像分割方法、装置、存储介质及电子设备
CN117253044A (zh) * 2023-10-16 2023-12-19 安徽农业大学 一种基于半监督交互学习的农田遥感图像分割方法
CN117253044B (zh) * 2023-10-16 2024-05-24 安徽农业大学 一种基于半监督交互学习的农田遥感图像分割方法

Similar Documents

Publication Publication Date Title
Wang et al. Mixed transformer u-net for medical image segmentation
Mou et al. Vehicle instance segmentation from aerial image and video using a multitask learning residual fully convolutional network
Huang et al. Instance-aware image and sentence matching with selective multimodal lstm
CN116258695A (zh) 基于Transformer与CNN交互的半监督医学影像分割方法
CN109766840A (zh) 人脸表情识别方法、装置、终端及存储介质
Wang et al. When cnn meet with vit: Towards semi-supervised learning for multi-class medical image semantic segmentation
Cai et al. A robust interclass and intraclass loss function for deep learning based tongue segmentation
Wang et al. An uncertainty-aware transformer for MRI cardiac semantic segmentation via mean teachers
Wang et al. Medical matting: a new perspective on medical segmentation with uncertainty
Xu et al. Vision transformers for computational histopathology
CN114898407A (zh) 一种基于深度学习牙齿目标实例分割及其智能预览的方法
Zhao et al. Deeply supervised active learning for finger bones segmentation
Li et al. NIA-Network: Towards improving lung CT infection detection for COVID-19 diagnosis
CN111126155A (zh) 一种基于语义约束生成对抗网络的行人再识别方法
Shu et al. Privileged multi-task learning for attribute-aware aesthetic assessment
Yang et al. GGAC: Multi-relational image gated GCN with attention convolutional binary neural tree for identifying disease with chest X-rays
CN111898756B (zh) 一种多目标信息关联神经网络损失函数计算方法及装置
CN116759076A (zh) 一种基于医疗影像的无监督疾病诊断方法及系统
CN113590971B (zh) 一种基于类脑时空感知表征的兴趣点推荐方法及系统
CN114299342B (zh) 一种基于深度学习的多标记图片分类中未知标记分类方法
Wang et al. Optimized lightweight CA-transformer: Using transformer for fine-grained visual categorization
Cheng et al. Double attention for pathology image diagnosis network with visual interpretability
Zhao et al. VCMix-Net: A hybrid network for medical image segmentation
Yang et al. Robust feature mining transformer for occluded person re-identification
WO2024108522A1 (zh) 基于自监督学习的多模态脑肿瘤影像分割方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination