CN116935160A - A training method, sample classification method, electronic device and medium - Google Patents

A training method, sample classification method, electronic device and medium Download PDF

Info

Publication number
CN116935160A
CN116935160A CN202310892144.1A CN202310892144A CN116935160A CN 116935160 A CN116935160 A CN 116935160A CN 202310892144 A CN202310892144 A CN 202310892144A CN 116935160 A CN116935160 A CN 116935160A
Authority
CN
China
Prior art keywords
category
channel
data samples
sample
query set
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202310892144.1A
Other languages
Chinese (zh)
Other versions
CN116935160B (en
Inventor
黄维然
郑凯鹏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Shanghai Jiao Tong University
Original Assignee
Shanghai Jiao Tong University
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Shanghai Jiao Tong University filed Critical Shanghai Jiao Tong University
Priority to CN202310892144.1A priority Critical patent/CN116935160B/en
Publication of CN116935160A publication Critical patent/CN116935160A/en
Application granted granted Critical
Publication of CN116935160B publication Critical patent/CN116935160B/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/774Generating sets of training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/74Image or video pattern matching; Proximity measures in feature spaces
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/764Arrangements for image or video recognition or understanding using pattern recognition or machine learning using classification, e.g. of video objects
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/82Arrangements for image or video recognition or understanding using pattern recognition or machine learning using neural networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Multimedia (AREA)
  • Medical Informatics (AREA)
  • Databases & Information Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Data Mining & Analysis (AREA)
  • Molecular Biology (AREA)
  • General Engineering & Computer Science (AREA)
  • Mathematical Physics (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)

Abstract

The application provides a training method, a sample classification method, electronic equipment and a medium, comprising the following steps: acquiring training data, and acquiring characteristics of data samples in a support set and a query set by using a deep learning model; processing by utilizing the characteristics of the data samples of the support set, and obtaining the prototype vectors of each data sample of the support set and the difference values between the channels of the prototype vectors; processing by utilizing the characteristics of the data samples of the query set, and obtaining the difference value between all channels of the characteristics of the data samples of the query set; obtaining microKendall correlation coefficients between the characteristics of the data samples of the query set and the prototype vectors of each class by utilizing the differences between the channels of the prototype vectors of each class and the differences between the channels of the characteristics of the data samples of the query set; calculating a loss function according to the microKendell correlation coefficient; and training the deep learning model by using the loss function according to the training data. The training method can obtain a more elaborate deep learning model.

Description

一种训练方法、样本分类方法、电子设备及介质A training method, sample classification method, electronic device and medium

技术领域Technical field

本申请属于机器学习技术领域,涉及一种训练方法,特别是涉及一种训练方法、样本分类方法、电子设备及介质。This application belongs to the field of machine learning technology and relates to a training method, especially a training method, a sample classification method, electronic equipment and media.

背景技术Background technique

深度学习作为人工智能领域中的一种重要技术,通过模拟人脑神经网络来实现对数据的学习和处理。深度学习模型在许多领域中都取得了成功,然而当面对一个新的任务时,需要采集大量的数据来训练模型才能取得较好的性能。深度学习模型训练是深度学习中的一个关键步骤,它通过对大量数据进行训练来提高模型的准确性和性能,训练中主要采用基于梯度下降的优化算法。As an important technology in the field of artificial intelligence, deep learning achieves data learning and processing by simulating the human brain neural network. Deep learning models have achieved success in many fields. However, when faced with a new task, a large amount of data needs to be collected to train the model to achieve better performance. Deep learning model training is a key step in deep learning. It improves the accuracy and performance of the model by training a large amount of data. Optimization algorithms based on gradient descent are mainly used in training.

现如今深度学习模型训练仍然面临着诸多挑战,例如:训练数据不平衡会导致模型过拟合或欠拟合;超参数选择困难会影响模型性能;训练时间长、计算资源消耗大等问题也制约了深度学习模型训练的发展。因此,目前仍缺少一种获取更精良的深度学习模型的训练方法。Nowadays, deep learning model training still faces many challenges. For example, unbalanced training data will lead to model overfitting or underfitting; difficulty in selecting hyperparameters will affect model performance; long training time, high consumption of computing resources and other issues also restrict The development of deep learning model training. Therefore, there is still a lack of a training method to obtain more sophisticated deep learning models.

发明内容Contents of the invention

鉴于以上所述现有技术的缺点,本申请的目的在于提供一种训练方法、样本分类方法、电子设备及介质,用于解决现有技术中缺少一种获取更精良的深度学习模型的训练方法的问题。In view of the above shortcomings of the prior art, the purpose of this application is to provide a training method, sample classification method, electronic device and medium to solve the lack of a training method for obtaining a more sophisticated deep learning model in the prior art. The problem.

第一方面,本申请提供一种训练方法。所述训练方法包括:获取训练数据,所述训练数据包括支持集和查询集,所述支持集包括多个图像,所述查询集包括多个图像;利用深度学习模型获取所述支持集和所述查询集中的数据样本的特征;利用支持集的数据样本的特征进行处理,获取所述支持集各类别数据样本的原型向量和所述原型向量的各通道之间的差值;利用查询集的数据样本的特征进行处理,获取所述查询集的数据样本的特征的各通道之间的差值;利用各类别所述原型向量的各通道之间的差值和所述查询集的数据样本的特征的各通道之间的差值,获取所述查询集的数据样本的特征和各类别所述原型向量之间的可微肯德尔相关系数;根据所述可微肯德尔相关系数计算损失函数;根据所述训练数据,利用所述损失函数对所述深度学习模型进行训练。In the first aspect, this application provides a training method. The training method includes: obtaining training data, the training data includes a support set and a query set, the support set includes a plurality of images, and the query set includes a plurality of images; using a deep learning model to obtain the support set and the query set. The characteristics of the data samples in the query set are described; the characteristics of the data samples in the support set are used for processing, and the differences between the prototype vectors of each category of data samples in the support set and each channel of the prototype vector are obtained; the differences between the channels of the prototype vector are obtained Process the characteristics of the data samples to obtain the differences between the channels of the characteristics of the data samples of the query set; use the differences between the channels of the prototype vectors of each category and the differences of the data samples of the query set. The difference between each channel of the feature is used to obtain the differentiable Kendall correlation coefficient between the characteristics of the data sample of the query set and the prototype vector of each category; calculate the loss function according to the differentiable Kendall correlation coefficient; According to the training data, the deep learning model is trained using the loss function.

本申请中,根据原型向量各通道之间的差值和查询集的数据样本的特征的各通道之间的差值,获取可微肯德尔相关系数,进而利用数据样本对深度学习模型进行训练,训练过程中利用以可微肯德尔相关系数为参数的损失函数进行调整。此种训练方法能够建立原型向量的通道和查询集样本特征的通道之间的关联,获取可微肯德尔相关系数,进一步地获取更精良的深度学习模型。In this application, the differentiable Kendall correlation coefficient is obtained based on the difference between the channels of the prototype vector and the characteristics of the data samples in the query set, and then the data samples are used to train the deep learning model. During the training process, a loss function with differentiable Kendall correlation coefficients as parameters is used for adjustment. This training method can establish the correlation between the channel of the prototype vector and the channel of the query set sample feature, obtain the differentiable Kendall correlation coefficient, and further obtain a more sophisticated deep learning model.

在第一方面的一种实现方式中,利用支持集的数据样本的特征进行处理,获取所述支持集各类别数据样本的原型向量和所述原型向量的各通道之间的差值包括:根据所述支持集的数据样本的特征,获取所述支持集各类别数据样本的原型向量;利用所述原型向量获取所述原型向量的各通道值;根据所述原型向量的各通道值,获取所述原型向量的各通道之间的差值。In an implementation manner of the first aspect, using the characteristics of the data samples of the support set for processing, obtaining the prototype vector of each category of data samples in the support set and the difference between each channel of the prototype vector includes: according to According to the characteristics of the data samples in the support set, obtain the prototype vectors of each category of data samples in the support set; use the prototype vectors to obtain each channel value of the prototype vector; obtain all the channel values of the prototype vector according to the characteristics of the data samples in the support set. The difference between the channels of the prototype vector.

在第一方面的一种实现方式中,利用查询集的数据样本的特征进行处理,获取所述查询集的数据样本的特征的各通道之间的差值包括:根据所述查询集中的数据样本的特征获取所述查询集的数据样本的特征的各通道值;利用所述查询集的数据样本的特征的各通道值获取所述查询集的数据样本的特征的各通道之间的差值。In an implementation manner of the first aspect, processing is performed using the characteristics of the data samples in the query set, and obtaining the differences between the channels of the characteristics of the data samples in the query set includes: according to the data samples in the query set The characteristics of the query set are used to obtain the channel values of the characteristics of the data samples of the query set; the differences between the channels of the characteristics of the data samples of the query set are obtained by using the channel values of the characteristics of the data samples of the query set.

在第一方面的一种实现方式中,获取所述支持集的各类别的原型向量包括:根据所述支持集各类别数据样本的特征的平均值,获取所述支持集各类别数据样本的原型向量。In an implementation manner of the first aspect, obtaining the prototype vector of each category of the support set includes: obtaining the prototype of each category of data samples of the support set based on the average value of the characteristics of the data samples of each category of the support set. vector.

在第一方面的一种实现方式中,所述损失函数的计算方法包括:根据所述查询集数据样本的特征和各类别原型向量之间的可微肯德尔相关系数获取查询集的数据样本属于各类别的概率;利用所述查询集的数据样本属于各类别的概率和查询集的数据样本数量获取所述损失函数。In an implementation of the first aspect, the calculation method of the loss function includes: obtaining the data sample of the query set to which the query set belongs based on the differentiable Kendall correlation coefficient between the characteristics of the query set data sample and the prototype vector of each category. The probability of each category; the loss function is obtained by using the probability that the data samples in the query set belong to each category and the number of data samples in the query set.

在第一方面的一种实现方式中,所述损失函数为:其中,L为损失函数,|Q|为查询集的样本数量,x为输入的查询集数据样本,其所属类别为y,cy、cj为支持集第y个类别、第j个类别的原型向量,t为常数,N为支持集中的类别总数,/>为两个输入向量之间的可微肯德尔相关系数,fθ为深度学习模型,θ为深度学习模型的参数。In an implementation of the first aspect, the loss function is: Among them, L is the loss function, |Q| is the number of samples in the query set, x is the input query set data sample, its category is y, c y , c j are the yth category and jth category of the support set Prototype vector, t is a constant, N is the total number of categories in the support set, /> is the differentiable Kendall correlation coefficient between the two input vectors, f θ is the deep learning model, and θ is the parameter of the deep learning model.

在第一方面的一种实现方式中,利用所述损失函数对所述深度学习模型进行训练包括:利用所述损失函数对所述深度学习模型进行训练时,调整至少一个所述深度学习模型的参数值。In an implementation manner of the first aspect, using the loss function to train the deep learning model includes: when using the loss function to train the deep learning model, adjusting at least one parameter of the deep learning model. parameter value.

在第一方面的一种实现方式中,所述深度学习模型为卷积神经网络ResNet或基于Transformer的神经网络ViT。In an implementation manner of the first aspect, the deep learning model is a convolutional neural network ResNet or a Transformer-based neural network ViT.

第二方面,本申请提供一种样本分类方法。所述样本分类方法包括:获取各个类别至少一个参照样本和待分类的目标样本;利用深度学习模型获取所述参照样本和目标样本的特征,所述深度学习模型利用如第一方面中任一项所述的训练方法训练得到;根据所述参照样本的特征进行处理,获取各类别的原型向量;根据所述目标样本的特征和所述各类别的原型向量进行处理,获取所述目标样本的特征的通道重要性排序结果和所述各类别的原型向量的通道重要性排序结果;利用所述目标样本的所述特征的通道重要性排序结果和所述各类别的原型向量的通道重要性排序结果,获取所述目标样本的特征和各类别的原型向量之间的肯德尔相关系数;利用所述肯德尔相关系数,获取所述目标样本的分类结果。In the second aspect, this application provides a sample classification method. The sample classification method includes: obtaining at least one reference sample of each category and a target sample to be classified; using a deep learning model to obtain characteristics of the reference sample and the target sample, the deep learning model using any one of the first aspects Obtained by training by the described training method; process according to the characteristics of the reference sample to obtain the prototype vector of each category; process according to the characteristics of the target sample and the prototype vector of each category to obtain the characteristics of the target sample The channel importance ranking results and the channel importance ranking results of the prototype vectors of each category; using the channel importance ranking results of the features of the target sample and the channel importance ranking results of the prototype vectors of each category , obtain the Kendall correlation coefficient between the characteristics of the target sample and the prototype vector of each category; use the Kendall correlation coefficient to obtain the classification result of the target sample.

在第二方面的一种实现方式中,获取肯德尔相关系数包括:获取所述目标样本的特征和所述原型向量的通道重要性排序结果;利用所述目标样本的特征和所述原型向量之间所有对应通道对的重要性排序结果进行匹配,匹配结果包括重要性排序结果一致的通道对和重要性排序结果不一致的通道对;根据重要性排序结果一致的通道对的数量、重要性排序结果不一致的通道对的数量和总通道对的数量获取所述肯德尔相关系数。In an implementation manner of the second aspect, obtaining the Kendall correlation coefficient includes: obtaining the channel importance ranking result of the characteristics of the target sample and the prototype vector; using the relationship between the characteristics of the target sample and the prototype vector. Match the importance ranking results of all corresponding channel pairs. The matching results include channel pairs with consistent importance ranking results and channel pairs with inconsistent importance ranking results; according to the number of channel pairs with consistent importance ranking results, importance ranking results The Kendall correlation coefficient is obtained for the number of inconsistent channel pairs and the number of total channel pairs.

第三方面,本申请提供一种电子设备。所述电子设备包括:存储器,用于存储计算机程序;处理器,所述处理器用于执行所述存储器存储的计算机程序,以使所述电子设备执行第一方面中任一项所述的训练方法和/或第二方面所述的样本分类方法。In a third aspect, the present application provides an electronic device. The electronic device includes: a memory for storing a computer program; a processor for executing the computer program stored in the memory, so that the electronic device executes the training method according to any one of the first aspects. and/or the sample classification method described in the second aspect.

第四方面,本申请提供一种计算机可读存储介质,其上存储有计算机程序,该程序被处理器执行时实现第一方面中任一项所述的训练方法和/或第二方面所述的样本分类方法。In a fourth aspect, this application provides a computer-readable storage medium on which a computer program is stored. When the program is executed by a processor, the training method described in any one of the first aspects and/or the training method described in the second aspect is implemented. sample classification method.

附图说明Description of the drawings

图1A显示为本申请实施例所述的训练方法的一种应用场景示意图。FIG. 1A shows a schematic diagram of an application scenario of the training method described in the embodiment of the present application.

图1B显示为本申请实施例所述的训练方法的另一种应用场景示意图。Figure 1B shows a schematic diagram of another application scenario of the training method described in the embodiment of the present application.

图2显示为本申请实施例所述的训练方法的流程示意图。Figure 2 shows a schematic flow chart of the training method described in the embodiment of the present application.

图3显示为本申请实施例所述的训练方法的流程示意图。Figure 3 shows a schematic flow chart of the training method described in the embodiment of the present application.

图4显示为本申请实施例所述的训练方法的流程示意图。Figure 4 shows a schematic flow chart of the training method described in the embodiment of the present application.

图5显示为本申请实施例所述损失函数的计算方法的流程示意图。Figure 5 shows a schematic flow chart of the calculation method of the loss function according to the embodiment of the present application.

图6显示为本申请实施例所述的样本分类方法的流程示意图。Figure 6 shows a schematic flow chart of the sample classification method according to the embodiment of the present application.

图7显示为本申请实施例所述的样本分类方法的流程示意图。Figure 7 shows a schematic flow chart of the sample classification method according to the embodiment of the present application.

图8显示为本申请实施例所述的电子设备的结构示意图。FIG. 8 shows a schematic structural diagram of an electronic device according to an embodiment of the present application.

元件标号说明Component label description

1 医学图像分类系统1 Medical image classification system

11 图像采集设备11 Image acquisition equipment

12 本地处理器12 local processors

13 显示终端13 display terminal

2 端-云交互系统2 terminal-cloud interactive system

20 终端20 terminal

21 云端服务器21 Cloud Server

800 电子设备800 Electronic equipment

810 存储器810 memory

820 处理器820 processor

830 显示器830 monitor

S11~S17 步骤Steps S11~S17

S131~S133 步骤Steps S131~S133

S141~S142 步骤Steps S141~S142

S21~S22 步骤S21~S22 steps

S31~S36 步骤Steps S31~S36

S351~S353 步骤Steps S351~S353

具体实施方式Detailed ways

以下通过特定的具体实例说明本申请的实施方式,本领域技术人员可由本说明书所揭露的内容轻易地了解本申请的其他优点与功效。本申请还可以通过另外不同的具体实施方式加以实施或应用,本说明书中的各项细节也可以基于不同观点与应用,在没有背离本申请的精神下进行各种修饰或改变。需说明的是,在不冲突的情况下,以下实施例及实施例中的特征可以相互组合。The following describes the implementation of the present application through specific examples. Those skilled in the art can easily understand other advantages and effects of the present application from the content disclosed in this specification. This application can also be implemented or applied through other different specific embodiments. Various details in this specification can also be modified or changed in various ways based on different viewpoints and applications without departing from the spirit of this application. It should be noted that, as long as there is no conflict, the following embodiments and the features in the embodiments can be combined with each other.

需要说明的是,以下实施例中所提供的图示仅以示意方式说明本申请的基本构想,遂图式中仅显示与本申请中有关的组件而非按照实际实施时的组件数目、形状及尺寸绘制,其实际实施时各组件的型态、数量及比例可为一种随意的改变,且其组件布局型态也可能更为复杂。It should be noted that the illustrations provided in the following embodiments only illustrate the basic concept of the present application in a schematic manner, and the drawings only show the components related to the present application and do not follow the actual implementation of the component numbers, shapes and components. Dimension drawing, in actual implementation, the type, quantity and proportion of each component can be arbitrarily changed, and the component layout type may also be more complex.

由于在小样本学习中,用于模型训练的基数据集中的样本和小样本任务中的样本在类别上没有交集,因此在训练的时候模型实际上没有见过小样本任务上数据的类别。这会导致所提取到的小样本任务上数据样本的特征同传统监督学习方法中所提取到的样本特征具有显著差异。小样本任务上数据样本特征的一个重要特性是特征的取值分布更加平缓,特征当中绝大部分通道上的取值都较小并且彼此之间十分接近。现有方法中所采用的负欧式距离和余弦相似度都是通过直接计算特征之间的几何相似度来进行分类,这会导致分类结果实际上由极少数的几个取值较大的通道所主导,而对于在特征当中占据绝大部分的取值较小的通道则难以区分它们在分类中的重要性。而在这些通道当中,一些通道可能代表了某个类别具有区分性的关键特征,如果不能很好地区分它们的重要性将会得到错误的分类结果。Since in small-sample learning, the samples in the base data set used for model training and the samples in the small-sample task have no intersection in category, the model has not actually seen the category of the data in the small-sample task during training. This will lead to significant differences between the extracted features of the data samples on the small sample task and the sample features extracted by traditional supervised learning methods. An important characteristic of data sample features on small sample tasks is that the value distribution of the features is flatter. The values on most channels of the features are small and very close to each other. The negative Euclidean distance and cosine similarity used in existing methods are classified by directly calculating the geometric similarity between features, which will cause the classification results to be actually determined by a very small number of channels with large values. Dominant, and for the channels with smaller values that account for most of the features, it is difficult to distinguish their importance in classification. Among these channels, some channels may represent key distinguishing features of a certain category. If their importance cannot be distinguished well, incorrect classification results will be obtained.

至少针对上述问题,本申请实施例提供一种训练方法。所述训练方法包括:获取训练数据,所述训练数据包括支持集和查询集,所述支持集包括多个图像,所述查询集包括多个图像;利用深度学习模型获取所述支持集和所述查询集中的数据样本的特征;利用支持集的数据样本的特征进行处理,获取所述支持集各类别数据样本的原型向量和所述原型向量的各通道之间的差值;利用查询集的数据样本的特征进行处理,获取所述查询集的数据样本的特征的各通道之间的差值;利用各所述原型向量的各通道之间的差值和所述查询集的数据样本的特征的各通道之间的差值,获取所述查询集的数据样本的特征和各所述原型向量之间的可微肯德尔相关系数;根据所述可微肯德尔相关系数计算损失函数;根据所述训练数据,利用所述损失函数对所述深度学习模型进行训练。At least to address the above problems, embodiments of the present application provide a training method. The training method includes: obtaining training data, the training data includes a support set and a query set, the support set includes a plurality of images, and the query set includes a plurality of images; using a deep learning model to obtain the support set and the query set. The characteristics of the data samples in the query set are described; the characteristics of the data samples in the support set are used for processing, and the differences between the prototype vectors of each category of data samples in the support set and each channel of the prototype vector are obtained; the differences between the channels of the prototype vector are obtained Process the characteristics of the data samples to obtain the differences between the channels of the characteristics of the data samples of the query set; use the differences between the channels of each of the prototype vectors and the characteristics of the data samples of the query set The difference between each channel is used to obtain the differentiable Kendall correlation coefficient between the characteristics of the data samples of the query set and each of the prototype vectors; calculate the loss function according to the differentiable Kendall correlation coefficient; according to the The training data is used to train the deep learning model using the loss function.

本申请实施例中,根据原型向量各通道之间的差值和查询集的数据样本的特征的各通道之间的差值,获取可微肯德尔相关系数,进而利用数据样本对深度学习模型进行训练,训练过程中利用以可微肯德尔相关系数为参数的损失函数进行调整。此种训练方法能够建立原型向量的通道和查询集样本特征的通道之间的关联,获取可微肯德尔相关系数,进一步地获取更精良的深度学习模型。In the embodiment of the present application, the differentiable Kendall correlation coefficient is obtained based on the difference between the channels of the prototype vector and the difference between the characteristics of the data samples in the query set, and then the data samples are used to perform the deep learning model Training, the loss function with differentiable Kendall correlation coefficient as parameter is used for adjustment during the training process. This training method can establish the correlation between the channel of the prototype vector and the channel of the query set sample feature, obtain the differentiable Kendall correlation coefficient, and further obtain a more sophisticated deep learning model.

图1A显示为本申请实施例所述的训练方法的一种应用场景示意图。医学图像分类系统1可用于实现本申请实施例提供的训练方法,但本申请实施例提供的训练方法的应用场景并不限于图1A所示的医学图像分类系统1。如图1A所示,医学图像分类系统1包括图像采集设备11、本地处理器12和显示终端13。本申请实施例提供的训练方法可以应用于本地处理器12。FIG. 1A shows a schematic diagram of an application scenario of the training method described in the embodiment of the present application. The medical image classification system 1 can be used to implement the training method provided by the embodiment of the present application, but the application scenarios of the training method provided by the embodiment of the present application are not limited to the medical image classification system 1 shown in Figure 1A. As shown in FIG. 1A , the medical image classification system 1 includes an image acquisition device 11 , a local processor 12 and a display terminal 13 . The training method provided by the embodiment of this application can be applied to the local processor 12 .

其中,所述图像采集设备11用于采集医学影像数据样本。可选地,所述图像采集设备可以为CT扫描仪、X射线机、磁共振成像设备等,本申请并不以此为限。Wherein, the image acquisition device 11 is used to collect medical imaging data samples. Optionally, the image acquisition device may be a CT scanner, an X-ray machine, a magnetic resonance imaging device, etc., and this application is not limited thereto.

所述本地处理器12用于处理和分析采集到的医学影像数据样本,提取所述医学影像数据样本的特征分类结果。所述本地处理器12可以是一台处理器或多台处理器组成的处理器集群,具体此处均不限定。The local processor 12 is used to process and analyze the collected medical image data samples, and extract feature classification results of the medical image data samples. The local processor 12 may be one processor or a processor cluster composed of multiple processors, and the details are not limited here.

所述显示终端13可以是具有显示功能的设备,用于显示处理后的分类结果,供医生进行分析。可选地,所述显示终端13可以为医学显示器。The display terminal 13 may be a device with a display function, and is used to display the processed classification results for analysis by doctors. Optionally, the display terminal 13 may be a medical display.

需要说明的是,虽然图1A中仅示出了一个图像采集设备11、一个本地处理器12和一个显示终端13,但应当理解,图1A中的示例仅用于理解本方案,具体显示终端和服务器的数量均应当结合实际情况灵活确定。It should be noted that although FIG. 1A only shows an image acquisition device 11, a local processor 12 and a display terminal 13, it should be understood that the example in FIG. 1A is only used to understand this solution. Specifically, the display terminal and The number of servers should be flexibly determined based on the actual situation.

在另一些可能的实现方式中,本申请描述的训练方法可应用于端-云交互场景。图1B显示为本申请实施例所述的训练方法的另一种应用场景示意图。如图1B所示,端-云交互系统2包括终端20和云端服务器21,终端20与云端服务器21之间可进行通信,通信方式不限于有线或无线的方式,本申请所述的训练方法可以应用于云端服务器21中。In other possible implementations, the training method described in this application can be applied to device-cloud interaction scenarios. Figure 1B shows a schematic diagram of another application scenario of the training method described in the embodiment of the present application. As shown in Figure 1B, the terminal-cloud interactive system 2 includes a terminal 20 and a cloud server 21. The terminal 20 and the cloud server 21 can communicate. The communication method is not limited to wired or wireless methods. The training method described in this application can Applied to cloud server 21.

其中,所述终端20可以为移动的或固定的,例如,该终端20可以是无线终端也可以是有线终端,无线终端可以是指一种具有无线收发功能的设备,可以部署于数据处理中心和医学实验室。所述终端可以是医学处理器、手机(Mobile phone)、平板电脑(Pad)、笔记本电脑等等,在此不作限定。The terminal 20 may be mobile or fixed. For example, the terminal 20 may be a wireless terminal or a wired terminal. The wireless terminal may refer to a device with wireless transceiver functions and may be deployed in a data processing center and a wired terminal. Medical laboratory. The terminal may be a medical processor, a mobile phone, a tablet, a notebook computer, etc., which are not limited here.

所述云端服务器21可以包括一个或多个服务器,或者包括一个或多个处理节点,或者包括运行于服务器的一个或多个虚拟机,云端服务器21还可以被称为服务器集群、管理平台、数据处理中心等等,本申请实施例不做限定。The cloud server 21 may include one or more servers, or one or more processing nodes, or one or more virtual machines running on the servers. The cloud server 21 may also be called a server cluster, a management platform, a data server, or a server cluster. Processing centers, etc. are not limited in the embodiments of this application.

下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行详细描述。The technical solutions in the embodiments of the present application will be described in detail below with reference to the drawings in the embodiments of the present application.

本申请以下实施例提供了一种训练方法,该方法例如可通过图1A所示的本地处理器12或图1B所示的云端服务器21来实现。图2显示为本申请实施例所述的训练方法的流程示意图。如图2所示,所述训练方法包括步骤S11至S17。The following embodiments of the present application provide a training method, which can be implemented, for example, through the local processor 12 shown in Figure 1A or the cloud server 21 shown in Figure 1B. Figure 2 shows a schematic flow chart of the training method described in the embodiment of the present application. As shown in Figure 2, the training method includes steps S11 to S17.

步骤S11,获取训练数据,所述训练数据包括支持集和查询集,所述支持集包括多个图像,所述查询集包括多个图像。所述支持集的构建形式为N-way,K-shot,即支持集中的图像共有N个类别,其中每个类别包含K张图像。Step S11: Obtain training data. The training data includes a support set and a query set. The support set includes multiple images, and the query set includes multiple images. The construction form of the support set is N-way, K-shot, that is, there are N categories of images in the support set, and each category contains K images.

步骤S12,利用深度学习模型获取所述支持集和所述查询集中的数据样本的特征。Step S12: Use a deep learning model to obtain characteristics of the data samples in the support set and the query set.

步骤S13,利用支持集的数据样本的特征进行处理,获取所述支持集各类别数据样本的原型向量和所述原型向量的各通道之间的差值。Step S13: Use the characteristics of the data samples in the support set for processing, and obtain the differences between the prototype vectors of each category of data samples in the support set and each channel of the prototype vector.

步骤S14,利用查询集的数据样本的特征进行处理,获取所述查询集的数据样本的特征的各通道之间的差值。Step S14: Use the characteristics of the data samples in the query set to perform processing, and obtain the differences between the channels of the characteristics of the data samples in the query set.

步骤S15,利用各所述原型向量的各通道之间的差值和所述查询集的数据样本的特征的各通道之间的差值,获取所述查询集的数据样本的特征和各所述原型向量之间的可微肯德尔相关系数。其中,所述可微肯德尔相关系数用来度量两个有序变量之间单调关系强弱的相关程度,相关系数的绝对值越大,表示相关度越强,相关系数越小,表示相关度越弱。Step S15, using the difference between the channels of each prototype vector and the difference between the channels of the characteristics of the data samples of the query set, obtain the characteristics of the data samples of the query set and each of the Differentiable Kendall correlation coefficients between prototype vectors. Among them, the differentiable Kendall correlation coefficient is used to measure the degree of correlation between the strength of the monotonic relationship between two ordered variables. The greater the absolute value of the correlation coefficient, the stronger the correlation. The smaller the correlation coefficient, the degree of correlation. The weaker.

在一些可能的实现方式中,所述可微肯德尔相关系数的计算公式为:In some possible implementations, the calculation formula of the differentiable Kendall correlation coefficient is:

其中,u=(u1,u2,...un)查询集的数据样本的特征,v=(v1,v2,...vn)为原型向量,n为通道总数,ui为向量u在第i个通道上的取值,vi为向量v在第i个通道上的取值,α为参数。Among them, u=(u 1 ,u 2 ,...u n ) is the characteristic of the data sample of the query set, v=(v 1 ,v 2 ,...v n ) is the prototype vector, n is the total number of channels, u i is the value of vector u on the i-th channel, vi is the value of vector v on the i-th channel, and α is a parameter.

步骤S16,根据所述可微肯德尔相关系数计算损失函数。Step S16: Calculate a loss function based on the differentiable Kendall correlation coefficient.

步骤S17,根据所述训练数据,利用所述损失函数对所述深度学习模型进行训练。Step S17: Use the loss function to train the deep learning model based on the training data.

本申请实施例中,根据原型向量各通道之间的差值和查询集的数据样本的特征的各通道之间的差值,获取可微肯德尔相关系数,进而利用数据样本对深度学习模型进行训练,训练过程中利用以可微肯德尔相关系数为参数的损失函数进行调整。此种训练方法能够建立原型向量的通道和查询集样本特征的通道之间的关联,获取可微肯德尔相关系数,进一步地获取更精良的深度学习模型。In the embodiment of the present application, the differentiable Kendall correlation coefficient is obtained based on the difference between the channels of the prototype vector and the difference between the characteristics of the data samples in the query set, and then the data samples are used to perform the deep learning model Training, the loss function with differentiable Kendall correlation coefficient as parameter is used for adjustment during the training process. This training method can establish the correlation between the channel of the prototype vector and the channel of the query set sample feature, obtain the differentiable Kendall correlation coefficient, and further obtain a more sophisticated deep learning model.

图3显示为本申请实施例所述的训练方法的流程示意图。如图3所示,利用支持集的数据样本的特征进行处理,获取所述支持集各类别数据样本的原型向量和所述原型向量的各通道之间的差值包括步骤S131至S133。Figure 3 shows a schematic flow chart of the training method described in the embodiment of the present application. As shown in Figure 3, using the characteristics of the data samples in the support set for processing, and obtaining the difference between the prototype vectors of each category of data samples in the support set and each channel of the prototype vector includes steps S131 to S133.

步骤S131,根据所述支持集的数据样本的特征,获取所述支持集各类别数据样本的原型向量。Step S131: According to the characteristics of the data samples in the support set, obtain prototype vectors of each category of data samples in the support set.

步骤S132,利用所述原型向量获取所述原型向量的各通道值。Step S132: Use the prototype vector to obtain each channel value of the prototype vector.

步骤S133,根据所述原型向量的各通道值,获取所述原型向量的各通道之间的差值。Step S133: Obtain the difference between each channel of the prototype vector according to the channel value of the prototype vector.

在一些可能的实现方式中,ui为向量u在第i个通道上的取值,uj为向量u在第j个通道上的取值,则原型向量u的第i个和第j个通道之间的差值计算公式为:δ=ui-ujIn some possible implementations, u i is the value of vector u on the i-th channel, u j is the value of vector u on the j-th channel, then the i-th and j-th prototype vector u The calculation formula for the difference between channels is: δ=u i -u j .

图4显示为本申请实施例所述的训练方法的流程示意图。如图4所示,利用查询集的数据样本的特征进行处理,获取所述查询集的数据样本的特征的各通道之间的差值包括步骤S141至步骤S142。Figure 4 shows a schematic flow chart of the training method described in the embodiment of the present application. As shown in Figure 4, processing is performed using the characteristics of the data samples of the query set. Obtaining the difference between the channels of the characteristics of the data samples of the query set includes steps S141 to S142.

步骤S141,根据所述查询集中的数据样本的特征获取所述查询集的数据样本的特征的各通道值。Step S141: Obtain the channel values of the characteristics of the data samples in the query set according to the characteristics of the data samples in the query set.

步骤S142,利用所述查询集的数据样本的特征的各通道值获取所述查询集的数据样本的特征的各通道之间的差值。Step S142: Use the channel values of the characteristics of the data samples in the query set to obtain the differences between the channels of the characteristics of the data samples in the query set.

于本申请的一实施例中,获取所述支持集的各类别的原型向量包括步骤S134:根据所述支持集各类别数据样本的特征的平均值,获取所述支持集各类别数据样本的原型向量。In one embodiment of the present application, obtaining the prototype vector of each category of the support set includes step S134: obtaining the prototype of each category of data samples in the support set based on the average value of the characteristics of the data samples of each category of the support set. vector.

图5显示为本申请实施例所述损失函数的计算方法的流程示意图。如图5所示,所述损失函数的计算方法包括步骤S21至S22。Figure 5 shows a schematic flow chart of the calculation method of the loss function according to the embodiment of the present application. As shown in Figure 5, the calculation method of the loss function includes steps S21 to S22.

步骤S21,根据所述查询集数据样本的特征和各类别原型向量之间的可微肯德尔相关系数获取查询集的数据样本属于各类别的概率。Step S21: Obtain the probability that the data sample of the query set belongs to each category based on the differentiable Kendall correlation coefficient between the characteristics of the query set data sample and the prototype vector of each category.

步骤S22,利用所述查询集的数据样本属于各类别的概率和查询集的数据样本数量获取所述损失函数。Step S22: Obtain the loss function using the probability that the data samples in the query set belong to each category and the number of data samples in the query set.

在一些可能的实现方式中,所述支持集的构建形式为5-way,5-shot。根据所述查询集数据样本的特征和各类别原型向量之间的可微肯德尔相关系数获取查询集的数据样本属于各类别的概率。对于其中支持集中的样本,对各类别中所有5个样本取平均作为各个类别的原型向量。根据所述查询集数据样本的特征和各类别原型向量之间的可微肯德尔相关系数获取查询集的数据样本属于各类别的概率的计算公式为:In some possible implementations, the support set is constructed in the form of 5-way, 5-shot. The probability that the data sample of the query set belongs to each category is obtained according to the differentiable Kendall correlation coefficient between the characteristics of the query set data sample and the prototype vector of each category. For the samples in the support set, all five samples in each category are averaged as the prototype vector of each category. The calculation formula for obtaining the probability that the data sample of the query set belongs to each category based on the differentiable Kendall correlation coefficient between the characteristics of the query set data sample and the prototype vector of each category is:

其中,ck为第k个类别的类原型向量,cj为第j个类别的类原型向量,sim为相似度度量,N为类别总数,t为常数,x为输入的查询集数据样本,y为样本所属类别,fθ为深度学习模型,θ为深度学习模型的参数。Among them, c k is the class prototype vector of the k-th category, c j is the class prototype vector of the j-th category, sim is the similarity measure, N is the total number of categories, t is a constant, x is the input query set data sample, y is the category to which the sample belongs, f θ is the deep learning model, and θ is the parameter of the deep learning model.

于本申请的一实施例中,所述损失函数的计算公式为:In an embodiment of the present application, the calculation formula of the loss function is:

其中,L为损失函数,|Q|为查询集的样本数量,x为输入的查询集数据样本,其所属类别为y,cy、cj为支持集第y个类别、第j个类别的原型向量,t为常数,N为支持集中的类别总数,为两个输入向量之间的可微肯德尔相关系数,fθ为深度学习模型,θ为深度学习模型的参数。Among them, L is the loss function, |Q| is the number of samples in the query set, x is the input query set data sample, its category is y, c y , c j are the yth category and jth category of the support set Prototype vector, t is a constant, N is the total number of categories in the support set, is the differentiable Kendall correlation coefficient between the two input vectors, f θ is the deep learning model, and θ is the parameter of the deep learning model.

于本申请一实施例中,利用所述损失函数对所述深度学习模型进行训练包括:利用所述损失函数对所述深度学习模型进行训练时,调整至少一个所述深度学习模型的参数值。In an embodiment of the present application, using the loss function to train the deep learning model includes: adjusting at least one parameter value of the deep learning model when using the loss function to train the deep learning model.

于本申请一实施例中,所述深度学习模型为卷积神经网络ResNet(Residualneural network,残差神经网络)或基于Transformer的神经网络ViT(Visiontransformer)。In one embodiment of the present application, the deep learning model is a convolutional neural network ResNet (Residual neural network, residual neural network) or a Transformer-based neural network ViT (Visiontransformer).

图6显示为本申请实施例所述的样本分类方法的流程示意图。如图6所示,所述样本分类方法包括步骤S31至S36。Figure 6 shows a schematic flow chart of the sample classification method according to the embodiment of the present application. As shown in Figure 6, the sample classification method includes steps S31 to S36.

步骤S31,获取各个类别至少一个参照样本和待分类的目标样本。可选地,所述参照样本为已有分类结果的样本。Step S31: Obtain at least one reference sample of each category and the target sample to be classified. Optionally, the reference sample is a sample with existing classification results.

步骤S32,利用深度学习模型获取所述参照样本和目标样本的特征,所述深度学习模型利用如本申请任一实施例所述的训练方法训练得到。可选地,利用训练好的所述深度学习模型作为特征提取器,分别提取目标样本的查询集样本的特征和支持集样本的特征。Step S32: Use a deep learning model to obtain the characteristics of the reference sample and the target sample. The deep learning model is trained using the training method described in any embodiment of this application. Optionally, use the trained deep learning model as a feature extractor to respectively extract features of the query set samples and support set samples of the target sample.

步骤S33,根据所述参照样本的特征进行处理,获取各类别的原型向量。Step S33: Process according to the characteristics of the reference sample to obtain prototype vectors of each category.

步骤S34,根据所述目标样本的特征和所述各类别的原型向量进行处理,获取所述目标样本的特征的通道重要性排序结果和所述各类别的原型向量的通道重要性排序结果。可选地,根据特征中各个通道取值大小进行排序,各个通道的重要程度利用排序结果表示。Step S34: Process according to the characteristics of the target sample and the prototype vectors of each category to obtain the channel importance ranking results of the characteristics of the target sample and the channel importance ranking results of the prototype vectors of each category. Optionally, sort according to the value of each channel in the feature, and the importance of each channel is expressed by the sorting result.

在一些可能的实现方式中,一个通道数为5的特征,其各个通道的取值分别为(0.1,0.11,0.13,0.12,0.14),将其进行排序后得到的特征的重要性排序结果为(1,2,4,3,5)。In some possible implementations, for a feature with 5 channels, the values of each channel are (0.1, 0.11, 0.13, 0.12, 0.14). After sorting them, the importance ranking result of the features is: (1,2,4,3,5).

步骤S35,利用所述目标样本的所述特征的通道重要性排序结果和所述各类别的原型向量的通道重要性排序结果,获取所述目标样本的特征和各类别的原型向量之间的肯德尔相关系数。Step S35: Use the channel importance ranking results of the features of the target sample and the channel importance ranking results of the prototype vectors of each category to obtain the correlation between the features of the target sample and the prototype vectors of each category. Del correlation coefficient.

步骤S36,利用所述肯德尔相关系数,获取所述目标样本的分类结果。其中,所述分类结果为与所述目标样本的特征具有最大肯德尔相关系数的原型向量对应的类别。Step S36: Use the Kendall correlation coefficient to obtain the classification result of the target sample. Wherein, the classification result is the category corresponding to the prototype vector with the largest Kendall correlation coefficient of the characteristics of the target sample.

图7显示为本申请实施例所述的样本分类方法的流程示意图。如图7所示,获取肯德尔相关系数包括步骤S351至S353。Figure 7 shows a schematic flow chart of the sample classification method according to the embodiment of the present application. As shown in Figure 7, obtaining the Kendall correlation coefficient includes steps S351 to S353.

步骤S351,获取所述目标样本的特征和所述原型向量的通道重要性排序结果。Step S351: Obtain the characteristics of the target sample and the channel importance ranking result of the prototype vector.

步骤S352,利用所述目标样本的特征和所述原型向量之间所有对应通道对的重要性排序结果进行匹配,匹配结果包括重要性排序结果一致的通道对和重要性排序结果不一致的通道对。可选地,任意一个通道对为各通道中任意两个通道的集合。任意一个通道对的重要性排序结果为任意两个通道的重要性排序结果。可选地,若匹配失败,则获取的匹配结果为重要性排序结果不一致的通道对,若匹配成功,则获取的匹配结果为重要性排序结果一致的通道对。Step S352, use the importance ranking results of all corresponding channel pairs between the characteristics of the target sample and the prototype vector to perform matching. The matching results include channel pairs with consistent importance ranking results and channel pairs with inconsistent importance ranking results. Optionally, any channel pair is a set of any two channels in each channel. The importance ranking result of any channel pair is the importance ranking result of any two channels. Optionally, if the matching fails, the obtained matching result is a channel pair with inconsistent importance ranking results. If the matching is successful, the obtained matching result is a channel pair with consistent importance ranking results.

步骤S353,根据重要性排序结果一致的通道对的数量、重要性排序结果不一致的通道对的数量和总通道对的数量获取所述肯德尔相关系数。Step S353: Obtain the Kendall correlation coefficient based on the number of channel pairs with consistent importance ranking results, the number of channel pairs with inconsistent importance ranking results, and the number of total channel pairs.

在一些可能的实现方式中,所述肯德尔相关系数的计算公式为:In some possible implementations, the calculation formula of the Kendall correlation coefficient is:

其中,Ncon为重要性排序一致的通道对的数量,Ndis为重要性排序不一致的通道对的数量,Ntotal为总通道对的数量。Among them, N con is the number of channel pairs with consistent importance ranking, N dis is the number of channel pairs with inconsistent importance ranking, and N total is the number of total channel pairs.

接下来将通过一个具体实例对上述肯德尔相关系数的计算方式进行介绍。其中,对于任意两个特征x和y,特征x和特征y的通道重要性排序结果分别为x=(1,2,3),y=(1,3,2)。其中特征x的通道1和通道2为(1,2),特征y的通道1和通道2为(1,3),通道对(包含通道1和2)的重要性排序结果一致。特征x的通道1和通道3为(1,3),特征y的通道1和通道3为(1,2),通道对(包含通道1和3)的重要性排序结果一致。特征x的通道2和通道3为(2,3),特征y的通道2和通道3为(3,2),通道对(包含通道2和3)的重要性排序结果不一致。重要性排序结果一致的通道对的数量Ncon=2,重要性排序结果不一致的通道对的数量Ndis=1,总通道对的数量Ntotal=3。根据重要性排序结果一致的通道对的数量、重要性排序结果不一致的通道对的数量和总通道对的数量获取所述肯德尔相关系数的为1/3。Next, the calculation method of the above Kendall correlation coefficient will be introduced through a specific example. Among them, for any two features x and y, the channel importance ranking results of feature x and feature y are x=(1,2,3) and y=(1,3,2) respectively. Among them, channel 1 and channel 2 of feature x are (1,2), channel 1 and channel 2 of feature y are (1,3), and the importance ranking results of channel pairs (including channels 1 and 2) are consistent. Channel 1 and channel 3 of feature x are (1,3), channel 1 and channel 3 of feature y are (1,2), and the importance ranking results of channel pairs (including channels 1 and 3) are consistent. Channel 2 and channel 3 of feature x are (2,3), channel 2 and channel 3 of feature y are (3,2), and the importance ranking results of channel pairs (including channels 2 and 3) are inconsistent. The number of channel pairs with consistent importance ranking results is N con =2, the number of channel pairs with inconsistent importance ranking results is N dis =1, and the total number of channel pairs is N total =3. The Kendall correlation coefficient obtained based on the number of channel pairs with consistent importance ranking results, the number of channel pairs with inconsistent importance ranking results, and the total number of channel pairs is 1/3.

需要说明的是,以上仅为本申请实施例可能的实现方式,本申请并不以此为限。It should be noted that the above are only possible implementation methods of the embodiments of the present application, and the present application is not limited thereto.

在本申请所提供的几个实施例中,应该理解到,所揭露的系统、装置或方法,可以通过其它的方式实现。例如,以上所描述的装置实施例仅是示意性的,例如,模块/单元的划分,仅仅为一种逻辑功能划分,实际实现时可以有另外的划分方式,例如多个模块或单元可以结合或者可以集成到另一个系统,或一些特征可以忽略,或不执行。另一点,所显示或讨论的相互之间的耦合或直接耦合或通信连接可以是通过一些接口,装置或模块或单元的间接耦合或通信连接,可以是电性,机械或其它的形式。In the several embodiments provided in this application, it should be understood that the disclosed system, device or method can be implemented in other ways. For example, the device embodiments described above are only illustrative. For example, the division of modules/units is only a logical function division. In actual implementation, there may be other division methods, for example, multiple modules or units may be combined or can be integrated into another system, or some features can be ignored, or not implemented. On the other hand, the coupling or direct coupling or communication connection between each other shown or discussed may be indirect coupling or communication connection through some interfaces, devices or modules or units, which may be in electrical, mechanical or other forms.

作为分离部件说明的模块/单元可以是或者也可以不是物理上分开的,作为模块/单元显示的部件可以是或者也可以不是物理模块,即可以位于一个地方,或者也可以分布到多个网络单元上。可以根据实际的需要选择其中的部分或者全部模块/单元来实现本申请实施例的目的。例如,在本申请各个实施例中的各功能模块/单元可以集成在一个处理模块中,也可以是各个模块/单元单独物理存在,也可以两个或两个以上模块/单元集成在一个模块/单元中。Modules/units described as separate components may or may not be physically separate. Components shown as modules/units may or may not be physical modules, that is, they may be located in one place, or they may be distributed to multiple network units. superior. Some or all of the modules/units may be selected according to actual needs to achieve the purpose of the embodiments of the present application. For example, each functional module/unit in various embodiments of the present application can be integrated into a processing module, or each module/unit can exist physically alone, or two or more modules/units can be integrated into one module/unit. in the unit.

本领域普通技术人员应该还可以进一步意识到,结合本文中所公开的实施例描述的各示例的单元及算法步骤,能够以电子硬件、计算机软件或者二者的结合来实现,为了清楚地说明硬件和软件的可互换性,在上述说明中已经按照功能一般性地描述了各示例的组成及步骤。这些功能究竟以硬件还是软件方式来执行,取决于技术方案的特定应用和设计约束条件。专业技术人员可以对每个特定的应用来使用不同方法来实现所描述的功能,但是这种实现不应认为超出本申请的范围。Those of ordinary skill in the art should further realize that the units and algorithm steps of each example described in conjunction with the embodiments disclosed herein can be implemented with electronic hardware, computer software, or a combination of both. In order to clearly illustrate the hardware and software interchangeability. In the above description, the composition and steps of each example have been generally described according to functions. Whether these functions are performed in hardware or software depends on the specific application and design constraints of the technical solution. Skilled artisans may implement the described functionality using different methods for each specific application, but such implementations should not be considered beyond the scope of this application.

本申请实施例还提供了一种电子设备。图8显示为本申请实施例所述的电子设备800的结构示意图。如图8所示,本实施例中电子设备800包括存储器810和处理器820。An embodiment of the present application also provides an electronic device. FIG. 8 shows a schematic structural diagram of an electronic device 800 according to an embodiment of the present application. As shown in FIG. 8 , the electronic device 800 in this embodiment includes a memory 810 and a processor 820 .

存储器810用于存储计算机程序;可选地,存储器810包括:ROM、RAM、磁碟、U盘、存储卡或者光盘等各种可以存储程序代码的介质。The memory 810 is used to store computer programs; optionally, the memory 810 includes various media that can store program codes, such as ROM, RAM, magnetic disks, USB disks, memory cards, or optical disks.

具体地,存储器810可以包括易失性存储器形式的计算机系统可读介质,例如随机存取存储器(RAM)和/或高速缓存存储器。电子设备800可以进一步包括其它可移动/不可移动的、易失性/非易失性计算机系统存储介质。存储器810可以包括至少一个程序产品,该程序产品具有一组(例如至少一个)程序模块,这些程序模块被配置以执行本申请各实施例的功能。Specifically, memory 810 may include computer system readable media in the form of volatile memory, such as random access memory (RAM) and/or cache memory. Electronic device 800 may further include other removable/non-removable, volatile/non-volatile computer system storage media. The memory 810 may include at least one program product having a set (eg, at least one) of program modules configured to perform the functions of various embodiments of the present application.

处理器820与存储器810相连,用于执行存储器810存储的计算机程序,以使电子设备800执行本申请任一实施例所述的训练方法和/或本申请实施例所述的样本分类方法。The processor 820 is connected to the memory 810 and is used to execute the computer program stored in the memory 810, so that the electronic device 800 executes the training method described in any embodiment of this application and/or the sample classification method described in the embodiment of this application.

可选地,处理器820可以是通用处理器,包括中央处理器(Central processingunit,简称CPU)、网络处理器(Network processor,简称NP)等;还可以是数字信号处理器(Digital signal processor,简称DSP)、专用集成电路(Application specificintegrated circuit,简称ASIC)、现场可编程门阵列(Field programmable gate array,简称FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。Optionally, the processor 820 may be a general-purpose processor, including a central processing unit (CPU for short), a network processor (NP for short), etc.; it may also be a digital signal processor (Digital signal processor for short) DSP), Application specific integrated circuit (ASIC for short), Field programmable gate array (FPGA for short) or other programmable logic devices, discrete gate or transistor logic devices, and discrete hardware components.

可选地,本实施例中电子设备800还可以包括显示器830。显示器830与存储器810和处理器820通信相连,用于显示本申请任一实施例所述的训练方法和/或本申请实施例所述的样本分类方法的相关图形用户界面(Graphics user interface,简称GUI)交互界面。Optionally, in this embodiment, the electronic device 800 may also include a display 830. The display 830 is communicatively connected to the memory 810 and the processor 820, and is used to display the relevant graphical user interface (Graphics user interface, abbreviation) of the training method described in any embodiment of the present application and/or the sample classification method described in the embodiment of the present application. GUI) interactive interface.

本申请实施例还提供了一种计算机可读存储介质,其上存储有计算机程序。该程序被处理器执行时实现本申请任一实施例所述的训练方法和/或本申请实施例所述的样本分类方法。Embodiments of the present application also provide a computer-readable storage medium on which a computer program is stored. When the program is executed by the processor, the training method described in any embodiment of this application and/or the sample classification method described in the embodiment of this application is implemented.

上述各个附图对应的流程或结构的描述各有侧重,某个流程或结构中没有详述的部分,可以参见其他流程或结构的相关描述。The descriptions of the processes or structures corresponding to each of the above drawings have different emphasis. For parts that are not described in detail in a certain process or structure, please refer to the relevant descriptions of other processes or structures.

上述实施例仅例示性说明本申请的原理及其功效,而非用于限制本申请。任何熟悉此技术的人士皆可在不违背本申请的精神及范畴下,对上述实施例进行修饰或改变。因此,举凡所属技术领域中具有通常知识者在未脱离本申请所揭示的精神与技术思想下所完成的一切等效修饰或改变,仍应由本申请的权利要求所涵盖。The above embodiments only illustrate the principles and effects of the present application, but are not used to limit the present application. Anyone familiar with this technology can modify or change the above embodiments without departing from the spirit and scope of the present application. Therefore, all equivalent modifications or changes made by those with ordinary knowledge in the technical field without departing from the spirit and technical ideas disclosed in this application shall still be covered by the claims of this application.

Claims (12)

1.一种训练方法,其特征在于,包括:1. A training method, characterized in that it includes: 获取训练数据,所述训练数据包括支持集和查询集,所述支持集包括多个图像,所述查询集包括多个图像;Obtain training data, the training data includes a support set and a query set, the support set includes a plurality of images, and the query set includes a plurality of images; 利用深度学习模型获取所述支持集和所述查询集中的数据样本的特征;Using a deep learning model to obtain characteristics of the data samples in the support set and the query set; 利用支持集的数据样本的特征进行处理,获取所述支持集各类别数据样本的原型向量和所述原型向量的各通道之间的差值;Utilize the characteristics of the data samples in the support set for processing, and obtain the difference between the prototype vector of each category of data samples in the support set and each channel of the prototype vector; 利用查询集的数据样本的特征进行处理,获取所述查询集的数据样本的特征的各通道之间的差值;Using the characteristics of the data samples of the query set for processing, obtaining the differences between the channels of the characteristics of the data samples of the query set; 利用各类别所述原型向量的各通道之间的差值和所述查询集的数据样本的特征的各通道之间的差值,获取所述查询集的数据样本的特征和各类别所述原型向量之间的可微肯德尔相关系数;Using the difference between each channel of the prototype vector of each category and the difference between each channel of the characteristics of the data sample of the query set, obtain the characteristics of the data sample of the query set and the prototype of each category Differentiable Kendall correlation coefficient between vectors; 根据所述可微肯德尔相关系数计算损失函数;Calculate a loss function based on the differentiable Kendall correlation coefficient; 根据所述训练数据,利用所述损失函数对所述深度学习模型进行训练。According to the training data, the deep learning model is trained using the loss function. 2.根据权利要求1所述的训练方法,其特征在于,利用支持集的数据样本的特征进行处理,获取所述支持集各类别数据样本的原型向量和所述原型向量的各通道之间的差值包括:2. The training method according to claim 1, characterized in that the characteristics of the data samples of the support set are used for processing to obtain the prototype vectors of each category of data samples in the support set and the distance between each channel of the prototype vector. Differences include: 根据所述支持集的数据样本的特征,获取所述支持集各类别数据样本的原型向量;According to the characteristics of the data samples in the support set, obtain the prototype vectors of each category of data samples in the support set; 利用所述原型向量获取所述原型向量的各通道值;Using the prototype vector to obtain each channel value of the prototype vector; 根据所述原型向量的各通道值,获取所述原型向量的各通道之间的差值。According to the value of each channel of the prototype vector, the difference between each channel of the prototype vector is obtained. 3.根据权利要求1所述的训练方法,其特征在于,利用查询集的数据样本的特征进行处理,获取所述查询集的数据样本的特征的各通道之间的差值包括:3. The training method according to claim 1, characterized in that, using the characteristics of the data samples of the query set for processing, obtaining the difference between the channels of the characteristics of the data samples of the query set includes: 根据所述查询集中的数据样本的特征获取所述查询集的数据样本的特征的各通道值;Obtain each channel value of the characteristics of the data samples in the query set according to the characteristics of the data samples in the query set; 利用所述查询集的数据样本的特征的各通道值获取所述查询集的数据样本的特征的各通道之间的差值。The difference between each channel of the characteristics of the data samples of the query set is obtained by using the channel values of the characteristics of the data samples of the query set. 4.根据权利要求1所述的训练方法,其特征在于,获取所述支持集的各类别的原型向量包括:4. The training method according to claim 1, wherein obtaining prototype vectors of each category of the support set includes: 根据所述支持集各类别数据样本的特征的平均值,获取所述支持集各类别数据样本的原型向量。According to the average value of the characteristics of the data samples of each category of the support set, the prototype vector of the data samples of each category of the support set is obtained. 5.根据权利要求1所述的训练方法,其特征在于,所述损失函数的计算方法包括:5. The training method according to claim 1, characterized in that the calculation method of the loss function includes: 根据所述查询集数据样本的特征和各类别原型向量之间的可微肯德尔相关系数获取查询集的数据样本属于各类别的概率;Obtain the probability that the data sample of the query set belongs to each category according to the differentiable Kendall correlation coefficient between the characteristics of the query set data sample and the prototype vector of each category; 利用所述查询集的数据样本属于各类别的概率和查询集的数据样本数量获取所述损失函数。The loss function is obtained using the probability that the data samples in the query set belong to each category and the number of data samples in the query set. 6.根据权利要求5所述的训练方法,其特征在于,所述损失函数为:6. The training method according to claim 5, characterized in that the loss function is: 其中,L为损失函数,|Q|为查询集的样本数量,x为输入的查询集数据样本,其所属类别为y,cy、cj为支持集第y个类别、第j个类别的原型向量,t为常数,N为支持集中的类别总数,为两个输入向量之间的可微肯德尔相关系数,fθ为深度学习模型,θ为深度学习模型的参数。Among them, L is the loss function, |Q| is the number of samples in the query set, x is the input query set data sample, its category is y, c y , c j are the yth category and jth category of the support set Prototype vector, t is a constant, N is the total number of categories in the support set, is the differentiable Kendall correlation coefficient between the two input vectors, f θ is the deep learning model, and θ is the parameter of the deep learning model. 7.根据权利要求1所述的训练方法,其特征在于,利用所述损失函数对所述深度学习模型进行训练包括:7. The training method according to claim 1, wherein using the loss function to train the deep learning model includes: 利用所述损失函数对所述深度学习模型进行训练时,调整至少一个所述深度学习模型的参数值。When the loss function is used to train the deep learning model, a parameter value of at least one of the deep learning models is adjusted. 8.根据权利要求1所述的训练方法,其特征在于,所述深度学习模型为卷积神经网络ResNet或基于Transformer的神经网络ViT。8. The training method according to claim 1, characterized in that the deep learning model is a convolutional neural network ResNet or a Transformer-based neural network ViT. 9.一种样本分类方法,其特征在于,包括:9. A sample classification method, characterized by including: 获取各个类别至少一个参照样本和待分类的目标样本;Obtain at least one reference sample for each category and the target sample to be classified; 利用深度学习模型获取所述参照样本和目标样本的特征,所述深度学习模型利用如权利要求1至8中任一项所述的训练方法训练得到;Using a deep learning model to obtain the characteristics of the reference sample and the target sample, the deep learning model is trained using the training method as described in any one of claims 1 to 8; 根据所述参照样本的特征进行处理,获取各类别的原型向量;Process according to the characteristics of the reference sample to obtain prototype vectors of each category; 根据所述目标样本的特征和所述各类别的原型向量进行处理,获取所述目标样本的特征的通道重要性排序结果和所述各类别的原型向量的通道重要性排序结果;Process according to the characteristics of the target sample and the prototype vectors of each category to obtain the channel importance ranking results of the characteristics of the target sample and the channel importance ranking results of the prototype vectors of each category; 利用所述目标样本的所述特征的通道重要性排序结果和所述各类别的原型向量的通道重要性排序结果,获取所述目标样本的特征和各类别的原型向量之间的肯德尔相关系数;Using the channel importance ranking results of the features of the target sample and the channel importance ranking results of the prototype vectors of each category, the Kendall correlation coefficient between the features of the target sample and the prototype vectors of each category is obtained ; 利用所述肯德尔相关系数,获取所述目标样本的分类结果。The Kendall correlation coefficient is used to obtain the classification result of the target sample. 10.根据权利要求9所述的样本分类方法,其特征在于,获取肯德尔相关系数包括:10. The sample classification method according to claim 9, characterized in that obtaining the Kendall correlation coefficient includes: 获取所述目标样本的特征和所述原型向量的通道重要性排序结果;Obtain the characteristics of the target sample and the channel importance ranking results of the prototype vector; 利用所述目标样本的特征和所述原型向量之间所有对应通道对的重要性排序结果进行匹配,匹配结果包括重要性排序结果一致的通道对和重要性排序结果不一致的通道对;Matching is performed using the importance ranking results of all corresponding channel pairs between the characteristics of the target sample and the prototype vector. The matching results include channel pairs with consistent importance ranking results and channel pairs with inconsistent importance ranking results; 根据重要性排序结果一致的通道对的数量、重要性排序结果不一致的通道对的数量和总通道对的数量获取所述肯德尔相关系数。The Kendall correlation coefficient is obtained according to the number of channel pairs with consistent importance ranking results, the number of channel pairs with inconsistent importance ranking results, and the number of total channel pairs. 11.一种电子设备,其特征在于,所述电子设备包括:11. An electronic device, characterized in that the electronic device includes: 存储器,用于存储计算机程序;Memory, used to store computer programs; 处理器,所述处理器用于执行所述存储器存储的计算机程序,以使所述电子设备执行权利要求1至8中任一项所述的训练方法和/或权利要求9至10中任一项所述的样本分类方法。Processor, the processor is used to execute the computer program stored in the memory, so that the electronic device executes the training method according to any one of claims 1 to 8 and/or any one of claims 9 to 10 The sample classification method described. 12.一种计算机可读存储介质,其上存储有计算机程序,其特征在于,该程序被处理器执行时实现权利要求1至8中任一项所述的训练方法和/或权利要求9至10中任一项所述的样本分类方法。12. A computer-readable storage medium with a computer program stored thereon, characterized in that when the program is executed by a processor, the training method according to any one of claims 1 to 8 and/or claims 9 to 8 is implemented. The sample classification method described in any one of 10.
CN202310892144.1A 2023-07-19 2023-07-19 Training method, sample classification method, electronic equipment and medium Active CN116935160B (en)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202310892144.1A CN116935160B (en) 2023-07-19 2023-07-19 Training method, sample classification method, electronic equipment and medium

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202310892144.1A CN116935160B (en) 2023-07-19 2023-07-19 Training method, sample classification method, electronic equipment and medium

Publications (2)

Publication Number Publication Date
CN116935160A true CN116935160A (en) 2023-10-24
CN116935160B CN116935160B (en) 2024-05-10

Family

ID=88378587

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202310892144.1A Active CN116935160B (en) 2023-07-19 2023-07-19 Training method, sample classification method, electronic equipment and medium

Country Status (1)

Country Link
CN (1) CN116935160B (en)

Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112287954A (en) * 2019-07-24 2021-01-29 华为技术有限公司 Image classification method, training method of image classification model and device thereof
CN112949740A (en) * 2021-03-17 2021-06-11 重庆邮电大学 Small sample image classification method based on multilevel measurement
CN114298290A (en) * 2021-12-08 2022-04-08 重庆邮电大学 Neural network coding method and coder based on self-supervision learning
WO2022243337A2 (en) * 2021-05-17 2022-11-24 Deep Safety Gmbh System for detection and management of uncertainty in perception systems, for new object detection and for situation anticipation
CN115512202A (en) * 2022-09-27 2022-12-23 湖南朗国视觉识别研究院有限公司 Small sample target detection method, system and storage medium based on metric learning
CN115795355A (en) * 2023-02-10 2023-03-14 中国科学院自动化研究所 A classification model training method, device and equipment

Patent Citations (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN112287954A (en) * 2019-07-24 2021-01-29 华为技术有限公司 Image classification method, training method of image classification model and device thereof
CN112949740A (en) * 2021-03-17 2021-06-11 重庆邮电大学 Small sample image classification method based on multilevel measurement
WO2022243337A2 (en) * 2021-05-17 2022-11-24 Deep Safety Gmbh System for detection and management of uncertainty in perception systems, for new object detection and for situation anticipation
CN114298290A (en) * 2021-12-08 2022-04-08 重庆邮电大学 Neural network coding method and coder based on self-supervision learning
CN115512202A (en) * 2022-09-27 2022-12-23 湖南朗国视觉识别研究院有限公司 Small sample target detection method, system and storage medium based on metric learning
CN115795355A (en) * 2023-02-10 2023-03-14 中国科学院自动化研究所 A classification model training method, device and equipment

Non-Patent Citations (4)

* Cited by examiner, † Cited by third party
Title
XU LUO .ETAL: "Channel Importance Matters in Few-Shot Image Classification", ARXIV *
刘宁 等: "生成式对抗网络在抑郁症分类中的应用", 计算机应用与软件, no. 06 *
王扬 等: "多查询相关的排序支持向量机融合算法", 计算机研究与发展, no. 04 *
胡春健: "小样本下Kendall τ相关系数的显著性检验", 控制工程, vol. 20, no. 06 *

Also Published As

Publication number Publication date
CN116935160B (en) 2024-05-10

Similar Documents

Publication Publication Date Title
US12061989B2 (en) Machine learning artificial intelligence system for identifying vehicles
US11847540B2 (en) Graph model build and scoring engine
CN108280477B (en) Method and apparatus for clustering images
Hossain et al. Improving consumer satisfaction in smart cities using edge computing and caching: A case study of date fruits classification
CN114332984B (en) Training data processing method, device and storage medium
CN108287857B (en) Expression picture recommendation method and device
WO2019015246A1 (en) Image feature acquisition
CN114565807B (en) Method and device for training target image retrieval model
CN109711228A (en) An image processing method and device for realizing image recognition, and electronic equipment
WO2023020214A1 (en) Retrieval model training method and apparatus, retrieval method and apparatus, device and medium
CN117216362A (en) Content recommendation method, device, apparatus, medium and program product
CN113139540A (en) Backboard detection method and equipment
CN111488479A (en) Hypergraph construction method, hypergraph construction device, computer system and medium
CN116935160B (en) Training method, sample classification method, electronic equipment and medium
CN114708449B (en) Similar video determination method, and training method and device of example characterization model
CN110428012A (en) Brain method for establishing network model, brain image classification method, device and electronic equipment
CN115658942A (en) A joint credit information intelligent data retrieval method for financial scenarios
Kanwal et al. Evaluation method, dataset size or dataset content: how to evaluate algorithms for image matching?
CN115905846A (en) A sample selection method and device
CN116257676A (en) Data processing method, model training method, device, equipment and storage medium
CN110298400A (en) A kind of image classification method, device, equipment and storage medium
CN118366478B (en) Generated audio identification and generated region localization method based on phoneme interval sequence
HK40071514B (en) Determining method of similar videos, training method of example representation model and equipment
HK40071514A (en) Determining method of similar videos, training method of example representation model and equipment
Campobello et al. An efficient algorithm for parallel distributed unsupervised learning

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant