WO2023174036A1 - 联邦学习模型训练方法、电子设备及存储介质 - Google Patents

联邦学习模型训练方法、电子设备及存储介质 Download PDF

Info

Publication number
WO2023174036A1
WO2023174036A1 PCT/CN2023/078224 CN2023078224W WO2023174036A1 WO 2023174036 A1 WO2023174036 A1 WO 2023174036A1 CN 2023078224 W CN2023078224 W CN 2023078224W WO 2023174036 A1 WO2023174036 A1 WO 2023174036A1
Authority
WO
WIPO (PCT)
Prior art keywords
information
participant
gradient
model
participant device
Prior art date
Application number
PCT/CN2023/078224
Other languages
English (en)
French (fr)
Inventor
鲁云飞
刘洋
郑会钿
王聪
吴烨
Original Assignee
北京字节跳动网络技术有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 北京字节跳动网络技术有限公司 filed Critical 北京字节跳动网络技术有限公司
Publication of WO2023174036A1 publication Critical patent/WO2023174036A1/zh

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • G06N20/20Ensemble learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F21/00Security arrangements for protecting computers, components thereof, programs or data against unauthorised activity
    • G06F21/60Protecting data
    • G06F21/602Providing cryptographic facilities or services

Definitions

  • the present disclosure relates to the field of artificial intelligence technology, and in particular, to a federated learning model training method, electronic device, and storage medium.
  • federated learning has been increasingly widely used.
  • federated learning multiple participants with different business data collaborate to complete the training of the federated learning model.
  • stochastic gradient descent SGD
  • Newton's method Newton's method
  • quasi-Newton method is usually used to optimize the model.
  • the convergence speed of the stochastic gradient descent method is slow, and the calculation complexity of the second-order derivative used by Newton's method and quasi-Newton method is high.
  • the purpose of this disclosure is to propose a federated learning model training method, electronic device, and storage medium.
  • this disclosure provides a federated learning model training method, including:
  • Any participant's device conducts joint encryption training with other participants' devices based on its own model parameters and feature information to obtain its own gradient information;
  • Any participant device obtains the model parameter change amount and the gradient information change amount based on the model parameters and gradient information, and performs interactive calculations with other participant devices for a preset number of rounds based on the model parameter change amount and the gradient information change amount. , obtain the local gradient search direction as a quasi-Newton condition;
  • the target participant device obtains the model loss function, and calculates the step size information based on the gradient search direction and the model loss function; wherein the target participant device is any participant device with a standard Participant equipment that signs information, and the model loss function is a convex function;
  • Any participating device updates its own model parameters based on the gradient search direction and the step size information until the federated learning model converges.
  • the any participant device uses a two-way loop recursion method to perform interactive calculations with other participant devices for a preset number of rounds, and obtains the gradient search direction as a quasi-Newton conditions, including:
  • Any participant device performs a preset number of interactive calculations with other participant devices based on the model parameter change amount and the gradient information change amount to obtain an intermediate change amount; the intermediate change amount is used to characterize the gradient information the size of;
  • Any participant device performs a preset number of interactive calculations with other participant devices based on the intermediate change amount to obtain the gradient search direction.
  • the any participant device performs a preset number of interactive calculations with other participant devices based on the model parameter change amount and the gradient information change amount to obtain the intermediate change amount, which also includes:
  • Any participating party device calculates its own first intermediate value information based on its own model parameter change and gradient information change, exchanges the first intermediate value information with other participating party devices and based on the third intermediate value information of each participating party device
  • An intermediate value information calculates a first global intermediate value to calculate the intermediate change amount according to the first global intermediate value.
  • the first intermediate value information is obtained based on the product of the transposed matrix of the gradient information variation and the model parameter variation.
  • the any participant device performs a preset number of interactive calculations with other participant devices based on the intermediate change amount to obtain the gradient search direction, which further includes:
  • Any participating party device calculates its own second intermediate value information based on its own intermediate change amount
  • Any participant device exchanges the second intermediate value information with other participant devices based on its own second intermediate value information and calculates a second global intermediate value based on the second intermediate value information of each participant device, so as to calculate the second intermediate value based on the second intermediate value information of each participant device.
  • the second global intermediate value is used to calculate the gradient search direction.
  • the device of any participating party calculates its own second intermediate value information based on its own intermediate change amount, including:
  • Any participant device obtains the first scalar information based on its own transposed matrix of the model parameter changes and the model parameter changes, and based on its own transposed matrix of the gradient information changes, the gradient
  • the information change amount obtains the second scalar information
  • Any party device interacts with other party devices to obtain the third party device
  • Three scalar information and fourth scalar information the third scalar information is obtained based on the transposed matrix and model parameter variation of the model parameter changes of other participant devices
  • the fourth scalar information is based on the gradient information of other participant devices Obtain the transposed matrix of the change amount and the change amount of the gradient information;
  • Any participant device calculates its own second intermediate value information based on the first scalar information, the second scalar information, the third scalar information, the fourth scalar information, and the intermediate variation.
  • the first global intermediate value is the sum of the first intermediate value information of each participant device
  • the second global intermediate value is the sum of the second intermediate value information of each participant device.
  • the target participant device obtains the model loss function, and calculates the step size information based on the gradient search direction and the model loss function, including:
  • the target participant device obtains sample label information, and obtains sample label prediction information based on its own model parameters, feature information and first data information of other participant devices; wherein the first data information is based on model parameters of other participant devices , characteristic information acquisition;
  • the target participant device calculates the model loss function based on the sample label prediction information and the sample label information
  • the target participant device determines whether the model loss function meets the preset conditions. If so, the current step information is used as the final step information; otherwise, the value of the step information is reduced and the model loss function is recalculated.
  • sample label prediction information is obtained based on the own model parameters, feature information and data information of other participants' devices, including:
  • the target participant device calculates the product of the transposed matrix of the model parameters and the feature information based on its own model parameters and feature information to obtain the second data information;
  • the target participant device interacts with other participant devices based on the second data information to obtain first data information of other participant devices;
  • the target participant device obtains the sample label prediction information based on the first data information, the second data information and the preset model function.
  • the present disclosure also provides an electronic device, including a memory, a processor, and a computer program stored in the memory and executable on the processor.
  • the processor executes the program, the method as described in any one of the above is implemented. .
  • the present disclosure also provides a non-transitory computer-readable storage medium that stores computer instructions, and the computer instructions are used to cause the computer to execute any of the above methods.
  • each participant's device after each participant's device obtains its own gradient information through joint encryption training with other participant's devices, it will perform the training based on the model parameter changes.
  • the amount and gradient information changes are jointly trained with other participant devices to obtain their respective gradient search directions; then, the target participant device calculates the step information based on the gradient search direction and the model loss function; finally, each participant device calculates the step information based on the gradient search direction and the model loss function.
  • the direction and step size information update the model parameters of the party, so that there is no need to calculate the inverse matrix of the Hessian matrix H.
  • the stochastic gradient descent method Newton method and quasi-Newton method, it has less calculation and communication volume, and can ensure fast speed. convergence.
  • Figure 1 is a schematic flowchart of a federated learning model training method according to an embodiment of the present disclosure
  • Figure 2 is a schematic diagram of the framework of the federated learning model according to the embodiment of the present disclosure
  • Figure 3 is a schematic diagram of sample information of the federated learning model according to the embodiment of the present disclosure.
  • Figure 4 is a schematic flowchart of a method for any participant device to obtain gradient information according to an embodiment of the present disclosure
  • Figure 5 is a schematic flowchart of the acquisition method of the gradient search method according to the embodiment of the present disclosure
  • FIG. 6 is a schematic structural diagram of an electronic device according to an embodiment of the present disclosure.
  • Artificial Intelligence is a theory, method, technology and application system that uses digital computers or machines controlled by digital computers to simulate, extend and expand human intelligence, perceive the environment, acquire knowledge and use knowledge to obtain the best results.
  • artificial intelligence is a comprehensive technology of computer science that attempts to understand the essence of intelligence and produce a new intelligent machine that can respond in a similar way to human intelligence.
  • Artificial intelligence is the study of the design principles and implementation methods of various intelligent machines, so that the machines have the functions of perception, reasoning and decision-making.
  • Artificial intelligence technology is a comprehensive subject that covers a wide range of fields, including both hardware-level technology and software-level technology.
  • Basic artificial intelligence technologies generally include technologies such as sensors, dedicated artificial intelligence chips, cloud computing, distributed storage, big data processing technology, operation/interaction systems, mechatronics and other technologies.
  • Artificial intelligence software technology mainly includes computer vision technology, speech processing technology, natural language processing technology, machine learning/deep learning, autonomous driving, smart transportation and other major directions.
  • Machine Learning is a multi-field interdisciplinary subject involving probability theory, statistics, approximation theory, convex analysis, algorithm complexity theory and other disciplines. It specializes in studying how computers can simulate or implement human learning behavior to acquire new knowledge or skills, and reorganize existing knowledge structures to continuously improve their performance.
  • Machine learning is the core of artificial intelligence and the fundamental way to make computers intelligent. Its applications cover all fields of artificial intelligence.
  • Machine learning and deep learning usually include artificial neural networks, belief networks, reinforcement learning, transfer learning, inductive learning, teaching learning and other technologies.
  • machine learning can be applied to various fields, such as data mining, computer vision, natural language processing, biometric identification, medical diagnosis, detection of credit card fraud, securities market analysis, and DNA sequence sequencing.
  • deep neural networks are a newer technology that use multi-layer network structures to build machine learning models and automatically learn representation features from data. Due to its ease of use and good practical effects, it has been widely used in image recognition, speech recognition, natural language processing, search recommendation and other fields.
  • Federated Learning can also be called federated machine learning, federated learning, alliance learning, etc.
  • Federated machine learning is a machine learning framework in which each participant jointly builds a machine learning model and only exchanges intermediate data during training, rather than directly exchanging the business data of each participant.
  • Enterprise A and Enterprise B each establish a task model.
  • a single task can be classification or prediction, and these tasks have been approved by the respective users when obtaining the data.
  • company A lacks label data
  • company B lacks feature data, or the data is insufficient.
  • the sample size is not enough to build a good model, then the model at each end may not be built or the effect is not ideal.
  • the problem that federated learning needs to solve is how to build a high-quality machine learning model on each side of A and B.
  • the training of this model uses the data of various enterprises such as A and B, and the own data of each enterprise is not known to other parties, that is, Establish a common model without exchanging own data.
  • This shared model is like the optimal model established by all parties by aggregating data together. In this way, the built model only serves its own goals in each party's area.
  • the implementation architecture of federated learning includes at least two participant devices.
  • Each participant device can include different business data, and can also participate in joint training of models through devices, computers, servers, etc.; among them, each participant device can include a At least one of a server, multiple servers, a cloud computing platform and a virtualization center.
  • the business data here can be various data such as characters, pictures, voices, animations, videos, etc., for example.
  • the business data contained in each participant's equipment is relevant, and the business parties corresponding to each training member can also be relevant.
  • a single participant device can hold the business data of one business or the business data of multiple business parties.
  • the model can be jointly trained by two or more participant devices.
  • the model here can be used to process business data and obtain corresponding business processing results. Therefore, it can also be called a business model.
  • the specific business data to be processed and the business processing results to be obtained depend on actual needs.
  • the business data can be data related to the user's finance, and the obtained business processing result is the user's financial credit evaluation result.
  • the business data can be customer service data, and the obtained business processing result is the recommendation result of the customer service answer, and so on.
  • Business data can also be in the form of text, pictures, animations, audios, videos, etc.
  • Each participating device can use the trained model to perform local business processing on local business data.
  • federated learning can be divided into horizontal federated learning (feature alignment), vertical federated learning (sample alignment) and federated transfer learning.
  • the implementation architecture provided in this specification is based on vertical federated learning, that is, a federated learning situation in which the sample subjects overlap between the various participant devices, so that partial characteristics of the samples can be provided separately.
  • the sample subject is the subject corresponding to the business data to be processed.
  • the business subject of financial risk assessment is a user or an enterprise.
  • the stochastic gradient descent (SGD) method or Newton's method and quasi-Newton method are usually used to optimize the model.
  • the core idea of the stochastic gradient descent (SGD) method is to use the first-order gradient of the loss function on the model parameters to iteratively optimize the model.
  • the existing first-order optimizer only uses the first-order gradient of the loss function on the model parameters, and convergence The speed will be slower;
  • Newton's method uses the inverse matrix of the second-order derivative Hessian matrix H multiplied by the first-order gradient to guide parameter update.
  • the computational complexity of this method is high; the quasi-Newton method replaces the inverse of the second-order derivative Hessian matrix in Newton's method with an n-order matrix, but the convergence speed of this method is still slow.
  • the federated learning model training method includes:
  • Step S101 Any participant device performs joint encryption training with other participant devices based on its own model parameters and feature information to obtain its own gradient information.
  • At least two participant devices jointly train the federated learning model, and each participant device can obtain feature information based on the business data on the participant device.
  • each participant device interacts with other participant devices based on encrypted model parameters, feature information and other information, so that each participant device obtains its own gradient information.
  • Step S103 Any participant device obtains the model parameter change amount and the gradient information change amount based on the model parameters and gradient information, and conducts a preset number of rounds with other participant devices based on the model parameter change amount and the gradient information change amount. Through interactive calculation, the gradient search direction of the local side is obtained as a quasi-Newton condition.
  • any participant device can obtain the gradient search direction of each participant device through a preset number of interactive calculations based on model parameters and gradient information.
  • Step S105 The target participant device obtains the model loss function, and calculates the step size information based on the gradient search direction and the model loss function; wherein the target participant device is a participant device with tag information among any participant devices.
  • the model loss function is a convex function.
  • model loss function is a convex function
  • its global extreme point can be obtained by calculating its local extreme point.
  • step S103 Based on the gradient search direction of each participant's equipment calculated in step S103, select a step size information to pre-update the model parameters until the model loss function meets the search stop condition, then the model is updated based on the gradient search direction and step size information. parameters are updated.
  • Step S107 Any participating device updates its own model parameters based on the gradient search direction and the step size information until the federated learning model converges.
  • any participant device is participating in the federated learning model training. Any one of all participant devices does not distinguish whether the participant device has tag information. That is, steps S101, S103 and S107 in this embodiment are steps that can be executed by all participating devices participating in federated learning model training.
  • the target participant device is a participant device with label information among all the participant devices participating in the federated learning model training. The target participant device not only performs the methods of steps S101, S103, and S107, but also performs the method of step S105.
  • each participant device After each participant device obtains its own gradient information through joint encryption training with other participant devices, it conducts joint training with other participant devices based on the model parameter changes and gradient information changes to obtain their respective gradient information.
  • the gradient search direction is used as a quasi-Newton condition; then, the target participant device calculates the step information based on the gradient search direction and the model loss function; finally, each participant device updates its own model parameters based on the gradient search direction and step information, so as to There is no need to calculate the inverse matrix of the Hessian matrix H.
  • Newton's method and quasi-Newton method it requires less calculation and communication, and can ensure fast convergence.
  • the method described in the above embodiment is applied between the target participant device Guest and other participant devices Host other than the target participant device.
  • the target participant device Guest stores first characteristic information and sample label information of multiple samples
  • the other participant device Host stores second characteristic information of multiple samples.
  • Other participant devices may include only one participant device, or may include multiple participant devices.
  • the other participant devices include only one participant device as an example. Detailed description is based on the standard participant device Guest and other participating devices. Federated learning model training method for party device Host.
  • data alignment between the target participant device Guest and other participant devices Host is achieved based on information shared by both parties (such as id information).
  • the aligned target participant device Guest and other participants Each party's device Host includes multiple samples with ID information 1, 2, and 3 respectively.
  • the other participant device Host includes multiple second feature information such as Feature 1, Feature 2, and Feature 3;
  • the target participant device Guest includes multiple first feature information such as Feature 4 (click), Feature 5, Feature 6, etc., as well as samples.
  • Tag Information Purchase).
  • the number of samples of the target participant device Guest and other participant device Hosts is n.
  • Each piece of first characteristic information in the target participant's equipment Guest is denoted as x G
  • the first characteristic information of all n samples in the target participant's equipment Guest is denoted as The sample label of each sample
  • the sample label of each sample is y
  • the sample label information of all n samples is listed as ⁇ y (i) ⁇
  • each second feature information in the host of other participants' equipment is recorded as x H
  • n in the host of other participants' equipment is The second characteristic information of all samples is listed as Among them, i represents the i-th sample among n.
  • Step S101 Any participant device performs joint encryption training with other participant devices based on its own model parameters and feature information to obtain its own gradient information.
  • the target participant device Guest includes a first local model built locally on the target participant device Guest, and the first local model includes the first model parameter w G ; correspondingly, other participant devices Host include built on other A second local model local to the participant device Host, where the second local model includes a second model parameter w H .
  • step S101 a homomorphic encryption algorithm or a semi-homomorphic encryption algorithm is used to encrypt the interactive data during the joint encryption training process.
  • the Paillier algorithm can be used for encryption to ensure that the target participant device Guest and other The device Host of the participating parties will not be leaked during the joint training process.
  • step S101 specifically includes the following steps:
  • Step S201 Other participant devices obtain first data information and send it to the target participant device.
  • the first data information is obtained based on the second model parameters and the second feature information.
  • the other participant device Host obtains the second model parameter w H of the second local model of the other participant device, and calculates the inner product of the second model parameter w H and the second feature information, thereby obtaining the first Data information and transfer the first data information Sent to the target party device Guest.
  • the first data information including the transposed matrix of the second model parameters w H
  • the inner product with each piece of second feature information x H therefore the first data information includes n pieces of information corresponding to n samples.
  • the other participant device Host can also calculate the first regularization term and send it to the target participant device Guest.
  • the first regular term is the L2 regular term
  • the first regular term is ⁇ represents the regularization coefficient.
  • the second model parameter w H when in the first update period, is the initial value of the model parameter after initialization; when in the intermediate update period, the second model parameter w H is the second local model in Updated model parameters during the last update cycle.
  • Step S203 The target participant device obtains second data information, where the second data information is obtained based on the first model parameters and the first feature information.
  • the target participant device Guest obtains the first model parameter w G of the first local model, and calculates the inner product of the first model parameter w G and the first feature information, thereby obtaining the second data information Specifically, in this embodiment, the second data information Transpose matrix including the first model parameter w G The inner product with each piece of first feature information x G.
  • the target participant device Guest also calculates the second regularization term.
  • the second regular term is also the L2 regular term
  • the second regular term is ⁇ represents the regularization coefficient.
  • the first model parameter w G when in the first update period, is the initial value of the model parameter after initialization; when in the intermediate update period, the first model parameter w G is the first local model in Updated model parameters during the last update cycle.
  • the first model parameter w G and the second model parameter w H are one-dimensional vectors, based on The first data information obtained and based on The obtained second data information is the result of matrix multiplication.
  • the other party cannot restore the original data information, and therefore will not transmit data in steps S201 and S203. Plain text information is leaked during the process, ensuring the security of the data of both parties.
  • Step S205 The target participant device obtains sample label prediction information based on the first data information and the second data information, and encrypts the difference between the sample label prediction information and the sample label information to obtain the first encrypted information, Send the first encrypted information to the other participant device.
  • the target participant device Guest obtains the sample label prediction information of each sample based on the first data information and the second data information. Among them, based on sample label prediction information It can determine the probability of binary classification of a sample, thereby solving the binary classification problem in the vertical federation model.
  • sample label prediction information The sigmoid function is defined as
  • the sample label prediction information of each sample based on the sample label prediction information of each sample and the sample label information y calculates the difference between the sample label prediction information and the sample label information of each sample And perform encryption to obtain the first encrypted information in, Due to the encryption algorithm used, the encrypted information will not leak the original sample label information after being sent to other participants' device hosts, ensuring data security.
  • the encryption algorithm used in this step may be the semi-homomorphic encryption algorithm Paillier, or other optional semi-homomorphic encryption algorithms or homomorphic encryption algorithms may also be used, which is not specifically limited in this embodiment.
  • the target participant device Guest transmits the first encrypted information Sent to the other party device Host.
  • Step S207 Other participant devices obtain second encrypted information based on the first encrypted information, the second feature information and a random number and send it to the target participant device.
  • the other participant device Host obtains the second encrypted information based on the sum of the products of the first encrypted information, the second characteristic information, and random numbers.
  • y i represents the sample label of the i-th sample
  • x iH represents the second feature information of the i-th sample
  • ⁇ i represents the random number of the i-th sample.
  • Step S209 The target participant device decrypts the second encrypted information to obtain third decrypted information, and sends the third decrypted information to the other participant devices.
  • the third decryption information is obtained based on the cumulative sum of the difference between the sample label prediction information and the sample label information of each sample, the second feature information and the random number.
  • the decryption algorithm corresponding to the encryption algorithm in S205 is used, and the target participant device Guest Decrypt and obtain the third decrypted information Afterwards, the target party device Guest will send the third decrypted information Sent to the other party device Host.
  • Step S211 Other participant devices receive the third decryption information, obtain the fourth decryption information based on the random number, and obtain the second gradient information based on the fourth decryption information.
  • the other party device Host can decrypt the information based on the fourth Calculate our own second gradient information
  • Step S213 The target participant device calculates fifth plaintext information based on the difference between the sample label prediction information and the sample label information and the first feature information, and obtains the first gradient information based on the fifth plaintext information.
  • the target participant device Guest uses the difference between the sample label prediction information of each sample and the sample label information. And the sum of the products of the first characteristic information x G of each sample is obtained to obtain the fifth plaintext information. And based on the fifth plaintext information Calculate first gradient information
  • step S205 also includes: the target participant device based on the sample target
  • the loss function Loss is calculated using the label prediction information and the sample label information.
  • the loss function Loss can also include the first regularization term and the second regularization term, including:
  • Step S103 Any participant device obtains the model parameter change amount and the gradient information change amount based on the model parameters and gradient information, and conducts a preset number of rounds with other participant devices based on the model parameter change amount and the gradient information change amount. Through interactive calculation, the gradient search direction of the local side is obtained as a quasi-Newton condition.
  • any participant device uses, for example, a two-way loop recursive method to perform a preset number of interactive calculations with other participant devices based on the model parameter changes and the gradient information changes.
  • Obtain the gradient search direction That is, in this embodiment, after the target participant device Guest obtains the first gradient information and the other participant device Host obtains the second gradient information, the respective model parameter changes and gradient information changes are calculated based on the two-way loop recursive method.
  • a preset number of rounds of interactive calculations is performed, so that the target participant device Guest obtains the first gradient search direction, and the other participant device Host obtains the second gradient search direction.
  • the data calculated, sent and received by the target participant device Guest and other participant devices Host are all based on the model parameter changes, the transposed matrix of the model parameter changes, and the It is obtained by the vector product or scalar product of at least two of the gradient information changes and the transposed matrix of the gradient information changes. It does not involve the operation of large matrices, so the amount of calculation and communication in the whole process is very small. , thus ensuring rapid convergence of the model.
  • step S103 specifically includes:
  • Step S301 The target participant device Guest obtains the first model parameter change amount and the first gradient information change amount, and the other participant device Host obtains the second model parameter change amount and the second gradient information change amount.
  • g represent gradient information, where g G represents the first gradient information, and g H represents the second gradient information.
  • t represent the change amount ⁇ g of the gradient information g
  • t G represents the change amount of the first gradient information
  • t H represents the change amount of the second gradient information.
  • s represents the change amount of model parameters ⁇ w
  • s G represents the change amount of the first model parameter
  • s H represents the change amount of the second model parameter.
  • Step S303 Any participant device performs a preset number of interactive calculations with other participant devices based on the model parameter change amount and the gradient information change amount to obtain an intermediate change amount; the intermediate change amount is used to characterize the change amount. The size of the gradient information.
  • a two-way loop algorithm can be used to calculate the gradient search direction.
  • the method includes: during the backward loop process, any participant device performs interactive calculations with other participant devices for a preset number of rounds based on the first intermediate information to obtain the intermediate change amount.
  • the preset number of rounds is one of 3-5, and the number of rounds of backward circulation and forward circulation is the same.
  • the target participant device Guest having the first gradient information change amount t G and the first model parameter change amount s G is different from the target participant device Guest having the second gradient information change amount t H and the second model parameter change amount s H
  • the target participant device Guest obtains its own intermediate change amount q G
  • the other participant device Hosts obtain its own intermediate change amount q H .
  • any participant device exchanges first intermediate value information with other participant devices based on its own first intermediate value information and calculates the first intermediate value information based on the first intermediate value information of each participant device.
  • a global intermediate value to calculate the intermediate variation based on the first global intermediate value.
  • the first intermediate value information in the backward loop process includes ⁇ G , ⁇ H and ⁇ G , ⁇ H .
  • the target participant device Guest and other participant devices Host are respectively based on their own model parameter changes. .
  • the first global intermediate value may be the sum of the first intermediate value information of each participating device, or may be set according to requirements, and this specification does not limit this.
  • the target participant device Guest and other participant devices Host respectively obtain the first intermediate value information ⁇ G and ⁇ H based on the transposed matrix of their own gradient information changes and the product of the model parameter changes, and exchange their respective first
  • the first global intermediate value ⁇ is obtained after the intermediate value information ⁇ G and ⁇ H ; and then the first global intermediate value ⁇ , the transposed matrix of the model parameter variation and the gradient information are combined to calculate the first intermediate value information ⁇ G and ⁇ H , After exchanging the first intermediate value information ⁇ G and ⁇ H , the first global intermediate value ⁇ is calculated, and finally the intermediate change amount of the local side is calculated based on ⁇ .
  • Step S403 Iterate L rounds of the following steps, i from L-1 to 0, and j from k-1 to k-L.
  • the target participant device Guest calculates the intermediate process variables
  • the target participant device Guest calculates the intermediate process variables
  • each intermediate process variable in each step in step S403 are all calculations and exchanges of vector multiplication or scalar multiplication, and do not involve the calculation of large matrices. Therefore, the amount of calculation and communication during the training process is relatively large. Less, which not only ensures rapid convergence of the model, but also improves the hardware processing rate of the target participant device and other participant devices.
  • Step S305 Any participant device performs a preset number of interactive calculations with other participant devices based on the intermediate change amount to obtain the gradient search direction.
  • step S305 further includes: any participating device calculates its own second intermediate value information based on its own intermediate change amount; any participating device calculates its own second intermediate value information based on its own second intermediate value information, Exchange second intermediate value information with other participant devices and calculate a second global intermediate value based on the second intermediate value information of each participant device, so as to calculate the gradient search direction according to the second global intermediate value.
  • a two-way loop algorithm can be used to calculate the gradient search direction. This includes: in the forward loop process, any participant device performs a transformation based on the model parameter change, the transposed matrix of the model parameter change, the gradient information change, and the gradient information change.
  • the second intermediate value information is obtained by vector product or scalar product of at least two in the matrix, and based on the second intermediate value information, the intermediate change amount, and other participant devices perform interactive calculations for a preset number of rounds to obtain the Describe the gradient search direction.
  • the target participant device Guest after 3-5 rounds of interactive calculations between the target participant device Guest with the intermediate change amount qG and the other participant device Host with the intermediate change amount qH , the target participant device Guest obtains the The first gradient search direction p kG , other participant device Hosts obtain their own second gradient search direction p kH .
  • Step S501 Any participant device obtains first scalar information based on its own transposed matrix of the model parameter variation and the model parameter variation, and obtains first scalar information based on the transposed matrix of the gradient information variation, the gradient The information change amount obtains the second scalar information.
  • the first scalar information is based on the product of the transposed matrix of the first model parameter variation s G and the first model parameter variation s G Obtain, the transformation of the second scalar information based on the change of the first gradient information The product of the setting matrix and the change amount of the first gradient information get.
  • Step S503 any participant device interacts with other participant devices to obtain third scalar information and fourth scalar information of other participant devices; the third scalar information is based on the conversion of model parameter changes of other participant devices.
  • the fourth scalar information is obtained based on the transpose matrix and the gradient information variation of the gradient information variation of other participant devices.
  • the third scalar information is based on the product of the transposed matrix of the second model parameter variation s H and the second model parameter variation s H Obtained, the fourth scalar information is based on the product of the transposed matrix of the change amount of the second gradient information and the change amount of the second gradient information. get.
  • the target participant device Guest exchanges the first scalar information, the second scalar information, the third scalar information and the fourth scalar information with other participant devices Host, so that the target participant device Guest and other participant devices Hosts all have the above information.
  • Step S505 Any participant device performs a function based on the first scalar information
  • the second scalar information The third scalar information
  • the fourth scalar information and the intermediate changes q G and q H to calculate the second intermediate value information of the local party, exchange the second intermediate value information with other participant devices, and calculate the second global intermediate value based on the second intermediate value information of each participant device, so as to The gradient search direction is calculated based on the second global intermediate value.
  • the second intermediate value information in the forward loop process includes ⁇ .
  • the target participant device Guest and other participant devices Host respectively calculate their own second intermediate value information ⁇ , they need to exchange each participant device. of the second intermediate value information, thereby obtaining the second global intermediate value.
  • the second global intermediate value can be the sum of the second intermediate value information of each participating device, or can be set according to requirements, and this specification does not limit this.
  • step S505 further includes:
  • Step S601 based on the first scalar information exchanged between the target participant device Guest and other participant devices Host second scalar information third scalar information Fourth scalar information Calculate the value of
  • Step S607 L rounds of iteration, i from 0 to L-1, j from kL to k-1.
  • the calculation process is all vector multiplication or scalar multiplication, which does not involve the calculation of large matrices, thus reducing the amount of calculation in the model training process; at the same time, both parties
  • the interactive variables are all scalar results after vector inner products, which ensures data security and reduces the communication volume during data transmission. It not only ensures the rapid convergence of the model, but also improves the communication between the target participant device and other participants.
  • the federated learning model training method described in the embodiments of the present disclosure only needs 3 iterations to make the model converge. ; However, using the gradient descent method requires dozens of rounds of iterations to ensure model convergence. Therefore, the federated learning model training method described in the embodiments of the present disclosure can improve the convergence speed of the model.
  • Step S105 The target participant device obtains the model loss function, and calculates the step size information based on the gradient search direction and the model loss function.
  • step S105 the target participant device obtains the model loss function and calculates the step size information based on the gradient search direction and the model loss function, including:
  • Step S701 The target participant device obtains sample label information, and obtains sample label prediction information based on its own model parameters, feature information, and first data information of other participant devices; wherein the first data information is based on other participant devices. Model parameters and feature information are obtained.
  • the target participant device Guest first obtains the second data information by calculating the product of the transposed matrix of the model parameters and the feature information based on its own model parameters and feature information. Afterwards, the target participant device Guest based on the second data information Interact with other participants' device hosts and obtain the first data information of other participants' device hosts. Finally, the target participant device Guest is based on the first data information Second data information and a preset model function to obtain the sample label prediction information.
  • the preset model function is the sigmoid function
  • the sample label prediction information is defined as
  • Step S703 The target participant device calculates a loss function based on the sample label prediction information and the sample label information.
  • Step S705 The target participant device determines whether the loss function meets the preset conditions. If so, the current step information is used as the final step information; otherwise, the value of the step information is reduced and the loss function is recalculated. .
  • the preset condition may be Armijo condition. Therefore, it can be judged whether the loss function Loss satisfies Armijo conditions, including: Loss(y,x H ,x G ,w H + ⁇ p H ,w G + ⁇ p G ) ⁇ Loss(y,x H ,x G ,w H , w G )+c 1 ⁇ (g H T p H +g G T p G ), where c 1 is a hyperparameter (for example, it can take the value 1E-4).
  • the current step information will be used as the final step information ⁇ ; if the loss function does not meet the Armijo condition, the value of the step information will be reduced, for example, to 1/2 of the original, and based on the reduction
  • the model parameters of both parties are updated and the loss function is recalculated until the loss function does not meet the Armijo condition.
  • the methods in the embodiments of the present disclosure can be executed by a single device, such as a computer or server.
  • the method of this embodiment can also be applied in a distributed scenario, and is completed by multiple devices cooperating with each other.
  • one device among the multiple devices can only perform one or more steps in the method of the embodiment of the present disclosure, and the multiple devices will interact with each other to complete all the steps. method described.
  • the present disclosure also provides an electronic device, including a memory, a processor, and a computer stored in the memory and capable of running on the processor.
  • a computer program is provided, and when the processor executes the program, the method described in any of the above embodiments is implemented.
  • FIG. 6 shows a more specific hardware structure diagram of an electronic device provided by this embodiment.
  • the device may include: a processor 1010, a memory 1020, an input/output interface 1030, a communication interface 1040, and a bus 1050.
  • the processor 1010, the memory 1020, the input/output interface 1030 and the communication interface 1040 implement communication connections between each other within the device through the bus 1050.
  • the processor 1010 can be implemented using a general-purpose CPU (Central Processing Unit, central processing unit), a microprocessor, an application specific integrated circuit (Application Specific Integrated Circuit, ASIC), or one or more integrated circuits, and is used to execute related program to implement the technical solutions provided by the embodiments of this specification.
  • a general-purpose CPU Central Processing Unit, central processing unit
  • a microprocessor central processing unit
  • ASIC Application Specific Integrated Circuit
  • the memory 1020 can be implemented in the form of ROM (Read Only Memory), RAM (Random Access Memory), static storage device, dynamic storage device, etc.
  • the memory 1020 can store operating systems and other application programs. When implementing the technical solutions provided by the embodiments of this specification through software or firmware, the relevant program codes are stored in the memory 1020 and called and executed by the processor 1010 .
  • the input/output interface 1030 is used to connect the input/output module to realize information input and output.
  • the input/output/module can be configured in the device as a component (not shown in the figure), or can be externally connected to the device to provide corresponding functions.
  • Input devices can include keyboards, mice, touch screens, microphones, various sensors, etc., and output devices can include monitors, speakers, vibrators, indicator lights, etc.
  • the communication interface 1040 is used to connect a communication module (not shown in the figure) to realize communication interaction between this device and other devices.
  • the communication module can realize communication through wired means (such as USB, network cable, etc.) or wireless means (such as mobile network, WIFI, Bluetooth, etc.).
  • Bus 1050 includes a path that carries information between various components of the device (eg, processor 1010, memory 1020, input/output interface 1030, and communication interface 1040).
  • the above device only shows the processor 1010, the memory 1020, the input/output interface 1030, the communication interface 1040 and the bus 1050, during specific implementation, the device may also include necessary components for normal operation. Other components.
  • the above-mentioned device may only include components necessary to implement the embodiments of this specification, and does not necessarily include all components shown in the drawings.
  • a non-transitory computer-readable storage medium stores computer instructions, and the computer instructions are used to cause the computer to perform the method described in any of the above embodiments.
  • the computer-readable media in this embodiment include permanent and non-permanent, removable and non-removable media, and information storage can be implemented by any method or technology.
  • Information may be computer-readable instructions, data structures, modules of programs, or other data.
  • Examples of computer storage media include, but are not limited to, phase change memory (PRAM), static random access memory (SRAM), dynamic random access memory (DRAM), other types of random access memory (RAM), read-only memory (ROM), electrically erasable programmable read-only memory (EEPROM), flash memory or other memory technology, compact disc read-only memory (CD-ROM), digital versatile disc (DVD) or other optical storage, Magnetic tape cassettes, tape magnetic disk storage or other magnetic storage devices or any other non-transmission medium can be used to store information that can be accessed by a computing device.
  • PRAM phase change memory
  • SRAM static random access memory
  • DRAM dynamic random access memory
  • RAM random access memory
  • ROM read-only memory
  • EEPROM electrically erasable programmable read-only memory
  • the computer instructions stored in the storage medium of the above embodiments are used to cause the computer to execute the method described in any of the above embodiments, and have the beneficial effects of the corresponding method embodiments, which will not be described again here.
  • DRAM dynamic RAM

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Medical Informatics (AREA)
  • Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Evolutionary Computation (AREA)
  • Bioethics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computer Hardware Design (AREA)
  • Computer Security & Cryptography (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本公开提供一种联邦学习模型训练方法、电子设备及存储介质。本公开提供的联邦学习模型训练方法、电子设备及存储介质,各参与方设备通过与其他参与方设备进行联合加密训练获得本方的梯度信息之后,基于模型参数变化量和梯度信息变化量与其他参与方设备进行联合训练从而获得各自的梯度搜索方向;之后,目标参与方设备基于梯度搜索方向以及模型损失函数计算步长信息;最后,各参与方设备基于梯度搜索方向、步长信息对本方的模型参数进行更新,从而无需计算Hessian矩阵H的逆矩阵,相比于随机梯度下降方法、牛顿法和拟牛顿法其计算量小、通信量少,且可以保证快速收敛。

Description

联邦学习模型训练方法、电子设备及存储介质
本申请要求2022年3月14日递交的,标题为“联邦学习模型训练方法、电子设备及存储介质”、申请号为CN202210249166.1的中国发明专利申请的优先权。
技术领域
本公开涉及人工智能技术领域,尤其涉及一种联邦学习模型训练方法、电子设备及存储介质。
背景技术
随着计算机技术的发展和人工智能技术的进步,联邦学习得到了越来越广泛的应用。在联邦学习中,具有不同业务数据的多个参与方通过协作来完成联邦学习模型的训练。
在联邦学习模型中,通常采用随机梯度下降法(SGD)、牛顿法、拟牛顿法来优化模型。但是,随机梯度下降法的收敛速度较慢,牛顿法、拟牛顿法所采用的二阶导数的计算复杂度高。
发明内容
有鉴于此,本公开的目的在于提出一种联邦学习模型训练方法、电子设备及存储介质。
基于上述目的,本公开提供了一种联邦学习模型训练方法,包括:
任一参与方设备基于本方模型参数和特征信息与其他参与方设备进行联合加密训练,获得本方的梯度信息;
任一参与方设备基于模型参数和梯度信息获取模型参数变化量和梯度信息变化量,并基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得本方的梯度搜索方向作为拟牛顿条件;
目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息;其中,所述目标参与方设备为任一参与方设备中具有标 签信息的参与方设备,所述模型损失函数为凸函数;
任一参与方设备基于所述梯度搜索方向、所述步长信息对本方的模型参数进行更新,直至所述联邦学习模型收敛。
所述任一参与方设备基于所述模型参数变化量和所述梯度信息变化量,采用双向循环递归方法与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向作为拟牛顿条件,包括:
任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量;所述中间变化量用于表征所述梯度信息的大小;
任一参与方设备基于所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。
可选的,所述任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量,还包括:
任一参与方设备基于本方的所述模型参数变化量和所述梯度信息变化量计算本方第一中间值信息,与其他参与方设备交换第一中间值信息并基于各参与方设备的第一中间值信息计算第一全局中间值,以根据所述第一全局中间值计算所述中间变化量。
可选的,所述第一中间值信息基于所述梯度信息变化量的转置矩阵与所述模型参数变化量的乘积获得。
可选的,所述任一参与方设备基于所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向,还包括:
任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息;
任一参与方设备基于本方的所述第二中间值信息,与其他参与方设备交换第二中间值信息并基于各参与方设备的第二中间值信息计算第二全局中间值,以根据所述第二全局中间值计算所述梯度搜索方向。
可选的,所述任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息,包括:
任一参与方设备基于本方的所述模型参数变化量的转置矩阵、所述模型参数变化量获得第一标量信息,基于本方的所述梯度信息变化量的转置矩阵、所述梯度信息变化量获得第二标量信息;
任一参与方设备与其他参与方设备进行交互以获得其他参与方设备的第 三标量信息和第四标量信息;所述第三标量信息基于其他参与方设备的模型参数变化量的转置矩阵、模型参数变化量获得,所述第四标量信息基于其他参与方设备的梯度信息变化量的转置矩阵、梯度信息变化量获得;
任一参与方设备基于所述第一标量信息、所述第二标量信息、所述第三标量信息、所述第四标量信息、所述中间变化量计算本方第二中间值信息。
可选的,所述第一全局中间值为各参与方设备的第一中间值信息之和,所述第二全局中间值为各参与方设备的第二中间值信息之和。
可选的,所述目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息,包括:
目标参与方设备获取样本标签信息,并基于本方模型参数、特征信息以及其他参与方设备的第一数据信息获得样本标签预测信息;其中,所述第一数据信息基于其他参与方设备的模型参数、特征信息获得;
目标参与方设备基于所述样本标签预测信息及所述样本标签信息计算所述模型损失函数;
目标参与方设备判断所述模型损失函数是否满足预设条件,若是,则将当前步长信息作为最终的步长信息;否则,减少所述步长信息的值并重新计算所述模型损失函数。
可选的,所述基于本方模型参数、特征信息以及其他参与方设备的数据信息获得样本标签预测信息,包括:
目标参与方设备基于本方模型参数、特征信息计算模型参数的转置矩阵与特征信息的乘积获得第二数据信息;
目标参与方设备基于所述第二数据信息与其他参与方设备进行交互,获得其他参与方设备的第一数据信息;
目标参与方设备基于第一数据信息、第二数据信息以及预设模型函数获得所述样本标签预测信息。
本公开还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,所述处理器执行所述程序时实现如上述任意一项所述的方法。
本公开还提供了一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行上述任一所述的方法。
从上面所述可以看出,本公开提供的联邦学习模型训练方法、电子设备及存储介质,各参与方设备通过与其他参与方设备进行联合加密训练获得本方的梯度信息之后,基于模型参数变化量和梯度信息变化量与其他参与方设备进行联合训练从而获得各自的梯度搜索方向;之后,目标参与方设备基于梯度搜索方向以及模型损失函数计算步长信息;最后,各参与方设备基于梯度搜索方向、步长信息对本方的模型参数进行更新,从而无需计算Hessian矩阵H的逆矩阵,相比于随机梯度下降方法、牛顿法和拟牛顿法其计算量小、通信量少,且可以保证快速收敛。
附图说明
为了更清楚地说明本公开或相关技术中的技术方案,下面将对实施例或相关技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本公开的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1为本公开实施例所述联邦学习模型训练方法的流程示意图;
图2为本公开实施例所述联邦学习模型的框架示意图;
图3为本公开实施例所述联邦学习模型的样本信息示意图;
图4为本公开实施例任一参与方设备获取梯度信息的方法的流程示意图;
图5为本公开实施例梯度搜索方法的获取方法的流程示意图;
图6为本公开实施例电子设备的结构示意图。
具体实施方式
为使本公开的目的、技术方案和优点更加清楚明白,以下结合具体实施例,并参照附图,对本公开进一步详细说明。
需要说明的是,除非另外定义,本公开实施例使用的技术术语或者科学术语应当为本公开所属领域内具有一般技能的人士所理解的通常意义。本公开实施例中使用的“第一”、“第二”以及类似的词语并不表示任何顺序、数量或者重要性,而只是用来区分不同的组成部分。“包括”或者“包含”等类似的词语意指出现该词前面的元件或者物件涵盖出现在该词后面列举的元件或者物件及其等同,而不排除其他元件或者物件。“连接”或者“相连”等类似的词语并非限定于物理的或者机械的连接,而是可以包括电性的连接,不管是直接的还是间接的。“上”、“下”、“左”、“右”等仅用于表示相对位置关系,当被描述对象的绝对 位置改变后,则该相对位置关系也可能相应地改变。
人工智能(Artificial Intelligence,AI)是利用数字计算机或者数字计算机控制的机器模拟、延伸和扩展人的智能,感知环境、获取知识并使用知识获得最佳结果的理论、方法、技术及应用系统。换句话说,人工智能是计算机科学的一个综合技术,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。人工智能也就是研究各种智能机器的设计原理与实现方法,使机器具有感知、推理与决策的功能。
人工智能技术是一门综合学科,涉及领域广泛,既有硬件层面的技术也有软件层面的技术。人工智能基础技术一般包括如传感器、专用人工智能芯片、云计算、分布式存储、大数据处理技术、操作/交互系统、机电一体化等技术。人工智能软件技术主要包括计算机视觉技术、语音处理技术、自然语言处理技术以及机器学习/深度学习、自动驾驶、智慧交通等几大方向。
机器学习(Machine Learning,ML)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。机器学习是人工智能的核心,是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。机器学习和深度学习通常包括人工神经网络、置信网络、强化学习、迁移学习、归纳学习、示教学习等技术。
随着机器学习的快速发展,机器学习可应用于各个领域,如数据挖掘、计算机视觉、自然语言处理、生物特征识别、医学诊断、检测信用卡欺诈、证券市场分析和DNA序列测序等。对比传统机器学习方法,深度神经网络是一种较新的技术,用多层的网络结构建立机器学习模型,从数据中自动学习到表示特征。由于易于使用、实践效果好,在图像识别、语音识别、自然语言处理、搜索推荐等领域得到广泛应用。
联邦学习(Federated Learning),又可以称为联邦机器学习、联合学习、联盟学习等。联邦机器学习是一个机器学习框架,各个参与方联合建立机器学习模型,且在训练中只交换中间数据,而不直接交换各参与方的业务数据。
具体地,假设企业A、企业B各自建立一个任务模型,单个任务可以是分类或预测,而这些任务也已经在获得数据时由各自用户的认可。然而,由于数据不完整,例如企业A缺少标签数据、企业B缺少特征数据,或者数据不充 分,样本量不足以建立好的模型,那么在各端的模型有可能无法建立或效果并不理想。联邦学习要解决的问题是如何在A和B各端建立高质量的机器学习模型,该模型的训练兼用A和B等各个企业的数据,并且各个企业的自有数据不被其他方知晓,即在不交换本方数据的情况下,建立一个共有模型。这个共有模型就好像各方把数据聚合在一起建立的最优模型一样。这样,建好的模型在各方的区域仅为自有的目标服务。
联邦学习的实施架构中包括至少两个参与方设备,各个参与方设备分别可以包括不同的业务数据,还可以通过设备、计算机、服务器等参与模型的联合训练;其中,各个参与方设备可以包括一台服务器、多台服务器、云计算平台和虚拟化中心中的至少一种。这里的业务数据例如可以是字符、图片、语音、动画、视频等各种数据。通常,各个参与方设备所包含的业务数据具有相关性,各个训练成员对应的业务方也可以具有相关性。单个参与方设备可以持有一个业务的业务数据,也可以持有多个业务方的业务数据。
在该实施架构下,可以由两个或两个以上的参与方设备共同训练模型。这里的模型可以用于处理业务数据,得到相应的业务处理结果,因此,也可以称之为业务模型。具体处理什么样的业务数据,得到什么样的业务处理结果,根据实际需求而定。例如,业务数据可以是用户金融相关的数据,得到的业务处理结果为用户的金融信用评估结果,再例如,业务数据可以是客服数据,得到的业务处理结果为客服答案的推荐结果,等等。业务数据的形式也可以是文字、图片、动画、音频、视频等各种形式的数据。各个参与方设备分别可以利用训练好的模型对本地业务数据进行本地业务处理。
可以理解,联邦学习可以分为横向联邦学习(特征对齐)、纵向联邦学习(样本对齐)与联邦迁移学习。本说明书提供的实施架构基于纵向联邦学习提出,即,各个参与方设备之间样本主体重叠,从而可以分别提供样本的部分特征的联邦学习情形。样本主体即待处理的业务数据对应的主体,例如金融风险性评估的业务主体为用户或者企业等。
在纵向联邦学习的二分类场景中,通常采用随机梯度下降(SGD)方法或者牛顿法及拟牛顿法来实现模型的优化。其中,随机梯度下降(SGD)方法的核心思想是利用损失函数对模型参数的一阶梯度来迭代优化模型,但是现有的一阶优化器只利用到了损失函数对模型参数的一阶梯度,收敛速度会比较慢;牛顿法是以二阶导数海森(Hessian)矩阵H的逆矩阵乘以一阶梯度来引导参数更新, 而这种方法的计算复杂度较高;拟牛顿方法即是将牛顿法中的二阶导数Hessian矩阵的逆用一个n阶矩阵来代替,但是这种方式的算法收敛速度仍然较慢。
有鉴于此,本公开实施例提供一种联邦学习模型训练方法,该方法可以提高纵向联邦学习中模型的收敛速度。如图1所示,所述联邦学习模型训练方法,包括:
步骤S101,任一参与方设备基于本方模型参数和特征信息与其他参与方设备进行联合加密训练,获得本方的梯度信息。
在本实施例中,至少两个参与方设备共同训练联邦学习模型,且各个参与方设备均可基于本参与方设备上的业务数据获得特征信息。在联邦学习模型的训练过程中,各个参与方设备基于加密后的模型参数、特征信息等信息与其他参与方设备进行交互,从而使得各个参与方设备均获得其各自的梯度信息。
步骤S103,任一参与方设备基于模型参数和梯度信息获取模型参数变化量和梯度信息变化量,并基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得本方的梯度搜索方向作为拟牛顿条件。
在本实施例中,在本实施例中,任一参与方设备基于模型参数和梯度信息,通过预设轮数的交互计算即可获得各个参与方设备的梯度搜索方向,各个参与方设备所获得的梯度搜索方向相当于牛顿法w=w-H-1g中的-H-1g,因此无需直接计算海森矩阵H或者海森矩阵的逆矩阵H-1即可,减小了数据的计算量和交互量。
步骤S105,目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息;其中,所述目标参与方设备为任一参与方设备中具有标签信息的参与方设备,所述模型损失函数为凸函数。
本实施例中,由于模型损失函数为凸函数,因此基于该模型损失函数的凸性,通过对其局部极值点的计算即可获得其全局极值点。基于步骤S103中所计算出的各个参与方设备的梯度搜索方向,选择一个步长信息对模型参数进行预更新,直至模型损失函数满足搜索停止条件,则基于该梯度搜索方向、步长信息对模型参数进行更新。
步骤S107,任一参与方设备基于所述梯度搜索方向、所述步长信息对本方的模型参数进行更新,直至所述联邦学习模型收敛。
可选的,在上述实施例中,任一参与方设备为参与联邦学习模型训练中的 全部参与方设备中的任意一个,不区分该参与方设备是否具有标签信息。即本实施例中步骤S101、S103以及S107为参与联邦学习模型训练中的全部参与方设备均可执行的步骤。目标参与方设备为参与联邦学习模型训练中的全部参与方设备中具有标签信息的参与方设备,该目标参与方设备不仅执行步骤S101、S103以及S107的方法,也执行步骤S105中的方法。
在本实施例中,各参与方设备通过与其他参与方设备进行联合加密训练获得本方的梯度信息之后,基于模型参数变化量和梯度信息变化量与其他参与方设备进行联合训练从而获得各自的梯度搜索方向作为拟牛顿条件;之后,目标参与方设备基于梯度搜索方向以及模型损失函数计算步长信息;最后,各参与方设备基于梯度搜索方向、步长信息对本方的模型参数进行更新,从而无需计算Hessian矩阵H的逆矩阵,相比于随机梯度下降方法、牛顿法和拟牛顿法其计算量小、通信量少,且可以保证快速收敛。
如图2所示,上述实施例所述方法应用于目标参与方设备Guest和除了目标参与方设备以外的其他参与方设备Host之间。其中,所述目标参与方设备Guest存储多个样本的第一特征信息和样本标签信息,所述其他参与方设备Host存储多个样本的第二特征信息。其他参与方设备可以仅包括一个参与方设备,也可以包括多个参与方设备,本实施例中以其他参与方设备仅包括一个参与方设备为例,详细说明基于标参与方设备Guest和其他参与方设备Host的联邦学习模型训练方法。
如图3所示,在一个具体的实施例中,基于双方共有信息(例如id信息)实现目标参与方设备Guest和其他参与方设备Host的数据对齐,对齐后的目标参与方设备Guest和其他参与方设备Host均包括id信息分别为1、2、3的多个样本。其中,其他参与方设备Host包括特征1、特征2以及特征3等多个第二特征信息;目标参与方设备Guest包括特征4(点击)、特征5、特征6等多个第一特征信息以及样本标签信息(购买)。
为了便于本公开实施例的后续表述,另目标参与方设备Guest和其他参与方设备Host的样本的数量为n。目标参与方设备Guest中每一条第一特征信息记为xG,目标参与方设备Guest中n个样本全部的第一特征信息列记为每一个样本的样本标签为y,n个样本全部的样本标签信息列为{y(i)};其他参与方设备Host中每一条第二特征信息记为xH,其他参与方设备Host中n个样本全部的第二特征信息列为其中,i表示n个样本中的第i个。
步骤S101,任一参与方设备基于本方模型参数和特征信息与其他参与方设备进行联合加密训练,获得本方的梯度信息。
在本实施例中,目标参与方设备Guest包括构建在目标参与方设备Guest本地的第一本地模型,第一本地模型包括第一模型参数wG;相应的,其他参与方设备Host包括构建在其他参与方设备Host本地的第二本地模型,第二本地模型包括第二模型参数wH
在一些实施例中,在步骤S101中,采用同态加密算法或半同态加密算法对联合加密训练过程中的交互数据进行加密,例如可采用Paillier算法进行加密从而保证目标参与方设备Guest和其他参与方设备Host在联合训练的过程中不会泄露。如图4所示,步骤S101具体包括以下步骤:
步骤S201,其他参与方设备获取第一数据信息并发送至目标参与方设备,所述第一数据信息基于第二模型参数与第二特征信息获得。
在本步骤中,其他参与方设备Host获取其他参与方设备本地的第二本地模型的第二模型参数wH,并计算第二模型参数wH与第二特征信息的內积,从而获得第一数据信息并将第一数据信息发送至目标参与方设备Guest。
可选,在本实施例中,第一数据信息包括第二模型参数wH的转置矩阵与每一条第二特征信息xH的內积,因此第一数据信息包括与n个样本对应的n条信息。
可选的,在步骤S201中,其他参与方设备Host还可以计算第一正则项并发送至目标参与方设备Guest。其中,第一正则项为L2正则项,且第一正则项为α表示正则系数。
可选的,当处于第一次更新周期内时,第二模型参数wH为初始化后的模型参数初始值;当处于中间的更新周期内时,第二模型参数wH为第二本地模型在上一更新周期内更新后的模型参数。
步骤S203,目标参与方设备获取第二数据信息,所述第二数据信息基于第一模型参数与第一特征信息获得。
在本步骤中,目标参与方设备Guest获取第一本地模型的第一模型参数wG,并计算第一模型参数wG与第一特征信息的內积,从而获得第二数据信息具体的,在本实施例中,第二数据信息包括第一模型参数wG的转置矩阵与每一条第一特征信息xG的內积。
可选的,在本实施例中,目标参与方设备Guest还计算第二正则项。其中,第二正则项也为L2正则项,且第二正则项为α表示正则系数。
可选的,当处于第一次更新周期内时,第一模型参数wG为初始化后的模型参数初始值;当处于中间的更新周期内时,第一模型参数wG为第一本地模型在上一更新周期内更新后的模型参数。
在步骤S201与步骤S203中,由于在纵向联邦LR模型中,第一模型参数wG、第二模型参数wH是一维向量,因此基于获得的第一数据信息以及基于获得的第二数据信息为矩阵相乘后的结果,当第一数据信息和第二数据信息被发送到对方时,对方无法恢复原本的数据信息,从而不会在步骤S201与步骤S203中数据传输过程中泄露明文信息,保证了双方数据的安全。
步骤S205,目标参与方设备基于所述第一数据信息、所述第二数据信息获得样本标签预测信息,对所述样本标签预测信息与所述样本标签信息的差值加密获得第一加密信息,将所述第一加密信息发送至所述其他参与方设备。
在本步骤中,目标参与方设备Guest基于所述第一数据信息、所述第二数据信息获得每一条样本的样本标签预测信息其中,基于样本标签预测信息可判断样本的二分类的概率,从而可以解决纵向联邦模型中二分类的问题。可选的,在一些实施例中,样本标签预测信息sigmoid函数定义为
之后,基于每一条样本的样本标签预测信息以及样本标签信息y计算每一条样本的所述样本标签预测信息与所述样本标签信息的差值并进行加密获得第一加密信息其中,由于采用了加密算法,加密后的信息在发送至其他参与方设备Host后不会泄露原始的样本标签信息,保证了数据的安全性。
可选的,本步骤中所采用的加密算法可以为半同态加密算法Paillier,或者也可采用其他可选的半同态加密算法或者同态加密算法,本实施例对此不作具体限定。
最后,目标参与方设备Guest将所述第一加密信息发送至所述其他参与方设备Host。
步骤S207,其他参与方设备基于所述第一加密信息、所述第二特征信息以及随机数获取第二加密信息并发送至目标参与方设备。
在本实施例中,其他参与方设备Host基于所述第一加密信息、所述第二特征信息以及随机数的乘积之和获得所述第二加密信息其中,表示第i个样本的样本标签预测信息,yi表示第i个样本的样本标签,表示xiH表示第i个样本的第二特征信息,∈i表示第i个样本的随机数。通过随机数的增加,当其他参与方设备Host将第二加密信息发送至目标参与方设备Guest时,目标参与方设备Guest无法还原出xiH的明文信息,也无法获得其他参与方设备的第二梯度信息,从而避免了数据的泄露。
步骤S209,目标参与方设备对所述第二加密信息进行解密获得第三解密信息,并将所述第三解密信息发送至所述其他参与方设备。其中,第三解密信息基于每一个样本的样本标签预测信息与样本标签信息的差值、第二特征信息以及随机数的积的累加之和获得。
在本步骤中,采用与S205中的加密算法对应的解密算法,目标参与方设备Guest对第二加密信息进行解密,获得第三解密信息之后,目标参与方设备Guest将所述第三解密信息发送至所述其他参与方设备Host。
步骤S211,其他参与方设备接收第三解密信息,基于所述随机数获得第四解密信息,并基于所述第四解密信息获得第二梯度信息。
其他参与方设备Host接收第三解密信息后,可去掉随机数∈i获得第四解密信息由于第四解密信息是累加值,因此即使其他参与方设备Host已知xiH也无法解析出每一条从而避免了数据的泄露。
之后,其他参与方设备Host可以基于第四解密信息计算本方的第二梯度信息
步骤S213,目标参与方设备根据所述样本标签预测信息与所述样本标签信息的差值以及第一特征信息计算第五明文信息,基于所述第五明文信息获得所述第一梯度信息。
在本步骤中,目标参与方设备Guest基于每一条样本的所述样本标签预测信息与所述样本标签信息的差值以及每一条样本的第一特征信息xG的乘积之和获得第五明文信息并基于第五明文信息计算第一梯度信息
在上述实施例中,步骤S205中还包括:目标参与方设备基于所述样本标 签预测信息、所述样本标签信息计算损失函数Loss。可选的,损失函数Loss中还可以包括第一正则项和第二正则项,包括:
步骤S103,任一参与方设备基于模型参数和梯度信息获取模型参数变化量和梯度信息变化量,并基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得本方的梯度搜索方向作为拟牛顿条件。
可选的,在本实施例中,任一参与方设备基于所述模型参数变化量和所述梯度信息变化量,采用例如双向循环递归方法与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。即在本实施例中,目标参与方设备Guest获得第一梯度信息、其他参与方设备Host获得第二梯度信息之后,计算各自的模型参数变化量和梯度信息变化量,并基于双向循环递归方法进行预设轮数的交互计算,从而使得目标参与方设备Guest获得第一梯度搜索方向、其他参与方设备Host获得第二梯度搜索方向。同时,由于在本实施例中,目标参与方设备Guest与其他参与方设备Host所计算、发送以及接收的数据均是基于所述模型参数变化量、所述模型参数变化量的转置矩阵、所述梯度信息变化量、所述梯度信息变化量的转置矩阵中至少两个的向量乘积或标量乘积所获得的,而不涉及大矩阵的运算,因此整个过程中计算量和通信量都很小,从而可以保证模型的快速收敛。
在本实施例中,如图5所示,步骤S103中具体包括:
步骤S301,目标参与方设备Guest获取第一模型参数变化量和第一梯度信息变化量,其他参与方设备Host获取第二模型参数变化量和第二梯度信息变化量。
在本实施例中,为了便于表示,令g表示梯度信息,其中,gG表示第一梯度信息,gH表示第二梯度信息。令t表示梯度信息g的变化量Δg,则tG表示第一梯度信息变化量,tH表示表示第二梯度信息变化量。s表示模型参数变化量Δw,则sG表示第一模型参数变化量,sH表示第二模型参数变化量。
步骤S303,任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量;所述中间变化量用于表征所述梯度信息的大小。
可选的,在本实施例中,可采用双向循环算法进行梯度搜索方向的计算。 其中,包括:在后向循环过程中,任一参与方设备基于基于所述第一中间信息与其他参与方设备进行预设轮数的交互计算,获得中间变化量。
其中,预设轮数为3-5中的一个,且后向循环与前向循环的轮数相同。
在本实施例中,具有第一梯度信息变化量tG和第一模型参数变化量sG的目标参与方设备Guest与具有第二梯度信息变化量tH和第二模型参数变化量sH的其他参与方设备Host,进行3-5轮的交互计算后,目标参与方设备Guest获得本方的中间变化量qG,其他参与方设备Host获得本方的中间变化量qH
同时,在后向循环过程中,任一参与方设备基于本方的第一中间值信息,与其他参与方设备交换第一中间值信息并基于各参与方设备的第一中间值信息计算第一全局中间值,以根据所述第一全局中间值计算所述中间变化量。
在本实施例中,后向循环过程中的第一中间值信息包括ρG、ρH和αG、αH,目标参与方设备Guest与其他参与方设备Host分别基于本方的模型参数变化量、梯度信息变化量计算本方的第一中间值信息之后,需交换各参与方设备的第一中间值信息,从而获得第一全局中间值ρ和α。可选的,第一全局中间值可以为各参与方设备的第一中间值信息之和,或者也可以根据需求进行设置,本说明书对此不作限制。
具体的,目标参与方设备Guest与其他参与方设备Host分别基于本方梯度信息变化量的转置矩阵、模型参数变化量的乘积获得第一中间值信息ρG、ρH,交换各自的第一中间值信息ρG、ρH后获得第一全局中间值ρ;再结合该第一全局中间值ρ、模型参数变化量的转置矩阵以及梯度信息计算第一中间值信息αG、αH,再交换第一中间值信息αG、αH后计算第一全局中间值α,最终基于α计算本方的中间变化量。
下面结合具体实施例进一步详述本实施例中后向循环的步骤,包括:
步骤S401,目标参与方设备Guest初始化qG=gkG,其他参与方设备Host初始化qH=gkH
步骤S403,对以下步骤迭代L轮,i从L-1到0,j从k-1到k-L。其中L表示预设轮数,且L=3~5;k表示当前的循环轮数。
1).其他参与方设备Host方计算中间过程变量
2).目标参与方设备Guest方计算中间过程变量
3).目标参与方设备Guest和其他参与方设备Host交换ρ值后计算
4).其他参与方设备Host方计算中间过程变量
5).目标参与方设备Guest方计算中间过程变量
6).目标参与方设备Guest和其他参与方设备Host交换α值后计算αi=αiHiG
7).其他参与方设备Host方计算中间变化量qH=qHitjH
8).目标参与方设备Guest方计算中间变化量qG=qGitjG
在步骤S403中各步骤的各中间过程变量的计算与交换过程中,都是向量乘法或标量乘法的计算与交换,不涉及大矩阵的计算,因此在训练过程中的计算量和通信量都较少,不仅可以保证模型的快速收敛,还可以提高目标参与方设备与其他参与方设备的硬件处理速率。
步骤S305,任一参与方设备基于所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。
可选的,步骤S305进一步包括:任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息;任一参与方设备基于本方的所述第二中间值信息,与其他参与方设备交换第二中间值信息并基于各参与方设备的第二中间值信息计算第二全局中间值,以根据所述第二全局中间值计算所述梯度搜索方向。
在本实施例中,可采用双向循环算法进行梯度搜索方向的计算。其中,包括:在前向循环过程中,任一参与方设备基于所述模型参数变化量、所述模型参数变化量的转置矩阵、所述梯度信息变化量、所述梯度信息变化量的转置矩阵中至少两个的向量乘积或标量乘积获得第二中间值信息,并基于所述第二中间值信息、所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。
在本实施例中,具有中间变化量qG的目标参与方设备Guest与具有中间变化量qH的其他参与方设备Host,进行3-5轮的交互计算后,目标参与方设备Guest获得本方的第一梯度搜索方向pkG,其他参与方设备Host获得本方的第二梯度搜索方向pkH
下面结合具体实施例进一步详述本实施例中前向循环的步骤,包括:
步骤S501,任一参与方设备基于本方的所述模型参数变化量的转置矩阵、所述模型参数变化量获得第一标量信息,基于所述梯度信息变化量的转置矩阵、所述梯度信息变化量获得第二标量信息。
在本实施例中,第一标量信息基于第一模型参数变化量sG的转置矩阵与第一模型参数变化量sG的积获得,第二标量信息基于第一梯度信息变化量的转 置矩阵与第一梯度信息变化量的积获得。
步骤S503,任一参与方设备与其他参与方设备进行交互以获得其他参与方设备的第三标量信息和第四标量信息;所述第三标量信息基于其他参与方设备的模型参数变化量的转置矩阵、模型参数变化量获得,所述第四标量信息基于其他参与方设备的梯度信息变化量的转置矩阵、梯度信息变化量获得。
在本实施例中,第三标量信息基于第二模型参数变化量sH的转置矩阵与第二模型参数变化量sH的积获得,第四标量信息基于第二梯度信息变化量的转置矩阵与第二梯度信息变化量的积获得。
在本实施例中,目标参与方设备Guest与其他参与方设备Host交换第一标量信息、第二标量信息、第三标量信息以及第四标量信息,从而使得目标参与方设备Guest与其他参与方设备Host均具有上述信息。
步骤S505,任一参与方设备基于所述第一标量信息所述第二标量信息所述第三标量信息所述第四标量信息以及中间变化量qG、qH计算本方第二中间值信息,并与其他参与方设备交换第二中间值信息并基于各参与方设备的第二中间值信息计算第二全局中间值,以根据所述第二全局中间值计算所述梯度搜索方向。
在本实施例中,前向循环过程中的第二中间值信息包括β,目标参与方设备Guest与其他参与方设备Host分别计算本方的第二中间值信息β之后,需交换各参与方设备的第二中间值信息,从而获得第二全局中间值。可选的,第二全局中间值可以为各参与方设备的第二中间值信息之和,或者也可以根据需求进行设置,本说明书对此不作限制。
可选的,步骤S505进一步包括:
步骤S601,根据目标参与方设备Guest和其他参与方设备Host交换的第一标量信息第二标量信息第三标量信息第四标量信息的值计算
步骤S603,目标参与方设备Guest和其他参与方设备Host分别计算D0=γkI,其中I为对角矩阵。
步骤S605,其他参与方设备Host方计算zH=D0·qH,目标参与方设备Guest计算zG=D0·qG
步骤S607,迭代L轮,i从0到L-1,j从k-L到k-1。其中,L表示预设的循环轮数,且L=3~5中的一个;k表示当前的循环轮数。
1).其他参与方设备Host方计算
2).目标参与方设备Guest方计算
3).目标参与方设备Guest和其他参与方设备Host交换β值后计算βi=βHG
4).其他参与方设备Host方计算zH=zH+(αii)sjH
5).目标参与方设备Guest方计算zG=zG+(αii)sjG
步骤S609,其他参与方设备Host方得到第二梯度搜索方向pkH=-zH,目标参与方设备Guest方得到第一梯度搜索方向pkG=-zG
在上述实施例中,由于计算过程中除了一次单位矩阵与向量的乘法,其他都是向量乘法或标量乘法,不涉及大矩阵的计算,从而减小了模型训练过程中的计算量;同时,双方的交互变量都是向量內积之后的标量结果,保证了数据的安全性,减小了数据传输过程中的通信量,不仅可以保证模型的快速收敛,还可以提高目标参与方设备与其他参与方设备的硬件处理速率。可选的,在一些具体的实施例中,对于同一份样本数据,在一次更新周期内,本公开实施例所述联邦学习模型训练方法仅需3个循环轮数的迭代,即可使得模型收敛;而采用梯度下降方法则需要数十轮迭代才可保证模型收敛,因此本公开实施例所述联邦学习模型训练方法能够提高模型的收敛速度。
步骤S105,目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息。
在一些实施例中,步骤S105中所述目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息,包括:
步骤S701,目标参与方设备获取样本标签信息,并基于本方模型参数、特征信息以及其他参与方设备的第一数据信息获得样本标签预测信息;其中,所述第一数据信息基于其他参与方设备的模型参数、特征信息获得。
在本实施例中,目标参与方设备Guest首先基于本方模型参数、特征信息计算模型参数的转置矩阵与特征信息的乘积获得第二数据信息之后,目标参与方设备Guest基于所述第二数据信息与其他参与方设备Host进行交互,获得其他参与方设备Host的第一数据信息最后,目标参与方设备Guest基于第一数据信息第二数据信息以及预设模型函数获得所述样本标签预测信息。
可选的,预设模型函数为sigmoid函数,样本标签预测信息 sigmoid函数定义为
步骤S703,目标参与方设备基于所述样本标签预测信息及所述样本标签信息计算损失函数。
在本实施例中,损失函数
步骤S705,目标参与方设备判断所述损失函数是否满足预设条件,若是,则将当前步长信息作为最终的步长信息;否则,减少所述步长信息的值并重新计算所述损失函数。
在本实施例中,预设条件可以为Armijo条件。因此,可判断损失函数Loss是否满足Armijo条件,包括:Loss(y,xH,xG,wH+λpH,wG+λpG)≤Loss(y,xH,xG,wH,wG)+c1λ(gH TpH+gG TpG),其中c1为超参数(例如可以取值1E-4)。
若损失函数满足Armijo条件,则将当前步长信息作为最终的步长信息λ;若损失函数不满足Armijo条件,则将减少所述步长信息的值例如为原来的1/2,并基于减少后的步长信息以及第一梯度搜索方向、第二梯度搜索方向更新双方的模型参数后重新计算损失函数,直至损失函数不满足Armijo条件。
之后,可基于获得的步长信息λ以及第一梯度搜索方向更新第一模型参数,其中,wG+1=wG+λpG
当双方的梯度变化稳定即||gk||≤ε阈值时,停止训练,模型更新完成。
需要说明的是,本公开实施例的方法可以由单个设备执行,例如一台计算机或服务器等。本实施例的方法也可以应用于分布式场景下,由多台设备相互配合来完成。在这种分布式场景的情况下,这多台设备中的一台设备可以只执行本公开实施例的方法中的某一个或多个步骤,这多台设备相互之间会进行交互以完成所述的方法。
需要说明的是,上述对本公开的一些实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于上述实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
基于同一发明构思,与上述任意实施例方法相对应的,本公开还提供了一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计 算机程序,所述处理器执行所述程序时实现上任意一实施例所述的方法。
图6示出了本实施例所提供的一种更为具体的电子设备硬件结构示意图,该设备可以包括:处理器1010、存储器1020、输入/输出接口1030、通信接口1040和总线1050。其中处理器1010、存储器1020、输入/输出接口1030和通信接口1040通过总线1050实现彼此之间在设备内部的通信连接。
处理器1010可以采用通用的CPU(Central Processing Unit,中央处理器)、微处理器、应用专用集成电路(Application Specific Integrated Circuit,ASIC)、或者一个或多个集成电路等方式实现,用于执行相关程序,以实现本说明书实施例所提供的技术方案。
存储器1020可以采用ROM(Read Only Memory,只读存储器)、RAM(Random Access Memory,随机存取存储器)、静态存储设备,动态存储设备等形式实现。存储器1020可以存储操作系统和其他应用程序,在通过软件或者固件来实现本说明书实施例所提供的技术方案时,相关的程序代码保存在存储器1020中,并由处理器1010来调用执行。
输入/输出接口1030用于连接输入/输出模块,以实现信息输入及输出。输入输出/模块可以作为组件配置在设备中(图中未示出),也可以外接于设备以提供相应功能。其中输入设备可以包括键盘、鼠标、触摸屏、麦克风、各类传感器等,输出设备可以包括显示器、扬声器、振动器、指示灯等。
通信接口1040用于连接通信模块(图中未示出),以实现本设备与其他设备的通信交互。其中通信模块可以通过有线方式(例如USB、网线等)实现通信,也可以通过无线方式(例如移动网络、WIFI、蓝牙等)实现通信。
总线1050包括一通路,在设备的各个组件(例如处理器1010、存储器1020、输入/输出接口1030和通信接口1040)之间传输信息。
需要说明的是,尽管上述设备仅示出了处理器1010、存储器1020、输入/输出接口1030、通信接口1040以及总线1050,但是在具体实施过程中,该设备还可以包括实现正常运行所必需的其他组件。此外,本领域的技术人员可以理解的是,上述设备中也可以仅包含实现本说明书实施例方案所必需的组件,而不必包含图中所示的全部组件。
上述实施例的电子设备用于实现前述任一实施例中相应的方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
基于同一发明构思,与上述任意实施例方法相对应的,本公开还提供了一 种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,所述计算机指令用于使所述计算机执行如上任一实施例所述的方法。
本实施例的计算机可读介质包括永久性和非永久性、可移动和非可移动媒体可以由任何方法或技术来实现信息存储。信息可以是计算机可读指令、数据结构、程序的模块或其他数据。计算机的存储介质的例子包括,但不限于相变内存(PRAM)、静态随机存取存储器(SRAM)、动态随机存取存储器(DRAM)、其他类型的随机存取存储器(RAM)、只读存储器(ROM)、电可擦除可编程只读存储器(EEPROM)、快闪记忆体或其他内存技术、只读光盘只读存储器(CD-ROM)、数字多功能光盘(DVD)或其他光学存储、磁盒式磁带,磁带磁磁盘存储或其他磁性存储设备或任何其他非传输介质,可用于存储可以被计算设备访问的信息。
上述实施例的存储介质存储的计算机指令用于使所述计算机执行如上任一实施例所述的方法,并且具有相应的方法实施例的有益效果,在此不再赘述。
所属领域的普通技术人员应当理解:以上任何实施例的讨论仅为示例性的,并非旨在暗示本公开的范围(包括权利要求)被限于这些例子;在本公开的思路下,以上实施例或者不同实施例中的技术特征之间也可以进行组合,步骤可以以任意顺序实现,并存在如上所述的本公开实施例的不同方面的许多其它变化,为了简明它们没有在细节中提供。
另外,为简化说明和讨论,并且为了不会使本公开实施例难以理解,在所提供的附图中可以示出或可以不示出与集成电路(IC)芯片和其它部件的公知的电源/接地连接。此外,可以以框图的形式示出装置,以便避免使本公开实施例难以理解,并且这也考虑了以下事实,即关于这些框图装置的实施方式的细节是高度取决于将要实施本公开实施例的平台的(即,这些细节应当完全处于本领域技术人员的理解范围内)。在阐述了具体细节(例如,电路)以描述本公开的示例性实施例的情况下,对本领域技术人员来说显而易见的是,可以在没有这些具体细节的情况下或者这些具体细节有变化的情况下实施本公开实施例。因此,这些描述应被认为是说明性的而不是限制性的。
尽管已经结合了本公开的具体实施例对本公开进行了描述,但是根据前面的描述,这些实施例的很多替换、修改和变型对本领域普通技术人员来说将是显而易见的。例如,其它存储器架构(例如,动态RAM(DRAM))可以使用所讨论的实施例。
本公开实施例旨在涵盖落入所附权利要求的宽泛范围之内的所有这样的替换、修改和变型。因此,凡在本公开实施例的精神和原则之内,所做的任何省略、修改、等同替换、改进等,均应包含在本公开的保护范围之内。

Claims (11)

  1. 一种联邦学习模型训练方法,其特征在于,包括:
    任一参与方设备基于本方模型参数和特征信息与其他参与方设备进行联合加密训练,获得本方的梯度信息;
    任一参与方设备基于模型参数和梯度信息获取模型参数变化量和梯度信息变化量,并基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得本方的梯度搜索方向作为拟牛顿条件;
    目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息;其中,所述目标参与方设备为任一参与方设备中具有标签信息的参与方设备,所述模型损失函数为凸函数;
    任一参与方设备基于所述梯度搜索方向、所述步长信息对本方的模型参数进行更新,直至所述联邦学习模型收敛。
  2. 根据权利要求1所述的方法,其特征在于,所述任一参与方设备基于所述模型参数变化量和所述梯度信息变化量,采用双向循环递归方法与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向作为拟牛顿条件,包括:
    任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量;所述中间变化量用于表征所述梯度信息的大小;
    任一参与方设备基于所述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向。
  3. 根据权利要求2所述的方法,其特征在于,所述任一参与方设备基于所述模型参数变化量和所述梯度信息变化量与其他参与方设备进行预设轮数的交互计算,获得中间变化量,还包括:
    任一参与方设备基于本方的所述模型参数变化量和所述梯度信息变化量计算本方第一中间值信息,与其他参与方设备交换第一中间值信息并基于各参与方设备的第一中间值信息计算第一全局中间值,以根据所述第一全局中间值计算所述中间变化量。
  4. 根据权利要求3所述的方法,其特征在于,所述第一中间值信息基于所述梯度信息变化量的转置矩阵与所述模型参数变化量的乘积获得。
  5. 根据权利要求3所述的方法,其特征在于,所述任一参与方设备基于所 述中间变化量与其他参与方设备进行预设轮数的交互计算,获得所述梯度搜索方向,还包括:
    任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息;
    任一参与方设备基于本方的所述第二中间值信息,与其他参与方设备交换第二中间值信息并基于各参与方设备的第二中间值信息计算第二全局中间值,以根据所述第二全局中间值计算所述梯度搜索方向。
  6. 根据权利要求5所述的方法,其特征在于,所述任一参与方设备基于本方的所述中间变化量计算本方的第二中间值信息,包括:
    任一参与方设备基于本方的所述模型参数变化量的转置矩阵、所述模型参数变化量获得第一标量信息,基于本方的所述梯度信息变化量的转置矩阵、所述梯度信息变化量获得第二标量信息;
    任一参与方设备与其他参与方设备进行交互以获得其他参与方设备的第三标量信息和第四标量信息;所述第三标量信息基于其他参与方设备的模型参数变化量的转置矩阵、模型参数变化量获得,所述第四标量信息基于其他参与方设备的梯度信息变化量的转置矩阵、梯度信息变化量获得;
    任一参与方设备基于所述第一标量信息、所述第二标量信息、所述第三标量信息、所述第四标量信息、所述中间变化量计算本方第二中间值信息。
  7. 根据权利要求6所述的方法,其特征在于,所述第一全局中间值为各参与方设备的第一中间值信息之和,所述第二全局中间值为各参与方设备的第二中间值信息之和。
  8. 根据权利要求1所述的方法,其特征在于,所述目标参与方设备获取模型损失函数,并基于所述梯度搜索方向以及模型损失函数计算步长信息,包括:
    目标参与方设备获取样本标签信息,并基于本方模型参数、特征信息以及其他参与方设备的第一数据信息获得样本标签预测信息;其中,所述第一数据信息基于其他参与方设备的模型参数、特征信息获得;
    目标参与方设备基于所述样本标签预测信息及所述样本标签信息计算所述模型损失函数;
    目标参与方设备判断所述模型损失函数是否满足预设条件,若是,则将当前步长信息作为最终的步长信息;否则,减少所述步长信息的值并重新计算所述模型损失函数。
  9. 根据权利要求8所述的方法,其特征在于,所述基于本方模型参数、特 征信息以及其他参与方设备的数据信息获得样本标签预测信息,包括:
    目标参与方设备基于本方模型参数、特征信息计算模型参数的转置矩阵与特征信息的乘积获得第二数据信息;
    目标参与方设备基于所述第二数据信息与其他参与方设备进行交互,获得其他参与方设备的第一数据信息;
    目标参与方设备基于第一数据信息、第二数据信息以及预设模型函数获得所述样本标签预测信息。
  10. 一种电子设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机程序,其特征在于,所述处理器执行所述程序时实现如权利要求1至9任意一项所述的方法。
  11. 一种非暂态计算机可读存储介质,所述非暂态计算机可读存储介质存储计算机指令,其特征在于,所述计算机指令用于使所述计算机执行权利要求1至9任一所述的方法。
PCT/CN2023/078224 2022-03-14 2023-02-24 联邦学习模型训练方法、电子设备及存储介质 WO2023174036A1 (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202210249166.1 2022-03-14
CN202210249166.1A CN114611720B (zh) 2022-03-14 2022-03-14 联邦学习模型训练方法、电子设备及存储介质

Publications (1)

Publication Number Publication Date
WO2023174036A1 true WO2023174036A1 (zh) 2023-09-21

Family

ID=81863537

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2023/078224 WO2023174036A1 (zh) 2022-03-14 2023-02-24 联邦学习模型训练方法、电子设备及存储介质

Country Status (2)

Country Link
CN (1) CN114611720B (zh)
WO (1) WO2023174036A1 (zh)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114611720B (zh) * 2022-03-14 2023-08-08 抖音视界有限公司 联邦学习模型训练方法、电子设备及存储介质
CN115618960B (zh) * 2022-09-21 2024-04-19 清华大学 联邦学习优化方法、装置、电子设备及存储介质
CN116017507B (zh) * 2022-12-05 2023-09-19 上海科技大学 基于无线空中计算和二阶优化的去中心化联邦学习方法
CN116187433B (zh) * 2023-04-28 2023-09-29 蓝象智联(杭州)科技有限公司 基于秘密分享的联邦拟牛顿训练方法、装置及存储介质

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108830416A (zh) * 2018-06-13 2018-11-16 四川大学 基于用户行为的广告点击率预测框架及算法
CN109635918A (zh) * 2018-10-30 2019-04-16 银河水滴科技(北京)有限公司 基于云平台和预设模型的神经网络自动训练方法和装置
US11254325B2 (en) * 2018-07-14 2022-02-22 Moove.Ai Vehicle-data analytics
CN114611720A (zh) * 2022-03-14 2022-06-10 北京字节跳动网络技术有限公司 联邦学习模型训练方法、电子设备及存储介质

Family Cites Families (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN101315670B (zh) * 2007-06-01 2010-08-11 清华大学 特定被摄体检测装置及其学习装置和学习方法
US11562230B2 (en) * 2017-03-22 2023-01-24 Visa International Service Association Privacy-preserving machine learning
CN109886417B (zh) * 2019-03-01 2024-05-03 深圳前海微众银行股份有限公司 基于联邦学习的模型参数训练方法、装置、设备及介质
CN113688855B (zh) * 2020-05-19 2023-07-28 华为技术有限公司 数据处理方法、联邦学习的训练方法及相关装置、设备
CN113723620A (zh) * 2020-05-25 2021-11-30 株式会社日立制作所 无线联邦学习中的终端调度方法和装置
CN111860829A (zh) * 2020-06-19 2020-10-30 光之树(北京)科技有限公司 联邦学习模型的训练方法及装置
CN112288100B (zh) * 2020-12-29 2021-08-03 支付宝(杭州)信息技术有限公司 一种基于联邦学习进行模型参数更新的方法、系统及装置
CN112733967B (zh) * 2021-03-30 2021-06-29 腾讯科技(深圳)有限公司 联邦学习的模型训练方法、装置、设备及存储介质

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108830416A (zh) * 2018-06-13 2018-11-16 四川大学 基于用户行为的广告点击率预测框架及算法
US11254325B2 (en) * 2018-07-14 2022-02-22 Moove.Ai Vehicle-data analytics
CN109635918A (zh) * 2018-10-30 2019-04-16 银河水滴科技(北京)有限公司 基于云平台和预设模型的神经网络自动训练方法和装置
CN114611720A (zh) * 2022-03-14 2022-06-10 北京字节跳动网络技术有限公司 联邦学习模型训练方法、电子设备及存储介质

Also Published As

Publication number Publication date
CN114611720A (zh) 2022-06-10
CN114611720B (zh) 2023-08-08

Similar Documents

Publication Publication Date Title
WO2023174036A1 (zh) 联邦学习模型训练方法、电子设备及存储介质
CN113688855B (zh) 数据处理方法、联邦学习的训练方法及相关装置、设备
Zhu et al. Federated learning on non-IID data: A survey
WO2022089256A1 (zh) 联邦神经网络模型的训练方法、装置、设备、计算机程序产品及计算机可读存储介质
US11734851B2 (en) Face key point detection method and apparatus, storage medium, and electronic device
CN108615073B (zh) 图像处理方法及装置、计算机可读存储介质、电子设备
US20230078061A1 (en) Model training method and apparatus for federated learning, device, and storage medium
EP3968179A1 (en) Place recognition method and apparatus, model training method and apparatus for place recognition, and electronic device
EP3627759A1 (en) Method and apparatus for encrypting data, method and apparatus for training machine learning model, and electronic device
CN112085159B (zh) 一种用户标签数据预测系统、方法、装置及电子设备
US10719693B2 (en) Method and apparatus for outputting information of object relationship
CN108229280A (zh) 时域动作检测方法和系统、电子设备、计算机存储介质
CN110442758B (zh) 一种图对齐方法、装置和存储介质
CN113435365B (zh) 人脸图像迁移方法及装置
CN112395979A (zh) 基于图像的健康状态识别方法、装置、设备及存储介质
CN111091010A (zh) 相似度确定、网络训练、查找方法及装置和存储介质
CN115563650A (zh) 基于联邦学习实现医疗数据的隐私保护系统
CN113191479A (zh) 联合学习的方法、系统、节点及存储介质
CN114676838A (zh) 联合更新模型的方法及装置
US20210326757A1 (en) Federated Learning with Only Positive Labels
CN114547658A (zh) 数据处理方法、装置、设备及计算机可读存储介质
CN113077383B (zh) 一种模型训练方法及模型训练装置
CN114841361A (zh) 一种模型训练方法及其相关设备
CN113609397A (zh) 用于推送信息的方法和装置
CN113961962A (zh) 一种基于隐私保护的模型训练方法、系统及计算机设备

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 23769544

Country of ref document: EP

Kind code of ref document: A1

WWE Wipo information: entry into national phase

Ref document number: 18572935

Country of ref document: US