CN115861617A - Semantic segmentation model training method and device, computer equipment and storage medium - Google Patents

Semantic segmentation model training method and device, computer equipment and storage medium Download PDF

Info

Publication number
CN115861617A
CN115861617A CN202211591464.5A CN202211591464A CN115861617A CN 115861617 A CN115861617 A CN 115861617A CN 202211591464 A CN202211591464 A CN 202211591464A CN 115861617 A CN115861617 A CN 115861617A
Authority
CN
China
Prior art keywords
image
sample
type
sample support
mask
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211591464.5A
Other languages
Chinese (zh)
Inventor
李瑞敏
王照
吴丹
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Industrial and Commercial Bank of China Ltd ICBC
Original Assignee
Industrial and Commercial Bank of China Ltd ICBC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Industrial and Commercial Bank of China Ltd ICBC filed Critical Industrial and Commercial Bank of China Ltd ICBC
Priority to CN202211591464.5A priority Critical patent/CN115861617A/en
Publication of CN115861617A publication Critical patent/CN115861617A/en
Pending legal-status Critical Current

Links

Images

Landscapes

  • Image Analysis (AREA)

Abstract

The application relates to a semantic segmentation model training method, a semantic segmentation model training device, a computer device, a storage medium and a computer program product. The method comprises the following steps: based on a sample support image in a training task and a mask label corresponding to the sample support image, acquiring a class prototype vector corresponding to each image type; performing semantic segmentation on each sample image based on the class prototype vector to obtain a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; constructing a target loss function based on a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model. By adopting the method, the accuracy of the semantic segmentation of the loan monitoring image can be ensured under the condition of only a small amount of sample data.

Description

Semantic segmentation model training method and device, computer equipment and storage medium
Technical Field
The present application relates to the field of image recognition technologies, and in particular, to a semantic segmentation model training method, apparatus, computer device, storage medium, and computer program product.
Background
Financial institutions such as banks often need to perform risk analysis on items applying for loan, for example, background investigation is performed before loan to determine whether to offer a loan, so that the problem that customers cannot pay back in time after the loan and bring direct economic loss to financial enterprises such as banks is avoided, and the conditions of loan items are monitored in real time after the loan, and the loan is stopped in time.
Usually, business personnel perform pre-loan background investigation and post-loan real-time monitoring by comparing images of loan items at different periods, however, in the loan field of road maintenance, farmland planting and the like, the problems of difficulty in background investigation in remote areas and less effective value of monitoring data exist, and only a small amount of sample data can be obtained. In order to accurately monitor loan projects in remote areas, a small sample semantic segmentation model is generally established to segment a small amount of loan monitoring images, and mask prediction results of each pixel of a target monitoring image are obtained, so that the image is subjected to feature enhancement.
The existing scheme is based on a prototype learning method, a category prototype is directly extracted from a support image, and deviation often exists when semantic categories in a query image are matched. Under the condition of only a small amount of sample data, a semantic segmentation model is established by a traditional prototype learning method to perform semantic segmentation on the loan monitoring image, and the segmentation result is not accurate enough.
Disclosure of Invention
In view of the above, there is a need to provide a semantic segmentation model training method, device, computer readable storage medium and computer program product capable of accurately performing semantic segmentation based on a small amount of loan monitoring images.
In a first aspect, the present application provides a semantic segmentation model training method. The method comprises the following steps:
acquiring training tasks, wherein the training tasks all comprise a sample image set of at least one image type, and the sample image set comprises at least one sample support image and at least one sample query image;
determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image;
constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label;
and constructing a target loss function based on the first loss function and the second loss function, and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model.
In one embodiment, determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image includes: extracting a multi-channel feature map of each sample support image aiming at each sample support image of the current image type; acquiring a class prototype vector corresponding to each sample support image based on a multi-channel feature map of each sample support image and a mask label corresponding to each sample support image; and integrating the class prototype vectors corresponding to the support images of all samples to obtain the class prototype vector corresponding to the current image type.
In one embodiment, obtaining a class prototype vector corresponding to each sample support image based on a multi-channel feature map of each sample support image and a mask label corresponding to each sample support image includes: the method comprises the steps that multi-channel feature maps of a current sample support image of a current image type are subjected to up-sampling processing, and based on mask labels corresponding to the current sample support image, first foreground feature maps in first mask regions corresponding to the mask labels in up-sampling processing results of multiple channels are extracted; acquiring a first type of prototype feature of each channel based on the first foreground feature map of each channel, and constructing a first type of prototype vector corresponding to the current sample support image based on the first type of prototype feature of each channel; performing downsampling processing on a mask label corresponding to the current sample support image, and extracting a second foreground feature map in a second mask area corresponding to a downsampling result from a multi-channel feature map of the current sample support image based on a downsampling result; acquiring a second type of prototype feature of each channel based on the second foreground feature map of each channel, and constructing a second type of prototype vector corresponding to the current sample support image based on the second type of prototype feature of each channel; and integrating the first type prototype vector and the second type prototype vector to obtain a type prototype vector corresponding to the current sample support image.
In one embodiment, the sample image is a sample support image; based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the mask prediction result comprises the following steps: extracting a multi-channel feature map of a current sample support image of the current image type; based on the class prototype vector of the current image type, obtaining a similarity feature map of the current sample support image, wherein the similarity feature map is used for representing the similarity degree between the multi-channel feature map of the current sample support image and the class prototype vector of the current image type; performing feature enhancement processing on the multi-channel feature map of the current sample support image based on the similarity feature map to obtain enhanced features; and performing void space convolution pooling on the enhanced features, and performing convolution classification on the processing result to obtain a mask prediction result corresponding to the current sample support image.
In one embodiment, the sample image is a sample query image; based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the mask prediction result comprises the following steps: aiming at a current sample query image of a current image type, extracting a multichannel feature map of the current sample query image, and extracting a multichannel feature map of each sample support image of each image type; performing downsampling processing on the mask labels corresponding to the sample support images of the image types, and extracting foreground feature maps corresponding to the sample support images of the image types based on downsampling results corresponding to the sample support images of the image types; performing attention guide processing based on a multi-channel feature map of a current sample query image and foreground feature maps corresponding to sample support images of all image types to obtain an attention guide processing result; based on the class prototype vectors of all image types, obtaining a similarity feature map of the current sample query image aiming at all the class prototype vectors, wherein the similarity feature map is used for representing the similarity degree between a multi-channel feature map of the current sample query image and the corresponding class prototype vectors; based on the similarity feature maps aiming at various prototype vectors, carrying out feature enhancement processing on the multi-channel feature map of the current sample query image to obtain enhanced features; and performing void space convolution pooling on the enhanced features, and performing convolution classification on processing results to obtain a mask prediction result corresponding to the current sample query image.
In one embodiment, the training method of the semantic segmentation model further includes: after the trained semantic segmentation model is obtained, a target image set to be subjected to semantic segmentation is obtained, the type of a target image corresponding to the target image set is not involved in a training process aiming at the semantic segmentation model, and the target image set comprises at least one target support image and at least one target query image; determining a class prototype vector corresponding to the type of the target image based on the target support image and the mask label corresponding to the target support image; performing semantic segmentation on the target support image through the trained semantic segmentation model based on the target support image and the class prototype vector corresponding to the type of the target image to obtain a mask prediction result corresponding to the target support image; constructing a third loss function based on the difference between the mask prediction result corresponding to the target support image and the corresponding mask label, and retraining the trained semantic segmentation model based on the third loss function; and performing semantic segmentation on the target query image based on the retrained semantic segmentation model.
In one embodiment, the training method of the semantic segmentation model further includes: the sample image set and the target image set are obtained by performing pre-loan investigation or post-loan monitoring on the loan object.
In a second aspect, the application further provides a semantic segmentation model training device. The device comprises:
the task acquisition module is used for acquiring training tasks, wherein the training tasks all comprise a sample image set of at least one image type, and the sample image set comprises at least one sample support image and at least one sample query image;
the prototype fusion module is used for determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
the mask prediction module is used for carrying out semantic segmentation on the sample image of each image type through a semantic segmentation model based on the class prototype vector corresponding to each image type and the sample support image of each image type to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image;
the loss construction module is used for constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label;
and the model optimization module is used for constructing a target loss function based on the first loss function and the second loss function, training the semantic segmentation model based on the target loss function and obtaining the trained semantic segmentation model.
In a third aspect, the present application also provides a computer device. The computer device comprises a memory storing a computer program and a processor implementing the following steps when executing the computer program:
acquiring training tasks, wherein the training tasks all comprise a sample image set of at least one image type, and the sample image set comprises at least one sample support image and at least one sample query image;
determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image;
constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label;
and constructing a target loss function based on the first loss function and the second loss function, and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model.
In a fourth aspect, the present application further provides a computer-readable storage medium. The computer-readable storage medium having stored thereon a computer program which, when executed by a processor, performs the steps of:
acquiring training tasks, wherein the training tasks all comprise a sample image set of at least one image type, and the sample image set comprises at least one sample support image and at least one sample query image;
determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image;
constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label;
and constructing a target loss function based on the first loss function and the second loss function, and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model.
In a fifth aspect, the present application further provides a computer program product. The computer program product comprising a computer program which when executed by a processor performs the steps of:
acquiring training tasks, wherein each training task comprises a sample image set of at least one image type, and each sample image set comprises at least one sample support image and at least one sample query image;
determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image;
constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label;
and constructing a target loss function based on the first loss function and the second loss function, and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model.
The semantic segmentation model training method, the semantic segmentation model training device, the computer equipment, the storage medium and the computer program product are used for acquiring class prototype vectors corresponding to all image types based on sample support images in training tasks and mask labels corresponding to the sample support images; performing semantic segmentation on each sample image based on the class prototype vector to obtain a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; constructing a target loss function based on a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model. The whole semantic segmentation model training process only needs a small number of training samples, for each training task, class prototype vectors corresponding to all image types in the training task are obtained through a sample support image and a corresponding mask label, then the sample support image and a sample query image are segmented based on the class prototype vectors, finally a target loss function combining sample support image loss and sample query image loss is constructed, and the accuracy of the semantic segmentation of the loan monitoring image is guaranteed under the condition of only a small number of sample data. The whole semantic segmentation model training process only needs a small number of training samples, for each training task, class prototype vectors corresponding to all image types in the training task are obtained through a sample support image and a corresponding mask label, then the sample support image and a sample query image are segmented based on the class prototype vectors, finally a target loss function combining sample support image loss and sample query image loss is constructed, and the accuracy of the semantic segmentation of the loan monitoring image is guaranteed under the condition of only a small number of sample data.
Drawings
FIG. 1 is a diagram of an application environment of a semantic segmentation model training method in one embodiment;
FIG. 2 is a schematic flow chart diagram of a semantic segmentation model training method in one embodiment;
FIG. 3 is a flow diagram illustrating a process for obtaining a mask prediction result in one embodiment;
FIG. 4 is a block diagram of a semantic segmentation model training method according to an embodiment;
FIG. 5 is a block diagram of an attention mechanism in one embodiment;
FIG. 6 is a block diagram showing the structure of a semantic segmentation model training apparatus according to an embodiment;
FIG. 7 is a diagram illustrating an internal structure of a computer device according to an embodiment.
Detailed Description
In order to make the objects, technical solutions and advantages of the present application more apparent, the present application is described in further detail below with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are merely illustrative of the present application and are not intended to limit the present application.
The semantic segmentation model training method provided by the embodiment of the application can be applied to the application environment shown in fig. 1. Wherein the terminal 102 communicates with the server 104 via a network. The data storage system may store data that the server 104 needs to process. The data storage system may be integrated on the server 104 or may be placed on the cloud or other network server. The terminal 102 obtains training tasks, the training tasks each include a sample image set of at least one image type, the sample image set includes at least one sample support image and at least one sample query image, and the terminal 102 sends the training tasks to the server 104 through the communication network.
The server 104 determines a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image; based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image; constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label; and constructing a target loss function based on the first loss function and the second loss function, training the semantic segmentation model based on the target loss function to obtain a trained semantic segmentation model, and storing the trained semantic segmentation model to a data storage system.
The terminal 102 may be, but is not limited to, various personal computers, notebook computers, smart phones, tablet computers, and the like. The server 104 may be implemented as a stand-alone server or as a server cluster comprised of multiple servers.
In one embodiment, as shown in fig. 2, a semantic segmentation model training method is provided, which is described by taking the method as an example applied to the server 104 in fig. 1, and includes the following steps:
s100: training tasks are obtained, wherein each training task comprises a sample image set of at least one image type, and each sample image set comprises at least one sample support image and at least one sample query image.
The training tasks are extracted from a training set, each training task at least comprises one image type, each training task is composed of a support image set and a query image set, the support image set comprises at least one sample support image, and the query image set comprises at least one sample query image.
Optionally from a training set
Figure BDA0003994622830000071
(I i Representing an image, Y i Representing segmentation masks corresponding to the images, and m being the number of images in the training set) to extract a plurality of training tasks, wherein each training task is composed of a support image set and a query image set. For the N-way K-shot training task, N categories are shown under the training task, and each category is provided withAnd if each category corresponds to M sample query images, the training task has N x (K + M) sample images, and all the sample images under the training task and segmentation masks corresponding to all the sample images are used as the input of the semantic segmentation model.
In the embodiment, a plurality of training tasks are acquired from a training set, each training task consists of a support image set and a query image set, all sample images under each training task and segmentation masks corresponding to all sample images are used as input of a semantic segmentation model, so that the semantic segmentation model performs semantic segmentation on new image types under new tasks through multiple training task learners, and the loan monitoring accuracy is improved under the condition of only a small amount of sample data through small sample learning.
S200: and determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image.
Wherein, the mask of the sample support image may be a foreground mask labeled 0 or 1; the class prototype vector refers to a prototype representative vector of each class of sample support image.
Optionally, in a semantic segmentation model training stage, all sample support images under each training task are input to a feature extraction network to obtain deep features of the sample support images, the deep features of the sample support images are input to a prototype fusion module, and a prototype representative vector of each type of image in the sample support images is obtained through the deep features of the sample support images and foreground mask labels corresponding to the sample support images.
The first five layers of the VGG16 network and the convolution layer with the convolution kernel size of 3 x 3 are used as the feature extraction network, wherein the convolution layer with the convolution kernel size of 3 x 3 is used as the last layer of the feature extraction network, and the function of the convolution layer is to adjust the dimension of the sample supporting image features and reduce the parameter number of the semantic segmentation model. After a feature map containing deep features of a sample support image is obtained by a feature extraction network, the size of the feature map is consistent with that of a mask of the sample support image in an upsampling mode, and mask averaging pooling operation is performed on the feature map by using a mask label to obtain a first type prototype vector. In order to avoid the influence of the upsampling operation on the integrity of the foreground features of the sample support image, the size of the mask of the sample support image is consistent with that of the feature map in a downsampling mode, and mask averaging pooling operation is performed on the feature map by using the mask label again to obtain a second type prototype vector. And the dimensionality of the first type prototype vector and the second type prototype vector is the channel number of the sample support image, and the first type prototype vector and the second type prototype vector are averaged to obtain a final prototype representative vector as a type prototype vector.
In this embodiment, the deep features of the sample support image in the training task are extracted through the feature extraction network and input to the prototype fusion module. The prototype fusion module samples the feature image of the sample support image to the same size as the mask of the sample support image in an up-sampling mode, obtains a first type of prototype vector in a mask averaging pooling operation, samples the mask of the sample support image to the same size as the feature image of the sample support image in a down-sampling mode, obtains a second type of prototype vector, and averages the first type of prototype vector and the second type of prototype vector to finally obtain a type of prototype vector for guiding the query image to be segmented. The whole prototype-like vector obtaining process is based on an improved prototype fusion strategy, and the accuracy of the loan monitoring image semantic segmentation is improved under the condition that only a small amount of sample data exists.
S300: and performing semantic segmentation on the sample image of each image type through a semantic segmentation model based on the class prototype vector corresponding to each image type and the sample support image of each image type to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image.
The task of semantic segmentation is to train a neural network to output a mask for each pixel of the target image. According to the method, a plurality of training tasks are obtained based on a training set, all sample images under each training task and segmentation masks corresponding to all sample images are used as input of a semantic segmentation model, and class prototype vectors corresponding to all image types are obtained through sample support images and segmentation masks corresponding to the sample support images.
On one hand, a sample support image of each image type is segmented through a semantic segmentation model, a similarity graph between deep features of the sample support image and similar prototypes of the type is calculated, the similarity graph is applied to the deep features of the sample support image to perform feature enhancement on the sample support image, and then an ASPP (advanced Spatial Pyramid clustering) and two convolution classification layers are connected to obtain a mask prediction result of the sample support image; on the other hand, the sample query image of each image type is segmented through a semantic segmentation model, a similarity graph between deep features of the sample query image and class prototypes of the type is calculated, the similarity graph is applied to the deep features of the sample query image to perform feature enhancement on the sample query image, and then the ASPP and the two volume integral class layers sharing parameters with the sample support image classification layer are connected to obtain a mask prediction result of the sample query image.
In the embodiment, mask prediction results of the sample support image and the sample query image are respectively obtained through the support image classification layer and the query image classification layer in the semantic segmentation model, and the query image classification layer and the support image classification layer share parameters, so that the accuracy of the semantic segmentation of the loan monitoring image is improved under the condition of only a small amount of sample data.
S400: and constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label.
Optionally, a cross entropy loss function is employed to calculate the prediction loss for the sample support image and the sample query image. For each training task, constructing a first loss function, and adding the prediction losses of all sample support images in the training task to be used as the prediction losses of the sample support images; and constructing a second loss function, and adding the prediction losses of all the sample query images in the training task to be used as the prediction losses of the sample query images.
In the embodiment, a cross entropy loss function for measuring a predicted value of a sample image mask and a true value of the sample image mask is constructed to obtain the prediction loss of all sample support images and the prediction loss of all sample query images under a certain training task, so that the semantic segmentation model is updated for multiple times through multiple training tasks, and the accuracy of the semantic segmentation of the loan monitoring image is improved under the condition of only a small amount of sample data.
S500: and constructing a target loss function based on the first loss function and the second loss function, and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model.
The target loss function is used for measuring the sum of prediction losses of all sample support images and sample query images under a certain training task.
Optionally, based on the target loss function, calculating a sum of a prediction loss of the sample support image and a prediction loss of the sample query image in each training task, taking the sum as a total loss of the semantic segmentation model in the training task, and updating parameters of the whole semantic segmentation model in a training stage according to the total loss. Specifically, by calculating the minimum value of the target loss function, the sample mask prediction result is made to approach the sample mask real result as much as possible, so as to learn the parameters of the support image classification layer and the query image classification layer, and simultaneously, the gradient can be reversely propagated to the feature extraction network, and the parameters of the feature extraction network are updated, so that the extracted features are more effective. In the model training stage, the parameter values of the model are updated once each training task is performed, and the trained semantic segmentation model is obtained based on all the training tasks.
In the embodiment, the self-segmentation loss of the sample support image is added on the basis of the prediction loss of the sample query image and is used as the total loss of the semantic segmentation model under each training task, so that prototype alignment is realized in the training stage, and the accuracy of the semantic segmentation of the loan monitoring image is improved under the condition of only a small amount of sample data.
The semantic segmentation model training method is based on a sample support image in a training task and a mask label corresponding to the sample support image, and class prototype vectors corresponding to all image types are obtained; performing semantic segmentation on each sample image based on the class prototype vector to obtain a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; constructing a target loss function based on a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model. The whole semantic segmentation model training process only needs a small number of training samples, for each training task, class prototype vectors corresponding to all image types in the training task are obtained through a sample support image and a corresponding mask label, then the sample support image and a sample query image are segmented based on the class prototype vectors, finally a target loss function combining sample support image loss and sample query image loss is constructed, and the accuracy of the semantic segmentation of the loan monitoring image is guaranteed under the condition of only a small number of sample data.
In one embodiment, determining a class prototype vector corresponding to each image type based on a sample supporting image of each image type and a mask label corresponding to the sample supporting image includes:
extracting a multi-channel feature map of each sample support image aiming at each sample support image of the current image type;
acquiring a class prototype vector corresponding to each sample support image based on a multi-channel feature map of each sample support image and a mask label corresponding to each sample support image;
and integrating the class prototype vectors corresponding to the support images of all samples to obtain the class prototype vector corresponding to the current image type.
Wherein, the current image type refers to a sample type currently used for training the semantic segmentation model.
In the training stage of the semantic segmentation model, all samples under each training task are used for supporting an image I i ∈R 3×h×w And its corresponding mask Y i ∈{0,1} h×w As input, a sample support image I is extracted through a feature extraction network i Obtaining a feature map F containing the deep features of the sample i '∈R c×h′×w′ . Wherein R is 3×h×w Represents the size of the sample support image, 3 represents that the sample support image has three channels, and h and w represent the height and width of the sample support image, respectively; the mask of the sample support image is a foreground mask with a mask label of 0 or 1; r c×h′×w′ The size of the feature map is shown, c is the number of channels in the feature map, and h 'and w' respectively show the height and width of the feature map.
Further, for each sample support image under the current image type, obtaining a feature map F containing deep features of the sample i '∈R c×h'×w' Then, by means of bilinear interpolation, the feature map F is processed i ' upsampled to and sample support image mask Y i Same size, feature F i ' after adjustment to F i ∈R c×h×w . Based on F i And Y i Obtaining a first type prototype vector through mask averaging pooling operation, wherein the mask averaging pooling operation specifically comprises the steps of averaging pixels of a target area on a feature map according to channels to obtain the ith element v of the first type prototype vector v i ,v i The specific calculation method is as follows:
Figure BDA0003994622830000111
wherein Y is x,y Refers to the x-th row and y-th column positions of the sample support image mask, F i,x,y Refers to the position of the x row and the y column of the ith channel image of the adjusted sample support image feature map, v i Is the i-th element of the first type prototype vector v.
Generally, the first type prototype vector obtained by the mask averaging pooling operation is used as a prototype of the category i to guide the segmentation of the query image, however, the integrity of the sample foreground feature extraction can be influenced by the up-sampling operation process, and in order to obtain rich foreground features, the present application bases the up-sampling operation branchAnd increasing a downsampling operation branch. For each sample support image in the current image type, downsampling the sample support image mask to the same size as the sample support image, and obtaining a second type prototype vector v based on the mask average pooling operation i ' to the above v i Averaging the vector v corresponding to the v to finally obtain a class prototype vector v of the sample support image for guiding the segmentation of the query image ave ∈R c×1 Wherein R is c×1 And representing the size of the sample support image, wherein the dimension of the class prototype vector is the channel number c of the sample support image. Averaging the class prototype vectors corresponding to the support images of all samples again for integration to obtain the class prototype vector V E R corresponding to the current image type c×1
In the embodiment, based on an improved prototype fusion strategy, on the basis of traditional prototype learning, down-sampling operation branches are added for improving the integrity of foreground feature extraction, and the class prototype vectors corresponding to each sample support image in the current image type are averaged and integrated, so that the accuracy of the semantic segmentation of the loan monitoring image is improved under the condition of only a small amount of sample data.
In one embodiment, obtaining a class prototype vector corresponding to each sample support image based on a multi-channel feature map of each sample support image and a mask label corresponding to each sample support image includes:
the method comprises the steps that multi-channel feature maps of a current sample support image of a current image type are subjected to up-sampling processing, and based on mask labels corresponding to the current sample support image, first foreground feature maps in first mask regions corresponding to the mask labels in up-sampling processing results of multiple channels are extracted;
acquiring a first type of prototype feature of each channel based on the first foreground feature map of each channel, and constructing a first type of prototype vector corresponding to the current sample support image based on the first type of prototype feature of each channel;
performing downsampling processing on a mask label corresponding to the current sample support image, and extracting a second foreground feature map in a second mask area corresponding to a downsampling result from a multi-channel feature map of the current sample support image based on a downsampling result;
acquiring a second type of prototype feature of each channel based on the second foreground feature map of each channel, and constructing a second type of prototype vector corresponding to the current sample support image based on the second type of prototype feature of each channel;
and integrating the first type prototype vector and the second type prototype vector to obtain a type prototype vector corresponding to the current sample support image.
The first mask area refers to an area of which a mask label corresponding to the current sample support image is 1; the second mask region refers to a region where a corresponding mask label is 1 after a segmentation mask corresponding to the current sample support image is subjected to downsampling operation.
The method obtains a plurality of training tasks based on a training set, and supports all samples under each training task to be an image I i ∈R 3×h×w And a segmentation mask Y corresponding to all sample support images i ∈{0,1} h×w Sample support image I as input to a semantic segmentation model i ∈R 3×h×w After deep features are extracted through a VGG-16 feature extraction network, a feature graph F corresponding to the sample support image is obtained i '∈R c×h′×w′ . Supporting the sample on the feature map F of the image i '∈R c×h′×w′ And a segmentation mask Y corresponding to the sample support image i ∈{0,1} h × w Inputting the data into a prototype fusion module to extract class prototype vectors V epsilon R corresponding to each image type under each training task c×1
And the prototype fusion module extracts class prototype vectors corresponding to the image types through the up-sampling operation branch and the down-sampling operation branch. In the up-sampling operation branch, the sample is supported to the multi-channel feature map F of the image i '∈R c×h′×w′ On-channel upsampling to and sample support image mask Y i ∈{0,1} h×w Extracting a first foreground feature map in a first mask area of a sample support image mask in an up-sampling processing result according to channels with the same size, thereby obtaining a first type of prototype features of each channel, wherein the first type of prototype features are to be extractedAveraging the first type prototype features of all channels to construct a first type prototype vector v corresponding to the current sample support image i
In the downsampling operation branch, the sample is supported by an image mask Y i ∈{0,1} h×w Multi-channel feature map F of downsampled-and-sample-support images i '∈R c×h′×w′ Extracting a second foreground feature map in a second mask area corresponding to a downsampling result in the multi-channel feature map of the current sample support image according to the channels with the same size, thereby obtaining a second type of prototype feature of each channel, averaging the second type of prototype features of all the channels, and constructing a second type of prototype vector v corresponding to the current sample support image i '. For the first type prototype vector v i And a second type prototype vector v i And averaging to obtain a class prototype vector corresponding to the current sample support image.
In the embodiment, a first type prototype vector corresponding to a current sample support image is obtained through an up-sampling operation branch; and obtaining a second type prototype vector corresponding to the current sample support image through a down-sampling operation branch, obtaining a type prototype vector corresponding to the current sample support image through averaging the first type prototype vector and the second type prototype vector, and improving the accuracy of the semantic segmentation of the loan monitoring image under the condition of only a small amount of sample data based on an improved prototype fusion strategy.
In one embodiment, the sample image is a sample support image; based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the mask prediction result comprises the following steps:
extracting a multi-channel feature map of a current sample support image of the current image type;
based on the similar prototype vector of the current image type, obtaining a similarity feature map of the current sample support image, wherein the similarity feature map is used for representing the similarity degree between the multi-channel feature map of the current sample support image and the similar prototype vector of the current image type;
performing feature enhancement processing on the multi-channel feature map of the current sample support image based on the similarity feature map to obtain enhanced features;
and performing void space convolution pooling on the enhanced features, and performing convolution classification on the processing result to obtain a mask prediction result corresponding to the current sample support image.
For an N-way K-shot training task, N categories are shown under the training task, K sample support images are arranged under each category, and for the current sample support image under the current image type, a multi-channel feature map F containing deep features is extracted through a VGG-16 feature extraction network i '∈R c×h′×w′ Obtaining a class prototype vector V epsilon R corresponding to the type of the current sample support image through a prototype fusion module c×1 Feature map F corresponding to the support image based on the prototype vector and the current sample i ' feature enhancement processing is performed on the sample support image.
Specifically, the calculation support image I i Feature graph F obtained by a feature extraction module i ' cosine distance from the class prototype vector V of the class, resulting in a similarity map that supports both images and class prototypes. The similar graph is used for representing the relationship weight between the deep features of the current sample support image and the corresponding category prototypes, and the larger the relationship is, the larger the corresponding weight value is. Multiplying the weight value to a feature map F of the sample support image i ' above, enhanced features of the target area are obtained. And performing addition fusion operation on the enhanced features of the target area and the features extracted by the feature extraction network, and inputting the fused features into the ASPP and the two convolution class layers so as to obtain a mask prediction result of the current sample support image.
The ASPP comprises four branches, namely a 3 × 3 convolutional layer with a void ratio of 6, a 3 × 3 convolutional layer with a void ratio of 12, a 3 × 3 convolutional layer with a void ratio of 18 and a 1 × 1 convolutional layer, and the characteristics of the four branches are added to obtain a fusion characteristic. ASPP is followed by two volume integral class layers to obtain final prediction mask Y i p Wherein the first convolution has a size of 3X 3, 256 channels, and a ReLU layer added after the convolution. The convolution kernel size of the last convolution layer is 1 × 1, and the number of channels is the sum of the number of object classes and the number of backgrounds.
In this embodiment, a similarity graph for characterizing a relationship weight between a current sample support image and a corresponding class prototype is obtained, the similarity graph is applied to a feature graph of the sample support image to obtain a feature after enhancement of a target region, and a mask prediction result of the current sample support image is obtained subsequently through an ASPP and two volume integration class layers. The similar graph serves as an attention guide graph, and the accuracy of semantic segmentation of the loan monitoring image is improved under the condition that only a small amount of sample data exists.
In one embodiment, as shown in FIG. 3, the sample image is a sample query image; based on the class prototype vector corresponding to each image type and the sample support image of each image type, performing semantic segmentation on the sample image of each image type through a semantic segmentation model to obtain a mask prediction result corresponding to the sample image of each image type, wherein the mask prediction result comprises the following steps:
s301: aiming at a current sample query image of a current image type, extracting a multichannel feature map of the current sample query image, and extracting a multichannel feature map of each sample support image of each image type;
s302: performing down-sampling processing on the mask label corresponding to each sample support image of each image type, and extracting a foreground feature map corresponding to each sample support image of each image type based on a down-sampling result corresponding to each sample support image of each image type;
s303: performing attention guide processing based on a multi-channel feature map of a current sample query image and foreground feature maps corresponding to sample support images of all image types to obtain an attention guide processing result;
s304: based on the class prototype vectors of all image types, obtaining a similarity feature map of the current sample query image aiming at all the class prototype vectors, wherein the similarity feature map is used for representing the similarity degree between a multi-channel feature map of the current sample query image and the corresponding class prototype vectors;
s305: based on the similarity feature maps aiming at various prototype vectors, performing feature enhancement processing on the multi-channel feature map of the current sample query image to obtain enhanced features;
s306: and performing void space convolution pooling on the enhanced features, and performing convolution classification on processing results to obtain a mask prediction result corresponding to the current sample query image.
According to the method, M sample query images are set under N types of N-way K-shot training tasks. And for each training task, extracting a multichannel feature map of the current sample query image and a multichannel feature map of each sample support image of each image type under the training task through a VGG-16 feature extraction network. And for the feature map of each sample support image, acquiring a foreground feature map corresponding to each sample support image through a down-sampling branch, inputting the foreground feature map corresponding to each sample support image and the multichannel feature map of the current sample query image into an attention guide module, and performing attention guide processing. And calculating the similarity relation between each pixel point in the sample query image and all foreground pixel points of the sample support image.
The attention guiding processing is realized in two steps, the realization process of the first step is based on an attention mechanism, and the input inquires an image for the sample
Figure BDA0003994622830000151
Is determined by a multi-channel characteristic map->
Figure BDA0003994622830000152
And foreground feature map F corresponding to each sample support image m ' ask ,F m ' ask Is a feature map F supporting images from each sample i ' with each sample to support the foreground mask Y of the image i The corresponding point is multiplied, and the specific implementation principle can be expressed as:
Figure BDA0003994622830000153
wherein the content of the first and second substances,
Figure BDA0003994622830000154
is->
Figure BDA0003994622830000155
Value at the i-th spatial position, its dimension and->
Figure BDA0003994622830000156
The number of channels is consistent; x is the number of j As a foreground feature F m ' ask Value at the jth spatial position, its dimension and F m ' ask The number of channels is consistent; f (-) is a similarity calculation function used to calculate query feature->
Figure BDA0003994622830000157
And foreground feature x j Similarity of (2); g (-) is a mapping function to map the foreground feature F m ' ask Each point in (2) is mapped into a vector; c (x) represents a normalization coefficient. For each sample query image, comparing the query image with the class prototypes of all image types under the current training task respectively to perform forward feature enhancement operation, then calculating the maximum value of feature response at each point on the feature map, and finally outputting the attention feature map ^ based on the maximum value>
Figure BDA0003994622830000158
Second step, calculating class prototype vector V and query image
Figure BDA0003994622830000159
Characteristic map of>
Figure BDA00039946228300001510
Cosine distance between the two images, obtaining similarity graphs of the sample support image and the sample query image, and applying the similarity graphs to the attention feature graph->
Figure BDA00039946228300001511
The specific way is to multiply the weight value for measuring the similarity between the query feature and the category prototype to ^ greater than or equal to>
Figure BDA00039946228300001512
And obtaining the enhanced characteristics of the target area of the query image. Will be/are>
Figure BDA00039946228300001513
And &>
Figure BDA00039946228300001514
And performing addition fusion operation, and inputting the fused features into an ASPP (atomic Spatial Pyramid Pooling) and two volume score class layers, so as to obtain a mask code prediction result of the current sample query image.
In the embodiment, for each sample query image in a training task set, firstly, based on an attention mechanism, the query image is guided by foreground features of a support image, feature enhancement is realized in a target area, and an attention feature map is output; then, based on the similarity relation between the class prototype vector and the deep features of the query image, obtaining a similar graph of the sample support image and the sample query image; and finally, applying the similarity graph to the attention feature graph to obtain a target region enhancement feature graph, inputting the feature fusion into an ASPP classification layer and two convolution classification layers, and obtaining a mask prediction result of the current sample query image. Through the attention guiding module, information interaction between the sample query image and the sample support image is enhanced, and the correlation of the sample query image and the sample support image in the segmentation process is enhanced, so that the sample query image is better segmented. The method improves the accuracy of semantic segmentation of the loan monitoring image under the condition of only a small amount of sample data.
In an embodiment, the training method of the semantic segmentation model further includes:
after the trained semantic segmentation model is obtained, a target image set to be subjected to semantic segmentation is obtained, the type of a target image corresponding to the target image set is not involved in a training process aiming at the semantic segmentation model, and the target image set comprises at least one target support image and at least one target query image;
determining a class prototype vector corresponding to the type of the target image based on the target support image and the mask label corresponding to the target support image;
performing semantic segmentation on the target support image through the trained semantic segmentation model based on the target support image and the class prototype vector corresponding to the type of the target image to obtain a mask prediction result corresponding to the target support image;
constructing a third loss function based on the difference between the mask prediction result corresponding to the target support image and the corresponding mask label, and retraining the trained semantic segmentation model based on the third loss function;
and performing semantic segmentation on the target query image based on the retrained semantic segmentation model.
From the test set
Figure BDA0003994622830000161
(I j Representing an image, Y j Representing the segmentation mask corresponding to the image, and m is the number of images in the test set), wherein the test tasks include a target image set of at least one image type, and the target image set is not intersected with the image category of the sample image set. Each test task also comprises a support image set and a query image set, wherein the support image set comprises at least one target support image, and the query image set comprises at least one target query image.
And inputting all target images under each test task and segmentation masks corresponding to all the target images into the trained semantic segmentation model, so that the semantic segmentation model performs semantic segmentation on new image types under the new test tasks. For the current testing task, in order to enable the semantic segmentation model to extract more representative class prototypes under the task, the network is further optimized by using the prediction loss of the target support image of the new class so as to adapt to the invisible class.
Specifically, for each test task, as in the training phase, all target support images under each test task are input to the feature extraction network to obtain deep features of the target support images, the deep features of the target support images are input to the prototype fusion module, and prototype vectors corresponding to each image type in the target support images are obtained through the deep features of the target support images and foreground mask labels corresponding to the target support images.
Based on the target support image and the class prototype vectors corresponding to the types of the target images, the target support image of each image type is segmented through a semantic segmentation model, a similarity graph between deep features of the target support image and the class prototypes of the class is calculated, the similarity graph is applied to the deep features of the target support image to perform feature enhancement on the target support image, and then mask prediction results of the target support image are obtained through ASPP and two convolution classification layers.
And constructing a cross entropy loss function for measuring the predicted value of the target support image mask and the true value of the target support image mask as a third loss function. In the testing task, the prediction loss of the testing stage is only used for updating the classification layer of the model, and the parameters of other layers are fixed and are obtained by learning in the training stage. The updated model is used for segmenting the target query image under the test task, and finally the prediction mask of each target query image is obtained.
In this embodiment, the trained semantic segmentation model is used as a basic model to segment new image types in a target image set, in order to extract a more representative class prototype from the model, parameters of the semantic segmentation model are further adjusted by using the prediction loss of a new class target support image to adapt to an invisible class, and the accuracy of the semantic segmentation of the loan monitoring image is improved by performing adaptive adjustment on the model parameters on the basis of the trained semantic segmentation model under the condition of only a small amount of sample data.
In one embodiment, the training method of the semantic segmentation model further includes:
the sample image set and the target image set are obtained by shooting the loan object for pre-loan investigation or post-loan monitoring.
The sample image set and the target image set both contain at least one image type, and the type of the sample image set is not coincident with the category of the target image set.
Usually, business personnel compare images of credit items at different periods to perform pre-credit background investigation and post-credit real-time monitoring. However, in the loan field of road maintenance, farmland planting and the like, the problems of difficult background investigation and few effective values of monitoring data exist in remote areas, and only a small amount of sample data can be obtained. Therefore, investigation pictures shot in pre-loan investigation or post-loan monitoring need to be divided into a sample image set and a target image set, a semantic segmentation model is built to segment loan monitoring images, and business personnel are assisted in risk analysis before and after loan by comparing segmentation results of different periods.
Specifically, a semantic segmentation model may be trained based on a sample image set including two types (e.g., a road and a farmland), after the semantic segmentation model is trained according to the sample image set, a new image type (e.g., a lake and a forest) in the target image set is segmented by using the trained semantic segmentation model as a base model, and parameters of the semantic segmentation model are further adjusted by using prediction loss in the new image type.
In this embodiment, for the situation that only a small amount of sample data can be obtained in the pre-loan background investigation and the post-loan real-time monitoring in the remote area, images obtained in the pre-loan investigation or the post-loan monitoring are divided into a sample image set and a target image set which are of different types. And establishing a semantic segmentation model based on the sample image set and the target image set, improving the accuracy of the semantic segmentation of the loan monitoring image, and assisting business personnel to carry out risk analysis before and after loan by comparing segmentation results of different periods.
To explain the technical solution of the semantic segmentation model training method in detail, the following describes the whole processing process with specific application examples and with reference to fig. 4 and 5, and specifically includes the following steps:
1. from the training set
Figure BDA0003994622830000171
And extracting a plurality of training tasks, wherein each training task consists of a support image set and a query image set. The training phase of the model is one ofThe 2-way 3-shot training task is taken as an example for explanation, and other training tasks are the same as the training process of the training task. The training task has two types of roads and farmlands, each type has 3 sample supporting images and 5 sample query images, and the total number of the sample images of the task is 16. Wherein, the picture of each sample query image can be of two types, namely a road and a farmland, or only one type. All sample images under the training task and segmentation masks corresponding to the sample images are input into a semantic segmentation model for training, wherein the segmentation masks can be foreground masks with labels of 0 or 1.
2. And determining a class prototype vector corresponding to each image type under the task aiming at the sample support image under the task and the mask label corresponding to the sample support image. Specifically, the following processes are carried out:
a) Supporting the sample with an image I i ∈R 3×400×400 And its corresponding mask Y i ∈{0,1} 400×400 As model input, firstly, a sample support image I is extracted through a feature extraction network VGG-16 i ∈R 3×400×400 Obtaining a feature map F containing the deep features of the sample i '∈R 512×345×345 A feature map F containing deep features of the sample i ' AND sample support image mask Y i And inputting the data into a prototype fusion module.
b) And the prototype fusion module extracts a class prototype vector corresponding to the sample support image according to the image type through the up-sampling operation branch and the down-sampling operation branch. In the up-sampling operation branch, the sample is supported to the multi-channel feature map F of the image i '∈R 512×345×345 On-channel upsampling to and sample support image mask Y i ∈{0,1} 400×400 Obtaining the first type prototype features of each channel, averaging the first type prototype features of all the channels, and constructing a first type prototype vector v corresponding to the current sample support image i . In the downsampling operation branch, the sample is supported by an image mask Y i ∈{0,1} 400×400 Multi-channel feature map F of downsampled-and-sample-support images i '∈R 512×345×345 Obtaining the second type prototype features of each channel, averaging the second type prototype features of all the channels, and constructing a second type prototype vector v corresponding to the current sample support image i '. For the first type prototype vector v i And a prototype vector v of the second type i Averaging to obtain class prototype vectors corresponding to the current sample support image, averaging the class prototype vectors corresponding to the sample support images again according to the image types for integration, and obtaining the class prototype vectors corresponding to the two image types of the road and the farmland under the training task, wherein the class prototype vectors are V e R 512×1
3. And obtaining a mask prediction result corresponding to the sample support image under the task according to the class prototype vector and the sample support image corresponding to the two image types, and constructing a loss function. Specifically, the following processes are carried out:
a) Calculating a sample support image I i Feature graph F obtained by a feature extraction module i ' cosine distance with the class prototype vector V of the class, and obtaining a similarity graph of the sample support image and the class prototype. Applying the similarity graph to a feature graph F of a sample support image i And obtaining the enhanced features of the target area of the sample support image, performing addition fusion operation on the enhanced features and the deep features of the sample support image, inputting the enhanced features and the deep features of the sample support image into the ASPP and two rolling point classification layers, and obtaining a mask prediction result of each sample support image under two image types of a road and a farmland under the training task. And constructing a cross entropy loss function for measuring the sample support image mask predicted value and the sample support image mask truth value as a first loss function.
4. And obtaining a mask prediction result corresponding to the sample query image under the task according to the class prototype vector, the sample support image and the sample query image corresponding to the two image types, and constructing a loss function. Specifically, the following processes are carried out:
a) Inquiring the sample into the image
Figure BDA0003994622830000191
And its corresponding mask>
Figure BDA0003994622830000192
As model input, firstly extracting a sample query image I through a feature extraction network VGG-16 i ∈R 3×400×400 Obtaining a feature map comprising deep characteristics of the sample>
Figure BDA0003994622830000193
Feature map F containing deep features of sample i ' with the foreground feature map F of the sample support image m ' ask Input to the attention guidance module. Wherein, the sample supports the corresponding foreground characteristic map F of the image m ' ask Is a feature map F supporting images from each sample i ', with mask Y of each sample support image i And multiplying corresponding points of the branch of the down-sampling operation.
b) Referring to fig. 5, the attention guidance module first calculates the similarity between the deep features of the query image and the foreground features of the sample support image based on the attention mechanism, compares each sample query image with the class prototypes of the road and the farmland to perform the forward feature enhancement operation, then finds the maximum value of the feature response at each point on the feature map, and finally outputs the attention feature map
Figure BDA0003994622830000194
Then a prototype-like vector V and a query image +are calculated>
Figure BDA0003994622830000195
Characteristic diagram of
Figure BDA0003994622830000196
Cosine distance between the two images, obtaining a similar shape graph of the sample support image and the sample query image, and applying the similar shape graph to the attention feature map->
Figure BDA0003994622830000197
Up and will->
Figure BDA0003994622830000198
And F i And performing addition fusion operation, and inputting the fused features into the ASPP and the two volume integral class layers so as to obtain a mask prediction result of the current sample query image. And constructing a cross entropy loss function for measuring the predicted value of the sample query image mask and the true value of the sample query image mask as a second loss function. />
5. And constructing a target loss function based on the first loss function and the second loss function, and enabling a sample mask prediction result to be as close to a sample mask real result as possible by calculating a minimum value of the target loss function so as to learn parameters of a support image classification layer and a query image classification layer, and simultaneously enabling the gradient to be reversely propagated to a feature extraction network, and updating the parameters of the feature extraction network so as to enable the extracted features to be more effective. In the model training stage, the parameter values of the model are updated once each training task is performed, and the trained semantic segmentation model is obtained based on all the training tasks.
6. From the test set
Figure BDA0003994622830000199
(I j Representing an image, Y j Representing the segmentation mask corresponding to the image, and m is the number of images in the test set), the test tasks include a target image set of at least one image type, and the target image set is not intersected with the image category of the sample image set. Each test task is also composed of a set of support images and a set of query images. The test stage of the model is illustrated by taking one 2-way 3-shot training task as an example, and the test task comprises two types of lakes and forests. The semantic segmentation model does not learn a new class in the testing task, and in order to enable the semantic segmentation model to extract a more representative class prototype under the task, the network is further optimized by using the prediction loss of a target support image of the new class so as to adapt to the invisible class. Specifically, the following treatment is carried out:
a) And aiming at each test task, inputting all target support images under each test task into a feature extraction network as in a training stage to obtain deep features of the target support images, inputting the deep features of the target support images into a prototype fusion module, and obtaining prototype vectors corresponding to each image type in the target support images through the deep features of the target support images and foreground mask labels corresponding to the target support images.
b) Based on the target support image and the class prototype vectors corresponding to the types of the target images, the target support image of each image type is segmented through a semantic segmentation model, a similarity graph between deep features of the target support image and the class prototypes of the types is calculated, the similarity graph is applied to the deep features of the target support image to perform feature enhancement on the target support image, and then mask prediction results of the target support image are obtained through ASPP and two convolution classification layers.
c) And constructing a cross entropy loss function for measuring the target support image mask predicted value and the target support image mask truth value as a third loss function. In the testing task, the prediction loss in the testing stage is only used for updating the classification layer of the model, and the parameters of other layers are fixed and obtained by learning in the training stage. The updated model is used for segmenting the target query image under the test task, and finally the prediction mask of each target query image is obtained.
It should be understood that, although the steps in the flowcharts related to the embodiments as described above are sequentially displayed as indicated by arrows, the steps are not necessarily performed sequentially as indicated by the arrows. The steps are not performed in the exact order shown and described, and may be performed in other orders, unless explicitly stated otherwise. Moreover, at least a part of the steps in the flowcharts related to the embodiments described above may include multiple steps or multiple stages, which are not necessarily performed at the same time, but may be performed at different times, and the execution order of the steps or stages is not necessarily sequential, but may be rotated or alternated with other steps or at least a part of the steps or stages in other steps.
Based on the inventive concept of the semantic segmentation model training method, as shown in fig. 6, the embodiment of the present application further provides a semantic segmentation model training device for implementing the above-mentioned semantic segmentation model training method. The device comprises:
a task obtaining module 601, configured to obtain training tasks, where each training task includes a sample image set of at least one image type, and the sample image set includes at least one sample support image and at least one sample query image;
a prototype fusion module 602, configured to determine a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
the mask prediction module 603 is configured to perform semantic segmentation on the sample image of each image type through a semantic segmentation model based on a class prototype vector corresponding to each image type and the sample support image of each image type, to obtain a mask prediction result corresponding to the sample image of each image type, where the sample image is a sample support image or a sample query image;
a loss constructing module 604, configured to construct a first loss function based on a difference between a mask prediction result corresponding to a sample support image of each image type and a corresponding mask tag, and construct a second loss function based on a difference between a mask prediction result corresponding to a sample query image of each image type and a corresponding mask tag;
and the model optimization module 605 is configured to construct a target loss function based on the first loss function and the second loss function, and train the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model.
The semantic segmentation model training device acquires class prototype vectors corresponding to image types based on sample support images and mask labels corresponding to the sample support images in a training task; performing semantic segmentation on each sample image based on the class prototype vector to obtain a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; constructing a target loss function based on a mask prediction result corresponding to the sample support image and a mask prediction result corresponding to the sample query image; and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model. The whole semantic segmentation model training process only needs a small number of training samples, for each training task, class prototype vectors corresponding to all image types in the training task are obtained through a sample support image and a corresponding mask label, then the sample support image and a sample query image are segmented based on the class prototype vectors, finally a target loss function combining sample support image loss and sample query image loss is constructed, and the accuracy of the semantic segmentation of the loan monitoring image is guaranteed under the condition of only a small number of sample data. The whole semantic segmentation model training process only needs a small number of training samples, for each training task, class prototype vectors corresponding to all image types in the training task are obtained through a sample support image and a corresponding mask label, then the sample support image and a sample query image are segmented based on the class prototype vectors, finally a target loss function combining sample support image loss and sample query image loss is constructed, and the accuracy of the semantic segmentation of the loan monitoring image is guaranteed under the condition of only a small number of sample data.
In one embodiment, the prototype fusion module 602 is further configured to extract, for each sample support image of the current image type, a multi-channel feature map of each sample support image; acquiring a class prototype vector corresponding to each sample support image based on a multi-channel feature map of each sample support image and a mask label corresponding to each sample support image; and integrating the class prototype vectors corresponding to the support images of all samples to obtain the class prototype vector corresponding to the current image type.
In one embodiment, the prototype fusion module 602 is further configured to, for a current sample support image of a current image type, perform upsampling processing on the multi-channel feature map of the current sample support image, and extract, based on a mask label corresponding to the current sample support image, a first foreground feature map in a first mask region corresponding to the mask label in the upsampling processing results of the multiple channels; acquiring a first type of prototype feature of each channel based on the first foreground feature map of each channel, and constructing a first type of prototype vector corresponding to the current sample support image based on the first type of prototype feature of each channel; performing downsampling processing on a mask label corresponding to the current sample support image, and extracting a second foreground feature map in a second mask area corresponding to a downsampling result from a multi-channel feature map of the current sample support image based on a downsampling result; acquiring a second type of prototype feature of each channel based on the second foreground feature map of each channel, and constructing a second type of prototype vector corresponding to the current sample support image based on the second type of prototype feature of each channel; and integrating the first type prototype vector and the second type prototype vector to obtain a type prototype vector corresponding to the current sample support image.
In one embodiment, the mask prediction module 603 is further configured to extract, for a current sample support image of the current image type, a multi-channel feature map of the current sample support image; based on the similar prototype vector of the current image type, obtaining a similarity feature map of the current sample support image, wherein the similarity feature map is used for representing the similarity degree between the multi-channel feature map of the current sample support image and the similar prototype vector of the current image type; performing feature enhancement processing on the multi-channel feature map of the current sample support image based on the similarity feature map to obtain enhanced features; and performing void space convolution pooling on the enhanced features, and performing convolution classification on the processing result to obtain a mask prediction result corresponding to the current sample support image.
In one embodiment, the mask prediction module 603 is further configured to, for a current sample query image of a current image type, extract a multi-channel feature map of the current sample query image, and extract a multi-channel feature map of each sample support image of each image type; performing down-sampling processing on the mask label corresponding to each sample support image of each image type, and extracting a foreground feature map corresponding to each sample support image of each image type based on a down-sampling result corresponding to each sample support image of each image type; performing attention guide processing based on a multi-channel feature map of a current sample query image and foreground feature maps corresponding to sample support images of all image types to obtain an attention guide processing result; based on the class prototype vectors of all image types, obtaining a similarity feature map of the current sample query image aiming at all the class prototype vectors, wherein the similarity feature map is used for representing the similarity degree between a multi-channel feature map of the current sample query image and the corresponding class prototype vectors; based on the similarity feature maps aiming at various prototype vectors, performing feature enhancement processing on the multi-channel feature map of the current sample query image to obtain enhanced features; and performing void space convolution pooling on the enhanced features, and performing convolution classification on the processing result to obtain a mask prediction result corresponding to the current sample query image.
In one embodiment, the model optimization module 605 is further configured to, after obtaining the trained semantic segmentation model, obtain a target image set to be subjected to semantic segmentation, where a target image type corresponding to the target image set is not involved in a training process for the semantic segmentation model, and the target image set includes at least one target support image and at least one target query image; determining a class prototype vector corresponding to the type of the target image based on the target support image and the mask label corresponding to the target support image; performing semantic segmentation on the target support image through the trained semantic segmentation model based on the target support image and the class prototype vector corresponding to the type of the target image to obtain a mask prediction result corresponding to the target support image; constructing a third loss function based on the difference between the mask prediction result corresponding to the target support image and the corresponding mask label, and retraining the trained semantic segmentation model based on the third loss function; and performing semantic segmentation on the target query image based on the retrained semantic segmentation model.
In one embodiment, the task obtaining module 601 is further configured to obtain a sample image set and a target image set, where the sample image set and the target image set are obtained by performing a pre-loan investigation or a post-loan monitoring on a loan object.
The modules in the semantic segmentation model training device can be wholly or partially implemented by software, hardware and a combination thereof. The modules can be embedded in a hardware form or independent from a processor in the computer device, and can also be stored in a memory in the computer device in a software form, so that the processor can call and execute operations corresponding to the modules.
In one embodiment, a computer device is provided, which may be a server, and the internal structure thereof may be as shown in fig. 7. The computer device comprises a processor, a memory, an Input/Output (I/O) interface and a communication interface. The processor, the memory and the input/output interface are connected through a system bus, and the communication interface is connected to the system bus through the input/output interface. Wherein the processor of the computer device is configured to provide computing and control capabilities. The memory of the computer device includes a non-volatile storage medium and an internal memory. The non-volatile storage medium stores an operating system, a computer program, and a database. The internal memory provides an environment for the operation of an operating system and computer programs in the non-volatile storage medium. The database of the computer device is used for storing semantic segmentation model training data. The input/output interface of the computer device is used for exchanging information between the processor and an external device. The communication interface of the computer device is used for connecting and communicating with an external terminal through a network. The computer program is executed by a processor to implement a semantic segmentation model training method.
Those skilled in the art will appreciate that the architecture shown in fig. 7 is merely a block diagram of some of the structures associated with the disclosed aspects and is not intended to limit the computing devices to which the disclosed aspects apply, as particular computing devices may include more or less components than those shown, or may combine certain components, or have a different arrangement of components.
In one embodiment, a computer device is provided, comprising a memory and a processor, the memory having stored therein a computer program, the processor implementing the steps of the above-described method embodiments when executing the computer program.
In an embodiment, a computer-readable storage medium is provided, on which a computer program is stored which, when being executed by a processor, carries out the steps of the above-mentioned method embodiments.
In an embodiment, a computer program product is also provided, comprising a computer program which, when being executed by a processor, carries out the steps of the above-mentioned method embodiments.
It should be noted that, the user information (including but not limited to user equipment information, user personal information, etc.) and data (including but not limited to data for analysis, stored data, displayed data, etc.) referred to in the present application are information and data authorized by the user or sufficiently authorized by each party, and the collection, use and processing of the related data need to comply with the relevant laws and regulations and standards of the relevant country and region.
It will be understood by those skilled in the art that all or part of the processes of the methods of the embodiments described above can be implemented by hardware instructions of a computer program, which can be stored in a non-volatile computer-readable storage medium, and when executed, can include the processes of the embodiments of the methods described above. Any reference to memory, databases, or other media used in the embodiments provided herein can include at least one of non-volatile and volatile memory. The nonvolatile Memory may include Read-Only Memory (ROM), magnetic tape, floppy disk, flash Memory, optical Memory, high-density embedded nonvolatile Memory, resistive Random Access Memory (ReRAM), magnetic Random Access Memory (MRAM), ferroelectric Random Access Memory (FRAM), phase Change Memory (PCM), graphene Memory, and the like. Volatile Memory can include Random Access Memory (RAM), external cache Memory, and the like. By way of illustration and not limitation, RAM can take many forms, such as Static Random Access Memory (SRAM) or Dynamic Random Access Memory (DRAM), for example. The databases involved in the embodiments provided herein may include at least one of relational and non-relational databases. The non-relational database may include, but is not limited to, a block chain based distributed database, and the like. The processors referred to in the embodiments provided herein may be general purpose processors, central processing units, graphics processors, digital signal processors, programmable logic devices, quantum computing based data processing logic devices, etc., without limitation.
The technical features of the above embodiments can be arbitrarily combined, and for the sake of brevity, all possible combinations of the technical features in the above embodiments are not described, but should be considered as the scope of the present specification as long as there is no contradiction between the combinations of the technical features.
The above-mentioned embodiments only express several embodiments of the present application, and the description thereof is more specific and detailed, but not construed as limiting the scope of the present application. It should be noted that, for a person skilled in the art, several variations and modifications can be made without departing from the concept of the present application, which falls within the scope of protection of the present application. Therefore, the protection scope of the present application shall be subject to the appended claims.

Claims (11)

1. A semantic segmentation model training method, the method comprising:
acquiring training tasks, wherein the training tasks all comprise a sample image set of at least one image type, and the sample image set comprises at least one sample support image and at least one sample query image;
determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
performing semantic segmentation on the sample image of each image type through a semantic segmentation model based on a class prototype vector corresponding to each image type and the sample support image of each image type to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image;
constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label;
and constructing a target loss function based on the first loss function and the second loss function, and training the semantic segmentation model based on the target loss function to obtain the trained semantic segmentation model.
2. The method of claim 1, wherein determining the class prototype vector corresponding to each image type based on the sample support image and the mask label corresponding to the sample support image for each image type comprises:
extracting a multi-channel feature map of each sample support image aiming at each sample support image of the current image type;
acquiring a class prototype vector corresponding to each sample support image based on a multi-channel feature map of each sample support image and a mask label corresponding to each sample support image;
and integrating the class prototype vectors corresponding to the sample support images to obtain the class prototype vector corresponding to the current image type.
3. The method according to claim 2, wherein the obtaining of the prototype-like vector corresponding to each sample support image based on the multi-channel feature map of each sample support image and the mask label corresponding to each sample support image comprises:
for a current sample support image of the current image type, performing up-sampling processing on a multi-channel feature map of the current sample support image, and extracting a first foreground feature map in a first mask region corresponding to a mask tag in up-sampling processing results of multiple channels based on the mask tag corresponding to the current sample support image;
acquiring a first type of prototype feature of each channel based on the first foreground feature map of each channel, and constructing a first type of prototype vector corresponding to the current sample support image based on the first type of prototype feature of each channel;
performing downsampling processing on a mask label corresponding to the current sample support image, and extracting a second foreground feature map in a second mask area corresponding to a downsampling result in a multi-channel feature map of the current sample support image based on a downsampling result;
acquiring a second type of prototype feature of each channel based on the second foreground feature map of each channel, and constructing a second type of prototype vector corresponding to the current sample support image based on the second type of prototype feature of each channel;
and integrating the first type prototype vector and the second type prototype vector to obtain a type prototype vector corresponding to the current sample support image.
4. The method of claim 1, wherein the sample image is a sample support image; the semantic segmentation is performed on the sample image of each image type through a semantic segmentation model based on the class prototype vector corresponding to each image type and the sample support image of each image type to obtain a mask prediction result corresponding to the sample image of each image type, and the method comprises the following steps:
extracting a multi-channel feature map of a current sample support image of a current image type;
based on the class prototype vector of the current image type, obtaining a similarity feature map of the current sample support image, wherein the similarity feature map is used for representing the similarity degree between the multi-channel feature map of the current sample support image and the class prototype vector of the current image type;
performing feature enhancement processing on the multi-channel feature map of the current sample support image based on the similarity feature map to obtain enhanced features;
and performing void space convolution pooling on the enhanced features, and performing convolution classification on processing results to obtain a mask prediction result corresponding to the current sample support image.
5. The method of claim 1, wherein the sample image is a sample query image; the semantic segmentation is performed on the sample image of each image type through a semantic segmentation model based on the class prototype vector corresponding to each image type and the sample support image of each image type to obtain a mask prediction result corresponding to the sample image of each image type, and the method comprises the following steps:
aiming at a current sample query image of a current image type, extracting a multichannel feature map of the current sample query image, and extracting a multichannel feature map of each sample support image of each image type;
performing down-sampling processing on the mask label corresponding to each sample support image of each image type, and extracting a foreground feature map corresponding to each sample support image of each image type based on a down-sampling result corresponding to each sample support image of each image type;
performing attention guide processing based on the multi-channel feature map of the current sample query image and the foreground feature map corresponding to each sample support image of each image type to obtain an attention guide processing result;
based on the class prototype vectors of all image types, obtaining a similarity feature map of the current sample query image aiming at all types of prototype vectors, wherein the similarity feature map is used for representing the similarity between a multi-channel feature map of the current sample query image and the corresponding class prototype vectors;
based on the similarity feature maps of various prototype vectors, performing feature enhancement processing on the multi-channel feature map of the current sample query image to obtain enhanced features;
and performing void space convolution pooling on the enhanced features, and performing convolution classification on processing results to obtain a mask prediction result corresponding to the current sample query image.
6. The method of claim 1, further comprising:
after a trained semantic segmentation model is obtained, a target image set to be subjected to semantic segmentation is obtained, the type of a target image corresponding to the target image set is a type which does not participate in a training process aiming at the semantic segmentation model, and the target image set comprises at least one target support image and at least one target query image;
determining a class prototype vector corresponding to the target image type based on the target support image and a mask label corresponding to the target support image;
performing semantic segmentation on the target support image through a trained semantic segmentation model based on the target support image and a class prototype vector corresponding to the type of the target image to obtain a mask prediction result corresponding to the target support image;
constructing a third loss function based on the difference between the mask prediction result corresponding to the target support image and the corresponding mask label, and retraining the trained semantic segmentation model based on the third loss function;
and performing semantic segmentation on the target query image based on the retrained semantic segmentation model.
7. The method of claim 6, wherein the sample image set and the target image set are obtained from pre-loan surveys or post-loan surveys of loan objects.
8. An apparatus for training a semantic segmentation model, the apparatus comprising:
the task acquisition module is used for acquiring training tasks, wherein each training task comprises a sample image set of at least one image type, and the sample image set comprises at least one sample support image and at least one sample query image;
the prototype fusion module is used for determining a class prototype vector corresponding to each image type based on the sample support image of each image type and the mask label corresponding to the sample support image;
the mask prediction module is used for carrying out semantic segmentation on the sample image of each image type through a semantic segmentation model based on a class prototype vector corresponding to each image type and the sample support image of each image type to obtain a mask prediction result corresponding to the sample image of each image type, wherein the sample image is a sample support image or a sample query image;
the loss construction module is used for constructing a first loss function based on the difference between the mask prediction result corresponding to the sample support image of each image type and the corresponding mask label, and constructing a second loss function based on the difference between the mask prediction result corresponding to the sample query image of each image type and the corresponding mask label;
and the model optimization module is used for constructing a target loss function based on the first loss function and the second loss function, training the semantic segmentation model based on the target loss function and obtaining the trained semantic segmentation model.
9. A computer device comprising a memory and a processor, the memory storing a computer program, characterized in that the processor, when executing the computer program, implements the steps of the method of any of claims 1 to 7.
10. A computer-readable storage medium, on which a computer program is stored which, when being executed by a processor, carries out the steps of the method according to any one of claims 1 to 7.
11. A computer program product comprising a computer program, characterized in that the computer program realizes the steps of the method of any one of claims 1 to 7 when executed by a processor.
CN202211591464.5A 2022-12-12 2022-12-12 Semantic segmentation model training method and device, computer equipment and storage medium Pending CN115861617A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211591464.5A CN115861617A (en) 2022-12-12 2022-12-12 Semantic segmentation model training method and device, computer equipment and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211591464.5A CN115861617A (en) 2022-12-12 2022-12-12 Semantic segmentation model training method and device, computer equipment and storage medium

Publications (1)

Publication Number Publication Date
CN115861617A true CN115861617A (en) 2023-03-28

Family

ID=85672201

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211591464.5A Pending CN115861617A (en) 2022-12-12 2022-12-12 Semantic segmentation model training method and device, computer equipment and storage medium

Country Status (1)

Country Link
CN (1) CN115861617A (en)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116227573A (en) * 2023-04-25 2023-06-06 智慧眼科技股份有限公司 Segmentation model training method, image segmentation device and related media
CN116912638A (en) * 2023-09-13 2023-10-20 深圳金三立视频科技股份有限公司 Multi-data-set combined training method and terminal

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116227573A (en) * 2023-04-25 2023-06-06 智慧眼科技股份有限公司 Segmentation model training method, image segmentation device and related media
CN116227573B (en) * 2023-04-25 2023-08-08 智慧眼科技股份有限公司 Segmentation model training method, image segmentation device and related media
CN116912638A (en) * 2023-09-13 2023-10-20 深圳金三立视频科技股份有限公司 Multi-data-set combined training method and terminal
CN116912638B (en) * 2023-09-13 2024-01-12 深圳金三立视频科技股份有限公司 Multi-data-set combined training method and terminal

Similar Documents

Publication Publication Date Title
CN115861617A (en) Semantic segmentation model training method and device, computer equipment and storage medium
Zeng et al. Single image super-resolution using a polymorphic parallel CNN
CN110838125B (en) Target detection method, device, equipment and storage medium for medical image
CN112330684B (en) Object segmentation method and device, computer equipment and storage medium
CN112288011A (en) Image matching method based on self-attention deep neural network
CN112287954A (en) Image classification method, training method of image classification model and device thereof
Shu et al. LVC-Net: Medical image segmentation with noisy label based on local visual cues
Wu et al. A deep residual convolutional neural network for facial keypoint detection with missing labels
CN109754357B (en) Image processing method, processing device and processing equipment
CN115797781A (en) Crop identification method and device, computer equipment and storage medium
CN115861248A (en) Medical image segmentation method, medical model training method, medical image segmentation device and storage medium
Ye et al. An improved efficientNetV2 model based on visual attention mechanism: application to identification of cassava disease
Peretroukhin et al. Inferring sun direction to improve visual odometry: A deep learning approach
Kate et al. A 3 Tier CNN model with deep discriminative feature extraction for discovering malignant growth in multi-scale histopathology images
US20180114109A1 (en) Deep convolutional neural networks with squashed filters
Selvakumar et al. Automated mango leaf infection classification using weighted and deep features with optimized recurrent neural network concept
CN116522143B (en) Model training method, clustering method, equipment and medium
Anjun et al. SRAD-CNN for adaptive synthetic aperture radar image classification
CN116310308A (en) Image segmentation method, device, computer equipment and storage medium
CN116128895A (en) Medical image segmentation method, apparatus and computer readable storage medium
CN114549174A (en) User behavior prediction method and device, computer equipment and storage medium
CN114529399A (en) User data processing method, device, computer equipment and storage medium
CN114219184A (en) Product transaction data prediction method, device, equipment, medium and program product
Tang et al. Triple-branch ternary-attention mechanism network with deformable 3D convolution for hyperspectral image classification
CN117314756B (en) Verification and protection method and device based on remote sensing image, computer equipment and storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination