WO2022252458A1 - Classification model training method and apparatus, device, and medium - Google Patents

Classification model training method and apparatus, device, and medium Download PDF

Info

Publication number
WO2022252458A1
WO2022252458A1 PCT/CN2021/121905 CN2021121905W WO2022252458A1 WO 2022252458 A1 WO2022252458 A1 WO 2022252458A1 CN 2021121905 W CN2021121905 W CN 2021121905W WO 2022252458 A1 WO2022252458 A1 WO 2022252458A1
Authority
WO
WIPO (PCT)
Prior art keywords
graph
training
vertex
classification model
matrix
Prior art date
Application number
PCT/CN2021/121905
Other languages
French (fr)
Chinese (zh)
Inventor
胡克坤
董刚
赵雅倩
刘海威
徐哲
Original Assignee
苏州浪潮智能科技有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 苏州浪潮智能科技有限公司 filed Critical 苏州浪潮智能科技有限公司
Publication of WO2022252458A1 publication Critical patent/WO2022252458A1/en

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning

Definitions

  • the present application relates to the technical field of classifiers, and in particular to a classification model training method, device, equipment and medium.
  • a graph neural network usually consists of an input layer, one or more hidden layers, and an output layer.
  • Fig. 1 is a graph neural network structural diagram in the prior art, and Fig. 1 shows a typical graph convolutional neural network structure, which consists of an input layer (Input layer), It consists of two graph convolutional layers (Gconv layer) and an output layer (Output layer).
  • the input layer reads the n*d-dimensional vertex feature matrix
  • the graph convolution layer performs feature extraction on the vertex feature matrix, which is passed to the next graph convolution layer after nonlinear activation functions such as ReLu transformation.
  • the task layer completes specific tasks such as vertex classification, clustering, etc.
  • Figure 1 shows a vertex classification task layer that outputs the category label of each vertex. At present, how to improve the classification accuracy is a problem that needs to be solved.
  • the purpose of the present application is to provide a classification model training method, device, equipment and medium, which can improve the classification accuracy of the classification model.
  • the specific plan is as follows:
  • the present application discloses a classification model training method, including:
  • the vertex label matrix includes label information for each vertex of the graph data set
  • the vertex feature matrix, the adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding supervised training loss in the training process;
  • the vertex feature matrix and the adjacency matrix are input to the Student graph wavelet neural network in the classification model to carry out unsupervised training, and determine the corresponding unsupervised training loss in the training process;
  • the current classification model is output to obtain the trained classification model.
  • determining the corresponding supervised training loss in the training process includes:
  • the corresponding supervised training loss is determined based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix;
  • the corresponding unsupervised training loss is determined in the training process, including:
  • a corresponding unsupervised training loss is determined based on the second vertex label prediction result of the Student graph wavelet neural network and the first vertex label prediction result.
  • the current vertex label matrix is output to obtain the category prediction result of each vertex without a category label.
  • the method also includes:
  • the Teacher graph wavelet neural network and the Student graph wavelet neural network perform graph convolution operations based on the graph wavelet transform base and the graph wavelet inverse transform base.
  • the method also includes:
  • calculation formula is a formula defined based on spectrum theory.
  • both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer;
  • the graph convolution layer is used to sequentially perform feature transformation and graph convolution operation processing on the input data of the layer during the training process.
  • the method also includes:
  • the convolution kernel of the graph convolution layer obtained through the training of the Teacher graph wavelet neural network is used to determine the convolution kernel of the corresponding graph convolution layer in the Student graph wavelet neural network based on the attention mechanism.
  • a classification model training device including:
  • a training data construction module configured to construct a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on a graph data set; wherein, the vertex label matrix includes label information for each vertex of the graph data set;
  • the classification model training module is used to input the vertex feature matrix, the adjacency matrix and the vertex label matrix to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding effective supervised training loss; input the vertex feature matrix and the adjacency matrix to the Student graph wavelet neural network in the classification model for unsupervised training, and determine the corresponding unsupervised training loss in the training process; based on the supervised The training loss and the unsupervised training loss determine the target training loss; when the target training loss converges, the current classification model is output to obtain the trained classification model.
  • an electronic device comprising:
  • a processor configured to execute the computer program, so as to realize the aforementioned classification model training method.
  • the present application discloses a computer-readable storage medium for storing a computer program, and when the computer program is executed by a processor, the aforementioned classification model training method is implemented.
  • the present application constructs a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on a graph data set; wherein, the vertex label matrix includes the label information of each vertex of the graph data set, and then the vertex feature matrix, the vertex label matrix, and The adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding supervised training loss in the training process; the vertex feature matrix, the adjacency matrix Input to the Student graph wavelet neural network in classification model to carry out unsupervised training, and determine corresponding unsupervised training loss in training process; Based on described supervised training loss and described unsupervised training loss, determine target training loss; When When the target training loss converges, the current classification model is output to obtain the post-training classification model.
  • the vertex feature matrix and adjacency matrix of the graph data set are input into the graph neural network for training, and the graph topology and vertex features are used.
  • supervised training and unsupervised training are used to give full play to supervised training and The respective advantages of unsupervised training can improve the classification accuracy of the classification model.
  • Fig. 1 is a kind of graph neural network structural diagram in the prior art
  • Fig. 2 is a flow chart of a classification model training method disclosed in the present application
  • Fig. 3 is a flow chart of a specific classification model training method disclosed in the present application.
  • Fig. 4 is a kind of classification model structural diagram disclosed in the present application.
  • FIG. 5 is a structural diagram of a specific classification model disclosed in the present application.
  • FIG. 6 is a flow chart of a specific classification model training method disclosed in the present application.
  • FIG. 7 is a schematic structural diagram of a classification model training device disclosed in the present application.
  • FIG. 8 is a structural diagram of an electronic device disclosed in the present application.
  • the embodiment of the present application discloses a classification model training method, including:
  • Step S11 Construct a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on the graph dataset; wherein, the vertex label matrix includes label information for each vertex of the graph dataset;
  • the label information indicates a corresponding category label or no category label.
  • V represents a vertex set
  • E represents the set of connected edges.
  • each vertex v of G has d features, and the features of all vertices constitute an n*d-dimensional vertex feature matrix, which is denoted as X.
  • the adjacency matrix of G is denoted as A, and the element A ij represents the weight of the connection edge between vertices i and j.
  • represents the number of all vertices in the graph, C represents the number of label categories of all vertices, and the matrix element Y ij represents Whether the category label of vertex i is j (j 1, 2, ..., C), when vertex i already has a category label, set the corresponding j-th column element to 1, and set the other column elements to 0. That is:
  • Step S12 Input the vertex feature matrix, the adjacency matrix and the vertex label matrix into the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss during the training process.
  • Step S13 Input the vertex feature matrix and the adjacency matrix into the Student graph wavelet neural network in the classification model for unsupervised training, and determine the corresponding unsupervised training loss during the training process.
  • the corresponding supervised training loss is determined based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix; based on the second vertex of the Student graph wavelet neural network
  • the label prediction result and the first vertex label prediction result determine a corresponding unsupervised training loss.
  • the first vertex label prediction result is compared with the vertex label matrix to calculate a supervised training loss
  • the second vertex label prediction result is compared with the first vertex label prediction result to calculate an unsupervised learning loss
  • Step S14 Determine a target training loss based on the supervised training loss and the unsupervised training loss.
  • the calculation formula of the target training loss is as follows:
  • ls T represents the supervised training loss
  • ls S represents the unsupervised training loss
  • is a constant used to adjust the proportion of unsupervised training loss in the target loss
  • Z T represents the first vertex label prediction result
  • Z S represents the second vertex label prediction result.
  • the output layer of the Teacher graph wavelet neural network and the Student graph wavelet neural network can be defined as
  • ⁇ r is the graph wavelet transform basis, is the graph wavelet inverse transform base, F L represents the convolution kernel matrix of the L layer graph convolution layer, Q L represents the L layer vertex feature transformation result, the Teacher graph wavelet neural network and the Student graph wavelet neural network both include L layer graphs convolutional layer.
  • the supervised training loss function calculates the degree of difference between the actual label probability distribution and the predicted label probability distribution of vertices based on the principle of cross entropy; the unsupervised training loss function calculates the sum of squares of the differences between the same coordinate elements of Z T and Z S.
  • the output results Z T and Z S of the two networks are consistent or the difference is negligible.
  • the output Z T of the teacher graph wavelet neural network can be used as the output of the entire network model.
  • the vertex label matrix is updated using the first vertex label prediction result. Specifically, for vertices without category labels, that is, for v i ⁇ V U , the first vertex label prediction result The category with the highest probability is used as the latest category of the vertex, and the vertex label matrix is updated.
  • Step S15 When the target training loss converges, output the current classification model to obtain the post-training classification model.
  • the current vertex label matrix is output to obtain the category prediction result of each vertex without a category label.
  • the target training loss when the target training loss reaches a preset threshold or the number of iterations reaches a specified maximum value of iterations, the target training loss converges and the training ends.
  • the preset threshold is usually a small value, at this time, for a vertex without a class label, the class to which it should belong is obtained according to the current vertex label matrix.
  • this application integrates the prediction of unlabeled vertices into the training process: during the training process, the vertex label matrix is updated according to each training result, and the category label of any unlabeled vertex can be obtained after the training is completed.
  • the network parameters of each layer of the graph wavelet neural network may be initialized first according to a specific strategy such as random initialization with normal distribution, Xavier initialization or He initialization.
  • a specific strategy such as random initialization with normal distribution, Xavier initialization or He initialization.
  • specific strategies such as SGD (Stochastic Gradient Descent, stochastic gradient descent), MGD (Momentum Gradient Descent, momentum gradient descent), Nesterov Momentum (Newton momentum), AdaGrad (Adaptive gradient algorithm, automatic Adaptive gradient algorithm), RMSprop (ie Root Mean Square Prop, forward root mean square gradient descent algorithm) and Adam (ie Adaptive Moment Estimation, adaptive moment estimation) or BGD (ie Batch Gradient Descent, batch gradient descent), etc.
  • the network parameters of each layer of the graph wavelet neural network are corrected and updated to optimize the value of the loss function.
  • the vertex feature matrix, adjacency matrix, and vertex label matrix are constructed based on the graph data set; wherein, the vertex label matrix includes the label information of each vertex of the graph data set, and then the vertex feature matrix , the adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss in the training process; the vertex feature matrix, the The adjacency matrix is input to the Student graph wavelet neural network in the classification model for unsupervised training, and the corresponding unsupervised training loss is determined during the training process; the target training loss is determined based on the supervised training loss and the unsupervised training loss ; When the target training loss converges, the current classification model is output to obtain the trained classification model.
  • the vertex feature matrix and adjacency matrix of the graph data set are input into the graph neural network for training, and the graph topology and vertex features are used.
  • supervised training and unsupervised training are used to give full play to supervised training and The respective advantages of unsupervised training can improve the classification accuracy of the classification model.
  • the embodiment of the present application discloses a specific classification model training method, including:
  • Step S21 Obtain the calculation formula of the graph wavelet transform basis.
  • calculation formula is a formula defined based on spectrum theory.
  • Step S22 Calculate the graph wavelet transform basis and the graph wavelet inverse transform basis of the graph data set by using Chebyshev polynomials.
  • H r diag(h(r ⁇ 1 ), h(r ⁇ 2 ),..., h(r ⁇ n )) is the scaling matrix whose scaling scale is r, and let is the eigenvalue obtained by eigendecomposing the Laplacian matrix of graph G; the graph wavelet inverse transform base It can be obtained by replacing h(r ⁇ i ) in ⁇ r with h(-r ⁇ i ).
  • the Teacher graph wavelet neural network and the Student graph wavelet neural network perform graph convolution operations based on the graph wavelet transform base and the graph wavelet inverse transform base.
  • the graph Fourier transform is inefficient in the process of graph convolution operation in the prior art, because the eigenvector matrix of the Laplacian matrix is dense, and this embodiment is based on the graph wavelet
  • the transform base and graph wavelet inverse transform base perform graph convolution operations.
  • the graph wavelet transform base and graph wavelet inverse transform base are sparse, so the computational efficiency of graph convolution operations can be improved.
  • Step S23 Construct a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on the graph data set; wherein, the vertex label matrix includes label information of each vertex of the graph data set, and the label information represents the corresponding category label or none category label.
  • Step S24 Input the vertex feature matrix, the adjacency matrix and the vertex label matrix to the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss during the training process.
  • Step S25 Input the vertex feature matrix and the adjacency matrix into the Student graph wavelet neural network in the classification model for unsupervised training, and determine the corresponding unsupervised training loss during the training process.
  • both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer;
  • the input data of is sequentially processed by feature transformation and graph convolution operation.
  • it may include 1 input layer, L (L ⁇ 1) graph convolution layers, and an output layer.
  • the graph convolution layer first performs feature transformation on the input data of the layer, and then the graph convolution operation processing, so that the graph convolution layer is divided into two processes of feature transformation and graph convolution operation stage, network parameters can be reduced, thereby reducing the amount of model computation and improving model training efficiency.
  • X represents the vertex feature matrix
  • m represents the ordinal number of the graph convolution layer
  • F is the graph convolution kernel matrix
  • h is the activation function.
  • the number of parameters included is n*p*q, where n represents the number of vertices in the graph, p represents the vertex feature dimension of the layer input, and q represents the output of the layer Vertex feature dimension.
  • the feature transformation is separated from the graph convolution operation, and the number of parameters of each graph convolution layer becomes n+p*q.
  • the convolution kernel of the graph convolution layer obtained by using the Teacher graph wavelet neural network training based on the attention mechanism determines the corresponding graph volume in the Student graph wavelet neural network. Layered convolution kernels.
  • the classification model may include a Teacher graph wavelet neural network, a Student graph wavelet neural network, and an attention network connecting each pair of the Teacher graph wavelet neural network and the Student graph wavelet neural network.
  • F l be the graph convolution kernel matrix of layer l, which is a diagonal matrix. From the perspective of signal processing, the elements (f 1 , f 2 . Note that the convolution kernel matrices of the first layer of the Teacher graph wavelet neural network and the Student graph wavelet neural network are T l and S l respectively, and the convolution kernel t l of the Teacher graph wavelet neural network layer and the Student graph wavelet neural network are The convolution kernel s l of the layer is obtained by diagonalizing, both of which are n-dimensional column vectors.
  • attention transfer can be performed based on the attention mechanism: each layer of the Teacher graph wavelet neural network transfers the learned convolution kernel to the corresponding layer of the Student graph wavelet neural network, that is, the Student graph wavelet neural network.
  • the graph wavelet neural network learns from the Teacher graph wavelet neural network to improve the performance of the entire network.
  • e l (i) represents the i-th component of e l
  • e′ l (i) represents the i-th component of e′ l
  • s′ l (i) represents the first layer of convolution kernel learned by the Student graph wavelet neural network from the Teacher graph wavelet neural network.
  • Step S26 Determine a target training loss based on the supervised training loss and the unsupervised training loss.
  • Step S27 When the target training loss converges, output the current classification model to obtain the trained classification model.
  • FIG. 4 is a structure diagram of a classification model disclosed in the embodiment of the present application.
  • Teacher graph wavelet neural network GWN T Student graph wavelet neural network GWNS
  • FIG. 5 is a specific classification model structure diagram disclosed in the embodiment of the present application.
  • the classification model consists of a Teacher graph wavelet neural network GWN T , a Student graph wavelet neural network GWNS S , and an attention network connecting each pair of graph convolutional layers of the two networks.
  • GWN T performs supervised learning based on labeled graph vertices, and the prediction accuracy is high; GWNS S uses unlabeled graph vertices to perform unsupervised learning under the guidance of GWN T (using its prediction results), in order to improve the prediction accuracy. Get a better vertex classification model.
  • the attention network is used by GWN T to transfer the "knowledge" learned by each layer, that is, the convolution kernel, to the corresponding layer of GWN S , that is, GWN S learns from GWN T.
  • Both GWN T and GWN S contain 1 input layer, L graph convolution layers and 1 output layer.
  • the input layer is mainly used to read the graph data to be classified, including the adjacency matrix A and the vertex feature matrix X representing the topology of the graph.
  • the graph convolution layer the graph convolution operation is decomposed into two stages: feature transformation and graph convolution.
  • the output layer is used to output prediction results.
  • the network parameters of each layer include the feature transformation matrix ⁇ l (including the Teacher graph wavelet neural network and Student graph wavelet neural network ), the convolution kernel (convolution kernel t l and convolution kernel s l ), and then use the convolution kernel to update the convolution kernel matrix F l , and the attention network parameter a l .
  • the aforementioned network parameters are initialized, and during the training process, the aforementioned network parameters are updated.
  • the embodiment of the present application discloses a flow chart of a specific classification model training method.
  • a given graph data set G its adjacency matrix A, vertex feature matrix X, and vertex label matrix Y
  • it is sent to the network for forward propagation, and the prediction results of all vertices belonging to each category are calculated.
  • the loss of the supervised learning part and the loss of the unsupervised learning part are calculated to obtain the total network loss.
  • Function value update the network parameters of each layer according to a certain strategy, until the network error reaches a specified minimum value or the number of iterations reaches the specified maximum value, the training ends.
  • the method based on the embodiment of the present application utilizes a collection of scientific papers to train a classification model and predict category labels of unlabeled scientific papers.
  • the network parameters are initialized.
  • Each network calculates the output feature matrix of each layer according to the definition of the graph convolution layer, combined with the input feature matrix of the layer; according to the definition of the output layer, calculates the prediction results Z T or Z S of all vertices belonging to each category, And calculate the supervised learning loss function value and the unsupervised learning function loss value according to the network loss function defined above, and then obtain the loss function value of the entire network; for unlabeled vertices, take the category with the highest probability as the latest category of the vertex , and update the vertex label matrix Y.
  • this application is not limited to the scientific citation classification problems listed in the examples, and can also be applied to any data classification problems that are conveniently modeled and represented by graphs, such as proteins, graphic images, etc., and for the study of infectious diseases
  • graphs such as proteins, graphic images, etc.
  • infectious diseases The law of the spread and diffusion of ideas and ideas in social networks over time, research on how groups in social networks form communities around specific interests or affiliation relationships, and the strength of community connections; social networks are based on the law of "dividing people into groups”. Discovering people with similar interests and suggesting or recommending new links or connections to them; question answering systems directing questions to those with the most relevant experience; advertising systems showing ads to individuals who are most interested and willing to receive advertisements on a particular topic, etc.
  • a classification model training device including:
  • Training data construction module 11 for constructing vertex feature matrix, adjacency matrix and vertex label matrix based on graph data set; Wherein, described vertex label matrix comprises the label information of each vertex of described graph data set;
  • Classification model training module 12 is used for described vertex feature matrix, described adjacency matrix and described vertex label matrix input to the Teacher figure wavelet neural network in the classification model to carry out supervised training, and determine corresponding Supervised training loss is arranged; described vertex characteristic matrix, described adjacency matrix are input to the Student graph wavelet neural network in classification model and carry out unsupervised training, and determine corresponding unsupervised training loss in training process; Based on described The supervised training loss and the unsupervised training loss determine the target training loss; when the target training loss converges, the current classification model is output to obtain the trained classification model.
  • the vertex feature matrix, adjacency matrix, and vertex label matrix are constructed based on the graph data set; wherein, the vertex label matrix includes the label information of each vertex of the graph data set, and then the vertex feature matrix , the adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss in the training process; the vertex feature matrix, the The adjacency matrix is input to the Student graph wavelet neural network in the classification model for unsupervised training, and the corresponding unsupervised training loss is determined during the training process; the target training loss is determined based on the supervised training loss and the unsupervised training loss ; When the target training loss converges, the current classification model is output to obtain the trained classification model.
  • the vertex feature matrix and adjacency matrix of the graph data set are input into the graph neural network for training, and the graph topology and vertex features are used.
  • supervised training and unsupervised training are used to give full play to supervised training and The respective advantages of unsupervised training can improve the classification accuracy of the classification model.
  • the classification model training module 12 is specifically used in the training process to determine the corresponding supervised training loss based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix; based on the Student graph wavelet neural network The second vertex label prediction result and the first vertex label prediction result determine the corresponding unsupervised training loss.
  • the classification model training module 12 is also used for: during the training process, update the vertex label matrix using the first vertex label prediction result; when the target training loss converges, output the current vertex label matrix to obtain each Class prediction results for vertices without class labels.
  • the device also includes a graph wavelet transform base calculation module, which is used to calculate the graph wavelet transform base and graph wavelet inverse transform base of the graph data set using Chebyshev polynomials; correspondingly, the Teacher graph wavelet neural network and the Student graph wavelet During the training process of the neural network, a graph convolution operation is performed based on the graph wavelet transform base and the graph wavelet inverse transform base.
  • the device also includes a graphic wavelet transform base formula acquisition module, configured to acquire the calculation formula of the graph wavelet transform base; wherein, the calculation formula is a formula defined based on spectral theory.
  • both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer;
  • the graph convolution layer is used to sequentially perform feature transformation and graph convolution operation processing on the input data of the layer during the training process.
  • the classification model training module 12 is also used in the training process to determine the corresponding graph volume in the Student graph wavelet neural network based on the attention mechanism using the graph convolution layer trained by the Teacher graph wavelet neural network. Layered convolution kernels.
  • the embodiment of the present application discloses an electronic device 20, including a processor 21 and a memory 22; wherein, the memory 22 is used to store computer programs; the processor 21 is used to execute the A computer program, the classification model training method disclosed in the foregoing embodiments.
  • the memory 22, as a resource storage carrier may be a read-only memory, random access memory, magnetic disk or optical disk, etc., and the storage method may be temporary storage or permanent storage.
  • the electronic device 20 also includes a power supply 23, a communication interface 24, an input and output interface 25, and a communication bus 26; wherein, the power supply 23 is used to provide working voltage for each hardware device on the server 20; the communication The interface 24 can create a data transmission channel between the electronic device 20 and the external device, and the communication protocol it follows is any communication protocol applicable to the technical solution of the present application, which is not specifically limited here; the input The output interface 25 is used to obtain external input data or output data to the external, and its specific interface type can be selected according to specific application needs, and is not specifically limited here.
  • the embodiment of the present application also discloses a computer-readable storage medium for storing a computer program, wherein, when the computer program is executed by a processor, the classification model training method disclosed in the foregoing embodiments is implemented.
  • each embodiment in this specification is described in a progressive manner, each embodiment focuses on the difference from other embodiments, and the same or similar parts of each embodiment can be referred to each other.
  • the description is relatively simple, and for the related information, please refer to the description of the method part.
  • RAM random access memory
  • ROM read-only memory
  • EEPROM electrically programmable ROM
  • EEPROM electrically erasable programmable ROM
  • registers hard disk, removable disk, CD-ROM, or any other Any other known storage medium.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Computing Systems (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

A classification model training method and apparatus, a device, and a medium. The method comprises: constructing a vertex feature matrix, an adjacency matrix, and a vertex tag matrix on the basis of a graph data set, the vertex tag matrix comprising tag information of each vertex of the graph data set; inputting the vertex feature matrix, the adjacency matrix, and the vertex tag matrix into a teacher graph wavelet neural network in a classification model for supervised training, and determining a corresponding supervised training loss in the training process; inputting the vertex feature matrix and the adjacency matrix into a student graph wavelet neural network in the classification model for unsupervised training, and determining a corresponding unsupervised training loss in the training process; determining a target training loss on the basis of the supervised training loss and the unsupervised training loss; and when the target training loss is converged, outputting a current classification model to obtain a trained classification model. In this way, the classification accuracy of the classification model can be improved.

Description

一种分类模型训练方法、装置、设备及介质A classification model training method, device, equipment and medium
本申请要求在2021年6月2日提交中国专利局、申请号为202110613729.6、发明名称为“一种分类模型训练方法、装置、设备及介质”的中国专利申请的优先权,其全部内容通过引用结合在本申请中。This application claims the priority of the Chinese patent application with the application number 202110613729.6 and the title of the invention "a classification model training method, device, equipment and medium" submitted to the China Patent Office on June 2, 2021, the entire contents of which are incorporated by reference incorporated in this application.
技术领域technical field
本申请涉及分类器技术领域,特别涉及一种分类模型训练方法、装置、设备及介质。The present application relates to the technical field of classifiers, and in particular to a classification model training method, device, equipment and medium.
背景技术Background technique
随着云计算、物联网、移动通信和智能终端等信息技术的快速发展,以社交网络、社区和博客为代表的新型应用得到广泛使用。这些应用不断产生大量数据,方便用图来建模分析。其中,顶点表示个人或团体,连接边表示他们之间的联系;顶点上通常附有标签信息,用以表示所建模对象的年龄、性别、位置、兴趣爱好和宗教信仰,以及其他许多可能的特征。这些特征从各个方面反映了个人的行为偏好,理想情况下,每个社交网络用户都附有所有与自己特征相关的标签。但现实情况却并非如此。这是因为,用户出于保护个人隐私的目的,越来越多的社交网络用户在分享个人信息时,显得更加谨慎,导致社交网络媒体仅能搜集用户的部分信息。因此,如何根据已知用户的标签信息,推测剩余用户的标签,显得尤为重要和迫切。该问题即顶点分类问题。With the rapid development of information technologies such as cloud computing, the Internet of Things, mobile communications, and smart terminals, new applications represented by social networks, communities, and blogs are widely used. These applications continue to generate a large amount of data, which is convenient for modeling and analysis with graphs. Among them, the vertices represent individuals or groups, and the connecting edges represent the connections between them; the vertices are usually attached with label information to represent the age, gender, location, hobbies and religious beliefs of the modeled objects, and many other possible feature. These characteristics reflect individual behavior preferences from various aspects. Ideally, each social network user has all tags related to his own characteristics. But the reality is not the case. This is because, for the purpose of protecting personal privacy, more and more social network users are more cautious when sharing personal information, so that social network media can only collect part of the user's information. Therefore, how to infer the tags of the remaining users based on the tag information of known users is particularly important and urgent. This problem is the vertex classification problem.
目前,通过图神经网络解决顶点分类问题已成为研究热点。图神经网络通常由输入层、一个或多个隐藏层,以及输出层组成。例如,参见图1所示,图1为现有技术中的一种图神经网络结构图,图1展示了一种典型的图卷积神经网络的结构,它由一个输入层(Input layer)、两个图卷积层(Gconv layer),和一个输出层(Output layer)组成。其中,输入层读取n*d维的顶点特征矩阵,图卷积层对顶点特征矩阵进行特征提取,经由非线性激活函数如ReLu变换后传 递给下一个图卷积层,最后,输出层即任务层,完成特定的任务如顶点分类、聚类等,图1中展示的是一个顶点分类任务层,输出每个顶点的类别标签。当前,如何提高分类准确度是需要解决的问题。At present, solving the vertex classification problem through graph neural network has become a research hotspot. A graph neural network usually consists of an input layer, one or more hidden layers, and an output layer. For example, referring to Fig. 1, Fig. 1 is a graph neural network structural diagram in the prior art, and Fig. 1 shows a typical graph convolutional neural network structure, which consists of an input layer (Input layer), It consists of two graph convolutional layers (Gconv layer) and an output layer (Output layer). Among them, the input layer reads the n*d-dimensional vertex feature matrix, and the graph convolution layer performs feature extraction on the vertex feature matrix, which is passed to the next graph convolution layer after nonlinear activation functions such as ReLu transformation. Finally, the output layer is The task layer completes specific tasks such as vertex classification, clustering, etc. Figure 1 shows a vertex classification task layer that outputs the category label of each vertex. At present, how to improve the classification accuracy is a problem that needs to be solved.
发明内容Contents of the invention
有鉴于此,本申请的目的在于提供一种分类模型训练方法、装置、设备及介质,能够提升分类模型的分类准确度。其具体方案如下:In view of this, the purpose of the present application is to provide a classification model training method, device, equipment and medium, which can improve the classification accuracy of the classification model. The specific plan is as follows:
第一方面,本申请公开了一种分类模型训练方法,包括:In a first aspect, the present application discloses a classification model training method, including:
基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;Constructing a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on a graph data set; wherein, the vertex label matrix includes label information for each vertex of the graph data set;
将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;The vertex feature matrix, the adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding supervised training loss in the training process;
将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;The vertex feature matrix and the adjacency matrix are input to the Student graph wavelet neural network in the classification model to carry out unsupervised training, and determine the corresponding unsupervised training loss in the training process;
基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;determining a target training loss based on the supervised training loss and the unsupervised training loss;
当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。When the target training loss converges, the current classification model is output to obtain the trained classification model.
可选的,所述在训练过程中确定出相应的有监督训练损失,包括:Optionally, determining the corresponding supervised training loss in the training process includes:
在训练过程中,基于Teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;In the training process, the corresponding supervised training loss is determined based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix;
相应的,所述在训练过程中确定出相应的无监督训练损失,包括:Correspondingly, the corresponding unsupervised training loss is determined in the training process, including:
在训练过程中,基于Student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测结果确定出相应的无监督训练损失。During the training process, a corresponding unsupervised training loss is determined based on the second vertex label prediction result of the Student graph wavelet neural network and the first vertex label prediction result.
可选的,还包括:Optionally, also include:
在训练过程中,利用所述第一顶点标签预测结果更新所述顶点标签矩阵;During the training process, using the first vertex label prediction result to update the vertex label matrix;
当所述目标训练损失收敛,则输出当前的顶点标签矩阵,得到每个无类别标签的顶点的类别预测结果。When the target training loss converges, the current vertex label matrix is output to obtain the category prediction result of each vertex without a category label.
可选的,所述方法还包括:Optionally, the method also includes:
利用切比雪夫多项式计算所述图数据集的图小波变换基,以及图小波逆变换基;Computing a graph wavelet transform basis and a graph wavelet inverse transform basis for said graph data set using Chebyshev polynomials;
相应的,Teacher图小波神经网络以及Student图小波神经网络在训练过程中基于所述图小波变换基和图小波逆变换基进行图卷积操作。Correspondingly, during the training process, the Teacher graph wavelet neural network and the Student graph wavelet neural network perform graph convolution operations based on the graph wavelet transform base and the graph wavelet inverse transform base.
可选的,所述方法还包括:Optionally, the method also includes:
获取所述图小波变换基的计算公式;Obtain the calculation formula of the wavelet transform base of the graph;
其中,所述计算公式为基于谱理论定义的公式。Wherein, the calculation formula is a formula defined based on spectrum theory.
可选的,Teacher图小波神经网络以及Student图小波神经网络均包括输入层,若干图卷积层,以及输出层;Optionally, both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer;
其中,所述图卷积层用于在训练过程中对该层的输入数据依次进行特征变换以及图卷积操作处理。Wherein, the graph convolution layer is used to sequentially perform feature transformation and graph convolution operation processing on the input data of the layer during the training process.
可选的,所述方法还包括:Optionally, the method also includes:
在训练过程中,基于注意力机制利用所述Teacher图小波神经网络训练得到的图卷积层的卷积核确定所述Student图小波神经网络中对应的图卷积层的卷积核。During the training process, the convolution kernel of the graph convolution layer obtained through the training of the Teacher graph wavelet neural network is used to determine the convolution kernel of the corresponding graph convolution layer in the Student graph wavelet neural network based on the attention mechanism.
第二方面,本申请公开了一种分类模型训练装置,包括:In a second aspect, the present application discloses a classification model training device, including:
训练数据构建模块,用于基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;A training data construction module, configured to construct a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on a graph data set; wherein, the vertex label matrix includes label information for each vertex of the graph data set;
分类模型训练模块,用于将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。The classification model training module is used to input the vertex feature matrix, the adjacency matrix and the vertex label matrix to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding effective supervised training loss; input the vertex feature matrix and the adjacency matrix to the Student graph wavelet neural network in the classification model for unsupervised training, and determine the corresponding unsupervised training loss in the training process; based on the supervised The training loss and the unsupervised training loss determine the target training loss; when the target training loss converges, the current classification model is output to obtain the trained classification model.
第三方面,本申请公开了一种电子设备,包括:In a third aspect, the present application discloses an electronic device, comprising:
存储器,用于保存计算机程序;memory for storing computer programs;
处理器,用于执行所述计算机程序,以实现前述的分类模型训练方法。A processor, configured to execute the computer program, so as to realize the aforementioned classification model training method.
第四方面,本申请公开了一种计算机可读存储介质,用于保存计算机程序,所述计算机程序被处理器执行时实现前述的分类模型训练方法。In a fourth aspect, the present application discloses a computer-readable storage medium for storing a computer program, and when the computer program is executed by a processor, the aforementioned classification model training method is implemented.
可见,本申请先基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,之后将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。这样,将图数据集的顶点特征矩阵、邻接矩阵输入图神经网络进行训练,利用了图拓扑结构和顶点特征,在训练的时候,利用了有监督训练和无监督训练,充分发挥有监督训练和无监督训练各自的优势,能够提升分类模型的分类准确度。It can be seen that the present application constructs a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on a graph data set; wherein, the vertex label matrix includes the label information of each vertex of the graph data set, and then the vertex feature matrix, the vertex label matrix, and The adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding supervised training loss in the training process; the vertex feature matrix, the adjacency matrix Input to the Student graph wavelet neural network in classification model to carry out unsupervised training, and determine corresponding unsupervised training loss in training process; Based on described supervised training loss and described unsupervised training loss, determine target training loss; When When the target training loss converges, the current classification model is output to obtain the post-training classification model. In this way, the vertex feature matrix and adjacency matrix of the graph data set are input into the graph neural network for training, and the graph topology and vertex features are used. During training, supervised training and unsupervised training are used to give full play to supervised training and The respective advantages of unsupervised training can improve the classification accuracy of the classification model.
附图说明Description of drawings
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据提供的附图获得其他的附图。In order to more clearly illustrate the technical solutions in the embodiments of the present application or the prior art, the following will briefly introduce the drawings that need to be used in the description of the embodiments or the prior art. Obviously, the accompanying drawings in the following description are only It is an embodiment of the present application, and those skilled in the art can also obtain other drawings according to the provided drawings without creative work.
图1为现有技术中的一种图神经网络结构图;Fig. 1 is a kind of graph neural network structural diagram in the prior art;
图2为本申请公开的一种分类模型训练方法流程图;Fig. 2 is a flow chart of a classification model training method disclosed in the present application;
图3为本申请公开的一种具体的分类模型训练方法流程图;Fig. 3 is a flow chart of a specific classification model training method disclosed in the present application;
图4为本申请公开的一种分类模型结构图;Fig. 4 is a kind of classification model structural diagram disclosed in the present application;
图5为本申请公开的一种具体的分类模型结构图;FIG. 5 is a structural diagram of a specific classification model disclosed in the present application;
图6为本申请公开的一种具体的分类模型训练方法流程图;FIG. 6 is a flow chart of a specific classification model training method disclosed in the present application;
图7为本申请公开的一种分类模型训练装置结构示意图;7 is a schematic structural diagram of a classification model training device disclosed in the present application;
图8为本申请公开的一种电子设备结构图。FIG. 8 is a structural diagram of an electronic device disclosed in the present application.
具体实施方式Detailed ways
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅仅是本申请一部分实施例,而不是全部的实施例。基于本申请中的实施例,本领域普通技术人员在没有做出创造性劳动前提下所获得的所有其他实施例,都属于本申请保护的范围。The following will clearly and completely describe the technical solutions in the embodiments of the application with reference to the drawings in the embodiments of the application. Apparently, the described embodiments are only some of the embodiments of the application, not all of them. Based on the embodiments in this application, all other embodiments obtained by persons of ordinary skill in the art without making creative efforts belong to the scope of protection of this application.
参见图2所示,本申请实施例公开了一种分类模型训练方法,包括:Referring to Fig. 2, the embodiment of the present application discloses a classification model training method, including:
步骤S11:基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;Step S11: Construct a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on the graph dataset; wherein, the vertex label matrix includes label information for each vertex of the graph dataset;
其中,所述标签信息表示对应的类别标签或无类别标签。Wherein, the label information indicates a corresponding category label or no category label.
在具体的实施方式中,假设有图数据集为G=(V,E),V表示顶点集合,V分为少量具有类别标签的顶点集合V L和大部分无类别标签的顶点集合V U两部分,并满足V L∪V U=V,
Figure PCTCN2021121905-appb-000001
E表示连接边集合。除标签外,G的每个顶点v均拥有d个特征,所有顶点的特征构成了n*d维的顶点特征矩阵,记为X。G的邻接矩阵记为A,元素A ij表示顶点i和j之间的连接边的权重。根据已有标签的顶点集合V L,构建n*C维的顶点标签矩阵Y,其中,n=|V|表示图中所有顶点个数,C表示所有顶点的标签类别数,矩阵元素Y ij表示顶点i的类别标签是否为j(j=1,2,…,C),当顶点i已有类别标签时,置对应的第j列元素为1,其余列元素为0。即有:
In a specific implementation, it is assumed that the graph data set is G=(V, E), V represents a vertex set, and V is divided into a small number of vertex sets V L with category labels and a majority of vertex sets V U without category labels. part, and satisfy V L ∪ V U = V,
Figure PCTCN2021121905-appb-000001
E represents the set of connected edges. In addition to labels, each vertex v of G has d features, and the features of all vertices constitute an n*d-dimensional vertex feature matrix, which is denoted as X. The adjacency matrix of G is denoted as A, and the element A ij represents the weight of the connection edge between vertices i and j. According to the vertex set V L with existing labels, a n*C-dimensional vertex label matrix Y is constructed, where n=|V| represents the number of all vertices in the graph, C represents the number of label categories of all vertices, and the matrix element Y ij represents Whether the category label of vertex i is j (j=1, 2, ..., C), when vertex i already has a category label, set the corresponding j-th column element to 1, and set the other column elements to 0. That is:
Figure PCTCN2021121905-appb-000002
Figure PCTCN2021121905-appb-000002
当顶点i为无类别标签时,将该行对应的每一列元素都置为0。When the vertex i has no category label, the elements of each column corresponding to the row are set to 0.
步骤S12:将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失。Step S12: Input the vertex feature matrix, the adjacency matrix and the vertex label matrix into the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss during the training process.
步骤S13:将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失。Step S13: Input the vertex feature matrix and the adjacency matrix into the Student graph wavelet neural network in the classification model for unsupervised training, and determine the corresponding unsupervised training loss during the training process.
在具体的实施方式中,在训练过程中,基于Teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;基于Student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测 结果确定出相应的无监督训练损失。In a specific embodiment, in the training process, the corresponding supervised training loss is determined based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix; based on the second vertex of the Student graph wavelet neural network The label prediction result and the first vertex label prediction result determine a corresponding unsupervised training loss.
具体的,第一顶点标签预测结果和顶点标签矩阵进行比较,计算有监督训练损失,第二顶点标签预测结果与所述第一顶点标签预测结果比较,计算无监督学习损失。Specifically, the first vertex label prediction result is compared with the vertex label matrix to calculate a supervised training loss, and the second vertex label prediction result is compared with the first vertex label prediction result to calculate an unsupervised learning loss.
步骤S14:基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失。Step S14: Determine a target training loss based on the supervised training loss and the unsupervised training loss.
在具体的实施方式中,目标训练损失的计算公式如下:In a specific implementation manner, the calculation formula of the target training loss is as follows:
Figure PCTCN2021121905-appb-000003
Figure PCTCN2021121905-appb-000003
其中,ls T表示有监督训练损失,ls S表示无监督训练损失,α为一个常数,用于调节无监督训练损失在目标损失中所占的比例。Z T表示第一顶点标签预测结果,Z S表示第二顶点标签预测结果。 Among them, ls T represents the supervised training loss, ls S represents the unsupervised training loss, and α is a constant used to adjust the proportion of unsupervised training loss in the target loss. Z T represents the first vertex label prediction result, and Z S represents the second vertex label prediction result.
其中,Z T和Z S均为n*C维的矩阵,并且,Z T或Z S中每个列向量z j表示所有顶点属于类别j的概率,即它的第i(1≤i≤n)个元素表示顶点i属于类别j(j=1,2,…,C)的概率。 Among them, both Z T and Z S are n*C-dimensional matrices, and each column vector z j in Z T or Z S represents the probability that all vertices belong to category j, that is, its i-th (1≤i≤n ) elements represent the probability that vertex i belongs to category j (j=1, 2, . . . , C).
需要指出的是,本申请实施例可以将Teacher图小波神经网络以及Student图小波神经网络的输出层定义为
Figure PCTCN2021121905-appb-000004
It should be pointed out that in the embodiment of the present application, the output layer of the Teacher graph wavelet neural network and the Student graph wavelet neural network can be defined as
Figure PCTCN2021121905-appb-000004
其中,
Figure PCTCN2021121905-appb-000005
in,
Figure PCTCN2021121905-appb-000005
ψ r为图小波变换基,
Figure PCTCN2021121905-appb-000006
为图小波逆变换基,F L表示第L层图卷积层的卷积核矩阵,Q L表示第L层顶点特征变换结果,Teacher图小波神经网络以及Student图小波神经网络均包括L层图卷积层。
ψ r is the graph wavelet transform basis,
Figure PCTCN2021121905-appb-000006
is the graph wavelet inverse transform base, F L represents the convolution kernel matrix of the L layer graph convolution layer, Q L represents the L layer vertex feature transformation result, the Teacher graph wavelet neural network and the Student graph wavelet neural network both include L layer graphs convolutional layer.
并且,有监督训练损失函数基于交叉熵原理,计算顶点实际标签概率分布和预测标签概率分布的差异程度;无监督训练损失函数计算Z T和Z S相同坐标元素之间差值的平方和。 Moreover, the supervised training loss function calculates the degree of difference between the actual label probability distribution and the predicted label probability distribution of vertices based on the principle of cross entropy; the unsupervised training loss function calculates the sum of squares of the differences between the same coordinate elements of Z T and Z S.
这样,当整个网络训练结束时,两个网络的输出结果Z T和Z S一致或差别可忽略不计。可以Teacher图小波神经网络的输出Z T为整个网络模型的输出。 In this way, when the entire network training ends, the output results Z T and Z S of the two networks are consistent or the difference is negligible. The output Z T of the teacher graph wavelet neural network can be used as the output of the entire network model.
本实施例在训练过程中,利用所述第一顶点标签预测结果更新所述顶点标签矩阵,具体的,对于无类别标签的顶点,即对于v i∈V U,将第一顶点标签预测结果中概率最大的类别作为该顶点的最新类别,更新顶点标签矩阵。 In this embodiment, during the training process, the vertex label matrix is updated using the first vertex label prediction result. Specifically, for vertices without category labels, that is, for v i ∈ V U , the first vertex label prediction result The category with the highest probability is used as the latest category of the vertex, and the vertex label matrix is updated.
步骤S15:当所述目标训练损失收敛,则输出当前的分类模型,得到训练 后分类模型。Step S15: When the target training loss converges, output the current classification model to obtain the post-training classification model.
并且,当所述目标训练损失收敛,则输出当前的顶点标签矩阵,得到每个无类别标签的顶点的类别预测结果。And, when the target training loss converges, the current vertex label matrix is output to obtain the category prediction result of each vertex without a category label.
在具体的实施方式中,当目标训练损失达到预设阈值或者迭代次数达到指定迭代最大值,则目标训练损失收敛,训练结束。其中,预设阈值通常为一个较小的值,此时,对于无类别标签的顶点,根据当前的顶点标签矩阵,得到其应归属的类别。In a specific implementation manner, when the target training loss reaches a preset threshold or the number of iterations reaches a specified maximum value of iterations, the target training loss converges and the training ends. Wherein, the preset threshold is usually a small value, at this time, for a vertex without a class label, the class to which it should belong is obtained according to the current vertex label matrix.
也即,本申请将无标签顶点的预测融合进训练过程:在训练过程中,根据每次的训练结果更新顶点标签矩阵,训练结束后即可获得任意一个无标签顶点的类别标签。That is, this application integrates the prediction of unlabeled vertices into the training process: during the training process, the vertex label matrix is updated according to each training result, and the category label of any unlabeled vertex can be obtained after the training is completed.
其中,在具体的实施方式中,可以先根据按照特定策略如正态分布随机初始化、Xavier初始化或He初始化,对图小波神经网络各层网络参数进行初始化。在训练的过程中,可以根据特定策略如SGD(即Stochastic Gradient Descent,随机梯度下降)、MGD(即Momentum Gradient Descent,动量梯度下降)、Nesterov Momentum(牛顿动量)、AdaGrad(即Adaptive gradient algorithm,自适应梯度算法)、RMSprop(即Root Mean Square Prop,前向均方根梯度下降算法)和Adam(即Adaptive Moment Estimation,自适应矩估计)或BGD(即Batch Gradient Descent,批量梯度下降)等,对图小波神经网络各层网络参数进行修正和更新,以优化损失函数值。Wherein, in a specific implementation manner, the network parameters of each layer of the graph wavelet neural network may be initialized first according to a specific strategy such as random initialization with normal distribution, Xavier initialization or He initialization. In the process of training, according to specific strategies such as SGD (Stochastic Gradient Descent, stochastic gradient descent), MGD (Momentum Gradient Descent, momentum gradient descent), Nesterov Momentum (Newton momentum), AdaGrad (Adaptive gradient algorithm, automatic Adaptive gradient algorithm), RMSprop (ie Root Mean Square Prop, forward root mean square gradient descent algorithm) and Adam (ie Adaptive Moment Estimation, adaptive moment estimation) or BGD (ie Batch Gradient Descent, batch gradient descent), etc. The network parameters of each layer of the graph wavelet neural network are corrected and updated to optimize the value of the loss function.
可见,本申请实施例先基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,之后将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。这样,将图数据集的顶点特征矩阵、邻接矩阵输入图神经网络进行训练,利用了图拓扑结构和顶点特征,在训练的时候,利用了有监督训练和无监督训练,充分发挥有监督训练和无监督训练各自的优势,能 够提升分类模型的分类准确度。It can be seen that in the embodiment of the present application, the vertex feature matrix, adjacency matrix, and vertex label matrix are constructed based on the graph data set; wherein, the vertex label matrix includes the label information of each vertex of the graph data set, and then the vertex feature matrix , the adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss in the training process; the vertex feature matrix, the The adjacency matrix is input to the Student graph wavelet neural network in the classification model for unsupervised training, and the corresponding unsupervised training loss is determined during the training process; the target training loss is determined based on the supervised training loss and the unsupervised training loss ; When the target training loss converges, the current classification model is output to obtain the trained classification model. In this way, the vertex feature matrix and adjacency matrix of the graph data set are input into the graph neural network for training, and the graph topology and vertex features are used. During training, supervised training and unsupervised training are used to give full play to supervised training and The respective advantages of unsupervised training can improve the classification accuracy of the classification model.
参见图3所示,本申请实施例公开了一种具体的分类模型训练方法,包括:Referring to Figure 3, the embodiment of the present application discloses a specific classification model training method, including:
步骤S21:获取图小波变换基的计算公式。Step S21: Obtain the calculation formula of the graph wavelet transform basis.
其中,所述计算公式为基于谱理论定义的公式。Wherein, the calculation formula is a formula defined based on spectrum theory.
需要指出的是,通过傅里叶变换定义的图卷积操作在顶点域局部性差,利用谱理论定义图小波变换的基底,保证了图卷积计算的局部性。It should be pointed out that the graph convolution operation defined by the Fourier transform has poor locality in the vertex domain, and the use of spectral theory to define the basis of the graph wavelet transform ensures the locality of the graph convolution calculation.
步骤S22:利用切比雪夫多项式计算所述图数据集的图小波变换基,以及图小波逆变换基。Step S22: Calculate the graph wavelet transform basis and the graph wavelet inverse transform basis of the graph data set by using Chebyshev polynomials.
在具体的实施方式中,图小波变换基的计算公式为ψ r=UH rU T,其中,ψ r表示从图数据集G中提取的图小波变换基,U表示由对图数据集G的拉普拉斯矩阵
Figure PCTCN2021121905-appb-000007
进行特征分解得到的特征向量所组成的矩阵;D是一个对角阵,其主对角线上的n个元素分别表示n个顶点的度数,其余元素均为零。H r=diag(h(rλ 1),h(rλ 2),…,h(rλ n))是缩放尺度为r的缩放矩阵,并设
Figure PCTCN2021121905-appb-000008
是对图G的拉普拉斯矩阵进行特征分解得到的特征值;图小波逆变换基
Figure PCTCN2021121905-appb-000009
可以通过将ψ r中的h(rλ i)替换为h(-rλ i)求得。由于矩阵的特征分解计算开销较大,为避免此开销,利用切比雪夫多项式T k(x)=2xT k-1(x)-T k-2(x),且T 0=1,T 1=x,来近似计算图小波变换基,以及图小波逆变换基。
In a specific implementation, the calculation formula of the graph wavelet transform basis is ψ r =UH r U T , where ψ r represents the graph wavelet transform base extracted from the graph data set G, and U represents the graph obtained from the graph data set G. Laplace matrix
Figure PCTCN2021121905-appb-000007
A matrix composed of eigenvectors obtained by eigendecomposition; D is a diagonal matrix, and the n elements on the main diagonal represent the degrees of n vertices respectively, and the remaining elements are all zero. H r =diag(h(rλ 1 ), h(rλ 2 ),..., h(rλ n )) is the scaling matrix whose scaling scale is r, and let
Figure PCTCN2021121905-appb-000008
is the eigenvalue obtained by eigendecomposing the Laplacian matrix of graph G; the graph wavelet inverse transform base
Figure PCTCN2021121905-appb-000009
It can be obtained by replacing h(rλ i ) in ψ r with h(-rλ i ). Since the eigendecomposition of the matrix has a large computational cost, in order to avoid this cost, use the Chebyshev polynomial T k (x)=2xT k-1 (x)-T k-2 (x), and T 0 =1,T 1 =x, to approximate the calculation of graph wavelet transform base and graph wavelet inverse transform base.
相应的,Teacher图小波神经网络以及Student图小波神经网络在训练过程中基于所述图小波变换基和图小波逆变换基进行图卷积操作。Correspondingly, during the training process, the Teacher graph wavelet neural network and the Student graph wavelet neural network perform graph convolution operations based on the graph wavelet transform base and the graph wavelet inverse transform base.
需要指出的是,现有技术中在图卷积操作的过程中图傅里叶变换是低效的,因为拉普拉斯矩阵的特征向量矩阵是稠密的,而本实施例基于所述图小波变换基和图小波逆变换基进行图卷积操作,图小波变换基和图小波逆变换基是稀疏的,所以能够提升图卷积操作的运算效率。It should be pointed out that the graph Fourier transform is inefficient in the process of graph convolution operation in the prior art, because the eigenvector matrix of the Laplacian matrix is dense, and this embodiment is based on the graph wavelet The transform base and graph wavelet inverse transform base perform graph convolution operations. The graph wavelet transform base and graph wavelet inverse transform base are sparse, so the computational efficiency of graph convolution operations can be improved.
步骤S23:基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,所述标签信息表示对应的类别标签或无类别标签。Step S23: Construct a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on the graph data set; wherein, the vertex label matrix includes label information of each vertex of the graph data set, and the label information represents the corresponding category label or none category label.
步骤S24:将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失。Step S24: Input the vertex feature matrix, the adjacency matrix and the vertex label matrix to the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss during the training process.
步骤S25:将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失。Step S25: Input the vertex feature matrix and the adjacency matrix into the Student graph wavelet neural network in the classification model for unsupervised training, and determine the corresponding unsupervised training loss during the training process.
在具体的实施方式中,Teacher图小波神经网络以及Student图小波神经网络均包括输入层,若干图卷积层,以及输出层;其中,所述图卷积层用于在训练过程中对该层的输入数据依次进行特征变换以及图卷积操作处理。具体的,可以包括1个输入层,L(L≥1)个图卷积层,以及输出层。In a specific embodiment, both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer; The input data of is sequentially processed by feature transformation and graph convolution operation. Specifically, it may include 1 input layer, L (L≥1) graph convolution layers, and an output layer.
也即,本申请实施例中,图卷积层对该层的输入数据先进行特征变换,然后图卷积操作处理,这样将图卷积层分为特征变换和图卷积操作先后两个处理阶段,能够减少网络参数,从而减低模型运算量,提升模型训练效率。That is to say, in the embodiment of the present application, the graph convolution layer first performs feature transformation on the input data of the layer, and then the graph convolution operation processing, so that the graph convolution layer is divided into two processes of feature transformation and graph convolution operation stage, network parameters can be reduced, thereby reducing the amount of model computation and improving model training efficiency.
其中,在l(1≤l≤L)层图卷积层中:Among them, in the l(1≤l≤L) layer graph convolution layer:
特征变换:
Figure PCTCN2021121905-appb-000010
Feature transformation:
Figure PCTCN2021121905-appb-000010
图卷积:
Figure PCTCN2021121905-appb-000011
Graph convolution:
Figure PCTCN2021121905-appb-000011
其中,H l和H l+1分别为第l层图隐藏层的输入和输出数据,且H 1=X;Θ l为第l层待训练的特征变换矩阵,Q l为第l层特征变换结果,T表示矩阵的转置操作。 Among them, H l and H l+1 are the input and output data of the hidden layer of the l-th layer graph respectively, and H 1 =X; Θ l is the feature transformation matrix to be trained in the l-th layer, and Q l is the feature transformation of the l-th layer As a result, T represents the transpose operation of the matrix.
需要指出的是,现有技术中图卷积层定义通常未区分特征变换和卷积操作,结合本申请实施例中的图小波变换基,如果不将图卷积层分为特征变换和图卷积操作先后两个处理阶段。以如下公式定义图卷积层:It should be pointed out that the definition of the graph convolution layer in the prior art usually does not distinguish between feature transformation and convolution operation, combined with the graph wavelet transform base in the embodiment of the application, if the graph convolution layer is not divided into feature transformation and graph convolution The plot operation has two processing stages. The graph convolutional layer is defined by the following formula:
Figure PCTCN2021121905-appb-000012
Figure PCTCN2021121905-appb-000012
其中,X表示顶点特征矩阵,m表示图卷积层的序数,F是图卷积核矩阵,h是激活函数。在采用上述方式定义的图卷积层中,包含的参数个数是n*p*q,其中n表示图中顶点的个数,p表示该层输入的顶点特征维度,q表示该层输出的顶点特征维度。而本申请实施例将特征变换从图卷积操作剥离出来,每一个图卷积层的参数个数就变成了n+p*q。Among them, X represents the vertex feature matrix, m represents the ordinal number of the graph convolution layer, F is the graph convolution kernel matrix, and h is the activation function. In the graph convolution layer defined in the above way, the number of parameters included is n*p*q, where n represents the number of vertices in the graph, p represents the vertex feature dimension of the layer input, and q represents the output of the layer Vertex feature dimension. However, in the embodiment of the present application, the feature transformation is separated from the graph convolution operation, and the number of parameters of each graph convolution layer becomes n+p*q.
另外,在具体的实施方式中,在训练过程中,基于注意力机制利用所述 Teacher图小波神经网络训练得到的图卷积层的卷积核确定所述Student图小波神经网络中对应的图卷积层的卷积核。In addition, in a specific embodiment, during the training process, the convolution kernel of the graph convolution layer obtained by using the Teacher graph wavelet neural network training based on the attention mechanism determines the corresponding graph volume in the Student graph wavelet neural network. Layered convolution kernels.
具体的,分类模型可以包括Teacher图小波神经网络、Student图小波神经网络,以及连接Teacher图小波神经网络与Student图小波神经网络每一对图卷积层的注意力网络。Specifically, the classification model may include a Teacher graph wavelet neural network, a Student graph wavelet neural network, and an attention network connecting each pair of the Teacher graph wavelet neural network and the Student graph wavelet neural network.
需要说明的是,设F l是第l层的图卷积核矩阵,为一个对角阵。从信号处理角度看,F l对角线上元素(f 1,f 2…f n)可视为图的频率,表示该频率对应的特征向量的重要性。记Teacher图小波神经网络和Student图小波神经网络第l层的卷积核矩阵分别为T l和S l,分别由Teacher图小波神经网络该层的卷积核t l和Student图小波神经网络该层的卷积核s l对角化得到,两者均是n维的列向量。 It should be noted that, let F l be the graph convolution kernel matrix of layer l, which is a diagonal matrix. From the perspective of signal processing, the elements (f 1 , f 2 . Note that the convolution kernel matrices of the first layer of the Teacher graph wavelet neural network and the Student graph wavelet neural network are T l and S l respectively, and the convolution kernel t l of the Teacher graph wavelet neural network layer and the Student graph wavelet neural network are The convolution kernel s l of the layer is obtained by diagonalizing, both of which are n-dimensional column vectors.
本实施例中,可以基于注意力机制进行注意力转移(attention transfer):Teacher图小波神经网络的每一层将学习到的卷积核转移给Student图小波神经网络的相应层,也即,Student图小波神经网络向Teacher图小波神经网络学习,促使提高整个网络的性能。具体地,可设计一个单层的前馈神经网络,其输入层负责读取Teacher图小波神经网络和Student图小波神经网络第l层卷积核t l和s l;其隐藏层用于实现注意力函数a l:R n×R n→R,以便得到两个向量之间的注意力权重e l:e l=a l(t l,s l); In this embodiment, attention transfer (attention transfer) can be performed based on the attention mechanism: each layer of the Teacher graph wavelet neural network transfers the learned convolution kernel to the corresponding layer of the Student graph wavelet neural network, that is, the Student graph wavelet neural network. The graph wavelet neural network learns from the Teacher graph wavelet neural network to improve the performance of the entire network. Specifically, a single-layer feed-forward neural network can be designed, and its input layer is responsible for reading the convolution kernels t l and s l of the first layer of the Teacher graph wavelet neural network and the Student graph wavelet neural network; its hidden layer is used to realize the attention Force function a l : R n ×R n → R, in order to get the attention weight e l between the two vectors: e l = a l (t l , s l );
进一步的,通过softmax函数对注意力权重e l进行归一化得到归一化的注意力权重为e′ lFurther, normalize the attention weight e l through the softmax function to obtain the normalized attention weight e′ l :
Figure PCTCN2021121905-appb-000013
Figure PCTCN2021121905-appb-000013
其中,e l(i)表示e l的第i个分量,e′ l(i)表示e′ l的第i个分量。则有: Among them, e l (i) represents the i-th component of e l , and e′ l (i) represents the i-th component of e′ l . Then there are:
s′ l(i)=e′ l(i)×t l(i),i∈[1,n]; s′ l (i)=e′ l (i)×t l (i), i∈[1,n];
其中,s′ l(i)表示Student图小波神经网络向Teacher图小波神经网络学习到的第l层卷积核。 Among them, s′ l (i) represents the first layer of convolution kernel learned by the Student graph wavelet neural network from the Teacher graph wavelet neural network.
需要指出的是,注意力机制的加入,促进Student图小波神经网络快速利用Teacher图小波神经网络掌握的知识,提高训练速度。It should be pointed out that the addition of the attention mechanism promotes the Student graph wavelet neural network to quickly use the knowledge mastered by the Teacher graph wavelet neural network to improve the training speed.
步骤S26:基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失。Step S26: Determine a target training loss based on the supervised training loss and the unsupervised training loss.
步骤S27:当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。Step S27: When the target training loss converges, output the current classification model to obtain the trained classification model.
例如,参见图4所示,图4为本申请实施例公开的一种分类模型结构图。Teacher图小波神经网络GWN T,Student图小波神经网络GWN S,进一步的,参见图5所示,图5为本申请实施例公开的一种具体的分类模型结构图。分类模型由Teacher图小波神经网络GWN T,Student图小波神经网络GWN S,以及连接两个网络每一对图卷积层的注意力网络组成。GWN T根据有标签的图顶点进行有监督学习,预测准确度较高;GWN S在GWN T的指导下(利用其预测结果)利用无标签的图顶点进行无监督学习,以期提高预测准确度,获得更好的顶点分类模型。注意力网络用于GWN T将每一层学习到的“知识”即卷积核转移给GWN S对应层,也即GWN S向GWN T学习。GWN T和GWN S均包含1个输入层、L个图卷积层以及1个输出层。输入层主要用于读取待分类图数据,包括表示图拓扑结构的邻接矩阵A和顶点特征矩阵X。图卷积层中将图卷积操作分解为特征变换和图卷积先后两个阶段。输出层用于输出预测结果。 For example, refer to FIG. 4 , which is a structure diagram of a classification model disclosed in the embodiment of the present application. Teacher graph wavelet neural network GWN T , Student graph wavelet neural network GWNS , further, refer to FIG. 5 , which is a specific classification model structure diagram disclosed in the embodiment of the present application. The classification model consists of a Teacher graph wavelet neural network GWN T , a Student graph wavelet neural network GWNS S , and an attention network connecting each pair of graph convolutional layers of the two networks. GWN T performs supervised learning based on labeled graph vertices, and the prediction accuracy is high; GWNS S uses unlabeled graph vertices to perform unsupervised learning under the guidance of GWN T (using its prediction results), in order to improve the prediction accuracy. Get a better vertex classification model. The attention network is used by GWN T to transfer the "knowledge" learned by each layer, that is, the convolution kernel, to the corresponding layer of GWN S , that is, GWN S learns from GWN T. Both GWN T and GWN S contain 1 input layer, L graph convolution layers and 1 output layer. The input layer is mainly used to read the graph data to be classified, including the adjacency matrix A and the vertex feature matrix X representing the topology of the graph. In the graph convolution layer, the graph convolution operation is decomposed into two stages: feature transformation and graph convolution. The output layer is used to output prediction results.
并且,整个分类模型中,每一层的网络参数均包含特征变换矩阵Θ l(包括Teacher图小波神经网络的
Figure PCTCN2021121905-appb-000014
和Student图小波神经网络的
Figure PCTCN2021121905-appb-000015
),卷积核(卷积核t l和卷积核s l),进而利用卷积核更新卷积核矩阵F l,以及注意力网络参数a l。在初始化阶段,初始前述网络参数,在训练过程中,更新前述网络参数。
Moreover, in the entire classification model, the network parameters of each layer include the feature transformation matrix Θ l (including the Teacher graph wavelet neural network
Figure PCTCN2021121905-appb-000014
and Student graph wavelet neural network
Figure PCTCN2021121905-appb-000015
), the convolution kernel (convolution kernel t l and convolution kernel s l ), and then use the convolution kernel to update the convolution kernel matrix F l , and the attention network parameter a l . In the initialization phase, the aforementioned network parameters are initialized, and during the training process, the aforementioned network parameters are updated.
例如,参见图6所示,本申请实施例公开了一种具体的分类模型训练方法流程图,对于一个给定的图数据集G,以其邻接矩阵A、顶点特征矩阵X以及顶点标签矩阵Y作为输入,送入网络进行前向传播,计算所有顶点属于每一类别的预测结果,更新预测结果矩阵的同时,计算有监督学习部分的损失和无监督学习部分的损失,从而得到总的网络损失函数值,按照一定策略更新各层网络参数,直至网络误差达到一个指定的较小值或迭代次数达到指定的最大值时,训练结束。For example, referring to Fig. 6, the embodiment of the present application discloses a flow chart of a specific classification model training method. For a given graph data set G, its adjacency matrix A, vertex feature matrix X, and vertex label matrix Y As an input, it is sent to the network for forward propagation, and the prediction results of all vertices belonging to each category are calculated. While updating the prediction result matrix, the loss of the supervised learning part and the loss of the unsupervised learning part are calculated to obtain the total network loss. Function value, update the network parameters of each layer according to a certain strategy, until the network error reaches a specified minimum value or the number of iterations reaches the specified maximum value, the training ends.
例如,基于本申请实施例的方法利用科技论文集训练分类模型并预测无标签的科技论文的类别标签。For example, the method based on the embodiment of the present application utilizes a collection of scientific papers to train a classification model and predict category labels of unlabeled scientific papers.
(1)下载引文网络数据集Citeseer,包含共分为六个类别的3312篇科技论 文以及4732条论文间的引用关系;利用bag-of-words(词袋模型)为每篇论文构建其特征向量x,所有文档的特征向量组成特征矩阵X。根据论文间的引用关系,构建其邻接矩阵A。目标是将每个文档归类,每个类别随机抽取20个实例作为标记数据,将1000个实例作为测试数据,其余用作未标记的数据;构建顶点标签矩阵Y。(1) Download the citation network dataset Citeseer, which contains 3312 scientific papers divided into six categories and the citation relationship between 4732 papers; use bag-of-words (word bag model) to construct its feature vector for each paper x, the feature vectors of all documents form the feature matrix X. According to the citation relationship between papers, construct its adjacency matrix A. The goal is to classify each document, randomly sample 20 instances of each category as labeled data, use 1000 instances as test data, and use the rest as unlabeled data; construct a vertex label matrix Y.
(2)定义网络结构:基于前述公开内容定义图卷积层、输出层以及网络损失函数。(2) Define the network structure: define the graph convolution layer, output layer, and network loss function based on the aforementioned disclosure.
(3)利用切比雪夫多项式近似计算图小波变换基底和图小波逆变换的基底。(3) Using Chebyshev polynomials to approximate the calculation of graph wavelet transform basis and graph wavelet inverse transform basis.
(4)按照正则化初始化方法,对网络参数进行初始化。(4) According to the regularization initialization method, the network parameters are initialized.
(5)以A,X和Y作为网络输入,送入网络进行前向传播。其中,Teacher图小波神经网络GWN T以A,X和Y作为输入,Student图小波神经网络GWN S以A和X作为输入。每个网络根据图卷积层的定义,结合该层的输入特征矩阵,计算每一层的输出特征矩阵;按照输出层的定义,计算所有顶点属于每一类别的预测结果Z T或Z S,并根据前述定义的网络损失函数计算有监督学习损失函数值、无监督学习函数损失值,进而得到整个网络的损失函数值;对于无标签顶点,取概率最大的那一类别作为该顶点的最新类别,并更新顶点标签矩阵Y。 (5) Take A, X and Y as network input and send them to the network for forward propagation. Among them, the Teacher graph wavelet neural network GWN T takes A, X and Y as inputs, and the Student graph wavelet neural network GWN S takes A and X as inputs. Each network calculates the output feature matrix of each layer according to the definition of the graph convolution layer, combined with the input feature matrix of the layer; according to the definition of the output layer, calculates the prediction results Z T or Z S of all vertices belonging to each category, And calculate the supervised learning loss function value and the unsupervised learning function loss value according to the network loss function defined above, and then obtain the loss function value of the entire network; for unlabeled vertices, take the category with the highest probability as the latest category of the vertex , and update the vertex label matrix Y.
(6)按照优化方法,计算损失函数关于网络参数的梯度,并后向传播,以便对网络参数进行优化,直至网络预测误差达到一个指定的较小值或迭代次数达到指定迭最大值时,训练结束。此时,对于无类别标签的顶点,可根据顶点标签矩阵Y得到其应归属的类别。(6) According to the optimization method, calculate the gradient of the loss function with respect to the network parameters, and propagate backwards to optimize the network parameters until the network prediction error reaches a specified minimum value or the number of iterations reaches the specified maximum value, training Finish. At this time, for vertices without category labels, the category to which they belong can be obtained according to the vertex label matrix Y.
当然,本申请不局限应用于实施例中列举的科学引文分类问题,还可应用于任意方便用图来建模表示的数据的分类问题,如蛋白质、图形图像等,以及用于研究传染性疾病和思想观点等在社交网络中随着时间传播扩散的规律、研究社交网络中的群体如何围绕特定利益或隶属关系形成社团,以及社团连接的强度;社交网络根据“人以群分”的规律,发现具有相似兴趣的人,向他们建议或推荐新的链接或联系;问答系统将问题引导给最有相关经验的人;广告系统向最有兴趣并愿意接受特定主题广告的个人显示广告等。Of course, this application is not limited to the scientific citation classification problems listed in the examples, and can also be applied to any data classification problems that are conveniently modeled and represented by graphs, such as proteins, graphic images, etc., and for the study of infectious diseases The law of the spread and diffusion of ideas and ideas in social networks over time, research on how groups in social networks form communities around specific interests or affiliation relationships, and the strength of community connections; social networks are based on the law of "dividing people into groups". Discovering people with similar interests and suggesting or recommending new links or connections to them; question answering systems directing questions to those with the most relevant experience; advertising systems showing ads to individuals who are most interested and willing to receive advertisements on a particular topic, etc.
参见图7所示,本申请实施例公开了一种分类模型训练装置,包括:Referring to Figure 7, the embodiment of the present application discloses a classification model training device, including:
训练数据构建模块11,用于基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;Training data construction module 11, for constructing vertex feature matrix, adjacency matrix and vertex label matrix based on graph data set; Wherein, described vertex label matrix comprises the label information of each vertex of described graph data set;
分类模型训练模块12,用于将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。Classification model training module 12, is used for described vertex feature matrix, described adjacency matrix and described vertex label matrix input to the Teacher figure wavelet neural network in the classification model to carry out supervised training, and determine corresponding Supervised training loss is arranged; described vertex characteristic matrix, described adjacency matrix are input to the Student graph wavelet neural network in classification model and carry out unsupervised training, and determine corresponding unsupervised training loss in training process; Based on described The supervised training loss and the unsupervised training loss determine the target training loss; when the target training loss converges, the current classification model is output to obtain the trained classification model.
可见,本申请实施例先基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息,之后将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型。这样,将图数据集的顶点特征矩阵、邻接矩阵输入图神经网络进行训练,利用了图拓扑结构和顶点特征,在训练的时候,利用了有监督训练和无监督训练,充分发挥有监督训练和无监督训练各自的优势,能够提升分类模型的分类准确度。It can be seen that in the embodiment of the present application, the vertex feature matrix, adjacency matrix, and vertex label matrix are constructed based on the graph data set; wherein, the vertex label matrix includes the label information of each vertex of the graph data set, and then the vertex feature matrix , the adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model for supervised training, and determine the corresponding supervised training loss in the training process; the vertex feature matrix, the The adjacency matrix is input to the Student graph wavelet neural network in the classification model for unsupervised training, and the corresponding unsupervised training loss is determined during the training process; the target training loss is determined based on the supervised training loss and the unsupervised training loss ; When the target training loss converges, the current classification model is output to obtain the trained classification model. In this way, the vertex feature matrix and adjacency matrix of the graph data set are input into the graph neural network for training, and the graph topology and vertex features are used. During training, supervised training and unsupervised training are used to give full play to supervised training and The respective advantages of unsupervised training can improve the classification accuracy of the classification model.
其中,分类模型训练模块12,具体用于在训练过程中,基于Teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;基于Student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测结果确定出相应的无监督训练损失。Wherein, the classification model training module 12 is specifically used in the training process to determine the corresponding supervised training loss based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix; based on the Student graph wavelet neural network The second vertex label prediction result and the first vertex label prediction result determine the corresponding unsupervised training loss.
分类模型训练模块12,还用于:在训练过程中,利用所述第一顶点标签预测结果更新所述顶点标签矩阵;当所述目标训练损失收敛,则输出当前的顶点标签矩阵,得到每个无类别标签的顶点的类别预测结果。The classification model training module 12 is also used for: during the training process, update the vertex label matrix using the first vertex label prediction result; when the target training loss converges, output the current vertex label matrix to obtain each Class prediction results for vertices without class labels.
所述装置还包括图小波变换基计算模块,用于利用切比雪夫多项式计算所述图数据集的图小波变换基,以及图小波逆变换基;相应的,Teacher图小波神经网络以及Student图小波神经网络在训练过程中基于所述图小波变换基和图小波逆变换基进行图卷积操作。The device also includes a graph wavelet transform base calculation module, which is used to calculate the graph wavelet transform base and graph wavelet inverse transform base of the graph data set using Chebyshev polynomials; correspondingly, the Teacher graph wavelet neural network and the Student graph wavelet During the training process of the neural network, a graph convolution operation is performed based on the graph wavelet transform base and the graph wavelet inverse transform base.
所述装置还包括,图小波变换基公式获取模块,用于获取所述图小波变换基的计算公式;其中,所述计算公式为基于谱理论定义的公式。The device also includes a graphic wavelet transform base formula acquisition module, configured to acquire the calculation formula of the graph wavelet transform base; wherein, the calculation formula is a formula defined based on spectral theory.
在具体的实施方式中,Teacher图小波神经网络以及Student图小波神经网络均包括输入层,若干图卷积层,以及输出层;In a specific embodiment, both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer;
其中,所述图卷积层用于在训练过程中对该层的输入数据依次进行特征变换以及图卷积操作处理。Wherein, the graph convolution layer is used to sequentially perform feature transformation and graph convolution operation processing on the input data of the layer during the training process.
分类模型训练模块12,还用于在训练过程中,基于注意力机制利用所述Teacher图小波神经网络训练得到的图卷积层的卷积核确定所述Student图小波神经网络中对应的图卷积层的卷积核。The classification model training module 12 is also used in the training process to determine the corresponding graph volume in the Student graph wavelet neural network based on the attention mechanism using the graph convolution layer trained by the Teacher graph wavelet neural network. Layered convolution kernels.
参见图8所示,本申请实施例公开了一种电子设备20,包括处理器21和存储器22;其中,所述存储器22,用于保存计算机程序;所述处理器21,用于执行所述计算机程序,前述实施例公开的分类模型训练方法。Referring to FIG. 8 , the embodiment of the present application discloses an electronic device 20, including a processor 21 and a memory 22; wherein, the memory 22 is used to store computer programs; the processor 21 is used to execute the A computer program, the classification model training method disclosed in the foregoing embodiments.
关于上述分类模型训练方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。Regarding the specific process of the above classification model training method, reference may be made to the corresponding content disclosed in the foregoing embodiments, and details are not repeated here.
并且,所述存储器22作为资源存储的载体,可以是只读存储器、随机存储器、磁盘或者光盘等,存储方式可以是短暂存储或者永久存储。Moreover, the memory 22, as a resource storage carrier, may be a read-only memory, random access memory, magnetic disk or optical disk, etc., and the storage method may be temporary storage or permanent storage.
另外,所述电子设备20还包括电源23、通信接口24、输入输出接口25和通信总线26;其中,所述电源23用于为所述服务器20上的各硬件设备提供工作电压;所述通信接口24能够为所述电子设备20创建与外界设备之间的数据传输通道,其所遵循的通信协议是能够适用于本申请技术方案的任意通信协议,在此不对其进行具体限定;所述输入输出接口25,用于获取外界输入数据或向外界输出数据,其具体的接口类型可以根据具体应用需要进行选取,在此不进行具体限定。In addition, the electronic device 20 also includes a power supply 23, a communication interface 24, an input and output interface 25, and a communication bus 26; wherein, the power supply 23 is used to provide working voltage for each hardware device on the server 20; the communication The interface 24 can create a data transmission channel between the electronic device 20 and the external device, and the communication protocol it follows is any communication protocol applicable to the technical solution of the present application, which is not specifically limited here; the input The output interface 25 is used to obtain external input data or output data to the external, and its specific interface type can be selected according to specific application needs, and is not specifically limited here.
进一步的,本申请实施例还公开了一种计算机可读存储介质,用于保存 计算机程序,其中,所述计算机程序被处理器执行时实现前述实施例公开的分类模型训练方法。Further, the embodiment of the present application also discloses a computer-readable storage medium for storing a computer program, wherein, when the computer program is executed by a processor, the classification model training method disclosed in the foregoing embodiments is implemented.
关于上述分类模型训练方法的具体过程可以参考前述实施例中公开的相应内容,在此不再进行赘述。Regarding the specific process of the above classification model training method, reference may be made to the corresponding content disclosed in the foregoing embodiments, and details are not repeated here.
本说明书中各个实施例采用递进的方式描述,每个实施例重点说明的都是与其它实施例的不同之处,各个实施例之间相同或相似部分互相参见即可。对于实施例公开的装置而言,由于其与实施例公开的方法相对应,所以描述的比较简单,相关之处参见方法部分说明即可。Each embodiment in this specification is described in a progressive manner, each embodiment focuses on the difference from other embodiments, and the same or similar parts of each embodiment can be referred to each other. As for the device disclosed in the embodiment, since it corresponds to the method disclosed in the embodiment, the description is relatively simple, and for the related information, please refer to the description of the method part.
结合本文中所公开的实施例描述的方法或算法的步骤可以直接用硬件、处理器执行的软件模块,或者二者的结合来实施。软件模块可以置于随机存储器(RAM)、内存、只读存储器(ROM)、电可编程ROM、电可擦除可编程ROM、寄存器、硬盘、可移动磁盘、CD-ROM、或技术领域内所公知的任意其它形式的存储介质中。The steps of the methods or algorithms described in conjunction with the embodiments disclosed herein may be directly implemented by hardware, software modules executed by a processor, or a combination of both. Software modules can be placed in random access memory (RAM), internal memory, read-only memory (ROM), electrically programmable ROM, electrically erasable programmable ROM, registers, hard disk, removable disk, CD-ROM, or any other Any other known storage medium.
以上对本申请所提供的一种分类模型训练方法、装置、设备及介质进行了详细介绍,本文中应用了具体个例对本申请的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本申请的方法及其核心思想;同时,对于本领域的一般技术人员,依据本申请的思想,在具体实施方式及应用范围上均会有改变之处,综上所述,本说明书内容不应理解为对本申请的限制。A classification model training method, device, equipment and medium provided by this application have been introduced in detail above. In this paper, specific examples have been used to illustrate the principle and implementation of this application. The description of the above embodiments is only used to help Understand the method of this application and its core idea; at the same time, for those of ordinary skill in the art, according to the idea of this application, there will be changes in the specific implementation and scope of application. In summary, the content of this specification does not It should be understood as a limitation on the present application.

Claims (10)

  1. 一种分类模型训练方法,其特征在于,包括:A classification model training method, characterized in that, comprising:
    基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;Constructing a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on a graph data set; wherein, the vertex label matrix includes label information for each vertex of the graph data set;
    将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;The vertex feature matrix, the adjacency matrix and the vertex label matrix are input to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding supervised training loss in the training process;
    将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;The vertex feature matrix and the adjacency matrix are input to the Student graph wavelet neural network in the classification model to carry out unsupervised training, and determine the corresponding unsupervised training loss in the training process;
    基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;determining a target training loss based on the supervised training loss and the unsupervised training loss;
    当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型;When the target training loss converges, output the current classification model to obtain the trained classification model;
    其中,Teacher图小波神经网络以及Student图小波神经网络均包括输入层,若干图卷积层,以及输出层;Among them, both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer;
    并且,所述方法还包括:在训练过程中,基于注意力机制利用所述Teacher图小波神经网络训练得到的图卷积层的卷积核确定所述Student图小波神经网络中对应的图卷积层的卷积核。And, the method also includes: in the training process, using the convolution kernel of the graph convolution layer obtained by the Teacher graph wavelet neural network training based on the attention mechanism to determine the corresponding graph convolution in the Student graph wavelet neural network The convolution kernel of the layer.
  2. 根据权利要求1所述的分类模型训练方法,其特征在于,所述在训练过程中确定出相应的有监督训练损失,包括:The classification model training method according to claim 1, wherein the corresponding supervised training loss is determined in the training process, comprising:
    在训练过程中,基于Teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;In the training process, the corresponding supervised training loss is determined based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix;
    相应的,所述在训练过程中确定出相应的无监督训练损失,包括:Correspondingly, the corresponding unsupervised training loss is determined in the training process, including:
    在训练过程中,基于Student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测结果确定出相应的无监督训练损失。During the training process, a corresponding unsupervised training loss is determined based on the second vertex label prediction result of the Student graph wavelet neural network and the first vertex label prediction result.
  3. 根据权利要求2所述的分类模型训练方法,其特征在于,还包括:The classification model training method according to claim 2, further comprising:
    在训练过程中,利用所述第一顶点标签预测结果更新所述顶点标签矩阵;During the training process, using the first vertex label prediction result to update the vertex label matrix;
    当所述目标训练损失收敛,则输出当前的顶点标签矩阵,得到每个无类别标签的顶点的类别预测结果。When the target training loss converges, the current vertex label matrix is output to obtain the category prediction result of each vertex without a category label.
  4. 根据权利要求1所述的分类模型训练方法,其特征在于,还包括:The classification model training method according to claim 1, further comprising:
    利用切比雪夫多项式计算所述图数据集的图小波变换基,以及图小波逆变换基;Computing a graph wavelet transform basis and a graph wavelet inverse transform basis for said graph data set using Chebyshev polynomials;
    相应的,Teacher图小波神经网络以及Student图小波神经网络在训练过程中基于所述图小波变换基和图小波逆变换基进行图卷积操作。Correspondingly, during the training process, the Teacher graph wavelet neural network and the Student graph wavelet neural network perform graph convolution operations based on the graph wavelet transform base and the graph wavelet inverse transform base.
  5. 根据权利要求4所述的分类模型训练方法,其特征在于,还包括:The classification model training method according to claim 4, further comprising:
    获取所述图小波变换基的计算公式;Obtain the calculation formula of the wavelet transform base of the graph;
    其中,所述计算公式为基于谱理论定义的公式。Wherein, the calculation formula is a formula defined based on spectrum theory.
  6. 根据权利要求1至5任一项所述的分类模型训练方法,其特征在于,所述图卷积层用于在训练过程中对该层的输入数据依次进行特征变换以及图卷积操作处理。The classification model training method according to any one of claims 1 to 5, wherein the graph convolution layer is used to sequentially perform feature transformation and graph convolution operation processing on the input data of the layer during the training process.
  7. 一种分类模型训练装置,其特征在于,包括:A classification model training device, characterized in that it comprises:
    训练数据构建模块,用于基于图数据集构建顶点特征矩阵、邻接矩阵以及顶点标签矩阵;其中,所述顶点标签矩阵包括所述图数据集每个顶点的标签信息;A training data construction module, configured to construct a vertex feature matrix, an adjacency matrix, and a vertex label matrix based on a graph data set; wherein, the vertex label matrix includes label information for each vertex of the graph data set;
    分类模型训练模块,用于将所述顶点特征矩阵、所述邻接矩阵以及所述顶点标签矩阵输入至分类模型中的Teacher图小波神经网络进行有监督训练,并在训练过程中确定出相应的有监督训练损失;将所述顶点特征矩阵、所述邻接矩阵输入至分类模型中的Student图小波神经网络进行无监督训练,并在训练过程中确定出相应的无监督训练损失;基于所述有监督训练损失以及所述无监督训练损失确定目标训练损失;当所述目标训练损失收敛,则输出当前的分类模型,得到训练后分类模型;The classification model training module is used to input the vertex feature matrix, the adjacency matrix and the vertex label matrix to the Teacher graph wavelet neural network in the classification model to carry out supervised training, and determine the corresponding effective supervised training loss; input the vertex feature matrix and the adjacency matrix to the Student graph wavelet neural network in the classification model for unsupervised training, and determine the corresponding unsupervised training loss in the training process; based on the supervised The training loss and the unsupervised training loss determine the target training loss; when the target training loss converges, the current classification model is output to obtain the trained classification model;
    其中,Teacher图小波神经网络以及Student图小波神经网络均包括输入层,若干图卷积层,以及输出层;Among them, both the Teacher graph wavelet neural network and the Student graph wavelet neural network include an input layer, several graph convolution layers, and an output layer;
    并且,所述分类模型训练模块还用于:在训练过程中,基于注意力机制利用所述Teacher图小波神经网络训练得到的图卷积层的卷积核确定所述Student图小波神经网络中对应的图卷积层的卷积核。And, the classification model training module is also used for: in the training process, based on the attention mechanism, using the convolution kernel of the graph convolution layer obtained from the Teacher graph wavelet neural network training to determine the corresponding The convolution kernel of the graph convolution layer.
  8. 根据权利要求7所述的分类模型训练装置,其特征在于,所述分类模型训练模块,具体用于在训练过程中,基于Teacher图小波神经网络的第一顶点标签预测结果与所述顶点标签矩阵确定出相应的有监督训练损失;基于 Student图小波神经网络的第二顶点标签预测结果与所述第一顶点标签预测结果确定出相应的无监督训练损失。The classification model training device according to claim 7, wherein the classification model training module is specifically used to, in the training process, based on the first vertex label prediction result of the Teacher graph wavelet neural network and the vertex label matrix A corresponding supervised training loss is determined; a corresponding unsupervised training loss is determined based on the second vertex label prediction result of the Student graph wavelet neural network and the first vertex label prediction result.
  9. 一种电子设备,其特征在于,包括:An electronic device, characterized in that it comprises:
    存储器,用于保存计算机程序;memory for storing computer programs;
    处理器,用于执行所述计算机程序,以实现如权利要求1至6任一项所述的分类模型训练方法。A processor, configured to execute the computer program, so as to realize the classification model training method according to any one of claims 1 to 6.
  10. 一种计算机可读存储介质,其特征在于,用于保存计算机程序,所述计算机程序被处理器执行时实现如权利要求1至6任一项所述的分类模型训练方法。A computer-readable storage medium, characterized in that it is used to store a computer program, and when the computer program is executed by a processor, the classification model training method according to any one of claims 1 to 6 is realized.
PCT/CN2021/121905 2021-06-02 2021-09-29 Classification model training method and apparatus, device, and medium WO2022252458A1 (en)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202110613729.6 2021-06-02
CN202110613729.6A CN113255798A (en) 2021-06-02 2021-06-02 Classification model training method, device, equipment and medium

Publications (1)

Publication Number Publication Date
WO2022252458A1 true WO2022252458A1 (en) 2022-12-08

Family

ID=77186018

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/CN2021/121905 WO2022252458A1 (en) 2021-06-02 2021-09-29 Classification model training method and apparatus, device, and medium

Country Status (2)

Country Link
CN (1) CN113255798A (en)
WO (1) WO2022252458A1 (en)

Families Citing this family (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112364372A (en) * 2020-10-27 2021-02-12 重庆大学 Privacy protection method with supervision matrix completion
CN113255798A (en) * 2021-06-02 2021-08-13 苏州浪潮智能科技有限公司 Classification model training method, device, equipment and medium
CN114048816B (en) * 2021-11-16 2024-04-30 中国人民解放军国防科技大学 Method, device, equipment and storage medium for sampling data of graph neural network
CN114943324B (en) * 2022-05-26 2023-10-13 中国科学院深圳先进技术研究院 Neural network training method, human motion recognition method and device, and storage medium
CN115240037A (en) * 2022-09-23 2022-10-25 卡奥斯工业智能研究院(青岛)有限公司 Model training method, image processing method, device and storage medium

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170083829A1 (en) * 2015-09-18 2017-03-23 Samsung Electronics Co., Ltd. Model training method and apparatus, and data recognizing method
CN111552803A (en) * 2020-04-08 2020-08-18 西安工程大学 Text classification method based on graph wavelet network model
CN111639755A (en) * 2020-06-07 2020-09-08 电子科技大学中山学院 Network model training method and device, electronic equipment and storage medium
CN112464057A (en) * 2020-11-18 2021-03-09 苏州浪潮智能科技有限公司 Network data classification method, device, equipment and readable storage medium
CN113255798A (en) * 2021-06-02 2021-08-13 苏州浪潮智能科技有限公司 Classification model training method, device, equipment and medium

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170083829A1 (en) * 2015-09-18 2017-03-23 Samsung Electronics Co., Ltd. Model training method and apparatus, and data recognizing method
CN111552803A (en) * 2020-04-08 2020-08-18 西安工程大学 Text classification method based on graph wavelet network model
CN111639755A (en) * 2020-06-07 2020-09-08 电子科技大学中山学院 Network model training method and device, electronic equipment and storage medium
CN112464057A (en) * 2020-11-18 2021-03-09 苏州浪潮智能科技有限公司 Network data classification method, device, equipment and readable storage medium
CN113255798A (en) * 2021-06-02 2021-08-13 苏州浪潮智能科技有限公司 Classification model training method, device, equipment and medium

Also Published As

Publication number Publication date
CN113255798A (en) 2021-08-13

Similar Documents

Publication Publication Date Title
WO2023000574A1 (en) Model training method, apparatus and device, and readable storage medium
WO2022252458A1 (en) Classification model training method and apparatus, device, and medium
US10248664B1 (en) Zero-shot sketch-based image retrieval techniques using neural networks for sketch-image recognition and retrieval
US11544573B2 (en) Projection neural networks
US9990558B2 (en) Generating image features based on robust feature-learning
Corchado et al. Ibr retrieval method based on topology preserving mappings
Hammer et al. Learning vector quantization for (dis-) similarities
US20160140425A1 (en) Method and apparatus for image classification with joint feature adaptation and classifier learning
CN110796190A (en) Exponential modeling with deep learning features
Duan et al. Separate or joint? Estimation of multiple labels from crowdsourced annotations
CN110377587B (en) Migration data determination method, device, equipment and medium based on machine learning
CN111667022A (en) User data processing method and device, computer equipment and storage medium
WO2022105108A1 (en) Network data classification method, apparatus, and device, and readable storage medium
CN116261731A (en) Relation learning method and system based on multi-hop attention-seeking neural network
Jia et al. A semi-supervised online sequential extreme learning machine method
US20220253722A1 (en) Recommendation system with adaptive thresholds for neighborhood selection
CN112288086A (en) Neural network training method and device and computer equipment
CN116010684A (en) Article recommendation method, device and storage medium
US11816562B2 (en) Digital experience enhancement using an ensemble deep learning model
Muhammadi et al. A unified statistical framework for crowd labeling
CN113392317A (en) Label configuration method, device, equipment and storage medium
Zhang et al. Learning from few samples with memory network
CN112131261A (en) Community query method and device based on community network and computer equipment
US11455512B1 (en) Representing graph edges using neural networks
CN114117048A (en) Text classification method and device, computer equipment and storage medium

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 21943811

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE

122 Ep: pct application non-entry in european phase

Ref document number: 21943811

Country of ref document: EP

Kind code of ref document: A1