JP7616396B2 - 訓練装置、訓練方法、及びプログラム - Google Patents
訓練装置、訓練方法、及びプログラム Download PDFInfo
- Publication number
- JP7616396B2 JP7616396B2 JP2023541364A JP2023541364A JP7616396B2 JP 7616396 B2 JP7616396 B2 JP 7616396B2 JP 2023541364 A JP2023541364 A JP 2023541364A JP 2023541364 A JP2023541364 A JP 2023541364A JP 7616396 B2 JP7616396 B2 JP 7616396B2
- Authority
- JP
- Japan
- Prior art keywords
- class
- domain
- target domain
- feature
- values
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Active
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/09—Supervised learning
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/096—Transfer learning
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/044—Recurrent networks, e.g. Hopfield networks
-
- G—PHYSICS
- G06—COMPUTING OR CALCULATING; COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0464—Convolutional networks [CNN, ConvNet]
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Physics & Mathematics (AREA)
- Data Mining & Analysis (AREA)
- General Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Life Sciences & Earth Sciences (AREA)
- Evolutionary Computation (AREA)
- Artificial Intelligence (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Health & Medical Sciences (AREA)
- Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
Description
以下、本発明に係る第1の例示的実施形態について図面を用いて説明する。第1の例示的実施形態は、後続の例示的実施形態の基礎となる実施形態である。
本例示的実施形態に係る訓練装置(学習装置と呼ぶこともある)は、訓練装置に入力されたデータから特徴量を抽出する特徴抽出部を訓練する(学習する、学習させるとも言う)。また、訓練装置は、特徴値(特徴量とも呼ぶ)に基づいて分類を行うクラス予測部を訓練する。
次に、第1の例示的実施形態に係る訓練装置10の構成について、図1を用いて説明する。図1は、訓練装置10の構成を示すブロック図である。図1に示すように、訓練装置10は、特徴抽出部11と、クラス予測部12と、フィルタ部(フィルタリング部とも呼ぶ)13と、更新部14とを備える。第1の例示的実施形態において、特徴抽出部11、クラス予測部12、及びフィルタ部13の個数は1個でもよいし、2個以上でもよい。
第1の例示的実施形態によれば、上述したように、更新部14は、ソースドメイン分類損失及びターゲットドメイン分類損失に加えて、グループ損失も参照して、特徴抽出部11及びクラス予測部12の少なくとも一方を更新する。したがって、第1の例示的実施形態によれば、ターゲットドメインラベル付きデータが少量であっても、特徴抽出部11及びクラス予測部12を適切に訓練することができる。
次に、第1の例示的実施形態に係る訓練装置10の訓練方法について、図2を用いて説明する。図2は、訓練装置10による訓練方法S1の流れを示すフロー図である。図2に示すように、訓練装置10は、特徴抽出ステップS11、クラス予測ステップS12、フィルタステップS13、及び更新ステップS14を実行する。
特徴抽出ステップS11において、特徴抽出部11はソースドメインデータからソースドメイン特徴値を抽出し、ターゲットドメインデータからターゲットドメイン特徴値を抽出する。
クラス予測ステップS12において、クラス予測部12はソース領域特徴値からソースドメインクラス予測値を予測し、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測する。
フィルタステップS13において、フィルタ部13は、
ソースドメインクラス予測値を参照してソースドメイン特徴値から1つ以上の値をフィルタリング(フィルタ除去)し、
ターゲットドメインクラス予測値を参照してターゲットドメイン特徴値から1つ以上の値をフィルタリング(フィルタ除去)することによって、
フィルタリングされたソースドメイン特徴値及びフィルタリングされたターゲットドメイン特徴値を出力する。
更新ステップS14において、更新部は、
ソースドメインクラス予測値を参照して算出されたソースドメイン分類損失、
ターゲットドメインクラス予測値を参照して算出されたターゲットドメイン分類損失、及び、
フィルタリングされたソースドメイン特徴値及びフィルタリングされたターゲットドメイン特徴値を参照して算出されたグループ損失
を参照して、特徴抽出部11及びクラス予測部12の少なくとも一方を更新する。
以上説明した訓練装置10による訓練方法S1は、少量のターゲットドメインラベル付きデータしか利用できない場合であっても、効率的で安定した訓練プロセスを提供する。
以下、第1の例示的実施形態に係る分類装置20の構成について、図3を用いて実施例する。図3は、分類装置20の構成を示すブロック図である。図3に示すように、分類装置20は、特徴抽出部11と、クラス予測部12とを備える。
ソースドメイン特徴値から1つ以上の値をフィルタリングして得られるフィルタリングされたソースドメイン特徴値と、
ターゲットドメイン特徴値から1つ以上の値をフィルタリングして得られるフィルタリングされたターゲットドメイン特徴値とを参照して訓練されている。
次に、第1の例示的実施形態に係る分類装置20の分類方法について、図4を用いて説明する。図4は、分類装置20が行う分類方法S2を示すフロー図である。図4に示すように、分類装置20は、特徴抽出ステップS11と、クラス予測ステップS12とを実行する。
特徴抽出ステップS11において、特徴抽出部11は、ターゲットドメインデータからターゲットドメイン特徴値を抽出する。訓練装置10について説明したように、特徴抽出部11及びクラス予測部12の少なくとも一方は、
ソースドメイン特徴値から1つ以上の値をフィルタリングして得られるフィルタリングされたソースドメイン特徴値と、
ターゲットドメイン特徴値から1つ以上の値をフィルタリングして得られるフィルタリングされたターゲットドメイン特徴値とを参照して訓練されている。
クラス予測ステップS12において、クラス予測部12は、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測する。
以下、第2の例示的実施形態について図面を用いて説明する。なお、第1の例示的実施形態で説明した要素と同一の機能を有する要素には同一の符号を付し、その説明は適宜省略する。また、第2の例示的実施形態の概要は、第1の例示的実施形態の概要と同じであるため、ここでは説明しない。
次に、例示的実施形態に係る訓練装置10aの構成について、図5を用いて説明する。図5は、訓練装置10aの構成を示すブロック図である。図5に示すように、訓練装置10aは、第1の特徴抽出部11a、第2の特徴抽出部11b、第1のクラス予測部12a、第2のクラス予測部12b、第1のフィルタ部13a、第2のフィルタ部13b、及び更新部14aを備える。
第1の特徴抽出部11aには、ソースドメインに属する入力データISが入力される。入力データISの具体例は第2の例示的実施形態を限定しないが、入力データISは1つまたは複数の入力画像であり得る。より具体的には一例として、入力データISは複数の領域を有する画像であってもよい。別の例として、入力画像データISは、図6の左側に示されているような一群の画像(batch of images)であってもよい。図6の左側の例ではインプットデータISが10枚の画像の群(batch)を含み、その各々は数字または複数の数字を表している。
同様に、第2の特徴抽出部11bには、ターゲットドメインに属する入力データITが入力される。入力データITの具体例は第2の例示的実施形態を限定しないが、入力データITは1つまたは複数の入力画像であり得る。より具体的には一例として、入力データITは複数の領域を有する画像であってもよい。別の例として、入力画像データITは、図6の右側に示されているような一群の画像であってもよい。図6の右側の例において、入力データITは10個の画像を含み、それぞれが数字または複数の数字を表している。
第1のクラス予測部12aは第1の特徴抽出部11aにより抽出されたソースドメイン特徴量XSから、ソースドメインクラス予測値PSを予測する。
第2のクラス予測部12bは第2の特徴抽出部11bにより抽出されたターゲットドメイン特徴値XTから、ターゲットドメインクラス予測値PTを予測する。
第1のフィルタ部13aは、ソースドメインクラス予測値PSを参照して、ソースドメイン特徴値XSから1つ以上の値をフィルタ除去することにより、フィルタリングされたソースドメイン特徴値X’Sを算出する。
(第2のフィルタ部)
第2のフィルタ部13bは、ターゲットドメインクラス予測値PTを参照して、ターゲットドメイン特徴値XTから1つ以上の値をフィルタ除去することにより、フィルタリングされたターゲットドメイン特徴値X’Tを算出する。
グルーピング部141はフィルタリングされたのソースドメイン特徴値X’Sと、フィルタリングされたのターゲットドメイン特徴値X’Tとから、クラスグループを生成する。ここで各クラスグループは同じクラスラベルを共有する特徴量を含んでいる。
ここで、Gr0は、特徴値が同じクラスラベル0を共有するクラスグループである。Gr1は、特徴値が同じクラスラベル1を共有するクラスグループである。同様に、Gr2、Gr3、Gr4はそれぞれ、特徴値が同じクラスラベル2、3、4を共有するクラスグループである。
グループ損失計算部142は、グルーピング部141が生成したクラスグループを参照して、グループ損失(Loss_grouping)を算出する。
ここで、全てのソースドメイン特徴値と全てのターゲットドメイン特徴値との和集合における各特徴値xについて、その特徴値xの「特徴空間におけるクラス内距離の最大値(maximum of intra-class distance in the feature space)」を、特徴値xと同じクラスグループに由来する他の任意の特徴値との間の最大距離として算出し、「特徴空間におけるクラス間距離の最小値(minimum of inter-class distance in the feature space)」を、特徴値xとは異なるクラスグループに由来する他の任意の特徴値との間の最小距離として算出する。マージン(margin)は、特徴値の最小クラス間距離から特徴値の最大クラス内距離を差し引いた値の許容最小値を示す。以下、(特徴空間におけるクラス間距離の最大値-特徴空間におけるクラス間距離の最小値+マージン)によって特徴値ごとに算出される値を「個々のグループ損失(individual grouping loss)」と呼ぶ。全体的なグループ損失は、各ソースドメイン特徴値および各ターゲットドメイン特徴値に対する個々のグループ損失の平均として計算される。平均は最初に、すべてのソースドメイン特徴値およびすべてのターゲットドメイン特徴値についての個々のグループ損失の合計を計算し、次いで、その合計を、ソースドメイン特徴値の数とターゲットドメイン特徴値の数との和で除算することによって計算される。
第1クラス分類損失計算部143aは、(i)ソースドメインクラス予測値PS、(ii)ソースドメインクラスラベルデータYSを参照して、ソースドメイン分類損失(Loss_classification_S)を算出する。
(第2クラス分類損失計算部)
第2クラス分類損失計算部143bは、(i)ターゲットドメインクラス予測値PT、(ii)ターゲットドメインクラスラベルデータYTを参照して、ターゲットドメイン分類損失ロス(Loss_classification_T)を算出する。
(マージ損失計算部)
マージ損失計算部144は、ソースドメイン分類損失(Loss_classification_S)、ターゲットドメイン分類損失(Loss_classification_T)及びグループ損失(Loss_grouping)を参照して、マージ損失(Loss_merge)を算出する。
モデル更新部145は、マージ損失が収束したか否かを判定する。マージ損失が収束した場合、モデル更新部145は、収束したモデルパラメータを記憶媒体に出力する。モデル更新部145は、マージ損失値が収束していない場合、マージ損失計算部144が算出したマージ損失を参照して、第1の特徴抽出部11a、第2の特徴抽出部11b、第1のクラス予測部12a、及び第2のクラス予測部12bのモデルパラメータを更新する。
第2の例示的実施形態によれば、上述したように、モデル更新部145は、ソースドメイン分類損失及びターゲットドメイン分類損失に加えて、グループ損失を参照してモデルパラメータを更新する。したがって、第2の例示的実施形態によれば、少量のターゲットドメインラベル付きデータしか利用できない場合であっても、第2の特徴抽出部11bおよび第2のクラス予測部12bを学習することができる。
以下、第2の例示的実施形態に係る訓練装置10aの訓練方法について、図8を用いて実施例する。図8は、訓練装置10aによる訓練方法S1aの流れを示すフロー図である。
ステップS100において、訓練装置10aは、初期モデルパラメータを受信する。初期モデルパラメータには、第1の特徴抽出部11a、第2の特徴抽出部11b、第1のクラス予測部12a、第2のクラス予測部12bの初期モデルパラメータが含まれる。このステップで受信された初期モデルパラメータは、第1の特徴抽出部11a、第2の特徴抽出部11b、第1のクラス予測部12a、及び第2のクラス予測部12bに供給される。
ステップS10aにおいて、訓練装置10aは、ソースドメインのデータを受信する。より具体的には、訓練装置10aは、ソースドメイン入力データISおよび入力データISに関連付けられたソースドメインクラスラベルデータYSを受信する。
ステップS10bにおいて、訓練装置10aは、ターゲットドメインのデータを受信する。より具体的には、訓練装置10aは、ターゲットドメイン入力データIT及び入力データITに関連付けられたターゲットドメインクラスラベルデータYTを受信する。
ステップS11aでは、第1の特徴抽出部11aがソースドメインデータISから特徴値XSを抽出する。なお、第1の特徴抽出部11aが行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS11bでは、第2の特徴抽出部11bがターゲットドメインデータITから特徴値XTを抽出する。なお、第2の特徴抽出部11bが行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS12aにおいて、第1のクラス予測部12aは、第1の特徴抽出部11aにより抽出されたソースドメイン特徴値XSからソースドメインクラス予測値PSを予測する。なお、第1のクラス予測部12aが行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS12bにおいて、第2のクラス予測部12bは、第2の特徴抽出部11bが抽出したターゲットドメイン特徴値XTから、ターゲットドメインクラス予測値PTを予測する。なお、第2のクラス予測部12bが行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS13aにおいて、第1のフィルタ部13aは、ソースドメインクラス予測値PSを参照して、ソースドメイン特徴値XSから1つ以上の値をフィルタ除去することにより、フィルタリングされたソースドメイン特徴値X’Sを算出する。なお、第1のフィルタ部13aによる具体的なプロセスについては、上述したので、ここでは繰り返し説明しない。
ステップS13bでは、第2のフィルタ部13bは、ターゲットドメインクラス予測値PTを参照して、ターゲットドメイン特徴値XTから1つ以上の値をフィルタ除去することにより、フィルタリングされたターゲットドメイン特徴値X’Tを算出する。なお、第2のフィルタ部13bによる具体的なプロセスについては、上述したので、ここでは繰り返し説明しない。
ステップS141では、グルーピング部141がフィルタリングされたソースドメイン特徴値X’Sとフィルタリングされたターゲットドメイン特徴値X’Tとから、クラスグループを生成して出力する。ここで、各クラスグループは、同じクラスラベルを共有する特徴量を含む。なお、グルーピング部141が行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS142において、グループ損失計算部142は、グルーピング部141が生成したクラスグループを参照して、グループ損失(Loss_grouping)を算出する。なお、グループ損失計算部142が行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS143aにおいて、第1クラス分類損失計算部143aは、ソースドメインクラス予測値PS及びソースドメインクラスラベルデータYSを用いて、ソースドメイン分類損失(Loss_classification_S)を算出する。なお、第1クラス分類損失計算部143aが行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS143bにおいて、第2クラス分類損失計算部143bは、ターゲットドメインクラス予測値PTとターゲットドメインクラスラベルデータYTとを用いて、ターゲットドメイン分類損失(Loss_classification_T)を算出する。なお、第2クラス分類損失計算部143bが行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS144において、マージ損失計算部144は、ソースドメイン分類損失(Loss_classification_S)、ターゲットドメイン分類損失(Loss_classification_T)及びグループ損失(Loss_grouping)を参照して、マージ損失(Loss_merge)を算出する。なお、マージ損失計算部144が行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS145において、モデル更新部145は、マージ損失が収束したか否かを判定する。マージ損失が収束している場合(ステップS145でYES)、ステップS147に進む。マージ損失が収束していない場合(ステップS145でNO)、ステップS146に進む。
ステップS146において、モデル更新部145は、マージ損失計算部144が算出したマージ損失を参照して、第1の特徴抽出部11a、第2の特徴抽出部11b、第1のクラス予測部12a、及び第2のクラス予測部12bのモデルパラメータを更新する。
ステップS147において、モデル更新部145は、マージ損失計算部144が算出したマージ損失を参照して、第1の特徴抽出部11a、第2の特徴抽出部11b、第1のクラス予測部12a、及び第2のクラス予測部12bのモデルパラメータを記憶媒体に記憶する。
以上説明した訓練装置10aによる訓練方法S1aは、少量のターゲットドメインラベル付きデータしか利用できない場合であっても、効率的で安定した訓練プロセスを提供する。
次に、第2の例示的実施形態に係る分類装置20aの構成について、図9を用いて説明する。図9は、分類装置20aの構成を示すブロック図である。図9に示すように、分類装置20aは、特徴抽出部11bと、クラス予測部12bとを備える。ここで、特徴抽出部11bは、上述した第2の特徴抽出部11bと同様に構成され、クラス予測部12bは、上述した第2のクラス予測部12bと同様に構成される。
(備考1:特徴抽出部について)
一構成例において、第1の特徴抽出部11aと第2の特徴抽出部11bとは、完全に独立していてもよい。すなわち、第1の特徴抽出部11aと第2の特徴抽出部11bとは、モデルパラメータも層(レイヤ)も共有しない。
一構成例において、第1のクラス予測部12aと第2のクラス予測部12bとは、完全に独立していてもよい。すなわち、第1のクラス予測部12aおよび第2のクラス予測部12bは、モデルパラメータおよび層を共用しない。
グループ損失の計算は、クラスグループ内の2つの特徴間の距離または類似性を計算するための任意の方法によって達成され得る。グループ損失は、L1ノルム、L2ノルム、コサイン類似性、または学習などを必要とする何らかの他の尺度であり得る。
この再スケーリングは、以下の問題点に鑑みて行われる。すなわち、高品質の特徴値の場合であっても、損失値が計算される距離は依然として非常に大きくなり得、これは大きな損失値をもたらす。安全な範囲内でグループ損失を再スケーリングするには、単純なクリッピング(Loss_grouping > 1 の場合は1 を返し、それ以外の場合はLoss_grouping を返す)、または重みλ (λloss_grouping)を使用した単純な線形再ウェイト付けなど、さまざまな方法がある。
マージ損失は、ソースドメイン分類損失(Loss_classification_S)、ターゲットドメイン分類損失(Loss_classification_T)、およびグループ損失(Loss_grouping)などのすべてのサブタスク損失の直接和であり得るか、またはサブタスク損失の加重和であり得る。
特徴値をフィルタ除去するかどうかを決定するルールは変化し得る。第1の例として、当該ルールは、第1のクラス予測部12aまたは第2のクラス予測部12bによって与えられる予測の正確さに依存し得る。
以下、本発明の第3の例示的実施形態について、図面を用いて詳しく説明する。なお、上記例示的実施形態で説明した要素と同一の機能を有する要素には同一の符号を付し、その説明は適宜省略する。さらに、第3の例示的実施形態の概要は、前述の実施例の実施形態の概要と同じであるので、ここでは説明しない。
次に、第3の例示的実施形態に係る訓練装置10bの構成について、図10を用いて説明する。図10は、訓練装置10bの構成を示すブロック図である。図10に示すように、第3の例示的実施形態では、第1のフィルタ部13aにソースドメイン分類損失(Loss_classification_S)が入力され、第2のフィルタ部13bにターゲットドメイン分類損失(Loss_classification_T)が入力される点が第2の実施の形態と相違する。
第3の例示的実施形態によれば、誤って分類されたが価値のある特徴を、訓練プロセスにおいて適切に利用することができる。
第3の例示的実施形態に係る分類装置20bは、第2の例示的実施形態に係る分類装置20aと同様の構成を有する。ただし、上述したように、第1のフィルタ部13aは訓練プロセスにおいてソースドメイン分類損失をさらに参照し、第2のフィルタ部13bは、第3の例示的実施形態における訓練プロセスにおけるターゲットドメイン分類損失をさらに参照する。
以下、本発明の第4の例示的実施形態について、図面を用いて詳しく説明する。なお、上記例示的実施形態で説明した要素と同一の機能を有する要素には同一の符号を付し、その説明は適宜省略する。さらに、第4の例示的実施形態の概要は、前述の例示的実施形態の概要と同じであるので、ここでは説明しない。
以下、第4の例示的実施形態に係る訓練装置10cの構成について、図11を用いて説明する。図11は、訓練装置10cの構成を示すブロック図である。図11に示すように、訓練装置10cは、第2の例示的実施形態に係る訓練装置10aに含まれる構成要素に加えて、ドメイン判別部15と、ドメイン損失計算部16とを備える。
ドメイン判別部15は、ターゲットドメインとソースドメインとを判別する判別処理を行う。すなわち、ドメイン判別部15は、特徴がソースドメインからのものであるか、ターゲットドメインからのものであるかを示すドメイン予測を行う。
(マージ損失計算部)
第4の例示的実施形態に係るマージ損失計算部144は、ソースドメイン分類損失(Loss_classification_S)、ターゲットドメイン分類損失(Loss_classification_T)、グループ損失(Loss_grouping)、及びドメイン損失(Loss_domain)を参照して、マージ損失(Loss_merge)を算出する。
ここで、係数α、β、γ、δは、重み係数を示している。これらの重み係数の具体的な値は第4の例示的実施形態を限定するものではない。ここで、ドメイン損失の前の符号はマイナスであることに留意されたい。これは、抽出された特徴がドメイン判別部による結果の正確性を低下させるように、モデル更新部145が、第1の特徴抽出部11a及び第2の特徴抽出部11bのモデルパラメータを更新することを意味する。すなわち、モデル更新部145は抽出された特徴がドメイン判別部15を混乱させるように、第1の特徴抽出部11a及び第2の特徴抽出部11bのモデルパラメータを更新する。
訓練装置10cは訓練の観点から、以下の処理を行う。まず、訓練装置10cは、特徴がソースドメインからのものであるか、ターゲットドメインからのものであるかをドメイン判別部15が判別できるように、ドメイン判別部15を訓練する。第2に、訓練装置10cは、第1の特徴抽出部11a及び第2の特徴抽出部11bを、訓練されたドメイン判別部15が混乱し得る特徴を抽出するように訓練する。
以上説明したように、第4の例示的実施形態によれば、訓練装置10cは、抽出された特徴XS及びXTのドメイン不変性を実現することができる。これは、好ましいターゲットドメインの特性をもたらす。
以下、第4の例示的実施形態に係る訓練装置10cの訓練方法について、図13を用いて説明する。図13は、訓練装置10cによる訓練方法S1cの流れを示すフロー図である。
ステップS15において、ドメイン判別部15は、ターゲットドメインをソースドメインから判別するドメイン判別処理を行う。ドメイン判別部15が行う具体的な処理については、上述したので、ここでは繰り返さない。
ステップS16において、ドメイン損失計算部16は、ドメイン判別部15による判別処理の結果を参照して、ドメイン判別損失を算出して出力する。ドメイン損失計算部16が行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
ステップS144において、第4の例示的実施形態に係るマージ損失計算部144は、(i)ソースドメイン分類損失(Loss_classification_S)、(ii)ターゲットドメイン分類損失(Loss_classification_T)、(iii)グループ損失(Loss_grouping)、(iv)ドメイン損失(Loss_domain)を参照して、マージ損失(Loss_merge)を算出する。なお、マージ損失計算部144が行う具体的な処理については、上述したので、ここでは繰り返し説明しない。
第4の例示的実施形態によれば、訓練方法S1cは、少量のターゲットドメインラベル付きデータしか利用できない場合であっても、効率的で安定した訓練プロセスを提供する。
以下、第4の例示的実施形態に係る分類装置の構成について説明する。第4の例示的実施形態に係る分類装置20cは、図9に示す分類装置20aと同様の構成を有する。
ドメイン損失を計算するために、クラスラベルは必要ない。したがって、訓練装置10cはラベル付けされたターゲットドメインデータのみを使用する代わりに、(クラスラベルの意味で)ラベル付けされていないターゲットデータを使用することもできる。データがターゲットデータセットからのものである限り、訓練装置10cは、データのドメインラベルが「ターゲット」であることを知ることができる。
以下、本発明の第5の例示的実施形態について、図面を用いて詳しく説明する。なお、上記例示的実施形態で説明した要素と同一の機能を有する要素には同一の符号を付し、その説明は適宜省略する。さらに、第5の例示的実施形態の概要は、前述の実施例の実施形態の概要と同じであるので、ここでは説明しない。
次に、第5の例示的実施形態に係る訓練装置10dの構成について、図14を用いて説明する。図14は、訓練装置10dの構成を示すブロック図である。図14に示すように、第5の例示的実施形態に係る訓練装置10dは、第1のフィルタ部13aにソースドメイン分類損失(Loss_classification_S)が入力され、第2のフィルタ部13bにターゲットドメイン分類損(Loss_classification_T)が入力される点で、第4の例示的実施形態に係る訓練装置10cと相違する。
第5の例示的実施形態に係る分類装置20dは、図9に示す分類装置20aと同様の構成を有する。第5の例示的実施形態に係る訓練装置10dは、第3の例示的実施形態で説明した構成と、第4の例示的実施形態で説明した構成との両方を備える。また、第5の例示的実施形態に係る分類装置20dは、訓練装置10dによって訓練された特徴抽出部11bと、クラス予測部12bとを備える。
訓練装置10,10a,10b,10c,10d及び分類装置20,20a,20b,20c,20dの機能の一部又は全部はICチップ(integrated circuit)等のハードウェアで実現してもよいし、ソフトウェアで実現してもよい。
本発明は、前述の例示的実施形態に限定されず、特許請求の範囲内で当業者によって様々な方法で変更され得る。例えば、上記例示的実施形態に開示されている技術的手段を適宜組み合わせて得られる例示的実施形態についても、本発明の技術的範囲に含まれる。
上述の例示的実施形態の全部または一部は、以下のように表現することもできる。ただし、本発明は以下の例示的態様に限定されない。
本発明の態様は、以下のように表すこともできる:
(態様1)
ソースドメインデータからソースドメイン特徴値を抽出し、ターゲットドメインデータからターゲットドメイン特徴値を抽出する1または複数の特徴抽出手段と、
ソースドメイン特徴値からソースドメインクラス予測値を予測し、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測する1または複数のクラス予測手段と、
ソースドメインクラス予測値を参照してソースドメイン特徴値から1または複数の値をフィルタ除去し、
ターゲットドメインクラス予測値を参照してターゲットドメイン特徴値から1または複数の値をフィルタ除去すること
によって、フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を計算する1または複数のフィルタリング手段と、
ソースドメインクラス予測値を参照して計算されたソースドメイン分類損失、
ターゲットドメインクラス予測値を参照して計算されたターゲットドメイン分類損失、および、
フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を参照して計算されたグループ損失
を参照して、前記1または複数の特徴抽出手段および前記1または複数のクラス予測手段のうちの少なくとも1つを更新するための更新手段と、
を備える訓練装置。
(態様2)
前記1または複数のフィルタリング手段は、
前記ソースドメインクラス予測値およびソースドメインクラスラベルデータを参照して、前記ソースドメイン特徴値から1または複数の値をフィルタ除去し、
前記ターゲットドメインクラス予測値およびターゲットドメインクラスラベルデータを参照して、前記ターゲットドメイン特徴値から1又は複数の値をフィルタ除去する、態様1に記載の訓練装置。
前記1または複数のフィルタリング手段は、前記ソースドメイン分類損失および前記ターゲットドメイン分類損失をさらに参照する、態様1または2に記載の訓練装置。
前記更新手段は
フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値から、クラスグループを出力するグループ化手段を備え、
前記クラスグループの各々は同じクラスラベルを共有する特徴値を含む、態様1から3のいずれか1項に記載の訓練装置。
前記更新手段は、前記クラスグループを参照して前記グループ損失を算出するグループ損失計算手段をさらに備える、態様4に記載の訓練装置。
前記グループ損失計算手段は、
同一クラス内の特徴量を参照して決定されたクラス内距離と、
異なるクラス内の特徴量を参照して決定されたクラス間距離と
に基づいて、前記グループ損失を算出することを特徴とする態様5に記載の訓練装置。
前記更新手段は、
前記ソースドメインクラス予測値およびソースドメインクラスラベルデータを参照して前記ソースドメイン分類損失を計算し、
前記ターゲットドメインクラス予測値およびターゲットドメインクラスラベルデータを参照して前記ターゲットドメイン分類損失を計算する
1または複数の分類損失計算手段をさらに備える、態様1から6のいずれか1項に記載の訓練装置。
前記更新手段は、
(i)前記ソースドメイン分類損失、(ii)前記ターゲットドメイン分類損失、および(iii)前記グループ損失を参照して、マージ損失を計算するマージ損失計算手段をさらに備え、
前記更新手段は、
前記マージ損失を参照して、前記1または複数の特徴抽出手段および前記1または複数のクラス予測手段のうちの少なくとも1つを更新する、態様1から7のいずれか1項に記載の訓練装置。
ソースドメインからターゲットドメインを判別するための判別処理を実行する1または複数のドメイン判別手段と、
判別処理の結果としてドメイン判別損失を出力する1または複数のドメイン損失計算手段と
を更に備え、
前記更新手段はドメイン判別損失をさらに参照し、
前記更新手段はドメイン判別手段をさらに更新する
態様1から8のいずれか1項に記載の訓練装置。
ターゲットドメインデータからターゲットドメイン特徴値を抽出する特徴抽出手段と、
ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測するクラス予測手段と、
を備え、
前記特徴抽出手段および前記クラス予測手段のうちの少なくとも1つは、
ソースドメイン特徴値から1または複数の値をフィルタ除去することによって得られるフィルタリングされたソースドメイン特徴値と、
ターゲットドメイン特徴値から1または複数の値をフィルタ除去することによって得られるフィルタリングされたターゲットドメイン特徴値と
を参照して訓練されている
分類装置。
1または複数の特徴抽出手段によって、ソースドメインデータからソースドメイン特徴値を抽出し、ターゲットドメインデータからターゲットドメイン特徴値を抽出することと、
1または複数のクラス予測手段によって、ソースドメイン特徴値からソースドメインクラス予測値を予測し、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測することと、
ソースドメインクラス予測値を参照してソースドメイン特徴値から1または複数の値をフィルタ除去し、
ターゲットドメインクラス予測値を参照してターゲットドメイン特徴値から1または複数の値をフィルタ除去する
ことによって、フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を計算することと、
ソースドメインクラス予測値を参照して計算されたソースドメイン分類損失、
ターゲットドメインクラス予測値を参照して計算されたターゲットドメイン分類損失、および、
フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を参照して計算されたグループ損失
を参照して、前記1または複数の特徴抽出手段および前記1または複数のクラス予測手段のうちの少なくとも1つを更新することと、
を含む訓練方法。
特徴抽出手段により、ターゲットドメインデータからターゲットドメイン特徴値を抽出することと、
クラス予測手段により、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測することと
を含み、
特徴抽出手段およびクラス予測手段のうちの少なくとも1つは、
ソースドメイン特徴値から1または複数の値をフィルタ除去することによって得られるフィルタリングされたソースドメイン特徴値と、
ターゲットドメイン特徴値から1または複数の値をフィルタ除去することによって得られるフィルタリングされたターゲットドメイン特徴値と
を参照して訓練されている
分類方法。
態様1に記載の訓練装置としてコンピュータを機能させるためのプログラムであって、前記コンピュータを、前記特徴抽出手段、前記クラス予測手段、前記フィルタリング手段、および前記更新手段として機能させることを特徴とするプログラム。
態様10に記載の分類装置としてコンピュータを機能させるためのプログラムであって、前記特徴抽出手段及び前記クラス予測手段としてコンピュータを機能させるためのプログラム。
少なくとも1つのプロセッサを備えた訓練装置であって、
前記プロセッサは、
1または複数の特徴抽出手段によって、ソースドメインデータからソースドメイン特徴値を抽出し、ターゲットドメインデータからターゲットドメイン特徴値を抽出することと、
1または複数のクラス予測手段によって、ソースドメイン特徴値からソースドメインクラス予測値を予測し、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測することと、
ソースドメインクラス予測値を参照してソースドメイン特徴値から1または複数の値をフィルタ除去し、
ターゲットドメインクラス予測値を参照してターゲットドメイン特徴値から1または複数の値をフィルタ除去する
ことによって、フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を計算することと、
ソースドメインクラス予測値を参照して計算されたソースドメイン分類損失、
ターゲットドメインクラス予測値を参照して計算されたターゲットドメイン分類損失、および、
フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を参照して計算されたグループ損失
を参照して、前記1または複数の特徴抽出手段および前記1または複数のクラス予測手段のうちの少なくとも1つを更新することと
を実行する訓練装置。
少なくとも1つのプロセッサを備えた分類装置であって、
前記プロセッサは、
特徴抽出手段により、ターゲットドメインデータからターゲットドメイン特徴値を抽出することと、
クラス予測手段により、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測することと
を実行し、
特徴抽出手段およびクラス予測手段のうちの少なくとも1つは、
ソースドメイン特徴値から1または複数の値をフィルタ除去することによって得られるフィルタリングされたソースドメイン特徴値と、
ターゲットドメイン特徴値から1または複数の値をフィルタ除去することによって得られるフィルタリングされたターゲットドメイン特徴値と
を参照して訓練されている
分類装置。
20, 20a、20b、20c、20d 分類装置
11, 11a、11b 特徴抽出部
12, 12a、12b クラス予測部
13, 13a、13b フィルタ部
14, 14a 更新部
141 グルーピング部
142 グループ損失計算部
143a、143b 分類損失計算部
144 マージ損失計算部
145 モデル更新部
15 ドメイン判別部
16 ドメイン損失計算部
Claims (7)
- ソースドメインデータからソースドメイン特徴値を抽出し、ターゲットドメインデータからターゲットドメイン特徴値を抽出する1または複数の特徴抽出手段と、
ソースドメイン特徴値からソースドメインクラス予測値を予測し、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測する1または複数のクラス予測手段と、
ソースドメインクラス予測値を参照してソースドメイン特徴値から1または複数の値をフィルタ除去し、
ターゲットドメインクラス予測値を参照してターゲットドメイン特徴値から1または複数の値をフィルタ除去することによって、フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を計算する1または複数のフィルタリング手段と、
ソースドメインクラス予測値を参照して計算されたソースドメイン分類損失、
ターゲットドメインクラス予測値を参照して計算されたターゲットドメイン分類損失、および、
フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を参照して計算されたグループ損失を参照して、前記1または複数の特徴抽出手段および前記1または複数のクラス予測手段のうちの少なくとも1つを更新するための更新手段と、
を備える訓練装置。 - 前記1または複数のフィルタリング手段は、
前記ソースドメインクラス予測値およびソースドメインクラスラベルデータを参照して、前記ソースドメイン特徴値から1または複数の値をフィルタ除去し、
前記ターゲットドメインクラス予測値およびターゲットドメインクラスラベルデータを参照して、前記ターゲットドメイン特徴値から1又は複数の値をフィルタ除去する、請求項1に記載の訓練装置。 - 前記1または複数のフィルタリング手段は、前記ソースドメイン分類損失および前記ターゲットドメイン分類損失をさらに参照する、請求項1または2に記載の訓練装置。
- 前記更新手段は
フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値から、クラスグループを出力するグループ化手段を備え、
前記クラスグループの各々は同じクラスラベルを共有する特徴値を含む、請求項1から3の何れか1項に記載の訓練装置。 - 前記更新手段は、前記クラスグループを参照して前記グループ損失を算出するグループ損失算出手段をさらに備える、請求項4に記載の訓練装置。
- 1または複数の特徴抽出手段によって、ソースドメインデータからソースドメイン特徴値を抽出し、ターゲットドメインデータからターゲットドメイン特徴値を抽出することと、
1または複数のクラス予測手段によって、ソースドメイン特徴値からソースドメインクラス予測値を予測し、ターゲットドメイン特徴値からターゲットドメインクラス予測値を予測することと、
ソースドメインクラス予測値を参照してソースドメイン特徴値から1または複数の値をフィルタ除去し、
ターゲットドメインクラス予測値を参照してターゲットドメイン特徴値から1または複数の値をフィルタ除去することによって、フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を計算することと、
ソースドメインクラス予測値を参照して計算されたソースドメイン分類損失、
ターゲットドメインクラス予測値を参照して計算されたターゲットドメイン分類損失、および、
フィルタリングされたソースドメイン特徴値およびフィルタリングされたターゲットドメイン特徴値を参照して計算されたグループ損失を参照して、前記1または複数の特徴抽出手段および前記1または複数のクラス予測手段のうちの少なくとも1つを更新することと、
を含む訓練方法。 - 請求項1に記載の訓練装置としてコンピュータを機能させるためのプログラムであって、前記コンピュータを、前記特徴抽出手段、前記クラス予測手段、前記フィルタリング手段、および前記更新手段として機能させることを特徴とするプログラム。
Applications Claiming Priority (3)
| Application Number | Priority Date | Filing Date | Title |
|---|---|---|---|
| JP2021003115 | 2021-01-12 | ||
| JP2021003115 | 2021-01-12 | ||
| PCT/JP2021/044388 WO2022153710A1 (en) | 2021-01-12 | 2021-12-03 | Training apparatus, classification apparatus, training method, classification method, and program |
Publications (2)
| Publication Number | Publication Date |
|---|---|
| JP2024502153A JP2024502153A (ja) | 2024-01-17 |
| JP7616396B2 true JP7616396B2 (ja) | 2025-01-17 |
Family
ID=82448377
Family Applications (1)
| Application Number | Title | Priority Date | Filing Date |
|---|---|---|---|
| JP2023541364A Active JP7616396B2 (ja) | 2021-01-12 | 2021-12-03 | 訓練装置、訓練方法、及びプログラム |
Country Status (4)
| Country | Link |
|---|---|
| US (1) | US20240054349A1 (ja) |
| EP (1) | EP4278311A4 (ja) |
| JP (1) | JP7616396B2 (ja) |
| WO (1) | WO2022153710A1 (ja) |
Families Citing this family (3)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| US12361201B2 (en) * | 2022-05-19 | 2025-07-15 | Salesforce, Inc. | Systems and methods for parameter ensembling for reducing hallucination in abstractive summarization |
| CN118279626A (zh) * | 2022-12-29 | 2024-07-02 | 脸萌有限公司 | 图像处理方法、装置及电子设备 |
| CN116628418B (zh) * | 2023-04-03 | 2026-01-23 | 西南交通大学 | 基于传感器信号与深度迁移学习的易损件失效预测方法 |
Citations (2)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| WO2020094026A1 (zh) | 2018-11-08 | 2020-05-14 | 腾讯科技(深圳)有限公司 | 组织结节检测及其模型训练方法、装置、设备和系统 |
| WO2020235033A1 (ja) | 2019-05-22 | 2020-11-26 | 日本電気株式会社 | データ変換装置、パターン認識システム、データ変換方法及び非一時的なコンピュータ可読媒体 |
-
2021
- 2021-12-03 JP JP2023541364A patent/JP7616396B2/ja active Active
- 2021-12-03 US US18/270,812 patent/US20240054349A1/en active Pending
- 2021-12-03 WO PCT/JP2021/044388 patent/WO2022153710A1/en not_active Ceased
- 2021-12-03 EP EP21919599.7A patent/EP4278311A4/en active Pending
Patent Citations (2)
| Publication number | Priority date | Publication date | Assignee | Title |
|---|---|---|---|---|
| WO2020094026A1 (zh) | 2018-11-08 | 2020-05-14 | 腾讯科技(深圳)有限公司 | 组织结节检测及其模型训练方法、装置、设备和系统 |
| WO2020235033A1 (ja) | 2019-05-22 | 2020-11-26 | 日本電気株式会社 | データ変換装置、パターン認識システム、データ変換方法及び非一時的なコンピュータ可読媒体 |
Also Published As
| Publication number | Publication date |
|---|---|
| JP2024502153A (ja) | 2024-01-17 |
| EP4278311A1 (en) | 2023-11-22 |
| EP4278311A4 (en) | 2024-06-26 |
| US20240054349A1 (en) | 2024-02-15 |
| WO2022153710A1 (en) | 2022-07-21 |
Similar Documents
| Publication | Publication Date | Title |
|---|---|---|
| JP7414901B2 (ja) | 生体検出モデルのトレーニング方法及び装置、生体検出の方法及び装置、電子機器、記憶媒体、並びにコンピュータプログラム | |
| CN112598643B (zh) | 深度伪造图像检测及模型训练方法、装置、设备、介质 | |
| JP7616396B2 (ja) | 訓練装置、訓練方法、及びプログラム | |
| CN113033622A (zh) | 跨模态检索模型的训练方法、装置、设备和存储介质 | |
| US11727109B2 (en) | Identifying adversarial attacks with advanced subset scanning | |
| CN113536845B (zh) | 人脸属性识别方法、装置、存储介质和智能设备 | |
| US11410327B2 (en) | Location determination apparatus, location determination method and computer program | |
| JP7485226B2 (ja) | 訓練装置、分類装置、訓練方法、分類方法、及びプログラム | |
| Wu et al. | Dual autoencoders generative adversarial network for imbalanced classification problem | |
| Hassanat et al. | Magnetic force classifier: a novel method for big data classification | |
| EP3975071A1 (en) | Identifying and quantifying confounding bias based on expert knowledge | |
| CN116993513A (zh) | 金融风控模型解释方法、装置及计算机设备 | |
| CN110298024A (zh) | 涉密文档的检测方法、装置及存储介质 | |
| CN113420699A (zh) | 一种人脸匹配方法、装置及电子设备 | |
| JP7544254B2 (ja) | 学習装置、学習方法、及びプログラム | |
| CN111385601B (zh) | 一种视频审核的方法、系统及设备 | |
| CN111353554A (zh) | 预测缺失的用户业务属性的方法及装置 | |
| CN115359574A (zh) | 人脸活体检测及相应模型的训练方法、装置及存储介质 | |
| JP7517417B2 (ja) | 敵対的サンプル検知装置、敵対的サンプル検知方法、およびプログラム | |
| CN119004236A (zh) | 一种自适应模型的训练方法、装置及终端设备 | |
| CN115222747B (zh) | 基于拓扑感知的点云分割网络构建方法、分割方法及装置 | |
| WO2025173264A1 (ja) | 学習装置、学習方法、及び学習プログラム | |
| CN118196863A (zh) | 活体检测网络、活体检测模型、活体检测方法和装置 | |
| CN107644251B (zh) | 对象分类方法、装置和系统 | |
| CN107862328A (zh) | 信息元集合生成方法及基于规则引擎的规则执行方法 |
Legal Events
| Date | Code | Title | Description |
|---|---|---|---|
| A521 | Request for written amendment filed |
Free format text: JAPANESE INTERMEDIATE CODE: A523 Effective date: 20230706 |
|
| A621 | Written request for application examination |
Free format text: JAPANESE INTERMEDIATE CODE: A621 Effective date: 20230706 |
|
| A131 | Notification of reasons for refusal |
Free format text: JAPANESE INTERMEDIATE CODE: A131 Effective date: 20240827 |
|
| A521 | Request for written amendment filed |
Free format text: JAPANESE INTERMEDIATE CODE: A523 Effective date: 20241002 |
|
| TRDD | Decision of grant or rejection written | ||
| A01 | Written decision to grant a patent or to grant a registration (utility model) |
Free format text: JAPANESE INTERMEDIATE CODE: A01 Effective date: 20241203 |
|
| A61 | First payment of annual fees (during grant procedure) |
Free format text: JAPANESE INTERMEDIATE CODE: A61 Effective date: 20241216 |
|
| R150 | Certificate of patent or registration of utility model |
Ref document number: 7616396 Country of ref document: JP Free format text: JAPANESE INTERMEDIATE CODE: R150 |
|
| RD04 | Notification of resignation of power of attorney |
Free format text: JAPANESE INTERMEDIATE CODE: R3D04 |















