JPWO2020161935A1 - 学習装置、学習方法、及び、プログラム - Google Patents

学習装置、学習方法、及び、プログラム Download PDF

Info

Publication number
JPWO2020161935A1
JPWO2020161935A1 JP2020570350A JP2020570350A JPWO2020161935A1 JP WO2020161935 A1 JPWO2020161935 A1 JP WO2020161935A1 JP 2020570350 A JP2020570350 A JP 2020570350A JP 2020570350 A JP2020570350 A JP 2020570350A JP WO2020161935 A1 JPWO2020161935 A1 JP WO2020161935A1
Authority
JP
Japan
Prior art keywords
estimation
temperature
estimation unit
unit
learning
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
JP2020570350A
Other languages
English (en)
Other versions
JP7180697B2 (ja
Inventor
遊哉 石井
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
NEC Corp
Original Assignee
NEC Corp
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by NEC Corp filed Critical NEC Corp
Publication of JPWO2020161935A1 publication Critical patent/JPWO2020161935A1/ja
Application granted granted Critical
Publication of JP7180697B2 publication Critical patent/JP7180697B2/ja
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06VIMAGE OR VIDEO RECOGNITION OR UNDERSTANDING
    • G06V10/00Arrangements for image or video recognition or understanding
    • G06V10/70Arrangements for image or video recognition or understanding using pattern recognition or machine learning
    • G06V10/77Processing image or video features in feature spaces; using data integration or data reduction, e.g. principal component analysis [PCA] or independent component analysis [ICA] or self-organising maps [SOM]; Blind source separation
    • G06V10/776Validation; Performance evaluation

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Health & Medical Sciences (AREA)
  • General Health & Medical Sciences (AREA)
  • Medical Informatics (AREA)
  • Multimedia (AREA)
  • Databases & Information Systems (AREA)
  • Mathematical Physics (AREA)
  • Data Mining & Analysis (AREA)
  • General Engineering & Computer Science (AREA)
  • Biomedical Technology (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Molecular Biology (AREA)
  • Feedback Control In General (AREA)
  • Air Conditioning Control Device (AREA)

Abstract

【課題】蒸留を用いた深層学習において、適切な温度を自動的に求めることを可能とする。【解決手段】学習装置は、生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定部と、教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定部と、前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算部と、を備える。

Description

本発明は、深層学習に関し、特に蒸留に関する。
多層に接続されたニューロンを有するDNN(Deep Neural Network)を利用して言語認識や画像認識の学習を行うディープラーニングに関する技術が知られている。しかし、一般にDNNはパラメータ数と推定精度とに比例関係があり、軽量な計算量で動作するDNNは精度が悪いという問題があった。
これに対し、蒸留と呼ばれる学習方法を用いることで軽量なDNNでも推定精度を向上させることができることが知られている(例えば、特許文献1、非特許文献1を参照)。蒸留では、学習データの正解ラベルに加えて、既に学習済みのDNN(「教師モデル」と呼ぶ。)の推定結果も正解レベルとして用いる。教師モデルは一般に、これから学習させたいDNN(「生徒モデル」と呼ぶ。)よりもパラメータ数が多く高精度なDNNが選ばれる。蒸留の1つとして、「温度」と呼ばれるハイパーパラメータTを導入することにより、教師モデルの出力を生徒モデルの学習に効果的に反映させることが可能となる。
米国特許公開公報 US2015/0356461 A1
Distilling the Knowledge in a Neural Network,Geoffrey Hinton,Oriol Vinyals,Jeff Dean
しかしながら、温度の値によっては、学習後の生徒モデルの精度が蒸留を用いない場合と比べて大差なかったり、むしろ下がってしまったりすることが知られている。そのため、いくつかの異なる温度に設定して学習した生徒モデル同士の精度を比較して、最も高精度となったものを採用する、「グリッドサーチ」と呼ばれる手法も提案されている。しかし、これには多くの学習時間を要するという問題のほか、離散的な値を設定せざるを得ないために最高精度を与える温度を省いてしまう可能性があるという問題もある。最高精度を与える適切な温度は、教師モデルや生徒モデルの構造や大きさ、解く問題の種類や難易度、学習に用いる他のハイパーパラメータによって様々に変化することが実験的に知られているが、それらを定量的に評価したり、自動的に適切な温度を求めたりする手法は公開されていない。
本発明の目的は、蒸留を用いた深層学習において、適切な温度を自動的に求めることを可能とすることにある。
本発明の1つの観点では、学習装置は、生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定部と、教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定部と、前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算部と、を備える。
本発明の他の観点では、学習装置により実行される学習方法は、生徒モデルに基づき、温度パラメータを用いて推定を行い、教師モデルに基づき、温度パラメータを用いて推定を行い、前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する。
本発明の更に他の観点では、コンピュータを備える学習装置により実行されるプログラムは、生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定部、教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定部、前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算部、として前記コンピュータを機能させる。
実施形態に係る学習装置のハードウェア構成を示すブロック図である。 実施形態に係る学習装置の基本的機能構成を示すブロック図である。 第1実施形態に係る学習装置の機能構成を示すブロック図である。 第1実施形態による学習処理のフローチャートである。 第2実施形態に係る学習装置の機能構成を示すブロック図である。 第2実施形態による学習処理のフローチャートである。 第3実施形態に係る学習装置の機能構成を示すブロック図である。 第3実施形態による学習処理のフローチャートである。
以下、図面を参照して本発明の好適な実施形態について説明する。
[装置構成]
図1は、実施形態に係る学習装置のハードウェア構成を示すブロック図である。学習装置1は、コンピュータにより構成され、プロセッサ1と、メモリ2とを備える。学習装置1には、正解ラベル付きの学習データが入力される。学習装置1は蒸留を用いた学習を行う装置であり、学習データ及び教師モデルを用いて生徒モデルの学習を行う。
詳細には、メモリ2は、生徒モデル及び教師モデルを構成する内部パラメータ、並びに、蒸留に用いる温度パラメータを記憶する。プロセッサ1は、メモリ2に記憶されている各パラメータに基づき、生徒モデル及び教師モデルに学習データを適用し、各パラメータを更新することにより学習を行う。さらに、プロセッサ1は、生徒モデル及び教師モデルによる推定処理により得られる情報(以下、「推定情報」と呼ぶ。)に基づいて、温度パラメータを自動的に更新する。
[基本的構成]
図2は、実施形態に係る学習装置の基本的な機能構成を示すブロック図である。図示のように、学習装置1は、推定部5と、推定部6と、温度計算部7とを備える。推定部5は、生徒モデルを構成する推定部であり、温度パラメータTを用いて学習データの推定処理を行って推定情報を出力する。一方、推定部6は、教師モデルを構成する推定部であり、温度パラメータTを用いて学習データの推定処理を行って推定情報を出力する。温度計算部7は、推定部5が生成する推定情報、及び、推定部6が生成する推定情報に基づいて、温度パラメータTを計算し、推定部5及び6に供給する。こうして、推定部5及び6が用いる温度パラメータTが自動的に更新される。一例では、推定情報として、生徒モデル及び教師モデルにより生成されるロジットもしくは推定結果が用いられる。他の例では、推定情報として、生徒モデルの推定結果の正解ラベルに対する損失、及び、生徒モデルの推定結果の教師モデルの推定結果に対する損失が用いられる。
[第1実施形態]
次に、第1実施形態について説明する。第1実施形態は、生徒モデルにより生成される推定結果の正解ラベルに対する損失、及び、生徒モデルにより生成される推定結果の教師モデルにより生成される推定結果に対する損失に基づいて温度パラメータを更新するものである。損失は本発明の推定情報の一例である。
(機能構成)
図3は、第1実施形態に係る学習装置100の機能構成を示すブロック図である。学習装置100には、学習データが入力される。学習データは、文字列や画像など、学習すべき特徴を示す学習データXと、その正解ラベルYとの組からなる。例えば、画像に写っているのが犬か猫かを判定するDNNを学習させる場合、学習データXは犬の画像であり、正解ラベルYは(犬クラス1,猫クラス0)といった、その画像の属する正解クラスに「1」を、その他の不正解クラスに「0」を持つバイナリ値のベクトルである。
学習装置100は、大別して推定部10aと、推定部10bと、最適化部30とを備える。推定部10aは、生徒モデルを構成するDNNに相当し、ロジット計算器11aと、活性化器12aとを備える。生徒モデルは、学習の対象となるモデルである。推定部10aは、入力された学習データXに対して推定結果Yaを出力する。詳細には、ロジット計算器11aは、内部パラメータPAに基づき、入力された学習データXをロジットyaに変換する。活性化器12aは、式(1)に示すソフトマックス関数などの予め定められた関数に基づいて、ロジットyaを確率分布である推定結果Yaに変換する。
Figure 2020161935
なお、活性化器12aは、温度パラメータTを用いて推定結果Yaを計算するが、その詳細は後述する。推定結果Yaは、一般的に(犬クラス0.8,猫クラス0.2)といった、不正解クラスにも「0」以外の値を持つベクトルである。推定結果Yaは、最適化部30へ供給される。
一方、推定部10bは、教師モデルを構成するDNNに相当し、ロジット計算器11bと、活性化器12bとを備える。教師モデルは、既に学習済みのモデルであり、一般的に生徒モデルよりもパラメータ数が多く高精度なDNNが用いられる。推定部10bは、入力された学習データXに対して推定結果Ybを出力する。詳細には、ロジット計算器11bは、内部パラメータPBに基づき、入力された学習データXをロジットybに変換する。活性化器12bは、上記の式(1)に示すソフトマックス関数などの予め定められた関数に基づいて、ロジットybを確率分布である推定結果Ybに変換する。なお、活性化器12bも温度Tを用いて推定結果Ybを計算するが、その詳細は後述する。推定結果Ybは、推定結果Yaと同様に不正解クラスにも「0」以外の値を持つベクトルである。推定結果Ybは、最適化部30へ供給される。
最適化部30は、損失計算器31と、加重平均計算器32と、パラメータ更新器33と、温度更新器34とを備える。損失計算器31には、推定部10aで生成された推定結果Yaと、推定部10bで生成された推定結果Ybと、学習データの正解ラベルYとが入力される。損失計算器31は、式(2)に示すカテゴリカルクロスエントロピーなどの予め定められた関数に基づいて、推定結果Yaの正解ラベルYに対する損失La、及び、推定結果Yaの推定結果Ybに対する損失Lbを計算する。「損失」は、推定結果Yaと正解ラベルY、又は、推定結果Yaと推定結果Ybがどれくらい離れているかを表す。
Figure 2020161935
学習データの数がクラス間で偏っている場合、損失LaやLbに偏りを補正する効果を加えてもよい。具体的には、式(3)により、クラスごとの損失に補正のための係数wをかける。例えば、wは、式(4)又は(5)により算出される。
Figure 2020161935
ここに、nはi番目のクラスに属する学習データの総数である。他の例では、wは、学習データの総数nの代わりに、教師モデルが全学習データもしくはミニバッチにおいて算出した推定結果Ybのクラスごとの総和を用いてもよい。
加重平均計算器32は、予め決められた重みを用いて、損失Laと損失Lbの加重平均Lavを計算する。パラメータ更新器33は、式(6)に示す勾配降下法などの予め定められた関数に基づいて、加重平均された損失Lavが小さくなるようにロジット計算器11aの内部パラメータPAを更新し、ロジット計算器11aに供給する。
Figure 2020161935
これにより、ロジット計算器11aの内部パラメータPAが更新される。なお、推定部10bのロジット計算器11bの内部パラメータPBは更新されない。
温度更新器34は、推定部10aの活性化器12a及び推定部10bの活性化器12bが用いる温度パラメータTを更新する。ここで、活性化器12a、12bによる、温度Tを用いた蒸留処理について詳しく説明する。推定部10a、10bはそれぞれ推定結果Ya、Ybを出力するが、推定結果Ybは学習済みの高精度な教師モデルによる出力であるから、例えば(犬クラス0.999,猫クラス0.001)のようにベクトルの要素がバイナリ値に近くなってしまい、(犬クラス1,猫クラス0)のような正解ラベルYのみを用いた場合とほとんど変わりがない学習工程となってしまう。そこで、新たに「温度パラメータ」と呼ばれるハイパーパラメータTを導入する。温度パラメータTは、式(7)に示すように、活性化器の活性化関数内に導入されており、一般にT≧1を用いる。
Figure 2020161935
こうすることで、実際は(犬クラス0.999,猫クラス0.001)である推定結果Ybが、(犬クラス0.9,猫クラス0.1)のようになだらかな分布となり、「猫クラス0.1」の数値が学習工程に影響を与え、生徒モデルを効率的に学習させることが可能となる。即ち、温度パラメータTは、教師モデルが出力する正解ラベルのうち、正解クラスと不正解クラスの桁の不整合を調整するために導入される。
パラメータ更新器33が損失を用いて生徒モデルの内部パラメータPAを最適化するのと同様に、温度更新器34は、損失を用いて温度パラメータTを逐次的に最適化する。具体的には、温度更新器34は、加重平均された損失Lavが小さくなるように温度パラメータTを更新する。温度更新器34は、例えば式(8)により温度パラメータTを更新することができる。
Figure 2020161935
式(8)において、「ε」は式(6)の学習率「ε」と同様の働きをする定数であり、学習の間中同じ値を用いてもよいし、プラトー(plateau)のように学習の進行に応じて変化させてもよい。また、「ε」は式(6)と同じ値を用いてもよいし、異なる値を用いてもよい。なお、式(8)は一般的な勾配降下法を表すが、モデルの内部パラメータPのときと同様に、重み減衰や重み上限などの正則化手法を合わせて用いてもよい。また、温度更新器34が温度パラメータTの更新に用いる最適化手法は式(8)の勾配降下法には限られず、momentum、adagrad、adadelta、adamなどの他の最適化手法を用いてもよい。
こうして、最適化部30は、推定部10aの内部パラメータPA及び温度パラメータTの更新を、更新量が十分に小さくなるまで繰り返し行う。更新量が所定値より小さくなった場合に、学習は終了する。なお、学習が終了した生徒モデルを推論に用いる際には、温度パラメータをT=1とする。即ち、温度パラメータは、学習時にT≧1とし、推論時にT=1とする。
(処理フロー)
次に、第1実施形態による学習処理の流れについて説明する。図4は、第1実施形態による学習処理のフローチャートである。この処理は、図1に示すプロセッサ1が予め用意されたプログラムを実行することにより実現される。
学習装置100に学習データXが入力されると、生徒モデルに相当する推定部10aにおいて、ロジット計算器11aは学習データXのロジットyaを計算する(ステップS11)。次に、活性化器12aは、前述のように温度パラメータTを用いてロジットyaから推定結果Yaを計算する(ステップS12)。同様に、教師モデルに相当する推定部10bにおいては、ロジット計算器11bは学習データXのロジットybを計算する(ステップS13)。次に、活性化器12bは、前述のように温度パラメータTを用いてロジットybから推定結果Ybを計算する(ステップS14)。
次に、最適化部30において、損失計算器31は、正解ラベルYに対する推定結果Yaの損失La、及び、推定結果Ybに対する推定結果Yaの損失Lbを計算する(ステップS15)。次に、加重平均計算器32は、損失Laと損失Lbの加重平均Lavを計算する(ステップS16)。
次に、パラメータ更新器33は、加重平均された損失Lavに基づいて、ロジット計算器11aの内部パラメータPAを更新する(ステップS17)。次に、温度更新器34は、加重平均された損失Lavに基づいて、活性化器12a及び12bが使用する温度パラメータTを更新する(ステップS18)。
学習装置100は、パラメータ更新器33による内部パラメータPAの更新量が所定値より小さくなるまで処理を繰り返し、更新量が所定値より小さくなったときに処理を終了する。処理が終了したときにロジット計算器11aに設定されている内部パラメータPA及び活性化器12aに設定されている温度パラメータTにより生徒モデルが規定される。そして、推論時には、この生徒モデルにおいて温度パラメータT=1として推論処理が実行される。
(変形例)
上記の第1実施形態では、温度更新器34は、正解ラベルYに対する推定結果Yaの損失La、及び、推定結果Ybに対する推定結果Yaの損失Lbの両方を用いて温度パラメータTを更新している。その代わりに、温度更新器34は、正解ラベルYに対する推定結果Yaの損失La、及び、推定結果Ybに対する推定結果Yaの損失Lbのいずれか一方を用いて温度パラメータTを更新しても良い。
(実施例)
第1実施形態の1つの実施例では、温度パラメータTの初期値を「1」に設定する。温度更新器34は、パラメータ更新器33がロジット計算器11aの内部パラメータPAを更新するのと同じタイミングで、温度パラメータTを更新する。温度更新器34は、確率的勾配降下法を用いて温度パラメータTを更新し、学習率はε=1.0×10−4とする。
以上の工程を、同様の学習データの1つのサンプルもしくは複数のサンプルのセット(「ミニバッチ」という。)で繰り返し行う。繰り返しは、温度パラメータTの更新量もしくはロジット計算器11aの内部パラメータPAの更新量が十分小さくなるか、既定の繰り返し回数(例えば100回)に達した段階で終了する。その後、別の学習データのサンプルもしくはミニバッチで同様の工程を行い、ロジット計算器11aの内部パラメータPAの更新量が十分小さくなるまで繰り返す。このとき、サンプルもしくはミニバッチが変更されるたびに温度パラメータTの初期値は「1」に設定する。
[第2実施形態]
次に、第2実施形態について説明する。第2実施形態は、推定部内のロジット計算器により計算されるロジットに基づいて温度パラメータを更新するものである。ロジットは、本発明の推定情報の一例である。なお、第2実施形態に係る学習装置200のハードウェア構成は、図1に示す第1実施形態の学習装置100と同様である。
(機能構成)
図5は、第2実施形態に係る学習装置200の機能構成を示すブロック図である。図3と比較するとわかるように、学習装置200は、基本的構成は第1実施形態の学習装置100と同様であり、推定部10aと、推定部10bと、最適化部30とを備える。但し、第2実施形態の学習装置200は、第1実施形態における温度更新器34の代わりに、温度計算器21を備える。
温度計算器21には、推定部10aのロジット計算器11aが出力するロジットyaと、推定部10bのロジット計算器11bが出力するロジットybとが入力される。温度計算器21は、予め定められた関数に基づいて、ロジットya及びybを用いて適切な温度パラメータTを計算する。前述のように、温度パラメータTは、教師モデルが出力する正解ラベルのうち、正解クラスと不正解クラスの桁の不整合を調整するために導入されている。学習データによって不整合の強さは変化するため、それに応じて温度パラメータTを変化させる。具体的には、温度計算器21は、ロジットyaとybの不整合が大きい場合には温度パラメータTに大きな値を設定し、不整合が小さい場合には温度パラメータTに「1」に近い値を設定する。これにより、どのような学習データも平等に学習することが可能となる。
温度計算器21は、計算した温度パラメータTを活性化器12a及び12bに供給する。活性化器12aは、ロジット計算器11aから入力されたロジットyaと、温度パラメータTとに基づいて推定結果Yaを計算し、最適化部30の損失計算器31へ供給する。また、活性化器12bは、ロジット計算器11bから入力されたロジットybと、温度パラメータTとに基づいて推定結果Ybを計算し、最適化部30の損失計算器31へ供給する。
最適化部30において、損失計算器31は入力された推定結果Yaの正解ラベルYに対する損失La、及び、推定結果Yaの推定結果Ybに対する損失Lbを計算し、加重平均計算器32は損失La、Lbの加重平均された損失Lavを計算する。そして、パラメータ更新器33は、加重平均された損失Lavに基づいて、推定部10a内のロジット計算器11aの内部パラメータPAを更新する。
学習データの数がクラス間で偏っている場合、損失LaやLbに偏りを補正する効果を加えてもよい。具体的には、前述の式(3)により、クラスごとの損失に補正のための係数wをかける。例えば、wは、前述の式(4)又は(5)により算出される。ここに、nはi番目のクラスに属する学習データの総数である。他の例では、wは、学習データの総数nの代わりに、教師モデルが全学習データもしくはミニバッチにおいて算出した推定結果Ybのクラスごとの総和を用いてもよい。
(処理フロー)
次に、第2実施形態による学習処理の流れについて説明する。図6は、第2実施形態による学習処理のフローチャートである。この処理は、図1に示すプロセッサ1が予め用意されたプログラムを実行することにより実現される。
学習装置200に学習データXが入力されると、生徒モデルに相当する推定部10aにおいて、ロジット計算器11aは学習データXのロジットyaを計算する(ステップS21)。教師モデルに相当する推定部10bにおいては、ロジット計算器11bは学習データXのロジットybを計算する(ステップS22)。次に、温度計算器21は、予め定められた関数に基づき、ロジットya及びybから温度パラメータTを決定する(ステップS23)。決定された温度パラメータTは、活性化器12a及び12bに供給される。
次に、活性化器12aは、温度パラメータTを用いてロジットyaから推定結果Yaを計算し、活性化器12bは、温度パラメータTを用いてロジットybから推定結果Ybを計算する(ステップS24)。
次に、最適化部30において、損失計算器31は、正解ラベルYに対する推定結果Yaの損失La、及び、推定結果Ybに対する推定結果Yaの損失Lbを計算する(ステップS25)。次に、加重平均計算器32は、損失Laと損失Lbの加重平均Lavを計算する(ステップS26)。次に、パラメータ更新器33は、加重平均された損失Lavに基づいて、ロジット計算器11aの内部パラメータPAを更新する(ステップS27)。
学習装置200は、パラメータ更新器33による内部パラメータPAの更新量が所定値より小さくなるまで上記の処理を繰り返し、更新量が所定値より小さくなったときに処理を終了する。処理が終了したときにロジット計算器11aに設定されている内部パラメータPA及び活性化器12aに設定されている温度パラメータTにより生徒モデルが規定される。そして、推論時には、この生徒モデルにおいて温度パラメータT=1として推論処理が実行される。
(変形例)
上記の第2実施形態では、温度計算器21は、生徒モデルに相当する推定部10aが生成するロジットya、及び、教師モデルに相当する推定部10bが生成するロジットybの両方を用いて温度パラメータTを計算している。その代わりに、温度計算器21は、生徒モデルに相当する推定部10aが生成するロジットya、及び、教師モデルに相当する推定部10bが生成するロジットybのいずれか一方を用いて温度パラメータTを計算しても良い。
(実施例)
第2実施形態の1つの実施例では、温度パラメータTは、生徒モデルに相当する推定部10a又は教師モデルに相当する推定部10bが出力するロジットの値が正のときに正の値に設定され、負のときに負の値に設定される。例えば、ロジットの値が負の場合には温度パラメータT=−5とし、ロジットの値が正の場合には温度パラメータT=1とする。好適には、温度パラメータTは、−100〜100の範囲で決定される。例えば、温度計算器21は、生徒モデルに相当する推定部10aが出力するロジットが正のときに温度パラメータTを0〜100の値に設定し、負のときに温度パラメータTを−100〜0の値に設定する。活性化器12a及び12bの活性化関数にはシグモイド関数を用いる。
[第3実施形態]
次に、第3実施形態について説明する。第3実施形態は、推定部により計算される推定結果に基づいて温度パラメータを更新するものである。推定結果は、本発明の推定情報の一例である。なお、第3実施形態に係る学習装置300のハードウェア構成は、図1に示す第1実施形態の学習装置100と同様である。
(機能構成)
図7は、第3実施形態に係る学習装置300の機能構成を示すブロック図である。図5と比較するとわかるように、学習装置300は、第2実施形態の学習装置200に加えて、活性化器22a及び22bを備える。なお、推定部10a及び10b、並びに最適化部30の構成は第2実施形態の学習装置200と同様である。
温度計算器21は、予め定められた関数に基づき、推定部10a及び10bから供給された推定結果Ya及びYbから温度パラメータTを計算し、これを活性化器22a及び22bに供給する。活性化器22aは、予め定められた関数及び温度パラメータTに基づき、ロジット計算器11aが出力するロジットyaから推定結果Y’aを計算し、最適化部30の損失計算器31へ供給する。同様に、活性化器22bは、予め定められた関数及び温度パラメータTに基づき、ロジット計算器11bが出力するロジットybから推定結果Y’bを計算し、最適化部30の損失計算器31へ供給する。
最適化部30の構成は、第2実施形態と同様である。即ち、損失計算器31は入力された推定結果Y’aの正解ラベルYに対する損失La、及び、推定結果Y’aの推定結果Y’bに対する損失Lbを計算し、加重平均計算器32は損失La、Lbの加重平均された損失Lavを計算する。そして、パラメータ更新器33は、加重平均された損失Lavに基づいて、推定部10a内のロジット計算器11aの内部パラメータPAを更新する。
学習データの数がクラス間で偏っている場合、損失LaやLbに偏りを補正する効果を加えてもよい。具体的には、前述の式(3)により、クラスごとの損失に補正のための係数wをかける。例えば、wは、前述の式(4)又は(5)により算出される。ここに、nはi番目のクラスに属する学習データの総数である。他の例では、wは、学習データの総数nの代わりに、教師モデルが全学習データもしくはミニバッチにおいて算出した推定結果Ybのクラスごとの総和を用いてもよい。
(処理フロー)
次に、第3実施形態による学習処理の流れについて説明する。図8は、第3実施形態による学習処理のフローチャートである。この処理は、図1に示すプロセッサ1が予め用意されたプログラムを実行することにより実現される。
学習装置300に学習データXが入力されると、生徒モデルに相当する推定部10aは学習データXから推定結果Yaを計算する(ステップS31)。推定結果Yaを計算する際、推定部10aの活性化器12aは温度パラメータT=1を用いる。教師モデルに相当する推定部10bは、学習データXから推定結果Ybを計算する(ステップS32)。推定結果Ybを計算する際、推定部10bの活性化器12bは温度パラメータT=1を用いる。
次に、温度計算器21は、予め定められた関数に基づき、推定結果Ya及びYbから温度パラメータTを決定する(ステップS33)。決定された温度パラメータTは、活性化器22a及び22bに供給される。
次に、活性化器22aは、温度パラメータTを用いて、ロジット計算器11aが出力したロジットyaから推定結果Y’aを計算する。また、活性化器22bは、温度パラメータTを用いて、ロジット計算器11bが出力したロジットybから推定結果Y’bを計算する(ステップS34)。
次に、最適化部30において、損失計算器31は、正解ラベルYに対する推定結果Y’aの損失La、及び、推定結果Y’bに対する推定結果Y’aの損失Lbを計算する(ステップS35)。次に、加重平均計算器32は、損失Laと損失Lbの加重平均Lavを計算する(ステップS36)。そして、パラメータ更新器33は、加重平均された損失Lavに基づいて、ロジット計算器11aの内部パラメータPAを更新する(ステップS37)。
学習装置300は、パラメータ更新器33による内部パラメータPAの更新量が所定値より小さくなるまで上記の処理を繰り返し、更新量が所定値より小さくなったときに処理を終了する。処理が終了したときにロジット計算器11aに設定されている内部パラメータPA及び活性化器22aに設定されている温度パラメータTにより生徒モデルが規定される。そして、推論時には、この生徒モデルにおいて温度パラメータT=1として推論処理が実行される。
(変形例)
上記の第3実施形態では、温度計算器21は、生徒モデルに相当する推定部10aが生成する推定結果Ya、及び、教師モデルに相当する推定部10bが生成する推定結果Ybの両方を用いて温度パラメータTを計算している。その代わりに、温度計算器21は、生徒モデルに相当する推定部10aが生成する推定結果Ya、及び、教師モデルに相当する推定部10bが生成する推定結果Ybのいずれか一方を用いて温度パラメータTを計算しても良い。
(実施例)
第3実施形態の第1の実施例では、温度計算器21は、推定部10bから出力される推定結果Ybの各クラスの確率推定値の大小バランスを補正するように、具体的には、各クラスの推定結果のオーダーが揃うように温度パラメータTを設定する。例えば、温度計算器21は、確率推定値の最大値/最小値が10以上の場合には温度パラメータT=5とし、それ以下の場合は温度パラメータT=1とする。好適には、温度パラメータTは1〜100の範囲で決定される。活性化器22a及び22bの活性化関数にはソフトマックス関数を用いる。
第3実施形態の第2の実施例では、教師モデルに相当する推定部10bから出力される推定結果Ybのうち、最も確率の高いクラスの値を「p1」、次に確率の高いクラスの値を「p2」としたときに、温度計算器21は、それらの値の比:r=p1/p2に基づいて温度パラメータTを決定する。一例では、温度計算器21は、r<4のときT=1とし、4≦r<5のときT=2とし、r≧5のときT=3とする。他の例では、温度計算器21は、より連続的にrに関するシグモイド関数で温度パラメータTを定める。例えば、温度計算器21は、rが十分小さいときにT=1となり、rが十分大きいときにT=3となり、r=4のときにT=2となるように、温度パラメータTを
T=2/(1+e(−r+4))+1
と定める。
第3実施形態の第3の実施例では、教師モデルに相当する推定部10bから出力される推定結果YbのエントロピーEに基づいて温度パラメータTを決定する。エントロピーEは、i番目のクラスの推定結果をYbとしたとき、以下の式(9)により与えられる。
Figure 2020161935
温度パラメータTは、推定結果YbのエントロピーEに関する単調減少関数として与えられる。一例では、Eが最小値のときT=10となり、Eが最大値のときT=1となるように、温度パラメータTを式(10)のように定める。ここに、Nは教師モデルが分類するクラス数である。
Figure 2020161935
[他の実施形態]
上記の第1実施形態と、第2実施形態又は第3実施形態とを組み合わせてもよい。例えば、第1実施形態と第2実施形態とを組み合わせた場合、第2実施形態の手法によりロジットに基づいて温度パラメータTの初期値を決定し、その後は第1実施形態の手法により損失に基づいて温度パラメータTを更新すればよい。また、第1実施形態と第3実施形態とを組み合わせた場合、第3実施形態の手法により推定結果に基づいて温度パラメータTの初期値を決定し、その後は第1実施形態の手法により損失に基づいて温度パラメータTを更新すればよい。
上記の実施形態の一部又は全部は、以下の付記のようにも記載されうるが、以下には限られない。
(付記1)
生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定部と、
教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定部と、
前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算部と、
を備える学習装置。
(付記2)
前記温度計算部は、正解ラベルに対する前記第1の推定部が生成する第1の推定結果の損失、及び、前記第2の推定部が生成する第2の推定結果に対する前記第1の推定結果の損失の少なくとも一方に基づいて前記温度パラメータを計算する付記1に記載の学習装置。
(付記3)
前記温度計算部は、正解ラベルに対する前記第1の推定部が生成する第1の推定結果の損失と、前記第2の推定部が生成する第2の推定結果に対する前記第1の推定結果の損失との加重平均に基づいて前記温度パラメータを計算する付記1に記載の学習装置。
(付記4)
前記温度計算部は、前記第1の推定部が生成するロジット、及び、前記第2の推定部が生成するロジットの少なくとも一方に基づいて、前記温度パラメータを計算する付記1に記載の学習装置。
(付記5)
前記温度計算部は、前記第1の推定部又は第2の推定部が生成するロジットの正負に応じて前記温度パラメータを計算する付記1に記載の学習装置。
(付記6)
前記温度計算部は、前記第1の推定部が生成する推定結果、及び、前記第2の推定部が生成する推定結果の少なくとも一方に基づいて前記温度パラメータを計算する付記1に記載の学習装置。
(付記7)
前記温度計算部は、前記第2の推定部が生成する推定結果のエントロピーに基づいて前記温度パラメータを計算する付記6に記載の学習装置。
(付記8)
前記温度計算部は、前記第2の推定部が生成する推定結果のエントロピーに関する単調減少関数として温度パラメータを計算する付記6に記載の学習装置。
(付記9)
前記温度計算部は、前記第1の推定部又は前記第2の推定部が生成する各クラスの推定結果の大小バランスが補正されるように前記温度パラメータを計算する付記6に記載の学習装置。
(付記10)
前記温度計算部は、前記第1の推定部又は前記第2の推定部が生成する各クラスの推定結果のオーダーが揃うように前記温度パラメータを計算する付記6に記載の学習装置。
(付記11)
前記温度計算部は、前記第2の推定部が生成する各クラスの推定結果のうち、最も確率の高いクラスの値と、次に確率の高いクラスの値との比に基づいて、前記温度パラメータを計算する付記6に記載の学習装置。
(付記12)
学習装置により実行される学習方法であって、
生徒モデルに基づき、温度パラメータを用いて推定を行い、
教師モデルに基づき、温度パラメータを用いて推定を行い、
前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する学習方法。
(付記13)
コンピュータを備える学習装置により実行されるプログラムであって、
生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定部、
教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定部、
前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算部、
として前記コンピュータを機能させるプログラム。
以上、実施形態及び実施例を参照して本願発明を説明したが、本願発明は上記実施形態及び実施例に限定されるものではない。本願発明の構成や詳細には、本願発明のスコープ内で当業者が理解し得る様々な変更をすることができる。
1 プロセッサ
2 メモリ
5、6、10a、10b 推定部
7、21 温度計算器
11a、11b ロジット計算器
12a、12b、22a、22b 活性化器
30 最適化部
31 損失計算器
32 加重平均計算器
33 パラメータ更新器
34 温度更新器
本発明の1つの観点では、学習装置は、生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定手段と、教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定手段と、前記第1の推定手段及び前記第2の推定手段が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算手段と、を備える。
本発明の他の観点では、学習装置により実行される学習方法は、生徒モデルに基づき、温度パラメータを用いて第1の推定を行い、教師モデルに基づき、温度パラメータを用いて第2の推定を行い、前記第1の推定及び前記第2の推定により生成された推定情報に基づいて、前記温度パラメータを計算する。
本発明の更に他の観点では、コンピュータを備える学習装置により実行されるプログラムは、生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定手段、教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定手段、前記第1の推定手段及び前記第2の推定手段が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算手段、として前記コンピュータを機能させる。

Claims (13)

  1. 生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定部と、
    教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定部と、
    前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算部と、
    を備える学習装置。
  2. 前記温度計算部は、正解ラベルに対する前記第1の推定部が生成する第1の推定結果の損失、及び、前記第2の推定部が生成する第2の推定結果に対する前記第1の推定結果の損失の少なくとも一方に基づいて前記温度パラメータを計算する請求項1に記載の学習装置。
  3. 前記温度計算部は、正解ラベルに対する前記第1の推定部が生成する第1の推定結果の損失と、前記第2の推定部が生成する第2の推定結果に対する前記第1の推定結果の損失との加重平均に基づいて前記温度パラメータを計算する請求項1に記載の学習装置。
  4. 前記温度計算部は、前記第1の推定部が生成するロジット、及び、前記第2の推定部が生成するロジットの少なくとも一方に基づいて、前記温度パラメータを計算する請求項1に記載の学習装置。
  5. 前記温度計算部は、前記第1の推定部又は第2の推定部が生成するロジットの正負に応じて前記温度パラメータを計算する請求項1に記載の学習装置。
  6. 前記温度計算部は、前記第1の推定部が生成する推定結果、及び、前記第2の推定部が生成する推定結果の少なくとも一方に基づいて前記温度パラメータを計算する請求項1に記載の学習装置。
  7. 前記温度計算部は、前記第2の推定部が生成する推定結果のエントロピーに基づいて前記温度パラメータを計算する請求項6に記載の学習装置。
  8. 前記温度計算部は、前記第2の推定部が生成する推定結果のエントロピーに関する単調減少関数として温度パラメータを計算する請求項6に記載の学習装置。
  9. 前記温度計算部は、前記第1の推定部又は前記第2の推定部が生成する各クラスの推定結果の大小バランスが補正されるように前記温度パラメータを計算する請求項6に記載の学習装置。
  10. 前記温度計算部は、前記第1の推定部又は前記第2の推定部が生成する各クラスの推定結果のオーダーが揃うように前記温度パラメータを計算する請求項6に記載の学習装置。
  11. 前記温度計算部は、前記第2の推定部が生成する各クラスの推定結果のうち、最も確率の高いクラスの値と、次に確率の高いクラスの値との比に基づいて、前記温度パラメータを計算する請求項6に記載の学習装置。
  12. 学習装置により実行される学習方法であって、
    生徒モデルに基づき、温度パラメータを用いて推定を行い、
    教師モデルに基づき、温度パラメータを用いて推定を行い、
    前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する学習方法。
  13. コンピュータを備える学習装置により実行されるプログラムであって、
    生徒モデルに基づき、温度パラメータを用いて推定を行う第1の推定部、
    教師モデルに基づき、温度パラメータを用いて推定を行う第2の推定部、
    前記第1の推定部及び前記第2の推定部が生成する推定情報に基づいて、前記温度パラメータを計算する温度計算部、
    として前記コンピュータを機能させるプログラム。
JP2020570350A 2019-02-05 2019-07-04 学習装置、学習方法、及び、プログラム Active JP7180697B2 (ja)

Applications Claiming Priority (3)

Application Number Priority Date Filing Date Title
PCT/JP2019/004032 WO2020161797A1 (ja) 2019-02-05 2019-02-05 学習装置、学習方法、及び、プログラム
JPPCT/JP2019/004032 2019-02-05
PCT/JP2019/026672 WO2020161935A1 (ja) 2019-02-05 2019-07-04 学習装置、学習方法、及び、プログラム

Publications (2)

Publication Number Publication Date
JPWO2020161935A1 true JPWO2020161935A1 (ja) 2021-11-25
JP7180697B2 JP7180697B2 (ja) 2022-11-30

Family

ID=71947438

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2020570350A Active JP7180697B2 (ja) 2019-02-05 2019-07-04 学習装置、学習方法、及び、プログラム

Country Status (3)

Country Link
US (1) US20220122349A1 (ja)
JP (1) JP7180697B2 (ja)
WO (2) WO2020161797A1 (ja)

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11551083B2 (en) 2019-12-17 2023-01-10 Soundhound, Inc. Neural network training from private data
CN112556107B (zh) * 2020-11-23 2022-10-28 北京新欧绿色建筑设计院有限公司 一种宜温宜湿宜氧室内环境智能控制系统

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150356461A1 (en) * 2014-06-06 2015-12-10 Google Inc. Training distilled machine learning models
WO2018051841A1 (ja) * 2016-09-16 2018-03-22 日本電信電話株式会社 モデル学習装置、その方法、及びプログラム
CN107977707A (zh) * 2017-11-23 2018-05-01 厦门美图之家科技有限公司 一种对抗蒸馏神经网络模型的方法及计算设备
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation

Family Cites Families (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
FR2996374B1 (fr) * 2012-10-03 2016-10-28 Valeo Systemes De Controle Moteur Reseau electrique pour vehicule automobile
JP6712644B2 (ja) * 2016-09-30 2020-06-24 日本電信電話株式会社 音響モデル学習装置、その方法、及びプログラム

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20150356461A1 (en) * 2014-06-06 2015-12-10 Google Inc. Training distilled machine learning models
WO2018051841A1 (ja) * 2016-09-16 2018-03-22 日本電信電話株式会社 モデル学習装置、その方法、及びプログラム
US20180268292A1 (en) * 2017-03-17 2018-09-20 Nec Laboratories America, Inc. Learning efficient object detection models with knowledge distillation
CN107977707A (zh) * 2017-11-23 2018-05-01 厦门美图之家科技有限公司 一种对抗蒸馏神经网络模型的方法及计算设备

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
MOSNER, LADISLAV ET AL.: "IMPROVING NOISE ROBUSTNESS OF AUTOMATIC SPEECH RECOGNITION VIA PARALLEL DATA AND TEACHER-STUDENT LEA", ARXIV, JPN6019015624, 11 January 2019 (2019-01-11), ISSN: 0004846094 *

Also Published As

Publication number Publication date
WO2020161797A1 (ja) 2020-08-13
US20220122349A1 (en) 2022-04-21
WO2020161935A1 (ja) 2020-08-13
JP7180697B2 (ja) 2022-11-30

Similar Documents

Publication Publication Date Title
US11544573B2 (en) Projection neural networks
US10929744B2 (en) Fixed-point training method for deep neural networks based on dynamic fixed-point conversion scheme
US20190034796A1 (en) Fixed-point training method for deep neural networks based on static fixed-point conversion scheme
JP3743247B2 (ja) ニューラルネットワークによる予測装置
Bai et al. Prediction of SARS epidemic by BP neural networks with online prediction strategy
CN107729999A (zh) 考虑矩阵相关性的深度神经网络压缩方法
JP6055058B1 (ja) 機械学習器及び組み立て・試験器を備えた生産設備
JP2017037392A (ja) ニューラルネットワーク学習装置
US20210027147A1 (en) Forward propagation of secondary objective for deep learning
CN113361685B (zh) 一种基于学习者知识状态演化表示的知识追踪方法及系统
EP3502978A1 (en) Meta-learning system
CN109934330A (zh) 基于多样化种群的果蝇优化算法来构建预测模型的方法
CN109885667A (zh) 文本生成方法、装置、计算机设备及介质
JP7180697B2 (ja) 学習装置、学習方法、及び、プログラム
CN111159419A (zh) 基于图卷积的知识追踪数据处理方法、系统和存储介质
JP2023519770A (ja) マルチタスク向けの予めトレーニング言語モデルの自動圧縮方法及びプラットフォーム
CN114971066A (zh) 融合遗忘因素和学习能力的知识追踪方法及系统
CN112381591A (zh) 基于lstm深度学习模型的销售预测优化方法
JP7256378B2 (ja) 最適化システムおよび最適化システムの制御方法
CN113868113B (zh) 一种基于Actor-Critic算法的类集成测试序列生成方法
JP2020119108A (ja) データ処理装置、データ処理方法、データ処理プログラム
CN111563548B (zh) 一种基于强化学习的数据预处理方法、系统及相关设备
US12055934B2 (en) Machine learning for trajectory planning
CN115952838B (zh) 一种基于自适应学习推荐系统生成方法及系统
US20240086678A1 (en) Method and information processing apparatus for performing transfer learning while suppressing occurrence of catastrophic forgetting

Legal Events

Date Code Title Description
A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20210716

A621 Written request for application examination

Free format text: JAPANESE INTERMEDIATE CODE: A621

Effective date: 20210716

A131 Notification of reasons for refusal

Free format text: JAPANESE INTERMEDIATE CODE: A131

Effective date: 20220809

A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20221005

TRDD Decision of grant or rejection written
A01 Written decision to grant a patent or to grant a registration (utility model)

Free format text: JAPANESE INTERMEDIATE CODE: A01

Effective date: 20221018

A61 First payment of annual fees (during grant procedure)

Free format text: JAPANESE INTERMEDIATE CODE: A61

Effective date: 20221031

R151 Written notification of patent or utility model registration

Ref document number: 7180697

Country of ref document: JP

Free format text: JAPANESE INTERMEDIATE CODE: R151