CN113077051B - 网络模型训练方法、装置、文本分类模型及网络模型 - Google Patents

网络模型训练方法、装置、文本分类模型及网络模型 Download PDF

Info

Publication number
CN113077051B
CN113077051B CN202110402004.2A CN202110402004A CN113077051B CN 113077051 B CN113077051 B CN 113077051B CN 202110402004 A CN202110402004 A CN 202110402004A CN 113077051 B CN113077051 B CN 113077051B
Authority
CN
China
Prior art keywords
training
branch network
samples
model
branch
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110402004.2A
Other languages
English (en)
Other versions
CN113077051A (zh
Inventor
黄深能
赵茜
利啟东
佟博
高玮
叶凯亮
胡盼盼
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Nanjing Lingdong Shuzhi Technology Co ltd
Original Assignee
Nanjing Lingdong Shuzhi Technology Co ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Nanjing Lingdong Shuzhi Technology Co ltd filed Critical Nanjing Lingdong Shuzhi Technology Co ltd
Priority to CN202110402004.2A priority Critical patent/CN113077051B/zh
Publication of CN113077051A publication Critical patent/CN113077051A/zh
Application granted granted Critical
Publication of CN113077051B publication Critical patent/CN113077051B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • General Health & Medical Sciences (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • Computational Linguistics (AREA)
  • Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

本申请涉及一种网络模型训练方法、装置、文本分类模型及网络模型,属于计算机技术领域。该网络模型训练方法包括:获取训练样本集,所述训练样本集包括常见类样本和少见类样本;利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型,其中,所述双边分支网络模型的两个分支网络均包括基于多尺度注意力机制模块的编码层。本申请实施例中,通过基于多尺度注意力机制模块的编码层来构建双边分支网络模型,通过将多个信息头的信息进行融合,来获取丰富的语义信息,从而能使模型尽快收敛,以及提高模型的分类精度。

Description

网络模型训练方法、装置、文本分类模型及网络模型
技术领域
本申请属于计算机技术领域,具体涉及一种网络模型训练方法、装置、文本分类模型及网络模型。
背景技术
随着产生的文本数量越来越多,纯靠人力去分类显然是不现实。目前文本分类技术主要分两类,一类是基于传统机器学习的分类方法,主要有贝叶斯,支持向量机(SupportVector Machine,SVM)等;另外一类是基于深度学习的分类方法。传统机器学习方法需要的数据量不是很多,也不过多依靠计算机,但是由于提取特征相对简单,模型的泛化能力相对较弱,分类准确性较差。而深度学习方法对数量依赖比较大,在实际应用中,由于有些数据的采集相对困难,并没有足够的数据进行训练,使得模型的收敛相对较难。另外,有些行业的数据量并不是很均衡,例如,病历文本分类中,一些少见病情对应的文本可能是不够的,这就造成模型对少见类缺乏足够的拟合能力,导致模型的分类精度不高。
发明内容
鉴于此,本申请的目的在于提供一种网络模型训练方法、装置、文本分类模型及网络模型,以改善现有分类模型的收敛相对较慢以及分类精度不高的问题。
本申请的实施例是这样实现的:
第一方面,本申请实施例提供了一种网络模型训练方法,包括:获取训练样本集,所述训练样本集包括常见类样本和少见类样本;利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型,其中,所述双边分支网络模型的两个分支网络均包括基于多尺度注意力机制模块的编码层。
本申请实施例中,通过基于多尺度注意力机制模块的编码层来构建双边分支网络模型,通过将多个信息头的信息进行融合,来获取丰富的语义信息,从而能使模型尽快收敛,以及提高模型的分类精度。
结合第一方面实施例的一种可能的实施方式,利用所述训练样本集对双边分支网络模型进行训练,包括:每次迭代训练时,对所述训练样本集中的样本进行随机采样,以及对所述训练样本集中的不同类别样本进行权重采样,在进行权重采样时,类别数量少的样本的采样频率大于类别数量多的样本的采样频率;将随机采样的N个样本输入所述双边分支网络模型中的第一分支网络中,将权重采样的N个样本输入所述双边分支网络模型中的第二分支网络中,对所述双边分支网络模型进行迭代训练,其中,N为正整数,且小于所述训练样本集中的样本数。
本申请实施例中,通过在每次迭代训练时,通过对训练样本集中的样本进行随机采样,将获得的N个样本输入第一分支网络中,对第一分支网络进行训练;通过对训练样本集中的不同类别样本进行权重采样,将获得的N个样本输入第二分支网络中,对第二分支网络进行训练,且在进行权重采样时,类别数量越少的样本,采样的频率越高,这样能解决模型对少见类样本缺乏足够的拟合能力的问题。
结合第一方面实施例的一种可能的实施方式,训练过程中,通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子,从而动态调整两个分支网络的特征融合。
本申请实施例中,通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子,从而动态调整两个分支网络的特征融合,能够有效防止过度采样少见类样本数据造成模型对少见类样本数据过拟合,同时也能抑制常见类别数据对模型的影响过大的问题。
结合第一方面实施例的一种可能的实施方式,所述双边分支网络模型中的第一分支网络和第二分支网络的自适应权重因子分别为W1和W2;其中,W1=a1*g1,W2=a2*g2,a1=1-T/2*Tmax,a2=T/2*Tmax,T表示当前训练轮次,Tmax表示训练最大轮次,K为最大样本类别,/>和fci分别表示当前训练轮次输入所述第一分支网络的样本中的第i类别样本的真实值和预测值,/>和fri分别表示当前训练轮次输入所述第二分支网络的样本中的第i类别样本的真实值和预测值。
本申请实施例中,在刚开始训练的时候,T比较小,a1初始值接近1,等到训练后期接近0.5。在训练初期第一分支网络占的权重比较大,主要训练第一分支网络,以提高模型的基础特征提取能力;等到了后期第一分支网络的a1值和第二分支网络的a2值相近,此时,主要靠每个支路上的梯度平均模长(g1、g2)动态调节对应的权重,当第一分支网络上的支路梯度平均模长比较大,则占的权重就大,能够有效防止过度采样少见类样本数据造成模型对少见类样本数据过拟合,同时也能抑制常见类别数据对模型的影响过大的问题。
第二方面,本申请实施例还提供了一种文本分类模型,用于处理待分类文本,所述文本分类模型由利用上述第一方面实施例和/或结合第一方面实施例的任一种可能的实施方式提供的网络模型训练方法训练得到。
第三方面,本申请实施例还提供了一种网络模型,包括:第一分支网络、第二分支网络以及合并层;所述第一分支网络和所述第二分支网络均包括基于多尺度注意力机制模块的编码层;合并层,用于将所述第一分支网络和所述第二分支网络各自输出的特征向量进行相加,并作为模型的最终预测值输出。
结合第三方面实施例的一种可能的实施方式,所述编码层中的多尺度注意力机制模块中不同信息头head对应的超参数权重不同。
本申请实施例中,通过将编码层中的多尺度注意力机制模块中不同信息头head对应的超参数权重不同,从而使模型更倾向有重要信息的head,提高模型提取少见类样本的语义信息的能力。
第四方面,本申请实施例还提供了一种网络模型训练装置,包括:获取模块以及训练模块;获取模块,用于获取训练样本集,所述训练样本集包括常见类样本和少见类样本;训练模块,用于利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型,其中,所述双边分支网络模型的两个分支网络均包括基于多尺度注意力机制模块的编码层。
第五方面,本申请实施例还提供了一种电子设备,包括:存储器和处理器,所述处理器与所述存储器连接;所述存储器,用于存储程序;所述处理器,用于调用存储于所述存储器中的程序,以执行上述第一方面实施例和/或结合第一方面实施例的任一种可能的实施方式提供的方法。
第六方面,本申请实施例还提供了一种存储介质,其上存储有计算机程序,所述计算机程序被处理器运行时,执行上述第一方面实施例和/或结合第一方面实施例的任一种可能的实施方式提供的方法。
本申请的其他特征和优点将在随后的说明书阐述,并且,部分地从说明书中变得显而易见,或者通过实施本申请实施例而了解。本申请的目的和其他优点可通过在所写的说明书以及附图中所特别指出的结构来实现和获得。
附图说明
为了更清楚地说明本申请实施例或现有技术中的技术方案,下面将对实施例中所需要使用的附图作简单地介绍,显而易见地,下面描述中的附图仅仅是本申请的一些实施例,对于本领域普通技术人员来讲,在不付出创造性劳动的前提下,还可以根据这些附图获得其他的附图。通过附图所示,本申请的上述及其它目的、特征和优势将更加清晰。在全部附图中相同的附图标记指示相同的部分。并未刻意按实际尺寸等比例缩放绘制附图,重点在于示出本申请的主旨。
图1示出了本申请实施例提供的一种双边分支网络模型的结构示意图。
图2为现有的基于注意力机制模块的编码层的原理示意图。
图3a示出了本申请实施例提供的一种基于多尺度注意力机制模块的编码层的原理示意图。
图3b示出了本申请实施例提供的又一种基于多尺度注意力机制模块的编码层的原理示意图
图3c示出了本申请实施例提供的一种基于自适应权重的多尺度注意力机制模块的编码层的原理示意图。
图4示出了本申请实施例提供的一种网络模型训练方法的流程示意图。
图5示出了本申请实施例提供的一种文本分类方法的流程示意图。
图6示出了本申请实施例提供的一种网络模型训练装置的结构框图。
图7示出了本申请实施例提供的一种电子设备的结构示意图。
具体实施方式
下面将结合本申请实施例中的附图,对本申请实施例中的技术方案进行描述。
应注意到:相似的标号和字母在下面的附图中表示类似项,因此,一旦某一项在一个附图中被定义,则在随后的附图中不需要对其进行进一步定义和解释。同时,在本申请的描述中诸如“第一”、“第二”等之类的关系术语仅仅用来将一个实体或者操作与另一个实体或操作区分开来,而不一定要求或者暗示这些实体或操作之间存在任何这种实际的关系或者顺序。而且,术语“包括”、“包含”或者其任何其他变体意在涵盖非排他性的包含,从而使得包括一系列要素的过程、方法、物品或者设备不仅包括那些要素,而且还包括没有明确列出的其他要素,或者是还包括为这种过程、方法、物品或者设备所固有的要素。在没有更多限制的情况下,由语句“包括一个……”限定的要素,并不排除在包括所述要素的过程、方法、物品或者设备中还存在另外的相同要素。
再者,本申请中术语“和/或”,仅仅是一种描述关联对象的关联关系,表示可以存在三种关系,例如,A和/或B,可以表示:单独存在A,同时存在A和B,单独存在B这三种情况。
鉴于现有的文本分类方法的准确性不高,对少见类样本缺乏足够的拟合能力的问题,本申请实施例提供了一种基于多尺度注意力机制的语义信息的提取和双边分支网络(Bilateral-Branch Network,BBN)模型的文本分类方法,能够有效提取语义信息加速模型收敛和提高模型的泛化能力,还能抑制长尾效应,提高模型对少见类样本的拟合能力。
下面将对本申请涉及的包含多尺度注意力机制模块的双边分支网络模型进行说明。如图1所示,该双边分支网络模型包括:第一分支网络、第二分支网络和合并层。其中,第一分支网络和第二分支网络均包括基于多尺度注意力机制模块的编码层。
其中,第一分支网络包括第一主干网络(backbone)、第一编码层(encoder)和第一特征提取层。训练样本X1输入第一主干网络后再经过第一编码层,获得特征向量Fc,特征向量Fc再经过第一特征提取层乘以自适应权重因子W1后,得到特征向量Wc。其中,第一主干网络由多层编码层组成,例如由11层编码层串接而成。第一主干网络和第一编码层均为基于多尺度注意力机制模块的编码层。
第二分支网络包括第二主干网络(backbone)、第二编码层(encoder)和第二特征提取层。训练样本X2输入第二主干网络后再经过第二编码层,获得特征向量Fr,特征向量Fr再经过第二特征提取层乘以自适应权重因子W2后,得到特征向量Wr。其中,第二主干网络由多层编码层组成,例如由11层编码层串接而成。第二主干网络和第二编码层均为基于多尺度注意力机制模块的编码层。
其中,第一主干网络和第二主干网络共享权重值,第一编码层和第二编码层不共享权重值。第一主干网络和第二主干网络用于从输入样本中提取基础特征,第一编码层和第二编码层用于提取语义特征。本申请实施例中采用基于多尺度注意力机制模块的编码层来搭建双边分支网络模型,能够有效提取语义信息加速模型收敛和提高模型的泛化能力。
现有的双边分支网络模型中的编码层中的注意力机制模块(self-Attention)是将上一层的信息头输出的信息进行累加,再输入到下一层中,如图2所示,j层的信息头累加融合到j+1层(其中,表示第j层第0个信息头),该机制只能学习单个信息头之间的关系。所以,self-Attention机制为了获取丰富的语义,需要大量的语料进行长时间的训练,才能使模型收敛。
本申请实施例中,构建基于多尺度注意力机制模块(Scale-Aware Self-Attention)的编码层,通过将多个信息头的信息进行融合,来获取丰富的语义信息。如图3a所示,head1的信息来自信息头head2的信息来自信息头/>head3的信息来自信息头然后再将head1、head2、head3的信息传给下一层的信息头/>下一层编码层的其他信息头的获取和/>类似,例如/>的head1的信息来自信息头/>head2的信息来自信息头/>head3的信息来自信息头/> 如图3b所示。
考虑到由于文本的重要信息的分布相对随机,因此本发明给各个head设置不同的超参数权重,让模型去学习该权重值,如图3c所示。这样模型就能根据输入文本学习得到不同head的权重,从而更倾向有重要信息的head。也即编码层中的多尺度注意力机制模块中不同信息头head对应的超参数权重不同,此时该多尺度注意力机制模块为自适应权重的多尺度注意力机制模块(Weighted Scale-Aware Self-Attention)。图3c中可以看出head1的超参数权重为H1,head2的超参数权重为H2,head3的超参数权重为H3,其中,H1、H2、H3的值不同。
合并层,用于将第一分支网络和第二分支网络各自输出的特征向量进行相加,并作为模型的最终预测值输出。也即将第一分支网络输出的特征向量Wc和第二分支网络输出的特征向量Wr进行相加,并作为模型的最终预测值(Loss)输出。
为了解决长尾效应(某些类别的数据会比较少,而其他类别的数据比较多),造成类别质检的数量失衡,模型主要学习到常见类类别,对少见类类别欠拟合的问题。本申请实施例中,在模型训练过程中通过梯度平均模长和训练轮次动态调整两个分支对应的自适应权重因子,从而动态调整两个分支的特征融合,实现特征的均衡化。为了便于理解,下面将结合图4对本申请实施例提供的网络模型训练方法进行说明。该模型训练方法包括:
步骤S101:获取训练样本集,训练样本集包括常见类样本和少见类样本。
其中,获取训练样本集的过程包括:获取包含至少两种类别的样本,且至少两种类别的样本中常见类样本和少见类样本的数量比为M:1,M为大于等于20的正整数。例如,该训练样本集中样本的数量有10万条(并不限于此),其中少见类样本类别和常见类样本类别之间的数量比为1:20。
步骤S102:利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型。
在获取到训练样本集后,利用获取的训练样本集对双边分支网络模型进行训练,便可得到训练好的文本分类模型。其中,该双边分支网络模型即为上述所述的双边分支网络模型,此处不再对模型结构进行说明。
其中,利用训练样本集对双边分支网络模型进行训练的过程包括:每次迭代训练时,对训练样本集中的样本进行随机采样,以及对训练样本集中的不同类别样本进行权重采样,在进行权重采样时,类别数量少的样本的采样频率大于类别数量多的样本的采样频率(也即类别数量越少的样本,采样的频率越高),将随机采样的N个样本输入双边分支网络模型中的第一分支网络中,将权重采样的N个样本输入所述双边分支网络模型中的第二分支网络中,对双边分支网络模型进行迭代训练,其中,N为正整数,且小于训练样本集中的样本数。结合上述的模型结构进行说明,也即在对第一分支网络进行训练的训练样本X1为对训练样本集中的样本进行随机采样获得的N个样本,而对第二分支网络进行训练的训练样本X2为对训练样本集中的不同类别样本进行权重采样而获得的N个样本。且在进行权重采样时,类别数量越少的样本,采样的频率越高,这样能解决模型对少见类样本缺乏足够的拟合能力的问题。
为了防止过度采样少见类样本数据造成模型对少见类样本数据过拟合,同时也能抑制常见类别数据对模型的影响过大的问题,本申请实施例中,训练过程中,通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子,从而动态调整两个分支网络的特征融合。
其中,双边分支网络模型中的第一分支网络和第二分支网络的自适应权重因子分别为W1和W2。其中,W1=a1*g1,W2=a2*g2,a1=1-T/2*T max,a2=T/2*T max,T表示当前训练轮次,Tmax表示训练最大轮次(也即迭代总次数),K为最大样本类别,/>和fci分别表示当前训练轮次输入第一分支网络的样本中的第i类别样本的真实值和预测值,/>和fri分别表示当前训练轮次输入第二分支网络的样本中的第i类别样本的真实值和预测值。为了便于理解,举例进行说明,假设训练最大轮次为20,则Tmax=20,若当前为第10次迭代,则当前训练轮次T=10,K为最大样本类别,假设共有10类样本类别,则K=10。其中,上式中的g1和g2分别用于求取第一分支网络和第二分支网络的梯度平均模长的。
通过上述公式可以看出,在刚开始训练的时候,T比较小,a1初始值接近1,等到训练后期接近0.5。在训练初期第一分支网络占的权重比较大,主要训练第一分支网络,以提高模型的基础特征提取能力。等到了后期第一分支网络的a1值和第二分支网络的a2值相近,此时,主要靠每个支路上的梯度平均模长(g1、g2)动态调节对应的权重。例如,当第一分支网络上的支路梯度平均模长比较大,则占的权重就大,能够有效防止过度采样少见类样本数据造成模型对少见类样本数据过拟合,同时也能抑制常见类别数据对模型的影响过大的问题。
本申请通过基于自适应权重的多尺度注意力机制模块(Weighted Scale-AwareSelf-Attention)来构建多注意力机制编码层;然后用该编码层搭建双边分支网络模型,并将通过随机采样的训练样本和通过权重采样的训练样本分别输入模型两个分支进行训练,训练过程中通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子,从而动态调整两个分支网络的特征融合,实现特征的均衡化;最后便可利用训练好的模型进行文本分类。
为了便于说明本申请改进后的模型的性能提升,下面将普通注意力机制(Self-Attention)、多尺度注意力机制(Scale-Aware Self-Attention)、自适应权重的多尺度注意力机制(Weighted Scale-Aware Self-Attention)以及本申请的包含自适应权重的多尺度注意力机制的双边分支网络模型(Weighted Scale-Aware Self-Attention+BBN)的模型收敛次数以及准确率进行对比。通过获取包含至少两种类别的样本,且至少两种类别的样本中常见类样本和少见类样本的数量比为20:1的训练样本集(包括训练集和测试集),分别对普通注意力机制(Self-Attention)、多尺度注意力机制(Scale-Aware Self-Attention)、自适应权重的多尺度注意力机制(Weighted Scale-Aware Self-Attention)以及本申请的包含自适应权重的多尺度注意力机制的双边分支网络模型(WeightedScale-Aware Self-Attention+BBN)进行训练和测试,其训练和测试实验结果如表1可知。
表1
通过表1可知,基于自适应权重的多尺度注意力机制的双边分支网络模型能够有效的提高模型的语义提取能力,而且加速模型的收敛和提高分类的准确性。
基于同样的发明构思,本申请实施例还提供了一种文本分类模型,用于处理待分类文本,该用于处理待分类文本由图4所述的网络模型训练方法训练得到。
基于同样的发明构思,下面将结合图5,对本申请实施例提供的文本分类方法进行说明。
步骤S201:获取待分类文本。
步骤S202:利用上述的网络模型训练方法训练的文本分类模型对所述待分类文本进行处理,得到分类结果。
当需要对待分类文本进行分类时,利用上述的网络模型训练方法训练的文本分类模型(包含多尺度注意力机制模块的双边分支网络模型)对待分类文本进行处理,便可得到分类结果。
其中,在进行文本预测时,使用双边分支网络模型中的第二分支网络对待分类文本进行预测和分类。
其中,训练文本分类模型的方法,以及文本分类模型的具体结果请参照前述相同部分即可。
基于同样的发明构思,本申请实施例还提供了一种网络模型训练装置100,如图6所示。该网络模型训练装置100包括:获取模块110、训练模块120。
获取模块110,用于获取训练样本集,所述训练样本集包括常见类样本和少见类样本。
训练模块120,用于利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型,其中,所述双边分支网络模型的两个分支网络均包括基于多尺度注意力机制模块的编码层。
在本申请实施例中,训练模块120,具体用于:每次迭代训练时,对所述训练样本集中的样本进行随机采样,以及对所述训练样本集中的不同类别样本进行权重采样,在进行权重采样时,类别数量越少的样本,采样的频率越高;将随机采样的N个样本输入所述双边分支网络模型中的第一分支网络中,将权重采样的N个样本输入所述双边分支网络模型中的第二分支网络中,对所述双边分支网络模型进行迭代训练,其中,N为正整数,且小于所述训练样本集中的样本数。
可选地,在本申请实施例中,训练过程中,通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子,从而动态调整两个分支网络的特征融合。
可选地,在本申请实施例中,所述双边分支网络模型中的第一分支网络和第二分支网络的自适应权重因子分别为W1和W2;其中,W1=a1*g1,W2=a2*g2,a1=1-T/2*Tmax,a2=T/2*Tmax, T表示当前训练轮次,Tmax表示训练最大轮次,K为最大样本类别,/>和fci分别表示当前训练轮次输入所述第一分支网络的样本中的第i类别样本的真实值和预测值,/>和fri分别表示当前训练轮次输入所述第二分支网络的样本中的第i类别样本的真实值和预测值。
本申请实施例所提供的网络模型训练装置100,其实现原理及产生的技术效果和前述方法实施例相同,为简要描述,装置实施例部分未提及之处,可参考前述方法实施例中相应内容。
基于同样的发明构思,如图7所示,图7示出了本申请实施例提供的一种电子设备200的结构框图。所述电子设备200包括:收发器210、存储器220、通讯总线230以及处理器240。
所述收发器210、所述存储器220、处理器240各元件相互之间直接或间接地电性连接,以实现数据的传输或交互。例如,这些元件相互之间可通过一条或多条通讯总线230或信号线实现电性连接。其中,收发器210用于收发数据。存储器220用于存储计算机程序,如存储有图6中所示的软件功能模块,即网络模型训练装置100。其中,网络模型训练装置100包括至少一个可以软件或固件(firmware)的形式存储于所述存储器220中或固化在所述电子设备200的操作系统(operating system,OS)中的软件功能模块。所述处理器240,用于执行存储器220中存储的可执行模块,例如网络模型训练装置100包括的软件功能模块或计算机程序。例如,处理器240,用于获取训练样本集,所述训练样本集包括常见类样本和少见类样本;利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型,其中,所述双边分支网络模型的两个分支网络均包括基于多尺度注意力机制模块的编码层。
其中,存储器220可以是,但不限于,随机存取存储器(Random Access Memory,RAM),只读存储器(Read Only Memory,ROM),可编程只读存储器(Programmable Read-OnlyMemory,PROM),可擦除只读存储器(Erasable Programmable Read-Only Memory,EPROM),电可擦除只读存储器(Electric Erasable Programmable Read-Only Memory,EEPROM)等。
处理器240可能是一种集成电路芯片,具有信号的处理能力。上述的处理器可以是通用处理器,包括中央处理器(Central Processing Unit,CPU)、网络处理器(NetworkProcessor,NP)等;还可以是数字信号处理器(Digital Signal Processor,DSP)、专用集成电路(Application Specific Integrated Circuit,ASIC)、现场可编程门阵列(FieldProgrammable Gate Array,FPGA)或者其他可编程逻辑器件、分立门或者晶体管逻辑器件、分立硬件组件。可以实现或者执行本申请实施例中的公开的各方法、步骤及逻辑框图。通用处理器可以是微处理器或者该处理器240也可以是任何常规的处理器等。
其中,上述的电子设备200,包括但不限于计算机、服务器等。
本申请实施例还提供了一种非易失性计算机可读取存储介质(以下简称存储介质),该存储介质上存储有计算机程序,该计算机程序被计算机如上述的电子设备200运行时,执行上述所示的网络模型训练方法,或者上述的文本分类方法。
需要说明的是,本说明书中的各个实施例均采用递进的方式描述,每个实施例重点说明的都是与其他实施例的不同之处,各个实施例之间相同相似的部分互相参见即可。
在本申请所提供的几个实施例中,应该理解到,所揭露的装置和方法,也可以通过其它的方式实现。以上所描述的装置实施例仅仅是示意性的,例如,附图中的流程图和框图显示了根据本申请的多个实施例的装置、方法和计算机程序产品的可能实现的体系架构、功能和操作。在这点上,流程图或框图中的每个方框可以代表一个模块、程序段或代码的一部分,所述模块、程序段或代码的一部分包含一个或多个用于实现规定的逻辑功能的可执行指令。也应当注意,在有些作为替换的实现方式中,方框中所标注的功能也可以以不同于附图中所标注的顺序发生。例如,两个连续的方框实际上可以基本并行地执行,它们有时也可以按相反的顺序执行,这依所涉及的功能而定。也要注意的是,框图和/或流程图中的每个方框、以及框图和/或流程图中的方框的组合,可以用执行规定的功能或动作的专用的基于硬件的系统来实现,或者可以用专用硬件与计算机指令的组合来实现。
另外,在本申请各个实施例中的各功能模块可以集成在一起形成一个独立的部分,也可以是各个模块单独存在,也可以两个或两个以上模块集成形成一个独立的部分。
所述功能如果以软件功能模块的形式实现并作为独立的产品销售或使用时,可以存储在一个计算机可读取存储介质中。基于这样的理解,本申请的技术方案本质上或者说对现有技术做出贡献的部分或者该技术方案的部分可以以软件产品的形式体现出来,该计算机软件产品存储在一个存储介质中,包括若干指令用以使得一台计算机设备(可以是个人计算机,笔记本电脑,服务器,或者电子设备等)执行本申请各个实施例所述方法的全部或部分步骤。而前述的存储介质包括:U盘、移动硬盘、只读存储器(Read-Only Memory,ROM)、随机存取存储器(Random Access Memory,RAM)、磁碟或者光盘等各种可以存储程序代码的介质。
以上所述,仅为本申请的具体实施方式,但本申请的保护范围并不局限于此,任何熟悉本技术领域的技术人员在本申请揭露的技术范围内,可轻易想到变化或替换,都应涵盖在本申请的保护范围之内。因此,本申请的保护范围应所述以权利要求的保护范围为准。

Claims (8)

1.一种网络模型训练方法,其特征在于,包括:
获取训练样本集,所述训练样本集包括常见类样本和少见类样本;
利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型,其中,所述双边分支网络模型的两个分支网络均包括基于多尺度注意力机制模块的编码层;
其中,训练过程中,通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子以动态调整两个分支网络的特征融合;
所述双边分支网络模型中的第一分支网络和第二分支网络的自适应权重因子分别为W1和W2;
其中,W1=a1*g1,W2=a2*g2,a1=1-T/2*Tmax,a2=T/2*Tmax, T表示当前训练轮次,Tmax表示训练最大轮次,K为最大样本类别,/>和fci分别表示当前训练轮次输入所述第一分支网络的样本中的第i类别样本的真实值和预测值,/>和fri分别表示当前训练轮次输入所述第二分支网络的样本中的第i类别样本的真实值和预测值,g1和g2为梯度平均模长。
2.根据权利要求1所述的方法,其特征在于,利用所述训练样本集对双边分支网络模型进行训练,包括:
每次迭代训练时,对所述训练样本集中的样本进行随机采样,以及对所述训练样本集中的不同类别样本进行权重采样,在进行权重采样时,类别数量少的样本的采样频率大于类别数量多的样本的采样频率;
将随机采样的N个样本输入所述双边分支网络模型中的第一分支网络中,将权重采样的N个样本输入所述双边分支网络模型中的第二分支网络中,对所述双边分支网络模型进行迭代训练,其中,N为正整数,且小于所述训练样本集中的样本数。
3.一种文本分类模型,用于处理待分类文本,其特征在于:
所述文本分类模型由如权利要求1-2中任一项所述的网络模型训练方法训练得到。
4.一种网络模型,其特征在于,包括:
第一分支网络、第二分支网络,所述第一分支网络和所述第二分支网络均包括基于多尺度注意力机制模块的编码层;以及,合并层,用于将所述第一分支网络和所述第二分支网络各自输出的特征向量进行相加,并作为模型的最终预测值输出;
其中,训练过程中,通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子以动态调整两个分支网络的特征融合;
双边分支网络模型中的第一分支网络和第二分支网络的自适应权重因子分别为W1和W2;
其中,W1=a1*g1,W2=a2*g2,a1=1-T/2*Tmax,a2=T/2*Tmax, T表示当前训练轮次,Tmax表示训练最大轮次,K为最大样本类别,/>和fci分别表示当前训练轮次输入所述第一分支网络的样本中的第i类别样本的真实值和预测值,/>和fri分别表示当前训练轮次输入所述第二分支网络的样本中的第i类别样本的真实值和预测值,g1和g2为梯度平均模长。
5.根据权利要求4所述的模型,其特征在于,所述编码层中的多尺度注意力机制模块中不同信息头head对应的超参数权重不同。
6.一种网络模型训练装置,其特征在于,包括:
获取模块,用于获取训练样本集,所述训练样本集包括常见类样本和少见类样本;
训练模块,用于利用所述训练样本集对双边分支网络模型进行训练,得到训练好的文本分类模型,其中,所述双边分支网络模型的两个分支网络均包括基于多尺度注意力机制模块的编码层;还用于训练过程中,通过梯度平均模长和训练轮次动态调整两个分支网络对应的自适应权重因子以动态调整两个分支网络的特征融合;所述双边分支网络模型中的第一分支网络和第二分支网络的自适应权重因子分别为W1和W2;其中,W1=a1*g1,W2=a2*g2,a1=1-T/2*Tmax,a2=T/2*Tmax,T表示当前训练轮次,Tmax表示训练最大轮次,K为最大样本类别,/>和fci分别表示当前训练轮次输入所述第一分支网络的样本中的第i类别样本的真实值和预测值,/>和fri分别表示当前训练轮次输入所述第二分支网络的样本中的第i类别样本的真实值和预测值,g1和g2为梯度平均模长。
7.一种电子设备,其特征在于,包括:
存储器和处理器,所述处理器与所述存储器连接;
所述存储器,用于存储程序;
所述处理器,用于调用存储于所述存储器中的程序,以执行如权利要求1-2中任一项所述的方法。
8.一种存储介质,其特征在于,其上存储有计算机程序,所述计算机程序被处理器运行时,执行如权利要求1-2中任一项所述的方法。
CN202110402004.2A 2021-04-14 2021-04-14 网络模型训练方法、装置、文本分类模型及网络模型 Active CN113077051B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110402004.2A CN113077051B (zh) 2021-04-14 2021-04-14 网络模型训练方法、装置、文本分类模型及网络模型

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110402004.2A CN113077051B (zh) 2021-04-14 2021-04-14 网络模型训练方法、装置、文本分类模型及网络模型

Publications (2)

Publication Number Publication Date
CN113077051A CN113077051A (zh) 2021-07-06
CN113077051B true CN113077051B (zh) 2024-01-26

Family

ID=76618697

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110402004.2A Active CN113077051B (zh) 2021-04-14 2021-04-14 网络模型训练方法、装置、文本分类模型及网络模型

Country Status (1)

Country Link
CN (1) CN113077051B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114240101A (zh) * 2021-12-02 2022-03-25 支付宝(杭州)信息技术有限公司 一种风险识别模型的验证方法、装置以及设备

Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
DE102009009904A1 (de) * 2009-02-20 2009-10-15 Daimler Ag Verfahren zur Identifizierung von Objekten
CN109993220A (zh) * 2019-03-23 2019-07-09 西安电子科技大学 基于双路注意力融合神经网络的多源遥感图像分类方法
CN110472665A (zh) * 2019-07-17 2019-11-19 新华三大数据技术有限公司 模型训练方法、文本分类方法及相关装置
CN110489545A (zh) * 2019-07-09 2019-11-22 平安科技(深圳)有限公司 文本分类方法及装置、存储介质、计算机设备
CN111144448A (zh) * 2019-12-09 2020-05-12 江南大学 基于多尺度注意力卷积编码网络的视频弹幕情感分析方法
CN111966831A (zh) * 2020-08-18 2020-11-20 创新奇智(上海)科技有限公司 一种模型训练方法、文本分类方法、装置及网络模型
CN112115826A (zh) * 2020-09-08 2020-12-22 成都奥快科技有限公司 一种基于双边分支网络的人脸活体检测方法及系统

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US8611677B2 (en) * 2008-11-19 2013-12-17 Intellectual Ventures Fund 83 Llc Method for event-based semantic classification

Patent Citations (7)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
DE102009009904A1 (de) * 2009-02-20 2009-10-15 Daimler Ag Verfahren zur Identifizierung von Objekten
CN109993220A (zh) * 2019-03-23 2019-07-09 西安电子科技大学 基于双路注意力融合神经网络的多源遥感图像分类方法
CN110489545A (zh) * 2019-07-09 2019-11-22 平安科技(深圳)有限公司 文本分类方法及装置、存储介质、计算机设备
CN110472665A (zh) * 2019-07-17 2019-11-19 新华三大数据技术有限公司 模型训练方法、文本分类方法及相关装置
CN111144448A (zh) * 2019-12-09 2020-05-12 江南大学 基于多尺度注意力卷积编码网络的视频弹幕情感分析方法
CN111966831A (zh) * 2020-08-18 2020-11-20 创新奇智(上海)科技有限公司 一种模型训练方法、文本分类方法、装置及网络模型
CN112115826A (zh) * 2020-09-08 2020-12-22 成都奥快科技有限公司 一种基于双边分支网络的人脸活体检测方法及系统

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
BBN:Bilateral-Branch Network With Cumulative Learning for Long-Tailed_Visual_Recognition;Boyan Zhou et al;《2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition》;第9716-9725页 *

Also Published As

Publication number Publication date
CN113077051A (zh) 2021-07-06

Similar Documents

Publication Publication Date Title
US11270225B1 (en) Methods and apparatus for asynchronous and interactive machine learning using word embedding within text-based documents and multimodal documents
CN111753092B (zh) 一种数据处理方法、模型训练方法、装置及电子设备
CN111353303B (zh) 词向量构建方法、装置、电子设备及存储介质
CN110659367B (zh) 文本分类号的确定方法、装置以及电子设备
CN102929906B (zh) 基于内容特征和主题特征的文本分组聚类方法
CN112685539B (zh) 基于多任务融合的文本分类模型训练方法和装置
CN112749274A (zh) 基于注意力机制和干扰词删除的中文文本分类方法
CN112948676A (zh) 文本特征提取模型的训练方法、文本推荐方法及装置
CN112364947B (zh) 一种文本相似度计算方法和装置
CN110968692A (zh) 一种文本分类方法及系统
KR20230107558A (ko) 모델 트레이닝, 데이터 증강 방법, 장치, 전자 기기 및 저장 매체
CN113077051B (zh) 网络模型训练方法、装置、文本分类模型及网络模型
CN115146068A (zh) 关系三元组的抽取方法、装置、设备及存储介质
CN115374845A (zh) 商品信息推理方法和装置
CN116227467A (zh) 模型的训练方法、文本处理方法及装置
CN115392357A (zh) 分类模型训练、标注数据样本抽检方法、介质及电子设备
CN112989790B (zh) 基于深度学习的文献表征方法及装置、设备、存储介质
CN112685594B (zh) 基于注意力的弱监督语音检索方法及系统
CN113535912A (zh) 基于图卷积网络和注意力机制的文本关联方法及相关设备
CN112732863A (zh) 电子病历标准化切分方法
CN107748801A (zh) 新闻推荐方法、装置、终端设备及计算机可读存储介质
CN115271045A (zh) 一种基于机器学习的神经网络模型优化方法及系统
CN113590813B (zh) 文本分类方法、推荐方法、装置及电子设备
CN113704466B (zh) 基于迭代网络的文本多标签分类方法、装置及电子设备
CN117235271A (zh) 信息抽取方法、装置、计算机存储介质及电子设备

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
TA01 Transfer of patent application right

Effective date of registration: 20230322

Address after: 528311 no.l203 Country Garden International Club, Beijiao Town, Shunde District, Foshan City, Guangdong Province

Applicant after: Zero Hole Technology Co.,Ltd.

Address before: 528000 a2-05, 2nd floor, building A1, 1 Panpu Road, Biguiyuan community, Beijiao Town, Shunde District, Foshan City, Guangdong Province (for office use only) (residence declaration)

Applicant before: GUANGDONG BOZHILIN ROBOT Co.,Ltd.

TA01 Transfer of patent application right
TA01 Transfer of patent application right

Effective date of registration: 20231226

Address after: Room 01, Floor 9, Xinlihua Center Building, 151 Mount Taishan Road, Jianye District, Nanjing, Jiangsu 210004

Applicant after: Nanjing Lingdong Shuzhi Technology Co.,Ltd.

Address before: 528311 no.l203 Country Garden International Club, Beijiao Town, Shunde District, Foshan City, Guangdong Province

Applicant before: Zero Hole Technology Co.,Ltd.

TA01 Transfer of patent application right
GR01 Patent grant
GR01 Patent grant