JP7287490B2 - 学習装置、学習方法、及び、プログラム - Google Patents

学習装置、学習方法、及び、プログラム Download PDF

Info

Publication number
JP7287490B2
JP7287490B2 JP2021554809A JP2021554809A JP7287490B2 JP 7287490 B2 JP7287490 B2 JP 7287490B2 JP 2021554809 A JP2021554809 A JP 2021554809A JP 2021554809 A JP2021554809 A JP 2021554809A JP 7287490 B2 JP7287490 B2 JP 7287490B2
Authority
JP
Japan
Prior art keywords
class
prediction
classes
grouping
target data
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Active
Application number
JP2021554809A
Other languages
English (en)
Other versions
JPWO2021090518A5 (ja
JPWO2021090518A1 (ja
Inventor
瑛士 金子
あずさ 澤田
和俊 鷺
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
NEC Corp
Original Assignee
NEC Corp
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by NEC Corp filed Critical NEC Corp
Publication of JPWO2021090518A1 publication Critical patent/JPWO2021090518A1/ja
Publication of JPWO2021090518A5 publication Critical patent/JPWO2021090518A5/ja
Application granted granted Critical
Publication of JP7287490B2 publication Critical patent/JP7287490B2/ja
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/243Classification techniques relating to the number of classes
    • G06F18/2431Multiple classes
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • G06F18/241Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
    • G06F18/2415Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on parametric or probabilistic models, e.g. based on likelihood ratio or false acceptance rate versus a false rejection rate
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N7/00Computing arrangements based on specific mathematical models
    • G06N7/01Probabilistic graphical models, e.g. probabilistic networks

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Data Mining & Analysis (AREA)
  • Physics & Mathematics (AREA)
  • General Physics & Mathematics (AREA)
  • Evolutionary Computation (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Software Systems (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Evolutionary Biology (AREA)
  • Computing Systems (AREA)
  • Mathematical Physics (AREA)
  • Probability & Statistics with Applications (AREA)
  • Medical Informatics (AREA)
  • Algebra (AREA)
  • Computational Mathematics (AREA)
  • Mathematical Analysis (AREA)
  • Mathematical Optimization (AREA)
  • Pure & Applied Mathematics (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)
  • Image Analysis (AREA)

Description

本発明は、画像に基づいて物体を識別する技術に関する。
近年、深層学習を用いたニューラルネットワークによる物体識別手法が提案されている。物体識別器は、物体識別モデルを用いて画像から対象物を検出し、その対象物が複数のクラスのいずれに該当するかを示す確率をクラス毎に出力する。通常、学習時には、物体識別器が予測した複数のクラスと、予め用意された、正解を示す複数のクラスとを用いて、クラス毎に差を表す指標を算出し、それらの総和に基づいて物体識別モデルのパラメータが更新される。
一方、物体識別モデルが出力した予測確率が上位である複数のクラスに着目して処理を行う手法が提案されている。例えば、特許文献1は、判定モデルによる予測スコアが上位の所定数に属するデータから正解率を算出し、その正解率に基づいて判定モデルの更新が必要であるか否かを決定する学習方法を記載している。
国際公開WO2014/155690号公報
通常の物体識別器は、入力画像から1つのクラスを高い精度で予測するように学習されるが、入力画像の撮影環境などによっては、予測結果を1つのクラスに絞ると精度が低下してしまう場合がある。このような場合、精度が低下してしまうよりは、複数のクラスの中に高い確率で正解が含まれるという予測結果が得られる方がよいことがある。
本発明の1つの目的は、対象物が複数のクラスの中に高い確率で含まれることを示す予測結果を出力するモデルを生成することにある。
本発明の一つの観点では、学習装置は、
予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力する予測手段と、
前記クラス毎の予測確率に基づいて、前記予測確率が上位のk個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出するグループ化手段と、
前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出する損失算出手段と、
算出された損失に基づいて、前記予測モデルを更新するモデル更新手段と、
を備える。
本発明の他の観点では、学習方法は、
予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力し、
前記クラス毎の予測確率に基づいて、前記予測確率が上位のk個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出し、
前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出し、
算出された損失に基づいて、前記予測モデルを更新する。
本発明の他の観点では、プログラムは、
予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力し、
前記クラス毎の予測確率に基づいて、前記予測確率が上位k個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出し、
前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出し、
算出された損失に基づいて、前記予測モデルを更新する処理をコンピュータに実行させる
本発明によれば、対象物が複数のクラスの中に高い確率で含まれることを示す予測結果を出力するモデルを生成することができる。
第1実施形態に係る学習装置のハードウェア構成を示す。 第1実施例に係る学習装置の機能構成を示すブロック図である。 第1実施例による学習処理のフローチャートである。 複数のクラスをグループ化する方法の例を示す。 第2実施例に係る学習装置の機能構成を示すブロック図である。 第2実施例による学習処理のフローチャートである。 第3実施例に係る学習装置の機能構成を示すブロック図である。 第3実施例による学習処理のフローチャートである。 情報統合システムの構成を示すブロック図である。 第2実施形態に係る学習装置の機能構成を示すブロック図である。
以下、図面を参照して、本発明の好適な実施形態について説明する。
[第1実施形態]
(ハードウェア構成)
図1は、第1実施形態に係る学習装置のハードウェア構成を示すブロック図である。図示のように、学習装置100は、入力IF(InterFace)12と、プロセッサ13と、メモリ14と、記録媒体15と、データベース(DB)16と、を備える。
入力IF12は、学習装置100の学習に用いられるデータを入力する。具体的には、後述する訓練用入力データ及び訓練用目標データが入力IF12を通じて入力される。プロセッサ13は、CPU(Central Processing Unit)又はGPU(Graphics Processing Unit)などのコンピュータであり、予め用意されたプログラムを実行することにより、学習装置100の全体を制御する。具体的に、プロセッサ13は、後述する学習処理を実行する。
メモリ14は、ROM(Read Only Memory)、RAM(Random Access Memory)などにより構成される。メモリ14は、プロセッサ13により実行される各種のプログラムを記憶する。また、メモリ14は、プロセッサ13による各種の処理の実行中に作業メモリとしても使用される。
記録媒体15は、ディスク状記録媒体、半導体メモリなどの不揮発性で非一時的な記録媒体であり、学習装置100に対して着脱可能に構成される。記録媒体15は、プロセッサ13が実行する各種のプログラムを記録している。学習装置100が各種の処理を実行する際には、記録媒体15に記録されているプログラムがメモリ14にロードされ、プロセッサ13により実行される。
データベース16は、入力IF12を含む外部装置から入力されるデータを記憶する。具体的には、データベース16には、学習装置100の学習に使用されるデータが記憶される。なお、上記に加えて、学習装置100は、ユーザが指示や入力を行うためのキーボード、マウスなどの入力機器や、表示部を備えていても良い。
(第1実施例)
次に、第1実施形態の第1実施例について説明する。
(1)機能構成
図2は、第1実施例に係る学習装置100の機能構成を示すブロック図である。図示のように、学習装置100は、予測部20と、グループ化部30と、損失算出部40と、モデル更新部50とを備える。学習時には、訓練用入力データ(以下、単に「入力データ」と呼ぶ。)xtrainと、訓練用目標データ(以下、単に「目標データ」と呼ぶ。)ttrainが用意される。入力データxtrainは予測部20に入力され、目標データttrainはグループ化部30に入力される。また、学習の対象となる初期モデルf(winit)はモデル更新部50に入力される。なお、学習の開始時には、初期モデルf(winit)が予測部20に設定されている。
予測部20は、内部に設定されている初期モデルf(winit)を用いて、入力データxtrainの予測を行う。入力データxtrainは画像データであり、予測部20はその画像データから特徴抽出を行い、抽出された特徴量に基づいて画像データに含まれる対象物を予測し、クラス分類を行う。予測部20は、予測結果として予測分類情報yを出力する。予測分類情報yは、入力データxtrainが各クラスである予測確率を出力する。具体的に、予測分類情報yは、以下の式で与えられる。
Figure 0007287490000001
ここで、「N」はクラス数である。なお、添え字「b」は、学習の回数を示す。よって、初期モデルf(winit)に基づいて最初に得られる予測結果は、予測分類情報yとなる。
グループ化部30は、並び替え部31と、変形部32とを備える。並び替え部31には、目標データttrainが入力される。目標データttrainは、以下の式で与えられる。
Figure 0007287490000002
並び替え部31は、予測分類情報yを大きさ順に、即ち予測確率の大きい順に並び替え、以下の予測分類情報y’を求める。
Figure 0007287490000003
また、並び替え部31は、予測分類情報yと同じ順序、即ち、予測分類情報yの大きさ順に目標データttrainを並び替え、以下の目標データt’を生成する。
Figure 0007287490000004
次に、変形部32は、予測確率の上位k個のクラスを1つのクラスにまとめる。具体的に、変形部32は、予測確率が上位のk個のクラスにより1つのクラス(以下、「topkクラス」と呼ぶ。)を作る。そして、変形部32は、以下の式により、予測分類情報y’の上位k個のクラスの予測確率の和をtopkクラスの予測確率y’topkとして算出する。
Figure 0007287490000005
そして、変形部32は、式(3)に示す予測分類情報y’の上位k個のクラスの予測確率を、以下のようにtopkクラスの予測確率y’b,topkに置換する。
Figure 0007287490000006
同様に、変形部32は、以下の式により、予測分類情報y’の上位k個のクラスについて、目標データt’の値の和をtopkクラスの目標データの値t’topkとして算出する。
Figure 0007287490000007
そして、変形部32は、式(4)に示す目標データt’の上位k個のクラスの値を、topkクラスの目標データの値t’topkに置換する。
Figure 0007287490000008
こうして、変形部32は、topkクラスに対応する予測確率を置換した予測分類情報(以下、「グループ化予測分類情報」と呼ぶ。)y’と、topkクラスに対応する値を置換した目標データ(以下、「グループ化目標データ」と呼ぶ。)t’を、グループ化分類情報(y’,t’)として損失算出部40に出力する。
損失算出部40は、グループ化分類情報(y’,t’)を用いて、以下の式により損失Ltopkを算出する。
Figure 0007287490000009
もしくは、損失算出部40は、グループ化分類情報(y’,t’)を用いて、以下の式により損失Ltopkを算出してもよい。
Figure 0007287490000010
モデル更新部50は、損失Ltopkに基づいて、モデル更新部50内に設定されているモデルのパラメータを更新して更新済みモデルf(w)を生成し、これをモデル更新部50及び予測部20に設定する。例えば、最初の更新では、モデル更新部50及び予測部20に設定されている初期モデルf(winit)が、更新済みモデルf(w)に更新される。
モデル更新部50は、所定の終了条件が具備されるまで上記の処理を繰り返し、終了条件が具備されると学習を終了する。終了条件は、例えば、モデルのパラメータが所定回数更新されたこと、用意された所定量の目標データを使用したこと、モデルのパラメータが所定値に収束したことなどとすることができる。そして、学習を終了した時点の更新済みモデルf(w)が、訓練済みモデルf(wtrained)として出力される。
(2)学習処理
図3は、第1実施例による学習処理のフローチャートである。この処理は、図1に示すプロセッサ13が予め用意されたプログラムを実行し、図2に示す各要素として動作することにより実現される。なお、学習処理の開始時には、予測部20及びモデル更新部50には、初期モデルf(winit)が設定されている。
まず、予測部20は、入力データxtrainの予測を行い、予測結果として式(1)に示す予測分類情報yを出力する(ステップS11)。次に、グループ化部30の並び替え部31は、式(3)及び式(4)に示すように、予測分類情報yと、訓練用目標データttrainを並び替える(ステップS12)。
次に、グループ化部30の変形部32は、並び替え後の予測分類情報y’の予測確率の上位k個から、式(5)に示すtopkクラスの予測確率y’topkを算出し、式(6)に示すようにtopkクラスを構成するk個のクラスの予測確率をtopkクラスの予測確率y’b,topkに置き換えてグループ化予測分類情報y’を生成する(ステップS13)。また、変形部32は、式(7)に示すtopkクラスの目標データの値t’topkを算出し、式(8)に示すように目標データt’におけるtopkクラスを構成するk個のクラスの目標データの値をtopkクラスの目標データの値t’topkに置き換えて、グループ化目標データt’を生成する(ステップS14)。
次に、損失算出部40は、グループ化予測分類情報y’と、グループ化目標データt’とを用いて、式(9)又は式(9’)により損失Ltopkを算出する(ステップS15)。次に、モデル更新部50は、損失Ltopkが小さくなるように、モデルのパラメータを更新し、更新済みモデルf(w)を予測部20及びモデル更新部50に設定する(ステップS16)。
次に、モデル更新部50は、所定の終了条件が具備されたか否かを判定する(ステップS17)。終了条件が具備されていない場合(ステップS17:No)、次の入力データxtrain及び目標データttrainを用いて、ステップS11~S16の処理が行われる。一方、終了条件が具備された場合(ステップS17:Yes)、処理は終了する。
以上のように、第1実施例では、予測分類情報yが示す予測確率が上位のk個のクラスをtopkクラスという1つのクラスとみなして損失を算出し、モデルのパラメータを更新する。よって、学習により得られるモデルは、予測確率の上位k個に正解があることを高精度で検出することが可能となる。
(3)グループ化方法
本実施例では、複数のクラスをグループ化する方法としては以下のものが考えられる。以下、グループ化により作成されたクラスを「グループ化クラス」と呼ぶ。
(A)上位k個をグループ化
図4(A)は、予測確率の上位k個をグループ化する方法を示す。この方法で得られたグループ化クラスが上記のtopkクラスである。前述のように、グループ化部30は、予測分類情報yが示す各クラスの予測確率を大きさ順に並び替え、上位k個のクラスをグループ化して1つのグループ化クラスとする。例えば、k=3とすると、予測確率が上位の3クラスによりグループ化クラスが構成される。
(B)(k+1)位以下をグループ化
図4(B)は、予測確率の(k+1)位以下をグループ化する方法を示す。この方法は、予測分類情報yが示す各クラスの予測確率を大きさ順に並び替え、上位k個以外のクラス、即ち、予測確率が上位k+1以下であるクラスをグループ化して1つのグループ化クラスとする。例えば、k=3とすると、予測確率が上位である3クラス以外のクラスによりグループ化クラスが構成される。この場合、グループ化クラスの予測確率は、予測確率の上位k個に正解が含まれない確率を示すものとなる。
(C)上位k個と(k+1)以下の両方をグループ化
上記の上位k個をグループ化する方法と、(k+1)位以下をグループ化する方法を併用してもよい。
(D)1位と上位k個の両方をグループ化
図4(C)は、予測確率の1位と上位k個の両方をグループ化する方法を示す。この方法では、予測分類情報yが示す各クラスの予測確率のうち、1位のクラスと、前述のtopkクラスの両方を使用する。k=3の例では、予測確率が上位3位までのクラスをまとめてtop3クラスを作成し、さらに予測確率が1位のクラス(「top1クラス」と呼ぶ。)をtop3クラスとは別に1つのクラスとして取り扱う。この場合、topkクラスに正解がある確率が高くなると同時に、top1クラスが正解となる確率が高くなるようにモデルの学習が行われる。
上記のグループ化方法では、グループ化するクラス数「k」が予め決まっているものとしているが、その代わりに、グループ化部30がkの値を自動推定するようにしてもよい。この場合の第1の方法では、グループ化部30は、上位k個のクラスの予測確率がいずれも既定値以上になるようにkの値を決める。この方法では、既定値以上の予測確率を有する複数のクラスによりグループ化クラスが構成される。即ち、「k」の値は、規定値以上の予測確率を有するクラス数となる。第2の方法では、グループ化部30は、上位k個のクラスの累積予測確率が既定値以上になるようにkの値を決める。この方法では、例えば、予測確率が1位~4位までのクラスの累積予測確率が既定値以上となる場合、上位4クラスによりグループ化クラスを構成する。
(4)グループ化クラスの予測確率
上記の実施形態では、式(5)に示すように、グループ化クラスに属する複数のクラスの予測確率の和をそのグループ化クラスの予測確率としている。この方法は、1つの入力データがいずれか1つのクラスを持つ場合に使用される。これに対し、1つの入力データが複数の分類結果を同時に持ちうる問題(いわゆるマルチクラス問題)の場合には、グループ化クラスの予測確率は、「k個のどのクラスでもない事象」の背反事象の確率となり、以下の式で与えられる。
Figure 0007287490000011
(第2実施例)
次に、本発明の第2実施例について説明する。第1実施例では、topkクラスについて、予測分類情報y’と目標データt’を変形し、損失を求めている。その代わりに、第2実施例では、topkクラスについて目標データt’のみを変形し、損失を求める。
(1)機能構成
図5は、第2実施例に係る学習装置100xの機能構成を示すブロック図である。図示のように、学習装置100xは、第1実施形態に係る学習装置100におけるグループ化部30の代わりにグループ化部60を備える。グループ化部60は、並び替え部61と、目標変形部62を備える。予測部20から出力される予測分類情報yは、グループ化部60と損失算出部40に入力される。この点以外は、学習装置100xの構成は第1実施形態の学習装置100と同様であるので、共通する部分の説明は行わない。
予測部20は、入力データxtrainの予測を行い、予測分類情報yをグループ化部60及び損失算出部40に出力する。グループ化部60の並び替え部61は、予測分類情報yが示す予測確率の大きさ順にクラスを並べ替え、上記の式(3)及び(4)により並び替え後の予測分類情報y’と目標データt’を算出し、上位のk個のクラスをtopkクラスとして選出する。
目標変形部62は、予測分類情報y’を用いて以下の式により目標データt’を変形し、変形後の目標データ(以下、「変形目標データ」と呼ぶ。)t’’を算出する。
Figure 0007287490000012
ここで、式(11)はtopkクラスに対する変形目標データt’’を示し、式(12)はtopkクラス以外のクラスに対する変形目標データt’’を示す。例えば、目標データt’における正解クラス(値が「1」であるクラス)がtopkクラスに含まれる場合、topkクラスに属する各クラスの値t’’は、値「1」を各クラスの予測確率で各クラスに配分した値となる。この場合、topkクラス以外のクラスの変形目標データt’’の値は全て「0」となる。一方、目標データt’における正解のクラスがtopkクラス以外のクラスに含まれる場合、topkクラスに属する各クラスの値t’’は全て「0」となり、topkクラス以外のクラスの変形目標データt’’の値は変形前の目標データt’と同一となる。即ち、変形前の目標データt’と同じクラスが正解クラス(値が「1」)となる。目標変形部62は、こうして算出した変形目標データt’’を損失算出部40に出力する。
損失算出部40は、変形目標データt’’と、予測分類情報y’とを用いて、以下の式により損失Ltopkを算出する。
Figure 0007287490000013
もしくは、損失算出部40は、変形目標データt’’と、予測分類情報y’とを用いて、以下の式により損失Ltopkを算出してもよい。
Figure 0007287490000014
モデル更新部50は、第1実施例と同様に、損失Ltopkに基づいて、モデル更新部50内に設定されているモデルのパラメータを更新して更新済みモデルf(w)を生成し、これをモデル更新部50及び予測部20に設定する。
(2)学習処理
図6は、第2実施例による学習処理のフローチャートである。この処理は、図1に示すプロセッサ13が予め用意されたプログラムを実行し、図5に示す各要素として動作することにより実現される。なお、学習処理の開始時には、予測部20及びモデル更新部50には、初期モデルf(winit)が設定されている。
まず、予測部20は、入力データxtrainに基づいて予測を行い、予測結果として式(1)に示す予測分類情報yを出力する(ステップS21)。次に、グループ化部60の並び替え部61は、式(3)及び式(4)に示すように、予測分類情報yと、目標データttrainを並び替える(ステップS22)。
次に、グループ化部60の目標変形部62は、予測分類情報y’を用いて式(11)及び(12)により目標データt’を変形し、変形目標データt’’を算出する(ステップS23)。
次に、損失算出部40は、変形目標データt’’と、予測分類情報y’とを用いて、式(13)又は式(13’)により損失Ltopkを算出する(ステップS24)。次に、モデル更新部50は、損失Ltopkが小さくなるように、モデルのパラメータを更新し、更新済みモデルf(w)を予測部20及びモデル更新部50に設定する(ステップS25)。
次に、モデル更新部50は、所定の終了条件が具備されたか否かを判定する(ステップS26)。終了条件が具備されていない場合(ステップS26:No)、次の入力データxtrain及び目標データttrainを用いて、ステップS21~S25の処理が行われる。一方、終了条件が具備された場合(ステップS26:Yes)、処理は終了する。
以上のように、第2実施例では、目標データのみを変形することにより、予測確率の上位k個に正解があることを高精度で検出するモデルを生成することができる。
(3)グループ化方法
第2実施例においても、第1実施形態と同様に、(A)~(D)の方法で複数のクラスをグループ化することができる。
(4)グループ化クラスの目標データ
(A)上位k個をグループ化
この場合の変形目標データt’’は、前述の式(11)及び(12)で与えられる。
(B)(k+1)位以下をグループ化
この場合の変形目標データt’’は以下の式で与えられる。
Figure 0007287490000015
ここで、式(14)は上位k個のクラスに対する変形目標データt’’を示し、式(15)は上位k個のクラス以外に対する変形目標データt’’を示す。式(15)は上位k個のクラスに正解が含まれない場合に「0」以外の値をとるため、関数g(j)の符号をマイナス(-)とし、上位k個のクラスに正解が含まれない場合に損失の値が大きくなるようにしている。
(C)上位k個と(k+1)以下の両方をグループ化
この場合の変形目標データt’’は以下の式で与えられる。
Figure 0007287490000016
ここで、式(16)は上位k個のクラスに対する変形目標データt’’を示し、式(17)は上位k個以外のクラスに対する変形目標データt’’を示す。式(16)では、目標データt’における正解クラスが上位k個のクラスに含まれる場合、上位k個のクラスの値t’’は、正解クラスを示す値「1」を各クラスの予測確率で各クラスに配分した値を2倍したものとなる。式(17)は前述の式(15)と同様である。
(D)1位と上位k個の両方をグループ化
この場合の変形目標データt’’は以下の式で与えられる。
Figure 0007287490000017
ここで、式(18)は1位のクラスに対する変形目標データt’’を示し、式(19)は、上位2位~k位のクラスに対する変形目標データt’’を示す。「w」は、1位と上位k個のうち1位を重視する割合を示す重みであり、「0」~「1」の値に設定される。
なお、上記の各式において、関数g(j)は以下のいずれかを用いることができる。
Figure 0007287490000018
(第3実施例)
次に、本発明の第3実施例について説明する。第1実施例では、topkクラスについて、予測分類情報y’bと目標データt’を変形し、損失を求めている。第3実施例では、代わりに、topkクラスについて、グループ化するクラスの数であるkを変えて、予測分類情報yと目標データt’とを複数組生成し、生成された複数組のグループ化分類情報(y’,t’)を用いて単一の損失を混合損失として求める。
(1)機能構成
図7は、第3実施例に係る学習装置100yの機能構成を示すブロック図である。図示のように、この学習装置100yは、第1実施例に係る学習装置100におけるグループ化部30の代わりに複数グループ化部30yを備え、損失算出部40の代わりに混合損失算出部40yを備える。予測部20、モデル更新部50は、第1実施例と同じである。
複数グループ化部30y部は、第1実施例のグループ化部30と同じ動作を、グループ化するクラスの数であるkをk,k,…,kNkと変えて複数回行い、それぞれのkに対して、グループ化予測分類情報yと、グループ化目標データt’とを生成する。結果として、複数グループ化部30yは、N組のグループ化分類情報(y’,t’)を生成する。
混合損失算出部40yは、複数グループ化部30yが生成した複数組の、グループ化予測分類情報yと、グループ化目標データt’とを用いて混合損失Lmixを算出する。混合損失算出部40yは、例えば、kがある値kのときの、グループ化目標データt’とグループ化予測分類情報yの差異の程度を示す損失関数L(tki’,yki)と、予測結果yや目標データt、学習回数b等に依存する既定の関数αki(y,t,b)を用いた以下の式により算出する。
Figure 0007287490000019
この式(20)は、グループ化予測分類情報yと、グループ化目標データt’とを用いて算出した各kについての損失を合成して混合損失を算出している。
なお、損失関数L(tki’,yki)は、例えば、第1実施例の損失算出部40で算出する損失と同様に、式(9)もしくは式(10)によって算出してもよい。また、既定の関数αは既定の値であってもよい。
また、混合損失算出部40yは、上記の損失関数と既定の関数とを用いた、以下の式により混合損失Lmixを算出してもよい。
Figure 0007287490000020
この式(21)は、グループ化予測分類情報yb’kと、グループ化目標データt’kとを用いて算出した各kについての損失を比較し、最大の値を混合損失としている。なお、既定の関数αは既定の値であってもよい。
また、混合損失算出部40yは、上記の損失関数と既定値a,b,c,dとを用いて、以下の式により混合損失Lmixを算出してもよい。
Figure 0007287490000021
この式(22)は、グループ化目標データt’を既定値a,bを用いて変形した値と、グループ化予測分類情報yを既定値c,dを用いて変形した値とを用いて混合損失を算出している。
また、上記の式(22)を用いて例えば、k={1,m}のとき、
Figure 0007287490000022
として、混合損失Lmixを算出してもよい。
(2)学習処理
図8は、第3実施例による学習処理のフローチャートである。この処理は、図1に示すプロセッサ13が予め用意されたプログラムを実行し、図7に示す各要素として動作することにより実現される。なお、学習処理の開始時には、予測部20及びモデル更新部50には、初期モデルf(winit)が設定されている。
まず、予測部20は、入力データxtrainの予測を行い、予測結果として式(1)に示す予測分類情報yを出力する(ステップS31)。次に、複数グループ化部30yの並び替え部31は、式(3)及び式(4)に示すように、予測分類情報yと、訓練用目標データttrainを並び替える(ステップS32)。
次に、複数グループ化部30yの変形部32は、あるクラス数kについて、並び替え後の予測分類情報y’の予測確率の上位k個から、式(5)に示すtopkクラスの予測確率y’topkを算出し、式(6)に示すようにtopkクラスを構成するk個のクラスの予測確率をtopkクラスの予測確率y’b,topkに置き換えてグループ化予測分類情報y’を生成する(ステップS33)。また、変形部32は、式(7)に示すtopkクラスの目標データの値t’topkを算出し、式(8)に示すように目標データt’におけるtopkクラスを構成するk個のクラスの目標データの値をtopkクラスの目標データの値t’topkに置き換えて、グループ化目標データt’を生成する(ステップS34)。
次に、複数グループ化部30yは、グループ化分類情報(y’,t’)をN組生成したか否かを判定する(ステップS35)。複数グループ化部30yがグループ化分類情報(y’b,t’)をN組生成していない場合(ステップS35:No)、処理はステップS32へ戻り、複数グループ化部30yは次のクラス数kに対してグループ化分類情報(y’,t’)を生成する。
一方、複数グループ化部30yがグループ化分類情報(y’,t’)をN組生成した場合(ステップS35:Yes)、混合損失算出部40yは、前述の式20~22のいずれかを用いて、損失Lmixを算出する(ステップS36)。次に、モデル更新部50は、損失Lmixが小さくなるように、モデルのパラメータを更新し、更新済みモデルf(w)を予測部20及びモデル更新部50に設定する(ステップS37)。
次に、モデル更新部50は、所定の終了条件が具備されたか否かを判定する(ステップS38)。終了条件が具備されていない場合(ステップS38:No)、次の入力データxtrain及び目標データttrainを用いて、ステップS31~S37の処理が行われる。一方、終了条件が具備された場合(ステップS38:Yes)、処理は終了する。
以上のように、第3実施例では、複数組のグループ化分類情報を用いて混合損失を求め、モデルの学習を行うので、複数組のtopkクラスの精度を両立するようにモデルを学習することが可能となる。例えば、k=1、3の2組のグループ化分類情報を用いて混合損失を求めて学習を行なえば、top1クラスの精度とtop3クラスの精度を両立させることが可能なモデルを生成することができる。
(情報統合システム)
次に、第1実施形態に係る情報統合システムについて説明する。図9は、情報統合システム200の構成を示すブロック図である。情報統合システム200は、図示のように、第1実施例に係る学習装置100又は第2実施例に係る学習装置100xと、分類装置210と、関連情報DB220と、情報統合部230とを備える。
学習装置100又は100xは、上述のように、入力データxtrain及び目標データttrainを用いて初期モデルf(winit)を学習し、訓練済みモデルf(wtrained)を生成する。分類装置210は、訓練済みモデルf(wtrained)を用いてクラス分類を行う装置であり、実用入力データxが入力される。実用入力データxは、実際の分類対象となる画像データである。分類装置210は、訓練済みモデルf(wtrained)を用いて実用入力データxの分類を行い、1次分類結果R1を生成して情報統合部230へ出力する。1次分類結果R1は、第1実施例に係る学習装置100又は第2実施例に係る学習装置100xにより生成され、上述のtopkクラスの予測確率、つまり対象物がtopkクラスを構成するいずれかのクラスである確率を含む。言い換えると、分類装置210は、多数の対象物をk個に絞った1次分類結果R1を出力する。
関連情報DBは、関連情報Iを記憶している。関連情報Iは、実用入力データxの分類を行う際に使用される追加情報であり、実用入力データxとは別のルートや手法などにより得た情報である。例えば、実用入力データがカメラによる撮影画像である場合に、レーダやセンサを用いて得たセンサ画像を関連情報Iとして使用することができる。
情報統合部230は、分類装置210から1次分類結果R1を取得すると、その実用入力データxに対応する関連情報Iを関連情報DB220から取得する。そして、情報統合部230は、取得した関連情報Iを用いて、1次分類結果R1が示すk個のクラスから、最終的に1つのクラスを決定して最終分類結果Rfとして出力する。即ち、情報統合部230は、分類装置210が絞り込んだk個のクラスを、さらに1つのクラスに絞り込む処理を行う。なお、情報統合部230は、実用入力データxに関する複数の関連情報Iを用いて最終分類結果Rfを生成してもよい。上記の構成において、分類装置210は本発明の1次分類装置の一例であり、情報統合部230は本発明の2次分類装置の一例である。
上記の情報統合システムにおいては、実用入力データxに対応する関連情報Iが用意されているので、分類装置210は実用入力データxの分類結果を1つのクラスまで絞り込む必要はない。即ち、分類装置210は、実用入力データxが高い確率でtopkクラスに含まれることを検出できればよい。このように、第1実施形態に係る学習装置100及び100xは、上記の情報統合システムのような付加情報を使用できるシステムに好適に適用することができる。
[第2実施形態]
次に、本発明の第2実施形態について説明する。図10は、第2実施形態に係る学習装置の機能構成を示すブロック図である。なお、学習装置80のハードウェア構成は、図1と同様である。図示のように、学習装置80は、予測部81と、グループ化部82と、損失算出部83と、モデル更新部84とを備える。
予測部81は、予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力する。グループ化部82は、クラス毎の予測確率に基づいて、予測確率が上位のk個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出する。損失算出部83は、グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出する。モデル更新部84は、算出された損失に基づいて、予測モデルを更新する。これにより、学習装置80は、予測確率が上位k個のクラスについての予測確率を高精度で出力するモデルを生成することができる。
上記の実施形態の一部又は全部は、以下の付記のようにも記載されうるが、以下には限られない。
(付記1)
予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力する予測部と、
前記クラス毎の予測確率に基づいて、前記予測確率が上位のk個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出するグループ化部と、
前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出する損失算出部と、
算出された損失に基づいて、前記予測モデルを更新するモデル更新部と、
を備える学習装置。
(付記2)
前記グループ化クラスの予測確率は、当該グループ化クラスを構成するk個のクラスのいずれかに正解が含まれる確率である付記1に記載の学習装置。
(付記3)
前記グループ化部は、前記予測部が出力したクラス毎の予測確率を大きさ順に並び替え、前記k個のクラスを決定する付記1又は2に記載の学習装置。
(付記4)
前記グループ化部は、前記グループ化クラスを構成するk個のクラスの予測確率を当該グループ化クラスの予測確率に置き換えた変形予測結果と、前記グループ化クラスを構成するk個のクラスの目標データの値を当該グループ化クラスの目標データの値に置き換えた変形目標データと、を生成する変形部を備え、
前記損失算出部は、前記変形予測結果と、前記変形目標データとに基づいて前記損失を計算する付記1乃至3のいずれか一項に記載の学習装置。
(付記5)
前記変形部は、前記グループ化クラスを構成するk個のクラスの予測確率の和を当該グループ化クラスの予測確率とし、前記グループ化クラスを構成するk個のクラスに含まれる目標データの値の和を当該グループ化クラスの目標データの値とする付記4に記載の学習装置。
(付記6)
前記グループ化部は、前記グループ化クラスを構成するk個のクラスの予測確率を用いて目標データを変形して変形目標データを生成する変形部を備え、
前記損失算出部は、前記予測部から出力された予測結果と、前記変形目標データとに基づいて前記損失を計算する付記1乃至3のいずれか一項に記載の学習装置。
(付記7)
前記変形部は、前記グループ化クラスを構成するk個のクラスの目標データの値の和を、当該k個のクラスの予測確率に応じて配分した値を、前記k個のクラス各々の目標データの値とする付記6に記載の学習装置。
(付記8)
前記グループ化部は、前記予測部が出力したクラス毎の予測確率と、既定値とに基づいて前記kの値を決定する付記1乃至7のいずれか一項に記載の学習装置。
(付記9)
前記変形部は、前記kの値を複数用いて、複数組の変形予測結果と変形目標データとを生成し、
前記損失算出部は、前記複数組の変形予測結果と変形目標データとに基づいて、単一の前記損失を算出する付記4又は5に記載の学習装置。
(付記10)
前記損失算出部は、グループ化するクラスの数毎に、前記変形予測結果と、前記変形目標データを用いて算出した損失を合成したものを前記損失とする付記9に記載の学習装置。
(付記11)
前記損失算出部は、グループ化するクラスの数毎に、前記変形予測結果と、前記変形目標データを用いて算出した損失を比較し、最大の値を前記損失とする付記9に記載の学習装置。
(付記12)
前記損失算出部は、グループ化するクラスの数毎に損失を算出する際に、前記変形予測結果の代わりに前記変形予測結果を変形した値を用い、前記変形目標データの代わりに前記変形目標データを変形した値を用いる付記10又は11に記載の学習装置。
(付記13)
付記1乃至12のいずれか一項に記載の学習装置と、
前記学習装置により学習済みの予測モデルを用いて、実用入力データを、前記グループ化クラスを含む複数のクラスに分類する1次分類装置と、
追加情報を用いて、前記実用入力データを、前記グループ化クラスを構成するk個のクラスのいずれかにさらに分類する2次分類装置と、
を備える情報統合システム。
(付記14)
予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力し、
前記クラス毎の予測確率に基づいて、前記予測確率が上位のk個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出し、
前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出し、
算出された損失に基づいて、前記予測モデルを更新する学習方法。
(付記15)
予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力し、
前記クラス毎の予測確率に基づいて、前記予測確率が上位k個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出し、
前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出し、
算出された損失に基づいて、前記予測モデルを更新する処理をコンピュータに実行させるプログラムを記録した記録媒体。
この出願は、2019年11月8日に出願された国際出願PCT/JP2019/043909を基礎とする優先権を主張し、その開示の全てをここに取り込む。
以上、実施形態及び実施例を参照して本発明を説明したが、本発明は上記実施形態及び実施例に限定されるものではない。本発明の構成や詳細には、本発明のスコープ内で当業者が理解し得る様々な変更をすることができる。
10、100、100x 学習装置
20 予測部
30、60 グループ化部
31、61 並び替え部
32 変形部
40 損失算出部
50 モデル更新部
62 目標変形部
200 情報統合システム
210 分類装置
220 関連情報DB
230 情報統合部

Claims (10)

  1. 予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力する予測手段と、
    前記クラス毎の予測確率に基づいて、前記予測確率が上位のk個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出するグループ化手段と、
    前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出する損失算出手段と、
    算出された損失に基づいて、前記予測モデルを更新するモデル更新手段と、
    を備える学習装置。
  2. 前記グループ化クラスの予測確率は、当該グループ化クラスを構成するk個のクラスのいずれかに正解が含まれる確率である請求項1に記載の学習装置。
  3. 前記グループ化手段は、前記予測手段が出力したクラス毎の予測確率を大きさ順に並び替え、前記k個のクラスを決定する請求項1又は2に記載の学習装置。
  4. 前記グループ化手段は、前記グループ化クラスを構成するk個のクラスの予測確率を当該グループ化クラスの予測確率に置き換えた変形予測結果と、前記グループ化クラスを構成するk個のクラスの目標データの値を当該グループ化クラスの目標データの値に置き換えた変形目標データと、を生成する変形手段を備え、
    前記損失算出手段は、前記変形予測結果と、前記変形目標データとに基づいて前記損失を計算する請求項1乃至3のいずれか一項に記載の学習装置。
  5. 前記変形手段は、前記グループ化クラスを構成するk個のクラスの予測確率の和を当該グループ化クラスの予測確率とし、前記グループ化クラスを構成するk個のクラスに含まれる目標データの値の和を当該グループ化クラスの目標データの値とする請求項4に記載の学習装置。
  6. 前記グループ化手段は、前記グループ化クラスを構成するk個のクラスの予測確率を用いて目標データを変形して変形目標データを生成する変形手段を備え、
    前記損失算出手段は、前記予測手段から出力された予測結果と、前記変形目標データとに基づいて前記損失を計算する請求項1乃至3のいずれか一項に記載の学習装置。
  7. 前記変形手段は、前記グループ化クラスを構成するk個のクラスの目標データの値の和を、当該k個のクラスの予測確率に応じて配分した値を、前記k個のクラス各々の目標データの値とする請求項6に記載の学習装置。
  8. 前記グループ化手段は、前記予測手段が出力したクラス毎の予測確率と、既定値とに基づいて前記kの値を決定する請求項1乃至7のいずれか一項に記載の学習装置。
  9. 予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力し、
    前記クラス毎の予測確率に基づいて、前記予測確率が上位のk個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出し、
    前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出し、
    算出された損失に基づいて、前記予測モデルを更新する学習方法。
  10. 予測モデルを用いて入力データを複数のクラスに分類し、クラス毎の予測確率を予測結果として出力し、
    前記クラス毎の予測確率に基づいて、前記予測確率が上位k個に含まれるk個のクラスにより構成されるグループ化クラスを生成し、当該グループ化クラスの予測確率を算出し、
    前記グループ化クラスを含む複数のクラスの予測確率に基づいて損失を算出し、
    算出された損失に基づいて、前記予測モデルを更新する処理をコンピュータに実行させるプログラム
JP2021554809A 2019-11-08 2020-03-03 学習装置、学習方法、及び、プログラム Active JP7287490B2 (ja)

Applications Claiming Priority (3)

Application Number Priority Date Filing Date Title
JPPCT/JP2019/043909 2019-11-08
PCT/JP2019/043909 WO2021090484A1 (ja) 2019-11-08 2019-11-08 学習装置、情報統合システム、学習方法、及び、記録媒体
PCT/JP2020/008844 WO2021090518A1 (ja) 2019-11-08 2020-03-03 学習装置、情報統合システム、学習方法、及び、記録媒体

Publications (3)

Publication Number Publication Date
JPWO2021090518A1 JPWO2021090518A1 (ja) 2021-05-14
JPWO2021090518A5 JPWO2021090518A5 (ja) 2022-06-28
JP7287490B2 true JP7287490B2 (ja) 2023-06-06

Family

ID=75848295

Family Applications (1)

Application Number Title Priority Date Filing Date
JP2021554809A Active JP7287490B2 (ja) 2019-11-08 2020-03-03 学習装置、学習方法、及び、プログラム

Country Status (3)

Country Link
US (1) US20220405534A1 (ja)
JP (1) JP7287490B2 (ja)
WO (2) WO2021090484A1 (ja)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113255824B (zh) * 2021-06-15 2023-12-08 京东科技信息技术有限公司 训练分类模型和数据分类的方法和装置

Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2013250809A (ja) 2012-05-31 2013-12-12 Casio Comput Co Ltd 多クラス識別器、方法、およびプログラム

Family Cites Families (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP5060224B2 (ja) * 2007-09-12 2012-10-31 株式会社東芝 信号処理装置及びその方法

Patent Citations (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP2013250809A (ja) 2012-05-31 2013-12-12 Casio Comput Co Ltd 多クラス識別器、方法、およびプログラム

Also Published As

Publication number Publication date
US20220405534A1 (en) 2022-12-22
WO2021090484A1 (ja) 2021-05-14
JPWO2021090518A1 (ja) 2021-05-14
WO2021090518A1 (ja) 2021-05-14

Similar Documents

Publication Publication Date Title
US7930700B1 (en) Method of ordering operations
JP7232122B2 (ja) 物性予測装置及び物性予測方法
JP7481902B2 (ja) 管理計算機、管理プログラム、及び管理方法
JP7287490B2 (ja) 学習装置、学習方法、及び、プログラム
CN110717537B (zh) 训练用户分类模型、执行用户分类预测的方法及装置
JP2019053491A (ja) ニューラルネットワーク評価装置、ニューラルネットワーク評価方法、およびプログラム
US20210248293A1 (en) Optimization device and optimization method
JP7136217B2 (ja) 決定リスト学習装置、決定リスト学習方法および決定リスト学習プログラム
CN117493920A (zh) 一种数据分类方法及装置
US8495070B2 (en) Logic operation system
US20220366242A1 (en) Information processing apparatus, information processing method, and storage medium
WO2022252694A1 (zh) 神经网络优化方法及其装置
US20040193573A1 (en) Downward hierarchical classification of multivalue data
KR20200052411A (ko) 영상 분류 장치 및 방법
JP2019185121A (ja) 学習装置、学習方法及びプログラム
WO2020218246A1 (ja) 最適化装置、最適化方法、及びプログラム
Christodoulou et al. Improving the performance of classification models with fuzzy cognitive maps
CN114186706A (zh) 基于整数规划的法院案件均衡分配方法、系统及电子设备
JP6463961B2 (ja) 情報処理装置、情報処理方法及びプログラム
JP4202339B2 (ja) 類似事例に基づく予測を行う予測装置および方法
JP7263567B1 (ja) 情報選択システム、情報選択方法及び情報選択プログラム
WO2023209983A1 (ja) パラメータ生成装置、システム、方法およびプログラム
KR102399833B1 (ko) 인공 신경망 기반의 로그 라인을 이용한 시놉시스 제작 서비스 제공 장치 및 그 방법
JP2005157788A (ja) モデル同定装置,モデル同定プログラム及びモデル同定方法
Barai et al. Neuro-fuzzy models for constructability analysis

Legal Events

Date Code Title Description
A521 Request for written amendment filed

Free format text: JAPANESE INTERMEDIATE CODE: A523

Effective date: 20220422

A621 Written request for application examination

Free format text: JAPANESE INTERMEDIATE CODE: A621

Effective date: 20220422

TRDD Decision of grant or rejection written
A01 Written decision to grant a patent or to grant a registration (utility model)

Free format text: JAPANESE INTERMEDIATE CODE: A01

Effective date: 20230425

A61 First payment of annual fees (during grant procedure)

Free format text: JAPANESE INTERMEDIATE CODE: A61

Effective date: 20230508

R151 Written notification of patent or utility model registration

Ref document number: 7287490

Country of ref document: JP

Free format text: JAPANESE INTERMEDIATE CODE: R151