Disclosure of Invention
The invention fully considers the defects of the prior art, and aims to provide a three-dimensional medical image automatic segmentation method based on deep learning to relieve the instability and overfitting condition in the deep convolutional neural network training process, improve the segmentation precision of small and medium targets in the three-dimensional medical image and have good robustness.
First, technical principle
The current segmentation model based on the convolutional neural network is mostly composed of an encoder and a decoder. The encoder identifies the target by obtaining high-dimensional semantic information by stepping down spatial information through a cascade of convolutional and pooling layers, but the pooling operation results in feature map size reduction. In order to obtain the final voxel-level segmentation result, the identified target needs to be restored to the original size. Therefore, the decoder needs to restore the original size according to the high-dimensional semantic information and the spatial information, but in the prior art, a large amount of spatial information is lost in the stage of obtaining the semantic information, and the spatial information of the medium and small targets cannot be effectively restored. In order to better acquire spatial information and semantic information, the invention provides a multi-scale layer cross-connection structure: the multilevel feature graph of the encoder stage is directly connected to the decoder stage by using shortcut (shortcut), so that the model automatically learns and selects the required hierarchical feature when the spatial information is recovered, and the reuse of the multilevel feature graph is realized.
In order to solve the problem of class imbalance caused by too small target, the invention introduces a region-of-interest adaptive attention mechanism based on an auxiliary function. By predicting the large target boundary rectangular frame, the encoder of the model pays more attention to the position information of the large target containing the small target, so that the segmentation area of the small target is reduced, the self-adaptive attention mechanism of the model is realized, and the segmentation precision of the small target is improved. Meanwhile, in order to prevent the problems of instability and overfitting in the network training process, a residual convolution module, a channel adaptive attention (SE) module and anti-aliasing pooling operation are introduced.
Secondly, according to the principle, the invention is realized by the following scheme:
a three-dimensional medical image automatic segmentation method based on deep learning comprises the following steps:
(1) acquiring an original training data set from a three-dimensional medical image segmentation public database, extracting boundary rectangular frame information of an interested area by reading annotation data in the original training data set, and forming a sample data set by using case images, segmentation annotations of the case images and the boundary information of the interested area;
(2) randomly dicing the three-dimensional medical image, and expanding a sample data set:
due to the limitation of video memory, the whole three-dimensional medical image cannot be directly input into the segmentation model, so that the original sample data set is zoomed and randomly cut into blocks for multiple times to form an expanded sample data set;
(3) constructing a new feature extraction network, which specifically comprises the following steps:
(3-a) taking a 3D U-Net network as a basic network, wherein the basic network comprises ten convolutional layers and four splicing layers, the output of the first convolutional layer is connected with the output of the eleventh convolutional layer to form a twelfth splicing layer, the output of the second convolutional layer is connected with the output of the ninth convolutional layer to form a tenth splicing layer, the output of the third convolutional layer is connected with the output of the seventh convolutional layer to form an eighth splicing layer, and the output of the fourth convolutional layer is connected with the output of the fifth convolutional layer to form a sixth splicing layer;
(3-b) adding cross-connection among multiple layers on the basis network described in the step (3-a), and constructing a new feature extraction network: connecting the outputs of the first, second, third and fourth convolutional layers, performing dimension reduction by using convolution, and then connecting the outputs of the first, second, third and fourth convolutional layers with the output of the eleventh convolutional layer to form a twelfth splicing layer, connecting the outputs of the second, third and fourth convolutional layers, performing dimension reduction by using convolution, and then connecting the outputs of the second, third and fourth convolutional layers with the output of the ninth convolutional layer to form a tenth splicing layer, and connecting the outputs of the third and fourth convolutional layers with the output of the seventh convolutional layer to form an eighth splicing layer;
(3-c) reconstructing all convolutional layers in the base network described in step (3-a) as follows:
I. replacing a convolution module in the original convolution layer with a residual convolution module;
replacing a pooling module in the original convolution layer with an anti-aliasing pooling module, wherein the anti-aliasing pooling is to add smooth convolution operation before maximum pooling operation;
a channel self-adaptive attention (SE) module is added behind the residual convolution module and in front of the anti-aliasing pooling module;
(3-d) adding a thirteenth convolutional layer and a fourteenth convolutional layer on the basis network described in the step (3-a), wherein the output of the twelfth convolutional layer is connected with the input of the thirteenth convolutional layer, the output of the thirteenth convolutional layer is connected with the input of the fourteenth convolutional layer, and the output of the fourteenth convolutional layer and the segmentation golden standard labeled by the data set are used for constructing a segmentation loss function by using a Dice coefficient loss function, which is shown as the following formula:
where K is the number of segmentation classes, y
iFor the segmentation gold criteria of class i in the dataset,
the symbol # represents a union for the output of the fourteenth convolutional layer, i.e., the i-th class segmentation result obtained by the network.
(4) The method for constructing the area-of-interest adaptive attention network specifically comprises the following steps:
(4-a) adding a region-of-interest adaptive attention network on the feature extraction network obtained in the step (3), wherein the network comprises two new convolutional layers: a first new convolutional layer and a second new convolutional layer, and the output of the fifth layer of convolutional layer is connected with the input of the first new convolutional layer and the output of the first new convolutional layer is connected with the input of the second new convolutional layer;
(4-b) constructing a supplementary loss function using a mean square error function from the output of the second new convolution layer and the data set labeled region of interest bounding rectangle information, as shown in the following equation:
Lroi=(t-tp)2
wherein t is the boundary rectangle frame information of the region of interest obtained in the step (1), and t ispThe predicted value of the boundary rectangle frame is the output of the new convolution layer of the second layer, namely the predicted value of the network pair boundary rectangle frame.
(5) Obtaining a segmentation model, specifically comprising the following steps:
(5-a) constructing a new total loss function by using the segmentation loss function obtained in the step (3-d) and the auxiliary loss function obtained in the step (4-b), as shown in the following formula:
Ltotal=αLroi+(1-α)Lseg
wherein α is a segmentation loss function L for balancing the segmentation loss function of step (3-d)segAnd the auxiliary loss function L described in step (4-b)roiAnd α is a constant greater than 0 and less than 1.
(5-b) combining the feature extraction network constructed in the step (3) and the region-of-interest adaptive attention network constructed in the step (4) to obtain a final segmentation model;
(6) training a segmentation model:
training the segmentation model constructed in the step (5-b) by using the total loss function obtained in the step (5-a) through the extended sample data set obtained in the step (2), and optimizing the weight parameter of each layer through gradient back propagation to obtain the trained segmentation model;
(7) three-dimensional medical image segmentation:
and (4) segmenting each three-dimensional medical image in the test data set by using the trained segmentation model to obtain a final segmentation result of each medical image.
In the (5-a) step, the hyperparameter α is preferably 0.18.
The invention has the following advantages:
firstly, the invention provides a layer cross connection method aiming at the problem that the segmentation result cannot be effectively recovered at the decoder stage due to the loss of too much spatial information of the existing image segmentation model at the encoder stage, so that a multi-scale feature map can be repeatedly utilized at the encoder stage, a network can select required hierarchical features by self, and the segmentation precision and robustness of the model are improved.
Secondly, aiming at the problem of category imbalance possibly caused by three-dimensional model training, the invention designs an attention mechanism based on auxiliary function, so that the model automatically learns and pays attention to the position information of a large target containing a small target, thereby improving the segmentation accuracy of the small target and the small target.
Thirdly, the method introduces a residual convolution module, a channel adaptive attention (SE) module and anti-aliasing pooling operation, effectively stabilizes the training process, and reduces the overfitting possibility and the training difficulty of the network.
Detailed Description
The following describes specific embodiments of the present invention:
example 1
Fig. 1 is a flowchart of an embodiment of an automatic segmentation method for a three-dimensional medical image based on deep learning, which includes the following specific steps:
step 1, acquiring a three-dimensional medical image.
The method comprises the steps of obtaining an original training data set from a three-dimensional medical image segmentation public database, extracting boundary rectangular frame information of an interested area by reading label data in the original training data set, and forming a sample data set by using case images, segmentation labels of the case images and the boundary information of the interested area.
And 2, randomly cutting the three-dimensional medical image into blocks and expanding a sample data set.
Due to the limitation of video memory, the whole three-dimensional medical image cannot be directly input into the segmentation model, so that the original sample data set is zoomed and randomly three-dimensionally diced for many times to form an expanded sample data set.
And 3, constructing a new feature extraction network.
Fig. 2 shows a segmentation model based on a convolutional neural network constructed in the embodiment of the present invention, which includes the following specific steps:
(3-a) taking a 3D U-Net network as a basic network, wherein the basic network comprises ten convolutional layers and four splicing layers, the output of the first convolutional layer is connected with the output of the eleventh convolutional layer to form a twelfth splicing layer, the output of the second convolutional layer is connected with the output of the ninth convolutional layer to form a tenth splicing layer, the output of the third convolutional layer is connected with the output of the seventh convolutional layer to form an eighth splicing layer, and the output of the fourth convolutional layer is connected with the output of the fifth convolutional layer to form a sixth splicing layer;
(3-b) adding cross-connection among multiple layers on the basis network described in the step (3-a), and constructing a new feature extraction network: connecting the outputs of the first, second, third and fourth convolutional layers, performing dimension reduction by using convolution, and then connecting the outputs of the first, second, third and fourth convolutional layers with the output of the eleventh convolutional layer to form a twelfth splicing layer, connecting the outputs of the second, third and fourth convolutional layers, performing dimension reduction by using convolution, and then connecting the outputs of the second, third and fourth convolutional layers with the output of the ninth convolutional layer to form a tenth splicing layer, and connecting the outputs of the third and fourth convolutional layers with the output of the seventh convolutional layer to form an eighth splicing layer;
(3-c) reconstructing all convolutional layers in the base network described in the step (3-a).
Fig. 3 shows a method for constructing a convolutional layer according to an embodiment of the present invention, which includes the following steps:
I. replacing a convolution module in the original convolution layer with a residual convolution module;
replacing a pooling module in the original convolution layer with an anti-aliasing pooling module, wherein the anti-aliasing pooling is to add smooth convolution operation before maximum pooling operation;
a channel self-adaptive attention (SE) module is added behind the residual convolution module and in front of the anti-aliasing pooling module;
(3-d) adding a thirteenth convolutional layer and a fourteenth convolutional layer on the basis network described in the step (3-a), wherein the output of the twelfth convolutional layer is connected with the input of the thirteenth convolutional layer, the output of the thirteenth convolutional layer is connected with the input of the fourteenth convolutional layer, and the output of the fourteenth convolutional layer and the segmentation golden standard labeled by the data set are used for constructing a segmentation loss function by using a Dice coefficient loss function, which is shown as the following formula:
where K is the number of segmentation classes, y
iFor the segmentation gold criteria of class i in the dataset,
the symbol # represents a union for the output of the fourteenth convolutional layer, i.e., the i-th class segmentation result obtained by the network.
The structure of each layer of the partial network is constructed as follows:
in the first layer of convolution layer, two groups of 32 convolution kernels connected by residual errors are used, the size of each convolution kernel is 3 multiplied by 3, the convolution kernels are used for outputting 32 characteristic graphs, then a channel adaptive attention (SE) module is used, and finally anti-aliasing pooling operation with the step length of 2 is adopted;
in the second layer of convolution layer, two groups of 48 convolution kernels connected by residual errors are used, the size of each convolution kernel is 3 multiplied by 3, the convolution kernels are used for outputting 48 characteristic graphs, then a channel self-adaptive attention (SE) module is used, and finally anti-aliasing pooling operation with the step length of 2 is adopted;
in the third layer of convolution layer, two groups of 64 convolution kernels connected by residual errors are used, the size of each convolution kernel is 3 multiplied by 3, the convolution kernels are used for outputting 64 characteristic graphs, then a channel adaptive attention (SE) module is used, and finally anti-aliasing pooling operation with the step length of 2 is adopted;
in the fourth layer of convolution layer, two groups of 96 convolution kernels connected by residual errors are used, the size of each convolution kernel is 3 multiplied by 3, the convolution kernels are used for outputting 96 characteristic graphs, then a channel adaptive attention (SE) module is used, and finally anti-aliasing pooling operation with the step length of 2 is adopted;
a fifth convolution layer, wherein two groups of 128 convolution kernels connected by residual errors are used, the size of each convolution kernel is 3 multiplied by 3, the convolution kernels are used for outputting 128 characteristic graphs, a channel adaptive attention (SE) module is used, and finally bilinear interpolation operation with a scaling factor of 2 is adopted;
a sixth splicing layer, which splices the output of the fourth layer of convolution layer and the output of the fifth layer of convolution layer, uses 128 convolution kernels, the size of the convolution kernels is 1 multiplied by 1, and outputs 128 characteristic graphs after characteristic graph fusion;
a seventh layer of convolution layer, two groups of 96 convolution kernels connected by residual errors are used, the size of the convolution kernels is 3 multiplied by 3, the convolution kernels are used for outputting 96 characteristic graphs, then a channel adaptive attention (SE) module is used, and finally bilinear interpolation operation with a scaling factor of 2 is adopted;
the eighth splicing layer splices the outputs of the third and fourth layers of convolution layers and the output of the seventh layer of convolution layer, uses 96 convolution kernels, the size of the convolution kernels is 1 multiplied by 1, and outputs 96 characteristic graphs after characteristic graph fusion is carried out;
in the ninth layer of convolution layer, two groups of 64 convolution kernels connected by residual errors are used, the size of each convolution kernel is 3 multiplied by 3, the convolution kernels are used for outputting 64 characteristic graphs, then a channel adaptive attention (SE) module is used, and finally bilinear interpolation operation with a scaling factor of 2 is adopted;
a tenth splicing layer, splicing the outputs of the second, third and fourth layers of convolution layer and the ninth layer of convolution layer, using 64 convolution kernels, the size of the convolution kernels is 1 multiplied by 1, fusing the feature maps, and outputting 64 feature maps;
in the eleventh convolutional layer, two groups of 48 convolutional kernels connected by residual errors are used, the size of the convolutional kernels is 3 multiplied by 3, the convolutional kernels are used for outputting 48 characteristic graphs, a channel adaptive attention (SE) module is used, and finally bilinear interpolation operation with a scaling factor of 2 is adopted;
a twelfth splicing layer, after splicing the outputs of the first, second, third and fourth layers of convolution layer and the output of the eleventh layer of convolution layer, using 32 convolution kernels, the size of the convolution kernels is 1 multiplied by 1, and outputting 32 characteristic graphs after characteristic graph fusion;
a thirteenth convolution layer, using two groups of 32 convolution kernels connected by residual errors, wherein the size of the convolution kernels is 3 multiplied by 3, and outputting 32 characteristic graphs, and then using a channel adaptive attention (SE) module;
and a fourteenth convolutional layer, which uses n convolutional kernels, wherein the size of the convolutional kernel is 1 × 1 × 1, and is used for outputting n types of segmentation results, and n is the number of types of segmentation results.
And 4, constructing the region-of-interest adaptive attention network.
(4-a) adding a region-of-interest adaptive attention network on the feature extraction network obtained in the step (3), wherein the network comprises two new convolutional layers: a first new convolutional layer and a second new convolutional layer, and the output of the fifth layer of convolutional layer is connected with the input of the first new convolutional layer and the output of the first new convolutional layer is connected with the input of the second new convolutional layer;
(4-b) constructing a supplementary loss function using a mean square error function from the output of the second new convolution layer and the data set labeled region of interest bounding rectangle information, as shown in the following equation:
Lroi=(t-tp)2
wherein t is the boundary rectangle frame information of the region of interest obtained in the step (1), and t ispThe predicted value of the boundary rectangle frame is the output of the new convolution layer of the second layer, namely the predicted value of the network pair boundary rectangle frame.
The structure of each layer of the partial region of interest adaptive attention network is constructed as follows:
in the first layer of new convolution layer, two groups of 32 convolution kernels connected by residual errors are used, the size of each convolution kernel is 3 multiplied by 3, the convolution kernels are used for outputting 32 characteristic graphs, then a channel adaptive attention (SE) module is adopted, and finally global average pooling operation is used;
and the second new convolution layer uses 6 convolution kernels, and the size of the convolution kernels is 1 multiplied by 1 to obtain the prediction result of the boundary frame of the region of interest.
And 5, obtaining a segmentation model.
(5-a) constructing a new total loss function from the loss functions obtained in steps (3-d) and (4-b), as shown in the following formula:
Ltotal=αLroi+(1-α)Lseg
wherein α is a segmentation loss function L for balancing the segmentation loss function of step (3-d)segAnd the auxiliary loss function L described in step (4-b)roiIs determined. In this embodiment, α is preferably 0.18.
And (5-b) combining the feature extraction network constructed in the step (3) and the region-of-interest adaptive attention network constructed in the step (4) to obtain a final segmentation model.
And 6, training a segmentation model.
And (3) training the segmentation model constructed in the step (5-b) by using the total loss function obtained in the step (5-a) through the extended sample data set obtained in the step (2), and optimizing the weight parameter of each layer through gradient back propagation to obtain the trained segmentation model.
And 7, segmenting the three-dimensional medical image.
And (4) segmenting each three-dimensional medical image in the test data set by using the trained segmentation model to obtain a final segmentation result of each medical image.
Example 2
Liver Tumor Segmentation experiments were performed on the public data set lits (liver Tumor Segmentation challenge) using the method in example 1. The three types of segmentation are background, liver and tumor. Computer environment of this experiment: the operating system is Linux ubuntu 16.06 version; two NVIDIA1080Ti 11G GPUs; the software platform is as follows: python, PyTorch.
FIG. 4 is a graph showing the results of segmentation of a liver and tumor using an embodiment of the present invention compared to the results of segmentation using other methods. Each medical image in the test data set is segmented by using a trained segmentation model, and the segmentation sample result is shown in fig. 4, wherein fig. 4(a) - (e) respectively show 3D U-Net, FCN + SegSE, CE Net and Attention U-Net, and the segmentation result of the method of the invention shows that the segmentation effect of the small tumor of fig. 4(a) and 4(d) is poor and the omission phenomenon occurs, the method of fig. 4(c) has the wrong segmentation condition, and the liver segmentation edge of fig. 4(b) is rough.
TABLE 1
The segmentation Average precision of the test sample set is compared by using a Dice coefficient (Dice coefficient), Sensitivity (Sensitivity) and Average Symmetric Surface Distance (ASD) in each method, and the result is shown in table 1.
Example 3
Brain Tumor Segmentation experiments were performed on the public data set brats (brain Tumor Segmentation challenge) using the method in example 1. The classification includes five types, namely background, necrotic tissue (necrossis), cyst (Edema), Non-enhanced Tumor (Non-enhanced Tumor) and enhanced Tumor (enhanced Tumor). Computer environment of this experiment: the operating system is Linux ubuntu 16.06 version; two NVIDIA1080Ti 11G GPUs; the software platform is as follows: python, PyTorch.
Fig. 5 is a graph showing the results of the brain tumor segmentation according to the embodiment of the present invention compared with the results of the other methods. Each medical image in the test data set is segmented by using the trained segmentation model, and the segmentation sample results are shown in FIG. 5, wherein FIGS. 5(a) - (e) are 3D U-Net, FCN + SegSE, CE Net, Attention U-Net, and the segmentation results of the method of the present invention, respectively. It can be seen that fig. 5(a) - (c) show obvious over-segmentation, the non-tumor region is segmented into tumors by mistake, and fig. 5(d) shows omission phenomenon for enhancing tumor segmentation.
The mean accuracy of the segmentation of the test sample set was compared by each of the above methods using Dice coefficient (Dice coefficient), Sensitivity (Sensitivity) and edge distance (HD), and the results are shown in table 2, where the whole Tumor (whole Tumor) includes necrotic tissue, cyst, non-enhanced Tumor and enhanced Tumor, and the Tumor core (Tumor core) includes necrotic tissue, non-enhanced Tumor and enhanced Tumor.
TABLE 2