CN113222035A - 基于强化学习和知识蒸馏的多类别不平衡故障分类方法 - Google Patents
基于强化学习和知识蒸馏的多类别不平衡故障分类方法 Download PDFInfo
- Publication number
- CN113222035A CN113222035A CN202110549644.6A CN202110549644A CN113222035A CN 113222035 A CN113222035 A CN 113222035A CN 202110549644 A CN202110549644 A CN 202110549644A CN 113222035 A CN113222035 A CN 113222035A
- Authority
- CN
- China
- Prior art keywords
- class
- cluster
- sample
- samples
- network
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
- 238000000034 method Methods 0.000 title claims abstract description 78
- 238000013140 knowledge distillation Methods 0.000 title claims abstract description 41
- 230000002787 reinforcement Effects 0.000 title claims abstract description 39
- 230000006870 function Effects 0.000 claims abstract description 28
- 230000008569 process Effects 0.000 claims abstract description 22
- 238000012549 training Methods 0.000 claims description 43
- 238000009826 distribution Methods 0.000 claims description 16
- 230000009471 action Effects 0.000 claims description 13
- 238000013507 mapping Methods 0.000 claims description 10
- 238000004519 manufacturing process Methods 0.000 claims description 8
- 238000012360 testing method Methods 0.000 claims description 8
- 238000005516 engineering process Methods 0.000 claims description 6
- 238000011478 gradient descent method Methods 0.000 claims description 6
- 238000004422 calculation algorithm Methods 0.000 claims description 5
- 238000009499 grossing Methods 0.000 claims description 5
- 238000004364 calculation method Methods 0.000 claims description 4
- 238000013528 artificial neural network Methods 0.000 claims description 3
- 230000000694 effects Effects 0.000 abstract description 10
- 238000010586 diagram Methods 0.000 description 5
- 238000012733 comparative method Methods 0.000 description 4
- 238000012544 monitoring process Methods 0.000 description 3
- 238000003745 diagnosis Methods 0.000 description 2
- 238000004821 distillation Methods 0.000 description 2
- 238000005457 optimization Methods 0.000 description 2
- 238000004220 aggregation Methods 0.000 description 1
- 230000002776 aggregation Effects 0.000 description 1
- 230000009286 beneficial effect Effects 0.000 description 1
- 238000001311 chemical methods and process Methods 0.000 description 1
- 238000013145 classification model Methods 0.000 description 1
- 238000010276 construction Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 238000002059 diagnostic imaging Methods 0.000 description 1
- 238000010801 machine learning Methods 0.000 description 1
- 238000003825 pressing Methods 0.000 description 1
- 230000009467 reduction Effects 0.000 description 1
- 238000004088 simulation Methods 0.000 description 1
- 230000002194 synthesizing effect Effects 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/22—Matching criteria, e.g. proximity measures
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/214—Generating training patterns; Bootstrap methods, e.g. bagging or boosting
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/23—Clustering techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Computational Linguistics (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Probability & Statistics with Applications (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Abstract
本发明公开了一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,该方法结合层次聚类、知识蒸馏和强化学习等算法,用来解决多类别不平衡故障分类问题。对于多类别故障分类问题,首先针对不平衡问题中同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点使用层次聚类将多类别聚类为几个簇类,根据不同簇类分别建立学生网络进行细粒度化分类,再用知识蒸馏方法兼顾全局信息,最后结合强化学习迭代学习样本权重,从而提高不平衡故障分类效果。在此过程中,需要设计合理的奖励函数配合细粒度知识蒸馏分类器去优化样本权重。相比其他对比方法,本发明的方法有良好的效果和适用性。
Description
技术领域
本发明属于工业过程监测领域,尤其涉及一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法。
背景技术
在机器学习或深度学习分类中,类别样本数量不平衡是一个非常普遍的问题,广泛存在于各个领域,例如生物信息学,智能电网,医学成像,故障诊断。大多数现有的分类方法都基于以下假设:观测数据的基本分布是相对均衡的。但是,实际工业数据集通常会违反此假设,并呈现出偏斜的分布甚至是极度不平衡的类别样本数量分布。例如,数据驱动的故障分类是工业过程监测的重要组成部分,由于故障发生的频率不同,它们表现出不平衡的偏斜分布。在这种情况下,如果假定所有类别都具有同等的重要性,则分类器会倾向于分对频繁(多数)类别的样本而不是不频繁(少数)类别的样本。因此,迫切需要提出恰当的方法来消除不平衡的类别分布的负面影响,而又不过度牺牲任何多数类别或少数类别的准确性。
发明内容
本发明的目的在于提供一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其能对多数类不平衡的分类问题,获得较好的故障分类结果。具体技术方案如下:
一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,包括以下步骤:
S1:离线建模
S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据和正常数据;
S1.3:通过基于Ward-Linkage的层次聚类,将同质类的类别特征中心分配在一个簇类中,最终将所有类别特征中心uk分配到C个簇类中;然后根据类别特征中心的聚类结果分配每个类别的所有样本到对应簇类中;
S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个簇类中样本进行训练,其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数;
S1.5:基于所述的教师网络的预训练参数,采用所有样本,通过微调技术,训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标;
S1.6:训练完教师网络之后,所有学生网络都通过综合交叉熵损失一起训练;根据包含所述软目标和硬目标的综合损失,采用每个簇类中样本,通过微调技术进行训练,将所有学生网络得到的logit拼接在一起,组成学生网络的综合logit;各个学生网络拼接的每个logit中值的位置对应于原先类别顺序;所述硬目标为样本的真实标签;
S1.7:使用强化学习结合知识蒸馏的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建损失函数;
S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;
S2:在线应用测试
S2.1:获取在线样本;
S2.2:将在线样本分类到S1.3层次聚类得到的C个簇类的其中一个簇类中;
S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重wt,并用加权的softmax函数计算属于各个类别的概率,选取概率最大的类别作为在线样本的类别。
进一步地,所述S1.2中的特征中心点计算具体为:
其中,uk为类别k的特征中心点,xi表示第i个样本,gk表示类别k的所有样本的集合,|gk|表示类别k的样本数量。
进一步地,所述S1.3具体为:
基于Ward-Linkage进行层次聚类,直到最后所有样本都聚成一个簇类。主要有以下步骤:
①在初始化过程中,将每个样本独立的归为一个簇类中;计算每两个簇类中心之间的相似度;
②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;
③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度;所述簇类的中心为一个簇类的所有样本的平均值;
④重复②和③,直到所有样本归为一个簇类,聚类算法结束;
⑤选择所需的最终聚类后的簇类数,作为最终的簇类数,即C的值。
其中vi∈{0,1},hj∈{0,1};θ={W,a,b}是高斯伯努利限制玻尔兹曼机的结构参数;wij是连接可见单元i和隐藏单元j的对称权重;ai和bj分别是可见偏差和隐藏偏差;σi是可见单元i的高斯噪声的标准差;
所述高斯伯努利限制玻尔兹曼机的目标函数为:
其中,xi为第i维的输入数据,p(xi,h|θ)为xi和h的联合概率密度函数;
通过随机梯度上升方法最大化以找到最佳θ,完成对所述高斯伯努利限制玻尔兹曼机的训练:
其中,θ中的w和b用作知识蒸馏神经网络第一层的初始参数。
进一步地,所述S1.5通过梯度下降法训练教师网络,其中,教师网络的交叉熵损失函数如下:
进一步地,所述S1.6通过梯度下降法训练学生网络,其中,学生网络的交叉熵损失函数如下:
进一步地,所述S1.7具体为:
设定πθ为一种带参数θ的参数化随机平稳策略。策略中包含动作。这个平稳策略用来把每个状态st映射到动作概率分布at。rt为第t次迭代所获得的奖励。强化学习迭代用于学习样本权重的过程如下所示:
(1)初始化样本权重(动作):
a0=w0=[w0,0,…,w0,b]
其中w0,i是样本i在第一次迭代时的样本权重。是样本i所在类别k的不平衡率。Nmax是样本最多的类别的样本数量。Nk是类别k的样本数量。a0为πθ的初始化动作。样本权重的更新过程可以形式化成一个序列决策问题。通过迭代t=1,2,…,T次(T为最大迭代次数),wt,i为样本i在第t次迭代时的样本权重。
(2)计算老师-学生网络的加权交叉熵损失:
(3)计算奖励rt。设计的奖励如下:
rt(st,at)=F1(t),
其中F1(t)代表学生网络在第t次迭代的F1得分值。状态、动作(每次迭代的样本权重)和相应的奖励都存储在经验回放中。
(5)更新策略πθ。样本权重wt是策略πθ的动作。策略πθ采用了策略梯度损失函数。损失函数定义为:
进一步地,将在线样本分到对应簇类中,所述S2.2具体为:
在线样本分类到对应的簇类中,其公式如下:
进一步地,所述S2.3具体为:
用强化学习学习到的πθ得到在线样本的权重:
wonline=πθ(xonline)
在线样本经过学生网络得到的输出为:
logit=wonlineft(xonline)
其中,ft(·)表示学生网络;对输出进行softmax得到每个类别的概率,再取最大概率所对应的类别为分类类别:
本发明的有益效果如下:
本发明对于多类别的不平衡故障分类问题具有独特的效果,由于同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点,使得本发明在通过聚类方法得到的簇类的基础上,更加细粒度的通过多个学生网络来解决不平衡的故障分类问题。同时通过教师网络的引导使得各个学生网络不仅能学习到簇类中同质类别的决策边界,也能学习到总体的数据分布信息。不仅如此,进一步结合强化学习,不断结合识蒸馏网络进行迭代,结合样本类别数量与样本在分布中的作用,获取样本权重,增加少数类样本的权重,减少多数类样本的权重,使得故障分类效果更好,准确率更高。
附图说明
图1为本发明方法采用的基础方法的结构图;
图2为本发明方法的结构图;
图3为使用的数据集生成的工艺流程图;
图4为使用的数据集样本数量分布示意图;
图5为本发明方法训练奖励和测试G-mean的曲线图;
图6为通过层次聚类得到的树状图;
图7为所有对比方法10次运行后绘制的箱线图;
图8为分类最后一层隐层的数据通过t-SNE降维后的2D映射图。(a)为MLP最后一层隐层输出的2D映射图;(b)为SMOTE-MLP最后一层隐层输出的2D映射图;(c)为Cosen-MLP最后一层隐层输出的2D映射图;(d)为CSDBN-DE最后一层隐层输出的2D映射图;(e)为TU-MLP最后一层隐层输出的2D映射图;(f)为KD最后一层隐层输出的2D映射图;(g)为本发明最后一层隐层输出的2D映射图。
具体实施方式
下面根据附图和优选实施例详细描述本发明,本发明的目的和效果将变得更加明白,应当理解,此处所描述的具体实施例仅仅用以解释本发明,并不用于限定本发明。
针对多类别的不平衡分布问题,本发明提出了一种新的基于强化学习和知识蒸馏的多类别不平衡故障分类方法。
本发明针对多类别的不平衡分布下的故障分类问题,划定离线和在线数据集,首先使用知识蒸馏方法进行分类或识别故障的类别。再针对不平衡问题中同质类别样本之间存在相似性、异质类样本之间存在较大差异的特点,采用层次聚类方法,根据类别中心点的聚类结果,将所有类别样本进行聚类,从而获得细粒度簇类。最后针对每个簇类进行细粒度故障分类。因此,对于某个簇类中,都将建立一个学生模型,最后进行拼接,进行多学生模型一起优化。在教师模型的全局信息的指导下,并结合多学生模型细粒度的进行故障分类。不仅如此,进一步结合强化学习,不断结合识蒸馏网络进行迭代,结合样本类别数量与样本在分布中的作用,获取样本权重,增加少数类样本的权重,减少多数类样本的权重,使得故障分类效果更好,相比其他现存方法,本发明的方法有良好的效果和适用性。
如图1和2所示,本发明的基于强化学习和知识蒸馏的多类别不平衡故障分类方法,包括以下步骤:
S1:离线建模
S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据样本和正常数据样本;
其中,gk表示类别k的所有样本的集合,|gk|表示类别k的样本数量。
S1.3:通过基于Ward-Linkage的层次聚类,将所有类别特征中心uk分配到C个簇类中。同质类的类别特征中心将被分配在一个簇类中。根据类别特征中心的聚类结果分配每个类别的所有样本到对应簇类中。层次聚类使用以逐次聚合的方式(AgglomerativeClustering),将样本分类,直到最后所有样本都聚成一个簇类。主要有以下步骤:
①在初始化过程中,将每个样本独立的归为一个簇类中。计算每两个簇类中心之间的距离(也称为相似度);
②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;
③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度(一个簇类的所有样本的平均值代表该簇类的中心);
④重复②和③,直到所有样本归为一个簇类,聚类算法结束;
⑤选择所需的最终聚类后的簇类数,作为最终簇类数,即C的值。
整个聚类过程实际上是在构建一棵树。在构建过程中,第②步将设置一个阈值。当两个最近的簇类中心之间的距离大于此阈值时,则认为迭代已终止。另一个关键步骤是第三步,有很多方法可以确定两个聚类之间的相似性。常用的相似性度量包括Ward Linkage,Single Linkage、Complete Linkage和Average Linkage策略。在发明中,由于WardLinkage策略通常提供较高的聚类性能,因此采用Ward Linkage策略。Ward Linkage由两个聚类之间的平方误差和ESS计算得出,其目标函数是每次合并后ESS的最小增量,ESS定义如下:
S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个聚类中样本进行训练。其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数。
高斯伯努利限制玻尔兹曼机具有两层全连接的结构,分为可见单元(或数据变量)和隐藏单元(或潜在变量)p和d分别为可见单元和隐藏单元的数量。高斯伯努利限制玻尔兹曼机既是生成模型,也是基于能量的模型。联合配置v,h的能量函数表示为:
其中vi∈{0,1},hj∈{0,1}。θ={W,a,b}是高斯伯努利限制玻尔兹曼机的结构参数。wij是连接可见单元i和隐藏单元j的对称权重;ai和bj分别是可见偏差和隐藏偏差。σi是可见单元i的高斯噪声的标准差。v和h的联合概率采用以下形式:
通常,将导致高(低)能量的配置(v,h)分别设置为低(高)概率计算的一部分。所有可见单元或隐藏单元都是有条件的独立单元。因此,高斯伯努利限制玻尔兹曼机的可见节点和隐藏节点的概率分布可以由下式给出:
其中σ(x)是逻辑斯蒂sigmoid函数N(·|μ,σ2)是均值为μ,方差为σ2高斯概率密度函数。高斯伯努利限制玻尔兹曼机的优化目标是最大程度的适应数据分布。因此,目标函数是通过输入数据集获得(d是输入数据的特征维度,m是类别数目),如下所示:
现有的大多数基于高斯伯努利限制玻尔兹曼机的模型都是通过对比差异(CD)学习策略来处理数据非线性的,该策略将实值数据映射到隐特征空间。对数似然估计值可通过随机梯度上升方法最大化以找到最佳θ:
通过迭代获得高斯伯努利限制玻尔兹曼机的最优参数θ。θ中的w和b用作知识蒸馏神经网络第一层的初始参数。
S1.5:基于所述的教师网络的预训练参数,采用所有样本,通过微调技术,通过梯度下降法训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标。计算教师网络ft的交叉熵损失函数如下:
S1.6:训练完教师网络之后,所有的学生网络都通过综合交叉熵损失一起训练。根据包含了软目标(教师网络的logit)和硬目标(真实标签)的综合损失,采用每个簇类中样本,通过微调技术,通过梯度下降法进行训练所有学生网络。学生网络的综合logit由所有学生网络的logit拼接在一起。各个学生网络拼接的每个logit中值的位置对应于原先类别顺序。学生网络ft的综合损失,含了软目标(教师网络的logit)和硬目标(真实标签)的综合损失,定义如下:
S1.7:使用强化学习结合知识蒸馏的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建损失函数;
S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;
设定πθ为一种带参数θ的参数化随机平稳策略(由动作组成)。这个平稳策略用来把每个状态st映射到动作概率分布at。rt为第t次迭代所获得的奖励。强化学习迭代用于学习样本权重的过程如下所示:
(1)初始化样本权重(动作):
a0=w0=[w0,0,…,w0,b],
其中w0,i是样本i在第一次迭代时的样本权重。是样本i所在类别k的不平衡率。Nmax是样本最多的类别的样本数量。Nk是类别k的样本数量。a0为πθ的初始化动作。样本权重的更新过程可以形式化成一个序列决策问题。通过迭代t=1,2,…,T次(T为最大迭代次数),wt,i为样本i在第t次迭代时的样本权重。
(2)计算老师-学生网络的加权交叉熵损失:
(3)计算奖励rt。设计的奖励如下:
rt(st,at)=F1(t),
其中F1(t)代表学生网络在第t次迭代的F1得分值。状态、动作(每次迭代的样本权重)和相应的奖励都存储在经验回放中。
(5)更新策略πθ。样本权重wt是策略πθ的动作。策略πθ采用了策略梯度损失函数。损失函数定义为:
S2:在线应用测试
S2.1:获取在线样本;
S2.2:基于S1.3层次聚类得到的簇类信息,将在线样本分类到对应的簇类中。在线样本分类到对应的簇类中,其公式如下:
S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重wt,并用加权的softmax函数计算属于各个类别的概率,选取概率最大的类别作为在线样本的类别,具体为:
用强化学习学习到的πθ得到在线样本的权重:
wonline=πθ(xonline),
在线样本经过学生网络得到的输出为:
logit=Wonlineft(xonline),
其中,ft(·)表示学生网络;对输出进行softmax得到每个类别的概率,再取最大概率所对应的类别为分类类别:
以下结合一个具体的工业例子来说明本发明的有效性。使用田纳西州伊士曼(TE)工业基准来评估所提出的方法。TE过程是由伊士曼化学公司根据实际化学过程开发的工业仿真平台,已广泛用于测试过程监控和故障诊断方法的有效性。TE过程的流程如图3所示。
表1:每个故障类别TE过程训练样本数量设定
故障 | 训练数据 | 故障 | 训练数据 | 故障 | 训练数据 | 故障 | 训练数据 |
IDV1 | 7239 | IDV8 | 3595 | IDV15 | 1785 | IDV22 | 886 |
IDV2 | 6550 | IDV9 | 3253 | IDV16 | 1615 | IDV23 | 802 |
IDV3 | 5927 | IDV10 | 2943 | IDV17 | 1461 | IDV24 | 726 |
IDV4 | 5363 | IDV11 | 2663 | IDV18 | 1322 | IDV25 | 657 |
IDV5 | 4852 | IDV12 | 2410 | IDV19 | 1197 | IDV26 | 594 |
IDV6 | 4390 | IDV13 | 2180 | IDV20 | 1083 | IDV27 | 538 |
IDV7 | 3973 | IDV14 | 1973 | IDV21 | 980 | IDV28 | 486 |
TE数据中正常样本数量为8000。表1为每个故障类别TE过程训练样本数量设定,测试样本数量设定为2000。TE数据的过程变量由34维,故障类别有28个,如图4所示。选取对比方法有MLP(多层感知机)、SMOTE-MLP(合成少数类过采样技术的MLP)、CoSen-MLP(代价敏感MLP)、CSDBN-DE(差分演化的代价敏感深度信念网络)、TU-MLP(可训练的降采样器结合MLP)、KD(知识蒸馏)和本发明(基于强化学习和知识蒸馏的多类别不平衡故障分类方法)。
通过基于强化学习和知识蒸馏的多类别不平衡故障分类方法在TE过程训练样本上训练得到各个学生模型。通过离线训练得到的学生模型对在线样本(测试集)进行预测,得到的结果如表2所示:
表2:在TE过程数据上各个对比方法的分类性能
从表2中可以看出,所提出的基于强化学习和知识蒸馏的多类别不平衡故障分类方法的F1随着不平衡率的上升在更多的类别上优于对比方法,且随着不平衡程度的提高,本发明相比其他对比方法的优势越明显。综合所有对比方法在所有类别上的结果,本发明提出的方法可以在最终的Macro-F1和Gmean指标上明显优于其他方法。
图5为本发明的训练奖励和测试G-mean的曲线图,可以看出算法收敛较为稳定,并能够根据设定达到较优的性能。图6为用层次聚类方法得到的树状图,虚线为决策线,总共分为3个簇类。图7为所有对比方法10次运行后绘制的箱线图,本发明相对其他对比方法性能更好,更稳定。
为了方法优越性更加直观和明显,绘制了各个分类模型最后一层隐藏的输出经过t-SNE后得到的2D图,如图8所示。图8(g)为本发明的2D映射图,能够从图中看出,经过基于强化学习和知识蒸馏的多类别不平衡故障分类方法,获得降维2D图中的各个类别的边界更加明显,这充分体现了算法的分类性能得到了提高。
如上所述,本发明中所提的基于强化学习和知识蒸馏的多类别不平衡故障分类方法,具有令人满意的分类效果。
Claims (9)
1.一种基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,包括以下步骤:
S1:离线建模
S1.1:收集K个类别的历史离线工业过程数据样本,其中包含故障数据和正常数据;
S1.3:通过基于Ward-Linkage的层次聚类,将同质类的类别特征中心分配在一个簇类中,最终将所有类别特征中心uk分配到C个簇类中;然后根据类别特征中心的聚类结果分配每个类别的所有样本到对应簇类中;
S1.4:使用高斯伯努利限制玻尔兹曼机,分别基于所有样本以及每个簇类中样本进行训练,其中,所有样本训练得到的高斯伯努利限制玻尔兹曼机参数为教师网络的预训练参数;基于每个簇类中样本训练得到的高斯伯努利限制玻尔兹曼机参数为对应的学生网络的预训练参数;所述预训练参数作为首次迭代的初始参数;
S1.5:基于所述的教师网络的上一次训练参数,采用所有样本,通过微调技术,训练多类别不平衡的教师网络,得到的logit作为所有学生网络的软目标;
S1.6:所有学生网络都通过综合交叉熵损失一起训练;根据包含所述软目标和硬目标的综合损失,采用每个簇类中样本,通过微调技术进行训练,将所有学生网络得到的logit拼接在一起,组成学生网络的综合logit;各个学生网络拼接的每个logit中值的位置对应于原先类别顺序;所述硬目标为样本的真实标签;
S1.7:使用强化学习结合知识蒸馏中的教师网络和各个学生网络的输出来学习样本权重,并结合学习后的样本权重、教师网络和各个学生网络的输出构建知识蒸馏的损失函数;
S1.8:重复S1.5~S1.7,进行强化学习模型和知识蒸馏模型迭代训练,直到模型收敛;
S2:在线应用测试
S2.1:获取在线样本;
S2.2:将在线样本分类到S1.3层次聚类得到的C个簇类的其中一个簇类中;
S2.3:基于S1.8训练得到的知识蒸馏模型中的教师网络和各个学生网络,计算在线样本经过所在的簇类对应的学生网络得到的logit,和通过强化学习模型得到样本权重wt,并用加权的softmax函数计算属于各个类别的概率,选取概率最大的类别作为在线样本的类别。
3.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.3具体为:
基于Ward-Linkage进行层次聚类,直到最后所有样本都聚成一个簇类。主要有以下步骤:
①在初始化过程中,将每个样本独立的归为一个簇类中;计算每两个簇类中心之间的相似度;
②找到两个最近的簇类,并将它们归为一个簇类,因此簇类总数减少1个;
③重新计算新生成簇类的中心与每个旧簇类中心之间的相似度;所述簇类的中心为一个簇类的所有样本的平均值;
④重复②和③,直到所有样本归为一个簇类,聚类算法结束;
⑤选择所需的最终聚类后的簇类数,即C的值。
4.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.4中的高斯伯努利限制玻尔兹曼机具有两层全连接的结构,分为可见单元和隐藏单元p和d分别为可见单元和隐藏单元的数量;联合配置v,h的能量函数表示为:
其中vi∈{0,1},hj∈{0,1};θ={W,a,b}是高斯伯努利限制玻尔兹曼机的结构参数;wij是连接可见单元i和隐藏单元j的对称权重;ai和bj分别是可见偏差和隐藏偏差;σi是可见单元i的高斯噪声的标准差;
所述高斯伯努利限制玻尔兹曼机的目标函数为:
其中,xi为第i维的输入数据,p(xi,h|θ)为xi和h的联合概率密度函数;
通过随机梯度上升方法最大化以找到最佳θ,完成对所述高斯伯努利限制玻尔兹曼机的训练:
其中,θ中的w和b用作知识蒸馏神经网络第一层的初始参数。
7.根据权利要求1所述基于强化学习和知识蒸馏的多类别不平衡故障分类方法,其特征在于,所述S1.7具体为:
设定πθ为一种带参数θ的参数化随机平稳策略,策略中包含动作。这个平稳策略用来把每个状态st映射到动作概率分布at。rt为第t次迭代所获得的奖励。强化学习迭代用于学习样本权重的过程如下所示:
(1)初始化样本权重:
a0=w0=[w0,0,…,w0,b],
其中w0,i是样本i在第一次迭代时的样本权重。是样本i所在类别k的不平衡率。Nmax是样本最多的类别的样本数量。Nk是类别k的样本数量。a0为πθ的初始化动作。样本权重的更新过程形式化成一个序列决策问题。通过迭代t=1,2,…,T次,T为最大迭代次数,wt,i为样本i在第t次迭代时的样本权重。
(2)计算教师-学生网络的加权交叉熵损失:
(3)计算奖励rt,设计的奖励如下:
rt(st,at)=F1(t),
其中F1(t)代表学生网络在第t次迭代的F1得分值,状态、每次迭代的样本权重和相应的奖励都存储在经验回放中;
(5)更新策略πθ。样本权重wt是策略πθ的动作。策略πθ采用了策略梯度损失函数。损失函数定义为:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110549644.6A CN113222035B (zh) | 2021-05-20 | 2021-05-20 | 基于强化学习和知识蒸馏的多类别不平衡故障分类方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202110549644.6A CN113222035B (zh) | 2021-05-20 | 2021-05-20 | 基于强化学习和知识蒸馏的多类别不平衡故障分类方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN113222035A true CN113222035A (zh) | 2021-08-06 |
CN113222035B CN113222035B (zh) | 2021-12-31 |
Family
ID=77093626
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202110549644.6A Active CN113222035B (zh) | 2021-05-20 | 2021-05-20 | 基于强化学习和知识蒸馏的多类别不平衡故障分类方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN113222035B (zh) |
Cited By (3)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114219146A (zh) * | 2021-12-13 | 2022-03-22 | 广西电网有限责任公司北海供电局 | 一种电力调度故障处理操作量预测方法 |
CN114638336A (zh) * | 2021-12-26 | 2022-06-17 | 海南大学 | 聚焦于陌生样本的不平衡学习 |
CN115908955A (zh) * | 2023-03-06 | 2023-04-04 | 之江实验室 | 基于梯度蒸馏的少样本学习的鸟类分类系统、方法与装置 |
Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20170032276A1 (en) * | 2015-07-29 | 2017-02-02 | Agt International Gmbh | Data fusion and classification with imbalanced datasets |
CN108875772A (zh) * | 2018-03-30 | 2018-11-23 | 浙江大学 | 一种基于堆叠稀疏高斯伯努利受限玻尔兹曼机和强化学习的故障分类模型及方法 |
CN108875771A (zh) * | 2018-03-30 | 2018-11-23 | 浙江大学 | 一种基于稀疏高斯伯努利受限玻尔兹曼机和循环神经网络的故障分类模型及方法 |
CN111598216A (zh) * | 2020-04-16 | 2020-08-28 | 北京百度网讯科技有限公司 | 学生网络模型的生成方法、装置、设备及存储介质 |
CN112819159A (zh) * | 2021-02-24 | 2021-05-18 | 清华大学深圳国际研究生院 | 一种深度强化学习训练方法及计算机可读存储介质 |
-
2021
- 2021-05-20 CN CN202110549644.6A patent/CN113222035B/zh active Active
Patent Citations (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20170032276A1 (en) * | 2015-07-29 | 2017-02-02 | Agt International Gmbh | Data fusion and classification with imbalanced datasets |
CN108875772A (zh) * | 2018-03-30 | 2018-11-23 | 浙江大学 | 一种基于堆叠稀疏高斯伯努利受限玻尔兹曼机和强化学习的故障分类模型及方法 |
CN108875771A (zh) * | 2018-03-30 | 2018-11-23 | 浙江大学 | 一种基于稀疏高斯伯努利受限玻尔兹曼机和循环神经网络的故障分类模型及方法 |
CN111598216A (zh) * | 2020-04-16 | 2020-08-28 | 北京百度网讯科技有限公司 | 学生网络模型的生成方法、装置、设备及存储介质 |
CN112819159A (zh) * | 2021-02-24 | 2021-05-18 | 清华大学深圳国际研究生院 | 一种深度强化学习训练方法及计算机可读存储介质 |
Non-Patent Citations (3)
Title |
---|
AVRAAM TSANTEKIDIS,ET AL: "《Diversity-driven knowledge distillation for financial trading using Deep Reinforcement Learning》", 《NEURAL NETWORKS》 * |
ZHANG WEI HONG,ET AL: "《Periodic Intra-Ensemble Knowledge Distillation for Reinforcement Learning》", 《ARXIV:2002.00149V1》 * |
王锋涛: "《地铁车辆轮对踏面故障诊断系统研究与开发》", 《中国优秀博硕士学位论文全文数据库(硕士) 工程科技Ⅱ辑》 * |
Cited By (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114219146A (zh) * | 2021-12-13 | 2022-03-22 | 广西电网有限责任公司北海供电局 | 一种电力调度故障处理操作量预测方法 |
CN114638336A (zh) * | 2021-12-26 | 2022-06-17 | 海南大学 | 聚焦于陌生样本的不平衡学习 |
CN114638336B (zh) * | 2021-12-26 | 2023-09-22 | 海南大学 | 聚焦于陌生样本的不平衡学习 |
CN115908955A (zh) * | 2023-03-06 | 2023-04-04 | 之江实验室 | 基于梯度蒸馏的少样本学习的鸟类分类系统、方法与装置 |
Also Published As
Publication number | Publication date |
---|---|
CN113222035B (zh) | 2021-12-31 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN113222035B (zh) | 基于强化学习和知识蒸馏的多类别不平衡故障分类方法 | |
CN110263227B (zh) | 基于图神经网络的团伙发现方法和系统 | |
Basirat et al. | The quest for the golden activation function | |
Hruschka et al. | Extracting rules from multilayer perceptrons in classification problems: A clustering-based approach | |
CN112685504B (zh) | 一种面向生产过程的分布式迁移图学习方法 | |
CN108446214B (zh) | 基于dbn的测试用例进化生成方法 | |
Li et al. | A bilevel learning model and algorithm for self-organizing feed-forward neural networks for pattern classification | |
Indira et al. | Image segmentation using artificial neural network and genetic algorithm: a comparative analysis | |
CN111079926B (zh) | 基于深度学习的具有自适应学习率的设备故障诊断方法 | |
Perez-Godoy et al. | CO 2 RBFN: an evolutionary cooperative–competitive RBFN design algorithm for classification problems | |
CN107153837A (zh) | 深度结合K‑means和PSO的聚类方法 | |
CN113627471A (zh) | 一种数据分类方法、系统、设备及信息数据处理终端 | |
Urgun et al. | Composite power system reliability evaluation using importance sampling and convolutional neural networks | |
de Campos Souza et al. | Online active learning for an evolving fuzzy neural classifier based on data density and specificity | |
Poczeta et al. | Analysis of fuzzy cognitive maps with multi-step learning algorithms in valuation of owner-occupied homes | |
CN113222034B (zh) | 基于知识蒸馏的细粒度多类别不平衡故障分类方法 | |
KR20080078292A (ko) | 영역 밀도 표현에 기반한 점진적 패턴 분류 방법 | |
Liu et al. | Prediction of share price trend using FCM neural network classifier | |
CN115906959A (zh) | 基于de-bp算法的神经网络模型的参数训练方法 | |
CN108446718B (zh) | 一种动态深度置信网络分析方法 | |
Jamsandekar et al. | Self generated fuzzy membership function using ANN clustering technique | |
Baruque et al. | Hybrid classification ensemble using topology-preserving clustering | |
Bhowan | Genetic programming for classification with unbalanced data | |
Chouikhi et al. | Adaptive extreme learning machine for recurrent beta-basis function neural network training | |
Mousavi | A New Clustering Method Using Evolutionary Algorithms for Determining Initial States, and Diverse Pairwise Distances for Clustering |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |