CN116129197A - Fish classification method, system, equipment and medium based on reinforcement learning - Google Patents

Fish classification method, system, equipment and medium based on reinforcement learning Download PDF

Info

Publication number
CN116129197A
CN116129197A CN202310347212.6A CN202310347212A CN116129197A CN 116129197 A CN116129197 A CN 116129197A CN 202310347212 A CN202310347212 A CN 202310347212A CN 116129197 A CN116129197 A CN 116129197A
Authority
CN
China
Prior art keywords
pruning
network
fish
block
model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202310347212.6A
Other languages
Chinese (zh)
Inventor
段明
胡少秋
张东旭
刘会东
鲍江辉
段瑞
陆波
姜昊
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Institute of Hydrobiology of CAS
Original Assignee
Institute of Hydrobiology of CAS
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Institute of Hydrobiology of CAS filed Critical Institute of Hydrobiology of CAS
Priority to CN202310347212.6A priority Critical patent/CN116129197A/en
Publication of CN116129197A publication Critical patent/CN116129197A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A40/00Adaptation technologies in agriculture, forestry, livestock or agroalimentary production
    • Y02A40/80Adaptation technologies in agriculture, forestry, livestock or agroalimentary production in fisheries management
    • Y02A40/81Aquaculture, e.g. of fish

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Multimedia (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Computational Linguistics (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

The invention discloses a fish classification method, a system, equipment and a medium based on reinforcement learning, and relates to the field of fish classification. Training a base line network model by using a sample data set, pruning the trained base line network model to obtain a fish classification model, classifying the fish images to be classified by using the fish classification model to obtain the types of the fish, and improving the classification accuracy and efficiency.

Description

Fish classification method, system, equipment and medium based on reinforcement learning
Technical Field
The invention relates to the field of fish classification, in particular to a fish classification method, system, equipment and medium based on reinforcement learning.
Background
The effective classification of fish data is an effective means of studying the water ecosystem. In recent years, deep neural networks (Deep neural network, DNN) have been widely used and have achieved significant achievements in fish data classification tasks. However, due to the difficulty in acquiring fish data, unbalanced sample classification and high parameter and complex calculation amount of DNN, the traditional deep network model has a great challenge in accurately classifying fish data. Currently, a viable approach to this problem is to compress the network model without affecting accuracy. The network pruning technology is a common method in model compression and presents significant advantages in processing complex network model efficiency.
The network pruning technology is to remove redundant parameters and structures in the network to obtain a more sparse network structure, and can be divided into unstructured pruning and structured pruning. The unstructured pruning realizes higher sparsity of the weight matrix by removing unimportant weight values of each layer, for example, song Han et al propose a pruning method based on a threshold value to remove redundant weight values, and consider that the absolute value of the weight value is less than the threshold value as unimportant and delete. The implementation of unstructured pruning requires the assistance of specific software and hardware and introduces additional computational costs. Compared with unstructured pruning, the structured pruning reduces network parameters and calculation cost by removing redundant layers, convolution kernels and channels, and has wider application scenes.
Compared with the important weight parameters of the inheritance base line network, the structure of the pruning network is a key for determining the performance of the pruning network model. The network pruning technique can be regarded as a network architecture search problem, all networks meeting the search conditions are called sub-networks or candidate networks, and a network search space is formed by all sub-networks, and the object of the network search is to search for the optimal sub-network in such a search space.
At present, some network pruning methods are based on manually formulated pruning rates to prune a network model, but the manually formulated pruning rates can cause low network pruning efficiency and easy convergence to local optimum in the actual pruning process. In addition, most network pruning methods prune networks in a hierarchical manner, and cannot fully consider layer-to-layer dependency. The network pruning method is to search a sparse structure of a network in a layer-by-layer manner, and lacks effective utilization of global information of the network structure, and the layering strategy often generates suboptimal compression results. In addition, the network pruning method has serious label dependency, and most pruning methods need to rely on label data in the pruning process, so that the application of the network pruning method is limited when a data label cannot be used in the pruning process. The network pruning technology can be regarded as a neural network architecture search, all networks meeting the search conditions are called sub-networks or candidate networks, and a network search space is formed by all sub-networks, and the aim of the network search is to search the optimal sub-network in the search space. However, the conventional network architecture searching method has a large searching space, so that searching for an optimal sub-network structure is difficult.
In summary, the current network pruning method is adopted to prune the deep network model, so that the fish data is prevented from being classified, and the problems of low classification accuracy and low efficiency exist.
Disclosure of Invention
The invention aims to provide a fish classification method, a system, equipment and a medium based on reinforcement learning so as to improve the accuracy of classifying fish.
In order to achieve the above object, the present invention provides the following solutions:
a fish classification method based on reinforcement learning, comprising:
acquiring an image of fish to be classified;
inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
Optionally, the construction process of the fish classification model specifically comprises the following steps:
training the base line network model by using the training set to obtain a trained base line network model;
initializing the pruning network model by taking the trained base line network model as the pruning network model to obtain an initial pruning network model;
dividing the trained base line network model and the initial pruning network model into a plurality of base line block networks and a plurality of pruning block networks according to layers;
inputting the training set into the pruning block network and the baseline block network, and determining a metric score of each pruning block network;
determining the pruning rate of each pruning block network by using a reinforcement learning algorithm according to the measurement scores;
pruning each pruning block network according to the pruning rate to obtain a plurality of pruned baseline block networks;
and constructing a fish classification model according to the pruned pruning network based on the verification set and the test set.
Optionally, the constructing a fish classification model according to the pruned pruning network based on the verification set and the test set specifically includes:
inputting the verification set into the pruned pruning block network and the baseline block network respectively to obtain a first output result and a second output result;
calculating a first mean square error of the first output result and the second output result, and calculating a first pruning efficiency metric value of the pruned pruning block network;
performing balance calculation on the first mean square error and the first pruning efficiency metric value to obtain a balance calculation value;
selecting a preset number of pruned pruning block networks from large to small according to the weighing calculated value, and constructing an initial fish classification model;
and adjusting parameters of the initial fish classification model by using the test set to obtain a fish classification model.
Optionally, training the base line network model by using the training set to obtain a trained base line network model, which specifically includes:
carrying out data enhancement on the fish images by adopting random scrambling, zero filling and random sampling technologies to obtain a processed training set;
and training the base line network model by using the processed training set to obtain a trained base line network model.
Optionally, the training set is input into the pruning block network and the baseline block network, and the measurement score of each pruning block network is determined, which specifically includes:
respectively inputting the training set into a first-stage pruning block network and a first-stage baseline block network to obtain pruning block network output results and baseline block network output results;
calculating the mean square error of the pruning block network output result and the baseline block network output result;
calculating the accuracy measurement value of the current baseline block network according to the mean square error;
using the formula
Figure SMS_1
Calculating a pruning efficiency metric value of the current pruning block network; wherein FLOPs (S i ) FLOPs, FLOPs (B i ) FLOPs representing the ith baseline block network;
determining the measurement score of the current pruning block network according to the accuracy measurement value and the pruning efficiency measurement value;
and inputting the baseline block network output result to a next-stage pruning block network and a next-stage baseline block network to obtain a pruning block network output result and a baseline block network output result, and returning to the step of calculating the mean square error of the pruning block network output result and the baseline block network output result to obtain the measurement score of each pruning block network.
Optionally, pruning each baseline block network according to the pruning rate to obtain a plurality of pruned baseline block networks, which specifically includes:
calculating the number of convolution kernels to be pruned of the current layer according to the pruning rate and the number of convolution kernels of each layer of the baseline block network;
calculating importance scores of convolution kernels of each layer of the baseline block network;
pruning is carried out on the convolution kernels of each layer in the baseline block network from small to large according to the importance score and the number of the convolution kernels to be deleted in the current layer, and a baseline block network after pruning is obtained.
Optionally, the calculating the number of convolution kernels to be pruned in the current layer according to the pruning rate and the number of convolution kernels in each layer of the baseline block network specifically includes:
calculating the number of convolution kernels to be pruned of the current layer by using a formula v=o×u; v is the number of convolution kernels to be pruned in the current layer; o is pruning rate of the current layer; u is the number of convolution kernels of the current layer;
when (when)
Figure SMS_2
When the number of convolution kernels to be pruned in the current layer is v;
when v=u, the number of convolution kernels to be pruned at the current layer is u-1.
A reinforcement learning based fish classification system comprising:
the data acquisition module is used for acquiring the images of the fishes to be classified;
the classification module is used for inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
An electronic device, comprising: the fish classification system comprises a memory and a processor, wherein the memory is used for storing a computer program, and the processor runs the computer program to enable the electronic equipment to execute the fish classification method based on reinforcement learning.
A computer readable storage medium storing a computer program which when executed by a processor implements the reinforcement learning-based fish classification method described above.
According to the specific embodiment provided by the invention, the invention discloses the following technical effects:
according to the fish classification method based on reinforcement learning, the base line network model is trained by using the sample data set, then the trained base line network model is pruned to obtain the fish classification model, the fish image to be classified is classified by using the fish classification model, the types of the fishes are obtained, and the classification accuracy and efficiency are improved.
Drawings
In order to more clearly illustrate the embodiments of the present invention or the technical solutions of the prior art, the drawings that are needed in the embodiments will be briefly described below, it being obvious that the drawings in the following description are only some embodiments of the present invention, and that other drawings may be obtained according to these drawings without inventive effort for a person skilled in the art.
FIG. 1 is a flow chart of a fish classification method based on reinforcement learning provided by the invention;
FIG. 2 is a flow chart of a fish classification model construction process of the invention;
FIG. 3 is a block network supervision pruning method framework diagram based on reinforcement learning algorithm of the present invention;
FIG. 4 is a flowchart of a reinforcement learning algorithm according to the present invention;
FIG. 5 is a diagram of a network pruning algorithm framework of the present invention;
FIG. 6 is a graph of accuracy metric values for a ResNet-20 network of the present invention;
FIG. 7 is a comparison of the pruning of the ResNet-20 network of the present invention;
FIG. 8 is a comparison of the pruning of the ResNet-56 network of the present invention;
fig. 9 is a block diagram of a fish classification system based on reinforcement learning.
Detailed Description
The following description of the embodiments of the present invention will be made clearly and completely with reference to the accompanying drawings, in which it is apparent that the embodiments described are only some embodiments of the present invention, but not all embodiments. All other embodiments, which can be made by those skilled in the art based on the embodiments of the invention without making any inventive effort, are intended to be within the scope of the invention.
The invention aims to provide a fish classification method, a system, equipment and a medium based on reinforcement learning so as to improve the accuracy of classifying fish.
In order that the above-recited objects, features and advantages of the present invention will become more readily apparent, a more particular description of the invention will be rendered by reference to the appended drawings and appended detailed description.
As shown in fig. 1, the fish classification method based on reinforcement learning of the present invention comprises:
step 101: and obtaining an image of the fish to be classified.
Step 102: inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish.
The fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
Further, as shown in fig. 2, the construction process of the fish classification model specifically includes:
s1: and training the base line network model by using the training set to obtain a trained base line network model.
Further, the S1 specifically includes:
and carrying out data enhancement on the fish images by adopting random scrambling, zero filling and random sampling technologies to obtain a processed training set.
And training the base line network model by using the processed training set to obtain a trained base line network model.
In practical application, firstly, the fish image data is preprocessed. In order to improve the convergence and generalization ability of the base line network model, the base line network model is trained with the processed fish image data. Firstly, for the types with the quantity less than 300 in the processed fish image data, 5 data enhancement methods such as horizontal overturn, vertical overturn, 90 DEG rotation, 180 DEG rotation, 270 DEG rotation and the like are adopted to expand the data set, and finally the image is uniformly scaled to 224 multiplied by 224. And then dividing the expanded sample data set, and randomly dividing the expanded sample data set into a training set, a verification set and a test set according to the ratio of 8:1:1. Finally, training the base line network model by using the training set after processing.
S2: and initializing the pruning network model by taking the trained base line network model as the pruning network model to obtain an initial pruning network model. In practical reference, since a larger search space exists in the pruning network model, in order to reduce the space size of the pruning network model for searching the optimal network, the pruning network model is randomly initialized within the range of the compression rate of floating point operations (Floating Point of operations, FLPs).
S3: dividing the trained base line network model and the initial pruning network model into a plurality of base line block networks and a plurality of pruning block networks according to layers. In practical application, the trained base line network model and the initial pruning network model are divided into a plurality of block networks according to the same layer: as shown in fig. 3, in order to improve the efficiency of network pruning, the trained base line network model B and the initial pruning network model S are divided into a plurality of block networks according to layers by referring to the idea of knowledge distillation, and the pruning block networks learn the knowledge of the corresponding base line block networks, i-th pruning block network S i And baseline block network B i The input of (a) is the i-1 th baseline block network B i-1 Is provided.
S4: the training set is input into the pruning block network and the baseline block network, and a metric score of each pruning block network is determined.
Further, the step S4 specifically includes:
and respectively inputting the training set into the first-stage pruning block network and the first-stage baseline block network to obtain pruning block network output results and baseline block network output results.
And calculating the mean square error of the pruning block network output result and the baseline block network output result.
In practical application, the formula is utilized
Figure SMS_3
Calculating MSE errors for a pruned block network and a baseline block network, where f (X i W) and g (X) i W' denote the ith baseline block network B, respectively i And a pruning block network S i Is provided.
And calculating the accuracy measurement value of the current baseline block network according to the mean square error. In practical applications, the different network structures are evaluated based on a metric (accuracy metric) defining similar accuracy based on MSE loss, as shown in the following formula:
Figure SMS_4
using the formula
Figure SMS_5
And calculating the pruning efficiency metric value of the current pruning block network.
Wherein FLOPs (S i ) FLOPs, FLOPs (B i ) FLOPs representing the ith baseline block network. In practice, to further distinguish block networks with similar performance but differing computational efficiency, the present invention uses the flow compression rate of a pruned block network to define the efficiency metric of the model.
And determining the measurement score of the current pruning block network according to the accuracy measurement value and the pruning efficiency measurement value. In practical application, the model performance (accuracy measurement value) and the model efficiency measurement value are combined to obtain a score reflecting the advantages and disadvantages of the pruning network model, and the score is shown in the following formula:
Figure SMS_6
where α is a weight used to control network model performance and efficiency, higher α values will preferentially reduce more FLPs. For each block network in the pruned network, the goal is to find the block network with the highest metric score R.
And inputting the baseline block network output result to a next-stage pruning block network and a next-stage baseline block network to obtain a pruning block network output result and a baseline block network output result, and returning to the step of calculating the mean square error of the pruning block network output result and the baseline block network output result to obtain the measurement score of each pruning block network.
S5: and determining the pruning rate of each pruning block network by using a reinforcement learning algorithm according to the metric scores.
In practical applications, a reinforcement learning algorithm (Reinforcement Learning, RL) is used to search for the optimal network structure for each pruned block network: as shown in fig. 4, the reinforcement learning algorithm is an optimal solution algorithm of the reward-oriented mechanism, essentially by constructing the solution problem as a markov decision process, and by adjusting the learning strategy through iterative learning to find an optimal solution at each moment. In the technology of the invention, the pruning process of the pruning network is constructed into a Markov decision process, the characterization information of the pruning network model is used as a state, the pruning rate of each layer is used as an action, the model efficiency and the performance are used as rewards, and the better pruning rate of each layer in each block network is searched.
S6: pruning is carried out on each pruning block network according to the pruning rate, and a plurality of pruning block networks after pruning are obtained.
S7: and constructing a fish classification model according to the pruned pruning network based on the verification set and the test set.
In practical application, pruning is carried out on the pruning block network by using the pruning rate obtained in the step S5, then the pruning block network after pruning is evaluated, and finally the network model with the highest network performance is selected as the network model which is finally searched.
In practical application, the step S7 specifically includes:
and respectively inputting the verification set into the pruned pruning block network and the baseline block network to obtain a first output result and a second output result.
And calculating a first mean square error of the first output result and the second output result, and calculating a first pruning efficiency metric value of the pruned pruning block network.
And carrying out weighing calculation on the first mean square error and the first pruning efficiency metric value to obtain a weighing calculation value.
And selecting a preset number of pruned pruning block networks from large to small according to the weighing calculated value, and constructing an initial fish classification model.
And adjusting parameters of the initial fish classification model by using the test set to obtain a fish classification model.
Network pruning can be divided into network layer pruning and intra-layer convolution kernel pruning, and the technology only prunes intra-layer convolution kernels. The inventive technique uses the weight L1 norm to prune the network model, and pruning the convolution kernels of each layer is shown in fig. 5.
The specific flow is as follows:
(1) The importance of the convolution kernel is ordered. In each layer, the importance scores of the convolution kernels or neurons are calculated and the convolution kernels or neurons are ordered in a small to large manner according to the importance scores.
(2) And calculating the number of convolution kernels to be deleted in the current layer. Assuming that the pruning rate given by the layer is o and the number of convolution kernels is u, the number of convolution kernels to be deleted is v=o×u, and if v is a decimal number, the downward rounding operation is performed on the convolution kernels, and only the integer part is reserved.
(3) The convolution kernel that is not important for the current layer is deleted. If it is
Figure SMS_7
The first v convolution kernels are deleted directly. If v=u, u-1 convolution kernels are deleted, i.e. at least one convolution kernel is reserved, and in order to ensure connectivity between the front and back layers, the convolution kernel with the highest importance score is reserved at that layer.
In order to verify the compression performance of the invention on a Fish classification model, the invention selects a public data set Fish4 knowledges data set to carry out experimental verification on a ResNet-20 network model. The test platform is Ubuntu 18.06, the CPU is AMD 3090X, the GPU is Titan RTX, and the video memory is 24GB.
The Fish4 knowledges dataset is Fish image data collected at underwater viewing stations of the south bay strait, the island and the hubi lake during the period of 10 th 1 th 2010 to 9 th 30 th 2013. The dataset contained 23 fish 27370 images, the number of different categories of images being widely different, with a single top species accounting for approximately 44% of the images and the top 15 species corresponding to 97% of the images. In consideration of the fact that data imbalance in the training set is easy to cause deviation in model training results, data are enhanced, for the types with the number smaller than 300 in the data, 5 data enhancement methods such as horizontal overturning, vertical overturning, 90-degree rotation, 180-degree rotation and 270-degree rotation are adopted to expand the data set, and finally the image is uniformly scaled to 224×224 pixels for subsequent experiments. The data sets were randomly shuffled and then following 8:1: the scale of 1 divides the data into a training set, a validation set and a test set. Finally obtaining a training set image: 29575 sheets; test set image: 3625 sheets; verification set image: 3625 sheets.
In practical application, the network model of the base line selects a ResNet network model, and the ResNet network model mainly comprises residual blocks and residual connection, wherein one residual block comprises a plurality of convolution layers. For a residual block, the size of the input and output feature maps must be equal unless there is a shortcut in the block. The present invention compresses only the convolution layers of each block except the last layer in order to keep the output channel of each block unchanged. Parameters in the training process of the base line network model are set as follows: epoch is set to 10; the batch size is 32; the learning rate size is initialized to 0.001; the optimizer adopts Adam, the momentum size is 0.9, and the weight attenuation size is 5×10 -4
To verify the formula
Figure SMS_8
Pruning training was performed on the ResNet-20 network and the ResNet-56 network at the Fish4 knowledges dataset. The ResNet-20 network is divided into 3 Block networks, namely Block1, block2 and Block3, and each Block network is subjected to compression experiments, as shown in figure 6, with the increase of FLPs compression rate of each Block network, R a Is also gradually decreasing.
The reinforcement learning algorithm used in the invention is a depth deterministic strategy gradient algorithm. The depth deterministic gradient algorithm comprises an Actor network and a Critic network, wherein the Actor network and the Critic network respectively comprise 2 hidden layers, and each hidden layer comprises 300 neurons. The buffer size was set to 600 and the batch size was set to 32. The learning rate of the Actor network was set to 0.001, and the learning rate of the critic network was set to 0.002. The super parameter τ=0.01 of the target network soft update, and the number of rounds is set to 600.
ResNet-20 network was trained on the Fish4 knowledges dataset with 98.12% accuracy. ResNet-20 network can compress 32.53% of FLOPs, and the accuracy is improved by 0.52%. The ResNet-20 network compression results are shown in FIG. 7, where the variation of the convolution kernels of the layers of the ResNet-20 network can be seen before and after pruning. The experimental result shows that the method can find the redundant structural parameters of the network model and effectively compress the structural parameters.
To further verify the effectiveness of the method in complex network models, the accuracy of the ResNet-56 network test was 98.12% trained on the Fish4 knowledges dataset. FIG. 8 shows pruning results for various layers of ResNet-56, which can prune 48.43% of FLOPs, but with a post-pruning accuracy of 99.22, which can be improved by 1.1%. The experimental result shows that the method can be used for effectively compressing in a complex network model.
The technology combines a reinforcement learning algorithm and knowledge distillation to provide a block network supervision pruning algorithm based on the reinforcement learning algorithm. The invention has the following advantages:
(1) The technology of the invention uses the reinforcement learning algorithm to learn the pruning rate of each layer of the network model, and can dynamically adjust the pruning rate of each layer according to the efficiency and the performance of the network.
(2) In the pruning process, the technology does not prune the network in a layer-by-layer mode, but learns pruning rates of all layers of the network model.
(3) The technology of the invention uses knowledge distillation technology as a reference, and can monitor the pruning network by minimizing the difference between the output characteristics of the pruning network and the base line network without using data tag information in the pruning process.
(4) The technology of the invention refers to the Markov chain Monte Carlo method, and the base line network and the pruning network are divided into the same block networks according to layers, so that pruning can be carried out on each block network at the same time. The technology of the invention can reduce the search space of the network model, effectively compress the network structure and improve the pruning efficiency of the network model.
Example two
In order to perform a corresponding method of the above embodiment to achieve the corresponding functions and technical effects, a fish classification system based on reinforcement learning is provided, as shown in fig. 9, which includes:
the data acquisition module 901 is used for acquiring the images of the fishes to be classified.
The classification module 902 is configured to input the image of the fish to be classified into a fish classification model to obtain a classification result; the classification result is the type of fish.
The fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
Example III
The invention also provides an electronic device, comprising: the fish classification method according to the first embodiment is a fish classification method based on reinforcement learning, and the electronic device is a fish classification device.
Example IV
The present invention also provides a computer-readable storage medium storing a computer program which, when executed by a processor, implements the reinforcement learning-based fish classification method of the first embodiment.
In the present specification, each embodiment is described in a progressive manner, and each embodiment is mainly described in a different point from other embodiments, and identical and similar parts between the embodiments are all enough to refer to each other. For the system disclosed in the embodiment, since it corresponds to the method disclosed in the embodiment, the description is relatively simple, and the relevant points refer to the description of the method section.
The principles and embodiments of the present invention have been described herein with reference to specific examples, the description of which is intended only to assist in understanding the methods of the present invention and the core ideas thereof; also, it is within the scope of the present invention to be modified by those of ordinary skill in the art in light of the present teachings. In view of the foregoing, this description should not be construed as limiting the invention.

Claims (10)

1. A fish classification method based on reinforcement learning, comprising:
acquiring an image of fish to be classified;
inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
2. The reinforcement learning-based fish classification method of claim 1, wherein the process of constructing the fish classification model specifically comprises:
training the base line network model by using the training set to obtain a trained base line network model;
initializing the pruning network model by taking the trained base line network model as the pruning network model to obtain an initial pruning network model;
dividing the trained base line network model and the initial pruning network model into a plurality of base line block networks and a plurality of pruning block networks according to layers;
inputting the training set into the pruning block network and the baseline block network, and determining a metric score of each pruning block network;
determining the pruning rate of each pruning block network by using a reinforcement learning algorithm according to the measurement scores;
pruning is carried out on each pruning block network according to the pruning rate, so that a plurality of pruned pruning networks are obtained;
and constructing a fish classification model according to the pruned pruning network based on the verification set and the test set.
3. The reinforcement learning-based fish classification method according to claim 2, wherein the constructing a fish classification model from the pruned pruning network based on the verification set and the test set specifically comprises:
inputting the verification set into the pruned pruning block network and the baseline block network respectively to obtain a first output result and a second output result;
calculating a first mean square error of the first output result and the second output result, and calculating a first pruning efficiency metric value of the pruned pruning block network;
performing balance calculation on the first mean square error and the first pruning efficiency metric value to obtain a balance calculation value;
selecting a preset number of pruned pruning block networks from large to small according to the weighing calculated value, and constructing an initial fish classification model;
and adjusting parameters of the initial fish classification model by using the test set to obtain a fish classification model.
4. The reinforcement learning-based fish classification method of claim 2, wherein the training the base line network model using the training set to obtain a trained base line network model specifically comprises:
carrying out data enhancement on the fish images by adopting random scrambling, zero filling and random sampling technologies to obtain a processed training set;
and training the base line network model by using the processed training set to obtain a trained base line network model.
5. The reinforcement learning based fish classification method of claim 2, wherein said inputting said training set into said pruning block network and said baseline block network, determining a metric score for each of said pruning block networks, comprises in particular:
respectively inputting the training set into a first-stage pruning block network and a first-stage baseline block network to obtain pruning block network output results and baseline block network output results;
calculating the mean square error of the pruning block network output result and the baseline block network output result;
calculating the accuracy measurement value of the current baseline block network according to the mean square error;
using the formula
Figure QLYQS_1
Calculating a pruning efficiency metric value of the current pruning block network; wherein FLOPs (S i ) FLOPs, FLOPs (B i ) FLOPs representing the ith baseline block network;
determining the measurement score of the current pruning block network according to the accuracy measurement value and the pruning efficiency measurement value;
and inputting the baseline block network output result to a next-stage pruning block network and a next-stage baseline block network to obtain a pruning block network output result and a baseline block network output result, and returning to the step of calculating the mean square error of the pruning block network output result and the baseline block network output result to obtain the measurement score of each pruning block network.
6. The reinforcement learning-based fish classification method according to claim 2, wherein pruning is performed on each of the baseline block networks according to the pruning rate to obtain a plurality of pruned baseline block networks, and specifically comprising:
calculating the number of convolution kernels to be pruned of the current layer according to the pruning rate and the number of convolution kernels of each layer of the baseline block network;
calculating importance scores of convolution kernels of each layer of the baseline block network;
pruning is carried out on the convolution kernels of each layer in the baseline block network from small to large according to the importance score and the number of the convolution kernels to be deleted in the current layer, and a baseline block network after pruning is obtained.
7. The reinforcement learning-based fish classification method according to claim 6, wherein the calculating the number of convolution kernels to be pruned in the current layer according to the pruning rate and the number of convolution kernels in each layer of the baseline block network specifically comprises:
calculating the number of convolution kernels to be pruned of the current layer by using a formula v=o×u; v is the number of convolution kernels to be pruned in the current layer; o is pruning rate of the current layer; u is the number of convolution kernels of the current layer;
when (when)
Figure QLYQS_2
When the number of convolution kernels to be pruned in the current layer is v;
when v=u, the number of convolution kernels to be pruned at the current layer is u-1.
8. A reinforcement learning-based fish classification system, comprising:
the data acquisition module is used for acquiring the images of the fishes to be classified;
the classification module is used for inputting the fish images to be classified into a fish classification model to obtain classification results; the classification result is the type of fish;
the fish classification model is obtained by training a base line network model by using a sample data set and pruning the trained base line network model; the sample data set comprises a training set, a verification set and a test set; the training set, the verification set and the test set all comprise a plurality of fish images and fish type labels corresponding to the fish images.
9. An electronic device, comprising: a memory for storing a computer program, and a processor that runs the computer program to cause the electronic device to perform the reinforcement learning-based fish classification method of any one of claims 1-7.
10. A computer readable storage medium, characterized in that the computer readable storage medium stores a computer program which, when executed by a processor, implements the reinforcement learning based fish classification method of any one of claims 1-7.
CN202310347212.6A 2023-04-04 2023-04-04 Fish classification method, system, equipment and medium based on reinforcement learning Pending CN116129197A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310347212.6A CN116129197A (en) 2023-04-04 2023-04-04 Fish classification method, system, equipment and medium based on reinforcement learning

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310347212.6A CN116129197A (en) 2023-04-04 2023-04-04 Fish classification method, system, equipment and medium based on reinforcement learning

Publications (1)

Publication Number Publication Date
CN116129197A true CN116129197A (en) 2023-05-16

Family

ID=86303034

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310347212.6A Pending CN116129197A (en) 2023-04-04 2023-04-04 Fish classification method, system, equipment and medium based on reinforcement learning

Country Status (1)

Country Link
CN (1) CN116129197A (en)

Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111340227A (en) * 2020-05-15 2020-06-26 支付宝(杭州)信息技术有限公司 Method and device for compressing business prediction model through reinforcement learning model
CN111600851A (en) * 2020-04-27 2020-08-28 浙江工业大学 Feature filtering defense method for deep reinforcement learning model
CN112686382A (en) * 2020-12-30 2021-04-20 中山大学 Convolution model lightweight method and system
CN112766496A (en) * 2021-01-28 2021-05-07 浙江工业大学 Deep learning model security guarantee compression method and device based on reinforcement learning
CN113011588A (en) * 2021-04-21 2021-06-22 华侨大学 Pruning method, device, equipment and medium for convolutional neural network
US20210397965A1 (en) * 2020-06-22 2021-12-23 Nokia Technologies Oy Graph Diffusion for Structured Pruning of Neural Networks
CN114118402A (en) * 2021-10-12 2022-03-01 重庆科技学院 Self-adaptive pruning model compression algorithm based on grouping attention mechanism
CN115527106A (en) * 2022-10-21 2022-12-27 深圳大学 Imaging identification method and device based on quantitative fish identification neural network model
CN115600650A (en) * 2022-11-02 2023-01-13 华侨大学(Cn) Automatic convolution neural network quantitative pruning method and equipment based on reinforcement learning and storage medium
CN115829022A (en) * 2022-11-16 2023-03-21 西安交通大学 CNN network pruning rate automatic search method and system based on reinforcement learning

Patent Citations (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN111600851A (en) * 2020-04-27 2020-08-28 浙江工业大学 Feature filtering defense method for deep reinforcement learning model
CN111340227A (en) * 2020-05-15 2020-06-26 支付宝(杭州)信息技术有限公司 Method and device for compressing business prediction model through reinforcement learning model
US20210397965A1 (en) * 2020-06-22 2021-12-23 Nokia Technologies Oy Graph Diffusion for Structured Pruning of Neural Networks
CN112686382A (en) * 2020-12-30 2021-04-20 中山大学 Convolution model lightweight method and system
CN112766496A (en) * 2021-01-28 2021-05-07 浙江工业大学 Deep learning model security guarantee compression method and device based on reinforcement learning
CN113011588A (en) * 2021-04-21 2021-06-22 华侨大学 Pruning method, device, equipment and medium for convolutional neural network
CN114118402A (en) * 2021-10-12 2022-03-01 重庆科技学院 Self-adaptive pruning model compression algorithm based on grouping attention mechanism
CN115527106A (en) * 2022-10-21 2022-12-27 深圳大学 Imaging identification method and device based on quantitative fish identification neural network model
CN115600650A (en) * 2022-11-02 2023-01-13 华侨大学(Cn) Automatic convolution neural network quantitative pruning method and equipment based on reinforcement learning and storage medium
CN115829022A (en) * 2022-11-16 2023-03-21 西安交通大学 CNN network pruning rate automatic search method and system based on reinforcement learning

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
MANAS GUPTA: "Learning to Prune Deep Neural Networks via Reinforcement Learning", 《ARXIV.ORG/ABS/2007.04756》, pages 1 - 11 *
刘会东: "分块压缩学习剪枝算法", 《小型微型计算机系统》, vol. 44, no. 02, pages 3 *
刘会东: "基于强化学习的无标签网络剪枝", 模式识别与人工智能, vol. 34, no. 03, pages 2 *

Similar Documents

Publication Publication Date Title
CN110189334B (en) Medical image segmentation method of residual error type full convolution neural network based on attention mechanism
CN109948029B (en) Neural network self-adaptive depth Hash image searching method
US10339450B2 (en) System and method for efficient evolution of deep convolutional neural networks using filter-wise recombination and propagated mutations
CN108765506B (en) Layer-by-layer network binarization-based compression method
Giacomello et al. Doom level generation using generative adversarial networks
CN112115998B (en) Method for overcoming catastrophic forgetting based on anti-incremental clustering dynamic routing network
CN110136135B (en) Segmentation method, device, equipment and storage medium
CN110033089B (en) Method and system for optimizing parameters of handwritten digital image recognition deep neural network based on distributed estimation algorithm
CN111105017A (en) Neural network quantization method and device and electronic equipment
CN111242268A (en) Method for searching convolutional neural network
US20230376777A1 (en) System and method for efficient evolution of deep convolutional neural networks using filter-wise recombination and propagated mutations
CN111695640A (en) Foundation cloud picture recognition model training method and foundation cloud picture recognition method
CN114548591A (en) Time sequence data prediction method and system based on hybrid deep learning model and Stacking
CN114743027B (en) Weak supervision learning-guided cooperative significance detection method
CN107240100B (en) Image segmentation method and system based on genetic algorithm
CN114529793A (en) Depth image restoration system and method based on gating cycle feature fusion
CN111222534A (en) Single-shot multi-frame detector optimization method based on bidirectional feature fusion and more balanced L1 loss
CN116129197A (en) Fish classification method, system, equipment and medium based on reinforcement learning
CN114937154B (en) Significance detection method based on recursive decoder
CN114359359B (en) Multitask optical and SAR remote sensing image registration method, equipment and medium
CN115908909A (en) Evolutionary neural architecture searching method and system based on Bayes convolutional neural network
CN115375966A (en) Image countermeasure sample generation method and system based on joint loss function
CN113095328A (en) Self-training-based semantic segmentation method guided by Gini index
Zhao et al. Multi-Objective Net Architecture Pruning for Remote Sensing Classification
Wang et al. Phenological prediction algorithm based on deep learning

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
RJ01 Rejection of invention patent application after publication
RJ01 Rejection of invention patent application after publication

Application publication date: 20230516