CN114881129A - Model training method and device, electronic equipment and storage medium - Google Patents

Model training method and device, electronic equipment and storage medium Download PDF

Info

Publication number
CN114881129A
CN114881129A CN202210441633.0A CN202210441633A CN114881129A CN 114881129 A CN114881129 A CN 114881129A CN 202210441633 A CN202210441633 A CN 202210441633A CN 114881129 A CN114881129 A CN 114881129A
Authority
CN
China
Prior art keywords
model
sample
pseudo
label
confidence
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210441633.0A
Other languages
Chinese (zh)
Inventor
李莹莹
蒋旻悦
杨喜鹏
孙昊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202210441633.0A priority Critical patent/CN114881129A/en
Publication of CN114881129A publication Critical patent/CN114881129A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting

Abstract

The present disclosure provides a model training method, which relates to the technical field of artificial intelligence, in particular to the technical fields of deep learning, image processing, computer vision, etc., and in particular to a model training method, an apparatus, an electronic device, and a storage medium. The specific implementation scheme is as follows: obtaining each sample for model training; respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result; calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label; and adjusting model parameters of the semi-supervised model based on the model loss so as to obtain a target semi-supervised model. By the scheme, the precision of the semi-supervised model can be improved; moreover, through more accurate model loss, the convergence speed of the model can be accelerated, the occupied time aiming at the computing resources is reduced, and the computing resources utilized in the model training are saved.

Description

