CN113222035B - 基于强化学习和知识蒸馏的多类别不平衡故障分类方法 - Google Patents

基于强化学习和知识蒸馏的多类别不平衡故障分类方法 Download PDF

Info

Publication number
CN113222035B
CN113222035B CN202110549644.6A CN202110549644A CN113222035B CN 113222035 B CN113222035 B CN 113222035B CN 202110549644 A CN202110549644 A CN 202110549644A CN 113222035 B CN113222035 B CN 113222035B
Authority
CN
China
Prior art keywords
class
sample
cluster
samples
network
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
CN202110549644.6A
Other languages
English (en)
Other versions
CN113222035A (zh
Inventor
张新民
范赛特
魏驰航
宋执环
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Zhejiang University ZJU
Original Assignee
Zhejiang University ZJU
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Zhejiang University ZJU filed Critical Zhejiang University ZJU
Priority to CN202110549644.6A priority Critical patent/CN113222035B/zh
Publication of CN113222035A publication Critical patent/CN113222035A/zh
Application granted granted Critical
Publication of CN113222035B publication Critical patent/CN113222035B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/22Matching criteria, e.g. proximity measures
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Abstract

本发明公开了一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,该方法结合层次聚类、知识蒸馏和强化学习等算法,用来解决多类别不平衡故障分类问题。对于多类别故障分类问题,首先针对不平衡问题中同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点使用层次聚类将多类别聚类为几个簇类,根据不同簇类分别建立学生网络进行细粒度化分类,再用知识蒸馏方法兼顾全局信息,最后结合强化学习迭代学习样本权重,从而提高不平衡故障分类效果。在此过程中,需要设计合理的奖励函数配合细粒度知识蒸馏分类器去优化样本权重。相比其他对比方法,本发明的方法有良好的效果和适用性。

Description

基于强化学习和知识蒸馏的多类别不平衡故障分类方法
技术领域
本发明属于工业过程监测领域,尤其涉及一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法。
背景技术
在机器学习或深度学习分类中,类别样本数量不平衡是一个非常普遍的问题,广泛存在于各个领域,例如生物信息学,智能电网,医学成像,故障诊断。大多数现有的分类方法都基于以下假设:观测数据的基本分布是相对均衡的。但是,实际工业数据集通常会违反此假设,并呈现出偏斜的分布甚至是极度不平衡的类别样本数量分布。例如,数据驱动的故障分类是工业过程监测的重要组成部分,由于故障发生的频率不同,它们表现出不平衡的偏斜分布。在这种情况下,如果假定所有类别都具有同等的重要性,则分类器会倾向于分对频繁(多数)类别的样本而不是不频繁(少数)类别的样本。因此,迫切需要提出恰当的方法来消除不平衡的类别分布的负面影响,而又不过度牺牲任何多数类别或少数类别的准确性。
发明内容
本发明的目的在于提供一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其能对多数类不平衡的分类问题,获得较好的故障分类结果。具体技术方案如下:
一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,包括以下步骤:
S1:离线建模
S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据和正常数据;
S1.2:计算每个类别特征中心点
Figure BDA0003074902770000011
S1.3:通过基于Ward-Linkage的层次聚类,将同质类的类别特征中心分配在一个簇类中,最终将所有类别特征中心uk分配到C个簇类中;然后根据类别特征中心的聚类结果分配每个类别的所有样本到对应簇类中;
S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个簇类中样本进行训练,其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数;
S1.5:基于所述的教师网络的预训练参数,采用所有样本,通过微调技术,训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标;
S1.6:训练完教师网络之后,所有学生网络都通过综合交叉熵损失一起训练;根据包含所述软目标和硬目标的综合损失,采用每个簇类中样本,通过微调技术进行训练,将所有学生网络得到的logit拼接在一起,组成学生网络的综合logit;各个学生网络拼接的每个logit中值的位置对应于原先类别顺序;所述硬目标为样本的真实标签;
S1.7:使用强化学习结合知识蒸馏的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建损失函数;
S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;
S2:在线应用测试
S2.1:获取在线样本;
S2.2:将在线样本分类到S1.3层次聚类得到的C个簇类的其中一个簇类中;
S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重wt,并用加权的softmax函数计算属于各个类别的概率,选取概率最大的类别作为在线样本的类别。
进一步地,所述S1.2中的特征中心点计算具体为:
Figure BDA0003074902770000021
其中,uk为类别k的特征中心点,xi表示第i个样本,gk表示类别k的所有样本的集合,|gk|表示类别k的样本数量。
进一步地,所述S1.3具体为:
基于Ward-Linkage进行层次聚类,直到最后所有样本都聚成一个簇类。主要有以下步骤:
①在初始化过程中,将每个样本独立的归为一个簇类中;计算每两个簇类中心之间的相似度;
②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;
③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度;所述簇类的中心为一个簇类的所有样本的平均值;
④重复②和③,直到所有样本归为一个簇类,聚类算法结束;
⑤选择所需的最终聚类后的簇类数,作为最终的簇类数,即C的值。
进一步地,所述S1.4中的高斯伯努利限制玻尔兹曼机具有两层全连接的结构,分为可见单元
Figure BDA0003074902770000022
和隐藏单元
Figure BDA0003074902770000023
p和d分别为可见单元和隐藏单元的数量;联合配置v,h的能量函数表示为:
Figure BDA0003074902770000031
其中vi∈{0,1},hj∈{0,1};θ={W,a,b}是高斯伯努利限制玻尔兹曼机的结构参数;wij是连接可见单元i和隐藏单元j的对称权重;ai和bj分别是可见偏差和隐藏偏差;σi是可见单元i的高斯噪声的标准差;
所述高斯伯努利限制玻尔兹曼机的目标函数为:
Figure BDA0003074902770000032
其中,xi为第i维的输入数据,p(xi,h|θ)为xi和h的联合概率密度函数;
通过随机梯度上升方法最大化以找到最佳θ,完成对所述高斯伯努利限制玻尔兹曼机的训练:
Figure BDA0003074902770000033
其中,θ中的w和b用作知识蒸馏神经网络第一层的初始参数。
进一步地,所述S1.5通过梯度下降法训练教师网络,其中,教师网络的交叉熵损失函数如下:
Figure BDA0003074902770000034
其中
Figure BDA0003074902770000035
Figure BDA0003074902770000036
是教师网络的输入样本xi的输出logit。
进一步地,所述S1.6通过梯度下降法训练学生网络,其中,学生网络的交叉熵损失函数如下:
Figure BDA0003074902770000037
其中
Figure BDA0003074902770000038
Figure BDA0003074902770000039
是学生网络的输入样本xi的输出logit。
进一步地,所述S1.7具体为:
设定πθ为一种带参数θ的参数化随机平稳策略。策略中包含动作。这个平稳策略用来把每个状态st映射到动作概率分布at。rt为第t次迭代所获得的奖励。强化学习迭代用于学习样本权重的过程如下所示:
(1)初始化样本权重(动作):
Figure BDA00030749027700000310
a0=w0=[w0,0,…,w0,b]
其中w0,i是样本i在第一次迭代时的样本权重。
Figure BDA0003074902770000041
是样本i所在类别k的不平衡率。Nmax是样本最多的类别的样本数量。Nk是类别k的样本数量。a0为πθ的初始化动作。样本权重的更新过程可以形式化成一个序列决策问题。通过迭代t=1,2,…,T次(T为最大迭代次数),wt,i为样本i在第t次迭代时的样本权重。
(2)计算老师-学生网络的加权交叉熵损失:
Figure BDA0003074902770000042
Figure BDA0003074902770000043
其中
Figure BDA0003074902770000044
Figure BDA0003074902770000045
分别是在第t次迭代时教师网络和学生网络的损失。
Figure BDA0003074902770000046
Figure BDA0003074902770000047
Figure BDA0003074902770000048
Figure BDA0003074902770000049
Figure BDA00030749027700000410
分别表示类别k的样本i经过教师网络和学生网络的输出logits。
Figure BDA00030749027700000411
C是簇类数量。K是类别数量。b是批次大小。
(3)计算奖励rt。设计的奖励如下:
rt(st,at)=F1(t),
其中F1(t)代表学生网络在第t次迭代的F1得分值。状态、动作(每次迭代的样本权重)和相应的奖励都存储在经验回放中。
(4)获取状态st。状态st是三部分的拼接。第一部分是样本部分xt。第二部分是教师网络的
Figure BDA00030749027700000412
第三部分是学生网络的
Figure BDA00030749027700000413
状态st表示如下:
Figure BDA00030749027700000414
(5)更新策略πθ。样本权重wt是策略πθ的动作。策略πθ采用了策略梯度损失函数。损失函数定义为:
Figure BDA00030749027700000415
Figure BDA00030749027700000416
用梯度下降进行求解。样本权重wt不断更新直到奖励收敛。
进一步地,将在线样本分到对应簇类中,所述S2.2具体为:
在线样本分类到对应的簇类中,其公式如下:
Figure BDA0003074902770000051
其中c为在线样本的簇类类,
Figure BDA0003074902770000052
为簇类c的特征中心,xonline为在线样本。
进一步地,所述S2.3具体为:
用强化学习学习到的πθ得到在线样本的权重:
wonline=πθ(xonline)
在线样本经过学生网络得到的输出为:
logit=wonlineft(xonline)
其中,ft(·)表示学生网络;对输出进行softmax得到每个类别的概率,再取最大概率所对应的类别为分类类别:
Figure BDA0003074902770000053
本发明的有益效果如下:
本发明对于多类别的不平衡故障分类问题具有独特的效果,由于同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点,使得本发明在通过聚类方法得到的簇类的基础上,更加细粒度的通过多个学生网络来解决不平衡的故障分类问题。同时通过教师网络的引导使得各个学生网络不仅能学习到簇类中同质类别的决策边界,也能学习到总体的数据分布信息。不仅如此,进一步结合强化学习,不断结合识蒸馏网络进行迭代,结合样本类别数量与样本在分布中的作用,获取样本权重,增加少数类样本的权重,减少多数类样本的权重,使得故障分类效果更好,准确率更高。
附图说明
图1为本发明方法采用的基础方法的结构图;
图2为本发明方法的结构图;
图3为使用的数据集生成的工艺流程图;
图4为使用的数据集样本数量分布示意图;
图5为本发明方法训练奖励和测试G-mean的曲线图;
图6为通过层次聚类得到的树状图;
图7为所有对比方法10次运行后绘制的箱线图;
图8为分类最后一层隐层的数据通过t-SNE降维后的2D映射图。(a)为MLP最后一层隐层输出的2D映射图;(b)为SMOTE-MLP最后一层隐层输出的2D映射图;(c)为Cosen-MLP最后一层隐层输出的2D映射图;(d)为CSDBN-DE最后一层隐层输出的2D映射图;(e)为TU-MLP最后一层隐层输出的2D映射图;(f)为KD最后一层隐层输出的2D映射图;(g)为本发明最后一层隐层输出的2D映射图。
具体实施方式
下面根据附图和优选实施例详细描述本发明,本发明的目的和效果将变得更加明白,应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
针对多类别的不平衡分布问题,本发明提出了一种新的基于强化学习和知识蒸馏的多类别不平衡故障分类方法。
本发明针对多类别的不平衡分布下的故障分类问题,划定离线和在线数据集,首先使用知识蒸馏方法进行分类或识别故障的类别。再针对不平衡问题中同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点,采用层次聚类方法,根据类别中心点的聚类结果,将所有类别样本进行聚类,从而获得细粒度簇类。最后针对每个簇类进行细粒度故障分类。因此,对于某个簇类中,都将建立一个学生模型,最后进行拼接,进行多学生模型一起优化。在教师模型的全局信息的指导下,并结合多学生模型细粒度的进行故障分类。不仅如此,进一步结合强化学习,不断结合识蒸馏网络进行迭代,结合样本类别数量与样本在分布中的作用,获取样本权重,增加少数类样本的权重,减少多数类样本的权重,使得故障分类效果更好,相比其他现存方法,本发明的方法有良好的效果和适用性。
如图1和2所示,本发明的基于强化学习和知识蒸馏的多类别不平衡故障分类方法,包括以下步骤:
S1:离线建模
S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据样本和正常数据样本;
S1.2:计算每个类别特征中心点
Figure BDA0003074902770000061
Figure BDA0003074902770000062
其中,gk表示类别k的所有样本的集合,|gk|表示类别k的样本数量。
S1.3:通过基于Ward-Linkage的层次聚类,将所有类别特征中心uk分配到C个簇类中。同质类的类别特征中心将被分配在一个簇类中。根据类别特征中心的聚类结果分配每个类别的所有样本到对应簇类中。层次聚类使用以逐次聚合的方式(AgglomerativeClustering),将样本分类,直到最后所有样本都聚成一个簇类。主要有以下步骤:
①在初始化过程中,将每个样本独立的归为一个簇类中。计算每两个簇类中心之间的距离(也称为相似度);
②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;
③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度(一个簇类的所有样本的平均值代表该簇类的中心);
④重复②和③,直到所有样本归为一个簇类,聚类算法结束;
⑤选择所需的最终聚类后的簇类数,作为最终簇类数,即C的值。
整个聚类过程实际上是在构建一棵树。在构建过程中,第②步将设置一个阈值。当两个最近的簇类中心之间的距离大于此阈值时,则认为迭代已终止。另一个关键步骤是第三步,有很多方法可以确定两个聚类之间的相似性。常用的相似性度量包括Ward Linkage,Single Linkage、Complete Linkage和Average Linkage策略。在发明中,由于WardLinkage策略通常提供较高的聚类性能,因此采用Ward Linkage策略。Ward Linkage由两个聚类之间的平方误差和ESS计算得出,其目标函数是每次合并后ESS的最小增量,ESS定义如下:
Figure BDA0003074902770000071
S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个聚类中样本进行训练。其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数。
高斯伯努利限制玻尔兹曼机具有两层全连接的结构,分为可见单元(或数据变量)
Figure BDA0003074902770000072
和隐藏单元(或潜在变量)
Figure BDA0003074902770000073
p和d分别为可见单元和隐藏单元的数量。高斯伯努利限制玻尔兹曼机既是生成模型,也是基于能量的模型。联合配置v,h的能量函数表示为:
Figure BDA0003074902770000074
其中vi∈{0,1},hj∈{0,1}。θ={W,a,b}是高斯伯努利限制玻尔兹曼机的结构参数。wij是连接可见单元i和隐藏单元j的对称权重;ai和bj分别是可见偏差和隐藏偏差。σi是可见单元i的高斯噪声的标准差。v和h的联合概率采用以下形式:
Figure BDA0003074902770000081
通常,将导致高(低)能量的配置(v,h)分别设置为低(高)概率计算的一部分。所有可见单元或隐藏单元都是有条件的独立单元。因此,高斯伯努利限制玻尔兹曼机的可见节点和隐藏节点的概率分布可以由下式给出:
Figure BDA0003074902770000082
Figure BDA0003074902770000083
其中σ(x)是逻辑斯蒂sigmoid函数
Figure BDA0003074902770000084
N(·|μ,σ2)是均值为μ,方差为σ2高斯概率密度函数。高斯伯努利限制玻尔兹曼机的优化目标是最大程度的适应数据分布。因此,目标函数是通过输入数据集
Figure BDA0003074902770000085
获得(d是输入数据的特征维度,m是类别数目),如下所示:
Figure BDA0003074902770000086
现有的大多数基于高斯伯努利限制玻尔兹曼机的模型都是通过对比差异(CD)学习策略来处理数据非线性的,该策略将实值数据映射到隐特征空间。对数似然估计值可通过随机梯度上升方法最大化以找到最佳θ:
Figure BDA0003074902770000087
通过迭代获得高斯伯努利限制玻尔兹曼机的最优参数θ。θ中的w和b用作知识蒸馏神经网络第一层的初始参数。
S1.5:基于所述的教师网络的预训练参数,采用所有样本,通过微调技术,通过梯度下降法训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标。计算教师网络ft的交叉熵损失函数如下:
Figure BDA0003074902770000088
其中
Figure BDA0003074902770000089
Figure BDA00030749027700000810
是教师网络的输入样本xi的输出logit。
S1.6:训练完教师网络之后,所有的学生网络都通过综合交叉熵损失一起训练。根据包含了软目标(教师网络的logit)和硬目标(真实标签)的综合损失,采用每个簇类中样本,通过微调技术,通过梯度下降法进行训练所有学生网络。学生网络的综合logit由所有学生网络的logit拼接在一起。各个学生网络拼接的每个logit中值的位置对应于原先类别顺序。学生网络ft的综合损失,含了软目标(教师网络的logit)和硬目标(真实标签)的综合损失,定义如下:
Figure BDA0003074902770000091
其中
Figure BDA0003074902770000092
Figure BDA0003074902770000093
是学生网络的输入样本xi的输出logit。
S1.7:使用强化学习结合知识蒸馏的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建损失函数;
S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;
设定πθ为一种带参数θ的参数化随机平稳策略(由动作组成)。这个平稳策略用来把每个状态st映射到动作概率分布at。rt为第t次迭代所获得的奖励。强化学习迭代用于学习样本权重的过程如下所示:
(1)初始化样本权重(动作):
Figure BDA0003074902770000094
a0=w0=[w0,0,…,w0,b],
其中w0,i是样本i在第一次迭代时的样本权重。
Figure BDA0003074902770000095
是样本i所在类别k的不平衡率。Nmax是样本最多的类别的样本数量。Nk是类别k的样本数量。a0为πθ的初始化动作。样本权重的更新过程可以形式化成一个序列决策问题。通过迭代t=1,2,…,T次(T为最大迭代次数),wt,i为样本i在第t次迭代时的样本权重。
(2)计算老师-学生网络的加权交叉熵损失:
Figure BDA0003074902770000096
Figure BDA0003074902770000097
其中
Figure BDA0003074902770000098
Figure BDA0003074902770000099
分别是在第t次迭代时教师网络和学生网络的损失。
Figure BDA00030749027700000910
Figure BDA00030749027700000911
Figure BDA00030749027700000912
Figure BDA00030749027700000913
分别表示类别k的样本i经过教师网络和学生网络的输出logits。
Figure BDA00030749027700000914
C是簇类数量。K是类别数量。b是批次大小。
(3)计算奖励rt。设计的奖励如下:
rt(st,at)=F1(t),
其中F1(t)代表学生网络在第t次迭代的F1得分值。状态、动作(每次迭代的样本权重)和相应的奖励都存储在经验回放中。
(4)获取状态st。状态st是三部分的拼接。第一部分是样本部分xt。第二部分是教师网络的
Figure BDA0003074902770000101
第三部分是学生网络的
Figure BDA0003074902770000108
状态st表示如下:
Figure BDA0003074902770000102
(5)更新策略πθ。样本权重wt是策略πθ的动作。策略πθ采用了策略梯度损失函数。损失函数定义为:
Figure BDA0003074902770000103
Figure BDA0003074902770000104
用梯度下降进行求解。样本权重w不断更新直到奖励收敛。
S2:在线应用测试
S2.1:获取在线样本;
S2.2:基于S1.3层次聚类得到的簇类信息,将在线样本分类到对应的簇类中。在线样本分类到对应的簇类中,其公式如下:
Figure BDA0003074902770000105
其中c为在线样本的簇类类,
Figure BDA0003074902770000106
为簇类c的特征中心,xonline为在线样本。
S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重wt,并用加权的softmax函数计算属于各个类别的概率,选取概率最大的类别作为在线样本的类别,具体为:
用强化学习学习到的πθ得到在线样本的权重:
wonline=πθ(xonline),
在线样本经过学生网络得到的输出为:
logit=Wonlineft(xonline),
其中,ft(·)表示学生网络;对输出进行softmax得到每个类别的概率,再取最大概率所对应的类别为分类类别:
Figure BDA0003074902770000107
以下结合一个具体的工业例子来说明本发明的有效性。使用田纳西州伊士曼(TE)工业基准来评估所提出的方法。TE过程是由伊士曼化学公司根据实际化学过程开发的工业仿真平台,已广泛用于测试过程监控和故障诊断方法的有效性。TE过程的流程如图3所示。
表1:每个故障类别TE过程训练样本数量设定
故障 训练数据 故障 训练数据 故障 训练数据 故障 训练数据
IDV1 7239 IDV8 3595 IDV15 1785 IDV22 886
IDV2 6550 IDV9 3253 IDV16 1615 IDV23 802
IDV3 5927 IDV10 2943 IDV17 1461 IDV24 726
IDV4 5363 IDV11 2663 IDV18 1322 IDV25 657
IDV5 4852 IDV12 2410 IDV19 1197 IDV26 594
IDV6 4390 IDV13 2180 IDV20 1083 IDV27 538
IDV7 3973 IDV14 1973 IDV21 980 IDV28 486
TE数据中正常样本数量为8000。表1为每个故障类别TE过程训练样本数量设定,测试样本数量设定为2000。TE数据的过程变量由34维,故障类别有28个,如图4所示。选取对比方法有MLP(多层感知机)、SMOTE-MLP(合成少数类过采样技术的MLP)、CoSen-MLP(代价敏感MLP)、CSDBN-DE(差分演化的代价敏感深度信念网络)、TU-MLP(可训练的降采样器结合MLP)、KD(知识蒸馏)和本发明(基于强化学习和知识蒸馏的多类别不平衡故障分类方法)。
通过基于强化学习和知识蒸馏的多类别不平衡故障分类方法在TE过程训练样本上训练得到各个学生模型。通过离线训练得到的学生模型对在线样本(测试集)进行预测,得到的结果如表2所示:
表2:在TE过程数据上各个对比方法的分类性能
Figure BDA0003074902770000111
Figure BDA0003074902770000121
从表2中可以看出,所提出的基于强化学习和知识蒸馏的多类别不平衡故障分类方法的F1随着不平衡率的上升在更多的类别上优于对比方法,且随着不平衡程度的提高,本发明相比其他对比方法的优势越明显。综合所有对比方法在所有类别上的结果,本发明提出的方法可以在最终的Macro-F1和Gmean指标上明显优于其他方法。
图5为本发明的训练奖励和测试G-mean的曲线图,可以看出算法收敛较为稳定,并能够根据设定达到较优的性能。图6为用层次聚类方法得到的树状图,虚线为决策线,总共分为3个簇类。图7为所有对比方法10次运行后绘制的箱线图,本发明相对其他对比方法性能更好,更稳定。
为了方法优越性更加直观和明显,绘制了各个分类模型最后一层隐藏的输出经过t-SNE后得到的2D图,如图8所示。图8(g)为本发明的2D映射图,能够从图中看出,经过基于强化学习和知识蒸馏的多类别不平衡故障分类方法,获得降维2D图中的各个类别的边界更加明显,这充分体现了算法的分类性能得到了提高。
如上所述,本发明中所提的基于强化学习和知识蒸馏的多类别不平衡故障分类方法,具有令人满意的分类效果。

Claims (9)

1.一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,包括以下步骤:
S1:离线建模
S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据和正常数据;
S1.2:计算每个类别特征中心点
Figure FDA0003356692720000011
S1.3:通过基于Ward-Linkage的层次聚类,将同质类的类别特征中心分配在一个簇类中,最终将所有类别特征中心uk分配到C个簇类中;然后根据类别特征中心的聚类结果分配每个类别的所有样本到对应簇类中;
S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个簇类中样本进行训练,其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数;所述预训练参数作为首次迭代的初始参数;
S1.5:基于所述的教师网络的上一次训练参数,采用所有样本,通过微调技术,训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标;
S1.6:所有学生网络都通过综合交叉熵损失一起训练;根据包含所述软目标和硬目标的综合损失,采用每个簇类中样本,通过微调技术进行训练,将所有学生网络得到的logit拼接在一起,组成学生网络的综合logit;各个学生网络拼接的每个logit中值的位置对应于原先类别顺序;所述硬目标为样本的真实标签;
S1.7:使用强化学习结合知识蒸馏中的教师网络和各个学生网络的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建知识蒸馏的损失函数;
S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;
S2:在线应用测试
S2.1:获取在线样本;
S2.2:将在线样本分类到S1.3层次聚类得到的C个簇类的其中一个簇类中;
S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重wt,并用加权的softmax函数计算属于各个类别的概率,选取概率最大的类别作为在线样本的类别。
2.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.2中的特征中心点计算具体为:
Figure FDA0003356692720000021
其中,uk为类别k的特征中心点,xi表示第i个样本,gk表示类别k的所有样本的集合,|gk|表示类别k的样本数量。
3.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.3具体为:
基于Ward-Linkage进行层次聚类,直到最后所有样本都聚成一个簇类,有以下步骤:
①在初始化过程中,将每个样本独立的归为一个簇类中;计算每两个簇类中心之间的相似度;
②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;
③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度;所述簇类的中心为一个簇类的所有样本的平均值;
④重复②和③,直到所有样本归为一个簇类,聚类算法结束;
⑤选择所需的最终聚类后的簇类数,即C的值。
4.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.4中的高斯伯努利限制玻尔兹曼机具有两层全连接的结构,分为可见单元
Figure FDA0003356692720000022
和隐藏单元
Figure FDA0003356692720000023
p和d分别为可见单元和隐藏单元的数量;联合配置v,h的能量函数表示为:
Figure FDA0003356692720000024
其中vi∈{0,1},hj∈{0,1};θ={W,a,b}是高斯伯努利限制玻尔兹曼机的结构参数;wij是连接可见单元i和隐藏单元j的对称权重;ai和bj分别是可见偏差和隐藏偏差;σi是可见单元i的高斯噪声的标准差;
所述高斯伯努利限制玻尔兹曼机的目标函数为:
Figure FDA0003356692720000025
其中,xi为第i维的输入数据,p(xi,h|θ)为xi和h的联合概率密度函数;
通过随机梯度上升方法最大化以找到最佳θ,完成对所述高斯伯努利限制玻尔兹曼机的训练:
Figure FDA0003356692720000031
其中,θ中的w和b用作知识蒸馏神经网络第一层的初始参数。
5.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.5通过梯度下降法训练教师网络,其中,教师网络的交叉熵损失函数如下:
Figure FDA0003356692720000032
其中
Figure FDA0003356692720000033
Figure FDA0003356692720000034
是教师网络的输入样本xi的输出logit。
6.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.6通过梯度下降法训练学生网络,其中,学生网络的交叉熵损失函数如下:
Figure FDA0003356692720000035
其中
Figure FDA0003356692720000036
Figure FDA0003356692720000037
是学生网络的输入样本xi的输出logit。
7.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.7具体为:
设定πθ为一种带参数θ的参数化随机平稳策略,策略中包含动作;这个平稳策略用来把每个状态st映射到动作概率分布at;rt为第t次迭代所获得的奖励;强化学习迭代用于学习样本权重的过程如下所示:
(1)初始化样本权重:
Figure FDA0003356692720000038
a0=w0=[w0,0,…,w0,b],
其中w0,i是样本i在第一次迭代时的样本权重;
Figure FDA0003356692720000039
是样本i所在类别k的不平衡率;Nmax是样本最多的类别的样本数量;Nk是类别k的样本数量;a0为πθ的初始化动作;样本权重的更新过程形式化成一个序列决策问题;通过迭代t=1,2,…,T次,T为最大迭代次数,wt,i为样本i在第t次迭代时的样本权重;
(2)计算教师-学生网络的加权交叉熵损失:
Figure FDA0003356692720000041
Figure FDA0003356692720000042
其中,
Figure FDA0003356692720000043
Figure FDA0003356692720000044
分别是在第t次迭代时教师网络和学生网络的损失;
Figure FDA0003356692720000045
Figure FDA0003356692720000046
Figure FDA0003356692720000047
Figure FDA0003356692720000048
分别表示类别k的样本i经过教师网络和学生网络的输出logit;
Figure FDA0003356692720000049
C是簇类数量;K是类别数量;b是批次大小;
(3)计算奖励rt,设计的奖励如下:
rt(st,at)=F1(t),
其中F1(t)代表学生网络在第t次迭代的F1得分值,状态、每次迭代的样本权重和相应的奖励都存储在经验回放中;
(4)获取状态st;状态st是三部分的拼接;第一部分是样本部分xt;第二部分是教师网络的
Figure FDA00033566927200000410
第三部分是学生网络的
Figure FDA00033566927200000411
状态st表示如下:
Figure FDA00033566927200000412
(5)更新策略πθ;样本权重wt是策略πθ的动作,策略πθ采用了策略梯度损失函数,损失函数定义为:
Figure FDA00033566927200000413
Figure FDA00033566927200000414
用梯度下降进行求解,样本权重wt不断更新直到奖励收敛。
8.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S2.2中将在线样本分类到对应的簇类中的计算公式如下:
Figure FDA00033566927200000415
其中,c为在线样本的簇类类别,
Figure FDA00033566927200000416
为簇类c的特征中心,xonline为在线样本。
9.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S2.3具体为:
用强化学习学习到的πθ得到在线样本的权重:
Wonline=πθ(xonline).
在线样本经过学生网络得到的输出为:
logit=wonlineft(xonline),
其中,ft(·)表示学生网络;对输出进行softmax得到每个类别的概率,再取最大概率所对应的类别为分类类别:
Figure FDA0003356692720000051
CN202110549644.6A 2021-05-20 2021-05-20 基于强化学习和知识蒸馏的多类别不平衡故障分类方法 Active CN113222035B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202110549644.6A CN113222035B (zh) 2021-05-20 2021-05-20 基于强化学习和知识蒸馏的多类别不平衡故障分类方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202110549644.6A CN113222035B (zh) 2021-05-20 2021-05-20 基于强化学习和知识蒸馏的多类别不平衡故障分类方法

Publications (2)

Publication Number Publication Date
CN113222035A CN113222035A (zh) 2021-08-06
CN113222035B true CN113222035B (zh) 2021-12-31

Family

ID=77093626

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202110549644.6A Active CN113222035B (zh) 2021-05-20 2021-05-20 基于强化学习和知识蒸馏的多类别不平衡故障分类方法

Country Status (1)

Country Link
CN (1) CN113222035B (zh)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114638336B (zh) * 2021-12-26 2023-09-22 海南大学 聚焦于陌生样本的不平衡学习
CN115908955B (zh) * 2023-03-06 2023-06-20 之江实验室 基于梯度蒸馏的少样本学习的鸟类分类系统、方法与装置

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108875772A (zh) * 2018-03-30 2018-11-23 浙江大学 一种基于堆叠稀疏高斯伯努利受限玻尔兹曼机和强化学习的故障分类模型及方法
CN108875771A (zh) * 2018-03-30 2018-11-23 浙江大学 一种基于稀疏高斯伯努利受限玻尔兹曼机和循环神经网络的故障分类模型及方法
CN111598216A (zh) * 2020-04-16 2020-08-28 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN112819159A (zh) * 2021-02-24 2021-05-18 清华大学深圳国际研究生院 一种深度强化学习训练方法及计算机可读存储介质

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20170032276A1 (en) * 2015-07-29 2017-02-02 Agt International Gmbh Data fusion and classification with imbalanced datasets

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108875772A (zh) * 2018-03-30 2018-11-23 浙江大学 一种基于堆叠稀疏高斯伯努利受限玻尔兹曼机和强化学习的故障分类模型及方法
CN108875771A (zh) * 2018-03-30 2018-11-23 浙江大学 一种基于稀疏高斯伯努利受限玻尔兹曼机和循环神经网络的故障分类模型及方法
CN111598216A (zh) * 2020-04-16 2020-08-28 北京百度网讯科技有限公司 学生网络模型的生成方法、装置、设备及存储介质
CN112819159A (zh) * 2021-02-24 2021-05-18 清华大学深圳国际研究生院 一种深度强化学习训练方法及计算机可读存储介质

Non-Patent Citations (3)

* Cited by examiner, † Cited by third party
Title
《Diversity-driven knowledge distillation for financial trading using Deep Reinforcement Learning》;Avraam Tsantekidis,et al;《Neural Networks》;20210317;第193-202页 *
《Periodic Intra-Ensemble Knowledge Distillation for Reinforcement Learning》;Zhang Wei Hong,et al;《arXiv:2002.00149v1》;20200201;第1-8页 *
《地铁车辆轮对踏面故障诊断系统研究与开发》;王锋涛;《中国优秀博硕士学位论文全文数据库(硕士) 工程科技Ⅱ辑》;20190115(第01期);第C033-555页 *

Also Published As

Publication number Publication date
CN113222035A (zh) 2021-08-06

Similar Documents

Publication Publication Date Title
CN110263227B (zh) 基于图神经网络的团伙发现方法和系统
Basirat et al. The quest for the golden activation function
Sarikaya et al. Deep belief nets for natural language call-routing
Hruschka et al. Extracting rules from multilayer perceptrons in classification problems: A clustering-based approach
CN102520341B (zh) 一种基于Bayes-KFCM算法的模拟电路故障诊断方法
Luo et al. Species-based particle swarm optimizer enhanced by memory for dynamic optimization
CN113222035B (zh) 基于强化学习和知识蒸馏的多类别不平衡故障分类方法
CN111079926B (zh) 基于深度学习的具有自适应学习率的设备故障诊断方法
Li et al. A bilevel learning model and algorithm for self-organizing feed-forward neural networks for pattern classification
Perez-Godoy et al. CO 2 RBFN: an evolutionary cooperative–competitive RBFN design algorithm for classification problems
CN112685504A (zh) 一种面向生产过程的分布式迁移图学习方法
Asadi et al. A bi-objective optimization method to produce a near-optimal number of classifiers and increase diversity in Bagging
CN107153837A (zh) 深度结合K‑means和PSO的聚类方法
CN113627471A (zh) 一种数据分类方法、系统、设备及信息数据处理终端
Yan et al. Trustworthiness evaluation and retrieval-based revision method for case-based reasoning classifiers
Li et al. Automatic design of machine learning via evolutionary computation: A survey
Wang et al. A novel restricted Boltzmann machine training algorithm with fast Gibbs sampling policy
Urgun et al. Composite power system reliability evaluation using importance sampling and convolutional neural networks
Poczeta et al. Analysis of fuzzy cognitive maps with multi-step learning algorithms in valuation of owner-occupied homes
KR20080078292A (ko) 영역 밀도 표현에 기반한 점진적 패턴 분류 방법
CN113222034B (zh) 基于知识蒸馏的细粒度多类别不平衡故障分类方法
CN108446718B (zh) 一种动态深度置信网络分析方法
CN115906959A (zh) 基于de-bp算法的神经网络模型的参数训练方法
Liu et al. Prediction of share price trend using FCM neural network classifier
Baruque et al. Hybrid classification ensemble using topology-preserving clustering

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant