Disclosure of Invention
The invention aims to provide a method for detecting GAN attack in combined deep learning.
The technical solution for realizing the purpose of the invention is as follows: a method for detecting and generating counterattack network attack in combined deep learning comprises the following specific steps:
step 1, initializing a combined deep learning model: the method comprises the following steps that a server and participants initiate a joint deep learning model training task, the server and the participants jointly determine the system structure, the target, the label and the like of the joint deep learning model, and initialization of the joint deep learning model is completed;
step 2, the server simulates GAN attack and acquires sample data: the server simulates the training process of the combined deep learning model, and launches a GAN attack on the simulation model to obtain the updated gradient data in the training process of the simulation model;
step 3, the server constructs a GAN attack detection classifier through a deep neural network and trains the GAN attack detection classifier;
and 4, step 4: and (3) extracting the characteristics of the updated gradient in the training process of the combined deep learning model through a layer classifier, inputting the extracted gradient into a detecting GAN attack total classifier for prediction, and obtaining the probability of malicious data containing wrong classification information uploaded by a participant.
Preferably, the participants include attackers who launch GAN attacks and normal participants who do not launch attacks.
Preferably, the simulation model and the joint deep learning model have the same architecture and initial values.
Preferably, the specific steps of the training process of the server simulation joint deep learning model are as follows:
the server constructs an auxiliary training set according to a training target and a label of the joint deep learning model;
the server divides the auxiliary training set into a malicious data set and a normal data set;
and the server uses the malicious data set and the normal training set to carry out simulation model training and initiates a GAN attack to obtain simulation model update gradient data with correct classification and malicious simulation model update gradient data containing wrong classification information.
Preferably, the specific steps of constructing and training the GAN attack detection classifier are as follows:
step 3.1: the server marks the updated gradient data obtained in the simulated GAN attack in the step 2 into different labels according to the normal updated gradient data and the malicious gradient data, trains two classification tasks and distinguishes attackers and normal participants;
step 3.2: normalizing the update gradients of the simulation models uploaded after training of normal participants and attackers;
step 3.3: constructing a layer classifier for the weight parameter of each layer of the simulation model updating gradient, wherein the input dimension of the layer classifier is the dimension of the weight parameter of the layer corresponding to the simulation model updating gradient, and the output of the layer classifier is a score which represents the probability that the simulation model updating gradient is malicious data containing wrong classification information uploaded by a participant;
step 3.4: for the weight parameter of each layer of the simulation model updating gradient, performing feature extraction by using the layer classifier trained in the step 2, and taking the output of the layer classifier as the weight feature of the layer corresponding to the simulation model updating gradient; taking the bias of the layer corresponding to the updated gradient of the simulation model as a bias characteristic;
the aggregation simulation model updates the weight characteristics and the bias characteristics of each layer of the gradient to obtain the overall characteristics of the updated gradient;
constructing a general classifier for detecting GAN attack by using a neural network, wherein the input of the general classifier is the dimensionality of general characteristics, the output of the general classifier is a score, and the score represents the probability that a model updating parameter is malicious data which is uploaded by a participant and contains wrong classification information;
and inputting the overall characteristics of the updated gradient and the corresponding labels into a general classifier to train the overall characteristics.
Preferably, the training process of the joint deep learning model specifically includes:
the participator downloads the latest parameters of the joint deep learning model from the server, carries out local training and uploads the updated gradient loss parameters; and the server extracts the characteristics of the updated gradient in the training process of the combined deep learning model through a layer classifier, inputs the updated gradient into a trained global classifier for detecting GAN attack to predict, and obtains the probability of malicious data containing wrong classification information uploaded by participants.
Compared with the prior art, the invention has the following remarkable advantages: 1) the invention realizes the defense of the GAN attack by using the attack detection method for the first time, and protects the privacy of participants and the safety of a model while ensuring the normal training of joint learning; 2) the method for detecting the GAN attack is an active defense, can identify the identity of an attacker, and can actively limit the behavior of the attacker; 3) the method retains the characteristics of distributed and parallelized joint learning, does not increase the calculation overhead and communication overhead of participants, and does not influence the accuracy and convergence speed of the model; 4) when the invention trains the GAN attack detection classifier, the supervised learning and the unsupervised learning are combined, thereby improving the accuracy of the classifier.
The present invention is described in further detail below with reference to the attached drawings.
Detailed Description
A method for detecting and generating a counterattack network attack in combined deep learning is shown in figure 1, and comprises the following specific steps:
step 1, initializing a combined deep learning model: the server and the participants initiate a joint deep learning model training task, the server and the participants jointly determine the system structure, the target, the label and the like of the joint deep learning model, and the initialization of the model is completed.
The joint deep learning model training task is a training task under the white-box condition, namely, the participants know the specific details of the model, including the parameter number of each layer of the neural network, the setting of an activation function, a loss function and the like.
Step 2, the server simulates GAN attack and acquires sample data: the server simulates the training process of a combined deep learning model (simulation model), launches a GAN attack on the simulation model, and acquires the update gradient data of a participant in the training process of the simulation model;
specifically, the participants trained by the simulation model include an "attacker" who initiates a GAN attack and a "normal participant" who does not initiate an attack. The server extracts the updated gradient data of the 'attacker' and the 'normal participant' in the simulation model training process, and the updated gradient data is used as a sample for detecting and generating the network attack resistance. The simulation model and the joint deep learning model in step 1 have the same architecture and initial values.
Specifically, the specific steps of the training process of the server simulation joint deep learning model (simulation model) are as follows:
and the server constructs an auxiliary training set Data _ aux according to the training targets and the labels of the joint deep learning model. The auxiliary data set may be obtained from a common data set by means of sampling.
The server divides the secondary training set into two parts: malicious dataset
An attacker maliciously modifies the label of the data set, participates in the training of the simulation model, and uploads the simulation model update gradient with wrong classification; the normal Data set Data _ P, the Data set owned by the "normal participant", does not modify the Data set any way. The 'normal participants' participate in the training of the simulation model, and upload the simulation model update gradient with correct classification.
And the server performs simulation model training by using the malicious data set and the normal training set, and initiates a GAN attack to obtain simulation model update gradient data which are uploaded by normal participants and correctly classified, and malicious simulation model update gradient data which are uploaded by attackers and contain wrong classification information.
Specifically, in the joint training process of the simulation model, the update gradients uploaded by the "attacker" and the "normal participant" include different features of the training data: for the training set (x, y), where x is the input data of the neural network and y is the label, the neural network first propagates forward to compute the Loss function Loss (f (x; w), y), f () is the neural network model and w is the parameter of the model, and then the neural network tries to minimize the empirical expectation E of the Loss function L to get the correct prediction result. In neural network back propagation, gradient loss by calculating all parameters
The parameter w is updated. (x, y) can be represented as
In a certain form, i.e. from
Can deduce the information of (x, y). Update gradient uploaded by participants in the invention
Inferring whether participants are trainedThe training data set (x, y) is injected with erroneous classification information.
In this step, the simulation training obtains the update gradient samples of the "attacker" and the "normal participant", and the samples are used as the training sample Data of the GAN attack detection classifier (f)
1,f
2,...f
n) Wherein f is
iRepresenting the gradient loss parameters uploaded by the "attacker" and the "normal participant" during the simulation training. Each sample f
iThe corresponding label is "attacker"
And "normal participants" P.
Step 3, constructing a GAN attack detection classifier and training the GAN attack detection classifier: the server constructs a GAN attack detection layer classifier according to the parameters of each layer of the neural network, then constructs a GAN attack detection total classifier through the layer classifier,
as shown in fig. 2, in a further embodiment, the specific steps of constructing the GAN attack detection classifier are as follows:
step 3.1: the server marks the updated gradient data obtained in the simulated GAN attack in the step 2 into different labels, the normal updated gradient data uploaded by the normal participant is marked as P, and the malicious gradient data uploaded by the attacker and containing the wrong classification information is marked as P
Training a secondary classification task through a neural network so as to distinguish an attacker from a normal participant;
step 3.2: sample data preprocessing: and normalizing the update gradients of the simulation models uploaded after training of normal participants and attackers. The influence of the difference of the data volume and the number of rounds of local training on the amplitude of the gradient update is reduced.
Step 3.3: constructing a GAN attack detection layer classifier and training the classifier: constructing a layer classifier for each layer parameter of the simulation model updating gradient; the layer classifier is a neural network and comprises an input layer, a plurality of hidden layers and an output layer; and each layer of the simulation model updating gradient comprises a weight (Wieght, W) and a Bias (Bias, B), and a GAN attack detection layer classifier of the layer is constructed according to the weight parameter of each layer of the simulation model updating gradient. The input dimension of the layer classifier is the dimension of the layer weight parameter, the output of the layer classifier is a score, and the score represents the probability that the model updating parameter is malicious data containing wrong classification information uploaded by a participant;
respectively inputting the weight and the label of each layer of the simulation model updating gradient into a corresponding GAN attack detection layer classifier to train the GAN attack detection layer classifier;
step 3.4: constructing a GAN attack detection total classifier and training the general classifier: for the weight parameter of each layer of the simulation model updating gradient, using the layer classifier trained in the step 2 to perform feature extraction, and using the output of the last hidden layer of the layer classifier as the weight feature extraction result of the layer; and updating the bias of the corresponding layer of the gradient by the simulation model as the bias characteristic of the layer.
Aggregating the weight features and bias features of each layer to obtain an overall feature of the update gradient: l ═ L1+ Δ B1 '+ L2+ Δ B2' +. + Ln + Δ Bn '), where Li is the weight feature after each layer feature extraction and Δ Bi' is the bias feature after each layer normalization.
Constructing a global classifier for detecting GAN attacks by utilizing a neural network, wherein the neural network comprises an input layer, a plurality of hidden layers and an output layer;
the input of the overall classifier is the dimension of the overall characteristic L, the output of the overall classifier is a score, and the score represents the probability that the model updating parameter is malicious data which is uploaded by the participant and contains wrong classification information.
And inputting the overall characteristics of the updated gradient and the corresponding labels into a general classifier to train the overall characteristics.
And 4, step 4: and (3) extracting the characteristics of the updated gradient in the training process of the combined deep learning model through a layer classifier, inputting the extracted gradient into a detecting GAN attack total classifier for prediction, and obtaining the probability of malicious data containing wrong classification information uploaded by a participant. If the gradient is malicious data uploaded by an attacker, the model is not updated; if the gradient is data uploaded by normal participants, the model will be updated. After an attacker is identified, the frequency with which the attacker obtains the global model may be limited or excluded from the joint learning process.
In the embodiment of one step, the training process of the joint deep learning model specifically comprises the following steps:
and the participator downloads the latest parameters of the joint deep learning model from the server, trains locally and uploads the updated gradient loss parameters. And the server extracts the characteristics of the gradient loss parameters, extracts the characteristics of the updated gradient in the training process of the combined deep learning model through a layer classifier, inputs the extracted gradient into a global classifier for detecting the GAN attack for prediction, and obtains the probability of malicious data containing wrong classification information uploaded by participants.
The invention realizes a detection method of the GAN attack, which can identify the identity of the GAN attacker, thereby protecting the privacy of the participants of normal model training and the safety of the global model. According to the method, the identity of the attacker is recognized in the process of uploading the malicious gradient, so that leakage of more private information of the victim and reduction of the model accuracy are avoided. The detection method provided by the invention is an active defense mode, and after the identity of an attacker is identified, the server can actively limit or punish the attacker, so that the further consumption of server resources by the attacker is avoided. The training and prediction of the GAN attack detection classifier provided by the invention are carried out at the server end, the calculation overhead and the communication overhead of participants are not influenced, and the accuracy and the convergence speed of a model are not influenced.
The method and the system perform feature extraction on sample data, normalize the update gradient of the model uploaded by normal participants and attackers, and reduce the influence of the difference of data volume and the number of rounds of local training on the gradient update amplitude. The feature extraction method performs dimensionality reduction on the weights with large quantities in the sample data, and improves training and prediction efficiency of the classifier. In the training stage of the classifier, the invention combines supervised learning and unsupervised learning. In the initial stage of classifier training, a GAN attack process is simulated, and gradient data marked as normal participants and attackers are obtained and used for supervised learning with labels. In the normal training stage of the joint deep learning, the invention randomly selects some unmarked gradients for unsupervised learning. Meanwhile, after certain training, supervised learning is carried out again to ensure the correctness of the classifier.
In summary, the present invention has the following features:
(1) can detect GAN attack, protect participant's privacy
The invention defends the GAN attack from the attack detection direction for the first time. According to the invention, all the updating gradients are used as a training data set, different characteristics are extracted, and then a classifier is constructed to filter the updating gradients containing wrong classification information, so that the privacy of participants is protected.
(2) Securing a global model
GAN attackers may inject incorrect classification information into the model, resulting in a reduction in the accuracy of the global model. According to the invention, through the GAN detection classifier, the attacker is identified, so that the safety of the model is protected. The detection method provided by the invention is an active defense mode, and after the identity of the attacker is identified, the server can actively limit or punish the attacker, so that the further consumption of server resources by the attacker is avoided.
(3) Non-ciphertext operations and model capabilities
Most of the existing defense strategies are passive defense by means of differential privacy, safe multi-party computation and other cryptographic knowledge. The method reserves the characteristics of joint learning distribution, parallelization and non-ciphertext operation, the training and prediction of the attack detection classifier are carried out at the server end, the calculation overhead and the communication overhead of participants cannot be increased, and the accuracy and the convergence speed of the model cannot be influenced.
(4) Supervised learning and unsupervised learning
When the invention trains the GAN attack detection classifier, the supervised learning and the unsupervised learning are combined, thereby improving the accuracy of the classifier.