CN114092918A - Model training method, device, equipment and storage medium - Google Patents

Model training method, device, equipment and storage medium Download PDF

Info

Publication number
CN114092918A
CN114092918A CN202210024447.7A CN202210024447A CN114092918A CN 114092918 A CN114092918 A CN 114092918A CN 202210024447 A CN202210024447 A CN 202210024447A CN 114092918 A CN114092918 A CN 114092918A
Authority
CN
China
Prior art keywords
model
target
student
training
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210024447.7A
Other languages
Chinese (zh)
Inventor
袁振国
刘国清
杨广
王启程
朱爱晨
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shenzhen Minieye Innovation Technology Co Ltd
Original Assignee
Shenzhen Minieye Innovation Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shenzhen Minieye Innovation Technology Co Ltd filed Critical Shenzhen Minieye Innovation Technology Co Ltd
Priority to CN202210024447.7A priority Critical patent/CN114092918A/en
Publication of CN114092918A publication Critical patent/CN114092918A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections

Abstract

The application discloses a model training method, a device, equipment and a storage medium, wherein a training data set is obtained, and a preset teacher model is used for training by using labeled data until the teacher model reaches a preset first convergence condition, so that a target teacher model is obtained, and the teacher model can learn more deeper model characteristics; carrying out BN layer weight sharing on the target teacher model and the student model so as to guide the student model to train by using the target teacher model, thereby enabling the student model to have the BN layer weight of the target teacher model to train; and finally, performing joint training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain the target student model, so as to improve the expression capability of the student model while keeping lower model complexity, thereby effectively compressing the model, reducing the consumption of computing resources and reducing the labor cost of manual labeling.

Description

Model training method, device, equipment and storage medium
Technical Field
The present application relates to the field of artificial intelligence technologies, and in particular, to a model training method, apparatus, device, and storage medium.
Background
With the rapid development of artificial intelligence, convolutional neural networks are widely applied to the field of vehicle driving, such as vehicle detection and lane line detection. Wherein training the convolutional neural network requires a large amount of high-quality labeling data to obtain a high-complexity model, thereby improving the model accuracy. However, a large amount of labeled data requires a high storage space, and a training process also requires a huge amount of computing resources.
At present, due to cost limitation, a primary intelligent auxiliary driving system usually adopts a computing platform with relatively low computing power, and a high-complexity model is adopted to bring about a high-delay problem. Therefore, how to compress the high-complexity model to an acceptable level of the end-side computing platform is an urgent problem to be solved.
Disclosure of Invention
The application provides a model training method, a model training device, model training equipment and a storage medium, and aims to solve the technical problem that a convolutional neural network is high in computing resource consumption.
In order to solve the above technical problem, in a first aspect, an embodiment of the present application provides a model training method, including:
acquiring a training data set, wherein the training data set comprises marked data and unmarked data;
training a preset teacher model by using the marked data until the teacher model reaches a preset first convergence condition to obtain a target teacher model;
carrying out BN layer weight sharing on the target teacher model and the student model, wherein the model complexity of the target teacher model is greater than that of the student model;
and performing joint training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain a target student model, wherein the target student model can be used for being deployed to the end-side computing platform.
In the embodiment, a target teacher model is obtained by acquiring a training data set and training a preset teacher model by using labeled data until the teacher model reaches a preset first convergence condition, so that the teacher model learns more deeper model features; carrying out BN layer weight sharing on the target teacher model and the student model so as to guide the student model to train by using the target teacher model, thereby enabling the student model to have the BN layer weight of the target teacher model to train; and finally, performing joint training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain the target student model, so that the expression capacity of the student model can be improved under the condition of keeping lower model complexity, the model can be effectively compressed, and further, the consumption of computing resources and the labor cost of manual labeling are reduced.
In one embodiment, the target teacher model and the student model both have multiple BN layers, and the BN layer weight sharing for the target teacher model and the student model includes:
the multi-level BN layer weights of the target teacher model are shared to the student model.
In the embodiment, the weights of the multiple BN layers are shared, so that the student model can efficiently extract the feature expression capability of the target teacher model in the training stage, and the problem of poor expression capability caused by few convolutional network layers of the student model is effectively solved.
In one embodiment, the student model and the target teacher model fix the multi-level BN layer weights of the target teacher model and the multi-level BN layer weights of the student model during the joint training.
In the embodiment, the weight of the BN layer is fixed, so that adverse effects on the student model caused by updating of the weight of the BN layer are avoided.
In an embodiment, the jointly training the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain the target student model includes:
taking the training data set as input data of the student model and the target teacher model, and outputting a first prediction result of the student model and a second prediction result of the target teacher model;
calculating the total loss value of the target loss function according to the first prediction result and the second prediction result;
and updating the student model according to the total loss value until the student model converges to obtain the target student model.
In the embodiment, the target teacher model and the student model are jointly trained through the labeled data and the unlabeled data, so that the manual labeling cost can be reduced, the input data distribution of the student model and the target teacher model is unified, and the expression capacity of the student model is improved.
In one embodiment, calculating a total loss value of the target loss function based on the first prediction and the second prediction comprises:
determining the data type of input data, wherein the data type is marked data or unmarked data;
and calculating the total loss value of the target loss function according to the data type.
In one embodiment, the target loss function is:
Figure 774763DEST_PATH_IMAGE001
wherein the content of the first and second substances,
Figure 398424DEST_PATH_IMAGE002
is a function of the predicted loss for the student model,
Figure DEST_PATH_IMAGE003
for the target teacher modelThe loss function is measured and the measured loss function,
Figure 766258DEST_PATH_IMAGE004
in order to be able to predict the result of the first prediction,
Figure DEST_PATH_IMAGE005
in order to be the result of the second prediction,
Figure 853032DEST_PATH_IMAGE006
is a mean square error between the first prediction and the second prediction,
Figure DEST_PATH_IMAGE007
if the data type is the marked data, the data type is the data type
Figure 356825DEST_PATH_IMAGE008
If the data type is the unmarked data, then
Figure DEST_PATH_IMAGE009
In one embodiment, updating the student model according to the total loss value until the student model converges to obtain the target student model, including:
if the total loss value is not less than the preset threshold value, updating the first characteristic layer weight of the student model and the second characteristic layer weight of the target teacher model to obtain a new student model and a new target teacher model;
and predicting the training data set by using the new student model and the new target teacher model until the total loss value is smaller than a preset threshold value to obtain a target student model.
In a second aspect, an embodiment of the present application provides a model training apparatus, including:
the acquisition module is used for acquiring a training data set, and the training data set comprises marked data and unmarked data;
the first training module is used for training a preset teacher model by using the marked data until the teacher model reaches a preset first convergence condition to obtain a target teacher model;
the sharing module is used for carrying out BN layer weight sharing on the target teacher model and the student model, and the model complexity of the target teacher model is greater than that of the student model;
and the second training module is used for performing combined training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain a target student model, and the target student model can be used for being deployed to the end-side computing platform.
In a third aspect, an embodiment of the present application provides a computer device, including a processor and a memory, where the memory is used to store a computer program, and the computer program, when executed by the processor, implements the model training method according to the first aspect.
In a fourth aspect, embodiments of the present application provide a computer-readable storage medium, which stores a computer program, and when the computer program is executed by a processor, the computer program implements the model training method according to the first aspect.
Please refer to the relevant description of the first aspect for the beneficial effects of the second aspect to the fourth aspect, which are not described herein again.
Drawings
Fig. 1 is a schematic flowchart of a model training method according to an embodiment of the present disclosure;
FIG. 2 is a diagram illustrating multi-level weight sharing according to an embodiment of the present disclosure;
FIG. 3 is a schematic structural diagram of a model training apparatus according to an embodiment of the present disclosure;
fig. 4 is a schematic structural diagram of a computer device according to an embodiment of the present application.
Detailed Description
The technical solutions in the embodiments of the present application will be clearly and completely described below with reference to the drawings in the embodiments of the present application, and it is obvious that the described embodiments are only a part of the embodiments of the present application, and not all of the embodiments. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present application.
As described in the related art, due to cost limitation, a primary intelligent assistant driving system often adopts a computing platform with relatively low computing power, and the adoption of a high-complexity model causes a high delay problem.
Therefore, the embodiment of the application provides a model training method, a device, equipment and a storage medium, wherein a training data set is obtained, and a preset teacher model is trained by using the labeled data until the teacher model reaches a preset first convergence condition, so that a target teacher model is obtained, and the teacher model learns more deeper model features; carrying out BN layer weight sharing on the target teacher model and the student model so as to guide the student model to train by using the target teacher model, thereby enabling the student model to have the BN layer weight of the target teacher model to train; and finally, performing joint training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain a target student model, so that the expression capacity of the student model can be improved under the condition of keeping lower model complexity, the model can be effectively compressed, and the consumption of computing resources and the labor cost of manual labeling are reduced.
Referring to fig. 1, fig. 1 is a schematic flow chart of a model training method according to an embodiment of the present disclosure. The model training method can be applied to computer equipment including but not limited to smart phones, tablet computers, notebook computers, desktop computers, physical servers, cloud servers and other equipment. As shown in fig. 1, the model training method of the present embodiment includes steps S101 to S104, which are detailed as follows:
step S101, a training data set is obtained, wherein the training data set comprises marked data and unmarked data.
In this step, a training data set and a joint training model are constructed, the joint training model including a teacher model and a student model. Optionally, for the classification problem to be solved, actual scene data is collected, and for the scene problem that the original data can be continuously obtained, in consideration of the manual labeling cost, partial data is extracted according to a certain rule for manual labeling, for example, extraction at equal intervals.
Optionally, according to the actually deployed hardware platform and the scene requirements, a student model meeting the storage requirement and the calculation delay requirement is preferentially established, and then the complexity of the student model is increased to serve as a teacher model.
And S102, training a preset teacher model by using the marked data until the teacher model reaches a preset first convergence condition, so as to obtain a target teacher model.
Alternatively, the first convergence condition may be that the loss function of the teacher model is less than a preset value, or that the number of iterations of the teacher model reaches a preset number.
And S103, carrying out BN layer weight sharing on the target teacher model and the student model, wherein the model complexity of the target teacher model is greater than that of the student model.
In this step, the target teacher model and the student model both have a feature layer and a Batch Normalization (BN) layer. The BN layer weight sharing of this embodiment is to assign the BN layer weight of the target teacher model to the student model, so that the BN layer weight of the student model is the same as that of the target teacher model, and thus the student model can have the feature expression capability of the target teacher model.
And S104, performing joint training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain a target student model, wherein the target student model can be used for being deployed to an end-side computing platform.
In the step, the student model and the target teacher model are trained together through the marked data and the unmarked data, the total loss value between the student model and the target teacher model is calculated, and when the total loss value is smaller than a preset threshold value, the student model reaches a second convergence condition. In the embodiment, the input data distribution of the student model and the teacher model is unified by using a BN layer weight sharing mode and a joint training method, and the expression capacity of the student model is improved, so that the accuracy of the end-side deployment model is effectively improved under the condition of not increasing the complexity, and meanwhile, the consumption of computing resources is reduced.
In an embodiment, on the basis of the embodiment shown in fig. 1, the step S103 includes:
sharing the multi-level BN layer weights of the target teacher model to the student model.
In this step, as shown in fig. 2, for example, a common four-layer feature extraction neural network is used for weight sharing of the multi-level BN layer, where the feature extraction layer is used to extract features of input data, including a convolutional layer, a pooling layer, and an activity function layer, so as to improve the nonlinear feature expression capability of the model, and the pooling layer can reduce feature dimensions and enrich feature information output after convolution calculation. Based on the student model on the left side of fig. 2, the BN layer weight sharing relationship is as follows: < BN layer 1, BN layer 5>, < BN layer 2, BN layer 6>, < BN layer 3, BN layer 7>, < BN layer 4, BN layer 8 >. When the student model is understood, the teacher model has more convolution calculations and activation functions through the feature extraction layer 5, the feature extraction layer 6, the feature extraction layer 7 and the feature extraction layer 8, and the student model has more convolution calculations and activation functions through the feature extraction layer 1, the feature extraction layer 2, the feature extraction layer 3 and the feature extraction layer 4, so that the practical application requirements can be better met.
In the embodiment, the weights of the multiple BN layers are shared, so that the student model can efficiently extract the feature expression capability of the target teacher model in the training stage, and the problem of poor expression capability caused by few convolutional network layers of the student model is effectively solved.
Optionally, the student model and the target teacher model fix the multi-level BN layer weights of the target teacher model and the multi-level BN layer weights of the student model during joint training.
In this alternative embodiment, the student model and teacher model share BN layer weights, and the sharing relationship is shown in fig. 2: and in the joint training stage, the weights of the BN layer 1, the BN layer 2, the BN layer 3, the BN layer 7, the BN layer 4 and the BN layer 8 are fixed and are not updated along with the loss function of the joint training. In the embodiment, the weight of the BN layer is fixed, so that adverse effects on the student model caused by updating of the weight of the BN layer are avoided.
In an embodiment, based on the embodiment shown in fig. 1, the step S104 includes:
taking the training data set as input data of the student model and the target teacher model, and outputting a first prediction result of the student model and a second prediction result of the target teacher model;
calculating the total loss value of the target loss function according to the first prediction result and the second prediction result;
and updating the student model according to the total loss value until the student model converges to obtain the target student model.
In this embodiment, the student models are jointly trained using labeled data and unlabeled data as input data. Optionally, the calculating a total loss value of the target loss function according to the first prediction result and the second prediction result includes: determining the data type of the input data, wherein the data type is marked data or unmarked data; and calculating the total loss value of the target loss function according to the data type.
And respectively calculating the classification prediction results Ps and Pt of the student model and the target teacher model for each single training stage. If the input data is labeled data, the prediction losses Ls and Lt. of the student model and the target teacher model are calculated respectively, and if the input data is unlabeled data, the prediction losses of the student model and the teacher model are not calculated.
In the embodiment, the target teacher model and the student model are jointly trained through the labeled data and the unlabeled data, so that the manual labeling cost can be reduced, the input data distribution of the student model and the target teacher model is unified, and the expression capacity of the student model is improved.
Optionally, the target loss function is:
Figure 287522DEST_PATH_IMAGE010
wherein the content of the first and second substances,
Figure 648096DEST_PATH_IMAGE002
is a function of the predicted loss for the student model,
Figure 222166DEST_PATH_IMAGE003
for the predicted loss function of the target teacher model,
Figure 778918DEST_PATH_IMAGE004
in order to be able to predict the result of the first prediction,
Figure 43415DEST_PATH_IMAGE005
in order to be the result of the second prediction,
Figure 107315DEST_PATH_IMAGE006
is a mean square error between the first prediction and the second prediction,
Figure 965419DEST_PATH_IMAGE007
if the data type is the marked data, the data type is the data type
Figure 76594DEST_PATH_IMAGE008
If the data type is the unmarked data, then
Figure 87276DEST_PATH_IMAGE011
Optionally, the updating the student model according to the total loss value until the student model converges to obtain the target student model includes: if the total loss value is not less than a preset threshold value, updating a first feature layer weight of the student model and a second feature layer weight of the target teacher model to obtain a new student model and a new target teacher model; and predicting the training data set by using the new student model and the new target teacher model until the total loss value is smaller than the preset threshold value to obtain the target student model.
In this embodiment, the calculated loss value is used to update the first feature layer weight and the second feature layer weight corresponding to the feature extraction layers of the target teacher model and the student model through a back propagation algorithm. And if the total loss value is not less than the preset threshold value, entering the next single training stage, otherwise, ending the training. And taking the target student model after the training as an end-side deployment model.
In order to execute the model training method corresponding to the embodiment of the method, corresponding functions and technical effects are realized. Referring to fig. 3, fig. 3 is a block diagram illustrating a structure of a model training apparatus according to an embodiment of the present application. For convenience of explanation, only the part related to the present embodiment is shown, and the model training apparatus provided in the embodiment of the present application includes:
an obtaining module 301, configured to obtain a training data set, where the training data set includes labeled data and unlabeled data;
a first training module 302, configured to train a preset teacher model by using the labeled data until the teacher model reaches a preset first convergence condition, so as to obtain a target teacher model;
a sharing module 303, configured to perform BN layer weight sharing on the target teacher model and the student model, where a model complexity of the target teacher model is greater than a model complexity of the student model;
a second training module 304, configured to perform joint training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition, so as to obtain a target student model, where the target student model can be used for deployment to an end-side computing platform.
In an embodiment, the sharing module 303 is specifically configured to:
sharing the multi-level BN layer weights of the target teacher model to the student model.
In an embodiment, the student model and the target teacher model fix the multi-level BN layer weights of the target teacher model and the multi-level BN layer weights of the student model during the joint training.
In an embodiment, the second training module 304 includes:
an output unit, configured to output a first prediction result of the student model and a second prediction result of the target teacher model using the training data set as input data of the student model and the target teacher model;
a calculating unit, configured to calculate a total loss value of a target loss function according to the first prediction result and the second prediction result;
and the updating unit is used for updating the student model according to the total loss value until the student model converges to obtain the target student model.
In one embodiment, the computing unit includes:
the determining subunit is used for determining the data type of the input data, wherein the data type is marked data or unmarked data;
and the calculating subunit is used for calculating the total loss value of the target loss function according to the data type.
In one embodiment, the target loss function is:
Figure 38920DEST_PATH_IMAGE010
wherein the content of the first and second substances,
Figure 338314DEST_PATH_IMAGE002
is a function of the predicted loss for the student model,
Figure 220557DEST_PATH_IMAGE003
for the predicted loss function of the target teacher model,
Figure 85745DEST_PATH_IMAGE004
in order to be able to predict the result of the first prediction,
Figure 427865DEST_PATH_IMAGE012
in order to be the result of the second prediction,
Figure 948976DEST_PATH_IMAGE006
is a mean square error between the first prediction and the second prediction,
Figure 198692DEST_PATH_IMAGE007
if the data type is the marked data, the data type is the data type
Figure 856069DEST_PATH_IMAGE008
If the data type is the unmarked data, then
Figure 873485DEST_PATH_IMAGE011
In one embodiment, the update unit includes:
an updating subunit, configured to update a first feature layer weight of the student model and a second feature layer weight of the target teacher model if the total loss value is not less than a preset threshold value, so as to obtain a new student model and a new target teacher model;
and the iteration subunit is used for predicting the training data set by using the new student model and the new target teacher model until the total loss value is smaller than the preset threshold value, so as to obtain the target student model.
The model training device can implement the model training method of the method embodiment. The alternatives in the above-described method embodiments are also applicable to this embodiment and will not be described in detail here. The rest of the embodiments of the present application may refer to the contents of the above method embodiments, and in this embodiment, details are not described again.
Fig. 4 is a schematic structural diagram of a computer device according to an embodiment of the present application. As shown in fig. 4, the computer apparatus 400 of this embodiment includes: at least one processor 401 (only one shown in fig. 4), a memory 402, and a computer program 403 stored in the memory 402 and executable on the at least one processor 401, the processor 401 implementing the steps in any of the method embodiments described above when executing the computer program 403.
The computing device 400 may be a computing device such as a smartphone, a tablet computer, a desktop computer, and a cloud server. The computer device may include, but is not limited to, a processor 401, a memory 402. Those skilled in the art will appreciate that fig. 4 is merely an example of a computer device 400 and is not intended to limit the computer device 400 and may include more or fewer components than those shown, or some of the components may be combined, or different components may be included, such as input output devices, network access devices, etc.
The Processor 401 may be a Central Processing Unit (CPU), and the Processor 401 may be other general purpose Processor, a Digital Signal Processor (DSP), an Application Specific Integrated Circuit (ASIC), an off-the-shelf Programmable Gate Array (FPGA) or other Programmable logic device, discrete Gate or transistor logic, discrete hardware components, etc. A general purpose processor may be a microprocessor or the processor may be any conventional processor or the like.
The storage 402 may in some embodiments be an internal storage unit of the computer device 400, such as a hard disk or a memory of the computer device 400. The memory 402 may also be an external storage device of the computer device 400 in other embodiments, such as a plug-in hard disk, a Smart Media Card (SMC), a Secure Digital (SD) Card, a Flash memory Card (Flash Card), and the like, which are provided on the computer device 400. Further, the memory 402 may also include both internal storage units and external storage devices of the computer device 400. The memory 402 is used for storing an operating system, an application program, a BootLoader (BootLoader), data, and other programs, such as program codes of the computer programs. The memory 402 may also be used to temporarily store data that has been output or is to be output.
In addition, an embodiment of the present application further provides a computer-readable storage medium, where a computer program is stored, and when the computer program is executed by a processor, the computer program implements the steps in any of the method embodiments described above.
The embodiments of the present application provide a computer program product, which when executed on a computer device, enables the computer device to implement the steps in the above method embodiments.
In several embodiments provided herein, it will be understood that each block in the flowchart or block diagrams may represent a module, segment, or portion of code, which comprises one or more executable instructions for implementing the specified logical function(s). It should also be noted that, in some alternative implementations, the functions noted in the block may occur out of the order noted in the figures. For example, two blocks shown in succession may, in fact, be executed substantially concurrently, or the blocks may sometimes be executed in the reverse order, depending upon the functionality involved.
The functions, if implemented in the form of software functional modules and sold or used as a stand-alone product, may be stored in a computer readable storage medium. Based on such understanding, the technical solution of the present application or portions thereof that substantially contribute to the prior art may be embodied in the form of a software product stored in a storage medium and including instructions for causing a computer device to perform all or part of the steps of the method according to the embodiments of the present application. And the aforementioned storage medium includes: a U-disk, a removable hard disk, a Read-Only Memory (ROM), a Random Access Memory (RAM), a magnetic disk or an optical disk, and other various media capable of storing program codes.
The above-mentioned embodiments are further detailed to explain the objects, technical solutions and advantages of the present application, and it should be understood that the above-mentioned embodiments are only examples of the present application and are not intended to limit the scope of the present application. It should be understood that any modifications, equivalents, improvements and the like, which come within the spirit and principle of the present application, may occur to those skilled in the art and are intended to be included within the scope of the present application.

Claims (10)

1. A method of model training, comprising:
acquiring a training data set, wherein the training data set comprises marked data and unmarked data;
training a preset teacher model by using the marked data until the teacher model reaches a preset first convergence condition to obtain a target teacher model;
carrying out BN layer weight sharing on the target teacher model and the student model, wherein the model complexity of the target teacher model is greater than that of the student model;
and performing joint training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain a target student model, wherein the target student model can be used for being deployed to an end-side computing platform.
2. The model training method of claim 1, wherein the target teacher model and the student model each have a plurality of BN layers, and wherein the BN layer weight sharing of the target teacher model and the student model comprises:
sharing the multi-level BN layer weights of the target teacher model to the student model.
3. The model training method of claim 2, wherein the student model and the target teacher model fix the multi-level BN layer weights of the target teacher model and the multi-level BN layer weights of the student model in a joint training.
4. The model training method of claim 1, wherein the jointly training the student model and the target teacher model using the labeled data and the unlabeled data until the student model reaches a second predetermined convergence condition to obtain a target student model comprises:
taking the training data set as input data of the student model and the target teacher model, and outputting a first prediction result of the student model and a second prediction result of the target teacher model;
calculating the total loss value of the target loss function according to the first prediction result and the second prediction result;
and updating the student model according to the total loss value until the student model converges to obtain the target student model.
5. The model training method of claim 4, wherein said calculating a total loss value for an objective loss function based on said first predictor and said second predictor comprises:
determining the data type of the input data, wherein the data type is marked data or unmarked data;
and calculating the total loss value of the target loss function according to the data type.
6. The model training method of claim 5, wherein the objective loss function is:
Figure 104157DEST_PATH_IMAGE001
wherein the content of the first and second substances,
Figure 587091DEST_PATH_IMAGE002
is a function of the predicted loss for the student model,
Figure 872710DEST_PATH_IMAGE003
for the predicted loss function of the target teacher model,
Figure 48477DEST_PATH_IMAGE004
in order to be able to predict the result of the first prediction,
Figure 285292DEST_PATH_IMAGE005
in order to be the result of the second prediction,
Figure 384966DEST_PATH_IMAGE006
is a mean square error between the first prediction and the second prediction,
Figure 508780DEST_PATH_IMAGE007
if the data type is the marked data, the data type is the data type
Figure 172891DEST_PATH_IMAGE008
If the data type is the unmarked data, then
Figure 460784DEST_PATH_IMAGE009
7. The model training method of claim 4, wherein said updating the student model according to the total loss value until the student model converges to obtain the objective student model comprises:
if the total loss value is not less than a preset threshold value, updating a first feature layer weight of the student model and a second feature layer weight of the target teacher model to obtain a new student model and a new target teacher model;
and predicting the training data set by using the new student model and the new target teacher model until the total loss value is smaller than the preset threshold value to obtain the target student model.
8. A model training apparatus, comprising:
the system comprises an acquisition module, a storage module and a processing module, wherein the acquisition module is used for acquiring a training data set, and the training data set comprises marked data and unmarked data;
the first training module is used for training a preset teacher model by using the marked data until the teacher model reaches a preset first convergence condition, so as to obtain a target teacher model;
the sharing module is used for carrying out BN layer weight sharing on the target teacher model and the student model, and the model complexity of the target teacher model is greater than that of the student model;
and the second training module is used for performing combined training on the student model and the target teacher model by using the labeled data and the unlabeled data until the student model reaches a preset second convergence condition to obtain a target student model, and the target student model can be used for deploying to an end-side computing platform.
9. A computer device comprising a processor and a memory for storing a computer program which, when executed by the processor, implements the model training method of any one of claims 1 to 7.
10. A computer-readable storage medium, characterized in that it stores a computer program which, when being executed by a processor, carries out the model training method according to any one of claims 1 to 7.
CN202210024447.7A 2022-01-11 2022-01-11 Model training method, device, equipment and storage medium Pending CN114092918A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210024447.7A CN114092918A (en) 2022-01-11 2022-01-11 Model training method, device, equipment and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210024447.7A CN114092918A (en) 2022-01-11 2022-01-11 Model training method, device, equipment and storage medium

Publications (1)

Publication Number Publication Date
CN114092918A true CN114092918A (en) 2022-02-25

Family

ID=80308508

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210024447.7A Pending CN114092918A (en) 2022-01-11 2022-01-11 Model training method, device, equipment and storage medium

Country Status (1)

Country Link
CN (1) CN114092918A (en)

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US20190287515A1 (en) * 2018-03-16 2019-09-19 Microsoft Technology Licensing, Llc Adversarial Teacher-Student Learning for Unsupervised Domain Adaptation
US20200134506A1 (en) * 2018-10-29 2020-04-30 Fujitsu Limited Model training method, data identification method and data identification device
CN113111968A (en) * 2021-04-30 2021-07-13 北京大米科技有限公司 Image recognition model training method and device, electronic equipment and readable storage medium
CN113205002A (en) * 2021-04-08 2021-08-03 南京邮电大学 Low-definition face recognition method, device, equipment and medium for unlimited video monitoring
CN113281048A (en) * 2021-06-25 2021-08-20 华中科技大学 Rolling bearing fault diagnosis method and system based on relational knowledge distillation
CN113724242A (en) * 2021-09-10 2021-11-30 吉林大学 Combined grading method for diabetic retinopathy and diabetic macular edema

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
US20190287515A1 (en) * 2018-03-16 2019-09-19 Microsoft Technology Licensing, Llc Adversarial Teacher-Student Learning for Unsupervised Domain Adaptation
US20200134506A1 (en) * 2018-10-29 2020-04-30 Fujitsu Limited Model training method, data identification method and data identification device
CN113205002A (en) * 2021-04-08 2021-08-03 南京邮电大学 Low-definition face recognition method, device, equipment and medium for unlimited video monitoring
CN113111968A (en) * 2021-04-30 2021-07-13 北京大米科技有限公司 Image recognition model training method and device, electronic equipment and readable storage medium
CN113281048A (en) * 2021-06-25 2021-08-20 华中科技大学 Rolling bearing fault diagnosis method and system based on relational knowledge distillation
CN113724242A (en) * 2021-09-10 2021-11-30 吉林大学 Combined grading method for diabetic retinopathy and diabetic macular edema

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
王金甲等: "基于平均教师模型的弱标记半监督声音事件检测", 《复旦学报(自然科学版)》 *

Similar Documents

Publication Publication Date Title
US9990558B2 (en) Generating image features based on robust feature-learning
US20180181867A1 (en) Artificial neural network class-based pruning
US20160358070A1 (en) Automatic tuning of artificial neural networks
JP2019528502A (en) Method and apparatus for optimizing a model applicable to pattern recognition and terminal device
CN113159073B (en) Knowledge distillation method and device, storage medium and terminal
WO2021089013A1 (en) Spatial graph convolutional network training method, electronic device and storage medium
CN111406264A (en) Neural architecture search
KR102250728B1 (en) Sample processing method and device, related apparatus and storage medium
CN113723589A (en) Hybrid precision neural network
CN116644804B (en) Distributed training system, neural network model training method, device and medium
CN112966754B (en) Sample screening method, sample screening device and terminal equipment
WO2022246986A1 (en) Data processing method, apparatus and device, and computer-readable storage medium
CN116681127B (en) Neural network model training method and device, electronic equipment and storage medium
CN116912923B (en) Image recognition model training method and device
CN113409307A (en) Image denoising method, device and medium based on heterogeneous noise characteristics
CN112801107A (en) Image segmentation method and electronic equipment
CN115953651A (en) Model training method, device, equipment and medium based on cross-domain equipment
CN114241411B (en) Counting model processing method and device based on target detection and computer equipment
CN114092918A (en) Model training method, device, equipment and storage medium
CN114065913A (en) Model quantization method and device and terminal equipment
CN112561050B (en) Neural network model training method and device
CN110222693B (en) Method and device for constructing character recognition model and recognizing characters
CN109800873B (en) Image processing method and device
CN112669270A (en) Video quality prediction method and device and server
CN113570053A (en) Neural network model training method and device and computing equipment

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination