CN116091449A - 一种基于无监督异构蒸馏框架的视网膜oct图像病变分类方法 - Google Patents

一种基于无监督异构蒸馏框架的视网膜oct图像病变分类方法 Download PDF

Info

Publication number
CN116091449A
CN116091449A CN202310020402.7A CN202310020402A CN116091449A CN 116091449 A CN116091449 A CN 116091449A CN 202310020402 A CN202310020402 A CN 202310020402A CN 116091449 A CN116091449 A CN 116091449A
Authority
CN
China
Prior art keywords
stage
network
feature
teacher
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310020402.7A
Other languages
English (en)
Inventor
李慧琦
陆帅
赵赫
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Institute of Technology BIT
Original Assignee
Beijing Institute of Technology BIT
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Institute of Technology BIT filed Critical Beijing Institute of Technology BIT
Priority to CN202310020402.7A priority Critical patent/CN116091449A/zh
Publication of CN116091449A publication Critical patent/CN116091449A/zh
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T7/00Image analysis
    • G06T7/0002Inspection of images, e.g. flaw detection
    • G06T7/0012Biomedical image inspection
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/40Extraction of image or video features
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • G06V10/765Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects using rules for classification or partitioning the feature space
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06TIMAGE DATA PROCESSING OR GENERATION, IN GENERAL
    • G06T2207/00Indexing scheme for image analysis or image enhancement
    • G06T2207/30Subject of image; Context of image processing
    • G06T2207/30004Biomedical image processing
    • G06T2207/30041Eye; Retina; Ophthalmic
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Health & Medical Sciences (AREA)
  • Evolutionary Computation (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • Databases & Information Systems (AREA)
  • Quality & Reliability (AREA)
  • Radiology & Medical Imaging (AREA)
  • Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)

Abstract

本发明涉及一种基于无监督异构蒸馏框架的视网膜OCT图像病变分类方法,属于图像分类技术领域。该方法包含一个教师网络和一个学生网络,教师网络以在自然图像上预先训练的参数作为教师网络的初始参数,训练所述方法时只需要少量正常的视网膜OCT图像,并且在训练时教师网络不更新参数,学生网络以教师网络的特征为输入,并且学习教师网络产生的浅层特征,在测试阶段,通过对比教师网络和学生网络产生特征的差异来判断待测图像是否是病变图像,这样能有缓解训练深度学习网络需要大量医学图像标注的问题。

Description

一种基于无监督异构蒸馏框架的视网膜OCT图像病变分类方法
技术领域
本发明涉及一种基于无监督异构蒸馏框架的视网膜OCT图像病变分类方法,属于图像分类技术领域。
背景技术
据世界卫生组织统计,在2010年全球大约有3亿人受到眼疾困扰,其中包括3900万人失明。大约80%的视力损伤可以通过预防可以得到避免。在所有可能导致视力受损的因素中,眼底病变是一个重要因素。常见的眼底病有糖尿病黄斑水肿、视网膜阻塞和青光眼等。
眼底疾病的预防和早期针对可以避免失明和视力的损伤。光学相干断层扫描技术(Optical Coherence Tomography,OCT)作为一种新型的医学成像技术被用于眼科疾病的诊断和治疗。OCT成像技术具有无创、无侵入的优点,因此适用于眼底织成像。它可通过采集到的二维扫描切片对视网膜进行三维建模,极大便利了医生的诊断。通过OCT技术获取黄斑中心和视盘中心附近的扫描图像,就可初步地对视网膜形态进行评估。由于OCT技术可以获得更深层和更细致的视网膜结构信息,使得对眼部疾病的定性以及判断更加精确。
视网膜OCT图像分类方法可以分为基于手工特征的传统方法和基于卷积神经网络(CNN)的深度学习方法。传统方法主要包括边缘检测方法、阈值方法、色差方法和超像素方法。这些方法主要基于手工特征进行图像分类,容易受到图像质量和噪声损伤的影响。与传统方法相比,卷积神经网络可以自动从图像中提取特征。许多基于CNN的变体已经被提出来分类视网膜OCT图像。虽然基于CNN的方法比手工制作的基于特征的方法具有更好的性能,但是基于CNN的OCT图像分类方法需要大量的医学标注用于模型的训练才能提高模型的性能。
不同于自然图像的标注,医学图像的病变标注必须要有经验丰富的医生进行标注。经典的深度学习方法用于视网膜OCT图像分类需要大量带有标注的视网膜OCT图像。然而标注大量的视网膜OCT图像将会给医生带来巨大的负担,一些具有较高准确度的无监督方法成为当前医学图像分析的焦点。
发明内容
针对现有技术中的缺陷,本发明提供一种视网膜OCT图像病变分类方法,用于解决现有技术需要大量医学标注才能实现模型较高性能的问题。
本发明的技术解决方案是:
一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,该方法中使用一个教师网络和一个学生网络,教师网络以在自然图像上预先训练的参数作为教师网络的初始参数,训练所述方法时只需要少量正常的视网膜OCT图像,并且在训练时教师网络不更新参数,学生网络以教师网络的特征为输入,并且学习教师网络产生的浅层特征,在测试阶段,通过对比教师网络和学生网络产生特征的差异来判断待测图像是否是病变图像,这样能有缓解训练深度学习网络需要大量医学图像标注的问题;
该方法具体包含以下步骤:
S1,对输入的视网膜OCT图像进行预处理,教师网络提取预处理后的OCT图像的特征,得到四组不同尺度的特征,四组不同尺度的特征分别为教师网络的第一阶段特征
Figure BDA0004041638510000021
教师网络的第二阶段特征
Figure BDA0004041638510000022
教师网络的第三阶段特征
Figure BDA0004041638510000023
和教师网络的第四阶段特征
Figure BDA0004041638510000024
S2,学生网络以步骤S1得到的教师网络的第四阶段特征
Figure BDA0004041638510000025
作为输入,生成三组不同尺度的特征,三组不同尺度的特征分别为学生网络的第三阶段特征
Figure BDA0004041638510000026
学生网络的第二阶段特征
Figure BDA0004041638510000027
和学生网络的第一阶段特征
Figure BDA0004041638510000028
学生网络是一个卷积神经网络和transformer模块混合而成的混合网络;
S3,对学生网络进行参数优化,学生网络的优化目标是使得学生网络的第一阶段特征
Figure BDA0004041638510000029
与教师网络的第一阶段特征
Figure BDA00040416385100000210
更接近、学生网络的第二阶段特征
Figure BDA0004041638510000031
与教师网络的第二阶段特征
Figure BDA0004041638510000032
更接近、学生网络的第三阶段特征
Figure BDA0004041638510000033
与教师网络的第三阶段特征
Figure BDA0004041638510000034
更接近,最终得到优化后的学生网络;
S4,使用教师网络提取待测视网膜OCT图像的特征,得到四组不同尺度的特征,四组不同尺度的特征分别为教师网络的第一阶段特征
Figure BDA0004041638510000035
教师网络的第二阶段特征
Figure BDA0004041638510000036
教师网络的第三阶段特征
Figure BDA0004041638510000037
和教师网络的第四阶段特征
Figure BDA0004041638510000038
S5,以步骤S4得到的教师网络的第四阶段特征
Figure BDA0004041638510000039
作为步骤S3优化后的学生网络的输入,生成三组不同尺度的特征,三组不同尺度的特征分别为学生网络的第三阶段特征
Figure BDA00040416385100000310
学生网络的第二阶段特征
Figure BDA00040416385100000311
和学生网络的第一阶段特征
Figure BDA00040416385100000312
S6,计算步骤S4得到的教师网络的第一阶段特征
Figure BDA00040416385100000313
与步骤S5得到的学生网络的第一阶段特征
Figure BDA00040416385100000314
的相似度,进而用于计算第一阶段病变得分Score1;同理,计算教师网络的第二阶段特征
Figure BDA00040416385100000315
与步骤S5得到的学生网络的第二阶段特征
Figure BDA00040416385100000316
的相似度,进而计算第二阶段病变得分Score2;计算教师网络的第三阶段特征
Figure BDA00040416385100000317
与步骤S5得到的学生网络的第三阶段特征
Figure BDA00040416385100000318
的相似度,进而计算第三阶段病变得分Score3,并将第一阶段病变得分Score1、第二阶段病变得分Score2和第三阶段病变得分Score3相加,得到待测图像的病变得分Score。
所述的步骤S1中,对输入的视网膜OCT图像进行预处理具体为:将输入的OCT图像压缩至分辨率为(H,W)的大小,其中H的取值范围为112~448像素,W与H相同;
所述的步骤S1中,教师网络是一个在ImageNet大规模数据集上预先训练过的卷积神经网络(也称作CNN),优选地,教师网络可选择ResNet、DenseNet和VGGNet等分类卷积神经网络;
所述教师网络使用在ImageNet数据集上预先训练好的权重作为初始化,并且训练阶段教师网络的参数权重不更新;
所述教师网络继承了经典分类卷积神经网络四个阶段中特征提取块的结构,但是教师网络将经典卷积分类网络中最后的全连接层删除,教师网络的四个阶段特征提取块分别产生步骤S1中所述的教师网络产生的四组不同尺度的特征,分别记为教师网络的第一阶段特征
Figure BDA0004041638510000041
教师网络第二阶段特征
Figure BDA0004041638510000042
教师网络第三阶段特征
Figure BDA0004041638510000043
和教师网络第四阶段特征
Figure BDA0004041638510000044
四个阶段特征提取块分别为第一个阶段特征器、第二个阶段特征器、第三个阶段特征器、第四个阶段特征器;
所述S1中教师网络产生所述四组不同尺度的特征,其中所述教师网络第一阶段特征、第二阶段特征、第三阶段特征和第四阶段特征的提取方法为:
将预处理后分辨率为(H,W)的OCT图像输入到教师网络第一个阶段特征器后得到教师网络的第一阶段特征
Figure BDA0004041638510000045
其中特征
Figure BDA0004041638510000046
的维度为(H/4,W/4,64);教师网络的第二阶段特征提取器将第一阶段特征
Figure BDA0004041638510000047
提取压缩成教师网络的第二阶段特征
Figure BDA0004041638510000048
其中
Figure BDA0004041638510000049
维度为(H/8,W/8,128);进一步,教师网络的第三阶段特征提取器将第二阶段特征
Figure BDA00040416385100000410
提取压缩为教师网络的第三阶段特征
Figure BDA00040416385100000411
其中
Figure BDA00040416385100000412
维度为(H/16,W/16,256);最后,教师网络的第四阶段特征提取器将第三阶段特征
Figure BDA00040416385100000413
提取压缩为教师网络的第四阶段特征
Figure BDA00040416385100000414
其中
Figure BDA00040416385100000415
特征维度为(H/32,W/32,512);
在所述步骤S2中,学生网络生成的三组不同尺度的特征依次为学生网络的第三阶段特征
Figure BDA00040416385100000416
学生网络的第二阶段特征
Figure BDA00040416385100000417
和学生网络的第一阶段特征
Figure BDA00040416385100000418
其中生成三组不同尺度的特征的方法为:
所述学生网络以教师网络第四阶段特征
Figure BDA00040416385100000419
为输入,然后学生网络对输入的第四阶段的特征进行处理,学生网络共包括三个阶段来生成多尺度的特征,学生网络将第四阶段特征
Figure BDA00040416385100000420
依次解码至与所述教师网络三个阶段特征相一致的尺度,其中教师网络的三个阶段特征分别为教师的第一阶段特征
Figure BDA00040416385100000421
第二阶段特征
Figure BDA00040416385100000422
和第三阶段特征
Figure BDA00040416385100000423
所述学生网络将教师网络的第四阶段特征
Figure BDA00040416385100000424
处理为学生网络的第三阶段特征
Figure BDA00040416385100000425
其中
Figure BDA00040416385100000426
维度为
Figure BDA00040416385100000427
H3、W3和C3分别代表特征的高度、宽度和通道数目;进一步,学生网络将第三阶段特征
Figure BDA00040416385100000428
生成第二阶段特征
Figure BDA00040416385100000429
其中
Figure BDA00040416385100000430
维度为
Figure BDA00040416385100000431
最后,学生网络将第二阶段特征
Figure BDA00040416385100000432
生成第一阶段特征
Figure BDA00040416385100000433
其中
Figure BDA00040416385100000434
维度为
Figure BDA00040416385100000435
在所述的步骤S2中,学生网络是一个卷积神经网络和transformer模块混合而成的混合网络,学生网络的三个阶段都是由模块单元组成,模块单元包括CNN子块(也称作卷积子块)和transformer子块,所述模块单元为:
所述的模块单元中包含两个并行的子块,它们分别是CNN子块和transformer子块;
首先模块单元通过1×1卷积对输入的特征进行维度调整,调整后的新特征被分成两个特征组,分别记为混合模块的第一组特征和混合模块的第二组特征,卷积模块的第一组特征通过CNN子块后产生特征FConv,混合模块的第二组特征通过transformer子块后产生特征FTran,最终特征FConv和特征FTran并列堆叠到一起生成特征FTran-Conv,特征FTran-conv的通道数通过使用1×1的卷积来调整;
所述的模块单元中CNN子块具体结构为:卷积子块包含两个连续的卷积核大小为3×3的卷积,其中卷积的步长为1×1和填充padding为1;
如图3所示,所述模块单元中transformer子块为多尺度稀疏transformer模块,具体结构为:
多尺度稀疏transformer模块包括特征聚合模块、多头注意力机制(MCA)和多层感知机(MLP)三部分,假定输入到多尺度稀疏transformer模块的特征为输入特征F,特征F维度为
Figure BDA0004041638510000051
其中(Hi,Wi)表示在学生网络中第i个阶段的特征F的分辨率,Ci表示在学生网络中第i个阶段通道的维度,其中4C1=2C2=C3,H1=2H2=4H3和W1=2W2=4W3;优选地,H1和W1的取值范围56~224像素,C3的范围256~1024;
其中,所述特征聚合模块生成两种类型的特征,生成第一种特征是局部特征Flocal和生成的第二种特征是区域特征Fregion,其中Flocal的维度为
Figure BDA0004041638510000052
和Fregion的维度为RCi ×(Hj·j)
具体地,所述局部特征Flocal(第一种特征)如下方式获得:
Figure BDA0004041638510000061
j=1,…,Nl,Nl=Hi·Wi,
其中,特征
Figure BDA0004041638510000062
是多尺度稀疏transformer模块的输入特征F形变后的特征,
Figure BDA0004041638510000063
的维度是
Figure BDA0004041638510000064
fi表示
Figure BDA0004041638510000065
中特征的分量,fi维度为
Figure BDA0004041638510000066
Nl=Hi·Wi代表分量的个数,
Figure BDA0004041638510000067
代表位置嵌入特征;
具体得,所述区域特征Fregion(第二种特征)如下方式获得:
首先,所述输入特征
Figure BDA0004041638510000068
使用大小为(p,p)的块分成互不相交的特征块序列Fp,特征序列Fp的维度为
Figure BDA0004041638510000069
其中
Figure BDA00040416385100000610
表示序列块的数量和
Figure BDA00040416385100000611
表示第i个块;优选地,p的取值范围1~8像素;
然后,互不相交的特征块列Fp被矩阵E映射变换成序列块
Figure BDA00040416385100000612
其中矩阵E的维度为
Figure BDA00040416385100000613
进一步,一个可学习的位置嵌入Epos与特征序列块
Figure BDA00040416385100000614
相加生成区域特征的分量
Figure BDA00040416385100000615
的维度为
Figure BDA00040416385100000616
和Epos的维度为
Figure BDA00040416385100000617
区域特征的分量
Figure BDA00040416385100000618
如下计算:
Figure BDA00040416385100000619
j=1,…,Np,
所述多尺度稀疏transformer模块中的多头注意力机制为:
首先,所述局部特征Flocal被矩阵
Figure BDA00040416385100000620
线性映射到Queriy值(Queriy值记为Qj),其中矩阵
Figure BDA00040416385100000621
维度为
Figure BDA00040416385100000622
Qj的维度为
Figure BDA00040416385100000623
Figure BDA00040416385100000624
代表单头注意力机制的维度和Ni代表第i个阶段单头注意机制的数目;
同时Fregion被矩阵
Figure BDA00040416385100000625
线性映射到Key键值(Key键值记作
Figure BDA00040416385100000626
),其中
Figure BDA00040416385100000627
的维度为
Figure BDA00040416385100000628
Figure BDA00040416385100000629
的维度为
Figure BDA00040416385100000630
Fregion也被矩阵
Figure BDA00040416385100000631
线性映射到Value值(Value值记作
Figure BDA00040416385100000632
),其中
Figure BDA00040416385100000633
的维度为
Figure BDA00040416385100000634
维度为
Figure BDA00040416385100000635
Query值(Qj),key值
Figure BDA00040416385100000636
和value值
Figure BDA00040416385100000637
可以被如下定义:
Figure BDA0004041638510000071
Figure BDA0004041638510000072
Figure BDA0004041638510000073
所述多头注意力机制中的计算单头注意力机制计算Query值(Qj),key值
Figure BDA0004041638510000074
和value值
Figure BDA0004041638510000075
的关系,如下所示:
Figure BDA0004041638510000076
进一步,
Figure BDA0004041638510000077
个单头注意力被合并在一起来获得多头注意力机制(MCA),多头注意力被如下表示:
Figure BDA0004041638510000078
Figure BDA0004041638510000079
进一步,两个尺度的多头注意力机制被合并在一起。具体的说,具有分块大小为p1的多头注意力机制
Figure BDA00040416385100000710
和具有分块大小为p2的多头注意力机制
Figure BDA00040416385100000711
被合并在一起来得到特征Z,公式如下所示:
Figure BDA00040416385100000712
最后,所述特征Z通过正则化层(LN)和多层感知机(MLP)来增强特征获得最终的特征
Figure BDA00040416385100000713
特征
Figure BDA00040416385100000714
可以由如下公式表示:
Figure BDA00040416385100000715
其中p1和p2表示分块的大小,MLP表示多层感知机和LN表示层正则化;
所述步骤S3中对学生网络进行参数优化,学生网络的优化目标是使得学生网络的三组特征和教师网络的三组特征更接近的方法具体为:
Figure BDA00040416385100000716
代表第k阶段(h,w)位置的教师网络的特征,
Figure BDA00040416385100000717
代表第k阶段(h,w)位置的学生网络的特征;Lk(h,w)表示第k阶段(h,w)位置教师网络特征和学生网络特征的损失,Lk(h,w)损失主要由余弦函数cos和平方损失mse加权组成,具体的数学公式如下:
Figure BDA0004041638510000081
其中α代表权重,优选地,取值范围在(0,1)之间;
最终,学生网络完整的损失
Figure BDA0004041638510000082
是由三个阶段损失进行相加,如下公式表示:
Figure BDA0004041638510000083
其中(Hk,Wk)代表第k个阶段特征的分辨率大小,K代表学生所有的阶段数目。
所述S6中计算教师网络和学生网络的特征相似度最终获得待测图像的病变得分,其中病变得分具体如下实现:
Figure BDA0004041638510000084
代表第k阶段(h,w)位置的教师网络的特征,
Figure BDA00040416385100000818
代表第k阶段(h,w)位置的学生网络的特征;
Figure BDA0004041638510000085
代表特征
Figure BDA0004041638510000086
和特征
Figure BDA0004041638510000087
的相似度,
Figure BDA0004041638510000088
代表病变得分;
计算教师网络的第一阶段特征
Figure BDA0004041638510000089
与步骤S5得到的学生网络的第一阶段特征
Figure BDA00040416385100000810
的相似度,进而计算第一阶段病变得分Score1,如下表示,
Figure BDA00040416385100000811
计算教师网络的第二阶段特征
Figure BDA00040416385100000812
与步骤S5得到的学生网络的第二阶段特征
Figure BDA00040416385100000813
的相似度,进而计算第二阶段病变得分Score2,如下所示,
Figure BDA00040416385100000814
计算教师网络的第三阶段特征
Figure BDA00040416385100000815
与步骤S5得到的学生网络的第三阶段特征
Figure BDA00040416385100000816
的相似度,
Figure BDA00040416385100000817
最终学生三个阶段的特征和教师网络三个阶段的病变得分进行求和为最终图像的病变得分Score,具体公式如下表示,
Figure BDA0004041638510000091
有益效果
本发明方法,与相关技术相比较,具有以下优点:
1.所述方法是一个基于无监督异构知识蒸馏的框架用于视网膜OCT图像病变分类。所述方法只需要对少量的正常样本的特征分布进行学习就能实现较高性能的病变分类性能。在所述框架中教师网络是一个通用的基于卷积神经网络的分类网络,而学生网络是一个基于CNN和transformer混合的网络。学生网络和教师网络是异构结构,在训练过程中,只使用正常样本来让学生网络学习教师网络的特征。当在测试中通过计算学生网络和教师网络的特征差异来实现病变检测。
2.为了充分发挥所述无监督异构知识蒸馏框架中异构的优点,一个多尺度稀疏transformer被设计来提升所述分类方法的病变分类性能。所述多尺度稀疏transformer在能够建模长距离特征依赖的基础上,还能够缓解transformer方法消耗较高计算量和较高内存占用的问题。
3.所述分类方法中将教师网络的特征输出作为学生网络的特征输入等价于将自编码器特征压缩和特征重构的思想引入到蒸馏方法中,这种特征压缩和特征恢复的结构能够实现缓解无监督蒸馏方法在视网膜OCT图像病变中过检测的问题。
附图说明
图1为本发明方法及实施例中的流程示意图;
图2为本发明实施例提供的一种基于无监督异构蒸馏网络的视网膜OCT图像病变分类方法的结构示意图;
图3为本发明实施例中基于多尺度稀疏transformer模块结构示意图。
具体实施方式
为使本发明实施例的目的、技术方案和优点更加清楚,下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例设计了一种基于无监督异构蒸馏框架的视网膜OCT图像病变分类方法,该方法设计了结构相异的教师网络和学生网络作为蒸馏框架的两个子网络。为了让学生网络与教师网络具有更大的差异,一个多尺度稀疏的transformer模块被提出来建模长距离特征关系和降低transformer固有的高计算代价的问题。进一步,教师模型的输出被作为学生网络的输入,这实现了将特征压缩与特征恢复引入到了蒸馏框架中,这能够有效实现对正常OCT图像特征的保留和对异常OCT图像特征的去处的目的,从而实现最终OCT图像的病变分类。
图1为本发明方法及实施例中的流程示意图,如图1所示,它包含以下6个步骤:
S1,对输入的视网膜OCT图像进行预处理,教师网络提取预处理后的OCT图像的特征,得到四组不同尺度的特征,四组不同尺度的特征分别为教师网络的第一阶段特征
Figure BDA0004041638510000101
教师网络的第二阶段特征
Figure BDA0004041638510000102
教师网络的第三阶段特征
Figure BDA0004041638510000103
和教师网络的第四阶段特征
Figure BDA0004041638510000104
S2,学生网络以步骤S1得到的教师网络的第四阶段特征
Figure BDA0004041638510000105
作为输入,生成三组不同尺度的特征,三组不同尺度的特征分别为学生网络的第三阶段特征
Figure BDA0004041638510000106
学生网络的第二阶段特征
Figure BDA0004041638510000107
和学生网络的第一阶段特征
Figure BDA0004041638510000108
学生网络是一个卷积神经网络和transformer模块混合而成的混合网络;
S3,对学生网络进行参数优化,学生网络的优化目标是使得学生网络的第一阶段特征
Figure BDA0004041638510000109
与教师网络的第一阶段特征
Figure BDA00040416385100001010
更接近、学生网络的第二阶段特征
Figure BDA0004041638510000111
与教师网络的第二阶段特征
Figure BDA0004041638510000112
更接近、学生网络的第三阶段特征
Figure BDA0004041638510000113
与教师网络的第三阶段特征
Figure BDA0004041638510000114
更接近,最终得到优化后的学生网络;
S4,使用教师网络提取待测视网膜OCT图像的特征,得到四组不同尺度的特征,四组不同尺度的特征分别为教师网络的第一阶段特征
Figure BDA0004041638510000115
教师网络的第二阶段特征
Figure BDA0004041638510000116
教师网络的第三阶段特征
Figure BDA0004041638510000117
和教师网络的第四阶段特征
Figure BDA0004041638510000118
S5,以步骤S4得到的教师网络的第四阶段特征
Figure BDA0004041638510000119
作为步骤S3优化后的学生网络的输入,生成三组不同尺度的特征,三组不同尺度的特征分别为学生网络的第三阶段特征
Figure BDA00040416385100001110
学生网络的第二阶段特征
Figure BDA00040416385100001111
和学生网络的第一阶段特征
Figure BDA00040416385100001112
S6,计算步骤S4得到的教师网络的第一阶段特征
Figure BDA00040416385100001113
与步骤S5得到的学生网络的第一阶段特征
Figure BDA00040416385100001114
的相似度,进而用于计算第一阶段病变得分Score1;同理,计算教师网络的第二阶段特征
Figure BDA00040416385100001115
与步骤S5得到的学生网络的第二阶段特征
Figure BDA00040416385100001116
的相似度,进而计算第二阶段病变得分Score2;计算教师网络的第三阶段特征
Figure BDA00040416385100001117
与步骤S5得到的学生网络的第三阶段特征
Figure BDA00040416385100001118
的相似度,进而计算第三阶段病变得分Score3,并将第一阶段病变得分Score1、第二阶段病变得分Score2和第三阶段病变得分Score3相加,得到待测图像的病变得分Score。
步骤1:对输入的视网膜OCT图像进行预处理,然后教师网络对预处理后的OCT图像提取特征并且依次产生四组不同尺度的特征;
步骤1.1:对输入的视网膜OCT图像进行预处理,将输入的视网膜OCT图像压缩至分辨率为(256,256)的大小;
步骤1.2:教师网络对预处理后的OCT图像提取特征并且依次产生四组不同尺度的特征;其中的教师网络是一个在ImageNet大规模数据集上进行预先训练的卷积神经网络;教师网络使用ResNet系列网络,并且ResNet网络在ImageNet大规模数据集上进行预先训练的参数被用做网络的权重初始化,在方法的训练阶段教师网络的参数停止更新;所述教师网络将ResNet分类网络中最后的全连接层删除,仅仅继承了ResNet分类网络四个阶段特征特征提取块的结构;四个阶段特征提取块分别为第一个阶段特征器、第二个阶段特征器、第三个阶段特征器、第四个阶段特征器;
进一步,如图2中(a)所示四个特征提取块分别产生步骤1.2中所述的教师网络产生的四组不同尺度的特征,分别记为教师网络的第一阶段特征
Figure BDA0004041638510000121
教师网络第二阶段特征
Figure BDA0004041638510000122
教师网络的第三阶段特征
Figure BDA0004041638510000123
和教师网络第四阶段特征
Figure BDA0004041638510000124
其中所述教师网络第一阶段特征、第二阶段特征、第三阶段特征和第四阶段特征的提取方法为:
将预处理后分辨率为(256,256)的OCT图像输入到教师网络第一个阶段特征器后特征的维度转变成(64,64,64)得到教师网络的第一阶段特征
Figure BDA0004041638510000125
教师网络的第二阶段特征提取器将第一阶段特征提取压缩成教师网络的第二阶段特征
Figure BDA0004041638510000126
其维度为(32,32,128);进一步,教师网络的第三阶段特征提取器将第二阶段特征提取压缩为教师网络的第三阶段特征
Figure BDA0004041638510000127
其维度为(16,16,256);最后,教师网络的第四阶段特征提取器将第三阶段特征提取压缩为教师网络的第四阶段特征
Figure BDA0004041638510000128
其中特征维度为(8,8,512)。
步骤2:如图2中(a)所示,学生网络生成的三组不同尺度的特征依次为学生网络的第三阶段特征
Figure BDA0004041638510000129
学生网络的第二阶段特征
Figure BDA00040416385100001210
和学生网络的第一阶段特征
Figure BDA00040416385100001211
具体的实现方式为:
所述学生网络以教师网络第四阶段特征为输入,然后学生网络对输入的特征进行处理,学生网络共包括三个阶段来生成多尺度的特征。学生模型将特征依次解码至与所述教师网络三个阶段相一致的尺度;所述教师网络的第四阶段特征经过学生模型的第一个阶段变成了特征
Figure BDA00040416385100001212
其中维度为
Figure BDA00040416385100001213
同理,学生网络的第二阶段生成特征
Figure BDA00040416385100001214
其中
Figure BDA00040416385100001215
维度为
Figure BDA00040416385100001216
学生网络的第三阶段生成特征
Figure BDA00040416385100001217
其中
Figure BDA00040416385100001218
维度为
Figure BDA00040416385100001219
在所述步骤中学生网络是一个卷积神经网络和transformer模块混合而成的混合网络,学生网络的三个阶段都是由模块单元组成,模块单元包括CNN子块(也称作卷积子块)和transformer子块,所述模块单元为:
所述的模块单元中包含两个并行的子块,它们分别是CNN子块和transformer子块;
首先模块单元通过1×1卷积对输入的特征进行维度调整,调整后的新特征被分成两个特征组,分别记为混合模块的第一组特征和混合模块的第二组特征,卷积模块的第一组特征通过CNN子块后产生特征FConv,混合模块的第二组特征通过transformer子块后产生特征FTran,最终特征FConv和特征FTran并列堆叠到一起生成特征FTran-Conv,特征FTran-Conv的通道数通过使用1×1的卷积来调整;
所述的模块单元中CNN子块具体结构为:卷积子块包含两个连续的卷积核大小为3×3的卷积,其中卷积的步长为1×1和填充padding为1;
所述模块单元中transformer子块为多尺度稀疏transformer模块,具体结构为:
多尺度稀疏transformer模块包括特征聚合模块、多头注意力机制计算(MCA)和多层感知机(MLP)三部分,假定输入到多尺度稀疏transformer模块的特征为输入特征F,特征F维度为
Figure BDA0004041638510000131
其中(Hi,Wi)表示在学生网络中第i个阶段的特征F的分辨率,Ci表示在学生网络中第i个阶段通道的维度,其中4C1=2C2=C3,H1=2H2=4H3和W1=2W2=4W3;H1=64,W1=64和C3=512;
其中,所述特征聚合模块生成两种类型的特征,生成第一种特征是局部特征Flocal和生成的第二种特征是区域特征Fregion,其中Flocal的维度为
Figure BDA0004041638510000132
和Fregion的维度为
Figure BDA0004041638510000133
具体地,所述局部特征Flocal(第一种特征)如下方式获得:
Figure BDA0004041638510000134
j=1,…,Nl,Nl=Hi·Wi,
其中,特征
Figure BDA0004041638510000135
是多尺度稀疏transformer模块的输入特征F形变后的特征,
Figure BDA0004041638510000136
的维度是
Figure BDA0004041638510000137
fi表示
Figure BDA0004041638510000138
中特征的分量,fi维度为
Figure BDA0004041638510000139
Nl=Hi·Wi代表分量的个数,
Figure BDA00040416385100001310
代表位置嵌入特征;
具体得,所述区域特征Fregion(第二种特征)如下方式获得:
首先,所述输入特征
Figure BDA0004041638510000141
使用大小为(p,p)的块分成互不相交的特征块序列Fp,特征序列Fp的维度为
Figure BDA0004041638510000142
其中
Figure BDA0004041638510000143
表示序列块的数量和
Figure BDA0004041638510000144
表示第i个块;
然后,互不相交的特征块列Fp被矩阵E映射变换成序列块
Figure BDA0004041638510000145
其中矩阵E的维度为
Figure BDA0004041638510000146
进一步,一个可学习的位置嵌入Epos与特征序列块
Figure BDA0004041638510000147
相加生成区域特征的分量
Figure BDA0004041638510000148
的维度为
Figure BDA0004041638510000149
和Epos的维度为
Figure BDA00040416385100001410
区域特征的分量
Figure BDA00040416385100001411
如下计算:
Figure BDA00040416385100001412
j=1,…,Np,
所述多尺度稀疏transformer模块中的多头注意力机制为:
首先,所述局部特征Flocal被矩阵
Figure BDA00040416385100001413
线性映射到Queriy值(Queriy值记为Qj),其中矩阵
Figure BDA00040416385100001414
维度为
Figure BDA00040416385100001415
Qj的维度为
Figure BDA00040416385100001416
Figure BDA00040416385100001417
代表单头注意力机制的维度和Ni代表第i个阶段单头注意机制的数目;
其中,不同阶段学生网络的总通道数目为:C1=64,C2=128,C3=256.
其中,多头注意力机制的数目Ni在不同的i阶段为:N1=2,N2=4,N3=8.
同时Fregion被矩阵
Figure BDA00040416385100001418
线性映射到Key键值(Key键值记作
Figure BDA00040416385100001419
),其中
Figure BDA00040416385100001420
的维度为
Figure BDA00040416385100001421
Figure BDA00040416385100001422
的维度为
Figure BDA00040416385100001423
Fregion也被矩阵
Figure BDA00040416385100001424
线性映射到Value值(Value值记作
Figure BDA00040416385100001425
),其中
Figure BDA00040416385100001426
的维度为
Figure BDA00040416385100001427
Figure BDA00040416385100001428
维度为
Figure BDA00040416385100001429
Query值(Qj),key值
Figure BDA00040416385100001430
和value值
Figure BDA00040416385100001431
可以被如下定义:
Figure BDA00040416385100001432
Figure BDA00040416385100001433
Figure BDA00040416385100001434
所述多头注意力机制中的计算单头注意力机制计算Query值(Qj),key值
Figure BDA0004041638510000151
和value值
Figure BDA0004041638510000152
的关系,如下所示:
Figure BDA0004041638510000153
进一步,
Figure BDA0004041638510000154
个单头注意力被合并在一起来获得多头注意力(MCA),多头注意力被如下表示:
Figure BDA0004041638510000155
Figure BDA0004041638510000156
进一步,两个尺度的多头注意力机制被合并在一起。具体的说,具有分块大小为p1的多头注意力机制
Figure BDA0004041638510000157
和具有分块大小为p2的多头注意力机制
Figure BDA0004041638510000158
被合并在一起来得到特征Z,公式如下所示;
Figure BDA0004041638510000159
最后,所述特征Z通过正则化层(LN)和多层感知机(MLP)来增强特征获得最终的特征
Figure BDA00040416385100001510
特征
Figure BDA00040416385100001511
可以由如下公式表示:
Figure BDA00040416385100001512
其中,p1和p2表示分块的大小,MLP表示多层感知机和LN表示层正则化;在学生网络的第一阶段p1=4,p2=8;在学生网络的第二阶段p1=2,p2=4;在学生网络的第三阶段p1=1,p2=2;
步骤3:对学生网络进行参数优化,学生网络的优化目标是使得学生网络的三组特征和教师网络的三组特征更接近的方法具体为:
Figure BDA00040416385100001513
代表第k阶段(h,w)位置的教师网络的特征,
Figure BDA00040416385100001514
代表第k阶段(h,w)位置的学生网络的特征;Lk(h,w)表示第k阶段(h,w)位置教师网络特征和学生网络特征的损失,Lk(h,w)损失主要由余弦函数cos和平方损失mse加权组成,具体的数学公式如下:
Figure BDA00040416385100001515
最终,学生网络完整的损失是由三个阶段损失进行相加,如下公式表示:
Figure BDA0004041638510000161
其中(Hk,Wk)代表第k个阶段特征的分辨率大小,K=3代表学生所有的阶段数目;H1=2H2=4H3,W1=2W2=4W3,H1=64,W1=64;
步骤4:使用教师网络提取待测视网膜OCT图像的特征,得到四组不同尺度的特征,四组不同尺度的特征分别为教师网络的第一阶段特征
Figure BDA0004041638510000162
教师网络的第二阶段特征
Figure BDA0004041638510000163
教师网络的第三阶段特征
Figure BDA0004041638510000164
和教师网络的第四阶段特征
Figure BDA0004041638510000165
步骤5:以步骤4得到的教师网络的第四阶段特征
Figure BDA0004041638510000166
作为步骤3优化后的学生网络的输入,生成三组不同尺度的特征,三组不同尺度的特征分别为学生网络的第三阶段特征
Figure BDA0004041638510000167
学生网络的第二阶段特征
Figure BDA0004041638510000168
和学生网络的第一阶段特征
Figure BDA0004041638510000169
步骤6:计算教师网络和学生网络的特征相似度最终获得待测图像的病变得分,其中病变得分具体如下实现:
Figure BDA00040416385100001610
代表第k阶段(h,w)位置的教师网络的特征,
Figure BDA00040416385100001611
代表第k阶段(h,w)位置的学生网络的特征;
Figure BDA00040416385100001612
代表特征
Figure BDA00040416385100001613
和特征
Figure BDA00040416385100001614
的相似度,
Figure BDA00040416385100001615
代表病变得分;
计算教师网络的第一阶段特征
Figure BDA00040416385100001616
与步骤5得到的学生网络的第一阶段特征
Figure BDA00040416385100001617
的相似度,进而计算第一阶段病变得分Score1,如下表示,
Figure BDA00040416385100001618
计算教师网络的第二阶段特征
Figure BDA00040416385100001619
与步骤5得到的学生网络的第二阶段特征
Figure BDA00040416385100001620
的相似度,进而计算第二阶段病变得分Score2,如下所示,
Figure BDA00040416385100001621
计算教师网络的第三阶段特征
Figure BDA00040416385100001622
与步骤5得到的学生网络的第三阶段特征
Figure BDA00040416385100001623
的相似度,
Figure BDA00040416385100001624
最终学生三个阶段的特征和教师网络三个阶段的病变得分进行求和为最终图像的病变得分Score,具体公式如下表示,
Figure BDA0004041638510000171
综上所述,以上仅为本发明的较佳实施例而已,并非用于限定本发明的保护范围。凡在本发明的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

Claims (10)

1.一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于该方法的步骤包括:
S1,对输入的视网膜OCT图像进行预处理,教师网络提取预处理后的视网膜OCT图像的特征,得到四组不同尺度的特征,四组不同尺度的特征分别为教师网络的第一阶段特征、教师网络的第二阶段特征、教师网络的第三阶段特征和教师网络的第四阶段特征;
S2,学生网络以步骤S1得到的教师网络的第四阶段特征作为输入,生成三组不同尺度的特征,三组不同尺度的特征分别为学生网络的第三阶段特征、学生网络的第二阶段特征和学生网络的第一阶段特征;
S3,对学生网络进行参数优化,学生网络的优化目标是使得学生网络的第一阶段特征更接近教师网络的第一阶段特征、学生网络的第二阶段特征更接近教师网络的第二阶段特征、学生网络的第三阶段特征更接近教师网络的第三阶段特征,最终得到优化后的学生网络;
S4,使用教师网络提取待测视网膜OCT图像的特征,得到四组不同尺度的特征,四组不同尺度的特征分别为教师网络的第一阶段特征、教师网络的第二阶段特征、教师网络的第三阶段特征和教师网络的第四阶段特征;
S5,以步骤S4得到的教师网络的第四阶段特征作为步骤S3优化后的学生网络的输入,生成三组不同尺度的特征,三组不同尺度的特征分别为学生网络的第三阶段特征、学生网络的第二阶段特征和学生网络的第一阶段特征;
S6,计算步骤S4得到的教师网络的第一阶段特征与步骤S5得到的学生网络的第一阶段特征的相似度,进而用于计算第一阶段病变得分Score1,计算教师网络的第二阶段特征与步骤S5得到的学生网络的第二阶段特征的相似度,进而计算第二阶段病变得分Score2,计算教师网络的第三阶段特征与步骤S5得到的学生网络的第三阶段特征的相似度,进而计算第三阶段病变得分Score3,将第一阶病变得分Score1、第二阶段病变得分Score2和第三阶段病变得分Score3相加,得到待测图像最终的病变得分Score。
2.根据权利要求1所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述的步骤S1中,对输入的视网膜OCT图像进行预处理具体为:将输入的视网膜OCT图像压缩至分辨率为(H,W)的大小;H的取值范围为224~448像素,W与H相同。
3.根据权利要求1或2所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述的步骤S1中,教师网络为ResNet分类卷积神经网络、DenseNet分类卷积神经网络或VGGNet等分类卷积神经网络,
教师网络使用在ImageNet数据集上预先训练好的权重作为初始化,并且训练阶段教师网络的参数权重不更新。
4.根据权利要求1或2所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述的步骤S1中,教师网络提取特征的方法为:将预处理后分辨率为(H,W)的视网膜OCT图像输入到教师网络的第一个阶段特征提取器后,特征的维度转变成(H/4,W/4,64)得到教师网络的第一阶段特征;
教师网络的第二阶段特征提取器将第一阶段特征提取压缩成教师网络的第二阶段特征,维度为(H/8,W/8,128);
教师网络的第三阶段特征提取器将第二阶段特征提取压缩为教师网络的第三阶段特征,维度为(H/16,W/16,256);
教师网络的第四阶段特征提取器将第三阶段特征提取压缩为教师网络的第四阶段特征,维度为(H/32,W/32,512)。
5.根据权利要求1所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述步骤S2中,学生网络由卷积神经网络和Transformer模块混合而成,卷积神经网络包含两个连续的卷积核大小为3×3的卷积,其中卷积的步长为1×1,填充padding为1;Transformer模块为多尺度稀疏transformer模块。
6.根据权利要求5所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述的多尺度稀疏transformer模块包括特征聚合模块、多头注意力机制和多层感知机;
特征聚合模块包括局部特征Flocal∈RC×(H·W)和区域特征Fregion∈RC×(H·W)
局部特征Flocal∈RC×(HW)为:
Figure FDA0004041638500000031
其中特征
Figure FDA0004041638500000032
是多尺度稀疏transformer模块的输入特征F形变后的特征,
Figure FDA0004041638500000033
的维度是
Figure FDA0004041638500000034
fi表示
Figure FDA0004041638500000035
中特征的分量,fi维度为
Figure FDA0004041638500000036
Nl=Hi·Wi代表分量的个数,
Figure FDA0004041638500000037
代表位置嵌入特征;
区域特征Fregion∈RC×(H·W)的计算方法如下:
首先,特征
Figure FDA0004041638500000038
被使用大小为(p,p)的块分成互不相交的特征块列
Figure FDA0004041638500000039
其中(Hi,Wi)表示第i个阶段特征F的分辨率,Ci表示第i个阶段通道的维度,
Figure FDA00040416385000000310
表示块的数量和
Figure FDA00040416385000000311
表示第i个块;
然后,互不相交的特征块列Fp被矩阵E映射变换成序列块
Figure FDA00040416385000000312
其中矩阵E的维度为
Figure FDA00040416385000000313
可学习的位置嵌入
Figure FDA00040416385000000314
与特征序列块
Figure FDA00040416385000000315
相加生成区域特征的分量
Figure FDA00040416385000000316
区域特征的分量
Figure FDA00040416385000000317
如下计算:
Figure FDA00040416385000000318
7.根据权利要求6所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述多尺度稀疏transformer模块中的多头注意力机制中的输入(Queriy值、Key值和Value值)如下计算:
首先,所述局部特征Flocal被矩阵
Figure FDA0004041638500000041
线性映射到Queriy值(Queriy值记为Qj),其中矩阵
Figure FDA0004041638500000042
维度为
Figure FDA0004041638500000043
Qj的维度为
Figure FDA0004041638500000044
Figure FDA0004041638500000045
代表单头注意力机制的维度和Ni代表第i个阶段单头注意机制的数目;
同时Fregion被矩阵
Figure FDA0004041638500000046
线性映射到Key值(Key值记作
Figure FDA0004041638500000047
),其中
Figure FDA0004041638500000048
的维度为
Figure FDA0004041638500000049
Figure FDA00040416385000000410
的维度为
Figure FDA00040416385000000411
Fregion也被矩阵
Figure FDA00040416385000000412
线性映射到Value值(Value值记作
Figure FDA00040416385000000413
),其中
Figure FDA00040416385000000414
的维度为
Figure FDA00040416385000000415
Figure FDA00040416385000000416
维度为
Figure FDA00040416385000000417
Query值(Qj),key值
Figure FDA00040416385000000418
和value值
Figure FDA00040416385000000419
可以被如下定义:
Figure FDA00040416385000000420
8.根据权利要求6所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述多头注意力机制被用于计算单头注意力机制的局部信息和区域信息,单头注意力如下所示:
Figure FDA00040416385000000421
Figure FDA00040416385000000422
个单头注意力被合并在一起来获得多头注意力机制(MCA),多头注意力被如下表示:
Figure FDA00040416385000000423
Figure FDA00040416385000000424
两个尺度的多头注意力机制被合并在一起,具体的说,具有分块大小为p1的多头注意力机制
Figure FDA0004041638500000051
和具有分块大小为p2的多头注意力机制
Figure FDA0004041638500000052
被合并在一起来得到特征Z,公式如下所示:
Figure FDA0004041638500000053
最后,所述特征Z通过正则化层LN和多层感知机来增强特征获得最终的特征
Figure FDA0004041638500000054
特征
Figure FDA0004041638500000055
由如下公式表示:
Figure FDA0004041638500000056
其中p1和p2表示分块的大小,MLP表示多层感知机和LN表示层正则化。
9.根据权利要求6所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
所述步骤S3中,对学生网络进行参数优化,学生网络的优化目标是使得学生网络的三组特征和教师网络的三组特征相似具体为:
Figure FDA0004041638500000057
代表第k阶段(h,w)位置的教师网络的特征,
Figure FDA0004041638500000058
代表第k阶段(h,w)位置的学生网络的特征;Lk(h,w)表示第k阶段(h,w)位置教师网络特征和学生网络特征的损失,Lk(h,w)损失主要由余弦函数cos和平方损失mse加权组成,具体的数学公式如下:
Figure FDA0004041638500000059
其中α代表权重,优选地,取值范围在(0,1)之间;
最终,学生网络完整的损失
Figure FDA00040416385000000510
是由三个阶段损失进行相加,如下公式表示:
Figure FDA00040416385000000511
其中(Hk,Wk)代表第k个阶段特征的分辨率大小,K代表学生所有的阶段数目。
10.根据权利要求6所述的一种基于无监督特征蒸馏框架的视网膜OCT图像病变分类方法,其特征在于:
Figure FDA0004041638500000061
代表第k阶段(h,w)位置的教师网络的特征,
Figure FDA00040416385000000616
代表第k阶段(h,w)位置的学生网络的特征;
Figure FDA0004041638500000062
代表特征
Figure FDA0004041638500000063
和特征
Figure FDA0004041638500000064
的相似度,
Figure FDA0004041638500000065
代表病变得分;
计算教师网络的第一阶段特征
Figure FDA0004041638500000066
与步骤S5得到的学生网络的第一阶段特征
Figure FDA0004041638500000067
的相似度,进而计算第一阶段病变得分Score1,如下表示,
Figure FDA0004041638500000068
计算教师网络的第二阶段特征
Figure FDA0004041638500000069
与步骤S5得到的学生网络的第二阶段特征
Figure FDA00040416385000000610
的相似度,进而计算第二阶段病变得分Score2,如下所示,
Figure FDA00040416385000000611
计算教师网络的第三阶段特征
Figure FDA00040416385000000612
与步骤S5得到的学生网络的第三阶段特征
Figure FDA00040416385000000613
的相似度,
Figure FDA00040416385000000614
最终学生三个阶段的特征和教师网络三个阶段的病变得分进行求和为最终图像的病变得分Score,具体公式如下表示,
Figure FDA00040416385000000615
CN202310020402.7A 2023-01-06 2023-01-06 一种基于无监督异构蒸馏框架的视网膜oct图像病变分类方法 Pending CN116091449A (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310020402.7A CN116091449A (zh) 2023-01-06 2023-01-06 一种基于无监督异构蒸馏框架的视网膜oct图像病变分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310020402.7A CN116091449A (zh) 2023-01-06 2023-01-06 一种基于无监督异构蒸馏框架的视网膜oct图像病变分类方法

Publications (1)

Publication Number Publication Date
CN116091449A true CN116091449A (zh) 2023-05-09

Family

ID=86203948

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310020402.7A Pending CN116091449A (zh) 2023-01-06 2023-01-06 一种基于无监督异构蒸馏框架的视网膜oct图像病变分类方法

Country Status (1)

Country Link
CN (1) CN116091449A (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116342859A (zh) * 2023-05-30 2023-06-27 安徽医科大学第一附属医院 一种基于影像学特征识别肺部肿瘤区域的方法及系统

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116342859A (zh) * 2023-05-30 2023-06-27 安徽医科大学第一附属医院 一种基于影像学特征识别肺部肿瘤区域的方法及系统
CN116342859B (zh) * 2023-05-30 2023-08-18 安徽医科大学第一附属医院 一种基于影像学特征识别肺部肿瘤区域的方法及系统

Similar Documents

Publication Publication Date Title
Kwasigroch et al. Deep CNN based decision support system for detection and assessing the stage of diabetic retinopathy
Li et al. Automatic detection of diabetic retinopathy in retinal fundus photographs based on deep learning algorithm
CN109345538A (zh) 一种基于卷积神经网络的视网膜血管分割方法
CN109949235A (zh) 一种基于深度卷积神经网络的胸部x光片去噪方法
CN111104961A (zh) 基于改进的MobileNet网络对乳腺癌进行分类的方法
CN112101424B (zh) 一种视网膜病变识别模型的生成方法、识别装置及设备
CN104636580A (zh) 一种基于人脸的健康监控手机
Wang et al. Learning two-stream CNN for multi-modal age-related macular degeneration categorization
CN110070531A (zh) 用于检测眼底图片的模型训练方法、眼底图片的检测方法及装置
CN113689954A (zh) 高血压风险预测方法、装置、设备及介质
CN113012163A (zh) 一种基于多尺度注意力网络的视网膜血管分割方法、设备及存储介质
Ovreiu et al. Deep learning & digital fundus images: Glaucoma detection using DenseNet
Das et al. CA-Net: A novel cascaded attention-based network for multi-stage glaucoma classification using fundus images
CN114998651A (zh) 基于迁移学习的皮肤病变图像分类识别方法、系统及介质
CN113782184A (zh) 一种基于面部关键点与特征预学习的脑卒中辅助评估系统
CN116091449A (zh) 一种基于无监督异构蒸馏框架的视网膜oct图像病变分类方法
Zeng et al. Automated detection of diabetic retinopathy using a binocular siamese-like convolutional network
CN115409764A (zh) 一种基于域自适应的多模态眼底血管分割方法及装置
Zhuang et al. Classification of diabetic retinopathy via fundus photography: Utilization of deep learning approaches to speed up disease detection
CN109994202A (zh) 一种基于深度学习的人脸生成中药处方的方法
CN115937590A (zh) 一种并联融合CNN和Transformer的皮肤病图像分类方法
CN110507288A (zh) 基于一维卷积神经网络的视觉诱导晕动症检测方法
Nandy Pal et al. Content based retrieval of retinal OCT scans using twin CNN
CN115619814A (zh) 一种视盘和视杯联合分割方法与系统
Mathina Kani et al. Classification of skin lesion images using modified Inception V3 model with transfer learning and augmentation techniques

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination