CN114565797A - Neural network training and image classification method and device for classification - Google Patents

Neural network training and image classification method and device for classification Download PDF

Info

Publication number
CN114565797A
CN114565797A CN202210209488.3A CN202210209488A CN114565797A CN 114565797 A CN114565797 A CN 114565797A CN 202210209488 A CN202210209488 A CN 202210209488A CN 114565797 A CN114565797 A CN 114565797A
Authority
CN
China
Prior art keywords
network
updated
parameter
prediction result
student
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202210209488.3A
Other languages
Chinese (zh)
Inventor
刘吉豪
刘宇
刘博晓
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Sensetime Intelligent Technology Co Ltd
Original Assignee
Shanghai Sensetime Intelligent Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Sensetime Intelligent Technology Co Ltd filed Critical Shanghai Sensetime Intelligent Technology Co Ltd
Priority to CN202210209488.3A priority Critical patent/CN114565797A/en
Publication of CN114565797A publication Critical patent/CN114565797A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)

Abstract

The present disclosure provides a method, an apparatus, a computer device and a storage medium for training a classified neural network, wherein the method comprises: acquiring sample data, and respectively inputting the sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result; determining a pre-updated first network parameter based on the first prediction result, the second prediction result, and a current network parameter of the student network; and updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameter, and updating the current network parameter of the student network based on the updated temperature coefficient to obtain an updated student network.

Description

Neural network training and image classification method and device for classification
Technical Field
The present disclosure relates to the field of computer technologies, and in particular, to a method and an apparatus for training a neural network for classification and image classification.
Background
After the neural network is trained, the network precision is generally high, but the network size is relatively large, and if the neural network is to be deployed on a device with lower performance, the neural network needs to be further compressed to reduce the network size. Accordingly, since the network accuracy is reduced after the compression, knowledge distillation needs to be performed on the compressed neural network (i.e., the student network) based on the neural network before the compression (i.e., the teacher network) to improve the network accuracy of the compressed neural network.
In the related art, when knowledge distillation is performed, input data is generally input to a teacher network and a student network at the same time, and the student network is trained by using an output result of the teacher network as supervision data.
Disclosure of Invention
The embodiment of the disclosure at least provides a neural network training and image classification method and device for classification.
In a first aspect, an embodiment of the present disclosure provides a method for training a neural network based on meta-learning, including:
acquiring sample data, and respectively inputting the sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
determining a pre-updated first network parameter based on the first prediction result, the second prediction result, and a current network parameter of the student network;
and updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameters, and updating the current network parameters of the student network based on the updated temperature coefficients to obtain an updated student network.
In the method, the pre-updated first network parameter can be determined through the first detection result and the second detection result obtained by pre-training, the element parameter is updated through the first network parameter, and then the temperature coefficient is adjusted based on the element parameter. By the method, the temperature coefficient is determined as the training parameter in the student network training process, the temperature coefficients of the teacher network and the student network can be automatically adjusted, the optimal temperature coefficient is determined for the teacher network and the student network, and the network precision of the teacher network and the student network is further improved.
In one possible embodiment, the determining a pre-updated first network parameter based on the first prediction result, the second prediction result, and a current network parameter of the student network includes:
acquiring current network parameters of the student network;
determining a training loss based on the first prediction result and the second prediction result;
and adjusting the acquired current network parameters of the student network based on training loss, and determining pre-updated first network parameters.
Here, the obtained current network parameters of the student network are adjusted, and the current network parameters of the student network are not adjusted, but the internal network parameters of the student network are not directly adjusted. The pre-updated first network parameters are used for updating the meta parameters, so that the aim of distilling the corresponding meta parameters of the student network based on the teacher network can be fulfilled by the method.
In a possible embodiment, the updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameter includes:
updating meta-parameters associated with the student network training based on the pre-updated first network parameters; the element parameters are used for determining temperature coefficients corresponding to the student network and the teacher network;
and updating the temperature coefficient based on the updated element parameter.
In one possible embodiment, the sample data includes training sample data and verification sample data;
the step of inputting the sample data into a teacher network and a student network to be trained respectively to obtain a first prediction result and a second prediction result comprises the following steps:
respectively inputting the training sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
the updating the meta-parameters associated with the student network training based on the pre-updated first network parameters comprises:
taking the pre-updated first network parameter as a current network parameter of the student network to obtain a pre-updated student network;
inputting the verification sample data into the pre-updating student network to obtain a third prediction result;
and updating the meta-parameter based on the third prediction result and the marking information corresponding to the verification sample data.
By adopting the method, the element parameters can be automatically updated, so that the subsequent temperature coefficient calculated based on the element parameters can better improve the performance of a student network and a teacher network.
In a possible implementation manner, the updating the meta-parameter based on the third prediction result and the label information corresponding to the verification sample data includes:
determining verification loss based on the third prediction result and the marking information corresponding to the verification sample data;
updating the meta-parameter based on the verification loss.
In a possible embodiment, the determining, based on the third prediction result and the annotation information corresponding to the verification sample data, a verification loss includes:
determining error sample data of the pre-updated student network prediction error based on the third prediction result and the marking information corresponding to the verification sample data;
and determining the verification loss based on the confidence information of the prediction result corresponding to the error sample data in the third prediction result and the preset confidence information corresponding to the marking information corresponding to the verification sample data.
By updating the meta-parameters based on the error sample data, the updating speed of the meta-parameters can be improved, and the detection precision of the student network and the teacher network on the error sample data can be better improved by the temperature coefficient determined based on the meta-parameters.
In a possible embodiment, the determining the verification loss based on the confidence information of the prediction result corresponding to the incorrect sample data in the third prediction result and the preset confidence information corresponding to the labeling information corresponding to the verification sample data includes:
calculating the square sum of the difference between the confidence information of each classification result in the third prediction result of the error sample data and the preset confidence information of each classification result in the labeling information of the error sample data aiming at any error sample data;
and taking the sum of squares corresponding to each error sample data as the verification loss.
In the above embodiment, when the verification loss is calculated, the difference between the confidence information of each classification result in the third prediction result and the preset confidence information of each classification result in the labeling information of the error sample data is considered, and the difference is amplified by calculating the sum of squares of the difference between the confidence information of each classification result in the third prediction result and the preset confidence information of each classification result in the labeling information of the error sample data, so that the classification capability of the student network for the error sample data can be improved in an important way when the neural network is trained through the verification loss, and the trained student neural network has higher network accuracy.
In a possible embodiment, the updating the current network parameter of the student network based on the updated temperature coefficient includes:
inputting the sample data to the student network to determine a fourth prediction result through the updated temperature coefficient and the student network; inputting the sample data into the teacher network to determine a fifth prediction result through the teacher network and the updated temperature coefficient;
re-determining a training loss based on the fourth prediction result and the fifth prediction result, and updating a current network parameter of the student network based on the re-determined training loss.
By adopting the method, the temperature coefficients of the student network and the teacher network can be automatically updated, and the student network is trained again based on the updated temperature coefficients, so that the student network can intelligently update the current network parameters, and the performance of the student network is improved.
In one possible embodiment, the meta-parameters include network parameters of a parameter generation network associated with the student network training;
the updating the temperature coefficient based on the updated element parameter includes:
generating network parameters of the network by taking the updated meta-parameters as parameters to obtain an updated parameter generation network;
the network is generated based on the updated parameters and the initial temperature coefficient, and the temperature coefficient is re-determined.
By adopting the method, the training of the parameter generation network can be realized through the training element parameters, the accuracy of the generated temperature coefficient can be improved through the updated parameter generation network and the re-determined temperature coefficient, and the precision influence caused by directly adjusting the too large or too small temperature coefficient is avoided.
In one possible implementation, the temperature coefficient includes a first temperature coefficient corresponding to the student network and a second temperature coefficient corresponding to the teacher network;
the generating a network and an initial temperature coefficient based on the updated parameters, re-determining the temperature coefficient, comprising:
determining the first temperature coefficient and the second temperature coefficient based on the updated parameter generation network and the initial temperature coefficient.
By adopting the method, the first temperature coefficient and the second temperature coefficient can be generated simultaneously through one parameter generation network, and the determination speed of the temperature coefficient is improved.
In one possible implementation, the meta-parameter includes a first meta-parameter corresponding to the student network and a second meta-parameter corresponding to the teacher network;
the generating the network parameters of the network by using the updated meta-parameters as the parameters to obtain the updated parameter generating network includes:
taking the updated first meta-parameter as a network parameter of the parameter generation network to obtain an updated first parameter generation network; the updated second element parameter is used as the network parameter of the parameter generation network to obtain an updated second parameter generation network;
the generating a network and an initial temperature coefficient based on the updated parameters, re-determining the temperature coefficient, comprising:
generating a network and the initial temperature coefficient based on the updated first parameter, and determining a first temperature coefficient corresponding to the student network; and generating a network and the initial temperature coefficient based on the updated second parameter, and determining a second temperature coefficient corresponding to the teacher network.
By adopting the method, the temperature coefficients of the teacher network and the student network can be respectively generated through the two parameter generation networks, and the accuracy of the generated temperature coefficients is improved.
In a second aspect, an embodiment of the present disclosure provides an image classification method, including:
acquiring an image to be detected;
the student network obtained by training based on the neural network training method for classification according to the first aspect or any one of the possible embodiments of the first aspect identifies the image to be detected, and obtains a classification result corresponding to the image to be detected.
In a third aspect, an embodiment of the present disclosure further provides a training apparatus for a classified neural network, including:
the acquisition module is used for acquiring sample data and inputting the sample data to a teacher network and a student network to be trained respectively to obtain a first prediction result and a second prediction result;
a determining module for determining a pre-updated first network parameter based on the first prediction result, the second prediction result and a current network parameter of the student network;
and the updating module is used for updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameters, and updating the current network parameters of the student network based on the updated temperature coefficients to obtain an updated student network.
In a possible embodiment, the determining module, when determining the pre-updated first network parameter based on the first prediction result, the second prediction result and the current network parameter of the student network, is configured to:
acquiring current network parameters of the student network;
determining a training loss based on the first prediction result and the second prediction result;
and adjusting the acquired current network parameters of the student network based on training loss, and determining pre-updated first network parameters.
In a possible implementation manner, the updating module, when updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameter, is configured to:
updating meta-parameters associated with the student network training based on the pre-updated first network parameters; the element parameters are used for determining temperature coefficients corresponding to the student network and the teacher network;
and updating the temperature coefficient based on the updated element parameter.
In one possible embodiment, the sample data includes training sample data and verification sample data;
the acquisition module is used for respectively inputting the sample data into a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result:
respectively inputting the training sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
the updating module, when updating the meta-parameters associated with the student network training based on the pre-updated first network parameters, is configured to:
taking the pre-updated first network parameter as the current network parameter of the student network to obtain a pre-updated student network;
inputting the verification sample data into the pre-updating student network to obtain a third prediction result;
and updating the meta-parameter based on the third prediction result and the marking information corresponding to the verification sample data.
In a possible implementation manner, when the meta-parameter is updated based on the third prediction result and the tagging information corresponding to the verification sample data, the updating module is configured to:
determining verification loss based on the third prediction result and the marking information corresponding to the verification sample data;
updating the meta-parameter based on the verification loss.
In a possible implementation manner, when determining a verification loss based on the third prediction result and the label information corresponding to the verification sample data, the updating module is configured to:
determining error sample data of the pre-updated student network prediction error based on the third prediction result and the marking information corresponding to the verification sample data;
and determining the verification loss based on the confidence information of the prediction result corresponding to the error sample data in the third prediction result and the preset confidence information corresponding to the marking information corresponding to the verification sample data.
In a possible embodiment, when determining the verification loss based on the confidence information of the prediction result corresponding to the incorrect sample data in the third prediction result and the preset confidence information corresponding to the label information corresponding to the verification sample data, the updating module is configured to:
calculating the square sum of the difference between the confidence information of each classification result in the third prediction result of the error sample data and the preset confidence information of each classification result in the labeling information of the error sample data aiming at any error sample data;
and taking the sum of squares corresponding to each error sample data as the verification loss.
In a possible embodiment, the update module, when updating the current network parameters of the student network based on the updated temperature coefficients, is configured to:
inputting the sample data to the student network to determine a fourth prediction result through the updated temperature coefficient and the student network; inputting the sample data into the teacher network to determine a fifth prediction result through the teacher network and the updated temperature coefficient;
re-determining a training loss based on the fourth prediction result and the fifth prediction result, and updating a current network parameter of the student network based on the re-determined training loss.
In one possible embodiment, the meta-parameters include network parameters of a parameter generation network associated with the student network training;
the updating module, when updating the temperature coefficient based on the updated meta-parameter, is configured to:
generating network parameters of the network by taking the updated meta-parameters as parameters to obtain an updated parameter generation network;
the network is generated based on the updated parameters and the initial temperature coefficient, and the temperature coefficient is re-determined.
In one possible implementation, the temperature coefficient includes a first temperature coefficient corresponding to the student network and a second temperature coefficient corresponding to the teacher network;
the update module, when generating the network and the initial temperature coefficient based on the updated parameters and re-determining the temperature coefficient, is configured to:
determining the first temperature coefficient and the second temperature coefficient based on the updated parameter generation network and the initial temperature coefficient.
In one possible implementation, the meta-parameter includes a first meta-parameter corresponding to the student network and a second meta-parameter corresponding to the teacher network;
the updating module is used for generating the network parameters of the network by taking the updated meta-parameters as the parameters, and when the updated parameter generation network is obtained, the updating module is used for:
taking the updated first meta-parameter as a network parameter of the parameter generation network to obtain an updated first parameter generation network; the updated second element parameter is used as the network parameter of the parameter generation network to obtain an updated second parameter generation network;
the update module, when generating the network and the initial temperature coefficient based on the updated parameters and re-determining the temperature coefficient, is configured to:
generating a network and the initial temperature coefficient based on the updated first parameter, and determining a first temperature coefficient corresponding to the student network; and generating a network and the initial temperature coefficient based on the updated second parameter, and determining a second temperature coefficient corresponding to the teacher network.
In a fourth aspect, an embodiment of the present disclosure further provides an image classification device, including:
the second acquisition module is used for acquiring an image to be detected;
and the recognition module is used for recognizing the image to be detected based on the student network obtained by training the neural network training method for classification according to the first aspect or any one of the possible implementation manners of the first aspect, so as to obtain a classification result corresponding to the image to be detected.
In a fifth aspect, an embodiment of the present disclosure further provides a computer device, including: a processor, a memory and a bus, the memory storing machine-readable instructions executable by the processor, the processor and the memory communicating via the bus when the computer device is running, the machine-readable instructions, when executed by the processor, performing the steps of the first aspect, or any one of the possible implementations of the first aspect, or performing the steps of the second aspect, or any one of the possible implementations of the second aspect.
In a sixth aspect, this disclosed embodiment also provides a computer readable storage medium, on which a computer program is stored, which when executed by a processor performs the steps in the first aspect, or any one of the possible embodiments of the first aspect, or performs the steps in the second aspect, or any one of the possible embodiments of the second aspect.
For the above description of the training of the neural network for classification, the image classification apparatus, the computer device, and the computer readable storage medium, reference is made to the above description of the training of the neural network for classification, and the image classification method, which is not repeated herein.
It is to be understood that both the foregoing general description and the following detailed description are exemplary and explanatory only and are not restrictive of the technical aspects of the disclosure.
In order to make the aforementioned objects, features and advantages of the present disclosure more comprehensible, preferred embodiments accompanied with figures are described in detail below.
Drawings
The accompanying drawings, which are incorporated in and constitute a part of this specification, illustrate embodiments consistent with the present disclosure and, together with the description, serve to explain the principles of the disclosure.
In order to more clearly illustrate the technical solutions of the embodiments of the present disclosure, the drawings required for use in the embodiments will be briefly described below, and the drawings herein incorporated in and forming a part of the specification illustrate embodiments consistent with the present disclosure and, together with the description, serve to explain the technical solutions of the present disclosure. It is appreciated that the following drawings depict only certain embodiments of the disclosure and are therefore not to be considered limiting of its scope, for those skilled in the art will be able to derive additional related drawings therefrom without the benefit of the inventive faculty.
FIG. 1 illustrates a flow chart of a method of training a neural network for classification provided by an embodiment of the present disclosure;
FIG. 2 is a schematic diagram illustrating an architecture of a training apparatus for a classified neural network according to an embodiment of the present disclosure;
fig. 3 is a schematic diagram illustrating an architecture of an image classification apparatus provided in an embodiment of the present disclosure;
FIG. 4 is a schematic structural diagram of a computer device provided by an embodiment of the present disclosure;
fig. 5 shows a schematic structural diagram of another computer device provided by the embodiment of the present disclosure.
Detailed Description
In order to make the objects, technical solutions and advantages of the embodiments of the present disclosure more clear, the technical solutions of the embodiments of the present disclosure will be described clearly and completely with reference to the drawings in the embodiments of the present disclosure, and it is obvious that the described embodiments are only a part of the embodiments of the present disclosure, not all of the embodiments. The components of the embodiments of the present disclosure, generally described and illustrated in the figures herein, can be arranged and designed in a wide variety of different configurations. Thus, the following detailed description of the embodiments of the present disclosure, presented in the figures, is not intended to limit the scope of the claimed disclosure, but is merely representative of selected embodiments of the disclosure. All other embodiments, which can be derived by a person skilled in the art from the embodiments of the disclosure without making creative efforts, shall fall within the protection scope of the disclosure.
In the related art, when knowledge distillation is performed, input data is generally input to a teacher network and a student network at the same time, and the student network is trained by using an output result of the teacher network as supervision data.
The research shows that the temperature coefficient has a large influence on the network precision, and when the model precision is high, the model precision may be reduced or increased after the temperature coefficient is adjusted, so that the model precision can be improved by selecting a proper temperature coefficient. In the related art, the student network and the teacher network generally adopt fixed temperature coefficients.
Based on this, the present disclosure provides a training method, an apparatus, a computer device, and a storage medium for a classified neural network, which may determine a pre-updated first network parameter through a first detection result and a second detection result obtained by pre-training, and adjust a temperature coefficient through the first network parameter. By the method, the temperature coefficient is determined as the training parameter in the student network training process, the temperature coefficients of the teacher network and the student network can be automatically adjusted, the optimal temperature coefficient is determined for the teacher network and the student network, and the network precision of the teacher network and the student network is further improved.
It should be noted that: like reference numbers and letters refer to like items in the following figures, and thus, once an item is defined in one figure, it need not be further defined and explained in subsequent figures.
The term "and/or" herein merely describes an associative relationship, meaning that three relationships may exist, e.g., a and/or B, may mean: a exists alone, A and B exist simultaneously, and B exists alone. In addition, the term "at least one" herein means any one of a plurality or any combination of at least two of a plurality, for example, including at least one of A, B, C, and may mean including any one or more elements selected from the group consisting of A, B and C.
To facilitate understanding of the present embodiment, a detailed description is first given of a training method for a classified neural network disclosed in the embodiments of the present disclosure, and an execution subject of the training for the classified neural network provided in the embodiments of the present disclosure is generally a computer device with certain computing power, where the computer device includes, for example: a terminal device, which may be a User Equipment (UE), a mobile device, a User terminal, a cellular phone, a cordless phone, a Personal Digital Assistant (PDA), a handheld device, a computing device, a vehicle mounted device, a wearable device, or a server or other processing device. In some possible implementations, the method for training a neural network for classification, image classification, may be implemented by a processor calling computer-readable instructions stored in a memory.
Referring to fig. 1, a flowchart of a training method for a classified neural network provided in an embodiment of the present disclosure is shown, where the method includes steps 101 to 103, where:
step 101, obtaining sample data, and respectively inputting the sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
step 102, determining a pre-updated first network parameter based on the first prediction result, the second prediction result and the current network parameter of the student network;
and 103, updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameter, and updating the current network parameter of the student network based on the updated temperature coefficient to obtain an updated student network.
The following is a detailed description of the above steps:
for step 101,
In one possible embodiment, the sample data may include an unprocessed original sample, and a data enhanced sample.
The data enhancement method may include Label Smoothing Regularization (LSR), mixed data enhancement (Mixup), cropping data enhancement (CutMix), image flipping, image rotation, image cropping, image deformation, image scaling, noise addition, image blurring, color transformation, and the like. The student network is trained by using the sample data after data enhancement, so that the recognition capability of the student network on the enhanced sample data can be improved, and the network precision of the student network is improved.
In one possible implementation, the student network may be a network after network compression of the teacher network, and the network inference targets of the student network and the teacher network are the same.
In one possible embodiment, the sample data comprises training sample data and verification sample data. When the sample data is input to the teacher network and the student network to be trained respectively to obtain the first prediction result and the second prediction result, the training sample data may be input to the teacher network and the student network to be trained respectively to obtain the first prediction result and the second prediction result.
The first prediction result may include a first prediction value and a first confidence corresponding to the first prediction value; the second prediction result may include a second prediction value and a second confidence corresponding to the second prediction value.
Specifically, when the first prediction result is determined, the training sample data may be input into the teacher network, and the classroom network may output the first prediction value and a first confidence degree corresponding to the first prediction value; when the second prediction result is determined, the training sample data may be input to the student network to be trained, and the student network to be trained may output the second prediction value and a second confidence corresponding to the second prediction value.
With respect to step 102,
In a possible implementation manner, when determining the pre-updated first network parameter based on the first prediction result, the second prediction result, and the current network parameter of the student network, the current network parameter of the student network may be obtained first, then the training loss may be determined based on the first prediction result and the second prediction result, and finally the obtained current network parameter of the student network may be adjusted based on the training loss to determine the pre-updated first network parameter.
Here, the obtained current network parameters of the student network are adjusted, and the current network parameters of the student network are not adjusted, but the internal network parameters of the student network are not directly adjusted. The pre-updated first network parameters are used for updating the meta parameters, so that the aim of distilling the corresponding meta parameters of the student network based on the teacher network can be fulfilled by the method.
In determining a training loss based on the first predictor and the second predictor, a cross-entropy loss between the first predictor and the second predictor may be calculated, for example.
When the obtained current network parameter is adjusted based on the training loss, for example, a gradient descent method may be adopted, and for example, the following formula may be adopted for calculation:
Figure BDA0003532578200000131
wherein, thetas' is a pre-updated first network parameter, θsFor the current network parameter, α is the learning rate (i.e., step size), LtFor the training loss, LtsPhi) is at a parameter thetasAnd phi is a meta parameter, the meta parameter is used for determining the temperature coefficients corresponding to the student network and the teacher network, and the initial value of the meta parameter is a preset value and can be generated through the parameter generation network in the step 104. The description of the meta-parameters will be presented below, and will not be described here.
For step 103,
Wherein the temperature coefficient is used for controlling softness of the result output by the neural network (which may include the student network and the teacher network), the softness is used for representing the difference magnitude between the confidence degrees of the plurality of predicted values output by the neural network, when the temperature coefficient is higher, the difference between the confidence degrees of the plurality of predicted values output by the neural network is smaller, and when the temperature coefficient is lower, the difference between the confidence degrees of the plurality of predicted values output by the neural network is larger.
For example, after a certain sample data is input into the neural network, three predicted values "cat", "dog" and "bird" may be output, where the confidence levels of the three predicted values are "1%", "98%", and "1%" respectively when the temperature coefficient is 1, and the confidence levels of the predicted values are "50%", "30%", and "20%" respectively when the temperature coefficient is 10.
In a possible implementation manner, when the temperature coefficients corresponding to the student network and the teacher network are updated based on the pre-updated first network parameter, the first network parameter may be directly used as a current network parameter of the student network to obtain a pre-updated student network, then verification sample data is input into the pre-updated student network and the teacher network to obtain a sixth prediction result and a seventh prediction result, and then the temperature coefficient is adjusted based on the sixth prediction result and the seventh prediction result.
Here, when the temperature coefficient is adjusted based on the sixth prediction result and the seventh prediction result, the gradient descent method adjustment may be directly employed as an example.
Considering that the adjustment of the temperature coefficient directly may affect the network accuracy of the student neural network due to too large or too small adjustment force, in one possible embodiment, when the temperature coefficients corresponding to the student network and the teacher network are updated based on the pre-updated first network parameter, the element parameters associated with the student network training may be updated based on the pre-updated first network parameter; the element parameters are used for determining temperature coefficients corresponding to the student network and the teacher network; the temperature coefficient is then updated based on the updated meta-parameters.
Here, the meta parameter is a parameter that can be learned in the meta learning process. Specifically, the meta-parameters include network parameters of a parameter generation network for calculating a temperature coefficient, and in the present disclosure, there are two learning objectives, the first learning objective is the classification capability of the student network, and the second learning objective is the capability of the parameter generation network to generate the temperature coefficient, and the generation of the temperature coefficient aims to better improve the classification capability of the student network.
In a possible implementation manner, when updating the meta-parameters associated with the student network training based on the pre-updated first network parameters, the pre-updated first network parameters may be used as current network parameters of the student network to obtain a pre-updated student network; then inputting the verification sample data to the pre-updating student network to obtain a third prediction result; and finally updating the meta-parameter based on the third prediction result and the marking information corresponding to the verification sample data.
By adopting the method, the element parameters can be automatically updated, so that the subsequent temperature coefficient calculated based on the element parameters can better improve the performance of the student network and the teacher network.
Specifically, continuing the above example, when the first network parameter is used as the current network parameter of the student network to obtain the pre-updated student network, θ in the student network may be usedsIs updated to thetas', to obtain said pre-updated student network. After the verification sample data is input to the pre-update student network, the pre-update student network may output the third prediction result, where the third prediction result includes a third prediction value and a third confidence corresponding to the third prediction value, and the third confidence is used to indicate a probability that the verification sample data belongs to the third prediction value.
Here, it should be noted that the pre-updated student network is different from the student network, and the purpose of obtaining the pre-updated student network is only to update the temperature coefficient, and the pre-updated student network may be understood as copying the student network and setting the network parameter of the copied student network as the first network parameter, where the network parameter of the pre-updated student network is θs', while the current network parameter of the student network is still θs
Further, in a possible implementation manner, when the meta-parameter is updated based on the third prediction result and the label information corresponding to the verification sample data, a verification loss may be determined based on the third prediction result and the label information corresponding to the verification sample data, and then the meta-parameter is updated based on the verification loss.
Wherein the labeling information may be a correct category of the sample data that is manually labeled.
In a possible implementation manner, when determining a verification loss based on the third prediction result and the labeling information corresponding to the verification sample data, determining error sample data of the pre-updated student network prediction error based on the third prediction result and the labeling information corresponding to the verification sample data; and then determining the verification loss based on the confidence information of the prediction result corresponding to the error sample data in the third prediction result and the preset confidence information corresponding to the marking information corresponding to the verification sample data.
By updating the meta-parameters based on the error sample data, the updating speed of the meta-parameters can be improved, and the detection precision of the student network and the teacher network on the error sample data can be better improved by the temperature coefficient determined based on the meta-parameters.
The specific implementation method comprises the following steps:
step 1, when the error sample data is determined, comparing a third predicted value in the third predicted result with the labeling information, and when the third predicted value is different from the labeling information, determining that the third predicted result is an error predicted result, where the sample data corresponding to the error predicted result is the error sample data.
Step 2, when the verification loss is determined based on the confidence information of the prediction result corresponding to the error sample data in the third prediction result and the preset confidence information corresponding to the labeling information corresponding to the verification sample data, calculating the square sum of the difference between the confidence information of each classification result in the third prediction result of the error sample data and the preset confidence information of each classification result in the labeling information of the error sample data for any error sample data; and then taking the sum of the squares corresponding to each error sample data as the verification loss.
Illustratively, the calculation may be made by the following formula:
Figure BDA0003532578200000161
wherein L isvFor the verification loss, θsFor the current network parameters of the student network, i is the number of the error sample data, G is the number set of the error sample data, c is the maximum value of the data type number of the verification sample data, j represents the jth data type, PsConfidence information, P, of the prediction result corresponding to said erroneous sample datas (ij)The probability that the prediction result corresponding to the ith error sample data is the jth data type is shown, y is confidence information of the labeling information, and y is the confidence coefficient of the labeling information(ij)Indicating the probability that the label information corresponding to the ith error sample data is the jth data type, namely when the label information is the jth data type, y(ij)Is 1, when the label information is not the jth data category, y(ij)Is 0.
Here, the data type number is used to indicate a sequence number of the data type. Any data category may correspond to a data category number, and for example, the student network may include three data categories "cat", "dog", "bird", and the data category numbers of the three data categories are "1", "2", "3", respectively.
In this way, when the verification loss is calculated, the difference between the confidence information of each classification result in the third prediction result and the preset confidence information of each classification result in the labeling information of the error sample data is considered, and the difference is amplified by calculating the sum of squares of the difference between the confidence information of each classification result in the third prediction result and the preset confidence information of each classification result in the labeling information of the error sample data, so that the classification capability of the student network on the error sample data can be improved in an important way when the neural network is trained through the verification loss, and the network precision of the student network obtained through training is higher.
In a possible embodiment, when updating the meta-parameter based on the verification loss, a new meta-parameter may be calculated first by the following formula:
Figure BDA0003532578200000171
where φ' is the updated meta-parameter, φ is the meta-parameter to be updated (i.e., the meta-parameter before updating), β is the learning rate (i.e., the step size), LvTo verify the loss, θs' is the first network parameter, Lvs') indicates a verification loss under said first network parameters.
In a possible implementation manner, the meta-parameters include network parameters of a parameter generation network associated with the student network training, and when the temperature coefficient is re-determined based on the updated meta-parameters, the updated meta-parameters may be used as the network parameters of the parameter generation network to obtain an updated parameter generation network, and then the temperature coefficient is re-determined based on the updated parameter generation network and the initial temperature coefficient.
The parameter generation network is a pre-trained neural network, and may be a Multi-Layer perceptron (MLP) as an example.
In one possible implementation, the meta-parameter comprises a first meta-parameter corresponding to the student network and a second meta-parameter corresponding to the teacher network; when the updated meta-parameter is used as the parameter to generate the network parameter of the network and the updated parameter generation network is obtained, the updated first meta-parameter can be used as the network parameter of the parameter generation network to obtain the updated first parameter generation network; and taking the updated second element parameter as the network parameter of the parameter generation network to obtain the updated second parameter generation network.
When the temperature coefficient is re-determined based on the updated parameter generation network and the initial temperature coefficient, determining a first temperature coefficient corresponding to the student network based on the updated first parameter generation network and the initial temperature coefficient; and generating a network and the initial temperature coefficient based on the updated second parameter, and determining a second temperature coefficient corresponding to the teacher network.
The first parameter generation network and the second parameter generation network may be Multi-Layer perceptrons (MLPs) with different network parameters.
By adopting the method, the training of the parameter generation network can be realized through the training element parameters, the accuracy of the generated temperature coefficient can be improved through the updated parameter generation network and the re-determined temperature coefficient, and the precision influence caused by directly adjusting the too large or too small temperature coefficient is avoided.
In a possible implementation, the meta-parameters further include a learnable embedded representation embedding, which is used to characterize image noise. Specifically, after the updated parameter generation network is obtained, embedding in the meta-parameters may be input into the parameter generation network, and the parameter generation network may extract a characteristic value of the embedding and output a parameter value used for calculating the temperature coefficient.
Specifically, the imbedding may be input to the first parameter generation network and the second parameter generation network, respectively, and a first parameter value may be output by the first parameter generation network, a second parameter value may be output by the second parameter generation network, and the first temperature coefficient may be calculated based on the first parameter value and the initial temperature coefficient, and the second temperature coefficient may be calculated based on the second parameter value and the initial temperature coefficient.
In another possible implementation, the temperature coefficients include a first temperature coefficient corresponding to the student network and a second temperature coefficient corresponding to the teacher network, and the initial temperature coefficients corresponding to the student network and the classroom network may be the same temperature coefficient; when the temperature coefficient is re-determined based on the updated parameter generation network and the initial temperature coefficient, the first temperature coefficient and the second temperature coefficient may be determined based on the updated parameter generation network and the initial temperature coefficient.
Here, it should be noted that, in the iterative training process of the student neural network, the initial temperature coefficients applied each time the first temperature coefficient and the second temperature coefficient are re-determined are the same and are preset values, and the initial temperature coefficients may be regarded as preset hyper-parameters.
Specifically, when the first temperature coefficient and the second temperature coefficient are determined based on the updated parameter generation network and the initial temperature coefficient, the imbedding may be input to the parameter generation network, and then the parameter generation network may output a first parameter value and a second parameter value, and then may calculate the first temperature coefficient and the second temperature coefficient based on the first parameter value, the second parameter value, and the initial temperature coefficient.
For example, when determining the first temperature coefficient and the second temperature coefficient, the following formula may be used for calculation:
st}=τinit+σ(MLP(e))-0.5
wherein, tausIs said first temperature coefficient, τtIs said second temperature coefficient, τinitThe initial temperature coefficient is a preset numerical value, sigma is a Sigmoid function, MLP is a multilayer sensor, e is embedding, and MLP (e) is a parameter value output by the parameter generation network based on the embedding.
In one possible embodiment, when updating the current network parameters of the student network based on the re-determined temperature coefficient, the sample data may be first input to the student network to determine a fourth prediction result through the re-determined temperature coefficient and the student network; inputting the sample data into the teacher network to determine a fifth prediction result through the teacher network and the temperature coefficient; then, based on the fourth prediction result and the fifth prediction result, a training loss is re-determined, and the current network parameters of the student network are updated based on the re-determined training loss.
By adopting the method, the temperature coefficients of the student network and the teacher network can be automatically updated, and the student network is trained again based on the updated temperature coefficients, so that the student network can intelligently update the current network parameters, and the performance of the student network is improved.
The sample data may be the training sample data, and the fourth prediction result includes a fourth prediction value and a fourth confidence corresponding to the fourth prediction value, where the fourth confidence is used to indicate a probability that the sample data belongs to the fourth prediction value; the fifth prediction result comprises a fifth prediction value and a fifth confidence corresponding to the fifth prediction value, and the fifth confidence is used for representing the probability that the sample data belongs to the fifth prediction value.
For example, when the training loss is re-determined based on the fourth predictor and the fifth predictor, a cross-entropy loss between the fourth predictor and the fifth predictor may be calculated.
Then, when updating the current network parameters of the student network based on the re-determined training loss, a target network parameter for updating may be calculated based on the re-determined training loss and the current network parameters, and then the current network parameters are updated to the target network parameters.
For example, when calculating the first network parameter, a gradient descent method may be adopted, and specifically, the following formula may be adopted for calculation:
Figure BDA0003532578200000201
wherein, theta ″)sFor said target network parameter, θsFor the current network parameter, α is the learning rate (i.e., theStep size), LtFor the re-determined training loss, LtsPhi) is an inclusion parameter thetasAnd a cross entropy loss function of phi, phi being a meta-parameter.
Here, it should be noted that the network parameter of the pre-updated student network is θsHere, the current network parameter of the student network before updating is an un-updated network parameter, i.e. θs
With respect to step 104,
The training cutoff condition may be, for example, that the number of times of training is greater than a preset number of times, or the difference between the accuracy of the first prediction result and the accuracy of the second prediction result is smaller than a preset difference, or the training loss of the first prediction result is smaller than a preset loss value.
Specifically, after updating the current network parameter, the method returns to step 101 to step 104, updates the current network parameter for the next time, and updates the current network parameter for multiple times, so as to obtain a student network with a high accuracy of an output result.
Here, it should be noted that, after returning to execute step 101, the current network parameter of the student network to be trained is the network parameter after the last update.
In summary, the above-mentioned updating process of the network parameters of the student network can be summarized as follows:
1. the current network parameter theta of the student networksFirst network parameter theta updated to pre-updates′。
2. First network parameter theta based on pre-updates'update the meta-parameter from the meta-parameter phi to phi'.
3. Based on the updated meta-parameter phi', the current network parameter theta of the student network is determinedsUpdated as θ ″)s
For the specific updating process, reference is made to the description of the above embodiments, which will not be described herein again.
According to the training method for the classified neural network, the pre-updated first network parameters can be determined according to the first detection result and the second detection result obtained by pre-training, the element parameters are updated according to the first network parameters, and then the temperature coefficients are adjusted based on the element parameters. By the method, the temperature coefficient is determined as the training parameter in the student network training process, the temperature coefficients of the teacher network and the student network can be automatically adjusted, the optimal temperature coefficient is determined for the teacher network and the student network, and the network precision of the teacher network and the student network is further improved.
It will be understood by those skilled in the art that in the method of the present invention, the order of writing the steps does not imply a strict order of execution and any limitations on the implementation, and the specific order of execution of the steps should be determined by their function and possible inherent logic.
Based on the same inventive concept, the embodiment of the present disclosure further provides an image classification method, including the following steps:
step A, obtaining an image to be detected;
and B, identifying the image to be detected based on the student network obtained by the embodiment to obtain a classification result corresponding to the image to be detected.
Based on the same inventive concept, the embodiment of the present disclosure further provides a device for training a neural network for classification and image classification corresponding to the method for training the neural network for classification and image classification, and as the principle of solving the problem of the device in the embodiment of the present disclosure is similar to the method for training the neural network for classification and image classification described above in the embodiment of the present disclosure, the implementation of the device may refer to the implementation of the method, and repeated details are omitted.
Referring to fig. 2, there is shown an architecture diagram of a training apparatus for classified neural networks according to an embodiment of the present disclosure, the apparatus includes: a first obtaining module 201, a determining module 202 and an updating module 203; wherein the content of the first and second substances,
the first obtaining module 201 is configured to obtain sample data, and input the sample data to a teacher network and a student network to be trained respectively to obtain a first prediction result and a second prediction result;
a determining module 202, configured to determine a pre-updated first network parameter based on the first prediction result, the second prediction result, and a current network parameter of the student network;
and the updating module 203 is configured to update the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameter, and update the current network parameter of the student network based on the updated temperature coefficient to obtain an updated student network.
In a possible implementation, the determining module 202, when determining the pre-updated first network parameter based on the first prediction result, the second prediction result, and the current network parameter of the student network, is configured to:
acquiring current network parameters of the student network;
determining a training loss based on the first prediction result and the second prediction result;
and adjusting the acquired current network parameters of the student network based on training loss, and determining pre-updated first network parameters.
In a possible implementation manner, the updating module 203, when updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameter, is configured to:
updating meta-parameters associated with the student network training based on the pre-updated first network parameters; the element parameters are used for determining temperature coefficients corresponding to the student network and the teacher network;
and updating the temperature coefficient based on the updated element parameter.
In one possible embodiment, the sample data includes training sample data and verification sample data;
the first obtaining module 201, when the sample data is input to a teacher network and a student network to be trained respectively to obtain a first prediction result and a second prediction result, is configured to:
respectively inputting the training sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
the updating module 203, when updating the meta-parameters associated with the student network training based on the pre-updated first network parameters, is configured to:
taking the pre-updated first network parameter as the current network parameter of the student network to obtain a pre-updated student network;
inputting the verification sample data into the pre-updating student network to obtain a third prediction result;
and updating the meta-parameter based on the third prediction result and the marking information corresponding to the verification sample data.
In a possible implementation manner, the updating module 203, when updating the meta-parameter based on the third prediction result and the annotation information corresponding to the verification sample data, is configured to:
determining verification loss based on the third prediction result and the marking information corresponding to the verification sample data;
updating the meta-parameter based on the verification loss.
In a possible implementation manner, the updating module 203, when determining a verification loss based on the third prediction result and the label information corresponding to the verification sample data, is configured to:
determining error sample data of the pre-updated student network prediction error based on the third prediction result and the marking information corresponding to the verification sample data;
and determining the verification loss based on the confidence information of the prediction result corresponding to the error sample data in the third prediction result and the preset confidence information corresponding to the marking information corresponding to the verification sample data.
In a possible implementation manner, the updating module 203, when determining the verification loss based on the confidence information of the prediction result corresponding to the incorrect sample data in the third prediction result and the preset confidence information corresponding to the label information corresponding to the verification sample data, is configured to:
calculating the square sum of the difference between the confidence information of each classification result in the third prediction result of the error sample data and the preset confidence information of each classification result in the labeling information of the error sample data aiming at any error sample data;
and taking the sum of squares corresponding to each error sample data as the verification loss.
In a possible embodiment, the updating module 203, when updating the current network parameter of the student network based on the updated temperature coefficient, is configured to:
inputting the sample data to the student network to determine a fourth prediction result through the updated temperature coefficient and the student network; inputting the sample data into the teacher network to determine a fifth prediction result through the teacher network and the updated temperature coefficient;
re-determining a training loss based on the fourth prediction result and the fifth prediction result, and updating a current network parameter of the student network based on the re-determined training loss.
In one possible embodiment, the meta-parameters include network parameters of a parameter generation network associated with the student network training;
the updating module 203, when updating the temperature coefficient based on the updated meta-parameter, is configured to:
generating network parameters of the network by taking the updated meta-parameters as parameters to obtain an updated parameter generation network;
the network is generated based on the updated parameters and the initial temperature coefficient, and the temperature coefficient is re-determined.
In one possible implementation, the temperature coefficient includes a first temperature coefficient corresponding to the student network and a second temperature coefficient corresponding to the teacher network;
the updating module 203, when generating the network based on the updated parameters and the initial temperature coefficient, and re-determining the temperature coefficient, is configured to:
determining the first temperature coefficient and the second temperature coefficient based on the updated parameter generation network and the initial temperature coefficient.
In one possible implementation, the meta-parameter includes a first meta-parameter corresponding to the student network and a second meta-parameter corresponding to the teacher network;
the updating module 203, when the updated meta-parameter is used as a parameter to generate a network parameter of the network, and the updated parameter generation network is obtained, is configured to:
taking the updated first meta-parameter as a network parameter of the parameter generation network to obtain an updated first parameter generation network; the updated second element parameter is used as the network parameter of the parameter generation network to obtain an updated second parameter generation network;
the updating module 203, when generating the network based on the updated parameters and the initial temperature coefficient, and re-determining the temperature coefficient, is configured to:
generating a network and the initial temperature coefficient based on the updated first parameter, and determining a first temperature coefficient corresponding to the student network; and generating a network and the initial temperature coefficient based on the updated second parameter, and determining a second temperature coefficient corresponding to the teacher network.
Referring to fig. 3, a schematic diagram of an architecture of an image classification apparatus provided in an embodiment of the present disclosure is shown, where the apparatus includes: a second obtaining module 301 and an identifying module 302; wherein the content of the first and second substances,
the second acquisition module 301 is configured to acquire an image to be detected;
the identification module 302 is configured to identify the image to be detected based on the student network obtained by training the neural network for classification according to the above embodiment, so as to obtain a classification result corresponding to the image to be detected.
The description of the processing flow of each module in the device and the interaction flow between the modules may refer to the related description in the above method embodiments, and will not be described in detail here.
Based on the same technical concept, the embodiment of the disclosure also provides computer equipment. Referring to fig. 4, a schematic structural diagram of a computer device 400 provided in the embodiment of the present disclosure includes a processor 401, a memory 402, and a bus 403. The memory 402 is used for storing execution instructions and includes a memory 4021 and an external memory 4022; the memory 4021 is also referred to as an internal memory, and is configured to temporarily store operation data in the processor 401 and data exchanged with an external memory 4022 such as a hard disk, the processor 401 exchanges data with the external memory 4022 through the memory 4021, and when the computer device 400 operates, the processor 401 communicates with the memory 402 through the bus 403, so that the processor 401 executes the following instructions:
acquiring sample data, and respectively inputting the sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
determining a pre-updated first network parameter based on the first prediction result, the second prediction result, and a current network parameter of the student network;
and updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameters, and updating the current network parameters of the student network based on the updated temperature coefficients to obtain an updated student network.
Based on the same technical concept, the embodiment of the disclosure also provides computer equipment. Referring to fig. 5, a schematic structural diagram of a computer device 500 provided in the embodiment of the present disclosure includes a processor 501, a memory 502, and a bus 503. The memory 502 is used for storing execution instructions and includes a memory 5021 and an external memory 5022; the memory 5021 is also referred to as an internal memory, and is used for temporarily storing operation data in the processor 501 and data exchanged with an external storage 5022 such as a hard disk, the processor 501 exchanges data with the external storage 5022 through the memory 5021, and when the computer device 500 operates, the processor 501 communicates with the storage 502 through the bus 503, so that the processor 501 executes the following instructions:
acquiring an image to be detected;
the student network obtained by training based on the neural network training method for classification identifies the image to be detected to obtain a classification result corresponding to the image to be detected.
The embodiments of the present disclosure also provide a computer-readable storage medium, on which a computer program is stored, where the computer program, when executed by a processor, performs the steps of the method for training a neural network for classification and image classification described in the above method embodiments. The storage medium may be a volatile or non-volatile computer-readable storage medium.
The embodiments of the present disclosure also provide a computer program product, where the computer program product bears a program code, and instructions included in the program code may be used to execute the steps of the neural network training and image classification method for classification described in the above method embodiments, which may be referred to specifically for the above method embodiments, and are not described herein again.
The computer program product may be implemented by hardware, software or a combination thereof. In an alternative embodiment, the computer program product is embodied in a computer storage medium, and in another alternative embodiment, the computer program product is embodied in a Software product, such as a Software Development Kit (SDK), or the like.
It can be clearly understood by those skilled in the art that, for convenience and simplicity of description, the specific working process of the system and the apparatus described above may refer to the corresponding process in the foregoing method embodiment, and details are not described herein again. In the several embodiments provided in the present disclosure, it should be understood that the disclosed system, apparatus, and method may be implemented in other ways. The above-described embodiments of the apparatus are merely illustrative, and for example, the division of the units is only one logical division, and there may be other divisions when actually implemented, and for example, a plurality of units or components may be combined or integrated into another system, or some features may be omitted, or not executed. In addition, the shown or discussed mutual coupling or direct coupling or communication connection may be an indirect coupling or communication connection of devices or units through some communication interfaces, and may be in an electrical, mechanical or other form.
The units described as separate parts may or may not be physically separate, and parts displayed as units may or may not be physical units, may be located in one place, or may be distributed on a plurality of network units. Some or all of the units can be selected according to actual needs to achieve the purpose of the solution of the embodiment.
In addition, functional units in the embodiments of the present disclosure may be integrated into one processing unit, or each unit may exist alone physically, or two or more units are integrated into one unit.
The functions, if implemented in the form of software functional units and sold or used as a stand-alone product, may be stored in a non-volatile computer-readable storage medium executable by a processor. Based on such understanding, the technical solution of the present disclosure may be embodied in the form of a software product, which is stored in a storage medium and includes several instructions for causing a computer device (which may be a personal computer, a server, or a network device) to execute all or part of the steps of the method according to the embodiments of the present disclosure. And the aforementioned storage medium includes: various media capable of storing program codes, such as a usb disk, a removable hard disk, a Read-Only Memory (ROM), a Random Access Memory (RAM), a magnetic disk, or an optical disk.
Finally, it should be noted that: the above-mentioned embodiments are merely specific embodiments of the present disclosure, which are used for illustrating the technical solutions of the present disclosure and not for limiting the same, and the scope of the present disclosure is not limited thereto, and although the present disclosure is described in detail with reference to the foregoing embodiments, those skilled in the art should understand that: any person skilled in the art can modify or easily conceive of the technical solutions described in the foregoing embodiments or equivalent technical features thereof within the technical scope of the present disclosure; such modifications, changes or substitutions do not depart from the spirit and scope of the embodiments of the present disclosure, and should be construed as being included therein. Therefore, the protection scope of the present disclosure shall be subject to the protection scope of the claims.
If the technical scheme of the application relates to personal information, a product applying the technical scheme of the application clearly informs personal information processing rules before processing the personal information, and obtains personal independent consent. If the technical scheme of the application relates to sensitive personal information, a product applying the technical scheme of the application obtains individual consent before processing the sensitive personal information, and simultaneously meets the requirement of 'express consent'. For example, at a personal information collection device such as a camera, a clear and significant identifier is set to inform that the personal information collection range is entered, the personal information is collected, and if the person voluntarily enters the collection range, the person is regarded as agreeing to collect the personal information; or on the device for processing the personal information, under the condition of informing the personal information processing rule by using obvious identification/information, obtaining personal authorization by modes of popping window information or asking a person to upload personal information of the person by himself, and the like; the personal information processing rule may include information such as a personal information processor, a personal information processing purpose, a processing method, and a type of personal information to be processed.

Claims (16)

1. A method of training a neural network for classification, comprising:
acquiring sample data, and respectively inputting the sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
determining a pre-updated first network parameter based on the first prediction result, the second prediction result, and a current network parameter of the student network;
and updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameter, and updating the current network parameter of the student network based on the updated temperature coefficient to obtain an updated student network.
2. The method of claim 1, wherein determining the pre-updated first network parameter based on the first prediction result, the second prediction result, and a current network parameter of the student network comprises:
acquiring current network parameters of the student network;
determining a training loss based on the first prediction result and the second prediction result;
and adjusting the acquired current network parameters of the student network based on training loss, and determining pre-updated first network parameters.
3. The method according to claim 1 or 2, wherein the updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameters comprises:
updating meta-parameters associated with the student network training based on the pre-updated first network parameters; the element parameters are used for determining temperature coefficients corresponding to the student network and the teacher network;
and updating the temperature coefficient based on the updated element parameter.
4. The method of claim 3, wherein the sample data comprises training sample data and validation sample data;
the step of inputting the sample data into a teacher network and a student network to be trained respectively to obtain a first prediction result and a second prediction result comprises the following steps:
respectively inputting the training sample data to a teacher network and a student network to be trained to obtain a first prediction result and a second prediction result;
the updating the meta-parameters associated with the student network training based on the pre-updated first network parameters comprises:
taking the pre-updated first network parameter as the current network parameter of the student network to obtain a pre-updated student network;
inputting the verification sample data into the pre-updating student network to obtain a third prediction result;
and updating the meta-parameter based on the third prediction result and the marking information corresponding to the verification sample data.
5. The method according to claim 4, wherein the updating the meta-parameter based on the third prediction result and the label information corresponding to the verification sample data comprises:
determining verification loss based on the third prediction result and the marking information corresponding to the verification sample data;
updating the meta-parameter based on the verification loss.
6. The method according to claim 5, wherein the determining a verification loss based on the third prediction result and the annotation information corresponding to the verification sample data comprises:
determining error sample data of the pre-updated student network prediction error based on the third prediction result and the marking information corresponding to the verification sample data;
and determining the verification loss based on the confidence information of the prediction result corresponding to the error sample data in the third prediction result and the preset confidence information corresponding to the marking information corresponding to the verification sample data.
7. The method according to claim 6, wherein the determining the verification loss based on the confidence information of the prediction result corresponding to the incorrect sample data in the third prediction result and the preset confidence information corresponding to the labeling information corresponding to the verification sample data comprises:
calculating the square sum of the difference between the confidence information of each classification result in the third prediction result of the error sample data and the preset confidence information of each classification result in the labeling information of the error sample data aiming at any error sample data;
and taking the sum of squares corresponding to each error sample data as the verification loss.
8. The method according to any one of claims 3 to 7, wherein the updating the current network parameters of the student network based on the updated temperature coefficients comprises:
inputting the sample data to the student network to determine a fourth prediction result by the updated temperature coefficient and the student network; inputting the sample data into the teacher network to determine a fifth prediction result through the teacher network and the updated temperature coefficient;
re-determining a training loss based on the fourth prediction result and the fifth prediction result, and updating a current network parameter of the student network based on the re-determined training loss.
9. The method according to any one of claims 4 to 8, wherein the meta-parameters comprise network parameters of a parameter generation network associated with the student network training;
the updating the temperature coefficient based on the updated element parameter includes:
generating network parameters of the network by taking the updated meta-parameters as parameters to obtain an updated parameter generation network;
the network is generated based on the updated parameters and the initial temperature coefficient, and the temperature coefficient is re-determined.
10. The method of claim 9, wherein the temperature coefficients comprise a first temperature coefficient corresponding to the student network and a second temperature coefficient corresponding to the teacher network;
the generating a network and an initial temperature coefficient based on the updated parameters, re-determining the temperature coefficient, comprising:
determining the first temperature coefficient and the second temperature coefficient based on the updated parameter generation network and the initial temperature coefficient.
11. The method of claim 9 or 10, wherein the meta-parameters comprise a first meta-parameter corresponding to the student network and a second meta-parameter corresponding to the teacher network;
the generating the network parameter of the network by taking the updated meta-parameter as the parameter to obtain the updated parameter generating network includes:
taking the updated first meta-parameter as a network parameter of the parameter generation network to obtain an updated first parameter generation network; the updated second element parameter is used as the network parameter of the parameter generation network to obtain an updated second parameter generation network;
the generating a network and an initial temperature coefficient based on the updated parameters, re-determining the temperature coefficient, comprising:
generating a network and the initial temperature coefficient based on the updated first parameter, and determining a first temperature coefficient corresponding to the student network; and generating a network and the initial temperature coefficient based on the updated second parameter, and determining a second temperature coefficient corresponding to the teacher network.
12. An image classification method, comprising:
acquiring an image to be detected;
the student network obtained by training based on the neural network training method for classification as claimed in any one of claims 1 to 11 identifies the image to be detected to obtain a classification result corresponding to the image to be detected.
13. An apparatus for training a neural network for classification, comprising:
the first acquisition module is used for acquiring sample data and inputting the sample data into a teacher network and a student network to be trained respectively to obtain a first prediction result and a second prediction result;
a determining module for determining a pre-updated first network parameter based on the first prediction result, the second prediction result and a current network parameter of the student network;
and the updating module is used for updating the temperature coefficients corresponding to the student network and the teacher network based on the pre-updated first network parameters, and updating the current network parameters of the student network based on the updated temperature coefficients to obtain an updated student network.
14. An image classification apparatus, comprising:
the second acquisition module is used for acquiring an image to be detected;
an identification module, configured to identify the image to be detected based on the student network obtained by training the neural network for classification according to any one of claims 1 to 11, so as to obtain a classification result corresponding to the image to be detected.
15. A computer device, comprising: a processor, a memory and a bus, the memory storing machine-readable instructions executable by the processor, the processor and the memory communicating over the bus when a computer device is run, the machine-readable instructions when executed by the processor performing the steps of the training method for a neural network for classification as claimed in any one of claims 1 to 11 or performing the steps of the image classification method as claimed in claim 12.
16. A computer-readable storage medium, characterized in that a computer program is stored thereon, which computer program, when being executed by a processor, performs the steps of the method for training a neural network for classification as claimed in any one of claims 1 to 11, or performs the steps of the method for image classification as claimed in claim 12.
CN202210209488.3A 2022-03-04 2022-03-04 Neural network training and image classification method and device for classification Pending CN114565797A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210209488.3A CN114565797A (en) 2022-03-04 2022-03-04 Neural network training and image classification method and device for classification

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210209488.3A CN114565797A (en) 2022-03-04 2022-03-04 Neural network training and image classification method and device for classification

Publications (1)

Publication Number Publication Date
CN114565797A true CN114565797A (en) 2022-05-31

Family

ID=81718398

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210209488.3A Pending CN114565797A (en) 2022-03-04 2022-03-04 Neural network training and image classification method and device for classification

Country Status (1)

Country Link
CN (1) CN114565797A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114821204A (en) * 2022-06-30 2022-07-29 山东建筑大学 Meta-learning-based embedded semi-supervised learning image classification method and system

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114821204A (en) * 2022-06-30 2022-07-29 山东建筑大学 Meta-learning-based embedded semi-supervised learning image classification method and system

Similar Documents

Publication Publication Date Title
CN111091175A (en) Neural network model training method, neural network model classification method, neural network model training device and electronic equipment
CN110363084A (en) A kind of class state detection method, device, storage medium and electronics
CN110135505B (en) Image classification method and device, computer equipment and computer readable storage medium
CN111767883B (en) Question correction method and device
CN112329679B (en) Face recognition method, face recognition system, electronic equipment and storage medium
CN109726291B (en) Loss function optimization method and device of classification model and sample classification method
CN111694954B (en) Image classification method and device and electronic equipment
CN116229530A (en) Image processing method, device, storage medium and electronic equipment
WO2023123847A1 (en) Model training method and apparatus, image processing method and apparatus, and device, storage medium and computer program product
CN112819011A (en) Method and device for identifying relationships between objects and electronic system
CN112749737A (en) Image classification method and device, electronic equipment and storage medium
WO2022252527A1 (en) Neural network training method and apparatus, facial recognition method and apparatus, and device and storage medium
CN114565797A (en) Neural network training and image classification method and device for classification
CN110717407A (en) Human face recognition method, device and storage medium based on lip language password
CN113221695B (en) Method for training skin color recognition model, method for recognizing skin color and related device
CN114299304A (en) Image processing method and related equipment
CN115909336A (en) Text recognition method and device, computer equipment and computer-readable storage medium
CN113435531A (en) Zero sample image classification method and system, electronic equipment and storage medium
CN113011532A (en) Classification model training method and device, computing equipment and storage medium
CN113053395A (en) Pronunciation error correction learning method and device, storage medium and electronic equipment
CN113591892A (en) Training data processing method and device
CN109657710B (en) Data screening method and device, server and storage medium
CN115830618A (en) Text recognition method and device, computer equipment and storage medium
CN114970732A (en) Posterior calibration method and device for classification model, computer equipment and medium
CN112446428B (en) Image data processing method and device

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination