WO2020224403A1 - 分类任务模型的训练方法、装置、设备及存储介质 - Google Patents

分类任务模型的训练方法、装置、设备及存储介质 Download PDF

Info

Publication number
WO2020224403A1
WO2020224403A1 PCT/CN2020/085006 CN2020085006W WO2020224403A1 WO 2020224403 A1 WO2020224403 A1 WO 2020224403A1 CN 2020085006 W CN2020085006 W CN 2020085006W WO 2020224403 A1 WO2020224403 A1 WO 2020224403A1
Authority
WO
WIPO (PCT)
Prior art keywords
feature
classification task
samples
task model
training
Prior art date
Application number
PCT/CN2020/085006
Other languages
English (en)
French (fr)
Inventor
沈荣波
周可
田宽
颜克洲
江铖
Original Assignee
腾讯科技(深圳)有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 腾讯科技(深圳)有限公司 filed Critical 腾讯科技(深圳)有限公司
Priority to EP20802264.0A priority Critical patent/EP3968222B1/en
Publication of WO2020224403A1 publication Critical patent/WO2020224403A1/zh
Priority to US17/355,310 priority patent/US20210319258A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/74Image or video pattern matching; Proximity measures in feature spaces
    • G06V10/761Proximity, similarity or dissimilarity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/7715Feature extraction, e.g. by transforming the feature space, e.g. multi-dimensional scaling [MDS]; Mappings, e.g. subspace methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16HHEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
    • G16H30/00ICT specially adapted for the handling or processing of medical images
    • G16H30/40ICT specially adapted for the handling or processing of medical images for processing medical images, e.g. editing
    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16HHEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
    • G16H50/00ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
    • G16H50/20ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for computer-aided diagnosis, e.g. based on medical expert systems

Definitions

  • the embodiments of the present application relate to the field of machine learning technology, and particularly relate to the training of a classification task model.
  • Machine learning has good performance in processing classification tasks.
  • a classification task model is constructed based on deep neural networks, and the model is trained through appropriate training samples.
  • the trained classification task model can be used to process classification tasks, such as images Classification tasks such as recognition and speech recognition.
  • the categories of the training samples contained in the training data set may be unbalanced. For example, the number of positive samples is far less than the number of negative samples.
  • Such a training data set can be called a category imbalanced data set. If an imbalanced data set is used to train the classification task model, the performance of the final classification task model will be poor.
  • the embodiments of the present application provide a training method, device, equipment, and storage medium for a classification task model based on artificial intelligence, which can be used to solve the technical problem that the sample upsampling method provided by related technologies cannot be trained to obtain a high-precision classification task model.
  • the technical solution is as follows:
  • an embodiment of the present application provides a method for training a classification task model, the method is executed by a computer device, and the method includes:
  • the first data set is used to train the initial feature extractor to obtain the feature extractor; wherein, the first data set is an imbalanced data set including samples of the first category and samples of the second category, and the number of samples of the first category is greater than The number of samples of the second category, the first data set is determined through medical imaging;
  • the generative confrontation network including the feature extractor and an initial feature generator; wherein the initial feature generator is used to generate a feature vector of the same dimension as the feature extractor;
  • classification task model including the feature generator and the feature extractor
  • the first data set is used to train the classification task model; wherein the feature generator is used to amplify the second category samples in the feature space during the training process, and the trained classification task model Used to classify medical images.
  • an embodiment of the present application provides a training device for a classification task model, and the device includes:
  • the first training module is used to train an initial feature extractor to obtain a feature extractor using a first data set; wherein, the first data set is an unbalanced data set including samples of the first category and samples of the second category, and The number of samples in the first category is greater than the number of samples in the second category, and the first data set is determined through medical images;
  • the first construction module is used to construct a generative confrontation network, the generative confrontation network includes the feature extractor and an initial feature generator; wherein the initial feature generator is used to generate a feature vector of the same dimension as the feature extractor ;
  • the second training module is configured to train the generative confrontation network by using the second category samples to obtain a feature generator
  • the second building module is used to build a classification task model, the classification task model including the feature generator and the feature extractor;
  • the third training module is used to train the classification task model using the first data set; wherein the feature generator is used to amplify the second category samples in the feature space during the training process, and train The latter classification task model is used to classify medical images.
  • an embodiment of the present application provides a computer device, including:
  • the processor, the communication interface, and the memory complete communication with each other through the communication bus;
  • the communication interface is an interface of a communication module;
  • the memory is used for storing program code and transmitting the program code to the processor; the processor is used for calling the instruction of the program code in the memory to execute the training method of the classification task model in the above aspect.
  • an embodiment of the present application provides a storage medium, where the storage medium is used to store a computer program, and the computer program is used to execute the training method of the classification task model in the above aspect.
  • an embodiment of the present application provides a computer program product including instructions, which when run on a computer, causes the computer to execute the training method of the classification task model in the above aspect.
  • a feature generator is obtained based on the training of the generative confrontation network, and the minority class samples (that is, the type of training samples with a small number of imbalanced data sets) are expanded in the feature space through the feature generator Increase, from the feature level to amplify, instead of using sample upsampling to simply copy a few categories of samples, so that the final training of the classification task model avoids overfitting and improves the accuracy of the final training classification task model .
  • Fig. 1 is a flowchart of a training method for a classification task model provided by an embodiment of the present application
  • Figure 2 exemplarily shows the structure diagram of the initial classification task model
  • Fig. 3 exemplarily shows a schematic diagram of the structure of the generated confrontation network
  • Figure 4 exemplarily shows a schematic structural diagram of a classification task model
  • Figure 5 exemplarily shows the overall architecture diagram of the technical solution of the present application
  • Figures 6 and 7 exemplarily show schematic diagrams of two sets of experimental results
  • Fig. 8 is a block diagram of a training device for a classification task model provided by an embodiment of the present application.
  • Fig. 9 is a schematic structural diagram of a computer device provided by an embodiment of the present application.
  • a feature generator is obtained based on the training of the generative confrontation network, and the minority class samples (that is, the type of training samples with a small number of imbalanced data sets) are expanded in the feature space through the feature generator Increase, from the feature level to amplify, instead of using sample upsampling to simply copy a few categories of samples, so that the final training of the classification task model avoids overfitting and improves the accuracy of the final training classification task model .
  • ML machine learning
  • AI artificial intelligence
  • Deep Learning Deep Learning
  • ANN artificial neural networks
  • AI is a theory, method, technology and application system that uses digital computers or machines controlled by digital computers to simulate, extend and expand human intelligence, perceive the environment, acquire knowledge, and use knowledge to obtain the best results.
  • AI is a comprehensive technology of computer science, which attempts to understand the essence of intelligence and produce a new kind of intelligent machine that can react in a similar way to human intelligence.
  • AI is to study the design principles and implementation methods of various intelligent machines, so that the machines have the functions of perception, reasoning and decision-making.
  • AI technology is a comprehensive discipline, covering a wide range of fields, including both hardware-level technology and software-level technology.
  • AI basic technologies generally include technologies such as sensors, dedicated artificial intelligence chips, cloud computing, distributed storage, big data processing technologies, operation/interaction systems, and mechatronics.
  • AI software technology mainly includes computer vision technology, speech processing technology, natural language processing technology, and machine learning/deep learning.
  • the computer device that executes the training method of the classification task model may have machine learning capabilities, so as to train the classification task model through the machine learning capabilities.
  • Machine learning is a multi-disciplinary interdisciplinary, involving probability theory, statistics, approximation theory, convex analysis, algorithm complexity theory and other disciplines. Specializing in the study of how computers simulate or realize human learning behaviors in order to acquire new knowledge or skills, and reorganize the existing knowledge structure to continuously improve its own performance.
  • Machine learning is the core of artificial intelligence, the fundamental way to make computers intelligent, and its applications cover all fields of artificial intelligence.
  • Machine learning and deep learning usually include technologies such as artificial neural networks.
  • the classification task model involved in the embodiments of the present application refers to a machine learning model obtained through machine learning training and used to process classification tasks.
  • the classification task model may be a deep learning classification task model, that is, a classification task model based on a deep neural network, such as a classification task model based on a deep convolutional neural network.
  • the classification task model can be used to process lesion recognition and classification tasks in medical images, and can also be used to process classification tasks such as image recognition and speech recognition.
  • the specific application scenarios of the classification task model are not limited in the embodiment of the present application.
  • the execution subject of each step may be a computer device, and the computer device refers to an electronic device with data calculation, processing, and storage capabilities, such as a PC (Personal Computer) or a server.
  • PC Personal Computer
  • FIG. 1 shows a flowchart of a training method for a classification task model provided by an embodiment of the present application.
  • the method can include the following steps (101 ⁇ 105):
  • Step 101 Use the first data set to train the initial feature extractor to obtain the feature extractor.
  • the first data set is an imbalanced data set including samples of the first category and samples of the second category, and the number of samples of the first category is greater than that of the second category. The number of samples.
  • the first category sample and the second category sample are samples of two different categories in the first data set.
  • the samples of the first type are positive samples and the samples of the second type are negative samples; or, the samples of the first type are negative samples and the samples of the second type are positive samples.
  • the number of samples of the first category is greater than the number of samples of the second category, that is, the samples of the first category can be called samples of the majority category, and the samples of the second category can be called samples of the minority category. In most scenarios, the number of negative samples is greater than or even far greater than the number of positive samples. Therefore, samples of the first category can be negative samples, and samples of the second category are correspondingly positive samples.
  • the feature extractor is the part used to extract sample features in the classification task model, and the feature extractor is also called an encoder.
  • the classification task model includes a feature extractor and a classifier. The output end of the feature extractor is connected to the input end of the classifier. The feature extractor extracts feature vectors from the input samples of the model. The classifier is used to determine the input sample belongs to according to the feature vector category. Taking the classification task model for image recognition as an example, the feature extractor is used to map and encode the input image, and the output dimension is much lower than the feature vector of the input image pixel.
  • the feature extractor obtains a non-linear, local to global feature Mapping combines low-level visual features and high-level semantic information.
  • the classification task model is constructed based on a deep convolutional neural network, and the feature extractor may include multiple convolutional layers.
  • the classification task model is the Inception-v3 model
  • the Inception-v3 model is a deep neural network model, which has good performance on the image classification task.
  • another advantage of the Inception-v3 model is that the pre-trained Inception-v3 model can be used as an initialized classification task model without having to initialize the parameters in the classification task model randomly, which helps improve the model’s performance Training efficiency.
  • the classifier may adopt a normalized exponential function (Softmax) classifier or other classifiers, which is not limited in the embodiment of the present application.
  • Softmax normalized exponential function
  • step 101 includes the following sub-steps:
  • the initial classification task model can be a pre-trained Inception-v3 model.
  • the first data set includes samples of the first category and samples of the second category, and each training sample is set with a corresponding label according to its category.
  • the label of the first category sample is 1, and the label of the second category sample is 0; or, the label of the first category sample is 0, and the label of the second category sample is 1.
  • the loss function may adopt a cross entropy (Cross Entropy, CE) loss function.
  • the classification task model after the initial training includes the feature extractor after the initial training, and the feature extractor after the initial training is used in the following generative confrontation network.
  • the stopping training conditions of the initial classification task model can be set in advance, such as the model accuracy reaching the preset requirements, the number of training rounds reaching the preset number of rounds, or the training duration reaching the preset duration, etc., which are not limited in the embodiment of this application. .
  • the initial classification task model includes a feature extractor E I and a classifier C I.
  • the input end of the feature extractor E I is the input end of the model, and the output end of the feature extractor E I is connected to the input end of the classifier C I.
  • the output of the classifier C I is the output of the model.
  • the initial classification task model is trained using the first data set (including the majority class sample and the minority class sample) to obtain the classification task model after the initial training.
  • the classification task model after the initial training includes the feature extractor E I after the initial training and the classifier C I after the initial training.
  • Step 102 Construct a generative confrontation network, which includes a feature extractor and an initial feature generator.
  • the output of the feature generator and the output of the feature extractor are connected to the input of the domain classifier respectively.
  • the feature extractor is the feature extractor obtained by the initial feature extractor through the initial training in step 101 above.
  • the initial feature generator is used to generate feature vectors with the same dimensions as the feature extractor. For example, if the dimension of the feature vector output by the feature extractor is 20, the dimension of the feature vector generated by the initial feature generator is also 20.
  • the initial feature generator can also be constructed with multiple convolutional layers, such as including 6 convolutional layers, the size of the convolution kernel of the first 5 convolutional layers is 3*3, and the size of the convolution kernel of the last convolutional layer is 1. *1, the number of output feature maps corresponding to each convolutional layer are 64, 128, 256, 512, 1024 and 2048 respectively.
  • Each convolutional layer can be followed by a batch norm layer and a An activation function layer, such as a linear rectification function (Rectified Linear Unit, ReLU) layer.
  • ReLU linear rectification function
  • the generative confrontation network may also include a domain classifier, and the domain classifier is used to distinguish the feature vector output by the feature extractor from the feature vector output by the feature generator.
  • the domain classifier uses adversarial learning to adjust the feature generator so that the output feature vector is as close as possible to the feature vector output by the feature extractor. Through such a process of adversarial learning, the model parameters of the maximum-minimum game equilibrium are found.
  • Step 103 Use samples of the second category to train the generative confrontation network to obtain a feature generator.
  • the parameters of the feature extractor are fixed, that is, the parameters of the feature extractor are not updated.
  • the input of the feature extractor is the second category sample, that is, the minority category sample, and the output is the feature vector extracted from the second category sample.
  • the input of the initial feature generator includes the superposition of prior data and noise data, and the output is a feature vector of the same dimension as the feature extractor.
  • the prior data can be extracted from samples of the second type in the first data set, or from samples of the same type as the samples of the second data set in the second data set.
  • the second data set may be another data set different from the first data set in the same task.
  • the noise data may be random noise data. Taking an image with a priori data of 64*64 as an example, the noise data can also be an image of 64*64, but the pixel value of each pixel in the image of the noise data is randomly generated.
  • Superimposing the prior data and the noise data is to perform a weighted sum of the pixel values of the pixels in the same position in the prior data and the noise data, and finally obtain a superimposed image.
  • the initial feature generator extracts the feature vector from the superimposed image.
  • the prior data can be a small-sized sample image obtained by reducing the sample image, such as a 64*64 sample image.
  • the input of the initial feature generator in a possible implementation is not entirely noise data. If the feature vector similar to the real sample is generated from the noise data completely, there is no effective constraint.
  • the initial feature generator The input of is the superposition of prior data and noise data, which can suppress the problems of non-convergence and easy collapse in the training process of the generative countermeasure network, and increase the robustness of the generative countermeasure network.
  • step 103 includes the following sub-steps:
  • the first parameter update includes: assigning a first label to the input of the feature extractor, which is the initial feature generator Enter the second label;
  • the second parameter update includes: shielding the input of the feature extractor, and assigning a first label to the input of the initial feature generator;
  • the initial feature generator and the domain classifier compete against each other, that is, two back propagation calculations are performed in each round of training.
  • the parameters of the initial feature generator are fixed for the first time, and the domain is updated.
  • the parameters of the classifier, the parameters of the second fixed domain classifier, and the parameters of the initial feature generator are updated.
  • the first label and the second label are two different labels, for example, the first label is 1 and the second label is 0, or the first label is 0 and the second label is 1.
  • the generative confrontation network includes a feature extractor E I , a feature generator G and a domain classifier D.
  • the output terminal of the feature extractor E I and the output terminal of the feature generator G are connected to the input terminal of the domain classifier D respectively.
  • the input of the feature generator G is the superposition of prior data and noise data, and the input of the feature extractor E I is the minority class samples in the first data set.
  • the feature generator G is used in the following classification task model.
  • Step 104 Construct a classification task model, where the classification task model includes a feature generator and a feature extractor.
  • the classification task model may further include a classifier.
  • the output terminal of the feature generator and the output terminal of the feature extractor are respectively connected with the input terminal of the classifier.
  • the feature generator is the feature generator obtained by using the generative confrontation network training in step 103.
  • the feature extractor and classifier in this step adopt the same structure and configuration as the initial classification task model in step 101.
  • the feature extractor in this step is initialized with the parameters of the feature extractor obtained through training in step 101.
  • Step 105 Use the first data set to train the classification task model; wherein the feature generator is used to amplify the second category samples in the feature space.
  • the feature generator obtained by the generative confrontation network training is used to amplify the minority class samples in the feature space, and the unbalanced class
  • the learning task is transformed into a class-balanced learning task, and the classification task model is obtained by retraining.
  • the classification task model further includes a data cleaning unit through which the abnormal feature vectors output by the feature generator and the feature extractor are filtered.
  • the data cleaning unit can be a functional unit implemented through software, hardware or a combination of software and hardware. Appropriate data cleaning technology (such as Tomek Link algorithm) is used to suppress some abnormal feature vectors generated by the feature generator, thereby further improving the final training result. The accuracy of the classification task model.
  • the data cleaning unit can filter out the feature vector pairs that meet the preset conditions from the feature vectors output by the feature generator and the feature extractor.
  • the feature vector pairs that meet the preset conditions refer to tags. Two feature vectors that are different and whose similarity meets the threshold, for example, a group of feature vectors with the greatest similarity or a larger group of feature vectors.
  • the above-mentioned feature vector pairs that meet the preset conditions are filtered as abnormal feature vectors.
  • the similarity between the two feature vectors can be calculated by Euclidean distance algorithm or other similarity algorithms, which is not limited in the embodiment of the present application.
  • all feature vectors are traversed, for each feature vector, another feature vector that is most similar to the feature vector is found, and the two feature vectors are compared If the labels of the two eigenvectors are the same, if the labels of the two eigenvectors are not the same, for example, the label of one eigenvector is 1 and the label of the other eigenvector is 0, then the two eigenvectors are the eigenvector pairs that meet the preset conditions , And filter these two feature vectors as abnormal feature vectors.
  • the classification task model includes a feature generator G, a feature extractor E F , a classifier C F and a data cleaning unit.
  • Classification task model feature extractor 2 shown in FIG EF feature extractor E I have the same structure and configuration, as shown in the classification task model classifier C F 2 in FIG classifiers C 1 have the same structure And configuration.
  • the first data set (including the majority class sample and the minority class sample) is used to train the classification task model, and when the preset stop training condition is met, the training of the classification task model is stopped to obtain the classification task model.
  • the preset stopping training condition may be that the model accuracy reaches the preset requirement, the number of training rounds reaches the preset number of rounds, or the training duration reaches the preset duration, etc., which are not limited in the embodiment of the present application.
  • a feature generator is obtained based on the training of a generative confrontation network, and a minority class of samples (that is, a class of training samples with a small number of imbalanced data sets) are obtained through the feature generator.
  • the abnormal feature vector output by the feature generator and the feature extractor is also filtered by the data cleaning unit, so as to suppress some of the features generated by the feature generator. Abnormal feature vectors, thereby further improving the accuracy of the classification task model obtained by the final training.
  • the input of the feature generator is not completely noise data. If the feature vector similar to the real sample is completely generated from the noise data, there is no effective constraint.
  • the input of the feature generator is a priori data and The superposition of noise data can suppress the problems of non-convergence and easy collapse in the training process of the generative confrontation network, and increase the robustness of the generative confrontation network.
  • the training process of the classification task model provided in the embodiment of the present application may include the following three steps:
  • the first step training the initial feature extractor
  • an initial classification task model is constructed, including a feature extractor E I and a classifier C I , and the initial classification task model is trained using a class imbalanced data set to obtain a feature extractor E I.
  • Step 2 Training the feature generator
  • a generative adversarial network including the feature extractor E I after the initial training, the initial feature generator G and the domain classifier D.
  • the parameters of the fixed feature extractor E I remain unchanged, and the generated The feature generator G is obtained by training the adversarial network.
  • Step 3 Train the final classification task model.
  • the classification task model is constructed, including the feature generator G, the feature extractor EF , the data cleaning unit and the classifier EF .
  • the parameters of the fixed feature generator G remain unchanged, and the original
  • the feature generator G is used to amplify the minority class samples in the feature space, and the class imbalanced learning task is transformed into the class balanced learning task, and the final classification task model is trained.
  • the technical solutions provided by the embodiments of the present application can be applied to the model training process of the machine learning classification task in the AI field, and especially applicable to the training process of the classification task model whose training data set is an imbalanced data set.
  • the training data set can include multiple sub-images extracted from medical images. Some of these sub-images are positive samples (that is, images of the lesion area), and some are negative samples (also That is, the image of the non-lesion area), the number of negative samples is often much larger than the number of positive samples.
  • the classification task model can be called an imaging lesion discrimination model.
  • the input is a sub-image extracted from a medical image
  • the output is the judgment result of whether the sub-image is a lesion area.
  • a feature generator is obtained by generating a confrontation network training, and the feature generator is used to amplify a few types of samples in the feature space, and finally a more accurate imaging lesion discrimination model is trained to assist doctors in diagnosis and analysis of the lesion, such as mammography Detection and analysis of lumps in the image.
  • the above table-1 is the test result on the mammography data set
  • the table-2 is the test result on the camelyon2016 pathology image data set.
  • option 1 represents no processing of the data set
  • option 2 represents sample down-sampling of the data set
  • option 3 represents sample up-sampling of the data set
  • option 4 represents data The set is amplified from the sample space.
  • Option 5 represents the use of the technical solution of this application to amplify the data set from the feature space, and does not include the data cleaning step.
  • Option 6 represents the use of the technical solution of this application to amplify the data set from the feature space , And includes data cleaning steps.
  • Acc and AUC are model evaluation parameters.
  • Acc represents the accuracy of the classification task model obtained by the final training.
  • AUC Absolute under the ROC curve
  • AUC intuitively reflects the classification ability expressed by the ROC curve. The larger the AUC, the better the performance of the model. The smaller the AUC, the worse the performance of the model.
  • Figure 6 Part (a) of Figure 6 shows the ROC curve and the corresponding AUC value of the above 6 schemes in the data set of the mammography target image and the category imbalance ratio of 1:10.
  • Figure 6 (b) shows the ROC curve and the corresponding AUC value of the above 6 schemes in the data set of the mammography target image and the category imbalance ratio of 1:20.
  • Part (a) of Fig. 7 shows the ROC curves and corresponding AUC values of the above six schemes in the camelyon2016 pathological image data set and the category imbalance ratio of 1:10.
  • Part (b) of Figure 7 shows the ROC curves and corresponding AUC values of the above six schemes in the camelyon2016 pathological image data set and the category imbalance ratio of 1:20.
  • FIG. 8 shows a block diagram of a training device for a classification task model provided by an embodiment of the present application.
  • the device has the function of realizing the above method example, and the function can be realized by hardware, or by hardware executing corresponding software.
  • the device can be a computer device, or it can be set in a computer device.
  • the device 800 may include: a first training module 810, a first building module 820, a second training module 830, a second building module 840, and a third training module 850.
  • the first training module 810 is used to train an initial feature extractor to obtain a feature extractor using a first data set; wherein, the first data set is an unbalanced data set including samples of the first category and samples of the second category, so The number of samples of the first category is greater than the number of samples of the second category, and the first data set is determined through medical images.
  • the first construction module 820 is used to construct a generative confrontation network, the generative confrontation network including the feature extractor and an initial feature generator; wherein the initial feature generator is used to generate features of the same dimension as the feature extractor vector.
  • the second training module 830 is configured to train the generative confrontation network by using the second category samples to obtain a feature generator.
  • the second construction module 840 is configured to construct a classification task model, the classification task model including the feature generator and the feature extractor.
  • the third training module 850 is configured to use the first data set to train the classification task model; wherein the feature generator is configured to amplify the second category samples in the feature space during the training process.
  • a feature generator is obtained based on the training of a generative confrontation network, and a minority class of samples (that is, a class of training samples with a small number of imbalanced data sets) are obtained through the feature generator.
  • the generative confrontation network further includes a domain classifier, and the domain classifier is used to distinguish the feature vector output by the feature extractor from the feature vector output by the feature generator.
  • the second training module 830 is configured to: perform a first parameter update and a second parameter update during each round of training of the generative confrontation network, and the first parameter update includes: assigning input to the feature extractor The first label is to assign a second label to the input of the feature generator; calculate the first loss function value of the domain classifier; update the parameters of the domain classifier according to the first loss function value;
  • the second parameter update includes: shielding the input of the feature extractor, assigning the first label to the input of the feature generator; calculating the second loss function value of the domain classifier; and according to the second loss The function value updates the parameters of the feature generator.
  • the input of the initial feature generator includes a superposition of prior data and noise data; wherein, the prior data is extracted from the second category sample of the first data set, or , The a priori data is extracted from samples of the same type as the second type of samples in the second data set.
  • the classification task model further includes a data cleaning unit
  • the third training module is further configured to: use the data cleaning unit to analyze the abnormal features output by the feature generator and the feature extractor The vector is filtered.
  • the third training module is further configured to: through the data cleaning unit, filter out features that meet preset conditions from the feature vectors output by the feature generator and the feature extractor A vector pair, the feature vector pair meeting a preset condition includes two feature vectors with different labels and a similarity greater than a threshold; the feature vector pair meeting the preset condition is filtered as the abnormal feature vector.
  • the first training module 810 is configured to: construct an initial classification task model, the initial classification task model including the initial feature extractor and the initial classifier; and the first data set The initial classification task model is trained to obtain a feature extractor.
  • the device provided in the above embodiment when implementing its functions, only uses the division of the above functional modules for illustration. In practical applications, the above functions can be allocated by different functional modules as needed, namely The internal structure of the device is divided into different functional modules to complete all or part of the functions described above.
  • the apparatus and method embodiments provided by the above-mentioned embodiments belong to the same concept, and the specific implementation process is detailed in the method embodiments, which will not be repeated here.
  • FIG. 9 shows a schematic structural diagram of a computer device provided by an embodiment of the present application.
  • the computer device can be any electronic device with data processing and storage functions, such as a PC or server.
  • the computer device is used to implement the training method of the classification task model provided in the foregoing embodiment. Specifically:
  • the computer device 900 includes a central processing unit (CPU) 901, a system memory 904 including a random access memory (RAM) 902 and a read-only memory (ROM) 903, and a system bus 905 connecting the system memory 904 and the central processing unit 901 .
  • the computer device 900 also includes a basic input/output system (I/O system) 906 that helps to transfer information between various devices in the computer, and a large capacity for storing the operating system 913, application programs 914, and other program modules 915.
  • the basic input/output system 906 includes a display 908 for displaying information and an input device 909 such as a mouse and a keyboard for the user to input information.
  • the display 908 and the input device 909 are both connected to the central processing unit 901 through the input and output controller 910 connected to the system bus 905.
  • the basic input/output system 906 may also include an input and output controller 910 for receiving and processing input from multiple other devices such as a keyboard, a mouse, or an electronic stylus.
  • the input and output controller 910 also provides output to a display screen, a printer, or other types of output devices.
  • the mass storage device 907 is connected to the central processing unit 901 through a mass storage controller (not shown) connected to the system bus 905.
  • the mass storage device 907 and its associated computer readable medium provide non-volatile storage for the computer device 900. That is, the mass storage device 907 may include a computer-readable medium (not shown) such as a hard disk or a CD-ROM drive.
  • the computer-readable media may include computer storage media and communication media.
  • Computer storage media includes volatile and nonvolatile, removable and non-removable media implemented in any method or technology for storing information such as computer readable instructions, data structures, program modules or other data.
  • Computer storage media include RAM, ROM, EPROM, EEPROM, flash memory or other solid-state storage technologies, CD-ROM, DVD or other optical storage, tape cartridges, magnetic tape, disk storage or other magnetic storage devices.
  • RAM random access memory
  • ROM read-only memory
  • EPROM Erasable programmable read-only memory
  • EEPROM electrically erasable programmable read-only memory
  • the computer device 900 may also be connected to a remote computer on the network through a network such as the Internet to run. That is, the computer device 900 can be connected to the network 912 through the network interface unit 911 connected to the system bus 905, or in other words, the network interface unit 911 can also be used to connect to other types of networks or remote computer systems (not shown) ).
  • the memory stores at least one instruction, at least one section of program, code set or instruction set, and the at least one instruction, at least one section of program, code set or instruction set is configured to be executed by one or more processors to realize the foregoing
  • the training method of the classification task model provided by the embodiment is configured to be executed by one or more processors to realize the foregoing.
  • an embodiment of the present application further provides a computer-readable storage medium for storing a computer program, and the computer program is used to execute the training method of the classification task model provided in the foregoing embodiment.
  • the aforementioned computer-readable storage medium may be ROM, RAM, CD-ROM, magnetic tape, floppy disk, optical data storage device, and the like.
  • a computer program product is also provided.
  • the computer program product When executed, it is used to implement the training method of the classification task model provided in the foregoing embodiment.
  • the "plurality” mentioned herein refers to two or more.
  • “And/or” describes the association relationship of the associated objects, indicating that there can be three types of relationships, for example, A and/or B, which can mean: A alone exists, A and B exist at the same time, and B exists alone.
  • the character "/” generally indicates that the associated objects are in an "or” relationship.
  • the numbering of the steps described in this article only exemplarily shows a possible order of execution among the steps. In some other embodiments, the above steps may also be executed out of the order of numbers, such as two different numbers. The steps are executed at the same time, or the two steps with different numbers are executed in the reverse order of the figure, which is not limited in the embodiment of the present application.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Health & Medical Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Medical Informatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • General Engineering & Computer Science (AREA)
  • Databases & Information Systems (AREA)
  • Multimedia (AREA)
  • Biomedical Technology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Evolutionary Biology (AREA)
  • Mathematical Physics (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Molecular Biology (AREA)
  • Public Health (AREA)
  • Epidemiology (AREA)
  • Primary Health Care (AREA)
  • Probability & Statistics with Applications (AREA)
  • Pathology (AREA)
  • Nuclear Medicine, Radiotherapy & Molecular Imaging (AREA)
  • Radiology & Medical Imaging (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)

Abstract

一种基于人工智能的分类任务模型的训练方法、装置、设备及存储介质,涉及机器学习技术领域,所述方法包括:采用第一数据集训练初始特征提取器得到特征提取器(101),该第一数据集是类别不均衡数据集;构建生成对抗网络,该生成对抗网络包括特征提取器和初始特征生成器(102);采用第二类别样本对生成对抗网络进行训练,得到特征生成器(103);构建分类任务模型,该分类任务模型包括特征生成器和特征提取器(104);采用第一数据集对分类任务模型进行训练(105);其中,特征生成器用于在训练过程中对第二类别样本在特征空间进行扩增。通过特征生成器对少数类别样本在特征空间进行扩增,提高最终训练得到的分类任务模型的精度。

Description

分类任务模型的训练方法、装置、设备及存储介质
本申请要求于2019年05月07日提交中国专利局、申请号为201910377510.3、申请名称为“分类任务模型的训练方法、装置、设备及存储介质”的中国专利申请的优先权,其全部内容通过引用结合在本申请中。
技术领域
本申请实施例涉及机器学习技术领域,特别涉及分类任务模型的训练。
背景技术
机器学习对于处理分类任务具有较好的性能表现,例如基于深度神经网络构建分类任务模型,并通过适当的训练样本对该模型进行训练,完成训练的分类任务模型即可用于处理分类任务,如图像识别、语音识别等分类任务。
在训练分类任务模型时,训练数据集中包含的训练样本的类别可能并不均衡,例如正样本的数量远少于负样本的数量,这样的训练数据集可以称为类别不均衡数据集。如果采用类别不均衡数据集对分类任务模型进行训练,会导致最终得到的分类任务模型的性能表现不佳。
发明内容
本申请实施例提供了一种基于人工智能的分类任务模型的训练方法、装置、设备及存储介质,可用于解决相关技术提供的样本上采样手段,无法训练得到高精度的分类任务模型的技术问题。所述技术方案如下:
一方面,本申请实施例提供一种分类任务模型的训练方法,所述方法由计算机设备执行,所述方法包括:
采用第一数据集训练初始特征提取器得到特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量,所述第一数据集是通过医疗影像确定的;
构建生成对抗网络,所述生成对抗网络包括所述特征提取器和初始特征生成器;其中,所述初始特征生成器用于生成与所述特征提取器相同维度的特征向量;
采用所述第二类别样本对所述生成对抗网络进行训练,得到特征生成器;
构建分类任务模型,所述分类任务模型包括所述特征生成器和所述特征提 取器;
采用所述第一数据集对所述分类任务模型进行训练;其中,所述特征生成器用于在训练过程中对所述第二类别样本在特征空间进行扩增,训练后的所述分类任务模型用于对医疗影像进行病灶分类。
另一方面,本申请实施例提供一种分类任务模型的训练装置,所述装置包括:
第一训练模块,用于采用第一数据集训练初始特征提取器得到特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量,所述第一数据集是通过医疗影像确定的;
第一构建模块,用于构建生成对抗网络,所述生成对抗网络包括所述特征提取器和初始特征生成器;其中,所述初始特征生成器用于生成与所述特征提取器相同维度的特征向量;
第二训练模块,用于采用所述第二类别样本对所述生成对抗网络进行训练,得到特征生成器;
第二构建模块,用于构建分类任务模型,所述分类任务模型包括所述特征生成器和所述特征提取器;
第三训练模块,用于采用所述第一数据集对所述分类任务模型进行训练;其中,所述特征生成器用于在训练过程中对所述第二类别样本在特征空间进行扩增,训练后的所述分类任务模型用于对医疗影像进行病灶分类。
再一方面,本申请实施例提供一种计算机设备,包括:
处理器、通信接口、存储器和通信总线;
其中,所述处理器、所述通信接口和所述存储器通过所述通信总线完成相互间的通信;所述通信接口为通信模块的接口;
所述存储器,用于存储程序代码,并将所述程序代码传输给所述处理器;处理器,用于调用存储器中程序代码的指令执行以上方面的分类任务模型的训练方法。
又一方面,本申请实施例提供一种存储介质,所述存储介质用于存储计算机程序,所述计算机程序用于执行以上方面的分类任务模型的训练方法。
又一方面,本申请实施例提供了一种包括指令的计算机程序产品,当其在计算机上运行时,使得所述计算机执行以上方面的分类任务模型的训练方法。
本申请实施例提供的技术方案至少包括如下有益效果:
本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一个实施例提供的分类任务模型的训练方法的流程图;
图2示例性示出了初始的分类任务模型的结构示意图;
图3示例性示出了生成对抗网络的结构示意图;
图4示例性示出了分类任务模型的结构示意图;
图5示例性示出了本申请技术方案的整体架构图;
图6和图7示例性示出了两组实验结果的示意图;
图8是本申请一个实施例提供的分类任务模型的训练装置的框图;
图9是本申请一个实施例提供的计算机设备的结构示意图。
具体实施方式
为使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请实施方式作进一步地详细描述。
本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
在本申请实施例中,涉及人工智能(Artificial Intelligence,AI)技术中的机器学习(Machine learning,ML),以及机器学习中的深度学习(Deep Learning),包括各类人工神经网络(Artificial Neural Network,ANN)。
其中,AI是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,AI是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。AI也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
AI技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。AI基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。AI软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习等几大方向。
在本申请实施例中,执行分类任务模型的训练方法的计算机设备可以具备机器学习能力,以通过机器学习能力对分类任务模型进行训练。
机器学习是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络等技术。
本申请实施例中涉及的分类任务模型,是指通过机器学习训练得到的、用于处理分类任务的机器学习模型。该分类任务模型可以是深度学习分类任务模型,即基于深度神经网络构建的分类任务模型,如基于深度卷积神经网络构建的分类任务模型。该分类任务模型除了可以用于处理医疗影像中的病灶识别分类任务,还可用于处理图像识别、语音识别等分类任务,本申请实施例对该分类任务模型的具体应用场景不作限定。
本申请实施例提供的方法,各步骤的执行主体可以是计算机设备,该计算机设备是指具备数据计算、处理和存储能力的电子设备,如PC(Personal  Computer,个人计算机)或服务器。
请参考图1,其示出了本申请一个实施例提供的分类任务模型的训练方法的流程图。该方法可以包括以下几个步骤(101~105):
步骤101,采用第一数据集训练初始特征提取器得到特征提取器,第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,第一类别样本的数量大于第二类别样本的数量。
第一类别样本和第二类别样本是第一数据集中两种不同类别的样本。例如,第一类别样本为正样本,第二类别样本为负样本;或者,第一类别样本为负样本,第二类别样本为正样本。第一类别样本的数量大于第二类别样本的数量,即第一类别样本可以称为多数类别样本,第二类别样本可以称为少数类别样本。在大多数场景下,负样本的数量大于甚至远大于正样本的数量,因此,第一类别样本可以是负样本,相应地第二类别样本则为正样本。
特征提取器是分类任务模型中用于提取样本特征的部分,特征提取器也称为编码器(encoder)。分类任务模型包括特征提取器和分类器,特征提取器的输出端和分类器的输入端对接,特征提取器从模型的输入样本中提取特征向量,分类器用于根据该特征向量确定输入样本所属的类别。以分类任务模型用于图像识别为例,特征提取器用于对输入图像进行映射编码,输出维度远低于输入图像像素的特征向量,特征提取器获得了一种非线性的、局部到全局的特征映射,融合了低层的视觉特征和高层的语义信息。
在示例性实施例中,分类任务模型基于深度卷积神经网络构建,特征提取器可以包括多个卷积层。例如,分类任务模型为Inception-v3模型,Inception-v3模型是一种深度神经网络模型,其对图像分类任务具有较好的性能表现。另外,Inception-v3模型的另一优点是可以将预训练好的Inception-v3模型作为初始化的分类任务模型来使用,而不必对分类任务模型中的参数进行随机初始化,这有助于提高模型的训练效率。分类器可以采用归一化指数函数(Softmax)分类器或其它分类器,本申请实施例对此不作限定。
在示例性实施例中,步骤101包括如下几个子步骤:
1、构建初始分类任务模型,初始分类任务模型包括初始特征提取器和初始分类器;
如上文介绍,初始分类任务模型可以是预训练好的Inception-v3模型。
2、采用第一数据集对初始分类任务模型进行训练,得到特征提取器,该特征提取器为初始特征提取器通过前述初始训练后得到的。
第一数据集中包括第一类别样本和第二类别样本,每一个训练样本根据其所属类别设定有相应的标签。例如,第一类别样本的标签为1,第二类别样本的标签为0;或者,第一类别样本的标签为0,第二类别样本的标签为1。将第一数据集中的训练样本(包括第一类别样本和第二类别样本)输入至初始的分类任务模型,将模型输出的分类结果和标签进行比对,计算该模型对应的损失函数值;然后,根据损失函数值使用反向传播算法计算模型中各个参数的梯度;最后,使用梯度更新模型中的各个参数,更新的步调由学习率控制。其中,损失函数可以采用交叉熵(Cross Entropy,CE)损失函数。
在初始分类任务模型满足停止训练条件时,停止对该模型的训练,得到初始训练后的分类任务模型。该初始训练后的分类任务模型中包含初始训练后的特征提取器,该初始训练后的特征提取器被用于下述的生成对抗网络中。其中,初始的分类任务模型的停止训练条件可以预先进行设定,如模型精度达到预设要求、训练轮数达到预设轮数或训练时长达到预设时长等,本申请实施例对此不作限定。
如图2所示,其示例性示出了初始的分类任务模型的结构示意图。该初始的分类任务模型包括特征提取器E I和分类器C I,特征提取器E I的输入端即为模型的输入端,特征提取器E I的输出端和分类器C I的输入端对接,分类器C I的输出端即为模型的输出端。采用第一数据集(包括多数类别样本和少数类别样本)对该初始的分类任务模型进行训练,得到初始训练后的分类任务模型。该初始训练后的分类任务模型包括初始训练后的特征提取器E I和初始训练后的分类器C I
步骤102,构建生成对抗网络,该生成对抗网络包括特征提取器和初始特征生成器。
在生成对抗网络中,特征生成器的输出端和特征提取器的输出端,分别和域分类器的输入端对接。
特征提取器即为上述步骤101中由初始特征提取器通过初始训练后得到的 特征提取器。
初始特征生成器用于生成与特征提取器相同维度的特征向量。例如,特征提取器输出的特征向量的维度为20,则初始特征生成器生成的特征向量的维度也为20。初始特征生成器也可以采用多个卷积层构建,如包括6个卷积层,前5个卷积层的卷积核尺寸为3*3,最后一个卷积层的卷积核尺寸为1*1,对应的每个卷积层的输出特征图数量分别为64、128、256、512、1024和2048,每个卷积层后都可以跟随一个批量归一化(batch norm)层和一个激活函数层,例如线性整流函数(Rectified Linear Unit,ReLU)层。
在一种可能的实现方式中,生成对抗网络中还可以包括域分类器,域分类器用于对特征提取器输出的特征向量和特征生成器输出的特征向量进行区分。域分类器利用对抗学习来调整特征生成器,使其输出的特征向量尽可能地接近特征提取器输出的特征向量,通过这样一个对抗学习的过程找到最大-最小化博弈均衡的模型参数。
步骤103,采用第二类别样本对生成对抗网络进行训练,得到特征生成器。
在对生成对抗网络进行训练的过程中,特征提取器的参数固定,也即不对特征提取器的参数进行更新。特征提取器的输入是第二类别样本,也即少数类别样本,输出是从上述第二类别样本中提取到的特征向量。
初始特征生成器的输入包括先验数据与噪声数据的叠加,输出是与特征提取器同维度的特征向量。先验数据可以从第一数据集的第二类别样本中提取,也可以从第二数据集中与第二类别样本同类别的样本中提取。其中,第二数据集可以是同类任务中不同于第一数据集的另一数据集。噪声数据可以是随机噪声数据。以先验数据为64*64的图像为例,噪声数据也可以是64*64的图像,但噪声数据的图像中各个像素的像素值是随机生成的。将先验数据与噪声数据叠加,即为将先验数据与噪声数据中相同位置像素的像素值进行加权求和,最终得到一张叠加后的图像。
初始特征生成器从该叠加后的图像中提取得到特征向量。另外,考虑到特征生成器的网络层数可能较少,因此其输入不能过大,所以先验数据可以是对样本图像进行缩小后得到的小尺寸的样本图像,如64*64的样本图像。
在本申请实施例中,一种可能的实现方式中初始特征生成器的输入并非完 全是噪声数据,完全从噪声数据中生成与真实样本类似的特征向量的话,缺乏有效的约束,初始特征生成器的输入是先验数据与噪声数据的叠加,这样可以抑制生成对抗网络训练过程中不收敛和容易崩溃的问题,增加生成对抗网络的鲁棒性。
在示例性实施例中,步骤103包括如下几个子步骤:
1、在生成对抗网络的每一轮训练过程中,进行第一参数更新和第二参数更新,所述第一参数更新包括:为特征提取器的输入赋予第一标签,为初始特征生成器的输入赋予第二标签;
2、计算域分类器的第一损失函数值;
3、根据第一损失函数值对域分类器的参数进行更新;
4、所述第二次参数更新包括:屏蔽特征提取器的输入,为初始特征生成器的输入赋予第一标签;
5、计算域分类器的第二损失函数值;
6、根据第二损失函数值对初始特征生成器的参数进行更新。
在生成对抗网络的训练过程中,初始特征生成器和域分类器互相进行对抗,即在每一轮训练过程中进行两次反向传播计算,第一次固定初始特征生成器的参数,更新域分类器的参数,第二次固定域分类器的参数,更新初始特征生成器的参数。上述第一标签和第二标签是两个不同的标签,例如第一标签为1且第二标签为0,或第一标签为0且第二标签为1。
在一个示例中,首先,为特征提取器的输入赋予标签1,为初始特征生成器的输入赋予标签0,计算域分类器的第一损失函数值,根据该第一损失函数值反向传播调整域分类器的参数;然后,屏蔽特征提取器的输入,为初始特征生成器的输入赋予标签1,计算域分类器的第二损失函数值,根据该第二损失函数值反向传播调整初始特征生成器的参数。
如图3所示,其示例性示出了生成对抗网络的结构示意图。该生成对抗网络包括特征提取器E I、特征生成器G和域分类器D。特征提取器E I的输出端和特征生成器G的输出端分别与域分类器D的输入端对接。特征生成器G的输入为先验数据和噪声数据的叠加,特征提取器E I的输入为第一数据集中的少数类别样本。特征生成器G被用于下面的分类任务模型中。
步骤104,构建分类任务模型,该分类任务模型包括特征生成器和特征提取器。该分类任务模型还可以进一步包括分类器。
在分类任务模型中,特征生成器的输出端和特征提取器的输出端,分别和分类器的输入端对接。
特征生成器即为上述步骤103中利用生成对抗网络训练得到的特征生成器。本步骤中的特征提取器和分类器采用与步骤101中初始分类任务模型相同的结构和配置。可选地,本步骤中的特征提取器采用步骤101中训练得到的特征提取器的参数进行初始化。
步骤105,采用第一数据集对分类任务模型进行训练;其中,特征生成器用于对第二类别样本在特征空间进行扩增。
在对分类任务模型进行训练的过程中,配合原有的类别不均衡的第一数据集,利用生成对抗网络训练得到的特征生成器对少数类别样本在特征空间进行扩增,将类别不均衡的学习任务转化为类别均衡的学习任务,重新训练得到分类任务模型。
在示例性实施例中,分类任务模型还包括数据清洗单元,通过该数据清洗单元对特征生成器和特征提取器输出的异常特征向量进行过滤。数据清洗单元可以是一个通过软件、硬件或者软硬件结合实现的功能单元,通过采用合适的数据清洗技术(如Tomek Link算法)来抑制特征生成器生成的一些异常特征向量,从而进一步提高最终训练得到的分类任务模型的精度。
在示例性实施例中,通过数据清洗单元,可以从特征生成器和特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,该符合预设条件的特征向量对是指标签不同且相似度符合阈值的两个特征向量,例如相似度最大的一组特征向量或较大的多组特征向量。
然后将上述符合预设条件的特征向量对作为异常特征向量进行过滤。两个特征向量之间的相似度可以通过欧式距离算法或其它相似度算法进行计算得到,本申请实施例对此不作限定。示例性地,对于特征生成器和特征提取器输出的所有特征向量,遍历该所有特征向量,对于每一个特征向量,找到与该特征向量最相似的另一特征向量,比对这两个特征向量的标签是否相同,如果这两个特征向量的标签不相同,如一个特征向量的标签为1且另一个特征向量的 标签为0,则这两个特征向量即为符合预设条件的特征向量对,将这两个特征向量作为异常特征向量进行过滤。
如图4所示,其示例性示出了分类任务模型的结构示意图。该分类任务模型包括特征生成器G、特征提取器E F、分类器C F以及数据清洗单元。特征生成器G的输出端和特征提取器E F的输出端,分别与数据清洗单元的输入端对接,数据清洗单元的输出端与分类器C F的输入端对接。特征提取器EF与图2所示的分类任务模型中的特征提取器E I具有相同的结构和配置,分类器C F与图2所示的分类任务模型中的分类器C 1具有相同的结构和配置。采用第一数据集(包括多数类别样本和少数类别样本)对该分类任务模型进行训练,当满足预设的停止训练条件时,停止对该分类任务模型的训练,得到分类任务模型。其中,预设的停止训练条件可以是模型精度达到预设要求、训练轮数达到预设轮数或训练时长达到预设时长等,本申请实施例对此不作限定。
综上所述,本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
另外,本申请实施例提供的技术方案中,在训练分类任务模型的过程中,还通过数据清洗单元对特征生成器和特征提取器输出的异常特征向量进行过滤,实现抑制特征生成器生成的一些异常特征向量,从而进一步提高最终训练得到的分类任务模型的精度。
另外,在本申请实施例中,特征生成器的输入并非完全是噪声数据,完全从噪声数据中生成与真实样本类似的特征向量的话,缺乏有效的约束,特征生成器的输入是先验数据与噪声数据的叠加,这样可以抑制生成对抗网络训练过程中不收敛和容易崩溃的问题,增加生成对抗网络的鲁棒性。
下面,结合图5,对本申请实施例提供的技术方案进行整体说明。本申请实施例提供的分类任务模型的训练过程可以包括如下3个步骤:
第一步:训练初始特征提取器;
在本步骤中,构建初始的分类任务模型,包括特征提取器E I和分类器C I,采用类别不均衡数据集对该初始的分类任务模型进行训练,得到特征提取器E I
第二步:训练特征生成器;
在本步骤中,构建生成对抗网络,包括初始训练后的特征提取器E I、初始特征生成器G和域分类器D,在训练过程中,固定特征提取器E I的参数不变,利用生成对抗网络训练得到特征生成器G。
第三步:训练最终的分类任务模型。
在本步骤中,构建分类任务模型,包括特征生成器G、特征提取器E F、数据清洗单元和分类器E F,在训练过程中,固定特征生成器G的参数不变,配合原有的类别不均衡数据集,利用特征生成器G对少数类别样本在特征空间进行扩增,将类别不均衡的学习任务转化为类别均衡的学习任务,训练得到最终的分类任务模型。
本申请实施例提供的技术方案,可应用于AI领域的机器学习分类任务的模型训练过程中,特别适用于训练数据集为类别不均衡数据集的分类任务模型的训练过程中。以对类别不均衡的医疗影像的分类任务为例,训练数据集可以包括多张从医疗影像中提取的子图,这些子图有的是正样本(也即病灶区域的图像),有的是负样本(也即非病灶区域的图像),负样本的数量往往远大于正样本的数量。在这种应用场景下,分类任务模型可以称为影像学病灶判别模型,其输入是一张从医疗影像中提取的子图,输出是该子图是否为病灶区域的判别结果。通过生成对抗网络训练得到特征生成器,利用该特征生成器对少数类别样本在特征空间进行扩增,最终训练出更准确的影像学病灶判别模型,辅助医生做出病灶诊断分析,例如乳腺钼靶图像中的肿块检测分析。
本方案分别在一个包含2194张钼靶影像的数据集和一个camelyon2016病理图像数据集上测试,对图像进行感兴趣区域(region of interest,ROI)提取得到子图集合,分别使用了1:10和1:20的类别不均衡比例。测试的结果如下表-1和表-2所示。
表-1
Figure PCTCN2020085006-appb-000001
表-2
Figure PCTCN2020085006-appb-000002
上述表-1是在钼靶影像的数据集上的测试结果,表-2是在camelyon2016病理图像数据集上的测试结果。
在上述表-1和表-2中,方案1代表不对数据集做任何处理,方案2代表对数据集进行样本下采样处理,方案3代表对数据集进行样本上采样处理,方案4代表对数据集从样本空间进行扩增,方案5代表采用本申请技术方案对数据集从特征空间进行扩增,且不包含数据清洗步骤,方案6代表采用本申请技术方案对数据集从特征空间进行扩增,且包含数据清洗步骤。
在上述表-1和表-2中,Acc和AUC均为模型评价参数。其中,Acc(Accuracy)代表最终训练得到的分类任务模型的准确率,Acc越大,代表模型的性能越优,Acc越小,代表模型的性能越差。AUC(Area under the ROC curve)表示ROC(receiver operating characteristic curve,受试者工作特征曲线)曲线下的面积,AUC直观反映了ROC曲线表达的分类能力,AUC越大,代表模型的性能越优,AUC越小,代表模型的性能越差。
图6中(a)部分示出了上述6种方案在钼靶影像的数据集、1:10的类别不均衡比例下的ROC曲线及相应的AUC值。图6中(b)部分示出了上述6种方案 在钼靶影像的数据集、1:20的类别不均衡比例下的ROC曲线及相应的AUC值。
图7中(a)部分示出了上述6种方案在camelyon2016病理图像数据集、1:10的类别不均衡比例下的ROC曲线及相应的AUC值。图7中(b)部分示出了上述6种方案在camelyon2016病理图像数据集、1:20的类别不均衡比例下的ROC曲线及相应的AUC值。
从上述测试结果的图表中可以看出,本申请技术方案大多优于样本上采样、样本下采样、样本空间扩增技术等其它方案,且增加数据清洗步骤后的方案能够进一步提升最终训练得到的分类任务模型的性能。
下述为本申请装置实施例,可以用于执行本申请方法实施例。对于本申请装置实施例中未披露的细节,请参照本申请方法实施例。
请参考图8,其示出了本申请一个实施例提供的分类任务模型的训练装置的框图。该装置具有实现上述方法示例的功能,所述功能可以由硬件实现,也可以由硬件执行相应的软件实现。该装置可以是计算机设备,也可以设置在计算机设备中。该装置800可以包括:第一训练模块810、第一构建模块820、第二训练模块830、第二构建模块840和第三训练模块850。
第一训练模块810,用于采用第一数据集训练初始特征提取器得到特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量,所述第一数据集是通过医疗影像确定的。
第一构建模块820,用于构建生成对抗网络,所述生成对抗网络包括所述特征提取器和初始特征生成器;其中,所述初始特征生成器用于生成与所述特征提取器相同维度的特征向量。
第二训练模块830,用于采用所述第二类别样本对所述生成对抗网络进行训练,得到特征生成器。
第二构建模块840,用于构建分类任务模型,所述分类任务模型包括所述特征生成器和所述特征提取器。
第三训练模块850,用于采用所述第一数据集对所述分类任务模型进行训练;其中,所述特征生成器用于在训练过程中对所述第二类别样本在特征空间 进行扩增。
综上所述,本申请实施例提供的技术方案中,基于生成对抗网络训练得到特征生成器,通过该特征生成器对少数类别样本(即类别不均衡数据集中数量偏少的一类训练样本)在特征空间进行扩增,从特征层面进行扩增,而非采用样本上采样手段对少数类别样本进行简单复制,使得最终训练得到的分类任务模型避免出现过拟合的情况,提高最终训练得到的分类任务模型的精度。
在一些可能的设计中,所述生成对抗网络还包括域分类器,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分,所述第二训练模块830,用于:在所述生成对抗网络的每一轮训练过程中,进行第一参数更新和第二参数更新,所述第一参数更新包括:为所述特征提取器的输入赋予第一标签,为所述特征生成器的输入赋予第二标签;计算所述域分类器的第一损失函数值;根据所述第一损失函数值对所述域分类器的参数进行更新;所述第二参数更新包括:屏蔽所述特征提取器的输入,为所述特征生成器的输入赋予所述第一标签;计算所述域分类器的第二损失函数值;根据所述第二损失函数值对所述特征生成器的参数进行更新。
在一些可能的设计中,所述初始特征生成器的输入包括先验数据与噪声数据的叠加;其中,所述先验数据从所述第一数据集的所述第二类别样本中提取,或者,所述先验数据从第二数据集中与所述第二类别样本同类别的样本中提取。
在一些可能的设计中,所述分类任务模型还包括数据清洗单元,所述第三训练模块还用于:通过所述数据清洗单元对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤。
在一些可能的设计中,所述第三训练模块还用于:通过所述数据清洗单元,从所述特征生成器和所述特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,所述符合预设条件的特征向量对包括标签不同且相似度大于阈值的两个特征向量;将所述符合预设条件的特征向量对作为所述异常特征向量进行过滤。
在一些可能的设计中,所述第一训练模块810,用于:构建初始的分类任务模型,所述初始分类任务模型包括所述初始特征提取器和初始分类器;采用所述第一数据集对所述初始分类任务模型进行训练,得到特征提取器。
需要说明的是,上述实施例提供的装置,在实现其功能时,仅以上述各功能模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能模块完成,即将设备的内部结构划分成不同的功能模块,以完成以上描述的全部或者部分功能。另外,上述实施例提供的装置与方法实施例属于同一构思,其具体实现过程详见方法实施例,这里不再赘述。
请参考图9,其示出了本申请一个实施例提供的计算机设备的结构示意图。该计算机设备可以是任何具备数据处理和存储功能的电子设备,如PC或服务器。该计算机设备用于实施上述实施例中提供的分类任务模型的训练方法。具体来讲:
所述计算机设备900包括中央处理单元(CPU)901、包括随机存取存储器(RAM)902和只读存储器(ROM)903的系统存储器904,以及连接系统存储器904和中央处理单元901的系统总线905。所述计算机设备900还包括帮助计算机内的各个器件之间传输信息的基本输入/输出系统(I/O系统)906,和用于存储操作系统913、应用程序914和其他程序模块915的大容量存储设备907。
所述基本输入/输出系统906包括有用于显示信息的显示器908和用于用户输入信息的诸如鼠标、键盘之类的输入设备909。其中所述显示器908和输入设备909都通过连接到系统总线905的输入输出控制器910连接到中央处理单元901。所述基本输入/输出系统906还可以包括输入输出控制器910以用于接收和处理来自键盘、鼠标、或电子触控笔等多个其他设备的输入。类似地,输入输出控制器910还提供输出到显示屏、打印机或其他类型的输出设备。
所述大容量存储设备907通过连接到系统总线905的大容量存储控制器(未示出)连接到中央处理单元901。所述大容量存储设备907及其相关联的计算机可读介质为计算机设备900提供非易失性存储。也就是说,所述大容量存储设备907可以包括诸如硬盘或者CD-ROM驱动器之类的计算机可读介质(未示出)。
不失一般性,所述计算机可读介质可以包括计算机存储介质和通信介质。计算机存储介质包括以用于存储诸如计算机可读指令、数据结构、程序模块或其他数据等信息的任何方法或技术实现的易失性和非易失性、可移动和不可移动介质。计算机存储介质包括RAM、ROM、EPROM、EEPROM、闪存或其他固态存储其技术,CD-ROM、DVD或其他光学存储、磁带盒、磁带、磁盘存 储或其他磁性存储设备。当然,本领域技术人员可知所述计算机存储介质不局限于上述几种。上述的系统存储器904和大容量存储设备907可以统称为存储器。
根据本申请的各种实施例,所述计算机设备900还可以通过诸如因特网等网络连接到网络上的远程计算机运行。也即计算机设备900可以通过连接在所述系统总线905上的网络接口单元911连接到网络912,或者说,也可以使用网络接口单元911来连接到其他类型的网络或远程计算机系统(未示出)。
所述存储器中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、至少一段程序、代码集或指令集经配置以由一个或者一个以上处理器执行,以实现上述实施例提供的分类任务模型的训练方法。
在示例性实施例中,本申请实施例还提供一种计算机可读存储介质,用于存储计算机程序,该计算机程序用于执行上述实施例提供的分类任务模型的训练方法。在示例性实施例中,上述计算机可读存储介质可以是ROM、RAM、CD-ROM、磁带、软盘和光数据存储设备等。
在示例性实施例中,还提供了一种计算机程序产品,当该计算机程序产品被执行时,其用于实现上述实施例提供的分类任务模型的训练方法。
应当理解的是,在本文中提及的“多个”是指两个或两个以上。“和/或”,描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。字符“/”一般表示前后关联对象是一种“或”的关系。另外,本文中描述的步骤编号,仅示例性示出了步骤间的一种可能的执行先后顺序,在一些其它实施例中,上述步骤也可以不按照编号顺序来执行,如两个不同编号的步骤同时执行,或者两个不同编号的步骤按照与图示相反的顺序执行,本申请实施例对此不作限定。
以上所述仅为本申请的示例性实施例,并不用以限制本申请,凡在本申请的精神和原则之内,所作的任何修改、等同替换、改进等,均应包含在本申请的保护范围之内。

Claims (15)

  1. 一种分类任务模型的训练方法,所述方法由计算机设备执行,所述方法包括:
    采用第一数据集训练初始特征提取器得到特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量,所述第一数据集是通过医疗影像确定的;
    构建生成对抗网络,所述生成对抗网络包括所述特征提取器和初始特征生成器;其中,所述初始特征生成器用于生成与所述特征提取器相同维度的特征向量;
    采用所述第二类别样本对所述生成对抗网络进行训练,得到特征生成器;
    构建分类任务模型,所述分类任务模型包括所述特征生成器和所述特征提取器;
    采用所述第一数据集对所述分类任务模型进行训练;其中,所述特征生成器用于在训练过程中对所述第二类别样本在特征空间进行扩增,训练后的所述分类任务模型用于对医疗影像进行病灶分类。
  2. 根据权利要求1所述的方法,所述生成对抗网络还包括域分类器,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分,所述采用所述第二类别样本对所述生成对抗网络进行训练,得到特征生成器,包括:
    在所述生成对抗网络的每一轮训练过程中,进行第一参数更新和第二参数更新,所述第一参数更新包括:
    为所述特征提取器的输入赋予第一标签,为所述初始特征生成器的输入赋予第二标签;
    计算所述域分类器的第一损失函数值;
    根据所述第一损失函数值对所述域分类器的参数进行更新;
    所述第二参数更新包括:
    屏蔽所述特征提取器的输入,为所述初始特征生成器的输入赋予所述第一标签;
    计算所述域分类器的第二损失函数值;
    根据所述第二损失函数值对所述初始特征生成器的参数进行更新。
  3. 根据权利要求1所述的方法,所述初始特征生成器的输入包括先验数据与噪声数据的叠加;
    其中,所述先验数据从所述第一数据集的所述第二类别样本中提取,或者,所述先验数据从第二数据集中与所述第二类别样本同类别的样本中提取。
  4. 根据权利要求1至3任一项所述的方法,所述分类任务模型还包括数据清洗单元,在采用所述第一数据集对所述分类任务模型进行训练的过程中,所述方法还包括:
    通过所述数据清洗单元对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤。
  5. 根据权利要求4所述的方法,所述通过所述数据清洗单元对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤,包括:
    通过所述数据清洗单元从所述特征生成器和所述特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,所述符合预设条件的特征向量对包括标签不同且相似度大于阈值的两个特征向量;
    将所述符合预设条件的特征向量对作为所述异常特征向量进行过滤。
  6. 根据权利要求1至3任一项所述的方法,所述采用第一数据集训练初始特征提取器得到特征提取器,包括:
    构建初始分类任务模型,所述初始分类任务模型包括所述初始特征提取器;
    采用所述第一数据集对所述初始分类任务模型进行训练,得到特征提取器。
  7. 一种分类任务模型的训练装置,所述装置包括:
    第一训练模块,用于采用第一数据集训练初始特征提取器得到特征提取器;其中,所述第一数据集是包括第一类别样本和第二类别样本的类别不均衡数据集,所述第一类别样本的数量大于所述第二类别样本的数量,所述第一数据集是通过医疗影像确定的;
    第一构建模块,用于构建生成对抗网络,所述生成对抗网络包括所述特征提取器和初始特征生成器;其中,所述初始特征生成器用于生成与所述特征提取器相同维度的特征向量;
    第二训练模块,用于采用所述第二类别样本对所述生成对抗网络进行训练,得到特征生成器;
    第二构建模块,用于构建分类任务模型,所述分类任务模型包括所述完成训练的特征生成器和所述特征提取器;
    第三训练模块,用于采用所述第一数据集对所述分类任务模型进行训练;其中,所述特征生成器用于在训练过程中对所述第二类别样本在特征空间进行扩增。
  8. 根据权利要求7所述的装置,所述生成对抗网络还包括域分类器,所述域分类器用于对所述特征提取器输出的特征向量和所述特征生成器输出的特征向量进行区分,所述第二训练模块,用于:
    在所述生成对抗网络的每一轮训练过程中,进行第一参数更新和第二参数更新,所述第一参数更新包括:
    为所述特征提取器的输入赋予第一标签,为所述特征生成器的输入赋予第二标签;
    计算所述域分类器的第一损失函数值;
    根据所述第一损失函数值对所述域分类器的参数进行更新;
    所述第二参数更新包括:
    屏蔽所述特征提取器的输入,为所述特征生成器的输入赋予所述第一标签;
    计算所述域分类器的第二损失函数值;
    根据所述第二损失函数值对所述特征生成器的参数进行更新。
  9. 根据权利要求7所述的装置,所述初始特征生成器的输入包括先验数据与噪声数据的叠加;
    其中,所述先验数据从所述第一数据集的所述第二类别样本中提取,或者,所述先验数据从第二数据集中与所述第二类别样本同类别的样本中提取。
  10. 根据权利要求7至9任一项所述的装置,其特征在于,所述分类任务模型还包括数据清洗单元,所述第三训练模块还用于:
    通过数据清洗单元对所述特征生成器和所述特征提取器输出的异常特征向量进行过滤。
  11. 根据权利要求10所述的装置,,所述第三训练模块还用于:
    通过所述数据清洗单元从所述特征生成器和所述特征提取器输出的特征向量中,筛选出符合预设条件的特征向量对,所述符合预设条件的特征向量对包括标签不同且相似度大于阈值的两个特征向量;
    将所述符合预设条件的特征向量对作为所述异常特征向量进行过滤。
  12. 根据权利要求7至9任一项所述的装置,所述第一训练模块还用于:
    构建初始分类任务模型,所述初始分类任务模型包括所述初始特征提取器;
    采用所述第一数据集对所述初始分类任务模型进行训练,得到所述特征提取器。
  13. 一种计算机设备,所述计算机设备包括:
    处理器、通信接口、存储器和通信总线;
    其中,所述处理器、所述通信接口和所述存储器通过所述通信总线完成相互间的通信;所述通信接口为通信模块的接口;
    所述存储器,用于存储程序代码,并将所述程序代码传输给所述处理器;
    所述处理器,用于调用存储器中程序代码的指令执行权利要求1至6任一项所述的方法。
  14. 一种计算机可读存储介质,所述存储介质用于存储计算机程序,所述计算机程序用于执行权利要求1至6任一项所述的方法
  15. 一种包括指令的计算机程序产品,当其在计算机上运行时,使得所述计算机执行权利要求1-12任意一项所述的数据处理方法。
PCT/CN2020/085006 2019-05-07 2020-04-16 分类任务模型的训练方法、装置、设备及存储介质 WO2020224403A1 (zh)

Priority Applications (2)

Application Number Priority Date Filing Date Title
EP20802264.0A EP3968222B1 (en) 2019-05-07 2020-04-16 Classification task model training method, apparatus and device and storage medium
US17/355,310 US20210319258A1 (en) 2019-05-07 2021-06-23 Method and apparatus for training classification task model, device, and storage medium

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN201910377510.3A CN110097130B (zh) 2019-05-07 2019-05-07 分类任务模型的训练方法、装置、设备及存储介质
CN201910377510.3 2019-05-07

Related Child Applications (1)

Application Number Title Priority Date Filing Date
US17/355,310 Continuation US20210319258A1 (en) 2019-05-07 2021-06-23 Method and apparatus for training classification task model, device, and storage medium

Publications (1)

Publication Number Publication Date
WO2020224403A1 true WO2020224403A1 (zh) 2020-11-12

Family

ID=67447198

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2020/085006 WO2020224403A1 (zh) 2019-05-07 2020-04-16 分类任务模型的训练方法、装置、设备及存储介质

Country Status (4)

Country Link
US (1) US20210319258A1 (zh)
EP (1) EP3968222B1 (zh)
CN (1) CN110097130B (zh)
WO (1) WO2020224403A1 (zh)

Cited By (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112466436A (zh) * 2020-11-25 2021-03-09 北京小白世纪网络科技有限公司 基于循环神经网络的智能中医开方模型训练方法及装置
CN112463972A (zh) * 2021-01-28 2021-03-09 成都数联铭品科技有限公司 一种基于类别不均衡的样本分类方法
CN112905325A (zh) * 2021-02-10 2021-06-04 山东英信计算机技术有限公司 一种分布式数据缓存加速训练的方法、系统及介质
CN113642621A (zh) * 2021-08-03 2021-11-12 南京邮电大学 基于生成对抗网络的零样本图像分类方法
CN114360008A (zh) * 2021-12-23 2022-04-15 上海清鹤科技股份有限公司 人脸认证模型的生成方法、认证方法、设备及存储介质
CN114545255A (zh) * 2022-01-18 2022-05-27 广东工业大学 基于竞争型生成式对抗神经网络的锂电池soc估计方法
CN116934385A (zh) * 2023-09-15 2023-10-24 山东理工昊明新能源有限公司 用户流失预测模型的构建方法、用户流失预测方法及装置

Families Citing this family (21)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110598840B (zh) * 2018-06-13 2023-04-18 富士通株式会社 知识迁移方法、信息处理设备以及存储介质
CN110097130B (zh) * 2019-05-07 2022-12-13 深圳市腾讯计算机系统有限公司 分类任务模型的训练方法、装置、设备及存储介质
CN110570492B (zh) * 2019-09-11 2021-09-03 清华大学 一种基于神经网络的ct伪影抑制方法、设备以及介质
CN110888911A (zh) * 2019-10-11 2020-03-17 平安科技(深圳)有限公司 样本数据处理方法、装置、计算机设备及存储介质
CN110732139B (zh) * 2019-10-25 2024-03-05 腾讯科技(深圳)有限公司 检测模型的训练方法和用户数据的检测方法、装置
CN110706738B (zh) * 2019-10-30 2020-11-20 腾讯科技(深圳)有限公司 蛋白质的结构信息预测方法、装置、设备及存储介质
CN110807332B (zh) * 2019-10-30 2024-02-27 腾讯科技(深圳)有限公司 语义理解模型的训练方法、语义处理方法、装置及存储介质
CN111126503B (zh) * 2019-12-27 2023-09-26 北京同邦卓益科技有限公司 一种训练样本的生成方法和装置
CN111241969A (zh) * 2020-01-06 2020-06-05 北京三快在线科技有限公司 目标检测方法、装置及相应模型训练方法、装置
CN111444967B (zh) * 2020-03-30 2023-10-31 腾讯科技(深圳)有限公司 生成对抗网络的训练方法、生成方法、装置、设备及介质
CN111582647A (zh) * 2020-04-09 2020-08-25 上海淇毓信息科技有限公司 用户数据处理方法、装置及电子设备
CN111291841B (zh) * 2020-05-13 2020-08-21 腾讯科技(深圳)有限公司 图像识别模型训练方法、装置、计算机设备和存储介质
CN111832404B (zh) * 2020-06-04 2021-05-18 中国科学院空天信息创新研究院 一种基于特征生成网络的小样本遥感地物分类方法及系统
CN111950656B (zh) * 2020-08-25 2021-06-25 深圳思谋信息科技有限公司 图像识别模型生成方法、装置、计算机设备和存储介质
CN114666188A (zh) * 2020-12-24 2022-06-24 华为技术有限公司 信息生成方法及相关装置
CN113723519B (zh) * 2021-08-31 2023-07-25 平安科技(深圳)有限公司 基于对比学习的心电数据处理方法、装置及存储介质
CN113610191B (zh) * 2021-09-07 2023-08-29 中原动力智能机器人有限公司 垃圾分类模型建模方法、垃圾分类方法
CN113869398B (zh) * 2021-09-26 2024-06-21 平安科技(深圳)有限公司 一种不平衡文本分类方法、装置、设备及存储介质
CN114186617B (zh) * 2021-11-23 2022-08-30 浙江大学 一种基于分布式深度学习的机械故障诊断方法
CN113902131B (zh) * 2021-12-06 2022-03-08 中国科学院自动化研究所 抵抗联邦学习中歧视传播的节点模型的更新方法
US11983363B1 (en) * 2023-02-09 2024-05-14 Primax Electronics Ltd. User gesture behavior simulation system and user gesture behavior simulation method applied thereto

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2017215284A1 (zh) * 2016-06-14 2017-12-21 山东大学 基于卷积神经网络的胃肠道肿瘤显微高光谱图像处理方法
CN108537743A (zh) * 2018-03-13 2018-09-14 杭州电子科技大学 一种基于生成对抗网络的面部图像增强方法
CN108763874A (zh) * 2018-05-25 2018-11-06 南京大学 一种基于生成对抗网络的染色体分类方法及装置
CN109165666A (zh) * 2018-07-05 2019-01-08 南京旷云科技有限公司 多标签图像分类方法、装置、设备及存储介质
CN110097130A (zh) * 2019-05-07 2019-08-06 深圳市腾讯计算机系统有限公司 分类任务模型的训练方法、装置、设备及存储介质

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106650721B (zh) * 2016-12-28 2019-08-13 吴晓军 一种基于卷积神经网络的工业字符识别方法
JP6928371B2 (ja) * 2017-08-01 2021-09-01 国立研究開発法人情報通信研究機構 分類器、分類器の学習方法、分類器における分類方法
US11120337B2 (en) * 2017-10-20 2021-09-14 Huawei Technologies Co., Ltd. Self-training method and system for semi-supervised learning with generative adversarial networks
CN108805188B (zh) * 2018-05-29 2020-08-21 徐州工程学院 一种基于特征重标定生成对抗网络的图像分类方法
CN109522973A (zh) * 2019-01-17 2019-03-26 云南大学 基于生成式对抗网络与半监督学习的医疗大数据分类方法及系统

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2017215284A1 (zh) * 2016-06-14 2017-12-21 山东大学 基于卷积神经网络的胃肠道肿瘤显微高光谱图像处理方法
CN108537743A (zh) * 2018-03-13 2018-09-14 杭州电子科技大学 一种基于生成对抗网络的面部图像增强方法
CN108763874A (zh) * 2018-05-25 2018-11-06 南京大学 一种基于生成对抗网络的染色体分类方法及装置
CN109165666A (zh) * 2018-07-05 2019-01-08 南京旷云科技有限公司 多标签图像分类方法、装置、设备及存储介质
CN110097130A (zh) * 2019-05-07 2019-08-06 深圳市腾讯计算机系统有限公司 分类任务模型的训练方法、装置、设备及存储介质

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
See also references of EP3968222A4

Cited By (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112466436A (zh) * 2020-11-25 2021-03-09 北京小白世纪网络科技有限公司 基于循环神经网络的智能中医开方模型训练方法及装置
CN112466436B (zh) * 2020-11-25 2024-02-23 北京小白世纪网络科技有限公司 基于循环神经网络的智能中医开方模型训练方法及装置
CN112463972A (zh) * 2021-01-28 2021-03-09 成都数联铭品科技有限公司 一种基于类别不均衡的样本分类方法
CN112463972B (zh) * 2021-01-28 2021-05-18 成都数联铭品科技有限公司 一种基于类别不均衡的文本样本分类方法
CN112905325A (zh) * 2021-02-10 2021-06-04 山东英信计算机技术有限公司 一种分布式数据缓存加速训练的方法、系统及介质
CN113642621A (zh) * 2021-08-03 2021-11-12 南京邮电大学 基于生成对抗网络的零样本图像分类方法
CN114360008A (zh) * 2021-12-23 2022-04-15 上海清鹤科技股份有限公司 人脸认证模型的生成方法、认证方法、设备及存储介质
CN114545255A (zh) * 2022-01-18 2022-05-27 广东工业大学 基于竞争型生成式对抗神经网络的锂电池soc估计方法
CN116934385A (zh) * 2023-09-15 2023-10-24 山东理工昊明新能源有限公司 用户流失预测模型的构建方法、用户流失预测方法及装置
CN116934385B (zh) * 2023-09-15 2024-01-19 山东理工昊明新能源有限公司 用户流失预测模型的构建方法、用户流失预测方法及装置

Also Published As

Publication number Publication date
CN110097130A (zh) 2019-08-06
CN110097130B (zh) 2022-12-13
EP3968222B1 (en) 2024-01-17
EP3968222A1 (en) 2022-03-16
US20210319258A1 (en) 2021-10-14
EP3968222A4 (en) 2022-06-29

Similar Documents

Publication Publication Date Title
WO2020224403A1 (zh) 分类任务模型的训练方法、装置、设备及存储介质
JP7058373B2 (ja) 医療画像に対する病変の検出及び位置決め方法、装置、デバイス、及び記憶媒体
US11798132B2 (en) Image inpainting method and apparatus, computer device, and storage medium
JP6678778B2 (ja) 画像内の物体を検出する方法及び物体検出システム
US11232286B2 (en) Method and apparatus for generating face rotation image
US20210342643A1 (en) Method, apparatus, and electronic device for training place recognition model
WO2021227726A1 (zh) 面部检测、图像检测神经网络训练方法、装置和设备
JP6798183B2 (ja) 画像解析装置、画像解析方法およびプログラム
CN112418390B (zh) 使用单调属性函数对图像进行对比解释
TWI721510B (zh) 雙目圖像的深度估計方法、設備及儲存介質
CN110555481A (zh) 一种人像风格识别方法、装置和计算机可读存储介质
CN111368672A (zh) 一种用于遗传病面部识别模型的构建方法及装置
EP4322056A1 (en) Model training method and apparatus
WO2021218238A1 (zh) 图像处理方法和图像处理装置
US11830187B2 (en) Automatic condition diagnosis using a segmentation-guided framework
US11790492B1 (en) Method of and system for customized image denoising with model interpretations
KR20190126857A (ko) 이미지에서 오브젝트 검출 및 표현
WO2019146057A1 (ja) 学習装置、実写画像分類装置の生成システム、実写画像分類装置の生成装置、学習方法及びプログラム
CN110222718A (zh) 图像处理的方法及装置
EP4006777A1 (en) Image classification method and device
WO2020063835A1 (zh) 模型生成
CN111862040B (zh) 人像图片质量评价方法、装置、设备及存储介质
CN116452810A (zh) 一种多层次语义分割方法、装置、电子设备及存储介质
Lee et al. Background subtraction using the factored 3-way restricted Boltzmann machines
WO2023231753A1 (zh) 一种神经网络的训练方法、数据的处理方法以及设备

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 20802264

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

ENP Entry into the national phase

Ref document number: 2020802264

Country of ref document: EP

Effective date: 20211207