CN112348203A - Model training method and device, terminal device and storage medium - Google Patents
Model training method and device, terminal device and storage medium Download PDFInfo
- Publication number
- CN112348203A CN112348203A CN202011225367.5A CN202011225367A CN112348203A CN 112348203 A CN112348203 A CN 112348203A CN 202011225367 A CN202011225367 A CN 202011225367A CN 112348203 A CN112348203 A CN 112348203A
- Authority
- CN
- China
- Prior art keywords
- sample
- uncertainty
- sample data
- labeled
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 125
- 238000000034 method Methods 0.000 title claims abstract description 46
- 238000002372 labelling Methods 0.000 claims abstract description 55
- 239000013598 vector Substances 0.000 claims description 20
- 238000004364 calculation method Methods 0.000 claims description 15
- 238000004590 computer program Methods 0.000 claims description 15
- 238000007781 pre-processing Methods 0.000 claims description 15
- 238000010606 normalization Methods 0.000 claims description 9
- 238000000605 extraction Methods 0.000 claims description 6
- 238000012545 processing Methods 0.000 claims description 5
- 238000013473 artificial intelligence Methods 0.000 abstract description 3
- 238000005516 engineering process Methods 0.000 abstract description 3
- 230000008569 process Effects 0.000 description 12
- 230000006870 function Effects 0.000 description 5
- 238000010586 diagram Methods 0.000 description 4
- 238000005070 sampling Methods 0.000 description 3
- 230000005540 biological transmission Effects 0.000 description 1
- 238000004422 calculation algorithm Methods 0.000 description 1
- 238000013500 data storage Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 238000013507 mapping Methods 0.000 description 1
- 230000007246 mechanism Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/10—Machine learning using kernel methods, e.g. support vector machines [SVM]
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Software Systems (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Medical Informatics (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Physics & Mathematics (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Artificial Intelligence (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
The application is applicable to the technical field of artificial intelligence, and provides a model training method, a device, terminal equipment and a storage medium, wherein the method comprises the following steps: the target model carries out classification prediction on the sample data to obtain a classification prediction result; respectively calculating the uncertainty of different sample data according to the classification prediction result; and if the uncertainty is greater than the preset value, setting the sample data corresponding to the uncertainty as a sample to be labeled, labeling the sample to be labeled, and performing model training on the target model according to the labeled sample to be labeled until the target model converges. According to the method and the device, the sample data with the uncertainty larger than the preset value are set as the samples to be labeled, the samples to be labeled are labeled according to the obtained labeling information, the effective samples in the sample data can be accurately labeled, the redundant samples are prevented from being labeled, the time cost and the labor cost for labeling the samples are reduced, and the model training efficiency is improved. In addition, the application also relates to a block chain technology.
Description
Technical Field
The present application relates to the field of artificial intelligence technologies, and in particular, to a model training method and apparatus, a terminal device, and a storage medium.
Background
The training model is an indispensable link in the field of artificial intelligence, and the model adjusts parameters of the model according to sample data and corresponding labels, so that the intrinsic relation between the sample data and the labels is learned. For a classification task, the model learns the class boundaries of the sample data.
In the existing model training process, sample data is randomly selected or called random sampling, some sample data contains more information and is more helpful for determining a class boundary in the model training process, and some sample data contains less or more redundant information and is less helpful for determining the class boundary, the random sampling method implies the requirement for redundant samples, the time cost and the labor cost for marking the sample data are increased, and further the model training efficiency is low.
Disclosure of Invention
In view of this, embodiments of the present application provide a model training method and apparatus, a terminal device, and a storage medium, so as to solve the problem of low model training efficiency caused by performing model training by using random sampling in a model training process in the prior art.
A first aspect of an embodiment of the present application provides a model training method, including:
inputting a sample pool into a target model, and performing classification prediction on sample data in the sample pool according to the target model to obtain a classification prediction result, wherein the classification prediction result comprises classification probabilities between the sample data and different preset classifications;
respectively calculating the uncertainty of different sample data according to the classification prediction result, wherein the uncertainty is used for representing the confidence degree between the corresponding sample data and the preset classification;
if the uncertainty is larger than a preset value, setting the sample data corresponding to the uncertainty as a sample to be labeled, and respectively obtaining labeling information of the sample to be labeled;
and labeling the sample to be labeled according to the labeling information, and performing model training on the target model according to the labeled sample to be labeled until the target model converges.
Further, the classifying and predicting the sample data in the sample pool according to the target model to obtain a classifying and predicting result includes:
sample preprocessing is carried out on the sample data, and the sample data after sample preprocessing is input into the target model;
controlling a convolution layer in the target model to perform feature extraction on the sample data to obtain sample features, and controlling a full-connection layer in the target model to perform feature combination on the sample features to obtain combination features;
and calculating the similarity between the combined features and preset features corresponding to different preset classifications to obtain the classification prediction result.
Further, the sample preprocessing on the sample data includes:
acquiring sample characteristics of the sample data, and calculating a characteristic mean value and a characteristic standard deviation of the sample characteristics;
normalizing the sample features according to the feature mean and the feature standard deviation;
the calculation formula adopted for carrying out standardization processing on the sample characteristics according to the characteristic mean value and the characteristic standard deviation is as follows:
zij=(xij-xi)/si
wherein zij is the sample feature after normalization, xij is the sample feature before normalization, xi is the feature mean, and si is the feature standard deviation.
Further, the calculating the similarity between the combined feature and the preset features corresponding to different preset classifications to obtain the classification prediction result includes:
respectively obtaining the feature vectors of the combined features and the preset features to obtain combined vectors and preset vectors;
and respectively calculating the distances between the combined vector and different preset vectors according to an Euclidean distance formula to obtain the classification prediction result.
Further, the calculation formula for calculating the uncertainty of the different sample data according to the classification prediction result is as follows:
uncertainty=1-max(softmax(M(Samplem)))
wherein the uncertaintiy is the uncertainty, max (Sample max), corresponding to the mth Sample datam) ) is the maximum classification probability between the mth sample data and a different one of the preset classifications.
Further, the performing model training on the target model according to the sample data after the sample labeling includes:
generating sample training data according to the marked sample to be marked, and inputting the sample training data into the target model for model training;
if the model loss value of the target model after model training is larger than a loss threshold value, calculating the uncertainty of the to-be-labeled sample after labeling in the sample training data respectively;
if the uncertainty of the marked sample to be marked is smaller than or equal to the preset value, deleting the marked sample to be marked in the sample training data;
and inputting the sample training data of the to-be-labeled sample after the label is deleted into the target model for model training, and stopping the model training of the target model until the model loss value of the target model after the model training is smaller than or equal to the loss threshold value.
Further, after the classification prediction is performed on the sample data in the sample pool according to the target model and a classification prediction result is obtained, the method further includes:
and if the uncertainty is less than or equal to the preset value, deleting the sample data corresponding to the uncertainty in the sample pool.
A second aspect of an embodiment of the present application provides a model training apparatus, including:
the classification prediction unit is used for inputting a sample pool into a target model and performing classification prediction on sample data in the sample pool according to the target model to obtain a classification prediction result, wherein the classification prediction result comprises classification probabilities between the sample data and different preset classifications;
the uncertainty calculation unit is used for calculating the uncertainty of different sample data according to the classification prediction result, and the uncertainty is used for representing the confidence degree between the corresponding sample data and the preset classification;
a labeling information obtaining unit, configured to set the sample data corresponding to the uncertainty as a sample to be labeled and obtain labeling information of the sample to be labeled respectively, if the uncertainty is greater than a preset value;
and the model training unit is used for labeling the sample to be labeled according to the labeling information and performing model training on the target model according to the labeled sample to be labeled until the target model converges.
A third aspect of the embodiments of the present application provides a terminal device, which includes a memory, a processor, and a computer program stored in the memory and executable on the terminal device, where the processor implements the steps of the model training method provided by the first aspect when executing the computer program.
A fourth aspect of the embodiments of the present application provides a storage medium, which stores a computer program that, when executed by a processor, implements the steps of the model training method provided by the first aspect.
The model training method, the model training device, the terminal equipment and the storage medium have the following advantages that:
the model training method provided by the embodiment of the application can effectively represent the confidence degree between corresponding sample data and the preset classification by respectively calculating the uncertainty of different sample data according to the classification prediction result based on the uncertainty, and because the higher the confidence degree between the sample data and the preset classification is, the more accurate the classification between the target model and the sample data and the preset classification is, the less the help of the sample data for determining the class boundary in the model training process of the target model is, therefore, the redundant samples in the sample data can be effectively identified based on the uncertainty, and by setting the sample data with the uncertainty larger than the preset value as the samples to be labeled and labeling the samples to be labeled according to the obtained labeling information, the effective samples in the sample data can be accurately labeled, and the labeling of the redundant samples can be prevented, and further, the time cost and the labor cost for carrying out sample labeling on the sample data are reduced, and the model training efficiency is improved.
Drawings
In order to more clearly illustrate the technical solutions in the embodiments of the present application, the drawings needed to be used in the embodiments or the prior art descriptions will be briefly described below, and it is obvious that the drawings in the following description are only some embodiments of the present application, and it is obvious for those skilled in the art to obtain other drawings based on these drawings without inventive exercise.
FIG. 1 is a flowchart of an implementation of a model training method provided in an embodiment of the present application;
FIG. 2 is a flowchart illustrating an implementation of a model training method according to another embodiment of the present disclosure;
fig. 3 is a block diagram of a model training apparatus according to an embodiment of the present disclosure;
fig. 4 is a block diagram of a terminal device according to an embodiment of the present disclosure.
Detailed Description
In order to make the objects, technical solutions and advantages of the present application more apparent, the present application is described in further detail below with reference to the accompanying drawings and embodiments. It should be understood that the specific embodiments described herein are merely illustrative of the present application and are not intended to limit the present application.
The model training method according to the embodiment of the present application may be executed by a control device or a terminal (hereinafter referred to as a "mobile terminal").
Referring to fig. 1, fig. 1 shows a flowchart of an implementation of a model training method provided in an embodiment of the present application, including:
and step S10, inputting the sample pool into a target model, and carrying out classification prediction on the sample data in the sample pool according to the target model to obtain a classification prediction result.
The classification prediction result comprises classification probabilities between the sample data and different preset classifications, the number and the types of the preset classifications preset in the target model can be selected according to requirements, and the number and the data of the sample data in the sample pool can be set according to requirements.
For example, the preset classification includes class a, class B and class C, the sample pool includes sample data D and sample data E, and the class prediction results are obtained by performing class prediction of class a, class B and class C on the sample data D and the sample data E according to the target model, the class prediction results include class probability D1, class probability D2, class probability D3, class probability E1, class probability E2 and class probability E3, the class probability D1 is used for representing the probability that the sample data D belongs to class a, the class probability D2 is used for representing the probability that the sample data D belongs to class B, the class probability D3 is used for representing the probability that the sample data D belongs to class C, the class probability E1 is used for representing the probability that the sample data E belongs to class a, the class probability E2 is used for representing the probability that the sample data E belongs to class B, the classification probability E3 is used to characterize the probability that the sample data E belongs to the classification C.
Optionally, in this step, a model training instruction of a user is received, a target model and a sample pool are respectively queried according to a model identifier and a sample identifier carried in the model training instruction, and the queried sample pool is input into the queried target model for classification prediction, so as to obtain a classification prediction result.
Further, in this step, the calculating a similarity between the combined feature and preset features corresponding to different preset classifications to obtain the classification prediction result includes: respectively obtaining the feature vectors of the combined features and the preset features to obtain combined vectors and preset vectors; and respectively calculating the distances between the combined vector and different preset vectors according to an Euclidean distance formula to obtain the classification prediction result.
And step S20, respectively calculating the uncertainty of different sample data according to the classification prediction result.
The uncertainty is used for representing the confidence degree between the corresponding sample data and the preset classification, when the confidence degree between the sample data and the preset classification is higher, the more accurate the classification between the sample data and the preset classification is aimed at by the target model, the smaller the help of the sample data to determine the class boundary in the model training process of the target model is, and when the confidence degree between the sample data and the preset classification is lower, the more inaccurate the classification between the sample data and the preset classification is aimed at by the target model, the larger the help of the sample data to determine the class boundary in the model training process of the target model is.
Optionally, in this step, the calculation formula for calculating the uncertainty of the different sample data according to the classification prediction result is as follows:
uncertainty=1-max(softmax(M(Samplem)))
wherein the uncertaintiy is the uncertainty, max (Sample max), corresponding to the mth Sample datam) ) is the maximum classification probability between the mth sample data and a different one of the preset classifications.
For example, in this step, when the maximum classification probability between the sample data D and the preset classification is 0.8, the corresponding uncertainty of the sample data D is 0.2, and when the maximum classification probability between the sample data E and the preset classification is 0.7, the corresponding uncertainty of the sample data E is 0.3, that is, the classification of the target model for the sample data D is more accurate than that of the sample data E, the corresponding uncertainty of the sample data D is less than that of the sample data E, and the help of the sample data E in determining the class boundary in the model training process of the target model is greater than that of the sample data D.
Further, in this embodiment, an information entropy (entropy) may also be used as the uncertainty of the sample data, that is, the information entropies of different sample data are respectively calculated based on the classification prediction result, so as to obtain the uncertainties corresponding to different sample data.
Step S30, if the uncertainty is larger than a preset value, setting the sample data corresponding to the uncertainty as a sample to be labeled.
The preset value can be used for setting parameter values according to requirements, if the uncertainty is greater than the preset value, the better the uncertainty of the sample data is determined to be, the greater the help of the sample data corresponding to the uncertainty in the process of model training of the target model for determining the class boundary is, therefore, the sample data corresponding to the uncertainty is set as a sample to be marked, the generation of subsequent sample training data is effectively facilitated, and the accuracy of the subsequent sample training data generation is improved.
And step S40, respectively obtaining the labeling information of the samples to be labeled, and labeling the samples to be labeled according to the labeling information.
The to-be-labeled sample is matched with a pre-stored labeling database respectively to obtain labeling information corresponding to different to-be-labeled samples, and the labeling database stores corresponding relations between different sample data and corresponding labeling information, so that the accuracy of labeling the to-be-labeled sample is effectively improved by matching the to-be-labeled sample with the labeling database.
Optionally, in this step, if the sample to be labeled is an image sample, matching the sample to be labeled with a preset image in a labeling database, and querying labeling information corresponding to the matched preset image to obtain labeling information corresponding to different samples to be labeled;
and if the sample to be labeled is an audio sample, acquiring the audio features of the sample to be labeled, matching the audio features of the sample to be labeled with the preset audio features in the labeling database, and inquiring the labeling information corresponding to the matched preset audio features to obtain the labeling information corresponding to different samples to be labeled.
Step S50, performing model training on the target model according to the labeled sample to be labeled until the target model converges;
the uncertainty corresponding to the sample data is higher, the higher the uncertainty of the sample data is, the higher the help of the sample data on determining the class boundary in the model training process of the target model is, therefore, in the step, the sample marking can be performed on the sample data with the uncertainty higher than the preset value, and the model training can be performed on the target model by taking the sample data with the sample marking as the model training set, so that the training efficiency of the target model is effectively improved.
Specifically, in the step, sample training data is generated according to a labeled sample to be labeled, and the sample training data is input into a target model for model training;
if the model loss value of the target model after model training is larger than the loss threshold value, respectively calculating the uncertainty of the to-be-labeled sample labeled in the sample training data;
if the uncertainty of the marked sample to be marked is less than or equal to a preset value, deleting the marked sample to be marked in the sample training data;
inputting the sample training data of the to-be-labeled sample after the label is deleted into the target model for model training, and stopping the model training of the target model until the model loss value of the target model after the model training is smaller than or equal to the loss threshold value;
wherein, the loss calculation can be carried out on the target model after model training based on the loss function to obtain the model loss value, judging whether the target model converges or not by comparing the model loss value with the loss threshold value, if the model loss value of the target model after model training is larger than the loss threshold value, the target model is not converged, and by respectively calculating the uncertainty of the sample to be labeled after being labeled in the sample training data, to judge whether the sample to be labeled in the sample training data is a redundant sample for the target model after model training, namely, when the uncertainty of the marked sample to be marked is less than or equal to the preset value, the marked sample to be marked is a redundant sample, and deleting the marked sample to be marked in the sample training data, thereby effectively reducing redundant samples in the sample training data and accelerating the retraining of the subsequent target model.
Further, for step S10, after the classifying and predicting the sample data in the sample pool according to the target model and obtaining a classification prediction result, the method further includes:
and if the uncertainty is less than or equal to the preset value, deleting the sample data corresponding to the uncertainty in the sample pool, wherein if the uncertainty is less than or equal to the preset value, the sample data corresponding to the uncertainty has less help for determining the class boundary in the model training process of the target model, the sample data corresponding to the uncertainty is a redundant sample, and the redundant sample in the sample pool is effectively reduced by deleting the sample data corresponding to the uncertainty in the sample pool.
In the step, whether the sample data is a redundant sample is judged by comparing the uncertainty of the sample data with a preset value, if the uncertainty of the sample data is greater than the preset value, the sample data corresponding to the uncertainty is set as a sample to be labeled, so that the generation of subsequent sample training data is effectively facilitated, the accuracy of the generation of the subsequent sample training data is improved, and the accuracy of the labeling of the sample to be labeled is effectively improved by respectively matching the sample to be labeled with a labeling database.
In the embodiment, the uncertainty of different sample data is respectively calculated according to the classification prediction result, so that the confidence degree between the corresponding sample data and the preset classification can be effectively represented based on the uncertainty, as the higher the confidence degree between the sample data and the preset classification is, the more accurate the classification between the target model and the sample data and the preset classification is, the less the help of the sample data for determining the classification boundary in the model training process of the target model is, so that redundant samples in the sample data can be effectively identified based on the uncertainty, and by setting the sample data with the uncertainty greater than the preset value as the sample to be labeled and labeling the sample to be labeled according to the obtained labeling information, the effective sample in the sample data can be accurately labeled, the labeling of the redundant sample is prevented, and the time cost and the labor cost for labeling the sample data are further reduced, the model training efficiency is improved.
Referring to fig. 2, fig. 2 is a flowchart illustrating an implementation of a model training method according to another embodiment of the present application. With respect to the embodiment corresponding to fig. 1, the model training method provided in this embodiment is further detailed in step S10 in the embodiment corresponding to fig. 1, and includes:
step S11, sample preprocessing is carried out on the sample data, and the sample data after sample preprocessing is input into the target model.
The accuracy of the sample data is improved by sample preprocessing of the sample data.
Specifically, in this step, the sample preprocessing on the sample data includes:
acquiring sample characteristics of the sample data, and calculating a characteristic mean value and a characteristic standard deviation of the sample characteristics;
normalizing the sample features according to the feature mean and the feature standard deviation;
the characteristic mean value and the characteristic standard deviation of the sample characteristics are calculated, and the sample characteristics are subjected to standardization processing according to the characteristic mean value and the characteristic standard deviation, so that the data characteristics in the sample data are mapped to a specified interval, the unit limitation of the data is removed, the data are converted into dimensionless pure values, and the subsequent target model is effectively convenient for extracting the characteristics of the sample data.
Optionally, in this embodiment, the calculation formula adopted for normalizing the sample feature according to the feature mean and the feature standard deviation is as follows:
zij=(xij-xi)/si
wherein zij is the sample feature after normalization, xij is the sample feature before normalization, xi is the feature mean, and si is the feature standard deviation.
Further, in this embodiment, the sample data may be preprocessed by using a data normalization process, so as to achieve an effect of mapping the data features in the sample data to a specified interval and converting the data features into dimensionless pure numerical values.
Furthermore, in this embodiment, the data type of the sample data may be obtained, a corresponding preprocessing policy may be queried based on the data type, and the sample data may be preprocessed based on the preprocessing policy.
And step S12, controlling the convolution layer in the target model to perform feature extraction on the sample data to obtain sample features, and controlling the full-link layer in the target model to perform feature combination on the sample features to obtain combined features.
And performing feature extraction on the sample data based on the convolution kernel in the convolution layer in the target model to obtain the sample feature, and performing feature combination on the sample features output by different convolution layers based on the full-connection layer in the target model to obtain the combination feature.
And step S13, calculating the similarity between the combined features and preset features corresponding to different preset classifications to obtain the classification prediction result.
The similarity between the combined feature and the preset features corresponding to different preset classifications can be calculated by adopting an Euclidean distance formula, so as to obtain the classification probability between the sample data corresponding to the combined feature and the different preset classifications.
In the embodiment, by performing sample preprocessing on the sample data, the unit limit of the sample data can be effectively removed, the sample data is converted into a dimensionless pure numerical value, the feature extraction of the target model on the sample data is facilitated, and the classification probability between the sample data corresponding to the combined features and different preset classifications can be effectively calculated by calculating the similarity between the combined features and the preset features corresponding to the different preset classifications.
In all embodiments of the present application, model training is performed based on sample data after sample labeling to obtain a target model, specifically, the target model is obtained by performing model training on the sample data after sample labeling. Uploading the converged target model to the blockchain can ensure the safety and the fair transparency of the target model to the user. The user equipment may download the converged target model from the blockchain to verify whether the converged target model is tampered. The blockchain referred to in this example is a novel application mode of computer technologies such as distributed data storage, point-to-point transmission, consensus mechanism, encryption algorithm, and the like. A block chain (Blockchain), which is essentially a decentralized database, is a series of data blocks associated by using a cryptographic method, and each data block contains information of a batch of network transactions, so as to verify the validity (anti-counterfeiting) of the information and generate a next block. The blockchain may include a blockchain underlying platform, a platform product service layer, an application service layer, and the like.
Referring to fig. 3, fig. 3 is a block diagram illustrating a model training apparatus 100 according to an embodiment of the present disclosure. The model training apparatus 100 in this embodiment includes units for performing the steps in the embodiments corresponding to fig. 1 to 2. Please refer to fig. 1 to 2 and fig. 1 to 2 for the corresponding embodiments. For convenience of explanation, only the portions related to the present embodiment are shown. Referring to fig. 3, the model training apparatus 100 includes: a classification prediction unit 10, an uncertainty calculation unit 11, a labeling information acquisition unit 12, and a model training unit 13, wherein:
the classification prediction unit 10 is configured to input a sample pool into a target model, and perform classification prediction on sample data in the sample pool according to the target model to obtain a classification prediction result, where the classification prediction result includes classification probabilities between the sample data and different preset classifications.
Wherein the classification prediction unit 10 is further configured to: sample preprocessing is carried out on the sample data, and the sample data after sample preprocessing is input into the target model;
controlling a convolution layer in the target model to perform feature extraction on the sample data to obtain sample features, and controlling a full-connection layer in the target model to perform feature combination on the sample features to obtain combination features;
and calculating the similarity between the combined features and preset features corresponding to different preset classifications to obtain the classification prediction result.
Optionally, the classification predicting unit 10 is further configured to: acquiring sample characteristics of the sample data, and calculating a characteristic mean value and a characteristic standard deviation of the sample characteristics;
normalizing the sample features according to the feature mean and the feature standard deviation;
the calculation formula adopted for carrying out standardization processing on the sample characteristics according to the characteristic mean value and the characteristic standard deviation is as follows:
zij=(xij-xi)/si
wherein zij is the sample feature after normalization, xij is the sample feature before normalization, xi is the feature mean, and si is the feature standard deviation.
Optionally, the classification predicting unit 10 is further configured to: the calculating the similarity between the combined features and preset features corresponding to different preset classifications to obtain the classification prediction result includes:
respectively obtaining the feature vectors of the combined features and the preset features to obtain combined vectors and preset vectors;
and respectively calculating the distances between the combined vector and different preset vectors according to an Euclidean distance formula to obtain the classification prediction result.
And an uncertainty calculation unit 11, configured to calculate uncertainties of different sample data according to the classification prediction result, where the uncertainties are used to represent confidence degrees between corresponding sample data and the preset classification.
Optionally, in the uncertainty calculating unit 11, the calculation formulas for calculating the uncertainties of different sample data according to the classification prediction result are as follows:
uncertainty=1-max(softmax(M(Samplem)))
wherein the uncertaintiy is the uncertainty, max (Sample max), corresponding to the mth Sample datam) ) is the maximum classification probability between the mth sample data and a different one of the preset classifications.
Optionally, the uncertainty calculation unit 11 is further configured to: and if the uncertainty is less than or equal to the preset value, deleting the sample data corresponding to the uncertainty in the sample pool.
And the labeling information obtaining unit 12 is configured to set the sample data corresponding to the uncertainty as a sample to be labeled and obtain labeling information of the sample to be labeled, respectively, if the uncertainty is greater than a preset value.
And the model training unit 13 is configured to label the sample to be labeled according to the labeling information, and perform model training on the target model according to the labeled sample to be labeled until the target model converges.
Optionally, the model training unit 13 is further configured to: the model training of the target model according to the sample data after the sample labeling comprises the following steps:
generating sample training data according to the marked sample to be marked, and inputting the sample training data into the target model for model training;
if the model loss value of the target model after model training is larger than a loss threshold value, calculating the uncertainty of the to-be-labeled sample after labeling in the sample training data respectively;
if the uncertainty of the marked sample to be marked is smaller than or equal to the preset value, deleting the marked sample to be marked in the sample training data;
and inputting the sample training data of the to-be-labeled sample after the label is deleted into the target model for model training, and stopping the model training of the target model until the model loss value of the target model after the model training is smaller than or equal to the loss threshold value.
As can be seen from the above, by calculating the uncertainty of different sample data according to the classification prediction result, so that the confidence level between the corresponding sample data and the preset classification can be effectively represented based on the uncertainty, since the higher the confidence level between the sample data and the preset classification, the more accurate the target model is with respect to the classification between the sample data and the preset classification, the less helpful the sample data is in determining class boundaries during model training for the target model, therefore, the redundant samples in the sample data can be effectively identified based on the uncertainty, and the effective samples in the sample data can be accurately marked by marking the sample data according to the uncertainty, so as to prevent the redundant samples from being marked, and further, the time cost and the labor cost for carrying out sample labeling on the sample data are reduced, and the model training efficiency is improved.
Fig. 4 is a block diagram of a terminal device 2 according to another embodiment of the present application. As shown in fig. 4, the terminal device 2 of this embodiment includes: a processor 20, a memory 21 and a computer program 22, such as a program of a model training method, stored in said memory 21 and executable on said processor 20. The processor 20, when executing the computer program 23, implements the steps of the above-mentioned various embodiments of the model training method, such as S10 to S50 shown in fig. 1, or S11 to S13 shown in fig. 2. Alternatively, when the processor 20 executes the computer program 22, the functions of the units in the embodiment corresponding to fig. 3, for example, the functions of the units 10 to 12 shown in fig. 3, are implemented, for which reference is specifically made to the relevant description in the embodiment corresponding to fig. 4, which is not repeated herein.
Illustratively, the computer program 22 may be divided into one or more units, which are stored in the memory 21 and executed by the processor 20 to accomplish the present application. The one or more units may be a series of computer program instruction segments capable of performing specific functions, which are used to describe the execution of the computer program 22 in the terminal device 2. For example, the computer program 22 may be divided into a classification prediction unit 10, an uncertainty calculation unit 11, and a model training unit 13, each of which functions as described above.
The terminal device may include, but is not limited to, a processor 20, a memory 21. It will be appreciated by those skilled in the art that fig. 4 is merely an example of a terminal device 2 and does not constitute a limitation of the terminal device 2 and may include more or less components than those shown, or some components may be combined, or different components, for example the terminal device may also include input output devices, network access devices, buses, etc.
The Processor 20 may be a Central Processing Unit (CPU), other general purpose Processor, a Digital Signal Processor (DSP), an Application Specific Integrated Circuit (ASIC), an off-the-shelf Programmable Gate Array (FPGA) or other Programmable logic device, discrete Gate or transistor logic, discrete hardware components, etc. A general purpose processor may be a microprocessor or the processor may be any conventional processor or the like.
The memory 21 may be an internal storage unit of the terminal device 2, such as a hard disk or a memory of the terminal device 2. The memory 21 may also be an external storage device of the terminal device 2, such as a plug-in hard disk, a Smart Media Card (SMC), a Secure Digital (SD) Card, a Flash memory Card (Flash Card), and the like, which are provided on the terminal device 2. Further, the memory 21 may also include both an internal storage unit and an external storage device of the terminal device 2. The memory 21 is used for storing the computer program and other programs and data required by the terminal device. The memory 21 may also be used to temporarily store data that has been output or is to be output.
The above-mentioned embodiments are only used for illustrating the technical solutions of the present application, and not for limiting the same; although the present application has been described in detail with reference to the foregoing embodiments, it should be understood by those of ordinary skill in the art that: the technical solutions described in the foregoing embodiments may still be modified, or some technical features may be equivalently replaced; such modifications and substitutions do not substantially depart from the spirit and scope of the embodiments of the present application and are intended to be included within the scope of the present application.
Claims (10)
1. A method of model training, comprising:
inputting a sample pool into a target model, and performing classification prediction on sample data in the sample pool according to the target model to obtain a classification prediction result, wherein the classification prediction result comprises classification probabilities between the sample data and different preset classifications;
respectively calculating the uncertainty of different sample data according to the classification prediction result, wherein the uncertainty is used for representing the confidence degree between the corresponding sample data and the preset classification;
if the uncertainty is larger than a preset value, setting the sample data corresponding to the uncertainty as a sample to be labeled, and respectively obtaining labeling information of the sample to be labeled;
and labeling the sample to be labeled according to the labeling information, and performing model training on the target model according to the labeled sample to be labeled until the target model converges.
2. The model training method according to claim 1, wherein the performing classification prediction on the sample data in the sample pool according to the target model to obtain a classification prediction result comprises:
sample preprocessing is carried out on the sample data, and the sample data after sample preprocessing is input into the target model;
controlling a convolution layer in the target model to perform feature extraction on the sample data to obtain sample features, and controlling a full-connection layer in the target model to perform feature combination on the sample features to obtain combination features;
and calculating the similarity between the combined features and preset features corresponding to different preset classifications to obtain the classification prediction result.
3. The model training method of claim 2, wherein the sample preprocessing of the sample data comprises:
acquiring sample characteristics of the sample data, and calculating a characteristic mean value and a characteristic standard deviation of the sample characteristics;
normalizing the sample features according to the feature mean and the feature standard deviation;
the calculation formula adopted for carrying out standardization processing on the sample characteristics according to the characteristic mean value and the characteristic standard deviation is as follows:
zij=(xij-xi)/si
wherein zij is the sample feature after normalization, xij is the sample feature before normalization, xi is the feature mean, and si is the feature standard deviation.
4. The model training method according to claim 2, wherein the calculating the similarity between the combined features and the preset features corresponding to different preset classifications to obtain the classification prediction result comprises:
respectively obtaining the feature vectors of the combined features and the preset features to obtain combined vectors and preset vectors;
and respectively calculating the distances between the combined vector and different preset vectors according to an Euclidean distance formula to obtain the classification prediction result.
5. The model training method according to claim 1, wherein the calculation formula for calculating the uncertainty of each sample data according to the classification prediction result is:
uncertainty=1-max(softmax(M(Samplem)))
wherein the uncertaintiy is the uncertainty, max (Sample max), corresponding to the mth Sample datam) ) is the maximum classification probability between the mth sample data and a different one of the preset classifications.
6. The model training method according to claim 1, wherein the model training of the target model according to the sample data after sample labeling comprises:
generating sample training data according to the marked sample to be marked, and inputting the sample training data into the target model for model training;
if the model loss value of the target model after model training is larger than a loss threshold value, calculating the uncertainty of the to-be-labeled sample after labeling in the sample training data respectively;
if the uncertainty of the marked sample to be marked is smaller than or equal to the preset value, deleting the marked sample to be marked in the sample training data;
and inputting the sample training data of the to-be-labeled sample after the label is deleted into the target model for model training, and stopping the model training of the target model until the model loss value of the target model after the model training is smaller than or equal to the loss threshold value.
7. The method of claim 1, wherein after the performing classification prediction on the sample data in the sample pool according to the target model to obtain a classification prediction result, the method further comprises:
and if the uncertainty is less than or equal to the preset value, deleting the sample data corresponding to the uncertainty in the sample pool.
8. A model training apparatus, comprising:
the classification prediction unit is used for inputting a sample pool into a target model and performing classification prediction on sample data in the sample pool according to the target model to obtain a classification prediction result, wherein the classification prediction result comprises classification probabilities between the sample data and different preset classifications;
the uncertainty calculation unit is used for calculating the uncertainty of different sample data according to the classification prediction result, and the uncertainty is used for representing the confidence degree between the corresponding sample data and the preset classification;
a labeling information obtaining unit, configured to set the sample data corresponding to the uncertainty as a sample to be labeled and obtain labeling information of the sample to be labeled respectively, if the uncertainty is greater than a preset value;
and the model training unit is used for labeling the sample to be labeled according to the labeling information and performing model training on the target model according to the labeled sample to be labeled until the target model converges.
9. A terminal device comprising a memory, a processor and a computer program stored in the memory and executable on the processor, characterized in that the processor implements the steps of the method according to any of claims 1 to 7 when executing the computer program.
10. A storage medium storing a computer program, characterized in that the computer program realizes the steps of the method according to any one of claims 1 to 7 when executed by a processor.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011225367.5A CN112348203A (en) | 2020-11-05 | 2020-11-05 | Model training method and device, terminal device and storage medium |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202011225367.5A CN112348203A (en) | 2020-11-05 | 2020-11-05 | Model training method and device, terminal device and storage medium |
Publications (1)
Publication Number | Publication Date |
---|---|
CN112348203A true CN112348203A (en) | 2021-02-09 |
Family
ID=74429688
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202011225367.5A Pending CN112348203A (en) | 2020-11-05 | 2020-11-05 | Model training method and device, terminal device and storage medium |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN112348203A (en) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112906817A (en) * | 2021-03-16 | 2021-06-04 | 中科海拓(无锡)科技有限公司 | Intelligent image labeling method |
CN113469265A (en) * | 2021-07-14 | 2021-10-01 | 浙江大华技术股份有限公司 | Data category attribute determining method and device, storage medium and electronic device |
CN113645107A (en) * | 2021-07-27 | 2021-11-12 | 广州市威士丹利智能科技有限公司 | Gateway conflict resolution method and system based on smart home |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110232678A (en) * | 2019-05-27 | 2019-09-13 | 腾讯科技(深圳)有限公司 | A kind of image uncertainty prediction technique, device, equipment and storage medium |
CN111178407A (en) * | 2019-12-19 | 2020-05-19 | 中国平安人寿保险股份有限公司 | Road condition data screening method and device, computer equipment and storage medium |
CN111210024A (en) * | 2020-01-14 | 2020-05-29 | 深圳供电局有限公司 | Model training method and device, computer equipment and storage medium |
WO2020199591A1 (en) * | 2019-03-29 | 2020-10-08 | 平安科技(深圳)有限公司 | Text categorization model training method, apparatus, computer device, and storage medium |
-
2020
- 2020-11-05 CN CN202011225367.5A patent/CN112348203A/en active Pending
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2020199591A1 (en) * | 2019-03-29 | 2020-10-08 | 平安科技(深圳)有限公司 | Text categorization model training method, apparatus, computer device, and storage medium |
CN110232678A (en) * | 2019-05-27 | 2019-09-13 | 腾讯科技(深圳)有限公司 | A kind of image uncertainty prediction technique, device, equipment and storage medium |
CN111178407A (en) * | 2019-12-19 | 2020-05-19 | 中国平安人寿保险股份有限公司 | Road condition data screening method and device, computer equipment and storage medium |
CN111210024A (en) * | 2020-01-14 | 2020-05-29 | 深圳供电局有限公司 | Model training method and device, computer equipment and storage medium |
Non-Patent Citations (1)
Title |
---|
蒋建忠: "国际关系实证研究方法", vol. 2020, 30 October 2020, 上海远东出版社, pages: 283 - 384 * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN112906817A (en) * | 2021-03-16 | 2021-06-04 | 中科海拓(无锡)科技有限公司 | Intelligent image labeling method |
CN113469265A (en) * | 2021-07-14 | 2021-10-01 | 浙江大华技术股份有限公司 | Data category attribute determining method and device, storage medium and electronic device |
CN113645107A (en) * | 2021-07-27 | 2021-11-12 | 广州市威士丹利智能科技有限公司 | Gateway conflict resolution method and system based on smart home |
CN113645107B (en) * | 2021-07-27 | 2022-12-02 | 广州市威士丹利智能科技有限公司 | Gateway conflict resolution method and system based on smart home |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN112348203A (en) | Model training method and device, terminal device and storage medium | |
CN110472675B (en) | Image classification method, image classification device, storage medium and electronic equipment | |
CN112132277A (en) | Federal learning model training method and device, terminal equipment and storage medium | |
CN109086654B (en) | Handwriting model training method, text recognition method, device, equipment and medium | |
US10922588B2 (en) | Identification and/or verification by a consensus network using sparse parametric representations of biometric images | |
CN103313018A (en) | Registration determination device, control method thereof, and electronic apparatus | |
CN110909784B (en) | Training method and device of image recognition model and electronic equipment | |
CN116049345B (en) | Document-level event joint extraction method and system based on bidirectional event complete graph | |
CN113128536A (en) | Unsupervised learning method, system, computer device and readable storage medium | |
CN112329954B (en) | Article recall method, device, terminal equipment and storage medium | |
CN112883990A (en) | Data classification method and device, computer storage medium and electronic equipment | |
CN115935344A (en) | Abnormal equipment identification method and device and electronic equipment | |
CN113032524A (en) | Trademark infringement identification method, terminal device and storage medium | |
CN117235137B (en) | Professional information query method and device based on vector database | |
CN116662555B (en) | Request text processing method and device, electronic equipment and storage medium | |
CN113535582A (en) | Interface testing method, device, equipment and computer readable storage medium | |
CN112328881A (en) | Article recommendation method and device, terminal device and storage medium | |
CN116432608A (en) | Text generation method and device based on artificial intelligence, computer equipment and medium | |
CN116257885A (en) | Private data communication method, system and computer equipment based on federal learning | |
CN112348688B (en) | Vehicle insurance wind control analysis method and device, terminal equipment and storage medium | |
CN117975184A (en) | Training method of image recognition model and image recognition method | |
CN112769540B (en) | Diagnosis method, system, equipment and storage medium for side channel information leakage | |
CN111046933B (en) | Image classification method, device, storage medium and electronic equipment | |
CN111708988A (en) | Infringement video identification method and device, electronic equipment and storage medium | |
US20220156618A1 (en) | Ensemble classification algorithms having subclass resolution |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |