CN111967492A - Method and device for training classification model, electronic equipment and storage medium - Google Patents

Method and device for training classification model, electronic equipment and storage medium Download PDF

Info

Publication number
CN111967492A
CN111967492A CN202010606425.2A CN202010606425A CN111967492A CN 111967492 A CN111967492 A CN 111967492A CN 202010606425 A CN202010606425 A CN 202010606425A CN 111967492 A CN111967492 A CN 111967492A
Authority
CN
China
Prior art keywords
classification model
loss function
type
obtaining
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202010606425.2A
Other languages
Chinese (zh)
Inventor
希滕
张刚
温圣召
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Baidu Netcom Science and Technology Co Ltd
Original Assignee
Beijing Baidu Netcom Science and Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Baidu Netcom Science and Technology Co Ltd filed Critical Beijing Baidu Netcom Science and Technology Co Ltd
Priority to CN202010606425.2A priority Critical patent/CN111967492A/en
Publication of CN111967492A publication Critical patent/CN111967492A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques

Landscapes

  • Engineering & Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)

Abstract

The application discloses a training method and device of a classification model, electronic equipment and a storage medium, and relates to the technical field of artificial intelligence, deep learning and image processing. The specific implementation scheme is as follows: training the classification model by using a current sample set, and respectively obtaining a first loss function and a second loss function of the classification model, wherein the first loss function is a loss function with the type as granularity, and the second loss function is a loss function with the sample as granularity; acquiring a total loss function of the classification model according to the first loss function and the second loss function; and updating the model parameters of the classification model according to the total loss function so as to train the classification model. The method can generate a total loss function according to the loss function with the type as the granularity and the loss function with the sample as the granularity, comprehensively considers the problems of class balance and difficulty and easiness of the sample, and can train the classification model according to the total loss function, thereby accelerating the convergence speed of the classification model.

Description

Method and device for training classification model, electronic equipment and storage medium
Technical Field
The present application relates to the technical field of artificial intelligence, deep learning, and image processing in the field of computer technologies, and in particular, to a method and an apparatus for training a classification model, an electronic device, and a storage medium.
Background
The loss function is very important for the convergence of the model, the loss function with the sample as the granularity is mostly adopted at present, when the accuracy of a certain sample is very high, the sample is not subjected to back propagation or back propagation of a small numerical value, the method can solve the problem that the sample is difficult and easy to be unbalanced, but the model is poor in applicability and low in precision.
Disclosure of Invention
A training method and device for a classification model, an electronic device and a storage medium are provided.
According to a first aspect, there is provided a training method of a classification model, comprising: training a classification model by using a current sample set, and respectively obtaining a first loss function and a second loss function of the classification model, wherein the first loss function is a loss function with type as granularity, and the second loss function is a loss function with sample as granularity; obtaining a total loss function of the classification model according to the first loss function and the second loss function; and updating the model parameters of the classification model according to the total loss function so as to train the classification model.
According to a second aspect, there is provided a training apparatus for classification models, comprising: the device comprises a first obtaining module, a second obtaining module and a third obtaining module, wherein the first obtaining module is used for training a classification model by using a current sample set and respectively obtaining a first loss function and a second loss function of the classification model, the first loss function is a loss function with a type as granularity, and the second loss function is a loss function with a sample as granularity; a second obtaining module, configured to obtain a total loss function of the classification model according to the first loss function and the second loss function; and the updating module is used for updating the model parameters of the classification model according to the total loss function so as to train the classification model.
According to a third aspect, there is provided an electronic device comprising: at least one processor; and a memory communicatively coupled to the at least one processor; wherein the memory stores instructions executable by the at least one processor to enable the at least one processor to perform the method of training a classification model according to the first aspect of the present application.
According to a fourth aspect, there is provided a non-transitory computer readable storage medium having stored thereon computer instructions for causing a computer to perform the method of training a classification model according to the first aspect of the present application.
The embodiment provided by the application at least has the following beneficial technical effects:
according to the training method and device for the classification model, the electronic equipment and the storage medium, the total loss function can be generated according to the loss function with the type as the granularity and the loss function with the sample as the granularity, the problems of class balance and difficulty and easiness of the sample are comprehensively considered, the classification model can be trained according to the total loss function, the convergence rate of the classification model can be increased, the precision of the classification model is remarkably improved, and hardware benefits can be brought.
It should be understood that the statements in this section do not necessarily identify key or critical features of the embodiments of the present disclosure, nor do they limit the scope of the present disclosure. Other features of the present disclosure will become apparent from the following description.
Drawings
The drawings are included to provide a better understanding of the present solution and are not intended to limit the present application. Wherein:
FIG. 1 is a schematic flow chart diagram of a method for training a classification model according to a first embodiment of the present application;
FIG. 2 is a schematic flow chart of a first loss function and a second loss function of a classification model respectively obtained in a training method of the classification model according to a second embodiment of the present application;
fig. 3 is a schematic flowchart of obtaining a model identification accuracy of a classification model corresponding to a current sample set in a classification model training method according to a third embodiment of the present application;
FIG. 4 is a schematic flowchart of a method for training a classification model according to a fourth embodiment of the present application, in which accuracy of a classification model for identifying multiple types of samples of different types is obtained;
FIG. 5 is a schematic flowchart illustrating a method for training a classification model according to a fifth embodiment of the present application, wherein a first loss function is obtained according to a plurality of type recognition accuracy rates and a plurality of model recognition accuracy rates;
FIG. 6 is a block diagram of a training apparatus for a classification model according to a first embodiment of the present application;
fig. 7 is a block diagram of an electronic device for implementing a method for training a classification model according to an embodiment of the present application.
Detailed Description
The following description of the exemplary embodiments of the present application, taken in conjunction with the accompanying drawings, includes various details of the embodiments of the application for the understanding of the same, which are to be considered exemplary only. Accordingly, those of ordinary skill in the art will recognize that various changes and modifications of the embodiments described herein can be made without departing from the scope and spirit of the present application. Also, descriptions of well-known functions and constructions are omitted in the following description for clarity and conciseness.
Fig. 1 is a flowchart illustrating a training method of a classification model according to a first embodiment of the present application.
As shown in fig. 1, a method for training a classification model according to a first embodiment of the present application includes:
s101, training the classification model by using the current sample set, and respectively obtaining a first loss function and a second loss function of the classification model, wherein the first loss function is a loss function with the type as the granularity, and the second loss function is a loss function with the sample as the granularity.
It is understood that the sample set (Batch) may be a set of all samples that are trained on the classification model each time, and both the first loss function and the second loss function may be calibrated according to actual conditions.
And S102, acquiring a total loss function of the classification model according to the first loss function and the second loss function.
Optionally, obtaining the total loss function of the classification model according to the first loss function and the second loss function may include superimposing the first loss function and the second loss function, and using the superimposed function as the total loss function of the classification model.
And S103, updating the model parameters of the classification model according to the total loss function so as to train the classification model.
Optionally, the updating of the model parameters of the classification model according to the total loss function may include obtaining gradient information of the total loss function, and updating the model parameters according to the gradient information. Updating the model parameters according to the gradient information may include back-propagating according to the gradient information to update the model parameters.
In summary, according to the training method of the classification model of the embodiment of the application, the total loss function can be generated according to the loss function using the type as the granularity and the loss function using the sample as the granularity, the problems of class balance and difficulty and easiness of the sample are comprehensively considered, the classification model can be trained according to the total loss function, the convergence rate of the classification model can be increased, the precision of the classification model is remarkably improved, and the operation speed of the classification model is further increased.
The precision and the running speed of the model are both related to the complexity of the model, the more complex the model is, the higher the precision and the running speed of the model are, however, the precision and the running speed of the model can be obviously improved by the model training method, and further, the model with a small size can also have higher precision and running speed, so that the method is convenient to be applied to hardware with limited space resources, such as mobile phones, mobile terminals and the like, and the method can bring greater hardware benefits.
The image processing model obtained by the model training method is used for image recognition, and due to the fact that the model is simple in structure and high in precision, accuracy and reliability of image recognition are improved under the condition that excessive resources are not occupied, the image processing speed is improved, requirements of the image processing model on hardware performance of electronic equipment are lowered, and hardware cost of the electronic equipment is lowered.
On the basis of the above embodiment, the obtaining of the first loss function and the second loss function of the classification model in step S101 may include, as shown in fig. 2:
s201, obtaining model identification accuracy of the classification model corresponding to the current sample set.
The model identification accuracy refers to the percentage of the number of correct samples to the number of all samples after training the classification model by using all samples in the sample set. For example, if there are 1000 samples in the current sample set, and after the classification model is trained by using 1000 samples in the current sample set, if there are 800 samples that are correctly identified and the remaining samples are incorrectly identified, the model identification accuracy of the classification model corresponding to the current sample set is 80%.
It will be appreciated that different sample sets may correspond to different model identification accuracies.
S202, respectively obtaining a plurality of type identification accuracy rates of the classification model aiming at different types of samples.
It will be appreciated that a sample set may include multiple types of samples. For example, the sample set may include three types of samples, namely "gender", "age" and "interest".
The type identification accuracy refers to the percentage of the number of samples of a certain type to the number of all samples of a certain type, which are identified correctly after all samples of a certain type in the sample set are used for training the classification model. For example, if there are 300 "sex" type samples in the current sample set, and after the classification model is trained by using the 300 "sex" type samples, if there are 100 "sex" type samples that are correctly identified and the remaining "sex" type samples are incorrectly identified, the type identification accuracy of the classification model for the "sex" type samples is 33.3%.
S203, acquiring a first loss function according to the multiple type identification accuracy rates and the model identification accuracy rate.
Therefore, the method can comprehensively consider a plurality of type identification accuracy rates and model identification accuracy rates to obtain the first loss function, is beneficial to improving the adaptability of the first loss function to different types of samples, and can correspond to different first loss functions according to different type identification accuracy rates and model identification accuracy rates, so that the first loss function can be obtained more flexibly and accurately.
And S204, acquiring a second loss function according to the model identification accuracy.
For example, the second loss function may be a focal loss function, with the following correlation equation:
L_f=-(1-P)×K2log(P)
wherein L _ f represents a second loss function, P represents model identification accuracy, and K2Representing a constant.
Therefore, the method can obtain the second loss function according to the model identification accuracy, and different model identification accuracies can correspond to different second loss functions, so that the second loss function can be obtained more flexibly and accurately.
On the basis of the foregoing embodiment, the obtaining of the model identification accuracy of the classification model corresponding to the current sample set in step S201, as shown in fig. 3, may include:
s301, obtaining the mark type of each sample in the current sample set and the identification type of each sample identified by the classification model.
It can be understood that the corresponding label type can be preset for each sample in the current sample set in advance, and after the classification model is trained by using the current sample set, the identification type identified by the classification model for each sample can be obtained.
S302, comparing the identification type of each sample with the mark type to obtain the number of first samples with correct identification of the classification model, wherein when the mark type is consistent with the identification type, the classification model is judged to be correct.
It can be understood that when the mark type and the identification type corresponding to a certain sample are consistent, the classification model identifies the sample correctly, and then the identification type of each sample can be compared with the mark type, the first sample number which is identified correctly by the classification model is obtained according to the comparison result, and if the mark type is consistent with the identification type, the classification model is judged to be identified correctly.
As another possible implementation manner, if the mark type and the recognition type are not consistent, the classification model can be judged to be wrongly recognized.
And S303, acquiring the model identification accuracy according to the total sample number and the first sample number of the current sample set.
Optionally, obtaining the model identification accuracy according to the total sample number and the first sample number of the current sample set, which may include using a percentage of the first sample number to the total sample number as the model identification accuracy.
Therefore, the method can obtain the first sample number according to the mark type of each sample in the current sample set and the identification type of each sample identified by the classification model, obtain the model identification accuracy according to the total sample number and the first sample number, and further can be used for obtaining the second loss function with the type as the granularity.
On the basis of the foregoing embodiment, the obtaining, in step S202, a plurality of type identification accuracy rates of the classification model for different types of samples respectively may include, as shown in fig. 4:
s401, aiming at each type, obtaining the second sample number of the first sample belonging to the type and identified by the classification model in the current sample set.
It is understood that the classification model may identify types of samples, and after the classification model is trained by using the current sample set, a second number of samples belonging to the first sample of each type identified by the classification model in the current sample set may be obtained.
S402, obtaining the mark type of each first sample, and counting the number of the third samples with correct identification of the classification model under the type according to the mark type of the first sample, wherein when the mark type is consistent with the identification type, the classification model is judged to be correct in identification.
For the related content of step S402, please refer to the above embodiment, which is not described herein again.
And S403, acquiring the type identification accuracy under the type according to the second sample quantity and the third sample quantity.
Optionally, the obtaining of the type identification accuracy under the type may include taking a percentage of the third sample number to the second sample number as the type identification accuracy under the type according to the second sample number and the third sample number.
Therefore, the method can obtain the type identification accuracy under the type according to the second sample quantity and the third sample quantity of each type sample in the current sample set, and further can be used for obtaining the second loss function with the type as the granularity.
On the basis of the foregoing embodiment, the obtaining the first loss function in step S203 according to the plurality of type recognition accuracy rates and model recognition accuracy rates may include, as shown in fig. 5:
s501, weighting the multiple type identification accuracy rates to obtain the type accumulation identification accuracy rate of the classification model.
Therefore, the type accumulation identification accuracy of the classification model can be obtained according to the multiple types of identification accuracy, so that the type accumulation identification accuracy can reflect multiple types of samples, and the method is more flexible and accurate.
Optionally, the weighting may be performed on the plurality of type recognition accuracy rates to obtain the type cumulative recognition accuracy rate of the classification model, and the method may include obtaining a weighting coefficient corresponding to each type recognition accuracy rate, obtaining a product of the plurality of type recognition accuracy rates and the weighting coefficients corresponding to the plurality of type recognition accuracy rates, and taking a sum of all the products as the type cumulative recognition accuracy rate of the classification model.
And S502, acquiring a first loss function according to the type accumulated identification accuracy and the model identification accuracy.
Alternatively, the first loss function may be obtained using the following equation:
L_c=-(1-P_class)×K1log(P)
wherein L _ c represents a first loss function, P _ class represents type accumulation identification accuracy, P represents the model identification accuracy, K represents the model identification accuracy, and1representing a constant.
Therefore, the method can weight the plurality of type identification accuracy rates to obtain the type accumulation identification accuracy rate of the classification model, and obtain the first loss function according to the type accumulation identification accuracy rate and the model identification accuracy rate.
On the basis of the foregoing embodiment, after the parameters of the classification model are updated according to the total loss function in step S104, the method may further include obtaining a next sample set, and continuing training the updated classification model by using the next sample set until the classification model converges or the cumulative training times reaches the preset training times, and ending the training to generate the target classification model. The preset training times can be calibrated according to actual conditions. Therefore, the method can train the classification model according to a plurality of sample sets, is beneficial to improving the adaptability of the classification model to different samples, and is also beneficial to improving the precision of the classification model.
Fig. 6 is a block diagram of a training apparatus of a classification model according to a first embodiment of the present application.
As shown in fig. 6, the training apparatus 600 for a classification model according to the embodiment of the present application includes: a first obtaining module 601, a second obtaining module 602, and an updating module 603.
The first obtaining module 601 is configured to train a classification model by using a current sample set, and obtain a first loss function and a second loss function of the classification model respectively, where the first loss function is a loss function with a type as a granularity, and the second loss function is a loss function with a sample as a granularity.
The second obtaining module 602 is configured to obtain a total loss function of the classification model according to the first loss function and the second loss function.
The updating module 603 is configured to update the model parameters of the classification model according to the total loss function, so as to train the classification model.
In an embodiment of the present application, the first obtaining module 601 includes a first obtaining unit, configured to obtain a model identification accuracy of the classification model corresponding to the current sample set; the second acquisition unit is used for respectively acquiring a plurality of type identification accuracy rates of the classification model for different types of samples; a third obtaining unit, configured to obtain the first loss function according to the multiple type recognition accuracy rates and the model recognition accuracy rate; and a fourth obtaining unit, configured to obtain the second loss function according to the model identification accuracy.
In an embodiment of the application, the third obtaining unit includes: the first obtaining subunit is configured to weight the multiple type identification accuracy rates to obtain a type accumulation identification accuracy rate of the classification model; and the second obtaining subunit is used for obtaining the first loss function according to the type accumulated identification accuracy and the model identification accuracy.
In an embodiment of the application, the second obtaining subunit obtains the first loss function by using the following formula:
L_c=-(1-P_class)×K1log(P)
wherein L _ c represents the first loss function, P _ class represents the type cumulative identification accuracy, P represents the model identification accuracy, and K represents the model identification accuracy1Representing a constant.
In an embodiment of the present application, the first obtaining unit includes a third obtaining subunit, configured to obtain a label type of each sample in the current sample set and an identification type identified for each sample by the classification model; a fourth obtaining subunit, configured to compare the identification type of each sample with the mark type, and obtain a first number of samples that are correctly identified by the classification model, where when the mark type is consistent with the identification type, it is determined that the classification model is correctly identified; and a fifth obtaining subunit, configured to obtain the model identification accuracy according to the total number of samples in the current sample set and the first number of samples.
In an embodiment of the present application, the second obtaining unit includes a sixth obtaining subunit, configured to obtain, for each type, a second number of samples in the current sample set, which are identified by the classification model and belong to the first sample of the type; a seventh obtaining subunit, configured to obtain a mark type of each first sample, and count, according to the mark type of the first sample, the number of third samples in which the classification model is correctly identified in the type, where when the mark type is consistent with the identification type, it is determined that the classification model is correctly identified; and an eighth obtaining subunit, configured to obtain the type identification accuracy of the type according to the second sample number and the third sample number.
In an embodiment of the present application, the updating module includes a fifth obtaining unit, configured to obtain gradient information of the total loss function; and a sixth obtaining unit, configured to update the model parameter according to the gradient information.
In an embodiment of the application, the first obtaining module is further configured to obtain a next sample set after the parameters of the classification model are updated according to the total loss function, and continue training the updated classification model by using the next sample set until the classification model converges or an accumulated training number reaches a preset training number, so as to finish training, thereby generating the target classification model.
According to the training device of the classification model, a total loss function can be generated according to the loss function with the type as the granularity and the loss function with the sample as the granularity, the problems of class balance and difficulty and easiness in balance of the sample are comprehensively considered, the classification model can be trained according to the total loss function, the convergence rate of the classification model can be increased, the precision of the classification model is remarkably improved, and hardware benefits can be brought.
According to an embodiment of the present application, an electronic device and a readable storage medium are also provided.
Fig. 7 is a block diagram of an electronic device according to an embodiment of the present disclosure. Electronic devices are intended to represent various forms of digital computers, such as laptops, desktops, workstations, personal digital assistants, servers, blade servers, mainframes, and other appropriate computers. The electronic device may also represent various forms of mobile devices, such as personal digital processing, cellular phones, smart phones, wearable devices, and other similar computing devices. The components shown herein, their connections and relationships, and their functions, are meant to be examples only, and are not meant to limit implementations of the present application that are described and/or claimed herein.
As shown in fig. 7, the electronic apparatus includes: one or more processors 701, a memory 702, and interfaces for connecting the various components, including a high-speed interface and a low-speed interface. The various components are interconnected using different buses and may be mounted on a common motherboard or in other manners as desired. The processor 701 may process instructions for execution within the electronic device, including instructions stored in or on a memory to display graphical information of a GUI on an external input/output apparatus (such as a display device coupled to an interface). In other embodiments, multiple processors and/or multiple buses may be used, along with multiple memories and multiple memories, as desired. Also, multiple electronic devices may be connected, with each device providing portions of the necessary operations (e.g., as a server array, a group of blade servers, or a multi-processor system). In fig. 7, one processor 701 is taken as an example.
The memory 702 is a non-transitory computer readable storage medium as provided herein. Wherein the memory stores instructions executable by at least one processor to cause the at least one processor to perform the method of training a classification model provided herein. The non-transitory computer-readable storage medium of the present application stores computer instructions for causing a computer to perform the training method of the classification model provided herein.
The memory 702, which is a non-transitory computer readable storage medium, may be used to store non-transitory software programs, non-transitory computer executable programs, and modules, such as program instructions/modules (e.g., the first obtaining module 601, the second obtaining module 602, and the updating module 603 shown in fig. 6) corresponding to the training method of the classification model in the embodiment of the present application. The processor 701 executes various functional applications of the server and data processing, i.e., implements the training method of the classification model in the above method embodiments, by running non-transitory software programs, instructions, and modules stored in the memory 702.
The memory 702 may include a storage program area and a storage data area, wherein the storage program area may store an operating system, an application program required for at least one function; the storage data area may store data created according to use of the electronic device of the training method of the classification model, and the like. Further, the memory 702 may include high speed random access memory, and may also include non-transitory memory, such as at least one magnetic disk storage device, flash memory device, or other non-transitory solid state storage device. In some embodiments, the memory 702 may optionally include memory located remotely from the processor 701, and such remote memory may be connected to the electronic device of the training method of the classification model via a network. Examples of such networks include, but are not limited to, the internet, intranets, local area networks, mobile communication networks, and combinations thereof.
The electronic device of the training method of the classification model may further include: an input device 703 and an output device 704. The processor 701, the memory 702, the input device 703 and the output device 704 may be connected by a bus or other means, and fig. 7 illustrates an example of a connection by a bus.
The input device 703 may receive input numeric or character information and generate key signal inputs related to user settings and function control of the electronic device of the training method of the classification model, such as a touch screen, a keypad, a mouse, a track pad, a touch pad, a pointing stick, one or more mouse buttons, a track ball, a joystick, or other input devices. The output devices 704 may include a display device, auxiliary lighting devices (e.g., LEDs), and tactile feedback devices (e.g., vibrating motors), among others. The display device may include, but is not limited to, a Liquid Crystal Display (LCD), a Light Emitting Diode (LED) display, and a plasma display. In some implementations, the display device can be a touch screen.
Various implementations of the systems and techniques described here can be realized in digital electronic circuitry, integrated circuitry, application specific ASICs (application specific integrated circuits), computer hardware, firmware, software, and/or combinations thereof. These various embodiments may include: implemented in one or more computer programs that are executable and/or interpretable on a programmable system including at least one programmable processor, which may be special or general purpose, receiving data and instructions from, and transmitting data and instructions to, a storage system, at least one input device, and at least one output device.
These computer programs (also known as programs, software applications, or code) include machine instructions for a programmable processor, and may be implemented using high-level procedural and/or object-oriented programming languages, and/or assembly/machine languages. As used herein, the terms "machine-readable medium" and "computer-readable medium" refer to any computer program product, apparatus, and/or device (e.g., magnetic discs, optical disks, memory, Programmable Logic Devices (PLDs)) used to provide machine instructions and/or data to a programmable processor, including a machine-readable medium that receives machine instructions as a machine-readable signal. The term "machine-readable signal" refers to any signal used to provide machine instructions and/or data to a programmable processor.
To provide for interaction with a user, the systems and techniques described here can be implemented on a computer having: a display device (e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor) for displaying information to a user; and a keyboard and a pointing device (e.g., a mouse or a trackball) by which a user can provide input to the computer. Other kinds of devices may also be used to provide for interaction with a user; for example, feedback provided to the user can be any form of sensory feedback (e.g., visual feedback, auditory feedback, or tactile feedback); and input from the user may be received in any form, including acoustic, speech, or tactile input.
The systems and techniques described here can be implemented in a computing system that includes a back-end component (e.g., as a data server), or that includes a middleware component (e.g., an application server), or that includes a front-end component (e.g., a user computer having a graphical user interface or a web browser through which a user can interact with an implementation of the systems and techniques described here), or any combination of such back-end, middleware, or front-end components. The components of the system can be interconnected by any form or medium of digital data communication (e.g., a communication network). Examples of communication networks include: local Area Networks (LANs), Wide Area Networks (WANs), and the Internet.
The computer system may include clients and servers. A client and server are generally remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other.
According to the technical scheme of the embodiment of the application, the total loss function can be generated according to the loss function with the type as the granularity and the loss function with the sample as the granularity, the problems of class balance and difficulty and easiness in balance of the sample are comprehensively considered, the classification model can be trained according to the total loss function, the convergence speed of the classification model can be increased, the precision of the classification model is remarkably improved, and hardware benefits can be brought.
It should be understood that various forms of the flows shown above may be used, with steps reordered, added, or deleted. For example, the steps described in the present application may be executed in parallel, sequentially, or in different orders, and the present invention is not limited thereto as long as the desired results of the technical solutions disclosed in the present application can be achieved.
The above-described embodiments should not be construed as limiting the scope of the present application. It should be understood by those skilled in the art that various modifications, combinations, sub-combinations and substitutions may be made in accordance with design requirements and other factors. Any modification, equivalent replacement, and improvement made within the spirit and principle of the present application shall be included in the protection scope of the present application.

