Disclosure of Invention
The embodiment of the invention provides a medical image segmentation model training method, a segmentation method and a device, which are used for solving the problems of low segmentation precision and stability caused by smaller samples in the medical image segmentation processing process.
The technical scheme of the invention is as follows:
in one aspect, the invention provides a medical image segmentation model training method, which comprises the following steps:
acquiring a plurality of batches of source domain training data, wherein each batch of source domain training data comprises a group of support sets and a group of query sets, each support set and each query set comprises a plurality of medical images, and a block marked with a designated human organ in each medical image is used as a label; wherein, each batch of source domain training data only comprises a label of a human organ;
obtaining a preset reference model, wherein in an internal circulation, the reference model utilizes a support set in single batch source domain training data to perform k-step gradient descent, a second model parameter set is obtained by a first model parameter set, and a third model parameter set is obtained by one-step gradient descent through a query set; in an outer loop, a first model parameter set of an original reference model is taken as a starting point, a third model parameter set is taken as an updating direction, and a fourth model parameter set is obtained by updating according to a set step length, a learning rate and a cycle number; continuously training the reference model by utilizing each batch of source domain training data according to the steps of internal circulation and external circulation to obtain a source domain segmentation model; wherein k is a natural number;
acquiring target domain training data, wherein the target domain training data comprises a plurality of medical images, and marking blocks of target human organs as labels;
and training the source domain segmentation model by using target domain training data by adopting a layer freezing migration method to obtain a target domain segmentation model.
In some embodiments, the reference model is a U-Net network.
In some embodiments, the learning rate of the U-Net network in the outer loop is set to 1e -3 The step size is set to 0.4, the cycle number is 300, and k in the inner loop is 3.
In some embodiments, the U-Net network employs a cross entropy function as the loss function.
In some embodiments, the U-Net network comprises three parts, an encoder, a decoder, and a skip connection, the encoder comprising four downsampling modules, each downsampling module comprising two 3 x 3 convolutional layers, each convolutional layer followed by a batch normalization layer and a modified linear unit, and at the end of the downsampling module a2 x 2 max-pooling layer with a step size of 2; the decoder comprises four up-sampling modules and an activation function layer, wherein each up-sampling module comprises a2 multiplied by 2 transpose convolution layer, two 3 multiplied by 3 convolution layers, each convolution layer is followed by a batch normalization layer and a correction linear unit; and jumping the characteristic diagram before the maximum pooling layer of the downsampling module at the same depth and the characteristic diagram output by the transposed convolution layer in the upsampling module.
In some embodiments, the training the source domain segmentation model by using the target domain training data by adopting a layer freezing migration method to obtain a target domain segmentation model comprises the following steps:
freezing the first two downsampling blocks in the encoder to 1e -3 Fine tuning the learning rate of (2) on the target domain training data;
thawing a second downsampling module in the encoder at 1e -3 Is fine-tuned again on the target domain training data at an initial learning rate of 0.0077;
thawing the first two downsampling blocks in the encoder, 1e -3 Is fine-tuned on the target domain training data.
In some embodiments, data enhancement processing is performed on the medical images of the support set in each batch of source domain training data, including random angular flipping of 0 to 180 degrees, image translation, cross-cut transformation, and/or image stretching.
In another aspect, the present invention further provides a medical image segmentation method, including:
acquiring a medical image to be segmented, and cutting the image according to a preset size;
inputting the cut medical image to be segmented into the target domain segmentation model obtained by the medical image segmentation model training method so as to output a segmentation result.
In another aspect, the present invention also provides an electronic device, including a memory, a processor, and a computer program stored on the memory and executable on the processor, the processor implementing the steps of the method as described above when executing the program.
In another aspect, the present invention also provides a computer readable storage medium having stored thereon a computer program, characterized in that the program when executed by a processor implements the steps of the above method.
The invention has the advantages that:
according to the medical image segmentation model training method, the medical image segmentation model training method and the medical image segmentation model training device, model parameters are continuously adjusted by using a support set and a query set in source domain circulation in a source domain based on a meta-learning mode, an original basic model is adjusted by using the model parameters of the query set as an updating direction in external circulation, and model parameters sensitive to target domain task change are obtained through continuous training of multiple batches of source domain training data. When the source domain segmentation model is migrated to the target domain, the method can be better adapted to the new task of the target domain, and the generalization effect is improved.
Additional advantages, objects, and features of the invention will be set forth in part in the description which follows and in part will become apparent to those having ordinary skill in the art upon examination of the following or may be learned from practice of the invention. The objectives and other advantages of the invention will be realized and attained by the structure particularly pointed out in the written description and claims thereof as well as the appended drawings.
It will be appreciated by those skilled in the art that the objects and advantages that can be achieved with the present invention are not limited to the above-described specific ones, and that the above and other objects that can be achieved with the present invention will be more clearly understood from the following detailed description.
Detailed Description
The present invention will be described in further detail with reference to the following embodiments and the accompanying drawings, in order to make the objects, technical solutions and advantages of the present invention more apparent. The exemplary embodiments of the present invention and the descriptions thereof are used herein to explain the present invention, but are not intended to limit the invention.
It should be noted here that, in order to avoid obscuring the present invention due to unnecessary details, only structures and/or processing steps closely related to the solution according to the present invention are shown in the drawings, while other details not greatly related to the present invention are omitted.
It should be emphasized that the term "comprises/comprising" when used herein is taken to specify the presence of stated features, elements, steps or components, but does not preclude the presence or addition of one or more other features, elements, steps or components.
It is also noted herein that the term "coupled" may refer to not only a direct connection, but also an indirect connection in which an intermediate is present, unless otherwise specified.
In the medical image analysis stage of computer-aided diagnosis and treatment, semantic segmentation is often a fundamental task. Due to the special nature of medical images, the segmentation task has certain challenges and complexity, and is particularly characterized in that the requirements of medical image segmentation on precision and stability are high. However, in the advanced deep learning process of the medical image, the number of medical image samples is relatively small, and the traditional deep learning mode needs a huge data volume as a support to obtain stability and high generalization. One possible solution learns the initialization of the model from other similar tasks (source domain) and then performs fine tuning on a limited training set of target tasks (target domain), i.e. conventional migration learning, but when the model migrates from the source domain to the target domain, domain transfer problems occur, which inevitably result in poor migration learning adaptability due to the difference between the source domain and the target domain, creating sub-optimal problems. Domain adaptation and domain generalization can better solve this problem. Domain adaptation algorithms focus on using unlabeled or small amounts of labeled data in the target domain to enable a quick fit of models initialized on different source domains. The goal of the domain generalization algorithm is to train a model based on multiple source domains so that it can migrate directly to the target domain. According to the method, the segmentation task is realized on the target domain and the source domain with similar data distribution, in order to optimize the segmentation precision, the model parameters learned by the source domain data of a plurality of batches are finely adjusted on the target domain, and the domain transfer problem is optimized by using a domain adaptation method.
In the method, model parameters are continuously adjusted by using a support set and a query set in an inner circulation process, an original basic model is adjusted by taking the model parameters based on the query set as an updating direction in an outer circulation process, a source domain segmentation model is obtained through continuous training of source domain training data of multiple batches, the source domain segmentation model is migrated to a target domain for fine adjustment, and finally the target domain segmentation model which can adapt to a source domain segmentation task is obtained.
Specifically, the present application provides a medical image segmentation model training method, referring to fig. 1 and 2, including steps S101 to S104:
step S101: acquiring a plurality of batches of source domain training data, wherein each batch of source domain training data comprises a group of support sets and a group of query sets, each support set and each query set comprises a plurality of medical images, and a block marked with a designated human organ in each medical image is used as a label; wherein each batch of source domain training data only contains a label of a human organ.
Step S102: obtaining a preset reference model, wherein in an internal circulation, the reference model utilizes a support set in single batch source domain training data to perform k-step gradient descent, a second model parameter set is obtained by a first model parameter set, and a third model parameter set is obtained by one-step gradient descent through a query set; in an outer loop, a first model parameter set of an original reference model is taken as a starting point, a third model parameter set is taken as an updating direction, and a fourth model parameter set is obtained by updating according to a set step length, a learning rate and a cycle number; continuously training the reference model by utilizing each batch of source domain training data according to the steps of internal circulation and external circulation to obtain a source domain segmentation model; where k is a natural number.
Step S103: and acquiring target domain training data, wherein the target domain training data comprises a plurality of medical images, and marking a block of a target human organ as a label.
Step S104: and training the source domain segmentation model by using target domain training data by adopting a layer freezing migration method to obtain a target domain segmentation model.
In step S101, source domain data training data may be constructed using the public dataset of Medical Segmentation Decathlon. The medical images may be unified using nuclear magnetic resonance or CT scan images. The source domain training data includes a plurality of batches, each of which labels the same organ. For the image segmentation task, the label content comprises a block where a specified organ is located in the image, the label is marked by a professional doctor, and the accuracy meets clinical standards. Further, the image resolution for training may be adjusted to be uniform in size and meet the input requirements of the reference model, for example, the image resolution may be adjusted to 256×256.
Illustratively, 6 batches of source domain training data were created using Kings College London published heart images, IRCAD published liver images, nijmegen Medical Centre published prostate images, and Memorial Sloan Kettering Cancer Center provided pancreas, spleen, and cecum images, with each batch of source domain training data being further divided into a support set and a query set. The number of the medical images in the support set and the query set can be configured according to the actual application scene and the data volume, and under a certain condition, the number of the medical images in the support set and the query set can be configured according to a set proportion.
In step S102, the reference model is a network model for image segmentation, and FCN, deepMASK, U-Net and the like can be selected. In the initial state, parameters of the reference model may be randomly generated. In this application, the preferred benchmark model is a U-Net network. The most critical part of the U-Net is that each downsampling is cascaded with the corresponding upsampling, so that feature fusion of different scales is helpful to upsampling recovery pixels, specifically, the downsampling multiple of a high layer (shallow layer) is small, a feature map has finer map features, the downsampling multiple of a bottom layer (deep layer) is large, information is concentrated in a large amount, space loss is large, but the judgment of a target area (classification) is facilitated, and when the features of the high layer and the bottom layer are fused, the segmentation effect is quite good.
In the process of training the source domain segmentation model, the step S102 of the application refers to a meta learning mode, so that the source domain segmentation model can obtain stronger generalization for various segmentation tasks. Specifically, in the source domain training process, as shown in fig. 2, the reference model is trained one by one and continuously by using the constructed n batches of source domain training data. Each batch of source domain training data is trained in two parts, namely an inner loop and an outer loop. After the preset reference model is obtained, in an internal loop, the reference model uses a support set in single-batch source domain training data to carry out k-step gradient descent by model parameters theta' 0 Obtaining model parameter theta' k Then through inquiry and collectionStep gradient descent to obtain model parameters theta'; in an outer loop, the model parameters theta 'of the original reference model are used' 0 Starting from θ' 0 Updating according to a set step length, a learning rate and a cycle number to obtain a model parameter theta; and continuously training the reference model according to the steps of internal circulation and external circulation by utilizing the source domain training data of each batch to obtain a source domain segmentation model. In the inner loop, the reference model is trained iteratively on the support set batch, followed by the query set, for multiple gradient descent. In the outer loop, the model is updated by taking the original model parameters as a starting point and training the query set in the inner loop to obtain the parameters. And continuously performing continuous iterative training on the reference model based on the n batches of source domain training data to obtain a source domain segmentation model. In FIG. 2, the original U-Net network model is first trained with lot 1 source domain training data, and in the inner loop, a k-step gradient descent is performed on the support set by model parameters θ' 0 Obtaining model parameter theta' k And then obtaining a model parameter theta' through gradient descent of the query set by one step. The starting point of the external circulation is model theta' 0 In theta' 0 And (2) theta' is the model updating direction, beta is the step length, and the result model parameter theta of the external circulation is obtained. Then training the original U-Net network model by adopting source domain training data of the 2 nd batch, and obtaining model parameters theta by carrying out k-step gradient descent on a support set in an inner loop ″ k And then obtaining the model parameter theta' through gradient descent of the query set by one step. The starting point of the external circulation is the model theta, theta- & gt theta' is the model updating direction, beta is the step length, and the result model parameter theta of the external circulation is obtained 1 . And by analogy, the final source domain segmentation model is obtained by utilizing the continuous training of n batches of data.
By the adaptive learning mode, the finally obtained source domain segmentation model is not an optimal solution for a certain batch of source domain training data, but is a global optimal solution for all batches of source domain training data, has good adaptability to various target tasks, does not have the problem of domain transfer in the migration process, and does not have the problem of over fitting or poor segmentation effect.
The method has obvious difference with the existing MAML and Reptile two-element learning methods. In the step S102, training of the support set and the query set in the inner loop is continuous, and compared with the process of training the support set in the MAML and performing parameter adjustment on the original reference model by using the loss calculated on the query set in the inner loop, the method and the device can effectively transfer the model parameter characteristics obtained by training the support set and the query set to the original base model. Compared with the Reptile which does not distinguish the support set from the query set, the method and the device have the advantages of more stable training effect and stronger adaptability to various tasks.
In some embodiments, in step S102, data enhancement processing is performed on the medical images of the support set in each batch of source domain training data, including random angular flipping of 0 to 180 degrees, image translation, cross-cut transformation, and/or image stretching.
The training of the neural network generally requires a large amount of data to obtain a relatively ideal result, and under the condition of limited data volume, the diversity of training samples can be increased through data enhancement, so that the robustness of the model is improved, and the overfitting is avoided. Meanwhile, the characteristics of the training sample are randomly changed, so that the dependence of the model on certain attributes can be reduced, and the generalization capability of the model is improved. Medical images in each batch of source domain training data are input into a reference model after random transformation, and training is carried out. A greater number of training sets may also be formed by transformation.
In step S103, the target domain training data is training data constructed based on the target organ to be identified, including a plurality of medical images and a block in which the target organ is labeled. In the actual application process, a certain number of medical images can be set in the target domain to serve as a test set.
Specifically, as shown in FIG. 3, in some embodiments, the U-Net network comprises three parts, an encoder, a decoder, and a jump connection, the encoder comprises four downsampling modules, each downsampling module comprises two 3X 3 convolutional layers, each convolutional layer is followed by a batch normalization layer and a correction linear unit, and the downsampling module is followed by 2X 2 max pooling with a step size of 2A layer; the decoder comprises four up-sampling modules and an activation function layer, wherein each up-sampling module comprises a2 multiplied by 2 transpose convolution layer, two 3 multiplied by 3 convolution layers, each convolution layer is followed by a batch normalization layer and a correction linear unit; and jumping the characteristic diagram before the maximum pooling layer of the downsampling module at the same depth and the characteristic diagram output by the transposed convolution layer in the upsampling module. In some embodiments, the U-Net network employs a cross entropy function as the loss function, and the learning rate α of the U-Net network is set to 1e -3 The step size beta is set to 0.4, the cycle number is 300, and k in the inner cycle is 3.
Specifically, in the implementation process, a deep learning framework keras component basic semantic segmentation model U-Net is adopted, the experimental system environment is Ubuntu, and a NVIDIA Geforce 1080Ti display card with 11G video memory is adopted.
In step S104, the source domain segmentation model is trained by using the target domain training data by using the layer freeze migration method. When the neural network learns multiple phases of the hidden representation across layers, the differences between the source domain data and the target domain data distribution will be amplified. The source domain data acts as positive activation, while the target domain data causes negative activation. The large variance may corrupt the knowledge and experience learned from the source domain, i.e., the problem of amnestic corruption. To solve this problem, it is important to keep the difference relatively small throughout the network, which is achieved by gradually thawing the shallow layer in step S104 of the present application. The layered freezing method is employed at shallow layers because medical image segmentation focuses more on low-dimensional features than scene segmentation. These low-dimensional features are similar to the general features, and are typically extracted in shallow layers.
Specifically, in some embodiments, step S104, that is, training the source domain segmentation model by using the target domain training data by using the layer freeze migration method, obtains a target domain segmentation model, includes steps S1041 to S1043:
step S1041: the first two downsampling modules in the freeze encoder are denoted by 1e -3 Is fine-tuned on the target domain training data.
Step S1042: thawing encoderThe second downsampling block of (1 a) at 1e -3 Is again fine-tuned on the target domain training data at an initial learning rate of 0.0077.
Step S1043: thawing the first two downsampling blocks in the encoder, 1e -3 Is fine-tuned on the target domain training data.
In this embodiment, shallow layers in the encoder are frozen and thawed layer by layer for fine tuning based on the U-Net model constructed in step S103. In the U-Net network structure, the encoder can capture semantic information in the down sampling process, and the decoder can accurately position in the up sampling process. Since medical image segmentation focuses more on low-dimensional features, these features are generally extracted in the shallow layer, and in order to reduce negative activation, in this embodiment, the first two downsampling modules of the encoder are frozen, and the deep-to-shallow layer is progressively thawed during the thawing process. Specifically, the target domain training data may be finely tuned according to a set period by dividing the target domain training data into a plurality of batches, for example, 8 batches of data may be set, and training may be performed for 300 periods.
On the other hand, the invention also provides a medical image segmentation method, which comprises the steps of S201 to S202:
step S201: and acquiring the medical image to be segmented, and cutting the image according to a preset size.
Step S202: and (3) inputting the cut medical image to be segmented into the target domain segmentation model obtained in the medical image segmentation model training method in the steps S101-S104 so as to output a segmentation result.
In the present embodiment, steps S103 to S104 obtain a target domain segmentation model for segmenting the target organ by fine tuning in the target domain based on the specified segmentation task. In step S201, after the medical image to be segmented is acquired, a segmentation is performed according to the input size requirement of the target domain segmentation model. In step S202, the target domain segmentation model obtained in step S104 is computed and the segmentation result is output.
In another aspect, the present invention also provides an electronic device, including a memory, a processor, and a computer program stored on the memory and executable on the processor, the processor implementing the steps of the method as described above when executing the program.
In another aspect, the present invention also provides a computer readable storage medium having stored thereon a computer program, characterized in that the program when executed by a processor implements the steps of the above method.
The invention is illustrated below with reference to specific examples:
the published dataset of Medical Segmentation Decathlon was used as the dataset for the experiments herein. The dataset comprises nuclear magnetic resonance or CT scan images of ten different human organs. All images are marked by a professional physician, and the accuracy accords with clinical standards. The image labeling rate is adjusted to 256×256, and the multi-value labeling is simplified to a binary segmentation task. From which images of six organs were selected for verification experiments. These six organs were a Kings College London published heart image, an IRCAD published liver image, a Nijmegen Medical Centre published prostate image, and a Memorial Sloan Kettering Cancer Center provided pancreas, spleen, and cecum images, respectively.
Cecum and liver were selected as target domain data. To construct a few sample scene based on these two tasks, we divide the six images mentioned above into two groups. The first set of target domain training sets included 214 images randomly sampled from cecal data, and the target domain test set consisted of the remaining 1070 cecal data. The source domain training set of the first set included three batches of data, consisting of 2611 total images of prostate, pancreas and spleen, respectively. The second set of target domain training sets consisted of 191 images randomly sampled from liver data, and the target domain test set consisted of the remaining 18791 cecal data. The source domain training set of the second set includes three batches of data, each consisting of 2877 total images of the prostate, heart and pancreas.
In this embodiment, a medical image segmentation model training method is provided, as follows:
first, a basic semantic segmentation model U-Net is constructed based on a deep learning framework Keras, and a network structure is shown in FIG. 3. The network structure of the U-Net consists of an encoder, a decoder and a jump connection. The encoder comprises four downsampling modules, each comprising two 3 x 3 convolutional layers (conv2d+bn+relu), each followed by a batch normalization layer (Batch Normalization, BN) and a modified linear unit (Rectified Linear Unit, relu), at the end of which is a2 x 2 max-pooling layer with a step size of 2. The decoder architecture is similar to the encoder except that the max-pooling layer is replaced with a2 x 2 transposed convolutional layer (transposed convolution). The jump connection connects the feature map before the maximum pooling layer with the same depth with the feature map output by the transposed convolution layer in the up-sampling module. The encoder shallow freezing method can also protect the positive activation of the relevant decoder layer, since the upsampling module of a specific depth uses the feature map generated by the downsampling module at the corresponding depth.
Model parameters of the U-Net are randomly initialized, and a model is trained on a source domain training set. Specifically, the method includes an inner loop and an outer loop in the manner described in step S102. The inner loop is based on a support set and a query set, first, a continuous k-step gradient descent is performed on the support set from the model parameter θ' 0 Obtaining model parameter theta' k And then obtaining a model parameter theta' through gradient descent of the query set by one step. The starting point of the external circulation is model theta' 0 In theta' 0 And (2) theta' is the model updating direction, beta is the step length, and the result model parameter theta of the external circulation is obtained. The U-Net model uses a cross entropy function as a loss function, the size of a source domain element learning batch is 6, and each batch of data is continuously trained according to the steps of internal circulation and external circulation for 300 periods. Further, the learning rate α of the U-Net model is set to 1e -3 The step size β is set to 0.4, the number of learning cycles in the outer loop is also 300, and k in the inner loop is 3. Finally, training based on the source domain training set to obtain a source domain segmentation model.
By means of a function f
θ Represents U-Net, when f
θ When training based on batch tau, model parameter theta is updated to theta 'through continuous i-step gradient descent'
i The updated model is expressed as
The gradient update procedure at the ith iteration can be expressed as follows:
wherein alpha is a fixed super parameter,
representation model->
Based on the loss function of batch τ, θ'
i By continuously optimizing the model +.>
Obtained.
Further, meta-optimization can be described as the following expression:
the optimization of the whole meta-learning is based on the model parameter θ, which can be updated based on another batch τ' i To obtain θ', this optimization process can be described as follows:
where β is the meta-learning step size, representing the update rate towards the final parameter.
Further, the source domain segmentation model is transferred to the target domain for fine tuning according to the layer freezing migration method. In combination with the structure of the U-Net network, the first two downsampling modules are frozen first, and the learning rate is set to be 1e -3 . Then thawing the second downsampling layer in the deep layer and taking 1e -3 And an attenuation rate of 0.0077Again fine-tuning over the target domain. Finally, all layers were thawed and treated with 1e -4 The learning rate of (2) is fine-tuned last on the target sum. And (3) carrying out fine adjustment on the batch size of the target domain to 8, and training for 300 periods to obtain a target domain segmentation model. And finally, carrying out test evaluation on the target domain segmentation model on the target domain test set.
Further, the embodiment adopts the dess score as an objective evaluation index, and the dess score is an index for measuring the overlapping degree of two images, and is widely applied to the evaluation of the medical image segmentation effect. For the medical image binary segmentation task, we set the ratio of the region of interest to 0.9 and the background region to 0.1. In the embodiment, test verification is performed on three modes, wherein the first mode aims to directly train a model on a target domain based on randomly initialized model parameters; mode two is based on model parameters pre-trained on the source domain and fine-tuned on the target domain. Mode three is to apply a meta learning method in a source domain training stage, wherein the meta learning method comprises three meta learning structures including Reptile, MAML and a network structure (our) of the application. And comparative analysis was performed using a dess score, and the results are shown in table 1.
TABLE 1
In table 1, in the small sample scenario of cecum, pattern one was trained on the target domain using standard supervised learning methods, yielding a resultant dess score of 0.537. Mode two single source pre-training was performed on three different source domains, including prostate, pancreas and spleen, and the model parameters thus obtained were used as initialization parameters for the target domain training, yielding the dess scores of 0.591, 0.590 and 0.591, respectively. Meanwhile, a mode two experiment is carried out on a multisource domain I formed by combining three tasks, and a result with a dess score of 0.611 is obtained. Whereas in mode three, the Reptile, MAML and the algorithms presented herein implement the dess scores of 0.608, 0.615 and 0.628, respectively. Further, the experimental layer freezes the gain brought by the migration strategy. For the various methods of designing models two and three, this migration method can bring about a 4% improvement in the dess score. For a few sample scene of the liver, design mode one can achieve a dess score representation of 0.904. Shan Yuanyu mode two achieves a score of 0.903, 0.902 and 0.905, respectively, and multi-source domain II mode two achieves a score of 0.905. Reptile, MAML, and the algorithms presented herein implement a dess score of 0.904, 0.905, and 0.912, respectively. The best results were achieved when migration was performed using the layer freezing method, reaching a dess score of 0.926.
Further, as shown in fig. 4, the convergence of the loss function in the cecum-less sample scenario is illustrated. Curves using A1, A2, A3 and A4 show the multi-source domain I mode two, the solution of the present application, reptile and MAML, respectively. It is apparent that the other three methods can converge to smaller values than the method of the present application, but the method of the present application obtains better dess scores on the test data, which means that the method of the present application can perform better in avoiding the problem of training with few samples and fitting.
The innovation of the present application is the proposed meta learning algorithm, as shown in fig. 5 (a). Unlike MAML and Reptile, the inner loop of MAML gets the intermediate parameter θ of the model by gradient descent over the support set s And based on this parameter, calculate the loss and gradient on the query set, which directly acts on the model original parameter θ, while the model parameter on which this gradient is based is derived from the support set, the final gradient descent direction per outer loop of MAML is completely determined by the loss of the query set, which is overly focused on the ability to learn the query set based on the support set experience, which imbalance weakens the effect of the support set to some extent.
As shown in fig. 5 (b), reptile does not have the division of the support set and the query set, and its inner loop obtains the intermediate model parameter θ 'by continuous gradient descent on the same batch, uses (θ, θ') as the outer loop gradient direction and updates the model original parameter θ with a certain step, so that the learning mode emphasizes the learning ability of the model on the same batch and enhances the generalization ability with a certain step, but ignores the opportunity that different batches may bring about stronger generalization ability.
The algorithm of this embodiment is shown in fig. 5 (c), where the outer loop uses the same parameter update strategy as Reptile, but uses the idea of MAML in the inner loop, uses a diversity strategy that supports and queries sets, and the last gradient descent is based on the query set. Such an approach not only allows for diversity training learned "learning" capabilities, and more balances the contributions of the support set and the query set.
By combining the result analysis, compared with an advanced meta-learning method, the method of the embodiment has obvious improvement, and can reduce the domain migration problem and improve the segmentation precision, and meanwhile, the overfitting problem can be reduced to a certain extent.
In summary, in the medical image segmentation model training method, the segmentation method and the device, according to the medical image segmentation model training method, model parameters are continuously adjusted by using a support set and a query set in source domain circulation in a source domain based on meta-learning mode, an original basic model is adjusted by using the model parameters of the query set as an updating direction in outer circulation, and model parameters sensitive to target domain task change are obtained through continuous training of multiple batches of source domain training data. When the source domain segmentation model is migrated to the target domain, the method can be better adapted to the new task of the target domain, and the generalization effect is improved.
Those of ordinary skill in the art will appreciate that the various illustrative components, systems, and methods described in connection with the embodiments disclosed herein can be implemented as hardware, software, or a combination of both. The particular implementation is hardware or software dependent on the specific application of the solution and the design constraints. Skilled artisans may implement the described functionality in varying ways for each particular application, but such implementation decisions should not be interpreted as causing a departure from the scope of the present invention. When implemented in hardware, it may be, for example, an electronic circuit, an Application Specific Integrated Circuit (ASIC), suitable firmware, a plug-in, a function card, or the like. When implemented in software, the elements of the invention are the programs or code segments used to perform the required tasks. The program or code segments may be stored in a machine readable medium or transmitted over transmission media or communication links by a data signal carried in a carrier wave. A "machine-readable medium" may include any medium that can store or transfer information. Examples of machine-readable media include electronic circuitry, semiconductor memory devices, ROM, flash memory, erasable ROM (EROM), floppy disks, CD-ROMs, optical disks, hard disks, fiber optic media, radio Frequency (RF) links, and the like. The code segments may be downloaded via computer networks such as the internet, intranets, etc.
It should also be noted that the exemplary embodiments mentioned in this disclosure describe some methods or systems based on a series of steps or devices. However, the present invention is not limited to the order of the above-described steps, that is, the steps may be performed in the order mentioned in the embodiments, or may be performed in a different order from the order in the embodiments, or several steps may be performed simultaneously.
In this disclosure, features that are described and/or illustrated with respect to one embodiment may be used in the same way or in a similar way in one or more other embodiments and/or in combination with or instead of the features of the other embodiments.
The above description is only of the preferred embodiments of the present invention and is not intended to limit the present invention, and various modifications and variations can be made to the embodiments of the present invention by those skilled in the art. Any modification, equivalent replacement, improvement, etc. made within the spirit and principle of the present invention should be included in the protection scope of the present invention.