JP7307785B2 - 機器学習装置及び方法 - Google Patents

機器学習装置及び方法 Download PDF

Info

Publication number
JP7307785B2
JP7307785B2 JP2021195279A JP2021195279A JP7307785B2 JP 7307785 B2 JP7307785 B2 JP 7307785B2 JP 2021195279 A JP2021195279 A JP 2021195279A JP 2021195279 A JP2021195279 A JP 2021195279A JP 7307785 B2 JP7307785 B2 JP 7307785B2
Authority
JP
Japan
Prior art keywords
loss
processor
neural network
machine learning
network structure
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
JP2021195279A
Other languages
English (en)
Other versions
JP2022088341A (ja
Inventor
宇劭 彭
凱富 湯
智威 張
Original Assignee
宏達國際電子股▲ふん▼有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 宏達國際電子股▲ふん▼有限公司 filed Critical 宏達國際電子股▲ふん▼有限公司
Publication of JP2022088341A publication Critical patent/JP2022088341A/ja
Application granted granted Critical
Publication of JP7307785B2 publication Critical patent/JP7307785B2/ja
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)
  • Feedback Control In General (AREA)
  • Control Of Electric Motors In General (AREA)
  • Numerical Control (AREA)

Description

本開示は、機器学習技術に関し、特に、偽相関を除去する機器学習技術に関する。
例えば、機器学習、ニューラルネットワーク等の技術は、人工知能技術の分野で広く適用されている。人工知能の重要な用途の1つとしては、オブジェクト(例えば、顔、ナンバープレート等)の識別、又はデータの予測(例えば、株価予測、医療予測等)がある。オブジェクト検出及びデータ予測は、特徴抽出及び特徴分類によって実現されることができる。
しかしながら、特徴抽出及び特徴分類に用いられる特徴間には、一般的に、偽相関が発生し、しかも偽相関によりオブジェクト検出及びデータ予測の予測精度が低下してしまう。
本開示の一態様は、プロセッサによってメモリからモデルパラメータを取得して、モデルパラメータに基づいて複数のニューラルネットワーク構造層を含む分類モデルを実行する工程と、プロセッサによって複数のトレーニングサンプルに基づいて、複数のニューラルネットワーク構造層における出力層に対応する第1の損失と、複数のニューラルネットワーク構造層における出力層よりも前に位置する一方に対応する第2の損失を算出する工程と、プロセッサによって、第1の損失及び前記第2の損失に基づいてモデルパラメータに対して複数の更新操作を実行して、分類モデルをトレーニングする工程と、を備える機器学習方法を開示する。
いくつかの実施例において、前記複数のトレーニングサンプルに基づいて前記第1の損失及び前記第2の損失を算出する工程は、前記プロセッサによって前記複数のトレーニングサンプルに基づいて前記複数のニューラルネットワーク構造層の前記出力層から複数の予測ラベルを生成する工程と、前記プロセッサによって前記複数の予測ラベルと前記複数のトレーニングサンプルの複数のトレーニングラベルとを比較して前記第1の損失を算出する工程と、を含む。
いくつかの実施例において、前記複数のトレーニングサンプルに基づいて前記第1の損失及び前記第2の損失を算出する工程は、前記プロセッサによって前記複数のトレーニングサンプルに基づいて前記分類モデルから複数の抽出特徴を生成する工程と、前記プロセッサによって前記複数のニューラルネットワーク構造層における一方に対応する前記複数の抽出特徴の間の統計独立性に基づいて、前記第2の損失を算出する工程と、を含む。
いくつかの実施例において、前記第1の損失及び前記第2の損失に基づいて前記モデルパラメータに対して前記複数の更新操作を実行して前記分類モデルをトレーニングする工程は、
前記プロセッサによって前記第1の損失及び前記第2の損失に基づいて複数の損失差を算出する工程と、前記プロセッサによって前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新する工程と、を含む。
いくつかの実施例において、機器学習方法は、前記プロセッサによって前記複数の抽出特徴、及び前記複数のトレーニングサンプルの複数のトレーニングラベルの間の平均処理効果に基づいて第3の損失を算出する工程を更に備える。
いくつかの実施例において、前記第1の損失及び前記第2の損失に基づいて前記モデルパラメータに対して前記複数の更新操作を実行して前記分類モデルをトレーニングする工程は、
前記プロセッサによって前記第1の損失、前記第2の損失、及び前記第3の損失に基づいて、複数の損失差を算出する工程と、前記プロセッサによって前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新する工程と、を含む。
いくつかの実施例において、前記複数のトレーニングサンプルに基づいて前記第1の損失及び前記第2の損失を算出する工程は、前記プロセッサによって前記複数のトレーニングサンプルに基づいて前記分類モデルから複数の抽出特徴を生成する工程と、前記プロセッサによって前記複数のニューラルネットワーク構造における一方に対応する前記複数の抽出特徴と、前記複数のトレーニングサンプルの複数のトレーニングラベルの間の平均処理効果とに基づいて前記第2の損失を算出する工程と、を含む。
いくつかの実施例において、前記第1の損失及び前記第2の損失に基づいて前記モデルパラメータに対して前記複数の更新操作を実行して前記分類モデルをトレーニングする工程は、前記プロセッサによって前記第1の損失及び前記第2の損失に基づいて複数の損失差を算出する工程と、前記プロセッサによって前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新する工程と、を含む。
いくつかの実施例において、出力層は、少なくとも1つの完全結合層を含み、前記複数のニューラルネットワーク構造層における一方は、少なくとも1つの畳み込み層を含む。
いくつかの実施例において、前記分類モデルは、ニューラルネットワークに関連づけられる。
本開示の別の態様は、複数のコマンド及びモデルパラメータを記憶するためのメモリと、メモリに接続されるプロセッサと、を備える機器学習装置であって、前記プロセッサは、分類モデルを実行するとともに、メモリからモデルパラメータを取得して、モデルパラメータに基づいて複数のニューラルネットワーク構造層を含む分類モデルを実行し、複数のトレーニングサンプルに基づいて、複数のニューラルネットワーク構造層における出力層に対応する第1の損失と、複数のニューラルネットワーク構造層における出力層よりも前に位置する一方に対応する第2の損失を算出し、第1の損失及び第2の損失に基づいて、モデルパラメータに対して複数の更新操作を実行して、分類モデルをトレーニングするように、複数のコマンドを実行するためのものである、機器学習装置を開示する。
いくつかの実施例において、前記プロセッサは、更に、前記複数のトレーニングサンプルに基づいて前記複数のニューラルネットワーク構造層の前記出力層から複数の予測ラベルを生成し、及び前記第1の損失を算出するように、前記複数の予測ラベルと前記複数のトレーニングサンプルの複数のトレーニングラベルとを比較するためのものである。
いくつかの実施例において、前記プロセッサは、更に、前記複数のトレーニングサンプルに基づいて前記分類モデルから複数の抽出特徴を生成し、及び前記複数のニューラルネットワーク構造層における一方に対応する前記複数の抽出特徴の間の統計独立性に基づいて前記第2の損失を算出するためのものである。
いくつかの実施例において、前記プロセッサは、更に、前記第1の損失及び前記第2の損失に基づいて複数の損失差を算出し、及び前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新するためのものである。
いくつかの実施例において、前記プロセッサは、更に、前記複数の抽出特徴と、前記複数のトレーニングサンプルの複数のトレーニングラベルの間の平均処理効果とに基づいて第3の損失を算出するためのものである。
いくつかの実施例において、前記プロセッサは、更に、前記第1の損失、前記第2の損失、及び前記第3の損失に基づいて複数の損失差を算出し、及び前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新するためのものである。
いくつかの実施例において、前記プロセッサは、更に、前記複数のトレーニングサンプルに基づいて前記分類モデルから複数の抽出特徴を生成し、及び前記複数のニューラルネットワーク構造における一方に対応する前記複数の抽出特徴と、前記複数のトレーニングサンプルの複数のトレーニングラベルの間の平均処理効果とに基づいて前記第2の損失を算出するためのものである。
いくつかの実施例において、前記プロセッサは、更に、前記第1の損失及び前記第2の損失に基づいて複数の損失差を算出し、及び前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新するためのものである。
いくつかの実施例において、前記出力層は、少なくとも1つの完全結合層を含み、前記複数のニューラルネットワーク構造層における一方は、少なくとも1つの畳み込み層を含む。
いくつかの実施例において、前記分類モデルは、ニューラルネットワークに関連づけられる。
本開示の一実施例による機器学習装置を示す模式図である。 本開示の一実施例による機器学習方法を示す模式図である。 本開示の一実施例による分類モデル及び損失を示す模式図である。 いくつかの実施例における図2のある工程の細部を示すフローチャートである。 別のいくつかの実施例における図2のある工程の細部を示すフローチャートである。 いくつかの実施例における図2の別の工程の細部を示すフローチャートである。 いくつかの実施例における図2の追加工程を示すフローチャートである。 別のいくつかの実施例における図2の別の工程の細部を示すフローチャートである。
ここで、本開示の現在の実施例を詳細に参照し、その例を図面に示す。可能な場合には、図面及び説明において同一の要素符号を使用して同一の素子を表す。
図1を参照すると、図1は、本開示の一実施例による機器学習装置を示す模式図である。機器学習装置100は、プロセッサ110と、メモリ120と、を備える。プロセッサ110及びメモリ120は、互いに接続される。
いくつかの実施例において、機器学習装置100は、コンピュータ、サーバ、又は処理センターによって構築されてよい。いくつかの実施例において、プロセッサ110は、中央処理部又は演算部によって実現されてよい。いくつかの実施例において、メモリ120は、フラッシュメモリ、リードオンリーメモリ、ハードディスク、又は同等性を有する任意の記憶素子を用いて実現されてよい。
いくつかの実施例において、機器学習装置100は、プロセッサ110とメモリ120とを含むことに限定されず、動作及び適用に必要な他の素子を更に含んでよく、例としては、出力インターフェース(例えば、情報を表示するための表示パネル)、入力インターフェース(例えば、タッチパネル、キーボード、マイク、スキャナ、又はフラッシュリーダ)、及び通信回路(例えば、WiFi通信モジュール、Bluetooth通信モジュール、無線通信ネットワーク通信モジュール等)を更に含んでよい。
図1に示すように、プロセッサ110は、メモリ120に記憶された対応するソフトウェア/ファームウェアコマンドプログラムに基づいて分類モデル111を実行するためのものである。
いくつかの実施例において、分類モデル111は、入力されたデータ(例えば、上記のデータ強調画像)を分類することができ、例えば、入力画像の中に車両、顔、ナンバープレート、文字、トーテムオブジェクト、又はその他の画像特徴を有するオブジェクトを検出することができる。分類モデル111は、分類結果に応じて、対応するラベルを生成する。特に説明すべきなのは、分類モデル111は、分類動作を行う際に、その自体のモデルパラメータMPを参照する必要がある。
図1に示すように、メモリ120は、モデルパラメータMPを記憶するためのものである。いくつかの実施例において、モデルパラメータMPは、複数の重みパラメータ内容を含んでよい。
本実施例において、分類モデル111は、複数のニューラルネットワーク構造層を含む。いくつかの実施例において、各層のニューラルネットワーク構造層は、モデルパラメータMPにおける1つの重みパラメータ内容(1つのニューラルネットワーク構造層の動作を決定するためのものである)に対応してもよい。一方、分類モデル111の各ニューラルネットワーク構造層は、互いに独立した重みパラメータ内容に対応してよい。つまり、各層のニューラルネットワーク構造層は、1つの重み値集合に対応してよく、重み値集合は、複数の重み値を含んでよい。
いくつかの実施例において、ニューラルネットワーク構造層は、畳み込み層、プール層、線形整流層、完全結合層、又は他のタイプのニューラルネットワーク構造層であってよい。いくつかの実施例において、分類モデル111は、ニューラルネットワークに関連づけられてよい(例えば、分類モデル111は、深度残差ネットワーク及び完全結合層から構成され、又はEfficentNet及び完全結合層から構成されてよい)。
本開示の一実施例による機器学習方法を示す模式図である図2を併せて参照すると、図1に示される機器学習装置100は、図2の機器学習方法を実行するために使用されてよい。
図2に示すように、まず、工程S210において、メモリ120からモデルパラメータMPを取得して、モデルパラメータMPに基づいて分類モデル111を実行する。一実施例において、メモリ120におけるモデルパラメータMPは、従来のトレーニング経験から得られた平均値、人工的に与えられたデフォルト値、又は乱数値であってよい。
工程S220において、複数のトレーニングサンプルに基づいて、複数のニューラルネットワーク構造層における出力層に対応する第1の損失と、複数のニューラルネットワーク構造層における出力層よりも前に位置する一方に対応する第2の損失を算出する。一実施例において、第1の損失は、分類モデル111のニューラルネットワーク構造層の出力層からプロセッサ110によって生成され、第2の損失は、出力層よりも前のニューラルネットワーク構造層からプロセッサ110によって生成される。いくつかの実施例において、出力層は、少なくとも1つの完全結合層を含んでよい。以下、具体例に合わせて、いくつかの実施例における工程S220の詳細な工程について更に説明する。
工程S230において、第1の損失及び第2の損失に基づいて、モデルパラメータMPに対して複数の更新操作を実行して、分類モデル111をトレーニングする。一実施例において、トレーニングされたモデルパラメータMPを生成するように、プロセッサ110は、更新操作において、第1の損失及び第2の損失に基づいてモデルパラメータMPを更新し、更に、更新されたモデルパラメータMPに基づいて分類モデル111をトレーニングして、トレーニングされた分類モデル111を生成する。以下、具体例に合わせて、いくつかの実施例における工程S230の詳細な工程を更に説明する。
これにより、トレーニングされた分類モデル111は、後のアプリケーションを実行するために使用されてよい。例えば、トレーニングされた分類モデル111は、画像や映像の入力、ストリームにおけるオブジェクト識別、顔識別、音声識別、又は動的検出等に使用され、又は株価データ又は天気情報に関するデータ予測に使用されてよい。
図3及び図4を併せて参照すると、図3は、本開示の一実施例による分類モデル及び損失を示す模式図である。図4は、いくつかの実施例における工程S220の詳細な工程S221~S224Aを示すフローチャートである。
図3に示すように、分類モデル111は、ニューラルネットワーク構造層SL1、SL2、~SLtを含む。いくつかの実施例において、tは正の整数である。一般的に、分類モデル111における総層数は、実際の適用の要求(例えば、分類の精度、分類対象物の複雑さ、入力映像の相違性)に応じて決定されてよい。場合によって、tの一般的な範囲は16~128であってよいが、本開示は特定の層数に限定されない。
例としては、ニューラルネットワーク構造層SL1及びSL2は畳み込み層であってよく、ニューラルネットワーク構造層SL3はプール層であってよく、ニューラルネットワーク構造層SL4及びSL5は畳み込み層であってよく、ニューラルネットワーク構造層SL6はプール層であってよく、ニューラルネットワーク構造層SL7は畳み込み層であってよく、ニューラルネットワーク構造層SL8は線形整流層であってよく、ニューラルネットワーク構造層SLtは完全結合層であってよいが、本開示はこれらに限定されない。
いくつかの実施例において、分類モデル111は複数の残差マップブロックを有してもよく、残差マップブロックの構造を使用することで、tを大幅に低減することができる。以下、分類モデル111のような構成を例として、工程S221~工程S224Aを更に説明する。
なお、説明の便宜上、図3における分類モデル111は例示的な説明に過ぎず、残差マップブロックを有するモデル(例えば、ResNetモデル)を示すが、本開示はこれに限定されない。実際の適用では、分類モデル111は、他のタイプの畳み込みニューラルネットワークであってよい。いくつかの実施例において、分類モデル111はEfficentNetモデルであってよい。
図3及び図4に示すように、工程S221において、プロセッサ110によって、ニューラルネットワーク構造層SLl、SL2、~SLtの出力層SLtから、トレーニングサンプル
Figure 0007307785000001
に基づいて複数の予測ラベル
Figure 0007307785000002
が生成される。注意すべきなのは、nはトレーニングサンプル
Figure 0007307785000003
の数であり、nは
Figure 0007307785000004
予測ラベルの数であり、nは正の整数であってよく、iはn以下の正の整数であってよい。図3に示すように、トレーニングサンプルXiが分類モデル111に入力されると、ニューラルネットワーク構造層SLl、SL2、~SLtの演算により、分類モデル111のニューラルネットワーク構造層SLt(すなわち、出力層)から予測ラベル
Figure 0007307785000005
を生成することができる。同様に、予測ラベル
Figure 0007307785000006
を生成するように、トレーニングサンプル
Figure 0007307785000007
を分類モデル111に入力してよい。
図3及び図4に示すように、工程S222において、プロセッサ110によって比較アルゴリズムが実行されて予測ラベル
Figure 0007307785000008
とトレーニングサンプル
Figure 0007307785000009
の複数のトレーニングラベル
Figure 0007307785000010
とを比較して、第1の損失Llを生成する。図3に示すように、予測ラベル
Figure 0007307785000011
とトレーニングサンプルXiのトレーニングラベルyとを比較して、損失を算出する。同様に、プロセッサ110によって比較アルゴリズムが実行されて予測ラベルとトレーニングラベルとを比較して複数の損失を算出し、且つ、プロセッサ110によって、これらの損失(すなわち、従来の損失関数)に基づいて第1の損失L1を生成する。いくつかの実施例において、第1の損失L1を得るように、プロセッサ110によって予測ラベル
Figure 0007307785000012
及びトレーニングラベル
Figure 0007307785000013
に対してクロスエントロピー算出を実行してもよい。
図3及び図4に示すように、工程S223において、トレーニングサンプル
Figure 0007307785000014
に基づいて分類モデル111から複数の抽出特徴
Figure 0007307785000015
を生成する。図3に示すように、トレーニングサンプルXiが分類モデル111に入力されると、ニューラルネットワーク構造層SL1、SL2、~SLt-1の操作により分類モデル111のニューラルネットワーク構造層Lt-1の人工ニューロンから抽出特徴Hi,1、Hi,2、~Hi,m(mは正の整数で人工ニューロンの数に等しい)を算出してよく、且つ、抽出特徴 i,1 、H i,2 、…H i,m はそれぞれニューラルネットワーク構造層Lt-1における人工ニューロンに対応する。また、抽出特徴 i,1 、H i,2 、…H i,m は、それぞれニューラルネットワーク構造層Lt-1よりも前の何れのニューラルネットワーク構造層における人工ニューロンに対応してもよい。同様に、人工ニューロンからトレーニングサンプル
Figure 0007307785000016
に対応する抽出特徴
Figure 0007307785000017
算出してよい。
注意すべきなのは、抽出特徴
Figure 0007307785000018
とトレーニングラベル
Figure 0007307785000019
との間に偽相関が存在する可能性がある。詳細には、第1の抽出特徴は、第2の抽出特徴及びトレーニングラベルyの何れに対しても因果関係があるが、第2の抽出特徴とトレーニングラベルy同士の間には因果関係がないものとする。これに基づき、第2の抽出特徴及びトレーニングラベルyを関連付けることができる。第2の抽出特徴の数値がラベルの変化に伴い直線的に増加する場合、第2の抽出特徴とトレーニングラベルyとの間には偽相関が存在する。偽相関を引き起こす抽出特徴(すなわち、第1の抽出特徴、第2の抽出特徴、及びトレーニングラベルyの間の関係)が観察され得る場合、偽相関はドミナントである。そうでない場合、偽相関は、リセッシブ(すなわち、第2の抽出特徴とトレーニングラベルyとの間の関係)であると考えられてよい。偽相関は、予測ラベル
Figure 0007307785000020
とトレーニングラベル
Figure 0007307785000021
との間のより大きな差を引き起こす。
例えば、患者の臨床画像が病巣の細胞組織、及び細胞組織と色が類似した骨を有する場合、骨の抽出特徴と病巣のラベルとの間のドミナントな偽相関を引き起こす。別の例では、患者の臨床画像は、典型的には、バックグラウンドを有し、患者の臨床画像における病巣及びバックグラウンドは類似である。従って、これは、バックグラウンドの抽出特徴と病巣のラベルとの間のリセッシブな偽相関を引き起こす。
偽相関を回避するために、統計的独立性を使用してドミナントな偽相関を除去し、及び平均化効果を使用してリセッシブな偽相関を除去することの細部を、以下の段落で更に説明する。
図3及び図4に示すように、工程S224Aにおいて、プロセッサ110によって、抽出特徴間の統計的独立性に基づいて第2の損失L2を算出し、抽出特徴はニューラルネットワーク構造層SL1、SL2、~SLtにおける1つ(すなわち、ニューラルネットワーク構造層SLt-1)に対応する。具体的には、確率変数の統計的独立性は、以下の式(1)で示される。
E(a)=E(a)E(b) (1)
ここで、E(.)はランダム変数の期待値を表し、a及びbはランダム変数であり、p及びqは正の整数である。式(1)により、独立性損失は、以下の式(2)で表すことができる。
independent loss=-|E(a)-E(a)E(b)| (2)
図3に示すように、ランダム変数を抽出特徴
Figure 0007307785000022
に置き換えることで、式(2)は、第2の損失L2(すなわち、抽出特徴
Figure 0007307785000023
間の独立性損失)を表す以下の式(3)に書き換えることができる。
Figure 0007307785000024
ここで、j及びkは正の整数であり、m以下である。式(3)により、抽出特徴
Figure 0007307785000025
から第2の損失L2を算出する。いくつかの実施例において、式(3)の第2の損失に更に重要度値を乗算して第2の損失L2を生成してもよく、重要度値は、0より大きく且つ独立性損失の重要性を制御するハイパーパラメータである。
別の実施例における工程S220の詳細な工程S221~S224Bを示すフローチャートである図5を併せて参照されたい。
注意すべきなのは、図4と図5との相違点は、工程S224Bのみにある。すなわち、工程S224Aを実行して第2の損失を生成することに加えて、工程S224Bを実行して第2の損失を生成してもよい。従って、以下、工程S224Bについてのみ説明し、残りの工程については繰り返して説明しない。
図3及び図5に示すように、工程S224Bにおいて、プロセッサ110によって、抽出特徴とトレーニングサンプルのトレーニングラベルの間の平均処理効果とに基づいて第2の損失L3を算出し、抽出特徴はニューラルネットワーク構造層SL1、SL2、~SLtにおける1つ(すなわち、ニューラルネットワーク構造層SLt-1)に対応する。詳細には、確率変数の平均処理効果(すなわち、因果性)は、以下の式(4)で示される。
Figure 0007307785000026
ここで、p(.)は確率変数の確率を表し、 及び は確率変数であり、
Figure 0007307785000027
は治療を表し、
Figure 0007307785000028
で且つ観察結果であり、
Figure 0007307785000029
で且つ共変ベクトルであり、及び
Figure 0007307785000030
である。
図3に示すように、 及び をトレーニングラベル
Figure 0007307785000031
及び強活性関数により処理された抽出特徴
Figure 0007307785000032
に置き換えることで、式(4)は以下の式(5)のように書き換えられる。
Figure 0007307785000033
ここで、j番目の抽出特徴の損失とは、抽出特徴H1,j、H2,j、~Hn,jに対応する因果的損失(すなわち、平均処理効果損失)であり、
Figure 0007307785000034
とは範囲が
Figure 0007307785000035
の強活性関数である。式(5)より、抽出特徴
Figure 0007307785000036
の平均処理効果を示す第2の損失L3は、以下の式(6)で示される。
Figure 0007307785000037
式(6)により、抽出特徴とトレーニングサンプルのトレーニングラベルとに基づいて第2の損失L3を算出する。いくつかの実施例において、式(6)の第2の損失に、更に別の重要度値を乗算してもよく、他の重要度値は、0より大きく且つ平均処理効果損失の重要性を制御するハイパーパラメータである。
いくつかの実施例における工程S230の詳細な工程S231A~S233を示すフローチャートである図6を併せて参照されたい。
図6に示すように、工程S231Aにおいて、プロセッサ110によって、第1の損失及び第2の損失に基づいて損失差を算出する。詳細には、プロセッサ110によって第1の損失及び第2の損失の間の差分演算を実行して、損失差(すなわち、第1の損失から第2の損失を引く)を生成する。注意すべきなのは、第2の損失は、図4の工程S224A又は図5の工程S224Bから生成してもよい。つまり、第1の損失及び独立損失、又は第1の損失及び平均処理効果損失に基づいて、損失差を算出してよい。
また、第1の損失、図4の工程S224Aで生成した第2の損失、及び図5の工程S224Bで生成した第2の損失に基づいて損失差を算出してよい(より詳細は、以下の段落でいくつかの例によって説明する)。
工程S232では、損失差が収束したかを判断する。いくつかの実施例において、損失差は、収束すると、統計的実験結果から生じた差閾値に近づくか、又はこれに等しくなってよい。
本実施例において、損失差が収束していなければ、工程S233を実行する。工程S233において、プロセッサ110によって、第1の損失及び第2の損失に基づいて分類モデルに対して逆伝搬操作を実行して、モデルパラメータMPを更新する。つまり、第1の損失及び第2の損失に基づく逆伝搬操作によって、モデルパラメータMPから更新されたモデルパラメータを生成する。
これにより、工程S233、S220及びS231Aを継続的に繰り返して、モデルパラメータMPを繰り返しに徐々に更新する。このように、損失差は、差閾値に近づくか又は等しくなるまで、徐々に最小化する(すなわち、第2の損失が徐々に最大化する)。逆に、損失差が収束する場合、機器学習装置100がトレーニングを完了したことを示し、トレーニングされた分類モデル111は、後のアプリケーションを実行するために使用されてよい。
上記実施例に基づき、工程S224Aにおける第2の損失を用いることで、工程S230においてドミナント的な偽相関に属する抽出特徴を除去することができる。また、工程S224Bにおける第2の損失を用いることで、工程S230においてリセッシブな偽相関に属する抽出特徴を除去することができる。
図7を併せて参照すると、図は、いくつかの実施例における工程S224Aの次の追加工程を示すフローチャートである。
図7に示すように、工程S220’Aは、工程S224Bにおける第2の損失の算出と同様に、第3の損失を算出する。つまり、これは、プロセッサ110によって第1の損失が生成した後に、独立損失及び平均処理効果損失が生成することを意味する。工程S220’A及び工程S224Bは同様であるので、その工程については繰り返して説明しない。
別の実施例における工程S230の詳細な工程S231B~S233を示すフローチャートである図8を併せて参照されたい。
注意すべきなのは、図6と図8との相違点は、工程S231Bのみにある。すなわち、工程S231Aを実行して損失差を生成することに加えて、工程S231Bを実行して損失差を生成してもよい。従って、以下、工程S231Bについてのみ説明し、残りの工程については繰り返して説明しない。
図8に示すように、工程S220’を実行した後、工程S231Bを実行する。工程S231Bにおいて、プロセッサ110によって、第1の損失、第2の損失及び第3の損失に基づいて損失差を算出する。詳細には、プロセッサ110によって、第1の損失と第2の損失との間の差分演算を実行して第1の差分値を生成し、次に第1の差分値と第3の損失との間で別の差分演算を実行して損失差を生成する(すなわち、第1の損失から第2の損失を減算し、その後に第3の損失を減算する)。従って、工程S233において、第1の損失、第2の損失及び第3の損失に基づく逆伝搬によって、モデルパラメータMPから更新されたモデルパラメータを生成する。これにより、工程S233、S220及びS231Bを継続的に繰り返して、モデルパラメータMPを繰り返しに徐々に更新する。このように、損失差も、同様に、損失差が差分閾値に近づくか又は等しくなるまで、徐々に最小化する(すなわち、第2の損失及び第3の損失が徐々に最大化する)。
上記実施例に基づき、工程S224Aにおける第2の損失及びS220’における第3の損失を同時に用いることで、工程S230においてドミナント的な偽相関及びリセッシブな偽相関に属する抽出特徴を除去することができる。
図1に示すように、抽出特徴とトレーニングラベルとの間のドミナント偽相関又はリセッシブな偽相関を回避するように、機器学習装置100のトレーニング過程において、第1の損失及び第2の損失に基づいて分類モデル111のモデルパラメータMPを更新し、第2の損失は、独立性損失又は平均処理効果損失であってよい。また、独立性損失と平均処理効果損失を用いてモデルパラメータMPを調整することで、ドミナント偽相関又はリセッシブな偽相関を除去して、分類モデル111の予測精度を大幅に向上させることができる。
コンピュータビジョン及びコンピュータ予測の分野では、深層学習の正確度は、主に、大量のラベルのトレーニングデータに依存する。トレーニングデータの質、数、及びタイプの増加に伴い、分類モデルの性能は、一般に相対的に向上する。しかしながら、分類モデルは、抽出特徴とトレーニングラベルとの間に、常に、ドミナント偽相関又はリセッシブな偽相関が存在する。ドミナント偽相関又はリセッシブな偽相関を除去できれば、効率はより高く、より正確になる。上記の本開示の実施例において、独立性損失及び平均処理効果損失に基づいてモデルを調整し、分類モデルにおけるドミナント偽相関又はリセッシブな偽相関を除去することが提案される。従って、独立性損失及び平均処理効果損失に基づいてモデルパラメータを調整することで、モデルの全体的な性能を向上させることができる。
適用の点において、本開示の機器学習方法及び機器学習システムは、機器視覚、画像分類、データ予測又はデータ分類を有する各種の分野に用いることができ、例としては、この機器学習方法は、正常状態、肺炎、気管支炎、心臓疾患にかかるX線イメージ、又は正常胎児、胎位不正を識別可能な超音波イメージのような医療イメージの分類に用いることができる。機器学習方法は、将来の株データの上昇又は下降を予測するためにも用いることができる。一方、この機器学習方法は、正常な路面、障害物のある路面、及び他の車両のある路面を識別可能な道路状況画像等の自動運転収集の映像の分類にも用いることができる。また、これに類似する機器学習分野もあり、例としては、本開示の機器学習方法及び機器学習システムは、音声スペクトルの識別、スペクトルの識別、ビッグデータの分析、データ特徴の識別等の他の機器学習関連カテゴリにも用いることができる。
本開示の特定の実施例は、かかる上記の実施例をすでに開示したが、これらの実施例は、本開示を制限することを意図していない。様々な代替例および改良例は、本開示の原理及び趣旨から逸脱することなく、関連技術分野における当業者によって本開示において実施され得る。従って、本開示の保護範囲は、添付の特許請求の範囲によって決定される。
100 機器学習装置
110 プロセッサ
120 メモリ
MP モデルパラメータ
111 分類モデル
SL1、SL2、~SLt ニューラルネットワーク構造層
Figure 0007307785000038
L1 第1の損失
L2、L3 第2の損失
S210~S230、S221~S223、S224A、224B、S231A、S231B、S232~S233、S220’ 工程

Claims (14)

  1. プロセッサによってメモリからモデルパラメータを取得して、前記モデルパラメータに基づいて複数のニューラルネットワーク構造層を含む分類モデルを実行する工程と、
    前記プロセッサによって複数のトレーニングサンプルに基づいて、前記複数のニューラルネットワーク構造層における出力層に対応する第1の損失と、前記複数のニューラルネットワーク構造層における前記出力層よりも前に位置するいずれかの層に対応する第2の損失を算出する工程と、
    プロセッサによって前記第1の損失及び前記第2の損失に基づいて前記モデルパラメータに対して複数の更新操作を実行して前記分類モデルをトレーニングする工程と、
    を備え、
    前記複数のトレーニングサンプルに基づいて前記第1の損失と前記第2の損失を算出する前記工程は、
    前記プロセッサによって、前記複数のトレーニングサンプルに基づいて前記分類モデルから複数の抽出特徴を生成する工程と、
    前記プロセッサによって、前記複数の抽出特徴に基づいて前記第2の損失を算出する工程を含み、
    前記複数の抽出特徴は、前記複数のニューラルネットワーク構造層の前記いずれかの層に対応し、前記第2の損失は、前記複数の抽出特徴の期待値と前記複数の抽出特徴の積との差の値である、
    機器学習方法。
  2. 前記複数のトレーニングサンプルに基づいて前記第1の損失及び前記第2の損失を算出する工程は、
    前記プロセッサによって前記複数のトレーニングサンプルに基づいて前記複数のニューラルネットワーク構造層の前記出力層から複数の予測ラベルを生成する工程と、
    前記プロセッサによって前記複数の予測ラベルと前記複数のトレーニングサンプルの複数のトレーニングラベルとを比較して前記第1の損失を算出する工程と、
    を含む請求項1に記載の機器学習方法。
  3. 前記第1の損失及び前記第2の損失に基づいて前記モデルパラメータに対して前記複数の更新操作を実行して前記分類モデルをトレーニングする工程は、
    前記プロセッサによって前記第1の損失及び前記第2の損失に基づいて複数の損失差を算出する工程と、
    前記プロセッサによって前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新する工程と、
    を含む請求項に記載の機器学習方法。
  4. 前記プロセッサによって前記複数の抽出特徴、及び前記複数のトレーニングサンプルの複数のトレーニングラベルの間の第3の損失を算出する工程を更に備え、
    前記複数の抽出特徴は、前記複数のニューラルネットワーク構造層のいずれかの層に対応し、
    前記第3の損失は、
    Figure 0007307785000039
    によって算出され、ここで、
    Figure 0007307785000040
    は、前記複数のトレーニングラベルであり、nとiは、正の整数であり、H i,j は、前記抽出特徴であり、
    Figure 0007307785000041
    は、強活性関数である、
    請求項に記載の機器学習方法。
  5. 前記第1の損失及び前記第2の損失に基づいて前記モデルパラメータに対して前記複数の更新操作を実行して前記分類モデルをトレーニングする工程は、
    前記プロセッサによって前記第1の損失、前記第2の損失、及び前記第3の損失に基づいて、複数の損失差を算出する工程と、
    前記プロセッサによって前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新する工程と、
    を含む請求項に記載の機器学習方法。
  6. 前記出力層は、少なくとも1つの完全結合層を含み、前記複数のニューラルネットワーク構造層における前記いずれかの層は、少なくとも1つの畳み込み層を含む請求項1に記載の機器学習方法。
  7. 前記分類モデルは、ニューラルネットワークに関連づけられる請求項1に記載の機器学習方法。
  8. 複数の命令及びモデルパラメータを記憶するためのメモリと、
    前記メモリに接続されるプロセッサと、
    を備える機器学習装置であって、
    前記プロセッサは、分類モデルを実行するとともに、
    前記メモリから前記モデルパラメータを取得して、前記モデルパラメータに基づいて複数のニューラルネットワーク構造層を含む前記分類モデルを実行し、
    複数のトレーニングサンプルに基づいて、前記複数のニューラルネットワーク構造層における出力層に対応する第1の損失と、前記複数のニューラルネットワーク構造層における前記出力層よりも前に位置するいずれかの層に対応する第2の損失を算出し、
    前記第1の損失及び前記第2の損失に基づいて前記モデルパラメータに対して複数の更新操作を実行して前記分類モデルをトレーニングするように、
    前記複数の命令を実行するためのものであり、
    前記プロセッサは、さらに、
    前記複数のトレーニングサンプルに基づいて、前記分類モデルから複数の抽出特徴を生成し、
    前記複数の抽出特徴に基づいて前記第2の損失を算出するように構成されており、
    前記複数の抽出特徴は、前記複数のニューラルネットワーク構造層の前記いずれかの層に対応し、前記第2の損失は、前記複数の抽出特徴の期待値と前記複数の抽出特徴の積との差の値である、
    機器学習装置。
  9. 前記プロセッサは、更に、
    前記複数のトレーニングサンプルに基づいて前記複数のニューラルネットワーク構造層の前記出力層から複数の予測ラベルを生成し、及び
    前記第1の損失を算出するように、前記複数の予測ラベルと前記複数のトレーニングサンプルの複数のトレーニングラベルとを比較するためのものである、
    請求項に記載の機器学習装置。
  10. 前記プロセッサは、更に、
    前記第1の損失及び前記第2の損失に基づいて複数の損失差を算出し、及び
    前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新するためのもので
    請求項に記載の機器学習装置。
  11. 前記プロセッサは、更に、
    前記複数の抽出特徴と、前記複数のトレーニングサンプルの複数のトレーニングラベルの間の第3の損失を算出するためのものであり、
    前記複数の抽出特徴は、前記複数のニューラルネットワーク構造層のいずれかの層に対応し、
    前記第3の損失は、
    Figure 0007307785000042
    によって算出され、ここで、
    Figure 0007307785000043
    は、前記複数のトレーニングラベルであり、nとiは、正の整数であり、H i,j は、前記抽出特徴であり、
    Figure 0007307785000044
    は、強活性関数である、
    請求項に記載の機器学習装置。
  12. 前記プロセッサは、更に、
    前記第1の損失、前記第2の損失、及び前記第3の損失に基づいて複数の損失差を算出し、及び
    前記複数の損失差に基づいて前記分類モデルに対して複数の逆伝搬操作を実行して前記モデルパラメータを更新するためのものである、
    請求項11に記載の機器学習装置。
  13. 前記出力層は、少なくとも1つの完全結合層を含み、前記複数のニューラルネットワーク構造層におけるいずれかの層は、少なくとも1つの畳み込み層を含む請求項に記載の機器学習装置。
  14. 前記分類モデルは、ニューラルネットワークに関連づけられる請求項に記載の機器学習装置。
JP2021195279A 2020-12-02 2021-12-01 機器学習装置及び方法 Active JP7307785B2 (ja)

Applications Claiming Priority (6)

Application Number Priority Date Filing Date Title
US202063120216P 2020-12-02 2020-12-02
US63/120,216 2020-12-02
US202163152348P 2021-02-23 2021-02-23
US63/152,348 2021-02-23
US17/448,711 US20220172064A1 (en) 2020-12-02 2021-09-24 Machine learning method and machine learning device for eliminating spurious correlation
US17/448,711 2021-09-24

Publications (2)

Publication Number Publication Date
JP2022088341A JP2022088341A (ja) 2022-06-14
JP7307785B2 true JP7307785B2 (ja) 2023-07-12

Family

ID=78820691

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2021195279A Active JP7307785B2 (ja) 2020-12-02 2021-12-01 機器学習装置及び方法

Country Status (5)

Country Link
US (1) US20220172064A1 (ja)
EP (1) EP4009245A1 (ja)
JP (1) JP7307785B2 (ja)
CN (1) CN114648094A (ja)
TW (1) TWI781000B (ja)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116976221B (zh) * 2023-08-10 2024-05-17 西安理工大学 基于冲蚀特性的堰塞体溃决峰值流量预测方法及存储介质

Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2019032807A (ja) 2017-08-04 2019-02-28 富士電機株式会社 要因分析システム、要因分析方法およびプログラム
JP2019096006A (ja) 2017-11-21 2019-06-20 キヤノン株式会社 情報処理装置、情報処理方法
WO2020090651A1 (ja) 2018-10-29 2020-05-07 日本電信電話株式会社 音響モデル学習装置、モデル学習装置、それらの方法、およびプログラム
CN111476363A (zh) 2020-03-13 2020-07-31 清华大学 区分化变量去相关的稳定学习方法及装置
JP2020135465A (ja) 2019-02-20 2020-08-31 株式会社東芝 学習装置、学習方法、プログラムおよび認識装置

Family Cites Families (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US10108850B1 (en) * 2017-04-24 2018-10-23 Intel Corporation Recognition, reidentification and security enhancements using autonomous machines
US20200012890A1 (en) * 2018-07-06 2020-01-09 Capital One Services, Llc Systems and methods for data stream simulation
US11954881B2 (en) * 2018-08-28 2024-04-09 Apple Inc. Semi-supervised learning using clustering as an additional constraint
CN109766954B (zh) * 2019-01-31 2020-12-04 北京市商汤科技开发有限公司 一种目标对象处理方法、装置、电子设备及存储介质
US11699070B2 (en) * 2019-03-05 2023-07-11 Samsung Electronics Co., Ltd Method and apparatus for providing rotational invariant neural networks

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2019032807A (ja) 2017-08-04 2019-02-28 富士電機株式会社 要因分析システム、要因分析方法およびプログラム
JP2019096006A (ja) 2017-11-21 2019-06-20 キヤノン株式会社 情報処理装置、情報処理方法
WO2020090651A1 (ja) 2018-10-29 2020-05-07 日本電信電話株式会社 音響モデル学習装置、モデル学習装置、それらの方法、およびプログラム
JP2020135465A (ja) 2019-02-20 2020-08-31 株式会社東芝 学習装置、学習方法、プログラムおよび認識装置
CN111476363A (zh) 2020-03-13 2020-07-31 清华大学 区分化变量去相关的稳定学习方法及装置

Also Published As

Publication number Publication date
TWI781000B (zh) 2022-10-11
EP4009245A1 (en) 2022-06-08
JP2022088341A (ja) 2022-06-14
US20220172064A1 (en) 2022-06-02
TW202223770A (zh) 2022-06-16
CN114648094A (zh) 2022-06-21

Similar Documents

Publication Publication Date Title
EP3872705B1 (en) Detection model training method and apparatus and terminal device
US11151417B2 (en) Method of and system for generating training images for instance segmentation machine learning algorithm
US11836931B2 (en) Target detection method, apparatus and device for continuous images, and storage medium
CN107403426B (zh) 一种目标物体检测方法及设备
JP2020071883A (ja) モデル訓練方法、データ認識方法及びデータ認識装置
US20230267381A1 (en) Neural trees
US20220188636A1 (en) Meta pseudo-labels
CN113435430B (zh) 基于自适应时空纠缠的视频行为识别方法、系统、设备
CN113826125A (zh) 使用无监督数据增强来训练机器学习模型
CN114692732B (zh) 一种在线标签更新的方法、系统、装置及存储介质
CN113380413A (zh) 一种构建无效再通fr预测模型的方法和装置
JP7226696B2 (ja) 機械学習方法、機械学習システム及び非一時的コンピュータ可読記憶媒体
CN115511069A (zh) 神经网络的训练方法、数据处理方法、设备及存储介质
CN113537630A (zh) 业务预测模型的训练方法及装置
JP7307785B2 (ja) 機器学習装置及び方法
CN116912568A (zh) 基于自适应类别均衡的含噪声标签图像识别方法
CN117765432A (zh) 一种基于动作边界预测的中学理化生实验动作检测方法
CN110909860A (zh) 神经网络参数初始化的方法和装置
CN114462526A (zh) 一种分类模型训练方法、装置、计算机设备及存储介质
Yan Convolutional Neural Networks and Recurrent Neural Networks
CN112686277A (zh) 模型训练的方法和装置
CA3070816A1 (en) Method of and system for generating training images for instance segmentation machine learning algorithm
CN115658307B (zh) 一种基于压缩数据直接计算的智能负载处理方法和系统
CN114708471B (zh) 跨模态图像生成方法、装置、电子设备与存储介质
EP4116874A1 (en) A method for training a machine learning model to recognize a pattern in an input signal

Legal Events

Date Code Title Description
A621 Written request for application examination

Free format text: JAPANESE INTERMEDIATE CODE: A621

Effective date: 20220221

A131 Notification of reasons for refusal

Free format text: JAPANESE INTERMEDIATE CODE: A131

Effective date: 20230124

A977 Report on retrieval

Free format text: JAPANESE INTERMEDIATE CODE: A971007

Effective date: 20230125

RD02 Notification of acceptance of power of attorney

Free format text: JAPANESE INTERMEDIATE CODE: A7422

Effective date: 20230419

A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20230420

TRDD Decision of grant or rejection written
A01 Written decision to grant a patent or to grant a registration (utility model)

Free format text: JAPANESE INTERMEDIATE CODE: A01

Effective date: 20230627

A61 First payment of annual fees (during grant procedure)

Free format text: JAPANESE INTERMEDIATE CODE: A61

Effective date: 20230630

R150 Certificate of patent or registration of utility model

Ref document number: 7307785

Country of ref document: JP

Free format text: JAPANESE INTERMEDIATE CODE: R150