CN115423000A - Cross-domain small sample identification method based on multi-teacher knowledge distillation - Google Patents

Cross-domain small sample identification method based on multi-teacher knowledge distillation Download PDF

Info

Publication number
CN115423000A
CN115423000A CN202211001654.7A CN202211001654A CN115423000A CN 115423000 A CN115423000 A CN 115423000A CN 202211001654 A CN202211001654 A CN 202211001654A CN 115423000 A CN115423000 A CN 115423000A
Authority
CN
China
Prior art keywords
domain
model
small sample
teacher
data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
CN202211001654.7A
Other languages
Chinese (zh)
Inventor
姜育刚
傅宇倩
谢宇
付彦伟
陈静静
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Fudan University
Original Assignee
Fudan University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Fudan University filed Critical Fudan University
Priority to CN202211001654.7A priority Critical patent/CN115423000A/en
Publication of CN115423000A publication Critical patent/CN115423000A/en
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Health & Medical Sciences (AREA)
  • Image Analysis (AREA)

Abstract

The invention belongs to the technical field of computers, and particularly relates to a cross-domain small sample identification method based on multi-teacher knowledge distillation. The method comprises the following steps: constructing a feature extraction network model, a small sample classifier and a dynamic domain splitting model; connecting the feature extraction network and the small sample classifier to form a source domain teacher model; connecting the feature extraction network and the small sample classifier to form a target domain teacher model; connecting a feature extraction network and a small sample classifier, and inserting a dynamic field splitting model into specific layers in the feature extraction network to form a student model with a splittable field; the method uses source domain data and target domain data to respectively learn a teacher model, and then uses a knowledge distillation method to gradually distill the knowledge in the two teacher models into the student models. The method can achieve better category identification capability on the target domain test data which is labeled only in a small amount and unknown in category under the condition that the source domain and the target domain have huge domain difference.

Description

Cross-domain small sample identification method based on multi-teacher knowledge distillation
Technical Field
The invention belongs to the technical field of computers, and particularly relates to a cross-domain small sample identification method.
Background
Small sample learning aims at transferring knowledge from a source data set to a new target data set with only one or a few examples of labels. In general, small sample learning assumes that the images of the source and target data sets belong to the same data domain. However, such an ideal assumption may not be easily satisfied in real-world multimedia applications. For example, as revealed in work [1], models trained on datasets composed primarily of a large variety of natural images are still unable to identify novel, fine-grained birds. Therefore, the cross-domain small sample identification aims to solve the problem that the source domain is inconsistent with the target domain in the small sample identification.
In recent years, cross-domain small samples have been extensively studied in many previous approaches [2, 3, 4, 5, 6]. Most of them [3,5,6] use source domain images only for training and mainly aim to improve the generalization capability of the model. Although some efforts have been made, significant breakthroughs in performance are still difficult to achieve due to the large domain gap between the source and target data sets. Thus, some work [2,4] relaxes this most basic but strictest setting, allowing the target data to be used during the training phase. Where STARTUP [4] uses a large amount of unmarked target data, and Meta-FDMixup [2] claims the use of a small amount of limited marked target data. However, in the face of a particular category in reality, such as an endangered wildlife or a particular building, etc., it is not so easy to obtain a large number of images of the unlabeled category. In contrast, it is more realistic to use a small amount of limited labeled target domain data (e.g., 5 images per class). Therefore, the method of the present invention is consistent with meta-FDMixup in specific task settings.
Given a source data set with a large number of annotated images and a secondary target domain data set with only a few annotated images,
in order for the model to better learn and migrate the knowledge on these two data sets to the target domain data, the main approach of meta-FDMixup is: (1) A data enhancement method based on mixup [7] is provided to mix and utilize the training data of a source domain and a target domain; (2) A feature decoupling module is provided to split the overall features into domain-dependent features and domain-independent features.
Unlike meta-FDMixup, the present invention first addresses two major challenges to this problem: (1) The number of marker images of the source data set and the auxiliary target data set are very unbalanced. Models learned on such unbalanced training data will bias towards the source data set, while performing much worse on the target data set. (2) Since the source data set and the auxiliary target belong to two different domains, a single model may have difficulty learning knowledge from data sets of different domains simultaneously. Therefore, the invention provides a cross-domain small sample identification method based on multi-teacher knowledge distillation.
Disclosure of Invention
The invention aims to provide a cross-domain small sample identification method based on multi-teacher knowledge distillation, and provides a small sample classification model with strong generalization and migration capability for identifying a cross-domain small sample and a visual task with the two difficulties of few usable labeled samples of a target class and field difference between a source domain and a target domain.
The invention provides a cross-domain small sample identification method based on multi-teacher knowledge distillation, which relates to three modules: a feature extraction network, a small sample classifier, and a dynamic domain splitting model, and three network models: a source domain teacher model, a target domain teacher model and a student model with detachable domains.
The invention provides a cross-domain small sample identification method based on multi-teacher knowledge distillation.
(1) Three modules are built: a feature extraction network, a small sample classifier and a dynamic domain splitting model;
(1.1) selecting any one deep neural network model capable of extracting high-dimensional features of an image from the existing model as a feature extraction network model; in the present invention, resNet-10 structure [8 ]]Is used; using the feature extraction network model, giving source domain or target domain data, extracting to obtain corresponding source domain feature F S Target domain feature F T
(1.2) selecting any one model capable of classifying the classes of the images in the small sample query set according to the images in the small sample support set from the existing models as a small sample classifier; in the present invention, GNN 9 is used; using the small sample classifier, giving any one meta-learning task { S, Q }, and obtaining the probability distribution P of Q;
(1.3) building a dynamic domain splitting model, wherein the main function of the model is to dynamically split a specific layer of a network into a source domain related part and a target domain related part; specifically, a domain gate matrix M is defined, and the dimensionality of the matrix M is consistent with the number of convolution kernels needing to be disassembled; correspondingly, the value M of the ith element in M i The probability that the ith convolution kernel is assigned to the source domain is expressed, and 1-M i Then the probability that this convolution kernel is assigned to the target domain is indicated; but floating point numbers between 0-1 do not meet the ideal expectation of the present invention that one would like to fully assign a certain convolution kernel to either the source domain or the destination domain; thus, the present invention further introduces Gumbel softmax [10 ]]Realizing the binarization of the floating point M when M is i When the output is 1, the source domain channel is activated, and the target domain channel is closed; on the contrary, when M i When the output is 0, the source domain channel is closed, and the target domain channel is activated;
using the dynamic domain splitting model, the output characteristic source domain output F of a certain layer of a given network S Target domain output F T Determining the final source domain output as F by the value of the matrix M S M, meshScalar field output of F T 1-M; setting the matrix M as a learnable parameter, and updating together with network training.
(2) Three network models are formed based on three modules: a source domain teacher model, a target domain teacher model and a field-detachable student model;
(2.1) connecting the feature extraction network and the small sample classifier to form a source domain teacher model;
(2.2) connecting the feature extraction network and the small sample classifier to form a target domain teacher model;
and (2.3) connecting the feature extraction network and the small sample classifier, inserting a dynamic field splitting model into specific layers in the feature extraction network, and forming a student model with the field being split.
(3) Only source domain data is used for training a source domain teacher model, and the training method comprises the following steps: randomly sampling a meta-learning unit from a source domain data set as network input, sequentially passing through a feature extraction network and a small sample classifier to obtain the result prediction probability distribution of the model for each picture category in the query set, and then obtaining a training loss function according to the distance between the model and the correct category.
(4) Only using the target domain data to train the target domain teacher model, wherein the training method comprises the following steps: randomly sampling a meta-learning unit from a target domain data set as network input, sequentially passing through a feature extraction network and a small sample classifier to obtain the result prediction probability distribution of the model for each picture category in the query set, and then obtaining a training loss function according to the distance between the model and the correct category.
(5) The method is characterized in that a field detachable student model is trained by using data from a source domain and a target domain simultaneously, and the training method comprises the following steps: respectively sampling a meta learning unit from a source domain data set and a target domain data set, and respectively splitting the feature extraction network through a standard path (sequentially passing through the feature extraction network and a small sample classifier, and not splitting the domain of the feature extraction network, namely all convolution kernels can be activated no matter which domain the data comes from) and a domain path (the feature extraction network, a dynamic domain splitting model and a small sample classifier)Sub, i.e. having only corresponding domain gates M i When the output value of the convolution kernel is 1, the convolution kernel activates data in the current field) to respectively obtain probability prediction results, and then two subtasks are executed:
(5.1) classifying the small sample elements to learn a task, and then obtaining a training loss function according to the distance between the small sample elements and the correct class;
(5.2) knowledge distillation task: and comparing the prediction probability distribution of the student model with the probability distribution of the corresponding teacher model to obtain the training loss.
(6) On the basis of unknown class test data of a target domain, performing performance test on a student model with a detachable field by using a small sample element classification task, obtaining two different probability prediction results by passing the data through a standard path and a common path, averaging the two prediction probabilities to be used as a final prediction result, wherein the class with the highest score in the probability distribution is the prediction class of the time. This step is repeated several times (e.g., 1000 times) to obtain the final model accuracy.
In the step (5), in the small sample element classification learning task, firstly, an element learning unit { support set, query set } is randomly sampled from a data set to serve as network input, and a result prediction probability distribution of the model for each picture category in the query set is obtained by sequentially passing through a feature extraction network, a dynamic field disassembly network (if any), and a small sample classifier. The predicted probability distribution will be used to compare against the correct query set class, resulting in a training loss. In the present invention, cross entropy loss is used to calculate small sample element classification loss.
In step (5), the knowledge distillation task gives data to be input into the student model in a specific field, such as source data, and compares the prediction probability distribution of the student model with the probability distribution of the teacher model obtained by the source teacher model for the data to obtain training loss. In the present invention, the dispersion distance is used to calculate the knowledge distillation task loss.
In summary, the innovation of the invention is as follows:
1. according to the method, a multi-teacher model mechanism and a knowledge distillation technology are introduced into a cross-domain small sample identification task with a small amount of marked target data for the first time, so that the model is prevented from being directly learned on a data set with extremely unbalanced sample marks;
2. the invention innovatively provides a dynamic field splitting module, so that the network automatic learning decomposes the model structure into a source domain specific part and a target domain specific part; such domain-based structural decomposition has rarely been explored in previous cross-domain small sample work.
Drawings
FIG. 1 is a cross-domain small sample task difficulty and solution diagram that the present invention deals with.
FIG. 2 is a schematic diagram of the main steps of the present invention. Wherein (a) a teacher model is trained separately for a source domain and a target domain; (b) Training a domain-detachable student network by distilling knowledge from a source domain teacher model and a target domain teacher model; (c) In the inference phase, only student networks are used for prediction.
FIG. 3 is a schematic diagram of a dynamic domain teardown module of the present invention. The module learns the domain gate matrix M to control the activation state of the filter. Gumbel softmax is used to binarize the matrix values.
Detailed Description
The cross-domain small sample task difficulty and solution handled by the present invention is illustrated with reference to fig. 1. Given a source domain data set with a sufficient sample size and a target domain data set with only a small sample size, the task has the following difficulties: (1) A serious data imbalance problem exists between the two training data sets; (2) the model needs to be learned from different domains simultaneously. Accordingly, the present invention provides the following key solutions: (1) The method comprises the steps of training two independent teacher models, and transferring knowledge to a student model through knowledge distillation; (2) A new dynamic domain decomposition module is proposed that learns to decompose the network structure of a student model into two domain-related sub-parts.
The invention provides a cross-domain small sample identification method based on multi-teacher knowledge distillation, which comprises the following detailed steps.
Step 1, setting a basic network model, setting ResNet-10 as a feature extractor, setting GNN as a small sample feature extractor, and setting a domain gate matrix M of a dynamic domain splitting model as a learnable parameter.
And step 2, forming a source domain teacher model St-Net and a target domain teacher model Tt-Net by using ResNet10 and GNN.
And 3, randomly selecting a meta-learning task { S, Q, y } from the source domain data set to train a source domain teacher model St-Net, wherein S represents a support set, Q represents a query set, and y represents a correct sample label corresponding to the query set, and the specific process comprises the steps of sequentially inputting { S, Q } into ResNet10 and GNN to obtain the final small sample prediction probability
Figure BDA0003807575920000051
Then to
Figure BDA0003807575920000052
And calculating a cross entropy loss function with y as a metric loss function optimization model of the training.
Step 4, repeating step 3 for num train The final source domain teacher model St-Net is obtained through turns, wherein num train =400。
And 5, randomly selecting a meta-learning task { S, Q, y } from the target domain data set to train the target domain teacher model Tt-Net, wherein the specific training method is consistent with the steps 3 and 4, and the final target domain teacher model Tt-Net is obtained.
And 6, dividing ResNet10 into 4 structural blocks, setting the last two structural blocks as field detachable blocks, specifically counting the convolution kernel number of the two structural blocks, and setting the dimensionality of the field gate matrix M according to the number. On the basis of the ResNet10 with the learnable domain gate M, a GNN small sample classifier is added to form a domain detachable learning model ME-D2N.
Step 7, randomly sampling meta-learning task from source domain data set { S } src ,Q src ,y src Sampling a meta-learning task { S ] from a target domain dataset tgt ,Q tgt ,y tgt And the method is used for training the detachable student model in the ME-D2N field. Wherein the subscripts src/tgt denote the meta-learning taskData is sampled from the source/target domain. The method comprises the following specific steps:
step 7.1: learning task of source domain element (S) src ,Q src Inputting the predicted result into a source domain teacher model to obtain the predicted result output of the teacher model
Figure BDA0003807575920000053
Target Domain element learning task S tgt ,Q tgt Inputting the predicted result into a target domain teacher model to obtain a predicted result and outputting the predicted result
Figure BDA0003807575920000061
Wherein, P represents the result of the prediction,
Figure BDA0003807575920000062
particularly, the prediction result comes from a teacher model;
step 7.2: obtaining a probabilistic predictive distribution over a standard path (STD); learning task with source domain elements S src ,Q src ,y src Taking the standard path as an example, the method of the standard path is to directly input the support set and the query set into ResNet-10 and GNN respectively without considering the existence of the domain gate M, and obtain the probability distribution corresponding to the standard path
Figure BDA0003807575920000063
The same way is used for learning task S for the target domain element tgt ,Q tgt ,y tgt Get the probability distribution corresponding to the corresponding standard path
Figure BDA0003807575920000064
Step 7.3, obtaining probability prediction distribution through a domain access (DSG); learning task with source domain elements S src ,Q src ,y src For example, the field path is implemented as follows: controlling the activation relation of the convolution kernel to the source domain data and the target domain data according to the output of the domain gate M; specifically, gumbel softmax is firstly carried out on the numerical value of the domain door M to realize numerical value binarization; marking the original source domain data characteristics obtained by ResNet common path as
Figure BDA00038075759200000624
The final source domain output is determined by the value of the matrix M as
Figure BDA0003807575920000065
Then inputting the characteristics into GNN to obtain probability distribution corresponding to domain path
Figure BDA0003807575920000066
In the same way, the task S is learned for the target domain element tgt ,Q tgt ,y tgt Firstly, the original target domain data is obtained with the characteristics that
Figure BDA0003807575920000067
The final source domain output is determined by the value of the matrix M as
Figure BDA0003807575920000068
Inputting the characteristics into GNN to obtain probability distribution corresponding to domain path
Figure BDA0003807575920000069
Here, symbol F denotes a visual feature;
step 7.4: probability distribution to standard paths of source domain
Figure BDA00038075759200000610
Probability distribution of and source domain paths
Figure BDA00038075759200000611
Calculating small sample element learning task loss
Figure BDA00038075759200000612
And knowledge of distillation losses
Figure BDA00038075759200000613
The symbol L represents a loss function; the specific calculation method is as follows:
Figure BDA00038075759200000614
Figure BDA00038075759200000615
where CE () represents the cross entropy loss function, KD () represents the dispersion distance, k represents the hyperparameter, specifically, k =0.2
Step 7.5: probability distribution of standard paths to target domain
Figure BDA00038075759200000616
Probability distribution of source domain paths
Figure BDA00038075759200000617
Computing small sample element learning task loss
Figure BDA00038075759200000618
And knowledge of distillation losses
Figure BDA00038075759200000619
The specific calculation mode is similar to that of the step 7.4, and the specific formula is as follows:
Figure BDA00038075759200000620
Figure BDA00038075759200000621
step 7.6: the final one-time training loss function L is calculated as follows:
Figure BDA00038075759200000622
Figure BDA00038075759200000623
L=k 2 ·L src +(1-k 2 )·L tgt , (7)
wherein k is 1 And k 2 Representing a hyper-parameter, and the values of the hyper-parameter are all 0.2;
step 7.7: train this model with loss L, repeat step 7 for num train Wheel, num train Equal to 400.
And 8: task { S) for classifying ME-D2N student models with small sample elements, which can be split in field, on unknown class test data of target domain test ,Q test Performing performance test, and obtaining two different probability prediction results (synchronous step 7.2 and step 7.3) by data through a standard access and a common access
Figure BDA0003807575920000071
And
Figure BDA0003807575920000072
averaging the two prediction probabilities as the final prediction result
Figure BDA0003807575920000073
The highest scoring class in the probability distribution is the prediction class at this time. This step was repeated 1000 times to obtain the final model accuracy.
In this example, the model accuracy is as follows in table 1: wherein, cub, cars, places and plantae represent unknown class test data of four different target fields. 5-way 1-shot and 5-way 5-shot represent each small sample meta-learning task, wherein 5-way represents that each meta-learning task support set S has 5 different categories, and 1-shot and 5-shot represent that the categories in S have 1/5 support set samples.
TABLE 1 model accuracy
Cub Cars Places Plantae Average
5-way 1-shot 65.05±0.83 49.53±0.79 60.36±0.80 52.89±0.83 56.96
5-way 5-shot 83.17±0.56 69.17±0.68 80.45±0.62 72.87±0.67 76.42
Reference to the literature
[1]Wei-Yu Chen,Yen-Cheng Liu,Zsolt Kira,Yu-Chiang Frank Wang,and Jia-Bin Huang.2019.A closer look at few-shot classification.arXiv preprint(2019).
[2]Yuqian Fu,Yanwei Fu,and Yu-Gang Jiang.2021.Meta-FDMixup:Cross-Domain Few-Shot Learning Guided by Labeled Target Data.In ACM Multimedia.5326–5334.
[3]Yuqian Fu,Yu Xie,Yanwei Fu,Jingjing Chen,and Yu-Gang Jiang.2022.Wave-SAN:Wavelet based Style Augmentation Network for Cross-Domain Few-Shot Learning.arXiv preprint(2022).
[4]Cheng PerngPhoo and Bharath Hariharan.2020.Self-training for Few-shot Transfer Across Extreme Task Differences.arXiv preprint(2020).
[5]Jiamei Sun,Sebastian Lapuschkin,Wojciech Samek,Yunqing Zhao,Ngai-Man Cheung,and Alexander Binder.2020.Explanation-guided training for cross-domain few-shot classification.arXiv preprint(2020).
[6]Hung-Yu Tseng,Hsin-Ying Lee,Jia-Bin Huang,and Ming-Hsuan Yang.2020.Cross-domain few-shot classification via learned feature-wise transformation.In ICLR.
[7]Hongyi Zhang,MoustaphaCisse,Yann NDauphin,and David Lopez-Paz.2017.mixup:Beyond empirical risk minimization.arXiv preprint(2017).
[8]Kaiming He,Xiangyu Zhang,Shaoqing Ren,and Jian Sun.2016.Deep residual learning for image recognition.In CVPR.
[9]Victor Garcia and Joan Bruna.2017.Few-shot learning with graph neural networks.arXiv preprint(2017).
[10]Eric Jang,Shixiang Gu,and Ben Poole.2016.Categorical reparameterization with gumbel-softmax.arXiv preprint(2016)。

Claims (6)

1. A cross-domain small sample identification method based on multi-teacher knowledge distillation is characterized by comprising the following specific steps:
(1) Three modules are built: a feature extraction network, a small sample classifier and a dynamic domain splitting model;
(1.2) adopting ResNet-10 as a feature extraction network model; using the feature extraction network model, giving source domain or target domain data, extracting to obtain corresponding source domain features F S Target domain feature F T
(1.2) adopting GNN as a small sample classifier; using the small sample classifier, giving any one meta-learning task { S, Q }, and obtaining the probability distribution P of Q;
(1.3) a dynamic domain splitting model, which is mainly used for dynamically splitting a specific layer of a network into a source domain related part and a target domain related part; specifically, a domain gate matrix M is defined, and the dimensionality of the matrix M is consistent with the number of convolution kernels needing to be disassembled; correspondingly, the value M of the ith element in M i The probability that the ith convolution kernel is assigned to the source domain is expressed, and 1-M i Then the probability that this convolution kernel is assigned to the target domain is indicated; gumbel softmax is further introduced to realize the binarization of the floating point M when M is equal i When the output is 1, the source domain channel is activated, and the target domain channel is closed; on the contrary, when M i When the output is 0, the source domain channel is closed, and the target domain channel is activated;
using the dynamic domain splitting model, the output characteristic source domain output F of a certain layer of a given network S Target domain output F T Determining the final source domain output as F by the value of the matrix M S M, target Domain output of F T 1-M; setting a matrix M as a learnable parameter, and updating the matrix M together with network training;
(2) Three network models are formed based on three modules: a source domain teacher model, a target domain teacher model and a field-detachable student model;
(2.1) connecting the feature extraction network and the small sample classifier to form a source domain teacher model;
(2.2) connecting the feature extraction network and the small sample classifier to form a target domain teacher model;
(2.3) connecting the feature extraction network and the small sample classifier, and inserting a dynamic field splitting model into specific layers in the feature extraction network to form a student model with a splittable field;
(3) Only using source domain data to train a source domain teacher model, wherein the training method comprises the following steps: randomly sampling a meta-learning unit from a source domain data set as network input, sequentially passing through a feature extraction network and a small sample classifier to obtain the result prediction probability distribution of the model for each picture category in the query set, and then obtaining a training loss function according to the distance between the model and the correct category;
(4) Only using the target domain data to train the target domain teacher model, wherein the training method comprises the following steps: randomly sampling a meta-learning unit from a target domain data set as network input, sequentially passing through a feature extraction network and a small sample classifier to obtain the result prediction probability distribution of the model for each picture category in the query set, and then obtaining a training loss function through the distance between the model and the correct category;
(5) The method is characterized in that a field detachable student model is trained by using data from a source domain and a target domain simultaneously, and the training method comprises the following steps: respectively sampling a meta-learning unit from a source domain data set and a target domain data set, and respectively obtaining probability prediction results by the two learning units through a standard passage and a field passage; wherein:
the standard path is sequentially passed through a feature extraction network and a small sample classifier, and the domain of the feature extraction network is not split, namely, all convolution kernels are activated no matter which domain the data come from;
the domain path is formed by splitting the feature extraction network sequentially through the feature extraction network, the dynamic domain splitting model and the small sample classifier, namely only the corresponding domain gate M i When the output value of the convolution kernel is 1, the convolution kernel activates the data of the current field;
two subtasks are then performed:
(5.1) a small sample element learning task: obtaining a training loss function through the distance between the training loss function and the correct category;
(5.2) knowledge distillation task: comparing the prediction probability distribution of the student model with the probability distribution of the corresponding teacher model to obtain training loss;
(6) Performing performance test on a student model with a detachable field on unknown class test data of a target domain by using a small sample element classification task, obtaining two different probability prediction results by passing the data through a standard path and a common path, averaging the two prediction probabilities to be used as a final prediction result, wherein the class with the highest score in the probability distribution is the prediction class of the time; this step is repeated several times to obtain the final model accuracy.
2. The method for identifying the cross-domain small samples based on multi-teacher knowledge distillation as claimed in claim 1, wherein in the step (5), the small sample element classification learning task firstly randomly samples one element learning unit { support set, query set } from a data set as a network input, and sequentially passes through a feature extraction network, a dynamic domain disassembly network and a small sample classifier to obtain a result prediction probability distribution of a model for each picture category in the query set; the predicted probability distribution is used for comparing with the correct query set category to obtain training loss; in the present invention, cross entropy loss is used to calculate small sample element classification loss.
3. The method for identifying a cross-domain small sample based on multi-teacher knowledge distillation as claimed in claim 1, wherein in the step (5), the knowledge distillation task is given to data to be input into a student model in a specific field, such as source data, and the predicted probability distribution of the student model is compared with the probability distribution of the teacher model obtained by the source teacher model for the data to obtain training loss; in the present invention, the dispersion distance is used to calculate the knowledge distillation task loss.
4. The method for identifying the cross-domain small sample based on multi-teacher knowledge distillation as claimed in claim 1, wherein a feature extraction network model ResNet-10 is divided into 4 structural blocks, the last two structural blocks are set as domain detachable blocks, the method is specifically implemented by counting the number of convolution kernels of the two structural blocks, and then setting the dimension of a domain gate matrix M according to the number; on the basis of the ResNet10 with the learnable domain gate M, a GNN small sample classifier is added to form a domain-separable learning model ME-D2N.
5. The multi-teacher knowledge distillation-based cross-domain small sample identification method of claim 4, wherein the meta-learning task { S } is randomly sampled from the source domain data set src ,Q src ,y src From the target domain datasetSample meta learning task S tgt ,Q tgt ,y tgt The method is used for training the detachable student model in the ME-D2N field; the method comprises the following specific steps:
(1) Learning task of source domain element (S) src ,Q src Inputting the predicted result into a source domain teacher model to obtain the predicted result output of the teacher model
Figure FDA0003807575910000031
Target Domain element learning task S tgt ,Q tgt Inputting the predicted result into a target domain teacher model to obtain a predicted result and outputting the predicted result
Figure FDA0003807575910000032
Wherein, P represents the result of the prediction,
Figure FDA0003807575910000033
particularly, the prediction result comes from a teacher model; subscripts src and tgt respectively represent that the data of the meta-learning task is sampled from a source domain and a target domain;
(2) Obtaining probability prediction distribution through a standard path; for source domain element learning task S src ,Q src ,y src And the standard access is realized by directly inputting the support set and the query set into ResNet-10 and GNN respectively without considering the existence of a domain gate M to obtain the probability distribution corresponding to the standard access
Figure FDA0003807575910000034
In the same way, the task S is learned for the target domain element tgt ,Q tgt ,y tgt Get the probability distribution corresponding to the corresponding standard path
Figure FDA0003807575910000035
(3) Obtaining probability prediction distribution through a domain path; for source domain element learning task S src ,Q src ,y src The method of the field path is as follows: controlling convolution kernel to source domain data and target according to output of domain gate MThe activation relationship of the domain data; specifically, gumbel softmax is carried out on the numerical value of the domain door M to realize numerical value binarization; marking the original source domain data characteristics obtained by ResNet common path as
Figure FDA0003807575910000036
The final source domain output is determined by the value of the matrix M as
Figure FDA0003807575910000037
Then inputting the characteristics into GNN to obtain probability distribution corresponding to domain path
Figure FDA0003807575910000038
In the same way, the task S is learned for the target domain element tgt ,Q tgt ,y tgt Firstly, the original target domain data is obtained with the characteristics that
Figure FDA0003807575910000039
The final source domain output is determined by the value of the matrix M as
Figure FDA0003807575910000041
Inputting the characteristic into GNN to obtain probability distribution corresponding to domain path
Figure FDA0003807575910000042
Wherein symbol F represents a visual feature;
(4) Probability distribution to source domain standard paths
Figure FDA0003807575910000043
Probability distribution of and source domain paths
Figure FDA0003807575910000044
Computing small sample element learning task loss
Figure FDA0003807575910000045
And knowledge of distillation losses
Figure FDA0003807575910000046
Here, symbol L denotes a loss function; the specific calculation method is as follows:
Figure FDA0003807575910000047
Figure FDA0003807575910000048
wherein CE () represents a cross entropy loss function, KD () represents a dispersion distance, and k represents a hyper-parameter;
(5) Probability distribution of standard paths to target domain
Figure FDA0003807575910000049
Probability distribution of and source domain paths
Figure FDA00038075759100000410
Computing small sample element learning task loss
Figure FDA00038075759100000411
And knowledge of distillation losses
Figure FDA00038075759100000412
The specific formula is as follows:
Figure FDA00038075759100000413
Figure FDA00038075759100000414
(6) The final one-time training loss function L is calculated as follows:
Figure FDA00038075759100000415
Figure FDA00038075759100000416
L=k 2 .L src +(1-k 2 ).L tgt , (7)
wherein k is 1 And k 2 Representing a hyper-parameter;
(7) Training the model with loss L, and repeating the above process for a total of num train And (4) wheels.
6. The method for identifying the cross-domain small sample based on multi-teacher knowledge distillation as claimed in claim 5, wherein the task { S) of classifying the domain-separable ME-D2N student models with the small sample elements is performed on the target domain unknown class test data test ,Q test Performing performance test, and obtaining two different probability prediction results by passing data through a standard path and a common path
Figure FDA00038075759100000417
And
Figure FDA00038075759100000418
averaging the two prediction probabilities as the final prediction result
Figure FDA00038075759100000419
The highest scoring class in the probability distribution is the prediction class at this time.
CN202211001654.7A 2022-08-19 2022-08-19 Cross-domain small sample identification method based on multi-teacher knowledge distillation Pending CN115423000A (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211001654.7A CN115423000A (en) 2022-08-19 2022-08-19 Cross-domain small sample identification method based on multi-teacher knowledge distillation

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211001654.7A CN115423000A (en) 2022-08-19 2022-08-19 Cross-domain small sample identification method based on multi-teacher knowledge distillation

Publications (1)

Publication Number Publication Date
CN115423000A true CN115423000A (en) 2022-12-02

Family

ID=84197760

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211001654.7A Pending CN115423000A (en) 2022-08-19 2022-08-19 Cross-domain small sample identification method based on multi-teacher knowledge distillation

Country Status (1)

Country Link
CN (1) CN115423000A (en)

Cited By (3)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116091895A (en) * 2023-04-04 2023-05-09 之江实验室 Model training method and device oriented to multitask knowledge fusion
CN116758391A (en) * 2023-04-21 2023-09-15 大连理工大学 Multi-domain remote sensing target generalization identification method for noise suppression distillation
CN118092403A (en) * 2024-04-23 2024-05-28 广汽埃安新能源汽车股份有限公司 Electric control detection model training method, electric control system detection method and device

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116091895A (en) * 2023-04-04 2023-05-09 之江实验室 Model training method and device oriented to multitask knowledge fusion
CN116091895B (en) * 2023-04-04 2023-07-11 之江实验室 Model training method and device oriented to multitask knowledge fusion
CN116758391A (en) * 2023-04-21 2023-09-15 大连理工大学 Multi-domain remote sensing target generalization identification method for noise suppression distillation
CN116758391B (en) * 2023-04-21 2023-11-21 大连理工大学 Multi-domain remote sensing target generalization identification method for noise suppression distillation
CN118092403A (en) * 2024-04-23 2024-05-28 广汽埃安新能源汽车股份有限公司 Electric control detection model training method, electric control system detection method and device

Similar Documents

Publication Publication Date Title
CN115423000A (en) Cross-domain small sample identification method based on multi-teacher knowledge distillation
Morerio et al. Minimal-entropy correlation alignment for unsupervised deep domain adaptation
Guo et al. Deep clustering with convolutional autoencoders
Zhu et al. Iterative Entity Alignment via Joint Knowledge Embeddings.
Jia et al. Label distribution learning by exploiting label correlations
Xu et al. Weighted multi-view clustering with feature selection
Yang et al. Transfer learning for sequence tagging with hierarchical recurrent networks
Wen et al. A discriminative feature learning approach for deep face recognition
Zhang et al. Balanced knowledge distillation for long-tailed learning
Bai et al. Me-momentum: Extracting hard confident examples from noisily labeled data
Bacciu et al. Edge-based sequential graph generation with recurrent neural networks
Xu et al. Constructing balance from imbalance for long-tailed image recognition
CN111222318A (en) Trigger word recognition method based on two-channel bidirectional LSTM-CRF network
CN112308211A (en) Domain increment method based on meta-learning
CN111259938B (en) Manifold learning and gradient lifting model-based image multi-label classification method
CN114863175A (en) Unsupervised multi-source partial domain adaptive image classification method
Peng et al. Swin transformer-based supervised hashing
Wang et al. Exploiting hierarchical structures for unsupervised feature selection
Yang et al. Boosting the adversarial transferability of surrogate models with dark knowledge
He et al. Multilabel classification by exploiting data‐driven pair‐wise label dependence
Wu et al. Boundaryface: A mining framework with noise label self-correction for face recognition
Wang et al. Improved local-feature-based few-shot learning with Sinkhorn metrics
Li et al. DENA: display name embedding method for Chinese social network alignment
Andleeb et al. ESIDE: A computationally intelligent method to identify earthworm species (E. fetida) from digital images: Application in taxonomy
Xu et al. Training classifiers that are universally robust to all label noise levels

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination