WO2023202596A1 - 一种半监督模型训练方法、系统及相关设备 - Google Patents

一种半监督模型训练方法、系统及相关设备 Download PDF

Info

Publication number
WO2023202596A1
WO2023202596A1 PCT/CN2023/089098 CN2023089098W WO2023202596A1 WO 2023202596 A1 WO2023202596 A1 WO 2023202596A1 CN 2023089098 W CN2023089098 W CN 2023089098W WO 2023202596 A1 WO2023202596 A1 WO 2023202596A1
Authority
WO
WIPO (PCT)
Prior art keywords
pseudo
model
label
sample
pseudo label
Prior art date
Application number
PCT/CN2023/089098
Other languages
English (en)
French (fr)
Inventor
徐强
徐晓忻
黄全充
傅蓉蓉
蔡晨东
纪荣嵘
周奕毅
罗根
Original Assignee
华为技术有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 华为技术有限公司 filed Critical 华为技术有限公司
Publication of WO2023202596A1 publication Critical patent/WO2023202596A1/zh

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/25Fusion techniques

Definitions

  • This application relates to the technical field of artificial intelligence (AI), and in particular to a semi-supervised model training method, system and related equipment.
  • AI artificial intelligence
  • Semi-supervised learning refers to a method of training an AI model (also referred to as "model” in this application) using labeled samples and unlabeled samples. Through semi-supervised learning, labeled samples can be effectively reduced. The number of samples reduces the cost of model training.
  • semi-supervised learning adopts the semi-supervised model training method of the first-second model (also called the teacher-student model, where the first model is the teacher model and the second model is the student model).
  • the first model can be trained using labeled samples, and then the unlabeled samples can be input to the first model trained with labeled samples to infer the pseudo labels of the unlabeled samples, and then the weights of the first model can be Copy the parameters to a second model with the same structure, use the above pseudo-labels and unlabeled samples to train the second model, obtain the updated weight parameters of the second model, and then update some of the weight parameters of the second model to the first model .
  • the first model infers new pseudo-label samples and trains the second model.
  • the weight parameters of the first model are stably updated, and finally the robustness of the first model obtained through training is achieved. Stronger and stronger model performance.
  • pseudo-label samples have a great impact on the training of the second model. If the pseudo-label error generated by the first model is large, the training efficiency of the second model will be reduced. If it is low, the model performance is poor, which in turn affects the training efficiency and model performance of the final first model.
  • This application provides a semi-supervised model training method, system and related equipment to solve the problem of low model training efficiency and poor model training caused by poor pseudo-label quality in the semi-supervised learning process.
  • a semi-supervised model training method may include the following steps: inputting the first unlabeled sample into the first model, obtaining the first pseudo-label of the first unlabeled sample, and inputting the first expanded sample.
  • the first model obtains the second pseudo-label of the first expanded sample, where the first model is an artificial intelligence AI model trained using labeled samples, and the first expanded sample is obtained by data enhancement of the first unlabeled sample.
  • the sample obtain the third pseudo-label of the first unlabeled sample based on the first pseudo-label and the second pseudo-label, and use the first unlabeled sample and the third pseudo-label to train the second model, where the second model is The AI model obtained based on the weight parameters of the first model.
  • Implement the method described in the first aspect obtain the first expanded sample of the first unlabeled sample by performing data enhancement on the first unlabeled sample, and then input the first unlabeled sample and the first expanded sample into the first model for inference to obtain the first
  • the first pseudo-label of the unlabeled sample and the second pseudo-label of the first extended sample are then obtained based on the first pseudo-label and the second pseudo-label.
  • the third pseudo-label obtained in this way is of higher quality. Using the third When pseudo-labels are used for semi-supervised training of the second model, the training efficiency and performance of the model are improved, thereby improving the training efficiency and model performance of the finally obtained first model.
  • the data enhancement method may include but is not limited to flip transformation (flip), translation transformation (shift), scale transformation (scale), rotation transformation/reflection transformation One or more of (rotation/reflection), zoom transformation (zoom), crop (crop), color transformation (color space), noise disturbance (noise), and kernel filters (kernel filters).
  • flip transformation flip transformation
  • translation transformation shift
  • scale transformation scale
  • rotation transformation/reflection transformation One or more of (rotation/reflection), zoom transformation (zoom), crop (crop), color transformation (color space), noise disturbance (noise), and kernel filters (kernel filters).
  • flip transformation refers to flipping the image horizontally or vertically.
  • Horizontal flipping can also be divided into upward horizontal flipping and downward horizontal flipping.
  • Vertical flipping can also be divided into left vertical flipping and right vertical flipping;
  • translation transformation refers to It is a translation operation on the image, such as translation to the right in the x direction (param xoffset) and downward translation in the y direction (param yoffset), where the x direction and y direction refer to the horizontal and vertical axis directions of the image coordinate system;
  • rotation Transformation can also be called reflection transformation, which refers to rotating the image at a certain angle, which can be any angle from 0 to 360 degrees; scaling transformation refers to enlarging or reducing the image according to a certain ratio without will change the content in the image; trimming can also be called cropping, including uniform cropping and random cropping.
  • Uniform cropping refers to cropping images of different sizes to a set size, and random cropping refers to randomly cropping images of different sizes into Different sizes
  • color transformation refers to modifying a certain color channel of the image, such as turning off the channel or changing the channel brightness value.
  • the image usually includes three RGB channels, and color transformation can reduce or increase the R channel value.
  • Noise perturbation refers to adding a random value matrix sampled from a Gaussian distribution to the RGB pixel matrix of the image
  • kernel filtering refers to using a kernel filter with a specific function to perform convolution operations with the image, such as sharpening, blurring, etc. Kernel filter.
  • the above data enhancement method is for illustration.
  • This application can also expand the first unlabeled sample to obtain the first expanded sample through other data enhancement methods.
  • the image can be enhanced through adversarial training (adversarial training).
  • feature space augmentation feature space augmentation
  • GAN-based data augmentation gan-based data augmentation
  • other data enhancement methods this application will not go into details one by one.
  • the label of the target detection model is usually a detection box (bounding box)
  • the first model when the first model is a target detection model, flip transformation, translation transformation, scale transformation, rotation transformation, scaling transformation, etc. can be used for detection.
  • the data enhancement method that affects the frame is used to perform data enhancement on the first unlabeled sample; since the label of the image recognition model is the probability distribution of the category to which the image belongs, when the first model is an image recognition model, pruning, color transformation, Data enhancement methods that affect image category determination, such as noise perturbation and kernel filtering, are used to perform data enhancement on the first unlabeled sample.
  • the above implementation method by performing different data enhancement operations for the model type, can make the finally obtained expanded samples increase the generalization performance of the model and improve the robustness of the model.
  • the matching degree between the first pseudo label and the second pseudo label can be obtained according to the first pseudo label and the second pseudo label; when the matching degree is higher than the threshold, the first pseudo label is The label and the second pseudo-label are fused to obtain the third pseudo-label.
  • the output results of the target detection model may be multiple target detection frames.
  • the non-maximum suppression (NMS) method can be used to select the accuracy among the multiple target detection frames.
  • the highest detection frame is used as the first pseudo label or the second pseudo label, thereby increasing the accuracy of the first pseudo label of the first unlabeled sample and the second pseudo label of the expanded sample.
  • the reverse operation of data enhancement can be performed on the second pseudo label to obtain the fourth pseudo label, and then the fourth pseudo label can be obtained.
  • the four pseudo-labels are matched with the first pseudo-label, a matching result between the first pseudo-label and the fourth pseudo-label is obtained, and the above-mentioned matching degree is determined based on the above-mentioned matching result.
  • the inverse operation refers to the opposite operation to the data enhancement method performed by the data enhancement unit.
  • the data enhancement method performs a horizontal upward flip operation on the first unlabeled sample to obtain the first expanded sample.
  • the matching unit can The standard frame corresponding to the second pseudo-label of the first expanded sample is flipped horizontally downward to obtain the fourth pseudo-label.
  • Another example is the data
  • the enhancement method is to perform a 90° right rotation operation on the first unlabeled sample to obtain the first extended sample.
  • the matching unit 123 can perform a 90° left rotation operation on the standard frame corresponding to the second pseudo label of the first extended sample to obtain the first extended sample.
  • the pseudo label of the target detection model is the detection frame of the target in the image. Therefore, in the first expanded sample obtained through the data enhancement method, the position of the target has actually changed, and the detection frame needs to be reversely operated, so that the first The target selected by the pseudo label and the second pseudo label is the target at the same position, and then matched.
  • the first pseudo label with inaccurate labeling targets can be screened out, thereby avoiding the need for semi-supervised training of the second model. Wrong or low-precision pseudo-labels are used to improve training efficiency and improve the accuracy of the final first model.
  • the detection frame corresponding to the first pseudo label can be matched with the detection frame corresponding to the fourth pseudo label to obtain the above matching degree.
  • the matching degree here can be The intersection over union (IOU) between two detection boxes.
  • the matching unit when the matching degree is greater than the threshold, can fuse the first pseudo label and the second pseudo label to obtain the third pseudo label, and can correspond the detection frame corresponding to the fourth pseudo label to the first pseudo label.
  • the detection frame is subjected to multi-value averaging processing to obtain the third pseudo label.
  • the above implementation method fuses the detection frame corresponding to the first pseudo label and the detection frame corresponding to the fourth pseudo label.
  • the two detection frames are target detection frames determined using different methods, so fusing the two can further improve the final result obtained.
  • the accuracy of the third pseudo-label is improved, thereby improving the accuracy of the pseudo-label used in the subsequent semi-supervised training process, thereby improving the training efficiency and improving the accuracy of the first model finally obtained.
  • the first model is an image recognition model
  • the first pseudo label and the second pseudo label can be matched to obtain the matching result between the first pseudo label and the second pseudo label. According to the matching The result determines the match between the two.
  • the probability distribution corresponding to the first pseudo label can be matched with the probability distribution corresponding to the second pseudo label to determine the matching degree between the two.
  • the matching degree can be the similarity between the two probability distributions. Or distance, this application does not specifically limit it.
  • the probability distribution of the first pseudo label and the probability distribution of the second pseudo label can be averaged, such as average, weighted average, etc.
  • the first model is an image recognition model
  • the first pseudo label and the second pseudo label are matched, and if the value is greater than the threshold, the two are fused, which can further improve the accuracy of the finally obtained third pseudo label. , thereby improving the accuracy of the pseudo labels used in the subsequent semi-supervised training process, thereby improving the training efficiency and improving the accuracy of the first model finally obtained.
  • the second unlabeled sample can also be input into the first model to obtain the fifth pseudo label of the second unlabeled sample, and then the second expanded sample can be input into the first model to obtain the fifth pseudo label of the second expanded sample.
  • the sixth pseudo-label here, the second augmented sample is a sample obtained after performing data enhancement on the second unlabeled sample.
  • the matching degree between the fifth pseudo label and the sixth pseudo label can be obtained based on the fifth pseudo label and the sixth pseudo label. When the matching degree is not higher than the above threshold, the fifth pseudo label and the sixth pseudo label are deleted. .
  • the method for determining the matching degree between the fifth pseudo label and the sixth pseudo label may refer to the method for determining the matching degree between the first pseudo label and the second pseudo label in the foregoing content, and the details will not be repeated here.
  • the matching degree between the first pseudo label and the second pseudo label is not higher than the threshold, the first pseudo label and the second pseudo label can also be deleted.
  • the fifth pseudo label and the sixth pseudo label If the matching degree between the labels is higher than the threshold, the fifth pseudo-label and the sixth pseudo-label can also be fused.
  • the fusion method please refer to the description in the previous content about the fusion of the first pseudo-label and the second pseudo-label to obtain the third pseudo-label. I won’t repeat them here. Simply put, the unlabeled sample set and the expanded sample set are matched.
  • the corresponding pseudo-labels of the unlabeled samples and the pseudo-labels of the expanded samples will be matched to obtain the corresponding match. Matching degree. If the matching degree is higher than the threshold, the pseudo-labels of the two will be fused. If the matching degree is lower than the threshold, the pseudo-labels of both will be deleted.
  • the above implementation method can filter out pseudo-labels with lower accuracy by matching the pseudo-labels of the unlabeled sample set and the pseudo-labels of the expanded sample set, thereby avoiding errors when semi-supervised training of the second model. Or pseudo-labels with low accuracy, thereby improving training efficiency and improving the accuracy of the final first model.
  • model structure of the first model may be the same as the model structure of the second model.
  • the weight parameters of the first model can be copied to the second model first, and then the first unlabeled sample, the third pseudo label and the above-mentioned labeled samples used in training the first model can be used.
  • the second model is iteratively trained on the labeled samples, and the weight parameters of the first model are iteratively updated based on the weight parameters of the second model obtained through each iterative training to obtain the target model.
  • the first unlabeled sample, the third pseudo-label and the above-mentioned labeled sample are used to perform a first round of training on the second model, and the weight parameters of the second model after the first round of updates are obtained, and then sent to the first model.
  • Update the first model to obtain a new first model then input the above-mentioned first unlabeled sample and first expanded sample into the new first model, and predict a new first pseudo label and a new second pseudo label, Then fuse the new first pseudo-label and the new second pseudo-label whose matching degree is higher than the threshold to obtain a new third pseudo-label, and then use the first unlabeled sample, the new third pseudo-label and the labeled sample to continue.
  • Train the second model until it converges, obtain the weight parameters after the second round of updates, and then update them to the first model, and so on. I will not go into details here.
  • the above-mentioned second unlabeled sample and second expanded sample can be input into the new first model to obtain a new fifth pseudo label and a new The sixth pseudo-label, if the matching degree between the fifth pseudo-label and the sixth pseudo-label is still not higher than the threshold, you can continue to delete the fifth pseudo-label and the sixth pseudo-label; if the fifth pseudo-label and the sixth pseudo-label The matching degree between the labels is higher than the threshold.
  • the fifth pseudo-label and the sixth pseudo-label can be fused to obtain the seventh pseudo-label, and then the seventh pseudo-label, the second unlabeled sample and the labeled sample are used to match the second pseudo-label.
  • the model is trained until convergence, the weight parameters of the second model after the second round of updates are obtained, and then updated to the first model, and so on.
  • the new first pseudo label can also be Pseudo-label and new second pseudo-label deletion will not be repeated here.
  • the above implementation method updates the first model with the weight parameters of the second model, and then performs multiple rounds of iterative training, so that the accuracy of the pseudo labels inferred by the first model becomes higher and higher until the prediction accuracy of the first model reaches standards required by the user to obtain the target model.
  • all the weight parameters obtained by the second model in each round of training can be updated to the first model, or part of the weight parameters obtained in each round of training can be updated to the first model, so that the first model obtains slow, Stable weight updates make the first model obtained through training more robust and the model performance better.
  • the student's weight can be updated to the first model through the exponential moving average (EMA) method.
  • EMA exponential moving average
  • the second model has the same model structure as the first model. It can use a small number of labeled samples that are difficult to obtain and a large number of unlabeled samples that are easy to obtain to train the machine learning model.
  • the target model obtained is not only reckless but also easy to obtain. It has good stickiness, good model performance and high training efficiency.
  • the model structure of the first model may also include the model structure of the second model. That is to say, the second model is a small model and the first model is a large model.
  • the second model is the first model.
  • the second model a submodel of , similarly , the second model synchronizes the updated weights obtained in each round of training to the first model, and the new first model predicts new pseudo labels to train the second model, and so on, steadily updating the second model and the first model.
  • the trained second model will be used as the target model.
  • the second model used is a small model and the first model is a large model.
  • the second model finally obtained not only has a low structural complexity, but also has model performance that is close to that of the first model, or even better than the first model. The performance is better, thereby achieving the purpose of model compression.
  • the input samples in the labeled sample set can be input into the second model to obtain the first output value.
  • the first unlabeled sample is input into the second model to obtain the second output value
  • the loss value of the second model is determined based on the first output value and the second output value
  • the second model is backpropagated according to the loss value until convergence, and training is obtained
  • a good second model is then synchronized to the model parameters of the trained second model to the first model, and then the next round of model training is performed.
  • the above-mentioned loss value L includes a label loss L 1 and an unlabeled loss L 2 .
  • the loss value L is obtained based on the difference between the first output value and the real label.
  • the pseudo-label loss is obtained based on the second output value and the second output value.
  • the gap between the three pseudo-labels is obtained.
  • the proportion of labeled loss L 1 and unlabeled loss L 2 in the loss value L can be controlled through coefficient weighting.
  • the loss value L L 1 + ⁇ L 2 , where the larger ⁇ , the less The greater the proportion of label loss L 2 in the loss value L, the greater the impact of the first unlabeled sample on the model performance of the second model and the first model.
  • the above implementation method uses labeled loss and unlabeled loss to jointly affect the training direction of the second model, and the unlabeled loss is determined based on the gap between the third pseudo label obtained after the above filtering and fusion and the output value, so that the third During the semi-supervised learning process, the second model can use a large number of unlabeled samples for training, thereby reducing the cost of sample acquisition without affecting the performance of the final target model.
  • the above-mentioned semi-supervised model training method can also be packaged into a software module to upgrade the software of some existing model training equipment so that it can have the functions of pseudo label filtering and fusion, making the upgrade
  • the later model training system can have better semi-supervised training functions.
  • each unit module used to implement the above semi-supervised model training method can be packaged as a configuration module as a small configuration function in the public cloud model training service. If the public cloud user purchases the function to provide users with corresponding permissions.
  • each unit module used to implement the above semi-supervised model training method can be packaged as a microservice or software package. After the user purchases the pseudo-label filtering and fusion functions provided by this application, the user can be provided with the corresponding Permission license, different charging levels can be set for different permissions. There are no specific limitations in this application.
  • the above implementation method is to package software into microservices, provide licenses or provide cloud services. Not only is the user acquisition method simple and fast, but also developers can implement simple software upgrades to the original model training system to achieve the above functions. , it is very convenient for developers to upgrade and maintain.
  • the semi-supervised model training method provided by this application is easy to deploy and has high usability.
  • a semi-supervised model training system in a second aspect, includes: an inference unit, used to input the first unlabeled sample into the first model and obtain the first pseudo-label of the first unlabeled sample; the inference unit is used to After inputting the first expanded sample into the first model, a second pseudo-label of the expanded sample is obtained, where the first model is an artificial intelligence AI model trained using labeled samples, and the first expanded sample is a comparison of the first unlabeled sample.
  • the pseudo labels are used to train the second model, where the second model is an AI model obtained based on the weight parameters of the first model.
  • the model training system obtaineds an expanded sample set of the unlabeled sample set by performing data enhancement on the unlabeled sample set, and then inputs the unlabeled sample set and the expanded sample set into the first model inference Obtain multiple first pseudo-labels of the set of unlabeled samples and multiple second pseudo-labels of the expanded sample set, and then fuse the first pseudo-labels and second pseudo-labels whose matching degree is higher than the threshold to obtain a third pseudo-label, Filter the first pseudo-labels whose matching degree is lower than or equal to the threshold, thereby improving the quality of the pseudo-labels of the unlabeled sample set, so that when the third pseudo-label is subsequently used to conduct semi-supervised training of the student model, the training efficiency and performance of the model can be improved can be improved, thereby improving the training efficiency and model performance of the finally obtained first model.
  • the second model has the same structure as the first model.
  • the matching unit is used to obtain the matching degree between the first pseudo label and the second pseudo label according to the first pseudo label and the second pseudo label; the matching unit is used to obtain the matching degree between the first pseudo label and the second pseudo label when the matching degree is higher than In the case of a threshold, the first pseudo label and the second pseudo label are fused to obtain the third pseudo label.
  • the first model includes a target detection model
  • the data enhancement method includes one or more of flip transformation, translation transformation, scale transformation, rotation transformation, and scaling transformation.
  • the matching unit is used to perform the reverse operation of data enhancement on the second pseudo label to obtain the fourth pseudo label; the matching unit is used to match the first pseudo label and the fourth pseudo label to obtain The matching result between the first pseudo label and the fourth pseudo label; the matching unit is used to determine the matching degree according to the matching result between the first pseudo label and the fourth pseudo label.
  • the first model includes an image recognition model
  • the data enhancement method includes one or more of pruning, color transformation, noise perturbation, and kernel filtering.
  • the matching unit is used to match the first pseudo label and the second pseudo label to obtain the matching result between the first pseudo label and the second pseudo label; the matching unit is used to match the first pseudo label and the second pseudo label according to the first pseudo label.
  • the matching result between the pseudo label and the second pseudo label obtains the matching degree.
  • the inference unit is used to input the second unlabeled sample into the first model and obtain the fifth pseudo-label of the second unlabeled sample; the inference unit is used to input the second extended sample into the first model , obtain the sixth pseudo-label of the second expanded sample, which is a sample obtained after data enhancement of the second unlabeled sample; the matching unit is used to obtain the fifth pseudo-label based on the fifth pseudo-label and the sixth pseudo-label.
  • the matching degree between the pseudo label and the sixth pseudo label the matching unit is used to delete the fifth pseudo label and the sixth pseudo label when the matching degree is not higher than the threshold.
  • the training unit is used to iteratively train the second model using labeled samples, first unlabeled samples and third pseudo-labeled samples, based on the weight parameters of the second model obtained through each iterative training. Iteratively update the weight parameters of the first model to obtain the target model.
  • the training unit is configured to input the input sample into the second model to obtain the first output value, and input the first unlabeled sample into the second model to obtain the second output value.
  • the output value determines the loss value of the second model, where the loss value includes a labeled loss and a pseudo-label loss.
  • the labeled loss is obtained based on the difference between the first output value and the true label, and the pseudo-label loss is obtained based on the second Obtained from the difference between the output value and the third pseudo label; the training unit is used to iteratively train the second model based on the loss value.
  • a computing device in a third aspect, includes a processor and a memory.
  • the memory stores codes.
  • the processor includes a method for executing the method described in the first aspect or any possible implementation of the first aspect.
  • a computer storage medium is provided. Instructions are stored in the storage medium. When run on a computing device, the computing device causes the computing device to execute the method described in the first aspect or any possible implementation of the first aspect. .
  • a fifth aspect provides a computer program instruction that, when run on a computing device, causes the computing device to execute the method described in the first aspect or any possible implementation of the first aspect.
  • Figure 1 is a schematic flowchart of the steps of semi-supervised learning
  • Figure 2 is a schematic architectural diagram of a semi-supervised model training system provided by this application.
  • Figure 3 is a schematic flowchart of the steps of a semi-supervised model training method provided by this application.
  • Figure 4 is a schematic flow chart of the steps in an application scenario of a semi-supervised model training method provided by this application;
  • Figure 5 is a schematic diagram of the fusion process of the first pseudo-label and the second pseudo-label in the semi-supervised model training method provided by this application;
  • Figure 6 is a schematic structural diagram of a computing device provided by this application.
  • Labeled samples refer to samples with labels.
  • the label of the sample represents the true value of the sample.
  • the true value is used to compare with the model during model training.
  • the loss value is calculated together with the predicted values, and then the weight parameters of the model are adjusted. For example: when training a classification model, after labeled samples are input into the initial classification model as input data, the prediction results provided by the initial classification model are compared with the labels of the labeled samples to obtain the loss value of this round of training, and then based on the loss The value can adjust the weight parameters of the classification model.
  • unlabeled samples are samples that do not contain labels.
  • Loss function The loss function is used to evaluate the difference between the model's output and the sample label during the model training process.
  • the loss value (loss) is the value corresponding to the loss function. The lower the loss value, the better the robustness of the model. Therefore, during the model training process, after the sample is input to the model and the output value is obtained, the loss value is determined based on the difference between the output value and the sample label, and the loss value is determined based on the size of the loss value. The weight parameters of the model are adjusted and iterated until the loss function of the model is minimized and the target model is obtained.
  • AI is a theory, method, technology and application system that uses digital computers or machines controlled by digital computers to simulate, extend and expand human intelligence, perceive the environment, acquire knowledge and use knowledge to obtain the best results.
  • Application scenarios in the field of artificial intelligence include robotics, natural language processing, computer vision, decision-making and reasoning, human-computer interaction, recommendation and search, etc.
  • Figure 1 is a schematic flowchart of the steps of semi-supervised learning.
  • the sample set used in the semi-supervised learning process includes labeled samples 101 and unlabeled samples 102, the network structure of the first model 112 and the second model 113 Similarly, specifically, the step process of semi-supervised learning can be as follows:
  • Step 1 Use labeled samples 101 to train the initial first model, and obtain the first model 112 after the model converges.
  • the initial first model may be an AI model that has not yet been trained after the network is initialized.
  • Step 2 Input the unlabeled sample 102 into the first model 112.
  • the first model 112 infers the pseudo label of the unlabeled sample 102 and obtains the pseudo labeled sample 103.
  • Step 3 Copy the weight parameters of the first model 112 to the second model 113.
  • Step 4 Use the labeled sample 101 and the pseudo-labeled sample 103 to train the second model 113 to obtain the updated weight parameters of the second model.
  • the loss of the model includes labeled loss and pseudo-label loss, where the labeled loss is based on the model output and the labeled sample 101
  • the pseudo-label loss is obtained based on the difference between the model output and the pseudo-label of the pseudo-label sample 103.
  • Step 5 Feed back the updated weight parameters of the second model to the first model, and update the weight parameters of the first model.
  • labeled samples 101 need to be manually labeled and are limited in number, usually the number of unlabeled samples 102 is much larger than the number of labeled samples 101. Therefore, during the training process of the second model 113, the proportion of pseudo-labeled samples 103 is much larger than that of labeled samples 101. For labeled sample 101, the weight of pseudo-label loss is also much higher than that of labeled loss. As a result, the quality of pseudo-label samples determines the optimization direction of the second model and the first model.
  • the pseudo-label is the label inferred by the first model. There will be many errors and noise in the labels, and semi-supervised learning cannot correct the pseudo-labels. If the pseudo-labels are corrected manually, it is not only inefficient but also the labor cost is too high.
  • this application provides a semi-supervised model training system.
  • This system obtains expanded samples of unlabeled samples by performing data enhancement on unlabeled samples, and then inputs the unlabeled samples and expanded samples into the first model for inference to obtain unlabeled samples. Label the first pseudo-label of the sample and the second pseudo-label of the expanded sample, and then obtain the third pseudo-label based on the matching results of the first pseudo-label and the second pseudo-label.
  • the third pseudo-label obtained in this way is of higher quality. Use the third When pseudo-labels are used for semi-supervised training of the second model, the training efficiency and performance of the model are improved, thereby improving the training efficiency and model performance of the finally obtained first model.
  • FIG 2 is an architectural schematic diagram of a semi-supervised model training system provided by this application.
  • the architecture of the semi-supervised model training system includes an inference device 110, a semi-supervised model training system 120, a user device 140, and data storage. System 150 and data collection device 160.
  • a communication connection can be established between the inference device 110, the semi-supervised model training system 120, the user device 140, the data storage system 150 and the data collection device 160.
  • the communication connection can be established through a wired network or a wireless network. This application does not Specific limitations.
  • the data acquisition device 160 is used to collect original samples and send them to the semi-supervised model training system 120 for model training.
  • the original samples may be original samples of image type.
  • the data acquisition device 160 may include an image acquisition device and a radar acquisition device.
  • Sensors used to collect original samples, etc., image collection devices can be surveillance cameras, electronic police, depth cameras, drones, etc., and radar collection devices can be radars, satellites, etc., which are not specifically limited in this application. It should be understood that in different application scenarios, the original samples required for training the machine learning model are different, and the corresponding data collection device 160 It is also different, and this application does not specifically limit this.
  • the semi-supervised model training system 120 is used to receive the original samples collected by the data collection device 160, process the original samples to obtain the labeled sample set 131, the unlabeled sample set 132 and the expanded sample set 133 shown in Figure 2, using the above samples After collectively training the first model 125 and the second model 126, the trained target model 127 is obtained and sent to the inference device 110.
  • the semi-supervised model training system 120 can be deployed on a computing device, which can be a bare metal server (bare metal server, BMS), a virtual machine or a container.
  • BMS bare metal server
  • a virtual machine refers to a complete hardware system function implemented through network functions virtualization (NFV) technology and simulated through software
  • NFV network functions virtualization
  • a container refers to a group of processes that are subject to resource constraints and isolated from each other.
  • the computing device can also be an edge computing device, which is not specifically limited in this application.
  • the semi-supervised model training system 120 may also be a server cluster, such as a centralized server or a distributed server.
  • the semi-supervised model training system 120 can also be deployed in the public cloud, and is provided to public cloud users as a model training cloud service. Users can obtain the use rights of the semi-supervised model training system 120 by purchasing the service. Applications are not subject to specific restrictions.
  • the semi-supervised model training system 120 can also be provided to users in the form of software packaging, and users install the software on their own computing devices, or it can be provided to users in the form of microservices.
  • users can purchase the required software version or capabilities according to their own needs and obtain a license for the corresponding permissions. Different fees can be set for different permissions.
  • the inference device 110 is configured to receive the target model 127 sent by the semi-supervised model training system 120, use the target model 127 to perform inference on the input data sent by the user device 140, obtain the output data, and return it to the user device 140, or alternatively, stored in data storage system 150.
  • the above-mentioned inference device 110 may be a computing device, specifically a BMS, a virtual machine, a container, a terminal device or an edge computing device, which is not specifically limited in this application.
  • the user device 140 may be a terminal device held by the user, including a computer, a smartphone, a handheld processing device, a tablet, a mobile notebook, an augmented reality (AR) device, a virtual reality (VR) device, an integrated device Handheld devices, wearable devices, vehicle-mounted devices, smart conference equipment, smart advertising equipment, smart home appliances, etc. are not specifically limited here.
  • a computer including a computer, a smartphone, a handheld processing device, a tablet, a mobile notebook, an augmented reality (AR) device, a virtual reality (VR) device, an integrated device Handheld devices, wearable devices, vehicle-mounted devices, smart conference equipment, smart advertising equipment, smart home appliances, etc. are not specifically limited here.
  • AR augmented reality
  • VR virtual reality
  • the user device 140 may also be a data collection device, the input data collected by it may be input to the inference device 110 for target detection or image recognition, and the output results are stored in the data storage system 150 .
  • the user device 140 can be an electronic police officer on the road
  • the inference device 110 is an edge computing device on both sides of the road
  • the target model is a license plate recognition model
  • the data storage system is a database maintained by the traffic police brigade
  • the picture can be input into the license plate recognition model in the edge computing device to identify the license plate number of the speeding vehicle in the picture of the speeding vehicle and store it in the database maintained by the traffic police brigade. It should be understood that the above examples are for illustration and are not specifically limited in this application.
  • the user device 140 and the inference device 110 may be the same device.
  • the user's smartphone downloads the face recognition model trained by the semi-supervised model training system 120
  • the user collects face input data through the camera on the smartphone. , input it into the face recognition model to obtain the face recognition result.
  • the face recognition result can be directly displayed on the user device 140, or stored in a remote server for subsequent authentication and matching such as safe unlocking and safe payment.
  • This application does not Specific limitations.
  • the data storage system 150 can be a server or storage array with storage function.
  • the server can be a physical server such as an ARM server or an X86 server, or a virtual machine, which is not specifically limited in this application.
  • the data storage system 150 is used to store output data of the inference device 110 .
  • the semi-supervised model training system 120 can be further divided into multiple unit modules.
  • Figure 2 is an exemplary division method.
  • the semi-supervised model training system 120 can include a data enhancement unit 121 and an inference unit 122. , matching unit 123 and training unit 124, wherein a communication connection is established between the data enhancement unit 121, inference unit 122, matching unit 123 and training unit 124, which may be a wired connection or a wireless connection, which is not specifically limited in this application.
  • the sample database 130 may be stored in the semi-supervised model training system 120 as shown in FIG. 2 , or may be stored in an external memory of the semi-supervised model training system 120 , which is not specifically limited in this application.
  • the semi-supervised model training system 120 may also include a sample database 130, where the sample database 130 includes an unlabeled sample set 132, a labeled sample set 131, and an expanded sample set 133.
  • the unlabeled sample set 132 may include a plurality of unlabeled unlabeled samples. Labeled samples, the unlabeled samples can be original samples collected by the data collection device 160, or obtained after data preprocessing of the original samples (such as cropping, noise reduction and other preprocessing methods to improve sample quality); labeled sample sets 131 can include multiple labeled samples, and each labeled sample includes an input sample and a real label.
  • the input sample can be the above-mentioned original sample, or it can be the real label of the input sample obtained after data preprocessing of the original sample. It may be obtained after manual annotation; the expanded sample set 133 includes a plurality of unlabeled expanded samples, which are obtained after the data enhancement unit 121 of the semi-supervised model training system 120 performs data enhancement on the unlabeled sample set 132.
  • the data enhancement unit 121 can perform data enhancement on the unlabeled samples in the unlabeled sample set 132 through a data enhancement method to obtain expanded samples corresponding to the unlabeled samples, and by analogy, obtain the expanded sample set, where the data enhancement method may include: Not limited to flip transformation (flip), translation transformation (shift), scale transformation (scale), rotation transformation/reflection transformation (rotation/reflection), zoom transformation (zoom), crop (crop), color transformation (color space), noise One or more of noise and kernel filters.
  • flip transformation refers to flipping the image horizontally or vertically.
  • Horizontal flipping can also be divided into upward horizontal flipping and downward horizontal flipping.
  • Vertical flipping can also be divided into left vertical flipping and right vertical flipping;
  • translation transformation refers to It is a translation operation on the image, such as translation to the right in the x direction (param xoffset) and downward translation in the y direction (param yoffset), where the x direction and y direction refer to the horizontal and vertical axis directions of the image coordinate system;
  • rotation Transformation can also be called reflection transformation, which refers to rotating the image at a certain angle, which can be any angle from 0 to 360 degrees; scaling transformation refers to enlarging or reducing the image according to a certain ratio without will change the content in the image; trimming can also be called cropping, including uniform cropping and random cropping.
  • Uniform cropping refers to cropping images of different sizes to a set size, and random cropping refers to randomly cropping images of different sizes into Different sizes
  • color transformation refers to modifying a certain color channel of the image, such as turning off the channel or changing the channel brightness value.
  • the image usually includes three RGB channels, and color transformation can reduce or increase the R channel value.
  • Noise perturbation refers to adding a random value matrix sampled from a Gaussian distribution to the RGB pixel matrix of the image
  • kernel filtering refers to using a kernel filter with a specific function to perform convolution operations with the image, such as sharpening, blurring, etc. Kernel filter.
  • This application can also expand unlabeled samples through other data enhancement methods to obtain expanded samples.
  • images can also be processed through adversarial training (adversarial training) and feature space enhancement (Feature space augmentation), GAN-based data augmentation (gan-based data augmentation) and other data enhancement methods expand the unlabeled sample set 132 to obtain the expanded sample set 133.
  • adversarial training asversarial training
  • feature space enhancement Feature space augmentation
  • GAN-based data augmentation gan-based data augmentation
  • other data enhancement methods expand the unlabeled sample set 132 to obtain the expanded sample set 133. This application will not go into details one by one.
  • the label of the target detection model is usually a detection box (bounding box)
  • flip transformation, translation transformation, scale transformation, rotation transformation, scaling transformation, etc. can be used.
  • the data enhancement method that affects the detection frame is to perform data enhancement on unlabeled samples; since the label of the image recognition model is the probability distribution of the category to which the image belongs, when the first model 125 is an image recognition model, pruning, color transformation, Data enhancement methods that affect image category determination, such as noise perturbation and kernel filtering, are used to perform data enhancement on unlabeled samples. It should be understood that performing different data enhancement operations for the model type can make the finally obtained expanded samples increase the generalization performance of the model and improve the robustness of the model.
  • the data enhancement unit 121 and the sample database 130 can also be deployed outside the semi-supervised model training system 120.
  • the semi-supervised model training system 120 establishes a connection with the preprocessing system, and the data enhancement unit 121 and the sample database 130 are deployed there.
  • the sample database 130 is maintained through the preprocessing system, and the data enhancement operation is performed on the unlabeled sample set 132, which is not specifically limited in this application.
  • the inference unit 122 is configured to input the first unlabeled sample into the first model 125 to generate a first pseudo label of the first unlabeled sample, and input the first expanded sample into the first model 125 to generate a second pseudo label of the first expanded sample.
  • the first model 125 is an AI model obtained after training using labeled samples. It should be noted that when using the labeled sample set to train the first model, the number of training rounds can be controlled so that the first model has a certain detection capability to prevent overfitting of the first model 125 and the second model 126 during subsequent semi-supervised training.
  • the first model may be a target detection model or an image recognition model.
  • the target detection model can be a one-stage unified real-time target detection (You Only Look Once: Unified, Real-Time Object Detection, Yolo) model, a single shot multi box detector (Single Shot multi box Detector, SSD) model, or a regional convolutional neural network.
  • Network Regular Convolutional Neural Network, RCNN
  • Fast-RCNN Fast Region Convolutional Neural Network
  • the output results of the target detection model may be multiple target detection frames, and the inference unit 122 may select multiple targets through the non-maximum suppression (NMS) method.
  • NMS non-maximum suppression
  • the detection frame with the highest accuracy among the detection frames is used as the first pseudo label or the second pseudo label, thereby increasing the accuracy of the first pseudo label of the first unlabeled sample and the second pseudo label of the expanded sample.
  • the matching unit 123 is configured to match the first pseudo label of the first unlabeled sample with the second pseudo label of the first extended sample to obtain the third pseudo label of the first unlabeled sample.
  • the matching unit 123 may obtain the matching degree between the first pseudo label and the second pseudo label based on the first pseudo label and the second pseudo label, and when the matching degree is higher than the threshold, the first pseudo label Fusion with the second pseudo-label to obtain the third pseudo-label.
  • the matching unit 123 can perform the reverse operation of data enhancement on the second pseudo label to obtain the fourth pseudo label. , and then match the fourth pseudo-label with the first pseudo-label to obtain the matching result between the first pseudo-label and the fourth pseudo-label, and determine the above-mentioned matching degree based on the above-mentioned matching result.
  • the inverse operation refers to the opposite operation to the data enhancement method performed by the data enhancement unit 121.
  • the data enhancement unit 121 performs a horizontal upward flip operation on the first unlabeled sample to obtain the first expanded sample, then the matching unit 123 The standard frame corresponding to the second pseudo label of the first extended sample can be flipped horizontally downward to obtain the fourth pseudo label.
  • the data enhancement unit 121 performs a 90° right rotation operation on the first unlabeled sample to obtain the first If the sample is expanded, then the matching unit 123 can rotate the standard box corresponding to the second pseudo label of the first expanded sample 90 degrees to the left to obtain the fourth pseudo label, and so on. Examples are not given here.
  • the pseudo label of the target detection model is the detection frame of the target in the image. Therefore, in the first expanded sample obtained through the data enhancement method, the position of the target has actually changed, and the detection frame needs to be reversely operated, so that the first The target framed by the pseudo label and the second pseudo label is the target at the same position, and then matching them can filter out the first pseudo label that labels the target inaccurately, thereby avoiding semi-supervised training of the second model 126 Sometimes incorrect or low-precision pseudo-labels are used, thereby improving training efficiency and improving the accuracy of the final first model.
  • the matching unit 123 may compare the detection frame corresponding to the first pseudo label and the detection frame corresponding to the fourth pseudo label. Perform matching to obtain the above matching degree, where the matching degree can be the intersection over union (IOU) between the two detection frames.
  • IOU intersection over union
  • the matching unit can fuse the first pseudo label and the second pseudo label to obtain the third pseudo label, and can correspond the detection frame corresponding to the fourth pseudo label to the first pseudo label.
  • the detection frame is subjected to multi-value averaging processing to obtain the third pseudo label.
  • the coordinates of the detection frame corresponding to the first pseudo label are
  • the coordinates of the detection frame corresponding to the fourth pseudo label are
  • the formula of the third pseudo-label y u after fusion can be as follows:
  • the detection frame corresponding to the first pseudo label and the detection frame corresponding to the fourth pseudo label are fused.
  • the two detection frames are target detection frames determined using different methods. Therefore, fusing the two can further improve the final result obtained.
  • the accuracy of the three pseudo-labels is avoided, thereby avoiding the use of incorrect or low-precision pseudo-labels during semi-supervised training of the second model 126, thereby improving the training efficiency and improving the accuracy of the finally obtained first model.
  • the first model 125 is an image recognition model
  • the first pseudo label and the second pseudo label can be matched to obtain the matching result between the first pseudo label and the second pseudo label, and the two pseudo labels are determined based on the matching result.
  • the degree of matching between them can be the matching unit 123 can match the probability distribution corresponding to the first pseudo label with the probability distribution corresponding to the second pseudo label to determine the matching degree between the two.
  • the matching degree here can be between the two probability distributions.
  • the similarity or distance is not specifically limited in this application.
  • the probability distribution of the first pseudo label and the probability distribution of the second pseudo label can be averaged, such as average, weighted average, etc. etc., this application does not make specific limitations.
  • the inference unit 122 may also input the second unlabeled sample into the first model to obtain the fifth pseudo label of the second unlabeled sample, and then input the second expanded sample into the first model to obtain the fifth pseudo label of the second expanded sample.
  • Six pseudo-labels here, the second augmented sample is the sample obtained after data enhancement of the second unlabeled sample.
  • the matching unit 123 may obtain the matching degree between the fifth pseudo label and the sixth pseudo label based on the fifth pseudo label and the sixth pseudo label, and when the matching degree is not higher than the above threshold, delete the fifth pseudo label and the sixth pseudo label.
  • the method for determining the matching degree between the fifth pseudo label and the sixth pseudo label may refer to the method for determining the matching degree between the first pseudo label and the second pseudo label in the foregoing content, and the details will not be repeated here.
  • the matching degree between the first pseudo label and the second pseudo label is not higher than the threshold, the first pseudo label and the second pseudo label can also be deleted.
  • the fifth pseudo label and the sixth pseudo label If the matching degree between the labels is higher than the threshold, the fifth pseudo-label and the sixth pseudo-label can also be fused.
  • the fusion method please refer to the description in the previous content about the fusion of the first pseudo-label and the second pseudo-label to obtain the third pseudo-label. I won’t repeat them here. To put it simply, the unlabeled sample set 132 and the expanded sample set 133 are matched.
  • the corresponding pseudo-labels of the unlabeled samples and the pseudo-labels of the expanded samples will be matched to obtain the corresponding matching degree. If the matching degree is higher than the threshold, the The pseudo-labels of the two are fused, and if the matching degree is lower than the threshold, both pseudo-labels are deleted.
  • pseudo labels with lower accuracy can be filtered out, thereby avoiding the need for semi-supervised training of the second model 126.
  • Wrong or low-precision pseudo-labels are used to improve training efficiency and improve the accuracy of the final first model.
  • the training unit 124 is used to train the second model 126 using the first unlabeled sample and the third pseudo label.
  • the weight parameters of the second model 126 are the same as those of the first model 125.
  • the labeled sample set is used 131 Train the machine learning model to obtain the first model 125, and then copy the weight parameters of the first model 125 to the second model 126, and then train the second model 126 using the first unlabeled sample and the third pseudo-label.
  • the model structure of the first model 125 may be the same as the model structure of the second model 126 .
  • the training unit 124 trains the second model 126, it first copies the weight parameters of the first model 125 to the second model 126, and then uses the first unlabeled sample, the third pseudo label and the above training
  • the second model 126 is iteratively trained using the labeled samples used in the first model 125, and the weight parameters of the first model are iteratively updated according to the weight parameters of the second model obtained through each iterative training to obtain the target model 127.
  • the first unlabeled sample, the third pseudo-label and the above-mentioned labeled sample are used to perform a first round of training on the second model 126, and the weight parameters of the second model 126 after the first round of updates are obtained, and then sent to the second model 126.
  • the first model 125 updates the first model 125 to obtain a new first model 125, and then inputs the first unlabeled sample and the first expanded sample into the new first model 125 to predict a new first pseudo label and a new first model 125.
  • the above-mentioned second unlabeled sample and the second expanded sample can be input into the new first model 125 to obtain a new fifth pseudo label and
  • For the new sixth pseudo-label if the matching degree between the fifth pseudo-label and the sixth pseudo-label is still not higher than the threshold, you can continue to delete the fifth pseudo-label and the sixth pseudo-label; if the fifth pseudo-label and the sixth pseudo-label The matching degree between pseudo-labels is higher than the threshold.
  • the fifth pseudo-label and the sixth pseudo-label can be fused to obtain the seventh pseudo-label, and then the seventh pseudo-label, the second unlabeled sample and the labeled sample can be used to match the seventh pseudo-label.
  • the second model is trained until convergence, the weight parameters of the second model after the second round of updates are obtained, and then updated to the first model 125, and so on.
  • the new first pseudo label can also be Pseudo-label and new second pseudo-label deletion will not be repeated here.
  • the accuracy of the pseudo labels inferred by the first model 125 becomes higher and higher until the prediction of the first model 125 The accuracy reaches the standard required by the user, thereby obtaining the target model 127 and sending it to the inference device 110 .
  • all the weight parameters obtained by the second model 126 in each round of training can be updated to the first model 125, or some of the weight parameters obtained in each round of training can be updated to the first model 125, so that the first model can be slow and stable.
  • the weights are updated so that the first model obtained through training is more robust and the model performance is better.
  • EMA exponential moving average
  • the second model 126 has the same model structure as the first model 125.
  • a small number of labeled samples that are difficult to obtain and a large number of unlabeled samples that are easy to obtain can be used to train the machine learning model to obtain the target model.
  • 127 not only has good robustness, but also has good model performance and high training efficiency.
  • the model structure of the first model 125 may also include the model structure of the second model 126. That is to say, the second model 126 is a small model and the first model 125 is a large model.
  • the second model 126 is the first model.
  • the second model 126 synchronizes the updated weights obtained in each round of training to the first model 125.
  • the new first model 125 then predicts new pseudo labels to train the second model 126. , and so on, steadily update the second model 126 and the first model 125, so that the second model finally obtained not only has low structural complexity, but also has model performance that is close to, or even better than, the first model 125. Better, thereby achieving the purpose of model compression.
  • the input samples in the labeled sample set 131 can be input into the second model to obtain the first output. value, input the first unlabeled sample into the second model to obtain the second output value, determine the loss value of the second model based on the first output value and the second output value, and then perform backpropagation on the second model based on the loss value until convergence , obtain the trained second model, and then synchronize the model parameters of the trained second model to the first model 125, and then perform the next round of model training.
  • the above-mentioned loss value L includes a label loss L 1 and an unlabeled loss L 2 .
  • the loss value L is obtained based on the difference between the first output value and the real label.
  • the pseudo-label loss is obtained based on the second output value and the second output value. The gap between the three pseudo-labels is obtained.
  • the semi-supervised model training system 120 of the present application can also be packaged into a software module to perform software upgrades on some existing model training systems so that they can have the functions of pseudo label filtering and fusion, so that the upgraded The software model training system can have better semi-supervised training performance.
  • the above-mentioned data enhancement unit 121, inference unit 122, matching unit 123 and training unit 124 can be packaged as a configuration module as a small configuration function in the public cloud model training service. If Public cloud users who purchase this function can provide users with corresponding permissions.
  • the above-mentioned data enhancement unit 121, inference unit 122, matching unit 123 and training unit 124 can be packaged as a microservice or software package. After users purchase the pseudo-label filtering and fusion functions provided by this application, they can Users provide licenses with corresponding permissions, and different charging levels can be set for different permissions. There are no specific limitations in this application.
  • the semi-supervised model training system 120 may include more or fewer unit modules, such as the sample database 130 and the data enhancement unit 121. Deployed outside the semi-supervised model training system 120, for example, the matching unit 123 can be further divided into a threshold judgment unit, a fusion unit and a deletion unit, where the threshold judgment unit is used to determine the match between the first pseudo label and the second pseudo label. When the matching degree is higher than the threshold, the first pseudo label and the second pseudo label are fused through the fusion unit to obtain the third pseudo label. When the matching degree is not higher than the threshold, the first pseudo label is fused through the deletion unit. Pseudo-label and second pseudo-label deletion will not be repeated here.
  • model training system obtained by performing data enhancement on the unlabeled sample set, and then inputs the unlabeled sample set and the expanded sample set into the first model for inference to obtain the unlabeled sample.
  • Figure 3 is a schematic flowchart of the steps of a semi-supervised model training method provided by this application. This method can be applied to the semi-supervised model training system 120 shown in Figure 2. As shown in Figure 3, the method can include the following steps:
  • Step S310 The semi-supervised model training system 120 inputs the first unlabeled sample into the first model and obtains the first pseudo-label of the first unlabeled sample. This step can be implemented by the inference unit 122 in Figure 2.
  • the first model is obtained by training a machine learning model using a labeled sample set.
  • the machine learning model can be a target detection model or an image recognition model. Among them, the description of the target detection model and image recognition model can be found in The detailed description of the embodiment in Figure 2 will not be repeated here.
  • the output result of the target detection model may be multiple target detection frames
  • the semi-supervised model training system 120 may use the NMS method to select the detection frame with the highest accuracy among the multiple target detection frames as the The first pseudo label or the second pseudo label, thereby increasing the accuracy of the first pseudo label of the first unlabeled sample and the second pseudo label of the first augmented sample.
  • Step S320 Input the first expanded sample into the first model, and obtain the second pseudo label of the first expanded sample. This step can be implemented by the inference unit 122 in Figure 2.
  • step S320 data enhancement can be performed on the first unlabeled sample to obtain the first expanded sample of the first unlabeled sample.
  • This step can be implemented by the data enhancement unit 121 in Figure 2.
  • the description of the semi-supervised model training system 120 may refer to the embodiment of FIG. 2 and will not be repeated here.
  • the description of the first unlabeled sample may refer to the terminology explanation and will not be repeated here.
  • the data enhancement method may include but is not limited to one or more of flip transformation, translation transformation, scale transformation, rotation transformation/reflection transformation, scaling transformation, pruning, color transformation, noise perturbation, and kernel filtering.
  • flip transformation translation transformation
  • scale transformation rotation transformation/reflection transformation
  • scaling transformation pruning
  • color transformation noise perturbation
  • kernel filtering the specific description of each data enhancement method can refer to the description in the embodiment of Figure 2, and will not be repeated here.
  • the label of the target detection model is usually a detection box (bounding box)
  • the first model when the first model is a target detection model, flip transformation, translation transformation, scale transformation, rotation transformation, scaling transformation, etc. can be used for detection.
  • the data enhancement method that affects the frame is used to perform data enhancement on the first unlabeled sample; since the label of the image recognition model is the probability distribution of the category to which the image belongs, when the first model is an image recognition model, pruning, color transformation, Data enhancement methods that affect image category determination, such as noise perturbation and kernel filtering, are used to perform data enhancement on the first unlabeled sample. It should be understood that performing different data enhancement operations for the model type can make the finally obtained expanded samples increase the generalization performance of the model and improve the robustness of the model.
  • step S310 can also be implemented by other data preprocessing systems.
  • the semi-supervised model training system 120 can also obtain the above-mentioned first extended sample and the first unlabeled sample from other data preprocessing systems, or it can also obtain the first unlabeled sample by itself.
  • Data enhancement is performed on the first unlabeled sample to obtain the corresponding first expanded sample. The details can be determined according to actual business processing conditions, and are not specifically limited in this application.
  • Step S330 Obtain the third pseudo label of the first unlabeled sample based on the first pseudo label and the second pseudo label. This step can be implemented by the matching unit 123 in Figure 2.
  • the semi-supervised model training system 120 can obtain the matching degree between the first pseudo label and the second pseudo label based on the first pseudo label and the second pseudo label. If the matching degree is higher than the threshold, the The first pseudo-label and the second pseudo-label are fused to obtain the third pseudo-label.
  • the semi-supervised model training system 120 can perform the inverse operation of data enhancement on the second pseudo label to obtain the fourth pseudo-label, and then match the fourth pseudo-label with the first pseudo-label to obtain the above matching degree.
  • the inverse operation refers to the opposite operation to the data enhancement method performed by the data enhancement unit 121.
  • step S310 performs a horizontal upward flip operation on the first unlabeled sample to obtain the first expanded sample.
  • step S330 can perform the first extended sample on the first unlabeled sample.
  • the standard frame corresponding to the second pseudo-label of an expanded sample is flipped horizontally downward to obtain the fourth pseudo-label.
  • step S310 the first unlabeled sample is rotated 90° to the right to obtain the first expanded sample.
  • step S310 S330 may perform a left-rotation operation of 90 degrees on the standard box corresponding to the second pseudo-label of the first extended sample to obtain the fourth pseudo-label, and so on. Examples are not given here.
  • the pseudo label of the target detection model is the detection frame of the target in the image. Therefore, in the first expanded sample obtained through the data enhancement method, the position of the target has actually changed, and the detection frame needs to be reversely operated, so that the first The target selected by the pseudo label and the second pseudo label is the target at the same position, and then they are matched to filter out the target. Pay attention to the inaccurate first pseudo-labels to avoid using wrong or low-precision pseudo-labels during semi-supervised training of the second model, thereby improving training efficiency and improving the accuracy of the final first model.
  • the matching unit 123 may match the detection frame corresponding to the first pseudo label with the detection frame corresponding to the fourth pseudo label to obtain the above matching degree, where the matching degree may be the IOU between the two detection frames.
  • the matching unit can fuse the first pseudo label and the second pseudo label to obtain the third pseudo label, and can compare the detection frame corresponding to the fourth pseudo label with the detection frame corresponding to the first pseudo label. Multi-value average processing is used to obtain the third pseudo label.
  • Multi-value average processing is used to obtain the third pseudo label.
  • the detection frame corresponding to the first pseudo label and the detection frame corresponding to the fourth pseudo label are fused.
  • the two detection frames are target detection frames determined using different methods. Therefore, fusing the two can further improve the final result obtained.
  • the accuracy of the three pseudo-labels can be avoided by using incorrect or low-precision pseudo-labels during semi-supervised training of the second model, thereby improving the training efficiency and improving the accuracy of the final first model.
  • the first pseudo label and the second pseudo label can be matched to obtain a matching result between the first pseudo label and the second pseudo label, and the two pseudo labels can be determined based on the matching result. degree of matching between them.
  • the probability distribution corresponding to the first pseudo label can be matched with the probability distribution corresponding to the second pseudo label to determine the matching degree between the two.
  • the matching degree here can be the similarity between the two probability distributions. Or distance, this application does not specifically limit it.
  • the probability distribution of the first pseudo label and the probability distribution of the second pseudo label can be averaged, such as average, weighted average, etc. This method Applications are not subject to specific restrictions.
  • the second unlabeled sample into the first model to obtain the fifth pseudo-label of the second unlabeled sample, and then input the second expanded sample into the first model to obtain the sixth pseudo-label of the second expanded sample.
  • the second expanded sample is a sample obtained after performing data enhancement on the second unlabeled sample.
  • the matching degree between the fifth pseudo-label and the sixth pseudo-label is obtained.
  • the matching degree is not higher than the above threshold, the fifth pseudo-label and the sixth pseudo-label are deleted.
  • the method for determining the matching degree between the fifth pseudo label and the sixth pseudo label may refer to the method for determining the matching degree between the first pseudo label and the second pseudo label in the foregoing content, and the details will not be repeated here.
  • the matching degree between the first pseudo label and the second pseudo label is not higher than the threshold, the first pseudo label and the second pseudo label can also be deleted.
  • the fifth pseudo label and the sixth pseudo label If the matching degree between the labels is higher than the threshold, the fifth pseudo-label and the sixth pseudo-label can also be fused.
  • the fusion method please refer to the description in the previous content about the fusion of the first pseudo-label and the second pseudo-label to obtain the third pseudo-label. I won’t repeat them here. To put it simply, the unlabeled sample set 132 and the expanded sample set 133 are matched.
  • the corresponding pseudo-labels of the unlabeled samples and the pseudo-labels of the expanded samples will be matched to obtain the corresponding matching degree. If the matching degree is higher than the threshold, the The pseudo-labels of the two are fused, and if the matching degree is lower than the threshold, both pseudo-labels are deleted.
  • pseudo labels with lower accuracy can be filtered out, thereby avoiding the need for semi-supervised training of the second model 126.
  • Wrong or low-precision pseudo-labels are used to improve training efficiency and improve the accuracy of the final first model.
  • Step S340 Use the first unlabeled sample and the third pseudo label to train a second model.
  • the second model has the same weight parameters as the first model.
  • This step can be implemented by the training unit 124 in Figure 2. To put it simply, use the labeled sample set to train the machine learning model to obtain the first model, then copy the weight parameters of the first model to the second model, and then use the first unlabeled sample and the third pseudo label to pair the second model The model is trained.
  • the model structure of the first model can be the same as the model structure of the second model.
  • the weight parameters of the first model 125 are first copied to the second model. 126, and then use the first unlabeled sample, the third pseudo label and the labeled sample used when training the first model 125 to pair the second model 126 Iterative training is performed, and the weight parameters of the first model are iteratively updated according to the weight parameters of the second model obtained in each iterative training to obtain the target model 127.
  • the first unlabeled sample, the third pseudo-label and the above-mentioned labeled sample are used to perform a first round of training on the second model 126, and the weight parameters of the second model 126 after the first round of updates are obtained, and then sent to the second model 126.
  • the first model 125 updates the first model 125 to obtain a new first model 125, and then inputs the first unlabeled sample and the first expanded sample into the new first model 125 to predict a new first pseudo label and a new first model 125.
  • the above-mentioned second unlabeled sample and the second expanded sample can be input into the new first model 125 to obtain a new fifth pseudo label and
  • For the new sixth pseudo-label if the matching degree between the fifth pseudo-label and the sixth pseudo-label is still not higher than the threshold, you can continue to delete the fifth pseudo-label and the sixth pseudo-label; if the fifth pseudo-label and the sixth pseudo-label The matching degree between pseudo-labels is higher than the threshold.
  • the fifth pseudo-label and the sixth pseudo-label can be fused to obtain the seventh pseudo-label, and then the seventh pseudo-label, the second unlabeled sample and the labeled sample can be used to match the seventh pseudo-label.
  • the second model is trained until convergence, the weight parameters of the second model after the second round of updates are obtained, and then updated to the first model 125, and so on.
  • the new first pseudo label can also be Pseudo-label and new second pseudo-label deletion will not be repeated here.
  • the accuracy of the pseudo labels inferred by the first model 125 becomes higher and higher until the prediction of the first model 125 The accuracy reaches the standard required by the user, thereby obtaining the target model 127 and sending it to the inference device 110 .
  • all the weight parameters obtained by the second model in each round of training can be updated to the first model, or some of the weight parameters obtained in each round of training can be updated to the first model, so that the first model can obtain slow and stable weight updates. , the first model obtained through such training is more robust and has better model performance.
  • the second model has the same model structure as the first model.
  • a small number of labeled samples that are difficult to obtain and a large number of unlabeled samples that are easy to obtain can be used to train the machine learning model.
  • the obtained target model 127 is not only It has good robustness, good model performance and high training efficiency.
  • the model structure of the first model may also include the model structure of the second model. That is to say, the second model is a sub-model of the first model. Similarly, the second model will obtain the updated model obtained in each round of training. The weights are synchronized to the first model, and the new first model predicts new pseudo labels to train the second model, and so on, steadily updating the second model and the first model, so that the second model finally obtained is not only complex in structure The degree is low, and the model performance tends to be close to the first model, or even better than the first model, thereby achieving the purpose of model compression.
  • the input samples in the labeled sample set 131 can be input into the second model to obtain the first output value.
  • input the first unlabeled sample into the second model to obtain the second output value determine the loss value of the second model based on the first output value and the second output value, and then perform backpropagation on the second model based on the loss value until convergence,
  • Obtain the trained second model then synchronize the model parameters of the trained second model to the first model, and then perform the next round of model training.
  • the above-mentioned loss value L includes a label loss L 1 and an unlabeled loss L 2 .
  • the loss value L is obtained based on the difference between the first output value and the real label.
  • the pseudo label loss is obtained based on the difference between the second output value and the third pseudo label.
  • Figure 4 is a schematic flow chart of the semi-supervised model training method provided by the present application in an application scenario.
  • Figure 5 is a schematic diagram of the steps of the semi-supervised model training method provided by the present application. Schematic diagram of the fusion process of the first pseudo-label and the second pseudo-label in the semi-supervised model training method provided by the application.
  • the first model is a target detection model
  • the second model is a target detection model with the same network structure as the first model. Model.
  • the semi-supervised model training method in this application scenario may include the following steps.
  • Step 1 Obtain a training sample set, which may include the labeled sample set 131 and the unlabeled sample set 132 shown in Figure 2.
  • Step 2 Determine whether it is a labeled sample. Specifically, judge and classify each sample in the training sample set. Perform step 3 for labeled samples and step 4 for unlabeled samples.
  • Step 3 Use labeled samples to train the first model.
  • the first model is the first model 125 in the embodiment of FIG. 2 and FIG. 3 . After step 3 is completed, perform step 5 or step 6.
  • the number of training rounds can be controlled so that the first model has a certain detection capability and prevents the first model and the second model from overlapping in the subsequent semi-supervised training process. fitting phenomenon.
  • Step 4 Perform data enhancement on unlabeled samples to obtain expanded samples.
  • step S320 in the embodiment of FIG. 3
  • this step may be implemented by the data enhancement unit 121 in the embodiment of FIG. 2 .
  • step 4 go to step 7.
  • the data enhancement method used in the application scenario shown in Figure 4 is: flipping the first unlabeled sample vertically to the right.
  • steps 3 and 4 can be processed in parallel or serially, and the details can be determined according to the processing capabilities of the computing devices deployed in the semi-supervised model training system 120, which are not specifically limited in this application.
  • Step 5 Copy the model parameters of the first model to the second model.
  • the second model is the second model 126 described in the embodiment of FIGS. 2 and 3 .
  • Step 6 Input the unlabeled sample into the first model to obtain the first pseudo label.
  • This step can be implemented by the inference unit 122 in the embodiment of Figure 2.
  • the inference unit 122 for details, reference can be made to the description of step S310 in the embodiment of Figure 3, and the details will not be repeated here.
  • Step 7 Input the augmented sample into the first model to obtain the second pseudo label.
  • This step can be implemented by the inference unit 122 in the embodiment of Figure 2.
  • the inference unit 122 for details, reference can be made to the description of step S320 in the embodiment of Figure 3, and the details will not be repeated here.
  • Step 6 that needs to be explained can be processed in parallel or serially with step 7, and is not specifically limited in this application.
  • Step 8 Perform data enhancement inverse operation on the second pseudo label to obtain the fourth pseudo label.
  • the inverse operation of data enhancement refers to the inverse operation corresponding to the data enhancement in step 4.
  • Step 4 flips the first unlabeled sample vertically to the right, and then step 8 can vertically flip the second pseudo-label. Flip to the left.
  • Step 9 Match the first pseudo-label and the fourth pseudo-label to obtain the matching degree.
  • step 9 may match the detection frame corresponding to the first pseudo label with the detection frame corresponding to the fourth pseudo label.
  • the matching degree here may be the IOU between the two detection frames.
  • Step 10 Determine whether the matching degree is higher than the threshold. If the matching degree is higher than the threshold, execute step 11. If the matching degree is lower than the threshold, execute step 14.
  • step 10 can fuse the first pseudo-label and the fourth pseudo-label whose IOU is greater than 0.45, that is, perform step 11 to delete them.
  • Step 11 Fusion of the first pseudo-label and the fourth pseudo-label whose matching degree is higher than the threshold into the third pseudo-label.
  • the above steps 8 to 11 can be implemented by the matching unit 123 in the embodiment of FIG. 2.
  • step S330 in the embodiment of FIG. 3, which will not be repeated here.
  • Figure 5 is a schematic flowchart of the steps of merging the first pseudo-label and the second pseudo-label into a third pseudo-label in the semi-supervised model training method provided by this application.
  • step 6 to step 11 in the embodiment of Figure 4 it can be seen from Figure 5 that after step 4 expands the first unlabeled sample, step 6 inputs the first unlabeled sample into the first model 125 to obtain the first pseudo label, step 7 inputs the expanded sample into the first model 125 to obtain the second pseudo label.
  • the accuracy of the first pseudo label and the second label is low, and the detection frames corresponding to the two pseudo labels do not completely label the target (that is, vehicles).
  • Figure 5 clearly shows that the sample labeled by the second pseudo label is a different sample from the sample labeled by the first pseudo label.
  • the truck In the unlabeled sample labeled by the first pseudo label, the truck is in the upper lane, and in the second pseudo label, the truck is in the upper lane.
  • the truck In the expanded sample marked by the pseudo label, the truck is in the lower lane, so step 10 cannot be performed directly to match the two.
  • step 8 perform the reverse operation of data enhancement on the second pseudo-label, that is, after flipping it horizontally -180 degrees, the fourth pseudo-label is obtained.
  • the samples labeled by the first pseudo-label and the fourth pseudo-label are the same sample, that is, Truck sample on top.
  • step 10 is performed to match the first pseudo label and the fourth pseudo label.
  • the matching degree between the two can be a good indicator of the accuracy of the first pseudo label.
  • the first pseudo label and the fourth pseudo label whose matching degree is higher than the threshold are After fusion, the accuracy of the first pseudo-label can be further improved.
  • the third pseudo-label obtained in step 11 is more precise and accurate than the first pseudo-label and the second pseudo-label.
  • step 10 determines that the matching degree between the first label and the second label is not higher than the threshold
  • step 14 can be performed to delete the first pseudo label, the second pseudo label and the fourth pseudo label.
  • the first unlabeled sample does not need to participate in the training.
  • the matching degree between the first pseudo label and the second pseudo label is higher than the threshold, it can participate in the next round of training. , and so on, will not be elaborated here.
  • Step 12 Train the second model using the third pseudo-label, unlabeled samples and labeled samples.
  • Step 13 Introduce the trained weight parameters of the second model into the first model.
  • the above steps 12 and 13 can be implemented by the training unit 124 in the embodiment of Figure 1.
  • step S340 in the embodiment of Figure 3, which will not be repeated here.
  • the weight parameters of the partially trained second model can be introduced into the first model according to the EMA method to obtain a new first model, and continue to perform steps 6 to 14, and so on, until the first model converges.
  • Step 14 Delete the first pseudo-label and the second pseudo-label.
  • steps 6 to 14 are repeatedly executed, even if the first pseudo label and the second pseudo label in the previous round are deleted, the new first pseudo label and the second pseudo label in the next round can continue to match.
  • the process of fusion or deletion makes the first pseudo-label and the second pseudo-label inferred by the first model more and more accurate, thus making the model training effect better and better.
  • the semi-supervised model training method obtaineds expanded samples of unlabeled samples by performing data enhancement on unlabeled samples, and then inputs the unlabeled samples and expanded samples into the first model for inference to obtain the first model of unlabeled samples.
  • the pseudo label and the second pseudo label of the expanded sample are then fused with the first pseudo label and the second pseudo label whose matching degree is higher than the threshold to obtain the third pseudo label, and the first pseudo label whose matching degree is lower than or equal to the threshold is fused.
  • FIG. 6 is a schematic structural diagram of a computing device provided by this application.
  • the computing device 600 is the semi-supervised model training system 120 in the embodiment of FIGS. 1 to 5 .
  • the computing device 600 includes a processor 601, a storage unit 602, a storage medium 603 and a communication interface 604, wherein the processor 601, the storage unit 602, the storage medium 603 and the communication interface 604 communicate through the bus 605 and also through wireless transmission. and other means to achieve communication.
  • the processor 601 is composed of at least one general-purpose processor, such as a CPU, an NPU, or a combination of a CPU and a hardware chip.
  • the above-mentioned hardware chip is an application-specific integrated circuit (Application-Specific Integrated Circuit, ASIC), a programmable logic device (Programmable Logic Device, PLD), or a combination thereof.
  • the above-mentioned PLD is a complex programmable logic device (Complex Programmable Logic Device, CPLD), a field-programmable gate array (Field-Programmable Gate Array, FPGA), a general array logic (Generic Array Logic, GAL) or any combination thereof.
  • the processor 601 executes various types of digital storage instructions, such as software or firmware programs stored in the storage unit 602, which enables the computing device 600 to provide a wide variety of services.
  • the processor 601 includes one or more CPUs, such as CPU0 and CPU1 shown in FIG. 6 .
  • the computing device 600 also includes multiple processors, such as the processor 601 and the processor 606 shown in FIG. 6 .
  • processors can be a single-core processor (single-CPU) or a multi-core processor (multi-CPU).
  • a processor here refers to one or more devices, circuits, and/or processing cores for processing data (eg, computer program instructions).
  • the storage unit 602 is used to store program codes, and is controlled and executed by the processor 601 to perform the processing steps of the semi-supervised model training system 120 in any of the above embodiments in FIGS. 1 to 5 .
  • the program code includes one or more software units.
  • the one or more software units are the inference unit, matching unit and training unit in the embodiment of Figure 2, where the inference unit is used to input the first unlabeled sample into the first model.
  • the inference unit is used to execute steps S310 to S320 in Figure 3 and steps 6 and 7 in Figures 4 and 5, and the matching unit is used to execute step S330 in Figure 3 and the steps in Figures 4 and 5.
  • the training unit is used to execute step S340 in Figure 3 and steps 12 and 13 in Figure 4.
  • Storage unit 602 includes read-only memory and random access memory, and provides instructions and data to processor 601. Storage unit 602 also includes non-volatile random access memory. Storage unit 602 is volatile memory or non-volatile memory, or includes both volatile and non-volatile memory. Among them, non-volatile memory is read-only memory (ROM), programmable ROM (PROM), erasable programmable read-only memory (erasable PROM, EPROM), and electrically erasable programmable read-only memory. memory (electrically EPROM, EEPROM) or flash memory. Volatile memory is random access memory (RAM), which is used as an external cache.
  • RAM static random access memory
  • SRAM dynamic random access memory
  • DRAM dynamic random access memory
  • SDRAM synchronous dynamic random access memory
  • Double data rate synchronous dynamic random access memory double data date SDRAM, DDR SDRAM
  • enhanced synchronous dynamic random access memory enhanced SDRAM, ESDRAM
  • synchronous link dynamic random access memory direct memory bus random access memory
  • direct rambus RAM direct rambus RAM, DR RAM
  • hard disk hard disk
  • U disk universal serial bus, USB
  • flash flash
  • SD card secure digital memory Card, SD card
  • memory stick etc.
  • the hard disk is a hard disk drive (HDD) , solid state disk (SSD), mechanical hard disk (HDD), etc., this application does not make specific limitations.
  • Storage medium 603 is a carrier for storing data, such as hard disk (hard disk), U disk (universal serial bus, USB), flash memory (flash), SD card (secure digital memory Card, SD card), memory stick, etc.
  • the hard disk can It is a hard disk drive (HDD), a solid state disk (SSD), a mechanical hard disk (HDD), etc., and is not specifically limited in this application.
  • the communication interface 604 is a wired interface (such as an Ethernet interface), an internal interface (such as a high-speed serial computer expansion bus (Peripheral Component Interconnect express, PCIe) bus interface), a wired interface (such as an Ethernet interface), or a wireless interface (such as a cellular interface). network interface or using the wireless LAN interface) for communicating with other servers or units.
  • a wired interface such as an Ethernet interface
  • an internal interface such as a high-speed serial computer expansion bus (Peripheral Component Interconnect express, PCIe) bus interface
  • PCIe Peripheral Component Interconnect express
  • Ethernet interface such as an Ethernet interface
  • a wireless interface such as a cellular interface
  • Bus 605 is a peripheral component interconnect express (PCIe) bus, an extended industry standard architecture (EISA) bus, a unified bus (unified bus, Ubus or UB), or a computer express link (compute express). link, CXL), cache coherent interconnect for accelerators, CCIX, etc.
  • PCIe peripheral component interconnect express
  • EISA extended industry standard architecture
  • unified bus unified bus
  • Ubus or UB unified bus
  • CXL computer express link
  • cache coherent interconnect for accelerators CCIX, etc.
  • the bus 605 is divided into an address bus, a data bus, a control bus, etc.
  • bus 605 also includes a power bus, a control bus, a status signal bus, etc.
  • bus 605 also includes a power bus, a control bus, a status signal bus, etc.
  • the various buses are labeled bus 605 in the figure.
  • FIG. 6 is only one possible implementation manner of the embodiment of the present application.
  • the computing device 600 may also include more or fewer components, which is not limited here.
  • contents not shown or described in the embodiments of the present application please refer to the relevant explanations in the embodiments of FIGS. 1 to 5 , and will not be described again here.
  • An embodiment of the present application provides a computer storage medium, including: instructions stored in the computer storage medium; when the instructions are run on a computing device, the computing device is caused to execute the semi-supervised model training method described in Figures 1 to 5. .
  • Embodiments of the present application provide a program product containing instructions, including a program or instructions that, when run on a computing device, cause the computing device to execute the semi-supervised model training method described above in FIGS. 1 to 5 .
  • the above embodiments are implemented in whole or in part by software, hardware, firmware or any other combination.
  • the above-described embodiments are implemented in whole or in part in the form of a computer program product.
  • a computer program product includes at least one computer instruction.
  • the computer is a general-purpose computer, a special-purpose computer, a computer network, or other programming device.
  • Computer instructions are stored in or transmitted from one computer-readable storage medium to another, e.g., from a website, computer, server, or data center via wired (e.g., coaxial cable, fiber optic cable) , digital subscriber line (DSL)) or wireless (such as infrared, wireless, microwave, etc.) means to transmit to another website, computer, server or data center.
  • Computer-readable storage media are any media that can be accessed by a computer or data storage nodes such as servers and data centers that contain at least one media collection.
  • the media used is magnetic media (for example, floppy disk, hard disk, tape), optical media (for example, high-density digital video disc (DVD)), or semiconductor media.
  • the semiconductor medium is SSD.

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Image Analysis (AREA)

Abstract

一种半监督模型训练方法、系统及相关设备,该方法可包括以下步骤:将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签,将第一扩充样本输入第一模型,获得第一扩充样本的第二伪标签,其中,第一模型为采用有标签样本进行训练后的人工智能AI模型,第一扩充样本为对第一无标签样本进行数据增强后获得的样本,根据第一伪标签和第二伪标签,获得第一无标签样本的第三伪标签,使用第一无标签样本和第三伪标签对第二模型进行训练,其中,第二模型是根据第一模型的权重参数获得的AI模型,这样获得的第三伪标签的精度更高,从而提高半监督模型训练的训练效率以及最终获得模型的精度。

Description

一种半监督模型训练方法、系统及相关设备
本申请要求于2022年4月19日提交中国专利局、申请号为202210412186.6、发明名称为“一种半监督模型训练方法、系统及相关设备”的中国专利申请的优先权,所述专利申请的全部内容通过引用结合在本申请中。
技术领域
本申请涉及人工智能(artificial intelligence,AI)技术领域,尤其涉及一种半监督模型训练方法、系统及相关设备。
背景技术
半监督学习(semi-supervised learning,SSL)指的是使用有标签样本和无标签样本对AI模型(本申请也简称为“模型”)进行训练的方法,通过半监督学习,可以有效减少有标签样本数量,降低模型训练的成本。
通常情况下,半监督学习采用第一-第二模型(又可称为教师-学生模型,其中,第一模型为教师模型,第二模型为学生模型)的半监督模型训练方法。具体地:可先使用有标签样本对第一模型进行训练,然后将无标签样本输入至经有标签样本训练后的第一模型,推理出无标签样本的伪标签,再将第一模型的权重参数复制给结构相同的第二模型,使用上述伪标签以及无标签样本对第二模型进行训练,获得更新后的第二模型的权重参数,然后将第二模型的部分权重参数更新给第一模型。进而,在第二轮迭代中,第一模型推理新的伪标签样本给第二模型进行训练,以此迭代,使得第一模型的权重参数得以稳定更新,最终训练获得的第一模型鲁棒性更强、模型性能更强。
但是,由于通常无标签样本的数量远高于有标签样本,伪标签样本对第二模型训练的影响很大,若第一模型生成的伪标签误差较大时,会导致第二模型的训练效率低,模型性能差,进而影响最终获得第一模型的训练效率以及模型性能。
发明内容
本申请提供了一种半监督模型训练方法、系统及相关设备,用于解决半监督学习过程中,伪标签质量差导致模型训练效率低、模型训练差的问题。
第一方面,提供了一种半监督模型训练方法,该方法可包括以下步骤:将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签,将第一扩充样本输入第一模型,获得第一扩充样本的第二伪标签,其中,第一模型为采用有标签样本进行训练后的人工智能AI模型,第一扩充样本为对第一无标签样本进行数据增强后获得的样本,根据第一伪标签和第二伪标签,获得第一无标签样本的第三伪标签,使用第一无标签样本和第三伪标签对第二模型进行训练,其中,第二模型是根据第一模型的权重参数获得的AI模型。
实施第一方面描述的方法,通过对第一无标签样本进行数据增强获得第一无标签样本的第一扩充样本,然后将第一无标签样本和第一扩充样本输入第一模型推理获得第一无标签样本的第一伪标签以及第一扩充样本的第二伪标签,然后根据第一伪标签和第二伪标签获得第三伪标签,这样获得的第三伪标签质量更高,使用第三伪标签对第二模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
在一可能的实现方式中,对第一无标签样本进行数据增强时,数据增强方法可包括但不限于翻转变换(flip)、平移变换(shift)、尺度变换(scale)、旋转变换/反射变换(rotation/reflection)、缩放变换(zoom)、修剪(crop)、颜色变换(color space)、噪声扰动(noise)、内核过滤(kernel filters)中的一种或者多种。
其中,翻转变换指的是对图像进行水平或垂直翻转,水平翻转还可分为向上水平翻转和向下水平翻转,垂直翻转还可分为向左垂直翻转和向右垂直翻转;平移变换指的是对图像进行平移操作,比如x方向向右平移(param xoffset),y方向向下平移(param yoffset),其中x方向和y方向指的是图像坐标系的横轴方向和纵轴方向;旋转变换也可称为反射变换,指的是对图像进行某个角度的旋转,该角度可以是0~360度中的任意角度;缩放变换指的是将图像按照一定比例进行放大或者缩小,而不会改变图像中的内容;修剪也可称为裁剪,包括统一裁剪和随机裁剪,统一裁剪指的是将不同尺寸的图像裁剪至设定大小,随机裁剪指的是将不同尺寸的图像随机裁剪成不同尺寸大小;颜色变换指的是对图像某种颜色通道进行修改,比如关闭通道或者改变通道亮度值,举例来说,图像通常包括RGB三个通道,颜色变换可以将R通道值减少或增大;噪声扰动指的是从高斯分布中采样出的随机值矩阵加入到图像的RGB像素矩阵中;内核过滤指的是使用特定功能的内核过滤器与图像进行卷积操作,比如锐化、模糊等内核过滤器。
应理解,上述数据增强方法用于举例说明,本申请还可通过其他数据增强的方法对第一无标签样本进行扩充获得第一扩充样本,比如对图像进行增强还可以通过对抗生成(adversarial training)、特征空间增强(feature space augmentation)、基于GAN的数据增强(gan-based data augmentation)等数据增强方法,本申请不一一展开赘述。
需要说明的,由于目标检测模型的标签通常为检测框(bounding box),因此在在第一模型是目标检测模型时,可使用翻转变换、平移变换、尺度变换、旋转变换、缩放变换等对检测框产生影响的数据增强方法,对第一无标签样本进行数据增强;由于图像识别模型的标签为图像所属类别的概率分布,因此在第一模型是图像识别模型时,可使用修剪、颜色变换、噪声扰动、内核过滤等对图像类别判定产生影响的数据增强方法,对第一无标签样本进行数据增强。
上述实现方式,通过针对模型类型进行不同的数据增强操作,可以使得最终获得的扩充样本可以增加模型的泛化性能,提高模型的鲁棒性。
在一可能的实现方式中,可根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度;在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
可选地,第一模型为目标检测模型时,目标检测模型的输出结果可能是多个目标检测框,可以通过非极大抑制(non maximum suppression,NMS)方法,选择多个目标检测框中精度最高的检测框作为第一伪标签或者第二伪标签,从而增加第一无标签样本的第一伪标签和扩充样本的第二伪标签的精度。
可选地,若第一模型为目标检测模型,在将第一伪标签和第二伪标签进行匹配时,可以对第二伪标签进行数据增强的逆操作,获得第四伪标签,然后将第四伪标签与第一伪标签进行匹配,获得第一伪标签和第四伪标签之间的匹配结果,根据上述匹配结果确定上述匹配度。其中,逆操作指的是与数据增强单元执行的数据增强方法相反的操作,比如数据增强方法是对第一无标签样本进行了水平向上翻转操作获得第一扩充样本,那么匹配单元此时可以对第一扩充样本的第二伪标签对应的标准框进行水平向下翻转操作获得第四伪标签,再比如数据 增强方法是对第一无标签样本进行了向右旋转90°操作获得第一扩充样本,那么匹配单元123可以对第一扩充样本的第二伪标签对应的标准框进行向左旋转90操作获得第四伪标签,以此类推,这里不一一举例说明。
上述实现方式,目标检测模型的伪标签是图像中目标的检测框,因此通过数据增强方法获得的第一扩充样本,目标的位置实际已发生改变,检测框需要对其进行逆操作,使得第一伪标签和第二伪标签所框选的目标是同一个位置的目标,然后再将其进行匹配,可以筛选出标注目标不准确的第一伪标签,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,若第一模型为目标检测模型,可以将第一伪标签对应的检测框与第四伪标签对应的检测框进行匹配,获得上述匹配度,这里的匹配度可以是两个检测框之间的交并比(intersection over union,IOU)。
具体实现中,在匹配度大于阈值时,匹配单元可以将第一伪标签和第二伪标签进行融合获得第三伪标签时,可以将上述第四伪标签对应的检测框与第一伪标签对应的检测框进行多值平均处理,获得第三伪标签。
上述实现方式,将第一伪标签对应的检测框与第四伪标签对应的检测框进行融合,两个检测框是使用不同方法确定的目标检测框,因此将二者融合可以进一步提高最终获得的第三伪标签的精度,从而提高后续半监督训练过程中所使用伪标签的精度,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,若第一模型为图像识别模型,可以将第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果,根据该匹配结果确定二者之间的匹配度。
具体实现中,可以将第一伪标签对应的概率分布与第二伪标签对应的概率分布进行匹配,确定二者之间的匹配度,这里的匹配度可以是两个概率分布之间的相似度或者距离,本申请不作具体限定。将匹配度大于阈值的第一伪标签和第二伪标签进行融合时,可以将第一伪标签的概率分布与第二伪标签的概率分布进行均值处理,比如平均数、加权平均等等,本申请不作具体限定。
上述实现方式,在第一模型为图像识别模型时,将第一伪标签和第二伪标签进行匹配,大于阈值的情况下将二者进行融合,可以进一步提高最终获得的第三伪标签的精度,从而提高后续半监督训练过程中所使用伪标签的精度,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,也可以将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签,然后将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,这里,第二扩充样本为对第二无标签样本进行数据增强后获得的样本。可以根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度,在匹配度不高于上述阈值的情况下,删除第五伪标签和第六伪标签。其中,上述第五伪标签和第六伪标签之间匹配度的确定方式可以参考前述内容中第一伪标签和第二伪标签之间匹配度的确定方式,这里不重复展开赘述。
需要说明的,第一伪标签和第二伪标签之间的匹配度如果不高于阈值,也可以将第一伪标签和第二伪标签删除,同理,如果第五伪标签和第六伪标签之间的匹配度高于阈值,也可以将第五伪标签和第六伪标签进行融合,融合方式和参考前述内容中关于第一伪标签和第二伪标签融合获得第三伪标签的描述这里不重复赘述。简单来说,无标签样本集和扩充样本集进行匹配,呈对应关系的无标签样本的伪标签和扩充样本的伪标签会进行匹配获得相应的匹 配度,若匹配度高于阈值则将二者的伪标签进行融合,匹配度低于阈值则将二者的伪标签都进行删除。
上述实现方式,通过将无标签样本集的伪标签和扩充样本集的伪标签进行匹配的方式,可以过滤出准确度较低的伪标签,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
在一可能的实现方式中,第一模型的模型结构可以与第二模型的模型结构相同。
可选地,在对第二模型进行训练时,可先将第一模型的权重参数拷贝给第二模型,然后使用第一无标签样本、第三伪标签以及上述训练第一模型时使用的有标签样本对第二模型进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型。
具体地,使用第一无标签样本、第三伪标签以及上述有标签样本对第二模型进行第一轮训练,获得第一轮更新后第二模型的权重参数,然后将其发送给第一模型进行对第一模型的更新获得新的第一模型,再将上述第一无标签样本和第一扩充样本输入新的第一模型,预测出新的第一伪标签和新的第二伪标签,再将匹配度高于阈值的新的第一伪标签和新的第二伪标签进行融合获得新的第三伪标签,再使用第一无标签样本和新的第三伪标签以及有标签样本继续对第二模型训练至收敛,获得第二轮更新后的权重参数,然后再将其更新至第一模型,以此类推,这里不一一展开赘述。
需要说明的,在接下来多轮迭代训练过程中,比如第二轮训练时,可以将上述第二无标签样本和第二扩充样本输入新的第一模型,获得新的第五伪标签和新的第六伪标签,如果第五伪标签和第六伪标签之间的匹配度仍然不高于阈值,可以继续将第五伪标签和第六伪标签删除;如果第五伪标签和第六伪标签之间的匹配度高于阈值,此时可以将第五伪标签和第六伪标签进行融合获得第七伪标签,然后使用第七伪标签、第二无标签样本以及有标签样本对第二模型进行训练至收敛,获得第二轮更新后第二模型的权重参数,再将其更新至第一模型,以此类推。
同理,如果接下来多轮迭代训练过程中,比如第二轮训练时,新的第一伪标签和新的第二伪标签之间的匹配度不高于阈值,也可以将新的第一伪标签和新的第二伪标签删除,这里不重复赘述。
上述实现方式,通过将第二模型的权重参数对第一模型进行更新,然后多轮迭代训练的方式,使得第一模型推理出的伪标签精度越来越高,直至第一模型的预测精度达到用户所需的标准,从而获得目标模型。
在一可能的实现方式中,第二模型每轮训练获得的权重参数可以全部更新至第一模型,也可以将每轮训练获得的部分权重参数更新至第一模型,使得第一模型得到缓慢、稳定的权重更新,这样训练获得的第一模型更具鲁棒性,模型性能更佳。
具体实现中,可通过指数滑动平均(exponential moving average,EMA)方法将学生的权重更新给第一模型,举例来说,假设EMA=0.99,那么每轮训练获得的权重参数的1%将被更新入第一模型。应理解,上述举例用于说明,本申请不作具体限定。
上述实现方式,第二模型与第一模型的模型结构相同,可以使用少量的、获取困难的有标签样本和大量的、容易获取的无标签样本对机器学习模型进行训练,获得的目标模型不仅鲁棒性好,而且模型性能好,训练效率高。
在一可能的实现方式中,第一模型的模型结构也可以包括第二模型的模型结构,也就是说,第二模型是小模型,第一模型是大模型,比如第二模型是第一模型的一个子模型,同样 的,第二模型将每轮训练获得的更新后的权重同步至第一模型,新的第一模型再预测新的伪标签对第二模型进行训练,以此类推,稳步更新第二模型和第一模型。或者,不需要多轮训练,第一轮在第二模型训练收敛后,将训练好的第二模型作为目标模型。
上述实现方式,使用的第二模型是小模型,第一模型是大模型,这样最终获得的第二模型不仅结构复杂度低,而且模型性能与第一模型趋于接近,甚至可以比第一模型性能更好,从而达到模型压缩的目的。
在本申请实施例中,使用有标签样本、第一无标签样本和第三伪标签对第二模型进行训练时,可以将有标签样本集中的输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,然后根据损失值对第二模型进行反向传播直至收敛,获得训练好的第二模型,然后将训练好的第二模型的模型参数同步至第一模型中,再进行下一轮的模型训练。其中,上述损失值L包括有标签损失L1和无标签损失L2,该损失值L是根据第一输出值和真实标签之间的差距获得的,伪标签损失是根据第二输出值和第三伪标签之间的差距获得的。
具体实现中,可以通过系数加权的方式对有标签损失L1和无标签损失L2在损失值L中的比重进行调控,比如损失值L=L1+λL2,其中,λ越大,无标签损失L2在损失值L中的占比越大,第一无标签样本对第二模型和第一模型的模型性能影响越大,通过对λ的值进行调整,可以认为干预模型训练的学习方向。
上述实现方式,通过有标签损失和无标签损失共同影响第二模型的训练方向,并且无标签损失是基于上述过滤和融合后获得的第三伪标签与输出值之间的差距确定的,使得第二模型在半监督学习过程中,可以使用大量的无标签样本进行训练,从而降低样本获取的成本,同时不影响最终获得目标模型的性能。
在一可能的实现方式中,上述半监督模型训练方法也可以打包成一个软件模块,对现有的一些模型训练的设备进行软件升级,使其能够拥有对伪标签过滤、融合的功能,使得升级后的模型训练系统可以有更好的半监督训练功能。
举例来说,在公有云场景下,用于实现上述半监督模型训练方法的各个单元模块可以打包为一个配置模块,作为公有云模型训练服务中的一个小的配置功能,如果公有云用户购买该功能,即可为用户提供相应的权限。在非公有云场景下,用于实现上述半监督模型训练方法的各个单元模块可以打包为一个微服务或者软件包,用户购买本申请提供的伪标签过滤、融合功能之后,可以向用户提供相应的权限的许可(license),不同权限可设置不同的收费程度。本申请不作具体限定。
上述实现方式,通过软件打包为微服务、提供licencse或者提供云服务的方式,不仅用户获取方法简单快捷,而且开发者可以对原有的模型训练系统进行简单的软件升级即可实现上述各种功能,对开发者来说升级、维护都十分便捷,本申请提供的半监督模型训练方法部署方便,可用性高。
第二方面,提供了一种半监督模型训练系统,该系统包括:推理单元,用于将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签;推理单元,用于将第一扩充样本输入第一模型,获得扩充样本的第二伪标签,其中,第一模型为采用有标签样本进行训练后的人工智能AI模型,第一扩充样本为对第一无标签样本进行数据增强后获得的样本;匹配单元,用于根据第一伪标签和第二伪标签,获得第一无标签样本的第三伪标签;训练单元,用于使用第一无标签样本和第三伪标签对第二模型进行训练,其中,第二模型是根据第一模型的权重参数获得的AI模型。
实施第二方面描述的方法,本申请提供的模型训练系统,通过对无标签样本集进行数据增强获得无标签样本集的扩充样本集,然后将无标签样本集和扩充样本集输入第一模型推理获得无标签样本的集的多个第一伪标签以及扩充样本集的多个第二伪标签,然后将匹配度高于阈值的第一伪标签和第二伪标签进行融合获得第三伪标签,将匹配度低于或等于阈值的第一伪标签进行过滤,从而提高未标注样本集合的伪标签的质量,使得后续使用第三伪标签对学生模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
在一可能的实现方式中,第二模型与第一模型具有相同的结构。
在一可能的实现方式中,匹配单元,用于根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度;匹配单元,用于在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
在一可能的实现方式中,第一模型包括目标检测模型,数据增强方法包括翻转变换、平移变换、尺度变换、旋转变换、缩放变换中的一种或者多种。
在一可能的实现方式中,匹配单元,用于对第二伪标签进行数据增强的逆操作,获得第四伪标签;匹配单元,用于对第一伪标签和第四伪标签进行匹配,获得第一伪标签和第四伪标签之间的匹配结果;匹配单元,用于根据第一伪标签与第四伪标签之间的匹配结果确定匹配度。
在一可能的实现方式中,第一模型包括图像识别模型,数据增强方法包括修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种。
在一可能的实现方式中,匹配单元,用于对第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果;匹配单元,用于根据第一伪标签和第二伪标签之间的匹配结果获得匹配度。
在一可能的实现方式中,推理单元,用于将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签;推理单元,用于将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,第二扩充样本为对第二无标签样本进行数据增强后获得的样本;匹配单元,用于根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度;匹配单元,用于在匹配度不高于阈值的情况下,删除第五伪标签和第六伪标签。
在一可能的实现方式中,训练单元,用于使用有标签样本、第一无标签样本和第三伪标签样本对第二模型进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型。
在一可能的实现方式中,训练单元,用于将输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,其中,损失值包括有标签损失和伪标签损失,有标签损失是根据第一输出值和真实标签之间的差值获得的,伪标签损失是根据第二输出值和第三伪标签之间的差值获得的;训练单元,用于根据损失值对第二模型进行迭代训练。
第三方面,提供了一种计算设备,该计算设备包括处理器和存储器,存储器存储有代码,处理器包括用于执行第一方面或第一方面任一种可能实现方式描述的方法。
第四方面,提供了一种计算机存储介质,所述存储介质中存储有指令,当其在计算设备上运行时,使得计算设备执行第一方面或第一方面任一种可能实现方式描述的方法。
第五方面,提供了一种计算机程序指令,该计算机程序指令在计算设备上运行时,使得计算设备执行第一方面或第一方面任一种可能实现方式描述的方法。
本申请在上述各方面提供的实现方式的基础上,还可以进行进一步组合以提供更多实现方式。
附图说明
图1是一种半监督学习的步骤流程示意图;
图2是本申请提供的一种半监督模型训练系统的架构示意图;
图3是本申请提供的一种半监督模型训练方法的步骤流程示意图;
图4是本申请提供的一种半监督模型训练方法在一应用场景下步骤流程示意图;
图5是本申请提供的半监督模型训练方法中第一伪标签和第二伪标签的融合流程示意图;
图6是本申请提供的一种计算设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行描述,显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
首先,对本申请涉及的术语进行简单解释。
有标签样本(labeled)和无标签样本(unlabeled):有标签样本指的是拥有标签(label)的样本,该样本的标签表示该样本的真实值,该真实值用于在模型训练时与模型的预测值一起计算损失值,进而对模型的权重参数进行调整。例如:在分类模型训练时,将有标签样本作为输入数据输入至初始分类模型后,初始分类模型提供的预测结果与该有标签样本的标签进行比较,得到该轮训练的损失值,进而根据损失值可以调整分类模型的权重参数。相反,未标注样本即为不包含标签的样本。
损失函数(loss function):损失函数用于在模型训练过程中,评估模型的输出结果与样本标签之间的差距,损失值(loss)即为损失函数对应的值。损失值越低,模型的鲁棒性越好,因此,模型训练过程中通常会在样本输入模型获得输出值后,根据输出值与样本标签之间的差距确定损失值,根据损失值的大小对模型的权重参数进行调整,以此迭代,直至模型的损失函数最小化,获得目标模型。
其次,对本申请涉及的“半监督学习”应用场景进行说明。
AI是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。人工智能领域的应用场景包括机器人,自然语言处理,计算机视觉,决策与推理,人机交互,推荐与搜索等。
一般来说,AI领域的各种应用通过AI模型实现,而AI模型则是通过样本集对模型训练获得。其中,半监督学习/半监督训练是降低模型训练时所需的有标签样本集数量的一种重要方法,因理解,有标签样本集通常是人为标注的样本集,因此有标签样本集的数量越多,模型训练的效率越低、成本越高,因此,半监督学习在标注预算受限下,具有很强的实践价值。
图1是一种半监督学习的步骤流程示意图,如图1所示,半监督学习过程使用的样本集包括有标签样本101和无标签样本102,第一模型112以及第二模型113的网络结构相同,具体地,半监督学习的步骤流程可以如下:
步骤1、使用有标签样本101对初始第一模型进行训练,模型收敛后获得第一模型112。其中,初始第一模型可以是网络初始化后,还未进行训练的AI模型。
步骤2、将无标签样本102输入第一模型112,第一模型112推理出无标签样本102的伪标签,获得伪标签样本103。
步骤3、再将第一模型112的权重参数复制给第二模型113。
步骤4、使用上述有标签样本101和伪标签样本103对第二模型113进行训练,获得更新后的第二模型的权重参数。
具体实现中,使用上述有标签样本101和伪标签样本103对第二模型113进行训练时,模型的损失包括有标签损失和伪标签损失,其中,有标签损失是根据模型输出与有标签样本101的标签之间的差值获得的,伪标签损失是根据模型输出与伪标签样本103的伪标签之间的差值获得的。
步骤5、将更新后的第二模型的权重参数反馈给第一模型,对第一模型的权重参数进行更新。
重复步骤2~步骤5,直至模型收敛,获得训练好的目标模型114。这样,第一模型可以缓慢、稳定的进行权重更新,更具鲁棒性,模型性能更强。
但是,由于有标签样本101需要人工标注,数量有限,通常情况下,无标签样本102的数量远大于有标签样本101,因此第二模型113的训练过程中,伪标签样本103占比远大于有标签样本101,伪标签损失的权重也远高于有标签损失,导致伪标签样本的优劣决定了第二模型和第一模型的优化方向,而伪标签是第一模型推理得到的标签,伪标签会出现较多错误和噪声,半监督学习又无法对伪标签进行校正,若人工对伪标签进行校正,不仅效率低下且人力成本太高。
综上可知,半监督学习过程中,由于无标签样本的数量远高于有标签样本,伪标签样本对第二模型训练的影响很大,若第一模型生成的伪标签误差较大时,会导致第二模型的训练效率低,模型性能差,进而影响最终获得第一模型的训练效率以及模型性能。
为了解决上述问题,本申请提供了一种半监督模型训练系统,该系统通过对无标签样本进行数据增强获得无标签样本的扩充样本,然后将无标签样本和扩充样本输入第一模型推理获得无标签样本的第一伪标签以及扩充样本的第二伪标签,然后根据第一伪标签和第二伪标签的匹配结果获得第三伪标签,这样获得的第三伪标签质量更高,使用第三伪标签对第二模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
图2是本申请提供的一种半监督模型训练系统的架构示意图,如图2所示,该半监督模型训练系统的架构包括推理设备110、半监督模型训练系统120、用户设备140、数据存储系统150以及数据采集设备160。其中,推理设备110、半监督模型训练系统120、用户设备140、数据存储系统150以及数据采集设备160之间可建立通信连接,具体可通过有线网络或者无线网络的方式建立通信连接,本申请不作具体限定。
数据采集设备160用于采集原始样本,将其发送给半监督模型训练系统120进行模型训练,其中,该原始样本可以是图像类型的原始样本,数据采集设备160可包括图像采集装置、雷达采集装置等用于采集原始样本的传感器等,图像采集装置可以是监控摄像头、电子警察、深度摄像机、无人机等等,雷达采集装置可以是雷达、卫星等,本申请不作具体限定。应理解,在不同的应用场景下,训练机器学习模型所需的原始样本不同,对应的数据采集设备160 也不同,本申请不对此进行具体限定。
半监督模型训练系统120用于接收数据采集设备160采集的原始样本,对原始样本进行处理后获得图2所示的有标签样本集131、无标签样本集132以及扩充样本集133,使用上述样本集对第一模型125和第二模型126进行训练后,获得训练好的目标模型127,并将其发送给推理设备110。
其中,半监督模型训练系统120可以部署于计算设备上,该计算设备可以是裸金属服务器(bare metal server,BMS)、虚拟机或容器。其中,BMS指的是通用的物理服务器,例如,ARM服务器或者X86服务器;虚拟机指的是网络功能虚拟化(network functions virtualization,NFV)技术实现的、通过软件模拟的具有完整硬件系统功能的、运行在一个完全隔离环境中的完整计算机系统,容器指的是一组受到资源限制,彼此间相互隔离的进程,计算设备还可以是边缘计算设备,本申请不作具体限定。可选地,半监督模型训练系统120也可以是服务器集群,比如集中式服务器或者分布式服务器。
可选地,半监督模型训练系统120也可部署于公有云中,作为一项模型训练的云服务提供给公有云用户,用户可通过购买该服务获取半监督模型训练系统120的使用权限,本申请不作具体限定。
可选地,半监督模型训练系统120也可以通过软件打包的方式提供给用户,用户将软件安装于自己的计算设备上,或者,以微服务的方式提供给用户。举例来说,软件打包的方式提供给用户后,用户可根据自己的需求购买所需的软件版本或者能力,获得相应权限的许可(license),不同权限可设置不同的收费。
推理设备110用于接收半监督模型训练系统120发送的目标模型127,使用目标模型127对用户设备140发送的输入数据进行推理,获得输出数据,并将其返回给用户设备140,或者,将其存储于数据存储系统150。上述推理设备110可以是计算设备,具体可以是BMS、虚拟机、容器、终端设备或者边缘计算设备,本申请不作具体限定。
用户设备140可以是用户所持有的终端设备,包括计算机、智能手机、掌上处理设备、平板电脑、移动笔记本、增强现实(augmented reality,AR)设备、虚拟现实(virtual reality,VR)设备、一体化掌机、穿戴设备、车载设备、智能会议设备、智能广告设备、智能家电等等,此处不作具体限定。
可选地,用户设备140也可以是数据采集设备,其采集的输入数据可以输入至推理设备110进行目标检测或者图像识别,输出结果存储至数据存储系统150。举例来说,用户设备140可以是道路上的电子警察,推理设备110是道路两侧的边缘计算设备,目标模型是车牌识别模型,数据存储系统是交警大队维护的数据库,电子警察采集的超速车辆图片可以输入边缘计算设备中的车牌识别模型,识别出超速车辆图片中超速车辆的车牌号,将其存储于交警大队维护的数据库中,应理解,上述举例用于说明,本申请不作具体限定。
可选地,用户设备140与推理设备110可以是同一个设备,比如用户的智能手机下载半监督模型训练系统120训练好的人脸识别模型后,用户通过智能手机上的摄像头采集人脸输入数据,将其输入至人脸识别模型,获得人脸识别结果,人脸识别结果可以直接显示于用户设备140上,或者,存储于远程服务器以便后续进行安全解锁、安全支付等认证匹配,本申请不作具体限定。
数据存储系统150可以是具有存储功能的服务器或者存储阵列,该服务器可以是物理服务器比如ARM服务器或者X86服务器,还可以是虚拟机,本申请不作具体限定。数据存储系统150用于存储推理设备110的输出数据。
进一步地,半监督模型训练系统120可进一步划分为多个单元模块,图2是一种示例性划分方式,如图2所示,半监督模型训练系统120可包括数据增强单元121、推理单元122、匹配单元123以及训练单元124,其中,数据增强单元121、推理单元122、匹配单元123以及训练单元124,之间建立通信连接,具体可以是有线连接或者无线连接,本申请不作具体限定。需要说明的,样本数据库130可以是如图2所示的,存储于半监督模型训练系统120内,也可以存储于半监督模型训练系统120的外部存储器中,本申请不作具体限定。
半监督模型训练系统120还可包括样本数据库130,其中,样本数据库130包括无标签样本集132、有标签样本集131以及扩充样本集133,上述无标签样本集132可以包括多个没有标签的无标签样本,该无标签样本可以是数据采集设备160采集的原始样本,或者,对原始样本进行数据预处理之后获得的(比如裁剪、降噪等提高样本质量的预处理手段);有标签样本集131可以包括多个有标签样本,每个有标签样本包括输入样本和真实标签,其中,输入样本可以是上述原始样本,也可以是对原始样本进行数据预处理之后获得的,输入样本的真实标签可以是人工标注后获得的;扩充样本集133包括多个没有标签的扩充样本,该扩充该样本是半监督模型训练系统120的数据增强单元121对无标签样本集132进行数据增强后获得的。
数据增强单元121可通过数据增强方法对无标签样本集132中的无标签样本进行数据增强,获得无标签样本对应的扩充样本,以此类推,获得扩充样本集,其中,数据增强方法可包括但不限于翻转变换(flip)、平移变换(shift)、尺度变换(scale)、旋转变换/反射变换(rotation/reflection)、缩放变换(zoom)、修剪(crop)、颜色变换(color space)、噪声扰动(noise)、内核过滤(kernel filters)中的一种或者多种。
其中,翻转变换指的是对图像进行水平或垂直翻转,水平翻转还可分为向上水平翻转和向下水平翻转,垂直翻转还可分为向左垂直翻转和向右垂直翻转;平移变换指的是对图像进行平移操作,比如x方向向右平移(param xoffset),y方向向下平移(param yoffset),其中x方向和y方向指的是图像坐标系的横轴方向和纵轴方向;旋转变换也可称为反射变换,指的是对图像进行某个角度的旋转,该角度可以是0~360度中的任意角度;缩放变换指的是将图像按照一定比例进行放大或者缩小,而不会改变图像中的内容;修剪也可称为裁剪,包括统一裁剪和随机裁剪,统一裁剪指的是将不同尺寸的图像裁剪至设定大小,随机裁剪指的是将不同尺寸的图像随机裁剪成不同尺寸大小;颜色变换指的是对图像某种颜色通道进行修改,比如关闭通道或者改变通道亮度值,举例来说,图像通常包括RGB三个通道,颜色变换可以将R通道值减少或增大;噪声扰动指的是从高斯分布中采样出的随机值矩阵加入到图像的RGB像素矩阵中;内核过滤指的是使用特定功能的内核过滤器与图像进行卷积操作,比如锐化、模糊等内核过滤器。
应理解,上述数据增强方法用于举例说明,本申请还可通过其他数据增强的方法对无标签样本进行扩充获得扩充样本,比如对图像进行还可以通过对抗生成(adversarial training)、特征空间增强(feature space augmentation)、基于GAN的数据增强(gan-based data augmentation)等数据增强方法对无标签样本集132进行扩充,获得扩充样本集133,本申请不一一展开赘述。
需要说明的,由于目标检测模型的标签通常为检测框(bounding box),因此在在第一模型125是目标检测模型时,可使用翻转变换、平移变换、尺度变换、旋转变换、缩放变换等对检测框产生影响的数据增强方法,对无标签样本进行数据增强;由于图像识别模型的标签为图像所属类别的概率分布,因此在第一模型125是图像识别模型时,可使用修剪、颜色变换、 噪声扰动、内核过滤等对图像类别判定产生影响的数据增强方法,对无标签样本进行数据增强。应理解,针对模型类型进行不同的数据增强操作,可以使得最终获得的扩充样本可以增加模型的泛化性能,提高模型的鲁棒性。
需要说明的,数据增强单元121和样本数据库130也可以部署于半监督模型训练系统120之外,比如半监督模型训练系统120与预处理系统建立连接,数据增强单元121和样本数据库130部署于该预处理系统中,通过预处理系统对样本数据库130进行维护,以及对无标签样本集132进行数据增强操作,本申请不作具体限定。
推理单元122用于将第一无标签样本输入第一模型125生成第一无标签样本的第一伪标签,将第一扩充样本输入第一模型125生成第一扩充样本的第二伪标签。其中,第一模型125是使用有标签样本进行训练后获得的AI模型,需要说明的,使用有标签样本集对第一模型进行训练时,可以控制训练的轮数,使得第一模型具备一定的检测能力,防止后续半监督训练过程中出现第一模型125和第二模型126过拟合的现象。
具体实现中,第一模型可以是目标检测模型或者图像识别模型。目标检测模型可以是一阶段统一实时目标检测(You Only Look Once:Unified,Real-Time Object Detection,Yolo)模型、单镜头多盒检测器(Single Shot multi box Detector,SSD)模型、区域卷积神经网络(Region Convolutional Neural Network,RCNN)模型或快速区域卷积神经网络(Fast Region Convolutional Neural Network,Fast-RCNN)模型等,本申请不作具体限定。
可选地,第一模型125为目标检测模型时,目标检测模型的输出结果可能是多个目标检测框,推理单元122可以通过非极大抑制(non maximum suppression,NMS)方法,选择多个目标检测框中精度最高的检测框作为第一伪标签或者第二伪标签,从而增加第一无标签样本的第一伪标签和扩充样本的第二伪标签的精度。
匹配单元123用于将第一无标签样本的第一伪标签和第一扩充样本的第二伪标签进行匹配,获得第一无标签样本的第三伪标签。
可选地,匹配单元123可以根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度,在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
可选地,若第一模型125为目标检测模型,匹配单元123在将第一伪标签和第二伪标签进行匹配时,可以对第二伪标签进行数据增强的逆操作,获得第四伪标签,然后将第四伪标签与第一伪标签进行匹配,获得第一伪标签和第四伪标签之间的匹配结果,根据上述匹配结果确定上述匹配度。其中,逆操作指的是与数据增强单元121执行的数据增强方法相反的操作,比如数据增强单元121对第一无标签样本进行了水平向上翻转操作获得第一扩充样本,那么匹配单元123此时可以对第一扩充样本的第二伪标签对应的标准框进行水平向下翻转操作获得第四伪标签,再比如数据增强单元121对第一无标签样本进行了向右旋转90°操作获得第一扩充样本,那么匹配单元123可以对第一扩充样本的第二伪标签对应的标准框进行向左旋转90操作获得第四伪标签,以此类推,这里不一一举例说明。
可以理解的,目标检测模型的伪标签是图像中目标的检测框,因此通过数据增强方法获得的第一扩充样本,目标的位置实际已发生改变,检测框需要对其进行逆操作,使得第一伪标签和第二伪标签所框选的目标是同一个位置的目标,然后再将其进行匹配,可以筛选出标注目标不准确的第一伪标签,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
具体实现中,匹配单元123可以将第一伪标签对应的检测框与第四伪标签对应的检测框 进行匹配,获得上述匹配度,这里的匹配度可以是两个检测框之间的交并比(intersection over union,IOU)。
具体实现中,在匹配度大于阈值时,匹配单元可以将第一伪标签和第二伪标签进行融合获得第三伪标签时,可以将上述第四伪标签对应的检测框与第一伪标签对应的检测框进行多值平均处理,获得第三伪标签。举例来说,第一伪标签对应的检测框坐标为第四伪标签对应的检测框坐标为那么融合后的第三伪标签yu的公式可以如下:
应理解,上述公式(1)用于助说明,还可以其他方式对第一伪标签和第四伪标签对应的检测框进行融合处理,比如加权平均,本申请不对此进行具体限定。
应理解,将第一伪标签对应的检测框与第四伪标签对应的检测框进行融合,两个检测框是使用不同方法确定的目标检测框,因此将二者融合可以进一步提高最终获得的第三伪标签的精度,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
可选地,若第一模型125为图像识别模型,可以将第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果,根据该匹配结果确定二者之间的匹配度。具体实现中,匹配单元123可以将第一伪标签对应的概率分布与第二伪标签对应的概率分布进行匹配,确定二者之间的匹配度,这里的匹配度可以是两个概率分布之间的相似度或者距离,本申请不作具体限定。匹配单元123将匹配度大于阈值的第一伪标签和第二伪标签进行融合时,可以将第一伪标签的概率分布与第二伪标签的概率分布进行均值处理,比如平均数、加权平均等等,本申请不作具体限定。
可选地,推理单元122也可以将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签,然后将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,这里,第二扩充样本为对第二无标签样本进行数据增强后获得的样本。匹配单元123可以根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度,在匹配度不高于上述阈值的情况下,删除第五伪标签和第六伪标签。其中,上述第五伪标签和第六伪标签之间匹配度的确定方式可以参考前述内容中第一伪标签和第二伪标签之间匹配度的确定方式,这里不重复展开赘述。
需要说明的,第一伪标签和第二伪标签之间的匹配度如果不高于阈值,也可以将第一伪标签和第二伪标签删除,同理,如果第五伪标签和第六伪标签之间的匹配度高于阈值,也可以将第五伪标签和第六伪标签进行融合,融合方式和参考前述内容中关于第一伪标签和第二伪标签融合获得第三伪标签的描述这里不重复赘述。简单来说,无标签样本集132和扩充样本集133进行匹配,呈对应关系的无标签样本的伪标签和扩充样本的伪标签会进行匹配获得相应的匹配度,若匹配度高于阈值则将二者的伪标签进行融合,匹配度低于阈值则将二者的伪标签都进行删除。
可以理解的,通过将无标签样本集132的伪标签和扩充样本集133的伪标签进行匹配的方式,可以过滤出准确度较低的伪标签,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
训练单元124用于使用上述第一无标签样本和第三伪标签对第二模型126进行训练,其中,第二模型126的权重参数与第一模型125相同,简单来说,使用有标签样本集131对机器学习模型进行训练,获得第一模型125,然后将第一模型125的权重参数拷贝给第二模型 126,然后使用第一无标签样本和第三伪标签对第二模型126进行训练。
可选地,第一模型125的模型结构可以与第二模型126的模型结构相同。在该应用场景下,在训练单元124对第二模型126进行训练时,先将第一模型125的权重参数拷贝给第二模型126,然后使用第一无标签样本、第三伪标签以及上述训练第一模型125时使用的有标签样本对第二模型126进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型127。
具体地,使用第一无标签样本、第三伪标签以及上述有标签样本对第二模型126进行第一轮训练,获得第一轮更新后第二模型126的权重参数,然后将其发送给第一模型125进行对第一模型125的更新获得新的第一模型125,再将上述第一无标签样本和第一扩充样本输入新的第一模型125,预测出新的第一伪标签和新的第二伪标签,再将匹配度高于阈值的新的第一伪标签和新的第二伪标签进行融合获得新的第三伪标签,再使用第一无标签样本和新的第三伪标签以及有标签样本继续对第二模型126训练至收敛,获得第二轮更新后的权重参数,然后再将其更新至第一模型125,以此类推,这里不一一展开赘述。
需要说明的,在接下来多轮迭代训练过程中,比如第二轮训练时,可以将上述第二无标签样本和第二扩充样本输入新的第一模型125,获得新的第五伪标签和新的第六伪标签,如果第五伪标签和第六伪标签之间的匹配度仍然不高于阈值,可以继续将第五伪标签和第六伪标签删除;如果第五伪标签和第六伪标签之间的匹配度高于阈值,此时可以将第五伪标签和第六伪标签进行融合获得第七伪标签,然后使用第七伪标签、第二无标签样本以及有标签样本对第二模型进行训练至收敛,获得第二轮更新后第二模型的权重参数,再将其更新至第一模型125,以此类推。
同理,如果接下来多轮迭代训练过程中,比如第二轮训练时,新的第一伪标签和新的第二伪标签之间的匹配度不高于阈值,也可以将新的第一伪标签和新的第二伪标签删除,这里不重复赘述。
可以理解的,通过将第二模型的权重参数对第一模型进行更新,然后多轮迭代训练的方式,使得第一模型125推理出的伪标签精度越来越高,直至第一模型125的预测精度达到用户所需的标准,从而获得目标模型127,并将其发送至推理设备110。
具体实现中,第二模型126每轮训练获得的权重参数可以全部更新至第一模型125,也可以将每轮训练获得的部分权重参数更新至第一模型125,使得第一模型得到缓慢、稳定的权重更新,这样训练获得的第一模型更具鲁棒性,模型性能更佳。具体实现中,可通过指数滑动平均(exponential moving average,EMA)方法将学生的权重更新给第一模型125,举例来说,假设EMA=0.99,那么每轮训练获得的权重参数的1%将被更新入第一模型125。应理解,上述举例用于说明,本申请不作具体限定。
可以理解的,第二模型126与第一模型125的模型结构相同,可以使用少量的、获取困难的有标签样本和大量的、容易获取的无标签样本对机器学习模型进行训练,获得的目标模型127不仅鲁棒性好,而且模型性能好,训练效率高。
可选地,第一模型125的模型结构也可以包括第二模型126的模型结构,也就是说,第二模型126是小模型,第一模型125是大模型,比如第二模型126是第一模型125的一个子模型,同样的,第二模型126将每轮训练获得的更新后的权重同步至第一模型125,新的第一模型125再预测新的伪标签对第二模型126进行训练,以此类推,稳步更新第二模型126和第一模型125,这样最终获得的第二模型不仅结构复杂度低,而且模型性能与第一模型125趋于接近,甚至可以比第一模型125性能更好,从而达到模型压缩的目的。
在本申请实施例中,使用有标签样本、第一无标签样本和第三伪标签对第二模型126进行训练时,可以将有标签样本集131中的输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,然后根据损失值对第二模型进行反向传播直至收敛,获得训练好的第二模型,然后将训练好的第二模型的模型参数同步至第一模型125中,再进行下一轮的模型训练。其中,上述损失值L包括有标签损失L1和无标签损失L2,该损失值L是根据第一输出值和真实标签之间的差距获得的,伪标签损失是根据第二输出值和第三伪标签之间的差距获得的。
具体实现中,可以通过系数加权的方式对有标签损失L1和无标签损失L2在损失值L中的比重进行调控,比如损失值L=L1+λL2,其中,λ越大,无标签损失L2在损失值L中的占比越大,第一无标签样本对第二模型126和第一模型125的模型性能影响越大,通过对λ的值进行调整,可以认为干预模型训练的学习方向。
需要说明的,本申请的半监督模型训练系统120也可以打包成一个软件模块,对现有的一些模型训练系统进行软件升级,使其能够拥有对伪标签过滤、融合的功能,使得升级后的件模型训练系统可以有更好的半监督训练性能。
参考前述内容可知,在公有云场景下,上述数据增强单元121、推理单元122、匹配单元123以及训练单元124可以打包为一个配置模块,作为公有云模型训练服务中的一个小的配置功能,如果公有云用户购买该功能,即可为用户提供相应的权限。在非公有云场景下,上述数据增强单元121、推理单元122、匹配单元123以及训练单元124可以打包为一个微服务或者软件包,用户购买本申请提供的伪标签过滤、融合功能之后,可以向用户提供相应的权限的许可(license),不同权限可设置不同的收费程度。本申请不作具体限定。
需要说明的,图2展示了模型训练系统的一种示例性划分方式,具体实现中,半监督模型训练系统120可包括更多或者更少的单元模块,比如样本数据库130和数据增强单元121可以部署于半监督模型训练系统120之外,比如匹配单元123可进一步划分为阈值判断单元、融合单元和删除单元,其中,阈值判断单元用于确定第一伪标签和第二伪标签之间的匹配度,在匹配度高于阈值的情况下,通过融合单元将第一伪标签和第二伪标签进行融合获得第三伪标签,在匹配度不高于阈值的情况下,通过删除单元将第一伪标签和第二伪标签删除,这里不重复展开赘述。
综上可知,本申请提供的模型训练系统,通过对无标签样本集进行数据增强获得无标签样本集的扩充样本集,然后将无标签样本集和扩充样本集输入第一模型推理获得无标签样本的集的多个第一伪标签以及扩充样本集的多个第二伪标签,然后将匹配度高于阈值的第一伪标签和第二伪标签进行融合获得第三伪标签,将匹配度低于或等于阈值的第一伪标签进行过滤,从而提高未标注样本集合的伪标签的质量,使得后续使用第三伪标签对学生模型进行半监督训练时,模型的训练效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
图3是本申请提供的一种半监督模型训练方法的步骤流程示意图,该方法可应用于图2所示的半监督模型训练系统120中,如图3所示,该方法可包括以下步骤:
步骤S310:半监督模型训练系统120将第一无标签样本输入第一模型,获得第一无标签样本的第一伪标签。该步骤可以由图2中的推理单元122实现。
其中,第一模型是使用有标签样本集对机器学习模型进行训练后获得的,机器学习模型可以是目标检测模型或者图像识别模型。其中,目标检测模型和图像识别模型的描述可参考 图2实施例中的详细描述,这里不重复赘述。
可选地,第一模型为目标检测模型时,目标检测模型的输出结果可能是多个目标检测框,半监督模型训练系统120可以NMS方法,选择多个目标检测框中精度最高的检测框作为第一伪标签或者第二伪标签,从而增加第一无标签样本的第一伪标签和第一扩充样本的第二伪标签的精度。
步骤S320:将第一扩充样本输入第一模型,获得第一扩充样本的第二伪标签。该步骤可以由图2中的推理单元122实现。
具体实现中,在步骤S320之前,可以对第一无标签样本进行数据增强,获得第一无标签样本的第一扩充样本。该步骤可以由图2中的数据增强单元121实现。
其中,半监督模型训练系统120的描述可参考图2实施例,这里不重复赘述,第一无标签样本的描述可参考术语解释,这里也不重复赘述。
在本申请实施例中,数据增强方法可包括但不限于翻转变换、平移变换、尺度变换、旋转变换/反射变换、缩放变换、修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种,每种数据增强方法的具体描述可参考图2实施例中的描述,这里不重复赘述。
需要说明的,由于目标检测模型的标签通常为检测框(bounding box),因此在在第一模型是目标检测模型时,可使用翻转变换、平移变换、尺度变换、旋转变换、缩放变换等对检测框产生影响的数据增强方法,对第一无标签样本进行数据增强;由于图像识别模型的标签为图像所属类别的概率分布,因此在第一模型是图像识别模型时,可使用修剪、颜色变换、噪声扰动、内核过滤等对图像类别判定产生影响的数据增强方法,对第一无标签样本进行数据增强。应理解,针对模型类型进行不同的数据增强操作,可以使得最终获得的扩充样本可以增加模型的泛化性能,提高模型的鲁棒性。
需要说明的,步骤S310也可以由其他数据预处理系统实现,换句话说,半监督模型训练系统120也可以向其他数据预处理系统获取上述第一扩充样本和第一无标签样本,也可以自己对第一无标签样本进行数据增强获得相应的第一扩充样本,具体可根据实际业务处理情况决定,本申请不作具体限定。
步骤S330:根据第一伪标签和第二伪标签获得第一无标签样本的第三伪标签。该步骤可以由图2中的匹配单元123实现。
具体实现中,半监督模型训练系统120可以根据第一伪标签和第二伪标签,获得第一伪标签和第二伪标签之间的匹配度,在匹配度高于阈值的情况下,将第一伪标签和第二伪标签进行融合,获得第三伪标签。
可选地,若第一模型为目标检测模型,半监督模型训练系统120在将第一伪标签和第二伪标签进行匹配时,可以对第二伪标签进行数据增强的逆操作,获得第四伪标签,然后将第四伪标签与第一伪标签进行匹配,获得上述匹配度。其中,逆操作指的是与数据增强单元121执行的数据增强方法相反的操作,比如步骤S310对第一无标签样本进行了水平向上翻转操作获得第一扩充样本,那么步骤S330此时可以对第一扩充样本的第二伪标签对应的标准框进行水平向下翻转操作获得第四伪标签,再比如步骤S310对第一无标签样本进行了向右旋转90°操作获得第一扩充样本,那么步骤S330可以对第一扩充样本的第二伪标签对应的标准框进行向左旋转90操作获得第四伪标签,以此类推,这里不一一举例说明。
可以理解的,目标检测模型的伪标签是图像中目标的检测框,因此通过数据增强方法获得的第一扩充样本,目标的位置实际已发生改变,检测框需要对其进行逆操作,使得第一伪标签和第二伪标签所框选的目标是同一个位置的目标,然后再将其进行匹配,可以筛选出标 注目标不准确的第一伪标签,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
具体实现中,匹配单元123可以将第一伪标签对应的检测框与第四伪标签对应的检测框进行匹配,获得上述匹配度,这里的匹配度可以是两个检测框之间的IOU。在匹配度大于阈值时,匹配单元可以将第一伪标签和第二伪标签进行融合获得第三伪标签时,可以将上述第四伪标签对应的检测框与第一伪标签对应的检测框进行多值平均处理,获得第三伪标签。具体可参考公式(1)的相关描述,这里不重复赘述。
应理解,将第一伪标签对应的检测框与第四伪标签对应的检测框进行融合,两个检测框是使用不同方法确定的目标检测框,因此将二者融合可以进一步提高最终获得的第三伪标签的精度,从而避免对第二模型进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
可选地,若第一模型为图像识别模型,可以将第一伪标签和第二伪标签进行匹配,获得第一伪标签和第二伪标签之间的匹配结果,根据该匹配结果确定二者之间的匹配度。具体实现中,可以将第一伪标签对应的概率分布与第二伪标签对应的概率分布进行匹配,确定二者之间的匹配度,这里的匹配度可以是两个概率分布之间的相似度或者距离,本申请不作具体限定。将匹配度大于阈值的第一伪标签和第二伪标签进行融合时,可以将第一伪标签的概率分布与第二伪标签的概率分布进行均值处理,比如平均数、加权平均等等,本申请不作具体限定。
可选地,也可以将第二无标签样本输入第一模型,获得第二无标签样本的第五伪标签,然后将第二扩充样本输入第一模型,获得第二扩充样本的第六伪标签,这里,第二扩充样本为对第二无标签样本进行数据增强后获得的样本。根据第五伪标签和第六伪标签,获得第五伪标签和第六伪标签之间的匹配度,在匹配度不高于上述阈值的情况下,删除第五伪标签和第六伪标签。其中,上述第五伪标签和第六伪标签之间匹配度的确定方式可以参考前述内容中第一伪标签和第二伪标签之间匹配度的确定方式,这里不重复展开赘述。
需要说明的,第一伪标签和第二伪标签之间的匹配度如果不高于阈值,也可以将第一伪标签和第二伪标签删除,同理,如果第五伪标签和第六伪标签之间的匹配度高于阈值,也可以将第五伪标签和第六伪标签进行融合,融合方式和参考前述内容中关于第一伪标签和第二伪标签融合获得第三伪标签的描述这里不重复赘述。简单来说,无标签样本集132和扩充样本集133进行匹配,呈对应关系的无标签样本的伪标签和扩充样本的伪标签会进行匹配获得相应的匹配度,若匹配度高于阈值则将二者的伪标签进行融合,匹配度低于阈值则将二者的伪标签都进行删除。
可以理解的,通过将无标签样本集132的伪标签和扩充样本集133的伪标签进行匹配的方式,可以过滤出准确度较低的伪标签,从而避免对第二模型126进行半监督训练时使用到错误或者精度低的伪标签,从而提高训练效率、提高最终获得的第一模型的精度。
步骤S340:使用第一无标签样本和第三伪标签对第二模型进行训练,该第二模型与第一模型的权重参数相同。该步骤可以由图2中的训练单元124实现。简单来说,使用有标签样本集对机器学习模型进行训练,获得第一模型,然后将第一模型的权重参数拷贝给第二模型,然后使用第一无标签样本和第三伪标签对第二模型进行训练。
可选地,第一模型的模型结构可以与第二模型的模型结构相同,在该应用场景下,在对第二模型126进行训练时,先将第一模型125的权重参数拷贝给第二模型126,然后使用第一无标签样本、第三伪标签以及上述训练第一模型125时使用的有标签样本对第二模型126 进行迭代训练,根据每次迭代训练获得的第二模型的权重参数对第一模型的权重参数进行迭代更新,获得目标模型127。
具体地,使用第一无标签样本、第三伪标签以及上述有标签样本对第二模型126进行第一轮训练,获得第一轮更新后第二模型126的权重参数,然后将其发送给第一模型125进行对第一模型125的更新获得新的第一模型125,再将上述第一无标签样本和第一扩充样本输入新的第一模型125,预测出新的第一伪标签和新的第二伪标签,再将匹配度高于阈值的新的第一伪标签和新的第二伪标签进行融合获得新的第三伪标签,再使用第一无标签样本和新的第三伪标签以及有标签样本继续对第二模型126训练至收敛,获得第二轮更新后的权重参数,然后再将其更新至第一模型125,以此类推,这里不一一展开赘述。
需要说明的,在接下来多轮迭代训练过程中,比如第二轮训练时,可以将上述第二无标签样本和第二扩充样本输入新的第一模型125,获得新的第五伪标签和新的第六伪标签,如果第五伪标签和第六伪标签之间的匹配度仍然不高于阈值,可以继续将第五伪标签和第六伪标签删除;如果第五伪标签和第六伪标签之间的匹配度高于阈值,此时可以将第五伪标签和第六伪标签进行融合获得第七伪标签,然后使用第七伪标签、第二无标签样本以及有标签样本对第二模型进行训练至收敛,获得第二轮更新后第二模型的权重参数,再将其更新至第一模型125,以此类推。
同理,如果接下来多轮迭代训练过程中,比如第二轮训练时,新的第一伪标签和新的第二伪标签之间的匹配度不高于阈值,也可以将新的第一伪标签和新的第二伪标签删除,这里不重复赘述。
可以理解的,通过将第二模型的权重参数对第一模型进行更新,然后多轮迭代训练的方式,使得第一模型125推理出的伪标签精度越来越高,直至第一模型125的预测精度达到用户所需的标准,从而获得目标模型127,并将其发送至推理设备110。
具体实现中,第二模型每轮训练获得的权重参数可以全部更新至第一模型,也可以将每轮训练获得的部分权重参数更新至第一模型,使得第一模型得到缓慢、稳定的权重更新,这样训练获得的第一模型更具鲁棒性,模型性能更佳。具体实现中,可通过EMA方法将学生的权重更新给第一模型,举例来说,假设EMA=0.99,那么每轮训练获得的权重参数的1%将被更新入第一模型。应理解,上述举例用于说明,本申请不作具体限定。
可以理解的,第二模型与第一模型的模型结构相同,可以使用少量的、获取困难的有标签样本和大量的、容易获取的无标签样本对机器学习模型进行训练,获得的目标模型127不仅鲁棒性好,而且模型性能好,训练效率高。
可选地,第一模型的模型结构也可以包括第二模型的模型结构,也就是说,第二模型是第一模型的一个子模型,同样的,第二模型将每轮训练获得的更新后的权重同步至第一模型,新的第一模型再预测新的伪标签对第二模型进行训练,以此类推,稳步更新第二模型和第一模型,这样最终获得的第二模型不仅结构复杂度低,而且模型性能与第一模型趋于接近,甚至可以比第一模型性能更好,从而达到模型压缩的目的。
在本申请实施例中,使用有标签样本、第一无标签样本和第三伪标签对第二模型进行训练时,可以将有标签样本集131中的输入样本输入第二模型获得第一输出值,将第一无标签样本输入第二模型获得第二输出值,根据第一输出值和第二输出值确定第二模型的损失值,然后根据损失值对第二模型进行反向传播直至收敛,获得训练好的第二模型,然后将训练好的第二模型的模型参数同步至第一模型中,再进行下一轮的模型训练。其中,上述损失值L包括有标签损失L1和无标签损失L2,该损失值L是根据第一输出值和真实标签之间的差距获 得的,伪标签损失是根据第二输出值和第三伪标签之间的差距获得的。
具体实现中,可以通过系数加权的方式对有标签损失L1和无标签损失L2在损失值L中的比重进行调控,比如损失值L=L1+λL2,其中,λ越大,无标签损失L2在损失值L中的占比越大,第一无标签样本对第二模型和第一模型的模型性能影响越大,通过对λ的值进行调整,可以认为干预模型训练的学习方向。
下面结合附图4和图5,对图3所示的半监督模型训练方法进行举例说明,图4是本申请提供的半监督模型训练方法在一应用场景下的步骤流程示意图,图5是本申请提供的半监督模型训练方法中第一伪标签和第二伪标签的融合流程示意图,该应用场景中,第一模型为目标检测模型,第二模型为网络结构与第一模型相同的目标检测模型。
如图4所示,该应用场景下的半监督模型训练方法可包括以下步骤。
步骤1.获取训练样本集,该训练样本集可包括图2所示的有标签样本集131和无标签样本集132。
步骤2.判断是否为有标签样本,具体地,将训练样本集中的每个样本进行判断分类,有标签样本执行步骤3,无标签样本执行步骤4。
步骤3.使用有标签样本训练第一模型。该第一模型即为图2和图3实施例中的第一模型125。步骤3执行完毕后执行步骤5或步骤6。
需要说明的,使用有标签样本集对机器学习模型进行训练时,可以控制训练的轮数,使得第一模型具备一定的检测能力,防止后续半监督训练过程中出现第一模型和第二模型过拟合的现象。
步骤4.对无标签样本进行数据增强,获得扩充样本。该步骤的具体描述可参考图3实施例中的步骤S320,该步骤可以由图2实施例中的数据增强单元121实现。步骤4执行完毕后执行步骤7。
具体实现中,图4所示的应用场景使用的数据增强方法为:将第一无标签样本垂直向右翻转。
需要说明的,步骤3和步骤4可以并行或者串行处理,具体可根据半监督模型训练系统120所部署计算设备的处理能力决定,本申请不作具体限定。
步骤5.将第一模型的模型参数拷贝至第二模型。第二模型即为图2和图3实施例中描述的第二模型126。
步骤6.将无标签样本输入第一模型获得第一伪标签。该步骤可由图2实施例中的推理单元122实现,具体可参考图3实施例步骤S310的描述,这里不重复赘述。
步骤7.将扩充样本输入第一模型获得第二伪标签。该步骤可由图2实施例中的推理单元122实现,具体可参考图3实施例步骤S320的描述,这里不重复赘述。
需要说明的步骤6可以与步骤7并行或串行处理,本申请不作具体限定。
步骤8.对第二伪标签进行数据增强逆操作,获得第四伪标签。
具体实现中,数据增强的逆操作指的是对步骤4中的数据增强对应的逆操作,步骤4将第一无标签样本进行垂直向右的翻转,那么步骤8可以对第二伪标签进行垂直向左的翻转。
步骤9.将第一伪标签和第四伪标签进行匹配,获得匹配度。
具体实现中,步骤9可以将第一伪标签对应的检测框与第四伪标签对应的检测框进行匹配,这里的匹配度可以是两个检测框之间的IOU。
步骤10.判断匹配度是否高于阈值,在匹配度高于阈值的情况下执行步骤11,在匹配度低于阈值的情况下执行步骤14。
具体地,假设阈值为0.45,那么步骤10可以将IOU大于0.45的第一伪标签和第四伪标签进行融合,即执行步骤11,将其删除。
步骤11.将匹配度高于阈值的第一伪标签和第四伪标签融合为第三伪标签。上述步骤8~步骤11可以由图2实施例中的匹配单元123实现,具体可参考图3实施例中的步骤S330,这里不重复赘述。
示例性地,如图5所示,图5是本申请提供的半监督模型训练方法中第一伪标签和第二伪标签融合为第三伪标签的步骤流程示意图。对应图4实施例中的步骤4、步骤6~步骤11,由图5可知,步骤4对第一无标签样本进行扩充后,步骤6将第一无标签样本输入第一模型125获得第一伪标签,步骤7将扩充样本输入第一模型125获得第二伪标签,其中,第一伪标签和第二位标签的精度较低,两个伪标签对应的检测框都没有完整的标注出目标(也就是车辆)。
可以理解的,图5清晰的显示了第二伪标签所标注的样本与第一伪标签所标注的样本是不同的样本,第一伪标签所标注的未标注样本中货车在上方车道,第二伪标签所标注的扩充样本中货车在下方车道,因此不能直接执行步骤10将二者进行匹配。通过步骤8将第二伪标签进行数据增强的逆操作,即水平翻转-180度之后,获得第四伪标签,此时第一伪标签和第四伪标签所标注的样本是同一个样本,即货车在上的样本。这样,执行步骤10将第一伪标签和第四伪标签进行匹配,二者的匹配度可以良好的指示第一伪标签的精度,将匹配度高于阈值的第一伪标签和第四伪标签融合后,还可以进一步提高第一伪标签的精度,如图5所示,步骤11获得的第三伪标签相比第一伪标签和第二伪标签的精度更高,准确度更高,使用第三伪标签训练第二模型不仅可以提高训练效率,提高第二模型的性能,在第二模型的权重参数引入第一模型后,还可以提高第一模型的性能,整个过程不需要人工辅助校正样本,提高整个训练过程的训练效率,降低人力成本。
需要说明的,如果步骤10对第一标签和第二标签之间的匹配度判定为不高于阈值,那么可执行步骤14将第一伪标签、第二伪标签以及第四伪标签删除,本轮训练过程中,该第一无标签样本可以不参与训练,下一轮训练过程中,如果第一伪标签和第二伪标签之间的匹配度高于阈值,那么可以参与下一轮的训练,以此类推,这里不展开赘述。
步骤12.使用第三伪标签、无标签样本和有标签样本对第二模型进行训练。
步骤13.将训练好的第二模型的权重参数引入第一模型。上述步骤12和13可以由图1实施例中的训练单元124实现,具体可参考图3实施例中的步骤S340,这里不重复赘述。
可选地,可以根据EMA方法将部分训练好的第二模型的权重参数引入第一模型,获得新的第一模型,继续执行步骤6~步骤14,以此类推,直至第一模型收敛。
步骤14.删除第一伪标签和第二伪标签。
需要说明的,在重复执行步骤6~步骤14时,即使上一轮第一伪标签和第二伪标签进行了删除,下一轮中新的第一伪标签和第二伪标签可继续进行匹配、融合或删除的过程,使得第一模型推理出的第一伪标签和第二伪标签的精度越来越高,从而使得模型训练效果越来越好。
综上可知,本申请提供的半监督模型训练方法,通过对无标签样本进行数据增强获得无标签样本的扩充样本,然后将无标签样本和扩充样本输入第一模型推理获得无标签样本的第一伪标签以及扩充样本的第二伪标签,然后将匹配度高于阈值的第一伪标签和第二伪标签进行融合获得第三伪标签,将匹配度低于或等于阈值的第一伪标签进行过滤,从而提高未标注样本的伪标签的质量,使得后续使用第三伪标签对第二模型进行半监督训练时,模型的训练 效率和性能得以提升,进而提升最终获得的第一模型的训练效率以及模型性能。
图6是本申请提供的一种计算设备的结构示意图,该计算设备600是图1至图5实施例中的半监督模型训练系统120。
进一步地,计算设备600包括处理器601、存储单元602、存储介质603和通信接口604,其中,处理器601、存储单元602、存储介质603和通信接口604通过总线605进行通信,也通过无线传输等其他手段实现通信。
处理器601由至少一个通用处理器构成,例如CPU、NPU或者CPU和硬件芯片的组合。上述硬件芯片是专用集成电路(Application-Specific Integrated Circuit,ASIC)、编程逻辑器件(Programmable Logic Device,PLD)或其组合。上述PLD是复杂编程逻辑器件(Complex Programmable Logic Device,CPLD)、现场编程逻辑门阵列(Field-Programmable Gate Array,FPGA)、通用阵列逻辑(Generic Array Logic,GAL)或其任意组合。处理器601执行各种类型的数字存储指令,例如存储在存储单元602中的软件或者固件程序,它能使计算设备600提供较宽的多种服务。
具体实现中,作为一种实施例,处理器601包括一个或多个CPU,例如图6中所示的CPU0和CPU1。
在具体实现中,作为一种实施例,计算设备600也包括多个处理器,例如图6中所示的处理器601和处理器606。这些处理器中的每一个可以是一个单核处理器(single-CPU),也可以是一个多核处理器(multi-CPU)。这里的处理器指一个或多个设备、电路、和/或用于处理数据(例如计算机程序指令)的处理核。
存储单元602用于存储程序代码,并由处理器601来控制执行,以执行上述图1-图5中任一实施例中半监督模型训练系统120的处理步骤。程序代码中包括一个或多个软件单元,上述一个或多个软件单元是图2实施例中的推理单元、匹配单元和训练单元,其中,推理单元用于将第一无标签样本输入第一模型获得第一无标签样本的第一伪标签,将第一扩充样本输入第一模型获得第一扩充样本的第二伪标签;匹配单元用于将第一伪标签和第二伪标签进行匹配,获得无标签样本的第三伪标签;训练单元用于使用第三伪标签和第一无标签样本对第二模型进行训练。其中,推理单元用于执行图3中的步骤S310~步骤S320以及图4和图5中的步骤6和步骤7,匹配单元用于执行图3中的步骤S330以及图4和图5中的步骤8~步骤11,训练单元用于执行图3中的步骤后S340以及图4中的步骤12和步骤13。具体实现方式参考图1~图5实施例,此处不再赘述。
存储单元602包括只读存储器和随机存取存储器,并向处理器601提供指令和数据。存储单元602还包括非易失性随机存取存储器。存储单元602是易失性存储器或非易失性存储器,或包括易失性和非易失性存储器两者。其中,非易失性存储器是只读存储器(read-only memory,ROM)、编程只读存储器(programmable ROM,PROM)、擦除编程只读存储器(erasable PROM,EPROM)、电擦除编程只读存储器(electrically EPROM,EEPROM)或闪存。易失性存储器是随机存取存储器(random access memory,RAM),其用作外部高速缓存。通过示例性但不是限制性说明,许多形式的RAM用,例如静态随机存取存储器(static RAM,SRAM)、动态随机存取存储器(DRAM)、同步动态随机存取存储器(synchronous DRAM,SDRAM)、双倍数据速率同步动态随机存取存储器(double data date SDRAM,DDR SDRAM)、增强型同步动态随机存取存储器(enhanced SDRAM,ESDRAM)、同步连接动态随机存取存储器(synchlink DRAM,SLDRAM)和直接内存总线随机存取存储器(direct rambus RAM,DR  RAM)。还是硬盘(hard disk)、U盘(universal serial bus,USB)、闪存(flash)、SD卡(secure digital memory Card,SD card)、记忆棒等等,硬盘是硬盘驱动器(hard disk drive,HDD)、固态硬盘(solid state disk,SSD)、机械硬盘(mechanical hard disk,HDD)等,本申请不作具体限定。
存储介质603是存储数据的载体,比如硬盘(hard disk)、U盘(universal serial bus,USB)、闪存(flash)、SD卡(secure digital memory Card,SD card)、记忆棒等等,硬盘可以是硬盘驱动器(hard disk drive,HDD)、固态硬盘(solid state disk,SSD)、机械硬盘(mechanical hard disk,HDD)等,本申请不作具体限定。
通信接口604为有线接口(例如以太网接口),为内部接口(例如高速串行计算机扩展总线(Peripheral Component Interconnect express,PCIe)总线接口)、有线接口(例如以太网接口)或无线接口(例如蜂窝网络接口或使用无线局域网接口),用于与其他服务器或单元进行通信。
总线605是快捷外围部件互联标准(Peripheral Component Interconnect Express,PCIe)总线,或扩展工业标准结构(extended industry standard architecture,EISA)总线、统一总线(unified bus,Ubus或UB)、计算机快速链接(compute express link,CXL)、缓存一致互联协议(cache coherent interconnect for accelerators,CCIX)等。总线605分为地址总线、数据总线、控制总线等。
总线605除包括数据总线之外,还包括电源总线、控制总线和状态信号总线等。但是为了清楚说明起见,在图中将各种总线都标为总线605。
需要说明的,图6仅仅是本申请实施例的一种能的实现方式,实际应用中,计算设备600还包括更多或更少的部件,这里不作限制。关于本申请实施例中未示出或未描述的内容,参见前述图1-图5实施例中的相关阐述,这里不再赘述。
本申请实施例提供一种计算机存储介质,包括:该计算机存储介质中存储有指令;当该指令在计算设备上运行时,使得该计算设备执行上述图1至图5描述的半监督模型训练方法。
本申请实施例提供了一种包含指令的程序产品,包括程序或指令,当该程序或指令在计算设备上运行时,使得该计算设备执行上述图1至图5描述的半监督模型训练方法。
上述实施例,全部或部分地通过软件、硬件、固件或其他任意组合来实现。当使用软件实现时,上述实施例全部或部分地以计算机程序产品的形式实现。计算机程序产品包括至少一个计算机指令。在计算机上加载或执行计算机程序指令时,全部或部分地产生按照本发明实施例的流程或功能。计算机为通用计算机、专用计算机、计算机网络、或者其他编程装置。计算机指令存储在计算机读存储介质中,或者从一个计算机读存储介质向另一个计算机读存储介质传输,例如,计算机指令从一个网站站点、计算机、服务器或数据中心通过有线(例如同轴电缆、光纤、数字用户线(digital subscriber line,DSL))或无线(例如红外、无线、微波等)方式向另一个网站站点、计算机、服务器或数据中心进行传输。计算机读存储介质是计算机能够存取的任何用介质或者是包含至少一个用介质集合的服务器、数据中心等数据存储节点。用介质是磁性介质(例如,软盘、硬盘、磁带)、光介质(例如,高密度数字视频光盘(digital video disc,DVD)、或者半导体介质。半导体介质是SSD。
以上,仅为本发明的具体实施方式,但本发明的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本发明揭露的技术范围内,轻易想到各种等效的修复或替换,这些修复或替换都应涵盖在本发明的保护范围之内。因此,本发明的保护范围应以权利要求的保护范围为准。

Claims (21)

  1. 一种半监督模型训练方法,其特征在于,所述方法包括:
    将第一无标签样本输入第一模型,获得所述第一无标签样本的第一伪标签;
    将第一扩充样本输入所述第一模型,获得所述第一扩充样本的第二伪标签,其中,所述第一模型为采用有标签样本进行训练后的人工智能AI模型,所述第一扩充样本为对所述第一无标签样本进行数据增强后获得的样本;
    根据所述第一伪标签和所述第二伪标签,获得所述第一无标签样本的第三伪标签;
    使用所述第一无标签样本和所述第三伪标签对第二模型进行训练,其中,所述第二模型是根据所述第一模型的权重参数获得的AI模型。
  2. 根据权利要求1所述的方法,其特征在于,所述第二模型与所述第一模型具有相同的结构。
  3. 根据权利要求1或2所述的方法,其特征在于,所述根据所述第一伪标签和所述第二伪标签,获得所述第一无标签样本的第三伪标签包括:
    根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度;
    在所述匹配度高于阈值的情况下,将所述第一伪标签和第二伪标签进行融合,获得所述第三伪标签。
  4. 根据权利要求1至3任一权利要求所述的方法,其特征在于,所述第一模型包括目标检测模型,所述数据增强方法包括翻转变换、平移变换、尺度变换、旋转变换、缩放变换中的一种或者多种。
  5. 根据权利要求4所述的方法,其特征在于,所述根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度包括:
    对所述第二伪标签进行所述数据增强的逆操作,获得第四伪标签;
    对所述第一伪标签和所述第四伪标签进行匹配,获得所述第一伪标签和所述第四伪标签之间的匹配结果;
    根据所述第一伪标签与所述第四伪标签之间的匹配结果确定所述匹配度。
  6. 根据权利要求1至4任一权利要求所述的方法,其特征在于,所述第一模型包括图像识别模型,所述数据增强方法包括修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种。
  7. 根据权利要求6所述的方法,其特征在于,所述根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度包括:
    对所述第一伪标签和所述第二伪标签进行匹配,获得所述第一伪标签和所述第二伪标签之间的匹配结果;
    根据所述第一伪标签和所述第二伪标签之间的匹配结果获得所述匹配度。
  8. 根据权利要求3至7任一权利要求所述的方法,其特征在于,所述方法还包括:
    将第二无标签样本输入所述第一模型,获得所述第二无标签样本的第五伪标签;
    将第二扩充样本输入所述第一模型,获得所述第二扩充样本的第六伪标签,所述第二扩充样本为对所述第二无标签样本进行数据增强后获得的样本;
    根据所述第五伪标签和所述第六伪标签,获得所述第五伪标签和所述第六伪标签之间的匹配度;
    在所述匹配度不高于所述阈值的情况下,删除所述第五伪标签和所述第六伪标签。
  9. 根据权利要求1至8任一权利要求所述的方法,其特征在于,所述使用所述第一无标签样本和所述第三伪标签样本对第二模型进行训练,包括:
    使用所述有标签样本、所述第一无标签样本和所述第三伪标签样本对所述第二模型进行迭代训练,根据每次迭代训练获得的所述第二模型的权重参数对所述第一模型的权重参数进行迭代更新,获得目标模型。
  10. 根据权利要求9所述的方法,其特征在于,所述使用所述有标签样本、所述第一无标签样本和所述第三伪标签样本对所述第二模型进行迭代训练包括:
    将所述输入样本输入所述第二模型获得第一输出值,将所述第一无标签样本输入所述第二模型获得第二输出值,根据所述第一输出值和所述第二输出值确定所述第二模型的损失值,其中,所述损失值包括有标签损失和伪标签损失,所述有标签损失是根据所述第一输出值和所述真实标签之间的差值获得的,所述伪标签损失是根据所述第二输出值和所述第三伪标签之间的差值获得的;
    根据所述损失值对所述第二模型进行迭代训练。
  11. 一种半监督模型训练系统,其特征在于,所述系统包括:
    推理单元,用于将第一无标签样本输入第一模型,获得所述第一无标签样本的第一伪标签;
    推理单元,用于将第一扩充样本输入所述第一模型,获得所述扩充样本的第二伪标签,其中,所述第一模型为采用有标签样本进行训练后的人工智能AI模型,所述第一扩充样本为对所述第一无标签样本进行数据增强后获得的样本;
    匹配单元,用于根据所述第一伪标签和所述第二伪标签,获得所述第一无标签样本的第三伪标签;
    训练单元,用于使用所述第一无标签样本和所述第三伪标签对第二模型进行训练,其中,所述第二模型是根据所述第一模型的权重参数获得的AI模型。
  12. 根据权利要求11所述的系统,其特征在于,所述第二模型与所述第一模型具有相同的结构。
  13. 根据权利要求11或12所述的系统,其特征在于,
    所述匹配单元,用于根据所述第一伪标签和所述第二伪标签,获得所述第一伪标签和所述第二伪标签之间的匹配度;
    所述匹配单元,用于在所述匹配度高于阈值的情况下,将所述第一伪标签和第二伪标签进行融合,获得所述第三伪标签。
  14. 根据权利要求11至13任一权利要求所述的系统,其特征在于,所述第一模型包括目标检测模型,所述数据增强方法包括翻转变换、平移变换、尺度变换、旋转变换、缩放变换中的一种或者多种。
  15. 根据权利要求14所述的系统,其特征在于,
    所述匹配单元,用于对所述第二伪标签进行所述数据增强的逆操作,获得第四伪标签;
    所述匹配单元,用于对所述第一伪标签和所述第四伪标签进行匹配,获得所述第一伪标签和所述第四伪标签之间的匹配结果;
    所述匹配单元,用于根据所述第一伪标签与所述第四伪标签之间的匹配结果确定所述匹配度。
  16. 根据权利要求11至14任一权利要求所述的系统,其特征在于,所述第一模型包括图像识别模型,所述数据增强方法包括修剪、颜色变换、噪声扰动、内核过滤中的一种或者多种。
  17. 根据权利要求16所述的系统,其特征在于,
    所述匹配单元,用于对所述第一伪标签和所述第二伪标签进行匹配,获得所述第一伪标签和所述第二伪标签之间的匹配结果;
    所述匹配单元,用于根据所述第一伪标签和所述第二伪标签之间的匹配结果获得所述匹配度。
  18. 根据权利要求11至17任一权利要求所述的系统,其特征在于,
    所述推理单元,用于将第二无标签样本输入所述第一模型,获得所述第二无标签样本的第五伪标签;
    所述推理单元,用于将第二扩充样本输入所述第一模型,获得所述第二扩充样本的第六伪标签,所述第二扩充样本为对所述第二无标签样本进行数据增强后获得的样本;
    所述匹配单元,用于根据所述第五伪标签和所述第六伪标签,获得所述第五伪标签和所述第六伪标签之间的匹配度;
    所述匹配单元,用于在所述匹配度不高于所述阈值的情况下,删除所述第五伪标签和所述第六伪标签。
  19. 根据权利要求11至18任一权利要求所述的系统,其特征在于,所述训练单元,用于使用所述有标签样本、所述第一无标签样本和所述第三伪标签样本对所述第二模型进行迭代训练,根据每次迭代训练获得的所述第二模型的权重参数对所述第一模型的权重参数进行迭代更新,获得目标模型。
  20. 根据权利要求19所述的系统,其特征在于,所述训练单元,用于将所述输入样本输入所述第二模型获得第一输出值,将所述第一无标签样本输入所述第二模型获得第二输出值,根据所述第一输出值和所述第二输出值确定所述第二模型的损失值,其中,所述损失值包括有标签损失和伪标签损失,所述有标签损失是根据所述第一输出值和所述真实标签之间的差 值获得的,所述伪标签损失是根据所述第二输出值和所述第三伪标签之间的差值获得的;
    所述训练单元,用于根据所述损失值对所述第二模型进行迭代训练。
  21. 一种计算设备,其特征在于,所述计算设备包括处理器和存储器,所述存储器用于存储代码,所述处理器用于执行所述代码实现如权利要求1至10任一权利要求所述的方法。
PCT/CN2023/089098 2022-04-19 2023-04-19 一种半监督模型训练方法、系统及相关设备 WO2023202596A1 (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202210412186.6 2022-04-19
CN202210412186.6A CN114970673B (zh) 2022-04-19 2022-04-19 一种半监督模型训练方法、系统及相关设备

Publications (1)

Publication Number Publication Date
WO2023202596A1 true WO2023202596A1 (zh) 2023-10-26

Family

ID=82977875

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2023/089098 WO2023202596A1 (zh) 2022-04-19 2023-04-19 一种半监督模型训练方法、系统及相关设备

Country Status (2)

Country Link
CN (1) CN114970673B (zh)
WO (1) WO2023202596A1 (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117237343A (zh) * 2023-11-13 2023-12-15 安徽大学 半监督rgb-d图像镜面检测方法、存储介质及计算机设备
CN118015316A (zh) * 2024-04-07 2024-05-10 之江实验室 一种图像匹配模型训练的方法、装置、存储介质、设备

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114970673B (zh) * 2022-04-19 2023-04-07 华为技术有限公司 一种半监督模型训练方法、系统及相关设备
CN115471717B (zh) * 2022-09-20 2023-06-20 北京百度网讯科技有限公司 模型的半监督训练、分类方法装置、设备、介质及产品
CN117151200A (zh) * 2023-10-27 2023-12-01 成都合能创越软件有限公司 基于半监督训练提升yolo检测模型精度方法及系统

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN114067444A (zh) * 2021-10-12 2022-02-18 中新国际联合研究院 基于元伪标签和光照不变特征的人脸欺骗检测方法和系统
US20220083840A1 (en) * 2020-09-11 2022-03-17 Google Llc Self-training technique for generating neural network models
CN114970673A (zh) * 2022-04-19 2022-08-30 华为技术有限公司 一种半监督模型训练方法、系统及相关设备

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108009589A (zh) * 2017-12-12 2018-05-08 腾讯科技(深圳)有限公司 样本数据处理方法、装置和计算机可读存储介质
US11620581B2 (en) * 2020-03-06 2023-04-04 International Business Machines Corporation Modification of machine learning model ensembles based on user feedback
CN112183099A (zh) * 2020-10-09 2021-01-05 上海明略人工智能(集团)有限公司 基于半监督小样本扩展的命名实体识别方法及系统
CN113705769A (zh) * 2021-05-17 2021-11-26 华为技术有限公司 一种神经网络训练方法以及装置
CN114330588A (zh) * 2022-01-04 2022-04-12 杭州网易智企科技有限公司 一种图片分类方法、图片分类模型训练方法及相关装置

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20220083840A1 (en) * 2020-09-11 2022-03-17 Google Llc Self-training technique for generating neural network models
CN112232416A (zh) * 2020-10-16 2021-01-15 浙江大学 一种基于伪标签加权的半监督学习方法
CN114067444A (zh) * 2021-10-12 2022-02-18 中新国际联合研究院 基于元伪标签和光照不变特征的人脸欺骗检测方法和系统
CN114970673A (zh) * 2022-04-19 2022-08-30 华为技术有限公司 一种半监督模型训练方法、系统及相关设备

Cited By (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117237343A (zh) * 2023-11-13 2023-12-15 安徽大学 半监督rgb-d图像镜面检测方法、存储介质及计算机设备
CN117237343B (zh) * 2023-11-13 2024-01-30 安徽大学 半监督rgb-d图像镜面检测方法、存储介质及计算机设备
CN118015316A (zh) * 2024-04-07 2024-05-10 之江实验室 一种图像匹配模型训练的方法、装置、存储介质、设备
CN118015316B (zh) * 2024-04-07 2024-06-11 之江实验室 一种图像匹配模型训练的方法、装置、存储介质、设备

Also Published As

Publication number Publication date
CN114970673A (zh) 2022-08-30
CN114970673B (zh) 2023-04-07

Similar Documents

Publication Publication Date Title
WO2023202596A1 (zh) 一种半监督模型训练方法、系统及相关设备
CN109902732B (zh) 车辆自动分类方法及相关装置
KR102173610B1 (ko) 딥 러닝에 기반한 차량번호판 분류 방법, 시스템, 전자장치 및 매체
CN108710885B (zh) 目标对象的检测方法和装置
WO2020199693A1 (zh) 一种大姿态下的人脸识别方法、装置及设备
EP4250189A1 (en) Model training method, data processing method and apparatus
JP6994588B2 (ja) 顔特徴抽出モデル訓練方法、顔特徴抽出方法、装置、機器および記憶媒体
US11328401B2 (en) Stationary object detecting method, apparatus and electronic device
WO2021151336A1 (zh) 基于注意力机制的道路图像目标检测方法及相关设备
WO2022007434A1 (zh) 可视化方法及相关设备
CN111401558A (zh) 数据处理模型训练方法、数据处理方法、装置、电子设备
EP4163831A1 (en) Neural network distillation method and device
US11928583B2 (en) Adaptation of deep learning models to resource constrained edge devices
WO2021098618A1 (zh) 数据分类方法、装置、终端设备及可读存储介质
CN111414879A (zh) 人脸遮挡程度识别方法、装置、电子设备及可读存储介质
WO2023206944A1 (zh) 一种语义分割方法、装置、计算机设备和存储介质
US20220165064A1 (en) Method for acquiring traffic state, relevant apparatus, roadside device and cloud control platform
CN112348081A (zh) 用于图像分类的迁移学习方法、相关装置及存储介质
CN113837097B (zh) 一种面向视觉目标识别的无人机边缘计算验证系统和方法
US20230100427A1 (en) Face image processing method, face image processing model training method, apparatus, device, storage medium, and program product
CN111738474A (zh) 交通状态预测方法和装置
CN112580481A (zh) 基于边缘节点和云端协同视频处理方法、装置、服务器
CN117688984A (zh) 神经网络结构搜索方法、装置及存储介质
CN114998679A (zh) 深度学习模型的在线训练方法、装置、设备及存储介质
Xiang et al. Crowd density estimation method using deep learning for passenger flow detection system in exhibition center

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 23791251

Country of ref document: EP

Kind code of ref document: A1