Model training method and device, electronic equipment and storage medium
Technical Field
The present disclosure relates to the field of artificial intelligence technologies, and in particular, to the field of deep learning, image processing, computer vision, and the like, and in particular, to a model training method and apparatus, an electronic device, and a storage medium.
Background
The semi-supervised model is a model trained by using a semi-supervised learning mode, namely a mode of deep learning by using a part of labeled samples and a part of unlabelled samples. The tasks that can be realized by the semi-supervised model can include: the object detection task, such as an image detection task or a text detection task, is certainly not limited thereto.
In order to ensure that the semi-supervised model can be converged quickly, so that the calculation resources are prevented from being occupied for a long time, a pseudo label can be generated for an unlabelled sample in the semi-supervised model training process, and the unlabelled sample and the generated pseudo label are used for model training.
Disclosure of Invention
The disclosure provides a model training method, a model training device, an electronic device and a storage medium.
According to an aspect of the present disclosure, there is provided a model training method, including:
obtaining each sample for model training; wherein an unlabeled first sample in each sample has a pseudo-label, and the pseudo-label has a confidence level;
respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result;
calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label; wherein the loss weight is set based on the confidence level and is positively correlated with the confidence level;
and adjusting model parameters of the semi-supervised model based on the model loss so as to obtain a target semi-supervised model.
According to a second aspect of the present disclosure, there is provided a model training apparatus comprising:
the acquisition module is used for acquiring each sample for model training; wherein an unlabeled first sample in each sample has a pseudo-label, and the pseudo-label has a confidence level;
the input module is used for respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result;
the calculation module is used for calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label; wherein the loss weight is set based on the confidence level and is positively correlated with the confidence level;
and the judging module is used for adjusting the model parameters of the semi-supervised model based on the model loss so as to obtain the target semi-supervised model.
According to a third aspect of the present disclosure, there is provided an electronic device comprising:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein the content of the first and second substances,
the memory stores instructions executable by the at least one processor to enable the at least one processor to perform any of the model training methods.
According to a fourth aspect of the present disclosure, there is provided a non-transitory computer readable storage medium having stored thereon computer instructions for causing the computer to perform any of the model training methods.
According to another aspect of the present disclosure, there is also provided a computer program product comprising a computer program which, when executed by a processor, implements a method of training according to any of the models.
It should be understood that the statements in this section do not necessarily identify key or critical features of the embodiments of the present disclosure, nor do they limit the scope of the present disclosure. Other features of the present disclosure will become apparent from the following description.
Drawings
The drawings are included to provide a better understanding of the present solution and are not to be construed as limiting the present disclosure. Wherein:
FIG. 1 is a flow chart of a model training method provided in accordance with the present disclosure;
FIG. 2 is another flow chart diagram of a model training method provided in accordance with the present disclosure;
FIG. 3 is a schematic diagram of a model training apparatus provided in accordance with the present disclosure;
FIG. 4 is a block diagram of an electronic device for implementing a model training method of an embodiment of the present disclosure.
Detailed Description
Exemplary embodiments of the present disclosure are described below with reference to the accompanying drawings, in which various details of the embodiments of the disclosure are included to assist understanding, and which are to be considered as merely exemplary. Accordingly, those of ordinary skill in the art will recognize that various changes and modifications of the embodiments described herein can be made without departing from the scope and spirit of the present disclosure. Also, descriptions of well-known functions and constructions are omitted in the following description for clarity and conciseness.
STAC (a Simple Semi-Supervised Learning Framework for Object Detection) is representative in that Semi-Supervised training is currently performed based on self-training. For the semi-supervised learning architecture, firstly, a Teacher (Teacher) model for generating pseudo labels is trained by using samples marked with labels, namely, a pseudo label generation model is trained by using the marked samples, the model is used for generating the pseudo labels corresponding to the samples after a sample is given, specifically, the model can output the confidence degrees of various labels, and the label with the highest confidence degree is used as the pseudo label of the sample; then, generating a pseudo label of the unlabeled sample by using the trained Teacher model; and finally, inputting the samples marked with the labels and the unlabeled samples with the pseudo labels into a student model to be trained, namely a semi-supervised model, and detecting whether the model is trained or not by calculating the loss of the semi-supervised model.
However, during the training process of the semi-supervised model, the pseudo labels are obtained off-line, and during the training process of the semi-supervised model, the model generating the pseudo labels is not updated. Therefore, in the semi-supervised model training process, when the precision of the semi-supervised model exceeds the precision of the model generating the pseudo label, the precision of the semi-supervised model training cannot be continuously improved, so that the training precision of the semi-supervised model is limited.
In the other semi-supervised model training process, the accuracy of the semi-supervised model is improved by adopting a pseudo label on-line updating mode. According to the training mode, the model for generating the pseudo label is updated on line, the quality of the pseudo label is timely improved, and therefore the precision of model training is improved, and the model is converged. However, in this training method, each unlabeled sample used for training the semi-supervised model has the same supervision effect on model training regardless of the confidence level of the pseudo label, that is, the confidence level of the pseudo label, so that the improvement of the precision of the semi-supervised model is definitely limited.
Based on the above problems, embodiments of the present disclosure provide a model training method and apparatus, an electronic device, and a storage medium, so as to further improve the accuracy of a semi-supervised model.
The following first introduces a model training method provided by the present disclosure.
The model training method related to the embodiment of the present disclosure may be applied to an electronic device, where the electronic device may be a terminal device or a server, and the present disclosure does not limit the specific form of the electronic device. In addition, the model training method provided by the embodiment of the disclosure can be applied to any scene in which a semi-supervised model is trained by using labeled samples and unlabeled samples, and the embodiment of the disclosure is not limited to a specific scene.
Moreover, the semi-supervised model according to the present disclosure is a model trained by using a semi-supervised learning manner; the tasks that can be achieved by the semi-supervised model may include: the object detection task is, of course, not limited thereto. The object detection task may be considered as a classification task, that is, a class to which the detection object belongs, at this time, the object class is a label, the labeled sample may be a sample labeled with the object class, the unlabeled sample may be a sample not labeled with the object class, a pseudo label of the unlabeled sample represents the object class, and a confidence coefficient of the pseudo label is a confidence probability that the unlabeled sample belongs to the represented object class.
In addition, in a specific application, the type of the sample may be an image, a text, or the like, but is not limited thereto. For example, if the type of the sample is an image type, in this case, the object detection task may be detecting an identity type to which an object in the sample image belongs, or may also be detecting an identity type and location information to which an object in the sample image belongs; if the type of the sample is a text type, the object detection task may be to detect a category of emotion represented by the text, for example: pessimistic, aggressive, etc.
The model training method provided by the embodiment of the disclosure can comprise the following steps:
obtaining each sample for model training; wherein an unlabeled first sample in each sample has a pseudo-label, and the pseudo-label has a confidence level;
respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result;
calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label; wherein the loss weight is set based on the confidence level and is positively correlated with the confidence level;
and adjusting model parameters of the semi-supervised model based on the model loss so as to obtain a target semi-supervised model.
In the scheme, the pseudo labels correspond to loss weights, the loss weights are set based on confidence degrees of the pseudo labels and are positively correlated with the confidence degrees, namely, the supervision effect of the pseudo labels of the first sample in the model training process is influenced by the confidence degrees, so that the application of the pseudo labels of the first sample is more reasonable; therefore, when the semi-supervised model is trained, more accurate model loss can be calculated by using the difference between the output result of each sample and the label of each sample and the loss weight corresponding to the pseudo label. Therefore, the scheme can further improve the precision of the semi-supervised model; moreover, through more accurate model loss, the convergence speed of the model can be accelerated, the occupied time aiming at the computing resources is reduced, and the computing resources utilized in the model training are saved.
A model training method provided by the present disclosure is exemplarily described below with reference to the accompanying drawings.
As shown in fig. 1, a model training method provided by the present disclosure may include the following steps:
s101: obtaining each sample for model training;
wherein an unlabeled first sample in each sample has a pseudo-label, and the pseudo-label has a confidence level;
in the model training process for the semi-supervised model, each sample may be obtained first, and then model training may be performed using each obtained sample. It will be appreciated that each sample may include a first unlabeled sample, and a second labeled sample, as the training is directed to the semi-supervised model. The second sample has a labeled label, and the labeled label is a true value labeled for the second sample; and the first sample is not labeled with a label, but a pseudo label is set for the first sample in advance through a preset pseudo label generation mode, and the pseudo label has confidence.
For example, each sample may be an image sample, and at this time, the pseudo label of the first sample may be an object class in the image, and the confidence degree that the pseudo label has is the probability of belonging to the object class; and the label of the annotation that the second sample has may be the object class in the image. Each sample can be a text sample, and at this time, the pseudo label of the first sample can be an emotion category represented by the text, and the confidence degree of the pseudo label is the probability of belonging to the emotion category; and the second sample may have labeled labels for the emotion categories expressed by the text.
The obtaining manner of the second sample may include obtaining the second sample from the public data set, which is not limited to this, and this embodiment does not limit this. In addition, for clarity of the scheme and clarity of layout, the following describes an exemplary manner of obtaining the first sample in combination with other embodiments.
In the technical scheme of the present disclosure, the collection, storage, use, processing, transmission, provision, disclosure and other processing of the related samples all meet the regulations of related laws and regulations, and do not violate the common customs of the public order.
S102: respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result;
after obtaining each sample, the semi-supervised model to be trained can be trained, and firstly, each obtained sample can be respectively input into the semi-supervised model to be trained to obtain each output result; thereby performing the subsequent step of model training according to the output result. It should be noted that each sample corresponds to an output result.
For example, taking an object detection task as an example, the output result may be the confidence of each object class, for example: the input is a text and the output result may be a confidence for the text category, for example: the confidence of belonging to the pessimistic category is 40% and the confidence of belonging to the positive category is 60%. Although not limited thereto.
When the accuracy of the semi-supervised model is high, the output result may be similar to the label corresponding to the input sample, but there is a small difference, for example: the output result of the first sample is similar to the pseudo-label of the first sample; the output result of the second swatch is similar to the label of the second swatch. When the model accuracy is low, the labels of the output result and the input sample response may be very different, even completely different, which is possible.
It should be noted that the output results of different samples may be the same or different, and the disclosure is not limited herein.
S103: calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label;
wherein the loss weight is set based on the confidence degree of the pseudo tag and is positively correlated with the confidence degree of the pseudo tag; that is, the loss weight is affected by the confidence of the pseudo tag, the variation trend of the loss weight is consistent with the variation trend of the confidence of the pseudo tag, the confidence is higher when the loss weight is higher, and the confidence is lower when the loss weight is lower and the loss weight is lower. It is emphasized that the first sample has a label that is a pseudo label of the first sample; the second exemplar has a label that is the label that the second exemplar is labeled with.
After the output result of the semi-supervised model to be trained is obtained, the model loss of the semi-supervised model can be calculated by using the output result, so that whether the model is trained or not is judged according to the model loss.
It should be noted that there is a difference between the output result of each sample and the label of each sample, where the label corresponding to the first sample is a pseudo label, and the label corresponding to the second sample is a labeled label. The loss of the semi-supervised model can be calculated using the difference between the output result of each sample and the label of each sample.
For example, by using the difference between the output result of each sample and the label of each sample, and the loss weight corresponding to the pseudo label, the model loss may be calculated as follows: and calculating the model loss by using the output result of each sample, the label of each sample and the loss weight corresponding to the pseudo label according to a predetermined loss value calculation function.
For example, in one implementation, the loss of the semi-supervised model to be trained may be calculated by calculating cross entropy. At this time, the output result of each sample, the label of each sample, and the loss weight corresponding to the pseudo label may be substituted into a specified cross entropy calculation formula to obtain the model loss.
Wherein, the calculation formula of the cross entropy is as follows:
C=-1/n*∑ x [ylna+(1-y)ln(1-a)]
wherein n is the number of samples, x is the sample, a is the output result corresponding to the sample, y is the label labeled by the second sample when x is the second sample, and y is the product of the pseudo label of the first sample and the loss weight when x is the first sample. In this way, the loss weight corresponding to the pseudo tag is applied to the calculation process of the model loss.
It should be noted that, the above-mentioned manner for calculating the model loss by using the cross entropy is only an example, and any manner capable of calculating the model loss may be applied to the present disclosure, and is not limited herein.
S104: adjusting model parameters of the semi-supervised model based on the model loss so as to obtain a target semi-supervised model;
after the model loss calculation is completed, it may be determined whether the model converges based on the result of the calculation. For example, in one implementation, the determining whether the model converges may include: if the model loss is not higher than a preset threshold value, judging that the model precision meets the requirement, and finishing training; and if the model loss is higher than the preset threshold value, the model precision is too low, the training needs to be continued, at the moment, the model parameters of the semi-supervised model are adjusted, the step of obtaining each sample for model training is returned until the model loss is not higher than the preset threshold value, and the target semi-supervised model is obtained. In adjusting the model parameters, a gradient descent method may be used, but the method is not limited to this.
In the scheme, the pseudo labels correspond to loss weights, the loss weights are set based on confidence degrees of the pseudo labels and are positively correlated with the confidence degrees, namely, the supervision effect of the pseudo labels of the first sample in the model training process is influenced by the confidence degrees, so that the application of the pseudo labels of the first sample is more reasonable; therefore, when the semi-supervised model is trained, more accurate model loss can be calculated by using the difference between the output result of each sample and the label of each sample and the loss weight corresponding to the pseudo label. Therefore, the scheme can further improve the precision of the semi-supervised model; moreover, through more accurate model loss, the convergence speed of the model can be accelerated, the occupied time aiming at the computing resources is reduced, and the computing resources utilized in the model training are saved.
Optionally, in another embodiment of the present disclosure, the obtaining of the first sample includes:
determining at least one initial sample belonging to the unlabeled class;
generating a pseudo label of the initial sample by utilizing a pre-trained pseudo label generation model; the pseudo label generation model is a model obtained by training by using samples belonging to labels and corresponding labels;
from the initial samples with the pseudo-label, a first sample is selected.
The pseudo label generation model can be trained on line, namely can be updated in time, and therefore the training precision of the semi-supervised model is improved. The specific model structure of the pseudo tag generation model is not limited in the disclosure; in addition, the training process of the pseudo label generation model may include: outputting a pseudo label generation model to be trained by samples belonging to the label to obtain an output result, then calculating the model loss of the pseudo label generation model by using the difference between the output result and the pseudo label, and judging convergence if the model loss is less than a set threshold value; if the value is not less than the set threshold value, the model parameters are adjusted, and samples belonging to the labels and corresponding labels are continuously obtained, so that the model is continuously trained.
For example, in one implementation, each initial sample with a pseudo label may be used as a first sample to be applied in a training process of a semi-supervised model, so as to ensure sufficiency of the first sample.
For example, from the initial samples with the pseudo labels, the first sample may be selected by: and selecting the initial samples with the pseudo labels, the confidence degrees of which are not lower than a first threshold value, from the initial samples with the pseudo labels to obtain a first sample.
Because the pseudo label with the lower confidence coefficient plays a little role or does not play a role in the semi-supervised model training process, the initial sample with the confidence coefficient of the pseudo label not lower than the first threshold value can be selected from the initial samples with the pseudo label as the first sample, namely the initial sample is screened, and the pseudo label which can play a larger role in the semi-supervised model training process is selected, so that the semi-supervised model training precision is improved.
It should be noted that the above-mentioned manner for obtaining the respective samples for model training is only an example, and should not be construed as a limitation of the present disclosure.
Therefore, in this embodiment, the pseudo label generation model is adopted to generate the pseudo label of each initial sample, and the first sample is selected based on the initial sample with the pseudo label, so that the first sample can be quickly and effectively obtained.
Based on the above model training method, in another embodiment of the present disclosure, another model training method is further provided, as shown in fig. 2, the method may include the following steps:
s201: obtaining each sample for model training;
s202: respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result;
s203: determining a loss weight corresponding to a pseudo label of a first sample;
in the model loss calculation process, the loss weight corresponding to the pseudo tag needs to be used, so that the loss weight corresponding to the pseudo tag can be determined first, and the subsequent loss calculation step is executed.
For example, the method for determining the loss weight corresponding to the pseudo tag may include steps a 1-a 2:
step a1, determining a target confidence interval in which the confidence level of the pseudo label of the first sample is located; the target confidence interval is one of a plurality of confidence intervals, and different confidence intervals correspond to different loss weights;
step A2, determining the loss weight corresponding to the target confidence interval as the loss weight corresponding to the pseudo label.
Because the confidence degrees are positively correlated with the loss weights, a plurality of confidence degree intervals can be set, and when the loss weight corresponding to the confidence degree of a certain pseudo label needs to be determined, the corresponding confidence degree interval can be determined first, so that the loss weight corresponding to the pseudo label is determined. It can be understood that, on the premise of ensuring that the loss weight is positively correlated with the confidence, the specific value of the loss weight corresponding to each confidence interval may be set according to the actual situation, and the disclosure is not limited.
It should be noted that, by setting a plurality of confidence intervals, it is possible to avoid that too many loss weights are set due to each different confidence corresponding to a different loss weight, which leads to a complicated calculation.
Optionally, in an implementation manner, the confidence intervals are divided according to a specified step length. For example, the unit step size specified may be 20%, in which case the confidence intervals are: the loss weight for each confidence interval may be 0.2, 0.4, 0.6, 0.8, and 1 in the order of (0, 20%, (20%, 40%, (40%, 60%, (60%, 80%, (80%, 100%).
Optionally, in another implementation manner, the confidence level of the pseudo tag is not lower than a first threshold, at this time, the confidence level intervals include a first confidence level interval and a second confidence level interval, and the loss weight corresponding to the first confidence level interval is higher than the loss weight corresponding to the second confidence level interval; the confidence degrees in the first confidence degree intervals are all higher than the second threshold, the confidence degree in the second confidence degree intervals is not lower than the first threshold and not higher than a second threshold, and the second threshold is higher than the first threshold.
For example, if the confidence of the pseudo tag is not lower than the first threshold (30%), and the first threshold is 30% and the second threshold is 60%, the first confidence interval may be 100% to 60%, the corresponding loss weight is 1, the second confidence interval may be 60% to 30%, and the corresponding loss weight is 0.5. If the confidence of a certain pseudo label is 80%, the label belongs to the first confidence interval, the weight of the label is set to be 1, and if the confidence of a certain pseudo label is 40%, the label belongs to the second confidence interval, the weight of the label is set to be 0.5.
Such a mode of setting the interval by the first threshold and the second threshold may be referred to as a dual threshold mode, and it is needless to say that a plurality of thresholds may be set according to actual needs, which is all reasonable. At the moment, the weight of the pseudo label with higher confidence coefficient can be directly determined, the model training efficiency is improved, and the model precision can be improved. In addition, it should be emphasized that the references to "first" in the "first threshold", to "second" in the "second threshold", to "first" in the "first confidence interval", and to "second" in the "second confidence interval" in the embodiments of the present disclosure are only used to distinguish different thresholds and intervals belonging to different thresholds from names, and do not have any limiting meaning.
The above description of the manner of determining the loss weight corresponding to the pseudo tag is merely an example, and should not be construed as limiting the present disclosure.
S204: calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label;
s205: adjusting model parameters of the semi-supervised model based on the model loss so as to obtain a target semi-supervised model;
steps S201 and S202 are similar to steps S101 and S102, and steps S204 and S205 are similar to steps S103 and S104, and are not described herein again.
In the scheme, the pseudo labels correspond to loss weights, the loss weights are set based on confidence degrees of the pseudo labels and are positively correlated with the confidence degrees, namely, the supervision effect of the pseudo labels of the first sample in the model training process is influenced by the confidence degrees, so that the application of the pseudo labels of the first sample is more reasonable; therefore, when the semi-supervised model is trained, more accurate model loss can be calculated by using the difference between the output result of each sample and the loss weight corresponding to the pseudo label. Therefore, the scheme can further improve the precision of the semi-supervised model; moreover, through more accurate model loss, the convergence speed of the model can be accelerated, the occupied time aiming at the computing resources is reduced, and the computing resources utilized in the model training are saved.
Optionally, in another embodiment of the present disclosure, the task for the semi-supervised model may be an object detection task, which is a detection task for an image.
Accordingly, a model training method may include steps a-D:
step A, obtaining each image sample for model training;
wherein, the first image sample is an unlabelled image sample, the second image sample is an labeled image sample, and the second sample is labeled with a label, which may be as follows: an animal image tag, a plant image tag, and the like, the pseudo tag of the first sample may be a pseudo tag generated by a predetermined pseudo tag generation manner, for example: the image of the first sample contains the image content of the dog, and the label generated by a predetermined pseudo label generation method comprises: animal image tags with a confidence of 80% and plant image tags with a confidence of 20%; the animal image label can be considered as a false label for this first sample image with a confidence of 80%.
And B, respectively inputting the acquired image samples into the semi-supervised model to be trained to obtain an output result.
For a first image sample, which is an image of a cat, the output result may be: the animal type, for a second image sample, which is an image of a tree, the output result may be: the type of plant. If the accuracy of the semi-supervised model to be trained is not high, the output result may be greatly different from the actual result, for example: and aiming at a certain animal image, the output result is a plant type image.
And C, calculating model loss by using the difference between the output result and the label of each image sample and the loss weight corresponding to the pseudo label.
When the confidence of the pseudo label is high, the loss weight of the pseudo label can be set to be a large value, so that the effect of the first image sample in model loss calculation is increased in the loss calculation, and when the confidence of the pseudo label is low, the loss weight can be set to be a small value, so that the effect of the first image sample in the model loss calculation is reduced, and the accuracy of the model loss calculation is increased.
Step D, adjusting model parameters of the semi-supervised model based on model loss so as to obtain a target semi-supervised model:
when detecting an image sample, if the model loss is very small, for example, 0.1 or 0.2, the target semi-supervised model can be directly obtained without adjusting the parameters of the semi-supervised model. When the model loss is large, such as 0.5, 0.6, or 0.7, etc., the model parameters need to be adjusted, and the semi-supervised model continues to be trained until the obtained image is the target image, so as to obtain the target semi-supervised model.
Therefore, the scheme can further improve the precision of the semi-supervised model; moreover, through more accurate model loss, the convergence speed of the model can be accelerated, the occupied time aiming at the computing resources is reduced, and the computing resources utilized in the model training are saved.
According to an embodiment of the present disclosure, there is also provided a model training apparatus, as shown in fig. 3, including:
an obtaining module 310, configured to obtain each sample for model training; wherein an unlabeled first sample in each sample has a pseudo-label, and the pseudo-label has a confidence level;
the input module 320 is used for respectively inputting the obtained samples into the semi-supervised model to be trained to obtain an output result;
a calculating module 330, configured to calculate a model loss by using a difference between the output result and a label of each sample and a loss weight corresponding to the pseudo label; wherein the loss weight is set based on the confidence level and is positively correlated with the confidence level;
a determining module 340, configured to adjust model parameters of the semi-supervised model based on the model loss, so as to obtain a target semi-supervised model.
In the scheme, the pseudo labels correspond to loss weights, the loss weights are set based on confidence degrees of the pseudo labels and are positively correlated with the confidence degrees, namely, the supervision effect of the pseudo labels of the first sample in the model training process is influenced by the confidence degrees, so that the application of the pseudo labels of the first sample is more reasonable; therefore, when the semi-supervised model is trained, more accurate model loss can be calculated by using the difference between the output result of each sample and the corresponding label and the loss weight corresponding to the pseudo label. Therefore, the scheme can further improve the precision of the semi-supervised model; moreover, through more accurate model loss, the convergence speed of the model can be accelerated, the occupied time aiming at the computing resources is reduced, and the computing resources utilized in the model training are saved.
Optionally, the method further comprises:
a first determining module, configured to determine a target confidence interval in which a confidence level of a pseudo tag of the first sample is located before the calculating module calculates a model loss using a difference between an output result of each sample and the tag of each sample and a loss weight corresponding to the pseudo tag; the target confidence interval is one of a plurality of confidence intervals, and different confidence intervals correspond to different loss weights;
and the second determining module is used for determining the loss weight corresponding to the target confidence degree interval as the loss weight corresponding to the pseudo tag.
Optionally, the pseudo tag has a confidence level not below a first threshold;
the confidence intervals comprise a first confidence interval and a second confidence interval, and the loss weight corresponding to the first confidence interval is higher than the loss weight corresponding to the second confidence interval;
the confidence degrees in the first confidence degree intervals are all higher than the second threshold value, the confidence degree in the second confidence degree intervals is not lower than the first threshold value and not higher than a second threshold value, and the second threshold value is higher than the first threshold value.
Optionally, the obtaining manner of the first sample includes:
determining at least one initial sample belonging to the unlabeled class;
generating a pseudo label of the initial sample by utilizing a pre-trained pseudo label generation model; the pseudo label generation model is a model obtained by training by using samples belonging to labels and corresponding labels;
from the initial samples with the pseudo-label, a first sample is taken.
Optionally, selecting a first sample from the initial samples with the pseudo-label comprises:
and selecting the initial samples with the pseudo labels, the confidence degrees of which are not lower than a first threshold value, from the initial samples with the pseudo labels to obtain a first sample.
Optionally, the calculation module is specifically configured to:
and calculating a model loss according to a predetermined loss value calculation function based on the output result, the label of each sample and the loss weight corresponding to the pseudo label.
The present disclosure also provides an electronic device, a readable storage medium, and a computer program product according to embodiments of the present disclosure.
An embodiment of the present disclosure provides an electronic device, including:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein the content of the first and second substances,
the memory stores instructions executable by the at least one processor to enable the at least one processor to perform any one of the model training methods.
The disclosed embodiments provide a non-transitory computer-readable storage medium having stored thereon computer instructions for causing the computer to perform any of the model training methods.
Embodiments of the present disclosure provide a computer program product comprising a computer program which, when executed by a processor, implements any of the methods of model training.
FIG. 4 shows a schematic block diagram of an example electronic device 400 that may be used to implement embodiments of the present disclosure. Electronic devices are intended to represent various forms of digital computers, such as laptops, desktops, workstations, personal digital assistants, servers, blade servers, mainframes, and other appropriate computers. The electronic device may also represent various forms of mobile devices, such as personal digital processing, cellular phones, smart phones, wearable devices, and other similar computing devices. The components shown herein, their connections and relationships, and their functions, are meant to be examples only, and are not meant to limit implementations of the disclosure described and/or claimed herein.
As shown in fig. 4, the apparatus 400 includes a computing unit 401 that can perform various appropriate actions and processes according to a computer program stored in a Read Only Memory (ROM)402 or a computer program loaded from a storage unit 408 into a Random Access Memory (RAM) 403. In the RAM 403, various programs and data required for the operation of the device 400 can also be stored. The computing unit 401, ROM 402, and RAM 403 are connected to each other via a bus 404. An input/output (I/O) interface 405 is also connected to bus 404.
A number of components in device 400 are connected to I/O interface 405, including: an input unit 406 such as a keyboard, a mouse, or the like; an output unit 407 such as various types of displays, speakers, and the like; a storage unit 408 such as a magnetic disk, optical disk, or the like; and a communication unit 409 such as a network card, modem, wireless communication transceiver, etc. The communication unit 409 allows the device 400 to exchange information/data with other devices via a computer network, such as the internet, and/or various telecommunication networks.
Computing unit 401 may be a variety of general and/or special purpose processing components with processing and computing capabilities. Some examples of the computing unit 401 include, but are not limited to, a Central Processing Unit (CPU), a Graphics Processing Unit (GPU), various dedicated Artificial Intelligence (AI) computing chips, various computing units running machine learning model algorithms, a Digital Signal Processor (DSP), and any suitable processor, controller, microcontroller, and so forth. The computing unit 401 performs the various methods and processes described above, such as the model training method. For example, in some embodiments, the model training method may be implemented as a computer software program tangibly embodied in a machine-readable medium, such as storage unit 408. In some embodiments, part or all of the computer program may be loaded and/or installed onto the device 400 via the ROM 402 and/or the communication unit 409. When the computer program is loaded into RAM 403 and executed by computing unit 401, one or more steps of the model training method described above may be performed. Alternatively, in other embodiments, the computing unit 401 may be configured to perform the model training method by any other suitable means (e.g., by means of firmware).
Various implementations of the systems and techniques described here above may be implemented in digital electronic circuitry, integrated circuitry, Field Programmable Gate Arrays (FPGAs), Application Specific Integrated Circuits (ASICs), Application Specific Standard Products (ASSPs), system on a chip (SOCs), Complex Programmable Logic Devices (CPLDs), computer hardware, firmware, software, and/or combinations thereof. These various embodiments may include: implemented in one or more computer programs that are executable and/or interpretable on a programmable system including at least one programmable processor, which may be special or general purpose, receiving data and instructions from, and transmitting data and instructions to, a storage system, at least one input device, and at least one output device.
Program code for implementing the methods of the present disclosure may be written in any combination of one or more programming languages. These program codes may be provided to a processor or controller of a general purpose computer, special purpose computer, or other programmable data processing apparatus, such that the program codes, when executed by the processor or controller, cause the functions/operations specified in the flowchart and/or block diagram to be performed. The program code may execute entirely on the machine, partly on the machine, as a stand-alone software package partly on the machine and partly on a remote machine or entirely on the remote machine or server.
In the context of this disclosure, a machine-readable medium may be a tangible medium that can contain, or store a program for use by or in connection with an instruction execution system, apparatus, or device. The machine-readable medium may be a machine-readable signal medium or a machine-readable storage medium. A machine-readable medium may include, but is not limited to, an electronic, magnetic, optical, electromagnetic, infrared, or semiconductor system, apparatus, or device, or any suitable combination of the foregoing. More specific examples of a machine-readable storage medium would include an electrical connection based on one or more wires, a portable computer diskette, a hard disk, a Random Access Memory (RAM), a read-only memory (ROM), an erasable programmable read-only memory (EPROM or flash memory), an optical fiber, a portable compact disc read-only memory (CD-ROM), an optical storage device, a magnetic storage device, or any suitable combination of the foregoing.
To provide for interaction with a user, the systems and techniques described here can be implemented on a computer having: a display device (e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor) for displaying information to a user; and a keyboard and a pointing device (e.g., a mouse or a trackball) by which a user can provide input to the computer. Other kinds of devices may also be used to provide for interaction with a user; for example, feedback provided to the user can be any form of sensory feedback (e.g., visual feedback, auditory feedback, or tactile feedback); and input from the user may be received in any form, including acoustic, speech, or tactile input.
The systems and techniques described here can be implemented in a computing system that includes a back-end component (e.g., as a data server), or that includes a middleware component (e.g., an application server), or that includes a front-end component (e.g., a user computer having a graphical user interface or a web browser through which a user can interact with an implementation of the systems and techniques described here), or any combination of such back-end, middleware, or front-end components. The components of the system can be interconnected by any form or medium of digital data communication (e.g., a communication network). Examples of communication networks include: local Area Networks (LANs), Wide Area Networks (WANs), and the Internet.
The computer system may include clients and servers. A client and server are generally remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other. The server may be a cloud server, a server of a distributed system, or a server with a combined blockchain.
It should be understood that various forms of the flows shown above, reordering, adding or deleting steps, may be used. For example, the steps described in the present disclosure may be executed in parallel or sequentially or in different orders, and are not limited herein as long as the desired results of the technical solutions disclosed in the present disclosure can be achieved.
The above detailed description should not be construed as limiting the scope of the disclosure. It should be understood by those skilled in the art that various modifications, combinations, sub-combinations and substitutions may be made in accordance with design requirements and other factors. Any modification, equivalent replacement, and improvement made within the spirit and principle of the present disclosure should be included in the scope of protection of the present disclosure.

Claims (15)

1. A model training method, comprising:
obtaining each sample for model training; wherein an unlabeled first sample in each sample has a pseudo-label, and the pseudo-label has a confidence level;
respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result;
calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label; wherein the loss weight is set based on the confidence level and is positively correlated with the confidence level;
and adjusting model parameters of the semi-supervised model based on the model loss so as to obtain a target semi-supervised model.
2. The method of claim 1, wherein before calculating model loss using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label, the method further comprises:
determining a target confidence interval in which the confidence of the pseudo label of the first sample is located; the target confidence interval is one of a plurality of confidence intervals, and different confidence intervals correspond to different loss weights;
and determining the loss weight corresponding to the target confidence interval as the loss weight corresponding to the pseudo label.
3. The method of claim 2, wherein the pseudo-tag has a confidence level not below a first threshold;
the confidence intervals comprise a first confidence interval and a second confidence interval, and the loss weight corresponding to the first confidence interval is higher than the loss weight corresponding to the second confidence interval;
the confidence degrees in the first confidence degree intervals are all higher than the second threshold, the confidence degree in the second confidence degree intervals is not lower than the first threshold and not higher than a second threshold, and the second threshold is higher than the first threshold.
4. The method of claim 1, wherein the first sample is obtained in a manner comprising:
determining at least one initial sample belonging to the unlabeled class;
generating a pseudo label of the initial sample by utilizing a pre-trained pseudo label generation model; the pseudo label generation model is a model obtained by training by using samples belonging to labels and corresponding labels;
from the initial samples with the pseudo-label, a first sample is selected.
5. The method of claim 4, wherein selecting the first sample from the initial samples with the pseudo-tag comprises:
and selecting the initial samples with the pseudo labels, the confidence degrees of which are not lower than a first threshold value, from the initial samples with the pseudo labels to obtain a first sample.
6. The method according to any one of claims 1-5, wherein said calculating a model loss using a difference between the output result and a label of each sample and a loss weight corresponding to the pseudo label comprises:
and calculating a function according to a preset loss value based on the output result, the label of each sample and the loss weight corresponding to the pseudo label, and calculating the model loss.
7. A model training apparatus comprising:
the acquisition module is used for acquiring each sample for model training; wherein an unlabeled first sample in each sample has a pseudo-label, and the pseudo-label has a confidence level;
the input module is used for respectively inputting the obtained samples into a semi-supervised model to be trained to obtain an output result;
the calculation module is used for calculating model loss by using the difference between the output result and the label of each sample and the loss weight corresponding to the pseudo label; wherein the loss weight is set based on the confidence level and is positively correlated with the confidence level;
and the judging module is used for adjusting the model parameters of the semi-supervised model based on the model loss so as to obtain the target semi-supervised model.
8. The apparatus of claim 7, further comprising:
a first determining module, configured to determine a target confidence interval in which a confidence level of a pseudo tag of the first sample is located before the calculating module calculates a model loss using a difference between an output result of each sample and the tag of each sample and a loss weight corresponding to the pseudo tag; the target confidence interval is one of a plurality of confidence intervals, and different confidence intervals correspond to different loss weights;
and the second determining module is used for determining the loss weight corresponding to the target confidence degree interval as the loss weight corresponding to the pseudo tag.
9. The apparatus of claim 8, wherein the pseudo tag has a confidence level not below a first threshold;
the confidence intervals comprise a first confidence interval and a second confidence interval, and the loss weight corresponding to the first confidence interval is higher than the loss weight corresponding to the second confidence interval;
the confidence degrees in the first confidence degree intervals are all higher than the second threshold, the confidence degree in the second confidence degree intervals is not lower than the first threshold and not higher than a second threshold, and the second threshold is higher than the first threshold.
10. The apparatus of claim 7, wherein the first sample is obtained in a manner comprising:
determining at least one initial sample belonging to the unlabeled class;
generating a pseudo label of the initial sample by utilizing a pre-trained pseudo label generation model; the pseudo label generation model is a model obtained by training by using samples belonging to labels and corresponding labels;
from the initial samples with the pseudo-label, a first sample is selected.
11. The apparatus of claim 10, wherein selecting a first sample from among the initial samples with the pseudo-tag comprises:
and selecting the initial samples with the pseudo labels, the confidence degrees of which are not lower than a first threshold value, from the initial samples with the pseudo labels to obtain a first sample.
12. The apparatus according to any one of claims 7-11, wherein the computing module is specifically configured to:
and calculating a model loss according to a predetermined loss value calculation function based on the output result, the label of each sample and the loss weight corresponding to the pseudo label.
13. An electronic device, comprising:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein the content of the first and second substances,
the memory stores instructions executable by the at least one processor to enable the at least one processor to perform the method of any one of claims 1-6.
14. A non-transitory computer readable storage medium having stored thereon computer instructions for causing the computer to perform the method of any one of claims 1-6.
15. A computer program product comprising a computer program which, when executed by a processor, implements the method according to any one of claims 1-6.
CN202210441633.0A 2022-04-25 2022-04-25 Model training method and device, electronic equipment and storage medium Pending CN114881129A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210441633.0A CN114881129A (en) 2022-04-25 2022-04-25 Model training method and device, electronic equipment and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210441633.0A CN114881129A (en) 2022-04-25 2022-04-25 Model training method and device, electronic equipment and storage medium

Publications (1)

Publication Number Publication Date
CN114881129A true CN114881129A (en) 2022-08-09

Family

ID=82671154

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210441633.0A Pending CN114881129A (en) 2022-04-25 2022-04-25 Model training method and device, electronic equipment and storage medium

Country Status (1)

Country Link
CN (1) CN114881129A (en)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115471717A (en) * 2022-09-20 2022-12-13 北京百度网讯科技有限公司 Model semi-supervised training and classification method and device, equipment, medium and product
CN115471805A (en) * 2022-09-30 2022-12-13 阿波罗智能技术(北京)有限公司 Point cloud processing and deep learning model training method and device and automatic driving vehicle
CN116468112A (en) * 2023-04-06 2023-07-21 北京百度网讯科技有限公司 Training method and device of target detection model, electronic equipment and storage medium

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115471717A (en) * 2022-09-20 2022-12-13 北京百度网讯科技有限公司 Model semi-supervised training and classification method and device, equipment, medium and product
CN115471805A (en) * 2022-09-30 2022-12-13 阿波罗智能技术(北京)有限公司 Point cloud processing and deep learning model training method and device and automatic driving vehicle
CN115471805B (en) * 2022-09-30 2023-09-05 阿波罗智能技术(北京)有限公司 Point cloud processing and deep learning model training method and device and automatic driving vehicle
CN116468112A (en) * 2023-04-06 2023-07-21 北京百度网讯科技有限公司 Training method and device of target detection model, electronic equipment and storage medium
CN116468112B (en) * 2023-04-06 2024-03-12 北京百度网讯科技有限公司 Training method and device of target detection model, electronic equipment and storage medium

Similar Documents

Publication Publication Date Title
CN112559007B (en) Parameter updating method and device of multitask model and electronic equipment
CN112561077B (en) Training method and device of multi-task model and electronic equipment
CN114881129A (en) Model training method and device, electronic equipment and storage medium
CN113657483A (en) Model training method, target detection method, device, equipment and storage medium
CN112966744A (en) Model training method, image processing method, device and electronic equipment
CN113360711A (en) Model training and executing method, device, equipment and medium for video understanding task
CN113627536A (en) Model training method, video classification method, device, equipment and storage medium
CN113947188A (en) Training method of target detection network and vehicle detection method
CN115358392A (en) Deep learning network training method, text detection method and text detection device
CN112580732A (en) Model training method, device, equipment, storage medium and program product
CN114581732A (en) Image processing and model training method, device, equipment and storage medium
CN114186681A (en) Method, apparatus and computer program product for generating model clusters
CN114511743A (en) Detection model training method, target detection method, device, equipment, medium and product
CN113641804A (en) Pre-training model obtaining method and device, electronic equipment and storage medium
CN113537192B (en) Image detection method, device, electronic equipment and storage medium
CN114492370B (en) Webpage identification method, webpage identification device, electronic equipment and medium
CN115984791A (en) Method and device for generating automatic driving perception model and electronic equipment
CN113361621B (en) Method and device for training model
CN112541557B (en) Training method and device for generating countermeasure network and electronic equipment
CN114817476A (en) Language model training method and device, electronic equipment and storage medium
CN113886543A (en) Method, apparatus, medium, and program product for generating an intent recognition model
CN114067805A (en) Method and device for training voiceprint recognition model and voiceprint recognition
CN114707638A (en) Model training method, model training device, object recognition method, object recognition device, object recognition medium and product
CN113936158A (en) Label matching method and device
CN113408304A (en) Text translation method and device, electronic equipment and storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination