CN113902256A - 训练标签预测模型的方法、标签预测方法和装置 - Google Patents
训练标签预测模型的方法、标签预测方法和装置 Download PDFInfo
- Publication number
- CN113902256A CN113902256A CN202111059586.5A CN202111059586A CN113902256A CN 113902256 A CN113902256 A CN 113902256A CN 202111059586 A CN202111059586 A CN 202111059586A CN 113902256 A CN113902256 A CN 113902256A
- Authority
- CN
- China
- Prior art keywords
- sample
- label
- support
- query
- predicted
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 63
- 238000012549 training Methods 0.000 title claims abstract description 57
- 239000013598 vector Substances 0.000 claims description 51
- 238000012545 processing Methods 0.000 claims description 34
- 230000007246 mechanism Effects 0.000 claims description 33
- 238000013507 mapping Methods 0.000 claims description 20
- 230000006399 behavior Effects 0.000 claims description 8
- 238000004458 analytical method Methods 0.000 claims description 2
- 238000013528 artificial neural network Methods 0.000 description 10
- 238000010586 diagram Methods 0.000 description 10
- 230000006870 function Effects 0.000 description 8
- 230000008569 process Effects 0.000 description 8
- 238000003860 storage Methods 0.000 description 8
- 238000010606 normalization Methods 0.000 description 5
- 239000000126 substance Substances 0.000 description 5
- 230000000694 effects Effects 0.000 description 4
- 238000004590 computer program Methods 0.000 description 3
- 238000013527 convolutional neural network Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 230000002776 aggregation Effects 0.000 description 2
- 238000004220 aggregation Methods 0.000 description 2
- 230000008901 benefit Effects 0.000 description 2
- 238000004364 calculation method Methods 0.000 description 2
- 238000009826 distribution Methods 0.000 description 2
- 238000012423 maintenance Methods 0.000 description 2
- 238000005259 measurement Methods 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000005070 sampling Methods 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 238000013459 approach Methods 0.000 description 1
- 238000013473 artificial intelligence Methods 0.000 description 1
- 210000004556 brain Anatomy 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000004883 computer application Methods 0.000 description 1
- 238000011161 development Methods 0.000 description 1
- 238000011478 gradient descent method Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000002372 labelling Methods 0.000 description 1
- 238000012986 modification Methods 0.000 description 1
- 230000004048 modification Effects 0.000 description 1
- 238000003058 natural language processing Methods 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 239000013307 optical fiber Substances 0.000 description 1
- 238000005457 optimization Methods 0.000 description 1
- 230000000750 progressive effect Effects 0.000 description 1
- 230000000135 prohibitive effect Effects 0.000 description 1
- 230000000306 recurrent effect Effects 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
- 238000006467 substitution reaction Methods 0.000 description 1
- 230000009466 transformation Effects 0.000 description 1
- 238000012795 verification Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q10/00—Administration; Management
- G06Q10/06—Resources, workflows, human or project management; Enterprise or organisation planning; Enterprise or organisation modelling
- G06Q10/063—Operations research, analysis or management
- G06Q10/0635—Risk analysis of enterprise or organisation activities
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q10/00—Administration; Management
- G06Q10/04—Forecasting or optimisation specially adapted for administrative or management purposes, e.g. linear programming or "cutting stock problem"
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q10/00—Administration; Management
- G06Q10/06—Resources, workflows, human or project management; Enterprise or organisation planning; Enterprise or organisation modelling
- G06Q10/063—Operations research, analysis or management
- G06Q10/0639—Performance analysis of employees; Performance analysis of enterprise or organisation operations
- G06Q10/06393—Score-carding, benchmarking or key performance indicator [KPI] analysis
Landscapes
- Engineering & Computer Science (AREA)
- Business, Economics & Management (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Human Resources & Organizations (AREA)
- General Physics & Mathematics (AREA)
- Strategic Management (AREA)
- Economics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Entrepreneurship & Innovation (AREA)
- Life Sciences & Earth Sciences (AREA)
- General Engineering & Computer Science (AREA)
- Health & Medical Sciences (AREA)
- Development Economics (AREA)
- Software Systems (AREA)
- Biophysics (AREA)
- Mathematical Physics (AREA)
- Molecular Biology (AREA)
- General Health & Medical Sciences (AREA)
- Computing Systems (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Marketing (AREA)
- Game Theory and Decision Science (AREA)
- Operations Research (AREA)
- Quality & Reliability (AREA)
- Tourism & Hospitality (AREA)
- General Business, Economics & Management (AREA)
- Educational Administration (AREA)
- Probability & Statistics with Applications (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本说明书实施例提供了一种训练标签预测模型的方法、标签预测方法和装置。首先获取样本集合,所述样本集合中的各样本包括对象的特征数据以及对该对象标注的标签;然后从所述样本集合中确定支持集合和查询集合;再利用所述支持集合和查询集合训练标签预测模型;其中,将所述支持集合和查询集合中的查询样本输入所述标签预测模型,由所述标签预测模型利用输入的查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签;训练目标为最小化预测结果与查询样本被标注的标签之间的差异。
Description
技术领域
本说明书一个或多个实施例涉及计算机应用技术中的人工智能技术领域,特别涉及训练标签预测模型的方法、标签预测方法和装置。
背景技术
在很多业务场景下,很多识别类型通常只具有少量的样本。例如在风控技术场景下,攻击手法已经越来越多的呈现出多样化、小批量和频繁突发的趋势。许多风险类型在实际业务中样本数量非常少。因此就亟需基于小样本进行类别标签学习和预测的方法。
发明内容
本说明书一个或多个实施例描述了一种训练标签预测模型的方法、标签预测方法和装置,以便于实现基于小样本的标签学习和预测。
根据第一方面,提供了一种训练标签预测模型的方法,包括:
获取样本集合,所述样本集合中的各样本包括对象的特征数据以及对该对象标注的标签;
从所述样本集合中确定支持集合和查询集合;
利用所述支持集合和查询集合训练标签预测模型;其中,将所述支持集合和查询集合中的查询样本输入所述标签预测模型,由所述标签预测模型利用输入的查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签;训练目标为最小化预测结果与查询样本被标注的标签之间的差异。
在一个实施例中,所述样本集合包括一个以上类型的标签预测任务的样本集合;
利用所述支持集合和查询集合训练标签预测模型包括:
交替、顺序或随机选取各标签预测任务,利用所选取标签预测任务的支持集合和查询集合对所述标签预测模型进行迭代更新,直至达到预设的训练停止条件。
在另一个实施例中,所述标签预测模型包括Transformer网络和预测网络;
所述Transformer网络用以通过注意力机制对输入的各样本中的特征数据进行处理,得到各样本的特征向量表示;
所述预测网络用以将所述输入的查询样本与支持集合中各支持样本之间的特征相似度分别作为各支持样本的映射权重,利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对该查询样本标签的预测结果;其中,查询样本与支持样本之间的特征相似度由查询样本的特征向量表示和支持样本的特征向量表示之间的距离确定。
在一个实施例中,所述注意力机制包括多头注意力机制。
在另一个实施例中,该方法应用于风险识别;
各样本包括用户的行为特征数据以及对该用户标注的风险信息标签;其中所述风险信息标签包括:是否具有预设类型风险的标签、预设类型风险的等级标签或者风险类型标签。
根据第二方面,提供了一种标签预测方法,包括:
获取待预测对象的特征数据以及确定支持集合,所述支持集合中的各支持样本包括样本对象的特征数据以及对该样本对象标注的标签;
将所述待预测对象的特征数据输入标签预测模型,由所述标签预测模型利用所述待预测对象与所述支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测所述待预测对象的标签。
在一个实施例中,所述确定支持集合包括:
确定本次预测任务的类型,确定与本次预测任务的类型对应的支持集合。
在另一个实施例中,所述标签预测模型包括Transformer网络和预测网络;
所述Transformer网络用以通过注意力机制对所述待预测对象的特征数据以及各支持样本中的特征数据进行处理,得到待预测对象的特征向量表示以及各支持样本的特征向量表示;
所述预测网络用以利用待预测对象的特征向量表示和各支持样本的特征向量表示之间的距离,确定待预测对象与各支持样本之间的特征相似度;利用待预测对象与各支持样本之间的特征相似度分别确定各支持样本的映射权重;利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对所述待预测对象标签的预测结果。
在一个实施例中,所述注意力机制包括多头注意力机制。
在另一个实施例中,该方法应用于风险识别;
所述待识别对象的特征数据包括用户的行为特征数据;
所述待预测对象的标签包括用户的风险信息标签;其中所述风险信息标签包括是否具有预设类型风险的标签、预设类型风险的等级标签或者风险类型标签。
根据第三方面,提供了一种训练标签预测模型的装置,包括:
样本获取单元,被配置为获取样本集合,所述样本集合中的各样本包括对象的特征数据以及对该对象标注的标签;
样本确定单元,被配置为从所述样本集合中确定支持集合和查询集合;
模型训练单元,被配置为利用所述支持集合和查询集合训练标签预测模型;其中,将所述支持集合和查询集合中的查询样本输入所述标签预测模型,由所述标签预测模型利用输入的查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签;训练目标为最小化预测结果与查询样本被标注的标签之间的差异。
在一个实施例中,所述样本集合包括一个以上类型的标签预测任务的样本集合;
所述模型训练单元,具体被配置为交替、顺序或随机选取各标签预测任务,利用所选取标签预测任务的支持集合和查询集合对所述标签预测模型进行迭代更新,直至达到预设的训练停止条件。
在另一个实施例中,所述标签预测模型包括Transformer网络和预测网络;
所述Transformer网络用以通过注意力机制对输入的各样本中的特征数据进行处理,得到各样本的特征向量表示;
所述预测网络用以将所述输入的查询样本与支持集合中各支持样本之间的特征相似度分别作为各支持样本的映射权重,利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对该查询样本标签的预测结果;其中,查询样本与支持样本之间的特征相似度由查询样本的特征向量表示和支持样本的特征向量表示之间的距离确定。
在一个实施例中,所述注意力机制包括多头注意力机制。
在另一个实施例中,该装置应用于风险识别;
各样本包括用户的行为特征数据以及对该用户标注的风险信息标签;其中所述风险信息标签包括:是否具有预设类型风险的标签、预设类型风险的等级标签或者风险类型标签。
根据第四方面,提供了一种标签预测装置,包括:
数据获取单元,被配置为获取待预测对象的特征数据;
集合确定单元,被配置为确定支持集合,所述支持集合中的各支持样本包括样本对象的特征数据以及对该样本对象标注的标签;
标签预测单元,被配置为将所述待预测对象的特征数据输入标签预测模型,由所述标签预测模型利用所述待预测对象与所述支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测所述待预测对象的标签。
在一个实施例中,所述集合确定单元,具体被配置为确定本次预测任务的类型,确定与本次预测任务的类型对应的支持集合。
在另一个实施例中,所述标签预测模型包括Transformer网络和预测网络;
所述Transformer网络用以通过注意力机制对所述待预测对象的特征数据以及各支持样本中的特征数据进行处理,得到待预测对象的特征向量表示以及各支持样本的特征向量表示;
所述预测网络用以利用待预测对象的特征向量表示和各支持样本的特征向量表示之间的距离,确定待预测对象与各支持样本之间的特征相似度;利用待预测对象与各支持样本之间的特征相似度分别确定各支持样本的映射权重;利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对所述待预测对象标签的预测结果。
在一个实施例中,所述注意力机制包括多头注意力机制。
根据第五方面,提供了一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现第一方面的方法。
根本说明书实施例提供的方法和装置适用于小样本的标签学习,具有较好的模型效果。
附图说明
为了更清楚地说明本发明实施例或现有技术中的技术方案,下面将对实施例或现有技术描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图是本发明的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1示出了根据一个实施例的训练标签预测模型的方法流程图;
图2示出了根据一个实施例的标签预测模型的示意图;
图3示出了根据一个实施例的Transformer网络的示意图;
图4示出了根据一个实施例的标签预测方法的流程图;
图5示出了根据另一个实施例的标签预测模型的示意图;
图6示出了根据一个实施例的训练标签预测模型的装置结构图;
图7示出了根据一个实施例的标签预测装置结构图。
具体实施方式
下面结合附图,对本说明书提供的方案进行描述。
传统基于小样本的标签预测模型的训练方法主要采用以下两种:
第一种是在大规模样本上进行预训练,然后针对小样本采用fine-tune(精调)方式进行优化。但这种方法需要依赖大规模样本来学习最优模型,小样本上模型学习到的变化较小,对新类型的拟合较差。
第二种是在大规模样本上训练得到的模型基础上,重新针对各新类型训练新的模型。但这种方式对已有类型的数据信息遗忘较多,并且必须与旧版模型(即在大规模样本上训练得到的模型)配合使用。
显然,传统的两种方式均存在针对小样本学习的模型效果差的问题。而本说明书提供的方式则与传统实现方式的思路完全不同,引入支持集合(support set)的概念,基于度量学习的方式实现针对小样本学习的标签预测模型。下面描述本说明书所提供方法的具体实现方式。
图1示出了根据一个实施例的训练标签预测模型的方法流程图。可以理解,该方法可以通过任何具有计算、处理能力的装置、设备、平台、设备集群来执行。如图1所示,该方法可以包括:
步骤101,获取样本集合,该样本集合中的各样本包括对象的特征数据以及对该对象标注的标签。
步骤103,从样本集合中确定支持集合和查询集合。
步骤105,利用支持集合和查询集合训练标签预测模型;其中,将支持集合和查询集合中的查询样本输入标签预测模型,由标签预测模型利用输入的查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签;训练目标为最小化预测结果与查询样本被标注的标签之间的差异。
在图1所示的方法能够从小样本中确定支持集合,利用度量学习的方式学习查询样本和支持样本之间特征相似度对标签预测的影响,适用于小样本的标签学习,具有较好的模型效果。
下面描述图1所示的各个步骤的执行方式。首先结合实施例对上述步骤101即“获取样本集合”进行详细描述。
在很多标签预测的实际业务场景中,可能存在N种不同类型的标签预测,但每一种类型的样本数量很少即小样本。举个例子,在风险识别场景下存在N种不同类型的风险,每种类型的黑样本量为m,其中黑样本即风险标签的样本数量少,白样本数量相对较多。
如果N的取值为1,则直接获取该类型下的样本集合即可,样本集合T表示为:
其中,xi表示对象的特征数据,例如在上述风险识别场景下,对象可以为用户,对象的特征数据可以为用户的行为特征数据。yi表示对该对象标注的标签,例如可以采用标签1表示有风险,是风险用户;采用标签0表示无风险,是无风险用户。除了采用标签来表示是否具有风险之外,还可以采用标签来表示标签等级,例如高风险、中风险、低风险和无风险等,甚至还可以是风险类型标签。
但在很多场景下,上述N的取值大于1,风险类型可能还很多即N较大。如果为每种不同类型的标签预测任务均单独构建模型,则成本过高。本申请的建模目标是希望在不改变模型结构和参数的条件下,使得模型在多种风险类型的标签预测上都适用。这种情况下,整个样本集合T就可以表示为:
T={T1,T2,…,TN} (2)
其中,T1表示第一种类型的标签预测任务对应的样本集合,T2表示第二中类型的标签预测任务对应的样本集合,依次类推。每一个样本集合中均包含表示为(xi,yi)的样本。
另外,以风险识别为例,为了均衡样本集合中的黑白样本,可以对大量的无风险样本集即白样本进行合理范围内的降采样,例如选择黑白样本比例在1:100左右,或者使用更少的白样本。
下面结合实施例对上述步骤103即“从样本集合中确定支持集合和查询集合”进行详细描述。
在本步骤中,可以将样本集合中的一部分作为支持集合,另一部分作为查询集合。
作为一种优选的实施方式,可以基于专家经验从样本集合中筛选出高质量的样本构成支持集合S表示为:
其中支持集合中支持样本的数量k可以依赖人为经验设置,也可以依据一定的比例设置。
若存在N种不同类型,则可以从各类型的样本集合中分别确定各类型的支持集合和查询集合。
下面结合实施例对上述步骤105即“利用支持集合和查询集合训练标签预测模型”进行详细描述。
首先对标签预测模型的实现机理进行描述。本说明书实施例中提供的标签预测模型的每次输入为一个查询样本(表示为)以及支持集合S。标签预测模型利用查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签。训练目标为最小化预训练结果与查询样本标注的标签之间的差异。
作为一种优选的实施方式,上述标签预测模型的结构可以如图2中所示,主要包括Transformer网络和预测网络。
其中,Transformer网络用以通过注意力机制对输入的各样本中的特征数据进行处理,得到各样本的特征向量表示。
Transformer是Google Brain团队在2017年提出的神经网络模型,其主要用于解决自然语言处理相关问题。相比于CNN(Convolutional Neural Networks,卷积神经网络)、RNN(Recurrent Neural Network,循环神经网络)等其他神经网络而言,Transformer使用注意力(Attention)机制,不受序列结构的限制,训练推理过程的并行化程度更高。在本说明书实施例中优选Transformer来得到各样本的特征向量表示,但除了Transformer之外,也可以采用诸如CNN、RNN等其他神经网络进行特征向量表示的提取。
Transformer网络是目前已经广泛采用的比较成熟的网络,如图3中所示,主要由注意力网络和前馈神经网络两个主要模块组成,在模块之间由归一化与残差模块连接。
其中,Transformer网络首先对输入的样本的特征数据进行embedding(嵌入)处理,可以包括诸如字符嵌入、位置嵌入、段落嵌入等等。目的是提取出特征的原始特征表示。该部分不做详述。
其中,Wi Q、Wi K、Wi V是对X进行变换编码的模型参数,编码后每一个样本都有其对应的query、key、value向量,分别表示为Q、K和V。上述的Attention()可以表示为:
但作为一种优选的实施方式,上述注意力网络可以采用多头注意力机制,多头注意力机制能够使得任意两个样本之间都进行关联聚合,且为不同的特征维度应用不同的聚合方式。这种情况下,注意力网络的作用可以表示为:
Multihead(Q,K,V)=concat(head1,head2,...headh) (5)
其中,Concat()表示进行拼接。WO是多头注意力处理中的一个模型参数。h为注意力头的数量,可以采用经验值。
headi=Attention(XWi Q,XWi K,XWi V) (6)
通常单个样本独立提取特征向量的表示随机性强,存在偏差,这对于数量本身就少的小样本场景而言更不可接受。而本说明书通过Transformer网络采用上述注意力网络,考虑支持集合内所有样本的表示之间的相对关系,调整向量以减小偏差。
注意力网络的输出经过归一化与残差模块的归一化处理,再利用原始特征表示进行残差处理,得到经过注意力网络处理后的各样本的特征向量表示,经过前馈神经网络的非线性变换,将各样本的特征向量表示映射到通用的度量空间。前馈神经网络的处理可以表示为:
FFN(z)=max(0,zW1+b1)W2+b2 (7)
其中,z代表输入前馈神经网络的各样本的特征向量表示,W1、W2、b1和b2为前馈神经网络的模型参数。
经过前馈神经网络处理后,各样本的特征向量表示再经过归一化与残差模块的归一化处理,再利用输入前馈神经网络的特征向量表示进行残差处理,得到经过Transformer网络处理后的各样本的特征向量表示。为了简化上述各公式,将经过Transformer网络处理后得到的查询样本的特征向量表示记为将经过Transformer网络处理后得到的支持样本的特征向量表示记为
下面来描述标签预测模型中的预测网络。预测网络用以将输入的查询样本与支持集合中各支持样本之间的特征相似度分别作为各支持样本的映射权重,利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对该查询样本标签的预测结果;其中,查询样本与支持样本之间的特征相似度由查询样本的特征向量表示和支持样本的特征向量表示之间的距离确定。
预测网络所采用的机理可以表示为以下公式:
在实际训练的过程中,训练目标为最小化标签预测模型对查询样本的预测结果与查询样本被标注的标签之间的差异。即达到的目标是使得尽量与一致。可以依据该训练目标涉及损失函数,每一轮迭代利用损失函数的取值更新标签预测模型的模型参数,直至满足训练停止条件。其中训练停止条件可以包括诸如损失函数的值小于或等于预设阈值,迭代次数达到预设的次数阈值,等等。训练过程中可以采用诸如梯度下降法对标签预测模型进行更新。
以风险识别场景为例,可以将不同类型的风险识别任务分成两组:其中,一组任务作为训练集,用于训练标签预测模型,另一组任务作为验证集,用来验证标签预测模型的泛化能力。在每一个任务Ti中都包含支持集合和查询集合。在每一轮训练过程中,从中进行随机采样选择风险识别任务Ti,利用该风险识别任务的支持集合Si对查询集合Bi中的样本进行标签预测,计算损失函数。训练过程可以表示为如下公式:
可以看出,最终训练得到的标签预测模型是在所有类型的风险识别任务上共享的,并未针对各类型分别训练独立的模型。极大地降低了后续模型维护和运营的成本。
在上述实施例训练得到标签预测模型的基础上,图4示出了根据一个实施例的标签预测方法的流程图,可以理解,该方法可以通过任何具有计算、处理能力的装置、设备、平台、设备集群来执行。如图4所示,该方法可以包括:
步骤401:获取待预测对象的特征数据以及确定支持集合,支持集合中的各支持样本包括样本对象的特征数据以及对该样本对象标注的标签。
由于本说明书实施例提供的标签预测模型是所有类型的标签预测任务所共享的,因此,在实际对待预测对象进行标签预测的过程中,需要依据本次预测任务的类型,确定与本次预测任务的类型对应的支持集合,在标签预测模型中切换到该支持集合,即使得标签预测模型采用该支持集合对待预测对象进行标签预测。
步骤403:将待预测对象的特征数据输入标签预测模型,由标签预测模型利用待预测对象与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测待预测对象的标签。
标签预测模型的结构可以参见上述模型训练实施例中的相关描述,在此仅做简单描述,如图5中所示,包括Transformer网络和预测网络。
在进行标签预测的过程中,Transformer网络用以通过注意力机制对待预测对象的特征数据以及各支持样本中的特征数据进行处理,得到待预测对象的特征向量表示以及各支持样本的特征向量表示。
其中,Transformer网络可以采用普通注意力机制,也可以采用多头注意力机制。
假设待预测对象的特征数据表示为对应的支持集合S表示为 经过Transformer网络处理后待预测对象的特征向量表示记为经过Transformer网络处理后得到的支持样本的特征向量表示记为其中具体的处理原理参见模型训练实施例中的相关记载,在此不做赘述。
预测网络用以利用待预测对象的特征向量表示和各支持样本的特征向量表示之间的距离,确定待预测对象与各支持样本之间的特征相似度;利用待预测对象与各支持样本之间的特征相似度分别确定各支持样本的映射权重;利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对待预测对象标签的预测结果。
上述标签预测方法可以应用于多种应用场景,例如应用于风险识别、图像识别等等。
以风险识别为例,在利用各种风险类型的小样本数据训练得到标签预测模型后,当需要对某个用户进行风险识别时,上述待识别对象的特征数据包括用户的行为特征数据,上述待预测对象的标签可以包括用户的风险信息标签。其中上述风险信息标签包括是否具有预设类型风险的标签、预设类型风险的等级标签或者风险类型标签等。
上述对本说明书特定实施例进行了描述。其它实施例在所附权利要求书的范围内。在一些情况下,在权利要求书中记载的动作或步骤可以按照不同于实施例中的顺序来执行并且仍然可以实现期望的结果。另外,在附图中描绘的过程不一定要求示出的特定顺序或者连续顺序才能实现期望的结果。在某些实施方式中,多任务处理和并行处理也是可以的或者可能是有利的。
以上是对公开所提供方法进行的详细描述,下面结合实施例对本公开所提供的装置进行详细描述。
图6示出了根据一个实施例的训练标签预测模型的装置结构图,如图6中所示,该装置600可以包括:样本获取单元601、样本确定单元602和模型训练单元603。其中各组成单元的主要功能如下:
样本获取单元601,被配置为获取样本集合,样本集合中的各样本包括对象的特征数据以及对该对象标注的标签。
样本确定单元602,被配置为从样本集合中确定支持集合和查询集合。
模型训练单元603,被配置为利用支持集合和查询集合训练标签预测模型;其中,将支持集合和查询集合中的查询样本输入标签预测模型,由标签预测模型利用输入的查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签;训练目标为最小化预测结果与查询样本被标注的标签之间的差异。
其中,上述样本集合可以包括一个以上类型的标签预测任务的样本集合。
相应地,模型训练单元603,可以具体被配置为交替、顺序或随机选取各标签预测任务,利用所选取标签预测任务的支持集合和查询集合对标签预测模型进行迭代更新,直至达到预设的训练停止条件。
作为一种优选的实施方式,上述标签预测模型可以包括Transformer网络和预测网络。
其中,Transformer网络用以通过注意力机制对输入的各样本中的特征数据进行处理,得到各样本的特征向量表示。
预测网络用以将输入的查询样本与支持集合中各支持样本之间的特征相似度分别作为各支持样本的映射权重,利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对该查询样本标签的预测结果;其中,查询样本与支持样本之间的特征相似度由查询样本的特征向量表示和支持样本的特征向量表示之间的距离确定。
优选地,上述注意力机制包括多头注意力机制。
作为一种典型的应用场景,该装置可以应用于风险识别。这种情况下,上述的各样本包括用户的行为特征数据以及对该用户标注的风险信息标签;其中风险信息标签包括:是否具有预设类型风险的标签、预设类型风险的等级标签或者风险类型标签。
图7示出了根据一个实施例的标签预测装置结构图,如图7中所示,该装置700可以包括:数据获取单元701、集合确定单元702和标签预测单元703。
其中,各组成单元的主要功能如下:
数据获取单元701,被配置为获取待预测对象的特征数据;
集合确定单元702,被配置为确定支持集合,支持集合中的各支持样本包括样本对象的特征数据以及对该样本对象标注的标签;
标签预测单元703,被配置为将待预测对象的特征数据输入标签预测模型,由标签预测模型利用待预测对象与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测待预测对象的标签。
其中,集合确定单元702,可以具体被配置为确定本次预测任务的类型,确定与本次预测任务的类型对应的支持集合。
作为一种优选的实施方式,上述标签预测模型可以包括Transformer网络和预测网络。
其中,Transformer网络用以通过注意力机制对待预测对象的特征数据以及各支持样本中的特征数据进行处理,得到待预测对象的特征向量表示以及各支持样本的特征向量表示。
预测网络用以利用待预测对象的特征向量表示和各支持样本的特征向量表示之间的距离,确定待预测对象与各支持样本之间的特征相似度;利用待预测对象与各支持样本之间的特征相似度分别确定各支持样本的映射权重;利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对待预测对象标签的预测结果。
优选地,上述注意力机制包括多头注意力机制。
根据另一方面的实施例,还提供一种计算机可读存储介质,其上存储有计算机程序,当所述计算机程序在计算机中执行时,令计算机执行结合图1或图4所描述的方法。
根据再一方面的实施例,还提供一种计算设备,包括存储器和处理器,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现结合图1或图4所述的方法。
随着时间、技术的发展,计算机可读存储介质含义越来越广泛,计算机程序的传播途径不再受限于有形介质,还可以直接从网络下载等。可以采用一个或多个计算机可读存储介质的任意组合。计算机可读存储介质例如可以是——但不限于——电、磁、光、电磁、红外线、或半导体的系统、装置或器件,或者任意以上的组合。计算机可读存储介质的更具体的例子(非穷举的列表)包括:具有一个或多个导线的电连接、便携式计算机磁盘、硬盘、随机存取存储器(RAM)、只读存储器(ROM)、可擦式可编程只读存储器(EPROM或闪存)、光纤、便携式紧凑磁盘只读存储器(CD-ROM)、光存储器件、磁存储器件、或者上述的任意合适的组合。在本说明书中,计算机可读存储介质可以是任何包含或存储程序的有形介质,该程序可以被指令执行系统、装置或者器件使用或者与其结合使用。
上述的处理器可包括一个或多个单核处理器或多核处理器。处理器可包括任何一般用途处理器或专用处理器(如图像处理器、应用处理器基带处理器等)的组合。
本说明书中的各个实施例均采用递进的方式描述,各个实施例之间相同相似的部分互相参见即可,每个实施例重点说明的都是与其他实施例的不同之处。尤其,对于装置实施例而言,由于其基本相似于方法实施例,所以描述的比较简单,相关之处参见方法实施例的部分说明即可。
上述实施例所提供的技术方案可以具备以下优点:
1)通过将小样本构成支持集合,利用度量学习的方式学习查询样本和支持样本之间特征相似度对标签预测的影响,适用于小样本的标签学习,具有较好的模型效果。
2)无需针对各类型的标签预测任务分别训练独立的模型,而是训练统一的模型,在实际预测时仅需要依据待预测对象的标签预测任务类型切换支持集合即可,降低了模型的维护和运营成本。
3)将支持集合引入transformer网络,利用transformer网络的注意力机制在合理度量空间进行“快速记忆”,使模型对新型、历史风险案件都有较强的识别预测能力。
本领域技术人员应该可以意识到,在上述一个或多个示例中,本发明所描述的功能可以用硬件、软件、固件或它们的任意组合来实现。当使用软件实现时,可以将这些功能存储在计算机可读介质中或者作为计算机可读介质上的一个或多个指令或代码进行传输。
以上所述的具体实施方式,对本发明的目的、技术方案和有益效果进行了进一步详细说明,所应理解的是,以上所述仅为本发明的具体实施方式而已,并不用于限定本发明的保护范围,凡在本发明的技术方案的基础之上,所做的任何修改、等同替换、改进等,均应包括在本发明的保护范围之内。
Claims (13)
1.训练标签预测模型的方法,包括:
获取样本集合,所述样本集合中的各样本包括对象的特征数据以及对该对象标注的标签;
从所述样本集合中确定支持集合和查询集合;
利用所述支持集合和查询集合训练标签预测模型;其中,将所述支持集合和查询集合中的查询样本输入所述标签预测模型,由所述标签预测模型利用输入的查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签;训练目标为最小化预测结果与查询样本被标注的标签之间的差异。
2.根据权利要求1所述的方法,其中,所述样本集合包括一个以上类型的标签预测任务的样本集合;
利用所述支持集合和查询集合训练标签预测模型包括:
交替、顺序或随机选取各标签预测任务,利用所选取标签预测任务的支持集合和查询集合对所述标签预测模型进行迭代更新,直至达到预设的训练停止条件。
3.根据权利要求1所述的方法,其中,所述标签预测模型包括Transformer网络和预测网络;
所述Transformer网络用以通过注意力机制对输入的各样本中的特征数据进行处理,得到各样本的特征向量表示;
所述预测网络用以将所述输入的查询样本与支持集合中各支持样本之间的特征相似度分别作为各支持样本的映射权重,利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对该查询样本标签的预测结果;其中,查询样本与支持样本之间的特征相似度由查询样本的特征向量表示和支持样本的特征向量表示之间的距离确定。
4.根据权利要求1所述的方法,其中,所述注意力机制包括多头注意力机制。
5.根据权利要求1至4中任一项所述的方法,该方法应用于风险识别;
各样本包括用户的行为特征数据以及对该用户标注的风险信息标签;其中所述风险信息标签包括:是否具有预设类型风险的标签、预设类型风险的等级标签或者风险类型标签。
6.标签预测方法,包括:
获取待预测对象的特征数据以及确定支持集合,所述支持集合中的各支持样本包括样本对象的特征数据以及对该样本对象标注的标签;
将所述待预测对象的特征数据输入标签预测模型,由所述标签预测模型利用所述待预测对象与所述支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测所述待预测对象的标签。
7.根据权利要求6所述的方法,其中,所述确定支持集合包括:
确定本次预测任务的类型,确定与本次预测任务的类型对应的支持集合。
8.根据权利要求6所述的方法,其中,所述标签预测模型包括Transformer网络和预测网络;
所述Transformer网络用以通过注意力机制对所述待预测对象的特征数据以及各支持样本中的特征数据进行处理,得到待预测对象的特征向量表示以及各支持样本的特征向量表示;
所述预测网络用以利用待预测对象的特征向量表示和各支持样本的特征向量表示之间的距离,确定待预测对象与各支持样本之间的特征相似度;利用待预测对象与各支持样本之间的特征相似度分别确定各支持样本的映射权重;利用各支持样本的映射权重对各支持样本的标签进行加权处理,得到对所述待预测对象标签的预测结果。
9.根据权利要求8所述的方法,其中,所述注意力机制包括多头注意力机制。
10.根据权利要求6至9中任一项所述的方法,该方法应用于风险识别;
所述待识别对象的特征数据包括用户的行为特征数据;
所述待预测对象的标签包括用户的风险信息标签;其中所述风险信息标签包括是否具有预设类型风险的标签、预设类型风险的等级标签或者风险类型标签。
11.训练标签预测模型的装置,包括:
样本获取单元,被配置为获取样本集合,所述样本集合中的各样本包括对象的特征数据以及对该对象标注的标签;
样本确定单元,被配置为从所述样本集合中确定支持集合和查询集合;
模型训练单元,被配置为利用所述支持集合和查询集合训练标签预测模型;其中,将所述支持集合和查询集合中的查询样本输入所述标签预测模型,由所述标签预测模型利用输入的查询样本与支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测输入的查询样本的标签;训练目标为最小化预测结果与查询样本被标注的标签之间的差异。
12.标签预测装置,包括:
数据获取单元,被配置为获取待预测对象的特征数据;
集合确定单元,被配置为确定支持集合,所述支持集合中的各支持样本包括样本对象的特征数据以及对该样本对象标注的标签;
标签预测单元,被配置为将所述待预测对象的特征数据输入标签预测模型,由所述标签预测模型利用所述待预测对象与所述支持集合中各支持样本之间的特征相似度以及各支持样本的标签,预测所述待预测对象的标签。
13.一种计算设备,包括存储器和处理器,其特征在于,所述存储器中存储有可执行代码,所述处理器执行所述可执行代码时,实现权利要求1-10中任一项所述的方法。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111059586.5A CN113902256A (zh) | 2021-09-10 | 2021-09-10 | 训练标签预测模型的方法、标签预测方法和装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202111059586.5A CN113902256A (zh) | 2021-09-10 | 2021-09-10 | 训练标签预测模型的方法、标签预测方法和装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN113902256A true CN113902256A (zh) | 2022-01-07 |
Family
ID=79027551
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202111059586.5A Pending CN113902256A (zh) | 2021-09-10 | 2021-09-10 | 训练标签预测模型的方法、标签预测方法和装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113902256A (zh) |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114373484A (zh) * | 2022-03-22 | 2022-04-19 | 南京邮电大学 | 语音驱动的帕金森病多症状特征参数的小样本学习方法 |
CN115100731A (zh) * | 2022-08-10 | 2022-09-23 | 北京万里红科技有限公司 | 一种质量评价模型训练方法、装置、电子设备及存储介质 |
CN115965817A (zh) * | 2023-01-05 | 2023-04-14 | 北京百度网讯科技有限公司 | 图像分类模型的训练方法、装置及电子设备 |
CN116188995A (zh) * | 2023-04-13 | 2023-05-30 | 国家基础地理信息中心 | 一种遥感图像特征提取模型训练方法、检索方法及装置 |
-
2021
- 2021-09-10 CN CN202111059586.5A patent/CN113902256A/zh active Pending
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114373484A (zh) * | 2022-03-22 | 2022-04-19 | 南京邮电大学 | 语音驱动的帕金森病多症状特征参数的小样本学习方法 |
CN115100731A (zh) * | 2022-08-10 | 2022-09-23 | 北京万里红科技有限公司 | 一种质量评价模型训练方法、装置、电子设备及存储介质 |
CN115965817A (zh) * | 2023-01-05 | 2023-04-14 | 北京百度网讯科技有限公司 | 图像分类模型的训练方法、装置及电子设备 |
CN116188995A (zh) * | 2023-04-13 | 2023-05-30 | 国家基础地理信息中心 | 一种遥感图像特征提取模型训练方法、检索方法及装置 |
CN116188995B (zh) * | 2023-04-13 | 2023-08-15 | 国家基础地理信息中心 | 一种遥感图像特征提取模型训练方法、检索方法及装置 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
Paul et al. | Robust visual tracking by segmentation | |
CN111191791B (zh) | 基于机器学习模型的图片分类方法、装置及设备 | |
CN109145781B (zh) | 用于处理图像的方法和装置 | |
CN113902256A (zh) | 训练标签预测模型的方法、标签预测方法和装置 | |
CN114241282B (zh) | 一种基于知识蒸馏的边缘设备场景识别方法及装置 | |
CN111105008A (zh) | 模型训练方法、数据识别方法和数据识别装置 | |
CN111797893A (zh) | 一种神经网络的训练方法、图像分类系统及相关设备 | |
CN111340221B (zh) | 神经网络结构的采样方法和装置 | |
US20220067588A1 (en) | Transforming a trained artificial intelligence model into a trustworthy artificial intelligence model | |
CN111127364B (zh) | 图像数据增强策略选择方法及人脸识别图像数据增强方法 | |
KR101828215B1 (ko) | Long Short Term Memory 기반 순환형 상태 전이 모델의 학습 방법 및 장치 | |
CN113128478B (zh) | 模型训练方法、行人分析方法、装置、设备及存储介质 | |
CN111052128B (zh) | 用于检测和定位视频中的对象的描述符学习方法 | |
CN113688890A (zh) | 异常检测方法、装置、电子设备及计算机可读存储介质 | |
CN114974397A (zh) | 蛋白质结构预测模型的训练方法和蛋白质结构预测方法 | |
JP7331937B2 (ja) | ロバスト学習装置、ロバスト学習方法、プログラム及び記憶装置 | |
CN113920583A (zh) | 细粒度行为识别模型构建方法及系统 | |
CN111161238A (zh) | 图像质量评价方法及装置、电子设备、存储介质 | |
CN111260074A (zh) | 一种超参数确定的方法、相关装置、设备及存储介质 | |
CN113111996A (zh) | 模型生成方法和装置 | |
KR102413588B1 (ko) | 학습 데이터에 따른 객체 인식 모델 추천 방법, 시스템 및 컴퓨터 프로그램 | |
CN116208399A (zh) | 一种基于元图的网络恶意行为检测方法及设备 | |
CN113516182B (zh) | 视觉问答模型训练、视觉问答方法和装置 | |
Chenxin et al. | Searching parameterized AP loss for object detection | |
US20240020531A1 (en) | System and Method for Transforming a Trained Artificial Intelligence Model Into a Trustworthy Artificial Intelligence Model |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
RJ01 | Rejection of invention patent application after publication |
Application publication date: 20220107 |
|
RJ01 | Rejection of invention patent application after publication |