CN112990298B - Key point detection model training method, key point detection method and device - Google Patents

Key point detection model training method, key point detection method and device Download PDF

Info

Publication number
CN112990298B
CN112990298B CN202110263320.6A CN202110263320A CN112990298B CN 112990298 B CN112990298 B CN 112990298B CN 202110263320 A CN202110263320 A CN 202110263320A CN 112990298 B CN112990298 B CN 112990298B
Authority
CN
China
Prior art keywords
key point
loss function
network
prediction result
training sample
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110263320.6A
Other languages
Chinese (zh)
Other versions
CN112990298A (en
Inventor
刘京
张慧
王雅丽
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Irisking Co ltd
Original Assignee
Beijing Irisking Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Irisking Co ltd filed Critical Beijing Irisking Co ltd
Priority to CN202110263320.6A priority Critical patent/CN112990298B/en
Publication of CN112990298A publication Critical patent/CN112990298A/en
Application granted granted Critical
Publication of CN112990298B publication Critical patent/CN112990298B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02TCLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
    • Y02T10/00Road transport of goods or passengers
    • Y02T10/10Internal combustion engine [ICE] based vehicles
    • Y02T10/40Engine management systems

Abstract

The invention provides a key point detection model training method, a key point detection method and a device, wherein the training method comprises the following steps: acquiring a training sample set; the first type of training sample comprises an image, a key point position and a key point availability, and the second type of training sample comprises an image; training a teacher network using a first type of training samples and a first total loss function; the first total loss function includes a loss for constraining the location and availability of the keypoints; when the training sample is selected to be of the second type, the training sample is respectively input into a teacher network and a student network after training, and a key point position and availability prediction result are obtained; inputting the prediction result to a second total loss function to train the student network through knowledge distillation; the second total loss function includes a loss for constraining the student to learn the location of the keypoints of the teacher network and the availability prediction results. Therefore, the time consumption of marking can be reduced, the detection effect of the low-quality image can be improved, and the detection instantaneity can be improved.

Description

Key point detection model training method, key point detection method and device
Technical Field
The present invention relates to the field of image processing technologies, and in particular, to a method and apparatus for training a key point detection model.
Background
Currently, the key point detection task mainly faces the following three problems:
firstly, the marking of the key point training data is extremely time-consuming, and especially when a new key point detection task is faced, the marking work of the key points is more time-consuming and complex. And when different labeling personnel label, the labeling errors of the key points can be introduced in the labeling task due to the personal judgment and the labeling difference. The position of the same key point fluctuates in a small range due to the marking error, so that the difficulty of the training convergence of the key point is increased.
Secondly, detection of low quality data is difficult, for example, when there are multiple keypoint detection tasks, some keypoints are usually blocked or have poor quality when the blocked, pose and illumination or the marked part of the keypoints are not present in the image, etc. In this case, the keypoint marker cannot accurately determine the accurate position of the keypoint, and if all the keypoints are sent to the network for learning, the accuracy of the whole keypoint learning will be affected.
Thirdly, in order to continuously improve the effect of the key point detection network, the existing network is mostly stacked continuously to improve the accuracy of key point detection, for example: hourglass, CPN (dual-propagation neural network), etc. However, in the deployment of practical use, the inference speed of the large-scale network is very slow, and the requirements on resources such as memory, video memory and the like are higher, so that the performance requirements in the practical use cannot be met.
Disclosure of Invention
In view of the foregoing, the present invention provides a method for training a keypoint detection model, a method for detecting keypoints, and a device thereof, so as to reduce one or more drawbacks of the prior art.
In order to achieve the above purpose, the invention is realized by adopting the following scheme:
according to an aspect of the embodiment of the present invention, there is provided a keypoint detection model training method, including:
acquiring a training sample set; the training sample set comprises a first type of training sample and a second type of training sample, wherein the first type of training sample comprises a training image, positions of key points in the training image and availability of corresponding key points, and the second type of training sample comprises a training image;
Training a teacher network by using a first type of training sample and a first total loss function to obtain a trained teacher network; wherein the first total loss function includes a loss function for constraining availability of the keypoint and a loss function for constraining a location of the keypoint;
selecting a first training sample from the training sample set, inputting the first training sample into a trained teacher network to obtain a first key point position prediction result and a corresponding first key point availability prediction result under the condition that the first training sample is the second type training sample, and inputting the first training sample into a student network to obtain a second key point position prediction result and a corresponding second key point availability prediction result;
inputting the first key point position prediction result, the first key point availability prediction result, the second key point position prediction result and the second key point availability prediction result into a second total loss function to obtain a first value of the second total loss function, feeding back the first value of the second total loss function to a student network, training the student network by using a second type of training sample based on knowledge distillation of a trained teacher network, and obtaining a key point detection model according to the trained student network or a network comprising the trained student network and the trained teacher network; wherein the second total loss function includes a loss function for constraining a keypoint location predictor of the student network to approach a keypoint location predictor of the teacher network and a loss function for constraining a keypoint availability predictor of the student network to approach a keypoint availability predictor of the teacher network.
In some embodiments, the second total loss function further comprises a loss function for constraining a keypoint location predictor of the student network to approach a known keypoint location and a loss function for constraining a keypoint availability predictor of the student network to approach a known keypoint availability predictor;
before obtaining the key point detection model according to the trained student network or the network comprising the trained student network and the trained teacher network, the method further comprises:
and selecting a second training sample from the training sample set, inputting the second training sample into a student network under the condition that the second training sample is the training sample of the first type to obtain a third key point position prediction result and a corresponding third key point availability prediction result, inputting the third key point position prediction result and the third key point availability prediction result into the second total loss function to obtain a second value of the second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network by using the training sample of the first type.
In some embodiments, selecting a second training sample from the training sample set, inputting the second training sample to a student network to obtain a third keypoint location prediction result and a corresponding third keypoint availability prediction result if the second training sample is the first type of training sample, inputting the third keypoint location prediction result and the third keypoint availability prediction result to the second total loss function to obtain a second value of the second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network using the first type of training sample, including:
And selecting a second training sample from the training sample set, shielding a loss function of a key point position prediction result used for restraining the student network from approaching a key point position prediction result of a teacher network and a loss function of a key point availability prediction result used for restraining the student network from approaching the teacher network in the second total loss function under the condition that the second training sample is the first type training sample, inputting the second training sample into the student network to obtain a third key point position prediction result and a corresponding third key point availability prediction result, inputting the third key point position prediction result and the third key point availability prediction result into the second total loss function to obtain a second value of the second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network by using the first type training sample.
In some embodiments, the loss function used to constrain the keypoint location prediction result of the student network to approach the keypoint location prediction result of the teacher network is a mean square error loss function, and/or the loss function used to constrain the keypoint location prediction result of the student network to approach the known keypoint location is a mean square error loss function; the loss function used to constrain the keypoint availability prediction result of the student network to approach the keypoint availability prediction result of the teacher network is a cross entropy loss function and/or the loss function used to constrain the keypoint availability prediction result of the student network to approach the known keypoint availability prediction result is a cross entropy loss function.
In some embodiments, each training sample of the first type includes locations of a plurality of keypoints and/or the computing resources required by the teacher network are greater than those required by the student network.
According to an aspect of the embodiment of the present invention, there is also provided a key point detection method, including:
obtaining a key point detection model obtained by training a teacher network and a student network by using the key point detection model training method in any embodiment;
and detecting the key points in the image to be detected by using the acquired key point detection model.
In some embodiments, detecting keypoints in the image to be detected using the acquired keypoint detection model comprises:
and detecting the key points in the image to be detected in real time by using the trained student network in the acquired key point detection model.
According to an aspect of the embodiment of the present invention, there is further provided a keypoint detection device, which trains a keypoint detection model obtained by training a teacher network and a student network by using the keypoint detection model training method described in any one of the embodiments.
According to an aspect of the embodiments of the present invention, there is also provided an electronic device including a memory, a processor and a computer program stored on the memory and executable on the processor, the processor implementing the steps of the method according to any of the embodiments described above when the program is executed.
According to an aspect of an embodiment of the present invention, there is also provided a computer-readable storage medium having stored thereon a computer program which, when executed by a processor, implements the steps of the method of any of the embodiments described above.
The key point detection model training method, the key point detection device, the electronic equipment and the computer readable storage medium can reduce the marking problem of a large amount of training data in the key point detection problem, the low-quality image key point detection problem and the key point model precision and the speed problem during deployment.
Drawings
In order to more clearly illustrate the embodiments of the invention or the technical solutions in the prior art, the drawings that are required in the embodiments or the description of the prior art will be briefly described, it being obvious that the drawings in the following description are only some embodiments of the invention, and that other drawings may be obtained according to these drawings without inventive effort for a person skilled in the art. In the drawings:
FIG. 1 is a flow chart of a method for training a key point detection model according to an embodiment of the present invention;
FIG. 2 is a flow chart of a method for training a keypoint detection model in accordance with an embodiment of the present invention.
Detailed Description
For the purpose of making the objects, technical solutions and advantages of the embodiments of the present invention more apparent, the embodiments of the present invention will be described in further detail with reference to the accompanying drawings. The exemplary embodiments of the present invention and their descriptions herein are for the purpose of explaining the present invention, but are not to be construed as limiting the invention.
It is to be noted in advance that the description of the embodiments or examples below or the features mentioned therein may be combined with or replace features in other embodiments or examples in the same or similar way to form possible implementations. Furthermore, the term "comprises/comprising" as used herein means the presence of a feature, element, step or component, but does not exclude the presence of one or more other features, elements, steps or components.
Aiming at the problems that the time consumption of the key point labeling is long, the low-quality image affects the key point learning accuracy and the real-time performance is poor, the invention provides a key point detection model training method, so that the time consumption of the key point labeling is reduced based on knowledge distillation, the influence of the low-quality image on the key point learning accuracy is reduced, and the improvement of the key point detection real-time performance is facilitated.
Fig. 1 is a flow chart of a training method of a key point detection model according to an embodiment of the invention, as shown in fig. 1, the training method of a key point detection model according to the embodiments may include the following steps S110 to S140.
Specific embodiments of step S110 to step S140 will be described in detail below.
Step S110: acquiring a training sample set; the training sample set comprises a first type of training sample and a second type of training sample, the first type of training sample comprises a training image, positions of key points in the training image and availability of corresponding key points, and the second type of training sample comprises a training image.
In step S110, the first type of training sample is a training sample with known positions of key points and availability of key points in the image, where the positions of the key points and the availability of the key points may be manually marked, or may be known through other approaches. Each training sample may include a training image, and there may be a plurality of keypoints in a training image, so that the first type of training sample may include the positions of the plurality of keypoints. A plurality of keypoints may be included in the training sample of the first type, including the positions of the plurality of keypoints (positions in the image, such as pixel positions), and the availability of each keypoint, such as 1 if available, or 0 if not available.
Step S120: training a teacher network by using a first type of training sample and a first total loss function to obtain a trained teacher network; wherein the first total loss function includes a loss function for constraining availability of the keypoint and a loss function for constraining a location of the keypoint.
In step S120, the loss function for constraining the positions of the keypoints in the first total loss function may enable the teacher network to perform training learning with respect to the positions of the keypoints in the input training sample, and the loss function for constraining the availability of the keypoints in the first total loss function may enable the teacher network to perform training learning with respect to the availability of the keypoints in the input training sample.
The penalty function used to constrain the availability of the keypoints may be a variety of available penalty functions. The penalty function for constraining the location of the keypoints may be a variety of available penalty functions. The first total loss function may be formed by the two-part loss function.
For example, the loss function for constraining the position of the key points may be a mean square error loss function, the loss function for constraining the availability of the key points may be a cross entropy loss function, the first total loss function L T Can be expressed as:
L T =L T_class +L T_mse
wherein L is T_class A loss function, L, representing availability for constraint keypoints T_mse A loss function representing the location of the constraint keypoints,representing the known availability of the kth key point in the ith training image, +.>Teacher network predicted key point availability representing the kth key point in the ith training image, K representing the number of key points in the training image, m k A known confidence map representing the kth key point in each training image, m k t A confidence map representing teacher network predictions for the kth keypoints in each training image.
In particular, the step S120 may include: the training images in the first type training samples are input into a teacher network, the teacher network outputs a key point position prediction result and a corresponding key point availability prediction result, then the positions of key points in the first type training samples and the availability of corresponding key points, the output key point position prediction result and the corresponding key point availability prediction result are brought into a first total loss function, the value of the loss function can be obtained, and the value of the loss function is fed back to the teacher network to train the teacher network. The training of the teacher network can be completed when a certain training frequency is reached or the value of the loss function reaches a certain threshold range.
Step S130: and selecting a first training sample from the training sample set, inputting the first training sample into a trained teacher network to obtain a first key point position prediction result and a corresponding first key point availability prediction result under the condition that the first training sample is the second type training sample, and inputting the first training sample into a student network to obtain a second key point position prediction result and a corresponding second key point availability prediction result.
In step S130, if the second type of training sample only includes the training image and does not include the position and the corresponding availability information of the key points, the training image in the second type of training sample may be input into the trained teacher network to obtain the training result of the teacher network, that is, the first key point position prediction result and the corresponding first key point availability prediction result, which may be used as the learning target of the student network. And inputting training images in the second type of training samples into the student network to obtain a prediction result of the student network, namely, a second key point position prediction result and a corresponding second key point availability prediction result, so that the difference between the prediction result and the target can be obtained.
Step S140: inputting the first key point position prediction result, the first key point availability prediction result, the second key point position prediction result and the second key point availability prediction result into a second total loss function to obtain a first value of the second total loss function, feeding back the first value of the second total loss function to a student network, training the student network by using a second type of training sample based on knowledge distillation of a trained teacher network, and obtaining a key point detection model according to the trained student network or a network comprising the trained student network and the trained teacher network; wherein the second total loss function includes a loss function for constraining a keypoint location predictor of the student network to approach a keypoint location predictor of the teacher network and a loss function for constraining a keypoint availability predictor of the student network to approach a keypoint availability predictor of the teacher network.
In the step S140, a loss function used for constraining the keypoint location prediction result of the student network to approach the keypoint location prediction result of the teacher network in the second total loss function may be used for representing a difference between the keypoint location prediction result of the student network and the keypoint location prediction result of the teacher network in the case that the input training images are the same, and a loss function used for constraining the keypoint availability prediction result of the student network to approach the keypoint availability prediction result of the teacher network may represent a difference between the keypoint availability prediction result of the student network and the keypoint availability prediction result of the teacher network. The student network can be adjusted according to the prediction result of the teacher network by bringing the second key point position prediction result and the second key point availability prediction result of the student network and the first key point position prediction result and the first key point availability prediction result of the teacher network into the second total loss function, so that the teacher network can distill learned knowledge to the student network.
The teacher network and the student network may be existing networks for detecting key points in the image, and may include convolutional neural networks. The computing resources required by the teacher's network may be greater than the computing resources required by the student's network. The network complexity of the teacher network can be generally higher than that of the student network, such as a plurality of network layers, the teacher network learning takes longer time, the student network learning takes shorter time, and the prediction result of the teacher network is more accurate, so that knowledge learned by the teacher network is distilled to the student network, the learning accuracy can be ensured as much as possible, and the time consumption can be reduced as much as possible. Further, since the teacher network and the student network learn not only the positions but also the availability of the predicted key points, it is possible to detect effective key points, and thus, even for low-quality images, it is possible to accurately detect them.
In a further embodiment, in the training method of the keypoint detection model shown in fig. 1, the second total loss function may further include a loss function for constraining the predicted result of the keypoint location of the student network to approach the known keypoint location and a loss function for constraining the predicted result of the keypoint availability of the student network to approach the predicted result of the known keypoint availability.
The "known keypoint location" in the loss function used to constrain the keypoint location prediction result of the student network to approach the known keypoint location may be the location of the keypoint that is obtained by labeling or other means, and the "known keypoint availability prediction result" in the loss function used to constrain the keypoint availability prediction result of the student network to approach the known keypoint availability prediction result may be the availability of the keypoint that is obtained by labeling or other means. Therefore, by having the second total loss function further include a loss function for constraining the keypoint location prediction result of the student network to approach a known keypoint location and a loss function for constraining the keypoint availability prediction result of the student network to approach a known keypoint availability prediction result, the student network can be trained with training samples labeled keypoint-related information based on the second total loss function. The training results are more accurate and the student network can be trained using labeled and unlabeled training samples.
Specifically, the loss function for constraining the student network's keypoint location predictor to approach the known keypoint location and the loss function for constraining the student network's keypoint availability predictor to approach the known keypoint availability predictor may correspond to one weight, the loss function for constraining the student network's keypoint location predictor to approach the teacher network's keypoint location predictor and the loss function for constraining the student network's keypoint availability predictor to approach the teacher network's keypoint availability predictor may correspond to the other weight, and by adjusting the ratio of these two weights, the impact of these two types of loss functions on the overall loss function may be adjusted.
In some embodiments, the loss function used to constrain the keypoint location prediction result of the student network to approach the keypoint location prediction result of the teacher network is a mean square error loss function, and/or the loss function used to constrain the keypoint location prediction result of the student network to approach the known keypoint location is a mean square error loss function; the loss function used to constrain the keypoint availability prediction result of the student network to approach the keypoint availability prediction result of the teacher network is a cross entropy loss function and/or the loss function used to constrain the keypoint availability prediction result of the student network to approach the known keypoint availability prediction result is a cross entropy loss function.
The loss function used to constrain the keypoint location prediction result of the student network to approach the known keypoint location may be a mean square error loss function and the loss function used to constrain the keypoint availability prediction result of the student network to approach the known keypoint availability prediction result may be a cross entropy loss function. In addition, the key point location prediction result of the teacher network and the key point location prediction result of the student network may be represented by a confidence map.
For example, the second total loss function L may be expressed as:
L=α(L ST_class +L kd )+(1-α)(L SG_class +L S_mse ),
Wherein L is ST_class Representing a loss function for constraining a keypoint availability predictor of a student network to approach a known keypoint availability predictor, L kd Representing a loss function for constraining a student's network's keypoint location prediction to approach a teacher's network's keypoint location prediction, L SG_class A loss function representing a near-critical point availability prediction result for constraining a student network to a teacher network, L S_mse Representing a loss function for constraining the student's network's keypoint location prediction results to approach the teacher's network's keypoint location prediction results, alpha represents a weight parameter,representing the ith trainingKnown availability of the kth key point of the image,/->A teacher network prediction availability probability representing the kth key point of the ith training image,/for the teacher network prediction>The probability of availability of student network predictions representing the kth keypoint of the ith training image, m k t Confidence map representing teacher network prediction of kth key point in each training image, m k s Confidence map representing student network predictions of kth key point in each training image, m k A known confidence map representing the kth key point in each training image,/for each training image>A confidence map representing student network predictions for the kth keypoint in each training image. Wherein the proportions of the different loss functions can be adjusted by adjusting the value of the weight parameter α, for example α can be set to 0.5.
In some embodiments, the student network may be trained directly using a first type of training sample. Illustratively, before the key point detection model is obtained according to the trained student network or the network including the trained student network and the trained teacher network in step S140, the method shown in fig. 1 may further include the steps of: and S150, selecting a second training sample from the training sample set, inputting the second training sample into a student network to obtain a third key point position prediction result and a corresponding third key point availability prediction result under the condition that the second training sample is the first type training sample, inputting the third key point position prediction result and the third key point availability prediction result into the second total loss function to obtain a second value of the second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network by using the first type training sample.
In this embodiment, for already annotated training samples, such samples may be used directly to train the student network.
In a further embodiment, the proportions of the various loss functions in the second total loss function may be adjusted to adjust the magnitude of the effect of the various loss functions. For example, the step S150 may specifically include the steps of: and selecting a second training sample from the training sample set, shielding a loss function of a key point position prediction result used for restraining the student network from approaching a key point position prediction result of a teacher network and a loss function of a key point availability prediction result used for restraining the student network from approaching the teacher network in the second total loss function under the condition that the second training sample is the first type training sample, inputting the second training sample into the student network to obtain a third key point position prediction result and a corresponding third key point availability prediction result, inputting the third key point position prediction result and the third key point availability prediction result into the second total loss function to obtain a second value of the second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network by using the first type training sample.
In this embodiment, when training the student network using the training image labeled with the keypoints, the teacher network may be masked by adjusting the weights to guide the relevant loss function of the student network, for example, the weight parameter α in the second total loss function L may be adjusted to 1 to mask the loss function L used to constrain the keypoint availability prediction result of the student network to be close to the keypoint availability prediction result of the teacher network SG_class And a loss function L for constraining the student network's keypoint location prediction result to approach the teacher network's keypoint location prediction result S_mse
Based on the same inventive concept as the key point detection model training method shown in fig. 1, the embodiment of the invention further provides a key point detection model training device, which includes a key point detection model obtained by training a teacher network and a student network by using the key point detection model training method described in any one of the embodiments. Because the principle of solving the problem of the key point detection model training device is similar to that of the key point detection model training method, the implementation of the key point detection model training device can refer to the implementation of the key point detection model training method, and the repetition is omitted.
In addition, the embodiment of the invention also provides a key point detection method. The key point detection method can comprise the following steps:
s210, acquiring a key point detection model obtained by training a teacher network and a student network by using the key point detection model training method in any one of the embodiments;
s220, detecting key points in the image to be detected by using the acquired key point detection model.
In a further embodiment, the step S220 of detecting the keypoints in the image to be detected by using the obtained keypoint detection model may specifically include the steps of: s221, detecting key points in the image to be detected in real time by using the trained student network in the acquired key point detection model. In this embodiment, the student network may have a simpler network structure, so the detection speed is faster, so real-time detection can be performed better.
In addition, the embodiment of the invention also provides an electronic device, which comprises a memory, a processor and a computer program stored on the memory and capable of running on the processor, wherein the processor realizes the steps of the key point detection model training method according to any one of the embodiments or the key point detection method according to any one of the embodiments when executing the program.
In addition, an embodiment of the present application further provides a computer readable storage medium, where a computer program is stored, where the program when executed by a processor implements the method for training a keypoint detection model according to any one of the above embodiments or the steps of the method for detecting a keypoint according to any one of the above embodiments.
The above method is described below in connection with a specific embodiment, however, it should be noted that this specific embodiment is only for better illustrating the present application and is not meant to be a undue limitation on the present application.
FIG. 2 is a flow chart of a method for training a keypoint detection model in accordance with an embodiment of the present application. Referring to fig. 2, the keypoint detection model training method of an embodiment may include the following process:
1) Firstly, picking out a part of images (training images of training samples) from a training set (training sample set), and marking key points of the images by using a marking tool to serve as a key point training set of a training teacher network; the training set obtained by labeling is marked asIt includes N training images, each training image I i There are K keypoints and each keypoint availability noted.
Wherein, the ith training image I i Is set of key point positions P i Can be expressed as:
wherein,represents the positions of K key points, +.>Representing a K x 2-dimensional real set.
Ith training image I i Key point availability set beta i Can be expressed as:
wherein,representing the availability of K keypoints, < ->Representing a K-dimensional real set.
Can be used asTo indicate the availability of each key point, when a position is occluded or not present in the image, etc., the key point located at that position +.>The keypoints at other positions are marked with bit 1, marked with 0. Using confidence map (confidence map) m k Represents the kth key point z k =(x k ,y k ) Is a position of (c).
In the x-y coordinate system, the key point position m k (x, y) can be expressed as:
wherein x is k X coordinate value, y representing the kth key point k The y-coordinate value representing the kth key point, σ represents a parameter that can control the range and size of the location of the key point on the confidence map.
2) The teacher network can be selected mainly by considering the accuracy of key point detection, and a network with complex calculation resources and good effect can be selected. A teacher network can be trained by using the marked key points; the teacher network obtained by training the complex model generally has slower reasoning speed, and has higher requirements on resources such as memory, video memory and the like when being deployed.
The teacher network can select cross entropy loss to restrict the classification learning of the usability of the key points during training, and for an input training image i, the loss function used by the teacher network to learn the usability of the key points can be expressed as:
wherein,a group-trunk tag indicating whether the marked kth key is available,/or not>Indicating the probability of whether the kth key point predicted by the teacher network is available.
The teacher network can select a mean square error loss function (MSE loss function) to restrict the learning of the key points during training, and the loss function used by the teacher network to learn the key point positions can be expressed as:
wherein m is k Representing a key point confidence map obtained by image annotation, m k t And representing a key point confidence map obtained by prediction of the teacher network.
Finally, the total loss function of the training teacher network may be expressed as:
L T =L T_class +L T_mse
3) The student network is selected, a network with smaller parameter quantity and relatively simple model structure can be selected, and in actual deployment, the network with lower requirements on resources such as memory, video memory and the like can be operated in real time.
4) A student network may be learned by distillation from the teacher network. When the input training image is X, if the key points are marked, the key points are marked as G, and when the image input into the whole network is not marked with the key points, the image can be firstly sent into a teacher network to obtain a more accurate key point result Y. Then, the learning goal of the student network is that the output of the student network approaches the output Y of the group Truth G and the teacher network in the case where the input training image X is the same.
Labeling results of training imagesAnd m k Can pass through the loss function L SG_class And L S_mse Guiding students to learn the availability of network key points and the accuracy of the key points, and losing the function L SG_class And L S_mse The respective terms can be expressed as:
wherein,indicating the probability of whether the kth key point predicted by the student network is available.
Output result of teacher network itselfAnd m k t Can pass through the loss function L ST_class And L kd Guiding students to learn the availability of network key points and the accuracy of the key points, and losing the function L ST_class And L kd The respective terms can be expressed as:
the key points of the output of the student network are made to approximate the output of the teacher network by the constraints of the loss function. In the actual training process, when the training image has marking information, marking key points and a teacher network can be used for predicting the key point information and guiding the training of a student network; when the training image does not have keypoint labeling information, the predicted keypoint information of the teacher network may be used to guide the training of the student network. The final loss function of the network can be expressed as:
L=α(L ST_class +L kd )+(1-α)(L SG_class +L S_mse ),
where α represents a weight parameter. By controlling the magnitude of α, the proportion of the two loss functions can be controlled. When the input image has the label group Truth, alpha=0.5 can be set, and the proportion of the teacher network and the group Truth can be adjusted by modifying the alpha; when the input image is not labeled with the group trunk, α=1 may be set, and only the teacher network is used to guide learning of the student network.
The student network obtained through training can provide accurate key point detection, and can also provide information whether each key point is available or not, so that the application of the student network in the subsequent key points is guided.
In the embodiment, the teacher network provides guidance of the position information of the key points for students through knowledge distillation and provides information whether the key points are available or not, so that the learning of the key points, particularly the learning of the key points of low-quality images, by the student network can be better guided. The difficulty in marking training data, particularly low-quality data, in key point training is solved through knowledge distillation. The difficulty in marking the key point data in the key point training can be solved, the marked and unmarked data can be trained before, and the richness of the training data is increased; in the practical application process, knowledge distillation is utilized to reduce the data marking time of the key points, so that the model can be quickly trained and iterated, and the trained key point model has strong generalization performance; the method can well solve the problem of detecting the key points of low-quality data and provide guidance for the subsequent use of the key points. The knowledge learned by the large model is distilled to the small model through knowledge distillation, so that the small model can rapidly infer, and meanwhile, the precision of key point detection can be close to the detection precision of the large model. Therefore, the method and the device can solve the annotation problem of a large amount of training data in the key point detection problem, the low-quality image key point detection problem, the precision of the key point model and the speed problem in deployment.
In summary, the key point detection model training method, the key point detection device, the electronic equipment and the computer readable storage medium according to the embodiments of the present invention train the teacher network by using the image with the known key point information, train the student network by using the loss function including the prediction result of the teacher network, that is, by using the image with the prediction result of the teacher network as the target and the unknown key point information, distill the knowledge learned by the teacher network to the student network, and train the student network by using the unlabeled training sample, so that the problem of time consuming and complicating the key point labeling of the image is reduced. Moreover, since the teacher network and the student network learn not only the position prediction but also the availability prediction of the key points, the student network can give a good detection effect even for low-quality images. Further, since the prediction result of the teacher network is accurate, the prediction result is used as a learning target of the student network, and the learning stability of the key points is also good. In addition, as the network structure of the student network is simpler and the learning time is shorter, the real-time performance of the key point detection is improved.
In the description of the present specification, reference to the terms "one embodiment," "one particular embodiment," "some embodiments," "for example," "an example," "a particular example," or "some examples," etc., means that a particular feature, structure, material, or characteristic described in connection with the embodiment or example is included in at least one embodiment or example of the invention. In this specification, schematic representations of the above terms do not necessarily refer to the same embodiments or examples. Furthermore, the particular features, structures, materials, or characteristics described may be combined in any suitable manner in any one or more embodiments or examples. The order of steps involved in the embodiments is illustrative of the practice of the invention, and is not limited and may be suitably modified as desired.
It will be appreciated by those skilled in the art that embodiments of the present invention may be provided as a method, system, or computer program product. Accordingly, the present invention may take the form of an entirely hardware embodiment, an entirely software embodiment or an embodiment combining software and hardware aspects. Furthermore, the present invention may take the form of a computer program product embodied on one or more computer-usable storage media (including, but not limited to, disk storage, CD-ROM, optical storage, and the like) having computer-usable program code embodied therein.
The present invention is described with reference to flowchart illustrations and/or block diagrams of methods, apparatus (systems) and computer program products according to embodiments of the invention. It will be understood that each flow and/or block of the flowchart illustrations and/or block diagrams, and combinations of flows and/or blocks in the flowchart illustrations and/or block diagrams, can be implemented by computer program instructions. These computer program instructions may be provided to a processor of a general purpose computer, special purpose computer, embedded processor, or other programmable data processing apparatus to produce a machine, such that the instructions, which execute via the processor of the computer or other programmable data processing apparatus, create means for implementing the functions specified in the flowchart flow or flows and/or block diagram block or blocks.
These computer program instructions may also be stored in a computer-readable memory that can direct a computer or other programmable data processing apparatus to function in a particular manner, such that the instructions stored in the computer-readable memory produce an article of manufacture including instruction means which implement the function specified in the flowchart flow or flows and/or block diagram block or blocks.
These computer program instructions may also be loaded onto a computer or other programmable data processing apparatus to cause a series of operational steps to be performed on the computer or other programmable apparatus to produce a computer implemented process such that the instructions which execute on the computer or other programmable apparatus provide steps for implementing the functions specified in the flowchart flow or flows and/or block diagram block or blocks.
The foregoing description of the embodiments has been provided for the purpose of illustrating the general principles of the invention, and is not meant to limit the scope of the invention, but to limit the invention to the particular embodiments, and any modifications, equivalents, improvements, etc. that fall within the spirit and principles of the invention are intended to be included within the scope of the invention.

Claims (10)

1. The key point detection model training method is characterized by comprising the following steps of:
acquiring a training sample set; the training sample set comprises a first type of training sample and a second type of training sample, wherein the first type of training sample comprises a training image, positions of key points in the training image and availability of corresponding key points, and the second type of training sample comprises a training image;
Training a teacher network by using a first type of training sample and a first total loss function to obtain a trained teacher network; wherein the first total loss function includes a loss function for constraining availability of the keypoint and a loss function for constraining a location of the keypoint;
selecting a first training sample from the training sample set, inputting the first training sample into a trained teacher network to obtain a first key point position prediction result and a corresponding first key point availability prediction result under the condition that the first training sample is the second type training sample, and inputting the first training sample into a student network to obtain a second key point position prediction result and a corresponding second key point availability prediction result;
inputting the first key point position prediction result, the first key point availability prediction result, the second key point position prediction result and the second key point availability prediction result into a second total loss function to obtain a first value of the second total loss function, feeding back the first value of the second total loss function to a student network, training the student network by using a second type of training sample based on knowledge distillation of a trained teacher network, and obtaining a key point detection model according to the trained student network or a network comprising the trained student network and the trained teacher network; wherein the second total loss function includes a loss function for constraining a keypoint location predictor of the student network to approach a keypoint location predictor of the teacher network and a loss function for constraining a keypoint availability predictor of the student network to approach a keypoint availability predictor of the teacher network;
Before obtaining the key point detection model according to the trained student network or the network comprising the trained student network and the trained teacher network, the method further comprises:
and selecting a second training sample from the training sample set, inputting the second training sample into a student network under the condition that the second training sample is the training sample of the first type to obtain a third key point position prediction result and a corresponding third key point availability prediction result, inputting the third key point position prediction result and the third key point availability prediction result into the second total loss function to obtain a second value of the second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network by using the training sample of the first type.
2. The keypoint detection model training method of claim 1, wherein the second total loss function further comprises a loss function for constraining a student network's keypoint location prediction result from approaching a known keypoint location and a loss function for constraining a student network's keypoint availability prediction result from approaching a known keypoint availability prediction result.
3. The keypoint detection model training method as claimed in claim 2, wherein selecting a second training sample from the training sample set, inputting the second training sample to a student network to obtain a third keypoint location prediction result and a corresponding third keypoint availability prediction result if the second training sample is the first type of training sample, and inputting the third keypoint location prediction result and the third keypoint availability prediction result to the second total loss function to obtain a second value of a second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network using the first type of training sample, comprising:
and selecting a second training sample from the training sample set, shielding a loss function of a key point position prediction result used for restraining the student network from approaching a key point position prediction result of a teacher network and a loss function of a key point availability prediction result used for restraining the student network from approaching the teacher network in the second total loss function under the condition that the second training sample is the first type training sample, inputting the second training sample into the student network to obtain a third key point position prediction result and a corresponding third key point availability prediction result, inputting the third key point position prediction result and the third key point availability prediction result into the second total loss function to obtain a second value of the second total loss function, and feeding back the second value of the second total loss function to the student network to train the student network by using the first type training sample.
4. The method for training a keypoint detection model of claim 2,
the loss function used for restraining the position prediction result of the key point of the student network from approaching the position prediction result of the key point of the teacher network is a mean square error loss function, and/or the loss function used for restraining the position prediction result of the key point of the student network from approaching the position of the known key point is a mean square error loss function;
the loss function used to constrain the keypoint availability prediction result of the student network to approach the keypoint availability prediction result of the teacher network is a cross entropy loss function and/or the loss function used to constrain the keypoint availability prediction result of the student network to approach the known keypoint availability prediction result is a cross entropy loss function.
5. The keypoint detection model training method according to any of the claims 1 to 4, characterized in that each training sample of the first type comprises a plurality of positions of keypoints and/or that the computing resources required by the teacher network are larger than those required by the student network.
6. A key point detection method, comprising:
acquiring a key point detection model obtained by training a teacher network and a student network by using the key point detection model training method according to any one of claims 1 to 5;
And detecting the key points in the image to be detected by using the acquired key point detection model.
7. The keypoint detection method as claimed in claim 6, wherein detecting the keypoints in the image to be detected using the acquired keypoint detection model comprises:
and detecting the key points in the image to be detected in real time by using the trained student network in the acquired key point detection model.
8. A keypoint detection apparatus, characterized in that a keypoint detection model obtained by training a network including a teacher and a student by using the keypoint detection model training method as claimed in any one of claims 1 to 5.
9. An electronic device comprising a memory, a processor and a computer program stored on the memory and executable on the processor, characterized in that the processor implements the steps of the method according to any one of claims 1 to 7 when executing the program.
10. A computer readable storage medium, on which a computer program is stored, characterized in that the program, when being executed by a processor, implements the steps of the method according to any one of claims 1 to 7.
CN202110263320.6A 2021-03-11 2021-03-11 Key point detection model training method, key point detection method and device Active CN112990298B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110263320.6A CN112990298B (en) 2021-03-11 2021-03-11 Key point detection model training method, key point detection method and device

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110263320.6A CN112990298B (en) 2021-03-11 2021-03-11 Key point detection model training method, key point detection method and device

Publications (2)

Publication Number Publication Date
CN112990298A CN112990298A (en) 2021-06-18
CN112990298B true CN112990298B (en) 2023-11-24

Family

ID=76334854

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110263320.6A Active CN112990298B (en) 2021-03-11 2021-03-11 Key point detection model training method, key point detection method and device

Country Status (1)

Country Link
CN (1) CN112990298B (en)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113496512B (en) * 2021-09-06 2021-12-17 北京字节跳动网络技术有限公司 Tissue cavity positioning method, device, medium and equipment for endoscope
CN113822254B (en) * 2021-11-24 2022-02-25 腾讯科技(深圳)有限公司 Model training method and related device
CN114898086B (en) * 2022-07-13 2022-09-20 山东圣点世纪科技有限公司 Target key point detection method based on cascade temperature control distillation

Citations (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108960419A (en) * 2017-05-18 2018-12-07 三星电子株式会社 For using student-teacher's transfer learning network device and method of knowledge bridge
CN109902798A (en) * 2018-05-31 2019-06-18 华为技术有限公司 The training method and device of deep neural network
CN110348352A (en) * 2019-07-01 2019-10-18 深圳前海达闼云端智能科技有限公司 Training method, terminal and storage medium for human face image age migration network
CN110674714A (en) * 2019-09-13 2020-01-10 东南大学 Human face and human face key point joint detection method based on transfer learning
CN111160474A (en) * 2019-12-30 2020-05-15 合肥工业大学 Image identification method based on deep course learning
CN111709409A (en) * 2020-08-20 2020-09-25 腾讯科技(深圳)有限公司 Face living body detection method, device, equipment and medium
CN112115894A (en) * 2020-09-24 2020-12-22 北京达佳互联信息技术有限公司 Training method and device for hand key point detection model and electronic equipment
CN112288086A (en) * 2020-10-30 2021-01-29 北京市商汤科技开发有限公司 Neural network training method and device and computer equipment

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US11694088B2 (en) * 2019-03-13 2023-07-04 Cortica Ltd. Method for object detection using knowledge distillation

Patent Citations (9)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108960419A (en) * 2017-05-18 2018-12-07 三星电子株式会社 For using student-teacher's transfer learning network device and method of knowledge bridge
CN109902798A (en) * 2018-05-31 2019-06-18 华为技术有限公司 The training method and device of deep neural network
WO2019228358A1 (en) * 2018-05-31 2019-12-05 华为技术有限公司 Deep neural network training method and apparatus
CN110348352A (en) * 2019-07-01 2019-10-18 深圳前海达闼云端智能科技有限公司 Training method, terminal and storage medium for human face image age migration network
CN110674714A (en) * 2019-09-13 2020-01-10 东南大学 Human face and human face key point joint detection method based on transfer learning
CN111160474A (en) * 2019-12-30 2020-05-15 合肥工业大学 Image identification method based on deep course learning
CN111709409A (en) * 2020-08-20 2020-09-25 腾讯科技(深圳)有限公司 Face living body detection method, device, equipment and medium
CN112115894A (en) * 2020-09-24 2020-12-22 北京达佳互联信息技术有限公司 Training method and device for hand key point detection model and electronic equipment
CN112288086A (en) * 2020-10-30 2021-01-29 北京市商汤科技开发有限公司 Neural network training method and device and computer equipment

Non-Patent Citations (5)

* Cited by examiner, † Cited by third party
Title
Distilling the Knowledge in a Neural Network;Geoffrey Hinton 等;《https://arxiv.org/pdf/1503.02531.pdf》;1-9 *
Learning efficient object detection models with knowledge distillation;Guobin Chen 等;《Proceedings of the 31st International Conference on Neural Information Processing 》;742–751 *
基于姿态对齐的行人重识别算法;郝智慧;《中国优秀硕士学位论文全文数据库 (信息科技辑)》;I138-496 *
基于生成对抗网络与知识蒸馏的人脸修复与表情识别;姜慧明;《中国优秀硕士学位论文全文数据库 (信息科技辑)》;I138-499 *
结合多任务迁移学习与知识蒸馏的人脸美丽预测研究;甘俊英 等;《信号处理》;第36卷(第7期);1151-1158 *

Also Published As

Publication number Publication date
CN112990298A (en) 2021-06-18

Similar Documents

Publication Publication Date Title
CN112990298B (en) Key point detection model training method, key point detection method and device
Bae et al. Rethinking class activation mapping for weakly supervised object localization
Hao et al. An end-to-end architecture for class-incremental object detection with knowledge distillation
CN109741332B (en) Man-machine cooperative image segmentation and annotation method
CN110674714B (en) Human face and human face key point joint detection method based on transfer learning
CN112052787B (en) Target detection method and device based on artificial intelligence and electronic equipment
CN110059672B (en) Method for class-enhanced learning of microscope cell image detection model
CN113269073B (en) Ship multi-target tracking method based on YOLO V5 algorithm
CN110506276A (en) The efficient image analysis of use environment sensing data
US20200134377A1 (en) Logo detection
CN109271539B (en) Image automatic labeling method and device based on deep learning
CN106845430A (en) Pedestrian detection and tracking based on acceleration region convolutional neural networks
CN112734803B (en) Single target tracking method, device, equipment and storage medium based on character description
Gu et al. Class-incremental instance segmentation via multi-teacher networks
CN116563738A (en) Uncertainty-based multi-stage guided small target semi-supervised learning detection method
Li et al. Real-time detection tracking and recognition algorithm based on multi-target faces
CN115393625A (en) Semi-supervised training of image segmentation from coarse markers
Fouad et al. A fish detection approach based on BAT algorithm
CN116681961A (en) Weak supervision target detection method based on semi-supervision method and noise processing
CN114898290A (en) Real-time detection method and system for marine ship
CN114708645A (en) Object identification device and object identification method
CN112287938A (en) Text segmentation method, system, device and medium
Liu et al. An improved method for small target recognition based on faster RCNN
CN110751197A (en) Picture classification method, picture model training method and equipment
Gong et al. Water Surface Object Detection Based on Neural Style Learning Algorithm

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant