WO2024093426A1 - 基于联邦机器学习的模型训练方法和装置 - Google Patents

基于联邦机器学习的模型训练方法和装置 Download PDF

Info

Publication number
WO2024093426A1
WO2024093426A1 PCT/CN2023/112501 CN2023112501W WO2024093426A1 WO 2024093426 A1 WO2024093426 A1 WO 2024093426A1 CN 2023112501 W CN2023112501 W CN 2023112501W WO 2024093426 A1 WO2024093426 A1 WO 2024093426A1
Authority
WO
WIPO (PCT)
Prior art keywords
client
training
cloud server
clients
gradient
Prior art date
Application number
PCT/CN2023/112501
Other languages
English (en)
French (fr)
Inventor
申书恒
傅欣艺
王维强
Original Assignee
支付宝(杭州)信息技术有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 支付宝(杭州)信息技术有限公司 filed Critical 支付宝(杭州)信息技术有限公司
Publication of WO2024093426A1 publication Critical patent/WO2024093426A1/zh

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • HELECTRICITY
    • H04ELECTRIC COMMUNICATION TECHNIQUE
    • H04LTRANSMISSION OF DIGITAL INFORMATION, e.g. TELEGRAPHIC COMMUNICATION
    • H04L9/00Cryptographic mechanisms or cryptographic arrangements for secret or secure communications; Network security protocols
    • HELECTRICITY
    • H04ELECTRIC COMMUNICATION TECHNIQUE
    • H04LTRANSMISSION OF DIGITAL INFORMATION, e.g. TELEGRAPHIC COMMUNICATION
    • H04L9/00Cryptographic mechanisms or cryptographic arrangements for secret or secure communications; Network security protocols
    • H04L9/40Network security protocols

Definitions

  • One or more embodiments of this specification relate to computer technology, and more particularly, to a model training method and apparatus based on federated machine learning.
  • Federated machine learning is a distributed machine learning framework with privacy protection effects. It can effectively help multiple clients use data and conduct machine learning modeling while meeting the requirements of privacy protection, data security, and government regulations. As a distributed machine learning paradigm, federated machine learning can effectively solve the problem of data silos, allowing each client to jointly model without sharing local data, achieve intelligent collaboration, and jointly train a global model with better performance.
  • the central cloud server sends the global model to each client.
  • Each client uses private local data to train the gradient of the model parameters, and then passes the gradient trained in this round to the cloud server. After collecting the gradients from all parties, the cloud server calculates the average gradient and uses it to update the global model on the cloud server. In the next round of training, the updated global model is sent to each client.
  • each client needs to send its trained gradient to the cloud server.
  • the gradient information sent by the client to the cloud server can be used to restore the original private data stored locally by the client, resulting in the leakage of private data, the user's privacy cannot be protected, and the security is poor.
  • One or more embodiments of this specification describe a model training method and device based on federated machine learning, which can improve the security of model training.
  • a model training method based on federated machine learning wherein at least two clients and at least one cloud server participate in the model training based on federated machine learning, and the method is applied to any first client among the at least two clients, and comprises: in each round of training, the first client receives a global model sent by the cloud server; the first client trains the gradient of the global model using local private data; the first client trains the gradient of the global model for the current round; The gradients obtained from the training are encrypted and then sent to the cloud server; the first client performs the next round of training until the global model converges.
  • the method further includes: the first client obtains a mask corresponding to the first client; the sum of all masks corresponding to all clients participating in the model training is less than a predetermined value; the first client encrypts the gradient obtained from this round of training, including: the first client adds the gradient obtained from this round of training to the mask corresponding to the first client to obtain an encrypted gradient.
  • the first client obtains a mask corresponding to the first client, including: the first client obtains each sub-mask s(u,v j ) generated by the first client and corresponding to each of the other clients among all the clients; the first client obtains each sub-mask s(v j ,u) generated by each of the other clients and corresponding to the first client; wherein j is a variable, with a value ranging from 1 to N; N is the number of all clients participating in the model training minus 1; u represents the first client, and v j represents the jth client among all the clients participating in the model training except the first client; the first client calculates the difference between s(u,v j ) and s(v j ,u) for each variable j, and obtains p(u,v j ) according to the difference; the first client calculates The calculated result is used as the mask corresponding to the first client.
  • the obtaining of p(u,v j ) according to the difference includes: directly using the difference as the p(u,v j ); or calculating the difference mod r and using the calculated remainder result as the p(u,v j ); wherein mod is a remainder operation and r is a preset value greater than 1.
  • the r is a prime number not less than 200 digits.
  • the method further includes: the first client generates a homomorphic encryption key pair corresponding to the first client; the first client sends the public key in the homomorphic encryption key pair corresponding to the first client to the forwarding server; and the first client receives the public key corresponding to each other client among all the clients sent by the forwarding server; accordingly, after the first client obtains the sub-masks s(u,v j ) generated by the first client and corresponding to each other client among all the clients, it further includes: for each of the other clients, the first client uses the public key corresponding to the j-th client to encrypt the sub-mask s(u,v j ) corresponding to the j-th client, and then sends the encrypted s(u,v j ) to the forwarding server; accordingly, the first client obtains the sub-masks s(v j ,u) generated by each of the other clients and corresponding to the first client, including: the first client receives the encrypted sub-masks s(v
  • the forwarding server includes: the cloud server, or a third-party server independent of the cloud server.
  • a model training method based on federated machine learning at least two clients and at least one cloud server participate in the model training based on federated machine learning, and the method is applied to the cloud server, including: in each round of training, the cloud server sends the latest global model to each client participating in the model training based on federated machine learning; the cloud server receives the encrypted gradient of the global model sent by each client; the cloud server adds the received gradients of each encrypted global model to obtain an aggregated gradient; the cloud server updates the global model using the aggregated gradient; the cloud server executes the next round of training until the global model converges.
  • a model training device based on federated machine learning at least two clients and at least one cloud server participate in the model training based on federated machine learning, the device is applied to any first client among the at least two clients, and the device includes: a global model acquisition module, configured to receive the global model sent by the cloud server in each round of training; a gradient acquisition module, configured to use local private data to train the gradient of the global model in each round of training; an encryption module, configured to encrypt the gradient obtained in each round of training, and then send the encrypted gradient to the cloud server; each module executes the next round of training until the global model converges.
  • a global model acquisition module configured to receive the global model sent by the cloud server in each round of training
  • a gradient acquisition module configured to use local private data to train the gradient of the global model in each round of training
  • an encryption module configured to encrypt the gradient obtained in each round of training, and then send the encrypted gradient to the cloud server; each module executes the next round of training until the global model converges.
  • a model training device based on federated machine learning in which at least two clients and at least one cloud server participate in the model training based on federated machine learning, and the device is applied to the cloud server, and the device includes: a global model sending module, configured to send the latest global model to each client participating in the model training based on federated machine learning in each round of training; a gradient receiving module, configured to receive the encrypted gradient of the global model sent by each client in each round of training; a gradient aggregation module, configured to add the received gradients of each encrypted global model in each round of training to obtain an aggregated gradient; a global model updating module, configured to update the global model using the aggregated gradient in each round of training; each module executes the next round of training until the global model converges.
  • a global model sending module configured to send the latest global model to each client participating in the model training based on federated machine learning in each round of training
  • a gradient receiving module configured to receive the encrypted gradient of the global model sent by each client in each round of training
  • a computing device including a memory and a processor, wherein the memory stores executable code, and when the processor executes the executable code, the method described in any embodiment of the present specification is implemented.
  • the client does not directly send the gradient information to the cloud server, but first encrypts the gradient and sends the encrypted information to the cloud server.
  • the cloud server obtains the encrypted information from each client.
  • the cloud server can only obtain the aggregated gradients, not the original gradients, which means that the cloud server can only obtain the aggregated gradients, but not the gradients of each client, thus improving security.
  • attackers cannot steal the original gradients from the transmission link from the client to the cloud server or from the cloud server, and thus cannot recover the private data in the terminal device where the client is located through means such as generative adversarial networks (GANs).
  • GANs generative adversarial networks
  • the client can keep its privacy in its own hands, which greatly improves security.
  • the embodiment of this specification adopts homomorphic encryption to encrypt the sub-mask during secret sharing, which can be achieved by relying on a central cloud server or a third-party server as an intermediate third party to convey the sub-mask, thereby avoiding the problem of sub-mask leakage caused by exchanging sub-masks between clients, thereby further improving security.
  • the difference is modulo, and the result of the modulo is used to obtain the mask corresponding to the client, so as to ensure that the numerical range of the calculated mask will not exceed the maximum numerical value that the protocol can carry, thereby increasing the application scope of the embodiments of this specification. For example, when the number of clients participating in the model training based on federated machine learning is huge, the model training in the embodiments of this specification can also be implemented.
  • FIG. 1 is a schematic diagram of a system structure used in an embodiment of the present specification.
  • FIG2 is a flow chart of a model training method based on federated machine learning executed by a client in one embodiment of the present specification.
  • FIG. 3 is a flow chart of a method for a first client to obtain a mask corresponding to the first client in one embodiment of the present specification.
  • FIG4 is a flow chart of a model training method based on federated machine learning executed by a cloud server in one embodiment of the present specification.
  • FIG5 is a flow chart of a model training method based on federated machine learning implemented by the cooperation of a client and a cloud server in one embodiment of this specification.
  • FIG6 is a schematic diagram of the structure of a model training device based on federated machine learning applied to a client in one embodiment of the present specification.
  • FIG7 is a schematic diagram of the structure of a model training device based on federated machine learning applied to a client in one embodiment of the present specification.
  • FIG8 is a schematic diagram of the structure of a model training device based on federated machine learning applied to a cloud server in one embodiment of the present specification.
  • each client needs to send its trained gradient to the cloud server.
  • the attacker can use the gradient information sent by the client to the cloud server to recover the original private data in the terminal device where the client is located, such as by using the generative adversarial network (GAN) and other means to recover the private data.
  • GAN generative adversarial network
  • the central cloud server receives the gradient information of each individual client.
  • the central cloud server is reliable, but when the central cloud server loses data unintentionally or colludes with other clients, the client's private data will be leaked. The client cannot keep its privacy in its own hands.
  • the system architecture mainly includes M clients and cloud servers participating in federated machine learning.
  • M is a positive integer greater than 1.
  • each client interacts with the cloud server through a network, and the network can include various connection types, such as wired, wireless communication links or optical fiber cables.
  • the M clients are located in M terminal devices. Each client can be located in any terminal device modeled by federated machine learning, such as bank devices, payment devices, mobile terminals, etc.
  • the cloud server can be located in the cloud.
  • the method of the embodiment of this specification involves the processing of the client and the processing of the cloud server, which are described below.
  • FIG. 2 is a flow chart of a model training method based on federated machine learning executed by a client in one embodiment of this specification.
  • the execution subject of the method is each client participating in the federated machine learning. It can be understood that the method can also be executed by any device, equipment, platform, or device cluster with computing and processing capabilities. Referring to FIG. 2 , the method includes steps 201 to 207.
  • Step 201 In each round of training, the first client receives the global model sent by the cloud server.
  • Step 203 The first client uses local private data to train the gradient of the global model.
  • Step 205 The first client encrypts the gradient obtained from this round of training, and then sends the encrypted gradient to the cloud server.
  • Step 207 The first client performs the next round of training until the global model converges.
  • the client after obtaining the gradient, the client does not directly send the gradient information to the cloud server, but first encrypts the gradient and sends the encrypted information to the cloud server.
  • the cloud server obtains the encrypted gradient from each client, rather than the original gradient, thereby improving security.
  • an attacker cannot steal the original gradient from the transmission link from the client to the cloud server or from the cloud server, and thus cannot recover the private data in the terminal device where the client is located by means such as the Generative Adversarial Network (GAN).
  • GAN Generative Adversarial Network
  • the client can keep privacy in its own hands, which greatly improves security.
  • the method of the embodiments of this specification can be applied to various business scenarios based on federated machine learning for model training, such as Alipay's "Ant Forest” product, QR code scanning image risk control, etc.
  • FIG. 2 Each step in FIG. 2 is described below in conjunction with a specific embodiment.
  • step 201 in each round of training, the first client receives the global model sent by the cloud server.
  • the client executing the model training method in FIG2 is recorded as the first client.
  • the first client is each client participating in the model training based on federated machine learning, that is, each client participating in the model training based on federated machine learning needs to execute the model training method described in conjunction with FIG2.
  • step 203 the first client uses local private data to train the gradient of the global model.
  • step 205 the first client encrypts the gradient obtained from this round of training, and then sends the encrypted gradient to the cloud server.
  • the client cannot send the original text of the gradient trained by itself directly to the cloud server, but sends the gradient text.
  • Ciphertext Ciphertext.
  • Availability In order to perform model training, the cloud server needs to obtain the aggregated results of each gradient of each client. The aggregated results must be equal to or close to the aggregated results of each gradient original text, so as to better perform model training. In other words, although the cloud server cannot directly obtain the original text of each gradient, the obtained gradient aggregation results must be equal to or close to the aggregation results of each gradient original text.
  • the method before step 205, the method further includes: step A: the first client obtains a mask corresponding to the first client.
  • the implementation process of this step 205 includes: the first client adds the gradient obtained in this round of training to the mask corresponding to the first client to obtain the encrypted gradient.
  • Each client has its own corresponding mask. For example, if there are 100 clients participating in the model training method based on federated machine learning, then each client will get its own corresponding mask. To further improve security, different clients have different corresponding masks.
  • an implementation process in which the first client in the above step A obtains the mask corresponding to the first client includes steps 301 to 307 .
  • Step 301 A first client obtains each sub-mask s(u,v j ) generated by the first client and corresponding to each other client among all the clients.
  • the first client For example, if there are 100 clients participating in the model training method based on federated machine learning, then the first client generates 99 sub-masks s(u,v j ) corresponding to the other 99 clients respectively.
  • s(u,v 1 ) represents the sub-mask generated by the first client corresponding to client 1 among the other 99 clients; similarly, s(u,v 2 ) represents the sub-mask generated by the first client corresponding to client 2 among the other 99 clients; and so on, s(u,v 99 ) represents the sub-mask generated by the first client corresponding to client 99.
  • Step 303 The first client obtains each sub-mask s(v j ,u) corresponding to the first client generated by each of the other clients; wherein j is a variable, with a value from 1 to N; N is all clients participating in the model training. The number of terminals is reduced by 1; u represents the first client, and v j represents the jth client among all the clients participating in the model training except the first client.
  • All clients participating in the model training method based on federated machine learning will perform the processing of the above step 301, so each other client will also generate a sub-mask corresponding to the first client.
  • the first client needs to obtain all sub-masks s(v j ,u) generated by each other client and corresponding to the first client.
  • the first client needs to obtain 99 sub-masks s(v j ,u) corresponding to the first client generated by the other 99 clients.
  • s(v 1 ,u) represents the sub-mask generated by client 1 among the other 99 clients and corresponding to the first client
  • s(v 2 ,u) represents the sub-mask generated by client 2 among the other 99 clients and corresponding to the first client
  • s(v 99 ,u) represents the sub-mask generated by client 99 among the other 99 clients and corresponding to the first client.
  • the first client obtains 99 sub-masks generated by itself corresponding to the other 99 clients, and 99 sub-masks generated by the other 99 clients corresponding to the first client, for a total of 198 sub-masks.
  • the first client needs to send all the sub-masks it generates to the cloud server or the third-party server, and the cloud server or the third-party server forwards them to the corresponding clients after receiving them.
  • the cloud server or the third-party server obtains the original text of the sub-mask, it may also cause the problem of obtaining the original text of the gradient according to the sub-mask later. Therefore, in order to further increase security, in one embodiment of the present specification, the sub-mask can be encrypted, and the encrypted sub-mask is sent to the cloud server or the third-party server. In this way, the cloud server or the third-party server not only cannot obtain the original text of the gradient of each client, but also cannot obtain the original text of the sub-mask generated by each client, which greatly improves security.
  • the method further includes: the first client generates a homomorphic encryption key pair corresponding to the first client; wherein the homomorphic encryption key pair corresponding to the first client is a homomorphic encryption key pair dedicated to the first client, rather than a homomorphic encryption key pair shared by all clients, and therefore, the homomorphic encryption key pairs corresponding to different clients are different; the first client sends the public key in the homomorphic encryption key pair corresponding to the first client to the forwarding server; and the first client receives the public key corresponding to each of the other clients among all the clients sent by the forwarding server; accordingly, after step 301, it further includes: for each of the other clients, the first client uses the public key corresponding to the j-th client to encrypt the sub-mask s(u,v j ) corresponding to the j-th client, and then sends the encrypted s(u,v j ) to the forwarding
  • the forwarding server mentioned above includes: a cloud server, or a third-party server independent of the cloud server.
  • Step 305 The first client calculates the difference between s(u,v j ) and s(v j ,u) for each variable j, and obtains p(u,v j ) according to the difference.
  • step 305 adopts method 1, including: directly using the calculated difference as p(u,v j ).
  • step 305 adopts method 2, including: mod r the calculated difference, and then taking the remainder as p(u,v j ); wherein mod is a remainder operation, and r is a preset value greater than 1.
  • the number of clients participating in model training may be very large, for example, there are 20,000 clients.
  • each client needs to calculate 19,999 differences, and then add the 19,999 differences in step 307.
  • the value of the result after addition will be very large, and it is likely to exceed the maximum value that the protocol can carry.
  • the subsequent cloud server needs to add the 20,000 masks obtained by the 20,000 clients, and each mask is the sum of the above 19,999 differences. Therefore, even if the value of the mask in one client does not exceed the maximum value that the protocol can carry, the value that the subsequent cloud server needs to calculate may also exceed the maximum value that the protocol can carry.
  • the embodiment of this specification can, in step 305, when each difference is calculated, let the difference take the modulus of r, so that all the differences are equivalent to being reduced by r times as a whole, thereby ensuring that the value is the value that the protocol can carry.
  • r can be taken as large a value as possible, so as to minimize the value of all differences. Perform maximum reduction, for example, r is a prime number not less than 200 digits.
  • the modulo processing does not affect the mask sum being less than the predetermined value or the mask sum being equal to 0. Regardless of whether the difference modulo is used, that is, whether method 1 or method 2 is adopted, the effect of making all mask sums of all clients less than the predetermined value or 0 is the same.
  • Step 307 The first client calculates The calculated result is used as the mask corresponding to the first client.
  • the first client needs to calculate the sum of 99 p(u,v j ) and use the sum as the mask corresponding to the first client.
  • the first client adds the gradient obtained in this round of training to the mask corresponding to the first client to obtain the encrypted gradient.
  • the gradient obtained by the first client is x(u)
  • the mask corresponding to the first client is ⁇ v p(u,v) obtained in step 307.
  • step 207 the first client performs the next round of training until the global model converges.
  • the following describes the processing of cloud servers in model training based on federated machine learning.
  • FIG4 is a flow chart of a model training method based on federated machine learning performed by a cloud server in one embodiment of the present specification. At least two clients and at least one cloud server participate in the model training based on federated machine learning, and the execution subject of the method is the cloud server participating in the federated machine learning. It can be understood that the method can also be executed by any device, equipment, platform, or device cluster with computing and processing capabilities. Referring to FIG4, the method includes steps 401 to 409.
  • Step 401 In each round of training, the cloud server sends the latest global model to each client participating in the model training based on federated machine learning.
  • Step 403 The cloud server receives the encrypted gradient of the global model sent by each client.
  • Step 405 The cloud server adds the gradients of the received encrypted global models to obtain the aggregated gradient.
  • Step 407 The cloud server updates the global model using the aggregated gradients.
  • Step 409 The cloud server performs the next round of training until the global model converges.
  • FIG5 is a flow chart of a model training method based on federated machine learning implemented by the client and the cloud server in one embodiment of this specification. Referring to FIG5, the method includes steps 501 to 527.
  • Step 501 Each client generates a dedicated homomorphic encryption key pair corresponding to the client.
  • Step 503 Each client sends the public key in the homomorphic encryption key pair corresponding to the client to the cloud server.
  • Step 505 After receiving the public keys sent by each client, the cloud server broadcasts them to each client, so that each client obtains the public keys corresponding to all clients participating in model training.
  • Step 507 The first client generates each sub-mask s(u,v j ) corresponding to each of the other clients among all the clients.
  • the process performed by the first client is taken as an example.
  • the process performed by the first client is the process performed by each client participating in the model training.
  • Step 509 For the other N clients, the first client uses the public key corresponding to the j-th client to encrypt s(u,v j ) corresponding to the j-th client, and obtains the encrypted sub-mask corresponding to the j-th client; where j is a variable, taking a value from 1 to N, and N is the number of all clients participating in the model training minus 1, and then all N encrypted sub-masks s(u,v j ) are sent to the cloud server.
  • Step 511 The cloud server sends the encrypted sub-masks corresponding to the i-th client sent by all clients to the i-th client; wherein i is a variable with a value ranging from 1 to M; and M is the number of all clients participating in the model training.
  • Step 513 The first client receives each encrypted sub-mask corresponding to itself, and decrypts each encrypted sub-mask using the private key in the dedicated homomorphic encryption key pair corresponding to the first client to obtain N decrypted s(v j ,u).
  • Step 517 The first client calculates The calculated result is used as the mask corresponding to the first client.
  • the process from step 501 to step 517 may be performed once when each client is started, and in each subsequent round of training, N masks p(u,v j ) are directly used, that is, the mask used by the first client in each round of training is the same.
  • the process from step 501 to step 517 may be performed once in each round of training, so that the mask used by the first client in each round of training is different, further improving security.
  • Step 519 In each round of training, the first client receives the global model sent by the cloud server.
  • Step 521 The first client uses local private data to train the gradient of the global model, which is recorded as x(u).
  • Step 523 The first client calculates the encrypted gradient Then y(u) is sent to the cloud server.
  • Step 525 The cloud server obtains M y(u) i sent by all clients and calculates the aggregate gradient in this round of polling Among them, i is a variable and M is the number of all clients participating in model training.
  • Step 527 The cloud server uses the aggregated gradient T obtained in this round of training to update the global model for use by all clients in the next round of training until the global model converges.
  • the embodiments of this specification also provide a business prediction method, which includes: using the trained global model to perform business prediction, such as identifying risky users.
  • the embodiment of this specification also proposes a model training device based on federated machine learning, at least two clients and at least one cloud server participate in the model training based on federated machine learning, and the device is applied to any first client among the at least two clients, referring to Figure 6, the device includes: a global model acquisition module 601, configured to receive the global model sent by the cloud server in each round of training; a gradient acquisition module 602, configured to use local private data to train the gradient of the global model in each round of training; an encryption module 603, configured to encrypt the gradient obtained in each round of training, and then send the encrypted gradient to the cloud server; each module executes the next round of training until the global model converges.
  • a global model acquisition module 601 configured to receive the global model sent by the cloud server in each round of training
  • a gradient acquisition module 602 configured to use local private data to train the gradient of the global model in each round of training
  • an encryption module 603 configured to encrypt the gradient obtained in each round of training, and then send the encrypted gradient to the cloud
  • the device of this specification further comprises: a mask acquisition module 701; the mask acquisition module 701 is configured to obtain a mask corresponding to the first client where the device is located; wherein the sum of all masks corresponding to all clients participating in the model training is less than a predetermined value; the encryption module 603 is configured to execute when encrypting: The gradient obtained in this round of training is added to the mask corresponding to the first client to obtain an encrypted gradient.
  • the mask acquisition module 701 is configured to execute: obtain each sub-mask s(u,v j ) generated by the first client and corresponding to each other client among all the clients; obtain each sub-mask s(v j ,u) generated by each other client and corresponding to the first client; wherein j is a variable, and its value ranges from 1 to N; N is the number of all clients participating in the model training minus 1; u represents the first client, and v j represents the jth client among all the clients participating in the model training except the first client; for each variable j, respectively calculate the difference between s(u,v j ) and s(v j ,u), and obtain p(u,v j ) according to the difference; calculate The calculated result is used as the mask corresponding to the first client.
  • the mask acquisition module 701 is configured to execute: directly taking the difference as the p(u, v j ); or, calculating the difference mod r, and taking the calculated remainder as the p(u, v j ); wherein mod is a remainder operation, and r is a preset value greater than 1.
  • the mask acquisition module 701 is further configured to execute: generate a homomorphic encryption key pair corresponding to the first client; send the public key in the homomorphic encryption key pair corresponding to the first client to the forwarding server; and receive the public key corresponding to each other client among all the clients sent by the forwarding server; accordingly, the mask acquisition module 701 is configured to execute: after obtaining the sub-masks s(u,v j ) generated by the first client and corresponding to each other client among all the clients, for each other client, use the public key corresponding to the j-th client to encrypt the sub-mask s(u,v j ) corresponding to the j-th client, and then send the encrypted s(u,v j ) to the forwarding server; receive the encrypted sub-masks s(v j ,u) generated by each other client and corresponding to the first client sent by the forwarding server; use the private key in the homomorphic encryption key pair
  • the forwarding server includes: the cloud server, or a third-party server independent of the cloud server.
  • a model training device based on federated machine learning is proposed. At least two clients and at least one cloud server participate in the model training based on federated machine learning.
  • the device is applied to the cloud server. See Figure 8.
  • the device includes: a global model sending module 801, configured to send the latest global model to each client participating in the model training based on federated machine learning in each round of training; a gradient receiving module 802, configured to receive the encrypted gradient of the global model sent by each client in each round of training; a gradient aggregation module 803, configured to receive the encrypted gradients of the global models received in each round of training.
  • the global model updating module 804 is configured to update the global model using the aggregated gradient in each round of training; each module performs the next round of training until the global model converges.
  • One embodiment of the present specification provides a computer-readable storage medium having a computer program stored thereon.
  • the computer program When the computer program is executed in a computer, the computer is caused to execute a method in any one of the embodiments of the present specification.
  • An embodiment of the present specification provides a computing device, including a memory and a processor, wherein the memory stores executable code, and when the processor executes the executable code, the method in any embodiment of the present specification is implemented.
  • the structures illustrated in the embodiments of this specification do not constitute specific limitations on the devices of the embodiments of this specification.
  • the above-mentioned device may include more or fewer components than those shown in the figure, or combine certain components, or split certain components, or arrange the components differently.
  • the components shown in the figure may be implemented in hardware, software, or a combination of software and hardware.

Landscapes

  • Engineering & Computer Science (AREA)
  • Computer Security & Cryptography (AREA)
  • Computer Networks & Wireless Communication (AREA)
  • Signal Processing (AREA)
  • Data Mining & Analysis (AREA)
  • Theoretical Computer Science (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Artificial Intelligence (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Computer And Data Communications (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

本说明书实施例提供了基于联邦机器学习的模型训练方法和装置。至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,在每一轮训练中,第一客户端接收云服务器下发的全局模型;第一客户端利用本地的私有数据训练出该全局模型的梯度;第一客户端对本轮训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器;第一客户端执行下一轮训练,直至全局模型收敛。本说明书实施例能够提高模型训练的安全性。

Description

基于联邦机器学习的模型训练方法和装置 技术领域
本说明书一个或多个实施例涉及计算机技术,尤其涉及基于联邦机器学习的模型训练方法和装置。
背景技术
联邦机器学习是一个具有隐私保护效果的分布式机器学习框架,能有效帮助多个客户端在满足隐私保护、数据安全和政府法规的要求下,进行数据使用和机器学习建模。联邦机器学习作为分布式的机器学习范式,可以有效解决数据孤岛问题,让各个客户端在不共享本端数据的基础上联合建模,实现智能协作,共同训练一个性能较好的全局模型。
在基于联邦机器学习进行模型训练时,在每一轮的训练中,中心的云服务器将全局模型下发给各个客户端,各个客户端用私有的本地数据训练出模型参数的梯度,再将本轮训练出的梯度传递给云服务器。云服务器收集到各方梯度后,计算出平均梯度,并利用该平均梯度更新云服务器端的全局模型,在下一轮训练时,将更新后的全局模型下发给各个客户端。
可见,在基于联邦机器学习的全局模型的训练中,各个客户端需要将自己训练出的梯度发送给云服务器。而在很多攻击场景中,可以利用客户端发送给云服务器的梯度信息恢复出该客户端本地存储的原始的私有数据,从而导致私有数据的泄露,用户的隐私无法得到保护,安全性较差。
发明内容
本说明书一个或多个实施例描述了基于联邦机器学习的模型训练方法和装置,能够提高模型训练的安全性。
根据第一方面,提供了基于联邦机器学习的模型训练方法,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该方法应用于所述至少两个客户端中的任意一个第一客户端,包括:在每一轮训练中,第一客户端接收云服务器下发的全局模型;第一客户端利用本地的私有数据训练出该全局模型的梯度;第一客户端对本轮 训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器;第一客户端执行下一轮训练,直至全局模型收敛。
其中,该方法进一步包括:第一客户端得到对应于该第一客户端的掩码;其中,参与所述模型训练的所有客户端对应的所有掩码的和小于预定值;所述第一客户端对本轮训练得到的梯度进行加密,包括:第一客户端将本轮训练得到的梯度与该第一客户端对应的掩码相加,得到加密后的梯度。
其中,所述所有客户端对应的所有掩码的和为0。
其中,所述第一客户端得到对应于该第一客户端的掩码,包括:第一客户端得到由该第一客户端生成的、对应所述所有客户端中其他每一个客户端的各个子掩码s(u,vj);第一客户端得到由所述其他每一个客户端生成的、对应第一客户端的各个子掩码s(vj,u);其中,j为变量,取值为1至N;N为参与所述模型训练的所有客户端的数量减1;u表征第一客户端,vj表征参与所述模型训练的所有客户端中除了第一客户端之外的第j个客户端;第一客户端针对每一个变量j,分别计算s(u,vj)与s(vj,u)两者的差值,根据该差值得到p(u,vj);第一客户端计算将计算出的结果作为第一客户端对应的掩码。
其中,所述根据该差值得到p(u,vj),包括:将该差值直接作为所述p(u,vj);或者,计算该差值mod r,将计算出的取余的结果作为所述p(u,vj);其中,mod为取余运算,r为大于1的预设值。
其中,所述r为不小于200位的质数。
该方法进一步包括:第一客户端生成该第一客户端对应的同态加密密钥对;第一客户端将该第一客户端对应的同态加密密钥对中的公钥发送给转发服务器;以及第一客户端接收转发服务器发来的所述所有客户端中其他每一个客户端对应的公钥;相应地,在所述第一客户端得到由该第一客户端生成的、对应所述所有客户端中其他每一个客户端的各个子掩码s(u,vj)之后,进一步包括:针对所述其他每一个客户端,第一客户端利用第j个客户端对应的公钥,对对应该第j个客户端的子掩码s(u,vj)进行加密,然后将加密后的s(u,vj)发送给转发服务器;相应地,所述第一客户端得到由所述其他每一个客户端生成的、对应第一客户端的各个子掩码s(vj,u),包括:所述第一客户端接收转发服务器发来的其他每一个客户端生成的、对应第一客户端的加密后的各个子掩码s(vj,u);第一客户端利用该第一客户端对应的同态加密密钥对中的私钥,对各个加密后的子掩码 s(vj,u)进行解密,得到各个子掩码s(vj,u)。
其中,所述转发服务器包括:所述云服务器,或者独立于所述云服务器的第三方服务器。
根据第二方面提供了基于联邦机器学习的模型训练方法,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该方法应用于云服务器,包括:在每一轮训练中,云服务器将最新得到的全局模型下发给参与基于联邦机器学习的模型训练的每一个客户端;云服务器接收每一个客户端发来的加密后的全局模型的梯度;云服务器将接收到的各个加密后的全局模型的梯度相加,得到聚合后的梯度;云服务器利用聚合后的梯度更新全局模型;云服务器执行下一轮训练,直至全局模型收敛。
根据第三方面,提供了基于联邦机器学习的模型训练装置,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该装置应用于所述至少两个客户端中的任意一个第一客户端,该装置包括:全局模型获取模块,配置为在每一轮训练中,接收云服务器下发的全局模型;梯度获取模块,配置为在每一轮训练中,利用本地的私有数据训练出该全局模型的梯度;加密模块,配置为在每一轮训练中,对本轮训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器;各模块执行下一轮训练,直至全局模型收敛。
根据第四方面,提供了基于联邦机器学习的模型训练装置,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该装置应用于云服务器,该装置包括:全局模型下发模块,配置为在每一轮训练中,将最新得到的全局模型下发给参与基于联邦机器学习的模型训练的每一个客户端;梯度接收模块,配置为在每一轮训练中,接收每一个客户端发来的加密后的全局模型的梯度;梯度聚合模块,配置为在每一轮训练中,将接收到的各个加密后的全局模型的梯度相加,得到聚合后的梯度;全局模型更新模块,配置为在每一轮训练中,利用聚合后的梯度更新全局模型;各模块执行下一轮训练,直至全局模型收敛。
根据第五方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现本说明书任一实施例所述的方法。
本说明书各个实施例提供的方法及装置,能够单独或者组合后实现如下有益效果:1、客户端在得到梯度后,不是直接将梯度信息发送给云服务器,而是首先对梯度进行加密,将加密后的信息发送给云服务器。这样,云服务器从每一个客户端处得到的就是加密后 的梯度,而不是梯度原文,也就是说云服务器只能获取聚合后的梯度,而不能获取每一个客户端的梯度,因此,提高了安全性。比如,攻击者无法从客户端至云服务器的传输链路上或者从云服务器中,窃取到梯度原文,从而无法通过生成对抗网络(GAN)等手段恢复出客户端所在的终端设备中的私有数据。客户端能够将隐私把握在自己手中,从而大大提高了安全性。
2、采用同态加密的手段对秘密分享时的子掩码进行加密,也就是说每一个客户端不会将子掩码的原文发送给转发服务器,而是发送被同态加密密钥对中的公钥加密后的子掩码,从而进一步提高了安全性。
3、相比于客户端之间两两交换子掩码的子掩码获取方式,本说明书实施例中采用同态加密的手段对秘密分享时的子掩码进行加密,可以依靠中心的云服务器或者第三方服务器作为中间第三方转达实现,避免了客户端之间两两交换子掩码所造成的子掩码泄露的问题,从而进一步提高了安全性。
4、在计算两个子掩码的差值时,利用该差值取余,利用取余的结果进来得到客户端对应的掩码,从而可以保证计算出的掩码的数值范围不会超过协议所能承载的最大数值,从而增加了本说明书实施例的应用范围,比如当参与基于联邦机器学习的模型训练的客户端的数量巨大时,也能够实现本说明书实施例中的模型训练。
附图说明
为了更清楚地说明本说明书实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本说明书的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本说明书一个实施例所应用的系统结构示意图。
图2是本说明书一个实施例中由客户端执行的基于联邦机器学习的模型训练方法流程图。
图3是本说明书一个实施例中第一客户端得到该第一客户端对应的掩码的方法流程图。
图4是本说明书一个实施例中由云服务器执行的基于联邦机器学习的模型训练方法流程图。
图5是本说明书一个实施例中由客户端及云服务器配合实现的基于联邦机器学习的模型训练方法的流程图。
图6是本说明书一个实施例中应用于客户端中的基于联邦机器学习的模型训练装置的结构示意图。
图7是本说明书一个实施例中应用于客户端中的基于联邦机器学习的模型训练装置的结构示意图。
图8是本说明书一个实施例中应用于云服务器中的基于联邦机器学习的模型训练装置的结构示意图。
具体实施方式
如前所述,各个客户端需要将自己训练出的梯度发送给云服务器。而在很多攻击场景中,攻击者可以利用客户端发送给云服务器的梯度信息恢复出该客户端所在的终端设备中的原始的私有数据,比如可以通过生成对抗网络(GAN)等手段恢复出私有数据。再如,中心的云服务器收到的是逐个单独客户端的梯度信息,一般来说,中心云服务器是可靠的,但是当中心云服务端存在无意的丢失数据行为或者与其他客户端合谋时,客户端的私有数据会遭到泄露。客户端无法将隐私把握在自己手中。
下面结合附图,对本说明书提供的方案进行描述。
为了方便对本说明书的理解,首先对本说明书所应用的系统架构进行描述。如图1中所示,该系统架构主要包括参与联邦机器学习的M个客户端以及云服务器。M为大于1的正整数。其中,各个客户端与云服务器之间通过网络交互,网络可以包括各种连接类型,例如有线、无线通信链路或者光纤电缆等。
M个客户端分别位于M个终端设备中。每一个客户端可以位于任意一个通过联邦机器学习进行建模的终端设备中,比如银行设备、支付端设备、移动终端等,云服务器可以位于云端。
本说明书实施例的方法涉及到客户端的处理以及云服务器的处理。下面分别进行说明。
首先说明在客户端中执行的模型训练方法。
图2是本说明书一个实施例中客户端执行的基于联邦机器学习的模型训练方法的流 程图。该方法的执行主体为参与联邦机器学习的每一个客户端。可以理解,该方法也可以通过任何具有计算、处理能力的装置、设备、平台、设备集群来执行。参见图2,该方法包括步骤201至步骤207。
步骤201:在每一轮训练中,第一客户端接收云服务器下发的全局模型。
步骤203:第一客户端利用本地的私有数据训练出该全局模型的梯度。
步骤205:第一客户端对本轮训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器。
步骤207:第一客户端执行下一轮训练,直至全局模型收敛。
根据上述图2所示的流程可以看出,本说明书实施例提供的方法,客户端在得到梯度后,不是直接将梯度信息发送给云服务器,而是首先对梯度进行加密,将加密后的信息发送给云服务器。这样,云服务器从每一个客户端处得到的就是加密后的梯度,而不是梯度原文,因此,提高了安全性。比如,攻击者无法从客户端至云服务器的传输链路上或者从云服务器中,窃取到梯度原文,从而无法通过生成对抗网络(GAN)等手段恢复出客户端所在的终端设备中的私有数据。客户端能够将隐私把握在自己手中,从而大大提高了安全性。
本说明书实施例的方法可以应用于基于联邦机器学习进行模型训练的各种业务场景中,比如支付宝的“蚂蚁森林”产品、扫码图像风控等。
下面结合具体实施例对图2中的每一个步骤分别进行说明。
首先对于步骤201:在每一轮训练中,第一客户端接收云服务器下发的全局模型。
为便于描述,更好地区分当前处理的客户端与其他客户端,将图2中执行模型训练方法的客户端记为第一客户端。可以理解,在本说明书实施例中,第一客户端是参与基于联邦机器学习进行模型训练的每一个客户端,也就是说,参与基于联邦机器学习进行模型训练的每一个客户端都需要执行结合图2说明的模型训练方法。
接下来对于步骤203:第一客户端利用本地的私有数据训练出该全局模型的梯度。
接下来对于步骤205:第一客户端对本轮训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器。
在本说明书实施例的方法中,需要满足如下两方面的要求:1、安全性。为了满足该安全性,客户端不能将自己训练出的梯度的原文直接发送给云服务器,而是发送梯度的 密文。2、可用性。为了进行模型训练,云服务器需要得到各个客户端的各个梯度的聚合结果,该聚合结果必须等于或接近于各个梯度原文的聚合结果,从而才能更好地进行模型训练。也就是说,云服务器虽然不能直接得到每一个梯度的原文,但是得到的梯度聚合结果必须等于或者接近于各个梯度原文的聚合结果。因此,参与模型训练的所有客户端的加密处理需要保证附加在各个梯度上的所有密码的和能够或者接近于相互抵消。举一个简单的例子来说明该思想,比如需要得到结果Y,一种计算方式是Y=X1+X2,另一种计算方式是:Y=(X1+S)+(X2-S)。为了满足该要求2,本说明书实施例的方法正是利用了后一种计算思路。
此时,在本说明书一个实施例中,在步骤205之前,该方法进一步包括:步骤A:第一客户端得到该第一客户端对应的掩码。
需要说明的是,其中,参与所述模型训练的所有客户端对应的所有掩码的和小于预定值。进一步地,该所有客户端对应的所有掩码的和为0。因为该所有掩码的和小于预定值甚至可以是0,因此,可以保证后续通过掩码对梯度加密这一处理对各个客户端的梯度和的值的大小影响不大,甚至影响为0。这样,本步骤205的实现过程包括:第一客户端将本轮训练得到的梯度与该第一客户端对应的掩码进行相加,得到加密后的梯度。
每一个客户端都有自己对应的掩码,比如,参与基于联邦机器学习的模型训练方法的客户端有100个,那么,每一个客户端都会得到自己对应的掩码。为了进一步提高安全性,不同客户端对应的掩码不同。
在本说明书一个实施例中,参见图3,上述步骤A中的第一客户端得到该第一客户端对应的掩码的一种实现过程包括步骤301至步骤307。
步骤301:第一客户端得到由该第一客户端生成的、对应所述所有客户端中其他每一个客户端的各个子掩码s(u,vj)。
比如,参与基于联邦机器学习的模型训练方法的客户端有100个,那么,第一客户端是针对其他99个客户端分别生成对应该其他99个客户端的99个子掩码s(u,vj)。比如,s(u,v1)表示第一客户端生成的、对应于其他99个客户端中的客户端1的子掩码;同理,s(u,v2)表示第一客户端生成的、对应于其他99个客户端中的客户端2的子掩码;依次类推,s(u,v99)表示第一客户端生成的、对应于客户端99的子掩码。
步骤303:第一客户端得到由所述其他每一个客户端生成的、对应第一客户端的各个子掩码s(vj,u);其中,j为变量,取值为1至N;N为参与所述模型训练的所有客户 端的数量减1;u表征第一客户端,vj表征参与所述模型训练的所有客户端中除了第一客户端之外的第j个客户端。
参与基于联邦机器学习的模型训练方法的所有客户端都会执行上述步骤301的处理,因此,其他每一个客户端也会生成对应第一客户端的子掩码。本步骤303中,第一客户端需要得到其他每一个客户端生成的、对应第一客户端的所有子掩码s(vj,u)。
比如,参与基于联邦机器学习的模型训练方法的客户端有100个,那么,第一客户端需要得到其他99个客户端各自生成的对应于第一客户端的99个子掩码s(vj,u)。其中,s(v1,u)表示其他99个客户端中的客户端1所生成的、对应于第一客户端的子掩码;s(v2,u)表示其他99个客户端中的客户端2所生成的、对应于第一客户端的子掩码;以此类推,s(v99,u)表示其他99个客户端中的客户端99所生成的、对应于第一客户端的子掩码。
比如,参与基于联邦机器学习的模型训练方法的客户端有100个,那么,执行完本步骤303之后,第一客户端则得到了自己生成的对应于其他99个客户端的99个子掩码,以及由其他99个客户端生成的对应于该第一客户端的99个子掩码,一共198个子掩码。
为了让参与模型训练的每一个客户端都得到其他各个客户端生成的对应于该每一个客户端的子掩码,在步骤301之后,第一客户端需要将其生成的所有子掩码都发送给云服务器或者第三方服务器,云服务器或者第三方服务器接收到之后,转发给对应的客户端。但是,如果让云服务器或者第三方服务器得到了子掩码的原文,那么,也可能会造成后续根据子掩码得到梯度原文的问题。因此,为了进一步增加安全性,在本说明书一个实施例中,可以对子掩码进行加密,发送给云服务器或者第三方服务器的都是加密后的子掩码。这样,云服务器或者第三方服务器不仅无法得到每一个客户端的梯度原文,也无法得到每一个客户端生成的子掩码的原文,大大提高了安全性。
为了实现云服务器或者第三方服务器无法得到子掩码原文的效果,该方法进一步包括:第一客户端生成该第一客户端对应的同态加密密钥对;其中,第一客户端对应的同态加密密钥对是第一客户端专用的同态加密密钥对,而不是各个客户端共用的同态加密密钥对,因此,不同客户端对应的同态加密密钥对不同;第一客户端将该第一客户端对应的同态加密密钥对中的公钥发送给转发服务器;以及第一客户端接收转发服务器发来的所述所有客户端中其他每一个客户端对应的公钥;相应地,在步骤301之后,进一步包括:针对其他每一个客户端,第一客户端利用第j个客户端对应的公钥,对对应该第j个客户端的子掩码s(u,vj)进行加密,然后将加密后的s(u,vj)发送给转发服务器,以便由该转发服务器将加密后的s(u,vj)发送给对应的第j个客户端;相应地,步骤303的过 程包括:第一客户端接收转发服务器发来的其他每一个客户端生成的、对应第一客户端的加密后的各个子掩码s(vj,u);第一客户端利用该第一客户端对应的同态加密密钥对中的私钥,对各个加密后的子掩码s(vj,u)进行解密,得到各个子掩码s(vj,u)。
其中,上述转发服务器包括:云服务器,或者独立于云服务器的第三方服务器。
步骤305:第一客户端针对每一个变量j,分别计算s(u,vj)与s(vj,u)两者的差值,根据该差值得到p(u,vj)。
比如,参与基于联邦机器学习的模型训练方法的客户端有100个,即j=99,那么本步骤305中,需要计算出99个差值。即,对应其他99个客户端中的客户端1,需要计算出s(u,v1)与s(v1,u)两者的差值;对应其他99个客户端中的客户端2,需要计算出s(u,v2)与s(v2,u)两者的差值;以此类推,直至对应其他99个客户端中的客户端99,需要计算出s(u,v99)与s(v99,u)两者的差值。
需要说明的是,在计算s(u,v1)与s(v1,u)两者的差值时,谁作为减数或者被减数都可以,只要保证所有客户端计算所有两者的差值时采用相同的方法即可,比如都将自己生成的s(u,vj)作为减数,都将第j个客户端生成的s(vj,u)作为被减数。
在本说明书一个实施例中,本步骤305的实现过程采用方式一,包括:将计算出的差值直接作为p(u,vj)。
可替代的,在本说明书另一个实施例中,本步骤305的实现过程采用方式二,包括:将计算出的差值mod r,然后将取余的结果作为p(u,vj);其中,mod为取余运算,r为大于1的预设值。
在实际的业务实现中,参与模型训练的客户端的数量可能会非常多,比如有2万个客户端,那么,根据步骤305的处理,每一个客户端都需要计算19999个差值,然后在步骤307中再将该19999个差值相加,相加后得到的结果的数值会非常大,很可能超过了协议所能承载的最大数值。而后续云服务器又需要将2万个客户端得到的2万个掩码进行相加,每一个掩码又是上述19999个差值相加的和,因此,即使在一个客户端中掩码的数值不会超过协议所能承载的最大数值,但是后续云服务器需要计算的数值也可能会超过协议所能承载的最大数值。因此,为了进一步避免参与模型训练的客户端数量巨大时导致的数值范围越界的问题,本说明书实施例可以在步骤305中,每计算出一个差值时,就让该差值对r取余,这样,所有的差值相当于整体缩小了r倍,从而可以保证数值为协议所能承载的数值。其中,r可以尽量取一个较大值,从而尽可能对所有差值 进行最大程度的限缩,比如,r为不小于200位的一个质数。
可以理解,取余的处理并不会对掩码和小于预定值或者掩码和等于0造成影响。无论是否利用差值取余,即无论采用方式一还是方式二,后续让所有客户端的所有掩码和小于预定值或者为0的效果是相同的。
步骤307:第一客户端计算将计算得到的结果作为第一客户端对应的掩码。
比如,参与基于联邦机器学习的模型训练方法的客户端有100个,即j=99,那么,根据步骤307的处理,第一客户端需要计算99个p(u,vj)的和,将和值作为第一客户端对应的掩码。
根据上述图3所示流程可以看出,因为第一客户端对应的掩码是根据所有p(u,vj)的和得到的,而每一个p(u,vj)是根据s(u,vj)与s(vj,u)两者的差值得到的。这样,如果将所有客户端的所有掩码p(u,vj)相加,就会使得掩码值正负抵消,从而消除利用掩码对梯度加密的影响。
如前所述,在步骤205中,第一客户端将本轮训练得到的梯度与该第一客户端对应的掩码进行相加,得到加密后的梯度。比如,本轮训练中,第一客户端得到的梯度为x(u),第一客户端对应的掩码为步骤307中得到的∑vp(u,v),那么,在步骤205中,第一客户端计算y(u)=x(u)+∑vp(u,v),并将y(u)发送给云服务器。
接下来执行步骤207:第一客户端执行下一轮训练,直至全局模型收敛。
下面说明云服务器在基于联邦机器学习的模型训练中的处理。
图4是本说明书一个实施例中云服务器执行的基于联邦机器学习的模型训练方法的流程图。至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该方法的执行主体为参与联邦机器学习的云服务器。可以理解,该方法也可以通过任何具有计算、处理能力的装置、设备、平台、设备集群来执行。参见图4,该方法包括步骤401至步骤409。
步骤401:在每一轮训练中,云服务器将最新得到的全局模型下发给参与基于联邦机器学习的模型训练的每一个客户端。
步骤403:云服务器接收每一个客户端发来的加密后的全局模型的梯度。
步骤405:云服务器将接收到的各个加密后的全局模型的梯度相加,得到聚合后的 梯度。
步骤407:云服务器利用聚合后的梯度更新全局模型。
步骤409:云服务器执行下一轮训练,直至全局模型收敛。
对云服务器所执行的处理的说明也可以进一步参考结合图2、图3、图5对本说明书实施例所进行的说明。
下面结合客户端及云服务器的处理,说明在本说明书的一个实施例中基于联邦机器学习的模型训练方法。图5是本说明书一个实施例中由客户端及云服务器配合实现的基于联邦机器学习的模型训练方法的流程图。参见图5,该方法包括步骤501至步骤527。
步骤501:每一个客户端生成该客户端对应的专用同态加密密钥对。
步骤503:每一个客户端将该客户端对应的同态加密密钥对中的公钥发送给云服务器。
步骤505:云服务器接收到各客户端发送的公钥后,将其广播给各个客户端,从而使得每一个客户端都得到了参与模型训练的所有客户端对应的公钥。
步骤507:第一客户端生成对应所有客户端中其他每一个客户端的各个子掩码s(u,vj)。
下述步骤中,为了便于描述,以第一客户端执行的处理为例进行说明。第一客户端执行的处理就是参与模型训练的每一个客户端执行的处理。
步骤509:对于其他N个客户端,第一客户端使用第j个客户端对应的公钥对第j个客户端对应的s(u,vj)进行加密,得到对应于第j个客户端的加密后的子掩码;其中j为变量,取值为1至N,N为参与模型训练的所有客户端的数量减1,然后将所有N个加密后的子掩码s(u,vj)发送给云服务器。
步骤511:云服务器将所有客户端发来的对应于第i个客户端的加密后的子掩码,发送给第i个客户端;其中,i为变量,取值为1至M;M为参与模型训练的所有客户端的数量。
步骤513:第一客户端接收到对应自己的各个加密后的子掩码,利用第一客户端对应的专用同态加密密钥对中的私钥对每一个加密后的子掩码进行解密,得到解密后的N个s(vj,u)。
步骤515:针对每一个变量j,第一客户端计算p(u,vj)=[s(u,vj)-s(vj,u)]mod r,得到N个p(u,vj)。
步骤517:第一客户端计算将计算得到的结果作为第一客户端对应的掩码。
上述步骤501至步骤517的过程,可以是在每一个客户端启动时执行一次,后续每一轮训练中,直接利用N个掩码p(u,vj),即各轮训练中第一客户端利用的掩码相同。或者,上述步骤501至步骤517的过程,也可以是在每一轮训练中均执行一次,使得各轮训练中第一客户端利用的掩码不相同,进一步提高了安全性。
步骤519:在每一轮训练中,第一客户端接收云服务器下发的全局模型。
步骤521:第一客户端利用本地的私有数据训练出该全局模型的梯度记为x(u)。
步骤523:第一客户端计算加密后的梯度然后将y(u)发送给云服务器。
步骤525:云服务器得到所有客户端发来的M个y(u)i,计算本轮轮询中的聚合梯度其中,i为变量,M为参与模型训练的所有客户端的数量。
步骤527:云服务器利用本轮训练中得到的聚合梯度T更新全局模型,以供所有客户端在下一轮训练中使用,直至全局模型收敛。
至此,则得到了全局模型。
本说明书实施例还提出一种业务预测方法,该方法包括:利用训练出的全局模型进行业务预测,比如进行风险用户识别等。
本说明书实施例还提出一种基于联邦机器学习的模型训练装置,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该装置应用于所述至少两个客户端中的任意一个第一客户端,参见图6,该装置包括:全局模型获取模块601,配置为在每一轮训练中,接收云服务器下发的全局模型;梯度获取模块602,配置为在每一轮训练中,利用本地的私有数据训练出该全局模型的梯度;加密模块603,配置为在每一轮训练中,对本轮训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器;各模块执行下一轮训练,直至全局模型收敛。
在本说明书装置的实施例中,参见图7,进一步包括:掩码获取模块701;掩码获取模块701,配置为得到该装置所在的第一客户端对应的掩码;其中,参与模型训练的所有客户端对应的所有掩码的和小于预定值;加密模块603在进行加密时被配置为执行: 将本轮训练得到的梯度与该第一客户端对应的掩码进行相加,得到加密后的梯度。
在图6、7所示的本说明书装置的实施例中,所有客户端对应的所有掩码的和为0。
在图7所示的本说明书装置的实施例中,掩码获取模块701被配置为执行:得到由所在的第一客户端生成的、对应所述所有客户端中其他每一个客户端的各个子掩码s(u,vj);得到由所述其他每一个客户端生成的、对应第一客户端的各个子掩码s(vj,u);其中,j为变量,取值为1至N;N为参与所述模型训练的所有客户端的数量减1;u表征第一客户端,vj表征参与所述模型训练的所有客户端中除了第一客户端之外的第j个客户端;针对每一个变量j,分别计算s(u,vj)与s(vj,u)两者的差值,根据该差值得到p(u,vj);计算将计算得到的结果作为第一客户端对应的掩码。
在图7所示的本说明书装置的实施例中,掩码获取模块701被配置为执行:将该差值直接作为所述p(u,vj);或者,计算该差值mod r,将计算出的取余的结果作为所述p(u,vj);其中,mod为取余运算,r为大于1的预设值。
在图7所示的本说明书装置的实施例中,其中,所述r为不小于200位的质数。
在图7所示的本说明书装置的实施例中,掩码获取模块701进一步被配置为执行:生成第一客户端对应的同态加密密钥对;将该第一客户端对应的同态加密密钥对中的公钥发送给转发服务器;以及接收转发服务器发来的所述所有客户端中其他每一个客户端对应的公钥;相应地,掩码获取模块701被配置为执行:在得到由该第一客户端生成的、对应所述所有客户端中其他每一个客户端的各个子掩码s(u,vj)之后,针对其他每一个客户端,利用第j个客户端对应的公钥,对对应该第j个客户端的子掩码s(u,vj)进行加密,然后将加密后的s(u,vj)发送给转发服务器;接收转发服务器发来的其他每一个客户端生成的、对应第一客户端的加密后的各个子掩码s(vj,u);利用该第一客户端对应的同态加密密钥对中的私钥,对各个加密后的子掩码s(vj,u)进行解密,得到各个子掩码s(vj,u)。
其中,转发服务器包括:所述云服务器,或者独立于所述云服务器的第三方服务器。
在本说明书一个实施例中提出了一种基于联邦机器学习的模型训练装置,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该装置应用于云服务器,参见图8,该装置包括:全局模型下发模块801,配置为在每一轮训练中,将最新得到的全局模型下发给参与基于联邦机器学习的模型训练的每一个客户端;梯度接收模块802,配置为在每一轮训练中,接收每一个客户端发来的加密后的全局模型的梯度;梯度聚合模块803,配置为在每一轮训练中,将接收到的各个加密后的全局模型的梯度 相加,得到聚合后的梯度;全局模型更新模块804,配置为在每一轮训练中,利用聚合后的梯度更新全局模型;各模块执行下一轮训练,直至全局模型收敛。
本说明书一个实施例提供了一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行说明书中任一个实施例中的方法。
本说明书一个实施例提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现执行说明书中任一个实施例中的方法。
可以理解的是,本说明书实施例示意的结构并不构成对本说明书实施例的装置的具体限定。在说明书的另一些实施例中,上述装置可以包括比图示更多或者更少的部件,或者组合某些部件,或者拆分某些部件,或者不同的部件布置。图示的部件可以以硬件、软件或者软件和硬件的组合来实现。
上述装置、系统内的各模块之间的信息交互、执行过程等内容,由于与本说明书方法实施例基于同一构思,具体内容可参见本说明书方法实施例中的叙述,此处不再赘述。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本申请所描述的功能可以用硬件、软件、挂件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本申请的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述为本申请的实施方式而已,并不用于限定本申请的保护范围,凡在本申请的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本申请的保护范围之内。

Claims (12)

  1. 基于联邦机器学习的模型训练方法,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该方法应用于所述至少两个客户端中的任意一个第一客户端,包括:
    在每一轮训练中,第一客户端接收云服务器下发的全局模型;
    第一客户端利用本地的私有数据训练出该全局模型的梯度;
    第一客户端对本轮训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器;
    第一客户端执行下一轮训练,直至全局模型收敛。
  2. 根据权利要求1所述的方法,其中,该方法进一步包括:第一客户端得到对应于该第一客户端的掩码;其中,参与所述模型训练的所有客户端对应的所有掩码的和小于预定值;
    所述第一客户端对本轮训练得到的梯度进行加密,包括:
    第一客户端将本轮训练得到的梯度与该第一客户端对应的掩码相加,得到加密后的梯度。
  3. 根据权利要求2所述的方法,其中,所述所有客户端对应的所有掩码的和为0。
  4. 根据权利要求3所述的方法,其中,所述第一客户端得到对应于该第一客户端的掩码,包括:
    第一客户端得到由该第一客户端生成的、对应所述所有客户端中其他每一个客户端的各个子掩码s(u,vj);
    第一客户端得到由所述其他每一个客户端生成的、对应第一客户端的各个子掩码s(vj,u);其中,j为变量,取值为1至N;N为参与所述模型训练的所有客户端的数量减1;u表征第一客户端,vj表征参与所述模型训练的所有客户端中除了第一客户端之外的第j个客户端;
    第一客户端针对每一个变量j,分别计算s(u,vj)与s(vj,u)两者的差值,根据该差值得到p(u,vj);
    第一客户端计算将计算出的结果作为第一客户端对应的掩码。
  5. 根据权利要求4所述的方法,其中,所述根据该差值得到p(u,vj),包括:
    将该差值直接作为所述p(u,vj);
    或者,
    计算该差值mod r,将计算出的取余的结果作为所述p(u,vj);其中,mod为取余运算,r为大于1的预设值。
  6. 根据权利要求5所述的方法,其中,所述r为不小于200位的质数。
  7. 根据权利要求4所述的方法,其中,
    该方法进一步包括:第一客户端生成该第一客户端对应的同态加密密钥对;第一客户端将该第一客户端对应的同态加密密钥对中的公钥发送给转发服务器;以及第一客户端接收转发服务器发来的所述所有客户端中其他每一个客户端对应的公钥;
    相应地,在所述第一客户端得到由该第一客户端生成的、对应所述所有客户端中其他每一个客户端的各个子掩码s(u,vj)之后,进一步包括:针对所述其他每一个客户端,第一客户端利用第j个客户端对应的公钥,对对应该第j个客户端的子掩码s(u,vj)进行加密,然后将加密后的s(u,vj)发送给转发服务器;
    相应地,所述第一客户端得到由所述其他每一个客户端生成的、对应第一客户端的各个子掩码s(vj,u),包括:
    所述第一客户端接收转发服务器发来的其他每一个客户端生成的、对应第一客户端的加密后的各个子掩码s(vj,u);
    第一客户端利用该第一客户端对应的同态加密密钥对中的私钥,对各个加密后的子掩码s(vj,u)进行解密,得到各个子掩码s(vj,u)。
  8. 根据权利要求7所述的方法,其中,所述转发服务器包括:所述云服务器,或者独立于所述云服务器的第三方服务器。
  9. 基于联邦机器学习的模型训练方法,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该方法应用于云服务器,包括:
    在每一轮训练中,云服务器将最新得到的全局模型下发给参与基于联邦机器学习的模型训练的每一个客户端;
    云服务器接收每一个客户端发来的加密后的全局模型的梯度;
    云服务器将接收到的各个加密后的全局模型的梯度相加,得到聚合后的梯度;
    云服务器利用聚合后的梯度更新全局模型;
    云服务器执行下一轮训练,直至全局模型收敛。
  10. 基于联邦机器学习的模型训练装置,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该装置应用于所述至少两个客户端中的任意一个第一客户端,该装置包括:
    全局模型获取模块,配置为在每一轮训练中,接收云服务器下发的全局模型;
    梯度获取模块,配置为在每一轮训练中,利用本地的私有数据训练出该全局模型的梯度;
    加密模块,配置为在每一轮训练中,对本轮训练得到的梯度进行加密,然后将加密后的梯度发送给云服务器;
    各模块执行下一轮训练,直至全局模型收敛。
  11. 基于联邦机器学习的模型训练装置,至少两个客户端以及至少一个云服务器参与基于联邦机器学习的模型训练,该装置应用于云服务器,该装置包括:
    全局模型下发模块,配置为在每一轮训练中,将最新得到的全局模型下发给参与基于联邦机器学习的模型训练的每一个客户端;
    梯度接收模块,配置为在每一轮训练中,接收每一个客户端发来的加密后的全局模型的梯度;
    梯度聚合模块,配置为在每一轮训练中,将接收到的各个加密后的全局模型的梯度相加,得到聚合后的梯度;
    全局模型更新模块,配置为在每一轮训练中,利用聚合后的梯度更新全局模型;
    各模块执行下一轮训练,直至全局模型收敛。
  12. 一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-9中任一项所述的方法。
PCT/CN2023/112501 2022-11-03 2023-08-11 基于联邦机器学习的模型训练方法和装置 WO2024093426A1 (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202211369556.9A CN115883053A (zh) 2022-11-03 2022-11-03 基于联邦机器学习的模型训练方法和装置
CN202211369556.9 2022-11-03

Publications (1)

Publication Number Publication Date
WO2024093426A1 true WO2024093426A1 (zh) 2024-05-10

Family

ID=85759374

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2023/112501 WO2024093426A1 (zh) 2022-11-03 2023-08-11 基于联邦机器学习的模型训练方法和装置

Country Status (2)

Country Link
CN (1) CN115883053A (zh)
WO (1) WO2024093426A1 (zh)

Families Citing this family (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115883053A (zh) * 2022-11-03 2023-03-31 支付宝(杭州)信息技术有限公司 基于联邦机器学习的模型训练方法和装置
CN117390448B (zh) * 2023-10-25 2024-04-26 西安交通大学 一种用于云际联邦学习的客户端模型聚合方法及相关系统
CN117150566B (zh) * 2023-10-31 2024-01-23 清华大学 面向协作学习的鲁棒训练方法及装置

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112580821A (zh) * 2020-12-10 2021-03-30 深圳前海微众银行股份有限公司 一种联邦学习方法、装置、设备及存储介质
CN113449872A (zh) * 2020-03-25 2021-09-28 百度在线网络技术(北京)有限公司 基于联邦学习的参数处理方法、装置和系统
CN114817958A (zh) * 2022-04-24 2022-07-29 山东云海国创云计算装备产业创新中心有限公司 一种基于联邦学习的模型训练方法、装置、设备及介质
CN115021905A (zh) * 2022-05-24 2022-09-06 北京交通大学 一种联邦学习本地模型参数聚合方法
CN115883053A (zh) * 2022-11-03 2023-03-31 支付宝(杭州)信息技术有限公司 基于联邦机器学习的模型训练方法和装置

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113449872A (zh) * 2020-03-25 2021-09-28 百度在线网络技术(北京)有限公司 基于联邦学习的参数处理方法、装置和系统
CN112580821A (zh) * 2020-12-10 2021-03-30 深圳前海微众银行股份有限公司 一种联邦学习方法、装置、设备及存储介质
CN114817958A (zh) * 2022-04-24 2022-07-29 山东云海国创云计算装备产业创新中心有限公司 一种基于联邦学习的模型训练方法、装置、设备及介质
CN115021905A (zh) * 2022-05-24 2022-09-06 北京交通大学 一种联邦学习本地模型参数聚合方法
CN115883053A (zh) * 2022-11-03 2023-03-31 支付宝(杭州)信息技术有限公司 基于联邦机器学习的模型训练方法和装置

Also Published As

Publication number Publication date
CN115883053A (zh) 2023-03-31

Similar Documents

Publication Publication Date Title
WO2024093426A1 (zh) 基于联邦机器学习的模型训练方法和装置
US11128447B2 (en) Cryptographic operation method, working key creation method, cryptographic service platform, and cryptographic service device
CN111431713B (zh) 一种私钥存储方法、装置和相关设备
CN110289968B (zh) 私钥恢复、协同地址的创建、签名方法及装置、存储介质
CN112380578A (zh) 一种基于区块链和可信执行环境的边缘计算框架
CN111371790B (zh) 基于联盟链的数据加密发送方法、相关方法、装置和系统
WO2021228239A1 (zh) 资产类型一致性证据生成、交易、交易验证方法及系统
CN112287377A (zh) 基于联邦学习的模型训练方法、计算机设备及存储介质
CN109741068A (zh) 网银跨行签约方法、装置及系统
CN109361508A (zh) 数据传输方法、电子设备及计算机可读存储介质
CN114143117B (zh) 数据处理方法及设备
CN113643134B (zh) 基于多密钥同态加密的物联网区块链交易方法及系统
CN115495768A (zh) 基于区块链及多方安全计算的涉密信息处理方法及系统
CN109995739A (zh) 一种信息传输方法、客户端、服务器及存储介质
CN116527279A (zh) 工控网络中安全数据聚合的可验证联邦学习装置及方法
CN109361512A (zh) 数据传输方法
CN114301677B (zh) 秘钥协商方法、装置、电子设备及存储介质
CN107104888B (zh) 一种安全的即时通信方法
CN112003690B (zh) 密码服务系统、方法及装置
CN110784318B (zh) 群密钥更新方法、装置、电子设备、存储介质及通信系统
CN115913513B (zh) 支持隐私保护的分布式可信数据交易方法、系统及装置
JP5932709B2 (ja) 送信側装置および受信側装置
CN115001719B (zh) 隐私数据处理系统、方法、装置、计算机设备及存储介质
US11770263B1 (en) Systems and methods for enforcing cryptographically secure actions in public, non-permissioned blockchains using bifurcated self-executing programs comprising shared digital signature requirements
US20240187256A1 (en) Systems and methods for enforcing cryptographically secure actions in public, non-permissioned blockchains using bifurcated self-executing programs comprising shared digital signature requirements