CN116778264B - Object classification method, image classification method and related equipment based on class reinforcement learning - Google Patents

Object classification method, image classification method and related equipment based on class reinforcement learning Download PDF

Info

Publication number
CN116778264B
CN116778264B CN202311070387.3A CN202311070387A CN116778264B CN 116778264 B CN116778264 B CN 116778264B CN 202311070387 A CN202311070387 A CN 202311070387A CN 116778264 B CN116778264 B CN 116778264B
Authority
CN
China
Prior art keywords
prompt
classification
data
training data
codebook
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202311070387.3A
Other languages
Chinese (zh)
Other versions
CN116778264A (en
Inventor
戴勇
洪晓鹏
王耀威
蒋冬梅
王亚斌
马智恒
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Peng Cheng Laboratory
Original Assignee
Peng Cheng Laboratory
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Peng Cheng Laboratory filed Critical Peng Cheng Laboratory
Priority to CN202311070387.3A priority Critical patent/CN116778264B/en
Publication of CN116778264A publication Critical patent/CN116778264A/en
Application granted granted Critical
Publication of CN116778264B publication Critical patent/CN116778264B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Abstract

The embodiment of the application provides an object classification method, an image classification method and related equipment based on class reinforcement learning, and belongs to the technical field of artificial intelligence. The method comprises the following steps: acquiring a prompt codebook and a training data characteristic representation; the method comprises the steps of inputting a prompt codebook and a training data characteristic representation into a prompt combination network of an original classification model to carry out prompt prediction to obtain combined prompt data; weighting the combined prompt data and the training data characteristic representation through a prompt weighting network of the original classification model to obtain weighted prompt information; classifying the weighted prompt information and the training data characteristic representation through a classification network of the original classification model to obtain a classification prediction result; optimizing the original classification model according to the classification verification result and the classification prediction result to obtain a target classification model; and classifying the target data through the target classification model to obtain a target classification result. The application can keep old knowledge learning when data is newly added, improve classification accuracy and reduce data volume storage.

Description

Object classification method, image classification method and related equipment based on class reinforcement learning
Technical Field
The application relates to the technical field of artificial intelligence, in particular to an object classification method, an image classification method and related equipment based on class reinforcement learning.
Background
With the development of information technology, data input by a classification model is also generated in an incremental manner. However, the conventional classification model is often trained by using a closed data set, so that the classification model cannot effectively cope with the processing of newly added data.
In order to enable the classification model to have the capability of processing newly added data, three types of incremental learning methods are set in the related technology, and the three types of incremental learning methods are a network fine tuning method, a regularization method and a playback method respectively. Most of network fine tuning methods only fine tune the full connection layer related to the new task, forget more for the old task, and increase the network volume. The regularization method is based on the fact that old knowledge is not covered by newly added knowledge in a mode of applying constraint to the loss function of the new task, but the regularization method is highly dependent on the correlation between the new task and the old task, and training time can be linearly increased along with the number of learning tasks. When a new task is trained, a part of representative old data is reserved for the classification model to review old knowledge, but the playback method is sensitive to the selection of the old data, and the calculation cost is multiplied along with the increase of the number of the tasks. Therefore, how to learn new data by the classification model without affecting old knowledge learning and without multiplying the task number is a technical problem to be solved urgently.
Disclosure of Invention
The embodiment of the application mainly aims to provide an object classification method, an image classification method and related equipment based on class augmentation learning, aiming at enabling a classification model to learn new added data without influencing old knowledge learning and without multiplying the task number.
To achieve the above object, a first aspect of an embodiment of the present application provides an object classification method based on class reinforcement learning, the method including:
acquiring a training data set of a current increment task; wherein the training data set comprises: prompting a codebook and training data;
performing feature mapping processing on the training data to obtain training data feature representation;
inputting the prompt codebook and the training data characteristic representation into a preset original classification model; wherein the original classification model comprises: a hint combining network, a hint weighting network, and a classification network;
prompting and predicting the prompting codebook and the training data characteristic representation through the prompting combined network to obtain combined prompting data;
weighting the combined prompt data and the training data characteristic representation through the prompt weighting network to obtain weighted prompt information;
Performing object classification on the weighted prompt information and the training data characteristic representation through the classification network to obtain a classification prediction result;
optimizing the original classification model according to a preset classification verification result and the classification prediction result to obtain a target classification model;
and obtaining target data, inputting the target data into the target classification model to perform object classification, and obtaining a target classification result.
In some embodiments, the performing, through the hint combination network, hint prediction on the hint codebook and the training data feature representation to obtain combined hint data includes:
coefficient prediction is carried out on the prompt codebook and the training data characteristic representation through the prompt combination network, so that prompt combination coefficients are obtained;
and multiplying the prompt combination coefficient by a data matrix of the prompt codebook through the prompt combination network to obtain the combined prompt data.
In some embodiments, the hint combining network includes: at least two correlation measurement layers and a coefficient prediction layer; the method for predicting the coefficients of the prompt codebook and the training data characteristic representation through the prompt combination network to obtain prompt combination coefficients comprises the following steps:
Performing relevance measurement processing on the prompt codebook and the training data characteristic representation through the relevance measurement layer to obtain relevance measurement data;
and carrying out coefficient prediction on the correlation measurement data through the coefficient prediction layer to obtain the prompt combination coefficient.
In some embodiments, the weighting the combined prompt data and the training data feature representation through the prompt weighting network to obtain weighted prompt information includes:
the prompt weighting network is used for carrying out weighting calculation on the combined prompt data and the training data characteristic representation to obtain prompt weighting;
and constructing the prompt weighted weight and the combined prompt data into the weighted prompt information through the prompt weighted network.
In some embodiments, the hint weighting network includes: a full connection layer; the step of calculating the weighting weights of the combined prompt data and the training data characteristic representation through a prompt weighting network to obtain prompt weighting weights comprises the following steps:
the combined prompt data is mapped to a preset space through the full connection layer to obtain combined prompt mapping data, and the training data feature representation is mapped to the preset space through the full connection layer to obtain training data feature mapping data;
Carrying out correlation measurement on the combined prompt mapping data and the training data feature mapping data to obtain information correlation data;
carrying out standardized processing on the information correlation data according to preset length information to obtain target correlation data; the length information is the length of a basic prompt in the prompt codebook;
and carrying out nonlinear mapping processing on the target correlation data to obtain the prompt weighting.
In some embodiments, the classifying, through the classification network, the weighted prompt information and the training data feature representation to obtain a classification prediction result, and obtaining the classification prediction result includes:
setting preset initialization prompt information at a head layer of the classification network, setting weighted prompt information at a middle layer of the classification network, and setting output dimensions of a tail layer in the classification network according to the total number of categories of the current incremental task;
and inputting the training data characteristic representation to the classification network with the structure adjusted to perform classification prediction, so as to obtain the classification prediction result.
In some embodiments, the optimizing the original classification model according to the classification verification result and the classification prediction result to obtain a target classification model includes:
Performing cross entropy loss calculation on the classification verification result and the classification prediction result to obtain classification loss data;
and carrying out parameter adjustment on the original classification model according to the classification loss data to obtain the target classification model.
In some embodiments, after said cross entropy loss calculation of said classification verification result and said classification prediction result, obtaining classification loss data, the method further comprises:
and optimizing the prompt codebook according to the classification loss data.
In some embodiments, the obtaining the training data set of the current incremental task includes:
acquiring training data and a candidate codebook of the current increment task;
initializing the candidate codebook according to a preset standard deviation to obtain a prompt codebook; wherein the hint codebook includes at least two basic hints.
In order to achieve the above object, a second aspect of the embodiments of the present application provides an image classification method applied to image classification, where the target classification model is an image classification model, and the target classification model is obtained from the object classification based on class reinforcement learning; the method comprises the following steps:
Acquiring target image data;
and inputting the target image data into the image classification model to perform image classification, so as to obtain image category information.
To achieve the above object, a third aspect of the embodiments of the present application provides an object classification device based on class reinforcement learning, the device including:
the data set acquisition module is used for acquiring a training data set of the current incremental task; wherein the training data set comprises: prompting a codebook and training data;
the feature mapping module is used for carrying out feature mapping processing on the training data to obtain a training data feature representation;
the input module is used for inputting the prompt codebook and the training data characteristic representation into a preset original classification model; wherein the original classification model comprises: a hint combining network, a hint weighting network, and a classification network;
the prompt prediction module is used for carrying out prompt prediction on the prompt codebook and the training data characteristic representation through the prompt combination network to obtain combined prompt data;
the weighting processing module is used for carrying out weighting processing on the combined prompt data and the training data characteristic representation through the prompt weighting network to obtain weighted prompt information;
The original classification module is used for carrying out object classification on the weighted prompt information and the training data characteristic representation through the classification network to obtain a classification prediction result;
the optimizing module is used for optimizing the original classification model according to a preset classification verification result and the classification prediction result to obtain a target classification model;
the target classification module is used for acquiring target data, inputting the target data into the target classification model for object classification, and obtaining a target classification result.
In order to achieve the above object, a fourth aspect of the embodiments of the present application provides an image classification device, which is applied to image classification, wherein the target classification model is an image classification model, and the target classification model is obtained from the object classification based on class reinforcement learning; the device comprises:
the data acquisition module is used for acquiring target image data;
and the image classification module is used for inputting the target image data into the image classification model to perform image classification so as to obtain image category information.
To achieve the above object, a fifth aspect of the embodiments of the present application proposes an electronic device, including a memory storing a computer program and a processor implementing the method according to the first aspect or the second aspect when the processor executes the computer program.
To achieve the above object, a sixth aspect of the embodiments of the present application proposes a computer-readable storage medium storing a computer program which, when executed by a processor, implements the method of the first aspect or the second aspect.
The object classification method, the image classification method and the related equipment based on class reinforcement learning provided by the application automatically generate the prompt information of each increment task according to the prompt codebook, so that the prompt information is used as a reference in classification, the targeted classification is realized, and the class prediction can be accurately performed under the condition of newly increased classes. Meanwhile, the hint codebook is applicable to incremental data, and cannot be increased along with the increase of the data.
Drawings
FIG. 1 is a flow chart of an object classification method based on class reinforcement learning provided by an embodiment of the application;
fig. 2 is a flowchart of step S101 in fig. 1;
FIG. 3 is a model structure diagram of an original classification model in an object classification method based on class reinforcement learning according to an embodiment of the present application;
fig. 4 is a flowchart of step S104 in fig. 1;
fig. 5 is a flowchart of step S401 in fig. 4;
FIG. 6 is a network structure diagram of a hint combining network in the class-enhanced learning-based object classification method according to the embodiment of the present application;
Fig. 7 is a flowchart of step S105 in fig. 1;
fig. 8 is a flowchart of step S701 in fig. 7;
FIG. 9 is a network structure diagram of a prompt weighted network in an object classification method based on class reinforcement learning according to an embodiment of the present application;
fig. 10 is a flowchart of step S106 in fig. 1;
FIG. 11 is a detailed block diagram of an original classification network in an object classification method based on class reinforcement learning according to an embodiment of the present application;
fig. 12 is a flowchart of step S107 in fig. 1;
FIG. 13 is a flow chart of an image classification method provided by an embodiment of the present application;
fig. 14 is a schematic structural diagram of an object classification device based on class reinforcement learning according to an embodiment of the present application;
fig. 15 is a schematic structural diagram of an image classification apparatus according to an embodiment of the present application;
fig. 16 is a schematic diagram of a hardware structure of an electronic device according to an embodiment of the present application.
Detailed Description
The present application will be described in further detail with reference to the drawings and examples, in order to make the objects, technical solutions and advantages of the present application more apparent. It should be understood that the specific embodiments described herein are for purposes of illustration only and are not intended to limit the scope of the application.
It should be noted that although functional block division is performed in a device diagram and a logic sequence is shown in a flowchart, in some cases, the steps shown or described may be performed in a different order than the block division in the device, or in the flowchart. The terms first, second and the like in the description and in the claims and in the above-described figures, are used for distinguishing between similar elements and not necessarily for describing a particular sequential or chronological order.
Unless defined otherwise, all technical and scientific terms used herein have the same meaning as commonly understood by one of ordinary skill in the art to which this application belongs. The terminology used herein is for the purpose of describing embodiments of the application only and is not intended to be limiting of the application.
First, several nouns involved in the present application are parsed:
artificial intelligence (artificial intelligence, AI): is a new technical science for researching and developing theories, methods, technologies and application systems for simulating, extending and expanding the intelligence of people; artificial intelligence is a branch of computer science that attempts to understand the nature of intelligence and to produce a new intelligent machine that can react in a manner similar to human intelligence, research in this field including robotics, language recognition, image recognition, natural language processing, and expert systems. Artificial intelligence can simulate the information process of consciousness and thinking of people. Artificial intelligence is also a theory, method, technique, and application system that utilizes a digital computer or digital computer-controlled machine to simulate, extend, and expand human intelligence, sense the environment, acquire knowledge, and use knowledge to obtain optimal results.
Incremental learning (Incremental Learning): is a machine learning method, incremental learning continuously updates and improves models by gradually adding new training data to adapt to new situations or tasks. Unlike traditional Batch Learning (Batch Learning), incremental Learning can perform new Learning on the basis of an existing model without retraining the entire model.
Codebook: the available codes are put together to form a table, the codes are numbered by serial numbers, and then the codes are searched for relevant codes in the table directly according to the serial numbers, namely the codebook.
Vision Transformer (ViT): is a deep learning model for image classification and visual tasks. Unlike conventional Convolutional Neural Networks (CNNs), viT employs a self-attention mechanism (self-attention mechanism) to capture semantic associations of different locations in an image. The core idea of ViT is to divide the input image into a series of image blocks (patches) and then transform each image block into a vector representation by a linear transformation. These vector representations are input into a transducer model for learning the relationships between the different image blocks in the image. The transducer model is composed of multiple encoder layers, each of which contains a self-attention mechanism and a feed-forward neural network.
Cross-entropy loss (Cross-entropy loss): is a commonly used loss function, typically used for model training in classification problems. Cross entropy loss is based on concepts in information theory and is used to measure the difference between the predicted result of the model and the actual label.
With the rapid development of information technology, data is also incrementally changed. The traditional classification model is trained based on a closed data set, and cannot effectively process the newly added data. Incremental learning aims to provide a classification model with the ability to process continuously newly added data. In the related art, the incremental learning methods are mainly three, and the three incremental learning methods are a network fine tuning method, a regularization method and a playback method respectively. Most of network-based fine tuning methods are only fine tuning of the full-connection layer related to the new task, so that more old tasks are forgotten, and network volume is increased. The regularization-based method ensures that old knowledge is not covered by applying constraint to lost data of new tasks, but is highly dependent on the correlation between the new and old tasks, and training time can linearly increase along with the number of learning tasks. When training a new task, the playback method reserves a part of representative old data for the classification model to review old knowledge, but is sensitive to the selection of the old data, and the selection of the old data is more complex along with the increase of the number of tasks, so that additional computing resources and storage space are needed. Besides the traditional three incremental methods, the current optimal incremental learning method is generally based on a prompt method, and a group of corresponding prompts are independently trained on newly added data, so that a pre-training model can be quickly adapted to new tasks and new data. However, the prompt-based method can make the prompt storage amount of data increase exponentially with the increase of data, and in the application test stage, searching the prompt corresponding to the data is very difficult, so that the classification accuracy of the classification model is reduced.
Based on the above, the embodiment of the application provides an object classification method, an image classification method and related equipment based on class augmentation learning, aiming at generating prompt information corresponding to newly added data based on a prompt weighted combination method of a prompt codebook under different incremental tasks, wherein the prompt information can play a role in classifying and prompting a classification model, so that the prompt information and training data characteristic representation are classified pertinently through an original classification model, and finally the original classification model is trained according to a classification prediction result and a classification verification result to obtain a target classification model, and the target data is classified through the target classification model to obtain a more accurate target classification result.
The object classification method, the image classification method and the related equipment based on class reinforcement learning provided by the embodiment of the application are specifically described through the following embodiments, and the object classification method based on class reinforcement learning in the embodiment of the application is described first.
The embodiment of the application can acquire and process the related data based on the artificial intelligence technology. Among these, artificial intelligence (Artificial Intelligence, AI) is the theory, method, technique and application system that uses a digital computer or a digital computer-controlled machine to simulate, extend and extend human intelligence, sense the environment, acquire knowledge and use knowledge to obtain optimal results.
Artificial intelligence infrastructure technologies generally include technologies such as sensors, dedicated artificial intelligence chips, cloud computing, distributed storage, big data processing technologies, operation/interaction systems, mechatronics, and the like. The artificial intelligence software technology mainly comprises a computer vision technology, a robot technology, a biological recognition technology, a voice processing technology, a natural language processing technology, machine learning/deep learning and other directions.
The embodiment of the application provides an object classification method based on class reinforcement learning, and relates to the technical field of artificial intelligence. The object classification method based on class reinforcement learning provided by the embodiment of the application can be applied to a terminal, a server side and software running in the terminal or the server side. In some embodiments, the terminal may be a smart phone, tablet, notebook, desktop, etc.; the server side can be configured as an independent physical server, a server cluster or a distributed system formed by a plurality of physical servers, and a cloud server for providing cloud services, cloud databases, cloud computing, cloud functions, cloud storage, network services, cloud communication, middleware services, domain name services, security services, CDNs, basic cloud computing services such as big data and artificial intelligent platforms and the like; the software may be an application or the like that implements an object classification method based on class reinforcement learning, but is not limited to the above form.
The application is operational with numerous general purpose or special purpose computer system environments or configurations. For example: personal computers, server computers, hand-held or portable devices, tablet devices, multiprocessor systems, microprocessor-based systems, set top boxes, programmable consumer electronics, network PCs, minicomputers, mainframe computers, distributed computing environments that include any of the above systems or devices, and the like. The application may be described in the general context of computer-executable instructions, such as program modules, being executed by a computer. Generally, program modules include routines, programs, objects, components, data structures, etc. that perform particular tasks or implement particular abstract data types. The application may also be practiced in distributed computing environments where tasks are performed by remote processing devices that are linked through a communications network. In a distributed computing environment, program modules may be located in both local and remote computer storage media including memory storage devices.
It should be noted that, in each specific embodiment of the present application, when related processing is required according to user information, user behavior data, user history data, user location information, and other data related to user identity or characteristics, permission or consent of the user is obtained first, and the collection, use, processing, and the like of the data comply with related laws and regulations and standards. In addition, when the embodiment of the application needs to acquire the sensitive personal information of the user, the independent permission or independent consent of the user is acquired through popup or jump to a confirmation page and the like, and after the independent permission or independent consent of the user is definitely acquired, the necessary relevant data of the user for enabling the embodiment of the application to normally operate is acquired.
Fig. 1 is an optional flowchart of an object classification method based on class reinforcement learning according to an embodiment of the present application, where the method in fig. 1 may include, but is not limited to, steps S101 to S108.
Step S101, acquiring a training data set of a current increment task; wherein the training data set comprises: prompting a codebook and training data;
step S102, performing feature mapping processing on training data to obtain training data feature representation;
step S103, inputting the prompt codebook and the training data characteristic representation into a preset original classification model; wherein the original classification model comprises: a hint combining network, a hint weighting network, and a classification network;
step S104, prompt prediction is carried out on the prompt codebook and the training data characteristic representation through a prompt combination network, so that combined prompt data are obtained;
step S105, weighting the combined prompt data and the training data characteristic representation through a prompt weighting network to obtain weighted prompt information;
step S106, carrying out object classification on the weighted prompt information and the training data characteristic representation through a classification network to obtain a classification prediction result;
step S107, optimizing an original classification model according to a preset classification verification result and a classification prediction result to obtain a target classification model;
Step S108, obtaining target data, inputting the target data into a target classification model to perform object classification, and obtaining a target classification result.
In the steps S101 to S108 shown in the embodiment of the present application, a training data set of a current newly added task is obtained, the training data set includes a prompt codebook and training data, the training data is subjected to feature mapping to obtain a training data feature representation, and the training data feature representation and the prompt codebook are input into an original classification model. And carrying out prompt prediction on the training data characteristic representation and the prompt codebook through a prompt combination network in the original classification model to obtain combined prompt data of the training data, and carrying out weighting processing on the combined prompt data and the training data characteristic representation through a prompt weighting network to obtain weighted prompt information which is weighted prompt information. And performing object classification processing according to the weighted prompt information and the training data characteristic representation through a classification network so as to determine a classification result of the training data and obtain a classification prediction result. And meanwhile, the original classification model is optimized according to the classification verification result and the classification prediction result to obtain a target classification model with more accurate classification, so that the classification of the target data is more accurate through the target classification model. Therefore, the prompt information of each increment task is automatically generated according to the prompt codebook, so that the prompt information is taken as a reference when classification is carried out, the targeted classification is realized, and the classification prediction can be accurately carried out under the condition of newly increasing the classification. Meanwhile, the hint codebook is applicable to incremental data, and cannot be increased along with the increase of the data.
In step S101 of some embodiments, a training data set may be extracted from a preset training library, and the training data set selected according to different objects is different. It should be noted that the object may be an image, text, voice, etc., and the specific content of the object is not limited. The training data set comprises a prompt codebook and training data, the training data set comprises a plurality of training data, the training data are input into the original classification model one by one for classification operation, and model parameters of the original classification model are adjusted after each classification. The prompt codebook is a codebook storing at least two basic prompts, and the basic prompts are used for prompting the training data set of each incremental task, so that the original classification model can use new incremental tasks and new data, which is equivalent to prompting the original classification model to classify the training data.
For example, if the object is an image, and there are 100 categories of image data in the training library, and each category contains 600 pieces of image data. Wherein, 500 pieces of training image data and 100 pieces of test image data are included in 600 pieces of image data, after the categories are randomly disturbed, 10 pieces of image data of the categories are taken as the image data of one increment task at a time according to the disturbed sequence, and each increment task comprises training and verification. The image data set which is extracted for the first time is set as a base data set, and the corresponding image data set is extracted for training and verifying the original classification model in each increment task. If the current increment task is the ith increment task, acquiring an image dataset of the ith increment task as the current image dataset, wherein the current image dataset comprises a training image dataset and a verification image dataset.
Referring to fig. 2, in some embodiments, step S101 may include, but is not limited to, steps S201 to S202:
step S201, training data and a candidate codebook of a current increment task are obtained;
step S202, initializing a candidate codebook according to a preset standard deviation to obtain a prompt codebook; wherein the hint codebook comprises at least two basic hints.
In step S201 of some embodiments, after determining the current incremental task, training data corresponding to the current incremental task is acquired. It should be noted that, the current incremental task includes a target data category, training data is extracted from the training library according to the target data category, and the training data is stored in the training data set. Meanwhile, a candidate codebook is acquired, wherein the candidate codebook is a codebook for storing N basic prompt messages, and the length of the basic prompt messages is L, so that the dimension of the acquired candidate codebook is
In step S202 of some embodiments, the candidate codebooks are normally distributed, so as to initialize the candidate codebooks according to a preset standard deviation to obtain the hint codebook. In this embodiment, the preset standard deviation is 0.02, and the candidate codebook is initialized with the standard deviation of 0.02 to obtain the hint codebook.
In step S201 to step S202 illustrated in this embodiment, after training data and a candidate codebook of a current incremental task are obtained, the candidate codebook is initialized according to a preset standard deviation to obtain a prompt codebook, so that a training data set is constructed according to the prompt codebook and the training data, and the training data set is easy to obtain.
In step S102 of some embodiments, training data features are obtained by feature extraction of training data, and then the training data features are input to an encoder for encoding processing to obtain a training data feature representation. It should be noted that, the current training data is obtained asAnd the prompt codebook is +.>Current training data +.>And prompt codebook +.>And inputting the training data to an encoding layer Encode for encoding processing to obtain training data characteristic representation.
In step S103 of some embodiments, as shown in fig. 3, the original classification model includes a hint combination network, a hint weighting network, and a classification network. It should be noted that, the current training data is input to the prompt combination network, the prompt weighting network and the classification network, and the prompt codebook is input to the prompt combination network, and the output of the prompt combination network is in butt joint with the input of the prompt weighting network. The output of the prompt weighting network is connected to several layers in the classification network so that the weighted prompt information output by the prompt weighting network alters the structure of the classification network.
Referring to fig. 4, in some embodiments, step S104 may include, but is not limited to, steps S401 to S402:
step S401, coefficient prediction is carried out on the prompt codebook and the training data characteristic representation through a prompt combination network, so as to obtain prompt combination coefficients;
step S402, the product of the prompt combination coefficient and the data matrix of the prompt codebook is carried out through the prompt combination network, so that the combined prompt data is obtained.
In step S401 of some embodiments, coefficient prediction is performed by inputting the hint codebook and the training data feature representation into a hint combining network to obtain hint combining coefficients, and the hint combining coefficients characterize a correlation between the hint current training data feature representation and the hint codebook. Therefore, the correlation between the training data characteristic representation and each basic prompt in the prompt codebook can be visually represented through the prompt combination coefficient, so that the prompt information which is more in line with the current training data characteristic representation can be constructed according to the prompt combination coefficient. Note that the hint combination coefficients are represented by a data matrix.
In step S402 of some embodiments, product processing is performed by using the hint combination coefficients and the data matrix corresponding to the hint codebook to obtain combined hint data. Note that, the prompt combination coefficient is The prompt codebook is +.>The combined prompt data is +.>Then->. Because, hint codebook->Is +.>Prompting the combination coefficientIs +.>Then combine prompt data +.>Is +.>N base cues of length L are included. Therefore, the prompt combination coefficient corresponding to the dimension is obtained according to the dimension calculation of the prompt codebook, and the combined prompt data is constructed according to the prompt codebook and the prompt combination coefficient, so that the combined prompt data is simple to construct.
In steps S401 to S402 illustrated in this embodiment, by calculating the hint combination coefficient and characterizing the correlation between each basic hint and the feature representation of training data in the hint codebook by the hint combination coefficient, so that the product calculation is performed according to the hint combination coefficient and the data matrix of the hint codebook to obtain the combined hint data, the combined hint data can be more targeted as the hint of the current training data.
Referring to fig. 5, in some embodiments, the hint combining network includes: at least two correlation measurement layers and a coefficient prediction layer; step S401 may include, but is not limited to, including step S501 to step S502:
step S501, carrying out correlation measurement processing on the prompt codebook and the training data feature representation through a correlation measurement layer to obtain correlation measurement data;
Step S502, carrying out coefficient prediction on the correlation measurement data through a coefficient prediction layer to obtain a prompt combination coefficient.
It should be noted that, as shown in fig. 6, fig. 6 shows a schematic structural diagram of a hint combination network, and the hint combination network includes at least two correlation measurement layers and a coefficient prediction layer. In this embodiment, the correlation measurement layer is also called a block layer, and the block layer and the VIT model have the same structure, and 5 block layers are set, and the 5 block layers are connected one by one, so as to continuously calculate the correlation between each basic prompt and the training data feature representation.
In step S501 of some embodiments, a relevance metric value of each base hint is obtained by relevance measuring each base hint and a feature representation of training data in a hint codebook by a relevance metric layer, and then the relevance metrics of the plurality of base hints are combined to form relevance metric data. For example, if the training data features are expressed asThe prompt codebook isBy characterizing training data +.>Setting a token0 position and prompting a codebook +.>The training data feature representations at the remaining token positions include all feature representations of the training data.
In step S502 of some embodiments, the correlation metric data is input to the coefficient prediction layer, and the coefficient prediction layer is a fully connected layer (FC). As shown in fig. 6, the correlation metric data is subjected to coefficient prediction by the full connection layer (FC) to obtain a correlation between the training data and each basic prompt that can be visually characterized.
In steps S501 to S502 illustrated in the present embodiment, correlation measurement data is obtained by measuring the correlation between each basic hint and training data feature in the hint codebook, and a hint combination coefficient is constructed based on the correlation measurement data, so that the correlation between each basic hint and training data in the hint codebook is intuitively represented by the hint combination coefficient.
Referring to fig. 7, in some embodiments, step S105 may include, but is not limited to, steps S701 to S702:
step S701, carrying out weighted weight calculation on the combined prompt data and training data characteristic representation through a prompt weighted network to obtain prompt weighted weight;
step S702, the prompt weighted weight and the combined prompt data are constructed into weighted prompt information through the prompt weighted network.
In steps S701 to S702 of some embodiments, the prompt weighting network (Prompt Weighted Module, PWM) can determine a prompt weighting weight for each base prompt according to the combined prompt data, and then construct the weighted prompt information of the current training data according to the prompt weighting weights and the combined prompt data. Therefore, not only the combined prompt data is used as the prompt of the training data, but also the weighted prompt information is obtained after weighting and is used as the prompt of the training data, and more representative and unique prompts can be set for the training data. It should be noted that, the weighted prompt information only needs to prompt the subsequent classification operation and does not need to be stored, so that only the classification model and the prompt codebook need to be stored in the data adding process, and the prompt codebook cannot be increased along with the increase of data, thereby saving the data storage space.
In steps S701 to S702 illustrated in this embodiment, the weighting weight of each combination of the prompt data is further set to obtain a prompt weighting weight, and then the product of the prompt weighting weight and the combination of the prompt data is performed to obtain weighted prompt information, so as to construct a more unique and representative prompt. Therefore, by setting the more unique weighted prompt information of the training data, the training data is classified more accurately according to the weighted prompt information.
Referring to fig. 8, in some embodiments, the hint weighting network includes: a full connection layer; step S701 includes, but is not limited to, steps S801 to S804:
step S801, mapping the combined prompt data to a preset space through a full connection layer to obtain combined prompt mapping data, and mapping the training data feature representation to the preset space through the full connection layer to obtain training data feature mapping data;
step S802, carrying out correlation measurement on the combined prompt mapping data and the training data feature mapping data to obtain information correlation data;
step S803, carrying out standardization processing on the information correlation data according to preset length information to obtain target correlation data; the length information is the length of a basic prompt in the prompt codebook;
Step S804, nonlinear mapping processing is carried out on the target correlation data, and prompt weighting is obtained.
In step S801 of some embodiments, as shown in fig. 9, fig. 9 shows a network structure of a hint weighted network. As can be seen from fig. 9, the prompt weighting network is provided with two full-connection layers, the two full-connection layers respectively input training data features and combined prompt data, and the data output by the two full-connection layers need to be multiplied and then subjected to standardization, nonlinear mapping processing and multiplication to output the prompt weighting. Thus, the training data feature representations are respectively input to the first fully-connected layer, so that the first fully-connected layer maps the training data feature representations to a preset space to obtain training data feature mapping data. And simultaneously, inputting the combined prompt data into a second full-connection layer, so that the second full-connection layer maps the combined prompt data to a preset space to obtain combined prompt mapping data. It should be noted that, the training data feature representation and the combined prompt data are mapped to the same space, so as to obtain the training data feature mapping data and the combined prompt mapping data with the same dimension, so as to calculate the correlation between the training data and the combined prompt data based on the same dimension.
Specifically, the data dimension of the preset space isSo the training data feature is represented +.>Mapping to a preset space to obtain training data feature mapping data of +.>Combined prompt data +.>Mapping to a preset space to obtain combined prompt mapping data of +.>And the training data feature mapping data and the combined prompt mapping data have the same data dimension.
In step S802 of some embodiments, after constructing the combined hint map data and the training data feature map data that are consistent in data dimensions, a correlation between the combined hint map data and the training data feature map data is calculated. It should be noted that, because the combined hint map data and training data feature map data are characterized by a data matrix, a data-by-data calculation of correlation is required to obtain information correlation data of the data matrix. For example, determining combined hint map dataAnd training data feature map data->The information correlation data between them is +.>
In step S803 of some embodiments, after the information correlation data is determined, it is necessary to perform normalization processing on the information correlation data, so that the target correlation data is obtained by dividing the information correlation data by the square root of the length L corresponding to the length information. Therefore, the correlation between the training data and the combined prompt data can be more intuitively represented by the normalized target correlation data. For example, the information-related data is So the target correlation data is
In step S804 of some embodiments, the target correlation data is obtained after the criteria are completedAnd carrying out nonlinear mapping processing on the target correlation data to obtain prompt weighting. It should be noted that, performing nonlinear mapping on the target correlation data according to the preset function to obtain a prompt weighting of +.>. Specifically, the preset function is Sigmoid function, so prompt weighting +.>. Therefore, the prompt weighting weight is obtained through calculation so as to determine the weight of each prompt in the combined prompt data, and the importance of the prompts in the combined prompt data can be intuitively represented through the prompt weighting weight.
It should be noted thatAnd after the prompt weighting is completed, carrying out product calculation on the prompt weighting and the data matrix corresponding to the combined prompt data to obtain the weighted prompt information. Specifically, the weighting weights are promptedAnd the current combined prompt data +.>Obtain->Weighted prompt of->
In steps S801 to S804 illustrated in this embodiment, the information correlation data is obtained by converting the training data feature representations of different data dimensions and the combined prompt data into the same data dimension, calculating the correlation between the data, and then, the information correlation data is normalized and then, nonlinear mapping is performed to obtain the prompt weighting weight. Thus, a prompt weighting weight is set based on the correlation between each training data and the combined prompt data to characterize which prompt in the combined prompt data is more important. The weighted prompt information which can represent the training data is constructed based on the prompt weighted weight and the combined prompt data, so that the training data is classified into more unique prompts, and the training data is classified more accurately.
Referring to fig. 10, in some embodiments, step S106 may include, but is not limited to, steps S1001 to S1002:
step S1001, preset initialization prompt information is set at a head layer of a classification network, weighted prompt information is set at a middle layer of the classification network, and output dimensions of a tail layer in the classification network are set according to the total number of categories of the current incremental task;
step S1002, inputting the training data characteristic representation into a classification network with a structure adjusted for classification prediction, and obtaining a classification prediction result.
In step S1001 of some embodiments, the classification network is a frozen pre-trained VIT model in this embodiment, and the VIT model is used as a backbone network in the classification network. Note that, in the present embodiment, the VIT model used is vit_base_patch16_224. The VIT model is split into a header layer, a middle layer and an end layer, wherein the header layer at least comprises a first block layer, the middle layer comprises at least one middle block layer, and the end layer is the last classification layer. As shown in fig. 11, the initialization hint information is randomly set on the header layerThe middle layer is provided with weighted prompt information output by a prompt weighting network, and the output dimension of the last classification layer is set to be the total number of categories of the current increment task. For example, if the header layer is the 1 st and 2 nd block layers, the middle layer is the 3 rd to 5 th block layers, and the current incremental task is 10 kinds of image classification. Initialization prompt information is set in the 1 st and 2 nd block layers, weighting prompt information is set in the 3 rd to 5 th block layers, and the output dimension of the last classification layer is set to be 10, namely a sequence of 10 probability values is output, and each probability value represents the probability of one image class.
In step S1002 of some embodiments, after modifying the classification network based on the weighted hints information, the classification model is a classification network that meets the current incremental task. Therefore, the classification network is used for carrying out classification prediction on the training data characteristic representation to obtain a classification prediction result, so that classification of category pertinence is realized, and classification accuracy is improved.
In steps S1001 to S1002 illustrated in the present embodiment, after modifying the classification network based on the weighted prompt information, a classification network conforming to the current incremental task is constructed, and a prediction conforming to the category can be made. Thus, making a class prediction for the training data feature representation based on the classification network after the structural adjustment is more accurate.
Referring to fig. 12, in some embodiments, step S107 may include, but is not limited to, steps S1201 through S1202:
step S1201, cross entropy loss calculation is carried out on the classification verification result and the classification prediction result, so as to obtain classification loss data;
and step S1202, carrying out parameter adjustment on the original classification model according to the classification loss data to obtain a target classification model.
In steps S1201 to S1202 illustrated in the present embodiment, classification loss data is obtained by performing cross entropy loss calculation on the classification verification result and the classification prediction result, and the classification loss data characterizes the classification accuracy of the classification network. Although the classification network can make predictions of corresponding categories based on the weighted prompt information, the accuracy of the predictions is not represented but is greatly improved, and the entire original classification model needs to be trained based on training data of the same set of categories. Therefore, the original classification model is subjected to parameter adjustment based on the classification loss data to obtain the target classification model, so that the target classification model with more accurate classification is constructed.
In some embodiments, after step S1102, the class-increment learning based object classification method further includes, but is not limited to, including: and optimizing the prompt codebook according to the classification loss data.
It should be noted that, when the optimization of the target classification model is completed, the prompt codebook needs to be optimized synchronously so as to make more accurate prompts for the training data, so that the classification prediction is more accurate.
As shown in FIG. 11, an embodiment of the present application obtains a hint codebookAnd the current training data->After the training data is subjected to feature mapping, the training data feature representation +.>. By characterizing training data->And prompt codebook +.>Input to the relevance metric layer for relevance metric and express training data feature +.>Setting a token0 position and prompting a codebook +.>And setting the rest token positions to obtain correlation measurement data. Coefficient prediction of correlation metric data by fully connected layer (FC) to obtain hint combination coefficient +.>. By prompting the combination coefficient->And prompt codebook +.>The data matrix of (2) is multiplied to obtain combined prompt data +.>. Will combine the prompt data +.>And training data characteristic representation +.>Mapping to preset space to obtain training data feature mapping data of +. >And the combined hint map data is +.>. Second, metric training data feature map data is +.>And combining the hint map data to obtain information-related data +.>And by means of information-dependent data +.>Divided by length->Is normalized to the square root of (2) to obtain target correlation data. Again, normalized target correlation data based on Sigmoid +.>Non-linear mappingObtaining training data->Prompt weighting weight +.>. Finally, weighting by hinting>And combined prompt data +.>Matrix product of training data ∈>Is provided. Therefore, after the classification network is modified based on the weighted prompt information, the classification network can make predictions of the categories corresponding to the current incremental tasks, so that the classification prediction data is obtained by classifying the training data feature representations through the classification network. Then, cross entropy loss calculation is carried out between the category prediction data and the category verification data to obtain category loss data, model parameter adjustment is carried out on the original category model according to the category loss data to obtain a target category model, and the target category model is based on the model parameter adjustmentAnd optimizing the prompt codebook in the classification loss data to obtain a classification model capable of accurately classifying and prompt the classification model to classify corresponding categories.
Referring to fig. 13, the embodiment of the application further provides an image classification method, which applies image classification, wherein the target classification model is an image classification model, and the target classification model is obtained from the object classification based on class-enhanced learning; the image classification method includes, but is not limited to, steps S1301 to S1302:
step S1301, acquiring target image data;
in step S1302, the target image data is input to the image classification model to perform image classification, so as to obtain image classification information.
The image classification model obtained by the method can accurately make class prediction under class increment, and the size of the set prompt codebook is not increased along with the increase of data, so that the method still has better expandability under the condition of not specially fine-tuning new tasks.
It should be noted that, by adopting the image classification method under class increment of the present application, the comparison of the accuracy with respect to the image classification method under the conventional class increment method is shown in table 1. It should be noted that, the conventional class increment image classification method includes LwF, L2P and dual_template. The method comprises the steps of presetting 100 categories of image data, wherein each category comprises 500 pieces of training image data and 100 pieces of vision measurement image data, randomly selecting 10 categories from the disturbed image data to serve as one increment task, and determining the precision under different image classification methods after completing the increment task for 10 times as shown in table 1. Wherein the upper performance bound is the result of training all classes in the CIFAR-100 dataset using the same pre-training VIT ("VIT_base_patch 16_224") as the present application:
TABLE 1 incremental task Performance (precision%) comparison based on CIFAR-100 dataset
As can be seen from Table 1, the image classification method of the application has obvious performance advantages compared with the traditional image classification method, and has performance improvement compared with the 9 times of incremental tasks of the dual_Prompt method with the best performance at present, and the performance of the 9 times of incremental tasks is closer to the upper boundary of the performance.
The construction of the image classification model refers to the specific embodiment of the object classification method based on class-increasing learning, and will not be described herein.
Referring to fig. 14, the embodiment of the present application further provides an object classification device based on class reinforcement learning, which can implement the above object classification method based on class reinforcement learning, where the device includes:
a data set obtaining module 1401, configured to obtain a training data set of a current incremental task; wherein the training data set comprises: prompting a codebook and training data;
the feature mapping module 1402 is configured to perform feature mapping processing on the training data to obtain a feature representation of the training data;
an input module 1403, configured to input the hint codebook and the training data feature representation to a preset original classification model; wherein the original classification model comprises: a hint combining network, a hint weighting network, and a classification network;
A prompt predicting module 1404, configured to predict, through a prompt combination network, a prompt for the prompt codebook and the training data feature representation, to obtain combined prompt data;
a weighting processing module 1405, configured to perform weighting processing on the combined prompt data and the training data feature representation through a prompt weighting network, so as to obtain weighted prompt information;
the original classification module 1406 is configured to perform object classification on the weighted prompt information and the training data feature representation through a classification network to obtain a classification prediction result;
an optimizing module 1407, configured to optimize the original classification model according to a preset classification verification result and a classification prediction result, so as to obtain a target classification model;
the object classification module 1408 is configured to obtain object data, input the object data to the object classification model for object classification, and obtain an object classification result.
The specific implementation of the object classification device based on class reinforcement learning is basically the same as the specific embodiment of the object classification method based on class reinforcement learning, and is not described herein.
Referring to fig. 15, an embodiment of the present application further provides an image classification apparatus, which can implement the above image classification method, and is applied to image classification, wherein the target classification model is an image classification model, and the target classification model is obtained from the above object classification based on class reinforcement learning; the image classification device includes:
A data acquisition module 1501 for acquiring target image data;
the image classification module 1502 is configured to input the target image data into the image classification model for image classification, so as to obtain image category information.
The specific implementation of the image classification device is basically the same as the specific embodiment of the image classification method, and will not be described herein.
The embodiment of the application also provides electronic equipment, which comprises a memory and a processor, wherein the memory stores a computer program, and the processor realizes the object classification method based on class increment learning when executing the computer program. The electronic equipment can be any intelligent terminal including a tablet personal computer, a vehicle-mounted computer and the like.
Referring to fig. 16, fig. 16 illustrates a hardware structure of an electronic device according to another embodiment, the electronic device includes:
the processor 1601 may be implemented by a general-purpose CPU (central processing unit), a microprocessor, an application-specific integrated circuit (ApplicationSpecificIntegratedCircuit, ASIC), or one or more integrated circuits, etc. for executing related programs to implement the technical solution provided by the embodiments of the present application;
memory 1602 may be implemented in the form of read-only memory (ReadOnlyMemory, ROM), static storage, dynamic storage, or random access memory (RandomAccessMemory, RAM). The memory 1602 may store an operating system and other application programs, and when implementing the technical solutions provided in the embodiments of the present disclosure by software or firmware, relevant program codes are stored in the memory 1602, and the processor 1601 invokes an object classification method based on class gain learning or an image classification method as described above to perform the embodiments of the present disclosure;
An input/output interface 1603 for implementing information input and output;
the communication interface 1604 is configured to implement communication interaction between the device and other devices, and may implement communication in a wired manner (e.g., USB, network cable, etc.), or may implement communication in a wireless manner (e.g., mobile network, WIFI, bluetooth, etc.);
a bus 1605 for transferring information between various components of the device (e.g., processor 1601, memory 1602, input/output interface 1603, and communication interface 1604);
wherein the processor 1601, the memory 1602, the input/output interface 1603 and the communication interface 1604 enable communication connection with each other inside the device via a bus 1605.
The embodiment of the application also provides a computer readable storage medium, wherein the computer readable storage medium stores a computer program, and the computer program realizes the object classification method based on class increment learning or the image classification method when being executed by a processor.
The memory, as a non-transitory computer readable storage medium, may be used to store non-transitory software programs as well as non-transitory computer executable programs. In addition, the memory may include high-speed random access memory, and may also include non-transitory memory, such as at least one magnetic disk storage device, flash memory device, or other non-transitory solid state storage device. In some embodiments, the memory optionally includes memory remotely located relative to the processor, the remote memory being connectable to the processor through a network. Examples of such networks include, but are not limited to, the internet, intranets, local area networks, mobile communication networks, and combinations thereof.
The embodiment of the application provides an object classification method, an image classification method and related equipment based on class reinforcement learning. Under the condition of different incremental tasks, the prompt weighted combination method based on the prompt codebook generates the prompt information corresponding to the newly added data, and the prompt information can play a role in classifying and prompting the classification model, so that the prompt information and the training data feature representation can be classified in a targeted manner through the original classification model, finally, the original classification model is trained according to the classification prediction result and the classification verification result to obtain a target classification model, and the target data is classified through the target classification model to obtain a more accurate target classification result.
The embodiments described in the embodiments of the present application are for more clearly describing the technical solutions of the embodiments of the present application, and do not constitute a limitation on the technical solutions provided by the embodiments of the present application, and those skilled in the art can know that, with the evolution of technology and the appearance of new application scenarios, the technical solutions provided by the embodiments of the present application are equally applicable to similar technical problems.
It will be appreciated by persons skilled in the art that the embodiments of the application are not limited by the illustrations, and that more or fewer steps than those shown may be included, or certain steps may be combined, or different steps may be included.
The above described apparatus embodiments are merely illustrative, wherein the units illustrated as separate components may or may not be physically separate, i.e. may be located in one place, or may be distributed over a plurality of network elements. Some or all of the modules may be selected according to actual needs to achieve the purpose of the solution of this embodiment.
Those of ordinary skill in the art will appreciate that all or some of the steps of the methods, systems, functional modules/units in the devices disclosed above may be implemented as software, firmware, hardware, and suitable combinations thereof.
The terms "first," "second," "third," "fourth," and the like in the description of the application and in the above figures, if any, are used for distinguishing between similar objects and not necessarily for describing a particular sequential or chronological order. It is to be understood that the data so used may be interchanged where appropriate such that the embodiments of the application described herein may be implemented in sequences other than those illustrated or otherwise described herein. Furthermore, the terms "comprises," "comprising," and "having," and any variations thereof, are intended to cover a non-exclusive inclusion, such that a process, method, system, article, or apparatus that comprises a list of steps or elements is not necessarily limited to those steps or elements expressly listed but may include other steps or elements not expressly listed or inherent to such process, method, article, or apparatus.
It should be understood that in the present application, "at least one (item)" means one or more, and "a plurality" means two or more. "and/or" for describing the association relationship of the association object, the representation may have three relationships, for example, "a and/or B" may represent: only a, only B and both a and B are present, wherein a, B may be singular or plural. The character "/" generally indicates that the context-dependent object is an "or" relationship. "at least one of" or the like means any combination of these items, including any combination of single item(s) or plural items(s). For example, at least one (one) of a, b or c may represent: a, b, c, "a and b", "a and c", "b and c", or "a and b and c", wherein a, b, c may be single or plural.
In the several embodiments provided by the present application, it should be understood that the disclosed apparatus and method may be implemented in other manners. For example, the above-described apparatus embodiments are merely illustrative, and for example, the above-described division of units is merely a logical function division, and there may be another division manner in actual implementation, for example, a plurality of units or components may be combined or may be integrated into another system, or some features may be omitted, or not performed. Alternatively, the coupling or direct coupling or communication connection shown or discussed with each other may be an indirect coupling or communication connection via some interfaces, devices or units, which may be in electrical, mechanical or other form.
The units described above as separate components may or may not be physically separate, and components shown as units may or may not be physical units, may be located in one place, or may be distributed over a plurality of network units. Some or all of the units may be selected according to actual needs to achieve the purpose of the solution of this embodiment.
In addition, each functional unit in the embodiments of the present application may be integrated in one processing unit, or each unit may exist alone physically, or two or more units may be integrated in one unit. The integrated units may be implemented in hardware or in software functional units.
The integrated units, if implemented in the form of software functional units and sold or used as stand-alone products, may be stored in a computer readable storage medium. Based on such understanding, the technical solution of the present application may be embodied in essence or a part contributing to the prior art or all or part of the technical solution in the form of a software product stored in a storage medium, including multiple instructions to cause a computer device (which may be a personal computer, a server, or a network device, etc.) to perform all or part of the steps of the method of the various embodiments of the present application. And the aforementioned storage medium includes: a U-disk, a removable hard disk, a Read-Only Memory (ROM), a random access Memory (Random Access Memory, RAM), a magnetic disk, or an optical disk, or other various media capable of storing a program.
The preferred embodiments of the present application have been described above with reference to the accompanying drawings, and are not thereby limiting the scope of the claims of the embodiments of the present application. Any modifications, equivalent substitutions and improvements made by those skilled in the art without departing from the scope and spirit of the embodiments of the present application shall fall within the scope of the claims of the embodiments of the present application.

Claims (12)

1. An object classification method based on class reinforcement learning, the method comprising:
acquiring a training data set of a current increment task; wherein the training data set comprises: the method comprises the steps of prompting a codebook and training data, wherein the prompting codebook is a codebook for storing at least two basic prompts, and the basic prompts are used for prompting a training data set of each incremental task, so that an original classification model can use new incremental tasks and new data;
performing feature mapping processing on the training data to obtain training data feature representation;
inputting the prompt codebook and the training data characteristic representation into a preset original classification model; wherein the original classification model comprises: a hint combining network, a hint weighting network, and a classification network;
and carrying out prompt prediction on the prompt codebook and the training data characteristic representation through the prompt combination network to obtain combined prompt data, wherein the method specifically comprises the following steps of:
Coefficient prediction is carried out on the prompt codebook and the training data characteristic representation through the prompt combination network, so that prompt combination coefficients are obtained; the prompt combination coefficient characterizes the correlation between the current training data characteristic representation and the prompt codebook, and the coefficient is predicted to be the correlation calculation;
multiplying the prompt combination coefficient and the data matrix of the prompt codebook through the prompt combination network to obtain the combined prompt data; wherein the prompt combination coefficient is represented by a data matrix, the data matrix of the prompt codebook is represented by the data matrix of the prompt codebook, and the prompt combination coefficient isThe prompt codebook is +.>The combined prompt data is +.>Then->
The combined prompt data and the training data feature representation are weighted through the prompt weighting network to obtain weighted prompt information, and the method specifically comprises the following steps:
the prompt weighting network is used for carrying out weighting calculation on the combined prompt data and the training data characteristic representation to obtain prompt weighting;
constructing the prompt weighted weight and the combined prompt data into the weighted prompt information through the prompt weighted network;
Performing object classification on the weighted prompt information and the training data characteristic representation through the classification network to obtain a classification prediction result;
optimizing the original classification model according to a preset classification verification result and the classification prediction result to obtain a target classification model;
and obtaining target data, inputting the target data into the target classification model to perform object classification, and obtaining a target classification result.
2. The method of claim 1, wherein the hint combining network comprises: at least two correlation measurement layers and a coefficient prediction layer; the method for predicting the coefficients of the prompt codebook and the training data characteristic representation through the prompt combination network to obtain prompt combination coefficients comprises the following steps:
performing relevance measurement processing on the prompt codebook and the training data characteristic representation through the relevance measurement layer to obtain relevance measurement data;
and carrying out coefficient prediction on the correlation measurement data through the coefficient prediction layer to obtain the prompt combination coefficient.
3. The method of claim 2, wherein the hint weighting network comprises: a full connection layer; the step of calculating the weighting weights of the combined prompt data and the training data characteristic representation through a prompt weighting network to obtain prompt weighting weights comprises the following steps:
The combined prompt data is mapped to a preset space through the full connection layer to obtain combined prompt mapping data, and the training data feature representation is mapped to the preset space through the full connection layer to obtain training data feature mapping data;
carrying out correlation measurement on the combined prompt mapping data and the training data feature mapping data to obtain information correlation data;
carrying out standardized processing on the information correlation data according to preset length information to obtain target correlation data; the length information is the length of a basic prompt in the prompt codebook;
and carrying out nonlinear mapping processing on the target correlation data to obtain the prompt weighting.
4. A method according to any one of claims 1 to 3, wherein said classifying, by said classification network, said weighted prompt message and said training data feature representation into an object class, resulting in a class prediction result, comprises:
setting preset initialization prompt information at a head layer of the classification network, setting weighted prompt information at a middle layer of the classification network, and setting output dimensions of a tail layer in the classification network according to the total number of categories of the current incremental task;
And inputting the training data characteristic representation to the classification network with the structure adjusted to perform classification prediction, so as to obtain the classification prediction result.
5. A method according to any one of claims 1 to 3, wherein said optimizing said original classification model based on said classification verification result and said classification prediction result to obtain a target classification model comprises:
performing cross entropy loss calculation on the classification verification result and the classification prediction result to obtain classification loss data;
and carrying out parameter adjustment on the original classification model according to the classification loss data to obtain the target classification model.
6. The method of claim 5, wherein after said cross entropy loss calculation of said classification verification result and said classification prediction result, said method further comprises:
and optimizing the prompt codebook according to the classification loss data.
7. A method according to any one of claims 1 to 3, wherein said obtaining a training dataset of a current incremental task comprises:
acquiring training data and a candidate codebook of the current increment task; the candidate codebook is a codebook storing a plurality of basic prompt messages, and the basic prompt messages are used for prompting a training data set of each incremental task, so that the original classification model can use new incremental tasks and new data;
Initializing the candidate codebook according to a preset standard deviation to obtain a prompt codebook; wherein the hint codebook includes at least two basic hints.
8. An image classification method, characterized in that it is applied to image classification, a target classification model is an image classification model, and the target classification model is obtained from the object classification method based on class reinforcement learning according to any one of claims 1 to 7; the method comprises the following steps:
acquiring target image data;
and inputting the target image data into the image classification model to perform image classification, so as to obtain image category information.
9. An object classification device based on class reinforcement learning, the device comprising:
the data set acquisition module is used for acquiring a training data set of the current incremental task; wherein the training data set comprises: the method comprises the steps of prompting a codebook and training data, wherein the prompting codebook is a codebook for storing at least two basic prompts, and the basic prompts are used for prompting a training data set of each incremental task, so that an original classification model can use new incremental tasks and new data;
the feature mapping module is used for carrying out feature mapping processing on the training data to obtain a training data feature representation;
The input module is used for inputting the prompt codebook and the training data characteristic representation into a preset original classification model; wherein the original classification model comprises: a hint combining network, a hint weighting network, and a classification network;
the prompt prediction module is used for carrying out prompt prediction on the prompt codebook and the training data characteristic representation through the prompt combination network to obtain combined prompt data, and specifically comprises the following steps:
coefficient prediction is carried out on the prompt codebook and the training data characteristic representation through the prompt combination network, so that prompt combination coefficients are obtained; the prompt combination coefficient characterizes the correlation between the current training data characteristic representation and the prompt codebook, and the coefficient is predicted to be the correlation calculation;
multiplying the prompt combination coefficient and the data matrix of the prompt codebook through the prompt combination network to obtain the combined prompt data; wherein the prompt combination coefficient is represented by a data matrix, the data matrix of the prompt codebook is represented by the data matrix of the prompt codebook, and the prompt combination coefficient isThe prompt codebook is +.>The combined prompt data is +. >Then->
The weighting processing module is used for carrying out weighting processing on the combined prompt data and the training data characteristic representation through the prompt weighting network to obtain weighted prompt information, and specifically comprises the following steps:
the prompt weighting network is used for carrying out weighting calculation on the combined prompt data and the training data characteristic representation to obtain prompt weighting;
constructing the prompt weighted weight and the combined prompt data into the weighted prompt information through the prompt weighted network;
the original classification module is used for carrying out object classification on the weighted prompt information and the training data characteristic representation through the classification network to obtain a classification prediction result;
the optimizing module is used for optimizing the original classification model according to a preset classification verification result and the classification prediction result to obtain a target classification model;
the target classification module is used for acquiring target data, inputting the target data into the target classification model for object classification, and obtaining a target classification result.
10. An image classification apparatus characterized by being applied to image classification, a target classification model being an image classification model, and the target classification model being obtained by the class-reinforcement learning-based object classification method according to any one of claims 1 to 7; the device comprises:
The data acquisition module is used for acquiring target image data;
and the image classification module is used for inputting the target image data into the image classification model to perform image classification so as to obtain image category information.
11. An electronic device comprising a memory storing a computer program and a processor implementing the class-learning based object classification method according to any one of claims 1 to 7 or the image classification method according to claim 8 when the computer program is executed.
12. A computer-readable storage medium storing a computer program, wherein the computer program, when executed by a processor, implements the class-learning based object classification method according to any one of claims 1 to 7, or the image classification method according to claim 8.
CN202311070387.3A 2023-08-24 2023-08-24 Object classification method, image classification method and related equipment based on class reinforcement learning Active CN116778264B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202311070387.3A CN116778264B (en) 2023-08-24 2023-08-24 Object classification method, image classification method and related equipment based on class reinforcement learning

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202311070387.3A CN116778264B (en) 2023-08-24 2023-08-24 Object classification method, image classification method and related equipment based on class reinforcement learning

Publications (2)

Publication Number Publication Date
CN116778264A CN116778264A (en) 2023-09-19
CN116778264B true CN116778264B (en) 2023-12-12

Family

ID=87986382

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202311070387.3A Active CN116778264B (en) 2023-08-24 2023-08-24 Object classification method, image classification method and related equipment based on class reinforcement learning

Country Status (1)

Country Link
CN (1) CN116778264B (en)

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112559784A (en) * 2020-11-02 2021-03-26 浙江智慧视频安防创新中心有限公司 Image classification method and system based on incremental learning
CN114090780A (en) * 2022-01-20 2022-02-25 宏龙科技(杭州)有限公司 Prompt learning-based rapid picture classification method
CN114549894A (en) * 2022-01-20 2022-05-27 北京邮电大学 Small sample image increment classification method and device based on embedded enhancement and self-adaptation
CN116089873A (en) * 2023-02-10 2023-05-09 北京百度网讯科技有限公司 Model training method, data classification and classification method, device, equipment and medium
CN116127065A (en) * 2022-12-15 2023-05-16 四川启睿克科技有限公司 Simple and easy-to-use incremental learning text classification method and system
CN116129219A (en) * 2023-01-16 2023-05-16 西安电子科技大学 SAR target class increment recognition method based on knowledge robust-rebalancing network
CN116310582A (en) * 2023-03-29 2023-06-23 抖音视界有限公司 Classification model training method, image classification method, device, medium and equipment

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11087184B2 (en) * 2018-09-25 2021-08-10 Nec Corporation Network reparameterization for new class categorization
GB2612866A (en) * 2021-11-09 2023-05-17 Samsung Electronics Co Ltd Method and apparatus for class incremental learning

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112559784A (en) * 2020-11-02 2021-03-26 浙江智慧视频安防创新中心有限公司 Image classification method and system based on incremental learning
CN114090780A (en) * 2022-01-20 2022-02-25 宏龙科技(杭州)有限公司 Prompt learning-based rapid picture classification method
CN114549894A (en) * 2022-01-20 2022-05-27 北京邮电大学 Small sample image increment classification method and device based on embedded enhancement and self-adaptation
CN116127065A (en) * 2022-12-15 2023-05-16 四川启睿克科技有限公司 Simple and easy-to-use incremental learning text classification method and system
CN116129219A (en) * 2023-01-16 2023-05-16 西安电子科技大学 SAR target class increment recognition method based on knowledge robust-rebalancing network
CN116089873A (en) * 2023-02-10 2023-05-09 北京百度网讯科技有限公司 Model training method, data classification and classification method, device, equipment and medium
CN116310582A (en) * 2023-03-29 2023-06-23 抖音视界有限公司 Classification model training method, image classification method, device, medium and equipment

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
基于CNN和BIRCH聚类算法的类别增量学习;赵璐;何子况;朱秋煜;《电子测量技术》;第43卷(第11期);第79-84页 *

Also Published As

Publication number Publication date
CN116778264A (en) 2023-09-19

Similar Documents

Publication Publication Date Title
CN110796190B (en) Exponential modeling with deep learning features
CN110366734B (en) Optimizing neural network architecture
CN110347835B (en) Text clustering method, electronic device and storage medium
EP3711000B1 (en) Regularized neural network architecture search
US20210004677A1 (en) Data compression using jointly trained encoder, decoder, and prior neural networks
CN109376222B (en) Question-answer matching degree calculation method, question-answer automatic matching method and device
CN111950269A (en) Text statement processing method and device, computer equipment and storage medium
US20200302271A1 (en) Quantization-aware neural architecture search
US11604960B2 (en) Differential bit width neural architecture search
CN111078847A (en) Power consumer intention identification method and device, computer equipment and storage medium
CN111667022A (en) User data processing method and device, computer equipment and storage medium
CN113127624B (en) Question-answer model training method and device
CN107391682B (en) Knowledge verification method, knowledge verification apparatus, and storage medium
US20190228297A1 (en) Artificial Intelligence Modelling Engine
CN111858898A (en) Text processing method and device based on artificial intelligence and electronic equipment
CN115374950A (en) Sample detection method, sample detection device, electronic apparatus, and storage medium
CN112200296A (en) Network model quantification method and device, storage medium and electronic equipment
CN111445545B (en) Text transfer mapping method and device, storage medium and electronic equipment
CN110929532B (en) Data processing method, device, equipment and storage medium
CN116778264B (en) Object classification method, image classification method and related equipment based on class reinforcement learning
CN110852066B (en) Multi-language entity relation extraction method and system based on confrontation training mechanism
CN116432705A (en) Text generation model construction method, text generation device, equipment and medium
CN112132269B (en) Model processing method, device, equipment and storage medium
CN116415624A (en) Model training method and device, and content recommendation method and device
CN113706347A (en) Multitask model distillation method, multitask model distillation system, multitask model distillation medium and electronic terminal

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant