WO2021204014A1 - 一种模型训练的方法及相关装置 - Google Patents

一种模型训练的方法及相关装置 Download PDF

Info

Publication number
WO2021204014A1
WO2021204014A1 PCT/CN2021/083815 CN2021083815W WO2021204014A1 WO 2021204014 A1 WO2021204014 A1 WO 2021204014A1 CN 2021083815 W CN2021083815 W CN 2021083815W WO 2021204014 A1 WO2021204014 A1 WO 2021204014A1
Authority
WO
WIPO (PCT)
Prior art keywords
sample
trained
model
vector
data set
Prior art date
Application number
PCT/CN2021/083815
Other languages
English (en)
French (fr)
Inventor
谯轶轩
陈浩
高鹏
Original Assignee
平安科技(深圳)有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 平安科技(深圳)有限公司 filed Critical 平安科技(深圳)有限公司
Publication of WO2021204014A1 publication Critical patent/WO2021204014A1/zh

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/33Querying
    • G06F16/3331Query processing
    • G06F16/334Query execution
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Definitions

  • the embodiments of the present application relate to the field of artificial intelligence technology, and in particular to a method and related devices for model training based on confrontation.
  • Text matching is the core issue in the field of information retrieval. Text matching can be attributed to the matching of query items and documents, that is, matching scores are given to query items and documents through text matching models. The higher the matching scores, the stronger the relevance of the query items and documents.
  • the text matching model can match semantically similar words or phrases.
  • the deep learning model itself has a large amount of parameters and requires a large amount of data to fully train the model.
  • the positive samples are the documents that users actually clicked, and the negative samples are randomly selected from all documents
  • the large difference in matching score means that the robustness of the model is reduced.
  • the embodiment of the present application discloses a model training method and related devices. By improving the method of generating samples in model training, the difficulty of model training is increased, thereby enhancing the robustness of the model.
  • an example of this application discloses a method of model training, including:
  • the sample to be trained is a positive sample
  • the first model, the second model and the third model are determined to converge.
  • an embodiment of the present application discloses a model training device, including:
  • the acquiring unit is configured to acquire a data set to be processed, acquire a sample to be trained from the data set to be processed, and use the first model to obtain a vector representing the sample to be trained.
  • the samples contained in the data set to be processed include positive samples and negative samples. sample;
  • the processing unit is configured to input the vector representing the sample to be trained into the second model when the sample to be trained is a positive sample to generate a vector representing the adversarial sample of the sample to be trained;
  • the vector of the adversarial sample of the sample to be trained is input to the third model, and the output value is obtained;
  • a calculation unit configured to determine the sub-loss value of the sample to be trained according to the output value, calculate the sum of the sub-loss values of all samples in the data set to be processed, to obtain a total loss value
  • the determining unit is configured to determine the convergence of the first model, the second model, and the third model when the difference between the total loss value obtained twice before and after is less than the threshold value.
  • an embodiment of the present application discloses a server, including: a processor and a memory, where a computer program is stored in the memory, and the processor invokes the computer program stored in the memory to execute the following method:
  • the sample to be trained is a positive sample
  • the first model, the second model and the third model are determined to converge.
  • the embodiments of the present application disclose a computer-readable storage medium.
  • the computer-readable storage medium stores a computer program.
  • the computer program runs on one or more processors, the following method is executed:
  • the sample to be trained is a positive sample
  • the first model, the second model and the third model are determined to converge.
  • this application can increase the difficulty of model training, thereby improving the update efficiency of the model's parameters; on the other hand, it can improve the model's ability to process boundary data, thereby improving the robustness of the model.
  • FIG. 1 is a schematic flowchart of a model training method disclosed in an embodiment of the present application
  • FIG. 2 is a schematic flowchart of a loss value calculation method disclosed in an embodiment of the present application
  • FIG. 3 is a schematic flowchart of another model training method disclosed in an embodiment of the present application.
  • FIG. 4 is a schematic flowchart of another model training method disclosed in an embodiment of the present application.
  • Fig. 5 is a schematic structural diagram of a model training device disclosed in an embodiment of the present application.
  • Fig. 6 is a schematic structural diagram of a server disclosed in an embodiment of the present application.
  • At least one (item) refers to one or more
  • “multiple” refers to two or more than two
  • “at least two (item)” refers to two or three and three
  • “and/or” is used to describe the association relationship of associated objects, which means that there can be three kinds of relationships.
  • a and/or B can mean: there is only A, only B, and both A and B. In this case, A and B can be singular or plural.
  • the character “/” generally indicates that the associated objects before and after are in an “or” relationship.
  • "The following at least one item (a)” or similar expressions refers to any combination of these items.
  • at least one of a, b, or c can mean: a, b, c, "a and b", “a and c", “b and c", or "a and b and c” ".
  • the technical solution of the present application relates to the field of artificial intelligence and/or big data technology.
  • it can specifically relate to neural network technology and can be applied to scenarios such as information retrieval.
  • the data involved in this application such as samples, output values, and/or loss values, can be stored in a database, or can be stored in a blockchain, which is not limited in this application.
  • the embodiment of the present application provides a method for model training, which improves the difficulty of model training by improving the method of generating samples in model training, thereby enhancing the robustness of the model.
  • FIG. 1 is a schematic flowchart of a model training method disclosed in an embodiment of the present application. As shown in Figure 1, the above method includes:
  • the processing method of this application includes:
  • the initial data is obtained from the public data set of Microsoft's document sorting task.
  • si ⁇ (d i1 ,l i1 ),(d i2 ,l i2 ),...,(d ij ,l ij ),...,(d im ,l im ) ⁇ , where d ij represents the j-th search result corresponding to the i-th query item, that is, a document item; l ij is a label item, and when l ij is 1, it means that the user clicked on the search result, l ij When it is 0, it means that the user did not click on the search result; m means the number of document items in the search result.
  • the document item corresponding to each query item may contain redundant information.
  • the above data set needs to be processed to obtain positive and negative samples that are effective for model training.
  • the specific processing steps include:
  • the word segmentation tools that can be used include stuttering word segmentation, Pangu word segmentation, etc., and the stuttering word segmentation model can be used to segment the above query items and document items.
  • the stuttering word segmentation is based on the prefix dictionary to achieve efficient word map scanning to generate Chinese characters in the sentence A directed acyclic graph composed of all possible word formation situations, and then dynamic planning to find the path of maximum probability, and find the maximum segmentation combination based on word frequency. Since the above-mentioned stuttering word segmentation is a very typical word segmentation tool, the specific principle will not be repeated here. .
  • Common stop vocabulary lists include the stop vocabulary of Harbin Institute of Technology, Baidu stop vocabulary, and the stop vocabulary of the Machine Intelligence Laboratory of Sichuan University. You can first use the stop vocabulary of Harbin Institute of Technology to filter the results of the above word segmentation. Regular expressions and manual filtering methods filter out the high-frequency network vocabulary in the above-mentioned word segmentation results. Since the above method of filtering stop words is a very typical processing step, it will not be repeated here.
  • the opencc toolkit can be used to convert traditional Chinese in the text to simplified, and at the same time, to convert uppercase English letters in the text to lowercase English letters. Since the above method is a very typical processing step, it will not be repeated here.
  • the document list si returned by the search engine contains a large number of documents and contains a large number of useless results.
  • the data set can be filtered by means of text matching model filtering. A higher quality result in the data set.
  • the matching score of the two can be expressed as:
  • r ij represents the matching score of the query item q i and the corresponding document d ij
  • F is the text matching model BM25.
  • the above-mentioned BM25 is an algorithm used to evaluate the correlation between the search term and the document. Item segmentation is performed, and the value corresponding to the relevance of each segmentation and the document item is weighted and summed to obtain the matching score between the query item and the document item. Since the above-mentioned BM25 algorithm is a typical algorithm, the specific principle will not be repeated here; A threshold is set for the above matching scores, and query items and document items whose matching scores exceed the threshold are retained as samples of the data set.
  • each query item q i After filtering the data set through the above text matching model BM25, each query item q i still contains multiple documents, but compared to before filtering, the number of documents in the filtered document list is less than that of the document list before filtering.
  • the number of documents, and the documents in the document list are relatively high-quality documents relative to the query item.
  • S102 Obtain a sample to be trained from the foregoing data set to be processed, and use the first model to obtain a vector representing the foregoing sample to be trained.
  • the sample to be trained is any sample in the data set to be processed; corresponding to the processing result of the data set in step 101, the sample to be trained includes query items, document items, and label items.
  • the specific implementation method of using the first model to obtain the vector representing the sample to be trained is to input the query item and document item of the sample to be trained into the first model to obtain the vector corresponding to the query item of the sample to be trained and the above
  • the vector corresponding to the document item of the sample to be trained therefore, the vector used to represent the sample to be trained includes the vector corresponding to the query item and the vector corresponding to the document item;
  • the first model includes the Recurrent Neural Network (RNN) ) Model, it should be noted that in addition to the above RNN model, variants of the above RNN model can also be used: Long short-term memory (LSTM) model and Gated recurrent unit (GRU) model, etc. ;
  • the vector used to represent the sample to be trained by inputting the above query items and document items into the first model contains more information than the query items and document items before the first model is not input, which can improve model training s efficiency.
  • the value of the label item can be used to determine whether the sample to be trained is a positive sample.
  • the sample to be trained is 1, the sample to be trained is a positive sample; when the sample to be trained is 0,
  • the above sample to be trained is a negative sample, and the subsequent classification method for the sample to be trained is the same.
  • Inputting the vector representing the sample to be trained into the second model to generate a vector representing the adversarial sample of the sample to be trained includes merging the vector of query terms of the sample to be trained and the vector of document terms of the sample to be trained, Then input the second model to generate a vector representing the adversarial document corresponding to the document item of the sample to be trained, and obtain a vector representing the adversarial sample of the sample to be trained, wherein the adversarial sample used to represent the sample to be trained
  • the vector of includes the vector corresponding to the query item of the sample to be trained and the vector of the confrontation document corresponding to the document item used to represent the sample to be trained.
  • the above second model includes a variational encoder-decoder (VED) model.
  • VED variational encoder-decoder
  • GAN Generative Adversarial Network
  • GPT Generative Pre-Training
  • S104 Input the vector representing the adversarial sample of the sample to be trained into the third model to obtain an output value.
  • the above-mentioned third model includes a deep neural network (Deep Neural Networks, DNN) model; the above-mentioned output value is a vector with a dimension greater than or equal to 2.
  • DNN Deep Neural Networks
  • This application does not limit the specific dimension of the above-mentioned vector. Setting as a two-dimensional vector is the preferred method for this application.
  • the vector of the confrontation sample of the sample to be trained is input into the third model, the vector of the confrontation sample of the sample to be trained is marked as a negative sample, so as to improve the quality of the negative sample in the data set, thereby improving the efficiency of model training.
  • the vector used to represent the sample to be trained is input to the third model to obtain an output value.
  • S105 Determine the sub-loss value of the sample to be trained according to the output value, and calculate the sum of the sub-loss values of all samples in the data set to be processed to obtain a total loss value.
  • the values output by the third model are collectively referred to as output values, and the sub-loss values determined based on the above-mentioned output values are collectively referred to as the sub-loss values of the samples to be trained.
  • output values the values output by the third model
  • sub-loss values determined based on the above-mentioned output values are collectively referred to as the sub-loss values of the samples to be trained.
  • Figure 2 is a schematic flowchart of a loss value calculation method disclosed in an embodiment of the present application. As shown in the figure, the above method includes:
  • S201 Use the vector corresponding to the output value as the first vector, and perform one-hot encoding on the label of the sample to be trained to obtain the second vector.
  • the vector corresponding to the output value is preferably a two-dimensional vector, that is, the first vector; the label of the sample to be trained is one-hot encoded, and a two-dimensional vector with the same dimension as the vector corresponding to the output value can be obtained.
  • Vector that is, the second vector; one-hot encoding is used to optimize the distance between discrete features.
  • the specific principle will not be repeated; in the specific implementation process, in the above-mentioned sample to be trained Is a positive sample, that is, when the label is 1, the vector [1,0] is obtained through one-hot encoding, and when the sample to be trained is a negative sample, that is, the label is 0, the vector [0, 1].
  • S202 Multiply and add the values of the same dimension in the first vector and the second vector to obtain the score loss value of the sample to be trained.
  • the first vector and the second vector have the same dimensions, and are preferably a two-dimensional vector
  • the data of the first dimension of the two vectors is multiplied to obtain the first result
  • the second dimension of the two vectors is The data is multiplied to obtain a second result
  • the above-mentioned first result plus the above-mentioned second result is the point loss value of the above-mentioned sample to be trained.
  • the above method includes: after the vector used to represent the adversarial sample of the sample to be trained is input to the DNN model, a two-dimensional prediction vector is first obtained, and then the two-dimensional prediction vector is input into softmax (a logic The regression model) layer maps each value in the two-dimensional prediction vector to a number greater than 0 and less than 1, as the vector corresponding to the output value, that is, the first vector.
  • softmax a logic The regression model
  • the first two-dimensional prediction vector output by the above DNN model is [1,1], after passing through the softmax layer, the above vector is converted to [0.5,0.5].
  • the above-mentioned multiplying the data of the same dimension of the above-mentioned first vector and the above-mentioned second vector, and then summing the above-mentioned results to obtain the sub-loss value of the above-mentioned sample to be trained includes: Take the logarithm, preferably, take the logarithm with 10 as the base, multiply the above-mentioned logarithmic vector with the data of the same dimension of the above-mentioned first vector, and finally sum the above-mentioned result, and then the above-mentioned sum result
  • the inverse number of is used as the score loss value of the sample to be trained.
  • the above sample to be trained is a positive sample, then the vector obtained after one-hot encoding of the label is [1,0], and the first two-dimensional prediction vector output by the sample to be trained through the above DNN model is [1,1]
  • the above vector is converted to [0.5,0.5], that is, the above second vector is [0.5,0.5], and the above first vector is [1,1]; first, take the above second vector 10 is the logarithm of the base, and the vector [log0.5,log0.5] is obtained.
  • the above vector [log0.5,log0.5] has the same dimension as the first vector [1,1].
  • S203 Calculate the sum of the sub-loss values of all samples in the data set to be processed to obtain a total loss value.
  • the total loss value is obtained by summing the sub-loss values of the N samples in the above-mentioned data set to be processed.
  • the absolute value of the difference between the above total loss values is less than the first threshold, then it can be determined that the above first, second, and third models converge; the above first threshold It is a number greater than 0.
  • the above-mentioned first threshold is set to 0.01.
  • it can also be set to a value smaller than 0.01, such as 0.001. This application does not make any restrictions.
  • FIG. 3 is a schematic flowchart of another model training method disclosed in an embodiment of the present application. As shown in FIG. 3, the above method includes:
  • S302 Obtain a sample to be trained from the foregoing data set to be processed, and use the first model to obtain a vector representing the foregoing sample to be trained.
  • step 301 and step 302 have been explained in the foregoing, and will not be repeated here.
  • the Bernoulli distribution is a discrete probability distribution. If the random variable obeys the Bernoulli distribution with a parameter of P, then the random variable takes 1 as the value of the probability P and 0 as the value of the probability 1-P; In the embodiment of the present application, the parameter P subject to the Bernoulli distribution is less than the second threshold, and the second threshold is a number greater than 0 and less than 1.
  • the vector representing the aforementioned sample to be trained is input into the second model to generate a vector representing the adversarial sample of the aforementioned sample to be trained, That is, when the above reference value is 1, the positive sample is processed to generate adversarial samples; then, for any positive sample, the probability of the need to generate adversarial sample processing is P.
  • the positive sample to be processed For the positive sample to be processed, For the set of samples, it is equivalent to extracting 100*P% of positive samples to generate adversarial sample processing; correspondingly, when the above reference is 0, the above vector representing the above sample to be trained is input into the third model, Get the output value.
  • the above random variable obeys the Bernoulli distribution with a parameter of 0.5.
  • the sample to be trained is a positive sample
  • the Bernoulli distribution with a parameter of 0.5 is taken as a reference value, for any positive sample
  • the parameter that the Bernoulli distribution obeys to 0.5 can make the difficulty of model training moderate and improve the efficiency of model training.
  • the positive sample is processed against the sample, which is equivalent to the positive sample Randomly select 70% of the positive samples from the set of Randomly select 20% of the positive samples from the set for adversarial sample generation processing, so as to achieve the purpose of controlling the difficulty of model training.
  • S306 Determine the sub-loss value of the sample to be trained according to the output value, and calculate the sum of the sub-loss values of all samples in the data set to be processed to obtain a total loss value.
  • step 305, step 306, and step 307 have been explained in the foregoing, and will not be repeated here.
  • using the first model to obtain a vector for representing the sample to be trained includes: inputting the query item and the document item into a recurrent neural network model to obtain the vector corresponding to the query item as the third Vector, and the vector corresponding to the aforementioned document item is obtained as the fourth vector.
  • the foregoing inputting the foregoing vector representing the foregoing sample to be trained into a third model to obtain an output value includes: combining the foregoing third vector and the foregoing fourth vector to obtain a fifth vector, and combining the foregoing The fifth vector is input to the deep convolution model and the vector is obtained as the output value.
  • the vector used to represent the confrontation sample of the sample to be trained includes the third vector and a vector representing the confrontation document corresponding to the sample to be trained.
  • the above inputting the vector representing the confrontation sample of the sample to be trained into a third model to obtain an output value includes: combining the third vector with the confrontation representing the sample to be trained.
  • the vectors of the documents are merged to obtain a sixth vector, and the vector obtained by inputting the above-mentioned sixth vector into the third model is used as an output value.
  • FIG. 4 is another model training method disclosed in the embodiment of this application. As shown in the figure, the above method includes:
  • the samples in the data set include positive samples and negative samples.
  • Each sample includes query items, document items, and label items.
  • the label item is used to indicate the category of the sample. In the case of label 1, The sample is a positive sample. When the label is 0, the sample is a negative sample.
  • step 401 perform word segmentation, remove stop words, font conversion, and filter the data set to obtain a data set with a total number of N samples, as shown in step 401, step 402, step 403, and step 404 in Figure 4.
  • step 401 perform word segmentation, remove stop words, font conversion, and filter the data set to obtain a data set with a total number of N samples, as shown in step 401, step 402, step 403, and step 404 in Figure 4.
  • the query items and document items after step 404 are effective information compared with the information contained in the query items and document items before step 401, which is beneficial to the model train.
  • each sample in the data set only one of the three processing methods is adopted.
  • the specific method depends on the type of the sample, that is, whether the sample is a positive sample or a negative sample, and the set of positive samples requires adversarial sample generation processing The number of positive samples.
  • the specific process for each sample in the data set is as follows:
  • the query item and the document item are input into the RNN model, and the vector corresponding to the query item and the vector corresponding to the document item are obtained respectively, as shown in step 405 in FIG. 4.
  • the label of the above-mentioned sample is judged, and the category of the above-mentioned sample is determined, as shown in step 406 in FIG. 4.
  • the value of the random variable is used to determine whether to perform the adversarial sample generation processing on the above-mentioned positive sample, so as to control the ratio of adversarial sample generation and achieve the purpose of controlling the difficulty of model training.
  • the vector corresponding to the query item and the vector corresponding to the document item are input into the VED model to obtain the vector corresponding to the positive sample of the confrontation document, as shown in step 408 in FIG. 4, and then the above query item
  • the corresponding vector and the vector corresponding to the above-mentioned adversarial document are input into the DNN model, and the above-mentioned positive sample is marked as a negative sample, and the output vector corresponding to the sample is obtained.
  • the cross loss function is used to calculate the sub-loss value corresponding to the above sample.
  • the total loss value in one training is calculated according to the above-mentioned sub-loss value.
  • the parameters of the model will be updated accordingly.
  • the selection of the proportion of the positive samples generated by the adversarial sample can be adopted with the first
  • the above ratio can also be adjusted according to the experimental requirements. For example, in the first training, 50% of the positive samples in the set of positive samples are generated against samples. In the second training, the positive samples are aligned. 60% of the positive samples in the sample set are processed to generate adversarial samples, which are sequentially increased in the subsequent steps to gradually increase the difficulty of model training.
  • the model training method proposed in this application generates adversarial examples based on the positive samples in the data set, which can improve the similarity between the generated adversarial documents and the original documents of the positive examples; using adversarial examples as the negative samples of the data set can be Improve the quality of negative samples in model training; training the model with a data set containing adversarial samples as negative samples, on the one hand, can increase the difficulty of model training, thereby improving the update efficiency of model parameters; on the other hand, it can improve model pairing The processing capacity of boundary data, thereby improving the robustness of the model.
  • the fully trained VED model can be disassembled in singular numbers and directly used for the adversarial sample generation processing of the given positive samples, thereby improving the efficiency of model training and shortening The life cycle of the project.
  • FIG. 5 is a schematic structural diagram of a model training apparatus disclosed in an embodiment of the application.
  • the above-mentioned data forwarding apparatus 110 may include an acquisition unit 501, a processing unit 502, a calculation unit 503, and a determination unit 504, wherein: The description of each unit is as follows:
  • the acquiring unit 501 is configured to acquire a data set to be processed, acquire a sample to be trained from the data set to be processed, and use the first model to obtain a vector representing the sample to be trained.
  • the samples contained in the data set to be processed include positive samples and Negative sample
  • the processing unit 502 is configured to input the vector representing the sample to be trained into the second model when the sample to be trained is a positive sample to generate a vector representing the adversarial sample of the sample to be trained;
  • the vector of the adversarial sample of the sample to be trained is input to the third model, and the output value is obtained;
  • the calculation unit 503 is configured to determine the sub-loss value of the sample to be trained according to the output value, calculate the sum of the sub-loss values of all samples in the data set to be processed, to obtain a total loss value;
  • the determining unit 504 is configured to determine that the first model, the second model, and the third model converge when the difference between the total loss value obtained twice before and after is less than the threshold value.
  • the foregoing device further includes:
  • the marking unit 505 is configured to mark the above-mentioned adversarial sample as a negative sample of the above-mentioned data set to be processed.
  • the processing unit 502 is further configured to extract a reference value for a random variable that obeys the Bernoulli distribution when the sample to be trained is a positive sample.
  • the parameter is less than the second threshold; when the reference value is 1, the vector representing the sample to be trained is input into the second model to generate a vector representing the adversarial sample of the sample to be trained.
  • the processing unit 502 is further configured to input the vector representing the sample to be trained into the third model when the reference value is 0 to obtain an output value.
  • the processing unit 502 is further configured to input the vector used to represent the sample to be trained into the third model when the sample to be trained is a negative sample to obtain an output value.
  • the sample to be trained includes a query item and a document item;
  • the vector representing the sample to be trained includes: a vector corresponding to the query item and a vector corresponding to the document item;
  • the first model includes A cyclic neural network model, the above-mentioned second model includes a variational codec model, and the above-mentioned third model includes a deep neural network model.
  • the processing unit 502 yuan is also used to input the query item and the document item into the recurrent neural network model to obtain the vector corresponding to the query item as the third vector, and to obtain the vector corresponding to the document item The vector is used as the fourth vector.
  • the processing unit 502 is further configured to merge the third vector and the fourth vector to obtain a fifth vector, and input the fifth vector into the deep convolution model to obtain the vector as an output value.
  • the vector used to represent the confrontation sample of the sample to be trained includes the third vector and a vector representing the confrontation document corresponding to the sample to be trained.
  • the processing unit 502 is further configured to merge the third vector and the vector representing the confrontation document corresponding to the sample to be trained to obtain a sixth vector, and input the sixth vector into the third model The resulting vector is used as the output value.
  • the foregoing device further includes:
  • the encoding unit 506 is configured to use the vector corresponding to the output value as a first vector, and perform one-hot encoding on the label of the sample to be trained to obtain a second vector, and the first vector and the second vector have the same vector dimension;
  • the calculation unit 503 is further configured to multiply and add the values of the same dimension in the first vector and the second vector to obtain the score loss value of the sample to be trained; to calculate the score of all samples in the data set to be processed. Sum the loss values to get the total loss value.
  • the model training method proposed in this application generates adversarial examples based on the positive samples in the data set, which can improve the similarity between the generated adversarial documents and the original documents of the positive examples; using adversarial examples as the negative samples of the data set can be Improve the quality of negative samples in model training; training the model with a data set containing adversarial samples as negative samples, on the one hand, can increase the difficulty of model training, thereby improving the update efficiency of model parameters; on the other hand, it can improve model pairing The processing capacity of boundary data, thereby improving the robustness of the model.
  • FIG. 6 is a schematic structural diagram of a server disclosed in an embodiment of the present application.
  • the foregoing server 60 may include a memory 601 and a processor 602. Further optionally, it may also include a communication interface 603 and a bus 604, where the memory 601, the processor 602, and the communication interface 603 implement communication connections between each other through the bus 604.
  • the communication interface 603 is used for data interaction with the spatiotemporal data query device.
  • the memory 601 is used to provide storage space, and the storage space can store data such as an operating system and a computer program.
  • the memory 601 includes but is not limited to random access memory (RAM), read-only memory (ROM), erasable programmable read-only memory (EPROM), or Portable read-only memory (compact disc read-only memory, CD-ROM).
  • the processor 602 is a module that performs arithmetic operations and logical operations, and can be a processing module such as a central processing unit (CPU), a graphics processing unit (GPU), or a microprocessor (MPU) One or a combination of more.
  • a processing module such as a central processing unit (CPU), a graphics processing unit (GPU), or a microprocessor (MPU) One or a combination of more.
  • a computer program is stored in the memory 601, and the processor 602 calls the computer program stored in the memory 601 to perform the following operations:
  • the sample to be trained is a positive sample
  • the first model, the second model and the third model are determined to converge.
  • server 60 may also correspond to the corresponding description of the method embodiments shown in FIG. 2, FIG. 3, and FIG. 4.
  • the embodiments of the present application also provide a computer-readable storage medium, and the computer-readable storage medium stores a computer program.
  • the computer program runs on one or more processors, it can implement Figure 1, Figure 2 and Figure 2. 3 and the method of model training shown in Figure 4.
  • the storage medium involved in this application such as a computer-readable storage medium, may be non-volatile or volatile.
  • the embodiments of the present application also provide a computer program product.
  • the computer program product includes program instructions.
  • the processor executes part or all of the steps of the method in the above embodiments, which will not be repeated here. .
  • the model training method proposed in this application generates adversarial examples based on the positive samples in the data set, which can improve the similarity between the generated adversarial documents and the original documents of the positive examples; using adversarial examples as the negative samples of the data set can be Improve the quality of negative samples in model training; use the data set containing adversarial samples as negative samples to train the model.
  • it can increase the difficulty of model training, thereby improving the update efficiency of model parameters; on the other hand, it can improve the model pairing The processing capacity of boundary data, thereby improving the robustness of the model.
  • the above-mentioned processes can be completed by computer program-related hardware.
  • the above-mentioned computer programs can be stored in a computer-readable storage medium. , May include the processes of the above-mentioned method embodiments.
  • the aforementioned storage media include: read-only memory ROM or random storage memory RAM, magnetic disks or optical disks and other media that can store computer program codes.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • Computational Linguistics (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Databases & Information Systems (AREA)
  • Machine Translation (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

一种模型训练的方法及相关装置。上述方法包括:获取待处理数据集(101);从待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量(102);上述待处理数据集包含的样本包括正样本和负样本;在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量(103);将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值(104);根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值(105);在前后两次获得的总损失值之间的差值小于第一阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛(106)。该方法通过改进模型训练中样本的生成方法,提高模型训练的难度,从而增强模型的鲁棒性。

Description

一种模型训练的方法及相关装置
本申请要求于2020年11月12日提交中国专利局、申请号为202011261109.2,发明名称为“一种模型训练的方法及相关装置”的中国专利申请的优先权,其全部内容通过引用结合在本申请中。
技术领域
本申请实施例涉及人工智能技术领域,具体涉及一种基于对抗的模型训练的方法及相关装置。
背景技术
文本匹配是信息检索领域的核心问题。文本匹配可以归结为查询项和文档的匹配,即通过文本匹配模型对查询项和文档给出匹配分数,匹配分数越高,查询项与文档的相关性越强。
发明人发现,基于BM25(一种用来评价搜索词和文档之间相关性的算法)的文本匹配模型只能在查询项和文档具有重复词的情况下对两者进行匹配;基于深度学习的文本匹配模型则可以将语义相似的词或词组进行匹配。深度学习模型本身参数量大,需要大量数据对模型进行充分训练,对上述基于深度学习的文本匹配模型构建样本进行训练时,正样本为用户真实点击的文档,负样本为所有文档中随机抽取的文档;但是,发明人意识到,随着上述负样本数量的增加,上述基于深度学习的文本匹配模型对查询项和相对应的真实文档之间的语义理解能力下降,对相似的文档给出的匹配分数差别大,即模型的鲁棒性降低。
发明内容
本申请实施例公开了一种模型训练的方法及相关装置,通过改进模型训练中样本的生成方法,提高模型训练的难度,从而增强模型的鲁棒性。
第一方面,本申请实例公开了一种模型训练的方法,包括:
获取待处理数据集,上述待处理数据集包含的样本包括正样本和负样本;
从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量;
在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量;
将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值;
根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值;
在前后两次获得的总损失值之间的差值小于阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛。
第二方面,本申请实施例公开了一种模型训练的装置,包括:
获取单元,用于获取待处理数据集,从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量,上述待处理数据集包含的样本包括正样本和负样本;
处理单元,用于在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量;将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值;
计算单元,用于根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值;
确定单元,用于在前后两次获得的总损失值之间的差值小于阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛。
第三方面,本申请实施例公开了一种服务器,包括:处理器和存储器,其中,上述存 储器中存储有计算机程序,上述处理器调用上述存储器中存储的计算机程序,用于执行以下方法:
获取待处理数据集,上述待处理数据集包含的样本包括正样本和负样本;
从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量;
在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量;
将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值;
根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值;
在前后两次获得的总损失值之间的差值小于阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛。
第四方面,本申请实施例公开了一种计算机可读存储介质,上述计算机可读存储介质中存储有计算机程序,当上述计算机程序在一个或多个处理器上运行时,执行以下方法:
获取待处理数据集,上述待处理数据集包含的样本包括正样本和负样本;
从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量;
在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量;
将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值;
根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值;
在前后两次获得的总损失值之间的差值小于阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛。
本申请一方面可以提高模型训练的难度,从而提升模型的参数的更新效率;另一方面可以提高模型对边界数据的处理能力,从而提高模型的鲁棒性。
附图说明
为了更清楚地说明本申请实施例或背景技术中的技术方案,下面将对本申请实施例或背景技术中所需要使用的附图作简单的介绍。
图1是本申请实施例公开的一种模型训练方法的流程示意图;
图2是本申请实施例公开的一种损失值计算方法的流程示意图;
图3是本申请实施例公开的另一种模型训练方法的流程示意图;
图4是本申请实施例公开的又一种模型训练方法的流程示意图;
图5是本申请实施例公开的一种模型训练的装置的结构示意图;
图6是本申请实施例公开的一种服务器的结构示意图。
具体实施方式
为了使本申请的目的、技术方案和优点更加清楚,下面将结合附图对本申请作进一步地描述。
本申请的说明书、权利要求书及附图中的术语“第一”和“第二”等仅用于区别不同对象,而不是用于描述特定顺序。此外,术语“包括”和“具有”以及它们的任何变形,意图在于覆盖不排他的包含。例如包含了一系列步骤或单元的过程、方法、系统、产品或设备等,没有限定于已列出的步骤或单元,而是可选地还包括没有列出的步骤或单元等,或可选地还包括对于这些过程、方法、产品或设备等固有的其它步骤或单元。
在本文中提及的“实施例”意味着,结合实施例描述的特定特征、结构或特性可以包 含在本申请的至少一个实施例中。在说明书中的各个位置出现上述短语并不一定均是指相同的实施例,也不是与其它实施例互斥的独立的或备选的实施例。本领域技术人员可以显式地和隐式地理解的是,本文所描述的实施例可以与其它实施例相结合。
在本申请中,“至少一个(项)”是指一个或者多个,“多个”是指两个或两个以上,“至少两个(项)”是指两个或三个及三个以上,“和/或”,用于描述关联对象的关联关系,表示可以存在三种关系,例如,“A和/或B”可以表示:只存在A,只存在B以及同时存在A和B三种情况,其中A,B可以是单数或者复数。字符“/”一般表示前后关联对象是一种“或”的关系。“以下至少一项(个)”或其类似表达,是指这些项中的任意组合。例如,a,b或c中的至少一项(个),可以表示:a,b,c,“a和b”,“a和c”,“b和c”,或“a和b和c”。
本申请的技术方案涉及人工智能和/或大数据技术领域,如可具体涉及神经网络技术,可应用于信息检索等场景。可选的,本申请涉及的数据如样本、输出值和/或损失值等可存储于数据库中,或者可以存储于区块链中,本申请不做限定。
本申请实施例提供了一种模型训练的方法,通过改进模型训练中样本的生成方法,提高模型训练的难度,从而增强模型的鲁棒性。为了更清楚地描述本申请的方案,接下来将结合本申请实施例中的附图对本申请实施例进行描述。
请参阅图1,图1是本申请实施例公开的一种模型训练方法的流程示意图。如图1所示,上述方法包括:
S101:获取待处理数据集。
对于基于深度学习的模型来说,需要大量的数据对模型进行训练。上述训练模型的数据的集合可以称为数据集,为了模型的训练更加有效,需要对数据集进行处理,得到对模型的训练有效的正样本和负样本,本申请的处理方法包括:
从微软的文档排序任务公开数据集获取初始数据,上述数据集可以表示为M={(q 1,s 1),(q 2,s 2),…,(q i,s i),…,(q n,s n)},其中,q i表示用户搜索的文本,即查询项;s i表示搜索引擎返回的结果,即文档列表,n表示上述数据集中查询项的个数。对于任意一个由搜索引擎返回的结果s i可以表示为:s i={(d i1,l i1),(d i2,l i2),…,(d ij,l ij),…,(d im,l im)},其中,d ij表示第i个查询项对应的第j个搜索结果,即文档项;l ij为标签项,l ij为1时,表示用户点击了该搜索结果,l ij为0时,表示用户未点击该搜索结果;m表示该搜索结果中文档项的个数。
由于上述数据集中数据量庞大,每个查询项对应的文档项中可能包含了冗余信息,需要对上述数据集进行处理,得到对模型训练有效的正样本和负样本,具体的处理步骤包括:
1、对数据集中的查询项和文档项分别进行分词。
在上述分词部分,可以采用的分词工具包括结巴分词、盘古分词等,可以采用结巴分词模型对上述查询项和文档项分别进行分词,结巴分词基于前缀词典实现高效的词图扫描,生成句子中汉字所有可能成词情况构成的有向无环图,再动态规划查找最大概率路径,找出基于词频的最大切分组合,由于上述结巴分词是一种非常典型的分词工具,具体原理这里不再赘述。
2、去掉上述分词的结果中的停用词。
常见的停用词表包括哈工大停用词表、百度停用词表、四川大学机器智能实验室停用词库等,可以首先采用哈工大停用词表对上述分词的结果进行初步过滤,再根据正则表达式以及人工筛选的方式,过滤掉上述分词结果中高频的网络词汇。由于上述过滤停用词的 方法是非常典型的处理步骤,这里不再赘述。
3、对上述去除停用词的结果进行字体转换。
其中,可以采用opencc工具包将文本中的繁体中文转化为简体,同时,将文本中的大写英文字母转换为小写英文字母。由于上述方法是非常典型的处理步骤,这里不再赘述。
4、对数据集中的样本进行过滤。
一般情况下,对于某个具体的查询项q i,搜索引擎返回的文档列表s i的文档数量较大,并且包含大量无用的结果,可以采用文本匹配模型过滤的方式对数据集进行过滤,筛选出数据集中质量较高的结果。对于任一查询项q i和与之对应的文档d ij,两者的匹配分数可以表示为:
r ij=F(q i,d ij)
其中,r ij表示查询项q i和与之对应的文档d ij的匹配分数,F为文本匹配模型BM25,上述BM25是一种用来评价搜索词和文档之间相关性的算法,通过对查询项进行分词,对每个分词与文档项的相关度对应的值进行加权求和得到查询项与文档项之间的匹配分数,由于上述BM25算法是典型的算法,具体原理这里不再赘述;可以对上述匹配分数设置阈值,将匹配分数超过阈值的查询项和文档项保留,作为数据集的样本。
通过上述文本匹配模型BM25对数据集进行过滤后,每一个查询项q i对应的文档依然包含多个,但是相较于过滤之前,过滤之后的文档列表的文档数量少于过滤之前的文档列表的文档数量,并且,文档列表中的文档相对于查询项为质量较高的文档。
对于通过上述文本匹配模型BM25筛选出来的任一文档d ij,将与之对应的查询项q i和对应的用户点击结果l ij组成的三元组(q i,d ij,l ij)作为数据集中的一个样本,在l ij为1时,表示用户点击了该搜索结果,该样本为正样本,l ij为0时,表示用户未点击该搜索结果,该样本为负样本。过滤之后含有N个样本的数据集可以表示为:
M={(q 1,d 1,l 1),(q 2,d 2,l 2),…,(q i,d i,l i),…,(q N,d N,l N)}
S102:从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量。
其中,上述待训练样本为上述待处理数据集中任意一个样本;与步骤101中数据集的处理结果相对应,上述待训练样本包括查询项、文档项、标签项。
使用第一模型获得用于表示上述待训练样本的向量的具体实现方式为,将上述待训练样本的查询项和文档项输入第一模型,分别得到上述待训练样本的查询项对应的向量和上述待训练样本的文档项对应的向量,所以,用于表示上述待训练样本的向量包括上述查询项对应的向量和上述文档项对应的向量;上述第一模型包括循环神经网络(Recurrent Neural Network,RNN)模型,需要说明的是,除了上述RNN模型,还可以采用上述RNN模型变体模型:长短期记忆(Long short-term memory,LSTM)模型和门控循环单元(Gated recurrent unit,GRU)模型等;将上述查询项和文档项输入第一模型得到的用于表示上述待训练样本的向量相比未输入上述第一模型之前的查询项和文档项,包含了更多的信息,可以提高模型训练的效率。
S103:在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量。
其中,可以通过标签项的取值来确定上述待训练样本是否为正样本,在上述待训练样 本为1的情况下,上述待训练样本为正样本;在上述待训练样本为0的情况下,上述待训练样本为负样本,后续对待训练样本的类别判断方法相同。
将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量,包括将上述待训练样本的查询项的向量和上述待训练样本的文档项的向量合并,再输入第二模型,生成用于表示上述待训练样本的文档项对应的对抗文档的向量,得到用于表示上述待训练样本的对抗样本的向量,其中,用于表示上述待训练样本的对抗样本的向量包括上述待训练样本的查询项对应的向量和上述用于表示上述待训练样本的的文档项对应的对抗文档的向量。
上述第二模型包括变分编解码模型(Variational Encoder-Decoder,VED),需要说明的是,除了上述VED模型用作对抗样本的生成器模型之外,还可以采用生成式对抗网络(Generative Adversarial Network,GAN)模型、自然语言处理领域的生成式的预训练(Generative Pre-Training,GPT)系列的模型,由于上述GAN模型、GPT系列的模型本身的计算量和参数量较大,训练难度较大,对于训练样本的文档相对较短的情况,上述VED模型为优选模型。
S104:将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值。
其中,上述第三模型包括深度神经网络(Deep Neural Networks,DNN)模型;上述输出值为一个维数大于或等于2的向量,本申请对上述向量具体的维数不作任何限定,将上述向量设定为二维向量为本申请的优选方式。
特别地,将上述待训练样本的对抗样本的向量输入第三模型时,将上述待训练样本的对抗样本的向量标记为负样本,以提高数据集中负样本的质量,从而提高模型训练的效率。
特别地,在上述待训练的样本为负样本的情况下,将上述用于表示上述待训练样本的向量输入第三模型,得到输出值。
S105:根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值。
需要说明的是,由第三模型输出的值都统一称为输出值,根据上述输出值确定的分损失值都统一称为上述待训练样本的分损失值,不特殊区分上述数据集中样本是否经过对抗样本的生成处理。
上述步骤的具体实现过程请参阅图2,图2是本申请实施例公开的一种损失值计算方法的流程示意图,如图所示,上述方法包括:
S201:将输出值对应的向量作为第一向量,将待训练样本的标签进行独热编码得到第二向量。
根据上述步骤104的描述,输出值对应的向量优选为一个二维向量,即第一向量;将待训练样本的标签进行独热编码,可以得到一个与输出值对应的向量相同维数的二维向量,即第二向量;独热编码用于优化离散型特征之间的距离,由于上述独热编码是一种常见的编码方式,具体原理不再赘述;具体实现过程中,在上述待训练样本为正样本,即标签为1的情况下,经过独热编码得到向量[1,0],在上述待训练样本为负样本,即标签为0的情况下,经过独热编码得到向量[0,1]。
S202:将上述第一向量与上述第二向量中相同维数的值相乘再相加,得到上述待训练样本的分损失值。
由于上述第一向量与上述第二向量的维数相同,且优选为二维向量,将上述两个向量第一维的数据相乘,得到第一结果,将上述两个向量的第二维的数据相乘得到第二结果,上述第一结果加上上述第二结果即为上述待训练样本的分损失值。
特别地,具体过程中上述方法包括,在上述用于表示上述待训练样本的对抗样本的向量输入DNN模型后,首先得到一个二维预测向量,再将上述二维预测向量输入softmax(一 种逻辑回归模型)层将上述二维预测向量中的每个值映射为大于0小于1之间的数,作为上述输出值对应的向量,即上述第一向量。例如上述DNN模型首先输出的二维预测向量为[1,1],经过softmax层之后,上述向量被转换为[0.5,0.5]。
上述将上述第一向量和上述第二向量相同维数的数据相乘,再对上述结果求和得到上述待训练样本的分损失值包括,先将上述第二向量的每个维数对应的数值取对数,优选情况下,取以10为底数的对数,将上述取对数后的向量与上述第一向量相同维数的数据相乘,最后对上述结果求和,将上述求和结果的相反数作为上述待训练样本的分损失值。
例如,上述待训练样本为正样本,那么,对标签进行独热编码后得到的向量为[1,0],上述待训练样本经过上述DNN模型首先输出的二维预测向量为[1,1],经过softmax层的处理之后,上述向量被转换为[0.5,0.5],即上述第二向量为[0.5,0.5],上述第一向量为[1,1];首先对上述第二向量取以10为底数的对数,得到向量[log0.5,log0.5],上述向量[log0.5,log0.5]与上述第一向量[1,1]的维数相同,在上述步骤中,向量中第一维的数据相乘即为1乘以log0.5,向量中第二维的数据相乘即为1乘以log0.5,那么,最后上述待训练样本的分损失值为-(1*log0.5+1*log0.5)。
S203:计算上述待处理数据集中所有样本的分损失值求和,得到总损失值。
将上述数据集中所有样本对应的分损失相加,即为一次训练中得到的总损失值。上述总损失值的计算公式可以为:
Figure PCTCN2021083815-appb-000001
其中,L表示一次训练中,数据集中所有样本的损失值之和,即数据集的总损失值;N表示数据集中样本总数;y i是上述待训练样本输入上述DNN模型得到二维预测向量;i表示上述待训练样本本为待处理数据集中第i个样本;k表示取向量的第k维的数据,而不是常规的取次方运算,例如,对于向量[1,2,3],[1,2,3] 1表示取向量的第一维的值1,而[1,2,3] 2表示取向量的第二维的值2;ll i是原始标签l i通过独热编码得到二维向量,在l i=1的情况下,ll i=[1,0];在l i=0的情况下,ll i=[0,1]。例如,经过softmax层之后得到的向量为[0.2,0.8],上述向量ll i=[1,0],那么上述待训练样本的分损失值为-(1*log0.2+0*log0.8)。
将上述待处理数据集中N个样本的分损失值求和得到总损失值,本申请实施例中,优选利用亚当优化器(Adam optimizer)和pytorch(一种机器学习库)框架进行训练,迭代更新模型参数,在相邻的两次训练中,上述总损失值之间的差值的绝对值小于第一阈值,即可确定上述第一模型、第二模型、第三模型收敛;上述第一阈值为大于0的数,一般情况下,上述第一阈值取0.01,为了提升模型训练的效果,也可以取比0.01更小的值,比如0.001等,本申请不作任何限制。
在一种可能的实现方式中,对待处理数据集中的正样本进行生成对抗样本处理时,只从正样本集合中随机抽取一部分进行生成对抗样本处理,这样既保证了模型训练的负样本为高质量样本,又可以控制模型训练的难度。请参阅图3,图3是本申请实施例公开的另一种模型训练方法的流程示意图,如图3所示,上述方法包括:
S301:获取待处理数据集。
S302:从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量。
上述步骤301和步骤302在前文中已经给出解释,这里不再赘述。
S303:在上述待训练样本为正样本的情况下,对服从伯努利分布的随机变量抽取参考值。
由于只需要对正样本进行生成对抗样本处理,那么在上述待训练样本为正样本的情况下,对服从伯努利分布的随机变量抽取参考值。其中,伯努利分布是一种离散型概率分布,如果随机变量服从参数为P的伯努利分布,那么,随机变量分别以概率P取1为值,以概率1-P取0为值;本申请实施例中,伯努利分布服从的参数P小于第二阈值,上述第二阈值为大于0且小于1的数。
S304:在上述参考值为1的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量。
由于上述随机变量分别以概率P取1为值,在上述参考值为1的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量,即在上述参考值为1的情况下,对上述正样本进行生成对抗样本处理;那么,对于任意一个正样本来说,需要进行生成对抗样本处理的概率为P,对于上述待处理数据集的正样本的集合来说,相当于抽取100*P%的正样本进行生成对抗样本处理;相对应的,在上述参考是为0的情况下,将上述表示上述待训练样本的向量输入第三模型,得到输出值。具体步骤在前文已经解释,这里不再赘述。
优选情况下,上述随机变量服从参数为0.5的伯努利分布,那么在上述待训练样本为正样本的情况下,对服从参数为0.5的伯努利分布抽取参考值时,对于任意一个正样本,有0.5的概率需要进行生成对抗样本处理,对于上述待处理数据集的正样本的集合来说,相当于从上述正样本的集合中随机抽取一半的正样本进行生成对抗样本的处理。将伯努利分布服从的参数设置为0.5,可以让模型训练的难度适中,提高模型训练的效率。
需要说明的是,除了上述通过从伯努利分布中抽取参考值,再通过参考值确定是否对正样本进行对抗样本生成处理之外,也可以采用其他的概率分布,只需要根据实验要求对参考值设置条件即可。比如从标准正态分布中抽取参考值,在参考值大于0的情况下对正样本进行生成对抗样本处理,这样的方法同样可以实现将待处理数据集中一半的正样本进行对抗样本生成处理;或者从在0到1上服从均匀分布的随机变量中抽取参数值,在上述参考值大于或等于0.3且小于或等于1的情况下,对正样本进行对抗样本生成处理,这样就相当于从正样本的集合中随机抽取70%的正样本进行对抗样本生成处理,也可以在参考值大于或等于0.8且小于或等于1的情况下,对正样本进行对抗样本生成处理,这样就相当于从正样本的集合中随机抽取20%的正样本进行对抗样本生成处理,以此达到控制模型训练难度的目的。
S305:将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值
S306:根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值。
S307:在前后两次获得的总损失值之间的差值小于阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛。
上述步骤305、步骤306、步骤307在前文中已经给出解释,这里不再赘述。
在一种可能的实施方式中,上述用第一模型获得用于表示上述待训练样本的向量,包括:将上述查询项和上述文档项输入循环神经网络模型得到上述查询项对应的向量作为第三向量,以及得到上述文档项对应的向量作为第四向量。
在一种可能的实施方式中,上述将上述用于表示上述待训练样本的向量输入第三模型,得到输出值,包括:将上述第三向量和上述第四向量合并得到第五向量,将上述第五向量输入深度卷积模型得到向量作为输出值。
在一种可能的实施方式中,上述用于表示上述待训练样本的对抗样本的向量包括上述第三向量,以及表示上述待训练样本对应的对抗文档的向量。
在一种可能的实施方式中,上述将上述用于表示上述待训练样本的对抗样本的向量输 入第三模型,得到输出值,包括:将上述第三向量和上述表示上述待训练样本对应的对抗文档的向量合并得到第六向量,将上述第六向量输入第三模型得到的向量作为输出值。
以上对本申请实施提供的方法中各个步骤进行了详细的解释,接下来对本申请实施例提供的方法做整体的介绍,请参阅图4,图4是本申请实施例公开的又一种模型训练方法的流程示意图,如图所示,上述方法包括:
首先构造模型训练需要的数据集,数据集中的样本包括正样本和负样本,每条样本包括查询项,文档项、标签项,标签项用于表示样本的类别,在标签为1的情况下,样本为正样本,在标签为0的情况下,样本为负样本,具体构造步骤请参阅前文步骤101部分的说明。
然后对数据集进行分词、去停用词、字体转换以及过滤处理,得到样本总数为N的数据集,如图4中步骤401、步骤402、步骤403、步骤404,对上述步骤的具体解释清参阅前文步骤101部分的说明,对数据集进行上述处理之后,即步骤404之后的查询项和文档项相较于与步骤401之前的查询项和文档项包含的信息为有效信息,有利于模型的训练。
对于数据集中每一条样本,仅仅采取三种处理方式中的一种,具体采取哪种方式取决于样本的类别,即样本为正样本还是负样本,以及正样本的集合中需要进行对抗样本生成处理的正样本数量。对于数据集中每一条样本具体流程如下:
将查询项和文档项输入RNN模型,分别得到查询项对应的向量和文档项对应的向量,如图4中步骤405。
对上述样本的标签进行判断,确定上述样本的类别,如图4中步骤406。
在上述样本为负样本,即标签为0的情况下,对样本不作任何处理,直接将上述查询项对应的向量和文档项对应的向量输入DNN模型,得到样本对应的输出向量。
在上述样本为正样本,即标签为1的情况下,根据随机变量的取值决定是否对上述正样本进行对抗样本生成处理,从而控制对抗样本生成比例,达到控制模型训练难度的目的。本申请实施例中优选将正样本的集合中,一半的正样本进行对抗样本生成处理,即对服从参数为0.5的伯努利分布的随机变量抽取参考值,如图4中步骤407。
在上述参考值为1的情况下,将上述查询项对应的向量和文档项对应的向量输入VED模型,得到上述正样本的对抗文档对应的向量,如图4中步骤408,再将上述查询项对应的向量和上述对抗文档对应的向量输入DNN模型,并且,将上述正样本标记为负样本,得到样本对应的输出向量。
在上述参考值为0的情况下,对样本不作任何处理,直接将上述查询项对应的向量和文档项对应的向量输入DNN模型,得到样本对应的输出向量。
再根据上述输出向量,利用交叉损失函数计算上述样本对应的分损失值。最后根据上述分损失值计算一次训练中的总损失值,本申请实施例中,优选利用Adam optimizer和pytorch框架进行训练,迭代更新模型参数,直到模型收敛。
需要说明的是,对数据集的样本进行第一次遍历之后,模型的参数会有相应的更新,在后续对模型的训练中,对进行对抗样本生成的正样本的比例的选择可以采取与第一次相同的方法,也可以根据实验要求对上述比例进行调整,比如第一次训练中,对正样本的集合中50%的正样本进行生成对抗样本处理,在第二次训练中,对正样本的集合中60%的正样本进行生成对抗样本处理,后续步骤中依次增加,以此循序渐进增加模型训练的难度。
综上所述,本申请提出的模型训练方法,基于数据集中的正样本生成对抗样本,可以提高生成的对抗文档与正样本原有文档的相似度;将对抗样本作为数据集的负样本,可以提高模型训练中负样本的质量;利用包含以对抗样本为负样本的数据集对模型进行训练,一方面可以提高模型训练的难度,从而提升模型的参数的更新效率;另一方面可以提高模型对边界数据的处理能力,从而提高模型的鲁棒性。
在模型的训练过程中,上述VED模型的参数也会更新,经过充分训练的VED模型可以单数拆解出来,直接用于给定的正样本的对抗样本生成处理,从而提高模型训练的效率,缩短项目的生命周期。
上述详细阐述了本申请实施例的方法,下面提供本申请实施例的装置。
请参阅图5,图5为本申请实施例公开的一种模型训练的装置的结构示意图,上述数据转发的装置110可以包括获取单元501、处理单元502、计算单元503,确定单元504,其中,各个单元的描述如下:
获取单元501,用于获取待处理数据集,从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量,上述待处理数据集包含的样本包括正样本和负样本;
处理单元502,用于在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量;将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值;
计算单元503,用于根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值;
确定单元504,用于在前后两次获得的总损失值之间的差值小于阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛。
在一种可能的实施方式中,上述装置还包括:
标记单元505,用于将上述对抗样本标记为上述待处理数据集的负样本。
在一种可能的实施方式中,上述处理单元502,还用于在上述待训练样本为正样本的情况下,对服从伯努利分布的随机变量抽取参考值,所述伯努利分布服从的参数小于第二阈值;在上述参考值为1的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量。
在一种可能的实施方式中,上述处理单元502,还用于在上述参考值为0的情况下,将上述表示上述待训练样本的向量输入第三模型,得到输出值。
在一种可能的实施方式中,上述处理单元502,还用于在上述待训练样本为负样本的情况下,将上述用于表示上述待训练样本的向量输入第三模型,得到输出值。
在一种可能的实施方式中,上述待训练样本包含查询项和文档项;上述表示上述待训练样本的向量包括:上述查询项对应的向量,以及上述文档项对应的向量;上述第一模型包括循环神经网络模型,上述第二模型包括变分编解码模型,上述第三模型包括深度神经网络模型。
在一种可能的实施方式中,上述处理单502元,还用于将上述查询项和上述文档项输入循环神经网络模型得到上述查询项对应的向量作为第三向量,以及得到上述文档项对应的向量作为第四向量。
在一种可能的实施方式中,上述处理单元502,还用于将上述第三向量和上述第四向量合并得到第五向量,将上述第五向量输入深度卷积模型得到向量作为输出值。
在一种可能的实施方式中,上述用于表示上述待训练样本的对抗样本的向量包括上述第三向量,以及表示上述待训练样本对应的对抗文档的向量。
在一种可能的实施方式中,上述处理单元502,还用于将上述第三向量和上述表示上述待训练样本对应的对抗文档的向量合并得到第六向量,将上述第六向量输入第三模型得到的向量作为输出值。
在一种可能的实施方式中,上述装置还包括:
编码单元506,用于将上述输出值对应的向量作为第一向量,将上述待训练样本的标签进行独热编码得到第二向量,上述第一向量与上述二向量的向量维数相同;
上述计算单元503,还用于将上述第一向量与上述第二向量中相同维数的值相乘再相加,得到上述待训练样本的分损失值;计算上述待处理数据集中所有样本的分损失值求和,得到总损失值。
综上所述,本申请提出的模型训练方法,基于数据集中的正样本生成对抗样本,可以提高生成的对抗文档与正样本原有文档的相似度;将对抗样本作为数据集的负样本,可以提高模型训练中负样本的质量;利用包含以对抗样本为负样本的数据集对模型进行训练,一方面可以提高模型训练的难度,从而提升模型的参数的更新效率;另一方面可以提高模型对边界数据的处理能力,从而提高模型的鲁棒性。
请参阅图6,图6是本申请实施例公开的一种服务器的结构示意图。上述服务器60可以包括存储器601、处理器602。进一步可选的,还可以包含通信接口603以及总线604,其中,存储器601、处理器602以及通信接口603通过总线604实现彼此之间的通信连接。通信接口603用于与时空数据查询装置进行数据交互。
其中,存储器601用于提供存储空间,存储空间中可以存储操作系统和计算机程序等数据。存储器601包括但不限于是随机存储记忆体(random access memory,RAM)、只读存储器(read-only memory,ROM)、可擦除可编程只读存储器(erasable programmable read only memory,EPROM)、或便携式只读存储器(compact disc read-only memory,CD-ROM)。
处理器602是进行算术运算和逻辑运算的模块,可以是中央处理器(central processing unit,CPU)、显卡处理器(graphics processing unit,GPU)或微处理器(microprocessor unit,MPU)等处理模块中的一种或者多种的组合。
存储器601中存储有计算机程序,处理器602调用存储器601中存储的计算机程序,以执行以下操作:
获取待处理数据集,上述待处理数据集包含的样本包括正样本和负样本;
从上述待处理数据集中获取待训练样本,使用第一模型获得用于表示上述待训练样本的向量;
在上述待训练样本为正样本的情况下,将上述表示上述待训练样本的向量输入第二模型生成用于表示上述待训练样本的对抗样本的向量;
将上述用于表示上述待训练样本的对抗样本的向量输入第三模型,得到输出值;
根据上述输出值确定上述待训练样本的分损失值,计算上述待处理数据集中所有样本的分损失值求和,得到总损失值;
在前后两次获得的总损失值之间的差值小于阈值的情况下,确定上述第一模型,上述第二模型和上述第三模型收敛。
需要说明的是,服务器60的具体实现还可以对应参照图2、图3、图4所示的方法实施例的相应描述。
本申请实施例还提供一种计算机可读存储介质,上述计算机可读存储介质中存储有计算机程序,当上述计算机程序在一个或多个处理器上运行时,可以实现图1、图2、图3以及图4所示的模型训练的方法。
可选的,本申请涉及的存储介质如计算机可读存储介质可以是非易失性的,也可以是易失性的。
本申请实施例还提供了一种计算机程序产品,上述计算机程序产品包括程序指令,上述程序指令当被处理器执行时使上述处理器执行上述实施例中方法的部分或全部步骤,此处不赘述。
综上所述,本申请提出的模型训练方法,基于数据集中的正样本生成对抗样本,可以提高生成的对抗文档与正样本原有文档的相似度;将对抗样本作为数据集的负样本,可以提高模型训练中负样本的质量;利用包含以对抗样本为负样本的数据集对模型进行训练, 一方面可以提高模型训练的难度,从而提升模型的参数的更新效率;另一方面可以提高模型对边界数据的处理能力,从而提高模型的鲁棒性。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,上述流程可以由计算机程序相关的硬件完成,上述计算机程序可存储于计算机可读取存储介质中,上述计算机程序在执行时,可包括如上述各方法实施例的流程。而前述的存储介质包括:只读存储器ROM或随机存储记忆体RAM、磁碟或者光盘等各种可存储计算机程序代码的介质。

Claims (20)

  1. 一种模型训练的方法,包括:
    获取待处理数据集,所述待处理数据集包含的样本包括正样本和负样本;
    从所述待处理数据集中获取待训练样本,使用第一模型获得用于表示所述待训练样本的向量;
    在所述待训练样本为正样本的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量;
    将所述用于表示所述待训练样本的对抗样本的向量输入第三模型,得到输出值;
    根据所述输出值确定所述待训练样本的分损失值,计算所述待处理数据集中所有样本的分损失值求和,得到总损失值;
    在前后两次获得的总损失值之间的差值小于第一阈值的情况下,确定所述第一模型,所述第二模型和所述第三模型收敛。
  2. 根据权利要求1所述的方法,其中,所述方法还包括:
    将所述对抗样本标记为所述待处理数据集的负样本。
  3. 根据权利要求2所述的方法,其中,所述在所述待训练样本为正样本的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量,包括:
    在所述待训练样本为正样本的情况下,对服从伯努利分布的随机变量抽取参考值,所述伯努利分布服从的参数小于第二阈值;
    在所述参考值为1的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量。
  4. 根据权利要求3所述的方法,其中,所述方法还包括:
    在所述参考值为0的情况下,将所述表示所述待训练样本的向量输入第三模型,得到输出值。
  5. 根据权利要求4所述的方法,其中,所述方法还包括:
    在所述待训练样本为负样本的情况下,将所述用于表示所述待训练样本的向量输入第三模型,得到输出值。
  6. 根据权利要求5所述的方法,其中,所述待训练样本包含查询项和文档项;
    所述表示所述待训练样本的向量包括:所述查询项对应的向量,以及所述文档项对应的向量;
    所述第一模型包括循环神经网络模型,所述第二模型包括变分编解码模型,所述第三模型包括深度神经网络模型。
  7. 根据权利要求6所述的方法,其中,所述根据所述输出值确定所述待训练样本的分损失值,计算所述待处理数据集中所有样本的分损失值求和,得到总损失值,包括:
    将所述输出值对应的向量作为第一向量,将所述待训练样本的标签进行独热编码得到第二向量,所述第一向量与所述二向量的向量维数相同;
    将所述第一向量与所述第二向量中相同维数的值相乘再相加,得到所述待训练样本的分损失值;
    计算所述待处理数据集中所有样本的分损失值求和,得到总损失值。
  8. 一种模型训练的装置,其中,所述装置包括:
    获取单元,用于获取待处理数据集,从所述待处理数据集中获取待训练样本,使用第一模型获得用于表示所述待训练样本的向量,所述待处理数据集包含的样本包括正样本和负样本;
    处理单元,用于在所述待训练样本为正样本的情况下,将所述表示所述待训练样本的 向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量;将所述用于表示所述待训练样本的对抗样本的向量输入第三模型,得到输出值;
    计算单元,用于根据所述输出值确定所述待训练样本的分损失值,计算所述待处理数据集中所有样本的分损失值求和,得到总损失值;
    确定单元,用于在前后两次获得的总损失值之间的差值小于阈值的情况下,确定所述第一模型,所述第二模型和所述第三模型收敛。
  9. 一种服务器,其中,所述服务器包括处理器和存储器,其中,所述存储器中存储有计算机程序,所述处理器调用所述存储器中存储的计算机程序,用于执行以下方法:
    获取待处理数据集,所述待处理数据集包含的样本包括正样本和负样本;
    从所述待处理数据集中获取待训练样本,使用第一模型获得用于表示所述待训练样本的向量;
    在所述待训练样本为正样本的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量;
    将所述用于表示所述待训练样本的对抗样本的向量输入第三模型,得到输出值;
    根据所述输出值确定所述待训练样本的分损失值,计算所述待处理数据集中所有样本的分损失值求和,得到总损失值;
    在前后两次获得的总损失值之间的差值小于第一阈值的情况下,确定所述第一模型,所述第二模型和所述第三模型收敛。
  10. 根据权利要求9所述的服务器,其中,所述处理器还用于执行:
    将所述对抗样本标记为所述待处理数据集的负样本;
    执行所述在所述待训练样本为正样本的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量,包括:
    在所述待训练样本为正样本的情况下,对服从伯努利分布的随机变量抽取参考值,所述伯努利分布服从的参数小于第二阈值;
    在所述参考值为1的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量。
  11. 根据权利要求10所述的服务器,其中,所述处理器还用于执行:
    在所述参考值为0的情况下,将所述表示所述待训练样本的向量输入第三模型,得到输出值。
  12. 根据权利要求11所述的服务器,其中,所述处理器还用于执行:
    在所述待训练样本为负样本的情况下,将所述用于表示所述待训练样本的向量输入第三模型,得到输出值。
  13. 根据权利要求12所述的服务器,其中,所述待训练样本包含查询项和文档项;
    所述表示所述待训练样本的向量包括:所述查询项对应的向量,以及所述文档项对应的向量;
    所述第一模型包括循环神经网络模型,所述第二模型包括变分编解码模型,所述第三模型包括深度神经网络模型。
  14. 根据权利要求13所述的服务器,其中,执行所述根据所述输出值确定所述待训练样本的分损失值,计算所述待处理数据集中所有样本的分损失值求和,得到总损失值,包括:
    将所述输出值对应的向量作为第一向量,将所述待训练样本的标签进行独热编码得到第二向量,所述第一向量与所述二向量的向量维数相同;
    将所述第一向量与所述第二向量中相同维数的值相乘再相加,得到所述待训练样本的分损失值;
    计算所述待处理数据集中所有样本的分损失值求和,得到总损失值。
  15. 一种计算机可读存储介质,其中,所述计算机可读存储介质中存储有计算机程序,当所述计算机程序在一个或多个处理器上运行时,执行以下方法:
    获取待处理数据集,所述待处理数据集包含的样本包括正样本和负样本;
    从所述待处理数据集中获取待训练样本,使用第一模型获得用于表示所述待训练样本的向量;
    在所述待训练样本为正样本的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量;
    将所述用于表示所述待训练样本的对抗样本的向量输入第三模型,得到输出值;
    根据所述输出值确定所述待训练样本的分损失值,计算所述待处理数据集中所有样本的分损失值求和,得到总损失值;
    在前后两次获得的总损失值之间的差值小于第一阈值的情况下,确定所述第一模型,所述第二模型和所述第三模型收敛。
  16. 根据权利要求15所述的计算机可读存储介质,其中,所述计算机程序在一个或多个处理器上运行时还用于执行:
    将所述对抗样本标记为所述待处理数据集的负样本;
    执行所述在所述待训练样本为正样本的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量,包括:
    在所述待训练样本为正样本的情况下,对服从伯努利分布的随机变量抽取参考值,所述伯努利分布服从的参数小于第二阈值;
    在所述参考值为1的情况下,将所述表示所述待训练样本的向量输入第二模型生成用于表示所述待训练样本的对抗样本的向量。
  17. 根据权利要求16所述的计算机可读存储介质,其中,所述计算机程序在一个或多个处理器上运行时还用于执行:
    在所述参考值为0的情况下,将所述表示所述待训练样本的向量输入第三模型,得到输出值。
  18. 根据权利要求17所述的计算机可读存储介质,其中,所述计算机程序在一个或多个处理器上运行时还用于执行:
    在所述待训练样本为负样本的情况下,将所述用于表示所述待训练样本的向量输入第三模型,得到输出值。
  19. 根据权利要求18所述的计算机可读存储介质,其中,所述待训练样本包含查询项和文档项;
    所述表示所述待训练样本的向量包括:所述查询项对应的向量,以及所述文档项对应的向量;
    所述第一模型包括循环神经网络模型,所述第二模型包括变分编解码模型,所述第三模型包括深度神经网络模型。
  20. 根据权利要求19所述的计算机可读存储介质,其中,执行所述根据所述输出值确定所述待训练样本的分损失值,计算所述待处理数据集中所有样本的分损失值求和,得到总损失值,包括:
    将所述输出值对应的向量作为第一向量,将所述待训练样本的标签进行独热编码得到第二向量,所述第一向量与所述二向量的向量维数相同;
    将所述第一向量与所述第二向量中相同维数的值相乘再相加,得到所述待训练样本的分损失值;
    计算所述待处理数据集中所有样本的分损失值求和,得到总损失值。
PCT/CN2021/083815 2020-11-12 2021-03-30 一种模型训练的方法及相关装置 WO2021204014A1 (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202011261109.2 2020-11-12
CN202011261109.2A CN112380319B (zh) 2020-11-12 2020-11-12 一种模型训练的方法及相关装置

Publications (1)

Publication Number Publication Date
WO2021204014A1 true WO2021204014A1 (zh) 2021-10-14

Family

ID=74583146

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/083815 WO2021204014A1 (zh) 2020-11-12 2021-03-30 一种模型训练的方法及相关装置

Country Status (2)

Country Link
CN (1) CN112380319B (zh)
WO (1) WO2021204014A1 (zh)

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114021739A (zh) * 2022-01-06 2022-02-08 北京达佳互联信息技术有限公司 业务处理、业务处理模型训练方法、装置及电子设备
CN116244416A (zh) * 2023-03-03 2023-06-09 北京百度网讯科技有限公司 生成式大语言模型训练方法、基于模型的人机语音交互方法

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112380319B (zh) * 2020-11-12 2023-10-17 平安科技(深圳)有限公司 一种模型训练的方法及相关装置
CN112927012A (zh) * 2021-02-23 2021-06-08 第四范式(北京)技术有限公司 营销数据的处理方法及装置、营销模型的训练方法及装置
CN113012153A (zh) * 2021-04-30 2021-06-22 武汉纺织大学 一种铝型材瑕疵检测方法
CN113656699B (zh) * 2021-08-25 2024-02-13 平安科技(深圳)有限公司 用户特征向量确定方法、相关设备及介质

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170323202A1 (en) * 2016-05-06 2017-11-09 Fujitsu Limited Recognition apparatus based on deep neural network, training apparatus and methods thereof
CN109800735A (zh) * 2019-01-31 2019-05-24 中国人民解放军国防科技大学 一种船目标精确检测与分割方法
CN110175615A (zh) * 2019-04-28 2019-08-27 华中科技大学 模型训练方法、域自适应的视觉位置识别方法及装置
CN111046866A (zh) * 2019-12-13 2020-04-21 哈尔滨工程大学 一种结合ctpn和svm的人民币冠字号区域检测方法
CN112380319A (zh) * 2020-11-12 2021-02-19 平安科技(深圳)有限公司 一种模型训练的方法及相关装置

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US8051072B2 (en) * 2008-03-31 2011-11-01 Yahoo! Inc. Learning ranking functions incorporating boosted ranking in a regression framework for information retrieval and ranking
JP6678930B2 (ja) * 2015-08-31 2020-04-15 インターナショナル・ビジネス・マシーンズ・コーポレーションInternational Business Machines Corporation 分類モデルを学習する方法、コンピュータ・システムおよびコンピュータ・プログラム
RU2637883C1 (ru) * 2016-06-20 2017-12-07 Общество С Ограниченной Ответственностью "Яндекс" Способ создания обучающего объекта для обучения алгоритма машинного обучения
CN111353554B (zh) * 2020-05-09 2020-08-25 支付宝(杭州)信息技术有限公司 预测缺失的用户业务属性的方法及装置

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170323202A1 (en) * 2016-05-06 2017-11-09 Fujitsu Limited Recognition apparatus based on deep neural network, training apparatus and methods thereof
CN109800735A (zh) * 2019-01-31 2019-05-24 中国人民解放军国防科技大学 一种船目标精确检测与分割方法
CN110175615A (zh) * 2019-04-28 2019-08-27 华中科技大学 模型训练方法、域自适应的视觉位置识别方法及装置
CN111046866A (zh) * 2019-12-13 2020-04-21 哈尔滨工程大学 一种结合ctpn和svm的人民币冠字号区域检测方法
CN112380319A (zh) * 2020-11-12 2021-02-19 平安科技(深圳)有限公司 一种模型训练的方法及相关装置

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114021739A (zh) * 2022-01-06 2022-02-08 北京达佳互联信息技术有限公司 业务处理、业务处理模型训练方法、装置及电子设备
CN116244416A (zh) * 2023-03-03 2023-06-09 北京百度网讯科技有限公司 生成式大语言模型训练方法、基于模型的人机语音交互方法

Also Published As

Publication number Publication date
CN112380319B (zh) 2023-10-17
CN112380319A (zh) 2021-02-19

Similar Documents

Publication Publication Date Title
WO2021204014A1 (zh) 一种模型训练的方法及相关装置
CN111310438B (zh) 基于多粒度融合模型的中文句子语义智能匹配方法及装置
WO2022198868A1 (zh) 开放式实体关系的抽取方法、装置、设备及存储介质
CN109815493B (zh) 一种智能嘻哈音乐歌词生成的建模方法
CN111310439B (zh) 一种基于深度特征变维机制的智能语义匹配方法和装置
CN110781306B (zh) 一种英文文本的方面层情感分类方法及系统
CN112800170A (zh) 问题的匹配方法及装置、问题的回复方法及装置
CN111159485B (zh) 尾实体链接方法、装置、服务器及存储介质
CN108875074A (zh) 基于交叉注意力神经网络的答案选择方法、装置和电子设备
CN110222173B (zh) 基于神经网络的短文本情感分类方法及装置
CN111274267A (zh) 一种数据库查询方法、装置及计算机可读取存储介质
CN111027292B (zh) 一种限定采样文本序列生成方法及其系统
CN112417894A (zh) 一种基于多任务学习的对话意图识别方法及识别系统
CN114298055B (zh) 基于多级语义匹配的检索方法、装置、计算机设备和存储介质
CN111563373B (zh) 聚焦属性相关文本的属性级情感分类方法
CN113609284A (zh) 一种融合多元语义的文本摘要自动生成方法及装置
CN116304748A (zh) 一种文本相似度计算方法、系统、设备及介质
US20220383119A1 (en) Granular neural network architecture search over low-level primitives
CN115759119A (zh) 一种金融文本情感分析方法、系统、介质和设备
CN113220862A (zh) 标准问识别方法、装置及计算机设备及存储介质
CN116258147A (zh) 一种基于异构图卷积的多模态评论情感分析方法及系统
CN116204622A (zh) 一种跨语言稠密检索中的查询表示增强方法
CN116186312A (zh) 用于数据敏感信息发现模型的多模态数据增强方法
CN113268657B (zh) 基于评论和物品描述的深度学习推荐方法及系统
CN115544999A (zh) 一种面向领域的并行大规模文本查重方法

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 21784656

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

122 Ep: pct application non-entry in european phase

Ref document number: 21784656

Country of ref document: EP

Kind code of ref document: A1