Claims (18)

1. A training method of a classification model comprises the following steps:
training a classification model by using a current sample set, and respectively obtaining a first loss function and a second loss function of the classification model, wherein the first loss function is a loss function with type as granularity, and the second loss function is a loss function with sample as granularity;
obtaining a total loss function of the classification model according to the first loss function and the second loss function; and
and updating the model parameters of the classification model according to the total loss function so as to train the classification model.
2. The training method of the classification model according to claim 1, wherein the separately obtaining the first loss function and the second loss function of the classification model comprises:
obtaining the model identification accuracy of the classification model corresponding to the current sample set;
respectively obtaining a plurality of type identification accuracy rates of the classification model aiming at different types of samples;
obtaining the first loss function according to the multiple type identification accuracy rates and the model identification accuracy rate; and
and acquiring the second loss function according to the model identification accuracy.
3. The method for training a classification model according to claim 2, wherein the obtaining the first loss function according to the plurality of type recognition accuracy rates and the model recognition accuracy rate comprises:
weighting the multiple type identification accuracy rates to obtain the type accumulation identification accuracy rate of the classification model; and
and acquiring the first loss function according to the type accumulated identification accuracy and the model identification accuracy.
4. The training method of a classification model according to claim 3, wherein the first loss function is obtained using the following formula:
L_c=-(1-P_class)×K1log(P)
wherein L _ c represents the first loss function, P _ class represents the type cumulative identification accuracy, P represents the model identification accuracy, and K represents the model identification accuracy1Representing a constant.
5. The training method of the classification model according to any one of claims 2 to 4, wherein the obtaining of the model identification accuracy of the classification model corresponding to the current sample set comprises:
obtaining the mark type of each sample in the current sample set and the identification type of each sample identified by the classification model;
comparing the identification type of each sample with the mark type to obtain the number of first samples with correct identification of the classification model, wherein when the mark type is consistent with the identification type, the classification model is judged to be correct; and
and obtaining the model identification accuracy according to the total sample number of the current sample set and the first sample number.
6. The training method of the classification model according to any one of claims 2 to 4, wherein the obtaining of the plurality of type recognition accuracy rates of the classification model for different types of samples respectively comprises:
for each type, acquiring a second sample number which is identified by the classification model in the current sample set and belongs to a first sample of the type;
obtaining the mark type of each first sample, and counting the number of third samples with correct identification of the classification model under the type according to the mark type of the first sample, wherein when the mark type is consistent with the identification type, the classification model is judged to be correctly identified; and
and obtaining the type identification accuracy under the type according to the second sample quantity and the third sample quantity.
7. The training method of the classification model according to claim 1, wherein the updating the model parameters of the classification model according to the total loss function comprises:
acquiring gradient information of the total loss function; and
and updating the model parameters according to the gradient information.
8. The training method of the classification model according to claim 1 or 7, wherein after the updating the parameters of the classification model according to the total loss function, further comprising:
and obtaining a next sample set, and continuously training the updated classification model by using the next sample set until the classification model converges or the accumulated training times reach the preset training times, and finishing the training to generate a target classification model.
9. A training apparatus for classification models, comprising:
the device comprises a first obtaining module, a second obtaining module and a third obtaining module, wherein the first obtaining module is used for training a classification model by using a current sample set and respectively obtaining a first loss function and a second loss function of the classification model, the first loss function is a loss function with a type as granularity, and the second loss function is a loss function with a sample as granularity;
a second obtaining module, configured to obtain a total loss function of the classification model according to the first loss function and the second loss function; and
and the updating module is used for updating the model parameters of the classification model according to the total loss function so as to train the classification model.
10. The training device of the classification model according to claim 9, wherein the first obtaining module comprises:
the first obtaining unit is used for obtaining the model identification accuracy of the classification model corresponding to the current sample set;
the second acquisition unit is used for respectively acquiring a plurality of type identification accuracy rates of the classification model for different types of samples;
a third obtaining unit, configured to obtain the first loss function according to the multiple type recognition accuracy rates and the model recognition accuracy rate; and
and the fourth obtaining unit is used for obtaining the second loss function according to the model identification accuracy.
11. The training device of the classification model according to claim 10, wherein the third obtaining unit comprises:
the first obtaining subunit is configured to weight the multiple type identification accuracy rates to obtain a type accumulation identification accuracy rate of the classification model; and
and the second obtaining subunit is configured to obtain the first loss function according to the type accumulation identification accuracy and the model identification accuracy.
12. The training apparatus for classification models according to claim 11, wherein the second obtaining subunit obtains the first loss function by using the following formula:
L_c=-(1-P_class)×K1log(P)
wherein L _ c represents the first loss function, P _ class represents the type cumulative identification accuracy, P represents the model identification accuracy, and K represents the model identification accuracy1Representing a constant.
13. The training apparatus of a classification model according to any one of claims 10-12, wherein the first obtaining unit includes:
a third obtaining subunit, configured to obtain a label type of each sample in the current sample set and a recognition type recognized by the classification model for each sample;
a fourth obtaining subunit, configured to compare the identification type of each sample with the mark type, and obtain a first number of samples that are correctly identified by the classification model, where when the mark type is consistent with the identification type, it is determined that the classification model is correctly identified; and
and the fifth obtaining subunit is configured to obtain the model identification accuracy according to the total number of samples in the current sample set and the first number of samples.
14. The training apparatus of a classification model according to any one of claims 10-12, wherein the second obtaining unit includes:
a sixth obtaining subunit, configured to obtain, for each type, a second number of samples, identified by the classification model, in the current sample set and belonging to the first sample of the type;
a seventh obtaining subunit, configured to obtain a mark type of each first sample, and count, according to the mark type of the first sample, the number of third samples in which the classification model is correctly identified in the type, where when the mark type is consistent with the identification type, it is determined that the classification model is correctly identified; and
and the eighth obtaining subunit is configured to obtain the type identification accuracy under the type according to the second sample number and the third sample number.
15. The training device of classification model according to claim 9, wherein the updating module comprises:
a fifth obtaining unit, configured to obtain gradient information of the total loss function; and
and the sixth acquisition unit is used for updating the model parameters according to the gradient information.
16. The training apparatus for classification model according to claim 9 or 15, wherein the first obtaining module is further configured to:
and after updating the parameters of the classification model according to the total loss function, obtaining a next sample set, and continuing training the updated classification model by using the next sample set until the classification model converges or the cumulative training times reach a preset training time, and finishing training to generate a target classification model.
17. An electronic device, comprising:
at least one processor; and
a memory communicatively coupled to the at least one processor; wherein the content of the first and second substances,
the memory stores instructions executable by the at least one processor to enable the at least one processor to perform a method of training a classification model according to any one of claims 1 to 8.
18. A non-transitory computer readable storage medium storing computer instructions for causing a computer to perform the training method of the classification model according to any one of claims 1 to 8.
CN202010606425.2A 2020-06-29 2020-06-29 Method and device for training classification model, electronic equipment and storage medium Pending CN111967492A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202010606425.2A CN111967492A (en) 2020-06-29 2020-06-29 Method and device for training classification model, electronic equipment and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202010606425.2A CN111967492A (en) 2020-06-29 2020-06-29 Method and device for training classification model, electronic equipment and storage medium

Publications (1)

Publication Number Publication Date
CN111967492A true CN111967492A (en) 2020-11-20

Family

ID=73360999

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202010606425.2A Pending CN111967492A (en) 2020-06-29 2020-06-29 Method and device for training classification model, electronic equipment and storage medium

Country Status (1)

Country Link
CN (1) CN111967492A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112331261A (en) * 2021-01-05 2021-02-05 北京百度网讯科技有限公司 Drug prediction method, model training method, device, electronic device, and medium

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019169816A1 (en) * 2018-03-09 2019-09-12 中山大学 Deep neural network for fine recognition of vehicle attributes, and training method thereof
CN110443280A (en) * 2019-07-05 2019-11-12 北京达佳互联信息技术有限公司 Training method, device and the storage medium of image detection model
CN110751175A (en) * 2019-09-12 2020-02-04 上海联影智能医疗科技有限公司 Method and device for optimizing loss function, computer equipment and storage medium
CN111046959A (en) * 2019-12-12 2020-04-21 上海眼控科技股份有限公司 Model training method, device, equipment and storage medium
CN111160448A (en) * 2019-12-26 2020-05-15 北京达佳互联信息技术有限公司 Training method and device for image classification model
CN111291817A (en) * 2020-02-17 2020-06-16 北京迈格威科技有限公司 Image recognition method and device, electronic equipment and computer readable medium

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
WO2019169816A1 (en) * 2018-03-09 2019-09-12 中山大学 Deep neural network for fine recognition of vehicle attributes, and training method thereof
CN110443280A (en) * 2019-07-05 2019-11-12 北京达佳互联信息技术有限公司 Training method, device and the storage medium of image detection model
CN110751175A (en) * 2019-09-12 2020-02-04 上海联影智能医疗科技有限公司 Method and device for optimizing loss function, computer equipment and storage medium
CN111046959A (en) * 2019-12-12 2020-04-21 上海眼控科技股份有限公司 Model training method, device, equipment and storage medium
CN111160448A (en) * 2019-12-26 2020-05-15 北京达佳互联信息技术有限公司 Training method and device for image classification model
CN111291817A (en) * 2020-02-17 2020-06-16 北京迈格威科技有限公司 Image recognition method and device, electronic equipment and computer readable medium

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
THOMAS DESELAERS; TOBIAS GASS; GEORG HEIGOLD; HERMANN NEY: "Latent Log-Linear Models for Handwritten Digit Classification", IEEE TRANSACTIONS ON PATTERN ANALYSIS AND MACHINE INTELLIGENCE, 8 November 2011 (2011-11-08) *
汪颖;孙建风;肖先勇;卢宏;杨晓梅;: "基于优化卷积神经网络的电缆早期故障分类识别", 电力系统保护与控制, no. 07, 1 April 2020 (2020-04-01) *

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112331261A (en) * 2021-01-05 2021-02-05 北京百度网讯科技有限公司 Drug prediction method, model training method, device, electronic device, and medium

Similar Documents

Publication Publication Date Title
CN111428008B (en) Method, apparatus, device and storage medium for training a model
CN112036509A (en) Method and apparatus for training image recognition models
CN111539514A (en) Method and apparatus for generating structure of neural network
CN112001169B (en) Text error correction method and device, electronic equipment and readable storage medium
KR20210132578A (en) Method, apparatus, device and storage medium for constructing knowledge graph
CN111079945B (en) End-to-end model training method and device
CN111461343B (en) Model parameter updating method and related equipment thereof
CN110852379B (en) Training sample generation method and device for target object recognition
CN111968203A (en) Animation driving method, animation driving device, electronic device, and storage medium
CN112149741A (en) Training method and device of image recognition model, electronic equipment and storage medium
CN111090991A (en) Scene error correction method and device, electronic equipment and storage medium
CN111241810A (en) Punctuation prediction method and device
CN111127191A (en) Risk assessment method and device
CN111241838B (en) Semantic relation processing method, device and equipment for text entity
CN114417194A (en) Recommendation system sorting method, parameter prediction model training method and device
CN111640103A (en) Image detection method, device, equipment and storage medium
CN111708477B (en) Key identification method, device, equipment and storage medium
CN110909136A (en) Satisfaction degree estimation model training method and device, electronic equipment and storage medium
CN111967492A (en) Method and device for training classification model, electronic equipment and storage medium
CN112580723A (en) Multi-model fusion method and device, electronic equipment and storage medium
CN110889392B (en) Method and device for processing face image
CN112561059A (en) Method and apparatus for model distillation
CN112016524A (en) Model training method, face recognition device, face recognition equipment and medium
CN111767946A (en) Medical image hierarchical model training and prediction method, device, equipment and medium
CN111177479A (en) Method and device for acquiring feature vectors of nodes in relational network graph

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination