CN116090503A - Method for training neural network model based on knowledge distillation and related products - Google Patents

Method for training neural network model based on knowledge distillation and related products Download PDF

Info

Publication number
CN116090503A
CN116090503A CN202211666148.XA CN202211666148A CN116090503A CN 116090503 A CN116090503 A CN 116090503A CN 202211666148 A CN202211666148 A CN 202211666148A CN 116090503 A CN116090503 A CN 116090503A
Authority
CN
China
Prior art keywords
modality
low
model
student model
training
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211666148.XA
Other languages
Chinese (zh)
Inventor
刘从新
王斌
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Beijing Airdoc Technology Co Ltd
Original Assignee
Beijing Airdoc Technology Co Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Beijing Airdoc Technology Co Ltd filed Critical Beijing Airdoc Technology Co Ltd
Priority to CN202211666148.XA priority Critical patent/CN116090503A/en
Publication of CN116090503A publication Critical patent/CN116090503A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02DCLIMATE CHANGE MITIGATION TECHNOLOGIES IN INFORMATION AND COMMUNICATION TECHNOLOGIES [ICT], I.E. INFORMATION AND COMMUNICATION TECHNOLOGIES AIMING AT THE REDUCTION OF THEIR OWN ENERGY USE
    • Y02D10/00Energy efficient computing, e.g. low power processors, power management or thermal management

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Medical Treatment And Welfare Office Work (AREA)

Abstract

The application discloses a method for training a neural network model based on knowledge distillation and a related product. Wherein the method comprises the following steps: acquiring high-modality data and low-modality data and dividing the high-modality data and the low-modality data into a high-modality data set, a low-modality data set and a mixed data set containing the high-modality data and the low-modality data; training a high-modality model and a low-modality model respectively by using the high-modality data set and the low-modality data set to obtain a corresponding high-modality teacher model and low-modality student model; and migrating the high-modality teacher model to the low-modality student model based on knowledge distillation to optimally train the low-modality student model. By means of the scheme, the low-mode model after training can be enabled to have good performance.

Description

Method for training neural network model based on knowledge distillation and related products
Technical Field
The present application relates generally to the field of artificial intelligence technology. More particularly, the present application relates to a method, apparatus, and computer-readable storage medium for training a neural network model based on knowledge distillation.
Background
In a medical scenario, many imaging techniques have high-end and low-end scores that correspond to high-modality data and low-modality data, respectively. For example, in fundus images, optical coherence tomography (Optical Coherence Tomography, "OCT") is high modality data, while color fundus images (Color Fundus Photography, "CFP") is low modality data; among electrocardiograms, 12-lead Electrocardiograph (ECG) is high-modality data, and single-lead electrocardiograph ECG is low-modality data; in phonocardiograms, echocardiograms (echocardiograms, "Echo") are high-modality data, and Phonocardiograms (PCGs) are generally low-modality data. The high-end image and the low-end image include image features of common disease (for example, as shown in fig. 1), and can be used for disease identification. For the high-end image, the image information is more accurate, rich and comprehensive, and the image information of the low-end image is relatively single, and the precision and the imaging quality are also poor. But the low-end image has low cost, convenient operation and convenient popularization. For example, models trained with low-end modality data may be applied to continuous monitoring and large area screening.
Currently, model training based on low-end modal data is generally performed by performing supervised training using labels of corresponding high-end images as true values ("group trunk") to train a model of the low-end modal data. Taking the identification of a heart valve as an example, taking high-mode data Echo as a gold standard, a low-mode PCG model is trained through a label corresponding to the high-mode data Echo. However, the foregoing method necessitates collection of paired patient or case data, while both high-modality data and low-modality data may be enriched, respectively, but with fewer intersections (paired data), which can result in poor model performance of the trained low-end-modality data.
In view of the foregoing, it is desirable to provide a scheme for training a neural network model based on knowledge distillation, which does not need paired data, fully utilizes the data resources of the existing medical scene, and enables the trained low-modality model to have better performance.
Disclosure of Invention
To address at least one or more of the technical problems mentioned above, the present application proposes, in various aspects, a solution for training a neural network model based on knowledge distillation.
In a first aspect, the present application provides a method of training a neural network model based on knowledge distillation, comprising: acquiring high-modality data and low-modality data and dividing the high-modality data and the low-modality data into a high-modality data set, a low-modality data set and a mixed data set containing the high-modality data and the low-modality data; training a high-modality model and a low-modality model respectively by using the high-modality data set and the low-modality data set to obtain a corresponding high-modality teacher model and low-modality student model; and migrating the high-modality teacher model to the low-modality student model based on knowledge distillation to optimally train the low-modality student model.
In one embodiment, wherein the high modality data and the low modality data are the same type of medical data.
In another embodiment, wherein training the high modality model and the low modality model using the high modality data set and the low modality data set, respectively, to obtain a corresponding high modality teacher model and low modality student model includes: performing supervised training on the high modality model using the high modality dataset to obtain a corresponding high modality teacher model; and performing supervised training or self-supervised training on the low modality model using the low modality dataset to obtain the corresponding low modality student model.
In yet another embodiment, wherein migrating the high-modality teacher model to the low-modality student model based on knowledge distillation to optimally train the low-modality student model comprises: reasoning a mixed data set containing the high-modality data and the low-modality data by using the trained high-modality teacher model to output a reasoning value; migrating the reasoning values output by the high-modality teacher model to the low-modality student model; and carrying out knowledge distillation training by combining the true value, the output value of the low-modality student model and the reasoning value transferred to the low-modality student model so as to carry out optimization training on the low-modality student model.
In yet another embodiment, wherein combining the true value, the output value of the low-modality student model, and the inferred value migrated to the low-modality student model for knowledge distillation training to optimize the low-modality student model comprises: setting temperature parameters for an output value of the low-modality student model and an inference value migrated to the low-modality student model respectively; and carrying out knowledge distillation training by combining the true value, the output value of the low-modality student model with the set temperature parameter and the reasoning value of the low-modality student model transferred to the set temperature parameter so as to carry out optimization training on the low-modality student model.
In yet another embodiment, wherein the output value of the low-modality student model and the inferred value migrated to the low-modality student model correspond to a single label or multiple labels, and setting the temperature parameter for the output value of the low-modality student model and the inferred value migrated to the low-modality student model, respectively, comprises: respectively calculating the output value of the low-modality student model and the probability value of the reasoning value migrated to the low-modality student model under the single label or the multi-label by using an activation function; and setting temperature parameters for probability values under corresponding labels so as to respectively set the temperature parameters for output values of the low-mode student model and reasoning values transferred to the low-mode student model.
In yet another embodiment, wherein combining the real value, the output value of the low-modality student model with the set temperature parameter, and the inferred value of the migration of the set temperature parameter to the low-modality student model for knowledge distillation training to optimize the low-modality student model comprises: calculating a first loss function based on the real value and an output value of the low-modality student model; calculating a second loss function based on the output value of the low-modality student model of the set temperature parameter and the reasoning value of the migration of the set temperature parameter to the low-modality student model; obtaining a total loss function from a weighted sum of the first loss function and the second loss function; and performing knowledge distillation training by utilizing the total loss function so as to perform optimization training on the low-modality student model.
In yet another embodiment, wherein the total loss function comprises: is represented by the following formula:
L all =αT 2 ·CE(q τ ,p τ )+(1-α)·CE(q,y)
wherein L is all Representing the total loss function corresponding to the single label, alpha representing the weighting coefficient, T representing the temperature parameter, CE (q, y) representing the first loss function corresponding to the single label, CE (q τ ,p τ ) Representing a second loss function corresponding to the single label, q, y respectively corresponding to an output value and a true value of the low-mode student model under the single label, q τ ,p τ The output value of the low-mode student model representing the set temperature parameter under the single label and the reasoning value of the migration of the set temperature parameter to the low-mode student model are respectively corresponding.
In yet another embodiment, wherein the total loss function further comprises: is represented by the following formula:
Figure BDA0004014740710000031
wherein L is all Representing the total loss function corresponding to the multi-label, alpha represents the weighting coefficient, i represents the category number and y i log(q i ) Representing a first loss function corresponding to the multi-tag, < >>
Figure BDA0004014740710000032
Representing a second loss function corresponding to the multi-label,
q i and y i Respectively corresponding to the output value and the true value of the low-mode student model under the multi-label,
Figure BDA0004014740710000041
and->
Figure BDA0004014740710000042
The output value of the low-mode student model representing the set temperature parameter under the multi-label and the reasoning value of the migration of the set temperature parameter to the low-mode student model are respectively corresponding.
In a second aspect, the present application provides an apparatus for training a neural network model based on knowledge distillation, comprising: a processor; and a memory storing program instructions for training a neural network model based on knowledge distillation, which when executed by the processor, cause the apparatus to implement the plurality of embodiments of the first aspect.
In a third aspect, the present application provides a computer-readable storage medium having stored thereon computer-readable instructions for training a neural network model based on knowledge distillation, which when executed by one or more processors, implement the embodiments of the first aspect described above.
Through the scheme of training the neural network model based on knowledge distillation provided above, the embodiment of the application trains the high-mode model and the low-mode model through the high-mode data set and the low-mode data set respectively to serve as a high-mode teacher model and a low-mode student model, and then migrates the high-mode teacher model to the low-mode student model through knowledge distillation so as to sufficiently promote the low-mode model through data resources of the high-mode model to optimally train the low-mode model. Based on the method, even if paired data are not used, the trained low-mode model has better performance, and the utilization rate of the existing data resources is improved. Further, in the embodiment of the application, the low-mode student model learns the soft label (i.e. the reasoning value) provided by the high-mode teacher model instead of the single label (the true value) to obtain a richer and more accurate signal, so that performance degradation caused by over confidence of the low-mode student model is avoided. In addition, the embodiment of the application performs supervised training or self-supervised training on the low-modality model through the low-modality data to obtain the basic model, so that the knowledge migration effect is better.
Drawings
The above, as well as additional purposes, features, and advantages of exemplary embodiments of the present application will become readily apparent from the following detailed description when read in conjunction with the accompanying drawings. Several embodiments of the present application are illustrated by way of example, and not by way of limitation, in the figures of the accompanying drawings and in which like reference numerals refer to similar or corresponding parts and in which:
FIG. 1 is an exemplary schematic diagram illustrating high-end and low-end images and their common disease types in several medical scenarios;
FIG. 2 is an exemplary diagram illustrating a prior art model of training low-end modal data;
FIG. 3 is an exemplary diagram of data showing high modality data and low modality data and intersections thereof;
FIG. 4 is an exemplary flow diagram illustrating a method of training a neural network model based on knowledge distillation, in accordance with an embodiment of the application;
FIG. 5 is an exemplary diagram illustrating knowledge distillation training in accordance with an embodiment of the application;
FIG. 6 is an exemplary diagram illustrating an ensemble of knowledge-based distillation training neural network models in accordance with an embodiment of the present application; and
FIG. 7 is an exemplary block diagram illustrating an apparatus for training a neural network model based on knowledge distillation, in accordance with an embodiment of the application.
Detailed Description
The technical solutions in the embodiments of the present application will be clearly and completely described below with reference to the accompanying drawings. It should be understood that the embodiments described in this specification are only some embodiments provided herein for the purpose of facilitating a clear understanding of the solution and meeting legal requirements, and not all embodiments of the application may be implemented. All other embodiments, which can be made by those skilled in the art without the exercise of inventive faculty, are intended to be within the scope of the present application, based on the embodiments disclosed in this specification.
As known from the background description, in medical scenarios, there are high-end and low-end divisions of many imaging technologies, which correspond to high-modality data and low-modality data, respectively. For high-end images, the image information is more accurate, rich and comprehensive, and for low-end images, the image information is relatively single, and the precision and the imaging quality are also poor. However, the high-end image and the low-end image include some image features of common disease, such as shown in fig. 1.
Fig. 1 is an exemplary schematic diagram showing high-end and low-end images and their common disease types in several medical scenarios. The corresponding high-end and low-end images under fundus images, electrocardiography and phonocardiogram and their common disease species are exemplarily shown in fig. 1. Wherein OCT (optical coherence tomography) in the fundus image is high-mode data, and CFP (color fundus image) is low-mode data; 12-lead (12-lead) EEG in an electrocardiogram is high-modality data, and single-lead (1-lead) ECG is low-modality data; the echocardiography Echo and echocardiography Echo in the phonocardiogram are high-mode data, and the phonocardiogram PCG is usually low-mode data. Further exemplary showing the presence of common disease species in the fundus image, both high-end and low-end images, include, for example, the anterior macular membrane, the macular hole, and the choroidal neovascularization. Common diseases in which high-end and low-end images exist in an electrocardiogram include, for example, arrhythmia, atrial fibrillation, and premature beat; common diseases in the phonocardiogram where high-end and low-end images exist include, for example, aortic stenosis, aortic regurgitation, mitral stenosis, and mitral regurgitation.
In an application scenario, the corresponding model can be trained for continuous monitoring and large-area screening of diseases. Currently, for model training of low-end modal data, supervised training is usually performed by using labels of corresponding high-end images as true values, for example, as shown in fig. 2.
FIG. 2 is an exemplary diagram illustrating a prior art model for training low-end modal data. As shown in fig. 2, the training label is first obtained by acquiring high-modality data and low-modality data corresponding to the same patient or the same case, and then labeled by using the true value of the high-modality data. Further, model training is performed on the model of the low-end modal data based on the training label and the low-end modal data to obtain a trained model. From the foregoing, it will be appreciated that the foregoing training process must employ paired data. In this case, there is a problem in that both high-modality data and low-modality data may be rich, respectively, but the intersection portion is small, for example, as shown in fig. 3.
Fig. 3 is an exemplary schematic diagram of data showing high modality data and low modality data and intersections thereof. As shown on the left side of fig. 3 is high modality data D H Shown on the right side of FIG. 3Is low-modal data D L The middle part in the figure is the high-modal data D H And low modality data D L Intersection part (i.e. paired data) D of (a) C . As can be seen, the paired data D C Less, this results in less data resources and inadequate training of the model, resulting in poor performance of the model.
Based on the method, the scheme for training the neural network model based on knowledge distillation can fully optimize the training of the low-modal model, so that the trained low-modal model has better performance.
Several embodiments of knowledge-based distillation training neural network models of the present application are described in detail below in conjunction with the figures 4-7.
FIG. 4 is an exemplary flow diagram illustrating a method 400 of training a neural network model based on knowledge distillation, in accordance with an embodiment of the application. As shown in fig. 4, at step 401, high-modality data and low-modality data are acquired and divided into a high-modality data set, a low-modality data set, and a hybrid data set containing the high-modality data and the low-modality data. In one embodiment, the high modality data and the low modality data are the same type of medical data. For example, high modality data (e.g., OCT) and low modality data (e.g., CFP) of the fundus image type, or high modality data (e.g., 12-lead ECG) and low modality data (e.g., single-lead ECG) of the electrocardiographic type, or high modality data (e.g., echo) and low modality data (e.g., PCG) of the phonocardiogram type. In some embodiments, the high modality data and the low modality data in the aforementioned three scenarios may be acquired by, for example, a fundus acquisition device, an electrocardiogram acquisition device, and a phonocardiogram acquisition device, respectively. Further, the foregoing high-modality data and low-modality data may be divided into high-modality data sets (e.g., denoted as D described above H ) Low modality data sets (e.g. noted as D above L ) And a hybrid dataset comprising high modality data and low modality data (e.g., denoted as D above C ) To train the model.
Based on the partitioned high-modality dataset and low-modality dataset, at step 402, the high-modality model and the low-modality model are trained using the high-modality dataset and the low-modality dataset, respectively, to obtain a corresponding high-modality teacher model and low-modality student model. That is, the embodiment of the present application trains the high-modality model and the low-modality model by the high-modality data set and the low-modality data set, respectively, and takes the high-modality model and the low-modality model as a high-modality teacher model (teacher model) and a low-modality student model (student model) in knowledge distillation, respectively. In one embodiment, a high-modality data set may be used to perform supervised training on a high-modality model to obtain a corresponding high-modality teacher model, and a low-modality data set may be used to perform supervised training or self-supervised training on a low-modality model to obtain a corresponding low-modality student model.
Specifically, when training the above-mentioned high-modality model, first, a high-modality data set may be labeled, for example, to label the category of the disease. Then, the high-modality data set is divided into a training set and a verification set to respectively conduct supervised training and evaluation on the high-modality model to obtain a high-modality teacher model. Similarly, when training the low-modality model described above, for labeled cases, the low-modality dataset may be divided into a training set and a validation set, with the low-modality model being trained and the low-modality model being evaluated in a supervised training manner. In addition, for the case of no labeling, a self-supervision training mode can be adopted to train the low-mode model, and the low-mode student model can be obtained by adding, for example, a full-connection layer output result. Based on the obtained aforementioned model, the follow-up may be used for disease classification, identification or prediction, for example.
After the high-modality teacher model and the low-modality student model are obtained, at step 403, the high-modality teacher model is migrated to the low-modality student model based on knowledge distillation to perform optimization training on the low-modality student model. Specifically, based on knowledge distillation, information in the high-modality teacher model is migrated to the low-modality student model to perform optimization training on the low-modality student model. It will be appreciated that knowledge distillation ("knowledge distillation") is the training of a small, lightweight model with the supervision of a larger model with better performance in order to achieve better performance and accuracy. The large model is called the Teacher model ("Teacher") and the small model is called the Student model ("Student"). Wherein the supervisory information output from the teacher model is referred to as Knowledge ("knowledges") and the process by which the student model learns to migrate the supervisory information from the teacher model is referred to as Distillation ("Distillation"). The foregoing knowledge distillation will be described in detail later with reference to fig. 5.
In one embodiment, a training-completed high-modality teacher model is first used to infer a mixed data set containing high-modality data and low-modality data to output an inferred value, then the inferred value output by the high-modality teacher model is migrated to a low-modality student model, and knowledge distillation training is further performed in combination with a true value, an output value of the low-modality student model and the inferred value migrated to the low-modality student model to perform optimization training on the low-modality student model. That is, in the embodiment of the present application, the data set formed by the high-modality data and the low-modality data is first input into the trained high-modality teacher model, and the output result (the inferred value) is obtained after the operation (inference) of the high-modality teacher model. And then, migrating an output result (an inference value) of the high-mode teacher model to the low-mode student model, and carrying out knowledge distillation training by combining a true value (or called a true label) and the output value of the low-mode student model so as to carry out optimization training on the low-mode student model. Based on the knowledge distillation, the low-mode student model can learn composite supervision information instead of a single real label, so that richer and accurate supervision signals are obtained, and the performance of the low-mode student model is improved.
In one implementation scenario, the temperature parameters may be set for the output value of the low-modality student model and the inference value migrated to the low-modality student model, respectively, so as to combine the real value, the output value of the low-modality student model with the set temperature parameters, and the inference value migrated to the low-modality student model with the set temperature parameters to perform knowledge distillation training, so as to perform optimization training on the low-modality student model. In some embodiments, the output value of the foregoing low-modality student model and the inferred value migrated to the low-modality student model correspond to a single label or multiple labels. That is, the low-modality student model and the high-modality teacher model may output a single disease species or multiple concurrent disease species. Further, the activation function may be used to calculate the probability value of the output value of the low-mode student model and the inference value migrated to the low-mode student model in a single label or multiple labels, so as to set the temperature parameter for the probability value under the corresponding label, so as to set the temperature parameter for the output value of the low-mode student model and the inference value migrated to the low-mode student model. It will be appreciated that the output values of the foregoing low-modality student model and the inferred values migrated to the low-modality student model contain logical values ("logits values") for each disease species.
In one embodiment, the activation function may be, for example, a softmax function or a sigmoid function, and the Temperature parameter ("Temperature") is a super parameter. Specifically, the embodiment of the present application firstly calculates probability values corresponding to a single label by using, for example, a softmax function or calculates probability values of a plurality of labels by using a sigmoid function based on the output value of the low-modality student model and the reasoning value migrated to the low-modality student model, and then divides the probability values corresponding to each disease species by the temperature parameter to set the temperature parameter for the output value of the low-modality student model and the reasoning value migrated to the low-modality student model, respectively. Based on this, a softened probability distribution (i.e., soft labels or soft targets) can be obtained to reduce the error rate of the classification probability, avoiding the introduction of unnecessary noise.
As an example, suppose that the reasoning value of migrating to a low-modality student model under a single label is K disease species in total, and the logical value of each disease species is denoted as v0 v 1 ,…,v K-1 The true value is recorded as y 0 ,y 1 ,…,y K-1 . In the actual application scene, the reasoning value of the low-mode student model can form a composite supervision signal v with the true value 0 ,v 1 ,…,v K-1 ;y 0 y 1 ,…,y K-1 . Wherein, for the reasoning value migrated to the low-mode student model, the probability value corresponding to the single label can be calculated through the softmax functionpi, e.g. noted as
Figure BDA0004014740710000091
Further, by setting the temperature parameter to the inferred value migrated to the low-modality student model, the inferred value +.>
Figure BDA0004014740710000092
It can be recorded as
Figure BDA0004014740710000093
Wherein T represents a temperature parameter.
As an example, assume that the output value of a low-modality student model under a single label for K disease species is denoted as z 0 ,z 1 ,…,z K-1 Its corresponding probability value qi is, for example, noted as
Figure BDA0004014740710000094
To which the temperature parameter is set, the output value +.>
Figure BDA0004014740710000095
Then->
Figure BDA0004014740710000096
For multi-labels, for the inferred values migrated to the low-modality student model, the probability values pi corresponding to the multi-label species can be calculated by a sigmoid function, e.g., noted as
Figure BDA0004014740710000097
The set temperature parameter can be obtained by the method, and the reasoning value of the set temperature parameter which is migrated to the low-mode student model can be obtained>
Figure BDA0004014740710000098
Then record as +.>
Figure BDA0004014740710000099
Similarly, the probability value qi of the low-modality student model for multi-label correspondence can be written as +.>
Figure BDA00040147407100000910
To which the temperature parameter is set, the output value +.>
Figure BDA00040147407100000911
Then->
Figure BDA00040147407100000912
And then, carrying out knowledge distillation training by combining the true value, the output value of the low-modality student model with the set temperature parameter and the reasoning value of the low-modality student model with the set temperature parameter transferred to the low-modality student model so as to carry out optimization training on the low-modality student model.
In one embodiment, the first loss function may first be calculated with a true value and an output value of the low-modality student model, and the second loss function is calculated based on the output value of the low-modality student model for the set temperature parameter and an inferred value of the migration of the set temperature parameter to the low-modality student model. In a practical application scenario, the first and second loss functions may be cross entropy functions, for example. Next, a total loss function is obtained from the weighted sum of the first loss function and the second loss function to perform knowledge distillation training using the total loss function to perform optimization training on the low-modality student model. In one implementation scenario, the foregoing total loss function may be represented by the following equation:
L all =αT 2 ·CE(q τ ,p τ )+(1-α)·CE(q,y) (1)
wherein L is all Representing the total loss function corresponding to the single label, alpha representing the weighting coefficient, T representing the temperature parameter, CE (q, y) representing the first loss function corresponding to the single label, CE (q τ ,p τ ) Representing a second loss function corresponding to the single label, q, y respectively corresponding to an output value and a true value of the low-mode student model under the single label, q τ ,p τ The output value of the low-mode student model representing the set temperature parameter under the single label and the reasoning value of the migration of the set temperature parameter to the low-mode student model are respectively corresponding. Where CE represents a cross entropy function, e.g
Figure BDA0004014740710000101
In some embodiments, the temperature parameter T may be selected between 5 and 20, and the weighting coefficient α may be, for example, 0.5.
In another implementation scenario, the above total loss function may also be represented by the following equation:
Figure BDA0004014740710000102
wherein L is all Representing the total loss function corresponding to the multi-label, alpha represents the weighting coefficient, i represents the category number and y i log(q i ) Representing a first loss function corresponding to the multi-tag, < >>
Figure BDA0004014740710000103
Representing a second loss function corresponding to the multi-label,
qi and yi respectively correspond to an output value and a true value representing a low-modality student model under multiple labels,
Figure BDA0004014740710000104
and->
Figure BDA0004014740710000105
The output value of the low-mode student model representing the set temperature parameter under the multi-label and the reasoning value of the migration of the set temperature parameter to the low-mode student model are respectively corresponding. In some embodiments, the temperature parameter T may be selected between 5 and 20, and the weighting coefficient α may be, for example, 0.5.
As can be seen from the above description, in the embodiment of the present application, the high-modality model and the low-modality model are first trained by the high-modality data set and the low-modality data set, respectively, so as to be used as a high-modality teacher model and a low-modality student model in knowledge distillation training. And then, migrating an reasoning value of the high-mode teacher model for reasoning the mixed data set containing the high-mode data and the low-mode data to the low-mode student model through knowledge distillation, and combining a true value as a composite supervision signal to optimally train the low-mode model. Based on the method, paired data and a single label can be avoided, so that the utilization rate of the existing data resources is improved, and a low-mode model with good performance is obtained. Further, in the embodiment of the application, the supervised training or the self-supervised training is performed through the limited low-modality data, so that the basic model is obtained, and the knowledge migration effect is better. In addition, the embodiment of the application also sets temperature parameters for the model output values so as to reduce the error rate of the classification probability and noise, thereby further improving the performance and the precision of the low-mode model.
Fig. 5 is an exemplary schematic diagram illustrating knowledge distillation training in accordance with an embodiment of the application. As shown in fig. 5, a high modality teacher model 510 and a low modality student model 520. In one implementation scenario, the high-modality model may be supervised trained via a labeled high-modality dataset to obtain the aforementioned high-modality teacher model 510. In another implementation scenario, the low-modality model may be supervised trained via the annotated low-modality dataset to obtain the aforementioned low-modality student model 520. In addition, for low-modality data that is not annotated, the low-modality student model 520 may also be obtained through unsupervised training. Based on the trained high-modality teacher model 510, a mixed data set containing high-modality data and low-modality data can be inferred to obtain an inferred value. Further, based on knowledge distillation 530, the inferred values output by the high-modality teacher model 510 may be migrated to the low-modality student model 520 as supervisory information. As previously described, the process of learning and migrating supervisory information from the high-modality teacher model 510 is referred to as Distillation ("Distillation") while the low-modality student model 520 is referred to as Knowledge ("knowledges") from the output of the high-modality teacher model 510.
It can be understood that knowledge distillation is generally only applied to data in the same mode, but in the embodiment of the application, knowledge distillation is applied to data in different modes, so that data resources in a high mode are fully utilized to improve low-mode data, and the model after optimization training in the embodiment of the application has better performance and popularization.
FIG. 6 is an exemplary diagram illustrating an ensemble of knowledge-based distillation training neural network models in accordance with an embodiment of the present application. As shown in fig. 6, at step 601, a high modality data set D is acquired H And annotate, at step 602, a low modality dataset D is acquired L . In one embodiment, the foregoing high modality data D H And low modality data D L Is the same type of medical data. For example, high modality data (e.g., OCT) and low modality data (e.g., CFP), which are both fundus image types. Acquisition-based high-modality dataset D H And a low modality dataset D L At step 603 and step 604, supervised training is performed on the high modality model using the high modality data set and supervised training or self-supervised training is performed on the low modality model using the low modality data set, respectively, and at step 605 and step 606, a high modality teacher model and a low modality student model may be obtained correspondingly.
Further, at step 607, a high-modality teacher model may be used on the hybrid dataset D containing high-modality data and low-modality data C And (5) reasoning to obtain corresponding reasoning values. In one implementation scenario, probability values for each disease species may first be calculated from the inferred values, followed by setting temperature parameters to the probability values to generate soft labels, and then at step 608, by composing the soft labels and the true values into a composite supervisory signal. Next, at step 609, the low-modality student model is knowledge-distilled trained based on the composite supervisory signal and the output values of the low-modality model to optimally train the low-modality student model (i.e., the low-modality model). In one embodiment, the low-modality student model may be trained based on the composite supervisory signal and the total loss function determined by the output values of the low-modality model. Wherein, for single and multi-labels, the total loss function may be calculated based on the above equation (1) and equation (2), respectively, to obtain an optimized trained low-modality model at step 610. Knowledge-based distillation optimizationFor further details of training the low-modality model, reference may be made to what is described above with respect to fig. 4, and this application is not repeated here.
Fig. 7 is an exemplary block diagram illustrating an apparatus 700 for training a neural network model based on knowledge distillation, in accordance with an embodiment of the application. It is to be appreciated that the device implementing aspects of the subject application may be a single device (e.g., a computing device) or a multi-function device including various peripheral devices.
As shown in fig. 7, the apparatus of the present application may include a central processing unit or central processing unit ("CPU") 711, which may be a general purpose CPU, a special purpose CPU, or other information processing and program running execution unit. Further, device 700 may also include a mass memory 712 and a read only memory ("ROM") 713, where mass memory 712 may be configured to store various types of data, including various and high-modality data and low-modality data, algorithm data, intermediate results, and various programs needed to operate device 700. ROM 713 may be configured to store data and instructions necessary to power-on self-test of device 700, initialization of functional modules in the system, drivers for basic input/output of the system, and boot the operating system.
Optionally, the device 700 may also include other hardware platforms or components, such as a tensor processing unit ("TPU") 714, a graphics processing unit ("GPU") 715, a field programmable gate array ("FPGA") 716, and a machine learning unit ("MLU") 717, as shown. It will be appreciated that while various hardware platforms or components are shown in device 700, this is by way of example only and not limitation, and that one of skill in the art may add or remove corresponding hardware as desired. For example, device 700 may include only a CPU, associated memory devices, and interface devices to implement the methods of the present application for training neural network models based on knowledge distillation.
In some embodiments, to facilitate the transfer and interaction of data with external networks, the device 700 of the present application further comprises a communication interface 718, whereby the device can be connected to a local area network/wireless local area network ("LAN/WLAN") 705 through the communication interface 718, and further to a local server 706 or to the Internet ("Internet") 707 through the LAN/WLAN. Alternatively or additionally, the device 700 of the present application may also be directly connected to the internet or cellular network via the communication interface 718 based on wireless communication technology, such as wireless communication technology based on generation 3 ("3G"), generation 4 ("4G"), or generation 5 ("5G"). In some application scenarios, the device 700 of the present application may also access the server 708 and database 709 of the external network as needed to obtain various known algorithms, data, and modules, and may store various data remotely, such as various types of data or instructions for rendering high-modality data and low-modality data, for example.
The peripheral devices of the device 700 may include a display 702, an input 703 and a data transmission interface 704. In one embodiment, the display device 702 may, for example, include one or more speakers and/or one or more visual displays configured for voice prompts and/or image video display of the knowledge-based distillation training neural network model of the present application. The input device 703 may include other input buttons or controls, such as a keyboard, mouse, microphone, gesture-capture camera, etc., configured to receive input of audio data and/or user instructions. The data transfer interface 704 may include, for example, a serial interface, a parallel interface, or a universal serial bus interface ("USB"), a small computer system interface ("SCSI"), serial ATA, fireWire ("FireWire"), PCI Express, and high definition multimedia interface ("HDMI"), etc., configured for data transfer and interaction with other devices or systems. According to aspects of the present application, the data transmission interface 704 may receive high modality data and low modality data from, for example, fundus collection devices; or high modality data and low modality data acquired by, for example, an electrocardiographic acquisition device or a phonocardiogram acquisition device, and transmits data or various other types of data or results including the high modality data and the low modality data to the device 700.
The above-described CPU 711, mass memory 712, ROM 713, TPU714, GPU 715, FPGA716, MLU 717, and communication interface 718 of the device 700 of the present application may be interconnected by a bus 719, and data interaction with peripheral devices may be achieved by the bus. In one embodiment, the CPU 711 may control other hardware components in the device 700 and their peripherals via the bus 719.
An apparatus that may be used to perform knowledge-based distillation training neural network models of the present application is described above in connection with fig. 7. It is to be understood that the device structure or architecture herein is merely exemplary, and that the implementation and implementation entities of the present application are not limited thereto, but that changes may be made without departing from the spirit of the present application.
Those skilled in the art will also appreciate from the foregoing description, taken in conjunction with the accompanying drawings, that embodiments of the present application may also be implemented in software programs. The present application thus also provides a computer program product. The computer program product may be used to implement the method of training a neural network model based on knowledge distillation described herein in connection with fig. 4-6.
It should be noted that although the operations of the methods of the present application are depicted in the drawings in a particular order, this does not require or imply that the operations must be performed in that particular order or that all of the illustrated operations be performed in order to achieve desirable results. Rather, the steps depicted in the flowcharts may change the order of execution. Additionally or alternatively, certain steps may be omitted, multiple steps combined into one step to perform, and/or one step decomposed into multiple steps to perform.
It should be understood that when the terms "first," "second," "third," and "fourth," etc. are used in the claims, the specification and the drawings of this application, they are used merely to distinguish between different objects and not to describe a particular sequence. The terms "comprises" and "comprising," when used in the specification and claims of this application, specify the presence of stated features, integers, steps, operations, elements, and/or components, but do not preclude the presence or addition of one or more other features, integers, steps, operations, elements, components, and/or groups thereof.
It is also to be understood that the terminology used in the description of the present application is for the purpose of describing particular embodiments only, and is not intended to be limiting of the application. As used in the specification and claims of this application, the singular forms "a", "an" and "the" are intended to include the plural forms as well, unless the context clearly indicates otherwise. It should be further understood that the term "and/or" as used in the present specification and claims refers to any and all possible combinations of one or more of the associated listed items, and includes such combinations.
Although the embodiments of the present application are described above, the content is only an example adopted for understanding the present application, and is not intended to limit the scope and application scenario of the present application. Any person skilled in the art can make any modifications and variations in form and detail without departing from the spirit and scope of the disclosure, but the scope of the disclosure is still subject to the scope of the claims.

Claims (11)

1. A method of training a neural network model based on knowledge distillation, comprising:
acquiring high-modality data and low-modality data and dividing the high-modality data and the low-modality data into a high-modality data set, a low-modality data set and a mixed data set containing the high-modality data and the low-modality data;
training a high-modality model and a low-modality model respectively by using the high-modality data set and the low-modality data set to obtain a corresponding high-modality teacher model and low-modality student model; and
and based on knowledge distillation, migrating the high-modality teacher model to the low-modality student model to optimally train the low-modality student model.
2. The method of claim 1, wherein the high modality data and the low modality data are the same type of medical data.
3. The method of claim 1, wherein training a high modality model and a low modality model, respectively, using the high modality data set and the low modality data set to obtain corresponding high modality teacher model and low modality student model comprises:
performing supervised training on the high modality model using the high modality dataset to obtain a corresponding high modality teacher model; and
performing supervised training or self-supervised training on the low modality model using the low modality dataset to obtain a corresponding low modality student model.
4. The method of claim 1, wherein migrating the high-modality teacher model to the low-modality student model based on knowledge distillation to optimally train the low-modality student model comprises:
reasoning a mixed data set containing the high-modality data and the low-modality data by using the trained high-modality teacher model to output a reasoning value;
migrating the reasoning values output by the high-modality teacher model to the low-modality student model; and
and carrying out knowledge distillation training by combining the true value, the output value of the low-modality student model and the reasoning value transferred to the low-modality student model so as to carry out optimization training on the low-modality student model.
5. The method of claim 4, wherein combining the true value, the output value of the low-modality student model, and the inferred value migrated to the low-modality student model for knowledge distillation training to optimally train the low-modality student model comprises:
setting temperature parameters for an output value of the low-modality student model and an inference value migrated to the low-modality student model respectively; and
and carrying out knowledge distillation training by combining the true value, the output value of the low-modality student model with the set temperature parameter and the reasoning value of the low-modality student model, so as to carry out optimization training on the low-modality student model.
6. The method of claim 5, wherein the output values of the low-modality student model and the inferred values migrated to the low-modality student model correspond to a single label or multiple labels, and setting temperature parameters for the output values of the low-modality student model and the inferred values migrated to the low-modality student model, respectively, comprises:
respectively calculating the output value of the low-modality student model and the probability value of the reasoning value migrated to the low-modality student model under the single label or the multi-label by using an activation function; and
and setting temperature parameters for probability values under corresponding labels so as to respectively set the temperature parameters for output values of the low-mode student model and reasoning values transferred to the low-mode student model.
7. The method of claim 6, wherein combining the true value, the output value of the low-modality student model for the set temperature parameter, and the inferred value of the migration of the set temperature parameter to the low-modality student model for knowledge distillation training to optimize training the low-modality student model comprises:
calculating a first loss function based on the real value and an output value of the low-modality student model;
calculating a second loss function based on the output value of the low-modality student model of the set temperature parameter and the reasoning value of the migration of the set temperature parameter to the low-modality student model;
obtaining a total loss function from a weighted sum of the first loss function and the second loss function; and
and carrying out knowledge distillation training by utilizing the total loss function so as to carry out optimization training on the low-modality student model.
8. The method of claim 7, wherein the total loss function comprises: is represented by the following formula:
L all =αT 2 ·CE(q τ ,p τ )+(1-α)·CE(q,y)
wherein L is all Representing the total loss function corresponding to the single label, alpha representing the weighting coefficient, T representing the temperature parameter, cE (q, y) representing the first loss function corresponding to the single label, cE (q τ ,p τ ) Representing a second loss function corresponding to the single label, q, y respectively corresponding to an output value and a true value of the low-mode student model under the single label, q τ ,p τ The output value of the low-mode student model representing the set temperature parameter under the single label and the reasoning value of the migration of the set temperature parameter to the low-mode student model are respectively corresponding.
9. The method of claim 7, wherein the total loss function further comprises: is represented by the following formula:
Figure FDA0004014740700000031
wherein L is all Representing the total loss function corresponding to the multi-label, alpha represents the weighting coefficient, i represents the category number and y i log(q i ) Representing a first loss function corresponding to the multi-label,
Figure FDA0004014740700000032
representing a second loss function corresponding to the multi-label, q i And y i Output values and true values respectively representing the low-modality student model under multiple labels, ++>
Figure FDA0004014740700000033
And->
Figure FDA0004014740700000034
The output value of the low-mode student model representing the set temperature parameter under the multi-label and the reasoning value of the migration of the set temperature parameter to the low-mode student model are respectively corresponding.
10. An apparatus for training a neural network model based on knowledge distillation, comprising:
a processor; and
a memory storing program instructions for training a neural network model based on knowledge distillation, which when executed by the processor, cause the apparatus to implement the method according to any one of claims 1-9.
11. A computer-readable storage medium having stored thereon computer-readable instructions for training a neural network model based on knowledge distillation, which when executed by one or more processors, implement the method of any of claims 1-9.
CN202211666148.XA 2022-12-23 2022-12-23 Method for training neural network model based on knowledge distillation and related products Pending CN116090503A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211666148.XA CN116090503A (en) 2022-12-23 2022-12-23 Method for training neural network model based on knowledge distillation and related products

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211666148.XA CN116090503A (en) 2022-12-23 2022-12-23 Method for training neural network model based on knowledge distillation and related products

Publications (1)

Publication Number Publication Date
CN116090503A true CN116090503A (en) 2023-05-09

Family

ID=86198365

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211666148.XA Pending CN116090503A (en) 2022-12-23 2022-12-23 Method for training neural network model based on knowledge distillation and related products

Country Status (1)

Country Link
CN (1) CN116090503A (en)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117421678A (en) * 2023-12-19 2024-01-19 西南石油大学 Single-lead atrial fibrillation recognition system based on knowledge distillation

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117421678A (en) * 2023-12-19 2024-01-19 西南石油大学 Single-lead atrial fibrillation recognition system based on knowledge distillation
CN117421678B (en) * 2023-12-19 2024-03-22 西南石油大学 Single-lead atrial fibrillation recognition system based on knowledge distillation

Similar Documents

Publication Publication Date Title
Dey et al. Artificial intelligence in cardiovascular imaging: JACC state-of-the-art review
US11864944B2 (en) Systems and methods for a deep neural network to enhance prediction of patient endpoints using videos of the heart
Tekkeşin Artificial intelligence in healthcare: past, present and future
US20220036135A1 (en) Method and apparatus for determining image to be labeled and model training method and apparatus
Karatzia et al. Artificial intelligence in cardiology: Hope for the future and power for the present
JP2022529682A (en) Medical video splitting methods, devices, computer devices and computer programs
CN112116090B (en) Neural network structure searching method and device, computer equipment and storage medium
CN116090503A (en) Method for training neural network model based on knowledge distillation and related products
Qazi et al. Automated Heart Wall Motion Abnormality Detection from Ultrasound Images Using Bayesian Networks.
Sengan et al. Echocardiographic image segmentation for diagnosing fetal cardiac rhabdomyoma during pregnancy using deep learning
Nova et al. Automated image segmentation for cardiac septal defects based on contour region with convolutional neural networks: A preliminary study
Wehbe et al. Deep learning for cardiovascular imaging: A review
CN114611879A (en) Clinical risk prediction system based on multitask learning
Sanjeevi et al. Automatic diagnostic tool for detection of regional wall motion abnormality from echocardiogram
US11896432B2 (en) Machine learning for identifying characteristics of a reentrant circuit
Ling et al. Extraction of volumetric indices from echocardiography: which deep learning solution for clinical use?
CN116230224A (en) Method and system for predicting adverse events of heart failure based on time sequence model
CN112562819B (en) Report generation method of ultrasonic multi-section data for congenital heart disease
CN115552543A (en) Time series data transformation for machine learning model applications
Nagel et al. A Bi-atrial Statistical Shape Model as a Basis to Classify Left Atrial Enlargement from Simulated and Clinical 12-Lead ECGs
Zavoyskih et al. The electrocardiogram signal morphology analysis based on convolutional neural network
CN116434969B (en) Multi-center chronic disease prediction device based on causal structure invariance
Gavidia Ai methods for detection and prediction of diseases from physiological data
Rashid Towards Automatic Visual Recognition of Horse Pain
CN117711629A (en) Federal training method and system for disease prediction model and disease prediction method

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination