JP2022042487A - ドメイン適応型ニューラルネットワークの訓練方法 - Google Patents
ドメイン適応型ニューラルネットワークの訓練方法 Download PDFInfo
- Publication number
- JP2022042487A JP2022042487A JP2021136658A JP2021136658A JP2022042487A JP 2022042487 A JP2022042487 A JP 2022042487A JP 2021136658 A JP2021136658 A JP 2021136658A JP 2021136658 A JP2021136658 A JP 2021136658A JP 2022042487 A JP2022042487 A JP 2022042487A
- Authority
- JP
- Japan
- Prior art keywords
- target data
- loss function
- feature
- data
- label
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 75
- 238000012549 training Methods 0.000 title claims abstract description 70
- 238000013528 artificial neural network Methods 0.000 title claims abstract description 69
- 230000003044 adaptive effect Effects 0.000 title claims abstract description 39
- 230000006870 function Effects 0.000 claims abstract description 159
- 238000000605 extraction Methods 0.000 claims description 26
- 239000000284 extract Substances 0.000 claims description 21
- 230000006978 adaptation Effects 0.000 description 15
- 238000004821 distillation Methods 0.000 description 9
- 230000008569 process Effects 0.000 description 9
- 238000013140 knowledge distillation Methods 0.000 description 7
- 238000010586 diagram Methods 0.000 description 6
- 238000004088 simulation Methods 0.000 description 6
- 238000012545 processing Methods 0.000 description 4
- 230000011218 segmentation Effects 0.000 description 4
- 238000004422 calculation algorithm Methods 0.000 description 3
- 238000002372 labelling Methods 0.000 description 3
- 230000003287 optical effect Effects 0.000 description 3
- 230000005540 biological transmission Effects 0.000 description 2
- 238000004891 communication Methods 0.000 description 2
- 238000013527 convolutional neural network Methods 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 239000004065 semiconductor Substances 0.000 description 2
- 230000009471 action Effects 0.000 description 1
- 238000003915 air pollution Methods 0.000 description 1
- 238000004364 calculation method Methods 0.000 description 1
- 230000008859 change Effects 0.000 description 1
- 239000003086 colorant Substances 0.000 description 1
- 230000000052 comparative effect Effects 0.000 description 1
- 238000013461 design Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000007613 environmental effect Effects 0.000 description 1
- 238000001914 filtration Methods 0.000 description 1
- 230000014509 gene expression Effects 0.000 description 1
- 238000011068 loading method Methods 0.000 description 1
- 238000004519 manufacturing process Methods 0.000 description 1
- 239000013618 particulate matter Substances 0.000 description 1
- 238000012360 testing method Methods 0.000 description 1
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2415—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/047—Probabilistic or stochastic networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Computing Systems (AREA)
- Software Systems (AREA)
- Molecular Biology (AREA)
- Computational Linguistics (AREA)
- Biophysics (AREA)
- Biomedical Technology (AREA)
- Mathematical Physics (AREA)
- General Health & Medical Sciences (AREA)
- Health & Medical Sciences (AREA)
- Probability & Statistics with Applications (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Evolutionary Biology (AREA)
- Image Analysis (AREA)
Abstract
【課題】ドメイン適応型ニューラルネットワークの訓練方法及び装置を提供する。【解決手段】方法は、ソースデータ及びターゲットデータの特徴に基づいてターゲットデータの第1のラベルを予測し、ソースデータセットの各クラスのクラス中心とターゲットデータの特徴との間の距離に基づいて第2のラベルを決定し、ターゲットデータセットから第1のラベルと第2のラベルとが同一のターゲットデータを選択し、ラベルをターゲットデータの疑似ラベルとし、ターゲットデータに基づいてターゲットデータセットの各クラスのクラス中心を計算し、ターゲットデータセットのクラス中心間の距離に基づいて第1の損失関数を構築し、ターゲットデータ及び疑似ラベルに基づいて第2の損失関数を構築し、ソースデータセットのソースデータ及びターゲットデータについて第3の損失関数を構築し、第1~第3の損失関数に基づいてニューラルネットワークを訓練する。【選択図】図3
Description
本発明は、一般的に、ドメイン適応(domain adaptation)関し、具体的には、教師なしドメイン適応のためのニューラルネットワーク及びその訓練(トレーニング)方法に関する。
教師なしドメイン適応とは、ラベル付けされたソースデータを用いて訓練されたモデルをラベル付けされていないデータのターゲットドメインに移行すると共に、該モデルのターゲットドメインでのパフォーマンスを可能な限り維持することを意味する。ソースドメインとターゲットドメインとの間にデータセットの偏差があり、ターゲットドメインにラベル付きデータが不十分であるため、ラベル付けされたソースデータを用いて訓練されたモデルのパフォーマンスがターゲットドメインで低下する場合がある。教師なしドメイン適応の訓練プロセスは、ソースドメインのラベル付きデータ及びターゲットドメインのラベルなしデータの両方を利用することで、ドメイン間の差異を効果的に軽減し、モデルの堅牢性を向上させることができる。
現在、教師なしドメイン適応の主流の方法は、例えば敵対的訓練などのドメイン不変の特徴(domain-invariant features)の学習方法を含む。典型的な敵対的訓練方法は、ドメイン敵対的ニューラルネットワークである。このニューラルネットワークでは、特徴抽出ネットワークの後にドメイン識別器を追加して、特徴がソースドメインからのものであるか、それともターゲットドメインからのものであるかを判断し、特徴抽出ネットワークとドメイン識別器との間に勾配反転層を追加する。ドメイン識別器の損失関数を最小化する場合、特徴抽出ネットワークは、勾配反転層によりドメイン不変の特徴を学習することができる。
さらに、知識の蒸留(knowledge distillation)は、最近、教師なしドメイン適応に導入され、次のような多くの新しい方法が提案されている。例えば、自己アンサンブリングの教師モデルを使用して、生徒モデルにターゲットドメインのラベルなしデータを学習させる。自己アンサンブリングの教師モデルを利用して、より正確なターゲットデータの疑似ラベルを取得する。ソースデータからターゲットデータに類似するデータを蒸留して、事前訓練モデルを微調整する。セマンティックレベル(クラスレベル)でソースドメインの特徴とターゲットドメインの特徴とのアラインメントを行い、即ち、ソースドメインとターゲットドメインとの同一のクラスの平均特徴(クラス中心)を接近させる。
以下は、これらの従来の方法を簡単に紹介する。
図1は、典型的なドメイン敵対的ニューラルネットワークの構造を示している。図1に示すように、ドメイン敵対的ニューラルネットワークは、特徴抽出器F、分類器Cs、及びドメイン識別器Dを含む。ドメイン識別器Dは、勾配反転層を介して特徴抽出器Fに接続され、勾配反転層は、勾配に特定の負の数を乗算して、特徴抽出器Fに送り返す。Isはラベル付けされたソースデータを表し、Itはラベル付けされていないターゲットデータを表し、両者は何れも特徴抽出器Fに入力される。ソースデータについて特徴抽出器Fにより抽出された特徴は、ソースデータのクラスを予測するために、分類器Csに入力される。また、ソースデータとターゲットデータの両方について特徴抽出器Fにより抽出された特徴が何れもドメイン識別器Dに入力され、ドメイン識別器Dは、入力された特徴に基づいて、現在処理されているデータがソースドメインからのものであるか、それともターゲットドメインからのものであるかを識別する。ドメイン敵対的ニューラルネットワークの訓練では、ソースドメインの分類交差エントロピー損失関数Lc及びドメイン識別のバイナリ交差エントロピー損失関数Ladvを使用して、損失関数Lc及びLadvを最小化するように、標準の逆伝播アルゴリズムに従って訓練を行うことで、特徴抽出器Fにドメイン不変の特徴を学習させる。
図2は自己アンサンブリング型教師モデルの構造を示し、ここで、生徒ネットワークのパラメータの指数移動平均を使用して教師ネットワークを構築する。図2では、xSiはラベル付けされたソースデータを表し、xTiはラベル付けされていないターゲットデータを表し、ySiはソースデータの真のラベルを表し、zTiは生徒ネットワークによるターゲットデータの予測確率を表し、
(外1)
は教師ネットワークによるターゲットデータの予測確率を表す。
(外1)
は教師ネットワークによるターゲットデータの予測確率を表す。
このスキームの前提としては、教師ネットワークの予測正確率が生徒ネットワークの予測正確率よりも高いと仮定する。生徒ネットワークが教師ネットワークの予測確率からターゲットデータの隠れた知識を学習できるため、該スキームは知識の蒸留である。ソースデータxSiについて、生徒ネットワークの予測確率zTiと真のラベルySiに基づく交差エントロピー損失関数が採用されている。ターゲットデータxTiについて、教師ネットワークの予測確率
(外2)
と生徒ネットワークの予測確率zTiの平均二乗誤差が損失関数として使用されている。そして、上記の2つの損失関数に対して重み付け加算を行い、最終的な損失関数を取得する。
(外2)
と生徒ネットワークの予測確率zTiの平均二乗誤差が損失関数として使用されている。そして、上記の2つの損失関数に対して重み付け加算を行い、最終的な損失関数を取得する。
ここで、Xs,kはソースドメインXsにおけるk番目のクラスに属する全てのデータサンプル(真のラベルに基づいて決定される)を表し、Xt,kはターゲットドメインXtにおけるk番目のクラスとしてラベル付けされた全てのデータサンプルを表す(疑似ラベルに基づいて決定される)。λs,kは、ソースドメインにおけるk番目のクラスのクラス中心、即ち、k番目のクラスに属する全てのソースデータの特徴Fの平均値を表す。同様に、λt,kは、ターゲットドメインにおけるk番目のクラスのクラス中心、即ち、k番目のクラスとしてラベル付けされた全てのターゲットデータの特徴Fの平均値を表す。ターゲットデータの疑似ラベルは、分類器を使用してターゲットデータのクラスを予測することで取得される。数式(1)に示すセマンティックアラインメント損失関数La(Xs,Xt)は、ソースドメインとターゲットドメインにおける同一のクラスのクラス中心の間の距離を表す。
上記の方法は良好な結果を達成しているが、まだ改善する必要な幾つかの問題がある。まず、セマンティックアラインメントでは、ターゲットデータの疑似ラベルの正確さは、ターゲットドメインにおけるクラス中心に大きな影響を与える。分類境界近傍にある一部のデータについて、疑似ラベルが誤っていると、クラス中心の計算結果に大きな偏差が発生する。また、比較学習では、誤った疑似ラベルは、クラス内のデータサンプルのクラスタリング及びクラス間のデータサンプルの分離の制約を損なう。さらに、自己アンサンブリングの平均教師モデルでは、指数移動平均に固定の減衰率が使用される場合が多いが、現在のモデルの性能が可変であるたね、固定の減衰率では現在のモデルの性能に応じてアンサンブリングの速度を調整できない。さらに、蒸留データを使用して微調整を行う場合、この方法では2つの段階が必要であり、中間の切り替え動作が追加され、訓練を一括的に完了することができない。
本発明は、ドメイン適応型ニューラルネットワークの訓練方法及び装置を提供する。
本発明の1つの態様では、コンピュータにより実行される、ドメイン適応型ニューラルネットワークを訓練する方法であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記コンピュータは、命令が記憶されたメモリ、及びプロセッサを含み、前記命令は、前記プロセッサにより実行される際に、前記プロセッサに前記方法を実行させ、前記方法は、前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法を提供する。
本発明のもう1つの態様では、ドメイン適応型ニューラルネットワークを訓練する装置であって、前記ドメイン適応型ニューラルネットワークは、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、前記装置は、プログラムが記憶されたメモリと、1つ又は複数のプロセッサと、を含み、前記プロセッサは、前記プログラムを実行することで、ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を実行する、装置を提供する。
本発明のもう1つの態様では、ドメイン適応型ニューラルネットワークを訓練するプログラムが記憶された記憶媒体であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記プログラムがコンピュータにより実行される際に、前記コンピュータに、前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法を実行させる、記憶媒体を提供する。
図3は、本発明に係る教師なしドメイン適応のニューラルネットワークの構造を概略的に示している。図3に示すように、該ニューラルネットワークは、図1を参照しながら説明されたドメイン敵対的ニューラルネットワークを含み、該ニューラルネットワークは、第1の特徴抽出器310、第1の分類器320、ドメイン識別器330、及び勾配反転層(図示せず)を含む。また、ニューラルネットワークは、第2の特徴抽出器310_T及び第2の分類器320_Tをさらに含む。なお、既存の技術として、図3における第1の特徴抽出器310、第2の特徴抽出器310_T、第1の分類器320、第2の分類器320_T、及びドメイン識別器330は何れも畳み込みニューラルネットワークにより実現されてもよい。本明細書では、これらのユニットを実現する畳み込みニューラルネットワークの構造について詳細に説明しない。
第1の特徴抽出器310及び第1の分類器320は、生徒ネットワークを構成し、第2の特徴抽出器310_T及び第2の分類器320_Tは、教師ネットワークを構成する。第2の(教師)特徴抽出器310_Tのパラメータは、第1の(生徒)特徴抽出器310のパラメータの指数移動平均であり、第2の(教師)分類器320_Tのパラメータは、第1の(生徒)分類器320のパラメータの指数移動平均である。
ソースデータXs及びターゲットデータXtは、第1の特徴抽出器310及び第2の特徴抽出器310_Tのそれぞれに入力される。第1の特徴抽出器310は、ソースデータXs及びターゲットデータXtについて抽出された特徴を第1の分類器320に入力し、第2の特徴抽出器310_Tは、ソースデータXs及びターゲットデータXtについて抽出された特徴を第2の分類器320_Tに入力する。
図3に示されるドメイン適応型ニューラルネットワークの訓練では、本発明は、複数の損失関数を提案し、以下はその詳細を説明する。
本発明の1つの態様では、ターゲットデータの疑似ラベルの精度を向上させるための投票スキーム(voting scheme)が提案される。投票スキームとは、少なくとも2つの予測方式を使用してターゲットデータの予測ラベルに投票することを意味する。例えば、ターゲットデータ
(外3)
について、分類器を利用してそのクラスラベルを予測して予測結果
(外4)
を取得する。また、クラス中心最近傍アルゴリズムを使用してそのラベルを予測して、以下の数式(2)及び(3)に示すように、予測結果ldを取得する。
(外3)
について、分類器を利用してそのクラスラベルを予測して予測結果
(外4)
を取得する。また、クラス中心最近傍アルゴリズムを使用してそのラベルを予測して、以下の数式(2)及び(3)に示すように、予測結果ldを取得する。
ここで、λs,kは、ソースドメインにおけるk番目のクラスのクラス中心、即ち、k番目のクラスに属する全てのソースデータの特性の平均値を表し、Kは、ソースドメインにおける全てのクラスの数を表し、ldは、ソースドメインの全てのK個のクラス中心のうち、ターゲットデータ
(外5)
に最も近いクラス中心に対応するクラスを表す。
(外5)
に最も近いクラス中心に対応するクラスを表す。
予測ラベルlcと予測ラベルldとが一致する場合、該ターゲットデータ
(外6)
を選択し、予測ラベルlc又はldを該ターゲットデータ
(外7)
の疑似ラベルとする。予測ラベルlcと予測ラベルldとが一致しない場合、該ターゲットデータ
(外8)
を無視する。選択された全てのターゲットデータ
(外9)
は、最適ターゲットデータセット
(外10)
を構成する。分類器予測のみを実行する場合、又はクラス中心最近傍予測のみを実行する場合に比べて、このようにフィルタリングにより選択されたデータセット
(外11)
における各ターゲットデータの疑似ラベルの正確率が高くなる。従って、本発明に係る投票スキームは、予測結果がより正確なターゲットデータを効果的に選別することができる。
(外6)
を選択し、予測ラベルlc又はldを該ターゲットデータ
(外7)
の疑似ラベルとする。予測ラベルlcと予測ラベルldとが一致しない場合、該ターゲットデータ
(外8)
を無視する。選択された全てのターゲットデータ
(外9)
は、最適ターゲットデータセット
(外10)
を構成する。分類器予測のみを実行する場合、又はクラス中心最近傍予測のみを実行する場合に比べて、このようにフィルタリングにより選択されたデータセット
(外11)
における各ターゲットデータの疑似ラベルの正確率が高くなる。従って、本発明に係る投票スキームは、予測結果がより正確なターゲットデータを効果的に選別することができる。
なお、上述した分類器予測及びクラス中心最近傍予測は、単なる少なくとも2つの異なる予測方式の一例であり、本発明はこれに限定されず、当業者は他の適切な予測方式を容易に想到できる。
次に、最適ターゲットデータセット
(外12)
に基づいて、図3に示すニューラルネットワークを訓練するためのセマンティックアラインメント損失関数La(図3には示されていない)を構築し、該セマンティックアラインメント損失関数Laは、第1の損失関数とも称される。具体的には、K個の所定のクラスのうちのk番目のクラスを一例として、まず、数式(2)に従ってソースドメインにおけるk番目のクラスのクラス中心λs,kを計算し、以下の数式(4)に従って最適ターゲットデータセット
(外13)
におけるk番目のクラスのクラス中心λt,kを計算し、数式(5)に従ってクラス中心λs,kとクラス中心λt,kとの間の距離
(外14)
を計算する。このように、全てのk個のクラスについて、ソースドメインのクラス中心とターゲットドメインのクラス中心との間の距離を、セマンティックアラインメント損失関数としてそれぞれ計算する。訓練では、距離
(外15)
を最小化することを目標とする。
(外12)
に基づいて、図3に示すニューラルネットワークを訓練するためのセマンティックアラインメント損失関数La(図3には示されていない)を構築し、該セマンティックアラインメント損失関数Laは、第1の損失関数とも称される。具体的には、K個の所定のクラスのうちのk番目のクラスを一例として、まず、数式(2)に従ってソースドメインにおけるk番目のクラスのクラス中心λs,kを計算し、以下の数式(4)に従って最適ターゲットデータセット
(外13)
におけるk番目のクラスのクラス中心λt,kを計算し、数式(5)に従ってクラス中心λs,kとクラス中心λt,kとの間の距離
(外14)
を計算する。このように、全てのk個のクラスについて、ソースドメインのクラス中心とターゲットドメインのクラス中心との間の距離を、セマンティックアラインメント損失関数としてそれぞれ計算する。訓練では、距離
(外15)
を最小化することを目標とする。
最適ターゲットデータセット
(外16)
におけるターゲットデータの疑似ラベルがより高い精度を有するため、データセット
(外17)
を使用して算出されたターゲットドメインのクラス中心λt,kがより正確であり、セマンティクスアラインメントロス損失関数を向上させることができる。
(外16)
におけるターゲットデータの疑似ラベルがより高い精度を有するため、データセット
(外17)
を使用して算出されたターゲットドメインのクラス中心λt,kがより正確であり、セマンティクスアラインメントロス損失関数を向上させることができる。
さらに、最適ターゲットデータセット
(外18)
におけるターゲットデータ及びその疑似ラベルを使用して、図3に示す第1の分類器320を訓練するための交差エントロピー損失関数(図3における
(外19)
)を構築してもよく、該交差エントロピー損失関数は第2の損失関数とも称され、具体的には、以下の数式(6)に示すものである。
(外18)
におけるターゲットデータ及びその疑似ラベルを使用して、図3に示す第1の分類器320を訓練するための交差エントロピー損失関数(図3における
(外19)
)を構築してもよく、該交差エントロピー損失関数は第2の損失関数とも称され、具体的には、以下の数式(6)に示すものである。
従来技術では、通常、真のラベルを有するソースデータのみを使用して第1の分類器320を訓練するが、本発明では、最適ターゲットデータセット
(外23)
におけるターゲットデータの疑似ラベルの精度が比較的に高いため、本発明は、最適ターゲットデータセット
(外24)
をさらに使用して第1の分類器320を訓練し、これによって、ネットワークモデルのターゲットデータへの認識能力を向上させることができる。
(外23)
におけるターゲットデータの疑似ラベルの精度が比較的に高いため、本発明は、最適ターゲットデータセット
(外24)
をさらに使用して第1の分類器320を訓練し、これによって、ネットワークモデルのターゲットデータへの認識能力を向上させることができる。
さらに、最適ターゲットデータセット
(外25)
におけるターゲットデータをソースデータと共に対比学習に使用することで、次の効果を実現することができる。クラス内の特徴を収束させるように制約すると共に、異なるクラスの特徴間の距離が大きくなるようにクラス間の特徴を離間させる。ここで、例えば数式(7)に示す対比学習損失関数Lcon(図3に示されていない)を構築してもよく、該対比学習損失関数Lconは第3の損失関数とも称される。
(外25)
におけるターゲットデータをソースデータと共に対比学習に使用することで、次の効果を実現することができる。クラス内の特徴を収束させるように制約すると共に、異なるクラスの特徴間の距離が大きくなるようにクラス間の特徴を離間させる。ここで、例えば数式(7)に示す対比学習損失関数Lcon(図3に示されていない)を構築してもよく、該対比学習損失関数Lconは第3の損失関数とも称される。
ここで、xi又はxjは、ソースデータセット及び最適ターゲットデータセット
(外26)
におけるデータサンプルを表し、f(xi)及びf(xj)は、データサンプルの特徴を表す。δijは、指示変数であり、xiとxjが同一のクラスのデータである場合、δijは1であり、xiとxjが異なるクラスのデータである場合、δijは0である。d(f(xi),f(xj))は、データxiの特徴とデータxjとの間の距離を表す。mは定数であり、例えば、m=3。
(外26)
におけるデータサンプルを表し、f(xi)及びf(xj)は、データサンプルの特徴を表す。δijは、指示変数であり、xiとxjが同一のクラスのデータである場合、δijは1であり、xiとxjが異なるクラスのデータである場合、δijは0である。d(f(xi),f(xj))は、データxiの特徴とデータxjとの間の距離を表す。mは定数であり、例えば、m=3。
上述したように、教師なし領域適応に使用される現在の知識蒸留方法は、指数移動平均を使用して教師ネットワークを構築するが、その減衰率は通常固定値に設定されているため、性能が優れた教師ネットワークを取得することは困難である。具体的には、指数移動平均とは、以下の数式(8)に示すように、特定の減衰率に基づいて教師ネットワークのパラメータをゆっくりと更新することを意味する。
ここで、Sは、生徒ネットワークの現在のパラメータを表し、Ttは、教師ネットワークの現在のパラメータ(更新後のパラメータ)を表し、Tt-1は、教師ネットワークの前のパラメータ(未更新のパラメータ)を表し、減衰率decayは、通常0.99に固定的に設定されている。
本発明のもう1つの態様では、本発明は、教師モデルの性能を改善するために、自己学習の減衰率を提案している。「自己学習」とは、減衰率が学習可能なパラメータ又は学習済みネットワーク(learnt network)の出力であることを意味する。本発明では、微分可能な変数を減衰率として使用し、或いは、1つの全結合層の出力を減衰率として使用してもよい。後者の場合、例えば該全結合層が出力層と並列に出力層の直前の層に接続されるように、該全結合層を第2の分類器320_Tの出力層と同一のレベルに設定してもよい。このように設定された減衰率が固定値ではなく、モデルのパフォーマンスの変化に応じてアンサンブリングの速度を調整することができるため、知識蒸留のパフォーマンスを向上させることができる。
また、本発明のもう1つの態様では、本発明は、ドメイン識別器に基づくデータ蒸留が提案されている。具体的には、ソースデータを使用して交差エントロピー損失関数に基づいて分類器を訓練する場合、ソースデータのうちのターゲットデータと類似するソースデータにより高い重みを与える。これによって、ターゲットデータとの類似度が高いソースデータが訓練でより大きな役割を果たすことができるため、訓練により取得された分類器は、ターゲットドメインでより優れたパフォーマンスを実現することができる。
ドメイン識別器の出力を使用して、ターゲットデータとの類似度が高いソースデータであるか否かを判断してもよい。ドメイン識別器は、現在のデータがソースデータである確率を予測できるため、この確率が小さいほど、現在のデータとターゲットデータとの類似度が高いことを意味する。言い換えれば、ドメイン識別器により出力された確率と類似度との間には反比例の関係がある。従って、ドメイン識別器の出力を使用してソースデータを重み付けしてもよい。
この原理に従って、以下の数式(9)又は(10)に示すように、図3に示すニューラルネットワークを訓練するためのデータ蒸留損失関数Ldd(図3に示されていない)を構築してもよく、該データ蒸留損失関数Lddは、第4の損失関数とも称される。
ここで、psは、ソースデータについてラベルを予測する場合、予測結果がその真のラベルである確率を表す。pdは、ドメイン識別器により決定されたソースデータがソースドメインからのものである確率を表し、1-pd又は1/pdは、該ソースデータに割り当てられた重みを表す。
ドメイン識別器により決定された確率pdが比較的に小さい場合(現在のソースデータとターゲットデータとの類似度が比較的に高いことを意味する)、1-pd又は1/pdの値が大きくなるため、現在のソースデータに与えられる重みが大きくなる。従って、(ターゲットデータと類似する)現在のソースデータは、訓練でより大きな役割を果たすことができる。
また、本発明のもう1つの態様では、本発明は、図2に示す自己アンサンブリング型教師モデルの構造をさらに改善する。図4は改善されたネットワーク構造を示している。
図4に示すように、ソースデータxSi及びターゲットデータxTiが生徒ネットワークだけでなく、教師ネットワークにも入力される。これに対して、図2では、ターゲットデータxTiのみが教師ネットワークに入力される。従って、本発明は、ターゲットドメインに対して蒸留学習を行うだけでなく、ソースドメインに対しても蒸留学習を行う。
図4では、ySiは、ソースデータxSiの真のラベルを表し、zTiは、生徒ネットワークによりターゲットデータxTiについて予測された確率(即ち、ターゲットデータxTiが各クラスに属する確率)を表し、
(外27)
は、教師ネットワークによりターゲットデータxTiについて予測された確率を表し、ZSiは、生徒ネットワークによりソースデータxSiについて予測された確率(即ち、ソースデータxSiが各クラスに属する確率)を表し、
(外28)
は、教師ネットワークによりソースデータxSiについて予測された確率を表す。また、図4における生徒ネットワークは、図3に示す第1の特徴抽出器310及び第1の分類器320を含んでもよく、図4における教師ネットワークは、図3に示す第2の特徴抽出器310_T及び第2の分類器320_Tを含んでもよく、上記の各予測確率は、第1の分類器320又は第2の分類器320_Tにより生成されてもよい。
(外27)
は、教師ネットワークによりターゲットデータxTiについて予測された確率を表し、ZSiは、生徒ネットワークによりソースデータxSiについて予測された確率(即ち、ソースデータxSiが各クラスに属する確率)を表し、
(外28)
は、教師ネットワークによりソースデータxSiについて予測された確率を表す。また、図4における生徒ネットワークは、図3に示す第1の特徴抽出器310及び第1の分類器320を含んでもよく、図4における教師ネットワークは、図3に示す第2の特徴抽出器310_T及び第2の分類器320_Tを含んでもよく、上記の各予測確率は、第1の分類器320又は第2の分類器320_Tにより生成されてもよい。
以下の数式(11)に示すように、上記の予測確率に基づいて、図3に示すニューラルネットワークを訓練するための知識蒸留損失関数Lkd(図3におけるLkd-s及びLkd-tを含む)を構築してもよく、該知識蒸留損失関数Lkdは、第5の損失関数とも称される。
ここで、
(外29)
は、ソースデータxSiについて第1の分類器320及び第2の分類器320_Tのそれぞれにより予測された確率の平均二乗誤差を表し、
(外30)
は、ターゲットデータxTiについて第1の分類器320及び第2の分類器320_Tのそれぞれにより予測された確率の平均二乗誤差を表す。nは、ソースデータの数を表し、mは、ターゲットデータの数を表す。
(外29)
は、ソースデータxSiについて第1の分類器320及び第2の分類器320_Tのそれぞれにより予測された確率の平均二乗誤差を表し、
(外30)
は、ターゲットデータxTiについて第1の分類器320及び第2の分類器320_Tのそれぞれにより予測された確率の平均二乗誤差を表す。nは、ソースデータの数を表し、mは、ターゲットデータの数を表す。
ここで、Lc-sは、ソースデータについての分類交差エントロピー損失関数を表し、図1に示す損失関数Lcと同一である。Ladvは、ドメイン識別器のバイナリ交差エントロピー損失関数を表し、図1に示す損失関数Ladvと同一である。損失関数Lc-s及びLadvは、従来技術における既知の損失関数であるため、本明細書ではそれらの詳細な説明を省略する。
また、数式(12)におけるλ1とλ2は、それぞれ第4の損失関数Lkdと第5の損失関数Lddを重み付けするための重みであり、訓練プロセスにおいて第4の損失関数と第5の損失関数の作用の程度を制御してもよい。具体的には、重みλ1は、数式(13)に従って決定されてもよい。
ここで、p=step/totalstep、即ち、現在の反復ステップ数を訓練ステップの総数で除算した商であるため、pは訓練の進行状況を表すことができる。α及びnは、ハイパーパラメータを表し、例えば、α=200、n=10に設定してもよい。図5は、重みλ1の訓練ステップ数の増加に伴って変化する曲線を示す図である(訓練ステップの総数が5000であると仮定する)。
ここで、pは、数式(13)におけるpと同一の意味を持つ。α及びnは、ハイパーパラメータを表し、例えば、α=5、n=10に設定してもよい。図6は、重みλ2の訓練ステップ数の増加に伴って変化する曲線を示す図である(訓練ステップの総数が5000であると仮定する)。
図5及び図6に示すように、訓練の開始段階において、分類器の予測及びドメイン識別器の予測が何れも正確ではないため、好ましくは、λ1及びλ2の値を小さく設定し、訓練の進行に伴って、教師ネットワークの分類器及びドメイン識別器の予測が徐々に正確になるため、λ1及びλ2の値を徐々に増大させてもよい。これによって、知識蒸留損失関数Lkd及びデータ蒸留損失関数Lddはより大きな役割を果たすことができる。
図7は、本発明に係る最適ターゲットデータセットの生成方法を示すフローチャートである。該方法は、図9における最適ターゲットデータセット生成部960により実行されてもよい。
図7に示すように、ステップS710において、第1の特徴抽出器310がソースデータについて特徴を抽出し、第1の分類器320が抽出された特徴に基づいて、ソースデータが複数の所定クラスのうちの各クラスに属する確率を予測する。最大の確率に対応するクラスを該ソースデータのラベルとして決定する。
ステップS720において、第1の特徴抽出器310がターゲットデータについて特徴を抽出し、第1の分類器320が抽出された特徴に基づいて、ターゲットデータが各クラスに属する確率を予測する。最大の確率に対応するクラスは、該ターゲットデータの第1のラベルとして決定される。
ステップS730において、数式(2)及び(3)に従って、クラス中心最近傍アルゴリズムを使用して該ターゲットデータの第2のラベルを決定する。
ステップS740において、決定された第1のラベルと第2のラベルとが同一であるターゲットデータを選択し、該第1のラベル又は第2のラベルは、選択されたターゲットデータの疑似ラベルとされる。そして、全ての選択されたターゲットデータは、最適ターゲットデータセットを構成してもよい。
図8は、本発明に係るドメイン適応型ニューラルネットワークの訓練方法を示すフローチャートであり、図9は、本発明に係るドメイン適応型ニューラルネットワークの訓練装置のモジュール化の構成を示すブロック図である。
図8に示すように、ステップS810において、数式(2)、(4)及び(5)に従って、ソースデータセットのクラス中心と最適ターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数La(セマンティックアラインメント損失関数)を構築する。このステップは、図9における第1の損失関数生成部910により実行されてもよい。
ステップS820において、数式(6)に従って、最適ターゲットデータセットにおけるターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数
(外31)
(交差エントロピー損失関数)を構築する。このステップは、図9における第2の損失関数生成部920により実行されてもよい。
(外31)
(交差エントロピー損失関数)を構築する。このステップは、図9における第2の損失関数生成部920により実行されてもよい。
ステップS830において、数式(7)に従って、ソースデータセットにおけるソースデータ及び最適ターゲットデータセットにおけるターゲットデータについて、第3の損失関数Lcon(対比学習損失関数)を構築する。このステップは、図9における第3の損失関数生成部930により実行されてもよい。
図9から分かるように、図7に示す方法により生成された最適ターゲットデータセットは、第1の損失関数~第3の損失関数を構築するために使用される。
次に、ステップS840において、数式(9)又は(10)に従って、ドメイン識別器により出力された確率に基づいて、第4の損失関数Ldd(データ蒸留損失関数)を構築する。このステップは、図9における第4の損失関数生成部940により実行されてもよい。
ステップS850において、第2の(教師)特徴抽出器310_Tがソースデータ及びターゲットデータの特徴を抽出し、第2の(教師)分類器320_Tがソースデータ及びターゲットデータのラベルを予測する。次に、ステップS860において、数式(11)に従って、第1の分類器320の予測結果及び第2の分類器320_Tの予測結果に基づいて、第5の損失関数Lkd(知識蒸留損失関数)を構築する。ステップS860は、図9における第5の損失関数生成部950により実行されてもよい。
次に、ステップS870において、数式(12)に従って、第1の損失関数~第5の損失関数の重み付けされた組み合わせに基づいて、ニューラルネットワークを訓練する。このステップは、図9における訓練部970により実行されてもよい。
なお、本発明の訓練方法は、必ずしも図8に示す順序で実行する必要がない。例えば、第1の損失関数~第5の損失関数の生成の順序は、図示されているものとは異なってもよいし、同時に生成されてもよい。
本発明の発明者は、MNIST、USPS、SVHN(何れも既知の文字データセットである)に基づいてテストを行った。ここで、MNIST→USPS、USPS→MNIST、SVHN→MNISTの3つの方向のドメイン適応を含む。以下の表1は、本発明の解決手段と従来技術(ADDA、DANNなど)とのパフォーマンスの比較を示している。表1における値は、分類の正解率を表し、正解率が高いほど、解決手段のパフォーマンスが良くなる。表1から分かるように、本発明の解決手段のパフォーマンスは、従来技術のパフォーマンスと同等であり、或いはそれよりも優れている。
特に、表1における「source only」は、ターゲットデータを使用せず、ソースデータのみを使用して訓練を行う方式、即ち最も単純な方式を表し、比較の基準とされる。DANN(Domain-Adversarial Training of Neural Networks)は、図1に示すドメイン敵対的ニューラルネットワークを表し、ADDA(Adversarial Discriminative Domain Adaptation)は、敵対的識別ドメイン適応を表す。CAT+RevGradは、「Cluster Alignment with a Teacher for Unsupervised Domain Adaptation[C]」、Deng Z et al.、Proceedings of IEEE International Conference on Computer Vision、2019:9944-9953という技術文献に記載されている。
本発明に係るドメイン適応方法は、幅広いドメインに適用することができ、以下は、代表的な適用シナリオを例示して説明する。
[応用シナリオ1]セマンティックセグメンテーション(semantic segmentation)
セマンティックセグメンテーションとは、画像における異なる物体を表す部分を異なる色でマークすることを意味する。セマンティックセグメンテーションの応用シナリオでは、実世界の画像に対して手動でラベルを付けるコストが非常に高いため、実世界の画像はラベルが殆ど付いていない。この場合、代替的な方法として、シミュレーション環境(例えば3Dゲーム)におけるシーンの画像を用いて訓練を行う。プログラミングによりシミュレーション環境において物体の自動的なラベル付けを容易に実現できるため、ラベル付きデータを簡単に取得することができる。このように、シミュレーション環境において生成されたラベル付きデータを用いてモデルを訓練し、訓練されたモデルを用いて実際の環境の画像を処理する。しかし、シミュレーション環境は、実際の環境と完全に一致するものではないため、シミュレーション環境のデータを用いて訓練されたモデルのパフォーマンスは、実際の環境の画像を処理する場合に大幅に低下してしまう。
セマンティックセグメンテーションとは、画像における異なる物体を表す部分を異なる色でマークすることを意味する。セマンティックセグメンテーションの応用シナリオでは、実世界の画像に対して手動でラベルを付けるコストが非常に高いため、実世界の画像はラベルが殆ど付いていない。この場合、代替的な方法として、シミュレーション環境(例えば3Dゲーム)におけるシーンの画像を用いて訓練を行う。プログラミングによりシミュレーション環境において物体の自動的なラベル付けを容易に実現できるため、ラベル付きデータを簡単に取得することができる。このように、シミュレーション環境において生成されたラベル付きデータを用いてモデルを訓練し、訓練されたモデルを用いて実際の環境の画像を処理する。しかし、シミュレーション環境は、実際の環境と完全に一致するものではないため、シミュレーション環境のデータを用いて訓練されたモデルのパフォーマンスは、実際の環境の画像を処理する場合に大幅に低下してしまう。
この場合、本発明に係るドメイン適応方法を用いて、ラベル付けされたシミュレーション環境のデータ及びラベル付けされていない実際環境のデータに基づいて訓練を行うことができるため、実際の環境の画像を処理する際のモデルのパフォーマンスを向上させることができる。
[応用シナリオ2]手書き文字の認識
手書き文字は、通常、手書きの数字、文字(例えば中国語、日本語)などを含む。手書き文字の認識では、一般的に使用されるラベル付き文字セットは、MNIST、USPS、SVHNなどを含み、通常、これらのラベル付き文字データを用いてモデルを訓練する。しかし、訓練済みモデルを実際(ラベルなし)の手書き文字の認識に適用する場合、その正確率が低下する可能性がある。
手書き文字は、通常、手書きの数字、文字(例えば中国語、日本語)などを含む。手書き文字の認識では、一般的に使用されるラベル付き文字セットは、MNIST、USPS、SVHNなどを含み、通常、これらのラベル付き文字データを用いてモデルを訓練する。しかし、訓練済みモデルを実際(ラベルなし)の手書き文字の認識に適用する場合、その正確率が低下する可能性がある。
この場合、本発明に係るドメイン適応方法を用いることで、ラベル付けされたソースデータ及びラベル付けされていないターゲットデータに基づいて訓練を行うことができるため、ターゲットデータを処理する際のモデルのパフォーマンスを向上させることができる。
[応用シナリオ3]時系列データの分類と予測
時系列データの予測は、例えば、大気汚染指数の予測、ICU患者の入院期間(LOS)の予測、株式市場の予測などを含む。微粒子状物質(PM2.5)指数の時系列データを一例として、ラベル付きの訓練サンプルセットを用いて予測モデルを訓練してもよい。訓練が完了した後、訓練済みのモデルを実際の予測に適用してもよい。例えば、現時点の直前の24時間のデータ(ラベルなしデータ)に基づいて、3日間後のPM2.5指数の範囲を予測してもよい。
時系列データの予測は、例えば、大気汚染指数の予測、ICU患者の入院期間(LOS)の予測、株式市場の予測などを含む。微粒子状物質(PM2.5)指数の時系列データを一例として、ラベル付きの訓練サンプルセットを用いて予測モデルを訓練してもよい。訓練が完了した後、訓練済みのモデルを実際の予測に適用してもよい。例えば、現時点の直前の24時間のデータ(ラベルなしデータ)に基づいて、3日間後のPM2.5指数の範囲を予測してもよい。
このシナリオでは、本発明に係るドメイン適応方法を用いることで、ラベル付けされたデータ及びラベル付けされていないデータに基づいてモデルを訓練することができるため、モデルの予測正確度を向上させることができる。
[応用シナリオ4]テーブルタイプのデータの分類と予測
テーブルタイプのデータは、オンラインローンデータなどの財務データを含んでもよい。この例では、ローンの組み方の返済延滞の可能性があるか否かを予測するために、予測モデルを構築してもよく、本発明に係る方法を用いてモデルを訓練してもよい。
テーブルタイプのデータは、オンラインローンデータなどの財務データを含んでもよい。この例では、ローンの組み方の返済延滞の可能性があるか否かを予測するために、予測モデルを構築してもよく、本発明に係る方法を用いてモデルを訓練してもよい。
[応用シナリオ5]画像認識
画像認識又は画像分類の応用シナリオでは、セマンティックセグメンテーションのシナリオと同様に、実世界の画像データセットにラベルを付けるコストが高いという問題もある。従って、要件を満たすパフォーマンスを有するモデルを得るように、本発明に係るドメイン適応方法を用いて、ラベル付きデータセット(例えばImageNet)をソースデータセットとして選択し、該ソースデータセット及びラベルなしターゲットデータセットに基づいて訓練を行ってもよい。
画像認識又は画像分類の応用シナリオでは、セマンティックセグメンテーションのシナリオと同様に、実世界の画像データセットにラベルを付けるコストが高いという問題もある。従って、要件を満たすパフォーマンスを有するモデルを得るように、本発明に係るドメイン適応方法を用いて、ラベル付きデータセット(例えばImageNet)をソースデータセットとして選択し、該ソースデータセット及びラベルなしターゲットデータセットに基づいて訓練を行ってもよい。
以上は具体的な実施例を参照しながら本発明の実施形態を説明した。上記の実施例に係る方法は、ソフトウェア、ハードウェア、又はソフトウェアとハードウェアとの組み合わせにより実現されてもよい。ソフトウェアに含まれるプログラムは、装置の内部又は外部に設置された記憶媒体に予め記憶されてもよい。一例として、実行中に、これらのプログラムはランダムアクセスメモリ(RAM)に書き込まれ、プロセッサ(例えばCPU)により実行されることで、本明細書で説明された各処理を実現する。
図10は本発明を実現可能なコンピュータのハードウェアの例示的な構成を示すブロック図であり、該コンピュータのハードウェアは、本発明に係るドメイン適応型ニューラルネットワークの訓練装置の一例である。また、本発明に係るドメイン適応型ニューラルネットワークも、該コンピュータハードウェアに基づいて実現されてもよい。
図10に示すように、コンピュータ1000では、中央処理装置(CPU)1001、読み出し専用メモリ(ROM)1002及びランダムアクセスメモリ(RAM)1003がバス1004により相互に接続されている。
入力/出力インターフェース1005は、バス1004にさらに接続されている。入力/出力インターフェース1005には、キーボード、マウス、マイクロフォンなどにより構成された入力部1006、ディスプレイ、スピーカなどにより構成された出力部1007、ハードディスク、不揮発性メモリなどにより構成された記憶部1008、ネットワークインターフェースカード(ローカルエリアネットワーク(LAN)カード、モデムなど)により構成された通信部1009、及び移動可能な媒体1011をドライブするドライバ1010が接続されている。移動可能な媒体1011は、例えば磁気ディスク、光ディスク、光磁気ディスク又は半導体メモリである。
上記の構成を有するコンピュータにおいて、CPU1001は、記憶部1008に記憶されているプログラムを、入力/出力インターフェース1005及びバス1004を介してRAM1003にロードし、プログラムを実行することにより、上記の方法を実行する。
コンピュータ(CPU1001)により実行されるプログラムは、パッケージ媒体である移動可能な媒体1011に記録されてもよい。該パッケージ媒体は、例えば磁気ディスク(フロッピーディスクを含む)、光ディスク(コンパクトディスクリードオンリーメモリ(CD-ROM)、デジタルバーサタイルディスク(DVD)などを含む)、光磁気ディスク、又は半導体メモリにより形成される。また、コンピュータ(CPU1001)により実行されるプログラムは、ローカルエリアネットワーク、インターネット、デジタル衛星放送の有線又は無線の伝送媒体を介して提供されてもよい。
移動可能な媒体1011がドライバ1010にインストールされると、プログラムは、入力/出力インターフェース1005を介して記憶部1008にインストールすることができる。また、プログラムは、有線又は無線の伝送媒体を介して通信部1009で受信され、記憶部1008にインストールされる。或いは、プログラムは、ROM1002又は記憶部1008に予めインストールされてもよい。
コンピュータにより実行されるプログラムは、本明細書で説明する順序に従って処理を実行するプログラムであってもよいし、処理を並列的に実行し、或いは必要に応じて(例えば呼び出しの時に)処理を実行するプログラムであってもよい。
本明細書で説明されている装置又はユニットは論理的なものであり、物理的な装置又はエンティティに限定されない。例えば、本明細書で説明されている各ユニットの機能は複数の物理エンティティにより実現されてもよいし、本明細書で説明される複数のユニットの機能は単一の物理エンティティにより実現されてもよい。また、1つの実施例で説明される特徴、構成要素、要素、ステップなどは、該実施例に限定されず、例えば、他の実施例に適用されてもよく、例えば他の実施例の特定の特徴、構成要素、要素、ステップなどの代わりに用いてもよいし、それと組み合わせてもよい。
本発明の範囲は、本明細書に記載の具体的な実施例に限定されない。当業者により理解できるように、設計要求及び他の要因に応じて、本発明の原理及び要旨から逸脱することなく、本明細書の実施例に対して様々な修正又は変更を行ってもよい。本発明の範囲は、添付の特許請求の範囲及びその均等物により制限される。
また、上述の各実施例を含む実施形態に関し、更に以下の付記を開示するが、これらの付記に限定されない。
(付記1)
コンピュータにより実行される、ドメイン適応型ニューラルネットワークを訓練する方法であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記コンピュータは、命令が記憶されたメモリ、及びプロセッサを含み、前記命令は、前記プロセッサにより実行される際に、前記プロセッサに前記方法を実行させ、前記方法は、
前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法。
(付記2)
前記識別部が、前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定するステップと、
前記識別部により決定された確率に基づいて第4の損失関数を構築するステップと、
前記第4の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、付記1に記載の方法。
(付記3)
前記識別部により決定された確率の逆数、及び1から前記識別部により決定された確率を減算した差のうちの1つに基づいて、前記第4の損失関数を構築する、付記2に記載の方法。
(付記4)
前記ドメイン適応型ニューラルネットワークは、第2の特徴抽出部及び第2の分類部をさらに含み、
前記方法は、
前記第2の特徴抽出部が前記ソースデータについて第3の特徴を抽出するステップと、
前記第2の分類部が、前記第3の特徴に基づいて、前記ソースデータが前記各クラスに属する確率を予測するステップと、
前記第2の特徴抽出部が前記ターゲットデータについて第4の特徴を抽出するステップと、
前記第2の分類部が、前記第4の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測するステップと、
前記第1の分類部により予測された確率及び前記第2の分類部により予測された確率に基づいて、第5の損失関数を構築するステップと、
前記第5の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、付記2に記載の方法。
(付記5)
前記ソースデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差、並びに前記ターゲットデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差に基づいて、前記第5の損失関数を構築する、付記4に記載の方法。
(付記6)
前記第2の特徴抽出部のパラメータは、前記第1の特徴抽出部のパラメータの指数移動平均であり、前記第2の分類部のパラメータは、前記第1の分類部のパラメータの指数移動平均であり、
前記指数移動平均で使用される減衰率を取得する際に、微分可能な変数を前記減衰率として使用し、或いは、全結合層を使用して前記減衰率を生成し、
前記全結合層は、前記第2の分類部の出力層と並列に前記出力層の直前の層に接続されるように設定されている、付記4に記載の方法。
(付記7)
前記第1の損失関数、前記第2の損失関数、前記第3の損失関数、前記第4の損失関数及び前記第5の損失関数の重み付けされた組み合わせに基づいて、前記ドメイン適応型ニューラルネットワークを訓練し、
訓練の実行に伴って、前記第4の損失関数及び前記第5の損失関数の重みを徐々に増加させる、付記4に記載の方法。
(付記8)
前記第2の損失関数は、前記第1の分類部を訓練するための交差エントロピー損失関数である、付記1に記載の方法。
(付記9)
前記識別部は、勾配反転部を介して前記第1の特徴抽出部に接続され、
前記識別部と前記第1の特徴抽出部とは、互いに敵対的に動作する、付記1に記載の方法。
(付記10)
前記ドメイン適応型ニューラルネットワークは、画像認識を実行するために使用され、前記ソースデータ及び前記ターゲットデータは、画像データであり、或いは、
前記ドメイン適応型ニューラルネットワークは、財務データを処理するために使用され、前記ソースデータ及び前記ターゲットデータは、テーブルタイプのデータであり、或いは、
前記ドメイン適応型ニューラルネットワークは、環境気象データ又は医療データを処理するために使用され、前記ソースデータ及び前記ターゲットデータは、時系列データである、付記1に記載の方法。
(付記11)
ドメイン適応型ニューラルネットワークを訓練する装置であって、
前記ドメイン適応型ニューラルネットワークは、
ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
前記装置は、
プログラムが記憶されたメモリと、
1つ又は複数のプロセッサと、を含み、
前記プロセッサは、前記プログラムを実行することで、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を実行する、装置。
(付記12)
ドメイン適応型ニューラルネットワークを訓練する装置であって、
前記ドメイン適応型ニューラルネットワークは、
ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
前記装置は、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定し、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択し、最適ターゲットデータセットを形成する最適ターゲットデータセット生成部であって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、最適ターゲットデータセット生成部と、
前記最適ターゲットデータセットにおけるターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算し、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築する第1の損失関数生成部と、
前記最適ターゲットデータセットにおけるターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築する第2の損失関数生成部と、
前記ソースデータセットにおけるソースデータ及び前記最適ターゲットデータセットにおけるターゲットデータについて、第3の損失関数を構築する第3の損失関数生成部と、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練する訓練部と、を実行する、装置。
(付記13)
ドメイン適応型ニューラルネットワークを訓練するプログラムが記憶された記憶媒体であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記プログラムがコンピュータにより実行される際に、前記コンピュータに、
前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法を実行させる、記憶媒体。
(付記1)
コンピュータにより実行される、ドメイン適応型ニューラルネットワークを訓練する方法であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記コンピュータは、命令が記憶されたメモリ、及びプロセッサを含み、前記命令は、前記プロセッサにより実行される際に、前記プロセッサに前記方法を実行させ、前記方法は、
前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法。
(付記2)
前記識別部が、前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定するステップと、
前記識別部により決定された確率に基づいて第4の損失関数を構築するステップと、
前記第4の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、付記1に記載の方法。
(付記3)
前記識別部により決定された確率の逆数、及び1から前記識別部により決定された確率を減算した差のうちの1つに基づいて、前記第4の損失関数を構築する、付記2に記載の方法。
(付記4)
前記ドメイン適応型ニューラルネットワークは、第2の特徴抽出部及び第2の分類部をさらに含み、
前記方法は、
前記第2の特徴抽出部が前記ソースデータについて第3の特徴を抽出するステップと、
前記第2の分類部が、前記第3の特徴に基づいて、前記ソースデータが前記各クラスに属する確率を予測するステップと、
前記第2の特徴抽出部が前記ターゲットデータについて第4の特徴を抽出するステップと、
前記第2の分類部が、前記第4の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測するステップと、
前記第1の分類部により予測された確率及び前記第2の分類部により予測された確率に基づいて、第5の損失関数を構築するステップと、
前記第5の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、付記2に記載の方法。
(付記5)
前記ソースデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差、並びに前記ターゲットデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差に基づいて、前記第5の損失関数を構築する、付記4に記載の方法。
(付記6)
前記第2の特徴抽出部のパラメータは、前記第1の特徴抽出部のパラメータの指数移動平均であり、前記第2の分類部のパラメータは、前記第1の分類部のパラメータの指数移動平均であり、
前記指数移動平均で使用される減衰率を取得する際に、微分可能な変数を前記減衰率として使用し、或いは、全結合層を使用して前記減衰率を生成し、
前記全結合層は、前記第2の分類部の出力層と並列に前記出力層の直前の層に接続されるように設定されている、付記4に記載の方法。
(付記7)
前記第1の損失関数、前記第2の損失関数、前記第3の損失関数、前記第4の損失関数及び前記第5の損失関数の重み付けされた組み合わせに基づいて、前記ドメイン適応型ニューラルネットワークを訓練し、
訓練の実行に伴って、前記第4の損失関数及び前記第5の損失関数の重みを徐々に増加させる、付記4に記載の方法。
(付記8)
前記第2の損失関数は、前記第1の分類部を訓練するための交差エントロピー損失関数である、付記1に記載の方法。
(付記9)
前記識別部は、勾配反転部を介して前記第1の特徴抽出部に接続され、
前記識別部と前記第1の特徴抽出部とは、互いに敵対的に動作する、付記1に記載の方法。
(付記10)
前記ドメイン適応型ニューラルネットワークは、画像認識を実行するために使用され、前記ソースデータ及び前記ターゲットデータは、画像データであり、或いは、
前記ドメイン適応型ニューラルネットワークは、財務データを処理するために使用され、前記ソースデータ及び前記ターゲットデータは、テーブルタイプのデータであり、或いは、
前記ドメイン適応型ニューラルネットワークは、環境気象データ又は医療データを処理するために使用され、前記ソースデータ及び前記ターゲットデータは、時系列データである、付記1に記載の方法。
(付記11)
ドメイン適応型ニューラルネットワークを訓練する装置であって、
前記ドメイン適応型ニューラルネットワークは、
ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
前記装置は、
プログラムが記憶されたメモリと、
1つ又は複数のプロセッサと、を含み、
前記プロセッサは、前記プログラムを実行することで、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を実行する、装置。
(付記12)
ドメイン適応型ニューラルネットワークを訓練する装置であって、
前記ドメイン適応型ニューラルネットワークは、
ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
前記装置は、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定し、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択し、最適ターゲットデータセットを形成する最適ターゲットデータセット生成部であって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、最適ターゲットデータセット生成部と、
前記最適ターゲットデータセットにおけるターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算し、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築する第1の損失関数生成部と、
前記最適ターゲットデータセットにおけるターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築する第2の損失関数生成部と、
前記ソースデータセットにおけるソースデータ及び前記最適ターゲットデータセットにおけるターゲットデータについて、第3の損失関数を構築する第3の損失関数生成部と、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練する訓練部と、を実行する、装置。
(付記13)
ドメイン適応型ニューラルネットワークを訓練するプログラムが記憶された記憶媒体であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記プログラムがコンピュータにより実行される際に、前記コンピュータに、
前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法を実行させる、記憶媒体。
Claims (10)
- コンピュータにより実行される、ドメイン適応型ニューラルネットワークを訓練する方法であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記コンピュータは、命令が記憶されたメモリ、及びプロセッサを含み、前記命令は、前記プロセッサにより実行される際に、前記プロセッサに前記方法を実行させ、前記方法は、
前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法。 - 前記識別部が、前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定するステップと、
前記識別部により決定された確率に基づいて第4の損失関数を構築するステップと、
前記第4の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、請求項1に記載の方法。 - 前記識別部により決定された確率の逆数、及び1から前記識別部により決定された確率を減算した差のうちの1つに基づいて、前記第4の損失関数を構築する、請求項2に記載の方法。
- 前記ドメイン適応型ニューラルネットワークは、第2の特徴抽出部及び第2の分類部をさらに含み、
前記方法は、
前記第2の特徴抽出部が前記ソースデータについて第3の特徴を抽出するステップと、
前記第2の分類部が、前記第3の特徴に基づいて、前記ソースデータが前記各クラスに属する確率を予測するステップと、
前記第2の特徴抽出部が前記ターゲットデータについて第4の特徴を抽出するステップと、
前記第2の分類部が、前記第4の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測するステップと、
前記第1の分類部により予測された確率及び前記第2の分類部により予測された確率に基づいて、第5の損失関数を構築するステップと、
前記第5の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、請求項2に記載の方法。 - 前記ソースデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差、並びに前記ターゲットデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差に基づいて、前記第5の損失関数を構築する、請求項4に記載の方法。
- 前記第2の特徴抽出部のパラメータは、前記第1の特徴抽出部のパラメータの指数移動平均であり、前記第2の分類部のパラメータは、前記第1の分類部のパラメータの指数移動平均であり、
前記指数移動平均で使用される減衰率を取得する際に、微分可能な変数を前記減衰率として使用し、或いは、全結合層を使用して前記減衰率を生成し、
前記全結合層は、前記第2の分類部の出力層と並列に前記出力層の直前の層に接続されるように設定されている、請求項4に記載の方法。 - 前記第1の損失関数、前記第2の損失関数、前記第3の損失関数、前記第4の損失関数及び前記第5の損失関数の重み付けされた組み合わせに基づいて、前記ドメイン適応型ニューラルネットワークを訓練し、
訓練の実行に伴って、前記第4の損失関数及び前記第5の損失関数の重みを徐々に増加させる、請求項4に記載の方法。 - 前記第2の損失関数は、前記第1の分類部を訓練するための交差エントロピー損失関数である、請求項1に記載の方法。
- ドメイン適応型ニューラルネットワークを訓練する装置であって、
前記ドメイン適応型ニューラルネットワークは、
ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
前記装置は、
プログラムが記憶されたメモリと、
1つ又は複数のプロセッサと、を含み、
前記プロセッサは、前記プログラムを実行することで、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を実行する、装置。 - ドメイン適応型ニューラルネットワークを訓練するプログラムが記憶された記憶媒体であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記プログラムがコンピュータにより実行される際に、前記コンピュータに請求項1乃至8の何れかに記載の方法を実行させる、記憶媒体。
Applications Claiming Priority (2)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202010911149.0 | 2020-09-02 | ||
CN202010911149.0A CN114139676A (zh) | 2020-09-02 | 2020-09-02 | 领域自适应神经网络的训练方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
JP2022042487A true JP2022042487A (ja) | 2022-03-14 |
Family
ID=80438142
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
JP2021136658A Pending JP2022042487A (ja) | 2020-09-02 | 2021-08-24 | ドメイン適応型ニューラルネットワークの訓練方法 |
Country Status (2)
Country | Link |
---|---|
JP (1) | JP2022042487A (ja) |
CN (1) | CN114139676A (ja) |
Cited By (5)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115219180A (zh) * | 2022-07-18 | 2022-10-21 | 西安交通大学 | 基于原型和双重域对抗的旋转机械跨工况故障诊断方法 |
CN116070796A (zh) * | 2023-03-29 | 2023-05-05 | 中国科学技术大学 | 柴油车排放等级评估方法及系统 |
CN116452897A (zh) * | 2023-06-16 | 2023-07-18 | 中国科学技术大学 | 跨域小样本分类方法、系统、设备及存储介质 |
CN117017288A (zh) * | 2023-06-14 | 2023-11-10 | 西南交通大学 | 跨被试情绪识别模型及其训练方法、情绪识别方法、设备 |
WO2024124455A1 (zh) * | 2022-12-14 | 2024-06-20 | 深圳市华大智造软件技术有限公司 | 碱基分类模型的训练方法及系统、碱基分类方法及系统 |
Families Citing this family (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN114399640B (zh) * | 2022-03-24 | 2022-07-15 | 之江实验室 | 一种不确定区域发现与模型改进的道路分割方法及装置 |
CN114445670B (zh) * | 2022-04-11 | 2022-07-12 | 腾讯科技(深圳)有限公司 | 图像处理模型的训练方法、装置、设备及存储介质 |
-
2020
- 2020-09-02 CN CN202010911149.0A patent/CN114139676A/zh active Pending
-
2021
- 2021-08-24 JP JP2021136658A patent/JP2022042487A/ja active Pending
Cited By (8)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN115219180A (zh) * | 2022-07-18 | 2022-10-21 | 西安交通大学 | 基于原型和双重域对抗的旋转机械跨工况故障诊断方法 |
CN115219180B (zh) * | 2022-07-18 | 2024-05-31 | 西安交通大学 | 基于原型和双重域对抗的旋转机械跨工况故障诊断方法 |
WO2024124455A1 (zh) * | 2022-12-14 | 2024-06-20 | 深圳市华大智造软件技术有限公司 | 碱基分类模型的训练方法及系统、碱基分类方法及系统 |
CN116070796A (zh) * | 2023-03-29 | 2023-05-05 | 中国科学技术大学 | 柴油车排放等级评估方法及系统 |
CN117017288A (zh) * | 2023-06-14 | 2023-11-10 | 西南交通大学 | 跨被试情绪识别模型及其训练方法、情绪识别方法、设备 |
CN117017288B (zh) * | 2023-06-14 | 2024-03-19 | 西南交通大学 | 跨被试情绪识别模型及其训练方法、情绪识别方法、设备 |
CN116452897A (zh) * | 2023-06-16 | 2023-07-18 | 中国科学技术大学 | 跨域小样本分类方法、系统、设备及存储介质 |
CN116452897B (zh) * | 2023-06-16 | 2023-10-20 | 中国科学技术大学 | 跨域小样本分类方法、系统、设备及存储介质 |
Also Published As
Publication number | Publication date |
---|---|
CN114139676A (zh) | 2022-03-04 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
JP2022042487A (ja) | ドメイン適応型ニューラルネットワークの訓練方法 | |
CN110472675B (zh) | 图像分类方法、图像分类装置、存储介质与电子设备 | |
CN108694443B (zh) | 基于神经网络的语言模型训练方法和装置 | |
CN113326731B (zh) | 一种基于动量网络指导的跨域行人重识别方法 | |
CN114841257B (zh) | 一种基于自监督对比约束下的小样本目标检测方法 | |
CN113392967A (zh) | 领域对抗神经网络的训练方法 | |
US20190370219A1 (en) | Method and Device for Improved Classification | |
CN110853630B (zh) | 面向边缘计算的轻量级语音识别方法 | |
CN114398855A (zh) | 基于融合预训练的文本抽取方法、系统及介质 | |
CN113255366B (zh) | 一种基于异构图神经网络的方面级文本情感分析方法 | |
CN114022737A (zh) | 对训练数据集进行更新的方法和设备 | |
CN112232395B (zh) | 一种基于联合训练生成对抗网络的半监督图像分类方法 | |
CN112527959A (zh) | 基于无池化卷积嵌入和注意分布神经网络的新闻分类方法 | |
CN115408525A (zh) | 基于多层级标签的信访文本分类方法、装置、设备及介质 | |
CN117648950A (zh) | 神经网络模型的训练方法、装置、电子设备及存储介质 | |
CN115063664A (zh) | 用于工业视觉检测的模型学习方法、训练方法及系统 | |
Deng et al. | Heterogeneous tri-stream clustering network | |
CN114328917A (zh) | 用于确定文本数据的标签的方法和设备 | |
CN116958548B (zh) | 基于类别统计驱动的伪标签自蒸馏语义分割方法 | |
CN117541853A (zh) | 一种基于类别解耦的分类知识蒸馏模型训练方法和装置 | |
CN116645980A (zh) | 一种聚焦样本特征间距的全生命周期语音情感识别方法 | |
CN112951270B (zh) | 语音流利度检测的方法、装置和电子设备 | |
CN114495114A (zh) | 基于ctc解码器的文本序列识别模型校准方法 | |
CN111259860B (zh) | 基于数据自驱动的多阶特征动态融合手语翻译方法 | |
US20230025148A1 (en) | Model optimization method, electronic device, and computer program product |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
A621 | Written request for application examination |
Free format text: JAPANESE INTERMEDIATE CODE: A621 Effective date: 20240509 |