CN112784677A - Model training method and device, storage medium and computing equipment - Google Patents

Model training method and device, storage medium and computing equipment Download PDF

Info

Publication number
CN112784677A
CN112784677A CN202011415641.5A CN202011415641A CN112784677A CN 112784677 A CN112784677 A CN 112784677A CN 202011415641 A CN202011415641 A CN 202011415641A CN 112784677 A CN112784677 A CN 112784677A
Authority
CN
China
Prior art keywords
model
category
reference model
probability
error
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011415641.5A
Other languages
Chinese (zh)
Inventor
段魁
蔡涛
陈新泽
黄冠
都大龙
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Xinyi Intelligent Technology Co ltd
Original Assignee
Shanghai Xinyi Intelligent Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Xinyi Intelligent Technology Co ltd filed Critical Shanghai Xinyi Intelligent Technology Co ltd
Priority to CN202011415641.5A priority Critical patent/CN112784677A/en
Publication of CN112784677A publication Critical patent/CN112784677A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V40/00Recognition of biometric, human-related or animal-related patterns in image or video data
    • G06V40/10Human or animal bodies, e.g. vehicle occupants or pedestrians; Body parts, e.g. hands
    • G06V40/103Static body considered as a whole, e.g. static pedestrian or occupant recognition
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Human Computer Interaction (AREA)
  • Multimedia (AREA)
  • Image Analysis (AREA)

Abstract

A model training method and device, a storage medium and computing equipment are provided, wherein the model training method comprises the following steps: inputting training data into the constructed reference model and the teacher model, wherein the number of network layers of the reference model is smaller than that of the teacher model; acquiring a first output result of the reference model aiming at the training data and a second output result of the teacher model aiming at the training data; generating a third classification probability that is not for each class based on the first classification probability for that class and a fourth classification probability that is not for each class based on the second classification probability for that class; calculating KL divergence by using the first probability distribution and the second probability distribution under each category, and calculating the error of the reference model; and performing back propagation in the reference model by using the KL divergence and the error of the reference model so as to adjust the network parameters of the reference model. According to the technical scheme, the accuracy and the real-time performance of the model classification effect can be improved.

Description

Model training method and device, storage medium and computing equipment
Technical Field
The invention relates to the technical field of data processing, in particular to a model training method and device, a storage medium and computing equipment.
Background
For feature extraction and classification of data, it is usually implemented using deep network models, especially pedestrian attribute data. The pedestrian attribute is like the characteristic that everyone carries with, and good model can very big promotion its application scene.
At present, the mainstream pedestrian attribute model in the market basically acquires a video through a camera, obtains a pedestrian frame through a pedestrian detection module, and obtains pedestrian attributes through an attribute identification module.
However, the current human body attribute model usually depends on a detection frame after the human body detection model, the actual scene is complex, and the effect of the human body detection model is difficult to guarantee, so that the pedestrian attribute prediction effect is not ideal under the condition that a human body is partially lost or the human body is detected by mistake (experiments are performed on a large number of existing open source interfaces to obtain the conclusion). Secondly, the pedestrian attribute model in the market usually sacrifices real-time performance while requiring high accuracy, and vice versa. Thirdly, the generalization force of the pedestrian attribute model is not strong in the cross-domain scene.
Disclosure of Invention
The invention solves the technical problem of how to improve the accuracy and the real-time performance of the model classification effect through model training.
In order to solve the above technical problem, an embodiment of the present invention provides a model training method, where the model training method includes: inputting training data into a constructed reference model and a teacher model, wherein the number of network layers of the reference model is smaller than that of the teacher model; obtaining a first output result of the benchmark model for the training data and a second output result of a teacher model for the training data, the first output result comprising a first classification probability for each category, the second output result comprising a second classification probability for each category; generating a third classification probability which is not of each category based on the first classification probability of each category, and generating a fourth classification probability which is not of each category based on the second classification probability of each category to obtain a first probability distribution and a second probability distribution of each category, wherein the first probability distribution comprises each category and the first classification probability thereof, and the second probability distribution comprises each category and the second classification probability thereof, and the non-category and the probability thereof; calculating KL divergence by using the first probability distribution and the second probability distribution under each category, and calculating the error of the reference model; and performing back propagation in the reference model by using the KL divergence and the error of the reference model so as to adjust the network parameters of the reference model.
Optionally, the performing back propagation in the reference model by using the KL divergence and the error of the reference model itself includes: calculating the sum of the product of the KL divergence and the first weight and the product of the error of the reference model and the second weight as a response error; and utilizing the response error to perform back propagation in the reference model.
Optionally, the calculating the error of the reference model itself includes: and calculating the error of the reference model by using the Focal loss.
Optionally, the calculating the error of the reference model itself includes: acquiring a sample proportion of the training data for each category, wherein the sample proportion is a ratio of the number of samples containing the category to the total number of effective samples under the category; calculating the original error of the reference model according to the first output result; and weighting the original error and the sample proportion to obtain the error of the quasi-model.
Optionally, before inputting the training data into the constructed reference model and the teacher model, the method further includes: acquiring original sample data, wherein the original sample data is a marked pedestrian image and comprises key points; and randomly erasing the upper body image or the lower body image of the pedestrian according to the coordinates of the key points of the original sample data, and changing the attribute value in the image of the pedestrian to obtain the training data.
Optionally, the pedestrian re-recognition model is used as a pre-training model of the human body model, and the network parameters in the backhaul of the network architecture in the reference model directly call the network parameters in the pedestrian re-recognition model.
Optionally, before inputting the training data into the constructed reference model and the teacher model, the method further includes: acquiring original sample data, wherein the original sample data comprises samples with various attributes;
the method comprises the steps of inputting samples with first attributes into a pre-trained generative confrontation network to generate samples with second attributes, wherein the first attributes and the second attributes belong to the same category, and the samples with the second attributes are samples of which the number is smaller than a preset threshold.
Optionally, the constructed reference model is initialized by using a Kaiming algorithm, and the constructed reference model is initialized by using a Normal initialization weight in the full connection layer.
Optionally, the reference model is constructed based on ResNet18, and the teacher model is constructed based on ResNet 101.
In order to solve the above technical problem, an embodiment of the present invention further provides a model training apparatus, where the model training apparatus includes: the input module is used for inputting training data into a constructed reference model and a teacher model, and the number of network layers of the reference model is smaller than that of the teacher model; an output result obtaining module, configured to obtain a first output result of the reference model for the training data and a second output result of the teacher model for the training data, where the first output result includes a first classification probability for each category, and the second output result includes a second classification probability for each category; a probability generating module, configured to generate a third classification probability that is not of each category based on the first classification probability of each category, and generate a fourth classification probability that is not of each category based on the second classification probability of each category, so as to obtain a first probability distribution and a second probability distribution of each category, where the first probability distribution includes each category and its first classification probability, and the non-category and its third classification probability, and the second probability distribution includes each category and its second classification probability, and the non-category and its probability; the KL divergence calculation module is used for calculating KL divergence by utilizing the first probability distribution and the second probability distribution under each category and calculating the error of the reference model; and the parameter adjusting module is used for performing back propagation in the reference model by using the KL divergence and the error of the reference model so as to adjust the network parameters of the reference model.
Embodiments of the present invention further provide a storage medium having a computer program stored thereon, where the computer program is executed by a processor to perform the steps of the model training method.
The embodiment of the present invention further provides a computing device, which includes a memory and a processor, where the memory stores a computer program that can be executed on the processor, and the processor executes the steps of the model training method when executing the computer program.
Compared with the prior art, the technical scheme of the embodiment of the invention has the following beneficial effects:
in the technical scheme of the invention, training data are respectively input into a reference model and a teacher model with different network layer numbers, and KL divergence is calculated according to probability distribution of output results of the two models so as to be used for back propagation of the reference model, and finally the accurate model after network parameter optimization is obtained. The number of network layers of the accurate model is small, so that the operation is fast, and the instantaneity can be ensured; and the accurate model is adjusted by using a teacher model with more network layers, so that the classification accuracy can be ensured, namely the accurate model trained by the technical scheme of the invention can give consideration to both the real-time performance and the accuracy of data classification.
Further, acquiring a sample proportion of the training data for each category, wherein the sample proportion is a ratio of the number of samples containing the category to the total number of effective samples in the category; calculating an original error of the reference model according to the first output result; weighting the original error with the sample ratio to obtain the error of the quasi-model itself. According to the technical scheme, when the error used by back propagation is calculated, the sample proportion is weighted to the original error, the training effect of the sample with less number can be ensured, and the final classification accuracy of the accurate model on all data is further improved.
Further, obtaining original sample data, wherein the original sample data is a marked pedestrian image; randomly erasing the upper body image or the lower body image of the pedestrian in the original sample data, and randomly changing the attribute value in the pedestrian image to obtain the training data. According to the technical scheme, sample data is subjected to online amplification, namely random erasure, so that the diversification of sample types is realized, the training effect is improved, and the classification effect of the finally trained accurate model under the condition that a human body is partially lost or the human body is wrongly detected is improved.
Further, obtaining original sample data, wherein the original sample data comprises samples with various attributes; the method comprises the steps of inputting a sample with a first attribute into a pre-trained generative confrontation network to generate a sample with a second attribute, wherein the first attribute and the second attribute belong to the same class, and the number of the samples with the second attribute is smaller than a preset threshold. In order to ensure the training effect, under the condition of less or missing samples, the technical scheme of the invention uses the generative confrontation network to supplement the samples, so as to ensure the comprehensiveness and diversity of the samples and further ensure the model training effect.
Drawings
FIG. 1 is a flow chart of a model training method according to an embodiment of the present invention;
FIG. 2 is a flowchart of one embodiment of step S104 shown in FIG. 1;
FIG. 3 is a partial flow diagram of an embodiment of a model training method according to the present invention;
FIG. 4 is a diagram of a model network architecture according to an embodiment of the present invention;
fig. 5 is a schematic structural diagram of a model training apparatus according to an embodiment of the present invention.
Detailed Description
As described in the background art, the current human body attribute model often depends on a detection frame after the human body detection model, but the actual scene is complex, and the effect of the human body detection model is difficult to guarantee, so that the pedestrian attribute prediction effect is not ideal under the condition that a human body is partially lost or the human body is detected by mistake (experiments are performed on a large number of existing open source interfaces to obtain the conclusion). Secondly, the pedestrian attribute models in the market are high in accuracy requirement and sacrifice real-time performance, and vice versa. And thirdly, the generalization force of the pedestrian attribute model is not strong in the cross-domain scene.
In the technical scheme of the invention, a stronger reference model flow is provided firstly, and then a knowledge distillation technology is adopted to optimize the reference model on the basis of the reference model.
First, in designing a reference model, Resnet18 is used as a network architecture (backhaul), and then an average pooling layer, a full connection layer, and an output layer are connected. In the training phase, a strong benchmark model is realized by using the techniques of Focal local, Sample ratio, key point-based data online amplification, Reid model-based pre-training, GAN and the like. The reference model can solve the problems that the existing market models cannot solve, such as human body partial loss, pedestrian recognition under the cross-domain scene of pedestrians and the like.
Secondly, in the aspect of knowledge distillation technology, training data are respectively input into a reference model with a small number of networks and a teacher model with a large number of networks, KL divergence is calculated according to probability distribution of output results of the two models, the KL divergence is used for back propagation of the reference model, and finally the accurate model with optimized network parameters is obtained. The number of network layers of the accurate model is small, so that the operation is fast, and the instantaneity can be ensured; and the accurate model is adjusted by using the teacher model with more network layers, so that the classification accuracy can be ensured, namely the accurate model trained by the technical scheme of the invention can give consideration to the real-time performance and the accuracy of data classification.
In order to make the aforementioned objects, features and advantages of the present invention comprehensible, embodiments accompanied with figures are described in detail below.
FIG. 1 is a flowchart of a model training method according to an embodiment of the present invention.
The technical scheme of the invention can be used for a computing device, namely, the computing device can execute each step of the method. The computing device may be any suitable terminal such as, but not limited to, a cell phone, a computer, an internet of things device, etc.
Specifically, the model training method may include the steps of:
step S101: inputting training data into a constructed reference model and a teacher model, wherein the number of network layers of the reference model is smaller than that of the teacher model;
step S102: obtaining a first output result of the benchmark model for the training data and a second output result of a teacher model for the training data, the first output result comprising a first classification probability for each class, the second output result comprising a second classification probability for each class;
step S103: generating a third classification probability which is not of each category based on the first classification probability of each category, and generating a fourth classification probability which is not of each category based on the second classification probability of each category to obtain a first probability distribution and a second probability distribution of each category, wherein the first probability distribution comprises each category and the first classification probability thereof, and the non-category and the third classification probability thereof, and the second probability distribution comprises each category and the second classification probability thereof, and the non-category and the probability thereof;
step S104: calculating KL divergence by using the first probability distribution and the second probability distribution under each category, and calculating the error of the reference model;
step S105: and performing back propagation in the reference model by using the KL divergence and the error of the reference model so as to adjust the network parameters of the reference model.
It should be noted that the sequence numbers of the steps in this embodiment do not represent a limitation on the execution sequence of the steps.
In this embodiment, the training data may be pre-labeled data, for example, pre-labeled pedestrian images.
In the implementation of step S101, the reference model and the teacher model may be constructed in advance. The number of network layers of the reference model is smaller than that of the teacher model. The larger the number of network layers of the model is, the higher the accuracy of the model is, but the slower the running speed of the model is. The embodiment of the invention aims to realize that the model with less network layers has the classification accuracy of the model with more network layers.
In a specific example, the reference model is constructed based on a Deep residual network (ResNet) 18, and the teacher model is constructed based on a ResNet 101. ResNet18 indicates the number of network layers is 18, and ResNet101 indicates the number of network layers is 101.
It should be noted that, in practical applications, other deep network building models may also be used, for example, the model is similar to Alexnet, similar to mobilene, similar to Shufflenet, similar to Hrnet, similar to vgnet, similar to Darknet, and the like, which is not limited in this embodiment of the present invention.
The reference model and the teacher model respectively give corresponding output results according to the training data. In a specific implementation of step S102, a first output result of the reference model for the training data and a second output result of the teacher model for the training data are obtained. The first output result comprises a first classification probability for each class and the second output result comprises a second classification probability for each class. When the number of the classes is N, the first output result and the second output result are N-dimensional vectors, and each value represents the classification probability of the corresponding class.
In particular, the specific categories set for different application scenarios may be different. For example, for the identification of the attribute of the pedestrian, the specific categories may be gender of male, gender of female, age of child, age of juvenile, age of young, age of middle, age of old, hairstyle of long hair, hairstyle of short hair, color of white coat, color of black coat, and the like, which are not described in detail in the embodiments of the present invention.
In one specific example, the first output result may be that the color of the jacket is white, and the probability is 0.9; the second output may be that the jacket color is white with a probability of 0.99.
Since calculating the KL divergence requires making the sum of probabilities in the output results 1, while only a certain class of probabilities is given in the output results, the first output result and the second output result are processed in a specific implementation of step S103. That is, a third classification probability that is not for each class is generated based on the first classification probability for that class, and a fourth classification probability that is not for that class is generated based on the second classification probability for each class.
Specifically, the sum of the first classification probability and the third classification probability is 1, and the sum of the second classification probability and the fourth classification probability is 1.
In one specific example, the first output result may be that the color of the jacket is white, and the probability is 0.9; the second output may be that the jacket color is white with a probability of 0.99. The probability that the top garment is non-white in the first probability distribution is 0.1 and the probability that the top garment is non-white in the second probability distribution is 0.01.
Further, in the implementation of step S104, KL Divergence (also called relative entropy) is calculated for each class. That is, the present embodiment uses the KL divergence to supervise the probability distribution of the tags (labels). Here, the labels (labels) are categories in the output results of the reference model and the teacher model. The KL divergence may measure the similarity between the first probability distribution and the second probability distribution.
In addition, the error of the reference model itself can be calculated. The error of the reference model itself may refer to an error of an output value of the reference model from a corresponding desired value.
It should be noted that, as for a specific algorithm for calculating the KL divergence and the error of the reference model itself, reference may be made to the prior art, and the embodiment of the present invention is not limited thereto.
In the specific implementation of step S105, when performing backward propagation in the reference model, the sum of KL divergence and the weighted sum of the error of the reference model itself is used to optimize the network parameter adjustment of the reference model.
It will be appreciated by those skilled in the art that training data is used as an input and that the process of inputting to the reference model to obtain the first output result is a forward-propagating process. In the forward propagation process, input information passes through the hidden layer through the input layer, is processed layer by layer and is transmitted to the output layer. If the expected output value can not be obtained in the output layer, taking the square sum of the output value and the expected error as a target function, turning into backward propagation, calculating the partial derivative of the target function to each neuron weight in the model layer by layer to form the gradient of the target function to the weight vector, and taking the gradient as the basis for modifying the weight, wherein the learning of the network is completed in the weight modifying process. When the error falls within a predetermined range, the training process is ended.
In the embodiment of the invention, the network layer number of the reference model is less, so that the operation is faster and the real-time performance can be ensured; and the reference model is adjusted by using the teacher model with more network layers, so that the classification accuracy can be ensured, namely the reference model trained by the embodiment of the invention can give consideration to both the real-time performance and the accuracy of data classification.
In one non-limiting embodiment, step S105 shown in fig. 1 may include the following steps: calculating the sum of the product of the KL divergence and the first weight and the product of the error of the reference model and the second weight as a response error; and performing backward propagation in the reference model by using the response error. The ratio of the first weight to the second weight may be determined according to the actual application requirements.
In a preferred embodiment, the first weight is greater than the second weight.
In this embodiment, in order to enable the reference model to better learn the classification capability of the teacher model, when the KL divergence and the error of the reference model are reversely propagated, the proportion of the KL divergence to the error of the reference model itself may be made larger. That is, when weighting calculation is performed on the errors of the KL divergence and the quasi-model itself, the first weight is set to be greater than the second weight.
In a specific example, the first weight and the second weight may be 7 and 1, respectively.
In one non-limiting embodiment, step S104 shown in fig. 1 may include the following steps: and calculating the error of the reference model by using the Focal loss.
The embodiment of the invention can effectively avoid the problem of the imbalance of the positive and negative samples of the training data. When the local error is calculated, the formula can be expressed as follows:
Figure BDA0002817711710000091
wherein, y is 1, which means that the sample image is a positive sample, that is, the attribute exists in the sample image, and y is 0, which means that the sample image is a negative sample, that is, the attribute does not exist in the sample image; p is the prediction probability.
In a specific example, the parameter λ is 1.5 and the parameter α is 0.5. Specifically, the parameter λ is used as an adjustment factor to adjust the importance degree of the positive and negative samples, and the larger the value of the parameter λ is, the more important the samples with smaller number are; the parameter α is an adjustment factor for inversely adjusting the parameter λ to prevent the parameter λ from being adjusted too much.
In a non-limiting embodiment, referring to fig. 2, step S104 shown in fig. 1 may include the following steps:
step S201: acquiring a sample ratio (sample ratio) of the training data for each category, wherein the sample ratio is a ratio of the number of samples containing the category to the total number of effective samples in the category;
step S202: calculating an original error of the reference model according to the first output result;
step S203: weighting the original error with the sample ratio to obtain the error of the quasi-model itself.
In this embodiment, when calculating the error that the back propagation used, weigh the sample proportion to original error, can guarantee the training effect to the less sample of number, and then promote accurate model to the final classification accuracy of all data.
In one embodiment, the ratio of the number of samples in the category to the total number of valid samples in the category may be calculated, for example, for the category of white jacket color, the sample ratio is the ratio of the number of pictures with white jacket color to the number of pictures with color jacket, and for the pictures with no color jacket (for example, the samples without labels) are invalid samples, and the invalid samples do not participate in the calculation of the sample ratio.
It should be noted that, the specific manner for calculating the original error may be any implementable error calculation algorithm, and the embodiment of the present invention is not limited thereto.
In a non-limiting embodiment, the following steps may be further included before step S101 shown in fig. 1: acquiring original sample data, wherein the original sample data is a marked pedestrian image; and randomly erasing the upper body image or the lower body image of the pedestrian in the original sample data, and randomly changing the attribute value in the pedestrian image to obtain the training data.
In order to ensure that the trained reference model has higher identification accuracy for the scene with the missing human body part in the image, the embodiment of the invention preprocesses the original sample data, specifically, randomly erasing the upper body image or the lower body image, and randomly changing the attribute value (namely, performing online amplification on the data based on the key points), thereby ensuring the diversity and flexibility of the training data and improving the training effect.
In one non-limiting embodiment, a pedestrian re-identification model is used as a pre-training model of a human body model, and network parameters in a backhaul of a network architecture in the reference model are directly called. .
In a specific implementation, the reference model may include a Backbone (also known as a pillar, or core, for feature extraction), a pooling layer, and a full link layer. By using a pedestrian Re-identification model (Person Re-identification, ReID as a pre-training model of the human body attribute, that is, directly calling network parameters of a backhaul (such as ResNet18) in the pedestrian Re-identification model, since the ReID model has a large-scale data set as training and the ReID model and a reference model for human body attribute identification have similarity of high-dimensional features in a feature extraction stage, the problem that the reference model is not high in accuracy in a cross-domain scene can be effectively solved.
In a non-limiting embodiment, referring to fig. 3, before step S101 shown in fig. 1, the following steps may be further included:
step S301: acquiring original sample data, wherein the original sample data comprises samples with various attributes;
step S302: the method comprises the steps of inputting samples with first attributes into a pre-trained Generative Adaptive Network (GAN) to generate samples with second attributes, wherein the first attributes and the second attributes belong to the same category, and the samples with the second attributes are samples with the number smaller than a preset threshold.
In order to ensure the training effect, under the condition of less or missing samples, the embodiment of the invention uses the generative confrontation network to supplement the samples, so as to ensure the comprehensiveness and diversity of the samples and further ensure the model training effect.
In particular, the generative confrontation network may be pre-trained. The input and output of the generative confrontation network are similar attributes, for example, the input of the generative confrontation network is that the jacket is red, and the output is that the jacket is green; the input is a backpack, and the output is a single-shoulder bag; the input is gender female, the output is gender male, and the like.
In one non-limiting embodiment, the constructed reference model initializes the weights using Kaiming algorithm at initialization, and the constructed reference model initializes the weights using Normal at the fully-connected layer.
The embodiment of the invention can ensure that when the relu activation layer exists,the output value of each layer keeps Gaussian distribution, and the problem that the gradient disappears in training is solved. Specifically, Kaiming initializes the scaling factor to
Figure BDA0002817711710000111
Therefore, the variance of the input layer and the variance of the output layer are consistent, the data of the front layer and the data of the rear layer are in Gaussian distribution, and the phenomenon that the gradient disappears due to the decreasing of the variance can be avoided when the gradient is reversely transmitted.
In one non-limiting embodiment, the error of the reference model itself is the focus loss.
According to the embodiment of the invention, the conventional SoftmaxBCELoss is not adopted in error calculation, but the Focal point loss (Focal loss) is adopted, so that the problem of unbalanced training data can be effectively avoided, and the accuracy of the model is greatly improved.
In a specific application scenario, please refer to fig. 4, a reference model (baseline)41 and a teacher model 42 are constructed in advance. The reference model 41 is constructed based on ResNet18, and the teacher model 42 is constructed based on ResNet 101. The training data may be a batch of pictures.
As shown in fig. 4, the reference model 41 may include a backbone (i.e., ResNet18), a Pooling layer (Pooling), a full connection layer (FC), and output layers (Outputs). The teacher model 42 may include a backbone (i.e., ResNet 101).
At the time of training, training data is input to the reference model 41 and the teacher model 42, respectively. The reference model 41 and the teacher model 42 output a first output result and a second output result, respectively. The reference model 41 calculates self error Focalloss; the teacher model 42 calculates KL divergence KLDivloss. The errors Focalloss and KL divergence KLDivloss are used for back propagation in the reference model 41.
The embodiment of the invention trains a larger model ResNet101 on a data set, adopts the technical principle of Knowledge Distillation (KD), and then uses the ResNet101 as a teacher of Knowledge Distillation to teach the model of ResNet18 on line to realize Knowledge transfer. In the embodiment of the invention, KD is not made before a Pooling (Pooling) layer, KD is adopted for a final label (label), and L2loss cannot be directly used for learning of label, but KL divergence is used for monitoring probability distribution of label.
Referring to fig. 5, an embodiment of the present invention further discloses a model training apparatus 50, where the model training apparatus 50 may include:
an input module 501, configured to input training data into a constructed reference model and a teacher model, where the number of network layers of the reference model is smaller than that of the teacher model;
an output result obtaining module 502, configured to obtain a first output result of the benchmark model for the training data and a second output result of a teacher model for the training data, where the first output result includes a first classification probability for each category, and the second output result includes a second classification probability for each category;
a probability generating module 503, configured to generate a third classification probability that is not of each category based on the first classification probability of each category, and generate a fourth classification probability that is not of each category based on the second classification probability of each category, so as to obtain a first probability distribution and a second probability distribution of each category, where the first probability distribution includes each category and its first classification probability, and the non-category and its third classification probability, and the second probability distribution includes each category and its second classification probability, and the non-category and its probability;
a KL divergence calculation module 504, configured to calculate a KL divergence by using the first probability distribution and the second probability distribution in each category, and calculate an error of the reference model itself;
a parameter adjusting module 505, configured to utilize the KL divergence and the error of the reference model to perform back propagation in the reference model, so as to adjust a network parameter of the reference model.
The number of network layers of the reference model is small, so that the operation is fast, and the instantaneity can be ensured; and the reference model is adjusted by using the teacher model with more network layers, so that the classification accuracy can be ensured, namely the accurate model trained by the technical scheme of the invention can give consideration to both the real-time performance and the accuracy of data classification.
For more details of the working principle and the working mode of the model training apparatus 50, reference may be made to the related descriptions in fig. 1 to fig. 4, which are not repeated herein.
The embodiment of the invention also discloses a storage medium, which is a computer readable storage medium and stores a computer program, and the computer program can execute the steps of the methods shown in fig. 1-3 when running. The storage medium may include ROM, RAM, magnetic or optical disks, etc. The storage medium may further include a non-volatile memory (non-volatile) or a non-transitory memory (non-transient), and the like.
The embodiment of the invention also discloses a computing device which can comprise a memory and a processor, wherein the memory is stored with a computer program which can run on the processor. The processor, when running the computer program, may perform the steps of the methods shown in fig. 1-3. The computing device includes, but is not limited to, a mobile phone, a computer, a tablet computer, and other terminal devices.
It should be understood that the processor may be a general purpose processor, a Digital Signal Processor (DSP), an Application Specific Integrated Circuit (ASIC), an off-the-shelf programmable gate array (FPGA) or other programmable logic device, a discrete gate or transistor logic device, a discrete hardware component, a system on chip (SoC), a Central Processing Unit (CPU), a Network Processor (NP), a Digital Signal Processor (DSP), a Micro Controller Unit (MCU), a programmable logic controller (PLD), or other integrated chip. The various methods, steps, and logic blocks disclosed in the embodiments of the present application may be implemented or performed. A general purpose processor may be a microprocessor or the processor may be any conventional processor or the like. The steps of the method disclosed in connection with the embodiments of the present application may be directly implemented by a hardware decoding processor, or implemented by a combination of hardware and software modules in the decoding processor. The software module may be located in ram, flash memory, rom, prom, or eprom, registers, etc. storage media as is well known in the art. The storage medium is located in a memory, and a processor reads information in the memory and completes the steps of the method in combination with hardware of the processor.
It will also be appreciated that the memory referred to in this embodiment of the invention may be either volatile memory or nonvolatile memory, or may include both volatile and nonvolatile memory. The non-volatile memory may be a read-only memory (ROM), a Programmable ROM (PROM), an Erasable PROM (EPROM), an electrically Erasable EPROM (EEPROM), or a flash memory. Volatile memory can be Random Access Memory (RAM), which acts as external cache memory. By way of example, but not limitation, many forms of RAM are available, such as Static Random Access Memory (SRAM), Dynamic Random Access Memory (DRAM), Synchronous Dynamic Random Access Memory (SDRAM), double data rate SDRAM, enhanced SDRAM, SLDRAM, Synchronous Link DRAM (SLDRAM), and direct rambus RAM (DR RAM). It should be noted that the memory of the systems and methods described herein is intended to comprise, without being limited to, these and any other suitable types of memory.
It should be noted that when the processor is a general-purpose processor, a DSP, an ASIC, an FPGA or other programmable logic device, a discrete gate or transistor logic device, or a discrete hardware component, the memory (storage module) is integrated in the processor. It should be noted that the memory described herein is intended to comprise, without being limited to, these and any other suitable types of memory.
Those of ordinary skill in the art will appreciate that the various illustrative elements and algorithm steps described in connection with the embodiments disclosed herein may be implemented as electronic hardware or combinations of computer software and electronic hardware. Whether such functionality is implemented as hardware or software depends upon the particular application and design constraints imposed on the implementation. Skilled artisans may implement the described functionality in varying ways for each particular application, but such implementation decisions should not be interpreted as causing a departure from the scope of the present application.
It is clear to those skilled in the art that, for convenience and brevity of description, the specific working processes of the above-described systems, apparatuses and units may refer to the corresponding processes in the foregoing method embodiments, and are not described herein again.
The functions, if implemented in the form of software functional units and sold or used as a stand-alone product, may be stored in a computer readable storage medium. Based on such understanding, the technical solution of the present application or portions thereof that substantially contribute to the prior art may be embodied in the form of a software product stored in a storage medium and including instructions for causing a computer device (which may be a personal computer, a server, or a network device) to execute all or part of the steps of the method according to the embodiments of the present application. And the aforementioned storage medium includes: various media capable of storing program codes, such as a usb disk, a removable hard disk, a read-only memory (ROM), a Random Access Memory (RAM), a magnetic disk, or an optical disk.
Although the present invention is disclosed above, the present invention is not limited thereto. Various changes and modifications may be effected therein by one skilled in the art without departing from the spirit and scope of the invention as defined in the appended claims.

Claims (12)

1. A method of model training, comprising:
inputting training data into a constructed reference model and a teacher model, wherein the number of network layers of the reference model is smaller than that of the teacher model;
obtaining a first output result of the benchmark model for the training data and a second output result of a teacher model for the training data, the first output result comprising a first classification probability for each category, the second output result comprising a second classification probability for each category;
generating a third classification probability which is not of each category based on the first classification probability of each category, and generating a fourth classification probability which is not of each category based on the second classification probability of each category to obtain a first probability distribution and a second probability distribution of each category, wherein the first probability distribution comprises each category and the first classification probability thereof, and the non-category and the third classification probability thereof, and the second probability distribution comprises each category and the second classification probability thereof, and the non-category and the probability thereof;
calculating KL divergence by using the first probability distribution and the second probability distribution under each category, and calculating the error of the reference model;
and performing back propagation in the reference model by using the KL divergence and the error of the reference model so as to adjust the network parameters of the reference model.
2. The model training method according to claim 1, wherein said back-propagating in the reference model using the KL divergence and the error of the reference model itself comprises:
calculating the sum of the product of the KL divergence and the first weight and the product of the error of the reference model and the second weight as a response error;
and utilizing the response error to perform back propagation in the reference model.
3. The model training method of claim 1, wherein the calculating the error of the reference model itself comprises:
focalloss is used to calculate the error of the reference model itself.
4. The model training method of claim 1, wherein the calculating the error of the reference model itself comprises:
acquiring a sample proportion of the training data for each category, wherein the sample proportion is a ratio of the number of samples containing the category to the total number of effective samples in the category;
calculating an original error of the reference model according to the first output result;
weighting the original error with the sample ratio to obtain the error of the quasi-model itself.
5. The model training method of claim 1, wherein inputting training data to the constructed reference model and the teacher model further comprises:
acquiring original sample data, wherein the original sample data is a marked pedestrian image and comprises key points;
and randomly erasing the upper body image or the lower body image of the pedestrian according to the coordinates of the key points of the original sample data, and changing the attribute value in the image of the pedestrian to obtain the training data.
6. The model training method according to claim 1, wherein a pedestrian re-recognition model is used as a pre-training model of the human body model, and the network parameters in the backhaul of the network architecture in the reference model are directly called.
7. The model training method of claim 1, wherein inputting training data to the constructed reference model and the teacher model further comprises:
acquiring original sample data, wherein the original sample data comprises samples with various attributes;
the method comprises the steps of inputting samples with first attributes into a pre-trained generative confrontation network to generate samples with second attributes, wherein the first attributes and the second attributes belong to the same category, and the samples with the second attributes are samples of which the number is smaller than a preset threshold.
8. The model training method as claimed in any one of claims 1 to 7, wherein the constructed reference model is initialized with Kaiming algorithm and the constructed reference model is initialized with Normal weights at the full connection layer.
9. The model training method of any one of claims 1 to 7, wherein the reference model is constructed based on ResNet18 and the teacher model is constructed based on ResNet 101.
10. A model training apparatus, comprising:
the input module is used for inputting training data into a constructed reference model and a teacher model, and the number of network layers of the reference model is smaller than that of the teacher model;
an output result obtaining module, configured to obtain a first output result of the reference model for the training data and a second output result of a teacher model for the training data, where the first output result includes a first classification probability for each category, and the second output result includes a second classification probability for each category;
a probability generating module, configured to generate a third classification probability that is not of each category based on the first classification probability of each category, and generate a fourth classification probability that is not of each category based on the second classification probability of each category, so as to obtain a first probability distribution and a second probability distribution of each category, where the first probability distribution includes each category and its first classification probability, and the non-category and its third classification probability, and the second probability distribution includes each category and its second classification probability, and the non-category and its probability;
the KL divergence calculation module is used for calculating KL divergence by utilizing the first probability distribution and the second probability distribution under each category and calculating the error of the reference model;
and the parameter adjusting module is used for performing back propagation in the reference model by using the KL divergence and the error of the reference model so as to adjust the network parameters of the reference model.
11. A storage medium having a computer program stored thereon, the computer program, when being executed by a processor, performing the steps of the model training method according to any one of claims 1 to 9.
12. A computing device comprising a memory and a processor, the memory having stored thereon a computer program operable on the processor, wherein the processor, when executing the computer program, performs the steps of the model training method of any one of claims 1 to 9.
CN202011415641.5A 2020-12-04 2020-12-04 Model training method and device, storage medium and computing equipment Pending CN112784677A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011415641.5A CN112784677A (en) 2020-12-04 2020-12-04 Model training method and device, storage medium and computing equipment

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011415641.5A CN112784677A (en) 2020-12-04 2020-12-04 Model training method and device, storage medium and computing equipment

Publications (1)

Publication Number Publication Date
CN112784677A true CN112784677A (en) 2021-05-11

Family

ID=75750750

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011415641.5A Pending CN112784677A (en) 2020-12-04 2020-12-04 Model training method and device, storage medium and computing equipment

Country Status (1)

Country Link
CN (1) CN112784677A (en)

Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108921092A (en) * 2018-07-02 2018-11-30 浙江工业大学 A kind of melanoma classification method based on convolutional neural networks model Two-level ensemble
CN109711544A (en) * 2018-12-04 2019-05-03 北京市商汤科技开发有限公司 Method, apparatus, electronic equipment and the computer storage medium of model compression
CN110021051A (en) * 2019-04-01 2019-07-16 浙江大学 One kind passing through text Conrad object image generation method based on confrontation network is generated
CN110147456A (en) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 A kind of image classification method, device, readable storage medium storing program for executing and terminal device
CN110197212A (en) * 2019-05-20 2019-09-03 北京邮电大学 Image classification method, system and computer readable storage medium
CN110321928A (en) * 2019-06-03 2019-10-11 深圳中兴网信科技有限公司 Generation method, computer equipment and the readable storage medium storing program for executing of environment measuring model
CN110659573A (en) * 2019-08-22 2020-01-07 北京捷通华声科技股份有限公司 Face recognition method and device, electronic equipment and storage medium
CN111008654A (en) * 2019-11-26 2020-04-14 江苏艾佳家居用品有限公司 Method and system for identifying rooms in house type graph
CN111488945A (en) * 2020-04-17 2020-08-04 上海眼控科技股份有限公司 Image processing method, image processing device, computer equipment and computer readable storage medium
CN111860147A (en) * 2020-06-11 2020-10-30 北京市威富安防科技有限公司 Pedestrian re-identification model optimization processing method and device and computer equipment

Patent Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108921092A (en) * 2018-07-02 2018-11-30 浙江工业大学 A kind of melanoma classification method based on convolutional neural networks model Two-level ensemble
CN109711544A (en) * 2018-12-04 2019-05-03 北京市商汤科技开发有限公司 Method, apparatus, electronic equipment and the computer storage medium of model compression
CN110021051A (en) * 2019-04-01 2019-07-16 浙江大学 One kind passing through text Conrad object image generation method based on confrontation network is generated
CN110147456A (en) * 2019-04-12 2019-08-20 中国科学院深圳先进技术研究院 A kind of image classification method, device, readable storage medium storing program for executing and terminal device
CN110197212A (en) * 2019-05-20 2019-09-03 北京邮电大学 Image classification method, system and computer readable storage medium
CN110321928A (en) * 2019-06-03 2019-10-11 深圳中兴网信科技有限公司 Generation method, computer equipment and the readable storage medium storing program for executing of environment measuring model
CN110659573A (en) * 2019-08-22 2020-01-07 北京捷通华声科技股份有限公司 Face recognition method and device, electronic equipment and storage medium
CN111008654A (en) * 2019-11-26 2020-04-14 江苏艾佳家居用品有限公司 Method and system for identifying rooms in house type graph
CN111488945A (en) * 2020-04-17 2020-08-04 上海眼控科技股份有限公司 Image processing method, image processing device, computer equipment and computer readable storage medium
CN111860147A (en) * 2020-06-11 2020-10-30 北京市威富安防科技有限公司 Pedestrian re-identification model optimization processing method and device and computer equipment

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
KAIMING HE ET AL: "Delving Deep into Rectifiers:Surpassing Human-Level Performance on ImageNet Classification", 《ARXIV》 *
李汉冰等: "基于YOLOV3改进的实时车辆检测方法", 《激光与光电子学进展》 *

Similar Documents

Publication Publication Date Title
WO2019100724A1 (en) Method and device for training multi-label classification model
WO2019100723A1 (en) Method and device for training multi-label classification model
US20230196117A1 (en) Training method for semi-supervised learning model, image processing method, and device
US20220058426A1 (en) Object recognition method and apparatus, electronic device, and readable storage medium
US20210012198A1 (en) Method for training deep neural network and apparatus
CN109871781B (en) Dynamic gesture recognition method and system based on multi-mode 3D convolutional neural network
US11417148B2 (en) Human face image classification method and apparatus, and server
US11232286B2 (en) Method and apparatus for generating face rotation image
Yun et al. Focal loss in 3d object detection
US11244191B2 (en) Region proposal for image regions that include objects of interest using feature maps from multiple layers of a convolutional neural network model
US11755889B2 (en) Method, system and apparatus for pattern recognition
CN111191526B (en) Pedestrian attribute recognition network training method, system, medium and terminal
WO2017096753A1 (en) Facial key point tracking method, terminal, and nonvolatile computer readable storage medium
CN111133453B (en) Artificial neural network
US20180144246A1 (en) Neural Network Classifier
CN109961107B (en) Training method and device for target detection model, electronic equipment and storage medium
CN111222487B (en) Video target behavior identification method and electronic equipment
US11093800B2 (en) Method and device for identifying object and computer readable storage medium
CN111401521B (en) Neural network model training method and device, and image recognition method and device
CN110968734A (en) Pedestrian re-identification method and device based on depth measurement learning
CN113781164B (en) Virtual fitting model training method, virtual fitting method and related devices
Gorijala et al. Image generation and editing with variational info generative AdversarialNetworks
CN114170654A (en) Training method of age identification model, face age identification method and related device
CN112749737A (en) Image classification method and device, electronic equipment and storage medium
Nida et al. Video augmentation technique for human action recognition using genetic algorithm

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
WD01 Invention patent application deemed withdrawn after publication
WD01 Invention patent application deemed withdrawn after publication

Application publication date: 20210511