CN116975622A - 目标检测模型的训练方法及装置、目标检测方法及装置 - Google Patents
目标检测模型的训练方法及装置、目标检测方法及装置 Download PDFInfo
- Publication number
- CN116975622A CN116975622A CN202310292326.5A CN202310292326A CN116975622A CN 116975622 A CN116975622 A CN 116975622A CN 202310292326 A CN202310292326 A CN 202310292326A CN 116975622 A CN116975622 A CN 116975622A
- Authority
- CN
- China
- Prior art keywords
- sampling
- node
- nodes
- target
- sample
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000012549 training Methods 0.000 title claims abstract description 141
- 238000001514 detection method Methods 0.000 title claims abstract description 115
- 238000000034 method Methods 0.000 title claims abstract description 84
- 238000005070 sampling Methods 0.000 claims abstract description 404
- 238000003062 neural network model Methods 0.000 claims abstract description 101
- 230000004931 aggregating effect Effects 0.000 claims abstract description 19
- 230000002776 aggregation Effects 0.000 claims description 16
- 238000004220 aggregation Methods 0.000 claims description 16
- 238000010276 construction Methods 0.000 claims description 8
- 230000000694 effects Effects 0.000 abstract description 10
- 238000009826 distribution Methods 0.000 description 23
- 238000003860 storage Methods 0.000 description 23
- 230000002159 abnormal effect Effects 0.000 description 22
- 230000006870 function Effects 0.000 description 16
- 238000012545 processing Methods 0.000 description 14
- 238000010586 diagram Methods 0.000 description 10
- 230000008569 process Effects 0.000 description 10
- 238000004364 calculation method Methods 0.000 description 8
- 238000004590 computer program Methods 0.000 description 8
- 239000004973 liquid crystal related substance Substances 0.000 description 6
- 230000007246 mechanism Effects 0.000 description 6
- 239000013598 vector Substances 0.000 description 5
- 238000013473 artificial intelligence Methods 0.000 description 4
- 241001465754 Metazoa Species 0.000 description 3
- 238000013475 authorization Methods 0.000 description 3
- 238000004891 communication Methods 0.000 description 3
- 238000005516 engineering process Methods 0.000 description 3
- 238000007726 management method Methods 0.000 description 3
- 230000001133 acceleration Effects 0.000 description 2
- 230000004913 activation Effects 0.000 description 2
- 230000005540 biological transmission Effects 0.000 description 2
- 238000011161 development Methods 0.000 description 2
- 239000011159 matrix material Substances 0.000 description 2
- 230000003287 optical effect Effects 0.000 description 2
- 238000006116 polymerization reaction Methods 0.000 description 2
- 238000012546 transfer Methods 0.000 description 2
- 230000009466 transformation Effects 0.000 description 2
- 241000282693 Cercopithecidae Species 0.000 description 1
- 241000282326 Felis catus Species 0.000 description 1
- 241000406668 Loxodonta cyclotis Species 0.000 description 1
- 241000009328 Perro Species 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 238000007599 discharging Methods 0.000 description 1
- 230000006872 improvement Effects 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 239000004065 semiconductor Substances 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F16/00—Information retrieval; Database structures therefor; File system structures therefor
- G06F16/90—Details of database functions independent of the retrieved data types
- G06F16/901—Indexing; Data structures therefor; Storage structures
- G06F16/9024—Graphs; Linked lists
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/25—Fusion techniques
- G06F18/253—Fusion techniques of extracted features
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06Q—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES; SYSTEMS OR METHODS SPECIALLY ADAPTED FOR ADMINISTRATIVE, COMMERCIAL, FINANCIAL, MANAGERIAL OR SUPERVISORY PURPOSES, NOT OTHERWISE PROVIDED FOR
- G06Q40/00—Finance; Insurance; Tax strategies; Processing of corporate or income taxes
- G06Q40/08—Insurance
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Artificial Intelligence (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Business, Economics & Management (AREA)
- Accounting & Taxation (AREA)
- Databases & Information Systems (AREA)
- Software Systems (AREA)
- Finance (AREA)
- General Business, Economics & Management (AREA)
- Biomedical Technology (AREA)
- Marketing (AREA)
- Strategic Management (AREA)
- Technology Law (AREA)
- Development Economics (AREA)
- Health & Medical Sciences (AREA)
- Economics (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Image Analysis (AREA)
Abstract
本申请公开了一种目标检测模型的训练方法及装置、目标检测方法及装置。目标检测模型的训练方法通过获取训练样本数据;根据样本对象之间的关联关系构建以样本对象为节点的图网络;基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。该方法可以提升训练得到的模型的效果,进而可以提升目标检测的准确性。
Description
技术领域
本申请涉及人工智能技术领域,具体涉及一种目标检测模型的训练方法及装置、目标检测方法及装置。
背景技术
近年来,随着人工智能技术的不断发展,越来越多的领域开始使用人工智能技术来提升该领域的提升该领域的整体业务处理效率。例如,在医保检测领域,可以通过神经网络模型来检测医保账户是否存在异常,从而判定医保账户是正常账户还是异常账户。通过神经网络模型进行医保检测,可以大大提升医保检测的检测效率。
然而,在一些情况下,采用神经网络模型对医保账户进行异常检测时,存在检测结果不准确的问题。
发明内容
本申请实施例提供一种目标检测模型的训练方法及装置、目标检测方法及装置,该方法可以提升目标检测的准确性。
本申请第一方面提供一种目标检测模型的训练方法,方法包括:
获取训练样本数据,所述训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,所述标签数据包括多个标签类别;
根据所述样本对象之间的关联关系构建以所述样本对象为节点的图网络;
基于不同标签类别对应的样本对象的数量在所述图网络中进行节点采样,得到多个第一采样节点,并在所述第一采样节点的邻居节点中采样所述第一采样节点关联的第二采样节点;
将所述第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;
以所述目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的参数进行调整。
相应的,本申请第二方面提供一种目标检测模型的训练装置,所述装置包括:
第一获取单元,用于获取训练样本数据,所述训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,所述标签数据包括多个标签类别;
构建单元,用于根据所述样本对象之间的关联关系构建以所述样本对象为节点的图网络;
采样单元,用于基于不同标签类别对应的样本对象的数量在所述图网络中进行节点采样,得到多个第一采样节点,并在所述第一采样节点的邻居节点中采样所述第一采样节点关联的第二采样节点;
聚合单元,用于将所述第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;
训练单元,用于以所述目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
可选地,在一些实施例中,采样单元,包括:
第一获取子单元,用于获取所述图网络中每一标签类别对应的节点的数量;
第一计算子单元,用于基于每一标签类别对应的节点的数量与所述图网络中的节点总数计算每一节点的第一采样概率;
第一采样子单元,用于基于所述第一采样概率在所述图网络中进行节点采样,得到多个第一采样节点;
第二采样子单元,用于在所述第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点。
可选地,在一些实施例中,第二采样子单元,包括:
第一采样模块,用于获取所述第一采样节点的每一邻居节点的度,并根据每一邻居节点的度在所述第一采样节点的邻居节点中采样第一数量的第二采样节点;
第二采样模块,用于获取所述第一采样节点的每一邻居节点的节点标签,并根据每一邻居节点的节点标签与所述第一采样节点的节点标签之间的关系在所述第一采样节点的邻居节点中采样第二数量的第二采样节点。
可选地,在一些实施例中,第一采样模块,包括:
获取子模块,用于获取所述第一采样节点的每一邻居节点的度;
计算子模块,用于根据每一邻居节点的度计算每一邻居节点对应的第二采样概率;
采样子模块,用于根据所述第二采样概率在所述第一采样节点的邻居节点中采样第一数量的第二采样节点。
可选地,在一些实施例中,构建单元,包括:
第二获取子单元,用于获取每一样本对象的样本数据;
第二计算子单元,用于根据所述样本数据计算样本对象之间的关联关系;
构建子单元,用于以所述样本对象为节点,并以所述样本对象之间的关联关系为边构建图网络。
可选地,在一些实施例中,聚合单元,包括:
第三获取子单元,用于获取每一第二采样节点与所述第一采样节点之间的权重系数;
聚合子单元,用于基于所述权重系数将所述第二采样节点的节点信息聚合到所述第一采样节点中,得到目标节点特征。
可选地,在一些实施例中,训练单元,包括:
输入子单元,用于将所述目标节点特征输入至待训练的神经网络模型,得到所述神经网络模型输出的预测值;
第三计算子单元,用于根据所述预测值与所述目标节点特征对应的标签数据计算损失值,并基于所述损失值确定反传梯度;
更新子单元,用于根据所述反传梯度对所述神经网络模型的参数进行更新。
本申请第三方面提供了一种目标检测方法,该方法包括:
获取待检测的目标的目标数据;
对所述目标数据进行特征编码,得到目标特征;
将所述目标特征输入至神经网络模型中,得到所述神经网络模型输出的预测值,所述神经网络模型为根据第一方面所述的目标检测模型的训练方法训练得到的神经网络模型;
根据所述预测值与预设阈值的比对结果确定对所述目标进行检测的检测结果。
本申请第四方面提供了一种目标检测装置,该装置包括:
第二获取单元,用于获取待检测的目标的目标数据;
编码单元,用于对所述目标数据进行特征编码,得到目标特征;
检测单元,用于将所述目标特征输入至神经网络模型中,得到所述神经网络模型输出的预测值,所述神经网络模型为根据第一方面所述的目标检测模型的训练方法训练得到的神经网络模型;
确定单元,用于根据所述预测值与预设阈值的比对结果确定对所述目标进行检测的检测结果。
本申请第五方面还提供一种计算机可读存储介质,所述计算机可读存储介质存储有多条指令,所述指令适于处理器进行加载,以执行本申请第一方面所提供的目标检测模型的训练方法或第三方面提供的目标检测方法中的步骤。
本申请第六方面提供一种计算机设备,包括存储器、处理器以及存储在所述存储器中并可以在所述处理器上运行的计算机程序,所述处理器执行所述计算机程序时实现本申请第一方面所提供的目标检测模型的训练方法或第三方面提供的目标检测方法中的步骤。
本申请第七方面提供一种计算机程序产品,包括计算机程序/指令,所述计算机程序/指令被处理器执行时实现第一方面所提供的目标检测模型的训练方法或第三方面提供的目标检测方法中的步骤。
本申请实施例提供的目标检测模型的训练方法,通过获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;根据样本对象之间的关联关系构建以样本对象为节点的图网络;基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
以此,本申请提供的目标检测模型的训练方法,通过基于样本对象构建图网络,然后对图网络进行节点采样的方法,来避免样本分布不均衡导致的模型训练效果不好的问题。而且,本方法还进一步对采样得到的节点进行邻居节点采样,并且与采样到的邻居节点进行特征聚合,以获得更为准确的节点特征。然后采用更准确的节点特征进行神经网络模型的训练,从而可以大大提升训练得到的神经网络模型的准确性,进而可以提升目标检测的准确性。
附图说明
为了更清楚地说明本申请实施例中的技术方案,下面将对实施例描述中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。
图1是本申请中目标检测的一个场景示意图;
图2是本申请提供的目标检测模型的训练方法的一个流程示意图;
图3是本申请提供的目标检测模型的训练方法的另一流程示意图;
图4是本申请中构架的图网络的示意图;
图5是本申请提供的目标检测方法的流程示意图;
图6是本申请提供的目标检测模型的训练装置的结构示意图;
图7是本申请提供的目标检测装置的结构意图;
图8是本申请提供的计算机设备的结构示意图。
具体实施方式
下面将结合本发明实施例中的附图,对本发明实施例中的技术方案进行清楚、完整地描述。显然,所描述的实施例仅仅是本发明一部分实施例,而不是全部的实施例。基于本发明中的实施例,本领域技术人员在没有作出创造性劳动前提下所获得的所有其他实施例,都属于本发明保护的范围。
本发明实施例提供一种目标检测模型的训练方法及装置、目标检测方法及装置。其中,该目标检测模型的训练方法可以使用于目标检测模型的训练装置中。该目标检测模型的训练装置可以集成在计算机设备中,该计算机设备可以是终端也可以是服务器。其中,终端可以为手机、平板电脑、笔记本电脑、智能电视、穿戴式智能设备、个人计算机(PC,Personal Computer)等设备。服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、网络加速服务(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。其中,服务器可以为区块链中的节点。
请参阅图1,为本申请提供的目标检测模型的训练方法的一个场景示意图。如图所示,本申请提供的目标检测模型的训练方法可以应用在装载目标检测模型的训练装置的终端A中,终端A获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;根据样本对象之间的关联关系构建以样本对象为节点的图网络;基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
需要说明的是,图1所示的模型训练场景示意图仅仅是一个示例,本申请实施例描述的模型训练场景是为了更加清楚地说明本申请的技术方案,并不构成对于本申请提供的技术方案的限定。本领域普通技术人员可知,随着模型训练场景演变和新业务场景的出现,本申请提供的技术方案对于类似的技术问题,同样适用。
基于上述实施场景以下分别进行详细说明。
本申请实施例将从目标检测模型的训练装置的角度进行描述,该目标检测模型的训练装置可以集成在计算机设备中。其中,计算机设备可以是终端也可以是服务器。其中,终端可以为手机、平板电脑、笔记本电脑、智能电视、穿戴式智能设备、个人计算机(PC,Personal Computer)等设备。服务器可以是独立的物理服务器,也可以是多个物理服务器构成的服务器集群或者分布式系统,还可以是提供云服务、云数据库、云计算、云函数、云存储、网络服务、云通信、中间件服务、域名服务、安全服务、网络加速服务(Content DeliveryNetwork,CDN)、以及大数据和人工智能平台等基础云计算服务的云服务器。如图2所示,为本申请提供的目标检测模型的训练方法的流程示意图,该方法包括:
步骤101,获取训练样本数据。
其中,在相关技术中,为了避免异常医保账户冒充正常医保账户进行使用从而导致医疗资源损失的问题,一般在使用医保账户进行医疗费用支付之前,可以先对医保账户进行检测,以确认当前使用的医保账户是否为正常医保账户。目前,随着人工智能技术的发展,在医保账户检测任务中也引入了医保账户检测模型来对医保账户进行检测。然而,目前医保账户检测模型的检测准确率并不高,主要由于目前医保账户中正常的医保账户较多,而异常医保账户较少,导致用于对医保账户检测模型进行训练的样本数据中存在严重的样本标签分布不均衡的问题,从而影响模型对异常医保账户的识别,进而导致了医保账户检测模型的准确性较低。
为了解决上述对医保账户检测模型进行训练的样本数据中标签分布不均衡导致训练得到的医保账户检测模型不准确的问题,本申请提供了一种目标检测模型的训练方法,以期能够提升训练得到的神经网络模型的准确性。下面对本申请提供的目标检测模型的训练方法进行详细描述。
首先,在对神经网络模型进行训练之前,可以先对神经网络模型的训练样本数据进行获取。该训练样本数据中包括了多个样本对象,以及每个样本对象对应的标签数据,标签数据包括了多个标签类别。其中,此处样本对象可以为图像数据、音视频数据、文本数据以及账户数据等任意对象。样本对象对应的标签数据可以为分类标签数据,具体可以为二分类的标签,也可以为多分类的标签。例如,当样本对象为医保账户数据时,样本对象对应的标签数据便可以为正常或者不正常这两类。当样本对象为动物图像时,样本对象对应的标签数据便可以为猫、狗、猴子或者大象等多种类别。
其中,训练样本数据中不同标签类别的标签数据之间可以为均衡的标签分布,也可以为不均衡的标签分布。当样本对象为医保账户数据时,不同标签类别的标签数据之间分布便可能为不均衡的分布;当样本对象为动物图像时,不同标签类别的标签数据之间便可能为均衡的分布。在本申请实施例中,不同的标签类别的标签数据分布不均衡时,本申请提供的目标检测模型的训练方法可以得到更好的改善效果。
其中,当训练样本数据中样本对象为医保账户时,样本对象则包含了医保账户的相关账户信息,即包括用户的医疗信息等。可以理解的是,在本申请实施例中,相关信息都是在用户授权的基础上进行获取和使用的,用户信息的获取和使用是在遵循相关法律法规的基础上进行的。当训练样本数据中样本对象为医保账户数据时,相关数据可以为从公开的数据库中获取得到的。
步骤102,根据样本对象之间的关联关系构建以样本对象为节点的图网络。
其中,在获取到训练样本数据后,便可以根据训练样本数据获取样本对象之间的关联关系,然后基于该关联关系构建以样本对象为节点的图网络。
其中,在一些实施例中,根据样本对象之间的关联关系构建以样本对象为节点的图网络,包括:
1、获取每一样本对象的样本数据;
2、根据样本数据计算样本对象之间的关联关系;
3、以样本对象为节点,并以样本对象之间的关联关系为边构建图网络。
其中,在本申请实施例中,基于样本对象之间的关联关系构建以样本对象为节点的图网络的具体过程,可以为先对每一样本对象的样本数据进行获取。例如,当样本对象为医保账户时,样本对象的样本数据具体可以为医保账户的用户数据以及该用户的医疗数据。当样本对象为动物图像时,样本对象的样本数据可以为图像的来源数据、图像的分辨率数据等。
在获取到每一样本对象的样本数据后,便可以基于样本对象的样本数据对样本对象进行相似度计算,然后根据计算得到的相似度确定样本对象之间的关联关系。
具体地,当样本对象为医保账户,样本对象对应的样本数据为医保账户的用户数据以及医疗数据时,可以先将该样本对象对应的用户数据和医疗数据转化为相应的文本数据,然后对该文本数据进行特征编码,得到样本对象的样本特征,然后计算样本特征之间的相似度,得到样本对象之间的相似度。其中,样本对象的用户数据以及医疗数据可以包括用户与用户之间的关系、用户与医生之间的关系、用户与医院之间的关系以及医院与医院之间的关系等。其中,将样本对象对应的文本数据进行特征编码得到样本特征的具体过程,可以为采用词嵌入的方式将样本对象的文本数据转化为词向量。如此,计算样本对象对应的样本特征之间的相似度,便可以为计算样本对象对应的词向量之间的余弦相似度。
在计算得到样本对象之间的相似度后,便可以根据样本对象之间的相似度确定样本对象之间的关联关系。当样本对象之间的相似度较高时,则说明样本对象之间的关联关系更紧密,那么在图网络中样本对象之间的距离较近。当样本对象之间的相似度较低时,则说明样本对象之间的关联关系较微弱,那么在图网络中样本对象之间的距离较远。当样本对象之间的相似度小于某一阈值时,则可以认为样本对象之间不存在关联关系,即在图网络中该两个样本对象之间不存在连接边。
进一步地,在确定了样本对象之间的关联关系后,则可以将样本对象作为图网络的节点,并将节点之间的关联关系作为图网络的边构建相应的图网络。可以理解的是,图网络节点中不仅包含了节点的特征信息,还包含了节点的标签信息。当图网络节点为医保账户时,节点的标签信息则表明了该节点对应的医保账户为正常账户或异常账户。
步骤103,基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点。
其中,如前所述,在较多情况下,获取到的训练样本数据中存在着样本分布不均衡的问题。即当获取到的训练样本数据按照对应的标签数据的标签类别分为多个类别时,不同类别的训练样本数据之间的数量存在着较大的差距。而样本分布的不均衡会导致模型学习到的能力产生偏差,从而导致模型的训练效果下降。对此,在本申请实施例中,可以先基于不同标签类别对应的样本对象的数量,在图网络中进行节点采样,得到类别均衡的采样节点,此处可以称为第一采样节点。如此,可以避免不同类别之间的样本分布不均衡导致模型训练效果较差的问题。
进一步地,为了进一步提升模型的训练效果,可以在采样得到的第一采样节点附近继续进行邻居节点的采样,以便将采样得到的邻居节点的特征与采样得到的第一采样节点的特征进行聚合,得到第一采样节点的更为准确的节点特征。
其中,在一些实施例中,基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点,包括:
1、获取图网络中每一标签类别对应的节点的数量;
2、基于每一标签类别对应的节点的数量与图网络中的节点总数计算每一节点的第一采样概率;
3、基于第一采样概率在图网络中进行节点采样,得到多个第一采样节点;
4、在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点。
其中,在本申请实施例中,基于图网络中不同标签类别对应的样本对象的数量在图网络中进行节点采样的具体过程,可以为先根据图网络中每一节点对应的标签数据将节点进行分类,具体可以按照标签数据对应的标签类别来进行分类。
具体地,可以先根据图网络中每一节点对应的标签信息来获取每一标签类别对应的节点的数量,例如当样本对象为医保账户时,可以获取到正常医保账户对应的节点数量为第三数量,以及可以获取到异常医保账户对应的节点数量为第四数量。可以理解的是,一般情况下,正常医保账户的数量远远大于异常医保账户,因此一般情况下第三数量远远大于第四数量。
在获取到每一标签类别对应的节点的数量后,便可以进一步基于不同标签类别对应的节点数量与图网络中节点总数来计算每一节点的采样概率,此处可以称为第一采样概率。具体地,可以采用不同标签类别对应的节点数量与图网络中节点总数的比值来确定对该标签类别对应的节点进行采样的节点概率。例如正常医保账户的采样概率便可以采用前述第三数量与医保账户总数量(第三数量与第四数量之和)的比值来进行确定。可以理解的是,同一标签类别的节点具有相同的采样概率。
在确定了每一节点的采样概率(第一采样概率)后,便可以采用前述第一采样概率对图网络中的节点进行采样,得到多个第一采样节点。可以理解的是,多个第一采样节点中,不同标签类别的节点之间分布相对均衡。
进一步地,在采样得到多个第一采样节点后,还可以进一步在每一第一采样节点的邻居节点中进一步采样多个第二采样节点。
其中,在一些实施例中,在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点,包括:
4.1、获取第一采样节点的每一邻居节点的度,并根据每一邻居节点的度在第一采样节点的邻居节点中采样第一数量的第二采样节点;
4.2、获取第一采样节点的每一邻居节点的节点标签,并根据每一邻居节点的节点标签与第一采样节点的节点标签之间的关系在第一采样节点的邻居节点中采样第二数量的第二采样节点。
其中,在本申请实施例中,在采样得到多个第一采样节点后,进一步在每一第一采样节点的邻居节点中进行第二采样节点的采样,具体可以通过两种方式进行采样。具体地,可以根据邻居节点的度以及邻居节点的节点标签分别在邻居节点中进行节点采样。对于任一采样得到的第一采样节点,可以获取其每一邻居节点的度,然后根据每一邻居节点的度在第一采样节点的邻居节点中采样第一数量的第二采样节点。此外,还可以获取第一采样节点的每一邻居节点的节点标签,然后可以根据邻居节点的节点标签与第一采样节点的节点标签之间的关系在第一采样节点的邻居节点中采样第二数量个第二采样节点。具体地,根据邻居节点的节点标签与第一采样节点的节点标签之间的关系在第一采样节点的邻居节点中采样第二数量个第二采样节点的具体过程可以为确定邻居节点的节点标签是否与第一采样节点的节点标签相同,当邻居节点的节点标签与第一采样节点的节点标签相同时,则赋予其高的采样概率;当邻居节点的节点标签与第一采样节点的节点标签不同时,则赋予其低的采样概率。然后再分别基于赋予的采样概率进行节点采样,得到第二数量个第二采样节点。
然后,可以将第一数量个第二采样节点和第二数量个第二采样节点都作为在邻居节点中采样得到的第二采样节点。从而确定了每一第一采样节点对应的第二采样节点。
其中,在一些实施例中,获取第一采样节点的每一邻居节点的度,并根据每一邻居节点的度在第一采样节点的邻居节点中采样第一数量的第二采样节点,包括:
4.1.1、获取第一采样节点的每一邻居节点的度;
4.1.2、根据每一邻居节点的度计算每一邻居节点对应的第二采样概率;
4.1.3、根据第二采样概率在第一采样节点的邻居节点中采样第一数量的第二采样节点。
其中,在本申请实施例中,基于第一采样节点的邻居节点的度在第一采样节点的邻居节点中进行节点采样的具体过程,可以为对任一第一采样节点,先获取其每一邻居节点的度。其中,在图网络中,节点的度具体可以为与节点相邻的节点的数量,或者可以为与节点连接的边的数量。在获取到每一邻居节点的度后,便可以根据每一邻居节点的度计算每一邻居节点对应的采样概率,此处可以称为第二采样概率,然后基于第二采样概率对第一采样节点的邻居节点进行节点采样,得到第一数量的第二采样节点。
步骤104,将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征。
其中,在采样得到多个第一采样节点,并对每一采样节点的邻居节点进行进一步的节点采样,得到每一第一采样节点对应的多个第二采样节点后,可以将每一采样节点的节点特征和与其关联的多个第二采样节点的节点特征进行聚合,从而得到更为准确的节点特征,此处可以称为目标节点特征。可以理解的是,目标节点特征与第一采样节点相对应。
其中,在一些实施例中,将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征,包括:
1、获取每一第二采样节点与第一采样节点之间的权重系数;
2、基于权重系数将第二采样节点的节点信息聚合到第一采样节点中,得到目标节点特征。
在本申请实施例中,将每一第一采样节点的节点特征和与其关联的多个第二采样节点的节点特征进行聚合的具体方法,可以为采用不同的权重系数对第一采样节点的节点特征与第二采样节点的节点特征进行加权的方法进行聚合。在本申请实施例中,第一采样节点以及与其关联的第二采样节点的权重系数,可以通过图注意力机制的方法来进行计算,即通过图注意力机制计算每一第一采样节点与关联的多个第二采样节点之间的权重系数,然后再基于计算得到的权重系数将第二采样节点的节点信息聚合到第一采样节点中,从而得到每一第一采样节点的目标节点特征。
步骤105,以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
其中,在经过图网络构建、节点采样以及特征聚合后,便可以得到多个第一采样节点的目标节点特征。其中,每个第一采样节点还对应了一个节点标签,由于第一采样节点是基于节点标签的标签类别进行不同采样概率的采样得到的,如此得到的模型训练样本的标签分布较为均衡,从而可以提升训练得到的模型的效果。
具体地,在得到多个第一采样节点的目标节点特征后,便可以将目标节点特征作为模型的输入,并将每一目标节点特征对应的标签数据作为模型的输出,对神经网络模型进行训练。其中,此处神经网络模型具体可以为一个常用的分类网络。
其中,在一些实施例中,以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型进行训练,包括:
1、将目标节点特征输入至待训练的神经网络模型,得到神经网络模型输出的预测值;
2、根据预测值与目标节点特征对应的标签数据计算损失值,并基于损失值确定反传梯度;
3、根据反传梯度对神经网络模型的参数进行更新。
其中,在本申请实施例中,在将每一第一采样节点的多个第二采样节点的节点信息聚合到第一采样节点中,得到每一第一采样节点的目标节点特征后,便可以将第一采样节点的目标节点特征输入至待训练的神经网络模型中,得到神经网络模型输出的预测值。
然后,可以根据神经网络模型输出的预测值与第一采样节点对应的节点标签计算损失值(即根据预设的损失函数计算相应的数值),并基于损失计算反传梯度。进一步地,可以根据反传梯度对神经网络模型的参数进行更新,得到参数更新后的神经网络模型。然后,可以进一步将另一目标节点特征输入至神经网络模型中,并循环执行上述步骤以对神经网络模型的参数进行迭代更新,直到循环次数达到预设此处或神经网络模型的参数变化小于预设范围,停止对神经网络模型参数的更新,得到训练后的神经网络模型。
根据上述描述可知,本申请实施例提供的目标检测方法,通过获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;根据样本对象之间的关联关系构建以样本对象为节点的图网络;基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
以此,本申请提供的目标检测模型的训练方法,通过基于样本对象构建图网络,然后对图网络进行节点采样的方法,来避免样本分布不均衡导致的模型训练效果不好的问题。而且,本方法还进一步对采样得到的节点进行邻居节点采样,并且与采样到的邻居节点进行特征聚合,以获得更为准确的节点特征。然后采用更准确的节点特征进行神经网络模型的训练,从而可以大大提升训练得到的神经网络模型的准确性,进而可以提升目标检测的准确性。
本申请还提供了一种目标检测模型的训练方法,该方法可以使用于计算机设备中,该计算机设备可以为终端也可以为服务器。如图3所示,为本申请提供的目标检测模型的训练方法的另一流程示意图,方法具体包括:
步骤201,计算机设备获取训练样本数据。
其中,在本申请实施例中,具体可以获取用于训练医保检测模型的训练样本数据,该训练样本数据具体可以包括样本对象,即医保账户数据。此外,训练样本数据中还包括了标签数据,该标签数据具体为正常或者异常,一般可以以1代表医保账户正常,以0代表医保账户异常。其中,本申请实施例中,对于训练样本数据中医保账户的获取,具体可以为在公开的数据集中进行获取,即此处的医保账户数据是人为构造或者机器构造的公开的数据,专门用于对各类模型进行训练的数据。并非对真实用户的医保数据进行获取的数据;此处医保账户数据也可以为对真实用户的医保账户数据进行获取的数据,在该情况下,本申请提供的方法是在经过用户授权的基础上,并且是在遵循相关法律法规的基础上对用户的医保账户数据进行获取和使用的。
步骤202,计算机设备计算训练样本数据中样本对象之间的相似度。
其中,在获取到训练样本数据后,可以提取训练样本数据中样本对象数据。如前所述,样本对象数据包括了医保账户数据,医保账户数据中包含了用户数据、医生数据以及医院数据等。获取到样本对象的医保账户数据后,便可以根据用户数据之间的关系、用户与医生之间的关系、用户与医院之间的关系来计算样本对象之间的相似度。
在一些实施例中,在获取到样本对象的医保账户数据后,可以对医保账户数据进行特征编码,得到特征向量,然后计算特征向量之间的相似度,得到样本对象之间的相似度。
步骤203,计算机设备根据样本对象之间的相似度构建图网络。
其中,在计算得到样本对象之间的相似度后,便可以根据样本对象之间的相似度构建图网络G=(V,E)。其中,V表示节点集合,E表示边集合。在该图网络中,节点可以表示样本对象,也可以理解为用户,边表示样本对象之间的关系,也可以理解为用户之间的关系。图中的节点除了表示特征信息之外,还有标签信息,即该节点表示的医保账户为正常账户还是异常账户。
如图4所示,为本申请实施例中构建的图网络的示意图。如图所示,图中包括了十个节点,节点之间通过边连接。十个节点中存在三个异常医保账户对应的节点和七个正常医保账户对应的节点,异常医保账户对应的节点以斜杠底纹标识出,分别为节点v4、节点v5和节点v7。其他节点为正常医保账户对应的节点。
步骤204,计算机设备根据图网络中节点标签进行节点采样,得到多个目标节点。
其中,在构建了以医保账户为节点,医保账户之间的关联关系为边的图网络后,计算机设备便可以在该图网络中进行节点采样,以得到标签分布均衡的模型训练样本。从而避免对模型进行训练的样本分布不均衡导致模型学习不到对异常医保账户进行识别的能力。
在构建的图网络G的节点集合V中,存在数量较多的正常医保账户对应的节点,该类正常医保账户对应的节点对应的集合可以记为Vm。节点集合中还存在少量的异常医保账户对应的节点,该类异常医保账户对应的节点对应的集合可以记为Vn。为了平衡节点的标签类别分布,可以通过对图中的节点进行采样的方法来解决标签类别分布不平衡的问题,同时可以将采样得到的节点称为目标节点。
为了平衡节点标签类别的分布,可以增加数量较少的类别的节点的采样概率,并减少数量较多的类别的节点的采样概率,使得节点的采样概率与所属类别的标签数量成反比。具体地,可以先计算图中节点总数和目标节点所述类别的标签数量之间的第一比值,然后可以确定目标节点的采样概率与该第一比值成正比关系。即目标节点的采样概率可以表示如下:
其中,N表示图中节点的总数,L(vi)表示节点vi所属的标签类别,|L(vi)|表示跟节点vi为相同标签类别的节点数量,即该标签类别的数量。p(vi)表示节点vi的采样概率。∝表示正比符号。该公式的含义为当节点vi所属标签类别的数量越多时,节点vi被采样的概率越低;反之,vi被采样的概率越高,这样就能实现节点标签类别分布的平衡。
步骤205,计算机设备对每一目标节点的邻居节点进行采样,得到每一目标节点的采样邻居节点。
其中,在对图网络中节点按照不同的采样概率进行采样,得到标签分布均衡的目标节点后,由于每个目标节点可能存在多个邻居节点以及多跳邻居节点,每个邻居节点都含有不同的语义信息,为了获取有用的邻居节点信息,还需要对每个目标节点的邻居节点进行采样。
具体地,对于任一目标节点其邻居节点的集合/>该集合中也包括结点vi。首先,可以根据目标节点的邻居节点的度的大小,采样k1个邻居节点。节点的度越大,表明该节点蕴涵了丰富的结构信息,有助于节点之间的信息传递与聚合,学习到更有效的节点表示。即可以先确定邻居节点的度,然后确定邻居节点的采样概率与该邻居节点的度成正比关系。根据邻居节点的度采样邻居节点的概率可以表示如下:
p(vj)∝deg(vj)
其中,为目标节点vi的邻居节点。deg(vj)表示节点vj的度。p(vj)表示节点vj的采样概率。
例如,当目标节点为图4中的节点v1时,那么在其邻居节点中进行采样时,便可以采样到节点v2和节点v9作为采样邻居节点。因为节点v2的度为3,节点v9的度为4,大于其他邻居节点的度,从而他们的采样概率也就越大。
此外,由于图中节点的标签类别分布不均衡,对于任一目标节点vi,它的邻居节点集合中也可能存在标签类别分布不平衡的问题,因此还需要对邻居节点的标签进行采样。因此,可以根据邻居节点的标签类别采样k2个邻居节点,而且,可以尽可能地采样与目标节点标签类别相同的邻居节点,这样才能够有效地传递相同类别节点之间的信息,并防止不同类别节点间的信息传递与聚合。
当目标节点即目标节点的标签属于多数类别时,应该增加标签属于多数类别的邻居节点的采样概率,此时可以定义邻居节点的采样概率与标签数量成正比,这样这类邻居节点的采样概率就会变大。当目标节点/>即目标节点的标签属于少数类别时,应该增加标签属于少数类别的邻居节点的采样概率,此时可以定义邻居节点的采样概率与标签数量成反比,这样这类邻居节点的采样概率就会变大。具体地,当目标节点的标签属于多数类别时,可以先计算邻居节点的数量与节点总数的第二比值,然后确定对邻居节点的采样概率与上述第二比值成正比;当目标节点的标签属于少数类别时,则可以确定对邻居节点的采样概率与上述第二比值成反比。根据标签类别采样邻居节点的概率可以表示如下:
例如,当目标节点为v1时,那么采样得到的邻居节点便可以为v3和v6,因为他们具有相同的标签类别。当目标节点为v7时,那么采样得到的邻居节点便可以为v4和v5,因为他们与目标节点具有相同的标签类别。
至此,可以为每一目标节点采样了k1+k2个采样邻居节点。
步骤206,计算机设备基于图注意力机制计算每一目标节点的多个采样邻居节点的权重系数。
其中,在采样得到多个目标节点,并对每一目标节点采样得到多个采样邻居节点后,便可以进一步将目标节点的节点信息和其对应的采样邻居节点的节点信息进行聚合,得到更准确的节点表示。由于每个目标节点都有多个邻居节点,而不同的邻居节点蕴含了不同的语义信息,并且不同的邻居节点可能也有不同的重要性。因此,可以采用注意力机制学习不同邻居节点的权重,然后采用不同的权重将采样邻居节点的节点信息聚合到目标节点中,得到有效的节点表示来刻画节点特征。其中,对于每一邻居节点,可以分别采用注意力机制计算其对应的权重系数,然后基于每一邻居节点的权重系数对邻居节点的节点特征进行加权,然后再采用激活函数对加权结果进行激活处理。具体地,采用注意力机制来计算不同采样邻居节点的权重的过程表示如下:
/>
其中,αij表示邻居结点vj对目标结点vi的权重。表示l层的变换矩阵。/>表示目标结点vi在l-1层的结点表示。/>表示邻居结点vj在l-1层的结点表示。/>表示注意力向量。σ(·)表示激活函数。||表示拼接变换。·T表示转置变换。/>表示在l层目标结点的结点表示。dl-1和dl分别表示目标结点在l-1层和l层的结点表示维度。
步骤207,计算机设备根据采样邻居节点的权重系数将采样邻居节点的节点特征聚合到目标节点中,得到每一目标节点的目标节点特征。
其中,在计算得到目标节点的每一采样邻居节点的权重系数后,便可以进一步基于权重系数将采样邻居节点的节点特征聚合到目标节点中,从而计算得到每一目标节点的目标节点特征。如此,便可以得到标签均衡而且有效的模型样本。
步骤208,计算机设备以目标节点特征为模型输入,以目标节点的节点标签为模型输出标签训练神经网络模型。
其中,在经过多个信息聚合层聚合得到目标节点的目标节点特征后,可以进一步将其输入至待进行训练的神经网络模型中,此处待进行训练的神经网络模型具体可以为一个多层前馈网络,从而得到模型输出的目标节点vi属于正常账户的概率:
其中,表示多层前馈网络的变换矩阵,/>表示目标节点vi为正常账户的概率。
其中,由于对医保账户进行检测为一个二分类的问题,即检测医保账户为正常账户还是异常账户,因此对该神经网络模型进行训练可以采用交叉熵作为损失函数来对神经网络模型进行训练,该损失函数的构建目标是使得模型输出的节点属于正常账户的概率与该节点对应的标签数据尽量接近。即当目标节点对应的账户为异常账户时,该节点特征经目标检测模型后输出的概率值应接近0;当目标节点对应的账户为正常账户时,则该节点特征经目标检测模型后输出的概率值应该接近。具体地,损失函数构建如下:
其中,表示目标结点vi所属的真实标签类别。N表示图中节点的数量。
然后,可以基于上述损失函数对神经网络模型进行训练,得到训练后的神经网络模型。
以此,本申请提供的目标检测模型的训练方法,通过获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;根据样本对象之间的关联关系构建以样本对象为节点的图网络;基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
以此,本申请提供的目标检测模型的训练方法,通过基于样本对象构建图网络,然后对图网络进行节点采样的方法,来避免样本分布不均衡导致的模型训练效果不好的问题。而且,本方法还进一步对采样得到的节点进行邻居节点采样,并且与采样到的邻居节点进行特征聚合,以获得更为准确的节点特征。然后采用更准确的节点特征进行神经网络模型的训练,从而可以大大提升训练得到的神经网络模型的准确性,进而可以提升目标检测的准确性。
本申请另一方面还提供了一种目标检测方法,该方法具体可以应用于目标检测装置中。如图5所示,为本申请提供的目标检测方法的流程示意图,该方法具体包括如下步骤:
步骤301,获取待检测的目标的目标数据。
其中,在本申请实施例中,当根据本申请前述的目标检测模型的训练方法训练得到神经网络模型,并将该神经网络模型部署上线后,便可以根据该神经网络模型进行目标检测。具体地,当该神经网络模型为前述医保检测模型时,可以用来进行医保检测。
具体地,可以先获取待检测的目标的目标数据,例如可以获取待检测的用户的医保数据。其中,对用户的医保数据的获取可以是在提前对用户进行告知,并征得用户的授权的基础上进行获取的。而且,对用户医保数据的处理是完全遵循相关法律法规规定的。
步骤302,对目标数据进行特征编码,得到目标特征。
在获取到目标的目标数据,例如获取到医保数据后,便可以对医保数据进行特征编码,具体可以采用前述的词嵌入的方式对医保数据进行特征编码,得到目标特征。
步骤303,将目标特征输入至神经网络模型中,得到神经网络模型输出的预测值。
在编码得到目标特征后,便可以将目标特征输入至神经网络模型中进行检测,其中,此处的神经网络模型便可以为前述目标检测模型的训练方法中训练得到的神经网络模型。
步骤304,根据预测值与预设阈值的比对结果确定对目标进行检测的检测结果。
其中,神经网络模型在接收到输入的目标特征并对目标特征进行检测后,可以输出预测值。此时可以将神经网络模型输出的预测值与预设的阈值进行比对,当预测值大于上述预设阈值时,则可以确定该目标为正常,否则确认目标异常。当目标为医保账户时,则当预测值大于预设阈值时,确认该医保账户为正常账户,反之则可以确认医保账户为异常账户。
为了更好地实施以上目标检测模型的训练方法,本申请实施例还提供一种目标检测模型的训练装置,该目标检测模型的训练装置可以集成在终端或服务器中。
例如,如图6所示,为本申请实施例提供的目标检测模型的训练装置的结构示意图,装置可以包括获取单元401、构建单元402、采样单元403、聚合单元404以及训练单元405,如下:
第一获取单元401,用于获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;
构建单元402,用于根据样本对象之间的关联关系构建以样本对象为节点的图网络;
采样单元403,用于基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;
聚合单元404,用于将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;
训练单元405,用于以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型进行训练。
可选地,在一些实施例中,采样单元,包括:
第一获取子单元,用于获取图网络中每一标签类别对应的节点的数量;
第一计算子单元,用于基于每一标签类别对应的节点的数量与图网络中的节点总数计算每一节点的第一采样概率;
第一采样子单元,用于基于第一采样概率在图网络中进行节点采样,得到多个第一采样节点;
第二采样子单元,用于在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点。
可选地,在一些实施例中,第二采样子单元,包括:
第一采样模块,用于获取第一采样节点的每一邻居节点的度,并根据每一邻居节点的度在第一采样节点的邻居节点中采样第一数量的第二采样节点;
第二采样模块,用于获取第一采样节点的每一邻居节点的节点标签,并根据每一邻居节点的节点标签与第一采样节点的节点标签之间的关系在第一采样节点的邻居节点中采样第二数量的第二采样节点。
可选地,在一些实施例中,第一采样模块,包括:
获取子模块,用于获取第一采样节点的每一邻居节点的度;
计算子模块,用于根据每一邻居节点的度计算每一邻居节点对应的第二采样概率;
采样子模块,用于根据第二采样概率在第一采样节点的邻居节点中采样第一数量的第二采样节点。
可选地,在一些实施例中,构建单元,包括:
第二获取子单元,用于获取每一样本对象的样本数据;
第二计算子单元,用于根据样本数据计算样本对象之间的关联关系;
构建子单元,用于以样本对象为节点,并以样本对象之间的关联关系为边构建图网络。
可选地,在一些实施例中,聚合单元,包括:
第三获取子单元,用于获取每一第二采样节点与第一采样节点之间的权重系数;
聚合子单元,用于基于权重系数将第二采样节点的节点信息聚合到第一采样节点中,得到目标节点特征。
可选地,在一些实施例中,训练单元,包括:
输入子单元,用于将目标节点特征输入至待训练的神经网络模型,得到神经网络模型输出的预测值;
第三计算子单元,用于根据预测值与目标节点特征对应的标签数据计算损失值,并基于损失值确定反传梯度;
更新子单元,用于根据反传梯度对神经网络模型的参数进行更新。
具体实施时,以上各个单元可以作为独立的实体来实现,也可以进行任意组合,作为同一或若干个实体来实现,以上各个单元的具体实施可参见前面的方法实施例,在此不再赘述。
根据上述描述可知,本申请实施例提供的目标检测模型的训练装置,通过第一获取单元401获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;构建单元402根据样本对象之间的关联关系构建以样本对象为节点的图网络;采样单元403基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;聚合单元404将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;训练单元405以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型进行训练。
以此,本申请提供的目标检测模型的训练装置,通过基于样本对象构建图网络,然后对图网络进行节点采样的方法,来避免样本分布不均衡导致的模型训练效果不好的问题。而且,本方法还进一步对采样得到的节点进行邻居节点采样,并且与采样到的邻居节点进行特征聚合,以获得更为准确的节点特征。然后采用更准确的节点特征进行神经网络模型的训练,从而可以大大提升训练得到的神经网络模型的准确性,进而可以提升目标检测的准确性。
为了更好地实施以上目标检测方法,本申请实施例还提供一种目标检测装置,该目标检测装置可以集成在终端或服务器中。
如图7所示,为本申请实施例提供的目标检测装置的结构示意图,装置可以包括第二获取单元501、编码单元502、检测单元503以及确定单元504,如下:
第二获取单元501,用于获取待检测的目标的目标数据;
编码单元502,用于对目标数据进行特征编码,得到目标特征;
检测单元503,用于将目标特征输入至神经网络模型中,得到神经网络模型输出的预测值,神经网络模型为根据第一方面的目标检测模型的训练方法训练得到的神经网络模型;
确定单元504,用于根据预测值与预设阈值的比对结果确定对目标进行检测的检测结果。
具体实施时,以上各个单元可以作为独立的实体来实现,也可以进行任意组合,作为同一或若干个实体来实现,以上各个单元的具体实施可参见前面的方法实施例,在此不再赘述。
本申请实施例还提供一种计算机设备,该计算机设备可以为终端或服务器,如图8所示,为本申请提供的计算机设备的结构示意图。具体来讲:
该计算机设备可以包括一个或者一个以上处理核心的处理单元601、一个或一个以上存储介质的存储单元602、电源模块603和输入模块604等部件。本领域技术人员可以理解,图8中示出的计算机设备结构并不构成对计算机设备的限定,可以包括比图示更多或更少的部件,或者组合某些部件,或者不同的部件布置。其中:
处理单元601是该计算机设备的控制中心,利用各种接口和线路连接整个计算机设备的各个部分,通过运行或执行存储在存储单元602内的软件程序和/或模块,以及调用存储在存储单元602内的数据,执行计算机设备的各种功能和处理数据。可选的,处理单元601可包括一个或多个处理核心;优选的,处理单元601可集成应用处理器和调制解调处理器,其中,应用处理器主要处理操作系统、对象界面和应用程序等,调制解调处理器主要处理无线通信。可以理解的是,上述调制解调处理器也可以不集成到处理单元601中。
存储单元602可用于存储软件程序以及模块,处理单元601通过运行存储在存储单元602的软件程序以及模块,从而执行各种功能应用以及目标检测。存储单元602可主要包括存储程序区和存储数据区,其中,存储程序区可存储操作系统、至少一个功能所需的应用程序(比如声音播放功能、图像播放功能以及网页访问等)等;存储数据区可存储根据计算机设备的使用所创建的数据等。此外,存储单元602可以包括高速随机存取存储器,还可以包括非易失性存储器,例如至少一个磁盘存储器件、闪存器件、或其他易失性固态存储器件。相应地,存储单元602还可以包括存储器控制器,以提供处理单元601对存储单元602的访问。
计算机设备还包括给各个部件供电的电源模块603,优选的,电源模块603可以通过电源管理系统与处理单元601逻辑相连,从而通过电源管理系统实现管理充电、放电、以及功耗管理等功能。电源模块603还可以包括一个或一个以上的直流或交流电源、再充电系统、电源故障检测电路、电源转换器或者逆变器、电源状态指示器等任意组件。
该计算机设备还可包括输入模块604,该输入模块604可用于接收输入的数字或字符信息,以及产生与对象设置以及功能控制有关的键盘、鼠标、操作杆、光学或者轨迹球信号输入。
尽管未示出,计算机设备还可以包括显示单元等,在此不再赘述。具体在本实施例中,计算机设备中的处理单元601会按照如下的指令,将一个或一个以上的应用程序的进程对应的可执行文件加载到存储单元602中,并由处理单元601来运行存储在存储单元602中的应用程序,从而实现各种功能,如下:
获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;根据样本对象之间的关联关系构建以样本对象为节点的图网络;基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
或者,获取待检测的目标的目标数据;对目标数据进行特征编码,得到目标特征;将目标特征输入至神经网络模型中,得到神经网络模型输出的预测值,神经网络模型为根据本申请提供的目标检测模型的训练方法训练得到的神经网络模型;根据预测值与预设阈值的比对结果确定对目标进行检测的检测结果。
应当说明的是,本申请实施例提供的计算机设备与上文实施例中的方法属于同一构思,以上各个操作的具体实施可参见前面的实施例,在此不作赘述。
本领域普通技术人员可以理解,上述实施例的各种方法中的全部或部分步骤可以通过指令来完成,或通过指令控制相关的硬件来完成,该指令可以存储于一计算机可读存储介质中,并由处理器进行加载和执行。
为此,本发明实施例提供一种计算机可读存储介质,其中存储有多条指令,该指令能够被处理器进行加载,以执行本发明实施例所提供的任一种方法中的步骤。例如,该指令可以执行如下步骤:
获取训练样本数据,训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,标签数据包括多个标签类别;根据样本对象之间的关联关系构建以样本对象为节点的图网络;基于不同标签类别对应的样本对象的数量在图网络中进行节点采样,得到多个第一采样节点,并在第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点;将第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;以目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
或者,获取待检测的目标的目标数据;对目标数据进行特征编码,得到目标特征;将目标特征输入至神经网络模型中,得到神经网络模型输出的预测值,神经网络模型为根据本申请提供的目标检测模型的训练方法训练得到的神经网络模型;根据预测值与预设阈值的比对结果确定对目标进行检测的检测结果。
以上各个操作的具体实施可参见前面的实施例,在此不再赘述。
其中,该计算机可读存储介质可以包括:只读存储器(ROM,Read Only Memory)、随机存取记忆体(RAM,Random Access Memory)、磁盘或光盘等。
由于该计算机可读存储介质中所存储的指令,可以执行本发明实施例所提供的任一种方法中的步骤,因此,可以实现本发明实施例所提供的任一种方法所能实现的有益效果,详见前面的实施例,在此不再赘述。
其中,根据本申请的一个方面,提供了一种计算机程序产品或计算机程序,该计算机程序产品或计算机程序包括计算机指令,该计算机指令存储在存储介质中。计算机设备的处理器从存储介质读取该计算机指令,处理器执行该计算机指令,使得该计算机设备执行上述目标检测方法中各种可选实现方式中提供的方法。
以上对本发明实施例所提供的目标检测模型的训练方法及装置、目标检测方法及装置进行了详细介绍,本文中应用了具体个例对本发明的原理及实施方式进行了阐述,以上实施例的说明只是用于帮助理解本发明的方法及其核心思想;同时,对于本领域的技术人员,依据本发明的思想,在具体实施方式及应用范围上均会有改变之处,综上,本说明书内容不应理解为对本发明的限制。
Claims (10)
1.一种目标检测模型的训练方法,其特征在于,所述方法包括:
获取训练样本数据,所述训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,所述标签数据包括多个标签类别;
根据所述样本对象之间的关联关系构建以所述样本对象为节点的图网络;
基于不同标签类别对应的样本对象的数量在所述图网络中进行节点采样,得到多个第一采样节点,并在所述第一采样节点的邻居节点中采样所述第一采样节点关联的第二采样节点;
将所述第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;
以所述目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
2.根据权利要求1所述的方法,其特征在于,所述基于不同标签类别对应的样本对象的数量在所述图网络中进行节点采样,得到多个第一采样节点,并在所述第一采样节点的邻居节点中采样所述第一采样节点关联的第二采样节点,包括:
获取所述图网络中每一标签类别对应的节点的数量;
基于每一标签类别对应的节点的数量与所述图网络中的节点总数计算每一节点的第一采样概率;
基于所述第一采样概率在所述图网络中进行节点采样,得到多个第一采样节点;
在所述第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点。
3.根据权利要求2所述的方法,其特征在于,所述在所述第一采样节点的邻居节点中采样第一采样节点关联的第二采样节点,包括:
获取所述第一采样节点的每一邻居节点的度,并根据每一邻居节点的度在所述第一采样节点的邻居节点中采样第一数量的第二采样节点;
获取所述第一采样节点的每一邻居节点的节点标签,并根据每一邻居节点的节点标签与所述第一采样节点的节点标签之间的关系在所述第一采样节点的邻居节点中采样第二数量的第二采样节点。
4.根据权利要求3所述的方法,其特征在于,所述获取所述第一采样节点的每一邻居节点的度,并根据每一邻居节点的度在所述第一采样节点的邻居节点中采样第一数量的第二采样节点,包括:
获取所述第一采样节点的每一邻居节点的度;
根据每一邻居节点的度计算每一邻居节点对应的第二采样概率;
根据所述第二采样概率在所述第一采样节点的邻居节点中采样第一数量的第二采样节点。
5.根据权利要求1所述的方法,其特征在于,所述根据所述样本对象之间的关联关系构建以所述样本对象为节点的图网络,包括:
获取每一样本对象的样本数据;
根据所述样本数据计算样本对象之间的关联关系;
以所述样本对象为节点,并以所述样本对象之间的关联关系为边构建图网络。
6.根据权利要求1所述的方法,其特征在于,所述将所述第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征,包括:
获取每一第二采样节点与所述第一采样节点之间的权重系数;
基于所述权重系数将所述第二采样节点的节点信息聚合到所述第一采样节点中,得到目标节点特征。
7.根据权利要求1所述的方法,其特征在于,所述以所述目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整,包括:
将所述目标节点特征输入至待训练的神经网络模型,得到所述神经网络模型输出的预测值;
根据所述预测值与所述目标节点特征对应的标签数据计算损失值,并基于所述损失值确定反传梯度;
根据所述反传梯度对所述神经网络模型的参数进行更新。
8.一种目标检测方法,其特征在于,所述方法包括:
获取待检测的目标的目标数据;
对所述目标数据进行特征编码,得到目标特征;
将所述目标特征输入至神经网络模型中,得到所述神经网络模型输出的预测值,所述神经网络模型为根据权利要求1至7中任一项所述的目标检测模型的训练方法训练得到的神经网络模型;
根据所述预测值与预设阈值的比对结果确定对所述目标进行检测的检测结果。
9.一种目标检测模型的训练装置,其特征在于,所述装置包括:
第一获取单元,用于获取训练样本数据,所述训练样本数据包括多个样本对象以及每个样本对象对应的标签数据,所述标签数据包括多个标签类别;
构建单元,用于根据所述样本对象之间的关联关系构建以所述样本对象为节点的图网络;
采样单元,用于基于不同标签类别对应的样本对象的数量在所述图网络中进行节点采样,得到多个第一采样节点,并在所述第一采样节点的邻居节点中采样所述第一采样节点关联的第二采样节点;
聚合单元,用于将所述第一采样节点与关联的第二采样节点的节点特征进行聚合,得到目标节点特征;
训练单元,用于以所述目标节点特征为模型输入,并以对应的标签数据为输出标签对待训练的神经网络模型的模型参数进行调整。
10.一种目标检测装置,其特征在于,所述装置包括:
第二获取单元,用于获取待检测的目标的目标数据;
编码单元,用于对所述目标数据进行特征编码,得到目标特征;
检测单元,用于将所述目标特征输入至神经网络模型中,得到所述神经网络模型输出的预测值,所述神经网络模型为根据权利要求1至7中任一项所述的目标检测模型的训练方法训练得到的神经网络模型;
确定单元,用于根据所述预测值与预设阈值的比对结果确定对所述目标进行检测的检测结果。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310292326.5A CN116975622A (zh) | 2023-03-20 | 2023-03-20 | 目标检测模型的训练方法及装置、目标检测方法及装置 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310292326.5A CN116975622A (zh) | 2023-03-20 | 2023-03-20 | 目标检测模型的训练方法及装置、目标检测方法及装置 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116975622A true CN116975622A (zh) | 2023-10-31 |
Family
ID=88470078
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310292326.5A Pending CN116975622A (zh) | 2023-03-20 | 2023-03-20 | 目标检测模型的训练方法及装置、目标检测方法及装置 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116975622A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117454097A (zh) * | 2023-12-25 | 2024-01-26 | 数据空间研究院 | 一种数据清洗预测方法、装置、电子设备及存储介质 |
-
2023
- 2023-03-20 CN CN202310292326.5A patent/CN116975622A/zh active Pending
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117454097A (zh) * | 2023-12-25 | 2024-01-26 | 数据空间研究院 | 一种数据清洗预测方法、装置、电子设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113822494B (zh) | 风险预测方法、装置、设备及存储介质 | |
CN109344884B (zh) | 媒体信息分类方法、训练图片分类模型的方法及装置 | |
CN110196908A (zh) | 数据分类方法、装置、计算机装置及存储介质 | |
CN110046981B (zh) | 一种信用评估方法、装置及存储介质 | |
CN113761250A (zh) | 模型训练方法、商户分类方法及装置 | |
CN111522915A (zh) | 中文事件的抽取方法、装置、设备及存储介质 | |
WO2023185539A1 (zh) | 机器学习模型训练方法、业务数据处理方法、装置及系统 | |
US20210081800A1 (en) | Method, device and medium for diagnosing and optimizing data analysis system | |
CN116975622A (zh) | 目标检测模型的训练方法及装置、目标检测方法及装置 | |
CN115905528A (zh) | 具有时序特征的事件多标签分类方法、装置及电子设备 | |
CN115858886A (zh) | 数据处理方法、装置、设备及可读存储介质 | |
CN113822144A (zh) | 一种目标检测方法、装置、计算机设备和存储介质 | |
CN114418189A (zh) | 水质等级预测方法、系统、终端设备及存储介质 | |
CN113569955A (zh) | 一种模型训练方法、用户画像生成方法、装置及设备 | |
CN113609337A (zh) | 图神经网络的预训练方法、训练方法、装置、设备及介质 | |
US20230214676A1 (en) | Prediction model training method, information prediction method and corresponding device | |
CN114169418B (zh) | 标签推荐模型训练方法及装置、标签获取方法及装置 | |
CN115905293A (zh) | 作业执行引擎的切换方法及装置 | |
CN117523218A (zh) | 标签生成、图像分类模型的训练、图像分类方法及装置 | |
CN116415624A (zh) | 模型训练方法及装置、内容推荐方法及装置 | |
CN114529191A (zh) | 用于风险识别的方法和装置 | |
CN111563191A (zh) | 基于图网络的数据处理系统 | |
CN116992031B (zh) | 数据处理方法、装置、电子设备、存储介质及程序产品 | |
CN113254635B (zh) | 数据处理方法、装置及存储介质 | |
CN115563954A (zh) | 机器阅读理解方法、装置、设备和计算机可读存储介质 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication |