WO2020220544A1 - 不平衡数据分类模型训练方法、装置、设备及存储介质 - Google Patents
不平衡数据分类模型训练方法、装置、设备及存储介质 Download PDFInfo
- Publication number
- WO2020220544A1 WO2020220544A1 PCT/CN2019/103523 CN2019103523W WO2020220544A1 WO 2020220544 A1 WO2020220544 A1 WO 2020220544A1 CN 2019103523 W CN2019103523 W CN 2019103523W WO 2020220544 A1 WO2020220544 A1 WO 2020220544A1
- Authority
- WO
- WIPO (PCT)
- Prior art keywords
- data
- preset
- classification model
- training
- sample
- Prior art date
Links
- 238000013145 classification model Methods 0.000 title claims abstract description 113
- 238000012549 training Methods 0.000 title claims abstract description 107
- 238000000034 method Methods 0.000 title claims abstract description 104
- 230000009467 reduction Effects 0.000 claims abstract description 70
- 238000005070 sampling Methods 0.000 claims abstract description 55
- 238000004422 calculation algorithm Methods 0.000 claims abstract description 37
- 238000010801 machine learning Methods 0.000 claims abstract description 20
- 238000012545 processing Methods 0.000 claims abstract description 19
- 230000006399 behavior Effects 0.000 claims description 64
- 230000006870 function Effects 0.000 claims description 63
- 238000013136 deep learning model Methods 0.000 claims description 42
- 238000005457 optimization Methods 0.000 claims description 20
- 238000004364 calculation method Methods 0.000 claims description 11
- 238000001514 detection method Methods 0.000 claims description 9
- 238000002790 cross-validation Methods 0.000 claims description 8
- 238000000605 extraction Methods 0.000 claims description 8
- 238000013480 data collection Methods 0.000 claims description 2
- 230000008569 process Effects 0.000 description 11
- 238000007637 random forest analysis Methods 0.000 description 5
- 238000010586 diagram Methods 0.000 description 4
- 238000012795 verification Methods 0.000 description 4
- 230000009286 beneficial effect Effects 0.000 description 3
- 238000013135 deep learning Methods 0.000 description 3
- 230000008676 import Effects 0.000 description 3
- 238000012360 testing method Methods 0.000 description 3
- 230000002159 abnormal effect Effects 0.000 description 2
- 230000004913 activation Effects 0.000 description 2
- 238000013528 artificial neural network Methods 0.000 description 2
- 238000003745 diagnosis Methods 0.000 description 2
- 238000011478 gradient descent method Methods 0.000 description 2
- 238000000513 principal component analysis Methods 0.000 description 2
- 230000001360 synchronised effect Effects 0.000 description 2
- ORILYTVJVMAKLC-UHFFFAOYSA-N Adamantane Natural products C1C(C2)CC3CC1CC2C3 ORILYTVJVMAKLC-UHFFFAOYSA-N 0.000 description 1
- 210000004556 brain Anatomy 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 201000010099 disease Diseases 0.000 description 1
- 208000037265 diseases, disorders, signs and symptoms Diseases 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000010365 information processing Effects 0.000 description 1
- 238000012417 linear regression Methods 0.000 description 1
- 238000007477 logistic regression Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000003068 static effect Effects 0.000 description 1
- 238000012706 support-vector machine Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/213—Feature extraction, e.g. by transforming the feature space; Summarisation; Mappings, e.g. subspace methods
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2411—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on the proximity to a decision surface, e.g. support vector machines
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02T—CLIMATE CHANGE MITIGATION TECHNOLOGIES RELATED TO TRANSPORTATION
- Y02T10/00—Road transport of goods or passengers
- Y02T10/10—Internal combustion engine [ICE] based vehicles
- Y02T10/40—Engine management systems
Definitions
- This application relates to the field of information processing, in particular to methods, devices, equipment and storage media for training imbalanced data classification models.
- Unbalanced data refers to the unbalanced proportions of samples in different categories during training or classification. For example, in user fraud detection, the proportion of fraudulent behaviors is much smaller than the proportion of non-fraudulent behaviors. Unbalanced data is widely used in practical applications such as fault detection, defect detection, network intrusion detection, and medical diagnosis.
- the embodiments of the present application provide an imbalanced data classification model training method, device, equipment, and storage medium to solve the problem of inaccurate classification when the classification model obtained by imbalanced data training is used for classification.
- a training method for imbalanced data classification model including:
- the balance data is used as a training sample, and a preset machine learning algorithm is used to train the training sample to obtain a classification model.
- An unbalanced data classification model training device includes:
- Data acquisition module used to acquire unbalanced data from the preset sample library
- the dimensionality reduction module is used to perform dimensionality reduction processing on the unbalanced data according to a preset dimensionality reduction method to obtain dimensionality reduction low-dimensional data;
- a sampling module configured to sample the low-dimensional data according to a preset sampling method to obtain balanced data
- the training module is configured to use the balance data as training samples, and use a preset machine learning algorithm to train the training samples to obtain a classification model.
- a computer device including a memory, a processor, and computer-readable instructions stored in the memory and capable of running on the processor, and the processor implements the above-mentioned unbalanced data classification when the processor executes the computer-readable instructions Model training method.
- One or more non-volatile readable storage media storing computer readable instructions, which when executed by one or more processors, cause the one or more processors to execute the aforementioned unbalanced data Classification model training method.
- FIG. 1 is a schematic diagram of an application environment of an imbalanced data classification model training method in an embodiment of the present application
- Figure 2 is a flowchart of a method for training an imbalanced data classification model in an embodiment of the present application
- FIG. 3 is a flowchart of step S2 in the method for training an imbalanced data classification model in an embodiment of the present application
- FIG. 4 is a flowchart of step S3 in the method for training an unbalanced data classification model in an embodiment of the present application
- FIG. 5 is a flowchart of optimizing the parameters of the classification model in an embodiment of the present application.
- FIG. 6 is a flowchart of judging the result of an iterative operation in an embodiment of the present application.
- FIG. 7 is a schematic diagram of an unbalanced data classification model training device in an embodiment of the present application.
- Fig. 8 is a schematic diagram of a computer device in an embodiment of the present application.
- the imbalanced data classification model training method provided in this application can be applied to the application environment as shown in Figure 1.
- the application environment includes a server and a preset sample library, where the preset sample library is a database storing imbalanced data; service
- the end is a computer device that trains unbalanced data.
- the server can be a server or a server cluster; the server is connected to the preset sample library through a network, which can be a wired network or a wireless network.
- the imbalanced data classification model training method provided in the embodiment of the application is applied to the server.
- a method for training an imbalanced data classification model is provided.
- the specific implementation process includes the following steps:
- the preset sample library is a storage platform for storing unbalanced data.
- the preset sample library may be a database, including but not limited to various relational or non-relational databases, such as MS-SQL, Oracle, MySQL, Sybase, DB2, Redis, MongodDB, Hbase, etc.; or, preset samples
- the library can also be a file that stores unbalanced data, and there is no specific restriction here.
- Unbalanced data that is, in the data set, the proportion of different types of data is unbalanced. For example, if the ratio of positive and negative samples in the training sample is 9:1, the training sample is unbalanced data. Understandably, in real classification problems, the data to be classified may also be unbalanced data. For example, taking user fraud detection as an example, among the user behaviors to be detected, the proportion of fraudulent behaviors is far less than the proportion of non-fraud behaviors, and the user behaviors to be detected are also unbalanced data.
- the server can obtain unbalanced data through SQL statements; if the preset sample library is a file, the server can directly read the file to the server.
- S2 Perform dimensionality reduction processing on the unbalanced data according to a preset dimensionality reduction method to obtain low-dimensional data after dimensionality reduction.
- Dimension refers to the description angle of sample data in machine learning. That is, the characteristics of a sample data can be reflected in multiple dimensions. The higher the dimensionality of the sample data, the more associated features there are, and therefore, the more difficult it is to train.
- Dimensionality reduction processing refers to reducing the dimensionality of sample data and turning high-dimensional data into low-dimensional data. At the same time, the low-dimensional data obtained after dimensionality reduction needs to retain as many sample features as possible.
- the data features can be expressed in three-dimensional or two-dimensional space, which is convenient for intuitively discovering some data features.
- the server can use linear dimensionality reduction or nonlinear dimensionality reduction for dimensionality reduction processing.
- linear dimensionality reduction includes but is not limited to PCA (Principal Component Analysis) dimensionality reduction method, etc.
- nonlinear dimensionality reduction is mainly divided into non-linear dimensionality reduction based on kernel function and non-linear dimensionality reduction based on eigenvalues, including But not limited to LLE (local linear embedding) dimensionality reduction method, etc.
- the server can achieve dimensionality reduction through the dimensionality reduction function in the sklearn library to obtain low-dimensional data.
- sklearn the full name is scikit-learn
- scikit-learn is a Python-based machine learning library provided by a third party.
- the dimension of the unbalanced data is 6 dimensions, that is, the feature vector of the unbalanced data includes 6 components
- the low-dimensional data obtained after dimensionality reduction can be a 3-dimensional feature vector, that is, the redundancy is discarded
- S3 Sampling low-dimensional data according to a preset sampling method to obtain balanced data.
- Balanced data is relative to unbalanced data, that is, in the data set, the proportion of different types of data is balanced.
- the boundary between balanced data and unbalanced data can be defined by a critical ratio. That is, the critical ratio is used to determine whether it is balanced data. Taking the binary classification problem as an example, the critical ratio can be 4:1, that is, if the ratio of positive and negative samples is less than 4:1, it can be considered that the balance has been reached, and the data set is balanced data; otherwise, the data set is unbalanced data.
- Sampling is to adjust the proportions of different types of data in low-dimensional data to convert unbalanced data into balanced data.
- the server can use a variety of preset sampling methods to sample low-dimensional data. For example, the server can reduce the number of most types of data in low-dimensional data, and at the same time, increase the number of minority types of data, so that the ratio of the two can be balanced.
- the server can reduce a certain number of sample data from the majority data, and at the same time, add a certain number of sample data to the minority data, so that the quantitative ratio between the majority data and the minority data is less than 4:1 .
- the server when reducing a certain amount of majority data, the server can randomly discard the majority data; when adding a certain amount of minority data, the server can randomly copy the minority data to achieve the purpose of increasing the minority data.
- S4 Use the balance data as the training sample, and use the preset machine learning algorithm to train the training sample to obtain the classification model.
- the preset machine learning algorithm refers to a training method based on supervised learning.
- the preset machine learning algorithms include but are not limited to linear regression algorithms, logistic regression algorithms, naive Bayes algorithms, SVM algorithms, etc.
- the server uses the SVM algorithm for training to obtain the classification model.
- SVM stands for support vector machine, which is a two-class classification model.
- the server can import SVM related functions from the Python-based sklearn library to create an SVM classifier; and then import the balanced data into the SVM classifier for training, thereby obtaining a classification model.
- dimensionality reduction is performed on unbalanced data to obtain low-dimensional data after dimensionality reduction; that is, redundant features of unbalanced data are removed, calculation speed is accelerated, storage space is reduced, and the use of training obtained
- sampling low-dimensional data according to the preset sampling method to obtain balanced data that is, converting unbalanced data into balanced data can increase the weight of minority sample data.
- Enhance the impact of minority sample data in the training process use balanced data as training samples and use preset machine learning algorithms to train the training samples to obtain a classification model; make the trained classification model affect the minority classes in unbalanced data
- the sample data is more sensitive. In the process of using a trained classification model to classify, it can reduce the misjudgment rate of minority data, thereby improving the accuracy of classification.
- step S2 that is, performing dimensionality reduction processing on the unbalanced data according to a preset dimensionality reduction method to obtain low-dimensional data after dimensionality reduction, specifically includes the following steps:
- the self-encoding model is a classification model based on the self-encoding network.
- the self-encoding network also known as the autoencoder, is a neural network that aims to reconstruct input information in the field of unsupervised learning.
- the self-encoding network can automatically learn features from unlabeled data, can give a better feature description than the original data, and has a strong feature learning ability.
- unsupervised learning is opposite to supervised learning.
- Unsupervised learning uses data with unknown labels or unknown classification results as training samples to train to obtain a classification model; supervised learning refers to data with known labels or known classification results As a training sample, go to train to get a classification model.
- the preset number of layers refers to the number of network layers except the input and output layers in the self-encoding network, that is, the number of intermediate layers.
- the server can build a simple self-encoding model with three preset layers.
- the server can establish a three-layer self-encoding model based on the built-in functions provided by the TensorFlow framework.
- TensorFlow is an open source software library for high-performance numerical calculations. Through TensorFlow, computing work can be easily deployed to multiple platforms such as CPU, GPU, and devices including desktop devices, server clusters, mobile devices, edge devices, etc.
- TensorFlow was originally developed by researchers and engineers in the Google Brain team. It can provide strong support for machine learning and deep learning, and its flexible numerical computing core is widely used in many other scientific fields.
- S22 Use the self-encoding model to perform feature extraction on the unbalanced data to obtain hidden features of the unbalanced data, and use the hidden features as low-dimensional data after dimensionality reduction, where the low-dimensional data is composed of data of different data types.
- the implicit features of unbalanced data are the same as the explicit features of unbalanced numbers, and both represent the features of unbalanced data. Unlike explicit features, implicit features cannot be intuitively obtained from unbalanced data.
- the dominant feature of unbalanced data A is a, which can be obtained by simply analyzing and statistically on A, while the implicit feature b of A cannot be obtained in the same way.
- the subject test score of each student is a dominant feature, which can be obtained through statistics on the subject test score of each student; while the learning ability of each student is an implicit feature and cannot be simple Derived from test results.
- f is the coding function in the self-coding model
- m is related to the number of intermediate layers of the self-coding model.
- the dimension of the unbalanced data as 6 dimensions as an example, if the number of layers of the self-encoding model is set to three layers, the dimension of the low-dimensional data obtained after dimensionality reduction is 3 dimensions.
- Low-dimensional data consists of data of different data types.
- low-dimensional data may include type A data and type B data; or, low-dimensional data may include type A data, type B data, or type C data, etc.
- the unbalanced data is input into the established self-encoding classification model, and the hidden features of the unbalanced data are calculated according to the decoding function of the self-encoding classification model, and this is used as low-dimensional data.
- the dimensionality reduction processing of balanced data because low-dimensional data is composed of implicit features of unbalanced data, the low-dimensional data after dimensionality reduction retains as many effective features of the original data as possible.
- step S3 that is, sampling low-dimensional data according to a preset sampling method to obtain balanced data, it specifically includes the following steps:
- S31 Calculate the total quantity of low-dimensional data and the sub-quantity of low-dimensional data corresponding to different data types.
- low-dimensional data is composed of data of different data types
- the total amount of low-dimensional data is the sum of the number of data of various data types in low-dimensional data, and the amount of data of each data type is called low-dimensional data The number of points.
- the two data types included are A-type data and B-type data.
- the total number is 70, and the sub-numbers of low-dimensional data are 50 (the number of type A data) and 20 (the number of type B data).
- the server can distinguish different data types in low-dimensional data according to the field names of preset data types, and calculate the number of data corresponding to different field names, so as to obtain the total number of low-dimensional data and the corresponding data types The number of points of low-dimensional data.
- the amount of data in different data types may vary greatly.
- the data in the data type that accounts for the majority is called the majority data; the data in the data type that accounts for the minority is called the minority data .
- Whether a piece of data belongs to the majority type of data or the minority type of data can be determined according to the ratio of the number of low-dimensional data to the total number of low-dimensional data, and the size relationship between the preset threshold.
- the preset threshold can be determined according to the number of data types in low-dimensional data. For example, if the number of data types in low-dimensional data is 2, the preset threshold is 1/2; If the number is 3, the preset threshold is 1/3, and so on, the preset threshold is a fraction, the denominator is the number value of the data type in the low-dimensional data, and the numerator is 1. It should be noted that this is only an implementation manner of the preset threshold, which is included in this application but is not limited to this implementation manner.
- the preset threshold is 1/2. Since the ratio between the sub-quantity 50 of the data type A and the total quantity 70 is more than 1/2, the data of the data type A is the majority data, while the data of the data type B is the minority data.
- the preset threshold is 1/3; therefore, among the three data types A, B, and C, if any one of the data types is If the ratio between the sub-quantity of data and the total quantity of low-dimensional data exceeds 1/3, the data of this type of data is majority data.
- the data type in the unbalanced data can be set in advance, that is, the data in the unbalanced data has been typed in the preset sample library. For example, in the training sample data about network intrusion detection, 90% of the sample data belongs to For normal flow data, the data type corresponding to normal flow data is M; 10% of the sample data belong to abnormal flow data, and the data type corresponding to abnormal flow data is N, where M and N are the sample data in the preset sample library The value of the data type field.
- Under-sampling or down-sampling, is a sampling method that discards part of the sample data.
- the first sample data is the data obtained after under-sampling.
- the server can use multiple clustering algorithms to cluster most types of data, and then remove redundant data in the same cluster to obtain the first sample data.
- the server can call the k-mean clustering function in the sklearn library to cluster most types of data.
- the server can call the k-mean clustering function in the sklearn library to cluster most types of data.
- select a centroid and a certain number of samples within the scope of the centroid and then discard the data outside the centroid, and the remaining data is the first sample data.
- K-Means also called K-means
- K-Means is a clustering algorithm.
- K-Means is an algorithm that outputs k clusters meeting the minimum standard of variance by inputting the number of clusters k and a database containing n data objects; the selection of a certain number of samples in the range of centroids, and undersampling Corresponding to the number of majority types of data to be discarded in.
- the server can determine the minority type data according to the ratio of the sub-quantity of the low-dimensional data to the total quantity of the low-dimensional data and the preset threshold.
- Oversampling is a sampling method that increases the number of minority samples.
- the second sample data is the data obtained after oversampling.
- the server can use multiple clustering algorithms to cluster most types of data, and then remove redundant data in the same cluster to obtain the second sample data.
- the server can oversample the minority data through the SMOTE algorithm. That is, through the interpolation method, the approximate minority data is constructed, so as to achieve oversampling.
- SMOTE Synthetic minority over-sampling technique
- SMOTE is an improved algorithm based on the random over-sampling algorithm.
- the server may implement SMOTE through built-in functions in the imbalanced-learn library in Python.
- imbalanced-learn is a Python-based third-party library for sampling unbalanced data.
- the preset ratio is the combined ratio of the first sample data and the second sample data.
- the preset ratio ensures that the data composed of the first sample data and the second sample data is balanced data. That is, the ratio set within the critical ratio range.
- the server can combine the first sample data and the second sample data in a ratio lower than the critical ratio to obtain balanced data.
- the server under-sampling the majority of data through the clustering algorithm, over-sampling the minority data by the SMOTE algorithm, and then combines the first sample data and the second sample obtained after sampling in a preset ratio Data, so as to obtain balanced data, making the discarded majority data as redundant as possible, and the added minority data as close as possible to the actual minority data, so that the obtained balanced data has diversity, which is beneficial to improve the training effect and classification The accuracy of the final classification result of the model.
- step S4 the balance data is used as training samples, and the training samples are trained using a preset machine learning algorithm, and after obtaining the classification model, step S5 is further included, which is described in detail as follows:
- S5 Optimize the parameters of the classification model according to the preset parameter optimization method to obtain parameter extreme values, where the parameter extreme values are the parameter values of the optimized classification model to improve the classification accuracy of the classification model.
- the parameters of the classification model are mainly hyperparameters and ordinary parameters.
- the hyperparameter is a parameter whose value is preset before training; the common parameter is the parameter data obtained through training; the parameter extreme value is the optimal value of the hyperparameter and the common parameter.
- Hyperparameters include, but are not limited to, the learning rate, the number of hidden layers, and the depth of the tree in the classification model.
- Common parameters include but are not limited to activation functions, optimization algorithms, regularized parameters, and the number of nodes in the neural network layer.
- the server can first use the SVM algorithm to establish a two-class model, and then optimize the hyperparameters through grid search. Specifically, the server can perform linear exhaustion through the kernel function in the sklearn library, that is, select a subset one by one from the value space of the hyperparameters for tuning until the extreme value of the parameter is found.
- the server can use the gradient descent method to optimize, and get the parameter extreme value.
- the gradient descent method is a method for obtaining the extreme value of the objective function by using the first derivative information.
- the server optimizes the hyperparameters and common parameters of the classification model, so that the hyperparameters and common parameters reach the optimal value, that is, the parameter extreme value, which is beneficial to further optimize the classification accuracy of the classification model and improve the accuracy of classification. Sex.
- step S5 the parameters of the classification model are optimized according to the preset parameter optimization method to obtain the parameter extreme values, which specifically includes the following steps:
- the preset candidate values are possible values of the parameters of the classification model.
- the preset candidate values of the parameter "activation function” include but are not limited to "ReLU function”, “sigmoid function”, “tanh function”, etc.
- the preset candidate values of the parameter “optimization algorithm” include but are not limited to "AdaDelta algorithm” , “Adam boosting algorithm”, “gradient descent algorithm”, etc.
- the data set is a data structure that stores preset candidate values of the classification model parameters.
- the data set may be a data dictionary, for example, a dictionary type defined in Python.
- a data set stored in a Python dictionary can be expressed as:
- S52 According to the cross-validation method, select preset candidate values from the data set to establish at least two deep learning models;
- Cross validation also known as Rotation Estimation
- Cross validation is a practical method of statistically cutting data samples into smaller subsets. That is, in a given modeling sample, take out most of the samples to build the model, leave a small part of the sample to use the newly established model for forecasting, and calculate the forecast error of this small part of the sample, and record the sum of their squares.
- the deep learning model is a neural network model based on deep learning. Among them, deep learning can form a more abstract high-level representation attribute category or feature by combining low-level features to discover the feature representation of the data.
- Select preset candidate values from the data set to build a deep learning model that is, traverse the data in the data set, and select preset candidate values to build a deep learning model.
- the server may select a preset candidate value from x, y, and z respectively, and use the preset candidate value to establish a deep learning model.
- the process of establishing a deep learning model is a supervised learning process. Understandably, according to different combinations of preset candidate values, different deep learning models can be constructed, that is, at least two deep learning models.
- the server first divides the balanced data according to the cross-validation method, and uses the divided balanced data as the training sample, then calls the Sequential function in the sklearn library to establish an initial model, and uses the preset candidates selected in the data set Training is performed for the parameters of the deep learning model to obtain at least two deep learning models.
- S53 Use the deep learning model to classify the training samples, and determine the optimal deep learning model according to the classification accuracy of the deep learning model
- Training samples are balanced data obtained after sampling low-dimensional data.
- the server can determine from the classification results which deep learning model has higher classification accuracy.
- the classification result is the accuracy of classifying the training samples. For example, if the accuracy of classification results of deep learning models A, B, and C are respectively 90%, 85%, and 80%, then deep learning model A is the optimal deep learning model.
- the preset candidate value corresponding to the optimal deep learning model is the preset candidate value selected from the data set when the deep learning model is established in step S52.
- the server composes the preset candidate values of the parameters of the classification model into a data set, and according to the cross-validation method, selects different combinations of preset candidate values from the data set to build multiple deep learning models;
- the learning model classifies the training samples and determines the optimal deep learning model according to the classification accuracy of the deep learning model, so that the preset candidate values corresponding to the optimal deep learning model are used as the parameter extremes of the classification model, that is, by establishing the depth Learn the model, and determine the extreme values of the parameters according to the classification results of different deep learning models, quickly narrow the range of the extreme values of the parameters, and determine the extreme values of the parameters, so that the overall training time for training the classification model is shorter and the training efficiency is higher.
- the parameters of the classification model further include hyperparameters.
- step S5 that is, optimizing the parameters of the classification model according to a preset parameter optimization method to obtain parameter extreme values
- the method further includes the steps:
- S55 Assign the hyperparameters by preset random sampling method, and use the assigned hyperparameters to iteratively operate the classification model until the iterative operation is completed, and the parameter extreme value is obtained.
- the preset random sampling method refers to randomly selecting several initial values for the hyperparameter within the range of the hyperparameter value.
- the preset random sampling method may adopt a random forest (Random Forest) algorithm.
- Iterative operation that is, use the initial value of the assigned hyperparameter to substitute the classification model into the calculation to stabilize the classification model.
- the initial value of the hyperparameter at this time is the parameter extreme value.
- the server can import the random forest function in the sklearn library, randomly select the initial value in the hyperparameter range through the random forest algorithm, and then iterate until the parameter extreme value is found.
- the server selects the initial value within the hyperparameter value range through a preset random sampling method. Due to the random forest algorithm, the accuracy of selecting the initial value that meets the parameter extreme condition is improved. Reduce the number of iterative operations and speed up calculation efficiency.
- step S55 for "until the iterative operation is completed" in step S55, the following steps are specifically included:
- S551 Use a preset loss function to detect the Nth operation result of the iterative operation, where N is a positive integer greater than 0.
- Loss function is a function in statistics and statistical decision theory that maps an event or an element in a sample space to another real number that expresses the economic cost or opportunity cost associated with the event.
- the loss function is a verification function for evaluating whether the iterative operation is completed, that is, when the loss function converges, it is determined that the iterative operation is completed.
- the loss function can specifically use the MSE function, that is, the mean square error function.
- MSE Mean Square Error
- MSE Mean Square Error
- the server takes the single operation result of the iterative operation as input, substitutes it into the loss function, and finds the limit value of the loss function.
- the server calculates the limit value of the loss function, it means that the loss function has converged. At this time, it is determined that the iterative calculation is completed and the initial value of the hyperparameter is the parameter extreme value; if the server fails to calculate the limit of the loss function Value, it means that the loss function cannot converge. At this time, the server needs to take the next operation result of the iterative operation as input and recalculate the limit value of the loss function until the loss function reaches convergence.
- the server judges the result of the iterative operation through the loss function to determine whether the iterative operation is complete, that is, it indirectly judges whether the iterative operation is completed through the convergence of the loss function, rather than directly through the classification result
- the analysis to determine whether the classification model is optimized is faster and more efficient.
- an unbalanced data classification model training device is provided, and the unbalanced data classification model training device corresponds to the unbalanced data classification model training method in the foregoing embodiment in a one-to-one correspondence.
- the device for training an unbalanced data classification model includes a login verification module 61, a page sending module 62, an automatic saving module 63 and a submission verification module 64.
- the detailed description of each functional module is as follows:
- the data acquisition module 61 is used to acquire unbalanced data from a preset sample library
- the dimensionality reduction module 62 is used to perform dimensionality reduction processing on the unbalanced data according to a preset dimensionality reduction method to obtain low-dimensional data after dimensionality reduction;
- the sampling module 63 is configured to sample low-dimensional data according to a preset sampling method to obtain balanced data
- the training module 64 is configured to use the balance data as training samples, and use a preset machine learning algorithm to train the training samples to obtain a classification model.
- the dimensionality reduction module 62 includes:
- the self-encoding model establishment sub-module 621 is used to establish a self-encoding model with a preset number of layers;
- the feature extraction sub-module 622 is used to perform feature extraction on unbalanced data using the self-encoding model to obtain hidden features of the unbalanced data, and use the hidden features as low-dimensional data after dimensionality reduction.
- sampling module 63 includes:
- the quantity calculation sub-module 631 is used to calculate the total quantity of low-dimensional data and the sub-quantity of low-dimensional data corresponding to different data types;
- the under-sampling sub-module 632 is configured to, if the ratio of the number of points to the total number exceeds a preset threshold, use the data of the data type corresponding to the number of points as the majority data, and perform under-sampling on the majority data to obtain the first sample data;
- the oversampling sub-module 633 is configured to, if the ratio of the number of points to the total number does not exceed the preset threshold, use the data of the data type corresponding to the number of points as minority data, and oversample the minority data to obtain the second sample data;
- the sample combination sub-module 634 is configured to combine the first sample data and the second sample data according to a preset ratio to obtain balanced data.
- the device for training an unbalanced data classification model further includes:
- the parameter optimization module 65 is used to optimize the parameters of the classification model according to the preset parameter optimization method to obtain parameter extreme values, where the parameter extreme values are the parameter values of the optimized classification model to improve the classification accuracy of the classification model .
- parameter optimization module 65 includes:
- the collection setting sub-module 651 is used for composing the preset candidate values of the parameters of the classification model into a data collection
- the deep learning model establishment sub-module 652 is configured to select preset candidate values from the data set to establish at least two deep learning models according to the cross-validation method;
- the classification sub-module 653 is used to classify the training samples using the deep learning model, and determine the optimal deep learning model according to the classification accuracy of the deep learning model;
- the extreme value selection sub-module 654 is used to use the preset candidate value corresponding to the optimal deep learning model as the parameter extreme value.
- parameter optimization module 65 also includes:
- the hyperparameter optimization sub-module 655 is used to assign values to hyperparameters through random sampling, and use the assigned hyperparameters to perform iterative operations on the classification model until the iterative operation is completed and the extreme values of the parameters are obtained.
- hyperparameter optimization sub-module 655 includes:
- the loss function detection unit 6551 is configured to detect the Nth operation result of the iterative operation using a preset loss function, where N is a positive integer greater than 0;
- the convergence calculation unit 6552 is configured to determine that the iterative operation is completed if the result of the Nth operation converges the preset loss function; if the result of the Nth operation does not cause the preset loss function to converge, use the preset loss function for the iterative operation The N+1th operation result is checked until the preset loss function reaches convergence.
- Each module in the above-mentioned unbalanced data classification model training device can be implemented in whole or in part by software, hardware, and a combination thereof.
- the foregoing modules may be embedded in the form of hardware or independent of the processor in the computer device, or may be stored in the memory of the computer device in the form of software, so that the processor can call and execute the operations corresponding to the foregoing modules.
- a computer device is provided.
- the computer device may be a server, and its internal structure diagram may be as shown in FIG. 7.
- the computer equipment includes a processor, a memory, a network interface and a database connected through a system bus.
- the processor of the computer device is used to provide calculation and control capabilities.
- the memory of the computer device includes a non-volatile storage medium and an internal memory.
- the non-volatile storage medium stores an operating system, computer readable instructions, and a database.
- the internal memory provides an environment for the operation of the operating system and computer-readable instructions in the non-volatile storage medium.
- the network interface of the computer device is used to communicate with an external terminal through a network connection.
- the computer-readable instructions are executed by the processor to realize an imbalanced data classification model training method.
- a computer device including a memory, a processor, and computer-readable instructions stored in the memory and capable of being run on the processor.
- the steps of the method for training a balanced data classification model are, for example, steps S1 to S4 shown in FIG. 2.
- the processor executes the computer-readable instructions
- the functions of the modules/units of the apparatus for training an unbalanced data classification model in the foregoing embodiment are realized, for example, the functions of modules 71 to 74 shown in FIG. 7. To avoid repetition, I won’t repeat them here.
- one or more non-volatile readable storage media are provided, and computer readable instructions are stored thereon.
- the computer readable instructions are executed by one or more processors, the imbalance in the above method embodiment is realized.
- the data classification model training method or, when the computer-readable instructions are executed by one or more processors, realize the functions of the modules/units in the unbalanced data classification model training device in the foregoing device embodiment. To avoid repetition, I won’t repeat them here.
- Non-volatile memory may include read only memory (ROM), programmable ROM (PROM), electrically programmable ROM (EPROM), electrically erasable programmable ROM (EEPROM), or flash memory.
- Volatile memory may include random access memory (RAM) or external cache memory.
- RAM is available in many forms, such as static RAM (SRAM), dynamic RAM (DRAM), synchronous DRAM (SDRAM), double data rate SDRAM (DDRSDRAM), enhanced SDRAM (ESDRAM), synchronous chain Channel (Synchlink) DRAM (SLDRAM), memory bus (Rambus) direct RAM (RDRAM), direct memory bus dynamic RAM (DRDRAM), and memory bus dynamic RAM (RDRAM), etc.
Landscapes
- Engineering & Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Artificial Intelligence (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Evolutionary Computation (AREA)
- Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
一种不平衡数据分类模型训练方法、装置、计算机设备及存储介质,所述方法包括:从预设样本库中获取不平衡数据(S1);根据预设的降维方法对所述不平衡数据进行降维处理,得到降维后的低维数据(S2);按照预设的采样方式对所述低维数据进行采样,得到平衡数据(S3);将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型(S4)。所述方法能够降低对不平衡数据中少数类数据的误判率,从而提高分类的准确性。
Description
本申请以2019年4月28日提交的申请号为201910351188.7,名称为“不平衡数据分类模型训练方法、装置、设备及存储介质”的中国发明专利申请为基础,并要求其优先权。
本申请涉及信息处理领域,尤其涉及不平衡数据分类模型训练方法、装置、设备及存储介质。
在使用机器学习方法对数据进行分类的实际应用中,对不平衡数据的处理一直是一个棘手的问题。不平衡数据是指在训练或分类过程中,不同类别的样本的比例不均衡。例如,在用户欺诈行为检测中,欺诈行为的占比远远小于非欺诈行为的占比。不平衡数据广泛存在于诸如故障检测、缺陷检测、网络入侵检测以及医疗诊断等实际应用中。
在不平衡数据中,数量较少的样本虽然数量少,但对训练或分类的结果也会产生重要影响,因此不能被当做噪声忽略。然而,发明人意识到若直接在不平衡数据上使用传统的机器学习方法,得到的分类规则往往偏向样本数量多的类别,使得针对需要被重点关注的类别的规则偏少和偏弱,导致分类模型无法产生有效的分类,从而不能达到准确分类的目的。
发明内容
本申请实施例提供一种不平衡数据分类模型训练方法、装置、设备及存储介质,以解决使用由不平衡数据训练得到的分类模型进行分类时,分类不准确的问题。
一种不平衡数据分类模型训练方法,包括:
从预设样本库中获取不平衡数据;
根据预设的降维方法对所述不平衡数据进行降维处理,得到降维后的低维数据;
按照预设的采样方式对所述低维数据进行采样,得到平衡数据;
将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型。
一种不平衡数据分类模型训练装置,包括:
数据获取模块,用于从预设样本库中获取不平衡数据;
降维模块,用于根据预设的降维方法对所述不平衡数据进行降维处理,得到降维后的低维数据;
采样模块,用于按照预设的采样方式对所述低维数据进行采样,得到平衡数据;
训练模块,用于将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型。
一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机可读指令,所述处理器执行所述计算机可读指令时实现上述不平衡数据分类模型训练方法。
一个或多个存储有计算机可读指令的非易失性可读存储介质,所述计算机可读指令被一个或多个处理器执行时,使得所述一个或多个处理器执行上述不平衡数据分类模型训练方法。
本申请的一个或多个实施例的细节在下面的附图和描述中提出,本申请的其他特征和优点将从说明书、附图以及权利要求变得明显。
为了更清楚地说明本申请实施例的技术方案,下面将对本申请实施例的描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动性的前提下,还可以根据这些附图获得其他的附图。
图1是本申请一实施例中不平衡数据分类模型训练方法的一应用环境示意图;
图2是本申请一实施例中不平衡数据分类模型训练方法的流程图;
图3是本申请一实施例中不平衡数据分类模型训练方法中步骤S2的流程图;
图4是本申请一实施例中不平衡数据分类模型训练方法中步骤S3的流程图;
图5是本申请一实施例中对分类模型的参数进行优化的流程图;
图6是本申请一实施例中对迭代运算的运算结果进行判断的流程图;
图7是本申请一实施例中不平衡数据分类模型训练装置的示意图;
图8是本申请一实施例中计算机设备的示意图。
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。
本申请提供的不平衡数据分类模型训练方法,可应用在如图1的应用环境中,该应用环境包括服务端和预设样本库,其中,预设样本库是存储不平衡数据的数据库;服务端是对不平衡数据进行训练的计算机设备,服务端可以是服务器或服务器集群;服务端与预设样本库之间通过网络连接,该网络可以是有线网络或无线网络。本申请实施例提供的不平衡数据分类模型训练方法应用于服务端。
在一实施例中,如图2所示,提供了一种不平衡数据分类模型训练方法,其具体实现流程包括如下步骤:
S1:从预设样本库中获取不平衡数据。
预设样本库,是用于存储不平衡数据的存储平台。具体地,预设样本库可以是数据库,包括但不限于各种关系型或非关系型数据库,如MS-SQL、Oracle、MySQL、Sybase、DB2、Redis、MongodDB、Hbase等;或者,预设样本库也可以是存储有不平衡数据的文件,此处不做具体限制。
不平衡数据,即在数据集中,不同类别数据的比例不均衡。例如,若训练样本中,正负样本的比例为9∶1,则该训练样本就是不平衡数据。可以理解地,在现 实分类问题中,待分类的数据也可能是不平衡数据。例如,以用户欺诈行为检测为例,在待检测用户行为中,欺诈行为的占比远远小于非欺诈行为的占比,则待检测用户行为也是不平衡数据。
具体地,若预设样本库为数据库,服务端可以通过SQL语句获取不平衡数据;若预设样本库为文件,服务端可以直接读取该文件到服务端本地。
S2:根据预设的降维方法对不平衡数据进行降维处理,得到降维后的低维数据。
维度,或维数,是指机器学习中对样本数据的描述角度。即,可以通过多个维度体现一个样本数据的特征。样本数据的维度越高,则其中存在越多相关联的特征,因此,进行训练的难度就越大。
例如,在机械故障检测中,引起设备故障的因素达到上百种;在医疗诊断中,引起一个病症的病因多达几百乃至更多种;等。若直接对这些数据进行训练,将带来维数灾难,不仅影响了训练速度,还很难找到最优的解。
降维处理,是指减少样本数据的维度,将高维数据化为低维度数据。同时,降维后得到的低维数据需要保留尽可能多的样本特征。
通过对数据进行降维处理,便于对数据进行可视化处理,方便对数据进行观察和探索;同时,简化了机器学习模型的训练和预测。例如,将维度降到三维或三维以下,然后就能把数据特征在三维空间或二维空间上表示出来,便于直观地发现一些数据特征。
服务端可以采用线性降维或非线性降维进行降维处理。其中,线性降维包括但不限于PCA(Principal Component Analysis,主成分分析)降维法等;非线性降维又主要分为基于核函数的非线性降维和基于特征值的非线性降维,包括但不限于LLE(局部线性嵌入)降维法等。
具体地,服务端可以通过sklearn库中的降维函数实现降维,得到低维数据。其中,sklearn,全称为scikit-learn,是一个基于Python的第三方提供的机器学习库。
举例来说,若不平衡数据的维度为6维,即不平衡数据的特征向量包括6个分量,则经过降维后得到的低维数据可以为一个3维特征向量,即,舍弃掉冗余的3 个特征分量。
S3:按照预设的采样方式对低维数据进行采样,得到平衡数据。
平衡数据,是相对于不平衡数据而言,即在数据集中,不同类别数据的比例达到均衡。平衡数据与不平衡数据的界限可以通过临界比例定义。即通过临界比例确定是否为平衡数据。以二分类问题为例,临界比例可以为4∶1,即正负样本的比例低于4∶1,则可以认为已达到平衡,该数据集为平衡数据;否则,则该数据集为不平衡数据。
采样,即对低维数据中不同类别数据的比例进行调整,使不平衡数据转换成平衡数据。
服务端可以采用多种预设的采样方式对低维数据进行采样。例如,服务端可以减少低维数据中多数类数据的数量,同时,增加少数类数据的数量,从而使得两者的比例达到平衡。
具体地,服务端可以从多数类数据中减少一定数量的样本数据,同时,向少数类数据中增加一定数量的样本数据,使得多数类数据与少数类数据之间的数量比例低于4∶1。其中,在减少一定数量的多数类数据时,服务端可以随机舍去多数类数据;在增加一定数量的少数类数据时,服务端可以随机复制少数类数据,达到增加少数类数据的目的。
S4:将平衡数据作为训练样本,使用预设的机器学习算法对训练样本进行训练,得到分类模型。
预设的机器学习算法,是指基于监督学习的训练方法。预设的机器学习算法包括但不限于线性回归算法、逻辑回归算法、朴素贝叶斯算法、SVM算法等。
优选地,服务端采用SVM算法进行训练,得到分类模型。其中,SVM即support vector machine,支持向量机,它是一种二类分类模型。
具体地,服务端可从基于Python的sklearn库中导入SVM相关函数,创建SVM分类器;然后将平衡数据导入到SVM分类器中进行训练,从而得到分类模型。
在本实施例中,对不平衡数据进行降维处理,得到降维后的低维数据;即去掉不平衡数据的冗余特征,加快计算速度,减少存储空间,同时有利于避免使用训练得到的分类模型进行分类时,出现过拟合的问题;按照预设的采样方式对 低维数据进行采样,得到平衡数据,即,将不平衡数据转换成平衡数据,可以增加少数类样本数据的权重,增强少数类样本数据在训练过程中的影响;将平衡数据作为训练样本,使用预设的机器学习算法对训练样本进行训练,得到分类模型;使得训练得到的分类模型对不平衡数据中的少数类样本数据更加敏感,在使用训练好的分类模型进行分类的过程中,能够降低对少数类数据的误判率,从而提高分类的准确性。
进一步地,在一实施例中,如图3所示,针对步骤S2,即根据预设的降维方法对不平衡数据进行降维处理,得到降维后的低维数据,具体包括如下步骤:
S21:建立预设层数的自编码模型。
自编码模型,是基于自编码网络的分类模型。其中,自编码网络,又称自编码器(autoencoder),是非监督学习领域中一种以重构输入信息为目标的神经网络。自编码网络可以自动从无标注的数据中学习特征,可以给出比原始数据更好的特征描述,具有较强的特征学习能力。其中,非监督学习与监督学习相对,非监督学习,是使用未知标签或未知分类结果的数据作为训练样本,去训练得到一个分类模型;监督学习,是将已知标签或已知分类结果的数据作为训练样本,去训练得到一个分类模型。
预设层数是指自编码网络中,除去输入、输出层之外的网络层数,即中间层的层数。例如,服务端可以建立预设层数为三层的简单自编码模型。
具体地,服务端可以基于TensorFlow框架提供的内置函数建立三层自编码模型。其中,TensorFlow是一个开放源代码软件库,用于进行高性能数值计算。通过TensorFlow,可以轻松地将计算工作部署到如CPU、GPU等多种平台和包括桌面设备、服务器集群、移动设备、边缘设备等在内的设备上。TensorFlow最初是由Google Brain团队中的研究人员和工程师开发的,可为机器学习和深度学习提供强力支持,并且其灵活的数值计算核心广泛应用于许多其他科学领域。
S22:使用自编码模型对不平衡数据进行特征提取,得到不平衡数据的隐含特征,并将隐含特征作为降维后的低维数据,其中,低维数据由不同数据类型的数据组成。
不平衡数据的隐含特征与不平衡数的显性特征一样,均代表了不平衡数据的特 征;与显性特征不同的是,隐含特征不能从不平衡数据上直观的获取。
举例来说,不平衡数据A的显性特征为a,a可以通过对A简单地分析统计得出,而A的隐含特征b无法相同的方式获得。例如,在学生样本数据中,每位学生的科目考试成绩是显性特征,可以通过对每位学生的科目考试成绩的统计得出;而每位学生的学习能力是隐含特征,并不能简单从考试成绩得出。
具体地,服务端将不平衡数据x输入到自编码模型中,则得到的输出数据m,可以用公式表示为m=f(x)。其中,f为自编码模型中的编码函数;m与自编模型的中间层数有关。然后,服务端可以用公式c=g(m)计算得到隐含特征,其中,g为自编码模型中的解码函数,c为不平衡数据x的隐含特征。由以上两个公式可以理解,自编码模型的层数越多,隐含特征越多;反之,隐含特征越少。
以不平衡数据的维度为6维为例,若自编码模型的层数设为三层,则降维后得到的低维数据的维数为3维。
低维数据由不同数据类型的数据组成。例如,低维数据中可能包括A类数据和B类数据;或者,低维数据中可能包括A类数据、B类数据,或C类数据,等。
在本实施例中,将不平衡数据输入到建立的自编码分类模型中,根据自编码分类模型的解码函数计算得到不平衡数据的隐含特征,并以此作为低维数据,实现了对不平衡数据的降维处理,由于低维数据由不平衡数据的隐含特征构成,使得降维后的低维数据保留了尽可能多的原数据的有效特征。
进一步地,在一实施例中,如图4所示,针对步骤S3,即按照预设的采样方式对低维数据进行采样,得到平衡数据,具体包括如下步骤:
S31:计算低维数据的总数量,以及不同数据类型对应的低维数据的分数量。
由于低维数据由不同数据类型的数据组成,因此,低维数据的总数量即为低维数据中各种数据类型的数据的数量总和,而每种数据类型的数据的数量称为低维数据分数量。
举例来说,在一个二分类的数据集中,包含的两个数据类型,分别为A类数据和B类数据,其中,A类数据有50个,B类数据有20个,则低维数据的总数量为70,低维数据的分数量分别为50(A类数据的数量)和20(B类数据的数量)。
具体地,服务端可以根据预设的数据类型的字段名区分低维数据中的不同数据 类型,计算不同字段名下对应的数据数量,从而可以得到低维数据的总数量,以及不同数据类型对应的低维数据的分数量。
S32:若分数量与总数量的比值超过预设阈值,则将分数量对应的数据类型的数据作为多数类数据,并对多数类数据进行欠采样,得到第一样本数据。
在不平衡数据中,不同数据类型中的数据的数量相差可能很大,其中,数量占多数的数据类型中的数据称为多数类数据;数量占少数的数据类型中的数据称为少数类数据。
一个数据属于多数类数据还是少数类数据,可以根据低维数据的分数量与低维数据的总数量的比值,与预设阈值之间的大小关系来确定。其中,预设阈值,可以根据低维数据中数据类型的数量来确定,比如:若低维数据中数据类型的数量为2,则预设阈值为1/2;若低维数据中数据类型的数量为3,则预设阈值为1/3,依此类推,预设阈值为分数,分母为低维数据中数据类型的数量值,分子为1。需要说明的是,这只是预设阈值的一种实施方式,本申请中包括但不限定于该实施方式。
举例来说,在一个低维数据中,包含A和B两个数据类型,每个数据类型下分别有数据50个和20个,则预设阈值为1/2。由于数据类型A的分数量50与总数量70之间的比值超过1/2,因此,数据类型A的数据为多数类数据,相对地,数据类型B的数据为少数类数据。
可以理解地,若低维数据包括A、B、C三种数据类型,则预设阈值为1/3;因此,在A、B、C三种数据类型中,若其中任一种数据类型的数据的分数量与低维数据的总数量之间的比值超过1/3,则该类数据类型的数据为多数类数据。
不平衡数据中的数据类型可以预先设置,即,在预设样本库中对不平衡数据中的数据已作出类型标注,例如,在关于网络入侵检测的训练样本数据中,90%的样本数据属于正常流量数据,正常流量数据对应的数据类型为M;10%的样本数据属于非正常流量数据,非正常流量数据对应的数据类型为N,其中,M和N是预设样本库中样本数据的数据类型字段的取值。
欠采样,或称下采样,是一种舍弃部分样本数据的采样方法。
第一样本数据,即经过欠采样后得到的数据。
服务端可以采用多种聚类算法对多数类数据进行聚类处理,然后去掉相同聚类中的冗余数据,得到第一样本数据。
具体地,服务端可以调用sklearn库中的k-mean聚类函数对多数类数据进行聚类处理。在聚类处理后得到的每一个聚类中选取一个类心,以及类心范围内的一定数量的样本,然后将类心以外的数据舍去,剩下的数据即为第一样本数据。
其中,K-Means,又叫K均值,是一种聚类算法。K-Means是通过输入聚类个数k,以及包含n个数据对象的数据库,输出满足方差最小标准的k个聚类的一种算法;类心范围内一定数量的样本的选取,与欠采样中要舍弃的多数类数据的数量相对应。
S33:若分数量与总数量的比值未超过预设阈值,则将分数量对应的数据类型的数据作为少数类数据,并对少数类数据进行过采样,得到第二样本数据。
与步骤S32中获取多数类数据的方法类似,服务端可以根据低维数据的分数量与低维数据的总数量的比值,与预设阈值之间的大小关系确定少数类数据。
过采样,或称上采样,即增加少数类样本的数量的采样方法。
第二样本数据,即经过过采样后得到的数据。
服务端可以采用多种聚类算法对多数类数据进行聚类处理,然后去掉相同聚类中的冗余数据,得到第二样本数据。
具体地,服务端可以可以通过SMOTE算法对少数类数据进行过采样。即,通过插值方法,构建近似的少数类数据,从而实现过采样。
其中,SMOTE,即Synthetic minoritye over-sampling technique,合成少数类过采样技术,是基于随机过采样算法的一种改进算法。
具体地,服务端可以可通过Python中的imbalanced-learn库中的内置函数实现SMOTE。其中,imbalanced-learn是一个基于Python的第三方提供的对不平衡数据进行采样的库。
S34:根据预设比例将第一样本数据与第二样本数据进行组合,得到平衡数据。
预设比例,是第一样本数据与第二样本数据的组合比例。
预设比例保证由第一样本数据和第二样本数据构成的数据是平衡数据。即,在 临界比例范围内设定的比例。
具体地,若临界比例为4∶1,则服务端可以以低于临界比例的比例组合第一样本数据和第二样本数据,得到平衡数据。
在本实施例中,服务端通过聚类算法对多数类数据进行欠采样,对少数类数据采样SMOTE算法进行过采样,然后以预设比例组合采样后得到的第一样本数据和第二样本数据,从而得到平衡数据,使得舍弃的多数类数据尽可能的为冗余数据,增加的少数类数据尽可能接近实际的少数类数据,让得到的平衡数据具有多样性,有利提高训练效果和分类模型最终的分类结果的准确性。
进一步地,在一实施例中,在步骤S4之后,即将平衡数据作为训练样本,使用预设的机器学习算法对训练样本进行训练,得到分类模型之后,还包括步骤S5,详述如下:
S5:按照预设的参数优化方法对分类模型的参数进行优化,得到参数极值,其中,参数极值为优化后的分类模型的参数值,用以提高分类模型的分类精度。
分类模型的参数主要是超参数和普通参数。
其中,超参数,是在训练之前预先设置值的参数;普通参数,是通过训练得到的参数数据;参数极值是超参数和普通参数的最优值。超参数包括但不限于分类模型中的学习率、隐藏层数、树的深度等。普通参数包括但不限于激活函数、优化算法、正则化的参数、神经网络层节点的数量等。
针对超参数,服务端可以采先用SVM算法建立一个二分类模型,然后通过网格搜索方式进行超参数调优。具体地,服务端可以通过sklearn库中的核函数进行线性穷举,即,从超参数的取值空间中逐一选定子集进行调优,直到找到参数极值为止。
针对普通参数,服务端可以采用梯度下降法进行优化,得到参数极值。其中,梯度下降法是一种利用一次导数信息求取目标函数极值的方法。
在本实施例中,服务端对分类模型的超参数和普通参数进行优化,使得超参数和普通参数达到最优值,即参数极值,有利于进一步优化分类模型的分类精度,提高分类的准确性。
进一步地,在一实施例中,如图5所示,针对步骤S5,即按照预设的参数优化 方法对分类模型的参数进行优化,得到参数极值,具体包括如下步骤:
S51:将分类模型的参数的预设候选值组成数据集合;
预设候选值是分类模型参数的可能取值。例如,参数“激活函数”的预设候选值包括但不限于“ReLU函数”、“sigmoid函数”、“tanh函数”等;参数“优化算法”的预设候选值包括但不限于“AdaDelta算法”、“Adam提升算法”、“梯度下降算法”等。
数据集合是存储分类模型参数的预设候选值的数据结构。具体地,数据集合可以是数据字典,例如,定义为Python中的字典类型。
具体地,一个以Python字典进行存储的数据集合可以表示为:
{x:[2:6],y:[1:3],z:[ReLU,sigmoid,tanh]},其中,x,y和z为分类模型的参数,[2:6]、[1:3]和[ReLU,sigmoid,tanh]分别表示为的取值x范围在2到6之间,参数y的取值范围在1到3之间,参数z的取值范围在ReLU,sigmoid或tanh之间。
S52:根据交叉验证法,从数据集合中选取预设候选值建立至少两个深度学习模型;
交叉验证(Cross Validation),也称作循环估计(Rotation Estimation),是一种统计学上将数据样本切割成较小子集的实用方法。即,在给定的建模样本中,拿出大部分样本进行建模型,留小部分样本用刚建立的模型进行预报,并求这小部分样本的预报误差,记录它们的平方加和。
举例来说,将样本分为k等分,选取其中任意(k-1)份用于建模,一份用于验证,这样共可以建立k个模型并得到k组误差并求取平方和,将这一平方和作为评估当前参数设定下的模型精度。
深度学习模型,是基于深度学习构建的神经网络模型。其中,深度学习可以通过组合低层特征形成更加抽象的高层表示属性类别或特征,以发现数据的特征表示。
从数据集合中选取预设候选值建立深度学习模型,即对数据集合中的数据进行遍历,选取预设候选值建立深度学习模型。例如,以步骤S52中建立的数据集合为例,服务端可以分别从x,y和z中选取一个预设候选值,并以预设候选值建立深度学习模型。其中,建立深度学习模型的过程是一个监督学习的过程。可以 理解地,根据预设候选值的不同组合,可以构建出不同的深度学习模型,即至少两个深度学习模型。
具体地,服务端先根据交叉验证法对平衡数据进行划分,并以划分后的平衡数据作为训练样本,再调用sklearn库中的Sequential函数建立一个初始模型,并以数据集合中选取的预设候选值为深度学习模型的参数进行训练,得到至少两个深度学习模型。
S53:使用深度学习模型对训练样本进行分类,并根据深度学习模型的分类准确率确定最优深度学习模型;
训练样本,即对低维数据进行采样后得到的平衡数据。
由于对训练样本进行分类属于监督学习过程,因此,服务端可以从分类结果中确定哪个深度学习模型的分类准确性更高。
其中,分类结果即对训练样本进行分类的准确率。例如,深度学习模型A、B和C的分类结果的准确率分别为90%,85%,80%,则深度学习模型A即为最优深度学习模型。
S54:将最优深度学习模型所对应的预设候选值作为参数极值。
最优深度学习模型所对应的预设候选值,即步骤S52中,建立深度学习模型时,从数据集合中选取的预设候选值。
在本实施例中,服务端将分类模型的参数的预设候选值组成数据集合,并根据交叉验证法,从数据集合中选取预设候选值的不同组合建立多个深度学习模型;然后使用深度学习模型对训练样本进行分类,并根据深度学习模型的分类准确率确定最优深度学习模型,从而以最优深度学习模型所对应的预设候选值作为分类模型的参数极值,即通过建立深度学习模型,并根据不同深度学习模型的分类结果确定参数极值,快速缩小参数极值的取值范围,从而确定参数极值,使得训练分类模型的整体训练时间更短,训练效率更高。
进一步地,在一实施例中,分类模型的参数还包括超参数,针对步骤S5,即按照预设的参数优化方法对分类模型的参数进行优化,得到参数极值,还包括步骤:
S55:通过预设随机采样方式,对超参数进行赋值,并使用赋值后的超参数对 分类模型进行迭代运算,直到迭代运算完成,得到参数极值。
预设随机采样方式,是指在超参数取值范围内为超参数随机选取若干个初始值。优选地,预设随机采样方式可以采用随机森林(RandomForest)算法。
迭代运算,即使用赋值后的超参数的初始值对分类模型进行代入计算,以使分类模型趋于稳定,则当分类模型趋于稳定后,此时的超参数的初始值为参数极值。
具体地,服务端可以通过导入sklearn库中的随机森林函数,通过随机森林算法对超参数取值范围内的初始值进行随机选取,然后进行迭代,直到找到参数极值为止。
在本实施例中,服务端通过预设随机采样方式对超参数取值范围内的初始值进行选取,由于采用了随机森林算法,提高了选取符合参数极值条件的初始值的准确率,可以减少迭代运算的次数,加快计算效率。
进一步地,在一实施例中,针对步骤S55中“直到所述迭代运算完成”,具体包括如下步骤:
S551:使用预设损失函数对迭代运算的第N次运算结果进行检测,其中,N为大于0的正整数。
损失函数,是在统计学上以及统计决策理论中一种将一个事件或在一个样本空间中的一个元素映射到另一个表达与其事件相关的经济成本或机会成本的实数上的一种函数。
在本实施例中,损失函数是评估迭代运算是否完成的验证函数,即当损失函数收敛,则确定迭代运算完成。
损失函数具体可以使用MSE函数,即均方误差函数。MSE,即Mean Square Error,其计算方法是求预测值与真实值之间距离的平方和。
具体地,服务端将迭代运算的单次运算结果作为输入,代入到损失函数中,并求损失函数的极限值。
S552:若第N次运算结果使得预设损失函数收敛,则确定迭代运算完成;若第N次运算结果未使得预设损失函数收敛,则使用预设损失函数对迭代运算的第N+1次运算结果进行检测,直到预设损失函数达到收敛为止。
具体地,当服务端计算得到损失函数的极限值时,代表损失函数已经收敛,则此时确定迭代运算完成,超参数的初始值为参数极值;若服务端未能计算到损失函数的极限值时,则代表损失函数不能收敛,此时,服务端需要以迭代运算的下一次运算结果作为输入,重新计算损失函数的极限值,直到损失函数达到收敛为止。
在本实施例中,服务端通过损失函数对迭代运算的运算结果进行判断,以确定迭代运算是否完成,即,通过损失函数的收敛与否来间接判断迭代运算是否完成,比直接通过对分类结果的分析来确定分类模型是否达到优化要更加快速和高效。
应理解,上述实施例中各步骤的序号的大小并不意味着执行顺序的先后,各过程的执行顺序应以其功能和内在逻辑确定,而不应对本申请实施例的实施过程构成任何限定。
在一实施例中,提供一种不平衡数据分类模型训练装置,该不平衡数据分类模型训练装置与上述实施例中不平衡数据分类模型训练方法一一对应。如图6所示,该不平衡数据分类模型训练装置包括登录验证模块61、页面发送模块62、自动保存模块63和提交验证模块64。各功能模块详细说明如下:
数据获取模块61,用于从预设样本库中获取不平衡数据;
降维模块62,用于根据预设的降维方法对不平衡数据进行降维处理,得到降维后的低维数据;
采样模块63,用于按照预设的采样方式对低维数据进行采样,得到平衡数据;
训练模块64,用于将平衡数据作为训练样本,使用预设的机器学习算法对训练样本进行训练,得到分类模型。
进一步地,降维模块62,包括:
自编码模型建立子模块621,用于建立预设层数的自编码模型;
特征提取子模块622,用于使用自编码模型对不平衡数据进行特征提取,得到不平衡数据的隐含特征,并将隐含特征作为降维后的低维数据。
进一步地,采样模块63,包括:
数量计算子模块631,用于计算低维数据的总数量,以及不同数据类型对应的 低维数据的分数量;
欠采样子模块632,用于若分数量与总数量的比值超过预设阈值,则将分数量对应的数据类型的数据作为多数类数据,并对多数类数据进行欠采样,得到第一样本数据;
过采样子模块633,用于若分数量与总数量的比值未超过预设阈值,则将分数量对应的数据类型的数据作为少数类数据,并对少数类数据进行过采样,得到第二样本数据;
样本组合子模块634,用于根据预设比例将第一样本数据与第二样本数据进行组合,得到平衡数据。
进一步地,不平衡数据分类模型训练装置,还包括:
参数优化模块65,用于按照预设的参数优化方法对分类模型的参数进行优化,得到参数极值,其中,参数极值为优化后的分类模型的参数值,用以提高分类模型的分类精度。
进一步地,参数优化模块65,包括:
集合设置子模块651,用于将分类模型的参数的预设候选值组成数据集合;
深度学习模型建立子模块652,用于根据交叉验证法,从数据集合中选取预设候选值建立至少两个深度学习模型;
分类子模块653,用于使用深度学习模型对训练样本进行分类,并根据深度学习模型的分类准确率确定最优深度学习模型;
极值选取子模块654,用于将最优深度学习模型所对应的预设候选值作为参数极值。
进一步地,参数优化模块65,还包括:
超参数优化子模块655,用于通过随机采样的方式,对超参数进行赋值,并使用赋值后的超参数对分类模型进行迭代运算,直到迭代运算完成,得到参数极值。
进一步,超参数优化子模块655,包括:
损失函数检测单元6551,用于使用预设损失函数对迭代运算的第N次运算结果进行检测,其中,N为大于0的正整数;
收敛计算单元6552,用于若第N次运算结果使得预设损失函数收敛,则确定迭代运算完成;若第N次运算结果未使得预设损失函数收敛,则使用预设损失函数对迭代运算的第N+1次运算结果进行检测,直到预设损失函数达到收敛为止。
关于不平衡数据分类模型训练装置的具体限定可以参见上文中对于不平衡数据分类模型训练方法的限定,在此不再赘述。上述不平衡数据分类模型训练装置中的各个模块可全部或部分通过软件、硬件及其组合来实现。上述各模块可以硬件形式内嵌于或独立于计算机设备中的处理器中,也可以以软件形式存储于计算机设备中的存储器中,以便于处理器调用执行以上各个模块对应的操作。
在一个实施例中,提供了一种计算机设备,该计算机设备可以是服务器,其内部结构图可以如图7所示。该计算机设备包括通过系统总线连接的处理器、存储器、网络接口和数据库。其中,该计算机设备的处理器用于提供计算和控制能力。该计算机设备的存储器包括非易失性存储介质、内存储器。该非易失性存储介质存储有操作系统、计算机可读指令和数据库。该内存储器为非易失性存储介质中的操作系统和计算机可读指令的运行提供环境。该计算机设备的网络接口用于与外部的终端通过网络连接通信。该计算机可读指令被处理器执行时以实现一种不平衡数据分类模型训练方法。
在一个实施例中,提供了一种计算机设备,包括存储器、处理器及存储在存储器上并可在处理器上运行的计算机可读指令,处理器执行计算机可读指令时实现上述实施例中不平衡数据分类模型训练方法的步骤,例如图2所示的步骤S1至步骤S4。或者,处理器执行计算机可读指令时实现上述实施例中不平衡数据分类模型训练装置的各模块/单元的功能,例如图7所示模块71至模块74的功能。为避免重复,这里不再赘述。
在一实施例中,提供一个或多个非易失性可读存储介质,其上存储有计算机可读指令,计算机可读指令被一个或多个处理器执行时实现上述方法实施例中不平衡数据分类模型训练方法,或者,该计算机可读指令被一个或多个处理器执行时实现上述装置实施例中不平衡数据分类模型训练装置中各模块/单元的功能。为避免重复,这里不再赘述。
本领域普通技术人员可以理解实现上述实施例方法中的全部或部分流程,是可 以通过计算机可读指令来指令相关的硬件来完成,所述的计算机可读指令可存储于一个或多个非易失性可读取存储介质中,该计算机可读指令在执行时,可包括如上述各方法的实施例的流程。其中,本申请所提供的各实施例中所使用的对存储器、存储、数据库或其它介质的任何引用,均可包括非易失性和/或易失性存储器。非易失性存储器可包括只读存储器(ROM)、可编程ROM(PROM)、电可编程ROM(EPROM)、电可擦除可编程ROM(EEPROM)或闪存。易失性存储器可包括随机存取存储器(RAM)或者外部高速缓冲存储器。作为说明而非局限,RAM以多种形式可得,诸如静态RAM(SRAM)、动态RAM(DRAM)、同步DRAM(SDRAM)、双数据率SDRAM(DDRSDRAM)、增强型SDRAM(ESDRAM)、同步链路(Synchlink)DRAM(SLDRAM)、存储器总线(Rambus)直接RAM(RDRAM)、直接存储器总线动态RAM(DRDRAM)、以及存储器总线动态RAM(RDRAM)等。
所属领域的技术人员可以清楚地了解到,为了描述的方便和简洁,仅以上述各功能单元、模块的划分进行举例说明,实际应用中,可以根据需要而将上述功能分配由不同的功能单元、模块完成,即将所述装置的内部结构划分成不同的功能单元或模块,以完成以上描述的全部或者部分功能。
以上所述实施例仅用以说明本申请的技术方案,而非对其限制;尽管参照前述实施例对本申请进行了详细的说明,本领域的普通技术人员应当理解:其依然可以对前述各实施例所记载的技术方案进行修改,或者对其中部分技术特征进行等同替换;而这些修改或者替换,并不使相应技术方案的本质脱离本申请各实施例技术方案的精神和范围,均应包含在本申请的保护范围之内。
发明概述
问题的解决方案
发明的有益效果
Claims (20)
- 一种不平衡数据分类模型训练方法,其特征在于,所述不平衡数据分类模型训练方法包括:从预设样本库中获取包含欺诈行为数据和非欺诈行为数据的不平衡数据;根据预设的降维方法对所述不平衡数据中的所述欺诈行为数据和所述非欺诈行为数据分别进行降维处理,确定所述欺诈行为数据和所述非欺诈行为数据中的隐含特征,并根据所述隐含特征得到所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据;按照预设的采样方式对所述低维数据进行采样,得到平衡数据;将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型。
- 如权利要求1所述的不平衡数据分类模型训练方法,其特征在于,所述根据预设的降维方法对所述不平衡数据中的所述欺诈行为数据和所述非欺诈行为数据分别进行降维处理,确定所述欺诈行为数据和所述非欺诈行为数据中的隐含特征,并根据所述隐含特征得到所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据,包括:建立预设层数的自编码模型;使用所述自编码模型对所述不平衡数据进行特征提取,得到所述不平衡数据的隐含特征,并将所述隐含特征作为所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据的特征维度,其中,所述低维数据由不同数据类型的数据组成。
- 如权利要求2所述的不平衡数据分类模型训练方法,其特征在于,所述按照预设的采样方式对所述低维数据进行采样,得到平衡数据,包括:计算所述低维数据的总数量,以及不同数据类型对应的低维数据的分数量;若所述分数量与所述总数量的比值超过预设阈值,则将所述分数量对应的数据类型的数据作为多数类数据,并对所述多数类数据进行欠采样,得到第一样本数据;若所述分数量与所述总数量的比值未超过预设阈值,则将所述分数量对应的数据类型的数据作为少数类数据,并对所述少数类数据进行过采样,得到第二样本数据;根据预设比例将所述第一样本数据与所述第二样本数据进行组合,得到所述平衡数据。
- 如权利要求1所述的不平衡数据分类模型训练方法,其特征在于,所述将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型之后,所述不平衡数据分类模型训练方法,还包括:按照预设的参数优化方法对所述分类模型的参数进行优化,得到参数极值,其中,所述参数极值为优化后的所述分类模型的参数值,用以提高所述分类模型的分类精度。
- 如权利要求4所述的不平衡数据分类模型训练方法,其特征在于,所述按照预设的参数优化方法对所述分类模型的参数进行优化,得到参数极值,包括:将所述分类模型的参数的预设候选值组成数据集合;根据交叉验证法,从所述数据集合中选取所述预设候选值建立至少两个深度学习模型;使用所述深度学习模型对所述训练样本进行分类,并根据所述深度学习模型的分类准确率确定最优深度学习模型;将所述最优深度学习模型所对应的预设候选值作为所述参数极值。
- 如权利要求4所述的不平衡数据分类模型训练方法,其特征在于,所述分类模型的参数包括超参数,所述按照预设的参数优化方法对分类模型的参数进行优化,得到参数极值,还包括:通过预设随机采样方式,对所述超参数进行赋值,并使用赋值后的超参数对所述分类模型进行迭代运算,直到所述迭代运算完成,得到所述参数极值。
- 如权利要求6所述的不平衡数据分类模型训练方法,其特征在于,所述直到所述迭代运算完成,包括:使用预设损失函数对所述迭代运算的第N次运算结果进行检测,其中,N为大于0的正整数;若所述第N次运算结果使得所述预设损失函数收敛,则确定所述迭代运算完成;若所述第N次运算结果未使得所述预设损失函数收敛,则使用预设损失函数对所述迭代运算的第N+1次运算结果进行检测,直到所述预设损失函数达到收敛为止。
- 一种不平衡数据分类模型训练装置,其特征在于,所述不平衡数据分类模型训练装置,包括:数据获取模块,用于从预设样本库中获取包含欺诈行为数据和非欺诈行为数据的不平衡数据;降维模块,用于根据预设的降维方法对所述不平衡数据中的所述欺诈行为数据和所述非欺诈行为数据分别进行降维处理,确定所述欺诈行为数据和所述非欺诈行为数据中的隐含特征,并根据所述隐含特征得到所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据;采样模块,用于按照预设的采样方式对所述低维数据进行采样,得到平衡数据;训练模块,用于将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型。
- 如权利要求8所述的不平衡数据分类模型训练装置,其特征在于,所述降维模块,包括:自编码模型建立子模块,用于建立预设层数的自编码模型;特征提取子模块,用于使用所述自编码模型对所述不平衡数据进 行特征提取,得到所述不平衡数据的隐含特征,并将所述隐含特征作为所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据的特征维度,其中,所述低维数据由不同数据类型的数据组成。
- 如权利要求9所述的不平衡数据分类模型训练装置,其特征在于,所述采样模块,包括:数量计算子模块,用于计算所述低维数据的总数量,以及不同数据类型对应的低维数据的分数量;欠采样子模块,用于若所述分数量与所述总数量的比值超过预设阈值,则将所述分数量对应的数据类型的数据作为多数类数据,并对所述多数类数据进行欠采样,得到第一样本数据;过采样子模块,用于若所述分数量与所述总数量的比值未超过预设阈值,则将所述分数量对应的数据类型的数据作为少数类数据,并对所述少数类数据进行过采样,得到第二样本数据;样本组合子模块,用于根据预设比例将所述第一样本数据与所述第二样本数据进行组合,得到所述平衡数据。
- 如权利要求8所述的不平衡数据分类模型训练装置,其特征在于,所述不平衡数据分类模型训练装置,还包括:参数优化模块,用于按照预设的参数优化方法对所述分类模型的参数进行优化,得到参数极值,其中,所述参数极值为优化后的所述分类模型的参数值,用以提高所述分类模型的分类精度。
- 如权利要求11所述的不平衡数据分类模型训练装置,其特征在于,所述参数优化模块,包括:集合设置子模块,用于将所述分类模型的参数的预设候选值组成数据集合;深度学习模型建立子模块,用于根据交叉验证法,从所述数据集合中选取所述预设候选值建立至少两个深度学习模型;分类子模块,用于使用所述深度学习模型对所述训练样本进行分 类,并根据所述深度学习模型的分类准确率确定最优深度学习模型;极值选取子模块,用于将所述最优深度学习模型所对应的预设候选值作为所述参数极值。
- 如权利要求11所述的不平衡数据分类模型训练装置,其特征在于,所述分类模型的参数包括超参数,所述参数优化模块,还包括:超参数优化子模块,用于通过预设随机采样方式,对所述超参数进行赋值,并使用赋值后的超参数对所述分类模型进行迭代运算,直到所述迭代运算完成,得到所述参数极值。
- 如权利要求13所述的不平衡数据分类模型训练装置,其特征在于,所述超参数优化子模块,包括:损失函数检测单元,用于使用预设损失函数对所述迭代运算的第N次运算结果进行检测,其中,N为大于0的正整数;收敛计算单元,用于若所述第N次运算结果使得所述预设损失函数收敛,则确定所述迭代运算完成;若所述第N次运算结果未使得所述预设损失函数收敛,则使用预设损失函数对所述迭代运算的第N+1次运算结果进行检测,直到所述预设损失函数达到收敛为止。
- 一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可在所述处理器上运行的计算机可读指令,其特征在于,所述处理器执行所述计算机可读指令时实现如下步骤:从预设样本库中获取包含欺诈行为数据和非欺诈行为数据的不平衡数据;根据预设的降维方法对所述不平衡数据中的所述欺诈行为数据和所述非欺诈行为数据分别进行降维处理,确定所述欺诈行为数据和所述非欺诈行为数据中的隐含特征,并根据所述隐含特征得到所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据;按照预设的采样方式对所述低维数据进行采样,得到平衡数据;将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型。
- 如权利要求15所述的计算机设备,其特征在于,所述根据预设的降维方法对所述不平衡数据中的所述欺诈行为数据和所述非欺诈行为数据分别进行降维处理,确定所述欺诈行为数据和所述非欺诈行为数据中的隐含特征,并根据所述隐含特征得到所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据,包括:建立预设层数的自编码模型;使用所述自编码模型对所述不平衡数据进行特征提取,得到所述不平衡数据的隐含特征,并将所述隐含特征作为所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据的特征维度,其中,所述低维数据由不同数据类型的数据组成。
- 如权利要求16所述的计算机设备,其特征在于,所述按照预设的采样方式对所述低维数据进行采样,得到平衡数据,包括:计算所述低维数据的总数量,以及不同数据类型对应的低维数据的分数量;若所述分数量与所述总数量的比值超过预设阈值,则将所述分数量对应的数据类型的数据作为多数类数据,并对所述多数类数据进行欠采样,得到第一样本数据;若所述分数量与所述总数量的比值未超过预设阈值,则将所述分数量对应的数据类型的数据作为少数类数据,并对所述少数类数据进行过采样,得到第二样本数据;根据预设比例将所述第一样本数据与所述第二样本数据进行组合,得到所述平衡数据。
- 一个或多个存储有计算机可读指令的非易失性可读存储介质,其特征在于,所述计算机可读指令被一个或多个处理器执行时,使得所述一个或多个处理器执行如下步骤:从预设样本库中获取包含欺诈行为数据和非欺诈行为数据的不平 衡数据;根据预设的降维方法对所述不平衡数据中的所述欺诈行为数据和所述非欺诈行为数据分别进行降维处理,确定所述欺诈行为数据和所述非欺诈行为数据中的隐含特征,并根据所述隐含特征得到所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据;按照预设的采样方式对所述低维数据进行采样,得到平衡数据;将所述平衡数据作为训练样本,使用预设的机器学习算法对所述训练样本进行训练,得到分类模型。
- 如权利要求18所述的非易失性可读存储介质,其特征在于,所述根据预设的降维方法对所述不平衡数据中的所述欺诈行为数据和所述非欺诈行为数据分别进行降维处理,确定所述欺诈行为数据和所述非欺诈行为数据中的隐含特征,并根据所述隐含特征得到所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据,包括:建立预设层数的自编码模型;使用所述自编码模型对所述不平衡数据进行特征提取,得到所述不平衡数据的隐含特征,并将所述隐含特征作为所述欺诈行为数据和所述非欺诈行为数据降维后的低维数据的特征维度,其中,所述低维数据由不同数据类型的数据组成。
- 如权利要求19所述的非易失性可读存储介质,其特征在于,所述按照预设的采样方式对所述低维数据进行采样,得到平衡数据,包括:计算所述低维数据的总数量,以及不同数据类型对应的低维数据的分数量;若所述分数量与所述总数量的比值超过预设阈值,则将所述分数量对应的数据类型的数据作为多数类数据,并对所述多数类数据进行欠采样,得到第一样本数据;若所述分数量与所述总数量的比值未超过预设阈值,则将所述分 数量对应的数据类型的数据作为少数类数据,并对所述少数类数据进行过采样,得到第二样本数据;根据预设比例将所述第一样本数据与所述第二样本数据进行组合,得到所述平衡数据。
Applications Claiming Priority (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN201910351188.7 | 2019-04-28 | ||
CN201910351188.7A CN110163261B (zh) | 2019-04-28 | 2019-04-28 | 不平衡数据分类模型训练方法、装置、设备及存储介质 |
Publications (1)
Publication Number | Publication Date |
---|---|
WO2020220544A1 true WO2020220544A1 (zh) | 2020-11-05 |
Family
ID=67640195
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
PCT/CN2019/103523 WO2020220544A1 (zh) | 2019-04-28 | 2019-08-30 | 不平衡数据分类模型训练方法、装置、设备及存储介质 |
Country Status (2)
Country | Link |
---|---|
CN (1) | CN110163261B (zh) |
WO (1) | WO2020220544A1 (zh) |
Cited By (6)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN113313110A (zh) * | 2021-05-25 | 2021-08-27 | 北京易华录信息技术股份有限公司 | 一种车牌类型识别模型构建及车牌类型识别方法 |
CN113762343A (zh) * | 2021-08-04 | 2021-12-07 | 德邦证券股份有限公司 | 处理舆情信息和训练分类模型的方法、装置以及存储介质 |
CN114330135A (zh) * | 2021-12-30 | 2022-04-12 | 国网浙江省电力有限公司信息通信分公司 | 分类模型构建方法及装置、存储介质及电子设备 |
CN114372560A (zh) * | 2021-12-30 | 2022-04-19 | 厦门市美亚柏科信息股份有限公司 | 一种神经网络训练方法、装置、设备及存储介质 |
CN115146689A (zh) * | 2021-03-16 | 2022-10-04 | 天津大学 | 一种基于深度学习的动力系统高维测量数据降维方法 |
CN117761733A (zh) * | 2023-12-21 | 2024-03-26 | 中国船舶集团有限公司第七一六研究所 | 一种用于无人机的gps欺骗攻击检测方法及系统 |
Families Citing this family (33)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN110163261B (zh) * | 2019-04-28 | 2024-06-21 | 平安科技(深圳)有限公司 | 不平衡数据分类模型训练方法、装置、设备及存储介质 |
CN110672324B (zh) * | 2019-09-02 | 2021-03-26 | 佛山科学技术学院 | 一种基于有监督lle算法的轴承故障诊断方法及装置 |
CN110705592B (zh) * | 2019-09-03 | 2024-05-14 | 平安科技(深圳)有限公司 | 分类模型训练方法、装置、设备及计算机可读存储介质 |
CN110717515A (zh) * | 2019-09-06 | 2020-01-21 | 北京三快在线科技有限公司 | 模型训练方法、装置及电子设备 |
CN110569925B (zh) * | 2019-09-18 | 2023-05-26 | 南京领智数据科技有限公司 | 应用于电力设备运检的基于lstm的时序异常检测方法 |
CN110796171A (zh) * | 2019-09-27 | 2020-02-14 | 北京淇瑀信息科技有限公司 | 机器学习模型的未分类样本处理方法、装置及电子设备 |
CN110796482A (zh) * | 2019-09-27 | 2020-02-14 | 北京淇瑀信息科技有限公司 | 用于机器学习模型的金融数据分类方法、装置及电子设备 |
CN110889457B (zh) * | 2019-12-03 | 2022-08-19 | 深圳奇迹智慧网络有限公司 | 样本图像分类训练方法、装置、计算机设备和存储介质 |
CN111178435B (zh) | 2019-12-30 | 2022-03-22 | 山东英信计算机技术有限公司 | 一种分类模型训练方法、系统、电子设备及存储介质 |
CN111340182B (zh) * | 2020-02-11 | 2024-04-02 | 无锡北邮感知技术产业研究院有限公司 | 一种输入特征逼近的低复杂度cnn训练方法与装置 |
CN111556016B (zh) * | 2020-03-25 | 2021-02-26 | 中国科学院信息工程研究所 | 一种基于自动编码器的网络流量异常行为识别方法 |
CN111626327A (zh) * | 2020-04-15 | 2020-09-04 | 中国电子产品可靠性与环境试验研究所((工业和信息化部电子第五研究所)(中国赛宝实验室)) | 飞机重着陆预测方法、装置、计算机设备及存储介质 |
CN113554048B (zh) * | 2020-04-26 | 2024-02-02 | 中移(成都)信息通信科技有限公司 | 数据识别方法、装置、设备及存储介质 |
CN112165464B (zh) * | 2020-09-15 | 2021-11-02 | 江南大学 | 一种基于深度学习的工控混合入侵检测方法 |
CN111967993A (zh) * | 2020-09-25 | 2020-11-20 | 北京信息科技大学 | 一种在线交易反欺诈方法及系统 |
CN112101952B (zh) * | 2020-09-27 | 2024-05-10 | 中国建设银行股份有限公司 | 银行可疑交易评估、数据处理方法及装置 |
CN112132225A (zh) * | 2020-09-28 | 2020-12-25 | 天津天地伟业智能安全防范科技有限公司 | 一种基于深度学习的数据增强方法 |
CN112257767B (zh) * | 2020-10-16 | 2023-03-17 | 浙江大学 | 针对类不均衡数据的产品关键零部件状态分类方法 |
CN112241705A (zh) * | 2020-10-19 | 2021-01-19 | 平安科技(深圳)有限公司 | 基于分类回归的目标检测模型训练方法和目标检测方法 |
CN112257807B (zh) * | 2020-11-02 | 2022-05-27 | 曲阜师范大学 | 一种基于自适应优化线性邻域集选择的降维方法及系统 |
CN112579711B (zh) * | 2020-12-28 | 2024-09-24 | 广东电网有限责任公司广州供电局 | 不平衡数据的分类方法、装置、存储介质及设备 |
CN112732913B (zh) * | 2020-12-30 | 2023-08-22 | 平安科技(深圳)有限公司 | 一种非均衡样本的分类方法、装置、设备及存储介质 |
CN112964962B (zh) * | 2021-02-05 | 2022-05-20 | 国网宁夏电力有限公司 | 一种输电线路故障分类方法 |
CN112861512A (zh) * | 2021-02-05 | 2021-05-28 | 北京百度网讯科技有限公司 | 数据处理方法、装置、设备以及存储介质 |
CN113127955A (zh) * | 2021-03-26 | 2021-07-16 | 广州大学 | 一种建筑抗震性能评估方法、系统、装置及存储介质 |
CN113628697A (zh) * | 2021-07-28 | 2021-11-09 | 上海基绪康生物科技有限公司 | 一种针对分类不平衡数据优化的随机森林模型训练方法 |
CN113569953A (zh) * | 2021-07-29 | 2021-10-29 | 中国工商银行股份有限公司 | 分类模型的训练方法、装置及电子设备 |
CN114239740A (zh) * | 2021-12-21 | 2022-03-25 | 税友信息技术有限公司 | 一种品名学习方法、系统及相关装置 |
CN114662580A (zh) * | 2022-03-14 | 2022-06-24 | 平安科技(深圳)有限公司 | 数据分类模型的训练方法、分类方法、装置、设备和介质 |
CN114661701A (zh) * | 2022-03-16 | 2022-06-24 | 平安科技(深圳)有限公司 | 一种数据均衡化方法、装置、电子设备及存储介质 |
CN116032615A (zh) * | 2022-12-27 | 2023-04-28 | 安徽江淮汽车集团股份有限公司 | 车载can总线入侵检测方法 |
CN116028815B (zh) * | 2023-01-05 | 2024-08-23 | 江苏科技大学 | 一种针对工业时序不平衡数据的时间序列数据增强方法 |
CN117540328B (zh) * | 2024-01-09 | 2024-04-02 | 山西众诚安信安全科技有限公司 | 一种煤矿噪声高精度测量过程中噪声处理方法 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104156562A (zh) * | 2014-07-15 | 2014-11-19 | 清华大学 | 银行后台运维系统的故障预测方法及系统 |
CN108665100A (zh) * | 2018-05-09 | 2018-10-16 | 中国农业大学 | 一种水体水质预测方法、系统及装置 |
CN109614420A (zh) * | 2018-12-06 | 2019-04-12 | 南京森根科技发展有限公司 | 一种基于大数据挖掘的虚拟身份关联分析算法模型 |
CN109636061A (zh) * | 2018-12-25 | 2019-04-16 | 深圳市南山区人民医院 | 医保欺诈预测网络的训练方法、装置、设备及存储介质 |
CN110163261A (zh) * | 2019-04-28 | 2019-08-23 | 平安科技(深圳)有限公司 | 不平衡数据分类模型训练方法、装置、设备及存储介质 |
Family Cites Families (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN106372581B (zh) * | 2016-08-25 | 2020-09-04 | 中国传媒大学 | 构建及训练人脸识别特征提取网络的方法 |
CN107247968A (zh) * | 2017-07-24 | 2017-10-13 | 东北林业大学 | 基于核熵成分分析失衡数据下物流设备异常检测方法 |
CN108921208A (zh) * | 2018-06-20 | 2018-11-30 | 天津大学 | 基于深度学习的不平衡数据的均衡采样及建模方法 |
-
2019
- 2019-04-28 CN CN201910351188.7A patent/CN110163261B/zh active Active
- 2019-08-30 WO PCT/CN2019/103523 patent/WO2020220544A1/zh active Application Filing
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN104156562A (zh) * | 2014-07-15 | 2014-11-19 | 清华大学 | 银行后台运维系统的故障预测方法及系统 |
CN108665100A (zh) * | 2018-05-09 | 2018-10-16 | 中国农业大学 | 一种水体水质预测方法、系统及装置 |
CN109614420A (zh) * | 2018-12-06 | 2019-04-12 | 南京森根科技发展有限公司 | 一种基于大数据挖掘的虚拟身份关联分析算法模型 |
CN109636061A (zh) * | 2018-12-25 | 2019-04-16 | 深圳市南山区人民医院 | 医保欺诈预测网络的训练方法、装置、设备及存储介质 |
CN110163261A (zh) * | 2019-04-28 | 2019-08-23 | 平安科技(深圳)有限公司 | 不平衡数据分类模型训练方法、装置、设备及存储介质 |
Cited By (9)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115146689A (zh) * | 2021-03-16 | 2022-10-04 | 天津大学 | 一种基于深度学习的动力系统高维测量数据降维方法 |
CN113313110A (zh) * | 2021-05-25 | 2021-08-27 | 北京易华录信息技术股份有限公司 | 一种车牌类型识别模型构建及车牌类型识别方法 |
CN113313110B (zh) * | 2021-05-25 | 2024-02-13 | 北京易华录信息技术股份有限公司 | 一种车牌类型识别模型构建及车牌类型识别方法 |
CN113762343A (zh) * | 2021-08-04 | 2021-12-07 | 德邦证券股份有限公司 | 处理舆情信息和训练分类模型的方法、装置以及存储介质 |
CN113762343B (zh) * | 2021-08-04 | 2024-03-15 | 德邦证券股份有限公司 | 处理舆情信息和训练分类模型的方法、装置以及存储介质 |
CN114330135A (zh) * | 2021-12-30 | 2022-04-12 | 国网浙江省电力有限公司信息通信分公司 | 分类模型构建方法及装置、存储介质及电子设备 |
CN114372560A (zh) * | 2021-12-30 | 2022-04-19 | 厦门市美亚柏科信息股份有限公司 | 一种神经网络训练方法、装置、设备及存储介质 |
CN114330135B (zh) * | 2021-12-30 | 2024-08-23 | 国网浙江省电力有限公司信息通信分公司 | 分类模型构建方法及装置、存储介质及电子设备 |
CN117761733A (zh) * | 2023-12-21 | 2024-03-26 | 中国船舶集团有限公司第七一六研究所 | 一种用于无人机的gps欺骗攻击检测方法及系统 |
Also Published As
Publication number | Publication date |
---|---|
CN110163261A (zh) | 2019-08-23 |
CN110163261B (zh) | 2024-06-21 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2020220544A1 (zh) | 不平衡数据分类模型训练方法、装置、设备及存储介质 | |
US10713597B2 (en) | Systems and methods for preparing data for use by machine learning algorithms | |
US11113124B2 (en) | Systems and methods for quickly searching datasets by indexing synthetic data generating models | |
Zhai et al. | Deep structured energy based models for anomaly detection | |
US11810000B2 (en) | Systems and methods for expanding data classification using synthetic data generation in machine learning models | |
WO2023116111A1 (zh) | 一种磁盘故障预测方法及装置 | |
CN110442516B (zh) | 信息处理方法、设备及计算机可读存储介质 | |
CN110363230B (zh) | 基于加权基分类器的stacking集成污水处理故障诊断方法 | |
US20210158227A1 (en) | Systems and methods for generating model output explanation information | |
US20230133247A1 (en) | Systems and techniques to monitor text data quality | |
CN103473556B (zh) | 基于拒识子空间的分层支持向量机分类方法 | |
US10956825B1 (en) | Distributable event prediction and machine learning recognition system | |
US20220058449A1 (en) | Systems and methods for classifying data using hierarchical classification model | |
US11354567B2 (en) | Systems and methods for classifying data sets using corresponding neural networks | |
Satyanarayana et al. | Survey of classification techniques in data mining | |
CN110826611A (zh) | 基于多个元分类器加权集成的stacking污水处理故障诊断方法 | |
CN113674862A (zh) | 一种基于机器学习的急性肾功能损伤发病预测方法 | |
Waqas et al. | Robust bag classification approach for multi-instance learning via subspace fuzzy clustering | |
Xiu et al. | Variational disentanglement for rare event modeling | |
CN115797041A (zh) | 基于深度图半监督学习的金融信用评估方法 | |
US11868899B2 (en) | System and method for model configuration selection preliminary class | |
US20230018525A1 (en) | Artificial Intelligence (AI) Framework to Identify Object-Relational Mapping Issues in Real-Time | |
Arya et al. | Design an Improved Model of Software Defect Prediction Model for Web Applications | |
CN114297397A (zh) | 基于卷积网络的路径感知的知识图谱补全方法及相关设备 | |
CN114819454A (zh) | 窃电检测方法、装置、设备、存储介质及程序产品 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
121 | Ep: the epo has been informed by wipo that ep was designated in this application |
Ref document number: 19926898 Country of ref document: EP Kind code of ref document: A1 |
|
NENP | Non-entry into the national phase |
Ref country code: DE |
|
122 | Ep: pct application non-entry in european phase |
Ref document number: 19926898 Country of ref document: EP Kind code of ref document: A1 |