JP2022042487A - ドメイン適応型ニューラルネットワークの訓練方法 - Google Patents

ドメイン適応型ニューラルネットワークの訓練方法 Download PDF

Info

Publication number
JP2022042487A
JP2022042487A JP2021136658A JP2021136658A JP2022042487A JP 2022042487 A JP2022042487 A JP 2022042487A JP 2021136658 A JP2021136658 A JP 2021136658A JP 2021136658 A JP2021136658 A JP 2021136658A JP 2022042487 A JP2022042487 A JP 2022042487A
Authority
JP
Japan
Prior art keywords
target data
loss function
feature
data
label
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Pending
Application number
JP2021136658A
Other languages
English (en)
Inventor
ワン・ジエ
Jie Wang
ジョオン・チャオリアン
Ciao-Lien Zheng
フォン・チョン
Cheng Feng
俊 孫
Shun Son
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Fujitsu Ltd
Original Assignee
Fujitsu Ltd
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Fujitsu Ltd filed Critical Fujitsu Ltd
Publication of JP2022042487A publication Critical patent/JP2022042487A/ja
Pending legal-status Critical Current

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Biophysics (AREA)
  • Biomedical Technology (AREA)
  • Mathematical Physics (AREA)
  • General Health & Medical Sciences (AREA)
  • Health & Medical Sciences (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Image Analysis (AREA)

Abstract

【課題】ドメイン適応型ニューラルネットワークの訓練方法及び装置を提供する。【解決手段】方法は、ソースデータ及びターゲットデータの特徴に基づいてターゲットデータの第1のラベルを予測し、ソースデータセットの各クラスのクラス中心とターゲットデータの特徴との間の距離に基づいて第2のラベルを決定し、ターゲットデータセットから第1のラベルと第2のラベルとが同一のターゲットデータを選択し、ラベルをターゲットデータの疑似ラベルとし、ターゲットデータに基づいてターゲットデータセットの各クラスのクラス中心を計算し、ターゲットデータセットのクラス中心間の距離に基づいて第1の損失関数を構築し、ターゲットデータ及び疑似ラベルに基づいて第2の損失関数を構築し、ソースデータセットのソースデータ及びターゲットデータについて第3の損失関数を構築し、第1~第3の損失関数に基づいてニューラルネットワークを訓練する。【選択図】図3

Description

本発明は、一般的に、ドメイン適応(domain adaptation)関し、具体的には、教師なしドメイン適応のためのニューラルネットワーク及びその訓練(トレーニング)方法に関する。
教師なしドメイン適応とは、ラベル付けされたソースデータを用いて訓練されたモデルをラベル付けされていないデータのターゲットドメインに移行すると共に、該モデルのターゲットドメインでのパフォーマンスを可能な限り維持することを意味する。ソースドメインとターゲットドメインとの間にデータセットの偏差があり、ターゲットドメインにラベル付きデータが不十分であるため、ラベル付けされたソースデータを用いて訓練されたモデルのパフォーマンスがターゲットドメインで低下する場合がある。教師なしドメイン適応の訓練プロセスは、ソースドメインのラベル付きデータ及びターゲットドメインのラベルなしデータの両方を利用することで、ドメイン間の差異を効果的に軽減し、モデルの堅牢性を向上させることができる。
現在、教師なしドメイン適応の主流の方法は、例えば敵対的訓練などのドメイン不変の特徴(domain-invariant features)の学習方法を含む。典型的な敵対的訓練方法は、ドメイン敵対的ニューラルネットワークである。このニューラルネットワークでは、特徴抽出ネットワークの後にドメイン識別器を追加して、特徴がソースドメインからのものであるか、それともターゲットドメインからのものであるかを判断し、特徴抽出ネットワークとドメイン識別器との間に勾配反転層を追加する。ドメイン識別器の損失関数を最小化する場合、特徴抽出ネットワークは、勾配反転層によりドメイン不変の特徴を学習することができる。
さらに、知識の蒸留(knowledge distillation)は、最近、教師なしドメイン適応に導入され、次のような多くの新しい方法が提案されている。例えば、自己アンサンブリングの教師モデルを使用して、生徒モデルにターゲットドメインのラベルなしデータを学習させる。自己アンサンブリングの教師モデルを利用して、より正確なターゲットデータの疑似ラベルを取得する。ソースデータからターゲットデータに類似するデータを蒸留して、事前訓練モデルを微調整する。セマンティックレベル(クラスレベル)でソースドメインの特徴とターゲットドメインの特徴とのアラインメントを行い、即ち、ソースドメインとターゲットドメインとの同一のクラスの平均特徴(クラス中心)を接近させる。
以下は、これらの従来の方法を簡単に紹介する。
図1は、典型的なドメイン敵対的ニューラルネットワークの構造を示している。図1に示すように、ドメイン敵対的ニューラルネットワークは、特徴抽出器F、分類器C、及びドメイン識別器Dを含む。ドメイン識別器Dは、勾配反転層を介して特徴抽出器Fに接続され、勾配反転層は、勾配に特定の負の数を乗算して、特徴抽出器Fに送り返す。Iはラベル付けされたソースデータを表し、Iはラベル付けされていないターゲットデータを表し、両者は何れも特徴抽出器Fに入力される。ソースデータについて特徴抽出器Fにより抽出された特徴は、ソースデータのクラスを予測するために、分類器Cに入力される。また、ソースデータとターゲットデータの両方について特徴抽出器Fにより抽出された特徴が何れもドメイン識別器Dに入力され、ドメイン識別器Dは、入力された特徴に基づいて、現在処理されているデータがソースドメインからのものであるか、それともターゲットドメインからのものであるかを識別する。ドメイン敵対的ニューラルネットワークの訓練では、ソースドメインの分類交差エントロピー損失関数L及びドメイン識別のバイナリ交差エントロピー損失関数Ladvを使用して、損失関数L及びLadvを最小化するように、標準の逆伝播アルゴリズムに従って訓練を行うことで、特徴抽出器Fにドメイン不変の特徴を学習させる。
図2は自己アンサンブリング型教師モデルの構造を示し、ここで、生徒ネットワークのパラメータの指数移動平均を使用して教師ネットワークを構築する。図2では、xSiはラベル付けされたソースデータを表し、xTiはラベル付けされていないターゲットデータを表し、ySiはソースデータの真のラベルを表し、zTiは生徒ネットワークによるターゲットデータの予測確率を表し、
(外1)
Figure 2022042487000002
は教師ネットワークによるターゲットデータの予測確率を表す。
このスキームの前提としては、教師ネットワークの予測正確率が生徒ネットワークの予測正確率よりも高いと仮定する。生徒ネットワークが教師ネットワークの予測確率からターゲットデータの隠れた知識を学習できるため、該スキームは知識の蒸留である。ソースデータxSiについて、生徒ネットワークの予測確率zTiと真のラベルySiに基づく交差エントロピー損失関数が採用されている。ターゲットデータxTiについて、教師ネットワークの予測確率
(外2)
Figure 2022042487000003
と生徒ネットワークの予測確率zTiの平均二乗誤差が損失関数として使用されている。そして、上記の2つの損失関数に対して重み付け加算を行い、最終的な損失関数を取得する。
さらに、セマンティックレベルでの特徴のアラインメントに関して、以下の損失関数が提案されている。
Figure 2022042487000004
Figure 2022042487000005
ここで、Xs,kはソースドメインXにおけるk番目のクラスに属する全てのデータサンプル(真のラベルに基づいて決定される)を表し、Xt,kはターゲットドメインXにおけるk番目のクラスとしてラベル付けされた全てのデータサンプルを表す(疑似ラベルに基づいて決定される)。λs,kは、ソースドメインにおけるk番目のクラスのクラス中心、即ち、k番目のクラスに属する全てのソースデータの特徴Fの平均値を表す。同様に、λt,kは、ターゲットドメインにおけるk番目のクラスのクラス中心、即ち、k番目のクラスとしてラベル付けされた全てのターゲットデータの特徴Fの平均値を表す。ターゲットデータの疑似ラベルは、分類器を使用してターゲットデータのクラスを予測することで取得される。数式(1)に示すセマンティックアラインメント損失関数L(X,X)は、ソースドメインとターゲットドメインにおける同一のクラスのクラス中心の間の距離を表す。
上記の方法は良好な結果を達成しているが、まだ改善する必要な幾つかの問題がある。まず、セマンティックアラインメントでは、ターゲットデータの疑似ラベルの正確さは、ターゲットドメインにおけるクラス中心に大きな影響を与える。分類境界近傍にある一部のデータについて、疑似ラベルが誤っていると、クラス中心の計算結果に大きな偏差が発生する。また、比較学習では、誤った疑似ラベルは、クラス内のデータサンプルのクラスタリング及びクラス間のデータサンプルの分離の制約を損なう。さらに、自己アンサンブリングの平均教師モデルでは、指数移動平均に固定の減衰率が使用される場合が多いが、現在のモデルの性能が可変であるたね、固定の減衰率では現在のモデルの性能に応じてアンサンブリングの速度を調整できない。さらに、蒸留データを使用して微調整を行う場合、この方法では2つの段階が必要であり、中間の切り替え動作が追加され、訓練を一括的に完了することができない。
本発明は、ドメイン適応型ニューラルネットワークの訓練方法及び装置を提供する。
本発明の1つの態様では、コンピュータにより実行される、ドメイン適応型ニューラルネットワークを訓練する方法であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記コンピュータは、命令が記憶されたメモリ、及びプロセッサを含み、前記命令は、前記プロセッサにより実行される際に、前記プロセッサに前記方法を実行させ、前記方法は、前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法を提供する。
本発明のもう1つの態様では、ドメイン適応型ニューラルネットワークを訓練する装置であって、前記ドメイン適応型ニューラルネットワークは、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、前記装置は、プログラムが記憶されたメモリと、1つ又は複数のプロセッサと、を含み、前記プロセッサは、前記プログラムを実行することで、ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を実行する、装置を提供する。
本発明のもう1つの態様では、ドメイン適応型ニューラルネットワークを訓練するプログラムが記憶された記憶媒体であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記プログラムがコンピュータにより実行される際に、前記コンピュータに、前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法を実行させる、記憶媒体を提供する。
従来のドメイン敵対的ニューラルネットワークの構造を概略的に示す図である。 従来の自己アンサンブリング型教師モデルの構造を概略的に示す図である。 本発明に係るドメイン適応型ニューラルネットワークの構造を概略的に示す図である。 本発明に係る自己アンサンブリング型教師モデルの構造を概略的に示す図である。 重みλの曲線を示す図である。 重みλの曲線を示す図である。 本発明に係る最適ターゲットデータセットの生成方法を示すフローチャートである。 本発明に係るドメイン適応型ニューラルネットワークの訓練方法を示すフローチャートである。 本発明に係るドメイン適応型ニューラルネットワークの訓練装置のモジュール化の構成を示すブロック図である。 本発明を実現可能なコンピュータのハードウェアの例示的な構成を示すブロック図である。
図3は、本発明に係る教師なしドメイン適応のニューラルネットワークの構造を概略的に示している。図3に示すように、該ニューラルネットワークは、図1を参照しながら説明されたドメイン敵対的ニューラルネットワークを含み、該ニューラルネットワークは、第1の特徴抽出器310、第1の分類器320、ドメイン識別器330、及び勾配反転層(図示せず)を含む。また、ニューラルネットワークは、第2の特徴抽出器310_T及び第2の分類器320_Tをさらに含む。なお、既存の技術として、図3における第1の特徴抽出器310、第2の特徴抽出器310_T、第1の分類器320、第2の分類器320_T、及びドメイン識別器330は何れも畳み込みニューラルネットワークにより実現されてもよい。本明細書では、これらのユニットを実現する畳み込みニューラルネットワークの構造について詳細に説明しない。
第1の特徴抽出器310及び第1の分類器320は、生徒ネットワークを構成し、第2の特徴抽出器310_T及び第2の分類器320_Tは、教師ネットワークを構成する。第2の(教師)特徴抽出器310_Tのパラメータは、第1の(生徒)特徴抽出器310のパラメータの指数移動平均であり、第2の(教師)分類器320_Tのパラメータは、第1の(生徒)分類器320のパラメータの指数移動平均である。
ソースデータX及びターゲットデータXは、第1の特徴抽出器310及び第2の特徴抽出器310_Tのそれぞれに入力される。第1の特徴抽出器310は、ソースデータX及びターゲットデータXについて抽出された特徴を第1の分類器320に入力し、第2の特徴抽出器310_Tは、ソースデータX及びターゲットデータXについて抽出された特徴を第2の分類器320_Tに入力する。
図3に示されるドメイン適応型ニューラルネットワークの訓練では、本発明は、複数の損失関数を提案し、以下はその詳細を説明する。
本発明の1つの態様では、ターゲットデータの疑似ラベルの精度を向上させるための投票スキーム(voting scheme)が提案される。投票スキームとは、少なくとも2つの予測方式を使用してターゲットデータの予測ラベルに投票することを意味する。例えば、ターゲットデータ
(外3)
Figure 2022042487000006
について、分類器を利用してそのクラスラベルを予測して予測結果
(外4)
Figure 2022042487000007
を取得する。また、クラス中心最近傍アルゴリズムを使用してそのラベルを予測して、以下の数式(2)及び(3)に示すように、予測結果lを取得する。
Figure 2022042487000008
Figure 2022042487000009
ここで、λs,kは、ソースドメインにおけるk番目のクラスのクラス中心、即ち、k番目のクラスに属する全てのソースデータの特性の平均値を表し、Kは、ソースドメインにおける全てのクラスの数を表し、lは、ソースドメインの全てのK個のクラス中心のうち、ターゲットデータ
(外5)
Figure 2022042487000010
に最も近いクラス中心に対応するクラスを表す。
予測ラベルlと予測ラベルlとが一致する場合、該ターゲットデータ
(外6)
Figure 2022042487000011
を選択し、予測ラベルl又はlを該ターゲットデータ
(外7)
Figure 2022042487000012
の疑似ラベルとする。予測ラベルlと予測ラベルlとが一致しない場合、該ターゲットデータ
(外8)
Figure 2022042487000013
を無視する。選択された全てのターゲットデータ
(外9)
Figure 2022042487000014
は、最適ターゲットデータセット
(外10)
Figure 2022042487000015
を構成する。分類器予測のみを実行する場合、又はクラス中心最近傍予測のみを実行する場合に比べて、このようにフィルタリングにより選択されたデータセット
(外11)
Figure 2022042487000016
における各ターゲットデータの疑似ラベルの正確率が高くなる。従って、本発明に係る投票スキームは、予測結果がより正確なターゲットデータを効果的に選別することができる。
なお、上述した分類器予測及びクラス中心最近傍予測は、単なる少なくとも2つの異なる予測方式の一例であり、本発明はこれに限定されず、当業者は他の適切な予測方式を容易に想到できる。
次に、最適ターゲットデータセット
(外12)
Figure 2022042487000017
に基づいて、図3に示すニューラルネットワークを訓練するためのセマンティックアラインメント損失関数L(図3には示されていない)を構築し、該セマンティックアラインメント損失関数Lは、第1の損失関数とも称される。具体的には、K個の所定のクラスのうちのk番目のクラスを一例として、まず、数式(2)に従ってソースドメインにおけるk番目のクラスのクラス中心λs,kを計算し、以下の数式(4)に従って最適ターゲットデータセット
(外13)
Figure 2022042487000018
におけるk番目のクラスのクラス中心λt,kを計算し、数式(5)に従ってクラス中心λs,kとクラス中心λt,kとの間の距離
(外14)
Figure 2022042487000019
を計算する。このように、全てのk個のクラスについて、ソースドメインのクラス中心とターゲットドメインのクラス中心との間の距離を、セマンティックアラインメント損失関数としてそれぞれ計算する。訓練では、距離
(外15)
Figure 2022042487000020
を最小化することを目標とする。
Figure 2022042487000021
Figure 2022042487000022
最適ターゲットデータセット
(外16)
Figure 2022042487000023
におけるターゲットデータの疑似ラベルがより高い精度を有するため、データセット
(外17)
Figure 2022042487000024
を使用して算出されたターゲットドメインのクラス中心λt,kがより正確であり、セマンティクスアラインメントロス損失関数を向上させることができる。
さらに、最適ターゲットデータセット
(外18)
Figure 2022042487000025
におけるターゲットデータ及びその疑似ラベルを使用して、図3に示す第1の分類器320を訓練するための交差エントロピー損失関数(図3における
(外19)
Figure 2022042487000026
)を構築してもよく、該交差エントロピー損失関数は第2の損失関数とも称され、具体的には、以下の数式(6)に示すものである。
Figure 2022042487000027
ここで、
(外20)
Figure 2022042487000028
が最適ターゲットデータセット
(外21)
Figure 2022042487000029
におけるターゲットデータ
(外22)
Figure 2022042487000030
に対してラベルを予測することを表す場合、予測結果は、その疑似ラベルの確率である。
従来技術では、通常、真のラベルを有するソースデータのみを使用して第1の分類器320を訓練するが、本発明では、最適ターゲットデータセット
(外23)
Figure 2022042487000031
におけるターゲットデータの疑似ラベルの精度が比較的に高いため、本発明は、最適ターゲットデータセット
(外24)
Figure 2022042487000032
をさらに使用して第1の分類器320を訓練し、これによって、ネットワークモデルのターゲットデータへの認識能力を向上させることができる。
さらに、最適ターゲットデータセット
(外25)
Figure 2022042487000033
におけるターゲットデータをソースデータと共に対比学習に使用することで、次の効果を実現することができる。クラス内の特徴を収束させるように制約すると共に、異なるクラスの特徴間の距離が大きくなるようにクラス間の特徴を離間させる。ここで、例えば数式(7)に示す対比学習損失関数Lcon(図3に示されていない)を構築してもよく、該対比学習損失関数Lconは第3の損失関数とも称される。
Figure 2022042487000034
ここで、x又はxは、ソースデータセット及び最適ターゲットデータセット
(外26)
Figure 2022042487000035
におけるデータサンプルを表し、f(x)及びf(x)は、データサンプルの特徴を表す。δijは、指示変数であり、xとxが同一のクラスのデータである場合、δijは1であり、xとxが異なるクラスのデータである場合、δijは0である。d(f(x),f(x))は、データxの特徴とデータxとの間の距離を表す。mは定数であり、例えば、m=3。
上述したように、教師なし領域適応に使用される現在の知識蒸留方法は、指数移動平均を使用して教師ネットワークを構築するが、その減衰率は通常固定値に設定されているため、性能が優れた教師ネットワークを取得することは困難である。具体的には、指数移動平均とは、以下の数式(8)に示すように、特定の減衰率に基づいて教師ネットワークのパラメータをゆっくりと更新することを意味する。
Figure 2022042487000036
ここで、Sは、生徒ネットワークの現在のパラメータを表し、Tは、教師ネットワークの現在のパラメータ(更新後のパラメータ)を表し、Tt-1は、教師ネットワークの前のパラメータ(未更新のパラメータ)を表し、減衰率decayは、通常0.99に固定的に設定されている。
本発明のもう1つの態様では、本発明は、教師モデルの性能を改善するために、自己学習の減衰率を提案している。「自己学習」とは、減衰率が学習可能なパラメータ又は学習済みネットワーク(learnt network)の出力であることを意味する。本発明では、微分可能な変数を減衰率として使用し、或いは、1つの全結合層の出力を減衰率として使用してもよい。後者の場合、例えば該全結合層が出力層と並列に出力層の直前の層に接続されるように、該全結合層を第2の分類器320_Tの出力層と同一のレベルに設定してもよい。このように設定された減衰率が固定値ではなく、モデルのパフォーマンスの変化に応じてアンサンブリングの速度を調整することができるため、知識蒸留のパフォーマンスを向上させることができる。
また、本発明のもう1つの態様では、本発明は、ドメイン識別器に基づくデータ蒸留が提案されている。具体的には、ソースデータを使用して交差エントロピー損失関数に基づいて分類器を訓練する場合、ソースデータのうちのターゲットデータと類似するソースデータにより高い重みを与える。これによって、ターゲットデータとの類似度が高いソースデータが訓練でより大きな役割を果たすことができるため、訓練により取得された分類器は、ターゲットドメインでより優れたパフォーマンスを実現することができる。
ドメイン識別器の出力を使用して、ターゲットデータとの類似度が高いソースデータであるか否かを判断してもよい。ドメイン識別器は、現在のデータがソースデータである確率を予測できるため、この確率が小さいほど、現在のデータとターゲットデータとの類似度が高いことを意味する。言い換えれば、ドメイン識別器により出力された確率と類似度との間には反比例の関係がある。従って、ドメイン識別器の出力を使用してソースデータを重み付けしてもよい。
この原理に従って、以下の数式(9)又は(10)に示すように、図3に示すニューラルネットワークを訓練するためのデータ蒸留損失関数Ldd(図3に示されていない)を構築してもよく、該データ蒸留損失関数Lddは、第4の損失関数とも称される。
Figure 2022042487000037
Figure 2022042487000038
ここで、pは、ソースデータについてラベルを予測する場合、予測結果がその真のラベルである確率を表す。pは、ドメイン識別器により決定されたソースデータがソースドメインからのものである確率を表し、1-p又は1/pは、該ソースデータに割り当てられた重みを表す。
ドメイン識別器により決定された確率pが比較的に小さい場合(現在のソースデータとターゲットデータとの類似度が比較的に高いことを意味する)、1-p又は1/pの値が大きくなるため、現在のソースデータに与えられる重みが大きくなる。従って、(ターゲットデータと類似する)現在のソースデータは、訓練でより大きな役割を果たすことができる。
また、本発明のもう1つの態様では、本発明は、図2に示す自己アンサンブリング型教師モデルの構造をさらに改善する。図4は改善されたネットワーク構造を示している。
図4に示すように、ソースデータxSi及びターゲットデータxTiが生徒ネットワークだけでなく、教師ネットワークにも入力される。これに対して、図2では、ターゲットデータxTiのみが教師ネットワークに入力される。従って、本発明は、ターゲットドメインに対して蒸留学習を行うだけでなく、ソースドメインに対しても蒸留学習を行う。
図4では、ySiは、ソースデータxSiの真のラベルを表し、zTiは、生徒ネットワークによりターゲットデータxTiについて予測された確率(即ち、ターゲットデータxTiが各クラスに属する確率)を表し、
(外27)
Figure 2022042487000039
は、教師ネットワークによりターゲットデータxTiについて予測された確率を表し、ZSiは、生徒ネットワークによりソースデータxSiについて予測された確率(即ち、ソースデータxSiが各クラスに属する確率)を表し、
(外28)
Figure 2022042487000040
は、教師ネットワークによりソースデータxSiについて予測された確率を表す。また、図4における生徒ネットワークは、図3に示す第1の特徴抽出器310及び第1の分類器320を含んでもよく、図4における教師ネットワークは、図3に示す第2の特徴抽出器310_T及び第2の分類器320_Tを含んでもよく、上記の各予測確率は、第1の分類器320又は第2の分類器320_Tにより生成されてもよい。
以下の数式(11)に示すように、上記の予測確率に基づいて、図3に示すニューラルネットワークを訓練するための知識蒸留損失関数Lkd(図3におけるLkd-s及びLkd-tを含む)を構築してもよく、該知識蒸留損失関数Lkdは、第5の損失関数とも称される。
Figure 2022042487000041
ここで、
(外29)
Figure 2022042487000042
は、ソースデータxSiについて第1の分類器320及び第2の分類器320_Tのそれぞれにより予測された確率の平均二乗誤差を表し、
(外30)
Figure 2022042487000043
は、ターゲットデータxTiについて第1の分類器320及び第2の分類器320_Tのそれぞれにより予測された確率の平均二乗誤差を表す。nは、ソースデータの数を表し、mは、ターゲットデータの数を表す。
数式(12)に示すように、上述した第1の損失関数~第5の損失関数に基づいて、図3に示すニューラルネットワークを訓練するための最終損失関数Lを構築してもよい。
Figure 2022042487000044
ここで、Lc-sは、ソースデータについての分類交差エントロピー損失関数を表し、図1に示す損失関数Lと同一である。Ladvは、ドメイン識別器のバイナリ交差エントロピー損失関数を表し、図1に示す損失関数Ladvと同一である。損失関数Lc-s及びLadvは、従来技術における既知の損失関数であるため、本明細書ではそれらの詳細な説明を省略する。
また、数式(12)におけるλとλは、それぞれ第4の損失関数Lkdと第5の損失関数Lddを重み付けするための重みであり、訓練プロセスにおいて第4の損失関数と第5の損失関数の作用の程度を制御してもよい。具体的には、重みλは、数式(13)に従って決定されてもよい。
Figure 2022042487000045
ここで、p=step/totalstep、即ち、現在の反復ステップ数を訓練ステップの総数で除算した商であるため、pは訓練の進行状況を表すことができる。α及びnは、ハイパーパラメータを表し、例えば、α=200、n=10に設定してもよい。図5は、重みλの訓練ステップ数の増加に伴って変化する曲線を示す図である(訓練ステップの総数が5000であると仮定する)。
重みλ2は、数式(14)に従って決定されてもよい。
Figure 2022042487000046
ここで、pは、数式(13)におけるpと同一の意味を持つ。α及びnは、ハイパーパラメータを表し、例えば、α=5、n=10に設定してもよい。図6は、重みλの訓練ステップ数の増加に伴って変化する曲線を示す図である(訓練ステップの総数が5000であると仮定する)。
図5及び図6に示すように、訓練の開始段階において、分類器の予測及びドメイン識別器の予測が何れも正確ではないため、好ましくは、λ及びλの値を小さく設定し、訓練の進行に伴って、教師ネットワークの分類器及びドメイン識別器の予測が徐々に正確になるため、λ及びλの値を徐々に増大させてもよい。これによって、知識蒸留損失関数Lkd及びデータ蒸留損失関数Lddはより大きな役割を果たすことができる。
図7は、本発明に係る最適ターゲットデータセットの生成方法を示すフローチャートである。該方法は、図9における最適ターゲットデータセット生成部960により実行されてもよい。
図7に示すように、ステップS710において、第1の特徴抽出器310がソースデータについて特徴を抽出し、第1の分類器320が抽出された特徴に基づいて、ソースデータが複数の所定クラスのうちの各クラスに属する確率を予測する。最大の確率に対応するクラスを該ソースデータのラベルとして決定する。
ステップS720において、第1の特徴抽出器310がターゲットデータについて特徴を抽出し、第1の分類器320が抽出された特徴に基づいて、ターゲットデータが各クラスに属する確率を予測する。最大の確率に対応するクラスは、該ターゲットデータの第1のラベルとして決定される。
ステップS730において、数式(2)及び(3)に従って、クラス中心最近傍アルゴリズムを使用して該ターゲットデータの第2のラベルを決定する。
ステップS740において、決定された第1のラベルと第2のラベルとが同一であるターゲットデータを選択し、該第1のラベル又は第2のラベルは、選択されたターゲットデータの疑似ラベルとされる。そして、全ての選択されたターゲットデータは、最適ターゲットデータセットを構成してもよい。
図8は、本発明に係るドメイン適応型ニューラルネットワークの訓練方法を示すフローチャートであり、図9は、本発明に係るドメイン適応型ニューラルネットワークの訓練装置のモジュール化の構成を示すブロック図である。
図8に示すように、ステップS810において、数式(2)、(4)及び(5)に従って、ソースデータセットのクラス中心と最適ターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数L(セマンティックアラインメント損失関数)を構築する。このステップは、図9における第1の損失関数生成部910により実行されてもよい。
ステップS820において、数式(6)に従って、最適ターゲットデータセットにおけるターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数
(外31)
Figure 2022042487000047
(交差エントロピー損失関数)を構築する。このステップは、図9における第2の損失関数生成部920により実行されてもよい。
ステップS830において、数式(7)に従って、ソースデータセットにおけるソースデータ及び最適ターゲットデータセットにおけるターゲットデータについて、第3の損失関数Lcon(対比学習損失関数)を構築する。このステップは、図9における第3の損失関数生成部930により実行されてもよい。
図9から分かるように、図7に示す方法により生成された最適ターゲットデータセットは、第1の損失関数~第3の損失関数を構築するために使用される。
次に、ステップS840において、数式(9)又は(10)に従って、ドメイン識別器により出力された確率に基づいて、第4の損失関数Ldd(データ蒸留損失関数)を構築する。このステップは、図9における第4の損失関数生成部940により実行されてもよい。
ステップS850において、第2の(教師)特徴抽出器310_Tがソースデータ及びターゲットデータの特徴を抽出し、第2の(教師)分類器320_Tがソースデータ及びターゲットデータのラベルを予測する。次に、ステップS860において、数式(11)に従って、第1の分類器320の予測結果及び第2の分類器320_Tの予測結果に基づいて、第5の損失関数Lkd(知識蒸留損失関数)を構築する。ステップS860は、図9における第5の損失関数生成部950により実行されてもよい。
次に、ステップS870において、数式(12)に従って、第1の損失関数~第5の損失関数の重み付けされた組み合わせに基づいて、ニューラルネットワークを訓練する。このステップは、図9における訓練部970により実行されてもよい。
なお、本発明の訓練方法は、必ずしも図8に示す順序で実行する必要がない。例えば、第1の損失関数~第5の損失関数の生成の順序は、図示されているものとは異なってもよいし、同時に生成されてもよい。
本発明の発明者は、MNIST、USPS、SVHN(何れも既知の文字データセットである)に基づいてテストを行った。ここで、MNIST→USPS、USPS→MNIST、SVHN→MNISTの3つの方向のドメイン適応を含む。以下の表1は、本発明の解決手段と従来技術(ADDA、DANNなど)とのパフォーマンスの比較を示している。表1における値は、分類の正解率を表し、正解率が高いほど、解決手段のパフォーマンスが良くなる。表1から分かるように、本発明の解決手段のパフォーマンスは、従来技術のパフォーマンスと同等であり、或いはそれよりも優れている。
Figure 2022042487000048
特に、表1における「source only」は、ターゲットデータを使用せず、ソースデータのみを使用して訓練を行う方式、即ち最も単純な方式を表し、比較の基準とされる。DANN(Domain-Adversarial Training of Neural Networks)は、図1に示すドメイン敵対的ニューラルネットワークを表し、ADDA(Adversarial Discriminative Domain Adaptation)は、敵対的識別ドメイン適応を表す。CAT+RevGradは、「Cluster Alignment with a Teacher for Unsupervised Domain Adaptation[C]」、Deng Z et al.、Proceedings of IEEE International Conference on Computer Vision、2019:9944-9953という技術文献に記載されている。
本発明に係るドメイン適応方法は、幅広いドメインに適用することができ、以下は、代表的な適用シナリオを例示して説明する。
[応用シナリオ1]セマンティックセグメンテーション(semantic segmentation)
セマンティックセグメンテーションとは、画像における異なる物体を表す部分を異なる色でマークすることを意味する。セマンティックセグメンテーションの応用シナリオでは、実世界の画像に対して手動でラベルを付けるコストが非常に高いため、実世界の画像はラベルが殆ど付いていない。この場合、代替的な方法として、シミュレーション環境(例えば3Dゲーム)におけるシーンの画像を用いて訓練を行う。プログラミングによりシミュレーション環境において物体の自動的なラベル付けを容易に実現できるため、ラベル付きデータを簡単に取得することができる。このように、シミュレーション環境において生成されたラベル付きデータを用いてモデルを訓練し、訓練されたモデルを用いて実際の環境の画像を処理する。しかし、シミュレーション環境は、実際の環境と完全に一致するものではないため、シミュレーション環境のデータを用いて訓練されたモデルのパフォーマンスは、実際の環境の画像を処理する場合に大幅に低下してしまう。
この場合、本発明に係るドメイン適応方法を用いて、ラベル付けされたシミュレーション環境のデータ及びラベル付けされていない実際環境のデータに基づいて訓練を行うことができるため、実際の環境の画像を処理する際のモデルのパフォーマンスを向上させることができる。
[応用シナリオ2]手書き文字の認識
手書き文字は、通常、手書きの数字、文字(例えば中国語、日本語)などを含む。手書き文字の認識では、一般的に使用されるラベル付き文字セットは、MNIST、USPS、SVHNなどを含み、通常、これらのラベル付き文字データを用いてモデルを訓練する。しかし、訓練済みモデルを実際(ラベルなし)の手書き文字の認識に適用する場合、その正確率が低下する可能性がある。
この場合、本発明に係るドメイン適応方法を用いることで、ラベル付けされたソースデータ及びラベル付けされていないターゲットデータに基づいて訓練を行うことができるため、ターゲットデータを処理する際のモデルのパフォーマンスを向上させることができる。
[応用シナリオ3]時系列データの分類と予測
時系列データの予測は、例えば、大気汚染指数の予測、ICU患者の入院期間(LOS)の予測、株式市場の予測などを含む。微粒子状物質(PM2.5)指数の時系列データを一例として、ラベル付きの訓練サンプルセットを用いて予測モデルを訓練してもよい。訓練が完了した後、訓練済みのモデルを実際の予測に適用してもよい。例えば、現時点の直前の24時間のデータ(ラベルなしデータ)に基づいて、3日間後のPM2.5指数の範囲を予測してもよい。
このシナリオでは、本発明に係るドメイン適応方法を用いることで、ラベル付けされたデータ及びラベル付けされていないデータに基づいてモデルを訓練することができるため、モデルの予測正確度を向上させることができる。
[応用シナリオ4]テーブルタイプのデータの分類と予測
テーブルタイプのデータは、オンラインローンデータなどの財務データを含んでもよい。この例では、ローンの組み方の返済延滞の可能性があるか否かを予測するために、予測モデルを構築してもよく、本発明に係る方法を用いてモデルを訓練してもよい。
[応用シナリオ5]画像認識
画像認識又は画像分類の応用シナリオでは、セマンティックセグメンテーションのシナリオと同様に、実世界の画像データセットにラベルを付けるコストが高いという問題もある。従って、要件を満たすパフォーマンスを有するモデルを得るように、本発明に係るドメイン適応方法を用いて、ラベル付きデータセット(例えばImageNet)をソースデータセットとして選択し、該ソースデータセット及びラベルなしターゲットデータセットに基づいて訓練を行ってもよい。
以上は具体的な実施例を参照しながら本発明の実施形態を説明した。上記の実施例に係る方法は、ソフトウェア、ハードウェア、又はソフトウェアとハードウェアとの組み合わせにより実現されてもよい。ソフトウェアに含まれるプログラムは、装置の内部又は外部に設置された記憶媒体に予め記憶されてもよい。一例として、実行中に、これらのプログラムはランダムアクセスメモリ(RAM)に書き込まれ、プロセッサ(例えばCPU)により実行されることで、本明細書で説明された各処理を実現する。
図10は本発明を実現可能なコンピュータのハードウェアの例示的な構成を示すブロック図であり、該コンピュータのハードウェアは、本発明に係るドメイン適応型ニューラルネットワークの訓練装置の一例である。また、本発明に係るドメイン適応型ニューラルネットワークも、該コンピュータハードウェアに基づいて実現されてもよい。
図10に示すように、コンピュータ1000では、中央処理装置(CPU)1001、読み出し専用メモリ(ROM)1002及びランダムアクセスメモリ(RAM)1003がバス1004により相互に接続されている。
入力/出力インターフェース1005は、バス1004にさらに接続されている。入力/出力インターフェース1005には、キーボード、マウス、マイクロフォンなどにより構成された入力部1006、ディスプレイ、スピーカなどにより構成された出力部1007、ハードディスク、不揮発性メモリなどにより構成された記憶部1008、ネットワークインターフェースカード(ローカルエリアネットワーク(LAN)カード、モデムなど)により構成された通信部1009、及び移動可能な媒体1011をドライブするドライバ1010が接続されている。移動可能な媒体1011は、例えば磁気ディスク、光ディスク、光磁気ディスク又は半導体メモリである。
上記の構成を有するコンピュータにおいて、CPU1001は、記憶部1008に記憶されているプログラムを、入力/出力インターフェース1005及びバス1004を介してRAM1003にロードし、プログラムを実行することにより、上記の方法を実行する。
コンピュータ(CPU1001)により実行されるプログラムは、パッケージ媒体である移動可能な媒体1011に記録されてもよい。該パッケージ媒体は、例えば磁気ディスク(フロッピーディスクを含む)、光ディスク(コンパクトディスクリードオンリーメモリ(CD-ROM)、デジタルバーサタイルディスク(DVD)などを含む)、光磁気ディスク、又は半導体メモリにより形成される。また、コンピュータ(CPU1001)により実行されるプログラムは、ローカルエリアネットワーク、インターネット、デジタル衛星放送の有線又は無線の伝送媒体を介して提供されてもよい。
移動可能な媒体1011がドライバ1010にインストールされると、プログラムは、入力/出力インターフェース1005を介して記憶部1008にインストールすることができる。また、プログラムは、有線又は無線の伝送媒体を介して通信部1009で受信され、記憶部1008にインストールされる。或いは、プログラムは、ROM1002又は記憶部1008に予めインストールされてもよい。
コンピュータにより実行されるプログラムは、本明細書で説明する順序に従って処理を実行するプログラムであってもよいし、処理を並列的に実行し、或いは必要に応じて(例えば呼び出しの時に)処理を実行するプログラムであってもよい。
本明細書で説明されている装置又はユニットは論理的なものであり、物理的な装置又はエンティティに限定されない。例えば、本明細書で説明されている各ユニットの機能は複数の物理エンティティにより実現されてもよいし、本明細書で説明される複数のユニットの機能は単一の物理エンティティにより実現されてもよい。また、1つの実施例で説明される特徴、構成要素、要素、ステップなどは、該実施例に限定されず、例えば、他の実施例に適用されてもよく、例えば他の実施例の特定の特徴、構成要素、要素、ステップなどの代わりに用いてもよいし、それと組み合わせてもよい。
本発明の範囲は、本明細書に記載の具体的な実施例に限定されない。当業者により理解できるように、設計要求及び他の要因に応じて、本発明の原理及び要旨から逸脱することなく、本明細書の実施例に対して様々な修正又は変更を行ってもよい。本発明の範囲は、添付の特許請求の範囲及びその均等物により制限される。
また、上述の各実施例を含む実施形態に関し、更に以下の付記を開示するが、これらの付記に限定されない。
(付記1)
コンピュータにより実行される、ドメイン適応型ニューラルネットワークを訓練する方法であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記コンピュータは、命令が記憶されたメモリ、及びプロセッサを含み、前記命令は、前記プロセッサにより実行される際に、前記プロセッサに前記方法を実行させ、前記方法は、
前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法。
(付記2)
前記識別部が、前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定するステップと、
前記識別部により決定された確率に基づいて第4の損失関数を構築するステップと、
前記第4の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、付記1に記載の方法。
(付記3)
前記識別部により決定された確率の逆数、及び1から前記識別部により決定された確率を減算した差のうちの1つに基づいて、前記第4の損失関数を構築する、付記2に記載の方法。
(付記4)
前記ドメイン適応型ニューラルネットワークは、第2の特徴抽出部及び第2の分類部をさらに含み、
前記方法は、
前記第2の特徴抽出部が前記ソースデータについて第3の特徴を抽出するステップと、
前記第2の分類部が、前記第3の特徴に基づいて、前記ソースデータが前記各クラスに属する確率を予測するステップと、
前記第2の特徴抽出部が前記ターゲットデータについて第4の特徴を抽出するステップと、
前記第2の分類部が、前記第4の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測するステップと、
前記第1の分類部により予測された確率及び前記第2の分類部により予測された確率に基づいて、第5の損失関数を構築するステップと、
前記第5の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、付記2に記載の方法。
(付記5)
前記ソースデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差、並びに前記ターゲットデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差に基づいて、前記第5の損失関数を構築する、付記4に記載の方法。
(付記6)
前記第2の特徴抽出部のパラメータは、前記第1の特徴抽出部のパラメータの指数移動平均であり、前記第2の分類部のパラメータは、前記第1の分類部のパラメータの指数移動平均であり、
前記指数移動平均で使用される減衰率を取得する際に、微分可能な変数を前記減衰率として使用し、或いは、全結合層を使用して前記減衰率を生成し、
前記全結合層は、前記第2の分類部の出力層と並列に前記出力層の直前の層に接続されるように設定されている、付記4に記載の方法。
(付記7)
前記第1の損失関数、前記第2の損失関数、前記第3の損失関数、前記第4の損失関数及び前記第5の損失関数の重み付けされた組み合わせに基づいて、前記ドメイン適応型ニューラルネットワークを訓練し、
訓練の実行に伴って、前記第4の損失関数及び前記第5の損失関数の重みを徐々に増加させる、付記4に記載の方法。
(付記8)
前記第2の損失関数は、前記第1の分類部を訓練するための交差エントロピー損失関数である、付記1に記載の方法。
(付記9)
前記識別部は、勾配反転部を介して前記第1の特徴抽出部に接続され、
前記識別部と前記第1の特徴抽出部とは、互いに敵対的に動作する、付記1に記載の方法。
(付記10)
前記ドメイン適応型ニューラルネットワークは、画像認識を実行するために使用され、前記ソースデータ及び前記ターゲットデータは、画像データであり、或いは、
前記ドメイン適応型ニューラルネットワークは、財務データを処理するために使用され、前記ソースデータ及び前記ターゲットデータは、テーブルタイプのデータであり、或いは、
前記ドメイン適応型ニューラルネットワークは、環境気象データ又は医療データを処理するために使用され、前記ソースデータ及び前記ターゲットデータは、時系列データである、付記1に記載の方法。
(付記11)
ドメイン適応型ニューラルネットワークを訓練する装置であって、
前記ドメイン適応型ニューラルネットワークは、
ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
前記装置は、
プログラムが記憶されたメモリと、
1つ又は複数のプロセッサと、を含み、
前記プロセッサは、前記プログラムを実行することで、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を実行する、装置。
(付記12)
ドメイン適応型ニューラルネットワークを訓練する装置であって、
前記ドメイン適応型ニューラルネットワークは、
ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
前記装置は、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定し、前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択し、最適ターゲットデータセットを形成する最適ターゲットデータセット生成部であって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、最適ターゲットデータセット生成部と、
前記最適ターゲットデータセットにおけるターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算し、前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築する第1の損失関数生成部と、
前記最適ターゲットデータセットにおけるターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築する第2の損失関数生成部と、
前記ソースデータセットにおけるソースデータ及び前記最適ターゲットデータセットにおけるターゲットデータについて、第3の損失関数を構築する第3の損失関数生成部と、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練する訓練部と、を実行する、装置。
(付記13)
ドメイン適応型ニューラルネットワークを訓練するプログラムが記憶された記憶媒体であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記プログラムがコンピュータにより実行される際に、前記コンピュータに、
前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法を実行させる、記憶媒体。

Claims (10)

  1. コンピュータにより実行される、ドメイン適応型ニューラルネットワークを訓練する方法であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記コンピュータは、命令が記憶されたメモリ、及びプロセッサを含み、前記命令は、前記プロセッサにより実行される際に、前記プロセッサに前記方法を実行させ、前記方法は、
    前記第1の特徴抽出部が、ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出するステップと、
    前記第1の分類部が、前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測するステップと、
    前記第1の特徴抽出部が、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出するステップと、
    前記第1の分類部が、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定するステップと、
    ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
    前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
    選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
    前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
    選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
    前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
    前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、方法。
  2. 前記識別部が、前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定するステップと、
    前記識別部により決定された確率に基づいて第4の損失関数を構築するステップと、
    前記第4の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、請求項1に記載の方法。
  3. 前記識別部により決定された確率の逆数、及び1から前記識別部により決定された確率を減算した差のうちの1つに基づいて、前記第4の損失関数を構築する、請求項2に記載の方法。
  4. 前記ドメイン適応型ニューラルネットワークは、第2の特徴抽出部及び第2の分類部をさらに含み、
    前記方法は、
    前記第2の特徴抽出部が前記ソースデータについて第3の特徴を抽出するステップと、
    前記第2の分類部が、前記第3の特徴に基づいて、前記ソースデータが前記各クラスに属する確率を予測するステップと、
    前記第2の特徴抽出部が前記ターゲットデータについて第4の特徴を抽出するステップと、
    前記第2の分類部が、前記第4の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測するステップと、
    前記第1の分類部により予測された確率及び前記第2の分類部により予測された確率に基づいて、第5の損失関数を構築するステップと、
    前記第5の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を含む、請求項2に記載の方法。
  5. 前記ソースデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差、並びに前記ターゲットデータについて前記第1の分類部及び前記第2の分類部のそれぞれにより予測された確率の平均二乗誤差に基づいて、前記第5の損失関数を構築する、請求項4に記載の方法。
  6. 前記第2の特徴抽出部のパラメータは、前記第1の特徴抽出部のパラメータの指数移動平均であり、前記第2の分類部のパラメータは、前記第1の分類部のパラメータの指数移動平均であり、
    前記指数移動平均で使用される減衰率を取得する際に、微分可能な変数を前記減衰率として使用し、或いは、全結合層を使用して前記減衰率を生成し、
    前記全結合層は、前記第2の分類部の出力層と並列に前記出力層の直前の層に接続されるように設定されている、請求項4に記載の方法。
  7. 前記第1の損失関数、前記第2の損失関数、前記第3の損失関数、前記第4の損失関数及び前記第5の損失関数の重み付けされた組み合わせに基づいて、前記ドメイン適応型ニューラルネットワークを訓練し、
    訓練の実行に伴って、前記第4の損失関数及び前記第5の損失関数の重みを徐々に増加させる、請求項4に記載の方法。
  8. 前記第2の損失関数は、前記第1の分類部を訓練するための交差エントロピー損失関数である、請求項1に記載の方法。
  9. ドメイン適応型ニューラルネットワークを訓練する装置であって、
    前記ドメイン適応型ニューラルネットワークは、
    ラベル付けされたソースデータセットにおけるソースデータについて第1の特徴を抽出し、ラベル付けされていないターゲットデータセットにおけるターゲットデータについて第2の特徴を抽出する第1の特徴抽出部と、
    前記第1の特徴に基づいて、前記ソースデータが複数のクラスのうちの各クラスに属する確率を予測し、前記第2の特徴に基づいて、前記ターゲットデータが前記各クラスに属する確率を予測し、最大の確率に対応するクラスを前記ターゲットデータの第1のラベルとして決定する第1の分類部と、
    前記第1の特徴及び前記第2の特徴に基づいて、現在入力されたデータがソースデータである確率を決定する識別部と、を含み、
    前記装置は、
    プログラムが記憶されたメモリと、
    1つ又は複数のプロセッサと、を含み、
    前記プロセッサは、前記プログラムを実行することで、
    ソースデータセットの前記各クラスについてのクラス中心と前記ターゲットデータの特徴との間の距離を計算し、距離が最も近いクラス中心に対応するクラスを前記ターゲットデータの第2のラベルとして決定するステップと、
    前記ターゲットデータセットから、決定された前記第1のラベルと前記第2のラベルとが同一であるターゲットデータを選択するステップであって、前記第1のラベル又は前記第2のラベルは、選択されたターゲットデータの疑似ラベルとされる、ステップと、
    選択されたターゲットデータに基づいて、前記ターゲットデータセットの前記各クラスについてのクラス中心を計算するステップと、
    前記ソースデータセットのクラス中心と計算されたターゲットデータセットのクラス中心との間の距離に基づいて、第1の損失関数を構築するステップと、
    選択されたターゲットデータ及びその疑似ラベルに基づいて、第2の損失関数を構築するステップと、
    前記ソースデータセットにおけるソースデータ及び選択されたターゲットデータについて、第3の損失関数を構築するステップと、
    前記第1の損失関数、前記第2の損失関数及び前記第3の損失関数に基づいて前記ドメイン適応型ニューラルネットワークを訓練するステップと、を実行する、装置。
  10. ドメイン適応型ニューラルネットワークを訓練するプログラムが記憶された記憶媒体であって、前記ドメイン適応型ニューラルネットワークは、第1の特徴抽出部、第1の分類部、及び識別部を含み、前記プログラムがコンピュータにより実行される際に、前記コンピュータに請求項1乃至8の何れかに記載の方法を実行させる、記憶媒体。
JP2021136658A 2020-09-02 2021-08-24 ドメイン適応型ニューラルネットワークの訓練方法 Pending JP2022042487A (ja)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202010911149.0 2020-09-02
CN202010911149.0A CN114139676A (zh) 2020-09-02 2020-09-02 领域自适应神经网络的训练方法

Publications (1)

Publication Number Publication Date
JP2022042487A true JP2022042487A (ja) 2022-03-14

Family

ID=80438142

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2021136658A Pending JP2022042487A (ja) 2020-09-02 2021-08-24 ドメイン適応型ニューラルネットワークの訓練方法

Country Status (2)

Country Link
JP (1) JP2022042487A (ja)
CN (1) CN114139676A (ja)

Cited By (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115219180A (zh) * 2022-07-18 2022-10-21 西安交通大学 基于原型和双重域对抗的旋转机械跨工况故障诊断方法
CN116070796A (zh) * 2023-03-29 2023-05-05 中国科学技术大学 柴油车排放等级评估方法及系统
CN116452897A (zh) * 2023-06-16 2023-07-18 中国科学技术大学 跨域小样本分类方法、系统、设备及存储介质
CN117017288A (zh) * 2023-06-14 2023-11-10 西南交通大学 跨被试情绪识别模型及其训练方法、情绪识别方法、设备
WO2024124455A1 (zh) * 2022-12-14 2024-06-20 深圳市华大智造软件技术有限公司 碱基分类模型的训练方法及系统、碱基分类方法及系统

Families Citing this family (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN114399640B (zh) * 2022-03-24 2022-07-15 之江实验室 一种不确定区域发现与模型改进的道路分割方法及装置
CN114445670B (zh) * 2022-04-11 2022-07-12 腾讯科技(深圳)有限公司 图像处理模型的训练方法、装置、设备及存储介质

Cited By (8)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN115219180A (zh) * 2022-07-18 2022-10-21 西安交通大学 基于原型和双重域对抗的旋转机械跨工况故障诊断方法
CN115219180B (zh) * 2022-07-18 2024-05-31 西安交通大学 基于原型和双重域对抗的旋转机械跨工况故障诊断方法
WO2024124455A1 (zh) * 2022-12-14 2024-06-20 深圳市华大智造软件技术有限公司 碱基分类模型的训练方法及系统、碱基分类方法及系统
CN116070796A (zh) * 2023-03-29 2023-05-05 中国科学技术大学 柴油车排放等级评估方法及系统
CN117017288A (zh) * 2023-06-14 2023-11-10 西南交通大学 跨被试情绪识别模型及其训练方法、情绪识别方法、设备
CN117017288B (zh) * 2023-06-14 2024-03-19 西南交通大学 跨被试情绪识别模型及其训练方法、情绪识别方法、设备
CN116452897A (zh) * 2023-06-16 2023-07-18 中国科学技术大学 跨域小样本分类方法、系统、设备及存储介质
CN116452897B (zh) * 2023-06-16 2023-10-20 中国科学技术大学 跨域小样本分类方法、系统、设备及存储介质

Also Published As

Publication number Publication date
CN114139676A (zh) 2022-03-04

Similar Documents

Publication Publication Date Title
JP2022042487A (ja) ドメイン適応型ニューラルネットワークの訓練方法
CN110472675B (zh) 图像分类方法、图像分类装置、存储介质与电子设备
CN108694443B (zh) 基于神经网络的语言模型训练方法和装置
CN113326731B (zh) 一种基于动量网络指导的跨域行人重识别方法
CN114841257B (zh) 一种基于自监督对比约束下的小样本目标检测方法
CN113392967A (zh) 领域对抗神经网络的训练方法
US20190370219A1 (en) Method and Device for Improved Classification
CN110853630B (zh) 面向边缘计算的轻量级语音识别方法
CN114398855A (zh) 基于融合预训练的文本抽取方法、系统及介质
CN113255366B (zh) 一种基于异构图神经网络的方面级文本情感分析方法
CN114022737A (zh) 对训练数据集进行更新的方法和设备
CN112232395B (zh) 一种基于联合训练生成对抗网络的半监督图像分类方法
CN112527959A (zh) 基于无池化卷积嵌入和注意分布神经网络的新闻分类方法
CN115408525A (zh) 基于多层级标签的信访文本分类方法、装置、设备及介质
CN117648950A (zh) 神经网络模型的训练方法、装置、电子设备及存储介质
CN115063664A (zh) 用于工业视觉检测的模型学习方法、训练方法及系统
Deng et al. Heterogeneous tri-stream clustering network
CN114328917A (zh) 用于确定文本数据的标签的方法和设备
CN116958548B (zh) 基于类别统计驱动的伪标签自蒸馏语义分割方法
CN117541853A (zh) 一种基于类别解耦的分类知识蒸馏模型训练方法和装置
CN116645980A (zh) 一种聚焦样本特征间距的全生命周期语音情感识别方法
CN112951270B (zh) 语音流利度检测的方法、装置和电子设备
CN114495114A (zh) 基于ctc解码器的文本序列识别模型校准方法
CN111259860B (zh) 基于数据自驱动的多阶特征动态融合手语翻译方法
US20230025148A1 (en) Model optimization method, electronic device, and computer program product

Legal Events

Date Code Title Description
A621 Written request for application examination

Free format text: JAPANESE INTERMEDIATE CODE: A621

Effective date: 20240509