CN115795355A - Classification model training method, device and equipment - Google Patents

Classification model training method, device and equipment Download PDF

Info

Publication number
CN115795355A
CN115795355A CN202310095677.7A CN202310095677A CN115795355A CN 115795355 A CN115795355 A CN 115795355A CN 202310095677 A CN202310095677 A CN 202310095677A CN 115795355 A CN115795355 A CN 115795355A
Authority
CN
China
Prior art keywords
trained
sample
feature vector
classification model
loss
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202310095677.7A
Other languages
Chinese (zh)
Other versions
CN115795355B (en
Inventor
王博
张钰
李兵
胡卫明
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Renmin Zhongke Jinan Intelligent Technology Co ltd
Institute of Automation of Chinese Academy of Science
Original Assignee
Renmin Zhongke Jinan Intelligent Technology Co ltd
Institute of Automation of Chinese Academy of Science
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Renmin Zhongke Jinan Intelligent Technology Co ltd, Institute of Automation of Chinese Academy of Science filed Critical Renmin Zhongke Jinan Intelligent Technology Co ltd
Priority to CN202310095677.7A priority Critical patent/CN115795355B/en
Publication of CN115795355A publication Critical patent/CN115795355A/en
Application granted granted Critical
Publication of CN115795355B publication Critical patent/CN115795355B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Landscapes

  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

The embodiment of the invention relates to the field of artificial intelligence and discloses a classification model training method, a classification model training device and classification model training equipment. The classification model training method provided by the embodiment of the invention is characterized by comprising the following steps: obtaining a sample set to be trained, extracting features to obtain a first feature vector, and calculating to obtain a prototype feature vector; and calling a preset loss adjuster to adjust parameters of a loss function corresponding to the classification model to be trained to obtain a target classification model. Therefore, corresponding strategy adjustment can be applied in time based on the current learning result in the model training process, the training accuracy of the classification model in the small sample learning environment is improved, and the accuracy of classification operation of the trained model is further improved.

Description

Classification model training method, device and equipment
Technical Field
The embodiment of the invention relates to the field of artificial intelligence, in particular to a classification model training method, a classification model training device and classification model training equipment.
Background
With the development of artificial intelligence technology, classification models are widely used in various fields. However, in some special areas, for example: medical and military, classification models are typically trained to recognize specific data. Due to the particularity of the field in which the data is located, a large number of training samples cannot be provided for the corresponding classification model. Therefore, in these special fields, a corresponding classification model training is usually performed by means of small sample learning.
In the existing small sample learning mode, a metric classification setting and a segment type training setting are generally adopted, so that the class combination of each training sample in iteration in the training process is different. In order to ensure the training efficiency and the accuracy of the result, different learning strategies need to be applied to the samples for model training. Since the existing small sample learning mode is realized based on a cross entropy loss function, the function treats all segments equally and ignores the learning strategy corresponding to each segment. Therefore, the existing small sample learning mode lacks self-adaption capability, and further, the sample identification result of the classification model has larger deviation.
Disclosure of Invention
In view of the above problems, the present invention provides a classification model training method, apparatus and device, so as to solve the problem that the existing small sample learning mode lacks adaptive capability, and further causes a larger deviation in the sample recognition result of the classification model.
In a first aspect, the present invention provides a classification model training method, including:
obtaining a sample set to be trained, wherein the sample set to be trained comprises: at least one sample to be trained and at least one category information;
performing feature extraction on each sample to be trained in the sample set to be trained to obtain a first feature vector, wherein the first feature vector is used for representing features corresponding to the sample to be trained, and each first feature vector corresponds to one category information;
calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category of information;
calling a classification model to be trained, and carrying out similarity comparison on at least one query sample based on the prototype feature vector to obtain at least one first comparison result;
and calling a preset loss adjuster based on each first comparison result to perform parameter adjustment on the loss function corresponding to the classification model to be trained so as to enable the results of all the loss functions to be smaller than or equal to a preset threshold value, and obtaining the target classification model.
In some possible embodiments, the performing similarity comparison for at least one query sample based on the prototype feature vector includes:
performing feature extraction on the at least one query sample to obtain at least one second feature vector, wherein each second feature vector is used for representing the feature corresponding to each query sample;
calling the at least one second feature vector and the prototype feature vector corresponding to each category information respectively to carry out similarity comparison one by one to obtain at least one comparison result, wherein each comparison result comprises: the similarity of each query sample and each prototype feature vector and the prediction probability of the corresponding category information of each query sample.
In some possible embodiments, the calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category information includes:
acquiring all the first characteristic vectors corresponding to each category information;
and calculating the average value of all the first feature vectors corresponding to each category information, and taking the calculation result as the prototype feature vector. In this way, the prototype feature vector is set by calculating the average value of the first feature vector corresponding to each category of information, and the prototype feature vector is used as a reference standard to be compared with the second feature vector, so that the comparison result is more accurate.
In some possible embodiments, the invoking a preset loss adjuster performs parameter adjustment on a loss function corresponding to the classification model to be trained, including:
carrying out similarity comparison on each second feature vector and each prototype feature vector to obtain at least one similarity comparison result;
screening to obtain a first maximum value and a second maximum value of the similarity comparison result;
obtaining a probability adjusting function based on the first maximum value and the second maximum value, wherein the probability adjusting function is used for adjusting an output value of a prediction probability;
if the value range distribution of the probability adjusting function is less than 1, calling the loss adjuster to adjust a first mode aiming at the loss function;
and if the value range distribution of the probability adjusting function is larger than 1, calling the loss adjuster to adjust a second mode aiming at the loss function. Therefore, the value range of the result probability adjusting function is adjusted by comparing the similarity of each second feature vector and each prototype feature vector, so that the classification model can be ensured to be timely adjusted according to the current training condition in the training process, and the training efficiency and the accuracy of the training result are improved.
In some possible embodiments, after the obtaining of the sample set to be trained, before performing feature extraction on each sample to be trained in the sample set to be trained, the method further includes:
performing data enhancement on the sample set to be trained, including: random cropping, random erasure, and sample normalization adjustment. Therefore, the accuracy and the training efficiency of the classification model training can be improved.
In some possible embodiments, each of the at least one query sample corresponds to one category information.
In some possible embodiments, the loss function corresponding to the classification model to be trained is a cross-entropy loss function.
In a second aspect, the present invention further provides a classification model training apparatus, including:
the acquisition module acquires a sample set to be trained, wherein the sample set to be trained comprises: at least one sample to be trained and at least one category information;
the characteristic extraction module is used for extracting characteristics of each sample to be trained in the sample set to be trained to obtain a first characteristic vector, the first characteristic vector is used for representing characteristics corresponding to the sample to be trained, and each first characteristic vector corresponds to one category information;
the processing module is used for calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category information;
and the comparison module is used for calling a classification model to be trained, carrying out similarity comparison on at least one query sample based on the prototype feature vector to obtain at least one first comparison result, and if the at least one comparison result is greater than a preset threshold, calling a preset loss adjuster to carry out parameter adjustment on a loss function corresponding to the classification model to be trained so that all comparison results are less than or equal to the preset threshold to obtain a target classification model.
In a third aspect, the present invention further provides an electronic device, including: the system comprises a processor, a memory, a communication interface and a communication bus, wherein the processor, the memory and the communication interface complete mutual communication through the communication bus;
the memory is configured to store executable instructions that, when executed, cause the processor to perform the classification model training method of any one of the possible embodiments of the first aspect or the second aspect.
In a fourth aspect, the present invention further provides a computer-readable storage medium, where executable instructions are stored in the storage medium, and when the executable instructions are executed, a computing device executes the classification model training method in any one of the possible embodiments of the first aspect or the second aspect.
The invention provides a classification model training method, in the scheme, firstly, a sample set to be trained is obtained, and the sample set to be trained comprises the following steps: at least one sample to be trained and at least one category information;
performing feature extraction on each sample to be trained in the sample set to be trained to obtain a first feature vector, wherein the first feature vector is used for representing features corresponding to the sample to be trained, and each first feature vector corresponds to one category information; then, calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category of information; then, calling a classification model to be trained, and comparing the similarity of at least one query sample based on the prototype feature vector to obtain at least one first comparison result; and finally, calling a preset loss adjuster based on each first comparison result to perform parameter adjustment on the loss function corresponding to the classification model to be trained so as to enable the results of all the loss functions to be smaller than or equal to a preset threshold value, and obtaining a target classification model. And then, comparing the current calculation result with a preset threshold value, and calling a preset loss adjustor to perform corresponding loss function parameter adjustment on the classification model to be trained based on the comparison result to obtain a target classification model. Therefore, corresponding strategy adjustment can be applied in time based on the current learning result in the model training process, the training accuracy of the classification model in the small sample learning environment is improved, and the accuracy of classification operation of the trained model is further improved.
Drawings
FIG. 1 is a schematic flow chart of a classification model training method proposed by the present invention;
FIG. 2 is a schematic diagram of a classification model training apparatus according to the present invention;
FIG. 3 is a schematic structural diagram of an electronic device for training a classification model according to the present invention.
Detailed Description
The terminology used in the following examples of the invention is for the purpose of describing alternative embodiments and is not intended to be limiting of the invention. As used in the description of the invention and the appended claims, the singular forms "a", "an", "the" and "the" are intended to include the plural forms as well. It should also be understood that although the terms first, second, etc. may be used in the following embodiments to describe a class of objects, the objects are not limited to these terms. These terms are used to distinguish between particular objects of that class of objects. For example, the following embodiments may adopt the terms first, second, etc. to describe other class objects in the same way, and are not described herein again.
With the development of artificial intelligence technology, classification models are widely used in various fields. However, in some special areas, for example: medical and military, classification models are often trained to recognize specific data. Due to the particularity of the field of the data, a large number of training samples cannot be provided for the corresponding classification model. Therefore, in these special fields, small sample learning is usually adopted for corresponding classification model training.
The small sample learning technique can significantly reduce the time and labor cost required for the identification-only system construction. Generally speaking, a small sample learning model is first trained on a source data domain rich in labeled samples to obtain a priori knowledge, and then fine-tuned on a target data domain with only a small number of labeled samples or directly applied to the target data domain.
In the existing small sample learning mode, a metric classification setting and a segment type training setting are generally adopted, and in this case, if a traditional training method based on a cross entropy loss function is adopted, the two are difficult to match.
In the first aspect, the similarity between the query sample and the class prototype is usually calculated by using a cosine function in the metric classification setting based on the nearest neighbor algorithm, so that for a backbone network using a modified Linear Unit (ReLU) as an activation function, the similarity between the output features is limited to an interval of 0 to 1, which results in a corresponding output probability distribution being limited to an interval of a smaller range.
For example, for 1 query sample and 5 category support samples, the corresponding output prediction probability is only distributed between 0 and 0.4046 after calculation.
In a second aspect, small samples are learned using a segmented training setup, which makes the class combinations different each time a training sample is iterated during the training process. In order to ensure the training efficiency and the accuracy of the result, different learning strategies need to be applied to the samples for model training. Because the existing small sample learning mode is realized based on the cross quotient loss function, the function can treat all segments equally, the self-adaptive capacity is lacked, and the learning strategy corresponding to each segment is further ignored.
The problem greatly limits the guiding function of the traditional cross entropy loss function on feature space learning under a small sample metric learning framework and limits the learning capability on a hard-to-weight sample. Therefore, the existing small sample learning mode lacks self-adaption capability, and further, the sample identification result of the classification model has larger deviation.
The embodiment of the application can specifically relate to Machine Learning in artificial intelligence, wherein Machine Learning (ML) is a multi-field cross subject and relates to multi-subjects such as probability theory, statistics, approximation theory, convex analysis, algorithm complexity theory and the like. Machine learning is the core of artificial intelligence, is the fundamental approach for computers to have intelligence, and is applied to all fields of artificial intelligence. Machine learning and deep learning generally include techniques such as artificial neural networks, belief networks, reinforcement learning, transfer learning, inductive learning, and formula learning. The training of the model is mainly realized through machine learning.
Of course, the model training method provided by the embodiment of the application can be applied to different fields, and in different fields, specific other technologies in artificial intelligence can be involved. For example, if the method is applied to the fields of facial expression recognition, face recognition and the like, computer Vision may be involved, and Computer Vision technology (Computer Vision, CV) is a science for researching how to make a machine "see", and further, the method is to use a camera and a Computer to replace human eyes to perform machine Vision such as recognition, following, measurement and the like on a target, and further perform image processing, so that the Computer processing becomes an image more suitable for human eyes to observe or transmit to an instrument to detect. The computer vision technology generally includes image processing, image recognition, image semantic understanding, image retrieval, OCR, video processing, video semantic understanding, video content/behavior recognition, three-dimensional object reconstruction, 3D technology, virtual reality, augmented reality, synchronous positioning and map construction, automatic driving, intelligent transportation and other technologies, and also includes common biometric identification technologies such as face recognition and fingerprint recognition.
The strategy model training method applied to multi-task learning provided by the embodiment of the application can be executed by one electronic device or a computer cluster. The computer cluster comprises at least two electronic devices supporting the strategy model training method of the embodiment of the application, and any one of the electronic devices can realize the strategy model training function described in the embodiment of the application through the strategy model training method.
Any electronic device designed by the embodiment of the application can be an electronic device such as a mobile phone, a tablet computer, a wearable device (e.g., a smart watch, a smart bracelet, etc.), a notebook computer, a desktop computer, and an in-vehicle device. The electronic device is preinstalled with a policy model training application. It is understood that the embodiment of the present application does not set any limit to the specific type of the electronic device.
The following is a description of several exemplary embodiments, and the technical solutions of the embodiments of the present invention and the technical effects produced by the technical solutions of the present invention will be explained.
FIG. 1 is a schematic flow chart of a classification model training method according to the present invention. As shown in fig. 1, the classification model training method proposed by the present invention includes the following steps:
s100: obtaining a sample set to be trained, wherein the sample set to be trained comprises: at least one sample to be trained and at least one category information;
s200: performing feature extraction on each sample to be trained in the sample set to be trained to obtain a first feature vector, wherein the first feature vector is used for representing features corresponding to the sample to be trained, and each first feature vector corresponds to one category information;
s300: calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category of information;
s400: calling a classification model to be trained, and carrying out similarity comparison on at least one query sample based on the prototype feature vector to obtain at least one first comparison result;
s500: and calling a preset loss adjuster to perform parameter adjustment on the loss function corresponding to the classification model to be trained based on each first comparison result, so that the results of all the loss functions are smaller than or equal to a preset threshold value, and obtaining a target classification model.
Illustratively, the image classification model training in a given scenario is a small sample learning scenario.
For this, a set of samples to be trained is first specified. As can be seen from the above, the training sample set obtaining method includes: and selecting an image data set with rich data as a source data field, and generating a training fragment set by using the source data field.
It is understood that the fragment set includes a support set and a query set, where the support set is the above-mentioned sample set to be trained, and includes at least one training sample (here, an image training sample) and at least one category information.
It can be understood that each training sample corresponds to a category information, that is, each training sample can be classified into a category; each category information may include a plurality of training samples, that is, a plurality of training samples may belong to a category.
It is understood that, for S500, the preset threshold is an evaluation criterion set for whether the classification model to be trained can become the target classification model, and is set for enabling the model to be trained to become the target classification model. If the above objective can be achieved by setting the number of training sessions, S500 may be replaced by: and calling a preset loss adjuster based on each first comparison result to adjust parameters of a loss function corresponding to the classification model to be trained until preset training times are met to obtain a target classification model. And waiting for semantic similarity processing steps.
In a possible embodiment, before S100, it is further required to select to construct a classification model (initial model or blank model) based on small sample metric learning, which includes, for example: a deep convolutional neural network is selected as a feature extractor, for example: VGG16, incleptionv 3, resNet50, and the like; the non-parametric nearest classifier is selected as the classifier.
In a possible embodiment, before S100, a pre-training process is further performed on the initial model (or the blank model) to make it have a basic classification capability, including:
an initial data set (which may be understood as an image data set in this scenario) is obtained as a source data field to pre-train the initial model (or a blank model), and the pre-trained model may be directly used to implement a classification task of a target data field (i.e., the training segment set).
It is understood that the target data field (i.e., the training segment set) may not intersect with the source data field in the category.
In a possible embodiment, before S100, performing data enhancement on the sample set to be trained includes: random cropping, random erasure, and sample normalization adjustment.
Illustratively, in the context of the image classification model described above, the training segments are subject to sample normalization adjustments, which are 84pt × 84pt.
In a possible implementation manner, for S400, the performing similarity comparison for at least one query sample based on the prototype feature vector includes:
s410: performing feature extraction on the at least one query sample to obtain at least one second feature vector, wherein each second feature vector is used for representing the feature corresponding to each query sample;
s420: calling the at least one second feature vector and the prototype feature vector corresponding to each category information respectively to carry out similarity comparison one by one to obtain at least one comparison result, wherein each comparison result comprises: similarity of each query sample and each prototype feature vector and prediction probability of corresponding category information of each query sample.
It is understood that, for the purposes of S410 to S420, the purpose here is only to enable the similarity comparison between the query sample and the training sample to obtain the first comparison result. The detailed similarity comparison process and method are well known in the art and are not described in detail herein.
It is understood that, for S410, the meaning characterized by the feature of each query sample includes the category information corresponding to each query sample.
For example, it is assumed that the set of samples to be trained includes N classes of samples to be trained, and each class includes K samples (in this scenario, image samples may be understood), where N and K are integers greater than or equal to 1. Then, N × K image feature vectors are obtained after S200.
It is understood that the query sample herein may be obtained from the set of queries in said set of segments.
Optionally, the number of category information of the query sample in the query set herein is equal to the number of category information in the sample to be trained. That is, the number of categories in the query set here may be N.
Further, feature extraction is performed on the query set, and the number of obtained largest category feature vectors (i.e., the second feature vectors) is also N.
Optionally, for S420, the prototype feature vector is an average value of the first feature vectors corresponding to each category of information, and the obtaining manner includes:
s421: acquiring all the first feature vectors corresponding to each category of information;
s422: and calculating the average value of all the first feature vectors corresponding to each category information, and taking the calculation result as the prototype feature vector.
Specifically, for S420, each alignment result includes: similarity of each query sample and each prototype feature vector and prediction probability of corresponding category information of each query sample, wherein the prediction probability can be expressed as:
Figure SMS_1
wherein ,
Figure SMS_2
representing the probability adjuster (i.e. the probability adjustment function described above),
Figure SMS_3
it is shown that the positive class prediction is,
Figure SMS_4
representing the similarity between the query image feature vector x and the class prototype feature vector c.
In particular, for probability regulators
Figure SMS_5
(i.e., the probability adjustment function) the function is expressed as follows:
Figure SMS_6
wherein ,
Figure SMS_7
which represents a scaling factor, is the ratio of the scaling factor,
Figure SMS_8
representing the first maximum in the set of similarities between the query image and all the category prototypes,
Figure SMS_9
representing in a set of similarities between a query image and all category prototypesThe second maximum value. The second maximum value is the division
Figure SMS_10
Except for the maximum in the set.
It will be appreciated that the probability adjustment function herein is used to achieve a predicted probability distribution of the scaled model output. The probability regulator acts on a normalization function in the standard cross entropy loss, and can automatically judge (or judge according to a preset rule) the loss contribution of a sample to be amplified or the loss contribution of a sample to be reduced according to the similarity ranking between a query sample and various types of prototype feature vectors, and calculate the corresponding scaling strength according to the difference between positive similarity (the similarity between the second feature vector and the positive type of prototype feature vectors) and negative similarity (the similarity between the second feature vector and the negative type of prototype feature vectors).
It is understood that the probability adjustment function may dynamically adjust the predicted probability distribution of the model output according to the size and ordering information of the predicted values.
It can be understood that, if only 1 training sample exists in the sample set to be trained, that is, only 1 feature vector exists in S422, it is determined that only 1 feature vector exists is the corresponding prototype feature vector, and averaging calculation is not required.
In one possible implementation, the standard cross-entropy loss function value is calculated based on the above expression of the predictive probability distribution.
Further, in S500, the invoking a preset loss adjuster to perform parameter adjustment on the loss function corresponding to the classification model to be trained includes:
s510: carrying out similarity comparison on each second feature vector and each prototype feature vector to obtain at least one similarity comparison result;
s520: screening to obtain a first maximum value and a second maximum value of the similarity comparison result;
s530: obtaining a probability adjusting function based on the first maximum value and the second maximum value, wherein the probability adjusting function is used for adjusting an output value of a prediction probability;
s540: if the value range distribution of the probability adjusting function is less than 1, calling the loss adjuster to adjust a first mode aiming at the loss function;
s550: and if the value range distribution of the probability adjusting function is larger than 1, calling the loss adjuster to adjust a second mode aiming at the loss function.
Specifically, the adjustment of the loss adjuster (i.e. the loss adjusting function) mentioned in S540 and S550 to the function loss value may be:
Figure SMS_11
wherein Loss represents a calculated Loss value,
Figure SMS_12
a loss adjuster is shown.
In particular, in the case of a loss adjuster,
Figure SMS_13
is a piecewise function affected by the value domain of the probability adjustment function, and can be expressed as:
Figure SMS_14
wherein ,
Figure SMS_17
which is indicative of the focus factor,
Figure SMS_20
representing a scaling factor. According to the functional expression, the loss adjuster
Figure SMS_23
Adjusting a function according to the probability
Figure SMS_16
With a magnitude relation of 1, to perform differentlyAnd (4) adjusting. That is to say that the first and second electrodes,
Figure SMS_19
when less than 1, define
Figure SMS_22
=
Figure SMS_24
In order to adjust the first mode, the first mode is adjusted,
Figure SMS_15
when greater than 1, define
Figure SMS_18
=
Figure SMS_21
Is the second mode adjustment. Therefore, it can be understood that by performing loss function adjustment on the value range of the result probability adjustment function based on the similarity comparison between each second feature vector and each prototype feature vector, it can be ensured that the classification model performs strategy adjustment in time according to the current training condition in the training process, and the training efficiency and the accuracy of the training result are improved.
It can be understood that the loss adjuster acts on the loss calculation part of the standard cross entropy loss, and adopts a piecewise function to further adjust the sample loss contribution in the correct prediction and the incorrect prediction, so that the model can be effectively learned in a small sample scene, namely, the problem that a large number of samples linger around in the convenient and fast decision vicinity for a long time due to the great difficulty of feature learning can be solved. For correctly classified samples, the loss adjuster can reduce the loss weight of the samples, and avoids over-compact distribution of the samples in the class; for a misclassified sample, the loss adjuster may display increasing the loss weight to enhance the interest level of the model for that sample.
In the foregoing embodiment, the training method for the classification model is described in detail, and it should be understood that the model trained by applying the above method may perform a corresponding classification action, which specifically includes:
acquiring a project set to be classified;
calling a target classification model to classify the item set to be classified, wherein the target classification model is obtained by training in S100-S500;
and outputting a classification result.
The embodiments described above introduce various embodiments of the split model training method provided by the present invention from the processing perspective of action logic and learning algorithm executed by an electronic device, such as an acquisition manner of a sample set to be trained, extraction of a first feature vector, acquisition of a source feature vector, training of a classification model, and the like. It should be understood that, corresponding to the processing steps of the acquisition mode from the sample set to be trained, the extraction of the first feature vector, the acquisition of the source feature vector, the training of the classification model, and the like, the embodiments of the present invention may implement the above functions in the form of hardware or a combination of hardware and computer software. Whether a function is performed as hardware or computer software drives hardware depends upon the particular application and design constraints imposed on the solution. Skilled artisans may implement the described functionality in varying ways for each particular application, but such implementation decisions should not be interpreted as causing a departure from the scope of the present invention.
For example, the functions realized by the above implementation steps can also be realized by the classification model training device. FIG. 2 is a schematic diagram of a classification model training apparatus according to the present invention. As shown in fig. 2, the classification model training apparatus may include:
the acquisition module acquires a sample set to be trained, wherein the sample set to be trained comprises: at least one sample to be trained and at least one category information;
the characteristic extraction module is used for extracting characteristics of each sample to be trained in the sample set to be trained to obtain a first characteristic vector, the first characteristic vector is used for representing characteristics corresponding to the sample to be trained, and each first characteristic vector corresponds to one category information;
the processing module is used for calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category information;
and the comparison module calls a preset loss adjuster to perform parameter adjustment on the loss function corresponding to the classification model to be trained based on each first comparison result so as to enable the results of all the loss functions to be smaller than or equal to a preset threshold value and obtain the target classification model.
It should be understood that the above modules/units are merely a logical division, and in actual implementation, the functions of the above modules may be integrated into a hardware entity, for example, the obtaining module, the calling module, the first parameter adjusting module, and the second parameter adjusting module may be integrated into a processor, and programs and instructions for implementing the functions of the above modules may be maintained in a memory. For example, fig. 3 is a schematic structural diagram of an electronic device for training a classification model according to the present invention. As shown in fig. 3, the electronic device includes a processor, a transceiver, and a memory. The transceiver is used for obtaining a sample set to be trained in the classification model training method. The memory may be used to store model training related information, code for execution by the processor, and the like. And when the processor executes the code stored in the memory, the electronic equipment is caused to execute part or all of the operation of the strategy model training method in the method.
The specific implementation process is described in detail in the embodiment of the method, and is not detailed here.
In a specific implementation, corresponding to the foregoing electronic device, an embodiment of the present invention further provides a computer storage medium, where the computer storage medium disposed in the electronic device may store a program, and when the program is executed, part or all of the steps in each embodiment of the training method including the policy model may be implemented. The storage medium may be a magnetic disk, an optical disk, a read-only memory (ROM), a Random Access Memory (RAM), or the like.
One or more of the above modules or units may be implemented in software, hardware or a combination of both. When any of the above modules or units are implemented in software, which is present as computer program instructions and stored in a memory, a processor may be used to execute the program instructions and implement the above method flows. The processor may include, but is not limited to, at least one of: various computing devices that run software, such as a Central Processing Unit (CPU), a microprocessor, a Digital Signal Processor (DSP), a Microcontroller (MCU), or an artificial intelligence processor, may each include one or more cores for executing software instructions to perform operations or processing. The processor may be built in a SoC (system on chip) or an Application Specific Integrated Circuit (ASIC), or may be a separate semiconductor chip. The processor may further include a necessary hardware accelerator such as a Field Programmable Gate Array (FPGA), a PLD (programmable logic device), or a logic circuit for implementing a dedicated logic operation, in addition to a core for executing software instructions to perform an operation or a process.
When the above modules or units are implemented in hardware, the hardware may be any one or any combination of a CPU, a microprocessor, a DSP, an MCU, an artificial intelligence processor, an ASIC, an SoC, an FPGA, a PLD, a dedicated digital circuit, a hardware accelerator, or a discrete device that is not integrated, which may run necessary software or is independent of software to perform the above method flows.
Further, a bus interface may also be included in FIG. 3, which may include any number of interconnected buses and bridges, with one or more processors, represented by a processor, and various circuits, represented by a memory, being linked together. The bus interface may also link together various other circuits such as peripherals, voltage regulators, power management circuits, and the like, which are well known in the art, and therefore, will not be described any further herein. The bus interface provides an interface. The transceiver provides a means for communicating with various other apparatus over a transmission medium. The processor is responsible for managing the bus architecture and the usual processing, and the memory may store data used by the processor in performing operations.
When the above modules or units are implemented using software, they may be implemented in whole or in part in the form of a computer program product. The computer program product includes one or more computer instructions. When loaded and executed on a computer, cause the processes or functions described in accordance with the embodiments of the invention to occur, in whole or in part. The computer may be a general purpose computer, a special purpose computer, a network of computers, or other programmable device. The computer instructions may be stored in a computer readable storage medium or transmitted from one computer readable storage medium to another, for example, from one website site, computer, server, or data center to another website site, computer, server, or data center via wired (e.g., coaxial cable, fiber optic, digital Subscriber Line (DSL)) or wireless (e.g., infrared, wireless, microwave, etc.). The computer-readable storage medium can be any available medium that can be accessed by a computer or a data storage device, such as a server, a data center, etc., that incorporates one or more of the available media. The usable medium may be a magnetic medium (e.g., floppy Disk, hard Disk, magnetic tape), an optical medium (e.g., DVD), or a semiconductor medium (e.g., solid State Disk (SSD)), among others.
It should be understood that, in various embodiments of the present invention, the sequence numbers of the processes do not mean the execution sequence, and the execution sequence of the processes should be determined by the functions and the internal logic, and should not constitute any limitation to the implementation process of the embodiments.
All parts of the specification are described in a progressive mode, the same and similar parts of all embodiments can be referred to each other, and each embodiment is mainly introduced to be different from other embodiments. In particular, as to the apparatus and system embodiments, since they are substantially similar to the method embodiments, the description is relatively simple and reference may be made to the description of the method embodiments in relevant places.
While alternative embodiments of the present invention have been described, additional variations and modifications in those embodiments may occur to those skilled in the art once they learn of the basic inventive concepts. Therefore, it is intended that the appended claims be interpreted as including the preferred embodiment and all changes and modifications that fall within the scope of the invention.
The above-mentioned embodiments, objects, technical solutions and advantages of the present invention are further described in detail, it should be understood that the above-mentioned embodiments are only examples of the present invention, and are not intended to limit the scope of the present invention, and any modifications, equivalent substitutions, improvements and the like made on the basis of the technical solutions of the present invention should be included in the scope of the present invention.

Claims (10)

1. A classification model training method, the method comprising:
obtaining a sample set to be trained, wherein the sample set to be trained comprises: at least one sample to be trained and at least one category information;
extracting features of each sample to be trained in the sample set to be trained to obtain a first feature vector, wherein the first feature vector is used for representing features corresponding to the sample to be trained, and each first feature vector corresponds to one category information;
calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category of information;
calling a classification model to be trained, and carrying out similarity comparison on at least one query sample based on the prototype feature vector to obtain at least one first comparison result;
and calling a preset loss adjuster to perform parameter adjustment on the loss function corresponding to the classification model to be trained based on each first comparison result, so that the results of all the loss functions are smaller than or equal to a preset threshold value, and obtaining a target classification model.
2. The method of claim 1, wherein the performing similarity comparisons for at least one query sample based on the prototype feature vectors comprises:
performing feature extraction on the at least one query sample to obtain at least one second feature vector, wherein each second feature vector is used for representing the feature corresponding to each query sample;
calling the at least one second feature vector and the prototype feature vector corresponding to each category information respectively to perform similarity comparison one by one to obtain at least one first comparison result, wherein each comparison result comprises: similarity of each query sample and each prototype feature vector and prediction probability of corresponding category information of each query sample.
3. The method of claim 1, wherein the calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category information comprises:
acquiring all the first characteristic vectors corresponding to each category information;
and performing averaging calculation on all the first feature vectors corresponding to each category of information, and taking the calculation result as the prototype feature vector.
4. The method of claim 1, wherein the invoking of the preset loss adjuster performs parameter adjustment for a loss function corresponding to the classification model to be trained, and comprises:
carrying out similarity comparison on each second feature vector and each prototype feature vector to obtain at least one similarity comparison result;
screening to obtain a first maximum value and a second maximum value of the similarity comparison result;
obtaining a probability adjusting function based on the first maximum value and the second maximum value, wherein the probability adjusting function is used for adjusting an output value of a prediction probability;
if the value range distribution of the probability adjusting function is less than 1, calling the loss adjuster to adjust a first mode aiming at the loss function;
and if the value range distribution of the probability adjusting function is larger than 1, calling the loss adjuster to adjust a second mode aiming at the loss function.
5. The method as claimed in claim 1, wherein after the obtaining of the sample set to be trained, before performing feature extraction on each sample to be trained in the sample set to be trained, further comprising:
performing data enhancement on the sample set to be trained, including: random cropping, random erasure, and sample normalization adjustment.
6. The method of claim 2, wherein each query sample of the at least one query sample corresponds to a category information.
7. The method of claim 1, wherein the loss function for the classification model to be trained is a cross entropy loss function.
8. A classification model training apparatus, characterized in that the apparatus comprises:
the acquisition module acquires a sample set to be trained, wherein the sample set to be trained comprises: at least one sample to be trained and at least one category information;
the characteristic extraction module is used for extracting characteristics of each sample to be trained in the sample set to be trained to obtain a first characteristic vector, the first characteristic vector is used for representing characteristics corresponding to the sample to be trained, and each first characteristic vector corresponds to one category information;
the processing module is used for calculating based on the first feature vector corresponding to each sample to be trained to obtain a prototype feature vector corresponding to each category information;
and the comparison module is used for calling a classification model to be trained, carrying out similarity comparison on at least one query sample based on the prototype feature vector to obtain at least one first comparison result, calling a preset loss adjuster based on each first comparison result to carry out parameter adjustment on a loss function corresponding to the classification model to be trained so as to enable the results of all the loss functions to be smaller than or equal to a preset threshold value, and obtaining a target classification model.
9. An electronic device, comprising: the system comprises a processor, a memory, a communication interface and a communication bus, wherein the processor, the memory and the communication interface complete mutual communication through the communication bus;
the memory is to store executable instructions that when executed cause the processor to perform the classification model training method of any one of claims 1-7.
10. A computer storage medium having stored therein executable instructions that when executed cause a computing device to perform the classification model training method of any one of claims 1-7.
CN202310095677.7A 2023-02-10 2023-02-10 Classification model training method, device and equipment Active CN115795355B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310095677.7A CN115795355B (en) 2023-02-10 2023-02-10 Classification model training method, device and equipment

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310095677.7A CN115795355B (en) 2023-02-10 2023-02-10 Classification model training method, device and equipment

Publications (2)

Publication Number Publication Date
CN115795355A true CN115795355A (en) 2023-03-14
CN115795355B CN115795355B (en) 2023-09-12

Family

ID=85430813

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310095677.7A Active CN115795355B (en) 2023-02-10 2023-02-10 Classification model training method, device and equipment

Country Status (1)

Country Link
CN (1) CN115795355B (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116935160A (en) * 2023-07-19 2023-10-24 上海交通大学 Training method, sample classification method, electronic equipment and medium

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190065899A1 (en) * 2017-08-30 2019-02-28 Google Inc. Distance Metric Learning Using Proxies
CN109961089A (en) * 2019-02-26 2019-07-02 中山大学 Small sample and zero sample image classification method based on metric learning and meta learning
CN110472652A (en) * 2019-06-30 2019-11-19 天津大学 A small amount of sample classification method based on semanteme guidance
CN111797893A (en) * 2020-05-26 2020-10-20 华为技术有限公司 Neural network training method, image classification system and related equipment
CN112949740A (en) * 2021-03-17 2021-06-11 重庆邮电大学 Small sample image classification method based on multilevel measurement

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190065899A1 (en) * 2017-08-30 2019-02-28 Google Inc. Distance Metric Learning Using Proxies
CN109961089A (en) * 2019-02-26 2019-07-02 中山大学 Small sample and zero sample image classification method based on metric learning and meta learning
CN110472652A (en) * 2019-06-30 2019-11-19 天津大学 A small amount of sample classification method based on semanteme guidance
CN111797893A (en) * 2020-05-26 2020-10-20 华为技术有限公司 Neural network training method, image classification system and related equipment
CN112949740A (en) * 2021-03-17 2021-06-11 重庆邮电大学 Small sample image classification method based on multilevel measurement

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116935160A (en) * 2023-07-19 2023-10-24 上海交通大学 Training method, sample classification method, electronic equipment and medium
CN116935160B (en) * 2023-07-19 2024-05-10 上海交通大学 Training method, sample classification method, electronic equipment and medium

Also Published As

Publication number Publication date
CN115795355B (en) 2023-09-12

Similar Documents

Publication Publication Date Title
US11487995B2 (en) Method and apparatus for determining image quality
WO2020228525A1 (en) Place recognition method and apparatus, model training method and apparatus for place recognition, and electronic device
CN111931592B (en) Object recognition method, device and storage medium
JP2022141931A (en) Method and device for training living body detection model, method and apparatus for living body detection, electronic apparatus, storage medium, and computer program
WO2021238586A1 (en) Training method and apparatus, device, and computer readable storage medium
CN114329029B (en) Object retrieval method, device, equipment and computer storage medium
CN113222149A (en) Model training method, device, equipment and storage medium
CN113254491A (en) Information recommendation method and device, computer equipment and storage medium
CN111126347A (en) Human eye state recognition method and device, terminal and readable storage medium
CN112668482A (en) Face recognition training method and device, computer equipment and storage medium
CN114611672A (en) Model training method, face recognition method and device
CN116994021A (en) Image detection method, device, computer readable medium and electronic equipment
CN115795355B (en) Classification model training method, device and equipment
CN117291185A (en) Task processing method, entity identification method and task processing data processing method
CN116109834A (en) Small sample image classification method based on local orthogonal feature attention fusion
CN111860054A (en) Convolutional network training method and device
CN115205956A (en) Left and right eye detection model training method, method and device for identifying left and right eyes
CN112487927B (en) Method and system for realizing indoor scene recognition based on object associated attention
CN116310615A (en) Image processing method, device, equipment and medium
CN111539420B (en) Panoramic image saliency prediction method and system based on attention perception features
CN111401112B (en) Face recognition method and device
CN114462526A (en) Classification model training method and device, computer equipment and storage medium
CN113435519A (en) Sample data enhancement method, device, equipment and medium based on antagonistic interpolation
Jain et al. Real-time eyeglass detection using transfer learning for non-standard facial data.
CN111091198A (en) Data processing method and device

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant