CN110796253A - Training method and device for generating countermeasure network - Google Patents
Training method and device for generating countermeasure network Download PDFInfo
- Publication number
- CN110796253A CN110796253A CN201911058600.2A CN201911058600A CN110796253A CN 110796253 A CN110796253 A CN 110796253A CN 201911058600 A CN201911058600 A CN 201911058600A CN 110796253 A CN110796253 A CN 110796253A
- Authority
- CN
- China
- Prior art keywords
- model
- parameters
- current
- discriminator
- loss function
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 66
- 238000000034 method Methods 0.000 title claims abstract description 56
- 238000005457 optimization Methods 0.000 claims abstract description 81
- 230000006870 function Effects 0.000 claims description 97
- 230000003042 antagnostic effect Effects 0.000 claims description 13
- 238000010586 diagram Methods 0.000 description 14
- 238000013461 design Methods 0.000 description 8
- 230000003121 nonmonotonic effect Effects 0.000 description 3
- 238000013256 Gubra-Amylin NASH model Methods 0.000 description 2
- 230000003044 adaptive effect Effects 0.000 description 2
- 230000000694 effects Effects 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 230000001902 propagating effect Effects 0.000 description 2
- 241000052079 Erioneuron Species 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013136 deep learning model Methods 0.000 description 1
- 230000008034 disappearance Effects 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 230000002093 peripheral effect Effects 0.000 description 1
- 238000012545 processing Methods 0.000 description 1
- 230000000644 propagated effect Effects 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Data Exchanges In Wide-Area Networks (AREA)
Abstract
The embodiment of the invention provides a training method and equipment for generating a countermeasure network, wherein the method comprises the steps of fixing parameters of a discriminator model, and carrying out iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold value; fixing the parameters of the generator model, and performing iterative optimization on the parameters of the discriminator model through a loss function until the discrimination of the discriminator model reaches a second threshold value; the value of the first threshold and the value of the second threshold are both related to the current alternation times; adding one to the number of alternations; repeatedly executing the steps until the discriminator model and the generator model reach Nash balance; and determining the trained generation countermeasure network according to the current discrimination model and the current generation model. The embodiment of the invention can control the iteration time of the generator model and the discriminator model, can improve the training efficiency of generating the confrontation model, and avoids the occurrence of model collapse.
Description
Technical Field
The embodiment of the invention relates to the technical field of artificial intelligence, in particular to a training method and training equipment for generating a confrontation network.
Background
Generation of a countermeasure network (GAN) is an unsupervised deep learning model that has recently been developed. The method has the core idea that Two-player Game (Two-player Game) is adopted, and Two Game players of the method consist of a generative Model and a discriminant Model.
In the prior art, a GAN model is usually trained with a fixed number of iterations.
However, the fixed training times can lead to a time-consuming training process and a problem of model collapse.
Disclosure of Invention
The embodiment of the invention provides a training method and equipment for generating a confrontation network, which are used for improving the training efficiency and avoiding model collapse.
In a first aspect, an embodiment of the present invention provides a training method for generating a countermeasure network, including:
fixing parameters of a discriminator model, and performing iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold value; the value of the first threshold value is related to the current alternation time;
fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value; the value of the second threshold is related to the current alternation time;
adding one to the number of alternations;
and repeatedly executing the steps until the optimized discriminator model and the optimized generator model reach Nash balance.
In one possible design, the fixing the parameters of the discriminator model and iteratively optimizing the parameters of the generator model by using a loss function until the similarity of the generator model reaches a first threshold includes:
inputting noise data into a current generator model to obtain a first false sample, and inputting the first false sample and a true sample into a discriminator model to obtain a first discrimination result;
reversely transmitting the first judgment result to the current generator model through a loss function, and optimizing the parameters of the current generator model;
determining the similarity of the current generator model according to the first judgment result, and comparing the similarity with a first threshold value;
if the similarity is smaller than the first threshold, repeatedly executing the steps of inputting noise data into the current generator model to obtain a first false sample, inputting the first false sample and a true sample into a discriminator model to obtain a first discrimination result, reversely transmitting the first discrimination result to the current generator model through a loss function, optimizing parameters of the current generator model, determining the similarity of the current generator model according to the first discrimination result, and comparing the similarity with the first threshold until the similarity is larger than or equal to the first threshold.
In one possible design, the determining the similarity of the generator models according to the first discrimination result includes:
determining a first number of the first false samples determined to be true according to the first discrimination result;
and calculating the ratio of the first number to the total number of the first false samples, and taking the ratio as the similarity.
In one possible design, fixing the parameters of the optimized generator model, and iteratively optimizing the parameters of the discriminator model through the loss function until the discrimination of the discriminator model reaches the second threshold includes:
inputting the noise data into a current generator model to obtain a second false sample;
inputting the second false sample and the true sample into a current discriminator model to obtain a second discrimination result;
reversely transmitting the second judgment result to the current discriminator model through a loss function, and optimizing the parameters of the current discriminator model;
determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value;
if the discrimination rate is smaller than the second threshold value, repeatedly executing the step of inputting the second false sample and the second true sample into the current discriminator model to obtain a second discrimination result, reversely transmitting the second discrimination result to the current discriminator model through a loss function, optimizing parameters of the current discriminator model, determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value until the discrimination rate is larger than or equal to the second threshold value.
In one possible design, the fixing the parameters of the discriminator model and before performing iterative optimization on the parameters of the generator model through the loss function includes:
constructing a loss function;
regularizing the loss function based on L1/2, L1 or L2 norm to obtain a regularized loss function;
the iterative optimization of the parameters of the generator model by the loss function includes:
iteratively optimizing parameters of the generator model through the regularized loss function;
the iterative optimization of the parameters of the discriminator model through the loss function comprises the following steps:
and performing iterative optimization on parameters of the discriminator model through the regularized loss function.
In one possible design, the iteratively optimizing the parameters of the generator model by the regularized loss function includes:
and iteratively optimizing the parameters of the generator model by using a gradient descent algorithm through the regularized loss function.
In one possible design, the gradient descent algorithm is a non-monotonic Barizilai-Borwein gradient algorithm.
In a second aspect, an embodiment of the present invention provides a training apparatus for generating an anti-network, including:
the generator optimization module is used for fixing the parameters of the discriminator model and performing iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold value; the value of the first threshold value is related to the current alternation time;
the discriminator optimization module is used for fixing the parameters of the optimized generator model, and carrying out iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value; the value of the second threshold is related to the current alternation time;
the adding module is used for adding one to the alternation times;
and the execution module is used for repeatedly executing the steps executed by the generator optimization module, the discriminator optimization module and the addition module until the optimized discriminator model and the optimized generator model reach Nash balance.
In a third aspect, an embodiment of the present invention provides a training apparatus for generating an anti-network, including: at least one processor and memory;
the memory stores computer-executable instructions;
the at least one processor executes computer-executable instructions stored by the memory to cause the at least one processor to perform the method as set forth in the first aspect above and in various possible designs of the first aspect.
In a fourth aspect, an embodiment of the present invention provides a computer-readable storage medium, in which computer-executable instructions are stored, and when a processor executes the computer-executable instructions, the method according to the first aspect and various possible designs of the first aspect are implemented.
The training method and the training equipment for generating the countermeasure network provided by the embodiment comprise the steps of fixing parameters of a discriminator model, and carrying out iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold value; the value of the first threshold value is related to the current alternation time; fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value; the value of the second threshold is related to the current alternation time; adding one to the number of alternations; repeatedly executing the steps until the optimized discriminator model and the optimized generator model reach Nash balance; and determining the trained generation countermeasure network according to the current discrimination model and the current generation model. In the training method for generating the countermeasure network provided by this embodiment, in the process of iterative optimization of the generator model and iterative optimization of the discriminator model, the number of times of iterative optimization in the current alternation number is controlled by the first threshold and the second threshold related to the current alternation number, respectively, so that the iteration time for the generator model and the discriminator model is controllable, the training efficiency for generating the countermeasure model is improved, and the occurrence of model collapse is avoided.
Drawings
In order to more clearly illustrate the embodiments of the present invention or the technical solutions in the prior art, the drawings needed to be used in the description of the embodiments or the prior art will be briefly introduced below, and it is obvious that the drawings in the following description are some embodiments of the present invention, and for those skilled in the art, other drawings can be obtained according to these drawings without creative efforts.
FIG. 1a is a diagram of a system architecture for training optimization of a generator model according to an embodiment of the present invention;
FIG. 1b is a diagram of a system architecture for training optimization of a discriminator model according to another embodiment of the present invention;
FIG. 2 is a flowchart illustrating a training method for generating a countermeasure network according to an embodiment of the present invention;
FIG. 3 is a diagram of a system architecture for training optimization of a generator model according to yet another embodiment of the present invention;
FIG. 4 is a diagram of a system architecture for training optimization of a discriminator model according to yet another embodiment of the present invention;
FIG. 5 is a flowchart illustrating a training method for generating a countermeasure network according to another embodiment of the present invention;
fig. 6 is a schematic structural diagram of a training device for generating a countermeasure network according to an embodiment of the present invention;
FIG. 7 is a schematic structural diagram of a training device for generating an anti-confrontation network according to another embodiment of the present invention;
fig. 8 is a schematic hardware structure diagram of a training device for generating a countermeasure network according to an embodiment of the present invention.
Detailed Description
In order to make the objects, technical solutions and advantages of the embodiments of the present invention clearer, the technical solutions in the embodiments of the present invention will be clearly and completely described below with reference to the drawings in the embodiments of the present invention, and it is obvious that the described embodiments are some, but not all, embodiments of the present invention. All other embodiments, which can be derived by a person skilled in the art from the embodiments given herein without making any creative effort, shall fall within the protection scope of the present invention.
FIG. 1a is a diagram of a system architecture for training optimization of a generator model according to an embodiment of the present invention; FIG. 1b is a diagram of an optimized system architecture for training a discriminator model according to yet another embodiment of the present invention. As shown in FIG. 1a, the generator model optimization system includes a generator model 101, a discriminator model 102, and a loss function 103. The generator model 101 is configured to receive noise data to generate a first false sample and output the first false sample to the discriminator model 102, and the discriminator model 102 is configured to receive the first false sample and a true sample and output a first discrimination result, and optimize parameters of the generator model 101 through a loss function 103. As shown in fig. 1b, the discriminator model optimization system includes a generator model 101, a discriminator model 102 and a loss function 103, the generator model 101 is configured to receive noise data to generate a second false sample and output the second false sample to the discriminator model 102, and the discriminator model 102 is configured to receive the second false sample and the true sample and output a second discrimination result, and optimize parameters of the discriminator model 102 through the loss function 103.
In a specific implementation process, the generator model and the discriminator model are alternately trained. When training the generator model, generating a first false sample according to noise data through the generator model 101, receiving the first false sample through the discriminator model 102, generating a first discrimination result according to the first false sample and a true sample, reversely propagating the first discrimination result to the generator model 101 through the loss function 103, optimizing parameters of the generator model 101, so as to complete first optimization under the current alternation number, generating a first false sample according to the noise data through the current generator model 101 with optimized parameters when performing second optimization, generating a first discrimination result according to the first false sample and the true sample, reversely propagating the first discrimination result to the generator model 101 through the loss function 103, optimizing parameters of the generator model 101, so as to complete second optimization under the current alternation number, and so on, optimizing the generator model 101 for multiple times, then optimizing the discriminator model 102, when optimizing the discriminator model 102, the generator model 101 is the generator model obtained after multiple times of optimization at the current alternation time, the generator model 101 generates a second false sample according to noise data, the discriminator model 102 generates a second discrimination result according to the second false sample and the true sample, the second discrimination result is reversely propagated into the discriminator model 102 through a loss function 103, parameters of the discriminator model 102 are optimized, so that the discriminator model 102 is optimized for the first time at the current alternation time, the optimized discriminator model 102 is further adopted for the second time of optimization, and the like, and the optimization of the generator model at the next alternation time is carried out after multiple times of optimization.
Therefore, in the alternate iterative optimization process, the optimization times of the generator model and the discriminator model under each alternate times play an important role in generating the training time of the antagonistic network. In the prior art, the optimization times of the generator model and the discriminator model are always fixed times. Therefore, the training time for generating the countermeasure network is long and uncontrollable, and the problem of model collapse is easy to occur in the process. Based on this, the embodiment provides a training method for generating a countermeasure network to improve the training efficiency and avoid model collapse.
The technical solution of the present invention will be described in detail below with specific examples. The following several specific embodiments may be combined with each other, and details of the same or similar concepts or processes may not be repeated in some embodiments.
Fig. 2 is a flowchart illustrating a training method for generating a countermeasure network according to an embodiment of the present invention.
As shown in fig. 2, the method includes:
201. fixing parameters of the discriminator model, and performing iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold value to obtain an optimized generator model; the value of the first threshold is related to the current alternation time.
Optionally, the similarity is related to the number of the false samples judged to be true by the discriminator, and specifically, the calculating process of the similarity may include: determining a first number of the first false samples determined to be true based on the first discrimination result.
And calculating the ratio of the first number to the total number of the first false samples, and taking the ratio as the similarity.
Optionally, a value of the first threshold is proportional to the current number of alternation.
Due to the complexity of the GAN model, the training process is time-consuming and the model is easy to crash, and the time consumption is mainly concentrated in the current link. Therefore, in the alternate iteration, the selection of each network iteration number is particularly important, and the training method provided by the embodiment provides a threshold-based adaptive alternate iteration strategy, so that the problems that the traditional alternate iteration is long in time consumption and the model is easy to crash are solved. Its adaptive threshold function is defined as:
y is a value of a first threshold, n represents the current network optimization times, the threshold is higher and higher as the optimization times increase, and the accuracy of the network approaches 100% infinitely.
Optionally, step 201 may specifically include:
2011. and inputting the noise data into a current generator model to obtain a first false sample, and inputting the first false sample and the true sample into a discriminator model to obtain a first discrimination result.
2012. And reversely transmitting the first judgment result to the current generator model through a loss function, and optimizing the parameters of the current generator model.
2013. And determining the similarity of the current generator model according to the first judgment result, and comparing the similarity with a first threshold value.
2014. If the similarity is smaller than the first threshold, repeatedly executing the steps of inputting noise data into the current generator model to obtain a first false sample, inputting the first false sample and a true sample into a discriminator model to obtain a first discrimination result, reversely transmitting the first discrimination result to the current generator model through a loss function, optimizing parameters of the current generator model, determining the similarity of the current generator model according to the first discrimination result, and comparing the similarity with the first threshold until the similarity is larger than or equal to the first threshold.
In practical application, fig. 3 is a system architecture diagram for training and optimizing a generator model according to another embodiment of the present invention; as shown in fig. 3, the generator model 101 generates a third false sample from the noise data, the discriminator model 102 receives the third false sample and the noise data and obtains a third discrimination result, the discriminator model 102 generates a first absolute loss from the third false sample and the true sample, and the loss function 103 optimizes the parameters of the generator module 101 according to the third discrimination result and the first absolute loss. The comparison module 104 determines the similarity of the current generator module according to the third judgment result. The similarity is compared with a first threshold, if the similarity is smaller than the first threshold, the optimization of the parameters of the generator model is continued according to the steps (the generator model 101 generates a first false sample according to noise data, the discriminator model 102 receives the first false sample and the noise data and obtains a third discrimination result, the discriminator model 102' generates an absolute loss according to the first false sample and a true sample, the loss function 103 optimizes the parameters of the generator module 101 according to the third discrimination result and the absolute loss, the comparison module 104 determines the similarity of the current generator module according to the third discrimination result and compares the similarity with the first threshold) until the current similarity is larger than or equal to the first threshold. And if the similarity is larger than or equal to the first threshold, finishing the iterative optimization of the generator model under the alternation times. A step of optimizing the discriminator model may be performed.
202. Fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value to obtain the optimized discriminator model; the value of the second threshold is related to the current alternation times;
optionally, the discrimination rate is related to the total number of samples determined to be true by the discriminator and the number of input true samples, and specifically, the calculation process of the discrimination rate may include: and determining a second number of samples judged to be true according to the second judgment result.
And calculating the ratio of the number of true samples to the second number, and taking the ratio as the discrimination.
Optionally, a value of the second threshold is proportional to the current number of alternation, and may refer to formula (1).
Optionally, step 202 may specifically include:
2021. inputting the noise data into a current generator model to obtain a second false sample;
2022. inputting the second false sample and the true sample into a current discriminator model to obtain a second discrimination result;
2023. reversely transmitting the second judgment result to the current discriminator model through a loss function, and optimizing the parameters of the current discriminator model;
2024. determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value;
2025. if the discrimination rate is smaller than the second threshold value, repeatedly executing the step of inputting the second false sample and the second true sample into the current discriminator model to obtain a second discrimination result, reversely transmitting the second discrimination result to the current discriminator model through a loss function, optimizing parameters of the current discriminator model, determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value until the discrimination rate is larger than or equal to the second threshold value.
In practical application, fig. 4 is a system architecture diagram for training and optimizing a discriminator model according to another embodiment of the present invention; as shown in fig. 4, the generator model 101 generates a fourth false sample from the noise data, the discriminator model 102 receives the fourth false sample and the noise data, and obtains a fourth discrimination result, the discriminator model 102 generates a second absolute loss from the fourth false sample and the true sample, and the loss function 103 optimizes parameters of the discriminator model 102 according to the fourth discrimination result and the second absolute loss. The comparison module 104 determines the discrimination of the current discriminator model 102 according to the fourth discrimination result. The discrimination is compared with a second threshold, if the discrimination is less than the second threshold, the iterative optimization of the discriminator model is ended according to the above steps (the generator model 101 generates a fourth false sample according to the noise data, the discriminator model 102 receives the fourth false sample and the noise data and obtains a fourth discrimination result, the discriminator model 102 generates a second absolute loss according to the fourth false sample and the true sample, the loss function 103 optimizes the parameters of the discriminator model 102 according to the fourth discrimination result and the second absolute loss, the comparison module 104 determines the discrimination of the current discriminator model 102 according to the fourth discrimination result and compares the discrimination with the second threshold) until the current discrimination is greater than or equal to the second threshold, and if the discrimination is greater than or equal to the second threshold, the iterative optimization of the discriminator model is ended. And (3) finishing the iterative optimization of the generator model and the iterative optimization of the discriminator model, adding one to the alternating times, judging whether the generator model and the discriminator model reach Nash balance, if so, finishing the alternating iterative optimization, generating a final generation countermeasure network, and if not, continuing to respectively carry out the iterative optimization on the generator model and the discriminator model.
203. Adding one to the number of alternations;
204. and repeatedly executing the step 201 to the step 203 until the optimized discriminator model and the optimized generator model reach Nash balance.
205. And determining the trained generation countermeasure network according to the current discrimination model and the current generation model.
In practice, experimental data of a true sample is first obtained, and optionally, the true sample may be obtained from a public MNIST dataset. For uniform formatting, all picture pixel sizes may be uniformly initialized to 128 × 128. This embodiment is not limited. Secondly, noise data conforming to a gaussian distribution can be selected as basic data for generating false samples by the generator. And constructing a loss function, and performing alternate iterative training on the generator model and the discriminator model according to the loss function. Specifically, the first stage: inputting the noise data generated in step 2 into a generator G while keeping the discriminator D unchanged; inputting the result output by the G and the real data acquired in the step 1 into a discriminator D, continuously updating the G network according to the parameters and the result returned by the D, and deceiving the discriminator D network as much as possible to achieve the purpose of reducing the accuracy of the discriminator D network; fig. 5 is an iterative schematic of a stage generator and discriminator. And a second stage: and keeping the generator G unchanged, inputting the pseudo data and the real data generated by the generator G into the discriminator D, and continuously updating the D network to improve the accuracy of the discriminator G. The first and second phases are alternately executed until the discriminator D and the generator G both reach the optimal convergence value, i.e. nash balance, and the iteration is ended. And obtaining the trained generation confrontation network.
In the training method for generating the countermeasure network provided by this embodiment, in the process of iterative optimization of the generator model and iterative optimization of the discriminator model, the number of times of iterative optimization in the current alternation number is controlled by the first threshold and the second threshold related to the current alternation number, respectively, so that the iteration time for the generator model and the discriminator model is controllable, the training efficiency for generating the countermeasure model is improved, and the occurrence of model collapse is avoided.
Fig. 5 is a flowchart illustrating a training method for generating a countermeasure network according to another embodiment of the present invention. On the basis of the above embodiment, the optimization of the loss function is explained in detail in this embodiment, as shown in fig. 5, the method includes:
501. a loss function is constructed.
502. And regularizing the loss function based on L1/2, L1 or L2 norm to obtain the regularized loss function.
Specifically, as shown in equations (2) and (3), a loss function is constructed, and the loss function is a very small game objective function:
the formula (2) is optimized and deformed to obtain the constructed loss function, as shown in the formula (3):
the problem of gradient disappearance easily occurs in the optimization process of the objective function shown in the formula (3), so that the learning training for generating the confrontation network model is extremely unstable. Therefore, a regularization penalty factor Ω (θ) needs to be added to the objective function to regularize the objective function, specifically, regularization may be performed based on L1 norm, L2 norm or L1/2 norm, which is exemplified below by L1/2 norm. In the target optimization, a LOSS function G _ LOSS of a generator model and a LOSS function D _ LOSS of a discriminator model are simultaneously optimized, the global optimal solution of the model is ensured, the problems that the G _ LOSS and the D _ LOSS are processed independently by the traditional GAN to cause only local optimization and model collapse are solved, and the regularized target function is as follows:
wherein, mu and lambda are non-negative error coefficients and are used for adjusting parameters of the penalty factor and the loss function weight. It is understood that the objective function becomes a constrained L1/2-norm minimization problem;
the L1/2 norm and the L1 norm are different, and the L1/2 norm objective function is non-convex, which will increase the complexity of the model.
Thus, in application, let
The model of equation (3) will be converted to:
for the constrained minimization problem of the objective function of equation (6), a normal projection gradient algorithm is used to solve. To increase the convergence speed. Exemplarily, the method can be solved based on a non-monotonic Barizilai-Borwein gradient algorithm, a non-monotonic search strategy determines a proper iteration direction, the convergence speed is improved, a Barzilai-Borwein step length is constructed each time, the global convergence of the algorithm is ensured, and a core model of the method is as follows:
xk+1=xk+SkDk(7)
wherein S iskIs the Barzilai-Borwein iteration step, DkA search direction defined by a non-monotonicity policy.
503. Fixing parameters of a discriminator model, and performing iterative optimization on the parameters of the generator model through a regularized loss function until the similarity of the generator model reaches a first threshold value to obtain an optimized generator model; the value of the first threshold is related to the current alternation time.
504. Fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the regularized loss function until the discrimination rate of the discriminator model reaches a second threshold value to obtain the optimized discriminator model; the value of the second threshold is related to the current alternation time.
505. The number of alternations is incremented by one.
506. And repeatedly executing the steps until the optimized discriminator model and the optimized generator model reach Nash balance.
507. And determining the trained generation countermeasure network according to the current discrimination model and the current generation model.
Steps 503 to 506 in this embodiment are similar to steps 201 to 204 in the above embodiment, and are not described again here.
According to the training method for generating the confrontation network, provided by the embodiment, the stability of training for generating the confrontation network can be enhanced by regularizing the loss function based on L1/2, L1 or L2 norm, and the training process can be accelerated by adopting the ordinary projection gradient algorithm, so that the efficiency is further improved.
Fig. 6 is a schematic structural diagram of a training device for generating a countermeasure network according to an embodiment of the present invention. As shown in fig. 6, the training device 60 for generating a countermeasure network includes: an execution module 601 and a determination module 602.
An executing module 601, configured to fix parameters of a discriminator model, and perform iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold; the value of the first threshold value is related to the current alternation time;
fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value; the value of the second threshold is related to the current alternation times;
adding one to the number of alternations;
repeating the steps executed by the generator optimization module, the discriminator optimization module and the addition module until the optimized discriminator model and the optimized generator model reach Nash balance;
a determining module 602, configured to determine a trained generative confrontation network according to the current discriminant model and the current generative model.
According to the training equipment for generating the countermeasure network, parameters of the discriminator model are fixed through the execution module, iterative optimization is carried out on the parameters of the generator model through the loss function until the similarity of the generator model reaches a first threshold value; the value of the first threshold value is related to the current alternation time; fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value; the value of the second threshold is related to the current alternation times; adding one to the number of alternations; repeating the steps executed by the generator optimization module, the discriminator optimization module and the addition module until the optimized discriminator model and the optimized generator model reach Nash balance; the determining module 602 determines the trained generative confrontation network according to the current discriminant model and the current generative model. According to the device, in the process of iterative optimization of the generator model and the iterative optimization of the discriminator model, the times of iterative optimization under the current alternation times are controlled by the first threshold and the second threshold related to the current alternation times respectively, so that the iteration time of the generator model and the discriminator model is controllable, the training efficiency of generating the countermeasure model is improved, and the condition of model collapse is avoided.
Fig. 7 is a schematic structural diagram of a training device for generating a countermeasure network according to another embodiment of the present invention. As shown in fig. 7, the training device 60 for generating an antagonistic network further includes: a build module 605 and a regularization module 606.
Optionally, the executing module 601 is specifically configured to: inputting noise data into a current generator model to obtain a first false sample, and inputting the first false sample and a true sample into a discriminator model to obtain a first discrimination result;
reversely transmitting the first judgment result to the current generator model through a loss function, and optimizing the parameters of the current generator model;
determining the similarity of the current generator model according to the first judgment result, and comparing the similarity with a first threshold value;
if the similarity is smaller than the first threshold, repeatedly executing the steps of inputting noise data into the current generator model to obtain a first false sample, inputting the first false sample and a true sample into a discriminator model to obtain a first discrimination result, reversely transmitting the first discrimination result to the current generator model through a loss function, optimizing parameters of the current generator model, determining the similarity of the current generator model according to the first discrimination result, and comparing the similarity with the first threshold until the similarity is larger than or equal to the first threshold.
Optionally, the executing module 601 is specifically configured to:
determining a first number of the first false samples determined to be true according to the first discrimination result;
and calculating the ratio of the first number to the total number of the first false samples, and taking the ratio as the similarity.
Optionally, the executing module 601 is specifically configured to:
inputting the noise data into a current generator model to obtain a second false sample;
inputting the second false sample and the true sample into a current discriminator model to obtain a second discrimination result;
reversely transmitting the second judgment result to the current discriminator model through a loss function, and optimizing the parameters of the current discriminator model;
determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value;
if the discrimination rate is smaller than the second threshold value, repeatedly executing the step of inputting the second false sample and the second true sample into the current discriminator model to obtain a second discrimination result, reversely transmitting the second discrimination result to the current discriminator model through a loss function, optimizing parameters of the current discriminator model, determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value until the discrimination rate is larger than or equal to the second threshold value.
Optionally, the apparatus further comprises:
a construction module 603 configured to construct a loss function;
a regularization module 604, configured to regularize the loss function based on a norm of L1/2, L1, or L2 to obtain a regularized loss function;
the execution module 601 is specifically configured to:
iteratively optimizing parameters of the generator model through the regularized loss function;
and performing iterative optimization on parameters of the discriminator model through the regularized loss function.
Optionally, the executing module 601 is specifically configured to: and (4) performing iterative optimization on the parameters of the generator model by using a normal projection gradient algorithm through the regularized loss function.
Optionally, the first threshold is proportional to the current number of alternations.
The training device for generating a countermeasure network provided in the embodiment of the present invention can be used to execute the above-described method embodiment, and the implementation principle and technical effect are similar, which are not described herein again.
Fig. 8 is a schematic hardware structure diagram of a training device for generating a countermeasure network according to an embodiment of the present invention. As shown in fig. 8, the training device 80 for generating a countermeasure network according to the present embodiment includes: at least one processor 801 and a memory 802. The processor 801 and the memory 802 are connected by a bus 803.
In a particular implementation, the at least one processor 801 executes the computer-executable instructions stored by the memory 802, causing the at least one processor 801 to perform the training method of generating an antagonistic network performed by the training apparatus 80 for generating an antagonistic network as described above.
For a specific implementation process of the processor 801, reference may be made to the above method embodiments, which have similar implementation principles and technical effects, and details of this embodiment are not described herein again.
In the embodiment shown in fig. 8, it should be understood that the Processor may be a Central Processing Unit (CPU), other general purpose processors, a Digital Signal Processor (DSP), an Application Specific Integrated Circuit (ASIC), etc. A general purpose processor may be a microprocessor or the processor may be any conventional processor or the like. The steps of a method disclosed in connection with the present invention may be embodied directly in a hardware processor, or in a combination of the hardware and software modules within the processor.
The memory may comprise high speed RAM memory and may also include non-volatile storage NVM, such as at least one disk memory.
The bus may be an Industry Standard Architecture (ISA) bus, a Peripheral Component Interconnect (PCI) bus, an Extended ISA (EISA) bus, or the like. The bus may be divided into an address bus, a data bus, a control bus, etc. For ease of illustration, the buses in the figures of the present application are not limited to only one bus or one type of bus.
The present application also provides a computer-readable storage medium having stored therein computer-executable instructions that, when executed by a processor, implement a training method of generating an antagonistic network performed by the training apparatus of generating an antagonistic network as described above.
The present application also provides a computer-readable storage medium having stored therein computer-executable instructions that, when executed by a processor, implement a training method of generating an antagonistic network performed by the training apparatus of generating an antagonistic network as described above.
The computer-readable storage medium may be implemented by any type of volatile or non-volatile memory device or combination thereof, such as Static Random Access Memory (SRAM), electrically erasable programmable read-only memory (EEPROM), erasable programmable read-only memory (EPROM), programmable read-only memory (PROM), read-only memory (ROM), magnetic memory, flash memory, magnetic or optical disk. Readable storage media can be any available media that can be accessed by a general purpose or special purpose computer.
An exemplary readable storage medium is coupled to the processor such the processor can read information from, and write information to, the readable storage medium. Of course, the readable storage medium may also be an integral part of the processor. The processor and the readable storage medium may reside in an Application Specific Integrated Circuits (ASIC). Of course, the processor and the readable storage medium may also reside as discrete components in the apparatus.
Those of ordinary skill in the art will understand that: all or a portion of the steps of implementing the above-described method embodiments may be performed by hardware associated with program instructions. The program may be stored in a computer-readable storage medium. When executed, the program performs steps comprising the method embodiments described above; and the aforementioned storage medium includes: various media that can store program codes, such as ROM, RAM, magnetic or optical disks.
Finally, it should be noted that: the above embodiments are only used to illustrate the technical solution of the present invention, and not to limit the same; while the invention has been described in detail and with reference to the foregoing embodiments, it will be understood by those skilled in the art that: the technical solutions described in the foregoing embodiments may still be modified, or some or all of the technical features may be equivalently replaced; and the modifications or the substitutions do not make the essence of the corresponding technical solutions depart from the scope of the technical solutions of the embodiments of the present invention.
Claims (10)
1. A training method for generating an antagonistic network, comprising:
fixing parameters of the discriminator model, and performing iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold value to obtain an optimized generator model; the value of the first threshold value is related to the current alternation times;
fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value to obtain an optimized discriminator model; the value of the second threshold is related to the current alternation times;
adding one to the number of alternations;
repeatedly executing the steps until the optimized discriminator model and the optimized generator model reach Nash balance;
and determining the trained generation countermeasure network according to the current discrimination model and the current generation model.
2. The method of claim 1, wherein fixing the parameters of the discriminator model and iteratively optimizing the parameters of the generator model by a loss function until the similarity of the generator model reaches a first threshold comprises:
inputting noise data into a current generator model to obtain a first false sample, and inputting the first false sample and a true sample into a discriminator model to obtain a first discrimination result;
reversely transmitting the first judgment result to the current generator model through a loss function, and optimizing the parameters of the current generator model;
determining the similarity of the current generator model according to the first judgment result, and comparing the similarity with a first threshold value;
if the similarity is smaller than the first threshold, repeatedly executing the steps of inputting noise data into the current generator model to obtain a first false sample, inputting the first false sample and a true sample into a discriminator model to obtain a first discrimination result, reversely transmitting the first discrimination result to the current generator model through a loss function, optimizing parameters of the current generator model, determining the similarity of the current generator model according to the first discrimination result, and comparing the similarity with the first threshold until the similarity is larger than or equal to the first threshold.
3. The method of claim 2, wherein said determining a similarity of said generator model based on said first decision comprises:
determining a first number of the first false samples determined to be true according to the first discrimination result;
and calculating the ratio of the first number to the total number of the first false samples, and taking the ratio as the similarity.
4. The method of claim 1, wherein fixing the parameters of the optimized generator model and iteratively optimizing the parameters of the discriminator model by the loss function until the discrimination of the discriminator model reaches a second threshold value comprises:
inputting the noise data into a current generator model to obtain a second false sample;
inputting the second false sample and the true sample into a current discriminator model to obtain a second discrimination result;
reversely transmitting the second judgment result to the current discriminator model through a loss function, and optimizing the parameters of the current discriminator model;
determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value;
if the discrimination rate is smaller than the second threshold value, repeatedly executing the step of inputting the second false sample and the second true sample into the current discriminator model to obtain a second discrimination result, reversely transmitting the second discrimination result to the current discriminator model through a loss function, optimizing parameters of the current discriminator model, determining the discrimination rate of the current discriminator model according to the second discrimination result, and comparing the discrimination rate with a second threshold value until the discrimination rate is larger than or equal to the second threshold value.
5. The method of claim 1, wherein the fixing the parameters of the discriminator model and before the iterative optimization of the parameters of the generator model by the loss function comprises:
constructing a loss function;
regularizing the loss function based on L1/2, L1 or L2 norm to obtain a regularized loss function;
the iterative optimization of the parameters of the generator model by the loss function includes:
iteratively optimizing parameters of the generator model through the regularized loss function;
the iterative optimization of the parameters of the discriminator model through the loss function comprises the following steps:
and performing iterative optimization on parameters of the discriminator model through the regularized loss function.
6. The method of claim 5, wherein iteratively optimizing the parameters of the generator model by the regularized loss function comprises:
and (4) performing iterative optimization on the parameters of the generator model by using a normal projection gradient algorithm through the regularized loss function.
7. The method of any of claims 1-6, wherein the first threshold is proportional to the current number of alternations.
8. A training device for generating an antagonistic network, comprising:
the execution module is used for fixing the parameters of the discriminator model and performing iterative optimization on the parameters of the generator model through a loss function until the similarity of the generator model reaches a first threshold value; the value of the first threshold value is related to the current alternation time;
fixing the parameters of the optimized generator model, and performing iterative optimization on the parameters of the discriminator model through the loss function until the discrimination rate of the discriminator model reaches a second threshold value; the value of the second threshold is related to the current alternation times;
adding one to the number of alternations;
repeating the steps executed by the generator optimization module, the discriminator optimization module and the addition module until the optimized discriminator model and the optimized generator model reach Nash balance;
and the determining module is used for determining the trained generation countermeasure network according to the current discrimination model and the current generation model.
9. A training device for generating an antagonistic network, comprising: at least one processor and memory;
the memory stores computer-executable instructions;
the at least one processor executing the computer-executable instructions stored by the memory causes the at least one processor to perform the training method for generating an antagonistic network of any one of claims 1 to 7.
10. A computer-readable storage medium having stored thereon computer-executable instructions which, when executed by a processor, implement the training method for generating an antagonistic network according to any one of claims 1 to 7.
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911058600.2A CN110796253A (en) | 2019-11-01 | 2019-11-01 | Training method and device for generating countermeasure network |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201911058600.2A CN110796253A (en) | 2019-11-01 | 2019-11-01 | Training method and device for generating countermeasure network |
Publications (1)
Publication Number | Publication Date |
---|---|
CN110796253A true CN110796253A (en) | 2020-02-14 |
Family
ID=69440726
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN201911058600.2A Pending CN110796253A (en) | 2019-11-01 | 2019-11-01 | Training method and device for generating countermeasure network |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN110796253A (en) |
Cited By (10)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111582647A (en) * | 2020-04-09 | 2020-08-25 | 上海淇毓信息科技有限公司 | User data processing method and device and electronic equipment |
CN111738367A (en) * | 2020-08-17 | 2020-10-02 | 成都中轨轨道设备有限公司 | Part classification method based on image recognition |
CN111783955A (en) * | 2020-06-30 | 2020-10-16 | 北京市商汤科技开发有限公司 | Neural network training method, neural network training device, neural network dialogue generating method, neural network dialogue generating device, and storage medium |
CN111794741A (en) * | 2020-08-11 | 2020-10-20 | 中国石油天然气集团有限公司 | Method for realizing sliding directional drilling simulator |
CN111914488A (en) * | 2020-08-14 | 2020-11-10 | 贵州东方世纪科技股份有限公司 | Data regional hydrological parameter calibration method based on antagonistic neural network |
CN111931062A (en) * | 2020-08-28 | 2020-11-13 | 腾讯科技(深圳)有限公司 | Training method and related device of information recommendation model |
CN113408808A (en) * | 2021-06-28 | 2021-09-17 | 北京百度网讯科技有限公司 | Training method, data generation method, device, electronic device and storage medium |
CN114301637A (en) * | 2021-12-11 | 2022-04-08 | 河南大学 | Intrusion detection method and system for medical Internet of things |
CN116206622A (en) * | 2023-05-06 | 2023-06-02 | 北京边锋信息技术有限公司 | Training and dialect conversion method and device for generating countermeasure network and electronic equipment |
CN117407784A (en) * | 2023-12-13 | 2024-01-16 | 北京理工大学 | Sensor data abnormality-oriented intelligent fault diagnosis method and system for rotary machine |
-
2019
- 2019-11-01 CN CN201911058600.2A patent/CN110796253A/en active Pending
Cited By (18)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN111582647A (en) * | 2020-04-09 | 2020-08-25 | 上海淇毓信息科技有限公司 | User data processing method and device and electronic equipment |
CN111783955A (en) * | 2020-06-30 | 2020-10-16 | 北京市商汤科技开发有限公司 | Neural network training method, neural network training device, neural network dialogue generating method, neural network dialogue generating device, and storage medium |
CN111794741B (en) * | 2020-08-11 | 2023-08-18 | 中国石油天然气集团有限公司 | Method for realizing sliding directional drilling simulator |
CN111794741A (en) * | 2020-08-11 | 2020-10-20 | 中国石油天然气集团有限公司 | Method for realizing sliding directional drilling simulator |
CN111914488B (en) * | 2020-08-14 | 2023-09-01 | 贵州东方世纪科技股份有限公司 | Data area hydrologic parameter calibration method based on antagonistic neural network |
CN111914488A (en) * | 2020-08-14 | 2020-11-10 | 贵州东方世纪科技股份有限公司 | Data regional hydrological parameter calibration method based on antagonistic neural network |
CN111738367B (en) * | 2020-08-17 | 2020-11-13 | 成都中轨轨道设备有限公司 | Part classification method based on image recognition |
CN111738367A (en) * | 2020-08-17 | 2020-10-02 | 成都中轨轨道设备有限公司 | Part classification method based on image recognition |
CN111931062B (en) * | 2020-08-28 | 2023-11-24 | 腾讯科技(深圳)有限公司 | Training method and related device of information recommendation model |
CN111931062A (en) * | 2020-08-28 | 2020-11-13 | 腾讯科技(深圳)有限公司 | Training method and related device of information recommendation model |
CN113408808A (en) * | 2021-06-28 | 2021-09-17 | 北京百度网讯科技有限公司 | Training method, data generation method, device, electronic device and storage medium |
CN113408808B (en) * | 2021-06-28 | 2024-01-12 | 北京百度网讯科技有限公司 | Training method, data generation device, electronic equipment and storage medium |
CN114301637B (en) * | 2021-12-11 | 2022-09-02 | 河南大学 | Intrusion detection method and system for medical Internet of things |
CN114301637A (en) * | 2021-12-11 | 2022-04-08 | 河南大学 | Intrusion detection method and system for medical Internet of things |
CN116206622A (en) * | 2023-05-06 | 2023-06-02 | 北京边锋信息技术有限公司 | Training and dialect conversion method and device for generating countermeasure network and electronic equipment |
CN116206622B (en) * | 2023-05-06 | 2023-09-08 | 北京边锋信息技术有限公司 | Training and dialect conversion method and device for generating countermeasure network and electronic equipment |
CN117407784A (en) * | 2023-12-13 | 2024-01-16 | 北京理工大学 | Sensor data abnormality-oriented intelligent fault diagnosis method and system for rotary machine |
CN117407784B (en) * | 2023-12-13 | 2024-03-12 | 北京理工大学 | Sensor data abnormality-oriented intelligent fault diagnosis method and system for rotary machine |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN110796253A (en) | Training method and device for generating countermeasure network | |
US11568258B2 (en) | Operation method | |
CN111414987B (en) | Training method and training device of neural network and electronic equipment | |
US10380479B2 (en) | Acceleration of convolutional neural network training using stochastic perforation | |
JP6965690B2 (en) | Devices and methods for improving the processing speed of neural networks, and their applications | |
US20180260709A1 (en) | Calculating device and method for a sparsely connected artificial neural network | |
CN109410974B (en) | Voice enhancement method, device, equipment and storage medium | |
US11775832B2 (en) | Device and method for artificial neural network operation | |
US20230196202A1 (en) | System and method for automatic building of learning machines using learning machines | |
CN110766044A (en) | Neural network training method based on Gaussian process prior guidance | |
CN111178520A (en) | Data processing method and device of low-computing-capacity processing equipment | |
WO2021051556A1 (en) | Deep learning weight updating method and system, and computer device and storage medium | |
CN114282666A (en) | Structured pruning method and device based on local sparse constraint | |
CN115129386A (en) | Efficient optimization for neural network deployment and execution | |
CN111160531A (en) | Distributed training method and device of neural network model and electronic equipment | |
CN111860364A (en) | Training method and device of face recognition model, electronic equipment and storage medium | |
CN116029359A (en) | Computer-readable recording medium, machine learning method, and information processing apparatus | |
US11886832B2 (en) | Operation device and operation method | |
CN115080139A (en) | Efficient quantization for neural network deployment and execution | |
CN109388784A (en) | Minimum entropy Density Estimator device generation method, device and computer readable storage medium | |
CN112906861A (en) | Neural network optimization method and device | |
US20200134434A1 (en) | Arithmetic processing device, learning program, and learning method | |
US20220405561A1 (en) | Electronic device and controlling method of electronic device | |
US20210012192A1 (en) | Arithmetic processing apparatus, control method, and non-transitory computer-readable recording medium having stored therein control program | |
JP7279507B2 (en) | Information processing device, information processing program and control method |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20200214 |
|
RJ01 | Rejection of invention patent application after publication |