CN113283598A - Model training method and device, storage medium and electronic equipment - Google Patents

Model training method and device, storage medium and electronic equipment Download PDF

Info

Publication number
CN113283598A
CN113283598A CN202110656444.0A CN202110656444A CN113283598A CN 113283598 A CN113283598 A CN 113283598A CN 202110656444 A CN202110656444 A CN 202110656444A CN 113283598 A CN113283598 A CN 113283598A
Authority
CN
China
Prior art keywords
processing result
regressor
updating
target
key point
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202110656444.0A
Other languages
Chinese (zh)
Inventor
龙明盛
江俊广
刘裕峰
蔡东阳
郭小燕
郑文
王建民
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Tsinghua University
Beijing Dajia Internet Information Technology Co Ltd
Original Assignee
Tsinghua University
Beijing Dajia Internet Information Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Tsinghua University, Beijing Dajia Internet Information Technology Co Ltd filed Critical Tsinghua University
Priority to CN202110656444.0A priority Critical patent/CN113283598A/en
Publication of CN113283598A publication Critical patent/CN113283598A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computational Linguistics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Evolutionary Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Data Exchanges In Wide-Area Networks (AREA)

Abstract

The present disclosure relates to a model training method, an apparatus, a storage medium, and an electronic device, where the model training method is applied to a key point detection model, and the key point detection model includes: the model training method comprises the following steps: determining a target object to be updated in the keypoint detection model, wherein the target object comprises: the feature extractor, the regressor, and the counterregressor; updating the network parameters of the target object to enable the key point detection model to migrate from the source domain to the target domain, so as to at least solve the problem that the domain adaptive method in the related art cannot realize the migration of the key point detection model between different data domains.

Description

Model training method and device, storage medium and electronic equipment
Technical Field
The present disclosure relates to the field of model training technologies, and in particular, to a model training method and apparatus, a storage medium, and an electronic device.
Background
In the related art, for the classification labeling of the deep learning model, the difficulty of labeling key points is high and the cost is high in a real scene, and the difficulty of labeling key points is low and the cost is low in a virtual scene, so that the current foreground direction is to transfer the model trained from virtual data to real data.
However, because a huge domain gap exists between the virtual data and the real data, and the current domain adaptation method (domain adaptation) is mainly applied to the classification problem, it is difficult to help the regression problem of the key point detection class, and the key point detection model cannot be migrated between different data domains.
Disclosure of Invention
The disclosure provides a model training method, a model training device, a storage medium and an electronic device, which at least solve the problem that a domain adaptive method in the related art cannot realize migration of a key point detection model between different data domains. The technical scheme of the disclosure is as follows:
according to a first aspect of the embodiments of the present disclosure, there is provided a model training method, where the model training method is applied to a keypoint detection model, and the keypoint detection model includes: the model training method comprises the following steps: determining a target object to be updated in the keypoint detection model, wherein the target object comprises: the feature extractor, the regressor, and the counterregressor; and updating the network parameters of the target object so as to make the key point detection model migrate from the source domain to the target domain.
Optionally, the step of updating the network parameter of the target object includes: the first network parameters in the feature extractor, the regressor and the counterregressor are updated.
Optionally, the step of updating the first network parameters in the feature extractor, the regressor and the counterregressor includes: normalizing the thermodynamic diagrams output by the regressor and the countercheck regressor to obtain a first processing result; calculating the ratio of the thermodynamic diagram corresponding to each key point label on the source domain to the sum of the thermodynamic diagrams corresponding to all key point labels on the source domain to obtain a second processing result; performing divergence calculation on the first processing result and the second processing result in a spatial dimension to obtain a third processing result; updating the first network parameters in the feature extractor, the regressor and the counterregressor based on the third processing result to obtain the third processing result again until the first network parameters are stopped to be updated when the third processing result meets a first preset condition.
Optionally, the step of updating the network parameter of the target object includes: updating second network parameters in the countermeasure regressor to maximize a difference between divergence of the countermeasure regressor in the target domain and divergence predicted by the regressor.
Optionally, the step of updating the second network parameter in the countermeasure regressor includes: acquiring a plurality of thermodynamic diagrams which are output by the regressor on the target domain and correspond to a plurality of key points, and calculating the sum of residual thermodynamic diagrams corresponding to the rest key points except the target key point in the plurality of key points; calculating the ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the residual thermodynamic diagrams to obtain a fourth processing result; acquiring a second thermodynamic diagram which is output by the countermeasure regressor on the target domain and corresponds to the first thermodynamic diagram, and performing normalization processing on the second thermodynamic diagram to obtain a fifth processing result; performing divergence calculation on the fourth processing result and the fifth processing result in a spatial dimension to obtain a sixth processing result; and updating the second network parameters in the counterregression device based on the sixth processing result to obtain the sixth processing result again until the second network parameters are stopped to be updated when the sixth processing result meets a second preset condition.
Optionally, the step of updating the network parameter of the target object includes: updating the third network parameter in the feature extractor to minimize a difference between the divergence of the countermeasure regressor in the target domain and the divergence predicted by the regressor.
Optionally, the step of updating the third network parameter in the feature extractor includes: acquiring a plurality of thermodynamic diagrams which are output by the regressor on the target domain and correspond to a plurality of key points, and calculating the sum of residual thermodynamic diagrams corresponding to the rest key points except the target key point in the plurality of key points; calculating the ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the residual thermodynamic diagrams to obtain a seventh processing result; acquiring a second thermodynamic diagram corresponding to the first thermodynamic diagram output by the countermeasure regressor on the target domain, and performing normalization processing on the second thermodynamic diagram to obtain an eighth processing result; performing divergence calculation on the seventh processing result and the eighth processing result in a spatial dimension to obtain a ninth processing result; and updating the third network parameter in the feature extractor based on the ninth processing result to obtain the ninth processing result again until the third network parameter is stopped to be updated when the ninth processing result meets a third preset condition.
According to a second aspect of the embodiments of the present disclosure, there is provided a model training apparatus, where the model training apparatus is applied to a keypoint detection model, and the keypoint detection model includes: the model training device comprises a feature extractor, a regressor and a counterregressor, and comprises: a determining unit configured to perform determining a target object to be updated in the keypoint detection model, wherein the target object includes: the feature extractor, the regressor, and the counterregressor; and the updating unit is configured to update the network parameters of the target object so as to enable the key point detection model to migrate from the source domain to the target domain.
Optionally, the update unit includes: a first updating subunit configured to perform updating of the first network parameters in the feature extractor, the regressor, and the counterregressor.
Optionally, the first updating subunit includes: the first processing subunit is configured to perform normalization processing on the thermodynamic diagrams output by the regressor and the countermeasure regressor to obtain a first processing result; the first calculating subunit is configured to calculate a ratio of the thermodynamic diagram corresponding to each key point label in the source domain to a sum of the thermodynamic diagrams corresponding to all the key point labels in the source domain, so as to obtain a second processing result; a second calculation unit configured to perform divergence calculation on the first processing result and the second processing result in a spatial dimension to obtain a third processing result; and a third calculation unit configured to perform updating of the first network parameters in the feature extractor, the regressor, and the counterregressor based on the third processing result to retrieve the third processing result again until the updating of the first network parameters is stopped when the third processing result satisfies a first preset condition.
Optionally, the update unit includes: a second updating subunit configured to perform updating of the second network parameters in the countermeasure regressor so as to maximize a difference between the divergence of the countermeasure regressor in the target domain and the divergence predicted by the regressor.
Optionally, the second updating subunit includes: an obtaining subunit, configured to perform obtaining of a plurality of thermodynamic diagrams corresponding to a plurality of key points output by the regressor on the target domain, and calculating a sum of remaining thermodynamic diagrams corresponding to remaining key points of the plurality of key points except for the target key point; a fourth calculating unit, configured to calculate a ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the remaining thermodynamic diagrams, so as to obtain a fourth processing result; the second processing subunit is configured to execute the steps of obtaining a second thermodynamic diagram, corresponding to the first thermodynamic diagram, output by the countermeasure regressor on the target domain, and performing normalization processing on the second thermodynamic diagram to obtain a fifth processing result; a fifth calculating subunit, configured to perform divergence calculation on the fourth processing result and the fifth processing result in a spatial dimension to obtain a sixth processing result; and a third updating subunit, configured to update the second network parameter in the countermeasure regressor based on the sixth processing result to obtain the sixth processing result again, until the second network parameter is stopped being updated when the sixth processing result meets a second preset condition.
Optionally, the update unit includes: a fourth updating subunit configured to perform updating of the third network parameters in the feature extractor so as to minimize a difference between the divergence of the countermeasure regressor in the target domain and the divergence predicted by the regressor.
Optionally, the fourth updating subunit includes: a sixth calculating subunit, configured to perform acquiring a plurality of thermodynamic diagrams corresponding to a plurality of key points output by the regressor on the target domain, and calculate a sum of remaining thermodynamic diagrams corresponding to the remaining key points except the target key point from among the plurality of key points; a seventh calculating subunit, configured to perform calculation of a ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the remaining thermodynamic diagrams, so as to obtain a seventh processing result; a third processing subunit, configured to execute obtaining a second thermodynamic diagram corresponding to the first thermodynamic diagram output by the countermeasure regressor on the target domain, and perform normalization processing on the second thermodynamic diagram to obtain an eighth processing result; an eighth calculating subunit, configured to perform divergence calculation on the seventh processing result and the eighth processing result in a spatial dimension, so as to obtain a ninth processing result; a fifth updating subunit, configured to execute updating the third network parameter in the feature extractor based on the ninth processing result to obtain the ninth processing result again, until the updating of the third network parameter is stopped when the ninth processing result meets a third preset condition.
According to a fourth aspect of the embodiments of the present disclosure, there is provided an electronic apparatus including: a processor; a memory for storing the processor-executable instructions; wherein the processor is configured to execute the instructions to implement any of the model training methods described above.
According to a fourth aspect of embodiments of the present disclosure, there is provided a computer-readable storage medium, wherein instructions of the computer-readable storage medium, when executed by a processor of an electronic device, enable the electronic device to perform any one of the above-mentioned model training methods.
According to a fifth aspect of embodiments of the present disclosure, there is provided a computer program product comprising a computer program which, when executed by a processor, implements any of the above-described model training methods.
The technical scheme provided by the embodiment of the disclosure at least brings the following beneficial effects:
the model training method provided by the embodiment of the disclosure is applied to a key point detection model, and the key point detection model comprises the following steps: the model training method comprises the following steps: determining a target object to be updated in the keypoint detection model, wherein the target object comprises: the feature extractor, the regressor, and the counterregressor; the network parameters of the target object are updated so that the key point detection model is migrated from the source domain to the target domain, and by providing a new regression domain adaptive method, the purpose of migrating the key point detection model between different data domains is achieved, the technical effect of improving the training effect of the key point detection model is achieved, and the problem that the key point detection model cannot be migrated between different data domains by a domain adaptive method in the related art is solved.
It is to be understood that both the foregoing general description and the following detailed description are exemplary and explanatory only and are not restrictive of the disclosure.
First, in order to facilitate understanding of the embodiments of the present disclosure, some terms or nouns referred to in the present disclosure will be explained below:
KL divergence: Kullback-Leibler divergence, is a measure of the asymmetry of the difference between two probability distributions P and Q. The KL divergence is a measure of the number of additional average bits required to encode samples of the P-compliant distribution using the Q-based distribution.
A source domain: the data field with labeled data is composed of data and labels thereon, and the data distribution is different from that of the data distribution seen by the model in the test.
Target domain: it refers to a data source without labeled data, and is composed of data seen by the model during testing. For example, the source domain is often composed of virtual data and computer-generated tags, while the target domain is a distribution of data in a real scene.
Disparity: divergence, i.e., the inconsistency of the predictions of the two models given the same input, may be defined, for example, by the L1 loss function.
DD, Disparity discrimination: the variance refers to the variance of the target domain minus the variance of the source domain.
MSE: mean Squared Loss: the mean square error loss function is a quadratic loss function, and is a commonly used regression loss function.
Domain addition: the domain adaptation refers to the fields related to machine learning and transfer learning.
Drawings
The accompanying drawings, which are incorporated in and constitute a part of this specification, illustrate embodiments consistent with the present disclosure and, together with the description, serve to explain the principles of the disclosure and are not to be construed as limiting the disclosure.
FIG. 1 is a flow diagram illustrating a method of model training in accordance with an exemplary embodiment;
FIG. 2 is a schematic diagram illustrating a keypoint detection model in accordance with an exemplary embodiment;
FIG. 3 is a flow diagram illustrating an alternative model training method in accordance with an exemplary embodiment;
FIG. 4 is a flow diagram illustrating another alternative model training method in accordance with an exemplary embodiment;
FIG. 5 is a schematic diagram illustrating another key point detection model training process in accordance with an exemplary embodiment;
FIG. 6 is a flow diagram illustrating yet another alternative model training method in accordance with an exemplary embodiment;
FIG. 7 is a schematic diagram illustrating a high-dimensional space-maximizing KL divergence output distribution in accordance with an exemplary embodiment;
FIG. 8 is a block diagram illustrating a model training apparatus in accordance with an exemplary embodiment;
FIG. 9 is a block diagram illustrating an electronic device in accordance with an example embodiment.
Detailed Description
In order to make the technical solutions of the present disclosure better understood by those of ordinary skill in the art, the technical solutions in the embodiments of the present disclosure will be clearly and completely described below with reference to the accompanying drawings.
It should be noted that the terms "first," "second," and the like in the description and claims of the present disclosure and in the above-described drawings are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used is interchangeable under appropriate circumstances such that the embodiments of the disclosure described herein are capable of operation in sequences other than those illustrated or otherwise described herein. The implementations described in the exemplary embodiments below are not intended to represent all implementations consistent with the present disclosure. Rather, they are merely examples of apparatus and methods consistent with certain aspects of the present disclosure, as detailed in the appended claims.
As the amount of model structure and parameters increases, the deep network is very highly dependent on the amount and quality of training data. The quality and quantity of the data greatly affect the performance of the algorithm. Unfortunately, in most tasks, there are problems of difficulty, high cost, inefficiency in data collection and annotation, and difficulty in obtaining sufficient training data. The generated data has the capability of quickly generating mass high-precision data with good diversity, and plays an important role in quick landing and effect improvement of the AI algorithm, so that the method plays an increasingly important role in deep learning key point detection processing.
However, the generated data, especially the data generated by the CG-based method, is limited by the reality and diversity of 3D models and environments, and there is often a significant domain gap (domain gap) with the real data. One way is to perform data rendering on the generated data before training, reducing the difference between the virtual data and the real data. Another important way is to reduce the domain gap at the model and feature level, i.e. Domain Adaptation (DA).
The domain adaptation in the keypoint detection process mainly includes two aspects: model-level DA (model-level DA) and feature-level DA (feature-level DA). The model-level DA method can effectively reduce domain gap and can greatly reduce some obvious errors in virtual data prediction and the inconsistency of real data and virtual data through the modes of parameter sharing, domain mapping, prior mining and the like. The feature-level DA method is similar to the data plane field adaptive method to a certain extent, but the adaptive process occurs in the middle feature layer, so that higher-dimensional information and more complex scene migration can be processed.
The current academic community has studied more domain-adaptive problems for classification problems and less domain-adaptive problems for regression problems. The classification problem features a clear clustering structure, while the regression problem features are more complex. How to research a field adaptive algorithm suitable for the regression problem is a technical difficulty which is not solved at present.
Fig. 1 is a flowchart illustrating a model training method according to an exemplary embodiment, where the model training method is applied to a keypoint detection model, as shown in fig. 1, and the keypoint detection model includes: the model training method comprises the following steps:
in step S11, a target object to be updated in the keypoint detection model is determined, where the target object includes: the feature extractor, the regressor, and the counterregressor;
in step S12, the network parameters of the target object are updated so that the keypoint detection model is migrated from the source domain to the target domain.
The model training method provided by the embodiment of the disclosure is applied to a key point detection model, and the key point detection model comprises the following steps: the model training method comprises the following steps: determining a target object to be updated in the keypoint detection model, wherein the target object comprises: the feature extractor, the regressor, and the counterregressor; the network parameters of the target object are updated so that the key point detection model is migrated from the source domain to the target domain, and by providing a new regression domain adaptive method, the purpose of migrating the key point detection model between different data domains is achieved, the technical effect of improving the training effect of the key point detection model is achieved, and the problem that the key point detection model cannot be migrated between different data domains by a domain adaptive method in the related art is solved.
Optionally, in the embodiment of the present disclosure, the feature extractor, the regressor, and the counterregression may be any type of feature extractor, regressor, and counterregression, for example, the feature extractor may be a ResNet or HRNet feature extractor.
The present disclosure finds that the output space of keypoint detection is sparse in the probabilistic sense, which is the key to the reduced classification and regression domain adaptive methods; the present disclosure takes advantage of this sparsity, reducing the size of the counterregressor output space. Since it is difficult to change the distribution mean value by performing KL maximization in a high-dimensional space, the present disclosure uses minimization of two targets with opposite physical meanings to complete the countermeasure training, and in the domain adaptation problem of keypoint detection, there is one source domain P with labeled data and one target domain Q without labeled data.
As shown in fig. 2, the disclosed embodiments divide the keypoint detection model into three parts: the feature extractor ψ, the regressor f and the counterregressor f' since both the regressor and the counterregressor output a thermodynamic diagram but only the regressor output is used for the key point detection in the test phase and the counterregressor is used only for the training phase, the supremum of the generalization error between the two domains can be calculated in the countercheck, where DD is the Disparity discropanacy, i.e. divergence of the target domain versus divergence of the source domain.
The embodiment of the present disclosure updates the network parameters of the target object to be updated in the keypoint detection model by providing a new regression domain adaptive method, so that the keypoint detection model migrates from the Source domain to the target domain, and the loss (Source rice) of the regressor f and the counterregressor f' on the Source domain can be minimized (target 1); if only the counterregressor f ' is updated, it is possible to make the counterregressor f ' as inconsistent as possible with the regressor f's predictions over the target domain, i.e. the variance divergence (DD) is as large as possible (target 2); if only the feature extractor ψ is updated so that the counterregressor f 'is as consistent as possible with the regressor f's predictions over the target domain, i.e. the divergence of the differences is as small as possible (target 3).
Through the embodiment of the disclosure, when divergence on the target domain is increased, the output difference corresponding to the two regressors is large, which means that for the same input, the antagonistic regressor f' which is output close on the source domain and the regressor f can give very different predictions on the target domain, which shows that the migration effect of the key point detection model is poor at this time, that is, the accuracy on the source domain is high, but the accuracy on the target domain is low. The essence of making the divergence of the differences as large as possible is therefore that f' is not migrated as much as possible in the function space.
With the embodiment of the present disclosure, the feature extractor ψ is updated so that the variance divergence on the target domain is reduced, and at this time, the feature extractor ψ tries to make the countermeasure regressor f and the regressor f give consistent predictions on the target domain, that is, the above-mentioned making the variance divergence as small as possible is to make the feature extractor ψ avoid the above-mentioned non-migration or poor migration effect as possible.
In the embodiment of the present disclosure, based on the above-mentioned whole training targets of the keypoint detection model, in the process of model training, the above-mentioned three targets (target 1, target 2, and target 3) are iterated repeatedly until a balance is reached, and the following describes, by way of example, a specific implementation process of each of the above-mentioned targets one by one:
in an optional embodiment, the step of updating the network parameter of the target object includes:
in step S21, the first network parameters in the feature extractor, the regressor, and the counterregressor are updated.
In the embodiment of the present disclosure, by updating the first network parameters in the feature extractor, the regressor and the counterregressor to migrate the keypoint detection model from the Source domain to the target domain, it is possible to achieve that the loss (Source rice) of the regressor f and the counterregressor f' on the Source domain is as small as possible (target 1).
In an alternative embodiment, fig. 3 is a flow chart illustrating an alternative model training method according to an exemplary embodiment, wherein the step of updating the first network parameters in the feature extractor, the regressor and the counterregressor comprises:
in step S31, a first processing result is obtained by normalizing the thermodynamic diagrams output by the regressor and the counterregressor;
in step S32, calculating a ratio between the thermodynamic diagram corresponding to each key point label in the source domain and a sum of the thermodynamic diagrams corresponding to all key point labels in the source domain to obtain a second processing result;
in step S33, performing divergence calculation on the first processing result and the second processing result in a spatial dimension to obtain a third processing result;
in step S34, the first network parameters in the feature extractor, the regressor and the counterregressor are updated based on the third processing result to obtain the third processing result again until the first network parameters are stopped from being updated when the third processing result satisfies the first predetermined condition.
In the embodiment of the present disclosure, Softmax normalization is performed on the thermodynamic diagrams output by the regressor f and the countermeasure regressor f' in the keypoint detection model to obtain a first processing result, where the first processing result has a meaning of probability distribution in space, and then a ratio of the thermodynamic diagram corresponding to each keypoint tag in the source domain to a sum of the thermodynamic diagrams (e.g., gaussian thermodynamic diagrams) corresponding to all the keypoint tags in the source domain is obtained, for example, a gaussian thermodynamic diagram corresponding to each keypoint tag is calculated, and then the sum of the thermodynamic diagrams corresponding to all the keypoint tags is spatially divided to obtain a second processing result, and a KL divergence in the spatial dimension is calculated on the first processing result and the second processing result to obtain a third processing result Ls.
And updating the first network parameters in the feature extractor, the regressor and the counterregressor based on the third processing result by adopting a gradient descent algorithm to obtain the third processing result again, so that the third processing result Ls is as small as possible, and the updating of the first network parameters is stopped when the third processing result meets a first preset condition, namely, the regressor f and the counterregressor f' have higher accuracy in a source domain.
Because the existing key point detection model generally adopts the MSE loss function between the thermodynamic diagram output by the model and the thermodynamic diagram corresponding to the true value to calculate, the problem of gradient explosion can occur in the confrontation training of the target 2 and the target 3 in this way, but the implementation method of the target 1 provided by the embodiment of the disclosure can effectively avoid the problem of gradient explosion caused by unbounded output of the key point detection model.
In an optional embodiment, the step of updating the network parameter of the target object includes:
in step S41, the second network parameters in the counterregression are updated so as to maximize the difference between the divergence of the counterregression in the target domain and the divergence predicted by the regression.
In the embodiment of the present disclosure, by updating only the second network parameters of the countermeasure regressor f', the prediction of the countermeasure regressor f and the regressor f can be made as inconsistent as possible on the target domain, i.e., the variance (DD) is as large as possible (target 2), and thus the variance (inconsistency) of the prediction on the target domain can be maximized.
In an alternative embodiment, fig. 4 is a flowchart illustrating another alternative model training method according to an exemplary embodiment, and as shown in fig. 4, the step of updating the second network parameters in the countermeasure regressor includes:
in step S51, acquiring a plurality of thermodynamic diagrams corresponding to the plurality of key points output by the regressor on the target domain, and calculating a sum of remaining thermodynamic diagrams corresponding to the remaining key points of the plurality of key points except the target key point;
in step S52, calculating a ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the remaining thermodynamic diagrams to obtain a fourth processing result;
in step S53, a second thermodynamic diagram corresponding to the first thermodynamic diagram and output by the countermeasure regressor in the target domain is obtained, and the second thermodynamic diagram is normalized to obtain a fifth processing result;
in step S54, performing divergence calculation on the fourth processing result and the fifth processing result in a spatial dimension to obtain a sixth processing result;
in step S55, the second network parameter in the countermeasure regressor is updated based on the sixth processing result to obtain the sixth processing result again, until the second network parameter is stopped being updated when the sixth processing result satisfies a second predetermined condition.
In the embodiment of the present disclosure, first, assuming that only the predicted inconsistency of the kth key point (target key point) is calculated before, the output of the regressor f on the target domain will obtain K thermodynamic diagrams (corresponding to K key points), and the thermodynamic diagrams output on the remaining key points except the kth key point may be summed.
Optionally, in this embodiment of the present disclosure, a ratio of a first thermodynamic diagram corresponding to the target key point to a sum of the remaining thermodynamic diagrams is calculated, and normalization is performed to obtain a spatial probability distribution of a most likely error of a key point k, so as to obtain a fourth processing result; acquiring a second thermodynamic diagram corresponding to the first thermodynamic diagram output by the countermeasure regressor f' on the target domain, and performing normalization processing on the second thermodynamic diagram, namely performing Softmax normalization on a k-th thermodynamic diagram of the thermodynamic diagram to obtain a fifth processing result; and calculating divergence of the fourth processing result and the fifth processing result in a spatial dimension, namely calculating KL divergence in the spatial dimension to obtain a sixth processing result Lf.
And then updating the second network parameters in the counterregressor f ' based on the sixth processing result by adopting a gradient descent algorithm, so that the sixth processing result Lf is as small as possible, namely, until the sixth processing result meets a second preset condition, the output of the counterregressor f ' to the kth key point can be encouraged to run to other key points except the predicted key point of the regressor f, namely, the counterregressor f ' is guided to make mistakes.
Since the output space of the keypoint detection model is large (generally 64 × 64), and it is difficult to find the case that f ' does not migrate in the countermeasure training (the source domain has high accuracy, but the divergence on the target domain is large), as shown in fig. 5, the embodiment of the disclosure finds that the output of the regression on the target domain has sparsity in the probability sense, that is, the error prediction is often concentrated on the positions of other keypoints rather than the background, and the disclosure uses the sparsity to encourage the output of the countermeasure regression f ' to be at the positions of the keypoints, and in this expectation sense, the size of the output space of the countermeasure regression f ' can be reduced.
In an optional embodiment, the step of updating the network parameter of the target object includes:
in step S61, the third network parameter in the feature extractor is updated so as to minimize the difference between the divergence of the countermeasure regressor in the target domain and the divergence predicted by the regressor.
Optionally, in the embodiment of the present disclosure, only the third network parameter in the feature extractor ψ is updated, so that the countervailing regressor f' is as consistent as possible with the prediction of the regressor f on the target domain, that is, the divergence of the difference is as small as possible (target 3), and thus the divergence (inconsistency) of the prediction on the target domain can be minimized.
In an alternative embodiment, fig. 6 is a flowchart illustrating a further alternative model training method according to an exemplary embodiment, and as shown in fig. 6, the step of updating the third network parameter in the feature extractor includes:
in step S71, acquiring a plurality of thermodynamic diagrams corresponding to the plurality of key points output by the regressor on the target domain, and calculating a sum of remaining thermodynamic diagrams corresponding to the remaining key points of the plurality of key points except the target key point;
in step S72, calculating a ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the remaining thermodynamic diagrams to obtain a seventh processing result;
in step S73, a second thermodynamic diagram corresponding to the first thermodynamic diagram and output by the countermeasure regressor in the target domain is obtained, and normalization processing is performed on the second thermodynamic diagram to obtain an eighth processing result;
in step S74, performing divergence calculation on the seventh processing result and the eighth processing result in a spatial dimension to obtain a ninth processing result;
in step S75, the third network parameter in the feature extractor is updated based on the ninth processing result to obtain the ninth processing result again, and the updating of the third network parameter is stopped when the ninth processing result satisfies a third preset condition.
In the embodiment of the present disclosure, assuming that only prediction inconsistency of a kth key point is currently calculated, a plurality of thermodynamic diagrams output by the regressor on the target domain and corresponding to the plurality of key points are obtained, and a sum of remaining thermodynamic diagrams corresponding to the remaining key points of the plurality of key points except the target key point is calculated; a gaussian thermodynamic diagram can be obtained from the k-th key point output by the regressor f on the target domain, the sum of the residual thermodynamic diagrams corresponding to the rest key points except the target key point in the plurality of key points is calculated, and the ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the residual thermodynamic diagrams is calculated to obtain a seventh processing result; and performing Softmax normalization on the kth thermodynamic diagram of the thermodynamic diagrams output by the countermeasure regressor f' to obtain an eighth processing result, and performing divergence calculation on the seventh processing result and the eighth processing result in the spatial dimension, namely calculating KL divergence in the spatial dimension to obtain a ninth processing result Lt.
Updating the third network parameter in the feature extractor ψ by using a gradient descent algorithm based on the ninth processing result to reacquire the ninth processing result until the third network parameter is stopped from being updated when the ninth processing result satisfies a third preset condition that is to make the ninth processing result as small as possible.
With the above embodiments, the embodiment of the present disclosure makes the countermeasure regressor f' unable to distinguish the source domain from the target domain and make a mistake on the target domain (the output of the kth keypoint is to other keypoints except for the prediction of f) by encouraging the feature extractor ψ to output a domain-independent feature, so that the ninth processing result Lt is lowered.
Furthermore, as shown in fig. 7, it is difficult to change the mean of the output distribution by maximizing the KL divergence in the high dimensional space as compared with the expected behavior, and in the embodiment of the present disclosure, it is desirable to change the mean of the outputs of the countermeasure regressor f 'when the KL divergence is maximized, but since the inventor finds that the variance of the countermeasure regressor f' is the main change in the experimental process, the embodiment of the present disclosure tries to complete the countermeasure training by using the minimization of two opposite targets in the countermeasure training. That is, since goal 2 corresponds to the minimization of the counterregressor f ' output and the erroneous prediction and goal 3 corresponds to the minimization of the counterregressor f ' output and the correct prediction, the minimization of KL in the high-dimensional space can effectively change the mean of the distribution, i.e., indicate that the output of the counterregressor f ' can be efficiently guided to other key points, thereby helping the countertraining.
As an optional embodiment, in the embodiment of the present disclosure, ResNet101 is used as a feature extractor, a virtual human hand data set RHD is used as a source domain, a real human hand data set H3D is used as a target domain, and an evaluation index is PCK (the distance between a predicted point and a label does not exceed 1/20 of a picture, i.e., is regarded as correct prediction), so that the inventor finds that the accuracy of the PCK of an original model is 61.8%, and after the model is trained by using the method provided by the embodiment of the present disclosure, the accuracy of the PCK is 72.5%.
Fig. 8 is a block diagram illustrating a model training apparatus applied to a keypoint detection model according to an exemplary embodiment, the keypoint detection model including: referring to fig. 8, the feature extractor, the regressor, and the counterregressor, the model training apparatus includes: a determination unit 121 and an update unit 122, wherein:
a determining unit 121 configured to perform determining a target object to be updated in the keypoint detection model, wherein the target object includes: the feature extractor, the regressor, and the counterregressor;
an updating unit 122 configured to perform updating of the network parameters of the target object so as to migrate the key point detection model from the source domain to the target domain.
In an optional embodiment, the update unit includes: a first updating subunit configured to perform updating of the first network parameters in the feature extractor, the regressor, and the counterregressor.
In an alternative embodiment, the first updating subunit includes: the first processing subunit is configured to perform normalization processing on the thermodynamic diagrams output by the regressor and the countermeasure regressor to obtain a first processing result; the first calculating subunit is configured to calculate a ratio of the thermodynamic diagram corresponding to each key point label in the source domain to a sum of the thermodynamic diagrams corresponding to all the key point labels in the source domain, so as to obtain a second processing result; a second calculation unit configured to perform divergence calculation on the first processing result and the second processing result in a spatial dimension to obtain a third processing result; and a third calculation unit configured to perform updating of the first network parameters in the feature extractor, the regressor, and the counterregressor based on the third processing result to retrieve the third processing result again until the updating of the first network parameters is stopped when the third processing result satisfies a first preset condition.
In an optional embodiment, the update unit includes: a second updating subunit configured to perform updating of the second network parameters in the countermeasure regressor so as to maximize a difference between the divergence of the countermeasure regressor in the target domain and the divergence predicted by the regressor.
In an optional embodiment, the second updating subunit includes: an obtaining subunit, configured to perform obtaining of a plurality of thermodynamic diagrams corresponding to a plurality of key points output by the regressor on the target domain, and calculating a sum of remaining thermodynamic diagrams corresponding to remaining key points of the plurality of key points except for the target key point; a fourth calculating unit, configured to calculate a ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the remaining thermodynamic diagrams, so as to obtain a fourth processing result; the second processing subunit is configured to execute the steps of obtaining a second thermodynamic diagram, corresponding to the first thermodynamic diagram, output by the countermeasure regressor on the target domain, and performing normalization processing on the second thermodynamic diagram to obtain a fifth processing result; a fifth calculating subunit, configured to perform divergence calculation on the fourth processing result and the fifth processing result in a spatial dimension to obtain a sixth processing result; and a third updating subunit, configured to update the second network parameter in the countermeasure regressor based on the sixth processing result to obtain the sixth processing result again, until the second network parameter is stopped being updated when the sixth processing result meets a second preset condition.
In an optional embodiment, the update unit includes: a fourth updating subunit configured to perform updating of the third network parameters in the feature extractor so as to minimize a difference between the divergence of the countermeasure regressor in the target domain and the divergence predicted by the regressor.
In an optional embodiment, the fourth updating subunit includes: a sixth calculating subunit, configured to perform acquiring a plurality of thermodynamic diagrams corresponding to a plurality of key points output by the regressor on the target domain, and calculate a sum of remaining thermodynamic diagrams corresponding to the remaining key points except the target key point from among the plurality of key points; a seventh calculating subunit, configured to perform calculation of a ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the remaining thermodynamic diagrams, so as to obtain a seventh processing result; a third processing subunit, configured to execute obtaining a second thermodynamic diagram corresponding to the first thermodynamic diagram output by the countermeasure regressor on the target domain, and perform normalization processing on the second thermodynamic diagram to obtain an eighth processing result; an eighth calculating subunit, configured to perform divergence calculation on the seventh processing result and the eighth processing result in a spatial dimension, so as to obtain a ninth processing result; a fifth updating subunit, configured to execute updating the third network parameter in the feature extractor based on the ninth processing result to obtain the ninth processing result again, until the updating of the third network parameter is stopped when the ninth processing result meets a third preset condition.
With regard to the apparatus in the above-described embodiment, the specific manner in which each module performs the operation has been described in detail in the embodiment related to the method, and will not be elaborated here.
In an exemplary embodiment, there is also provided a computer-readable storage medium comprising instructions, such as a memory comprising instructions, which when executed by a processor of an electronic device, enable the electronic device to perform any of the above-described model training methods. Alternatively, the computer readable storage medium may be a ROM, a Random Access Memory (RAM), a CD-ROM, a magnetic tape, a floppy disk, an optical data storage device, and the like.
According to an embodiment of the present disclosure, there is provided an electronic device, and fig. 9 is a block diagram of an electronic device shown according to an exemplary embodiment, and as shown in fig. 9, the electronic device may include: at least one processor 901 (only one shown), memory 902, and a peripheral interface 903. The method comprises the following steps: a processor; a memory for storing the processor-executable instructions; wherein the processor is configured to execute the instructions to implement any of the model training methods described above. Alternatively, the memory may include high speed random access memory, and may also include non-volatile memory, such as one or more magnetic storage devices, flash memory, or other non-volatile solid-state memory. In some examples, the memory may further include memory located remotely from the processor, and these remote memories may be connected to the computer terminal through a network. Examples of such networks include, but are not limited to, the internet, intranets, local area networks, mobile communication networks, and combinations thereof.
According to an embodiment of the present disclosure, there is provided a computer program product comprising a computer program which, when executed by a processor, implements any of the above-described model training methods.
Other embodiments of the disclosure will be apparent to those skilled in the art from consideration of the specification and practice of the disclosure disclosed herein. This disclosure is intended to cover any variations, uses, or adaptations of the disclosure following, in general, the principles of the disclosure and including such departures from the present disclosure as come within known or customary practice within the art to which the disclosure pertains. It is intended that the specification and examples be considered as exemplary only, with a true scope and spirit of the disclosure being indicated by the following claims.
It will be understood that the present disclosure is not limited to the precise arrangements described above and shown in the drawings and that various modifications and changes may be made without departing from the scope thereof. The scope of the present disclosure is limited only by the appended claims.

Claims (10)

1. A model training method is applied to a key point detection model, and the key point detection model comprises the following steps: the model training method comprises the following steps of:
determining a target object to be updated in the keypoint detection model, wherein the target object comprises: the feature extractor, the regressor and the counterregressor;
updating the network parameters of the target object so that the key point detection model is migrated from a source domain to a target domain.
2. The model training method of claim 1, wherein the step of updating the network parameters of the target object comprises:
updating first network parameters in the feature extractor, the regressor, and the counterregressor.
3. The model training method of claim 2, wherein the step of updating the first network parameters in the feature extractor, the regressor and the counterregressor comprises:
normalizing the thermodynamic diagrams output by the regressor and the countercheck regressor to obtain a first processing result;
calculating the ratio of the thermodynamic diagrams corresponding to each key point label on the source domain to the sum of the thermodynamic diagrams corresponding to all key point labels on the source domain to obtain a second processing result;
performing divergence calculation on the first processing result and the second processing result in a spatial dimension to obtain a third processing result;
updating the first network parameters in the feature extractor, the regressor and the counterregression based on the third processing result to obtain the third processing result again until the first network parameters are stopped from being updated when the third processing result meets a first preset condition.
4. The model training method of claim 1, wherein the step of updating the network parameters of the target object comprises:
updating second network parameters in the countermeasure regressor so as to maximize a difference between divergence of the countermeasure regressor at the target domain and divergence predicted by the regressor.
5. The model training method of claim 4, wherein the step of updating the second network parameters in the antagonistic regressor comprises:
acquiring a plurality of thermodynamic diagrams which are output by the regressor on the target domain and correspond to a plurality of key points, and calculating the sum of residual thermodynamic diagrams corresponding to the rest key points except the target key point in the plurality of key points;
calculating the ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the residual thermodynamic diagrams to obtain a fourth processing result;
acquiring a second thermodynamic diagram which is output by the countermeasure regressor on the target domain and corresponds to the first thermodynamic diagram, and performing normalization processing on the second thermodynamic diagram to obtain a fifth processing result;
performing divergence calculation on the fourth processing result and the fifth processing result in a spatial dimension to obtain a sixth processing result;
and updating the second network parameters in the counterregression device based on the sixth processing result to obtain the sixth processing result again until the second network parameters are stopped to be updated when the sixth processing result meets a second preset condition.
6. The model training method of claim 1, wherein the step of updating the network parameters of the target object comprises:
updating a third network parameter in the feature extractor such that a difference between the divergence of the countermeasure regressor in the target domain and the divergence predicted by the regressor is minimized.
7. The model training method of claim 6, wherein the step of updating the third network parameters in the feature extractor comprises:
acquiring a plurality of thermodynamic diagrams which are output by the regressor on the target domain and correspond to a plurality of key points, and calculating the sum of residual thermodynamic diagrams corresponding to the rest key points except the target key point in the plurality of key points;
calculating the ratio of the first thermodynamic diagram corresponding to the target key point to the sum of the residual thermodynamic diagrams to obtain a seventh processing result;
acquiring a second thermodynamic diagram which is output by the countermeasure regressor on the target domain and corresponds to the first thermodynamic diagram, and performing normalization processing on the second thermodynamic diagram to obtain an eighth processing result;
performing divergence calculation on the seventh processing result and the eighth processing result in a spatial dimension to obtain a ninth processing result;
and updating a third network parameter in the feature extractor based on the ninth processing result to obtain the ninth processing result again until the third network parameter stops being updated when the ninth processing result meets a third preset condition.
8. A model training apparatus, wherein the model training apparatus is applied to a keypoint detection model, the keypoint detection model comprising: the model training device comprises a feature extractor, a regressor and a counterregressor, and comprises:
a determining unit configured to perform determining a target object to be updated in the keypoint detection model, wherein the target object comprises: the feature extractor, the regressor and the counterregressor;
an updating unit configured to perform updating of network parameters of the target object to cause the key point detection model to migrate from a source domain to a target domain.
9. An electronic device, comprising:
a processor;
a memory for storing the processor-executable instructions;
wherein the processor is configured to execute the instructions to implement the model training method of any one of claims 1 to 7.
10. A computer-readable storage medium, wherein instructions in the computer-readable storage medium, when executed by a processor of an electronic device, enable the electronic device to perform the model training method of any of claims 1-7.
CN202110656444.0A 2021-06-11 2021-06-11 Model training method and device, storage medium and electronic equipment Pending CN113283598A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110656444.0A CN113283598A (en) 2021-06-11 2021-06-11 Model training method and device, storage medium and electronic equipment

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110656444.0A CN113283598A (en) 2021-06-11 2021-06-11 Model training method and device, storage medium and electronic equipment

Publications (1)

Publication Number Publication Date
CN113283598A true CN113283598A (en) 2021-08-20

Family

ID=77284656

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110656444.0A Pending CN113283598A (en) 2021-06-11 2021-06-11 Model training method and device, storage medium and electronic equipment

Country Status (1)

Country Link
CN (1) CN113283598A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115546295A (en) * 2022-08-26 2022-12-30 西北大学 Target 6D attitude estimation model training method and target 6D attitude estimation method

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109376659A (en) * 2018-10-26 2019-02-22 北京陌上花科技有限公司 Training method, face critical point detection method, apparatus for face key spot net detection model
CN112215255A (en) * 2020-09-08 2021-01-12 深圳大学 Training method of target detection model, target detection method and terminal equipment

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109376659A (en) * 2018-10-26 2019-02-22 北京陌上花科技有限公司 Training method, face critical point detection method, apparatus for face key spot net detection model
CN112215255A (en) * 2020-09-08 2021-01-12 深圳大学 Training method of target detection model, target detection method and terminal equipment

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
JUNGUANG JIANG ET AL.: "Regressive Domain Adaptation for Unsupervised Keypoint Detection", 《AIXIV》 *
赵永强 等: "深度学习目标检测方法综述", 《中国图象图形学报》 *

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115546295A (en) * 2022-08-26 2022-12-30 西北大学 Target 6D attitude estimation model training method and target 6D attitude estimation method
CN115546295B (en) * 2022-08-26 2023-11-07 西北大学 Target 6D gesture estimation model training method and target 6D gesture estimation method

Similar Documents

Publication Publication Date Title
CN111523621B (en) Image recognition method and device, computer equipment and storage medium
CN111310808B (en) Training method and device for picture recognition model, computer system and storage medium
CN109783604B (en) Information extraction method and device based on small amount of samples and computer equipment
CN112330685B (en) Image segmentation model training method, image segmentation device and electronic equipment
CN111310846A (en) Method, device, storage medium and server for selecting sample image
CN112270686B (en) Image segmentation model training method, image segmentation device and electronic equipment
CN110210625B (en) Modeling method and device based on transfer learning, computer equipment and storage medium
CN109285105A (en) Method of detecting watermarks, device, computer equipment and storage medium
CN110909868A (en) Node representation method and device based on graph neural network model
CN112926654A (en) Pre-labeling model training and certificate pre-labeling method, device, equipment and medium
CN110781818B (en) Video classification method, model training method, device and equipment
CN113377964B (en) Knowledge graph link prediction method, device, equipment and storage medium
CN114492601A (en) Resource classification model training method and device, electronic equipment and storage medium
CN115359308B (en) Model training method, device, equipment, storage medium and program for identifying difficult cases
CN113283598A (en) Model training method and device, storage medium and electronic equipment
CN114328942A (en) Relationship extraction method, apparatus, device, storage medium and computer program product
CN111552812B (en) Method, device and computer equipment for determining relationship category between entities
CN112906517A (en) Self-supervision power law distribution crowd counting method and device and electronic equipment
CN111062406B (en) Heterogeneous domain adaptation-oriented semi-supervised optimal transmission method
CN110570490B (en) Saliency image generation method and equipment
CN114741697B (en) Malicious code classification method and device, electronic equipment and medium
CN112115996B (en) Image data processing method, device, equipment and storage medium
CN114663751A (en) Power transmission line defect identification method and system based on incremental learning technology
CN114005015A (en) Model training method, electronic device, and computer-readable storage medium
CN113723431A (en) Image recognition method, image recognition device and computer-readable storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication

Application publication date: 20210820

RJ01 Rejection of invention patent application after publication