WO2022082742A1 - Model training method and device, server, terminal, and storage medium - Google Patents

Model training method and device, server, terminal, and storage medium Download PDF

Info

Publication number
WO2022082742A1
WO2022082742A1 PCT/CN2020/123292 CN2020123292W WO2022082742A1 WO 2022082742 A1 WO2022082742 A1 WO 2022082742A1 CN 2020123292 W CN2020123292 W CN 2020123292W WO 2022082742 A1 WO2022082742 A1 WO 2022082742A1
Authority
WO
WIPO (PCT)
Prior art keywords
training
model
model parameters
terminal
terminals
Prior art date
Application number
PCT/CN2020/123292
Other languages
French (fr)
Chinese (zh)
Inventor
牟勤
洪伟
赵中原
熊可欣
Original Assignee
北京小米移动软件有限公司
北京邮电大学
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 北京小米移动软件有限公司, 北京邮电大学 filed Critical 北京小米移动软件有限公司
Priority to PCT/CN2020/123292 priority Critical patent/WO2022082742A1/en
Priority to CN202080002976.6A priority patent/CN114667523A/en
Publication of WO2022082742A1 publication Critical patent/WO2022082742A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Definitions

  • the present disclosure relates to the field of communication technologies, and in particular, to a model training method, device, server, terminal and storage medium.
  • Meta-learning is a learning method that uses past knowledge and experience to guide the learning of new tasks. Meta-learning has the ability to learn to learn.
  • Centralized meta-learning is a meta-learning scheme, and its steps usually include: the server collects data from each terminal, and integrates the data to generate a training data set; the server randomly initializes a set of model parameters as global model parameters; each Round training extracts a set of tasks from the training data set, a set of tasks includes multiple tasks, each task includes a support set and a query set; uses the support set of the extracted tasks to perform local updates, and obtains locally updated model parameters; use the extracted tasks
  • the query set is used to test the locally updated model parameters, and the test gradient is obtained; the average value of the test gradient on each task is determined, and the gradient descent method is used to update the global model; the above process is repeated until the model converges, and the meta-model is obtained and distributed to each task.
  • Terminals each terminal uses local data to fine-tune the model to obtain an adaptive update model.
  • the embodiments of the present disclosure provide a model training method, apparatus, server, terminal and storage medium, which can save transmission bandwidth and save computing resources of the server.
  • the technical solution is as follows:
  • a model training method comprising:
  • the data distribution information includes categories of data and the number of samples included in each category
  • the model parameters are updated to obtain global model parameters.
  • a model training method comprising:
  • Model parameters are obtained by training a training data set selected by the server based on the data distribution information
  • the training result is sent, and the training result is used to globally update the model parameters to obtain global model parameters.
  • a model training apparatus comprising:
  • a receiving module configured to receive data distribution information of multiple terminals, where the data distribution information includes data categories and the number of samples included in each category;
  • a selection module configured to select a training data set that conforms to the data distribution information of the multiple terminals
  • a model training module configured to perform model training based on the training data set to obtain model parameters
  • a sending module configured to send the model parameters to at least some of the terminals
  • the receiving module is further configured to receive a training result obtained by training the model parameters by the at least part of the terminals;
  • the model training module is further configured to update the model parameters based on the training results of the at least part of the terminals to obtain global model parameters.
  • a model training apparatus comprising:
  • a sending module configured to send data distribution information of the terminal, where the data distribution information includes data categories and the number of samples included in each category;
  • a receiving module configured to receive model parameters, the model parameters are obtained by training a training data set selected by the server based on the data distribution information;
  • a model training module configured to train the model parameters to obtain a training result
  • the sending module is further configured to send the training result, where the training result is used to globally update the model parameters to obtain global model parameters.
  • a server comprising: a processor; a memory for storing instructions executable by the processor; wherein the processor is configured to load and execute the executable Instructions to implement the aforementioned model training method.
  • a terminal comprising: a processor; a memory for storing instructions executable by the processor; wherein the processor is configured to load and execute the executable Instructions to implement the aforementioned model training method.
  • a computer-readable storage medium when the instructions in the computer-readable storage medium are executed by a processor, the aforementioned model training method can be executed.
  • the server receives the data distribution information sent by the terminal, then summarizes the data distribution information of each terminal, selects data with the same distribution to form a training set for preliminary training, and then sends the model parameters to the terminal for distributed training, and then Then the terminal uploads the training results to the server for global update; in this process, the data distribution information, model parameters, etc. are transmitted between the terminal and the server, but the data of the terminal is not directly transmitted, and the bandwidth occupation is small; moreover, through the terminal Distributed training, low consumption of server computing resources.
  • FIG. 1 shows a block diagram of a model training system provided by an exemplary embodiment of the present disclosure
  • FIG. 2 is a flowchart of a model training method according to an exemplary embodiment
  • FIG. 3 is a flowchart of a model training method according to an exemplary embodiment
  • FIG. 4 is a flowchart of a model training method according to an exemplary embodiment
  • FIG. 5 is a flowchart showing a connection establishment process according to an exemplary embodiment
  • FIG. 6 is a flow chart of an initialization model parameter training process according to an exemplary embodiment
  • FIG. 7 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment
  • FIG. 8 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment
  • FIG. 9 is a block diagram of a terminal according to an exemplary embodiment.
  • Fig. 10 is a block diagram of a server according to an exemplary embodiment.
  • FIG. 1 shows a block diagram of a model training system provided by an exemplary embodiment of the present disclosure.
  • the model training system may include: a network side 12 and a terminal 13 .
  • the network side 12 includes a server 120, and the server 120 communicates with the terminal 13 through a wireless channel.
  • the server 120 may belong to a functional unit of a network-side device, and the network-side device may be a base station, which is a device deployed in an access network to provide a wireless communication function for a terminal.
  • the terminal 13 is a terminal accessing a network side device, and the network side device coordinates each terminal to participate in distributed collaborative learning.
  • the base station may include various forms of macro base station, micro base station, relay station, access point and so on.
  • the names of devices with base station functions may be different.
  • 5G New Radio (NR, New Radio) systems they are called gNodeBs or gNBs.
  • gNodeBs With the evolution of communication technology, the name "base station” may be descriptive and will change.
  • network-side devices For convenience of description, hereinafter, the above-mentioned apparatuses for providing wireless communication functions for terminals are collectively referred to as network-side devices.
  • the terminal 13 may include various handheld devices with wireless communication functions, vehicle-mounted devices, wearable devices, computing devices or other processing devices connected to the wireless modem, as well as various forms of user equipment, mobile stations (Mobile Station, MS), terminal and so on.
  • the access network device 120 and the terminal 13 communicate with each other through a certain air interface technology, such as a Uu interface.
  • the server 120 collects the data of each terminal for centralized meta-learning (or called centralized training).
  • this method requires data transmission and occupies a large bandwidth; on the other hand, all training work Completed by the server, the training time is long, the server computing resource consumption is large, and meta-learning does not pursue an optimal global model, but hopes to train a well-initialized model that can quickly adapt to new tasks. Therefore, centralized meta-learning The lengthy model convergence period in the solution brings little performance improvement, and the model training efficiency is low.
  • edge terminals accessing the network are often unable to upload data (small sample data) to the server for training due to data privacy, security issues, etc.
  • the data of these edge terminals often contains a large amount of information.
  • Model performance improvement is important, and centralized meta-learning schemes cannot fully utilize data from edge terminals.
  • the centralized meta-learning scheme indiscriminately uses server data as training data, and the server data is weakly correlated with the terminal, so the trained model has weak generalization ability for the task of the terminal.
  • model training system and business scenarios described in the embodiments of the present disclosure are for the purpose of illustrating the technical solutions of the embodiments of the present disclosure more clearly, and do not constitute a limitation on the technical solutions provided by the embodiments of the present disclosure.
  • the technical solutions provided by the embodiments of the present disclosure are also applicable to similar technical problems.
  • Fig. 2 is a flowchart of a model training method according to an exemplary embodiment. Referring to Figure 2, the method includes the following steps:
  • step 101 the server receives data distribution information of multiple terminals.
  • the data distribution information includes categories of data and the number of samples included in each category.
  • the data category is used to classify the data in the data set. For example, when training a picture classification model, the pictures in the data set can be divided into multiple categories according to the classification requirements, such as people, plants, landscapes, etc. belong to one category.
  • the number of samples included in each category refers to the amount of data of each category, for example, the number of pictures whose category is a person.
  • step 102 the server selects a training data set that conforms to the data distribution information of the multiple terminals.
  • the server After receiving the data distribution information of the multiple terminals, the server selects data with the same category and number of samples as those included in the data distribution information according to the data distribution information of the multiple terminals to form a training data set.
  • step 103 the server performs model training based on the training data set to obtain model parameters.
  • step 104 the server sends the model parameters to at least some of the terminals.
  • the server can select the terminals that meet the distributed training requirements based on the user scheduling information of the terminals, and let these terminals participate in the distributed training.
  • step 105 the server receives a training result obtained by training the model parameters by the at least some terminals.
  • step 106 the server updates the model parameters based on the training results of the at least part of the terminals to obtain global model parameters.
  • the training result is reported to the server, and the server can complete the global update based on the training results of each terminal to obtain the global model parameters.
  • the global model parameters achieve the training target, and on the other hand, because the training results of at least part of the terminals are integrated, they are suitable for at least part of the terminals mentioned above.
  • the server receives the data distribution information sent by the terminal, then summarizes the data distribution information of each terminal, selects data with the same distribution to form a training set for preliminary training, and then sends the model parameters to the terminal for distributed training, and then Then the terminal uploads the training results to the server for global update; in this process, the data distribution information, model parameters, etc. are transmitted between the terminal and the server, but the data of the terminal is not directly transmitted, and the bandwidth occupation is small; moreover, through the terminal Distributed training, the consumption of server computing resources is small, the training period is short, and the training efficiency is high.
  • the server selects the same data distributed as the terminal to form the training set during preliminary training, the correlation between the model and the terminal is strengthened, and the model has strong generalization ability; in addition, the distributed training process is directly participated by the terminal.
  • the terminal does not need to upload its own data, and even the edge terminal can participate, so that the data of the edge terminal can be used to improve the performance of the learning model, thus ensuring that the training scheme can make full use of the data of the edge terminal.
  • the solution for distributed collaborative learning using data distribution characteristics provided by the embodiments of the present disclosure is suitable for training meta-models with strong generalization capabilities, such as model training for tasks such as deep learning and image processing.
  • receive data distribution information of multiple terminals including:
  • Radio Resource Control Radio Resource Control
  • the terminal and the server when transmitting the data distribution information, may first establish an RRC connection, and in the process of establishing the RRC connection, transmit the above-mentioned data distribution information through RRC signaling. In this way, the uploading process of the data distribution information can be simplified.
  • selecting a training data set that conforms to the data distribution information of the multiple terminals including:
  • the data whose distribution conforms to the total data distribution information is extracted from the local data of the server to obtain the training data set.
  • the data distribution information of terminal 1 includes: ⁇ type A, sample size a1; type B, sample size b ⁇ ; the data distribution information of terminal 2 includes: ⁇ type A, sample size a2; type C, sample size c ⁇ ; Then, the total data distribution information includes: ⁇ type A, sample size a1+a2; type B, sample size b; type C, sample size c ⁇ .
  • the server selects the data type and sample size according to ⁇ type A, sample size a1+a2; type B, sample size b; type C, sample size c ⁇ to form a training data set.
  • model parameters include initialization model parameters, and model training is performed based on the training data set to obtain model parameters, including:
  • model parameters include intermediate model parameters, and model training is performed based on the training data set to obtain model parameters, including:
  • the initial model parameters are iteratively updated to obtain the intermediate model parameters.
  • the server performs model training according to the training data set to obtain an initialized model parameter, and then sends the initialized model parameter to the terminal, which can save the training time of the terminal; on the other hand, the server performs training based on the terminal.
  • the initialized model parameters are updated to obtain the intermediate model parameters, and then the intermediate model parameters are sent to the terminal, so that the terminal can perform training on the basis of the intermediate model parameters, thereby speeding up the entire model training process.
  • the intermediate model parameters are obtained by iteratively updated by the server based on the training results uploaded by the terminal.
  • sending the model parameters to at least some of the multiple terminals includes:
  • the user scheduling information includes at least one of the following parameters:
  • the data volume of the data in the terminal, the similarity between the data distribution and the total data distribution information, the communication status, the computing capability, and the performance requirements of the learning model, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
  • the data amount of the data in the terminal may be obtained based on the data distribution information uploaded by the terminal, that is, the sum of the sample sizes of various types of data in the data distribution information.
  • the similarity between the data distribution and the total data distribution information refers to the difference between the categories included in a terminal and the categories in the total data distribution information, such as the ratio of the number of categories included in the terminal to the number of categories in the total data distribution information, and the number of categories in the terminal.
  • the ratio of the number of samples of a category to the number of samples of the corresponding category in the total data distribution information, and the above difference type is obtained by combining the above two ratios.
  • the communication status may include Channel Quality Indication (CQI).
  • Computing power can include computing speed and equipment surplus computing power.
  • Computing speed refers to the number of calculations per second (calculation times/S), and equipment surplus computing power refers to the percentage of computing power that can be allocated to model training.
  • the server sets a threshold range that meets the distributed training requirements.
  • the terminal meets the distributed training requirements.
  • sending the model parameters to a terminal that meets the distributed training requirements among the multiple terminals including:
  • the model parameters are sent to the terminal that meets the distributed training requirements according to the data transmission parameters.
  • the data transmission parameters include parameters such as modulation mode and code rate.
  • modulation mode and code rate When the data amount of the model parameters is different and the communication status of the terminal is different, different modulation modes and code rates can be selected for transmission, so that the selected modulation mode and code rate can be used for transmission. It matches the amount of data to be transmitted and the communication status of the terminal, so as to achieve a better transmission effect.
  • the data volume of model parameters is related to the size of the model, the larger the model, the larger the data volume of model parameters; on the other hand, it is also related to the accuracy of each model parameter the larger the amount.
  • the precision of the model parameters may refer to the number of digits retained after the decimal point. The higher the precision of the model parameters and the more digits retained after the decimal point, the larger the amount of data occupied by the model parameters.
  • the training result includes a gradient value
  • the gradient value is a gradient value obtained by testing the trained model parameters after the terminal trains the model parameters
  • the training result includes a model update parameter
  • the model update parameter is a model parameter obtained after the terminal trains the model parameter.
  • the training result of the terminal can be in two cases, one is the gradient value obtained by testing after the training is completed, and the other is the model update parameters obtained only after model training without testing.
  • the reason for these two situations is that the data volume of the data in the terminal is different. For example, when the data volume in the terminal is large, the data in the terminal can form a support set and a query set. In this case, the terminal can use the support set first. Model training is performed, and then the query set is used for model testing; when the amount of data in the terminal is small, the data in the terminal can only form a support set. At this time, the terminal uses the support set for model training, and the model test is completed by the server. .
  • the size of the data amount in the terminal can be obtained by comparing with the threshold value, for example, if it is larger than the threshold value, it is larger, and if it is smaller than the threshold value, it is smaller.
  • the threshold may be determined based on the data volume of multiple terminals, for example, may be a quantile of the data volume of multiple terminals. For example, if the data volume of 80% of users reaches 1000, the threshold is set to 1000.
  • the threshold may be determined by the server based on the data distribution information of each terminal, and then notified to each terminal.
  • the training result of each terminal in the at least part of the terminals includes gradient values
  • the model parameters are updated to obtain global model parameters, including:
  • the model parameters are iteratively updated using a gradient descent method to obtain global model parameters.
  • each terminal in at least some of the terminals has a large amount of data and can form a support set and a query set at the same time. Therefore, each terminal in at least some of the terminals reports the gradient value to the server, so as to facilitate the server to complete the matching Iterative update of model parameters.
  • the training result of at least one terminal in the at least part of the terminals includes model update parameters
  • the model parameters are updated to obtain global model parameters, including:
  • the model parameters are iteratively updated using a gradient descent method to obtain global model parameters.
  • some terminals have a small amount of data and cannot form a support set and a query set at the same time. Therefore, these terminals only report model update parameters to the server, and the server extracts the query set from the local for model testing, and then uses the test set. The obtained gradient values complete the iterative update of the model parameters.
  • a gradient descent method to iteratively update the model parameters to obtain global model parameters, including:
  • the intermediate model parameters are iteratively updated by using the average value of the second gradient values of the at least part of the terminals; The gradient values obtained by testing the intermediate model parameters.
  • the method further includes:
  • the global model parameters In response to the average value of the first gradient values of the at least part of the terminals being within the threshold range, the global model parameters after the iterative update of the model parameters are sent to the at least part of the terminals, and the global model parameters are used for all the terminals.
  • the terminal performs adaptive update.
  • the updating of model parameters is usually a multi-round distributed training process, that is, at least some terminals perform model training once and report their respective training results as a round of distributed training process.
  • the server performs a global update based on the training results; on the other hand, the terminal can determine whether it meets the requirements of distributed training based on the average value of the gradient values corresponding to the training results of these terminals. If the requirements are met, the globally updated model will be used as the globally updated model.
  • the globally updated model does not require distributed training, but can be used by users after adaptive training; if the requirements are not met, the globally updated model will be used.
  • the intermediate model parameters are sent to the terminal as the basis for the next round of terminal training, and the next round of terminal training is performed on the basis of the intermediate model parameters.
  • the server monitors the effect of the distributed model training, and stops learning when the model accuracy meets the requirements, and does not require training until the model converges.
  • This training method greatly improves the training efficiency.
  • the global model parameters will be adaptively updated by each terminal in the future, so that each terminal can obtain a more personalized model, which ensures that the model used by the terminal is more in line with the terminal's task requirements. model performance.
  • the adaptive updating of the terminal may refer to that the terminal uses data in the terminal to update based on the global model parameters, so that the model parameters meet the requirements of the terminal.
  • Fig. 3 is a flow chart of a model training method according to an exemplary embodiment. Referring to Figure 3, the method includes the following steps:
  • step 201 the terminal sends data distribution information of the terminal, where the data distribution information includes data types and the number of samples included in each type.
  • the terminal counts the number of local samples of each data category, generates data distribution information, and sends it to the server.
  • step 202 the terminal receives model parameters, where the model parameters are obtained by training a training data set selected by the server based on the data distribution information.
  • the model parameters here can be either initial model parameters or intermediate model parameters.
  • step 203 the terminal trains the model parameters to obtain a training result.
  • step 204 the terminal sends the training result, and the training result is used to globally update the model parameters to obtain global model parameters.
  • the terminal sends its own data distribution information to the server, the server summarizes the data distribution information of each terminal, selects data with the same distribution to form a training set for preliminary training, and then sends the model parameters to the terminal for distribution training, and then the terminal uploads the training results to the server for global update; in this process, the data distribution information, model parameters, etc. are transmitted between the terminal and the server, but the data of the terminal is not directly transmitted, and the bandwidth consumption is small; and , through the terminal distributed training, the server computing resource consumption is small.
  • send the data distribution information of the terminal including:
  • the data distribution information is sent through RRC signaling.
  • model parameters include initializing model parameters and receiving model parameters, including:
  • model parameters include intermediate model parameters
  • received model parameters include:
  • the intermediate model parameters are received, where the intermediate model parameters are obtained by iteratively updating the initialization model parameters by the server.
  • the training result includes a gradient value
  • the gradient value is a gradient value obtained by testing the trained model parameters after the model parameters are trained
  • the training result includes a model update parameter
  • the model update parameter is a model parameter obtained after training the model parameter.
  • the training result includes model update parameters
  • the model update parameters are sent to the server according to the data transmission parameters.
  • the method further includes:
  • the user scheduling information includes at least one of the following parameters: data volume of data in the terminal, similarity between data distribution and total data distribution information, communication status, computing power, and learning model performance requirements, the The total data distribution information is obtained by combining the data distribution information of the multiple terminals.
  • the method further includes:
  • the global model parameters are adaptively updated.
  • Fig. 4 is a flow chart of a model training method according to an exemplary embodiment. Referring to Figure 4, the method includes the following steps:
  • step 301 the server and the terminal establish an RRC connection.
  • the process of establishing an RRC connection between the server and the terminal may refer to FIG. 5, and the steps are as follows:
  • Step 3011 The terminal sends a request to establish an RRC connection signaling to the server, and the request to establish an RRC connection signaling application requests to establish an RRC connection with the server.
  • the server receives the request to establish the RRC connection signaling.
  • Step 3012 The server sends the RRC connection establishment signaling to the client, where the RRC connection establishment signaling is used to notify the terminal server that the server agrees to establish the RRC connection.
  • the terminal receives the RRC connection establishment signaling.
  • Step 3013 The terminal sends an RRC connection establishment complete signaling to the server, where the RRC connection establishment complete signaling is used to notify the server that the RRC connection establishment is complete.
  • the server receives the RRC connection establishment completion signaling.
  • the signaling transmission and reception in the above-mentioned RRC connection establishment process is performed by the network communication module of the terminal and the network communication module of the server.
  • the network communication modules of the terminal and the server can be composed of two parts: a sending module and a receiving module.
  • step 302 the terminal sends the data distribution information to the server; the server receives the data distribution information sent by the terminal.
  • the data distribution information includes data categories and the number of samples included in each category.
  • step 301 and step 302 may be in no order, for example, data distribution information may be transmitted during the process of establishing an RRC connection between the server and the terminal, that is, the server receives the data distribution information transmitted by the terminal through RRC signaling .
  • the server receives the data distribution information that the terminal completes signaling transmission through RRC connection establishment.
  • step 303 the server combines the data distribution information of the multiple terminals to obtain total data distribution information.
  • the data distribution information of terminal 1 includes: ⁇ type A, sample size a1; type B, sample size b ⁇ ; the data distribution information of terminal 2 includes: ⁇ type A, sample size a2; type C, sample size c ⁇ ; Then, the total data distribution information includes: ⁇ type A, sample size a1+a2; type B, sample size b; type C, sample size c ⁇ .
  • step 304 the server extracts data whose distribution conforms to the total data distribution information from the local data of the server to obtain the training data set.
  • the server selects the data type and sample size according to ⁇ type A, sample size a1+a2; type B, sample size b; type C, sample size c ⁇ obtained by combining in step 303 to form a training data set.
  • step 305 the server uses the training data set to perform model training to obtain the initialization model parameters.
  • Step 3051 The server randomly initializes a set of model parameters.
  • Step 3052 The server extracts a batch of tasks from the training data set, and each task includes a support set and a query set.
  • the total data distribution information is denoted as P
  • the server local data is denoted as D s
  • data is extracted from the local data according to the total data distribution information P
  • a training data set is generated, denoted as server from training dataset Extract data to generate several tasks, each task contains support set and query set, respectively denoted as and
  • Step 3053 The server uses the support set of each task for training, and calculates the model loss and gradient to obtain the updated model parameters on each task.
  • the server can use the gradient descent method to obtain the updated model parameters, which can be expressed as the following formula (1):
  • ⁇ ′ i represents the updated model parameters on the ith task
  • represents a set of initialized model parameters
  • represents the learning rate of a single task
  • L represents the loss function of the model on the support set
  • f represents the model
  • T i represents the ith task
  • Step 3054 The server uses the query set of each task to calculate the test loss and gradient for updating the model parameters.
  • Step 3055 The server summarizes the gradients on each task, updates the randomly initialized model parameters, and obtains the initialized model parameters.
  • the server computes the test loss and gradient for updating the model parameters using the query set for each task, sums and averages the gradients over each task.
  • the global model parameters are updated with the average gradient value, which can be expressed as the following formula (2):
  • represents the global learning rate
  • N represents the number of tasks used in this round of training
  • p(T) represents the set of tasks used in this round of training
  • each step can be performed by the model training module of the server, and, in step 3052, the training data set of the server can be stored in the data processing and storage module of the server, and the model training module can be combined with the data processing and storage module. Signaling interaction is performed between them to extract a batch of tasks.
  • step 306 the terminal sends the user scheduling information to the server; the server receives the user scheduling information sent by the terminal.
  • the user scheduling information includes at least one of the following parameters: data volume of data in the terminal, similarity between data distribution and total data distribution information, communication status, computing power, and learning model performance requirements, and the total data distribution information is: It is obtained by combining the data distribution information of the multiple terminals.
  • step 306 and step 302 may be performed simultaneously, that is, the terminal sends the user scheduling information to the server when transmitting the data distribution information, that is, the user scheduling information may also be transmitted through RRC signaling.
  • the user scheduling information may only include communication status, computing capability, and performance requirements of the learning model, and the similarity of the data volume, data distribution, and total data distribution information of the data in the terminal may be determined by the server based on The data distribution information is determined.
  • each parameter in the user scheduling information may be sent to the server by the terminal together, or may be sent to the server in sequence.
  • the communication condition usually includes the CQI, and the CQI needs to be obtained by the terminal through measurement. Therefore, the method may further include: before step 306, the terminal performs CQI measurement.
  • the user scheduling information is acquired by the user management module in the terminal, and sent to the network communication module of the server through the network communication module of the terminal, and the network communication module of the server transmits it to the user management module of the server.
  • the network communication module and the user management module in the above-mentioned terminal or server carry out the transmission of user scheduling information, a new signaling may be used for execution, and the function of this signaling is to transmit the user scheduling information.
  • step 307 the server determines whether each of the multiple terminals meets the distributed training requirement based on the user scheduling information of each of the multiple terminals.
  • the server sets a threshold range that meets the distributed training requirements.
  • the terminal meets the distributed training requirements.
  • Terminals other than the terminals selected above that meet the distributed training requirements among the multiple terminals do not participate in this training.
  • step 308 the server sends the initial model parameters to the terminal that meets the distributed training requirement among the multiple terminals.
  • the terminal receives initial model parameters.
  • step 301-step 307 if the terminal in step 301-step 307 belongs to the terminal that meets the distributed training requirements, the terminal will participate in step 308-step 314; and if the terminal in step 301-step 307 does not belong to the distributed training requirement required terminal, the terminal will not participate in steps 308-314.
  • This embodiment is described by taking as an example that the terminal in step 301 to step 307 belongs to a terminal that meets the distributed training requirement.
  • the server when transmitting the initialization model parameters, the server first determines the data transmission parameters based on the data volume of the initial model parameters and the communication status of the terminal; and then sends the initialization model parameters to the terminal according to the data transmission parameters.
  • determining the data transmission parameters may be performed by a transmission control module in the server. After the transmission control module determines the data transmission parameters, it may control the network communication module to send the initialization model parameters according to the above data transmission parameters.
  • the data transmission parameters include parameters such as modulation mode and code rate.
  • modulation mode and code rate When the data amount of the model parameters is different and the communication status of the terminal is different, different modulation modes and code rates can be selected for transmission, so that the selected modulation mode and code rate can be used for transmission. It matches the amount of data to be transmitted and the communication status of the terminal, so as to achieve a better transmission effect.
  • the server encapsulates the initialization model parameters according to the above data transmission scheme.
  • the server sends the packaged data packet of initializing model parameters to the terminal.
  • the terminal decapsulates the data packet after receiving it.
  • the terminal confirms the correctness of the received data packet based on the decapsulated data.
  • the terminal feeds back a message to the server, informing the server that the terminal has correctly received the initialization model parameters.
  • verifying the correctness of the data packet and generating the feedback message is performed by the transmission control module in the terminal, and the receiving and sending processes are performed by the network communication module.
  • the data volume of model parameters is related to the size of the model, the larger the model, the larger the data volume of model parameters; on the other hand, it is also related to the accuracy of each model parameter the larger the amount.
  • the precision of the model parameters may refer to the number of digits retained after the decimal point. The higher the precision of the model parameters and the more digits retained after the decimal point, the larger the amount of data occupied by the model parameters.
  • step 309 the terminal trains the initial model parameters to obtain a training result.
  • the training result of the terminal can be in two cases, one is the gradient value obtained by testing after the training is completed, and the other is the model update parameters obtained only after model training without testing.
  • the reason for these two situations is that the data volume of the data in the terminal is different. For example, when the data volume in the terminal is large, the data in the terminal can form a support set and a query set. In this case, the terminal can use the support set first. Model training is performed, and then the query set is used for model testing; when the amount of data in the terminal is small, the data in the terminal can only form a support set. At this time, the terminal uses the support set for model training, and the model test is completed by the server. .
  • the size of the data amount in the terminal can be obtained by comparing with the threshold value, for example, if it is larger than the threshold value, it is larger, and if it is smaller than the threshold value, it is smaller.
  • the threshold may be determined based on the data volume of multiple terminals, for example, may be a quantile of the data volume of multiple terminals. For example, if the data volume of 80% of users reaches 1000, the threshold is set to 1000.
  • the threshold may be determined by the server based on the data distribution information of each terminal, and then notified to each terminal. The terminal can determine whether to generate a query set based on the threshold and its own data volume.
  • the terminal uses the support set to update the initial model parameters by gradient descent to obtain the model update parameters, which can be expressed as the following formula (3):
  • ⁇ ui represents the model update parameter of the ith terminal, represents the support set in the ith terminal.
  • the terminal uses the query set to test the model update parameters, and calculates the test loss and gradient value, which can be expressed as the following formula (4):
  • g ui represents the test gradient of the model update parameters of the ith terminal, represents the query set in the training set of the ith terminal.
  • step 310 the terminal sends the training result to the server; the server receives the training result sent by the terminal.
  • the terminal sends the training result
  • the server sends the initialization model parameters in step 308, that is, the data transmission parameters are first determined and then sent according to the data transmission parameters.
  • the data transmission parameters are first determined and then transmitted according to the data transmission parameters.
  • step 311 the server updates the model parameters based on the training results of at least some of the terminals.
  • step 312 is performed, and when the updated model parameters do not meet the requirements, step 313 is performed.
  • At least some of the terminals here refer to the terminals that participate in the training and meet the distributed training requirements.
  • the server may obtain the average value of the gradient values of at least some of the terminals based on the training results of the terminal. If the average value of the gradient values of at least some of the terminals is within the threshold range (eg, less than the set value), it means that the updated model parameters meet the requirements; otherwise, the updated model parameters do not meet the requirements.
  • step 311 may include:
  • the server uses a gradient descent method to iteratively update the model parameters based on the average value of the gradient values of the at least part of the terminals.
  • step 311 may include:
  • the server selects a query set that conforms to the data distribution information of the first terminal, where the first terminal is a terminal whose training result includes model update parameters;
  • the server tests the model update parameters of the first terminal based on the query set to obtain a gradient value
  • the server uses a gradient descent method to iteratively update the model parameters based on the average value of the gradient values of the at least part of the terminals.
  • the server determines whether a query set needs to be generated for the terminal according to the data volume of each terminal.
  • the server uses the average value of gradient values of at least some terminals to update the model parameters by gradient descent, which can be expressed as the following formula (5):
  • M represents the number of terminals that meet the distributed training requirements, that is, the number of terminals participating in the distributed training.
  • g 0 represents the aforementioned threshold value (set value).
  • Step 311 can be executed by the model update module in the server.
  • the module needs to interact with the data processing and storage module in the server, and obtains the data to generate a query set for the terminal, which can be used in the interaction process.
  • step 312 the server sends the intermediate model parameters whose model parameters are iteratively updated to the at least part of the terminals; the terminal receives the intermediate model parameters sent by the server.
  • the terminal After receiving the intermediate model parameters sent by the server, the terminal trains the intermediate model parameters to obtain a training result, and then repeats steps 310 and 311 to iteratively update.
  • step 313 the server sends the global model parameters whose model parameters are iteratively updated to the at least part of the terminals; the terminals receive the global model parameters sent by the server.
  • step 3134 the terminal adaptively updates the global model parameters.
  • the terminal uses the support set to test the global model parameters, calculates the test loss and gradient, and performs gradient descent update to obtain an adaptive model, which can be expressed as the following formula (7):
  • ⁇ ui ( ⁇ ) is the adaptive update model of the ith terminal, is the query set in the test set of the ith terminal.
  • the aforementioned steps 309 and 314 may be performed by the model updating module in the terminal, which needs to interact with the data processing and storage module in the terminal during the execution of the above steps to obtain data to generate a support set, a query set, and the like.
  • Fig. 7 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment.
  • the apparatus has the function of implementing the server in the above method embodiment, and the function may be implemented by hardware, or by executing corresponding software in hardware.
  • the apparatus includes: a receiving module 501 , a selecting module 502 , a model training module 503 and a sending module 504 .
  • the receiving module 501 is configured to receive data distribution information of multiple terminals, where the data distribution information includes data categories and the number of samples included in each category;
  • a selection module 502 is configured to select a training data set that conforms to the data distribution information of the multiple terminals
  • the model training module 503 is configured to perform model training based on the training data set to obtain model parameters
  • a sending module 504 configured to send the model parameters to at least some of the multiple terminals
  • the receiving module 501 is further configured to receive a training result obtained by training the model parameters by the at least part of the terminals;
  • the model training module 503 is further configured to update the model parameters based on the training results of the at least part of the terminals to obtain global model parameters.
  • the receiving module 501 is configured to receive the data distribution information transmitted by each of the multiple terminals through RRC signaling.
  • the selection module 502 is configured to combine the data distribution information of the multiple terminals to obtain total data distribution information; extract data whose distribution conforms to the total data distribution information from the local data of the server. , to obtain the training data set.
  • model parameters include initialization model parameters
  • model training module 503 is configured to use the training data set to perform model training to obtain the initialization model parameters
  • model parameters include intermediate model parameters
  • model training module 503 is configured to perform model training by using the training data set to obtain initialized model parameters; iteratively update the initialized model parameters to obtain the intermediate model parameters model parameters.
  • the receiving module 501 is further configured to receive user scheduling information of each terminal in the multiple terminals;
  • the apparatus further includes: a determination module 505, configured to determine whether each of the multiple terminals meets the distributed training requirement based on user scheduling information of each of the multiple terminals;
  • the sending module 504 is configured to send the model parameters to a terminal that meets the distributed training requirement among the multiple terminals.
  • the user scheduling information includes at least one of the following parameters:
  • the data volume of the data in the terminal, the similarity between the data distribution and the total data distribution information, the communication status, the computing capability, and the performance requirements of the learning model, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
  • the determining module 505 is further configured to determine data transmission parameters based on the data volume of the model parameters and the communication status of the terminals that meet the distributed training requirements;
  • the sending module 504 is configured to send the model parameter to the terminal that meets the distributed training requirement according to the data transmission parameter.
  • the training result includes a gradient value
  • the gradient value is a gradient value obtained by testing the trained model parameters after the terminal trains the model parameters
  • the training result includes a model update parameter
  • the model update parameter is a model parameter obtained after the terminal trains the model parameter.
  • the training result of each terminal in the at least part of the terminals includes gradient values
  • the model training module 503 is configured to use a gradient descent method to iteratively update the model parameters based on the average value of the gradient values of the at least part of the terminals to obtain global model parameters.
  • the training result of at least one terminal in the at least part of the terminals includes model update parameters
  • the selecting module 502 is configured to select a query set that conforms to the data distribution information of the first terminal, where the first terminal is a terminal whose training result includes model update parameters;
  • the model training module 503 is configured to test the model update parameters of the first terminal based on the query set to obtain gradient values;
  • the model parameters are iteratively updated to obtain global model parameters.
  • the model training module 503 is configured to use a gradient descent method to iteratively update the model parameters based on the average value of the first gradient values of the at least some terminals; determine the first gradient values of the at least some terminals. Whether the average value of a gradient value is within the threshold value range; in response to the average value of the first gradient values of the at least part of the terminals being not within the threshold value range, the intermediate model parameters after the iterative update of the model parameters are sent to the at least one terminal. some terminals; iteratively update the intermediate model parameters by using the average value of the second gradient values of the at least part of the terminals; wherein, the second gradient values are obtained by the terminal after training the intermediate model parameters The gradient values obtained by testing the intermediate model parameters after training.
  • a gradient descent method to iteratively update the model parameters based on the average value of the first gradient values of the at least some terminals; determine the first gradient values of the at least some terminals. Whether the average value of a gradient value is within the threshold value range; in
  • the sending module 504 is further configured to, in response to the average value of the first gradient values of the at least part of the terminals being within a threshold range, send the global model parameters after the iterative update of the model parameters to all the terminals. at least some of the terminals, the global model parameters are used for adaptive updating of the terminals.
  • Fig. 8 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment.
  • the apparatus has the function of realizing the terminal in the above method embodiment, and the function may be realized by hardware, or by executing corresponding software in hardware.
  • the apparatus includes: a sending module 601, a receiving module 602 and a model training module 603.
  • the sending module 601 is configured to send data distribution information of the terminal, where the data distribution information includes data categories and the number of samples included in each category;
  • the receiving module 602 is configured to receive model parameters, where the model parameters are obtained by training a training data set selected by the server based on the data distribution information;
  • a model training module 603, configured to train the model parameters to obtain a training result
  • the sending module 601 is further configured to send the training result, where the training result is used to globally update the model parameters to obtain global model parameters.
  • the sending module 601 is configured to send the data distribution information through RRC signaling.
  • the model parameters include initialization model parameters
  • the receiving module 602 is configured to receive the initialization model parameters, where the initialization model parameters are the training data set selected by the server using the data distribution information. owned;
  • the model parameters include intermediate model parameters
  • the receiving module 602 is configured to receive the intermediate model parameters, where the intermediate model parameters are obtained by iteratively updating the initialization model parameters by the server.
  • the training result includes a gradient value
  • the gradient value is a gradient value obtained by testing the trained model parameters after the model parameters are trained
  • the training result includes a model update parameter
  • the model update parameter is a model parameter obtained after training the model parameter.
  • the training result includes model update parameters
  • the apparatus further includes: a determining module 604, configured to determine data transmission parameters based on the data volume of the model update parameter and the communication status of the terminal;
  • the sending module 601 is configured to send the model update parameter to the server according to the data transmission parameter.
  • the sending module 601 is further configured to send user scheduling information, where the user scheduling information includes at least one of the following parameters: data volume of data in the terminal, similarity between data distribution and total data distribution information , communication status, computing capability, and learning model performance requirements, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
  • the receiving module 602 is further configured to receive global model parameters
  • the model training module 603 is further configured to adaptively update the global model parameters.
  • FIG. 9 is a block diagram of a terminal 700 according to an exemplary embodiment.
  • the terminal 700 may include: a processor 701 , a receiver 702 , a transmitter 703 , a memory 704 and a bus 705 .
  • the processor 701 includes one or more processing cores, and the processor 701 executes various functional applications and information processing by running software programs and modules.
  • the receiver 702 and the transmitter 703 may be implemented as a communication component, which may be a communication chip.
  • Memory 704 is connected to processor 701 via bus 705 .
  • the memory 704 may be configured to store at least one instruction, and the processor 701 may be configured to execute the at least one instruction, so as to implement various steps in the foregoing method embodiments.
  • memory 704 may be implemented by any type or combination of volatile or non-volatile storage devices including, but not limited to, magnetic or optical disks, electrically erasable programmable Read Only Memory (EEPROM), Erasable Programmable Read Only Memory (EPROM), Static Anytime Access Memory (SRAM), Read Only Memory (ROM), Magnetic Memory, Flash Memory, Programmable Read Only Memory (PROM) .
  • EEPROM electrically erasable programmable Read Only Memory
  • EPROM Erasable Programmable Read Only Memory
  • SRAM Static Anytime Access Memory
  • ROM Read Only Memory
  • Magnetic Memory Magnetic Memory
  • Flash Memory Programmable Read Only Memory
  • a computer-readable storage medium stores at least one instruction, at least one piece of program, code set or instruction set, the at least one instruction, the At least one section of program, the code set or the instruction set is loaded and executed by the processor to implement the model training method provided by each of the above method embodiments.
  • FIG. 10 is a block diagram of a server 800 according to an exemplary embodiment.
  • the server 800 may include: a processor 801 , a receiver 802 , a transmitter 803 and a memory 804 .
  • the receiver 802, the transmitter 803 and the memory 804 are respectively connected to the processor 801 through a bus.
  • the processor 801 includes one or more processing cores, and the processor 801 executes the method executed by the server in the model training method provided by the embodiment of the present disclosure by running software programs and modules.
  • Memory 804 may be used to store software programs and modules. Specifically, the memory 804 can store the operating system 8041 and an application module 8042 required for at least one function.
  • the receiver 802 is used for receiving communication data sent by other devices, and the transmitter 803 is used for sending communication data to other devices.
  • a computer-readable storage medium stores at least one instruction, at least one piece of program, code set or instruction set, the at least one instruction, the At least one section of program, the code set or the instruction set is loaded and executed by the processor to implement the model training method provided by each of the above method embodiments.
  • An exemplary embodiment of the present disclosure also provides a model training system, where the model training system includes a terminal and a server.
  • the terminal is the terminal provided by the embodiment shown in FIG. 9 .
  • the server is the server provided by the embodiment shown in FIG. 10 .

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Health & Medical Sciences (AREA)
  • Computing Systems (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Molecular Biology (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Information Transfer Between Computers (AREA)
  • Mobile Radio Communication Systems (AREA)

Abstract

The present disclosure relates to a model training method and device, a server, a terminal, and a storage medium, which belong to the technical field of communications. The method comprises: receiving data distribution information of multiple terminals, the data distribution information comprising data categories and the number of samples included in each category; selecting a training data set corresponding to the data distribution information of the multiple terminals; performing model training on the basis of the training data set, so as to obtain model parameters; sending the model parameters to at least part of the multiple terminals; receiving a training result obtained after the at least part of the multiple terminals have trained the model parameters; and updating the model parameters on the basis of the training result of the at least part of the multiple terminals, so as to obtain global model parameters.

Description

模型训练方法、装置、服务器、终端和存储介质Model training method, device, server, terminal and storage medium 技术领域technical field
本公开涉及通信技术领域,尤其涉及一种模型训练方法、装置、服务器、终端和存储介质。The present disclosure relates to the field of communication technologies, and in particular, to a model training method, device, server, terminal and storage medium.
背景技术Background technique
元学习(Meta Learning)是利用以往的知识经验来指导新任务的学习的一种学习方法,元学习具有学会学习(Learning to learn)的能力。Meta-learning is a learning method that uses past knowledge and experience to guide the learning of new tasks. Meta-learning has the ability to learn to learn.
集中式元学习是元学习的一种方案,其步骤通常包括:服务器从各个终端收集数据,并对数据进行整合处理,生成训练数据集;服务器随机初始化一组模型参数,作为全局模型参数;每轮训练从训练数据集中抽取一组任务,一组任务包括多个任务,每个任务包括支持集和查询集;采用抽取的任务的支持集进行局部更新,得到局部更新模型参数;采用抽取的任务的查询集对局部更新模型参数进行模型测试,得到测试梯度;确定各个任务上的测试梯度的平均值,并采用梯度下降法更新全局模型;重复上述过程至模型收敛,得到元模型并分发给各个终端;各个终端采用本地数据对模型进行微调,得到自适应更新模型。Centralized meta-learning is a meta-learning scheme, and its steps usually include: the server collects data from each terminal, and integrates the data to generate a training data set; the server randomly initializes a set of model parameters as global model parameters; each Round training extracts a set of tasks from the training data set, a set of tasks includes multiple tasks, each task includes a support set and a query set; uses the support set of the extracted tasks to perform local updates, and obtains locally updated model parameters; use the extracted tasks The query set is used to test the locally updated model parameters, and the test gradient is obtained; the average value of the test gradient on each task is determined, and the gradient descent method is used to update the global model; the above process is repeated until the model converges, and the meta-model is obtained and distributed to each task. Terminals; each terminal uses local data to fine-tune the model to obtain an adaptive update model.
发明内容SUMMARY OF THE INVENTION
本公开实施例提供了一种模型训练方法、装置、服务器、终端和存储介质,可以节省传输带宽,节省服务器的计算资源。所述技术方案如下:The embodiments of the present disclosure provide a model training method, apparatus, server, terminal and storage medium, which can save transmission bandwidth and save computing resources of the server. The technical solution is as follows:
根据本公开实施例的一方面,提供一种模型训练方法,所述方法包括:According to an aspect of the embodiments of the present disclosure, there is provided a model training method, the method comprising:
接收多个终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;receiving data distribution information of multiple terminals, where the data distribution information includes categories of data and the number of samples included in each category;
选取符合所述多个终端的数据分布信息的训练数据集;selecting a training data set that conforms to the data distribution information of the multiple terminals;
基于所述训练数据集进行模型训练,得到模型参数;Perform model training based on the training data set to obtain model parameters;
将所述模型参数发送给所述多个终端中的至少部分终端;sending the model parameters to at least some of the terminals;
接收所述至少部分终端对所述模型参数进行训练得到的训练结果;receiving a training result obtained by training the model parameters by the at least some terminals;
基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模 型参数。Based on the training results of the at least part of the terminals, the model parameters are updated to obtain global model parameters.
根据本公开实施例的另一方面,提供一种模型训练方法,所述方法包括:According to another aspect of the embodiments of the present disclosure, there is provided a model training method, the method comprising:
发送终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;sending data distribution information of the terminal, where the data distribution information includes categories of data and the number of samples included in each category;
接收模型参数,所述模型参数是服务器基于所述数据分布信息选取的训练数据集训练得到的;Receive model parameters, where the model parameters are obtained by training a training data set selected by the server based on the data distribution information;
对所述模型参数进行训练,得到训练结果;training the model parameters to obtain a training result;
发送所述训练结果,所述训练结果用于对所述模型参数进行全局更新,得到全局模型参数。The training result is sent, and the training result is used to globally update the model parameters to obtain global model parameters.
根据本公开实施例的另一方面,提供一种模型训练装置,所述装置包括:According to another aspect of the embodiments of the present disclosure, there is provided a model training apparatus, the apparatus comprising:
接收模块,被配置为接收多个终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;a receiving module, configured to receive data distribution information of multiple terminals, where the data distribution information includes data categories and the number of samples included in each category;
选取模块,被配置为选取符合所述多个终端的数据分布信息的训练数据集;a selection module, configured to select a training data set that conforms to the data distribution information of the multiple terminals;
模型训练模块,被配置为基于所述训练数据集进行模型训练,得到模型参数;a model training module, configured to perform model training based on the training data set to obtain model parameters;
发送模块,被配置为将所述模型参数发送给所述多个终端中的至少部分终端;a sending module, configured to send the model parameters to at least some of the terminals;
所述接收模块,还被配置为接收所述至少部分终端对所述模型参数进行训练得到的训练结果;The receiving module is further configured to receive a training result obtained by training the model parameters by the at least part of the terminals;
所述模型训练模块,还被配置为基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数。The model training module is further configured to update the model parameters based on the training results of the at least part of the terminals to obtain global model parameters.
根据本公开实施例的另一方面,提供一种模型训练装置,所述装置包括:According to another aspect of the embodiments of the present disclosure, there is provided a model training apparatus, the apparatus comprising:
发送模块,被配置为发送终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;a sending module, configured to send data distribution information of the terminal, where the data distribution information includes data categories and the number of samples included in each category;
接收模块,被配置为接收模型参数,所述模型参数是服务器基于所述数据分布信息选取的训练数据集训练得到的;a receiving module, configured to receive model parameters, the model parameters are obtained by training a training data set selected by the server based on the data distribution information;
模型训练模块,被配置为对所述模型参数进行训练,得到训练结果;a model training module, configured to train the model parameters to obtain a training result;
所述发送模块,还被配置为发送所述训练结果,所述训练结果用于对所述模型参数进行全局更新,得到全局模型参数。The sending module is further configured to send the training result, where the training result is used to globally update the model parameters to obtain global model parameters.
根据本公开实施例的另一方面,提供一种服务器,所述服务器包括:处理器;用于存储处理器可执行指令的存储器;其中,所述处理器被配置为加载并执行所述可执行指令以实现前述模型训练方法。According to another aspect of embodiments of the present disclosure, there is provided a server, the server comprising: a processor; a memory for storing instructions executable by the processor; wherein the processor is configured to load and execute the executable Instructions to implement the aforementioned model training method.
根据本公开实施例的另一方面,提供一种终端,所述终端包括:处理器;用于存储处理器可执行指令的存储器;其中,所述处理器被配置为加载并执行所述可执行指令以实现前述模型训练方法。According to another aspect of the embodiments of the present disclosure, a terminal is provided, the terminal comprising: a processor; a memory for storing instructions executable by the processor; wherein the processor is configured to load and execute the executable Instructions to implement the aforementioned model training method.
根据本公开实施例的另一方面,提供一种计算机可读存储介质,当所述计算机可读存储介质中的指令由处理器执行时,能够执行如前所述的模型训练方法。According to another aspect of the embodiments of the present disclosure, a computer-readable storage medium is provided, when the instructions in the computer-readable storage medium are executed by a processor, the aforementioned model training method can be executed.
在本公开实施例中,服务器接收终端发送的数据分布信息,然后汇总各个终端的数据分布信息,选择分布相同的数据组成训练集进行初步训练,然后将模型参数发给终端进行分布式训练,然后再由终端将训练结果上传给服务器进行全局更新;在此过程中,终端和服务器之间传输的是数据分布信息、模型参数等,而没有直接传输终端的数据,带宽占用小;并且,通过终端分布式训练,服务器计算资源消耗小。In the embodiment of the present disclosure, the server receives the data distribution information sent by the terminal, then summarizes the data distribution information of each terminal, selects data with the same distribution to form a training set for preliminary training, and then sends the model parameters to the terminal for distributed training, and then Then the terminal uploads the training results to the server for global update; in this process, the data distribution information, model parameters, etc. are transmitted between the terminal and the server, but the data of the terminal is not directly transmitted, and the bandwidth occupation is small; moreover, through the terminal Distributed training, low consumption of server computing resources.
应当理解的是,以上的一般描述和后文的细节描述仅是示例性和解释性的,并不能限制本公开。It is to be understood that the foregoing general description and the following detailed description are exemplary and explanatory only and are not restrictive of the present disclosure.
附图说明Description of drawings
此处的附图被并入说明书中并构成本说明书的一部分,示出了符合本公开的实施例,并与说明书一起用于解释本公开的原理。The accompanying drawings, which are incorporated in and constitute a part of this specification, illustrate embodiments consistent with the disclosure and together with the description serve to explain the principles of the disclosure.
图1示出的是本公开一个示意性实施例提供的模型训练系统的框图;1 shows a block diagram of a model training system provided by an exemplary embodiment of the present disclosure;
图2是根据一示例性实施例示出的一种模型训练方法的流程图;FIG. 2 is a flowchart of a model training method according to an exemplary embodiment;
图3是根据一示例性实施例示出的一种模型训练方法的流程图;3 is a flowchart of a model training method according to an exemplary embodiment;
图4是根据一示例性实施例示出的一种模型训练方法的流程图;4 is a flowchart of a model training method according to an exemplary embodiment;
图5是根据一示例性实施例示出的一种连接建立过程的流程图;FIG. 5 is a flowchart showing a connection establishment process according to an exemplary embodiment;
图6是根据一示例性实施例示出的一种初始化模型参数训练过程的流程图;FIG. 6 is a flow chart of an initialization model parameter training process according to an exemplary embodiment;
图7是根据一示例性实施例示出的一种模型训练装置的结构示意图;7 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment;
图8是根据一示例性实施例示出的一种模型训练装置的结构示意图;8 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment;
图9是根据一示例性实施例示出的一种终端的框图;FIG. 9 is a block diagram of a terminal according to an exemplary embodiment;
图10是根据一示例性实施例示出的一种服务器的框图。Fig. 10 is a block diagram of a server according to an exemplary embodiment.
具体实施方式Detailed ways
这里将详细地对示例性实施例进行说明,其示例表示在附图中。下面的描述涉及附图时,除非另有表示,不同附图中的相同数字表示相同或相似的要素。以下示例性实施例中所描述的实施方式并不代表与本公开相一致的所有实施方式。相反,它们仅是与如所附权利要求书中所详述的、本公开的一些方面相一致的装置和方法的例子。Exemplary embodiments will be described in detail herein, examples of which are illustrated in the accompanying drawings. Where the following description refers to the drawings, the same numerals in different drawings refer to the same or similar elements unless otherwise indicated. The implementations described in the illustrative examples below are not intended to represent all implementations consistent with this disclosure. Rather, they are merely examples of apparatus and methods consistent with some aspects of the present disclosure as recited in the appended claims.
图1示出的是本公开一个示意性实施例提供的模型训练系统的框图,如图1所示,该模型训练系统可以包括:网络侧12和终端13。FIG. 1 shows a block diagram of a model training system provided by an exemplary embodiment of the present disclosure. As shown in FIG. 1 , the model training system may include: a network side 12 and a terminal 13 .
网络侧12中包括服务器120,服务器120与终端13之间通过无线信道进行通信。在本公开实施例中,服务器120可以属于网络侧设备的一个功能单元,网络侧设备可以是基站,基站是一种部署在接入网中用以为终端提供无线通信功能的装置。终端13为接入网络侧设备的终端,网络侧设备协调各终端参与分布式协作学习。The network side 12 includes a server 120, and the server 120 communicates with the terminal 13 through a wireless channel. In this embodiment of the present disclosure, the server 120 may belong to a functional unit of a network-side device, and the network-side device may be a base station, which is a device deployed in an access network to provide a wireless communication function for a terminal. The terminal 13 is a terminal accessing a network side device, and the network side device coordinates each terminal to participate in distributed collaborative learning.
基站可以包括各种形式的宏基站,微基站,中继站,接入点等等。在采用不同的无线接入技术的系统中,具备基站功能的设备的名称可能会有所不同,在5G新空口(NR,New Radio)系统中,称为gNodeB或者gNB。随着通信技术的演进,“基站”这一名称可能描述,会变化。为方便描述,下文中将上述为终端提供无线通信功能的装置统称为网络侧设备。The base station may include various forms of macro base station, micro base station, relay station, access point and so on. In systems using different wireless access technologies, the names of devices with base station functions may be different. In 5G New Radio (NR, New Radio) systems, they are called gNodeBs or gNBs. With the evolution of communication technology, the name "base station" may be descriptive and will change. For convenience of description, hereinafter, the above-mentioned apparatuses for providing wireless communication functions for terminals are collectively referred to as network-side devices.
终端13可以包括各种具有无线通信功能的手持设备、车载设备、可穿戴设备、计算设备或连接到无线调制解调器的其他处理设备,以及各种形式的用户设备,移动台(Mobile Station,MS),终端等等。为方便描述,上面提到的设备统称为终端。接入网设备120与终端13之间通过某种空口技术互相通信,例如Uu接口。The terminal 13 may include various handheld devices with wireless communication functions, vehicle-mounted devices, wearable devices, computing devices or other processing devices connected to the wireless modem, as well as various forms of user equipment, mobile stations (Mobile Station, MS), terminal and so on. For the convenience of description, the devices mentioned above are collectively referred to as terminals. The access network device 120 and the terminal 13 communicate with each other through a certain air interface technology, such as a Uu interface.
在相关技术中,服务器120将各个终端的数据收集起来进行集中式元学习(或称为集中式训练),这种方式一方面需要进行数据传输,带宽占用大;另一方面所有的训练工作都由服务器完成,训练时间长,服务器计算资源消耗较大,并且元学习并不追求一个最优的全局模型,而是希望训练一个能够快速适应新任务的初始化良好的模型,因此,集中式元学习方案中冗长的模型收敛周期带来的性能提升很小,模型训练效率低。In the related art, the server 120 collects the data of each terminal for centralized meta-learning (or called centralized training). On the one hand, this method requires data transmission and occupies a large bandwidth; on the other hand, all training work Completed by the server, the training time is long, the server computing resource consumption is large, and meta-learning does not pursue an optimal global model, but hopes to train a well-initialized model that can quickly adapt to new tasks. Therefore, centralized meta-learning The lengthy model convergence period in the solution brings little performance improvement, and the model training efficiency is low.
另外,相关技术中,很多接入网络的边缘终端常由于数据隐私、安全问题等而无法将数据(小样本数据)上传给服务器进行训练,这些边缘终端的数据往往 包含了大量的信息,对于学习模型性能提升十分重要,而集中式元学习方案无法充分利用边缘终端的数据。并且,集中式元学习方案不加选择地将服务器数据作为训练数据,服务器数据与终端相关性弱,则训练出的模型面对于终端的任务泛化能力弱。In addition, in the related art, many edge terminals accessing the network are often unable to upload data (small sample data) to the server for training due to data privacy, security issues, etc. The data of these edge terminals often contains a large amount of information. Model performance improvement is important, and centralized meta-learning schemes cannot fully utilize data from edge terminals. In addition, the centralized meta-learning scheme indiscriminately uses server data as training data, and the server data is weakly correlated with the terminal, so the trained model has weak generalization ability for the task of the terminal.
本公开实施例描述的模型训练系统以及业务场景是为了更加清楚地说明本公开实施例的技术方案,并不构成对本公开实施例提供的技术方案的限定,本领域普通技术人员可知,随着模型训练系统的演变和新业务场景的出现,本公开实施例提供的技术方案对于类似的技术问题,同样适用。The model training system and business scenarios described in the embodiments of the present disclosure are for the purpose of illustrating the technical solutions of the embodiments of the present disclosure more clearly, and do not constitute a limitation on the technical solutions provided by the embodiments of the present disclosure. For the evolution of the training system and the emergence of new business scenarios, the technical solutions provided by the embodiments of the present disclosure are also applicable to similar technical problems.
图2是根据一示例性实施例示出的一种模型训练方法的流程图。参见图2,该方法包括以下步骤:Fig. 2 is a flowchart of a model training method according to an exemplary embodiment. Referring to Figure 2, the method includes the following steps:
在步骤101中,服务器接收多个终端的数据分布信息。In step 101, the server receives data distribution information of multiple terminals.
其中,所述数据分布信息包括数据的类别以及各个类别包含的样本数。The data distribution information includes categories of data and the number of samples included in each category.
数据的类别用于对数据集合中的数据进行分类,例如在训练图片分类模型,可以将数据集合中的图片,按照分类的需求分为多个类别,例如人物、植物、风景等等,每种都属于一个类别。各个类别包含的样本数是指每种类别的数据的数量,例如类别是人物的图片的数量。The data category is used to classify the data in the data set. For example, when training a picture classification model, the pictures in the data set can be divided into multiple categories according to the classification requirements, such as people, plants, landscapes, etc. belong to one category. The number of samples included in each category refers to the amount of data of each category, for example, the number of pictures whose category is a person.
在步骤102中,服务器选取符合所述多个终端的数据分布信息的训练数据集。In step 102, the server selects a training data set that conforms to the data distribution information of the multiple terminals.
服务器在接收到多个终端的数据分布信息后,按照多个终端的数据分布信息,选取出类别和样本数与数据分布信息包括的类别和样本数相同的数据,组成训练数据集。After receiving the data distribution information of the multiple terminals, the server selects data with the same category and number of samples as those included in the data distribution information according to the data distribution information of the multiple terminals to form a training data set.
在步骤103中,服务器基于所述训练数据集进行模型训练,得到模型参数。In step 103, the server performs model training based on the training data set to obtain model parameters.
在步骤104中,服务器将所述模型参数发送给所述多个终端中的至少部分终端。In step 104, the server sends the model parameters to at least some of the terminals.
服务器可以基于终端的用户调度信息,选择出符合分布式训练要求的终端,让这些终端参与分布式训练。The server can select the terminals that meet the distributed training requirements based on the user scheduling information of the terminals, and let these terminals participate in the distributed training.
在步骤105中,服务器接收所述至少部分终端对所述模型参数进行训练得到的训练结果。In step 105, the server receives a training result obtained by training the model parameters by the at least some terminals.
在步骤106中,服务器基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数。In step 106, the server updates the model parameters based on the training results of the at least part of the terminals to obtain global model parameters.
终端在对服务器提供的模型参数进行训练后,将训练结果上报给服务器,服务器可以基于各个终端的训练结果完成全局更新,得到全局模型参数。该全局模型参数一方面达到了训练目标,另一方面由于综合了至少部分终端的训练结果,适用于上述至少部分终端。After the terminal trains the model parameters provided by the server, the training result is reported to the server, and the server can complete the global update based on the training results of each terminal to obtain the global model parameters. On the one hand, the global model parameters achieve the training target, and on the other hand, because the training results of at least part of the terminals are integrated, they are suitable for at least part of the terminals mentioned above.
在本公开实施例中,服务器接收终端发送的数据分布信息,然后汇总各个终端的数据分布信息,选择分布相同的数据组成训练集进行初步训练,然后将模型参数发给终端进行分布式训练,然后再由终端将训练结果上传给服务器进行全局更新;在此过程中,终端和服务器之间传输的是数据分布信息、模型参数等,而没有直接传输终端的数据,带宽占用小;并且,通过终端分布式训练,服务器计算资源消耗小,训练周期短,训练效率高。同时,由于服务器在进行初步训练时,选择了和终端分布相同的数据组成训练集,加强了模型和终端的关联性,模型泛化能力强;另外,分布式训练过程,由终端直接参与,由于终端无需上传自己的数据,即使边缘终端也可以参与,使得边缘终端的数据可以用于学习模型的性能提升,从而保证训练方案可以充分利用边缘终端的数据。In the embodiment of the present disclosure, the server receives the data distribution information sent by the terminal, then summarizes the data distribution information of each terminal, selects data with the same distribution to form a training set for preliminary training, and then sends the model parameters to the terminal for distributed training, and then Then the terminal uploads the training results to the server for global update; in this process, the data distribution information, model parameters, etc. are transmitted between the terminal and the server, but the data of the terminal is not directly transmitted, and the bandwidth occupation is small; moreover, through the terminal Distributed training, the consumption of server computing resources is small, the training period is short, and the training efficiency is high. At the same time, since the server selects the same data distributed as the terminal to form the training set during preliminary training, the correlation between the model and the terminal is strengthened, and the model has strong generalization ability; in addition, the distributed training process is directly participated by the terminal. The terminal does not need to upload its own data, and even the edge terminal can participate, so that the data of the edge terminal can be used to improve the performance of the learning model, thus ensuring that the training scheme can make full use of the data of the edge terminal.
本公开实施例提供的利用数据分布特性进行分布式协作学习的方案,适用于训练泛化能力强的元模型,例如用于深度学习、图像处理等任务的模型训练。The solution for distributed collaborative learning using data distribution characteristics provided by the embodiments of the present disclosure is suitable for training meta-models with strong generalization capabilities, such as model training for tasks such as deep learning and image processing.
可选地,接收多个终端的数据分布信息,包括:Optionally, receive data distribution information of multiple terminals, including:
接收所述多个终端中每个终端通过无线资源控制(Radio Resource Control,RRC)信令传输的所述数据分布信息。Receive the data distribution information transmitted by each of the multiple terminals through radio resource control (Radio Resource Control, RRC) signaling.
在本公开实施例中,终端和服务器在传输数据分布信息时,可以先建立RRC连接,并在RRC连接建立过程中,通过RRC信令传输上述数据分布信息。这样,可以简化数据分布信息的上传过程。In the embodiment of the present disclosure, when transmitting the data distribution information, the terminal and the server may first establish an RRC connection, and in the process of establishing the RRC connection, transmit the above-mentioned data distribution information through RRC signaling. In this way, the uploading process of the data distribution information can be simplified.
可选地,选取符合所述多个终端的数据分布信息的训练数据集,包括:Optionally, selecting a training data set that conforms to the data distribution information of the multiple terminals, including:
将所述多个终端的数据分布信息合并,得到总数据分布信息;combining the data distribution information of the multiple terminals to obtain total data distribution information;
从服务器的本地数据中抽选出分布符合所述总数据分布信息的数据,得到所述训练数据集。The data whose distribution conforms to the total data distribution information is extracted from the local data of the server to obtain the training data set.
例如,终端1的数据分布信息包括:{类型A,样本数量a1;类型B,样本数量b};终端2的数据分布信息包括:{类型A,样本数量a2;类型C,样本数量c};则,总数据分布信息包括:{类型A,样本数量a1+a2;类型B,样本数量b;类型C,样本数量c}。服务器按照{类型A,样本数量a1+a2;类型B,样本数量b;类型C,样本数量c}选取数据类型和样本量,组成训练数据集。For example, the data distribution information of terminal 1 includes: {type A, sample size a1; type B, sample size b}; the data distribution information of terminal 2 includes: {type A, sample size a2; type C, sample size c}; Then, the total data distribution information includes: {type A, sample size a1+a2; type B, sample size b; type C, sample size c}. The server selects the data type and sample size according to {type A, sample size a1+a2; type B, sample size b; type C, sample size c} to form a training data set.
可选地,所述模型参数包括初始化模型参数,基于所述训练数据集进行模型训练,得到模型参数,包括:Optionally, the model parameters include initialization model parameters, and model training is performed based on the training data set to obtain model parameters, including:
采用所述训练数据集进行模型训练,得到所述初始化模型参数;Use the training data set to perform model training to obtain the initialization model parameters;
或者,所述模型参数包括中间模型参数,基于所述训练数据集进行模型训练,得到模型参数,包括:Alternatively, the model parameters include intermediate model parameters, and model training is performed based on the training data set to obtain model parameters, including:
采用所述训练数据集进行模型训练,得到初始化模型参数;Use the training data set to perform model training to obtain initialization model parameters;
对所述初始化模型参数进行迭代更新,得到所述中间模型参数。The initial model parameters are iteratively updated to obtain the intermediate model parameters.
在本公开实施例中,服务器一方面根据训练数据集进行模型训练得到一个初始化模型参数,然后将初始化模型参数下发给终端,这样可以节省终端训练的时间;另一方面,服务器基于终端的训练结果更新初始化模型参数,得到中间模型参数,然后将中间模型参数发送给终端,让终端可以在中间模型参数的基础上进行训练,加快整个模型训练的过程。In the embodiment of the present disclosure, on the one hand, the server performs model training according to the training data set to obtain an initialized model parameter, and then sends the initialized model parameter to the terminal, which can save the training time of the terminal; on the other hand, the server performs training based on the terminal. As a result, the initialized model parameters are updated to obtain the intermediate model parameters, and then the intermediate model parameters are sent to the terminal, so that the terminal can perform training on the basis of the intermediate model parameters, thereby speeding up the entire model training process.
其中,中间模型参数是服务器基于终端上传的训练结果进行迭代更新得到的。The intermediate model parameters are obtained by iteratively updated by the server based on the training results uploaded by the terminal.
可选地,将所述模型参数发送给所述多个终端中的至少部分终端,包括:Optionally, sending the model parameters to at least some of the multiple terminals includes:
接收所述多个终端中每个终端的用户调度信息;receiving user scheduling information of each terminal in the plurality of terminals;
基于所述多个终端中每个终端的用户调度信息,确定所述多个终端中每个终端是否符合分布式训练要求;determining, based on the user scheduling information of each terminal in the plurality of terminals, whether each terminal in the plurality of terminals meets the distributed training requirement;
将所述模型参数发送给所述多个终端中符合分布式训练要求的终端。Sending the model parameters to a terminal that meets the distributed training requirement among the multiple terminals.
示例性地,所述用户调度信息包括如下参数中的至少一项:Exemplarily, the user scheduling information includes at least one of the following parameters:
终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为所述多个终端的数据分布信息合并得到的。The data volume of the data in the terminal, the similarity between the data distribution and the total data distribution information, the communication status, the computing capability, and the performance requirements of the learning model, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
其中,终端中数据的数据量可以基于终端上传的数据分布信息得到,也即数据分布信息中各类数据的样本量之和。数据分布与总数据分布信息的相似性是指,一个终端包括的类别和总数据分布信息中类别的差异型,比如终端包括的类别数量与总数据分布信息中类别数量的比值,以及终端中每种类别的样本数与总数据分布信息中对应类别的样本数的比值,综合上述两个比值得到上述差异型。通信状况可以包括信道质量指示信息(Channel Quality Indication,CQI)。计算能力可以包括计算速度和设备富余算力,计算速度是指每秒的计算次数(计算次数/S),设备富余算力是指可分配给模型训练的算力百分比。学习模型性能 要求包括对任务的偏好和对精度的要求,其中,对任务的偏好:可以用本地可能执行的任务的概率特性来表示,以分类任务为例,以任务中各个类别数据出现的概率来表示:P={p(类别1),p(类别2),…};示例性地,对精度的要求可以如下:模型精度>90%。The data amount of the data in the terminal may be obtained based on the data distribution information uploaded by the terminal, that is, the sum of the sample sizes of various types of data in the data distribution information. The similarity between the data distribution and the total data distribution information refers to the difference between the categories included in a terminal and the categories in the total data distribution information, such as the ratio of the number of categories included in the terminal to the number of categories in the total data distribution information, and the number of categories in the terminal. The ratio of the number of samples of a category to the number of samples of the corresponding category in the total data distribution information, and the above difference type is obtained by combining the above two ratios. The communication status may include Channel Quality Indication (CQI). Computing power can include computing speed and equipment surplus computing power. Computing speed refers to the number of calculations per second (calculation times/S), and equipment surplus computing power refers to the percentage of computing power that can be allocated to model training. The performance requirements of the learning model include the preference for tasks and the requirements for accuracy. Among them, the preference for tasks can be represented by the probabilistic characteristics of tasks that may be performed locally. Taking classification tasks as an example, the probability of occurrence of each category of data in the task is used. to represent: P={p(category 1), p(category 2), ...}; exemplarily, the requirements for accuracy may be as follows: model accuracy>90%.
示例性地,对于用户调度信息中的每个参数,服务器都设定一个符合分布式训练要求的阈值范围,当一个终端的各个参数都符合设定的阈值范围时,该终端符合分布式训练要求。Exemplarily, for each parameter in the user scheduling information, the server sets a threshold range that meets the distributed training requirements. When each parameter of a terminal meets the set threshold range, the terminal meets the distributed training requirements. .
可选地,将所述模型参数发送给所述多个终端中符合分布式训练要求的终端,包括:Optionally, sending the model parameters to a terminal that meets the distributed training requirements among the multiple terminals, including:
基于所述模型参数的数据量和所述符合分布式训练要求的终端的通信状况,确定数据传输参数;Determine data transmission parameters based on the data volume of the model parameters and the communication status of the terminals that meet the distributed training requirements;
按照所述数据传输参数将所述模型参数发送给所述符合分布式训练要求的终端。The model parameters are sent to the terminal that meets the distributed training requirements according to the data transmission parameters.
这里,数据传输参数包括调制方式、码率等参数,对于模型参数的数据量多少不同、终端的通信状况不同时,可以选用不同的调制方式和码率进行传输,使得选用的调制方式和码率与当前要传输的数据量以及终端的通信状况匹配,从而达到较好的传输效果。Here, the data transmission parameters include parameters such as modulation mode and code rate. When the data amount of the model parameters is different and the communication status of the terminal is different, different modulation modes and code rates can be selected for transmission, so that the selected modulation mode and code rate can be used for transmission. It matches the amount of data to be transmitted and the communication status of the terminal, so as to achieve a better transmission effect.
其中,模型参数的数据量一方面与模型大小相关,模型越大,则模型参数的数据量越大;另一方面还与每个模型参数的精度有关,模型精度越高,则模型参数的数据量越大。其中,模型参数的精度可以是指保留到小数点后的位数,模型参数的精度越高,保留到小数点后的位数越多,则模型参数所占的数据量越大。Among them, the data volume of model parameters is related to the size of the model, the larger the model, the larger the data volume of model parameters; on the other hand, it is also related to the accuracy of each model parameter the larger the amount. The precision of the model parameters may refer to the number of digits retained after the decimal point. The higher the precision of the model parameters and the more digits retained after the decimal point, the larger the amount of data occupied by the model parameters.
可选地,所述训练结果包括梯度值,所述梯度值为所述终端对所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;Optionally, the training result includes a gradient value, and the gradient value is a gradient value obtained by testing the trained model parameters after the terminal trains the model parameters;
或者,所述训练结果包括模型更新参数,所述模型更新参数为所述终端对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after the terminal trains the model parameter.
在本公开实施例中,终端的训练结果可以有两种情况,一种是训练完成后测试得到的梯度值,另一种是只进行了模型训练得到的模型更新参数,而未进行测试。存在这两种情况的原因是终端中数据的数据量大小不同,例如,当终端中数据的数据量较大时,终端中的数据可以组成支持集和查询集,此时终端可以先采用支持集进行模型训练,然后采用查询集进行模型测试;而当终端中数据的数据量较小时,终端中的数据只能组成支持集,此时终端采用支持集进行模型训练, 而模型测试则由服务器完成。In the embodiment of the present disclosure, the training result of the terminal can be in two cases, one is the gradient value obtained by testing after the training is completed, and the other is the model update parameters obtained only after model training without testing. The reason for these two situations is that the data volume of the data in the terminal is different. For example, when the data volume in the terminal is large, the data in the terminal can form a support set and a query set. In this case, the terminal can use the support set first. Model training is performed, and then the query set is used for model testing; when the amount of data in the terminal is small, the data in the terminal can only form a support set. At this time, the terminal uses the support set for model training, and the model test is completed by the server. .
这里,终端中数据量的大小可以通过与阈值比较得到,例如大于阈值则为较大,小于阈值则为较小。该阈值可以基于多个终端的数据量确定,例如,可以是多个终端的数据量的一个分位数,如,80%用户的数据量达到1000,则将该阈值设为1000。该阈值可以由服务器基于各个终端的数据分布信息确定出来,然后通知给各个终端。Here, the size of the data amount in the terminal can be obtained by comparing with the threshold value, for example, if it is larger than the threshold value, it is larger, and if it is smaller than the threshold value, it is smaller. The threshold may be determined based on the data volume of multiple terminals, for example, may be a quantile of the data volume of multiple terminals. For example, if the data volume of 80% of users reaches 1000, the threshold is set to 1000. The threshold may be determined by the server based on the data distribution information of each terminal, and then notified to each terminal.
可选地,当所述至少部分终端中每个终端的训练结果均包括梯度值时,Optionally, when the training result of each terminal in the at least part of the terminals includes gradient values,
基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数,包括:Based on the training results of the at least part of the terminals, the model parameters are updated to obtain global model parameters, including:
基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。Based on the average value of the gradient values of the at least part of the terminals, the model parameters are iteratively updated using a gradient descent method to obtain global model parameters.
在这种情况下,至少部分终端中每个终端的数据量都较多,能够同时组成支持集和查询集,因此,至少部分终端中每个终端都向服务器上报梯度值,从而方便服务器完成对模型参数的迭代更新。In this case, each terminal in at least some of the terminals has a large amount of data and can form a support set and a query set at the same time. Therefore, each terminal in at least some of the terminals reports the gradient value to the server, so as to facilitate the server to complete the matching Iterative update of model parameters.
可选地,当所述至少部分终端中至少一个终端的训练结果包括模型更新参数时,Optionally, when the training result of at least one terminal in the at least part of the terminals includes model update parameters,
基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数,包括:Based on the training results of the at least part of the terminals, the model parameters are updated to obtain global model parameters, including:
选取符合第一终端的数据分布信息的查询集,所述第一终端为训练结果包括模型更新参数的终端;selecting a query set that conforms to the data distribution information of the first terminal, where the first terminal is a terminal whose training result includes model update parameters;
基于所述查询集对所述第一终端的模型更新参数进行测试,得到梯度值;Test the model update parameters of the first terminal based on the query set to obtain a gradient value;
基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。Based on the average value of the gradient values of the at least part of the terminals, the model parameters are iteratively updated using a gradient descent method to obtain global model parameters.
在这种情况下,部分终端的数据量较少,不能够同时组成支持集和查询集,因此,这些终端只向服务器上报模型更新参数,由服务器从本地抽取查询集进行模型测试,然后利用测试得到的梯度值完成对模型参数的迭代更新。In this case, some terminals have a small amount of data and cannot form a support set and a query set at the same time. Therefore, these terminals only report model update parameters to the server, and the server extracts the query set from the local for model testing, and then uses the test set. The obtained gradient values complete the iterative update of the model parameters.
可选地,基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数,包括:Optionally, based on the average value of the gradient values of the at least part of the terminals, using a gradient descent method to iteratively update the model parameters to obtain global model parameters, including:
基于所述至少部分终端的第一梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新;Based on the average value of the first gradient values of the at least part of the terminals, iteratively update the model parameters by using a gradient descent method;
确定所述至少部分终端的第一梯度值的平均值是否在阈值范围内;determining whether the average value of the first gradient values of the at least part of the terminals is within a threshold range;
响应于所述至少部分终端的第一梯度值的平均值不在阈值范围内,将所述模型参数经过迭代更新后的中间模型参数发送给所述至少部分终端;In response to the average value of the first gradient values of the at least part of the terminals being not within the threshold range, sending the intermediate model parameters of the model parameters after the iterative update of the model parameters to the at least part of the terminals;
采用所述至少部分终端的第二梯度值的平均值对所述中间模型参数进行迭代更新;其中,所述第二梯度值是所述终端对所述中间模型参数进行训练后,通过对训练后的所述中间模型参数测试得到的梯度值。The intermediate model parameters are iteratively updated by using the average value of the second gradient values of the at least part of the terminals; The gradient values obtained by testing the intermediate model parameters.
可选地,所述方法还包括:Optionally, the method further includes:
响应于所述至少部分终端的第一梯度值的平均值在阈值范围内,将所述模型参数经过迭代更新后的全局模型参数发送给所述至少部分终端,所述全局模型参数用于供所述终端进行自适应更新。In response to the average value of the first gradient values of the at least part of the terminals being within the threshold range, the global model parameters after the iterative update of the model parameters are sent to the at least part of the terminals, and the global model parameters are used for all the terminals. The terminal performs adaptive update.
在本公开实施例中,对于模型参数的更新通常是一个多轮次分布式训练过程,也即,至少部分终端进行一次模型训练,并上报各自的训练结果为一轮分布式训练过程。这一轮分布式训练过程结束后,一方面,服务器基于训练结果进行全局更新;另一方面,终端可以基于这些终端的训练结果对应的梯度值的平均值确定是否达到分布式训练的要求,如果达到要求,则将全局更新后的模型作为全局更新模型,全局更新模型不需要再进行分布式训练,而是可以供用户进行自适应训练后进行使用;如果没有达到要求,则将全局更新后的中间模型参数发送给终端,作为终端下一轮训练的基础,终端下一轮训练在此中间模型参数的基础上进行。In the embodiment of the present disclosure, the updating of model parameters is usually a multi-round distributed training process, that is, at least some terminals perform model training once and report their respective training results as a round of distributed training process. After this round of distributed training process is over, on the one hand, the server performs a global update based on the training results; on the other hand, the terminal can determine whether it meets the requirements of distributed training based on the average value of the gradient values corresponding to the training results of these terminals. If the requirements are met, the globally updated model will be used as the globally updated model. The globally updated model does not require distributed training, but can be used by users after adaptive training; if the requirements are not met, the globally updated model will be used. The intermediate model parameters are sent to the terminal as the basis for the next round of terminal training, and the next round of terminal training is performed on the basis of the intermediate model parameters.
在该实现方式中,服务器对分布式模型训练的效果进行监控,在模型精度满足要求时停止学习,而不要求训练至模型收敛。这种训练方式大大提高了训练效率,同时全局模型参数后续会由各个终端进行自适应更新,使得每个终端得到更加个性化的模型,保证了终端使用的模型更符合终端的任务需求,保证了模型性能。In this implementation, the server monitors the effect of the distributed model training, and stops learning when the model accuracy meets the requirements, and does not require training until the model converges. This training method greatly improves the training efficiency. At the same time, the global model parameters will be adaptively updated by each terminal in the future, so that each terminal can obtain a more personalized model, which ensures that the model used by the terminal is more in line with the terminal's task requirements. model performance.
其中,终端进行自适应更新,可以是指终端在全局模型参数的基础上,利用终端中的数据进行更新,使得模型参数符合终端需求。The adaptive updating of the terminal may refer to that the terminal uses data in the terminal to update based on the global model parameters, so that the model parameters meet the requirements of the terminal.
值得说明的是,前述步骤101~步骤102与上述可选步骤可以任意组合。It should be noted that the foregoing steps 101 to 102 and the foregoing optional steps may be combined arbitrarily.
图3是根据一示例性实施例示出的一种模型训练方法的流程图。参见图3,该方法包括以下步骤:Fig. 3 is a flow chart of a model training method according to an exemplary embodiment. Referring to Figure 3, the method includes the following steps:
在步骤201中,终端发送终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数。In step 201, the terminal sends data distribution information of the terminal, where the data distribution information includes data types and the number of samples included in each type.
终端统计本地各个数据类别的样本数,生成数据分布信息,并发送给服务器。The terminal counts the number of local samples of each data category, generates data distribution information, and sends it to the server.
在步骤202中,终端接收模型参数,所述模型参数是服务器基于所述数据分布信息选取的训练数据集训练得到的。In step 202, the terminal receives model parameters, where the model parameters are obtained by training a training data set selected by the server based on the data distribution information.
这里的模型参数,既可以是初始模型参数,也可以是中间模型参数。The model parameters here can be either initial model parameters or intermediate model parameters.
在步骤203中,终端对所述模型参数进行训练,得到训练结果。In step 203, the terminal trains the model parameters to obtain a training result.
在步骤204中,终端发送所述训练结果,所述训练结果用于对所述模型参数进行全局更新,得到全局模型参数。In step 204, the terminal sends the training result, and the training result is used to globally update the model parameters to obtain global model parameters.
在本公开实施例中,终端将自己的数据分布信息发送给服务器,由服务器汇总各个终端的数据分布信息,选择分布相同的数据组成训练集进行初步训练,然后将模型参数发给终端进行分布式训练,然后再由终端将训练结果上传给服务器进行全局更新;在此过程中,终端和服务器之间传输的是数据分布信息、模型参数等,而没有直接传输终端的数据,带宽占用小;并且,通过终端分布式训练,服务器计算资源消耗小。In the embodiment of the present disclosure, the terminal sends its own data distribution information to the server, the server summarizes the data distribution information of each terminal, selects data with the same distribution to form a training set for preliminary training, and then sends the model parameters to the terminal for distribution training, and then the terminal uploads the training results to the server for global update; in this process, the data distribution information, model parameters, etc. are transmitted between the terminal and the server, but the data of the terminal is not directly transmitted, and the bandwidth consumption is small; and , through the terminal distributed training, the server computing resource consumption is small.
可选地,发送终端的数据分布信息,包括:Optionally, send the data distribution information of the terminal, including:
通过RRC信令发送所述数据分布信息。The data distribution information is sent through RRC signaling.
可选地,所述模型参数包括初始化模型参数,接收模型参数,包括:Optionally, the model parameters include initializing model parameters and receiving model parameters, including:
接收所述初始化模型参数,所述初始化模型参数是所述服务器采用所述数据分布信息选取的训练数据集训练得到的;Receive the initialization model parameters, where the initialization model parameters are obtained by training the server using the training data set selected by the data distribution information;
或者,所述模型参数包括中间模型参数,接收模型参数,包括:Alternatively, the model parameters include intermediate model parameters, and the received model parameters include:
接收所述中间模型参数,所述中间模型参数是所述服务器对初始化模型参数进行迭代更新得到的。The intermediate model parameters are received, where the intermediate model parameters are obtained by iteratively updating the initialization model parameters by the server.
可选地,所述训练结果包括梯度值,所述梯度值为所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;Optionally, the training result includes a gradient value, and the gradient value is a gradient value obtained by testing the trained model parameters after the model parameters are trained;
或者,所述训练结果包括模型更新参数,所述模型更新参数为对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after training the model parameter.
可选地,当所述训练结果包括模型更新参数时,Optionally, when the training result includes model update parameters,
发送所述训练结果,包括:Send the training results, including:
基于所述模型更新参数的数据量和所述终端的通信状况,确定数据传输参数;determining data transmission parameters based on the data volume of the model update parameter and the communication status of the terminal;
按照所述数据传输参数将所述模型更新参数发送给服务器。The model update parameters are sent to the server according to the data transmission parameters.
可选地,所述方法还包括:Optionally, the method further includes:
发送用户调度信息,所述用户调度信息包括如下参数中的至少一项:终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为所述多个终端的数据分布信息合并得到的。Sending user scheduling information, where the user scheduling information includes at least one of the following parameters: data volume of data in the terminal, similarity between data distribution and total data distribution information, communication status, computing power, and learning model performance requirements, the The total data distribution information is obtained by combining the data distribution information of the multiple terminals.
可选地,所述方法还包括:Optionally, the method further includes:
接收全局模型参数;Receive global model parameters;
对所述全局模型参数进行自适应更新。The global model parameters are adaptively updated.
值得说明的是,前述步骤201~步骤202与上述可选步骤可以任意组合。It should be noted that the foregoing steps 201 to 202 and the foregoing optional steps may be combined arbitrarily.
图4是根据一示例性实施例示出的一种模型训练方法的流程图。参见图4,该方法包括以下步骤:Fig. 4 is a flow chart of a model training method according to an exemplary embodiment. Referring to Figure 4, the method includes the following steps:
在步骤301中,服务器和终端建立RRC连接。In step 301, the server and the terminal establish an RRC connection.
示例性地,服务器和终端建立RRC连接的过程可以参见图5,步骤如下:Exemplarily, the process of establishing an RRC connection between the server and the terminal may refer to FIG. 5, and the steps are as follows:
步骤3011、终端将请求建立RRC连接信令发送给服务器,该请求建立RRC连接信令应用请求与服务器建立RRC连接。相应地,服务器接收该请求建立RRC连接信令。Step 3011: The terminal sends a request to establish an RRC connection signaling to the server, and the request to establish an RRC connection signaling application requests to establish an RRC connection with the server. Correspondingly, the server receives the request to establish the RRC connection signaling.
步骤3012、服务器将建立RRC连接信令发送给客户端,该建立RRC连接信令用于通知终端服务器同意建立RRC连接。相应地,终端接收该建立RRC连接信令。Step 3012: The server sends the RRC connection establishment signaling to the client, where the RRC connection establishment signaling is used to notify the terminal server that the server agrees to establish the RRC connection. Correspondingly, the terminal receives the RRC connection establishment signaling.
步骤3013、终端将RRC连接建立完成信令发送给服务器,该RRC连接建立完成信令用于通知服务器RRC连接建立完成。相应地,服务器接收该RRC连接建立完成信令。Step 3013: The terminal sends an RRC connection establishment complete signaling to the server, where the RRC connection establishment complete signaling is used to notify the server that the RRC connection establishment is complete. Correspondingly, the server receives the RRC connection establishment completion signaling.
上述RRC连接建立过程中的信令收发由终端的网络通信模块和服务器的网络通信模块执行。终端和服务器的网络通信模块都可以由发送模块和接收模块两部分组成。The signaling transmission and reception in the above-mentioned RRC connection establishment process is performed by the network communication module of the terminal and the network communication module of the server. The network communication modules of the terminal and the server can be composed of two parts: a sending module and a receiving module.
在步骤302中,终端向服务器发送数据分布信息;服务器接收终端发送的数据分布信息。In step 302, the terminal sends the data distribution information to the server; the server receives the data distribution information sent by the terminal.
其中,数据分布信息包括数据的类别以及各个类别包含的样本数。The data distribution information includes data categories and the number of samples included in each category.
在本公开实施例中,步骤301和步骤302可以没有先后顺序,例如可以在服务器和终端建立RRC连接的过程中传输数据分布信息,也即服务器接收终端通过RRC信令传输的所述数据分布信息。例如,服务器接收终端通过RRC连 接建立完成信令传输的数据分布信息。In this embodiment of the present disclosure, step 301 and step 302 may be in no order, for example, data distribution information may be transmitted during the process of establishing an RRC connection between the server and the terminal, that is, the server receives the data distribution information transmitted by the terminal through RRC signaling . For example, the server receives the data distribution information that the terminal completes signaling transmission through RRC connection establishment.
在步骤303中,服务器将所述多个终端的数据分布信息合并,得到总数据分布信息。In step 303, the server combines the data distribution information of the multiple terminals to obtain total data distribution information.
例如,终端1的数据分布信息包括:{类型A,样本数量a1;类型B,样本数量b};终端2的数据分布信息包括:{类型A,样本数量a2;类型C,样本数量c};则,总数据分布信息包括:{类型A,样本数量a1+a2;类型B,样本数量b;类型C,样本数量c}。For example, the data distribution information of terminal 1 includes: {type A, sample size a1; type B, sample size b}; the data distribution information of terminal 2 includes: {type A, sample size a2; type C, sample size c}; Then, the total data distribution information includes: {type A, sample size a1+a2; type B, sample size b; type C, sample size c}.
在步骤304中,服务器从服务器的本地数据中抽选出分布符合所述总数据分布信息的数据,得到所述训练数据集。In step 304, the server extracts data whose distribution conforms to the total data distribution information from the local data of the server to obtain the training data set.
示例性地,服务器按照步骤303中合并得出的{类型A,样本数量a1+a2;类型B,样本数量b;类型C,样本数量c}选取数据类型和样本量,组成训练数据集。Exemplarily, the server selects the data type and sample size according to {type A, sample size a1+a2; type B, sample size b; type C, sample size c} obtained by combining in step 303 to form a training data set.
在步骤305中,服务器采用所述训练数据集进行模型训练,得到所述初始化模型参数。In step 305, the server uses the training data set to perform model training to obtain the initialization model parameters.
示例性地,服务器训练得到初始化模型参数的过程可以参见图6,步骤如下:Illustratively, the process of obtaining the initialization model parameters through server training can be seen in Figure 6, and the steps are as follows:
步骤3051、服务器随机初始化一组模型参数。Step 3051: The server randomly initializes a set of model parameters.
步骤3052、服务器从训练数据集中抽取一批任务,每个任务均包括支持集和查询集。Step 3052: The server extracts a batch of tasks from the training data set, and each task includes a support set and a query set.
示例性地,将总数据分布信息记为P,将服务器本地数据记为D s,按照总数据分布信息P从本地数据中抽取数据,生成训练数据集,记为
Figure PCTCN2020123292-appb-000001
服务器从训练数据集
Figure PCTCN2020123292-appb-000002
抽取数据生成若干个任务,每个任务包含支持集和查询集,分别记为
Figure PCTCN2020123292-appb-000003
Figure PCTCN2020123292-appb-000004
Exemplarily, the total data distribution information is denoted as P, the server local data is denoted as D s , data is extracted from the local data according to the total data distribution information P, and a training data set is generated, denoted as
Figure PCTCN2020123292-appb-000001
server from training dataset
Figure PCTCN2020123292-appb-000002
Extract data to generate several tasks, each task contains support set and query set, respectively denoted as
Figure PCTCN2020123292-appb-000003
and
Figure PCTCN2020123292-appb-000004
步骤3053、服务器使用各个任务的支持集进行训练,并计算模型损失和梯度,得到各个任务上的更新模型参数。Step 3053: The server uses the support set of each task for training, and calculates the model loss and gradient to obtain the updated model parameters on each task.
示例性地,服务器可以使用梯度下降法得到更新模型参数,可以表示为以下公式(1):Exemplarily, the server can use the gradient descent method to obtain the updated model parameters, which can be expressed as the following formula (1):
Figure PCTCN2020123292-appb-000005
Figure PCTCN2020123292-appb-000005
其中,θ′ i表示第i个任务上的更新模型参数,θ表示初始化的一组模型参数,α表示单个任务的学习率,
Figure PCTCN2020123292-appb-000006
表示求导,L表示模型在支持集上的损失函数,f表示模型,T i表示第i个任务,
Figure PCTCN2020123292-appb-000007
表示第i个任务的支持集。
Among them, θ′ i represents the updated model parameters on the ith task, θ represents a set of initialized model parameters, α represents the learning rate of a single task,
Figure PCTCN2020123292-appb-000006
represents the derivation, L represents the loss function of the model on the support set, f represents the model, T i represents the ith task,
Figure PCTCN2020123292-appb-000007
represents the support set for the ith task.
步骤3054、服务器使用各任务的查询集计算更新模型参数的测试损失和梯 度。Step 3054: The server uses the query set of each task to calculate the test loss and gradient for updating the model parameters.
步骤3055、服务器汇总各任务上的梯度,进行随机初始化的模型参数的更新,得到初始化模型参数。Step 3055: The server summarizes the gradients on each task, updates the randomly initialized model parameters, and obtains the initialized model parameters.
示例性地,服务器采用每个任务的查询集计算更新模型参数的测试损失和梯度,将每个任务上的梯度求和并取平均值。采用梯度下降法,用平均梯度值更新全局模型参数,可以表示为以下公式(2):Illustratively, the server computes the test loss and gradient for updating the model parameters using the query set for each task, sums and averages the gradients over each task. Using the gradient descent method, the global model parameters are updated with the average gradient value, which can be expressed as the following formula (2):
Figure PCTCN2020123292-appb-000008
Figure PCTCN2020123292-appb-000008
其中,β表示全局学习率,N表示本轮训练所采用的任务的个数,p(T)表示本轮训练采用的任务的集合,
Figure PCTCN2020123292-appb-000009
表示第i个任务的查询集。
Among them, β represents the global learning rate, N represents the number of tasks used in this round of training, p(T) represents the set of tasks used in this round of training,
Figure PCTCN2020123292-appb-000009
Represents the queryset for the ith task.
在上述过程中,各个步骤可以由服务器的模型训练模块执行,并且,在步骤3052中,服务器的训练数据集可以存储在服务器的数据处理及存储模块中,模型训练模块可以与数据处理及存储模块之间进行信令交互,从而抽取一批任务。In the above process, each step can be performed by the model training module of the server, and, in step 3052, the training data set of the server can be stored in the data processing and storage module of the server, and the model training module can be combined with the data processing and storage module. Signaling interaction is performed between them to extract a batch of tasks.
在步骤306中,终端向服务器发送用户调度信息;服务器接收终端发送的用户调度信息。In step 306, the terminal sends the user scheduling information to the server; the server receives the user scheduling information sent by the terminal.
所述用户调度信息包括如下参数中的至少一项:终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为所述多个终端的数据分布信息合并得到的。The user scheduling information includes at least one of the following parameters: data volume of data in the terminal, similarity between data distribution and total data distribution information, communication status, computing power, and learning model performance requirements, and the total data distribution information is: It is obtained by combining the data distribution information of the multiple terminals.
在本公开实施例中,步骤306和步骤302可以同时执行,也即终端在传输数据分布信息时将用户调度信息一并发送给服务器,也即用户调度信息也可以通过RRC信令传输。In this embodiment of the present disclosure, step 306 and step 302 may be performed simultaneously, that is, the terminal sends the user scheduling information to the server when transmitting the data distribution information, that is, the user scheduling information may also be transmitted through RRC signaling.
在本公开实施例的一种实现方式中,用户调度信息可以只包括通信状况、计算能力和学习模型性能要求,终端中数据的数据量、数据分布与总数据分布信息的相似性可以由服务器基于数据分布信息确定。In an implementation manner of the embodiment of the present disclosure, the user scheduling information may only include communication status, computing capability, and performance requirements of the learning model, and the similarity of the data volume, data distribution, and total data distribution information of the data in the terminal may be determined by the server based on The data distribution information is determined.
在本公开实施例中,用户调度信息中的各个参数,终端可以一起发送给服务器,也可以依次发送给服务器。In this embodiment of the present disclosure, each parameter in the user scheduling information may be sent to the server by the terminal together, or may be sent to the server in sequence.
在这些参数中,通信状况通常包括CQI,而CQI需要终端通过测量得到。因此,该方法还可以包括:在步骤306之前,终端进行CQI测量。Among these parameters, the communication condition usually includes the CQI, and the CQI needs to be obtained by the terminal through measurement. Therefore, the method may further include: before step 306, the terminal performs CQI measurement.
在本公开实施例中,用户调度信息由终端中的用户管理模块获取,并通过终端的网络通信模块发送给服务器的网络通信模块,服务器的网络通信模块再将其传输给服务器的用户管理模块。上述终端或服务器中的网络通信模块和用户管理模块进行用户调度信息传递时,可以采用一个新的信令执行,该信令的作用 就是传递用户调度信息。In the embodiment of the present disclosure, the user scheduling information is acquired by the user management module in the terminal, and sent to the network communication module of the server through the network communication module of the terminal, and the network communication module of the server transmits it to the user management module of the server. When the network communication module and the user management module in the above-mentioned terminal or server carry out the transmission of user scheduling information, a new signaling may be used for execution, and the function of this signaling is to transmit the user scheduling information.
在步骤307中,服务器基于所述多个终端中每个终端的用户调度信息,确定所述多个终端中每个终端是否符合分布式训练要求。In step 307, the server determines whether each of the multiple terminals meets the distributed training requirement based on the user scheduling information of each of the multiple terminals.
示例性地,对于用户调度信息中的每个参数,服务器都设定一个符合分布式训练要求的阈值范围,当一个终端的各个参数都符合设定的阈值范围时,该终端符合分布式训练要求。Exemplarily, for each parameter in the user scheduling information, the server sets a threshold range that meets the distributed training requirements. When each parameter of a terminal meets the set threshold range, the terminal meets the distributed training requirements. .
多个终端中除了上述选择出的符合分布式训练要求的终端外的其他终端则不参与本次训练。Terminals other than the terminals selected above that meet the distributed training requirements among the multiple terminals do not participate in this training.
在步骤308中,服务器将初始模型参数发送给所述多个终端中符合分布式训练要求的终端。终端接收初始模型参数。In step 308, the server sends the initial model parameters to the terminal that meets the distributed training requirement among the multiple terminals. The terminal receives initial model parameters.
在该步骤中,如果步骤301-步骤307中的终端属于符合分布式训练要求的终端,则该终端会参与步骤308-步骤314;而如果步骤301-步骤307中的终端不属于符合分布式训练要求的终端,则该终端不会参与步骤308-步骤314。该实施例是以步骤301-步骤307中的终端属于符合分布式训练要求的终端为例进行的说明。In this step, if the terminal in step 301-step 307 belongs to the terminal that meets the distributed training requirements, the terminal will participate in step 308-step 314; and if the terminal in step 301-step 307 does not belong to the distributed training requirement required terminal, the terminal will not participate in steps 308-314. This embodiment is described by taking as an example that the terminal in step 301 to step 307 belongs to a terminal that meets the distributed training requirement.
示例性地,服务器在传输初始化模型参数时,先基于初始模型参数的数据量和终端的通信状况,确定数据传输参数;然后按照数据传输参数将初始化模型参数发送给终端。这里,确定数据传输参数可以由服务器中的传输控制模块执行,传输控制模块确定数据传输参数后,可以控制网络通信模块按照上述数据传输参数进行初始化模型参数的发送。Exemplarily, when transmitting the initialization model parameters, the server first determines the data transmission parameters based on the data volume of the initial model parameters and the communication status of the terminal; and then sends the initialization model parameters to the terminal according to the data transmission parameters. Here, determining the data transmission parameters may be performed by a transmission control module in the server. After the transmission control module determines the data transmission parameters, it may control the network communication module to send the initialization model parameters according to the above data transmission parameters.
这里,数据传输参数包括调制方式、码率等参数,对于模型参数的数据量多少不同、终端的通信状况不同时,可以选用不同的调制方式和码率进行传输,使得选用的调制方式和码率与当前要传输的数据量以及终端的通信状况匹配,从而达到较好的传输效果。Here, the data transmission parameters include parameters such as modulation mode and code rate. When the data amount of the model parameters is different and the communication status of the terminal is different, different modulation modes and code rates can be selected for transmission, so that the selected modulation mode and code rate can be used for transmission. It matches the amount of data to be transmitted and the communication status of the terminal, so as to achieve a better transmission effect.
例如,服务器根据上述数据传输方案将初始化模型参数封装打包。服务器将打包后的初始化模型参数的数据包发送给终端。终端接收后对该数据包进行解封装。终端基于解封后的数据,确认接收到的数据包的正确性。然后终端向服务器反馈消息,告知服务器终端已经正确接收初始化模型参数。For example, the server encapsulates the initialization model parameters according to the above data transmission scheme. The server sends the packaged data packet of initializing model parameters to the terminal. The terminal decapsulates the data packet after receiving it. The terminal confirms the correctness of the received data packet based on the decapsulated data. Then the terminal feeds back a message to the server, informing the server that the terminal has correctly received the initialization model parameters.
在上述过程中,对于终端而言,验证数据包正确性以及产生反馈消息由终端内的传输控制模块执行,接收和发送过程则由网络通信模块执行。In the above process, for the terminal, verifying the correctness of the data packet and generating the feedback message is performed by the transmission control module in the terminal, and the receiving and sending processes are performed by the network communication module.
其中,模型参数的数据量一方面与模型大小相关,模型越大,则模型参数的 数据量越大;另一方面还与每个模型参数的精度有关,模型精度越高,则模型参数的数据量越大。其中,模型参数的精度可以是指保留到小数点后的位数,模型参数的精度越高,保留到小数点后的位数越多,则模型参数所占的数据量越大。Among them, the data volume of model parameters is related to the size of the model, the larger the model, the larger the data volume of model parameters; on the other hand, it is also related to the accuracy of each model parameter the larger the amount. The precision of the model parameters may refer to the number of digits retained after the decimal point. The higher the precision of the model parameters and the more digits retained after the decimal point, the larger the amount of data occupied by the model parameters.
在步骤309中,终端对初始模型参数进行训练,得到训练结果。In step 309, the terminal trains the initial model parameters to obtain a training result.
在本公开实施例中,终端的训练结果可以有两种情况,一种是训练完成后测试得到的梯度值,另一种是只进行了模型训练得到的模型更新参数,而未进行测试。存在这两种情况的原因是终端中数据的数据量大小不同,例如,当终端中数据的数据量较大时,终端中的数据可以组成支持集和查询集,此时终端可以先采用支持集进行模型训练,然后采用查询集进行模型测试;而当终端中数据的数据量较小时,终端中的数据只能组成支持集,此时终端采用支持集进行模型训练,而模型测试则由服务器完成。In the embodiment of the present disclosure, the training result of the terminal can be in two cases, one is the gradient value obtained by testing after the training is completed, and the other is the model update parameters obtained only after model training without testing. The reason for these two situations is that the data volume of the data in the terminal is different. For example, when the data volume in the terminal is large, the data in the terminal can form a support set and a query set. In this case, the terminal can use the support set first. Model training is performed, and then the query set is used for model testing; when the amount of data in the terminal is small, the data in the terminal can only form a support set. At this time, the terminal uses the support set for model training, and the model test is completed by the server. .
这里,终端中数据量的大小可以通过与阈值比较得到,例如大于阈值则为较大,小于阈值则为较小。该阈值可以基于多个终端的数据量确定,例如,可以是多个终端的数据量的一个分位数,如,80%用户的数据量达到1000,则将该阈值设为1000。该阈值可以由服务器基于各个终端的数据分布信息确定出来,然后通知给各个终端。终端可以基于该阈值与自身的数据量比较,确定是否生成查询集。Here, the size of the data amount in the terminal can be obtained by comparing with the threshold value, for example, if it is larger than the threshold value, it is larger, and if it is smaller than the threshold value, it is smaller. The threshold may be determined based on the data volume of multiple terminals, for example, may be a quantile of the data volume of multiple terminals. For example, if the data volume of 80% of users reaches 1000, the threshold is set to 1000. The threshold may be determined by the server based on the data distribution information of each terminal, and then notified to each terminal. The terminal can determine whether to generate a query set based on the threshold and its own data volume.
示例性地,终端采用支持集对初始模型参数进行梯度下降更新,得到模型更新参数,可以表示为以下公式(3):Exemplarily, the terminal uses the support set to update the initial model parameters by gradient descent to obtain the model update parameters, which can be expressed as the following formula (3):
Figure PCTCN2020123292-appb-000010
Figure PCTCN2020123292-appb-000010
其中,θ ui表示第i个终端的模型更新参数,
Figure PCTCN2020123292-appb-000011
表示第i个终端中的支持集。
Among them, θ ui represents the model update parameter of the ith terminal,
Figure PCTCN2020123292-appb-000011
represents the support set in the ith terminal.
如果终端中存在查询集,则终端采用查询集对模型更新参数进行测试,计算出测试损失和梯度值,可以表示为以下公式(4):If there is a query set in the terminal, the terminal uses the query set to test the model update parameters, and calculates the test loss and gradient value, which can be expressed as the following formula (4):
Figure PCTCN2020123292-appb-000012
Figure PCTCN2020123292-appb-000012
其中,g ui表示第i个终端的模型更新参数的测试梯度,
Figure PCTCN2020123292-appb-000013
表示第i个终端的训练集中的查询集。
Among them, g ui represents the test gradient of the model update parameters of the ith terminal,
Figure PCTCN2020123292-appb-000013
represents the query set in the training set of the ith terminal.
在步骤310中,终端向服务器发送所述训练结果;服务器接收终端发送的训练结果。In step 310, the terminal sends the training result to the server; the server receives the training result sent by the terminal.
终端发送训练结果时,如果发送的是模型更新参数,可以按照步骤308中,服务器发送初始化模型参数的方式完成,也即先确定数据传输参数,然后按照数据传输参数进行发送。推而广之,在本公开实施例中,如果终端和服务器之间需 要传输模型参数,都按照先确定数据传输参数,然后按照数据传输参数进行传输的方式执行。When the terminal sends the training result, if it sends the model update parameters, it can be done in the way that the server sends the initialization model parameters in step 308, that is, the data transmission parameters are first determined and then sent according to the data transmission parameters. By extension, in the embodiment of the present disclosure, if model parameters need to be transmitted between the terminal and the server, the data transmission parameters are first determined and then transmitted according to the data transmission parameters.
在步骤311中,服务器基于至少部分终端的训练结果,对所述模型参数进行更新。当更新后的模型参数符合要求时,执行步骤312,当更新后的模型参数不符合要求时,执行步骤313。In step 311, the server updates the model parameters based on the training results of at least some of the terminals. When the updated model parameters meet the requirements, step 312 is performed, and when the updated model parameters do not meet the requirements, step 313 is performed.
这里的至少部分终端是指参与训练的符合分布式训练要求的终端。服务器可以基于这部分终端的训练结果,得到至少部分终端的梯度值的平均值。如果这至少部分终端的梯度值的平均值在阈值范围(例如小于设定值)内,则说明更新后的模型参数符合要求,否则,说明更新后的模型参数不符合要求。At least some of the terminals here refer to the terminals that participate in the training and meet the distributed training requirements. The server may obtain the average value of the gradient values of at least some of the terminals based on the training results of the terminal. If the average value of the gradient values of at least some of the terminals is within the threshold range (eg, less than the set value), it means that the updated model parameters meet the requirements; otherwise, the updated model parameters do not meet the requirements.
示例性地,当所述至少部分终端中每个终端的训练结果均包括梯度值时,步骤311可以包括:Exemplarily, when the training result of each terminal in the at least some terminals includes a gradient value, step 311 may include:
服务器基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新。The server uses a gradient descent method to iteratively update the model parameters based on the average value of the gradient values of the at least part of the terminals.
示例性地,当所述至少部分终端中至少一个终端的训练结果包括模型更新参数时,步骤311可以包括:Exemplarily, when the training result of at least one terminal in the at least part of the terminals includes model update parameters, step 311 may include:
服务器选取符合第一终端的数据分布信息的查询集,所述第一终端为训练结果包括模型更新参数的终端;The server selects a query set that conforms to the data distribution information of the first terminal, where the first terminal is a terminal whose training result includes model update parameters;
服务器基于所述查询集对所述第一终端的模型更新参数进行测试,得到梯度值;The server tests the model update parameters of the first terminal based on the query set to obtain a gradient value;
服务器基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新。The server uses a gradient descent method to iteratively update the model parameters based on the average value of the gradient values of the at least part of the terminals.
在该步骤中,服务器根据各个终端的数据量确定是否需要为终端生成查询集。In this step, the server determines whether a query set needs to be generated for the terminal according to the data volume of each terminal.
在本公开实施例中,服务器采用至少部分终端的梯度值的平均值对模型参数进行梯度下降更新,可以表示为以下公式(5):In this embodiment of the present disclosure, the server uses the average value of gradient values of at least some terminals to update the model parameters by gradient descent, which can be expressed as the following formula (5):
Figure PCTCN2020123292-appb-000014
Figure PCTCN2020123292-appb-000014
其中,M表示符合分布式训练要求的终端的数量,也即参与分布式训练的终端的数量。Among them, M represents the number of terminals that meet the distributed training requirements, that is, the number of terminals participating in the distributed training.
前面判断更新后的模型参数符合要求可以按照如下公式(6)判断:It can be judged that the updated model parameters meet the requirements according to the following formula (6):
Figure PCTCN2020123292-appb-000015
Figure PCTCN2020123292-appb-000015
其中,g 0表示前述阈值(设定值)。 Here, g 0 represents the aforementioned threshold value (set value).
步骤311可以由服务器中的模型更新模块执行,该模块在执行上述步骤的过程中,需要与服务器中的数据处理及存储模块进行交互,获取数据为终端生成查询集,在该交互过程中可以使用一个新增的信令,来指示数据处理及存储模块提供上述生成查询集的数据。Step 311 can be executed by the model update module in the server. In the process of executing the above steps, the module needs to interact with the data processing and storage module in the server, and obtains the data to generate a query set for the terminal, which can be used in the interaction process. A newly added signaling to instruct the data processing and storage module to provide the above-mentioned data for generating the query set.
在步骤312中,服务器将所述模型参数经过迭代更新后的中间模型参数发送给所述至少部分终端;终端接收服务器发送的中间模型参数。In step 312, the server sends the intermediate model parameters whose model parameters are iteratively updated to the at least part of the terminals; the terminal receives the intermediate model parameters sent by the server.
终端在接收到服务器发送的中间模型参数后,对中间模型参数进行训练,得到训练结果,然后重复步骤310和步骤311进行迭代更新。After receiving the intermediate model parameters sent by the server, the terminal trains the intermediate model parameters to obtain a training result, and then repeats steps 310 and 311 to iteratively update.
在步骤313中,服务器将所述模型参数经过迭代更新后的全局模型参数发送给所述至少部分终端;终端接收服务器发送的全局模型参数。In step 313, the server sends the global model parameters whose model parameters are iteratively updated to the at least part of the terminals; the terminals receive the global model parameters sent by the server.
在上述步骤中,服务器和终端之间只有数据分布信息、用户调度信息等可以通过RRC信令传输,而后续的模型参数、训练结果等,由于数据量较大,通过业务数据传输。In the above steps, only data distribution information, user scheduling information, etc. can be transmitted between the server and the terminal through RRC signaling, and subsequent model parameters, training results, etc., are transmitted through service data due to the large amount of data.
在步骤314中,终端对全局模型参数进行自适应更新。In step 314, the terminal adaptively updates the global model parameters.
在本公开实施例中,终端采用支持集对全局模型参数进行测试,计算出测试损失和梯度,并进行梯度下降更新得到自适应模型,可以表示为以下公式(7):In the embodiment of the present disclosure, the terminal uses the support set to test the global model parameters, calculates the test loss and gradient, and performs gradient descent update to obtain an adaptive model, which can be expressed as the following formula (7):
Figure PCTCN2020123292-appb-000016
Figure PCTCN2020123292-appb-000016
其中,Φ ui(θ)为第i个终端的自适应更新模型,
Figure PCTCN2020123292-appb-000017
为第i个终端的测试集中的查询集。
Among them, Φ ui (θ) is the adaptive update model of the ith terminal,
Figure PCTCN2020123292-appb-000017
is the query set in the test set of the ith terminal.
前述步骤309和步骤314可以由终端中的模型更新模块执行,该模块在执行上述步骤的过程中,需要与终端中的数据处理及存储模块进行交互,获取数据生成支持集、查询集等。The aforementioned steps 309 and 314 may be performed by the model updating module in the terminal, which needs to interact with the data processing and storage module in the terminal during the execution of the above steps to obtain data to generate a support set, a query set, and the like.
图7是根据一示例性实施例示出的一种模型训练装置的结构示意图。该装置具有实现上述方法实施例中服务器的功能,该功能可以通过硬件实现,也可以通过硬件执行相应的软件实现。如图7所示,该装置包括:接收模块501、选取模块502、模型训练模块503和发送模块504。Fig. 7 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment. The apparatus has the function of implementing the server in the above method embodiment, and the function may be implemented by hardware, or by executing corresponding software in hardware. As shown in FIG. 7 , the apparatus includes: a receiving module 501 , a selecting module 502 , a model training module 503 and a sending module 504 .
其中,接收模块501,被配置为接收多个终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;Wherein, the receiving module 501 is configured to receive data distribution information of multiple terminals, where the data distribution information includes data categories and the number of samples included in each category;
选取模块502,被配置为选取符合所述多个终端的数据分布信息的训练数据 集;A selection module 502 is configured to select a training data set that conforms to the data distribution information of the multiple terminals;
模型训练模块503,被配置为基于所述训练数据集进行模型训练,得到模型参数;The model training module 503 is configured to perform model training based on the training data set to obtain model parameters;
发送模块504,被配置为将所述模型参数发送给所述多个终端中的至少部分终端;a sending module 504, configured to send the model parameters to at least some of the multiple terminals;
所述接收模块501,还被配置为接收所述至少部分终端对所述模型参数进行训练得到的训练结果;The receiving module 501 is further configured to receive a training result obtained by training the model parameters by the at least part of the terminals;
所述模型训练模块503,还被配置为基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数。The model training module 503 is further configured to update the model parameters based on the training results of the at least part of the terminals to obtain global model parameters.
可选地,所述接收模块501,被配置为接收所述多个终端中每个终端通过RRC信令传输的所述数据分布信息。Optionally, the receiving module 501 is configured to receive the data distribution information transmitted by each of the multiple terminals through RRC signaling.
可选地,所述选取模块502,被配置为将所述多个终端的数据分布信息合并,得到总数据分布信息;从服务器的本地数据中抽选出分布符合所述总数据分布信息的数据,得到所述训练数据集。Optionally, the selection module 502 is configured to combine the data distribution information of the multiple terminals to obtain total data distribution information; extract data whose distribution conforms to the total data distribution information from the local data of the server. , to obtain the training data set.
可选地,所述模型参数包括初始化模型参数,所述模型训练模块503,被配置为采用所述训练数据集进行模型训练,得到所述初始化模型参数;Optionally, the model parameters include initialization model parameters, and the model training module 503 is configured to use the training data set to perform model training to obtain the initialization model parameters;
或者,所述模型参数包括中间模型参数,所述模型训练模块503,被配置为采用所述训练数据集进行模型训练,得到初始化模型参数;对所述初始化模型参数进行迭代更新,得到所述中间模型参数。Alternatively, the model parameters include intermediate model parameters, and the model training module 503 is configured to perform model training by using the training data set to obtain initialized model parameters; iteratively update the initialized model parameters to obtain the intermediate model parameters model parameters.
可选地,所述接收模块501,还被配置为接收所述多个终端中每个终端的用户调度信息;Optionally, the receiving module 501 is further configured to receive user scheduling information of each terminal in the multiple terminals;
所述装置还包括:确定模块505,被配置为基于所述多个终端中每个终端的用户调度信息,确定所述多个终端中每个终端是否符合分布式训练要求;The apparatus further includes: a determination module 505, configured to determine whether each of the multiple terminals meets the distributed training requirement based on user scheduling information of each of the multiple terminals;
所述发送模块504,被配置为将所述模型参数发送给所述多个终端中符合分布式训练要求的终端。The sending module 504 is configured to send the model parameters to a terminal that meets the distributed training requirement among the multiple terminals.
可选地,所述用户调度信息包括如下参数中的至少一项:Optionally, the user scheduling information includes at least one of the following parameters:
终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为所述多个终端的数据分布信息合并得到的。The data volume of the data in the terminal, the similarity between the data distribution and the total data distribution information, the communication status, the computing capability, and the performance requirements of the learning model, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
可选地,所述确定模块505,还被配置为基于所述模型参数的数据量和所述符合分布式训练要求的终端的通信状况,确定数据传输参数;Optionally, the determining module 505 is further configured to determine data transmission parameters based on the data volume of the model parameters and the communication status of the terminals that meet the distributed training requirements;
所述发送模块504,被配置为按照所述数据传输参数将所述模型参数发送给所述符合分布式训练要求的终端。The sending module 504 is configured to send the model parameter to the terminal that meets the distributed training requirement according to the data transmission parameter.
可选地,所述训练结果包括梯度值,所述梯度值为所述终端对所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;Optionally, the training result includes a gradient value, and the gradient value is a gradient value obtained by testing the trained model parameters after the terminal trains the model parameters;
或者,所述训练结果包括模型更新参数,所述模型更新参数为所述终端对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after the terminal trains the model parameter.
可选地,当所述至少部分终端中每个终端的训练结果均包括梯度值时,Optionally, when the training result of each terminal in the at least part of the terminals includes gradient values,
所述模型训练模块503,被配置为基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。The model training module 503 is configured to use a gradient descent method to iteratively update the model parameters based on the average value of the gradient values of the at least part of the terminals to obtain global model parameters.
可选地,当所述至少部分终端中至少一个终端的训练结果包括模型更新参数时,Optionally, when the training result of at least one terminal in the at least part of the terminals includes model update parameters,
所述选取模块502,被配置为选取符合第一终端的数据分布信息的查询集,所述第一终端为训练结果包括模型更新参数的终端;The selecting module 502 is configured to select a query set that conforms to the data distribution information of the first terminal, where the first terminal is a terminal whose training result includes model update parameters;
所述模型训练模块503,被配置为基于所述查询集对所述第一终端的模型更新参数进行测试,得到梯度值;基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。The model training module 503 is configured to test the model update parameters of the first terminal based on the query set to obtain gradient values; The model parameters are iteratively updated to obtain global model parameters.
可选地,所述模型训练模块503,被配置为基于所述至少部分终端的第一梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新;确定所述至少部分终端的第一梯度值的平均值是否在阈值范围内;响应于所述至少部分终端的第一梯度值的平均值不在阈值范围内,将所述模型参数经过迭代更新后的中间模型参数发送给所述至少部分终端;采用所述至少部分终端的第二梯度值的平均值对所述中间模型参数进行迭代更新;其中,所述第二梯度值是所述终端对所述中间模型参数进行训练后,通过对训练后的所述中间模型参数测试得到的梯度值。Optionally, the model training module 503 is configured to use a gradient descent method to iteratively update the model parameters based on the average value of the first gradient values of the at least some terminals; determine the first gradient values of the at least some terminals. Whether the average value of a gradient value is within the threshold value range; in response to the average value of the first gradient values of the at least part of the terminals being not within the threshold value range, the intermediate model parameters after the iterative update of the model parameters are sent to the at least one terminal. some terminals; iteratively update the intermediate model parameters by using the average value of the second gradient values of the at least part of the terminals; wherein, the second gradient values are obtained by the terminal after training the intermediate model parameters The gradient values obtained by testing the intermediate model parameters after training.
可选地,所述发送模块504,还被配置为响应于所述至少部分终端的第一梯度值的平均值在阈值范围内,将所述模型参数经过迭代更新后的全局模型参数发送给所述至少部分终端,所述全局模型参数用于供所述终端进行自适应更新。Optionally, the sending module 504 is further configured to, in response to the average value of the first gradient values of the at least part of the terminals being within a threshold range, send the global model parameters after the iterative update of the model parameters to all the terminals. at least some of the terminals, the global model parameters are used for adaptive updating of the terminals.
图8是根据一示例性实施例示出的一种模型训练装置的结构示意图。该装置具有实现上述方法实施例中终端的功能,该功能可以通过硬件实现,也可以通过硬件执行相应的软件实现。如图8所示,该装置包括:发送模块601、接收模 块602和模型训练模块603。Fig. 8 is a schematic structural diagram of a model training apparatus according to an exemplary embodiment. The apparatus has the function of realizing the terminal in the above method embodiment, and the function may be realized by hardware, or by executing corresponding software in hardware. As shown in FIG. 8 , the apparatus includes: a sending module 601, a receiving module 602 and a model training module 603.
其中,发送模块601,被配置为发送终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;Wherein, the sending module 601 is configured to send data distribution information of the terminal, where the data distribution information includes data categories and the number of samples included in each category;
接收模块602,被配置为接收模型参数,所述模型参数是服务器基于所述数据分布信息选取的训练数据集训练得到的;The receiving module 602 is configured to receive model parameters, where the model parameters are obtained by training a training data set selected by the server based on the data distribution information;
模型训练模块603,被配置为对所述模型参数进行训练,得到训练结果;A model training module 603, configured to train the model parameters to obtain a training result;
所述发送模块601,还被配置为发送所述训练结果,所述训练结果用于对所述模型参数进行全局更新,得到全局模型参数。The sending module 601 is further configured to send the training result, where the training result is used to globally update the model parameters to obtain global model parameters.
可选地,所述发送模块601,被配置为通过RRC信令发送所述数据分布信息。Optionally, the sending module 601 is configured to send the data distribution information through RRC signaling.
可选地,所述模型参数包括初始化模型参数,所述接收模块602,被配置为接收所述初始化模型参数,所述初始化模型参数是所述服务器采用所述数据分布信息选取的训练数据集训练得到的;Optionally, the model parameters include initialization model parameters, and the receiving module 602 is configured to receive the initialization model parameters, where the initialization model parameters are the training data set selected by the server using the data distribution information. owned;
或者,所述模型参数包括中间模型参数,所述接收模块602,被配置为接收所述中间模型参数,所述中间模型参数是所述服务器对初始化模型参数进行迭代更新得到的。Alternatively, the model parameters include intermediate model parameters, and the receiving module 602 is configured to receive the intermediate model parameters, where the intermediate model parameters are obtained by iteratively updating the initialization model parameters by the server.
可选地,所述训练结果包括梯度值,所述梯度值为所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;Optionally, the training result includes a gradient value, and the gradient value is a gradient value obtained by testing the trained model parameters after the model parameters are trained;
或者,所述训练结果包括模型更新参数,所述模型更新参数为对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after training the model parameter.
可选地,当所述训练结果包括模型更新参数时,Optionally, when the training result includes model update parameters,
所述装置还包括:确定模块604,被配置为基于所述模型更新参数的数据量和所述终端的通信状况,确定数据传输参数;The apparatus further includes: a determining module 604, configured to determine data transmission parameters based on the data volume of the model update parameter and the communication status of the terminal;
所述发送模块601,被配置为按照所述数据传输参数将所述模型更新参数发送给服务器。The sending module 601 is configured to send the model update parameter to the server according to the data transmission parameter.
可选地,所述发送模块601,还被配置为发送用户调度信息,所述用户调度信息包括如下参数中的至少一项:终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为所述多个终端的数据分布信息合并得到的。Optionally, the sending module 601 is further configured to send user scheduling information, where the user scheduling information includes at least one of the following parameters: data volume of data in the terminal, similarity between data distribution and total data distribution information , communication status, computing capability, and learning model performance requirements, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
可选地,所述接收模块602,还被配置为接收全局模型参数;Optionally, the receiving module 602 is further configured to receive global model parameters;
所述模型训练模块603,还被配置为对所述全局模型参数进行自适应更新。The model training module 603 is further configured to adaptively update the global model parameters.
图9是根据一示例性实施例示出的一种终端700的框图,该终端700可以包括:处理器701、接收器702、发射器703、存储器704和总线705。FIG. 9 is a block diagram of a terminal 700 according to an exemplary embodiment. The terminal 700 may include: a processor 701 , a receiver 702 , a transmitter 703 , a memory 704 and a bus 705 .
处理器701包括一个或者一个以上处理核心,处理器701通过运行软件程序以及模块,从而执行各种功能应用以及信息处理。The processor 701 includes one or more processing cores, and the processor 701 executes various functional applications and information processing by running software programs and modules.
接收器702和发射器703可以实现为一个通信组件,该通信组件可以是一块通信芯片。The receiver 702 and the transmitter 703 may be implemented as a communication component, which may be a communication chip.
存储器704通过总线705与处理器701相连。 Memory 704 is connected to processor 701 via bus 705 .
存储器704可用于存储至少一个指令,处理器701用于执行该至少一个指令,以实现上述方法实施例中的各个步骤。The memory 704 may be configured to store at least one instruction, and the processor 701 may be configured to execute the at least one instruction, so as to implement various steps in the foregoing method embodiments.
此外,存储器704可以由任何类型的易失性或非易失性存储设备或者它们的组合实现,易失性或非易失性存储设备包括但不限于:磁盘或光盘,电可擦除可编程只读存储器(EEPROM),可擦除可编程只读存储器(EPROM),静态随时存取存储器(SRAM),只读存储器(ROM),磁存储器,快闪存储器,可编程只读存储器(PROM)。Additionally, memory 704 may be implemented by any type or combination of volatile or non-volatile storage devices including, but not limited to, magnetic or optical disks, electrically erasable programmable Read Only Memory (EEPROM), Erasable Programmable Read Only Memory (EPROM), Static Anytime Access Memory (SRAM), Read Only Memory (ROM), Magnetic Memory, Flash Memory, Programmable Read Only Memory (PROM) .
在示例性实施例中,还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现上述各个方法实施例提供的模型训练方法。In an exemplary embodiment, a computer-readable storage medium is also provided, wherein the computer-readable storage medium stores at least one instruction, at least one piece of program, code set or instruction set, the at least one instruction, the At least one section of program, the code set or the instruction set is loaded and executed by the processor to implement the model training method provided by each of the above method embodiments.
图10是根据一示例性实施例示出的一种服务器800的框图,服务器800可以包括:处理器801、接收机802、发射机803和存储器804。接收机802、发射机803和存储器804分别通过总线与处理器801连接。FIG. 10 is a block diagram of a server 800 according to an exemplary embodiment. The server 800 may include: a processor 801 , a receiver 802 , a transmitter 803 and a memory 804 . The receiver 802, the transmitter 803 and the memory 804 are respectively connected to the processor 801 through a bus.
其中,处理器801包括一个或者一个以上处理核心,处理器801通过运行软件程序以及模块以执行本公开实施例提供的模型训练方法中服务器所执行的方法。存储器804可用于存储软件程序以及模块。具体的,存储器804可存储操作系统8041、至少一个功能所需的应用程序模块8042。接收机802用于接收其他设备发送的通信数据,发射机803用于向其他设备发送通信数据。The processor 801 includes one or more processing cores, and the processor 801 executes the method executed by the server in the model training method provided by the embodiment of the present disclosure by running software programs and modules. Memory 804 may be used to store software programs and modules. Specifically, the memory 804 can store the operating system 8041 and an application module 8042 required for at least one function. The receiver 802 is used for receiving communication data sent by other devices, and the transmitter 803 is used for sending communication data to other devices.
在示例性实施例中,还提供了一种计算机可读存储介质,所述计算机可读存储介质中存储有至少一条指令、至少一段程序、代码集或指令集,所述至少一条指令、所述至少一段程序、所述代码集或指令集由所述处理器加载并执行以实现 上述各个方法实施例提供的模型训练方法。In an exemplary embodiment, a computer-readable storage medium is also provided, wherein the computer-readable storage medium stores at least one instruction, at least one piece of program, code set or instruction set, the at least one instruction, the At least one section of program, the code set or the instruction set is loaded and executed by the processor to implement the model training method provided by each of the above method embodiments.
本公开一示例性实施例还提供了一种模型训练系统,所述模型训练系统包括终端和服务器。所述终端为如图9所示实施例提供的终端。所述服务器为如图10所示实施例提供的服务器。An exemplary embodiment of the present disclosure also provides a model training system, where the model training system includes a terminal and a server. The terminal is the terminal provided by the embodiment shown in FIG. 9 . The server is the server provided by the embodiment shown in FIG. 10 .
本领域技术人员在考虑说明书及实践这里公开的发明后,将容易想到本公开的其它实施方案。本申请旨在涵盖本公开的任何变型、用途或者适应性变化,这些变型、用途或者适应性变化遵循本公开的一般性原理并包括本公开未公开的本技术领域中的公知常识或惯用技术手段。说明书和实施例仅被视为示例性的,本公开的真正范围和精神由下面的权利要求指出。Other embodiments of the present disclosure will readily occur to those skilled in the art upon consideration of the specification and practice of the invention disclosed herein. This application is intended to cover any variations, uses, or adaptations of the present disclosure that follow the general principles of the present disclosure and include common knowledge or techniques in the technical field not disclosed by the present disclosure . The specification and examples are to be regarded as exemplary only, with the true scope and spirit of the disclosure being indicated by the following claims.
应当理解的是,本公开并不局限于上面已经描述并在附图中示出的精确结构,并且可以在不脱离其范围进行各种修改和改变。本公开的范围仅由所附的权利要求来限制。It is to be understood that the present disclosure is not limited to the precise structures described above and illustrated in the accompanying drawings, and that various modifications and changes may be made without departing from the scope thereof. The scope of the present disclosure is limited only by the appended claims.

Claims (41)

  1. 一种模型训练方法,其特征在于,所述方法包括:A model training method, characterized in that the method comprises:
    接收多个终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;receiving data distribution information of multiple terminals, where the data distribution information includes categories of data and the number of samples included in each category;
    选取符合所述多个终端的数据分布信息的训练数据集;selecting a training data set that conforms to the data distribution information of the multiple terminals;
    基于所述训练数据集进行模型训练,得到模型参数;Perform model training based on the training data set to obtain model parameters;
    将所述模型参数发送给所述多个终端中的至少部分终端;sending the model parameters to at least some of the terminals;
    接收所述至少部分终端对所述模型参数进行训练得到的训练结果;receiving a training result obtained by training the model parameters by the at least some terminals;
    基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数。Based on the training results of the at least part of the terminals, the model parameters are updated to obtain global model parameters.
  2. 根据权利要求1所述的方法,其特征在于,接收多个终端的数据分布信息,包括:The method according to claim 1, wherein receiving data distribution information of multiple terminals comprises:
    接收所述多个终端中每个终端通过RRC信令传输的所述数据分布信息。The data distribution information transmitted by each of the multiple terminals through RRC signaling is received.
  3. 根据权利要求1所述的方法,其特征在于,选取符合所述多个终端的数据分布信息的训练数据集,包括:The method according to claim 1, wherein selecting a training data set conforming to the data distribution information of the multiple terminals comprises:
    将所述多个终端的数据分布信息合并,得到总数据分布信息;combining the data distribution information of the multiple terminals to obtain total data distribution information;
    从服务器的本地数据中抽选出分布符合所述总数据分布信息的数据,得到所述训练数据集。The data whose distribution conforms to the total data distribution information is extracted from the local data of the server to obtain the training data set.
  4. 根据权利要求1所述的方法,其特征在于,所述模型参数包括初始化模型参数,基于所述训练数据集进行模型训练,得到模型参数,包括:The method according to claim 1, wherein the model parameters include initialization model parameters, and model training is performed based on the training data set to obtain model parameters, including:
    采用所述训练数据集进行模型训练,得到所述初始化模型参数;Use the training data set to perform model training to obtain the initialization model parameters;
    或者,所述模型参数包括中间模型参数,基于所述训练数据集进行模型训练,得到模型参数,包括:Alternatively, the model parameters include intermediate model parameters, and model training is performed based on the training data set to obtain model parameters, including:
    采用所述训练数据集进行模型训练,得到初始化模型参数;Use the training data set to perform model training to obtain initialization model parameters;
    对所述初始化模型参数进行迭代更新,得到所述中间模型参数。The initial model parameters are iteratively updated to obtain the intermediate model parameters.
  5. 根据权利要求1所述的方法,其特征在于,将所述模型参数发送给所述多个终端中的至少部分终端,包括:The method according to claim 1, wherein sending the model parameters to at least some of the multiple terminals comprises:
    接收所述多个终端中每个终端的用户调度信息;receiving user scheduling information of each terminal in the plurality of terminals;
    基于所述多个终端中每个终端的用户调度信息,确定所述多个终端中每个终端是否符合分布式训练要求;determining, based on the user scheduling information of each terminal in the plurality of terminals, whether each terminal in the plurality of terminals meets the distributed training requirement;
    将所述模型参数发送给所述多个终端中符合分布式训练要求的终端。Sending the model parameters to a terminal that meets the distributed training requirement among the multiple terminals.
  6. 根据权利要求5所述的方法,其特征在于,所述用户调度信息包括如下参数中的至少一项:The method according to claim 5, wherein the user scheduling information includes at least one of the following parameters:
    终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为所述多个终端的数据分布信息合并得到的。The data volume of the data in the terminal, the similarity between the data distribution and the total data distribution information, the communication status, the computing capability, and the performance requirements of the learning model, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
  7. 根据权利要求5所述的方法,其特征在于,将所述模型参数发送给所述多个终端中符合分布式训练要求的终端,包括:The method according to claim 5, wherein sending the model parameters to a terminal that meets the distributed training requirement among the multiple terminals comprises:
    基于所述模型参数的数据量和所述符合分布式训练要求的终端的通信状况,确定数据传输参数;Determine data transmission parameters based on the data volume of the model parameters and the communication status of the terminals that meet the distributed training requirements;
    按照所述数据传输参数将所述模型参数发送给所述符合分布式训练要求的终端。The model parameters are sent to the terminal that meets the distributed training requirements according to the data transmission parameters.
  8. 根据权利要求1至7任一项所述的方法,其特征在于,所述训练结果包括梯度值,所述梯度值为所述终端对所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;The method according to any one of claims 1 to 7, characterized in that, the training result includes a gradient value, and the gradient value is obtained after the terminal trains the model parameters by The gradient value obtained from the model parameter test;
    或者,所述训练结果包括模型更新参数,所述模型更新参数为所述终端对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after the terminal trains the model parameter.
  9. 根据权利要求8所述的方法,其特征在于,当所述至少部分终端中每个终端的训练结果均包括梯度值时,The method according to claim 8, wherein when the training result of each terminal in the at least part of the terminals includes a gradient value,
    基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数,包括:Based on the training results of the at least part of the terminals, the model parameters are updated to obtain global model parameters, including:
    基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。Based on the average value of the gradient values of the at least part of the terminals, the model parameters are iteratively updated using a gradient descent method to obtain global model parameters.
  10. 根据权利要求8所述的方法,其特征在于,当所述至少部分终端中至少一个终端的训练结果包括模型更新参数时,The method according to claim 8, wherein when the training result of at least one terminal in the at least part of the terminals includes a model update parameter,
    基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数,包括:Based on the training results of the at least part of the terminals, the model parameters are updated to obtain global model parameters, including:
    选取符合第一终端的数据分布信息的查询集,所述第一终端为训练结果包括模型更新参数的终端;selecting a query set that conforms to the data distribution information of the first terminal, where the first terminal is a terminal whose training result includes model update parameters;
    基于所述查询集对所述第一终端的模型更新参数进行测试,得到梯度值;Test the model update parameters of the first terminal based on the query set to obtain a gradient value;
    基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。Based on the average value of the gradient values of the at least part of the terminals, the model parameters are iteratively updated using a gradient descent method to obtain global model parameters.
  11. 根据权利要求9或10所述的方法,其特征在于,基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数,包括:The method according to claim 9 or 10, wherein, based on the average value of the gradient values of the at least part of the terminals, using a gradient descent method to iteratively update the model parameters to obtain global model parameters, comprising:
    基于所述至少部分终端的第一梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新;Based on the average value of the first gradient values of the at least part of the terminals, iteratively update the model parameters by using a gradient descent method;
    确定所述至少部分终端的第一梯度值的平均值是否在阈值范围内;determining whether the average value of the first gradient values of the at least part of the terminals is within a threshold range;
    响应于所述至少部分终端的第一梯度值的平均值不在阈值范围内,将所述模型参数经过迭代更新后的中间模型参数发送给所述至少部分终端;In response to the average value of the first gradient values of the at least part of the terminals being not within the threshold range, sending the intermediate model parameters of the model parameters after the iterative update of the model parameters to the at least part of the terminals;
    采用所述至少部分终端的第二梯度值的平均值对所述中间模型参数进行迭代更新;其中,所述第二梯度值是所述终端对所述中间模型参数进行训练后,通过对训练后的所述中间模型参数测试得到的梯度值。The intermediate model parameters are iteratively updated by using the average value of the second gradient values of the at least part of the terminals; The gradient values obtained by testing the intermediate model parameters.
  12. 根据权利要求11所述的方法,其特征在于,所述方法还包括:The method according to claim 11, wherein the method further comprises:
    响应于所述至少部分终端的第一梯度值的平均值在阈值范围内,将所述模型参数经过迭代更新后的全局模型参数发送给所述至少部分终端,所述全局模型参数用于供所述终端进行自适应更新。In response to the average value of the first gradient values of the at least part of the terminals being within the threshold range, the global model parameters after the iterative update of the model parameters are sent to the at least part of the terminals, and the global model parameters are used for all the terminals. The terminal performs adaptive update.
  13. 一种模型训练方法,其特征在于,所述方法包括:A model training method, characterized in that the method comprises:
    发送终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;sending data distribution information of the terminal, where the data distribution information includes categories of data and the number of samples included in each category;
    接收模型参数,所述模型参数是服务器基于所述数据分布信息选取的训练数据集训练得到的;Receive model parameters, where the model parameters are obtained by training a training data set selected by the server based on the data distribution information;
    对所述模型参数进行训练,得到训练结果;training the model parameters to obtain a training result;
    发送所述训练结果,所述训练结果用于对所述模型参数进行全局更新,得到全局模型参数。The training result is sent, and the training result is used to globally update the model parameters to obtain global model parameters.
  14. 根据权利要求13所述的方法,其特征在于,发送终端的数据分布信息,包括:The method according to claim 13, wherein the data distribution information of the sending terminal comprises:
    通过RRC信令发送所述数据分布信息。The data distribution information is sent through RRC signaling.
  15. 根据权利要求13所述的方法,其特征在于,所述模型参数包括初始化 模型参数,接收模型参数,包括:The method of claim 13, wherein the model parameters include initialization model parameters, and the receiving model parameters include:
    接收所述初始化模型参数,所述初始化模型参数是所述服务器采用所述数据分布信息选取的训练数据集训练得到的;Receive the initialization model parameters, where the initialization model parameters are obtained by training the server using the training data set selected by the data distribution information;
    或者,所述模型参数包括中间模型参数,接收模型参数,包括:Alternatively, the model parameters include intermediate model parameters, and the received model parameters include:
    接收所述中间模型参数,所述中间模型参数是所述服务器对初始化模型参数进行迭代更新得到的。The intermediate model parameters are received, where the intermediate model parameters are obtained by iteratively updating the initialization model parameters by the server.
  16. 根据权利要求15所述的方法,其特征在于,所述训练结果包括梯度值,所述梯度值为所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;The method according to claim 15, wherein the training result includes a gradient value, and the gradient value is a gradient value obtained by testing the trained model parameters after the model parameters are trained;
    或者,所述训练结果包括模型更新参数,所述模型更新参数为对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after training the model parameter.
  17. 根据权利要求16所述的方法,其特征在于,当所述训练结果包括模型更新参数时,The method according to claim 16, wherein when the training result includes model update parameters,
    发送所述训练结果,包括:Send the training results, including:
    基于所述模型更新参数的数据量和所述终端的通信状况,确定数据传输参数;determining data transmission parameters based on the data volume of the model update parameter and the communication status of the terminal;
    按照所述数据传输参数将所述模型更新参数发送给服务器。The model update parameters are sent to the server according to the data transmission parameters.
  18. 根据权利要求13所述的方法,其特征在于,所述方法还包括:The method of claim 13, wherein the method further comprises:
    发送用户调度信息,所述用户调度信息包括如下参数中的至少一项:终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为多个终端的数据分布信息合并得到的。Sending user scheduling information, where the user scheduling information includes at least one of the following parameters: data volume of data in the terminal, similarity between data distribution and total data distribution information, communication status, computing power, and learning model performance requirements, the The total data distribution information is obtained by combining the data distribution information of multiple terminals.
  19. 根据权利要求13至18任一项所述的方法,其特征在于,所述方法还包括:The method according to any one of claims 13 to 18, wherein the method further comprises:
    接收全局模型参数;Receive global model parameters;
    对所述全局模型参数进行自适应更新。The global model parameters are adaptively updated.
  20. 一种模型训练装置,其特征在于,所述装置包括:A model training device, characterized in that the device comprises:
    接收模块,被配置为接收多个终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;a receiving module, configured to receive data distribution information of multiple terminals, where the data distribution information includes data categories and the number of samples included in each category;
    选取模块,被配置为选取符合所述多个终端的数据分布信息的训练数据集;a selection module, configured to select a training data set that conforms to the data distribution information of the multiple terminals;
    模型训练模块,被配置为基于所述训练数据集进行模型训练,得到模型参数;a model training module, configured to perform model training based on the training data set to obtain model parameters;
    发送模块,被配置为将所述模型参数发送给所述多个终端中的至少部分终端;a sending module, configured to send the model parameters to at least some of the terminals;
    所述接收模块,还被配置为接收所述至少部分终端对所述模型参数进行训练得到的训练结果;The receiving module is further configured to receive a training result obtained by training the model parameters by the at least part of the terminals;
    所述模型训练模块,还被配置为基于所述至少部分终端的训练结果,对所述模型参数进行更新,得到全局模型参数。The model training module is further configured to update the model parameters based on the training results of the at least part of the terminals to obtain global model parameters.
  21. 根据权利要求20所述的装置,其特征在于,所述接收模块,被配置为接收所述多个终端中每个终端通过RRC信令传输的所述数据分布信息。The apparatus according to claim 20, wherein the receiving module is configured to receive the data distribution information transmitted by each of the multiple terminals through RRC signaling.
  22. 根据权利要求20所述的装置,其特征在于,所述选取模块,被配置为将所述多个终端的数据分布信息合并,得到总数据分布信息;从服务器的本地数据中抽选出分布符合所述总数据分布信息的数据,得到所述训练数据集。The apparatus according to claim 20, wherein the selection module is configured to combine the data distribution information of the multiple terminals to obtain total data distribution information; The data of the total data distribution information is obtained to obtain the training data set.
  23. 根据权利要求20所述的装置,其特征在于,所述模型参数包括初始化模型参数,所述模型训练模块,被配置为采用所述训练数据集进行模型训练,得到所述初始化模型参数;The device according to claim 20, wherein the model parameters include initialization model parameters, and the model training module is configured to use the training data set to perform model training to obtain the initialization model parameters;
    或者,所述模型参数包括中间模型参数,所述模型训练模块,被配置为采用所述训练数据集进行模型训练,得到初始化模型参数;对所述初始化模型参数进行迭代更新,得到所述中间模型参数。Alternatively, the model parameters include intermediate model parameters, and the model training module is configured to use the training data set for model training to obtain initialized model parameters; and to iteratively update the initialized model parameters to obtain the intermediate model parameter.
  24. 根据权利要求20所述的装置,其特征在于,所述接收模块,还被配置为接收所述多个终端中每个终端的用户调度信息;The apparatus according to claim 20, wherein the receiving module is further configured to receive user scheduling information of each terminal in the plurality of terminals;
    所述装置还包括:确定模块,被配置为基于所述多个终端中每个终端的用户调度信息,确定所述多个终端中每个终端是否符合分布式训练要求;The apparatus further includes: a determining module configured to determine, based on user scheduling information of each terminal in the plurality of terminals, whether each terminal in the plurality of terminals meets the distributed training requirement;
    所述发送模块,被配置为将所述模型参数发送给所述多个终端中符合分布式训练要求的终端。The sending module is configured to send the model parameters to a terminal that meets the distributed training requirement among the multiple terminals.
  25. 根据权利要求24所述的装置,其特征在于,所述用户调度信息包括如下参数中的至少一项:The apparatus according to claim 24, wherein the user scheduling information includes at least one of the following parameters:
    终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为所述多个终端的数据分布信息合并得到的。The data volume of the data in the terminal, the similarity between the data distribution and the total data distribution information, the communication status, the computing capability, and the performance requirements of the learning model, and the total data distribution information is obtained by combining the data distribution information of the multiple terminals.
  26. 根据权利要求24所述的装置,其特征在于,所述确定模块,还被配置 为基于所述模型参数的数据量和所述符合分布式训练要求的终端的通信状况,确定数据传输参数;The device according to claim 24, wherein the determining module is further configured to determine data transmission parameters based on the data volume of the model parameters and the communication status of the terminals that meet the distributed training requirements;
    所述发送模块,被配置为按照所述数据传输参数将所述模型参数发送给所述符合分布式训练要求的终端。The sending module is configured to send the model parameter to the terminal meeting the distributed training requirement according to the data transmission parameter.
  27. 根据权利要求20至26任一项所述的装置,其特征在于,所述训练结果包括梯度值,所述梯度值为所述终端对所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;The apparatus according to any one of claims 20 to 26, wherein the training result includes a gradient value, and the gradient value is obtained by training the model parameters by the terminal after training the trained The gradient value obtained from the model parameter test;
    或者,所述训练结果包括模型更新参数,所述模型更新参数为所述终端对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after the terminal trains the model parameter.
  28. 根据权利要求27所述的装置,其特征在于,当所述至少部分终端中每个终端的训练结果均包括梯度值时,The apparatus according to claim 27, wherein when the training result of each terminal in the at least part of the terminals includes a gradient value,
    所述模型训练模块,被配置为基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。The model training module is configured to use a gradient descent method to iteratively update the model parameters based on the average value of the gradient values of the at least part of the terminals to obtain global model parameters.
  29. 根据权利要求27所述的装置,其特征在于,当所述至少部分终端中至少一个终端的训练结果包括模型更新参数时,The apparatus according to claim 27, wherein when the training result of at least one terminal in the at least part of the terminals includes a model update parameter,
    所述选取模块,被配置为选取符合第一终端的数据分布信息的查询集,所述第一终端为训练结果包括模型更新参数的终端;The selection module is configured to select a query set that conforms to the data distribution information of the first terminal, where the first terminal is a terminal whose training result includes model update parameters;
    所述模型训练模块,被配置为基于所述查询集对所述第一终端的模型更新参数进行测试,得到梯度值;基于所述至少部分终端的梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新,得到全局模型参数。The model training module is configured to test the model update parameters of the first terminal based on the query set to obtain a gradient value; The model parameters are iteratively updated to obtain the global model parameters.
  30. 根据权利要求28或29所述的装置,其特征在于,所述模型训练模块,被配置为基于所述至少部分终端的第一梯度值的平均值,采用梯度下降法对所述模型参数进行迭代更新;确定所述至少部分终端的第一梯度值的平均值是否在阈值范围内;响应于所述至少部分终端的第一梯度值的平均值不在阈值范围内,将所述模型参数经过迭代更新后的中间模型参数发送给所述至少部分终端;采用所述至少部分终端的第二梯度值的平均值对所述中间模型参数进行迭代更新;其中,所述第二梯度值是所述终端对所述中间模型参数进行训练后,通过对训练后的所述中间模型参数测试得到的梯度值。The apparatus according to claim 28 or 29, wherein the model training module is configured to use a gradient descent method to iterate the model parameters based on the average value of the first gradient values of the at least part of the terminals updating; determining whether the average value of the first gradient values of the at least part of the terminals is within a threshold range; in response to the average value of the first gradient values of the at least part of the terminals being not within the threshold range, updating the model parameters iteratively The intermediate model parameters are sent to the at least part of the terminals; the intermediate model parameters are iteratively updated by using the average value of the second gradient values of the at least part of the terminals; wherein the second gradient value is the pair of the terminals. After the intermediate model parameters are trained, the gradient values obtained by testing the trained intermediate model parameters.
  31. 根据权利要求30所述的装置,其特征在于,所述发送模块,还被配置为响应于所述至少部分终端的第一梯度值的平均值在阈值范围内,将所述模型 参数经过迭代更新后的全局模型参数发送给所述至少部分终端,所述全局模型参数用于供所述终端进行自适应更新。The apparatus according to claim 30, wherein the sending module is further configured to iteratively update the model parameters in response to the average value of the first gradient values of the at least part of the terminals being within a threshold range The latter global model parameters are sent to the at least part of the terminals, and the global model parameters are used for adaptive updating by the terminals.
  32. 一种模型训练装置,其特征在于,所述装置包括:A model training device, characterized in that the device comprises:
    发送模块,被配置为发送终端的数据分布信息,所述数据分布信息包括数据的类别以及各个类别包含的样本数;a sending module, configured to send data distribution information of the terminal, where the data distribution information includes data categories and the number of samples included in each category;
    接收模块,被配置为接收模型参数,所述模型参数是服务器基于所述数据分布信息选取的训练数据集训练得到的;a receiving module, configured to receive model parameters, the model parameters are obtained by training a training data set selected by the server based on the data distribution information;
    模型训练模块,被配置为对所述模型参数进行训练,得到训练结果;a model training module, configured to train the model parameters to obtain a training result;
    所述发送模块,还被配置为发送所述训练结果,所述训练结果用于对所述模型参数进行全局更新,得到全局模型参数。The sending module is further configured to send the training result, where the training result is used to globally update the model parameters to obtain global model parameters.
  33. 根据权利要求32所述的装置,其特征在于,所述发送模块,被配置为通过RRC信令发送所述数据分布信息。The apparatus according to claim 32, wherein the sending module is configured to send the data distribution information through RRC signaling.
  34. 根据权利要求32所述的装置,其特征在于,所述模型参数包括初始化模型参数,所述接收模块,被配置为接收所述初始化模型参数,所述初始化模型参数是所述服务器采用所述数据分布信息选取的训练数据集训练得到的;The apparatus according to claim 32, wherein the model parameters include initialization model parameters, and the receiving module is configured to receive the initialization model parameters, wherein the initialization model parameters are the data used by the server for the data The training data set selected by the distribution information is obtained by training;
    或者,所述模型参数包括中间模型参数,所述接收模块,被配置为接收所述中间模型参数,所述中间模型参数是所述服务器对初始化模型参数进行迭代更新得到的。Alternatively, the model parameters include intermediate model parameters, and the receiving module is configured to receive the intermediate model parameters, where the intermediate model parameters are obtained by iteratively updating the initialization model parameters by the server.
  35. 根据权利要求34所述的装置,其特征在于,所述训练结果包括梯度值,所述梯度值为所述模型参数进行训练后,通过对训练后的所述模型参数测试得到的梯度值;The device according to claim 34, wherein the training result includes a gradient value, and the gradient value is a gradient value obtained by testing the trained model parameters after the model parameters are trained;
    或者,所述训练结果包括模型更新参数,所述模型更新参数为对所述模型参数进行训练后得到的模型参数。Alternatively, the training result includes a model update parameter, and the model update parameter is a model parameter obtained after training the model parameter.
  36. 根据权利要求35所述的装置,其特征在于,当所述训练结果包括模型更新参数时,The apparatus according to claim 35, wherein when the training result includes model update parameters,
    所述装置还包括:确定模块,被配置为基于所述模型更新参数的数据量和所述终端的通信状况,确定数据传输参数;The apparatus further includes: a determination module configured to determine a data transmission parameter based on the data volume of the model update parameter and the communication status of the terminal;
    所述发送模块,被配置为按照所述数据传输参数将所述模型更新参数发送给服务器。The sending module is configured to send the model update parameter to the server according to the data transmission parameter.
  37. 根据权利要求32所述的装置,其特征在于,所述发送模块,还被配置为发送用户调度信息,所述用户调度信息包括如下参数中的至少一项:终端中数据的数据量、数据分布与总数据分布信息的相似性、通信状况、计算能力、学习模型性能要求,所述总数据分布信息为多个终端的数据分布信息合并得到的。The apparatus according to claim 32, wherein the sending module is further configured to send user scheduling information, wherein the user scheduling information includes at least one of the following parameters: data volume of data in the terminal, data distribution Similarity, communication status, computing capability, and learning model performance requirements with the total data distribution information, the total data distribution information is obtained by combining the data distribution information of multiple terminals.
  38. 根据权利要求32至37任一项所述的装置,其特征在于,所述接收模块,还被配置为接收全局模型参数;The apparatus according to any one of claims 32 to 37, wherein the receiving module is further configured to receive global model parameters;
    所述模型训练模块,还被配置为对所述全局模型参数进行自适应更新。The model training module is further configured to adaptively update the global model parameters.
  39. 一种服务器,其特征在于,所述服务器包括:A server, characterized in that the server comprises:
    处理器;processor;
    用于存储处理器可执行指令的存储器;memory for storing processor-executable instructions;
    其中,所述处理器被配置为加载并执行所述可执行指令以实现权利要求1至12任一项所述的模型训练方法。Wherein, the processor is configured to load and execute the executable instructions to implement the model training method of any one of claims 1 to 12.
  40. 一种终端,其特征在于,所述终端包括:A terminal, characterized in that the terminal comprises:
    处理器;processor;
    用于存储处理器可执行指令的存储器;memory for storing processor-executable instructions;
    其中,所述处理器被配置为加载并执行所述可执行指令以实现权利要求13至19任一项所述的模型训练方法。Wherein, the processor is configured to load and execute the executable instructions to implement the model training method of any one of claims 13 to 19.
  41. 一种计算机可读存储介质,其特征在于,当所述计算机可读存储介质中的指令由处理器执行时,能够执行权利要求1至12任一所述的模型训练方法,或者,能够执行权利要求13至19任一项所述的模型训练方法。A computer-readable storage medium, characterized in that, when the instructions in the computer-readable storage medium are executed by a processor, the model training method of any one of claims 1 to 12 can be executed, or the right The model training method described in any one of 13 to 19 is required.
PCT/CN2020/123292 2020-10-23 2020-10-23 Model training method and device, server, terminal, and storage medium WO2022082742A1 (en)

Priority Applications (2)

Application Number Priority Date Filing Date Title
PCT/CN2020/123292 WO2022082742A1 (en) 2020-10-23 2020-10-23 Model training method and device, server, terminal, and storage medium
CN202080002976.6A CN114667523A (en) 2020-10-23 2020-10-23 Model training method, device, server, terminal and storage medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
PCT/CN2020/123292 WO2022082742A1 (en) 2020-10-23 2020-10-23 Model training method and device, server, terminal, and storage medium

Publications (1)

Publication Number Publication Date
WO2022082742A1 true WO2022082742A1 (en) 2022-04-28

Family

ID=81291178

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2020/123292 WO2022082742A1 (en) 2020-10-23 2020-10-23 Model training method and device, server, terminal, and storage medium

Country Status (2)

Country Link
CN (1) CN114667523A (en)
WO (1) WO2022082742A1 (en)

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20190042956A1 (en) * 2018-02-09 2019-02-07 Intel Corporation Automatic configurable sequence similarity inference system
CN109716346A (en) * 2016-07-18 2019-05-03 河谷生物组学有限责任公司 Distributed machines learning system, device and method
CN110956202A (en) * 2019-11-13 2020-04-03 重庆大学 Image training method, system, medium and intelligent device based on distributed learning
CN111444848A (en) * 2020-03-27 2020-07-24 广州英码信息科技有限公司 Specific scene model upgrading method and system based on federal learning
CN111611610A (en) * 2020-04-12 2020-09-01 西安电子科技大学 Federal learning information processing method, system, storage medium, program, and terminal

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN109716346A (en) * 2016-07-18 2019-05-03 河谷生物组学有限责任公司 Distributed machines learning system, device and method
US20190042956A1 (en) * 2018-02-09 2019-02-07 Intel Corporation Automatic configurable sequence similarity inference system
CN110956202A (en) * 2019-11-13 2020-04-03 重庆大学 Image training method, system, medium and intelligent device based on distributed learning
CN111444848A (en) * 2020-03-27 2020-07-24 广州英码信息科技有限公司 Specific scene model upgrading method and system based on federal learning
CN111611610A (en) * 2020-04-12 2020-09-01 西安电子科技大学 Federal learning information processing method, system, storage medium, program, and terminal

Also Published As

Publication number Publication date
CN114667523A (en) 2022-06-24

Similar Documents

Publication Publication Date Title
WO2021243619A1 (en) Information transmission method and apparatus, and communication device and storage medium
KR102630605B1 (en) Communication methods and related devices
Kamoun et al. Joint resource allocation and offloading strategies in cloud enabled cellular networks
CN111869303A (en) Resource scheduling method, device, communication equipment and storage medium
WO2020187004A1 (en) Scheduling method and apparatus in communication system, and storage medium
US20230409962A1 (en) Sampling user equipments for federated learning model collection
CN115208812B (en) Service processing method and device, equipment and computer readable storage medium
WO2022104799A1 (en) Training method, training apparatus, and storage medium
WO2022099512A1 (en) Data processing method and apparatus, communication device, and storage medium
CN114097259A (en) Communication processing method, communication processing device and storage medium
US12041480B2 (en) User equipment and wireless communication method for neural network computation
CN115087036A (en) Background data transmission strategy configuration method and device
WO2023240572A1 (en) Information transmission method and apparatus, and communication device and storage medium
CN113692052A (en) Network edge machine learning training method
WO2022082742A1 (en) Model training method and device, server, terminal, and storage medium
WO2023082280A1 (en) Model updating method for wireless channel processing, apparatus, terminal, and medium
WO2021103947A1 (en) Scheduling method and device
WO2022133689A1 (en) Model transmission method, model transmission device, and storage medium
US11277715B2 (en) Transmit multicast frame
CN109361431B (en) Slice scheduling method and system
CN114270934A (en) Method, device, communication equipment and storage medium for controlling data transmission rate
US20240089742A1 (en) Data transmission method and related apparatus
CN113556247B (en) Multi-layer parameter distributed data transmission method, device and readable medium
US20230354296A1 (en) Node scheduling method and apparatus
US11877168B2 (en) Radio frame analysis system, radio frame analysis method, and program

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 20958324

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

122 Ep: pct application non-entry in european phase

Ref document number: 20958324

Country of ref document: EP

Kind code of ref document: A1