CN114090780A - 一种基于提示学习的快速图片分类方法 - Google Patents

一种基于提示学习的快速图片分类方法 Download PDF

Info

Publication number
CN114090780A
CN114090780A CN202210062188.7A CN202210062188A CN114090780A CN 114090780 A CN114090780 A CN 114090780A CN 202210062188 A CN202210062188 A CN 202210062188A CN 114090780 A CN114090780 A CN 114090780A
Authority
CN
China
Prior art keywords
vector
training
prompt
category
picture
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202210062188.7A
Other languages
English (en)
Other versions
CN114090780B (zh
Inventor
赵天成
陆骁鹏
刘鹏
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Honglong Technology Hangzhou Co ltd
Original Assignee
Honglong Technology Hangzhou Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Honglong Technology Hangzhou Co ltd filed Critical Honglong Technology Hangzhou Co ltd
Priority to CN202210062188.7A priority Critical patent/CN114090780B/zh
Publication of CN114090780A publication Critical patent/CN114090780A/zh
Application granted granted Critical
Publication of CN114090780B publication Critical patent/CN114090780B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/30Information retrieval; Database structures therefor; File system structures therefor of unstructured textual data
    • G06F16/35Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/50Information retrieval; Database structures therefor; File system structures therefor of still image data
    • G06F16/55Clustering; Classification
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F16/00Information retrieval; Database structures therefor; File system structures therefor
    • G06F16/50Information retrieval; Database structures therefor; File system structures therefor of still image data
    • G06F16/58Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually
    • G06F16/583Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually using metadata automatically derived from the content
    • G06F16/5846Retrieval characterised by using metadata, e.g. metadata not derived from the content or metadata generated manually using metadata automatically derived from the content using extracted text
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/10Text processing
    • G06F40/12Use of codes for handling textual entities
    • G06F40/126Character encoding
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/10Text processing
    • G06F40/166Editing, e.g. inserting or deleting
    • G06F40/186Templates
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/10Text processing
    • G06F40/194Calculation of difference between files
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F40/00Handling natural language data
    • G06F40/20Natural language analysis
    • G06F40/279Recognition of textual entities
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Artificial Intelligence (AREA)
  • General Health & Medical Sciences (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Audiology, Speech & Language Pathology (AREA)
  • Data Mining & Analysis (AREA)
  • Databases & Information Systems (AREA)
  • Library & Information Science (AREA)
  • Evolutionary Computation (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Image Analysis (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本发明公开了一种基于提示学习的快速图片分类方法,其包括以下步骤:S1、提示初始化;S2、提示学习与模型训练;S3、使用获得的模型进行图片分类。本方案通过提示学习和图文多模态预训练模型来提高图片分类任务的性能,并且减少对于人工标注的数据量的需求,仅需要几十张标注数据就可达到相当高的准确率,适用于计算机图片处理领域。

Description

一种基于提示学习的快速图片分类方法
技术领域
本发明涉及计算机视觉领域,尤其是涉及一种基于提示学习的快速图片分类方法。
背景技术
图片分类是计算机视觉领域常见的任务之一,通过人工标注的固定类别的分类图片数据训练模型,使模型学会将图片分到对应类别中。提示(prompt)是一种为了更好的使用预训练语言模型的知识,采用在输入段添加额外的文本的技术。提示学习就是在模型训练时加入提示模板,给预训练语言模型的一个线索/提示,帮助它可以更好的理解人类的问题。提示学习通常运用在自然语言处理领域的预训练语言模型相关任务上。
快速图片分类技术有着很多重要的商业应用,例如对于缺少大规模训练数据的长尾场景,快速图片分类可以通过使用小样本数据进行训练就达到一个相当可观的准确率,解决缺少数据无法训练分类模型的痛点。
发明内容
本发明主要是提供一种基于提示学习的快速图片分类方法,能够通过提示学习和图文多模态预训练模型相结合,在小样本训练数据上快速实现高性能图片分类。
本发明的主要方案是:一种基于提示学习的快速图片分类方法,包括以下步骤:
S1、提示初始化;
S2、提示学习与模型训练;
S3、使用获得的模型进行图片分类;
所述步骤S1具体为:
构建一个M×N维的数组向量,数组向量包括M个N维的数组,N是文本编码器的输入向量维度,M为数组个数,然后对数组向量进行初始化作为提示向量;
将提示向量与类别名称向量拼接,类别名称向量为要训练的分类任务中的具体类别名称的输入向量;类别名称向量通过文本编码器(即文本预训练模型)对预设文本分词后将分词字符转换为对应的特征向量数组得到;预设文本是和具体训练任务相关的,比如说这次的训练任务是训练模型分类猫和狗,那么训练数据中的每一个图片数据都要被标注是猫或者狗,而这次训练任务中生成类别名称向量的预设文本也就是“猫”和“狗”,总的来说这个预设文本就是分类任务中的各个类别名称;
所述步骤S2具体为:
S201、将拼接后的向量输入到文本编码器中,得到每个类别的文本向量;通过图片编码器将训练图片转换成图片向量,训练图片为标注好的训练数据中各个类别的图片;
S202、通过以下公式进行图片分类计算:
Figure 96783DEST_PATH_IMAGE001
其中g(xi)表示第i个类别加上提示向量后经过编码器生成的文本向量,f是图片编码器生成的图片向量;K表示此次分类任务的类别总数,j表达总类别中的第j个类,y表示模型的预测结果,y=i即为模型预测结果为第i类,而p(y=i|x)就表示模型预测图片为第i类的概率;
通过上述公式计算每个类别文本向量与图片向量的相似度之后取相似度最大的类别作为模型预测的类别;
S203、最后通过交叉熵损失函数来和真实类别计算损失,并且固定住图文多模态预训练模型(文本编码器和图片编码器)的参数,只通过反向传播损失来训练提示向量。
提示向量经过初始化之后并不是固定不变,它的向量参数是会在模型使用训练数据的训练迭代过程中,根据反向传播时损失函数计算出的损失值进行动态调整的。
作为优选,M为2的倍数,M随着训练任务的复杂度增加而增大,M的最大值不大于文本编码器的上下文长度。
作为优选,向量初始化通过随机的方式进行:
在固定范围为每一个维度随机取一个浮点数。具体可以通过标准差为0.02,平均值为0的正态分布来随机取值。
作为优选,向量初始化通过已有的提示文本模板生成:
设定一句提示文本,经过文本编码器的预处理后转换为已经预训练过的M×N的数组向量。预处理为对文本分词后将分词字符转换为对应的特征向量数组。
作为优选,提示向量中添加预训练的负样本提示向量和/或干扰提示向量来帮助模型学习如何区分干扰项。
这里的实现方法是通过前述中的已有提示文本模板来实现的。具体分为两类:
1. 预设负样本提示向量:这个是和训练任务相关的需要根据训练任务预设。例如:和人相关的分类任务中,图片中的人不是完整的人,只是身体的一部分,这样会影响分类任务的效果,因为无法根据不完整的人判断具体类别。这时候预设一个“不完整的人”文本模板,然后通过已有的提示文本模板生成负样本提示向量;
2. 干扰提示向量:这个提示向量生成方法和1一样,但是与具体训练任务不相关,它会用一个“其他”文本模板,生成提示向量。作用在于只有训练数据只有正样本类别时,加入“其他”这个干扰类,帮助模型学会在遇到训练时未学过但不属于任何一个正样本类别的数据时,分到其他中,而不是强行分给某一个正样本类别。
生成以上的负样本或干扰提示向量后,相当于多了一个类别,那就需要给这个类提供对应的训练数据。这里生成训练数据的方法是从现有训练数据中,随机取一部分数据。数量等于各类别数据中最少的那一类的数量。然后从随机选中的每一个训练图片中随机剪裁出一个矩形部分作为负样本或干扰提示向量类的训练数据。具体方法为,以图片长和宽为最大值,随机生成一个矩形进行裁剪,矩形长和宽不大于图片长宽的70%, 不小于图片长宽的20%。
作为优选,提示向量和类别名称向量按以下两种方式中的任意一种拼接:
x=[V]1[V]2…[V]M[CLASS]
x=[V]1[V]2…[V]M/2[CLASS] [V](M/2)+1…[V]M
其中,[V]m(m∈{1,2,…,M})是提示向量,[CLASS]是要训练的分类任务中的具体类别名称的输入向量。
第一种将类别名称向量拼接到提示向量后面,第二种是将类别名称向量拼接到提示向量中间。
作为优选,当训练数据及类别均少于各自的阈值的时候,所有图片类别共享同一个提示向量;当训练数据量和类别中的任意一个大于或等于各自的阈值的时候,每个图片类别分别生成并训练各自的提示向量。训练数据的阈值为100个,类别的阈值为5个。
作为优选,所述步骤S203具体为:
对于二分类任务,交叉熵损失函数的具体形式如下:
Figure 341819DEST_PATH_IMAGE002
对于多分类任务, 交叉熵损失函数的具体形式如下:
Figure 387136DEST_PATH_IMAGE003
L表示N个训练数据的损失,也就是每个数据的损失Li除以N得到的平均数;具体单个训练数据的损失如后半段公式,yi表示第i个数据的真实类别,正类为1,负类为0,pi表示模型预测该数据为正类的概率;在多分类任务的损失函数中,C表示多分类任务的类别数,也就是在二分类损失函数的基础上,对模型在每个类别的预测结果上进行了一个求和来计算损失;在反向传播时,通过对损失函数求导计算梯度,并将梯度传递到之前的网络结构的函数中,来通过调整函数中的权重参数实现降低损失,从而使模型学会预测正确的类别;训练的结束条件为以下两种中的任意一种:(1)设定为迭代固定次数的全量数据训练过程后停止;(2)设定为损失值降低到损失阈值或者K个训练迭代后损失值未降低时停止训练,K为预设参数,一般为20。
二分类任务只有一个类别,模型只需要预测是或者不是,也就是正类或负类。而多分类任务是多个类别,也就是说模型需要对每个类别预测是或者不是,就相当于在二分类任务的基础上对所有类别的损失进行了一个求和。
作为优选,所述步骤S3具体为:
使用上述训练最终保存下来的提示向量,通过多模态模型的文本编码器生成各类别文本向量,然后与图片编码器生成的图片向量计算相似度,相似度最高的类别就是该图片的预测类别。
本发明带来的实质性效果是,提出了一种新颖的基于提示学习的快速图片分类技术,通过一种提示学习和图文多模态预训练模型来提高图片分类任务的性能,并且减少对于人工标注的数据量的需求,仅需要几十张标注数据就可达到相当高的准确率。
附图说明
图1是本发明的一种流程图。
具体实施方式
下面通过实施例,并结合附图,对本发明的技术方案作进一步具体的说明。
实施例:一种基于提示学习的快速图片分类方法,如图1所示,包括以下步骤:
S1、提示初始化;
S2、提示学习与模型训练;
S3、使用获得的模型进行图片分类。
所述步骤S1具体为:
构建一个M×N维的数组向量,数组向量包括M个N维的数组,N是文本编码器的输入向量维度,M为数组个数,然后对数组向量进行初始化作为提示向量;本方案中,提示为多个拥有可学习参数的上下文向量;例如Roberta预训练语言模型,则N=512维,那就是一个长度为512维的数字数组;
将提示向量与类别名称向量拼接,类别名称向量为要训练的分类任务中的具体类别名称的输入向量;类别名称向量通过文本预训练模型对预设文本分词后将分词字符转换为对应的特征向量数组得到;预设文本是和具体训练任务相关的,比如说这次的训练任务是训练模型分类猫和狗,那么训练数据中的每一个图片数据都要被标注是猫或者狗,而这次训练任务中生成类别名称向量的预设文本也就是“猫”和“狗”,总的来说这个预设文本就是分类任务中的各个类别名称;
所述步骤S2具体为:
S201、将拼接后的向量输入到文本编码器中,得到每个类别的文本向量;通过图片编码器将训练图片转换成图片向量,训练图片为标注好的训练数据中各个类别的图片;这里文本编码器使用的是transformer架构的NLP模型,可以是bert、roberta、ernie等等,包括其他自定义的文本预训练模型也可作为替换;这里对图片编码器的模型结构也不进行限制,作为优选,这里使用的是残差网络ResNet的多层预训练模型,或者Vision Transformer架构的多层预训练模型,但也替换为其他能提取图片特征的图片预训练模型;
S202、通过以下公式进行图片分类计算:
Figure 608776DEST_PATH_IMAGE001
其中g(xi)表示第i个类别加上提示向量后经过编码器生成的文本向量,f是图片编码器生成的图片向量;K表示此次分类任务的类别总数,j表达总类别中的第j个类,y表示模型的预测结果,y=i即为模型预测结果为第i类,而p(y=i|x)就表示模型预测图片为第i类的概率;
通过上述公式(softmax函数,softmax函数计算每一类的相似度然后除以所有类相似度之和,从而把相似度之和控制为1,而每一个类的相似度也就可以百分之多少来表示)计算每个类别文本向量与图片向量的相似度之后取相似度最大的类别作为模型预测的类别;
S203、最后通过交叉熵损失函数来和真实类别计算损失,并且固定住图文多模态预训练模型的参数,只通过反向传播损失来训练提示向量。
提示向量经过初始化之后并不是固定不变,它的向量参数是会在模型使用训练数据的训练迭代过程中,根据反向传播时损失函数计算出的损失值进行动态调整的。
M为2的倍数,M随着训练任务的复杂度增加而增大,M的最大值不大于文本编码器的上下文长度。
向量初始化通过随机的方式进行:
在固定范围为每一个维度随机取一个浮点数。具体可以通过标准差为0.02,平均值为0的正态分布来随机取值。
向量初始化通过已有的提示文本模板生成:
设定一句提示文本,经过文本编码器的预处理后转换为已经预训练过的M×N的数组向量。预处理为对文本分词后将分词字符转换为对应的特征向量数组。
可以使用多种初始化方法来进行对比训练,选择准确率更高的一种。
提示向量也可以不仅仅是针对分类任务中的类别,可以添加预训练的负样本提示向量、干扰提示向量或其他未知类等来帮助模型学习如何区分干扰项,找到准确的类别,减少误报。
这里的实现方法是通过前述中的已有提示文本模板来实现的。具体分为两类:
1. 预设负样本提示向量:这个是和训练任务相关的需要根据训练任务预设。例如:和人相关的分类任务中,图片中的人不是完整的人,只是身体的一部分,这样会影响分类任务的效果,因为无法根据不完整的人判断具体类别。这时候预设一个“不完整的人”文本模板,然后通过已有的提示文本模板生成负样本提示向量;
2. 干扰提示向量:这个提示向量生成方法和1一样,但是与具体训练任务不相关,它会用一个“其他”文本模板,生成提示向量。作用在于只有训练数据只有正样本类别时,加入“其他”这个干扰类,帮助模型学会在遇到训练时未学过但不属于任何一个正样本类别的数据时,分到其他中,而不是强行分给某一个正样本类别。
生成以上的负样本或干扰提示向量后,相当于多了一个类别,那就需要给这个类提供对应的训练数据。这里生成训练数据的方法是从现有训练数据中,随机取一部分数据。数量等于各类别数据中最少的那一类的数量。然后从随机选中的每一个训练图片中随机剪裁出一个矩形部分作为负样本或干扰提示向量类的训练数据。具体方法为,以图片长和宽为最大值,随机生成一个矩形进行裁剪,矩形长和宽不大于图片长宽的70%, 不小于图片长宽的20%。
提示向量和类别名称向量按以下两种方式中的任意一种拼接:
x=[V]1[V]2…[V]M[CLASS]
x=[V]1[V]2…[V]M/2[CLASS] [V](M/2)+1…[V]M
其中,[V]m(m∈{1,2,…,M})是提示向量,[CLASS]是要训练的分类任务中的具体类别名称的输入向量。
第一种将类别名称向量拼接到提示向量后面,第二种是将类别名称向量拼接到提示向量中间。设提示向量为M•N维的数组向量,而类别名称向量是一个T•N的数组向量,那么拼接完蛋模型输入向量就是一个(M+T)•N的向量。
当训练数据及类别均少于各自的阈值的时候,所有图片类别共享同一个提示向量;当训练数据量和类别中的任意一个大于或等于各自的阈值的时候,每个图片类别分别生成并训练各自的提示向量。训练数据的阈值为100个,类别的阈值为5个。
所述步骤S203具体为:
对于二分类任务,交叉熵损失函数的具体形式如下:
Figure 688728DEST_PATH_IMAGE002
对于多分类任务, 交叉熵损失函数的具体形式如下:
Figure 42349DEST_PATH_IMAGE003
L表示N个训练数据的损失,也就是每个数据的损失Li除以N得到的平均数;具体单个训练数据的损失如后半段公式,yi表示第i个数据的真实类别,正类为1,负类为0,pi表示模型预测该数据为正类的概率;因此,如果数据的真实类别为正类,那么模型预测为正类时的概率越高,计算出的损失值就越低;C表示多分类任务的类别数;在反向传播时,通过对损失函数求导计算梯度,并将梯度传递到之前的网络结构的函数中,来通过调整函数中的权重参数实现降低损失,从而使模型学会预测正确的类别;本方案在训练过程中会选择固定住图文多模态预训练模型中的参数,因为预训练模型的参数是经过海量数据训练调整过的,不许再使用当前少量的分类训练数据去影响其中参数了。只需要使最外层初始化过的提示向量参数随训练进行调整就可以了。训练的结束条件为以下两种中的任意一种:(1)设定为迭代固定次数的全量数据训练过程后停止;(2)设定为损失值降低到损失阈值或者K个训练迭代后损失值未降低时停止训练,K为预设参数,一般为20。
二分类任务只有一个类别,模型只需要预测是或者不是,也就是正类或负类。而多分类任务是多个类别,也就是说模型需要对每个类别预测是或者不是,就相当于在二分类任务的基础上对所有类别的损失进行了一个求和。
所述步骤S3具体为:
使用上述训练最终保存下来的提示向量,通过多模态模型的文本编码器生成各类别文本向量,然后与图片编码器生成的图片向量计算相似度,相似度最高的类别就是该图片的预测类别。
经验证,在平均每类300的7分类人物行为数据上进行小样本训练后,模型对分类准确率达到90%以上。
本文中所描述的具体实施例仅仅是对本发明精神作举例说明。本发明所属技术领域的技术人员可以对所描述的具体实施例做各种各样的修改或补充或采用类似的方式替代,但并不会偏离本发明的精神或者超越所附权利要求书所定义的范围。
尽管本文较多地使用了提示学习、多模态预训练模型、损失函数等术语,但并不排除使用其它术语的可能性。使用这些术语仅仅是为了更方便地描述和解释本发明的本质;把它们解释成任何一种附加的限制都是与本发明精神相违背的。

Claims (9)

1.一种基于提示学习的快速图片分类方法,其特征在于,包括以下步骤:
S1、提示初始化;
S2、提示学习与模型训练;
S3、使用获得的模型进行图片分类;
所述步骤S1具体为:
构建一个M×N维的数组向量,数组向量包括M个N维的数组,N是文本编码器的输入向量维度,M为数组个数,然后对数组向量进行初始化作为提示向量;
将提示向量与类别名称向量拼接,类别名称向量为要训练的分类任务中的具体类别名称的输入向量;类别名称向量通过文本编码器对预设文本分词后将分词字符转换为对应的特征向量数组得到;
所述步骤S2具体为:
S201、将拼接后的向量输入到文本编码器中,得到每个类别的文本向量;通过图片编码器将训练图片转换成图片向量,训练图片为标注好的训练数据中各个类别的图片;
S202、通过以下公式进行图片分类计算:
Figure 794720DEST_PATH_IMAGE001
其中g(xi)表示第i个类别加上提示向量后经过编码器生成的文本向量,f是图片编码器生成的图片向量;K表示此次分类任务的类别总数,j表达总类别中的第j个类,y表示模型的预测结果,y=i即为模型预测结果为第i类,而p(y=i|x)就表示模型预测图片为第i类的概率;
S203、最后通过交叉熵损失函数来和真实类别计算损失,并且固定住图文多模态预训练模型的参数,只通过反向传播损失来训练提示向量。
2.根据权利要求1所述的一种基于提示学习的快速图片分类方法,其特征在于,M为2的倍数,M随着训练任务的复杂度增加而增大,M的最大值不大于文本编码器的上下文长度。
3.根据权利要求2所述的一种基于提示学习的快速图片分类方法,其特征在于,向量初始化通过随机的方式进行:
在固定范围为每一个维度随机取一个浮点数。
4.根据权利要求2所述的一种基于提示学习的快速图片分类方法,其特征在于,向量初始化通过已有的提示文本模板生成:
设定一句提示文本,经过文本编码器的预处理后转换为已经预训练过的M×N的数组向量。
5.根据权利要求3或4所述的一种基于提示学习的快速图片分类方法,其特征在于,提示向量中添加预训练的负样本提示向量和/或干扰提示向量来帮助模型学习如何区分干扰项。
6.根据权利要求5所述的一种基于提示学习的快速图片分类方法,其特征在于,提示向量和类别名称向量按以下两种方式中的任意一种拼接:
x=[V]1[V]2…[V]M[CLASS]
x=[V]1[V]2…[V]M/2[CLASS] [V](M/2)+1…[V]M
其中,[V]m(m∈{1,2,…,M})是提示向量,[CLASS]是要训练的分类任务中的具体类别名称的输入向量。
7.根据权利要求6所述的一种基于提示学习的快速图片分类方法,其特征在于,当训练数据及类别均少于各自的阈值的时候,所有图片类别共享同一个提示向量;当训练数据量和类别中的任意一个大于或等于各自的阈值的时候,每个图片类别分别生成并训练各自的提示向量。
8.根据权利要求1所述的一种基于提示学习的快速图片分类方法,其特征在于,所述步骤S203具体为:
对于二分类任务,交叉熵损失函数的具体形式如下:
Figure 39756DEST_PATH_IMAGE002
对于多分类任务, 交叉熵损失函数的具体形式如下:
Figure 350652DEST_PATH_IMAGE003
L表示N个训练数据的损失,也就是每个数据的损失Li除以N得到的平均数;具体单个训练数据的损失如后半段公式,yi表示第i个数据的真实类别,正类为1,负类为0,pi表示模型预测该数据为正类的概率;在多分类任务的损失函数中,C表示多分类任务的类别数,也就是在二分类损失函数的基础上,对模型在每个类别的预测结果上进行了一个求和来计算损失;在反向传播时,通过对损失函数求导计算梯度,并将梯度传递到之前的网络结构的函数中,来通过调整函数中的权重参数实现降低损失,从而使模型学会预测正确的类别;训练的结束条件为以下两种中的任意一种:(1)设定为迭代固定次数的全量数据训练过程后停止;(2)设定为损失值降低到损失阈值或者K个训练迭代后损失值未降低时停止训练,K为预设参数。
9.根据权利要求8所述的一种基于提示学习的快速图片分类方法,其特征在于,所述步骤S3具体为:
使用上述训练最终保存下来的提示向量,通过多模态模型的文本编码器生成各类别文本向量,然后与图片编码器生成的图片向量计算相似度,相似度最高的类别就是该图片的预测类别。
CN202210062188.7A 2022-01-20 2022-01-20 一种基于提示学习的快速图片分类方法 Active CN114090780B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202210062188.7A CN114090780B (zh) 2022-01-20 2022-01-20 一种基于提示学习的快速图片分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202210062188.7A CN114090780B (zh) 2022-01-20 2022-01-20 一种基于提示学习的快速图片分类方法

Publications (2)

Publication Number Publication Date
CN114090780A true CN114090780A (zh) 2022-02-25
CN114090780B CN114090780B (zh) 2022-05-31

Family

ID=80308663

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202210062188.7A Active CN114090780B (zh) 2022-01-20 2022-01-20 一种基于提示学习的快速图片分类方法

Country Status (1)

Country Link
CN (1) CN114090780B (zh)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114912522A (zh) * 2022-05-11 2022-08-16 北京百度网讯科技有限公司 信息分类方法和装置
CN116304066A (zh) * 2023-05-23 2023-06-23 中国人民解放军国防科技大学 一种基于提示学习的异质信息网络节点分类方法
CN116778264A (zh) * 2023-08-24 2023-09-19 鹏城实验室 基于类增学习的对象分类方法、图像分类方法及相关设备
CN116844161A (zh) * 2023-09-04 2023-10-03 深圳市大数据研究院 一种基于分组提示学习的细胞检测分类方法及系统
CN117689961A (zh) * 2024-02-02 2024-03-12 深圳大学 视觉识别模型训练、视觉识别方法、系统、终端及介质

Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106991197A (zh) * 2017-05-30 2017-07-28 海南大学 一种基于知识图谱的目标驱动的学习点和学习路径推荐方法
CN109190680A (zh) * 2018-08-11 2019-01-11 复旦大学 基于深度学习的医疗药品图像的检测与分类方法
CN109934261A (zh) * 2019-01-31 2019-06-25 中山大学 一种知识驱动参数传播模型及其少样本学习方法
CN110046656A (zh) * 2019-03-28 2019-07-23 南京邮电大学 基于深度学习的多模态场景识别方法
JP2020173415A (ja) * 2019-04-09 2020-10-22 株式会社スプリングボード 教材提示システム及び教材提示方法
CN112149564A (zh) * 2020-09-23 2020-12-29 上海交通大学烟台信息技术研究院 一种基于小样本学习的面容分类识别系统
CN112633419A (zh) * 2021-03-09 2021-04-09 浙江宇视科技有限公司 小样本学习方法、装置、电子设备和存储介质
CN113313084A (zh) * 2021-07-28 2021-08-27 中国航空油料集团有限公司 一种基于深度学习的睡岗检测方法
CN113570512A (zh) * 2021-02-01 2021-10-29 腾讯科技(深圳)有限公司 一种图像数据处理方法、计算机及可读存储介质
CN113673242A (zh) * 2021-08-20 2021-11-19 之江实验室 一种基于k邻近结点算法和对比学习的文本分类方法
CN113837309A (zh) * 2021-02-08 2021-12-24 宏龙科技(杭州)有限公司 一种基于变分自编码器的文本分类方法
CN113887627A (zh) * 2021-09-30 2022-01-04 北京百度网讯科技有限公司 噪音样本的识别方法、装置、电子设备以及存储介质

Patent Citations (12)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN106991197A (zh) * 2017-05-30 2017-07-28 海南大学 一种基于知识图谱的目标驱动的学习点和学习路径推荐方法
CN109190680A (zh) * 2018-08-11 2019-01-11 复旦大学 基于深度学习的医疗药品图像的检测与分类方法
CN109934261A (zh) * 2019-01-31 2019-06-25 中山大学 一种知识驱动参数传播模型及其少样本学习方法
CN110046656A (zh) * 2019-03-28 2019-07-23 南京邮电大学 基于深度学习的多模态场景识别方法
JP2020173415A (ja) * 2019-04-09 2020-10-22 株式会社スプリングボード 教材提示システム及び教材提示方法
CN112149564A (zh) * 2020-09-23 2020-12-29 上海交通大学烟台信息技术研究院 一种基于小样本学习的面容分类识别系统
CN113570512A (zh) * 2021-02-01 2021-10-29 腾讯科技(深圳)有限公司 一种图像数据处理方法、计算机及可读存储介质
CN113837309A (zh) * 2021-02-08 2021-12-24 宏龙科技(杭州)有限公司 一种基于变分自编码器的文本分类方法
CN112633419A (zh) * 2021-03-09 2021-04-09 浙江宇视科技有限公司 小样本学习方法、装置、电子设备和存储介质
CN113313084A (zh) * 2021-07-28 2021-08-27 中国航空油料集团有限公司 一种基于深度学习的睡岗检测方法
CN113673242A (zh) * 2021-08-20 2021-11-19 之江实验室 一种基于k邻近结点算法和对比学习的文本分类方法
CN113887627A (zh) * 2021-09-30 2022-01-04 北京百度网讯科技有限公司 噪音样本的识别方法、装置、电子设备以及存储介质

Non-Patent Citations (2)

* Cited by examiner, † Cited by third party
Title
MOHD. FARHAN ISRAK SOUMIK ET AL.: ""Improved Transfer Learning Based Deep Learning Model For Breast Cancer Histopathological Image Classification"", 《INTERNATIONAL CONFERENCE ON AUTOMATION, CONTROL AND MECHATRONICS FOR INDUSTRY 4.0》 *
汪旗 等: ""基于多示例学习的图像分类算法"", 《计算机技术与发展》 *

Cited By (10)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114912522A (zh) * 2022-05-11 2022-08-16 北京百度网讯科技有限公司 信息分类方法和装置
CN114912522B (zh) * 2022-05-11 2024-04-05 北京百度网讯科技有限公司 信息分类方法和装置
CN116304066A (zh) * 2023-05-23 2023-06-23 中国人民解放军国防科技大学 一种基于提示学习的异质信息网络节点分类方法
CN116304066B (zh) * 2023-05-23 2023-08-22 中国人民解放军国防科技大学 一种基于提示学习的异质信息网络节点分类方法
CN116778264A (zh) * 2023-08-24 2023-09-19 鹏城实验室 基于类增学习的对象分类方法、图像分类方法及相关设备
CN116778264B (zh) * 2023-08-24 2023-12-12 鹏城实验室 基于类增学习的对象分类方法、图像分类方法及相关设备
CN116844161A (zh) * 2023-09-04 2023-10-03 深圳市大数据研究院 一种基于分组提示学习的细胞检测分类方法及系统
CN116844161B (zh) * 2023-09-04 2024-03-05 深圳市大数据研究院 一种基于分组提示学习的细胞检测分类方法及系统
CN117689961A (zh) * 2024-02-02 2024-03-12 深圳大学 视觉识别模型训练、视觉识别方法、系统、终端及介质
CN117689961B (zh) * 2024-02-02 2024-05-07 深圳大学 视觉识别模型训练、视觉识别方法、系统、终端及介质

Also Published As

Publication number Publication date
CN114090780B (zh) 2022-05-31

Similar Documents

Publication Publication Date Title
CN114090780B (zh) 一种基于提示学习的快速图片分类方法
CN110609891B (zh) 一种基于上下文感知图神经网络的视觉对话生成方法
CN109308318B (zh) 跨领域文本情感分类模型的训练方法、装置、设备及介质
CN113656570B (zh) 基于深度学习模型的视觉问答方法及装置、介质、设备
CN111708882B (zh) 基于Transformer的中文文本信息缺失的补全方法
CN110321563B (zh) 基于混合监督模型的文本情感分析方法
CN114298158A (zh) 一种基于图文线性组合的多模态预训练方法
CN110188195B (zh) 一种基于深度学习的文本意图识别方法、装置及设备
CN111639186B (zh) 动态嵌入投影门控的多类别多标签文本分类模型及装置
CN110795549B (zh) 短文本对话方法、装置、设备及存储介质
CN113268609A (zh) 基于知识图谱的对话内容推荐方法、装置、设备及介质
WO2023137911A1 (zh) 基于小样本语料的意图分类方法、装置及计算机设备
CN111966812A (zh) 一种基于动态词向量的自动问答方法和存储介质
CN111274790A (zh) 基于句法依存图的篇章级事件嵌入方法及装置
CN114020906A (zh) 基于孪生神经网络的中文医疗文本信息匹配方法及系统
CN116049387A (zh) 一种基于图卷积的短文本分类方法、装置、介质
CN112988970A (zh) 一种服务于智能问答系统的文本匹配算法
CN114417872A (zh) 一种合同文本命名实体识别方法及系统
CN117313728A (zh) 实体识别方法、模型训练方法、装置、设备和存储介质
CN113869005A (zh) 一种基于语句相似度的预训练模型方法和系统
US20240037335A1 (en) Methods, systems, and media for bi-modal generation of natural languages and neural architectures
Wakchaure et al. A scheme of answer selection in community question answering using machine learning techniques
KR102458783B1 (ko) 일반화된 제로샷 객체 인식 장치 및 일반화된 제로샷 객체 인식 방법
CN114357166A (zh) 一种基于深度学习的文本分类方法
CN114003773A (zh) 一种基于自构建多场景的对话追踪方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant