CN114548353A - Model training method, electronic device and storage medium - Google Patents

Model training method, electronic device and storage medium Download PDF

Info

Publication number
CN114548353A
CN114548353A CN202011341114.4A CN202011341114A CN114548353A CN 114548353 A CN114548353 A CN 114548353A CN 202011341114 A CN202011341114 A CN 202011341114A CN 114548353 A CN114548353 A CN 114548353A
Authority
CN
China
Prior art keywords
network
model
models
sub
network model
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202011341114.4A
Other languages
Chinese (zh)
Inventor
蒋阳
豆泽阳
庞磊
赵丛
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Gongdadi Innovation Technology Shenzhen Co ltd
Original Assignee
Gongdadi Innovation Technology Shenzhen Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Gongdadi Innovation Technology Shenzhen Co ltd filed Critical Gongdadi Innovation Technology Shenzhen Co ltd
Priority to CN202011341114.4A priority Critical patent/CN114548353A/en
Publication of CN114548353A publication Critical patent/CN114548353A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

The application relates to the technical field of machine learning, and particularly discloses a model training method, electronic equipment and a storage medium, wherein the method comprises the following steps: acquiring a pre-trained hyper-network model; determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model; acquiring a plurality of mainstream network models trained based on open source data; splicing each target sub-network model serving as a first trunk network with a first branch network to obtain a plurality of first spliced networks, and splicing each main flow network model serving as a second trunk network with a second branch network to obtain a plurality of second spliced networks; fine-tuning and testing a plurality of the first and second spliced networks to determine a target network model; and carrying out transfer learning on the target network model to obtain a required model, so that convenience and intellectualization of model training are realized, and the experience degree of a user is improved.

Description

Model training method, electronic device and storage medium
Technical Field
The present application relates to the field of machine learning technologies, and in particular, to a model training method, an electronic device, and a storage medium.
Background
Neural Network Architecture Search (NAS) is one of the hotspots in the field of automatic Machine Learning (AutoML), and by designing an economic and efficient Search method, a Neural network with strong generalization capability and friendly hardware requirements can be automatically acquired, so that a large amount of manpower and material resources can be saved. The main working principle of the NAS is that a search space is defined firstly, then candidate network structures are found out through a search strategy, the candidate network structures are evaluated, next round of search is carried out according to feedback until a target network structure is searched, and automatic machine learning is carried out based on the target network structure to obtain a needed model.
However, the existing NAS search process is long, and after user data is obtained each time, a network structure needs to be obtained by searching from the head, and when the network structure is used, pre-training needs to be performed, and then migration learning is performed on the user data, so that a required model can be obtained. Since a long search waiting time is required, a user-friendly experience cannot be provided.
Disclosure of Invention
The embodiment of the application provides a model training method, electronic equipment and a storage medium, and aims to solve the problem of long-time searching waiting of a neural network structure, save time cost and provide more friendly experience for users.
In a first aspect, the present application provides a method for training a model, the method comprising:
acquiring a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models;
determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model;
acquiring a plurality of mainstream network models trained based on open source data;
splicing each target sub-network model serving as a first trunk network with a first branch network to obtain a plurality of first spliced networks, and splicing each main-flow network model serving as a second trunk network with a second branch network to obtain a plurality of second spliced networks, wherein the first branch networks spliced behind each first trunk network have the same network structure and shared parameters, and the second branch networks spliced behind each second trunk network have the same network structure and unshared parameters;
fine-tuning and testing a plurality of the first and second spliced networks to determine a target network model;
and carrying out transfer learning on the target network model to obtain a required model.
In a second aspect, an embodiment of the present application further provides another model training method, where the method includes:
acquiring a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models;
determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model;
acquiring a plurality of mainstream network models trained based on open source data;
splicing each target sub-network model serving as a first trunk network with a first branch network to obtain a plurality of first spliced networks, and splicing each main flow network model serving as a second trunk network with a second branch network to obtain a plurality of second spliced networks;
fine-tuning and testing a plurality of the first and second spliced networks to determine a target network model;
and carrying out transfer learning on the target network model to obtain a required model.
In a third aspect, an embodiment of the present application further provides another model training method, where the method includes:
acquiring a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models;
determining a plurality of target subnetwork models from a preset number of subnetwork models of the super-network model;
acquiring a plurality of mainstream network models trained based on open source data;
determining a target network model according to the plurality of target subnetwork models and the plurality of mainstream network models;
and carrying out transfer learning on the target network model to obtain a required model.
In a fourth aspect, an embodiment of the present application provides an electronic device, including a memory and a processor;
the memory is used for storing a computer program;
the processor is configured to execute the computer program and, when executing the computer program, implement any one of the model training methods provided in the embodiments of the present application.
In a fifth aspect, the present application provides a computer-readable storage medium, which stores a computer program, and when the computer program is executed by a processor, the computer program causes the processor to implement any one of the model training methods provided in the present application.
The model training method, the electronic device and the storage medium provided by the embodiment of the application combine the pre-trained hyper-network model and the main-flow network model, and then use the splicing mode of the main network and the branch network, so that the searching time of Neural network structure Search (NAS) can be greatly saved, and the accuracy of the model can be improved.
It is to be understood that both the foregoing general description and the following detailed description are exemplary and explanatory only and are not restrictive of the disclosure of the embodiments of the application.
Drawings
In order to more clearly illustrate the technical solutions of the embodiments of the present application, the drawings needed to be used in the description of the embodiments are briefly introduced below, and it is obvious that the drawings in the following description are some embodiments of the present application, and it is obvious for those skilled in the art to obtain other drawings based on these drawings without creative efforts.
FIG. 1 is a training method of a hyper-network model according to an embodiment of the present disclosure;
FIG. 2 is a schematic diagram of a super network architecture provided by an embodiment of the present application;
FIG. 3 is a schematic flow chart diagram of a model training method provided by an embodiment of the present application;
FIG. 4 is a schematic flow chart of selecting a sub-network model satisfying a predetermined model constraint condition from a super-network model according to an embodiment of the present application
FIG. 5 is a schematic flow chart diagram illustrating a method for obtaining a plurality of mainstream network models trained based on open source data according to an embodiment of the present application;
FIG. 6 is a schematic view of a scenario when a model training method provided in an embodiment of the present application is applied to a server;
FIG. 7 is a schematic flow chart diagram of another model training method provided by embodiments of the present application;
fig. 8 is a schematic block diagram of an electronic device according to an embodiment of the present application.
Detailed Description
The technical solutions in the embodiments of the present application will be clearly and completely described below with reference to the drawings in the embodiments of the present application, and it is obvious that the described embodiments are some, but not all, embodiments of the present application. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present application.
The flow diagrams depicted in the figures are merely illustrative and do not necessarily include all of the elements and operations/steps, nor do they necessarily have to be performed in the order depicted. For example, some operations/steps may be decomposed, combined or partially combined, so that the actual execution sequence may be changed according to the actual situation.
The term "and/or" as used in this specification and the appended claims refers to and includes any and all possible combinations of one or more of the associated listed items.
At present, Neural Network Architecture Search (NAS) is one of the hotspots in the field of automatic machine learning (Auto-ML), and by designing an economic and efficient Search method, a Neural network with strong generalization capability and friendly hardware requirements can be automatically acquired, so that a large amount of manpower and material resources can be saved.
However, the existing NAS search process is long, and after user data is obtained each time, a search from the head is needed to obtain a network structure, and when the network structure is used, pre-training is needed, and then migration learning is performed on the user data to obtain a required model. Since a long search waiting time is required, a user-friendly experience cannot be provided.
Therefore, the present application provides a model training method, an electronic device, and a storage medium to solve the above problems.
Some embodiments of the present application will be described in detail below with reference to the accompanying drawings. The embodiments described below and the features of the embodiments can be combined with each other without conflict.
Since the model training method provided by the embodiment of the application is realized by the super network model, before the model training method is introduced, the training method of the super network model is introduced first.
Referring to fig. 1, fig. 1 illustrates a training method of a hyper-network model provided in an embodiment of the present application, where the trained hyper-network model can be used to determine multiple target sub-network models.
As shown in fig. 1, the training method of the hyper-network model includes: step S101 to step S104.
And S101, acquiring a source data set.
The open source data may be a publicly trained set, such as an open source image set (imagenet), acquired over a network. Specifically, a corresponding open source data set can be searched in the open source data according to model requirement information provided by a user, and the open source data set is used for training the super network model. The model requirement information comprises information corresponding to the function, type and application range of the model, namely description information of what the user wants to do with the model.
Illustratively, the open source data set may also be a data set relating to CV (computer vision) tasks or NLP (natural language processing) tasks, etc.
In some embodiments, the model requirement information includes at least one of: task type, terminal type, application scenario, computational demand. The task type represents a practical scene of a target AI model required by a user, such as tasks required to be processed by the target AI model, such as classification, detection, video, natural language processing, and the like. The terminal type represents a deployment environment of the target AI model, such as a model of a terminal deploying the target AI model, a processor type, a model of the terminal, and the like. For example, the processor type of the terminal may include a CPU and/or a GPU. The application scenario may include at least one of: small sample detection, small object detection, unbalanced sample detection, and the like. The computational power requirements represent the ability of the model to process the task.
S102, acquiring a preset hyper-network, wherein the hyper-network comprises a first number of channels and a second number of layers.
The super network includes a first number of channels and a second number of layers, where the first number and the second number are greater than a preset number threshold to ensure that the neural network is the super network, the preset number threshold is specifically, for example, 100, for example, the width of a model first layer is 128 channels (channels), and the depth of the model is 101 layers, where the super network is specifically the neural network.
It is understood that the number of channels and layers of the super network may be any other number, and is not specifically limited herein.
S103, randomly switching off channels and/or layers of the super network, and training a batch of data for the rest network by using the open source data set.
The channels and/or layers of the super network are randomly switched off, a brand new network structure is formed after some channels and layers are disconnected, and the brand new network structure is trained according to the open source data set, namely a batch of (batch) data is trained. The rest of the network is a new network structure after some channels and some layers are cut off.
Illustratively, as shown in fig. 2, taking the super-network structure depth as 4 layers, each layer of the network width includes 6 channels as an example, the number of the channels of the 4 layers is sequentially 6-6-6-6, which is much larger than the 4 layers in practical application, and each layer includes 6 channels. After the channel with the width and/or the layer with the depth are randomly broken off, the obtained new network structure is changed into a network structure with the depth of 3 layers and the width of 5-5-6 in sequence. Wherein each layer and each channel can be broken at random. It should be noted that all of the "x" in fig. 2 are broken layers or channels.
And S104, obtaining a pre-trained hyper-network model until the preset hyper-network converges.
And repeating the steps of randomly switching off the channel and/or layer of the preset hyper-network and training a batch of data for the rest networks by using the open source data set until the preset hyper-network is converged to obtain a pre-trained hyper-network model.
Whether the hyper-network is converged can be judged by defining an error through a back propagation algorithm, and if the hyper-network convergence meets the agreed error condition, the success of the hyper-network convergence is determined, and a pre-trained hyper-network model is obtained; and if the ultra-network convergence does not meet the agreed error condition, repeatedly executing the step S103 and the step S104 until the preset ultra-network convergence to obtain a pre-trained ultra-network model. It follows that the super network model is a neural network model comprising a preset number of sub-network models. Therefore, the efficiency and the accuracy of subsequent model training can be improved, and the time cost of training is greatly reduced.
In some embodiments, after obtaining the pre-trained hyper-network model, the sub-network model may be directly extracted from the pre-trained hyper-network model, and the sub-network model is sent to the user, so that the user can directly use the sub-network model, which is suitable for the case where the user has low requirements for model accuracy and application scenarios. Therefore, when users with different requirements face, a suitable model can be provided in a targeted manner, unnecessary training is reduced, and user experience is improved.
According to the model training method of the super network, a pre-trained super network model can be obtained, the model training methods provided by the embodiment of the application are all based on the super network model, and therefore the model training method provided by the embodiment of the application can be introduced on the basis.
Referring to fig. 3, fig. 3 is a schematic flowchart of a model training method according to an embodiment of the present application. The model training method can be applied to electronic equipment, can realize efficient model training, improves the convenience of obtaining a required model by a user, reduces the time cost of training the model, and further improves the user experience.
The electronic device is, for example, a terminal device or a server, and the terminal device may be an electronic device such as a mobile phone, a tablet computer, a notebook computer, a desktop computer, a personal digital assistant, and a wearable device; the server may be an independent server or a server cluster, and may also be an Electronic Computer Service (ECS). In some embodiments, the electronic device includes a GPU for improving training efficiency of the model.
As shown in fig. 3, the model training method includes steps S201 to S206.
S201, obtaining a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models.
In an embodiment of the present application, the super network model obtained by the super network model training method may be obtained, where the super network model includes a preset number of sub network models, for example, 100 or more, and is not limited herein.
S202, determining a plurality of target sub-network models from the preset number of sub-network models of the hyper-network model.
And sampling the super network model according to a sampling algorithm, and screening out the sub network model meeting the preset model constraint condition according to the preset model constraint condition to serve as a target sub network model. The sub-network models are continuously collected until preset values are met, so as to determine a plurality of target sub-network models, wherein the preset values are M, for example, M can be set by a user according to the actual situation, and is not particularly limited herein.
For example, the value range of M is set to 10-50, and specifically, M may be set to be equal to 20.
Wherein the preset sampling algorithm comprises: at least one of a random sampling algorithm, an Evolutionary algorithm-based sampling algorithm (Evolutionary algorithm), and a Gradient-based sampling algorithm (Gradient-based method).
The random sampling algorithm is to randomly select a sub-network model from the super-network model; the sampling algorithm based on the evolutionary algorithm improves the sampling conformity on the premise of ensuring the sampling precision by utilizing the parallel iteration of the evolutionary algorithm; gradient-based sampling algorithms such as gradient descent can achieve corresponding changes in output based on the sampling results, increasing the probability of obtaining an optimal subnetwork. The three sampling methods can be applied to the embodiment of the application, but the sampling efficiency and the accuracy of the latter two methods are higher.
In some embodiments, taking a random sampling algorithm as an example, as shown in fig. 4, that is, the step of selecting a subnetwork model satisfying a preset model constraint condition from the subnetwork models specifically includes the following steps:
s2021, randomly selecting a sub-network model from the super-network model;
s2022, determining whether the operand of the sub-network model is smaller than a preset operand threshold value and whether the model parameter of the sub-network model is smaller than a preset parameter threshold value;
s2023, if the operand of the sub-network model is smaller than the preset operand threshold or the model parameter of the sub-network model is smaller than the preset parameter threshold, selecting the sub-network model.
Specifically, randomly selecting a sub-network model from a preset number of sub-network models in the super-network model, obtaining the operand and model parameters of the selected sub-network model, determining whether the operand of the selected sub-network model is smaller than a preset operand (FLOPS) threshold value, and whether the model parameters of the selected sub-network model are smaller than a preset parameter threshold value, and if the operand of the sub-network model is smaller than the preset operand threshold value and the model parameters of the sub-network model are smaller than the preset parameter threshold value, selecting the sub-network model; and if the operand of the sub-network model is greater than or equal to the preset operand threshold, or the model parameter number of the sub-network model is greater than or equal to the preset parameter threshold, discarding the sub-network model.
Specifically, if the computation of the sub-network model is greater than or equal to the preset computation threshold and the model parameter of the sub-network model is less than the preset parameter threshold, the sub-network model is discarded. And if the operand of the sub-network model is smaller than the preset operand threshold value and the model parameter quantity of the sub-network model is larger than or equal to the preset parameter threshold value, discarding the sub-network model.
Determining whether the operand of the selected sub-network model is smaller than a preset operand threshold, wherein the operand is the floating point operation number per second (FLOPS) and is used for determining whether the operand of the sub-network model meets the requirement; and determining whether the model parameter quantity of the sub-network model is smaller than the preset parameter quantity threshold value, wherein the model parameter quantity can comprise weight quantity used for determining whether the model parameter quantity of the sub-network model meets the requirement.
In some embodiments, judging whether the number of the collected sub-network models meets a preset value, if so, ending the cycle, and determining the selected sub-network model; and if the number of the collected sub-network models does not meet the preset value, continuously repeating the steps of S2021-S2023 until the number of the collected sub-network models meets the preset value. Therefore, the accuracy of the selected sub-network model can be improved, the optimal model is prevented from being missed due to contingency, and the fault tolerance rate is improved.
The preset value is a preset quantity value, and can be 50 or any quantity.
In some embodiments, the number of target sub-network models to be determined may also be determined according to the number of the master network models, for example, the number of the determined target sub-network models is equal to or greater than the number of the master network models, or the number of the master network models may be the same as or proportional to the number of the target sub-network models to be determined, or of course, other relationships may also be used, which is not specifically limited herein, and the accuracy and efficiency of the searched network structure may be improved by the number of the master network models.
In some embodiments, the number of target sub-network models to be determined may be defined, and then the number of acquired main-flow network models may be defined.
In some embodiments, the collected sub-network models may be further subjected to test evaluation to obtain an accuracy of the sub-network models, and test evaluation results are sorted to determine a plurality of target sub-network models. The ranking of the sub-network models can thus be clearly understood, and the optimal model can be selected.
Specifically, the collected multiple sub-network models are tested and evaluated according to a test set to obtain test evaluation results of the multiple sub-network models, and the test evaluation results are used for representing accuracy.
Specifically, a test set is used to perform test evaluation on a plurality of collected sub-network models to obtain test evaluation results of the plurality of sub-network models, where the test set may be test data used for outputting the test evaluation results of the test evaluation sub-network models.
For example, taking the image set as an example, the plurality of sub-network models may be tested by using test data of object identification, the plurality of sub-network models learn the test data, and output a test evaluation result, where the specific output test evaluation result may be a probability value, so as to calculate the accuracy of the plurality of sub-network models according to the test evaluation result.
It is understood that the plurality of sub-network models may perform a plurality of test evaluations and obtain a plurality of test evaluation results and test accuracy.
In some embodiments, the plurality of sub-network models may be further ranked according to the test evaluation results of the plurality of sub-network models, so as to obtain a ranking result of the plurality of sub-network models.
And sequencing the plurality of sub-network models according to the test evaluation results of the plurality of sub-network models on the test data, thereby obtaining the sequencing results of the plurality of sub-network models.
For example, the plurality of sub-network models may be sorted from high to low scores, such as sequentially from high to low scores or sequentially from low to high scores.
In some embodiments, a plurality of target subnetwork models are determined based on the ranking results and the number of target subnetwork models. And selecting a corresponding number of target sub-network models according to the number of the target sub-network models determined as required by the sequencing result.
For example, if the number of the acquired main flow network models is 10, the number of the target sub-network models to be determined is also 10, and according to the sorting result, for example, sorting is performed in sequence from the sub-network models with high scores, and the top 10 sub-network models are selected as the target sub-network models.
In some embodiments, if multiple test evaluations are performed, the test score ratios may be assigned according to the test importance levels, such as 60% for the test score ratio with high importance level and 40% for the test score ratio with low importance level, and finally the composite score is obtained.
Illustratively, if the person identification test and the animal identification test are performed, the ratio of the person identification test score is 60%, the ratio of the animal identification test score is 40%, and if the person identification test score of the sub-network A model is 100 points and the animal identification test score is 90 points, the comprehensive score of the sub-network A model is 96 points; if the character recognition test score of the sub-network model B is 90 scores and the animal recognition test score is 100 scores, the comprehensive score of the sub-model B is 94 scores, so that the ordering order of the sub-model A is higher than that of the sub-model B.
S203, acquiring a plurality of mainstream network models trained based on open source data.
The main stream network model is obtained based on open source data training, and a trained network model which is used by a user frequently can be obtained from an open source website and used as the main stream network model.
In some embodiments, the models of the plurality of acquired mainstream network models differ in complexity. Wherein the model complexity comprises at least one of a model operand and a model parameter. By selecting the mainstream network models with different model complexity, the accuracy of network structure search can be improved, and the accuracy of the model required by the final user can be further improved.
In some embodiments, as shown in fig. 5, that is, obtaining a plurality of mainstream network models trained based on open source data specifically includes the following steps:
s2031, determining the type of the pre-trained hyper-network model.
Wherein the type of the hyper-network model, such as detection, classification, identification, etc., indicates the usage of the hyper-network model. For example, if the pre-trained hyper-network model is image recognition, the type of the hyper-network model is recognition.
S2032, according to the type, selecting at least one open source network model matched with the type of the pre-trained hyper-network model from the open source network model set as a seed model.
And according to the determined type of the hyper-network model, selecting one or more open-source network models with matched types from the open-source network model set, and using the open-source network models as seed models. And the open source network model has different complexity.
Illustratively, if the determined type of the hyper-network model is a detection type, a plurality of open source network models for detection, such as a network model for human detection and a network model for animal detection, are selected from the open source network model set and used as seed models, such as a seed model using mobilenetV2 (lightweight convolutional neural network).
S2033, acquiring the requirement information of the model by the user, wherein the requirement information comprises the accuracy of the model and/or the magnitude of the model.
And acquiring the requirement information of the model, such as the accuracy of the model and/or the magnitude of the model, of the user. Illustratively, the accuracy requirements of the acquisition user on the model are up to 95% and the model is on the order of 2M (mega) parameters.
S2034, determining a conversion processing strategy for the model according to the demand information of the model by the user, wherein the conversion processing strategy at least comprises one of an increasing processing strategy and a compressing processing strategy.
Specifically, if the demand information has a high demand on the model, such as an accuracy requirement of 90% and a model with a magnitude of 2M (mega) parameters, the demand information with a high demand corresponds to an increased processing strategy. Conversely, if the demand information has a low demand on the model, such as an accuracy requirement of 80% and a model with a magnitude of 1M (mega) parameter, the demand information with a low demand corresponds to the compression processing strategy.
For example, the augmentation processing strategy may be added by adding RBF (radial basis function) or SE (Squeeze-and-Excitation) augmentation convolution layers to perform feature fusion on local regions, or designing more channel features.
Illustratively, the model can also be trained by adding the model to an attention mechanism to be used for functions of text summarization, reading comprehension, language modeling, syntactic analysis and the like.
For example, the compression processing strategy may obtain the mainstream network model by compressing the model parameters.
S2035, according to the conversion processing strategy, the seed model is converted to obtain a plurality of mainstream network models.
Illustratively, taking the example of using MobilenetV2 as a seed model, a number of different MobilenetV2 variant models of FLOPS can be obtained by an augmentation processing strategy or a compression processing strategy.
For example, as shown in fig. 6, fig. 6 is a schematic view of a scenario when a model training method provided in the embodiment of the present application is applied to a server. Before executing the model training method, the server may obtain model requirement information, i.e. a description of what the user wants to do with the model, from the terminal device. And executing the model training method to generate a target model, and sending the generated model to the terminal equipment so that the terminal equipment can perform operations such as model testing or deployment.
Specifically, the client may be an application program (APP), when a user opens the APP, the APP displays a requirement information interface, so that the user fills requirement information on the requirement information interface, acquires the requirement information filled by the user on the requirement information interface, sends the requirement information to the server, determines the type of the hyper-network model, the type of the open source network model set, or the accuracy, magnitude and the like of the model according to the requirement information, executes the model training method provided in the embodiment of the present application to perform model training, and sends the model obtained through training to the user, so that the user can use the model.
In some embodiments, the client may further obtain the voice of the user, for example, the client is provided with a voice button to prompt the user to issue the requirement information through voice, and the model requirement information of the user is obtained by recognizing the voice of the user.
S204, splicing each target subnetwork model serving as a first backbone network with a first branch network to obtain a plurality of first spliced networks, and splicing each mainstream network model serving as a second backbone network with a second branch network to obtain a plurality of second spliced networks.
And respectively taking a plurality of target sub-network models determined by the hyper-network model as a first trunk network and splicing with a first branch network to obtain a plurality of first spliced networks, and respectively taking a plurality of main flow network models trained based on open source data as a second trunk network and splicing with a second branch network to obtain a plurality of second spliced networks. The main network refers to a sub-network structure in a super network model and a main network model based on open source data training, the branch network is a network connected behind the main network, and the structure of the network can be communicated with all the main networks, namely the main network and the branch networks are connected in series and the branch networks are spliced behind the main network.
In some embodiments, the network structure of each first branch network spliced behind the first backbone network is the same and the parameters are shared, and the network structure of each second branch network spliced behind the second backbone network is the same and the parameters are not shared. The parameters include branch network parameters, which may include a backbone network (backbone), a head network (head), a neck network (neck), a learning rate, weight decay (weight decay), and the like.
Illustratively, the branch network structure and the branch network parameters of each first branch network are the same, and since the first backbone network is extracted from the super network, the parameters of the backbone network structure can be fixed by using the same branch network parameters. The branch network structure of each second branch network is the same, but the branch network parameters are different, and since the second backbone network is obtained from the open source data, different branch network parameters need to be configured to fix the parameters of the backbone network structure. By sharing the parameters of the first branch network, the network structure searching time can be further shortened, and the model training efficiency is further improved.
In some embodiments, the network structure of the second branch network may be set to be the same as that of the first branch network, so that the first splicing network and the second splicing network are conveniently trimmed and tested, and the accuracy of model training may be further improved.
S205, fine tuning and testing the first splicing networks and the second splicing networks to determine a target network model.
And respectively carrying out fine tuning (finetune) on the first splicing network and the second splicing network, testing, determining a target network model according to a test evaluation result, and carrying out transfer learning on the target network model to obtain a model required by a user. Therefore, the network parameters can be dynamically adjusted to train the target model which is more in line with the requirements of the user.
Illustratively, the first and second splicing networks may be individually fine-tuned, with the fine-tuning of branch network parameters such as head network (head) or neck network (tack) being used primarily on user data.
Illustratively, the first stitching network and the second stitching network are trimmed for N periods (epochs), where trimming 1 epoch means that all stitching networks in the training set have been trimmed 1 time, and N may be any number, generally 3 or 5.
In some embodiments, the first and second trimmed splicing networks are both subjected to test evaluation, and the test evaluation results are sorted for determining the target network model.
Illustratively, the same test evaluation is performed on both the first and second spliced networks, test evaluation results are obtained, all the test evaluation results are sorted, and the best spliced network is selected to determine the target network model.
For example, if multiple test evaluations are performed, the test score ratio may be assigned according to the test importance degree, for example, the test score ratio with a high importance degree is 60%, the test score ratio with a low importance degree is 40%, and finally, a composite score is obtained, and the composite score with a high composite score is selected as the target network model.
S206, carrying out transfer learning on the target network model to obtain a required model.
Specifically, the target network model is subjected to transfer learning by using scene data provided by a user, and the model required by the user can be efficiently obtained through the transfer learning. The transfer learning is to transfer the labeled data or the knowledge structure from the related field, so as to achieve the learning effect of completing or improving the target field or task.
In some embodiments, the transfer learning of the target network model may be based on a transfer learning algorithm of feature mapping or based on transfer of shared parameters, effective weight distribution is performed, and the example distribution of the domain of the target network model is made to approach the example distribution of the target domain, so that a reliable learning model with high classification accuracy is established in the target field, and a more suitable model is obtained.
In some embodiments, if the target network model is selected from the open source network model set, returning to S201-S203 of executing the model training method, re-determining the target subnetwork model, splicing, fine-tuning and testing the re-determined target subnetwork model and the first branch network to obtain the re-determined network model, comparing the test evaluation result of the re-determined network model with the test evaluation result of the target network model, and finally determining the target network model according to the test evaluation result until the determined target network model is sourced as the super network model, which is beneficial to improving the applicability and accuracy of the target network model and avoiding missing a more optimal network model due to the randomness of sampling.
Referring to fig. 7, fig. 7 is a schematic flowchart of another model training method according to an embodiment of the present application. The model training method can be applied to a server, realizes efficient model training, improves the convenience of obtaining a required model by a user, reduces the time cost of training the model, and improves the user experience.
As shown in fig. 7, the model training method includes steps S301 to S305.
S301, obtaining a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models;
s302, determining a plurality of target sub-network models from the preset number of sub-network models of the hyper-network model;
s303, acquiring a plurality of mainstream network models trained based on open source data;
s304, determining a target network model according to the plurality of target subnetwork models and the plurality of mainstream network models;
s305, performing transfer learning on the target network model to obtain a required model.
In the embodiment of the application, fine tuning (finetune) can be directly performed on a plurality of target sub-network models and a plurality of main-flow network models, and the target network models can be determined without splicing with a branch network.
Specifically, fine tuning (finetune) is performed on the multiple determined target sub-network models and the multiple mainstream network models trained based on open source data, testing is performed, and the target network models are determined according to test evaluation results and are used for performing migration learning on the target network models to obtain the needed models.
In some embodiments, the fine-tuned target sub-network models and the main-flow network models are both subjected to test evaluation, and the test evaluation results are sorted for determining the target network model.
Illustratively, the same test evaluation is performed on the plurality of target sub-network models and the plurality of main-flow network models, test evaluation results are obtained, all the test evaluation results are sorted, and the best network model is selected to determine the target network model.
Referring to fig. 8, fig. 8 is a schematic view of an electronic device 400 according to an embodiment of the present disclosure. The electronic device may be a server or a terminal.
As shown in fig. 8, the electronic device 400 includes a processor 402 and a memory 401 connected by a system bus, wherein the memory may include a nonvolatile storage medium and an internal memory.
The non-volatile storage medium may store an operating system and a computer program. The computer program includes program instructions that, when executed, cause a processor to perform any of the model training methods.
The processor is used for providing calculation and control capability and supporting the operation of the whole electronic equipment.
The internal memory provides an environment for the execution of a computer program on a non-volatile storage medium, which when executed by the processor, causes the processor to perform any of the model training methods.
It will be understood by those skilled in the art that the structure of the electronic device is a block diagram of only a part of the structure related to the present application, and does not constitute a limitation to the electronic device to which the present application is applied, and a specific electronic device may include more or less components than those shown in the drawings, or combine some components, or have different arrangements of components.
It should be understood that the Processor may be a Central Processing Unit (CPU), and the Processor may be other general purpose processors, Digital Signal Processors (DSPs), Application Specific Integrated Circuits (ASICs), Field Programmable Gate Arrays (FPGAs) or other Programmable logic devices, discrete Gate or transistor logic devices, discrete hardware components, etc. Wherein a general purpose processor may be a microprocessor or the processor may be any conventional processor or the like.
Wherein, in some embodiments, the processor is configured to execute a computer program stored in the memory to implement the steps of:
acquiring a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models; determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model; acquiring a plurality of mainstream network models trained based on open source data; splicing each target sub-network model serving as a first trunk network with a first branch network to obtain a plurality of first spliced networks, and splicing each main-flow network model serving as a second trunk network with a second branch network to obtain a plurality of second spliced networks, wherein the first branch networks spliced behind each first trunk network have the same network structure and shared parameters, and the second branch networks spliced behind each second trunk network have the same network structure and unshared parameters; fine-tuning and testing a plurality of the first and second spliced networks to determine a target network model; and carrying out transfer learning on the target network model to obtain a required model.
In some embodiments, when the processor is implemented to obtain the pre-trained hyper-network model, the processor is specifically configured to:
acquiring an open source data set; acquiring a preset hyper network, wherein the hyper network comprises a first number of channels and a second number of layers; randomly switching off the channel and/or layer of the preset super network, and training a batch of data for the rest network by using the open source data set; and repeating the steps of randomly switching off the channel and/or layer of the preset hyper-network and training a batch of data for the rest networks by using the open source data set until the preset hyper-network is converged to obtain a pre-trained hyper-network model.
In some embodiments, the processor is configured to select a sub-network model satisfying a preset model constraint condition from the super-network model, and specifically to:
randomly selecting a sub-network model from the super-network model; and determining whether the operand of the selected sub-network model is smaller than a preset operand threshold value and whether the model parameter number of the selected sub-network model is a preset parameter threshold value.
In some embodiments, the processor, in implementing determining the plurality of target subnetwork models, is specifically configured to:
determining the number of target sub-network models to be determined according to the number of the main-flow network models; testing the collected multiple sub-network models according to a test set to obtain the accuracy of the multiple sub-network models; sequencing the plurality of sub-network models according to the accuracy of the plurality of sub-network models to obtain a sequencing result of the plurality of sub-network models; and determining a plurality of target sub-network models according to the sequencing result and the number of the target sub-network models.
In some embodiments, the processor is specifically configured to, in implementing obtaining a plurality of mainstream network models trained based on open source data:
determining the type of the pre-trained hyper-network model; according to the type, selecting at least one open source network model matched with the type of the pre-trained hyper-network model from a public open source network model set as a seed model; acquiring demand information of a user on a model, wherein the demand information comprises the accuracy of the model and/or the magnitude of the model; determining a conversion processing strategy for the model according to the demand information of the user for the model, wherein the conversion processing strategy at least comprises an increasing processing strategy and a compressing processing strategy; and converting the seed model according to the conversion processing strategy to obtain a plurality of mainstream network models.
The embodiment of the present application further provides a computer-readable storage medium, where a computer program is stored on the computer-readable storage medium, where the computer program includes program instructions, and the program instructions, when executed, implement any one of the model training methods provided in the embodiment of the present application.
The computer-readable storage medium may be an internal storage unit of the electronic device according to the foregoing embodiment, for example, a hard disk or a memory of the electronic device. The computer readable storage medium may also be an external storage device of the electronic device, such as a plug-in hard disk, a Smart Media Card (SMC), a Secure Digital (SD) Card, a Flash memory Card (Flash Card), and the like, provided on the electronic device.
Further, the computer-readable storage medium may mainly include a storage program area and a storage data area, wherein the storage program area may store an operating system, an application program required for at least one function, and the like; the storage data area may store data created according to the use of the blockchain node, and the like.
The application refers to a novel application mode of computer technologies such as storage, point-to-point transmission, a consensus mechanism, an encryption algorithm and the like of a block chain language model. A block chain (Blockchain), which is essentially a decentralized database, is a series of data blocks associated by using a cryptographic method, and each data block contains information of a batch of network transactions, so as to verify the validity (anti-counterfeiting) of the information and generate a next block. The blockchain may include a blockchain underlying platform, a platform product service layer, an application service layer, and the like.
While the invention has been described with reference to specific embodiments, the scope of the invention is not limited thereto, and those skilled in the art can easily conceive various equivalent modifications or substitutions within the technical scope of the invention. Therefore, the protection scope of the present application shall be subject to the protection scope of the claims.

Claims (13)

1. A method of model training, the method comprising:
acquiring a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models;
determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model;
acquiring a plurality of mainstream network models trained based on open source data;
splicing each target sub-network model serving as a first backbone network with a first branch network to obtain a plurality of first spliced networks, and splicing each mainstream network model serving as a second backbone network with a second branch network to obtain a plurality of second spliced networks, wherein the first branch networks spliced behind each first backbone network have the same network structure and shared parameters, and the second branch networks spliced behind each second backbone network have the same network structure and unshared parameters;
fine-tuning and testing a plurality of the first splicing networks and the second splicing networks to determine a target network model;
and carrying out transfer learning on the target network model to obtain a required model.
2. The method of claim 1, further comprising:
acquiring an open source data set;
acquiring a preset hyper network, wherein the hyper network comprises a first number of channels and a second number of layers;
randomly shutting off channels and/or layers of the super network, and training a batch of data for the rest of networks by using the open source data set;
and repeating the steps of randomly switching off the channel and/or layer of the preset hyper-network and training a batch of data for the rest networks by using the open source data set until the hyper-network is converged to obtain a pre-trained hyper-network model.
3. The method of claim 1, wherein determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model comprises:
based on a preset sampling algorithm, selecting a sub-network model meeting preset model constraint conditions from the super-network model until the number of the collected sub-network models meets a preset value;
and performing test evaluation on the collected multiple sub-network models to determine multiple target sub-network models.
4. The method of claim 3, wherein the selecting the sub-network model satisfying the predetermined model constraint condition from the super-network model based on the predetermined sampling algorithm comprises:
randomly selecting a sub-network model from the super-network model;
determining whether the operand of the sub-network model is smaller than a preset operand threshold value and whether the model parameter number of the sub-network model is smaller than a preset parameter threshold value;
and if the operand of the sub-network model is smaller than the preset operand threshold value and the model parameter quantity of the sub-network model is smaller than the preset parameter threshold value, selecting the sub-network model.
5. The method of claim 1, wherein the performing test evaluations on the collected plurality of sub-network models to determine a plurality of target sub-network models comprises:
determining the number of target sub-network models to be determined according to the number of the main-flow network models;
testing the collected multiple sub-network models according to a test set to obtain the accuracy of the multiple sub-network models;
sequencing the plurality of sub-network models according to the accuracy of the plurality of sub-network models to obtain a sequencing result of the plurality of sub-network models; and
and determining a plurality of target sub-network models according to the sequencing result and the number of the target sub-network models.
6. The method of claim 1, wherein obtaining a plurality of mainstream network models trained based on open source data comprises:
determining the type of the pre-trained hyper-network model;
according to the type, selecting at least one open source network model matched with the type of the pre-trained hyper-network model from a public open source network model set as a seed model;
acquiring demand information of a user on a model, wherein the demand information comprises the accuracy of the model and/or the magnitude of the model;
determining a conversion processing strategy for the model according to the demand information of the user for the model;
and converting the seed model according to the conversion processing strategy to obtain a plurality of mainstream network models.
7. The method of claim 1, wherein the performing transfer learning on the target network model to obtain a required model comprises:
acquiring a backbone network of the target network model;
determining a branch network corresponding to the fine tuning and testing according to the target network model;
and splicing the main network and the branch network, and performing transfer learning training on scene data provided by a user to obtain a required model.
8. The method according to any of claims 1-7, wherein the network structure of the second branch network is the same as the network structure of the first branch network.
9. The method of any of claims 1-7, wherein the plurality of mainstream network models differ in model complexity, wherein the model complexity comprises at least one of a model operand and a model parameter.
10. A method of model training, the method comprising:
acquiring a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models;
determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model;
acquiring a plurality of mainstream network models trained based on open source data;
splicing each target sub-network model serving as a first trunk network with a first branch network to obtain a plurality of first spliced networks, and splicing each main flow network model serving as a second trunk network with a second branch network to obtain a plurality of second spliced networks;
fine-tuning and testing a plurality of the first and second spliced networks to determine a target network model;
and carrying out transfer learning on the target network model to obtain a required model.
11. A method of model training, the method comprising:
acquiring a pre-trained hyper-network model, wherein the hyper-network model comprises a preset number of sub-network models;
determining a plurality of target subnetwork models from a preset number of subnetwork models of the super network model;
acquiring a plurality of mainstream network models trained based on open source data;
determining a target network model according to the plurality of target sub-network models and the plurality of mainstream network models;
and carrying out transfer learning on the target network model to obtain a required model.
12. An electronic device, comprising a memory and a processor;
the memory for storing a computer program;
the processor is used for executing the computer program and realizing the following when the computer program is executed:
the model training method of any one of claims 1 to 11.
13. A computer-readable storage medium, characterized in that the computer-readable storage medium stores a computer program which, when executed by a processor, causes the processor to carry out the model training method according to any one of claims 1-11.
CN202011341114.4A 2020-11-25 2020-11-25 Model training method, electronic device and storage medium Pending CN114548353A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202011341114.4A CN114548353A (en) 2020-11-25 2020-11-25 Model training method, electronic device and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202011341114.4A CN114548353A (en) 2020-11-25 2020-11-25 Model training method, electronic device and storage medium

Publications (1)

Publication Number Publication Date
CN114548353A true CN114548353A (en) 2022-05-27

Family

ID=81660124

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202011341114.4A Pending CN114548353A (en) 2020-11-25 2020-11-25 Model training method, electronic device and storage medium

Country Status (1)

Country Link
CN (1) CN114548353A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115860135A (en) * 2022-11-16 2023-03-28 中国人民解放军总医院 Method, apparatus, and medium for solving heterogeneous federated learning using a super network

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115860135A (en) * 2022-11-16 2023-03-28 中国人民解放军总医院 Method, apparatus, and medium for solving heterogeneous federated learning using a super network

Similar Documents

Publication Publication Date Title
CN110837550B (en) Knowledge graph-based question answering method and device, electronic equipment and storage medium
Guo et al. Promptfl: Let federated participants cooperatively learn prompts instead of models-federated learning in age of foundation model
CN111310436B (en) Text processing method and device based on artificial intelligence and electronic equipment
WO2022134421A1 (en) Multi-knowledge graph based intelligent reply method and apparatus, computer device and storage medium
CN109062780A (en) The development approach and terminal device of automatic test cases
US11423307B2 (en) Taxonomy construction via graph-based cross-domain knowledge transfer
CN111859986B (en) Semantic matching method, device, equipment and medium based on multi-task twin network
CN110991658A (en) Model training method and device, electronic equipment and computer readable storage medium
CN112035549B (en) Data mining method, device, computer equipment and storage medium
CN111523324A (en) Training method and device for named entity recognition model
CN113435998B (en) Loan overdue prediction method and device, electronic equipment and storage medium
CN112035614B (en) Test set generation method, device, computer equipment and storage medium
CN111666393A (en) Verification method and device of intelligent question-answering system, computer equipment and storage medium
CN113190675A (en) Text abstract generation method and device, computer equipment and storage medium
CN109657125A (en) Data processing method, device, equipment and storage medium based on web crawlers
CN116956896A (en) Text analysis method, system, electronic equipment and medium based on artificial intelligence
CN113360300B (en) Interface call link generation method, device, equipment and readable storage medium
CN114548353A (en) Model training method, electronic device and storage medium
CN110516164A (en) A kind of information recommendation method, device, equipment and storage medium
CN114492742A (en) Neural network structure searching method, model issuing method, electronic device, and storage medium
CN113590786A (en) Data prediction method, device, equipment and storage medium
CN115328786A (en) Automatic testing method and device based on block chain and storage medium
CN111859985B (en) AI customer service model test method and device, electronic equipment and storage medium
CN112765481B (en) Data processing method, device, computer and readable storage medium
CN114610270A (en) AI model generation method, electronic device, and storage medium

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination