JP6955155B2 - 学習装置、学習方法及び学習プログラム - Google Patents

学習装置、学習方法及び学習プログラム Download PDF

Info

Publication number
JP6955155B2
JP6955155B2 JP2017200842A JP2017200842A JP6955155B2 JP 6955155 B2 JP6955155 B2 JP 6955155B2 JP 2017200842 A JP2017200842 A JP 2017200842A JP 2017200842 A JP2017200842 A JP 2017200842A JP 6955155 B2 JP6955155 B2 JP 6955155B2
Authority
JP
Japan
Prior art keywords
learning
accuracy
overfitting
data set
amount
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
JP2017200842A
Other languages
English (en)
Other versions
JP2019074947A (ja
Inventor
橋本 鉄太郎
鉄太郎 橋本
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Fujitsu Ltd
Original Assignee
Fujitsu Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Fujitsu Ltd filed Critical Fujitsu Ltd
Priority to JP2017200842A priority Critical patent/JP6955155B2/ja
Publication of JP2019074947A publication Critical patent/JP2019074947A/ja
Application granted granted Critical
Publication of JP6955155B2 publication Critical patent/JP6955155B2/ja
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Landscapes

  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Image Analysis (AREA)

Description

本発明は,学習装置、学習方法及び学習プログラムに関する。
学習、特に深層学習は、ディープニューラルネットワーク(Deep Neural Network: DNN)の入力層に訓練データの入力データを入力し、DNNの計算を実行して出力層の出力データを算出し、算出された出力データと訓練データの教師データとの差分を小さくするようにDNN内の変数(重み)の値を最適化する学習を繰り返し実行する。
変数の最適化は、例えば、勾配法により行われる。勾配法では、例えば、訓練データの入力データからDNNで算出した出力データの値と、訓練データの教師データの値との差分の二乗和を示す関数が、最小になる変数の値を求める。具体的には、あるサンプル点での変数xiから、関数fの勾配に学習率εを乗じた値を減じることで次のサンプル点xi+1の変数を求める。差分の二乗和が最小になると、出力データの値の精度(Accuracy)は最大になる。
一般に、ある学習率で上記の学習を行い、ある程度学習が進んだら学習率を減衰させ、減衰させた新たな学習率で学習を再開することを繰り返す。そして、学習率をある回数減衰し学習が進まなくなれば学習を終了する。
深層学習の終了方法については、以下の文献に記載されている。
特開2015−11510号公報 特開2017−16414号公報
"Automatic early stopping using cross validation: quantifying the criteria," Neural Networks 11 (1998) 761-767 https://www.tensorflow.org/get_started/monitors#early_stopping_with_validationmonitor https://keras.io/ja/callbacks/#earlystopping
一方、深層学習には、学習期間が長い、つまり学習量が多すぎることに起因して、訓練データセットの特定のランダムな特徴に過剰に適合する過学習(過剰適合(overfitting)ともいう。)の問題がある。過学習になると、例えば、白っぽい自動車を含む訓練データセット(Training set)で学習した結果、黒っぽい自動車を含む検証データセット(Validation set)に対して正しく車種を認識できなくなる。
そのため、学習量が多すぎて過学習になると、変数を最適化した学習モデルが汎化できない状態になり、訓練データセットに対しては精度が向上するが、訓練データに含まれなかった未知のデータである検証データセットに対しては逆に精度が悪くなる。したがって、過学習は無駄な学習といえる。逆に、学習量が少なすぎると、学習量が不十分であり高い精度を得ることができない。
そこで,本開示の第1の側面の目的は,過学習による無駄な学習を減らし、精度を劣化させずに早期に学習を終了する学習装置、学習方法及び学習プログラムを提供することにある。
本開示の第1の側面は,訓練データセットについて学習器で学習を行い、検証データセットについて精度を算出する学習部と、
前記学習部による前記精度に基づいて、過学習状態を検出する検出部と、
前記学習部による前記精度に基づいて、学習の収束状態を判定する判定部と、
前記検出部が前記過学習状態を検出した場合、前記学習部による学習率を変更して再び学習させるとともに、前記判定部が前記学習部による学習が収束したと判定した場合、前記学習部による学習を停止させる制御部と、を有する学習装置である。
第1の側面によれば,精度を劣化させずに早期に学習を終了することができる。
本実施の形態における学習装置の構成例を示す図である。 プロセッサが学習プログラムを実行することで実現されるDNNの一例を示す図である。 深層学習の学習率に対する、学習量と精度の変化の例を示す図である。 過学習を説明する図である。 過学習の別の例を説明する図である。 本実施の形態における学習装置の構成を示す図である。 本実施の形態における学習方法または学習プログラムの処理を示すフローチャート図である。 図7の学習と検証処理S12のフローチャート図である。 過学習を検出するための過学習判定閾値と精度の低下状態との関係を示す。 図7の最高精度のサンプル点imaxの取得について説明する図である。 本実施の形態の学習方法で学習した精度曲線の一例を示す図である。 第1の比較例の精度曲線を示す図である。 第2の比較例の精度曲線を示す図である。 第3の比較例の精度曲線を示す図である。 各学習率での学習量を一定(20万エポック(サンプル点))にして学習した例を示す図である。 本実施の形態により学習した例を示す図である。
図1は、本実施の形態における学習装置の構成例を示す図である。学習装置1は、情報処理装置、またはコンピュータである。学習装置1は、プロセッサであるCPU(Central Processing Unit)10、CPUがアクセス可能なメインメモリ12、グラフィックプロセッサ14、グラフィックプロセッサがアクセスするGPUメモリ16、外部ネットワークとのインターフェース18、内部バス28を有する。グラフィックプロセッサ14は、例えば、画像を入力データとするDNNに多く含まれる畳込み演算で必要な積和演算を並列に且つ高速に実行するプロセッサである。
但し、本実施の形態は、グラフィックプロセッサ14とGPUメモリ16を設けず、グラフィックプロセッサの演算をプロセッサ10で実行し、GPUメモリ16内に記憶するデータをメインメモリ12に記憶する構成であってもよい。
学習装置1は、ハードディスクやSSD(Solid State Device)などの大容量の補助記憶装置20,22,24,26を有し、補助記憶装置には、DNNの学習プログラム20と、学習に使用する訓練データセット22と、学習に使用する検証データセット24と、学習プログラムが実行されて算出される学習モデルの精度データ26とが格納される。訓練データセットと検証データセットは、共に、入力データとDNNが出力すべき教師データとを含む。
学習装置1には、インターネットやイントラネットなどのネットワークNWを介して、クライアント端末装置30,32がアクセス可能である。クライアント端末装置30,32は、学習装置1にアクセスし、クライアントが準備した訓練データセット22と検証データセット24について、プロセッサ10に学習プログラム20を実行させる。
プロセッサ10は、学習プログラムを実行し、訓練データセット22の入力データに対するDNNの出力データを算出し、その出力データと訓練データセットの教師データとの差分が最小になるようにDNNの変数を最適化する。さらに、プロセッサ10は、学習プログラムを実行し、検証データセット24の入力データに対するDNNの出力データを算出し、その出力データと検証データセットの教師データとの差分に基づき、検証データセットに対する精度を出力する。そして、プロセッサは、学習プログラムを実行し、精度データ26に基づいて、後述するように学習方法を制御し、DNNの変数を最適化した学習モデルを、少ない学習量で生成する。
図2は、プロセッサが学習プログラムを実行することで実現されるDNNの一例を示す図である。このDNNは、例えば画像データを入力とする入力層INPUTと、入力層に入力される画像データに対してフィルタの係数(または重み値)に基づいて畳込み演算を行う複数の畳込みCNV1,CNV2と、畳み込み層で算出したデータから局所的なノードの最大値を抽出するマックスプーリング層MP1, MP2と、全結合層FC1, FC2と、全結合層FC2に接続される出力層OUTPUTとを有する。
入力層INPUTは、複数の入力ノードを有し、それぞれの入力ノードに例えば画像データの画素データが入力される。図中、入力層INPUTは画像を模擬的に示す1つの矩形だが、実際は画像データの画素データが入力される複数のノードである。畳込み層CNV1, CNV2やマックスプーリング層MP1, MP2も同様である。
入力層INPUTの複数の入力ノードと、畳込み層CNV1の複数のノードとの間は、それぞれ重み値(DNNの変数)を有するエッジで結ばれる。例えば、複数の入力ノードに入力された画素データと各エッジの重み値とで積和演算され、畳込み層CNV1の複数のノードが有する活性化関数により各ノードの値が出力される。上記の各エッジの重み値は、前述のフィルタの係数(または重み値)に対応する。
畳込み層CNV1の複数のノードと次のマックスプーリング層MP1の複数のノードも、上記と同様に、それぞれ重み値を有するエッジで結ばれ、畳込み層CNV1のノードの値とエッジの重み値との積和演算と、マックスプーリング層MP1の各ノードの活性化関数の演算が実行され、マックスプーリング層MP1の各ノードに値が出力される。他の畳込み層やマックスプーリング層も同様である。
全結合層FC1,FC2は、前の層のノードと自分の層のノードが全てエッジで結ばれる。全結合層のノードの値の計算も、前の層のノードの値とエッジの重み値とによる積和演算と、自分のノードの活性化関数とにより行われる。
出力層OUTPUTの複数のノードには、例えば、入力される画像に認識対象の画像が含まれる確率を示す確率ベクトルが出力される。例えば、認識対象の画像が100種類の場合、出力層は100の出力ノードを有する。そして、各出力ノードには、入力画像内に認識対象画像が含まれている確率が出力される。確率ベクトルは、全確率の合計が1になるベクトルであり、入力画像にどの認識対象画像が含まれるかを示す特徴ベクトルである。
図2のDNNの場合、訓練データセットは、複数の入力画像である入力データと、それぞれの入力画像に対する確率ベクトルである教師データとを有する。また、評価データセットも、訓練データと同様に、複数の入力画像である入力データと、それぞれの入力画像に対する確率ベクトルである教師データとを有する。但し、評価データセットの評価データは、訓練データセットの訓練データと重複しない。これにより、訓練データセットで学習して変数が最適化されたDNNに対し、検証データセットの入力データで算出したDNNの出力データと検証データセットの教師データとの差分で、精度が評価される。
図2のDNNの学習では、学習装置のプロセッサが学習プログラムを実行して、例えば以下の処理を行う。
(1)学習工程
まず、プロセッサが、D個の訓練データの入力データについて、DNNの入力層から出力層に向かってそれぞれ定義された演算を実行し、出力層に出力される出力データを算出する。D個は、例えば学習装置のコンピュータが一度に並列演算できる訓練データの数であり、バッチ数と呼ばれる。
次に、プロセッサが、算出した出力データと訓練データの教師データとの差分の二乗和を算出する。これが、前述した関数fの値である。そして、前述の勾配法により、関数fの傾き(∂f/∂x)に学習率εを乗じた値を現在のDNNの変数(重み値)xiから減じて、新たな変数xi+1を算出する。すなわち、xi+1 = xi - ε*∂f/∂xである。
(2)検証工程
上記の(1)を所定回数(A回)繰り返した後、検証データの入力データについてDNNの演算を実行して出力データを算出し、検証データの教師データとの差分の二乗和に基づいて、精度を算出する。
(3)ある学習率εで上記の(1)(2)を所定回数(B回)繰り返したら、学習率εを減衰させ、再度(1)(2)を所定回数(B回)繰り返す。減衰させた学習率εで(1)(2)を所定回数(B回)繰り返すことを、予め決められた回数(C回)行って、つまり、C個の学習率について繰り返し、学習を終了する。
図2のDNNは、一例であり、本実施の形態が適用される学習モデルは他のDNNでも良い。
図3は、深層学習の学習率に対する、学習量と精度の変化の例を示す図である。横軸が学習量、縦軸が精度に対応する。横軸の目盛は、前述の学習処理の(1)(2)により精度が出力される単位であるエポック(Epoch)を示し、各エポック(またはサンプル点)に対して、精度がプロットされている。学習量は、学習した訓練データの総計である。したがって、上記の(1)での入力データ数がD個であれば、1エポックの学習量はD個*A回となり、1つの学習率εで行う学習量は、B回のエポック分であるので、D個*A回*B回となり、更に、全学習量は、D個*A回*B回*C回となる。
図3の例では、1個の学習率εで行うエポック数は30回(=B回)であり、3個(=C回)の学習率εに対して学習を繰り返し行っている。3個の学習率εは1/10ずつ減少している。上記(1)学習工程の訓練データの個数(D個)と、上記(1)学習工程を繰り返す回数(A回)によって、1エポックの学習量(D個*A回)が異なる。そこで、横軸は、総訓練データ数である学習量とエポックEpochに対応する。
図3の例では、
(a)ある学習率ε=0.01で複数のエポック数分、訓練データによる学習工程(1)の繰り返し(A回)と検証データによる検証工程(2)を繰り返す間に、精度が最初は急上昇し、その後徐々に上昇する。
(b)エポック数が30回(B回)に達すると、図3の例では学習率εを10分の1(1/10)に減少してε=0.001とし、再度学習(1)と検証(2)を繰り返す。ε=0.001での精度は、最初急上昇したあと少し減少している。この減少が過学習状態OFに対応する。
(c)同様に、エポック数が30回(B回)に達すると、学習率εを更に10分の1(1/10)に減少してε=0.0001にし、再度学習(1)と検証(2)を繰り返す。
図3に拡大して左側に示したとおり、学習率εが大きいε=0.01では、精度の変動幅が大きい。図3の左側に拡大して示されている。これは、学習率が大きいため、変数の更新幅が大きくなり、精度の変動幅が大きくなるからである。一方、学習率が減少してε=0.001、0.0001になると、精度の変動幅は小さくなっている。
図4は、過学習を説明する図である。深層学習を含む機械学習では、学習期間(学習量)が長すぎると(多すぎると)、DNNである学習モデルが訓練データセットの特定のランダムな特徴にまで過剰に適合してしまう過学習が発生する。過学習は過剰適合(Overfitting)とも呼ばれている。過学習の例としては、前述したとおり、白っぽい自動車を含む訓練データセット(Training set)で学習した結果、黒っぽい自動車を含む検証データセット(Validation set)に対して正しく車種を認識できなくなるなどである。
過学習状態になると、DNNである学習モデルが汎化できていない状態になる。その結果、図4に示すとおり、訓練データセットの精度は実線のように向上するが、訓練データセットとは異なり、学習モデルには未知のデータである検証データセットに対する精度は、破線のように逆に悪くなる(低下する)。その結果、訓練データセットの精度と検証データセットの精度との差であるロス(Loss)が拡大する。図4の例では、訓練データセットによる学習量がE0の時点で、検証データセットに対する精度が最大になり、その後徐々に低下している。
一方、図4から分かるとおり、学習期間(学習量)が短い場合は、学習回数が不十分のため、訓練データセットの精度も検証データセットの精度も十分に向上していない。
したがって、十分な学習量まで訓練データセットによる学習と検証データセットによる検証を繰り返し、過去の検証データセットの精度の変化をチェックして過学習が検出されれば、過去の検証データセットの精度が最大のエポックでの変数を設定して、学習モデルを完成させるのが望ましい。但し、過学習による精度の低下を見極めるためには、長期にわたり学習と検証を繰り返し、過学習開始直前のサンプル点の変数を最適化変数と判定する必要がある。この場合、過学習状態の学習は無駄な学習になってしまう。
図5は、過学習の別の例を説明する図である。図5にも訓練データセットの精度(実数)と検証データセットの精度(破線)とが示される。過学習は、前述したとおり、訓練データセットでの精度は上昇を続けているが、検証データセットでの精度が上昇から下降に転じて下降し続ける現象である。
しかし、図5に示すとおり、学習量E1では、検証データセットの精度がピークになりその後下降しているが、その後再度上昇し、学習量E2でピークとなっている。さらに、検証データセットの精度が、学習量E2でピークとなった後下降し、その後再度上昇し、学習量E3で再度ピークとなっている。そして、その後は、検証データセットの精度が長期間にわたり下降をし続けている。
このように、検証データセットの精度は、下降と上昇を繰り返す場合があり、過学習を判定するのは単純ではない。図5の例の場合、学習量E4まで学習を継続し、学習量E3からE4まで長期にわたり検証データセットの精度が低下したことで真の過学習に入ったと判断し、学習を終了する。そして、過去の検証データセットの精度が最大ピークとなった学習量E3での変数を設定して、学習モデルを完成するのが望ましい。但し、その場合学習量E4まで学習を継続するため、学習量E3-E4の間の学習は無駄になる。
[本実施の形態の説明]
図3に戻り、同じ学習率εでの学習と検証を一定の学習量行うことを、学習率を減少しながら、繰り返す場合、モデルのDNNの構成と、訓練データセット及び検証データセットに依存して、ある学習率で過学習が発生し始める学習量が異なる。
その結果、次のような現象が想定される。
(1)現象1:各学習率での学習量が多すぎると、それぞれの学習率で過学習が発生してしまい、目標とする検証データセットでの精度に達するまで学習期間(学習量)が長くなる(多くなる)。
(2)現象2:各学習率での学習量が多すぎると、それぞれの学習率で過学習が発生し、検証データセットでの精度が低下したまま、次の学習率での学習が再開され、最終的に達する検証データセットでの精度が、目標とする精度に達しない。
(3)現象3:各学習率での学習量が少なすぎると、それぞれの学習率での検証データセットでの精度が十分に向上する前に、次の学習率での学習が再開され、最終的に達する検証データセットでの精度が、目標とする精度に達しない。
そこで、本実施の形態では、学習装置は、各学習率での学習量を一定にせず、各学習率での学習と検証を繰り返す中で、所定の学習量の間(または所定の学習期間)検証データセットの精度が低下傾向にあることを検出すると、学習率を更新、例えば学習率を減衰させ、その学習率での学習と検証を再開する。所定の学習量の間(または所定の学習期間)検証データセットの精度が低下傾向にあることは、過学習が起こって精度が低下していることを判定することに対応する。
そして、学習装置は、好ましくは、所定の学習量の間(または所定の学習期間)精度が低下傾向にあることを検出するまでの過去の最大の精度の変数から、更新後の学習率での学習と検証を再開する。過学習により精度が低下した学習モデルの変数は適切でないからである。
また、学習装置は、好ましくは、検証データセットの精度が収束したら学習を終了する。この精度の収束の判定は、例えば、精度が低下傾向にあることを検出したタイミングで行う。
さらに、学習装置は、好ましくは、検証データセットの精度の変化量が大きいので、精度曲線を移動平均した移動平均線に変換し、移動平均線について、上記の所定の学習量の間にわたり検証データセットの精度が低下傾向にあることを検出する。特に好ましくは、学習率が大きい場合検証データセットの精度の変化量が大きいので、時間平均することで実質的にローパスフィルタを通過させ、高周波成分の変化を平滑化した移動平均線に変換する。検証データセットの精度の変化量が大きいことは、図3で拡大して示したとおりである。
上記の精度が低下傾向にあることの検出の条件は、例えば、検証データセットの精度の移動平均線における連続N個(Nは複数)の精度の変化量の平均が過学習判定閾値未満になることである。連続N個(Nは複数)の精度の変化量の平均をチェックすることは、精度の傾きをチェックすることである。
連続N個(Nは複数)の精度の変化量の平均が過学習判定閾値未満になることは、長期的に見て、学習により精度がまだ改善(上昇)しているのか、または、過学習が起こって精度が悪化(低下)しているのかを判定することである。過学習判定閾値を正に設定すると、前者の、学習により精度がまだ改善(上昇)しているのかの判定を行うことができ、ゼロまたは負に設定すると、後者の、過学習が起こって精度が悪化(低下)しているかの判定を行うことができる。
さらに、好ましくは、精度が低下傾向にあることの検出の条件に、最終サンプル点での精度の変化量が負であることを加える。
上記の精度の収束の判定の条件は、例えば、検証用データセットの連続L個(Lは複数)の精度の変化量の二乗平均平方根が収束判定閾値未満になることである。L個(Lは複数)の精度の変化量の二乗平均平方根は、精度の変動量に対応する。連続L個(Lは複数)の精度の変化量の二乗平均平方根が収束判定閾値未満になることは、精度が飽和したことをチェックすることである。学習が十分に進み、かつ、学習率が小さくなってくると、精度が飽和し、精度が変動しなくなる。飽和していない間は、精度が上昇したり下降したりを繰り返し、精度の変動量が大きい。
この条件が満たされると、学習装置は、これ以上学習により精度が改善しないと判定し、学習を終了させる。
本実施の形態によれば、各学習率での学習量を最適な量に(学習期間を最適な期間に)適宜制御することができ、短い学習期間(少ない学習量)で目標とする精度に達することができる。
図6は、本実施の形態における学習装置の構成を示す図である。学習装置は、訓練データセット22の訓練データについて学習を行い、検証データセット24の検証データについて精度を算出する学習部41を有する。さらに、学習装置は、学習部41が算出した精度に基づいて、過学習状態を検出する検出部42と、学習部41が算出した精度に基づいて、学習の収束状態を判定する判定部43とを有する。
そして、学習装置は、検出部42が過学習状態を検出した場合、学習部41による学習率を変更して再び学習部に学習させるとともに、判定部43が学習部による学習が収束したと判定した場合、学習部による学習を停止させる制御部40を有する。
図7は、本実施の形態における学習方法または学習プログラムの処理を示すフローチャート図である。学習装置のプロセッサは、学習プログラムを実行して、以下の処理を実行する。
プロセッサは、まず、初期値の設定として、学習率ε、過学習判定閾値Δth、収束判定閾値δthを設定する(S10)。そして、プロセッサは、訓練データセットと検証データセットを利用して深層学習を開始する(S11)。プロセッサは、学習では、訓練データセットによる学習と検証データセットによる検証を実行し(S12)、検証で算出した検証データセットの精度の所定の学習量の期間にわたる低下傾向があるか否かに基づいて、過学習状態を検出する(S13)。
上記の学習と検証工程S12は、前述のエポックEpochの1回分に対応する。
過学習状態が検出されない場合(S13のNO)、プロセッサは、学習と検証工程S12を繰り返す。過学習状態が検出されると(S13のYES)、プロセッサは、学習を一旦停止し(S14)、検証データセットの精度が収束しているか否かを判定する(S15)。
プロセッサは、検証データセットの精度が収束していないと判定すると(S15のNO)、学習率εと過学習判定閾値Δthを減衰して更新する(S16)。さらに、プロセッサは、更新前の学習率での検証データセットの最高精度のサンプル点imaxを取得する(S17)。最高精度のサンプル点imaxとは、更新前の学習率での検証データセットの精度の曲線の複数のエポックEpochの点(サンプル点)のうち、最高精度の点である。そして、プロセッサは、更新した学習率ε、過学習判定閾値Δthを設定し、学習を再開するDNNの変数を工程S17で取得したサンプル点imaxの変数に設定し(S18)、学習を再開する(S11)。
一方、プロセッサは、検証データセットの精度が収束していると判定すると(S15のYES)、最後の学習率での検証データセットの精度の曲線の複数のエポックEpochの点(サンプル点)のうち、最高精度の点imaxの変数を設定して(S19)、学習を終了する。
次に、図7の学習と検証処理S12を説明し、その後、図7の過学習検出処理S13と、精度の収束検出処理S15について詳細に説明する。
[学習と検証処理S12]
図8は、図7の学習と検証処理S12のフローチャート図である。前述のとおり、図8の学習と検証処理S12は、1エポックEpochでの処理に対応する。学習と検証処理では、プロセッサは、学習プログラムを実行して、以下の処理を実行する。
プロセッサは、D個の訓練データセットの入力データについて、DNNの演算を実行し、出力データを算出する(S121)。この出力データは、DNNの現在の変数xiに基づいて算出される。そして、プロセッサは、算出した各出力ノードの出力データと訓練データセットの教師データとの差分の二乗和を算出し、差分の二乗和に基づいてDNNの新たな変数xi+1を算出する。
この新たな変数への更新では、例えば、誤差逆拡散法に従い、各出力ノードの値(出力データの値)と訓練データセットの教師データとの差分をDNNの入力ノードに向かって逆拡散し、各層の複数のノードでの差分を小さくするように前段の層の複数のノードとの間のエッジの変数を最適化する。
上記のDNNの演算では、前段の層の複数のノードの値とエッジの重み値との積和演算と、積和演算結果を入力とする後段の層のノードの活性化関数の演算などが含まれる。そこで、学習装置のGPUによる積和演算能力、例えば並列演算数、に基づいて、GPUが一度に処理できる最大数に前述の訓練データセットの数D個が設定される。このD個はバッチ数とも呼ばれる。
プロセッサは、上記の訓練データセットの入力データに対するDNNの演算S121と、変数の更新S122とを、予め決められたA回繰り返す(S123)。工程S121-S123が1つのエポックでの学習ステップである。したがって、前述のとおり、1つのエポックでの学習量は、バッチ数D個と繰り返し回数A回の積(D*A)である。
次に、プロセッサは、学習ステップで最適化された変数のDNNにより、検証データセットの1つの又は少数の検証データの入力データについて、DNNの演算を実行して出力ノードの出力データを算出する(S124)。そして、プロセッサは、検証データの入力データから算出した出力データの値と検証データの教師データとの差分に基づいて、検証データによる精度を算出する(S125)。
精度は、最大精度1.0から上記の差分の二乗和の平均値の平方根(二乗和平均平方根)を減じて求められる。例えば、前述の入力データを画像の画素データとし、出力ノードの出力データを入力画像に含まれる認識対象画像が存在する確率ベクトルと仮定する。この場合、検証データの入力データから算出した出力データの値は確率値(0.0〜1.0)であり、一方、教師データの値は、入力画像に含まれる認識対象画像の出力ノードでは最大確率値1.0となり、入力画像に含まれない認識対象画像の出力ノードでは最小確率値0.0となる。よって、差分の二乗和平均平方根は、確率の誤差であり0.0〜1.0の値である。そして、精度は、最大精度1.0から差分の二乗和平均平方根を減じることで算出される。
上記の工程S124,S125が検証ステップである。
[過学習検出処理S13]
過学習検出処理S13では、プロセッサは、以下の演算により検証データセットの精度の所定の学習量の期間にわたる低下傾向があるか否かを判定する。
Figure 0006955155
ここで、y(i)はサンプルiでの検証データセットの精度である。
プロセッサは、式1により、現在のサンプルiから過去のM-1個のサンプルでの精度の合計
y(i)+y(i-1)+y(i-2)+…+y(i-(M-1))をサンプル数Mで除して、現在のサンプルiから過去M-1個のサンプルの精度の移動平均値ΦM(i)を算出する。
次に、プロセッサは、式2、式3-1により、検証データセットの精度の移動平均線における連続N個の精度の変化量の平均値(式3-1の左辺)を算出する。すなわち、式2によるΔiがサンプルiとi-1との間の精度の変化量である。さらに、プロセッサは、式3-1の左辺により、サンプルiから過去のN-1個のサンプルでの精度の変化量の合計
Δi+Δi-1+Δi-2+…+Δi-(N-1)をサンプル数Nで除して、検証データセットの精度の移動平均線における連続N個の精度の変化量の平均値を算出する。
そして、プロセッサは、式3-1の不等号式に基づいて、精度の移動平均線における連続N個の精度の変化量の平均が、過学習判定閾値Δth未満か否か判定する。この判定では、上記の連続するN個の精度の変化量の合計が、別の過学習判定閾値未満かを判定してもよい。その場合、過学習判定閾値ΔthはN倍にされる。
過学習判定閾値Δthは、正、負のいずれでもよい。前述のとおり、過学習判定閾値を正に設定すると、精度が未だ上昇過程にあるか否かを判定できる。また、過学習状態では、精度が低下する傾向を示すので、過学習判定閾値Δthを、例えば、ゼロ、または負の値に設定すると、過学習状態にあるか否かを判定できる。
また、過学習を判定するための連続N個の精度の変化量の平均値でのN個は、上記の移動平均を求める場合のM個より十分に大きい。つまり、N>Mである。
図9は、過学習を検出するための過学習判定閾値と精度の低下状態との関係を示す。実線が訓練データセットの精度、破線が検証データセットの精度である。検証データセットの精度は、3種類の過学習状態OF1,OF2,OF3が示される。3種類の過学習状態の傾きはOF1>OF2>OF3の順に大きい。例えば、学習開始時の学習率εが大きい場合は、精度の変動幅が大きくなり、過学習状態での精度の低下の程度が大きくなり、一方、学習の終了時での学習率εが小さい場合は、精度の変動幅が小さく、過学習状態での精度の低下の程度は小さくなる。したがって、学習率が大きい場合、過学習判定閾値をΔth = -Yに、次に学習率が大きい場合、Δth = -X (X<Y)に、学習率が最小の場合、Δth = 0に設定することの好ましい。
上記の理由から、図7のS17では、学習率を減衰するときに同時に過学習判定閾値Δthも減衰させて更新している。
過学習検出処理でのNは、図5に示したように精度が上下した後に低下し続ける過学習状態を検出するために適切な値が選択される。経験的には、訓練データセットのデータ数をNdとすると、学習量が2*NdになるようにNを設定するのが過学習判定に適切な最小のNである。すなわち、全訓練用データセットを少なくとも2回学習した場合の精度の傾向が低下傾向にあれば過学習状態と判定することで、図5の精度が上下した後の低下し続ける過学習状態を検出できる。Nを大きく設定すれば過学習状態を確実に検出できるが、その場合は学習量が多くなり無駄な学習が発生するリスクが有る。
前述したとおり、バッチサイズをD個での学習をA回繰り返す毎に、検証データセットで精度を算出しているので、連続N個のサンプル点での精度の変化量の平均での学習量は、D*A*Nであるので、以下の式を満たす最小Nを設定することが好ましい。
D*A*N≧2*Nd
N≧2*Nd/(D*A)
但し、N>M
上記の代替案として、過学習検出処理S13で、プロセッサは、上記の式3-1に代えて、以下の式3-2で過学習の発生を判定してもよい。
Figure 0006955155
式3-2は、式3-1の条件に、Δi<0の条件を加えている。すなわち、代替の過学習の発生の判定では、精度の移動平均線における連続N個の精度の変化量の平均が、過学習判定閾値Δth未満か否かに加えて、最後のサンプルiでの精度が前サンプルiの精度より低下しているか否かが判定される。この条件を加えることで、精度が再度上昇した場合は過学習の発生が検出されない。または、最後の所定の数(複数)のサンプルでの精度がすべて前サンプルの精度より低下しているか否かの条件を加えるようにしてもよい。
このように、プロセッサは、図7において過学習を検出すると(S13のYES)、精度が収束していなければ(S15のNO)、学習率を下げて学習を再開する。このとき、過学習判定閾値も学習率の減衰の程度に対応して下げる。
[精度の収束検出処理S15]
次に、学習の終わりを判定する精度の収束検出処理S15について詳述する。過学習検出処理S13では、プロセッサは、以下の演算により、検証データセットの精度が収束しているか否かを判定する。
Figure 0006955155
ここで、y(i)は、前述と同様に、サンプルiでの検証データセットの精度である。
プロセッサは、式4により、検証データセットのサンプルiとi-1との間の精度の変化量δiを算出する。さらに、プロセッサは、式5により、現在のサンプルiから連続する過去N-1個のサンプル(i-1)〜(i-(N-1))、つまり連続するN個のサンプル、それぞれの精度の変化量δi〜δi-(N-1)の二乗平均の平方根(式5の左辺)が、収束判定閾値δth未満か否か判定する。
上記式5の左辺のNは、精度の収束の判定の連続N個の精度の変化量の二乗平均平方根を意味するが、このN個は、過学習の判定の連続N個と同じである。但し、精度の収束判定における連続N個は、過学習の判定の連続N個と異なってもよい。
[過学習の判定S13と収束の判定S15のタイミング]
図7によれば、過学習が検出されると(S13のYES)、一旦深層学習を停止し、精度の収束の判定(S15)が行われる。つまり、過学習と収束が同時期に検出されると、学習が終了する。
例えば、学習開始時は、学習率が大きいので、過学習が検出されても、精度の収束が検出されることはない。一方、学習が進捗し、学習率が小さくなると、精度の収束が検出されやすくなる。そのため、学習率が小さいサイクルで、連続N個のサンプルの精度の変化量の平均が過学習判定閾値Δth未満になって過学習が検出されるとともに、同じ連続N個のサンプル点の精度の変化量の二乗平均平方根が収束判定閾値δth未満になって収束が検出されることがある。
具体的に言えば、学習率の減衰が進むにつれて、過学習は検出されるが収束は検出されない状況から、最後は、過学習が検出されると共に収束も検出される状況に変化する。この時、学習が終了する。一般に、収束状態は、過学習が発生する前の最高精度近辺での連続N個のサンプル点で発生すると、その後の過学習が発生している連続N個のサンプル点でも発生する。したがって、本実施の形態では、過学習が検出されてから(S13のYES)、収束を判定している(S15)。
過学習の判定S13と収束の判定S15の両方を、学習と検証の処理S12を実行する度に行っても良い。但し、その場合、学習率の減衰が進んだところで、収束は検出されるが未だ過学習は検出されない状況の後に、収束と過学習が同時に判定される状況になることが予測される。その場合、学習と検証の処理S12の度に行う収束の判定が無駄になる。したがって、図7のように、過学習を検出したときに学習を停止して収束を検出するようにするのが効率的である。
[更新前の学習率での検証データセットの最高精度のサンプル点imaxの取得(S17)]
図10は、図7の最高精度のサンプル点imaxの取得について説明する図である。図10には、3つの学習率ε1、ε2、ε3での検証データセットの精度曲線が示され、それぞれの学習率での学習で過学習OFが検出されている。図4,5などに示したとおり、過学習が発生すると検証データセットの精度曲線が低下傾向を示す。そこで、プロセッサは、過学習が検出された後、過去のサンプルの中で最高精度のサンプルimaxを取得し、その最高精度のサンプルでの変数で学習を再開する。これにより、学習が終了時の精度をできるだけ高くすることができる。
[本実施の形態の精度曲線]
上記の通り、本実施の形態の学習では、プロセッサは、ある学習率εと過学習判定閾値Δthと収束判定閾値δthを設定し、設定した学習率で訓練データセットによる学習と検証データセットによる検証とを繰り返しながら、各サンプル点(各エポック)で過学習状態に入ったか否か判定する。過学習状態に入ったことを検出すると、プロセッサは、学習率と過学習判定閾値とを減衰して更新し、最大精度サンプル点でのDNNの変数で、再度上記の学習と検証を再開する。さらに、過学習状態の検出とは独立して、精度が収束したか否かの判定を行い、収束したと判定されると学修を終了する。
図11は、本実施の形態の学習方法で学習した精度曲線の一例を示す図である。横軸がエポック、縦軸が精度である。これによれば、学習率ε0, ε1, ε2, ε3(ε0>ε1>ε2>ε3)それぞれでの学習と検証工程で、過学習状態になったか否かの判定を行い、学習量に対応するエポックE14で精度AC10に達している。過学習状態になったことを検出すると学習率を減衰させて次の学習と検証の繰り返し工程に移行させるので、各学習率での学習量E11-E10、E12-E11、E13-E12、E14-E13は一定ではない。
図11の例では、過学習状態になったことを検出したら学習率を減少させて学習と検証を再開させているので、各学習率での学習量(エポック数)が適切に決められ、過学習状態により学習終了までの学習量が無駄に長くなることはない。
図12は、第1の比較例の精度曲線を示す図である。第1の比較例は、前述の現象2に対応する。第1の比較例では、各学習率での学習量が多すぎるため、各学習率での学習中に過学習状態が発生し、学習終了までの学習量が無駄に多くなっている。図11での総学習量E14に対して、図12での総学習量はE24と長い。また、第1の比較例での到達精度は、過学習により精度が低下した時の変数で学習率を更新して学習を再開しているため、図11での到達精度AC10より低い。
図13は、第2の比較例の精度曲線を示す図である。第2の比較例は、前述の現象1に対応する。第2の比較例でも、各学習率での学習量が多すぎるため、各学習率での学習中に過学習状態が発生し、学習終了までの学習量が無駄に多くなっている。第2の比較例では、到達精度は図11での到達精度AC10と同程度であるが、学習終了までの学習量が図11の学習量E14より多くなっている。
図14は、第3の比較例の精度曲線を示す図である。第3の比較例は、前述の現象3に対応する。第3の比較例では、各学習率での学習量が少なすぎて、精度が最高になる前に学習率が更新されている。その結果、第3の比較例では、総学習量がE44と図11での総学習量E14より少なくなっているが、最終到達精度は図11での到達精度AC10より低い。
図11〜図14から理解できるとおり、各学習率での学習量を長期にわたって過学習状態が発生する前の適切な量に制御することで、総学習量を抑えつつ目標の到達精度に達することができる。
図15は、各学習率での学習量を一定(20万エポック(サンプル点))にして学習した例を示す図である。各学習率ε0, ε1, ε2, ε3(ε0>ε1>ε2>ε3)での学習量を固定し、エポック数E50, E51,E52, E53でそれぞれ学習率ε0, ε1, ε2, ε3に設定した結果、特に、E52-E53とE53-E54で過学習状態が長期にわたり発生し、総学習量はE54と多くなっている。
図16は、本実施の形態により学習した例を示す図である。この例では、学習率ε2, ε3での学習量E62-E63、E63-E64が、図15の例よりも特に少なくなっている。また、この例では、学習率ε1での学習量E60-E61も、図15の例のE50-E51よりも若干少なくなっている。その結果、図15と図16とでは到達精度は0.5を少し上回り同程度であるが、総学習量は、図15のE54の80万回よりも、図16のE64の50万未満と大幅に少なくなっている。
以上説明したとおり、本実施の形態によれば、各学習率での学習中に過学習が検出されたら学習率を更新して学習を再開するので、各学習率での学習量を少なくでき、無駄な過学習状態の発生を抑制できる。さらに、精度がさらに向上する前に学習率を更新することがなくなり、到達精度を高くできる。
1:学習装置
10:プロセッサ
12:メインメモリ
14:GPU
16:GPUメモリ
20:学習プログラム
22:訓練データセット
24:検証データセット
26:精度データ
40:制御部
41:学習部
42:過学習の検出部
43:収束の判定部
DNN:深層学習モデル、ディープニューロンネットワーク
OF:過学習
ε:学習率
EPOCH:エポック
Δth:過学習判定閾値
δth:収束判定閾値

Claims (9)

  1. 訓練データセットについて学習器で学習を行い、検証データセットについて精度を算出する学習部と、
    前記学習部による前記精度に基づいて、過学習状態を検出する検出部と、
    前記学習部による前記精度に基づいて、学習の収束状態を判定する判定部と、
    前記検出部が前記過学習状態を検出した場合、前記学習部による学習率を変更して再び学習させるとともに、前記判定部が前記学習部による学習が収束したと判定した場合、前記学習部による学習を停止させる制御部と、を有する学習装置。
  2. 前記検出部は、
    前記精度の複数のサンプルの移動平均線において、連続N(Nは複数)個の精度による傾きが負を示す場合、前記過学習状態を検出する、請求項1に記載の学習装置。
  3. 前記判定部は、
    複数のサンプルでの精度において、連続L(Lは複数)個の精度間の変化量が第1の閾値未満になる場合、前記収束状態と判定する、請求項1に記載の学習装置。
  4. 前記学習部はさらに、
    前記精度の複数のサンプルを収集する収集部を有する、請求項1に記載の学習装置。
  5. 前記検出部は、
    前記収集部が収集した精度の連続M個(Mは複数)のサンプルに関する移動平均線において、連続N個の精度の変化量の平均が第2の閾値未満になり、かつ、前記複数のサンプルのうち最終サンプル点での精度の変化量が負である場合、前記過学習状態を検出する、請求項4記載の学習装置。
  6. 前記制御部は、前記検出部が前記過学習状態を検出した場合、前記学習率の変更と共に前記第2の閾値を低下するよう変更して再び学習させる、請求項5に記載の学習装置。
  7. 前記判定部は、
    前記収集部が収集した精度の複数のサンプル間の変化量の二乗平均平方根が第3の閾値未満である場合、前記収束したと判定する、請求項4または5に記載の学習装置。
  8. 訓練データセットについて学習率に基づき学習器で学習を行い、検証データセットについて精度を算出し、
    前記精度に基づいて、過学習状態を検出し、
    前記精度に基づいて、学習の収束状態を判定し、
    前記過学習状態を検出した場合、前記学習率を変更して再び前記学習と前記精度の算出を行い、
    学習の収束状態を判定した場合、前記学習を停止する、処理を有する学習方法。
  9. 訓練データセットについて学習率に基づき学習器で学習を行い、検証データセットについて精度を算出し、
    前記精度に基づいて、過学習状態を検出し、
    前記精度に基づいて、学習の収束状態を判定し、
    前記過学習状態を検出した場合、前記学習率を変更して再び前記学習と前記精度の算出を行い、
    学習の収束状態を判定した場合、前記学習を停止する、処理をコンピュータに実行させ
    る学習プログラム。
JP2017200842A 2017-10-17 2017-10-17 学習装置、学習方法及び学習プログラム Active JP6955155B2 (ja)

Priority Applications (1)

Application Number Priority Date Filing Date Title
JP2017200842A JP6955155B2 (ja) 2017-10-17 2017-10-17 学習装置、学習方法及び学習プログラム

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
JP2017200842A JP6955155B2 (ja) 2017-10-17 2017-10-17 学習装置、学習方法及び学習プログラム

Publications (2)

Publication Number Publication Date
JP2019074947A JP2019074947A (ja) 2019-05-16
JP6955155B2 true JP6955155B2 (ja) 2021-10-27

Family

ID=66544168

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2017200842A Active JP6955155B2 (ja) 2017-10-17 2017-10-17 学習装置、学習方法及び学習プログラム

Country Status (1)

Country Link
JP (1) JP6955155B2 (ja)

Families Citing this family (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP7171520B2 (ja) * 2019-07-09 2022-11-15 株式会社日立製作所 機械学習システム
FI20195682A1 (en) 2019-08-15 2021-02-16 Liikennevirta Oy / Virta Ltd CHARGING STATION MONITORING METHOD AND APPARATUS
JP2021081930A (ja) * 2019-11-18 2021-05-27 日本放送協会 学習装置、情報分類装置、及びプログラム
WO2023188286A1 (ja) * 2022-03-31 2023-10-05 日本電気株式会社 学習装置、推定装置、学習方法および記録媒体

Family Cites Families (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US6513023B1 (en) * 1999-10-01 2003-01-28 The United States Of America As Represented By The Administrator Of The National Aeronautics And Space Administration Artificial neural network with hardware training and hardware refresh
US8521670B2 (en) * 2011-05-25 2013-08-27 HGST Netherlands B.V. Artificial neural network application for magnetic core width prediction and modeling for magnetic disk drive manufacture
JP6164639B2 (ja) * 2013-05-23 2017-07-19 国立研究開発法人情報通信研究機構 ディープ・ニューラルネットワークの学習方法、及びコンピュータプログラム
JP5777178B2 (ja) * 2013-11-27 2015-09-09 国立研究開発法人情報通信研究機構 統計的音響モデルの適応方法、統計的音響モデルの適応に適した音響モデルの学習方法、ディープ・ニューラル・ネットワークを構築するためのパラメータを記憶した記憶媒体、及び統計的音響モデルの適応を行なうためのコンピュータプログラム

Also Published As

Publication number Publication date
JP2019074947A (ja) 2019-05-16

Similar Documents

Publication Publication Date Title
JP6955155B2 (ja) 学習装置、学習方法及び学習プログラム
US20190122078A1 (en) Search method and apparatus
JP4223894B2 (ja) Pidパラメータ調整装置
CN112101530A (zh) 神经网络训练方法、装置、设备及存储介质
CN111221375B (zh) Mppt控制方法、装置、光伏发电设备及可读存储介质
US20200265307A1 (en) Apparatus and method with multi-task neural network
CN114861880A (zh) 基于空洞卷积神经网络的工业设备故障预测方法及装置
JP2016018230A (ja) 制御パラメータ適合方法及び制御パラメータ適合支援装置
CN115587545B (zh) 一种用于光刻胶的参数优化方法、装置、设备及存储介质
CN111461329A (zh) 一种模型的训练方法、装置、设备及可读存储介质
CN115346125B (zh) 一种基于深度学习的目标检测方法
JP2021197108A (ja) 学習プログラム、学習方法および情報処理装置
CN117150882A (zh) 发动机油耗预测方法、系统、电子设备及存储介质
JP6560207B2 (ja) 信号を特徴付けるための方法及びデバイス
CN113986700A (zh) 数据采集频率的优化方法、系统、装置及存储介质
CN108920842B (zh) 一种潜艇动力学模型参数在线估计方法及装置
CN113408692A (zh) 网络结构的搜索方法、装置、设备及存储介质
CN117152588B (zh) 一种数据优化方法、系统、装置及介质
JP5436689B2 (ja) 混合微分代数プロセスモデルの状態変数をリアルタイムに計算する方法
JP7436830B2 (ja) 学習プログラム、学習方法、および学習装置
CN116176737B (zh) 一种车辆控制方法、装置、车辆及存储介质
CN110648021B (zh) 一种两级电力负荷预测结果协调方法、装置及设备
US20240185070A1 (en) Training action selection neural networks using look-ahead search
CN115114966B (zh) 模型的操作策略的确定方法、装置、设备及存储介质
US20220253693A1 (en) Computer-readable recording medium storing machine learning program, apparatus, and method

Legal Events

Date Code Title Description
A621 Written request for application examination

Free format text: JAPANESE INTERMEDIATE CODE: A621

Effective date: 20200709

A977 Report on retrieval

Free format text: JAPANESE INTERMEDIATE CODE: A971007

Effective date: 20210428

A131 Notification of reasons for refusal

Free format text: JAPANESE INTERMEDIATE CODE: A131

Effective date: 20210511

A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20210708

TRDD Decision of grant or rejection written
A01 Written decision to grant a patent or to grant a registration (utility model)

Free format text: JAPANESE INTERMEDIATE CODE: A01

Effective date: 20210831

A61 First payment of annual fees (during grant procedure)

Free format text: JAPANESE INTERMEDIATE CODE: A61

Effective date: 20210913

R150 Certificate of patent or registration of utility model

Ref document number: 6955155

Country of ref document: JP

Free format text: JAPANESE INTERMEDIATE CODE: R150