TWI781000B - 機器學習裝置以及方法 - Google Patents

機器學習裝置以及方法 Download PDF

Info

Publication number
TWI781000B
TWI781000B TW110144877A TW110144877A TWI781000B TW I781000 B TWI781000 B TW I781000B TW 110144877 A TW110144877 A TW 110144877A TW 110144877 A TW110144877 A TW 110144877A TW I781000 B TWI781000 B TW I781000B
Authority
TW
Taiwan
Prior art keywords
loss
processor
machine learning
classification model
neural network
Prior art date
Application number
TW110144877A
Other languages
English (en)
Other versions
TW202223770A (zh
Inventor
彭宇劭
湯凱富
張智威
Original Assignee
宏達國際電子股份有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 宏達國際電子股份有限公司 filed Critical 宏達國際電子股份有限公司
Publication of TW202223770A publication Critical patent/TW202223770A/zh
Application granted granted Critical
Publication of TWI781000B publication Critical patent/TWI781000B/zh

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • General Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)
  • Control Of Electric Motors In General (AREA)
  • Numerical Control (AREA)
  • Feedback Control In General (AREA)

Abstract

一種機器學習方法,包括:藉由處理器從記憶體中獲取模型參數,並根據模型參數執行分類模型,其中分類模型包括多個神經網路結構層;藉由處理器根據多個訓練樣本計算第一損失以及第二損失,其中第一損失對應於多個神經網路結構層中的輸出層,第二損失對應於多個神經網路結構層中位於輸出層之前的一者;以及藉由處理器根據第一損失以及第二損失對模型參數執行多個更新操作以訓練分類模型。此外,一種機器學習裝置亦在此揭示。

Description

機器學習裝置以及方法
本揭示有關於一種機器學習技術,且特別是有關於消除偽相關的機器學習技術。
例如機器學習、神經網路等技術正廣泛應用於人工智慧技術領域。人工智慧的重要應用之一是辨識物件(例如,人臉、車牌等)或預測資料(例如,股票預測、醫療預測等)。物件檢測以及資料預測可以經由特徵提取以及特徵分類來實現。
然而,用於特徵提取以及特徵分類的特徵之間通常會發生偽相關(Spurious Correlation),並且偽相關總是會導致物件檢測以及資料預測的預測精確度下降。
本揭示的一態樣揭露一種機器學習方法,包括:藉由處理器從記憶體中獲取模型參數,並根據模型參數執行分類模型,其中分類模型包括多個神經網路結構層;藉由處理器根據多個訓練樣本計算第一損失以及第二損失,其 中第一損失對應於多個神經網路結構層中的輸出層,第二損失對應於多個神經網路結構層中位於輸出層之前的一者;以及藉由處理器根據第一損失以及第二損失對模型參數執行多個更新操作以訓練分類模型。
本揭示的另一態樣揭露一種機器學習裝置,包括記憶體以及處理器。記憶體用以儲存多個指令以及模型參數。處理器連接記憶體,其中處理器用以運行分類模型,並執行多個指令以:從記憶體中獲取模型參數,並根據模型參數執行分類模型,其中分類模型包括多個神經網路結構層;根據多個訓練樣本計算第一損失以及第二損失,其中第一損失對應於多個神經網路結構層中的輸出層,第二損失對應於多個神經網路結構層中位於輸出層之前的一者;以及根據第一損失以及第二損失對模型參數執行多個更新操作以訓練分類模型。
100:機器學習裝置
110:處理器
120:記憶體
MP:模型參數
111:分類模型
SL1、SL2、...SLt:神經網路結構層
Figure 110144877-A0305-02-0018-37
:訓練樣本
Figure 110144877-A0305-02-0018-38
:預測標籤
Figure 110144877-A0305-02-0018-39
:訓練標籤
Figure 110144877-A0305-02-0018-40
:提取特徵
L1:第一損失
L2、L3:第二損失
S210~S230、S221~S223、S224A、224B、S231A、S231B、S232~S233、S220’:步驟
第1圖繪示根據本揭示之一實施例中一種機器學習裝置的示意圖。
第2圖繪示根據本揭示之一實施例中一種機器學習方法的示意圖。
第3圖繪示根據本揭示之一實施例中分類模型以及損失的示意圖。
第4圖繪示在一些實施例當中第2圖中一步驟的詳細步 驟的流程圖。
第5圖繪示在另一些實施例當中第2圖中一步驟的詳細步驟的流程圖。
第6圖繪示在一些實施例當中第2圖中另一步驟的詳細步驟的流程圖。
第7圖繪示在一些實施例當中第2圖中額外步驟的流程圖。
第8圖繪示在另一些實施例當中第2圖中另一步驟的詳細步驟的流程圖。
現在將詳細參照本揭示的當前實施例,其例子在圖式中示出。在可能的情況下,在圖式以及描述中使用相同的圖式標號來指出相同或相似的部分。
參照第1圖,第1圖繪示根據本揭示之一實施例中一種機器學習裝置的示意圖。機器學習裝置100包括處理器110以及記憶體120。處理器110以及記憶體120互相連接。
於一些實施例中,機器學習裝置100可由電腦、伺服器或處理中心建立。於一些實施例中,處理器110可由中央處理單元或計算單元實現。於一些實施例中,記憶體120可以利用快閃記憶體、唯讀記憶體、硬碟或任何具相等性的儲存組件來實現。
於一些實施例中,機器學習裝置100並不限於包 括處理器110以及記憶體120,機器學習裝置100可以進一步包括操作以及應用中所需的其他元件,舉例來說,機器學習裝置100可以更包括輸出介面(例如用於顯示資訊的顯示面板)、輸入介面(例如觸控面板、鍵盤、麥克風、掃描器或快閃記憶體讀取器)以及通訊電路(例如WiFi通訊模組、藍芽通訊模組、無線電信網路通訊模組等)。
如第1圖所示,處理器110基於儲存在記憶體120中相應的軟體/韌體指令程序用以運行分類模型111。
於一些實施例中,分類模型111可以對輸入的資料(例如上述的資料增強影像)進行分類,例如檢測輸入的影像當中具有車輛、人臉、車牌、文字、圖騰或是其他影像特徵物件。分類模型111根據分類的結果產生相應的標籤。需特別說明的是,分類模型111在進行分類運作時需參考本身的模型參數MP。
如第1圖所示,記憶體120用以儲存模型參數MP。於一些實施例中,模型參數MP可包括多個權重參數內容。
於本實施例中,分類模型111包括多個神經網路結構層。於一些實施例中,各層神經網絡路結構層可對應於模型參數MP中的一個權重參數內容(用於決定一個神經網路結構層的操作)。另一方面,分類模型111的各神經網路結構層可對應於相互獨立的權重參數內容。換言之,各層神經網路結構層可對應於一個權重值集合,權重值集合可包括多個權重值。
於一些實施例中,神經網路結構層可以是卷積層、池化層、線性整流層、全連接層或其他類型的神經網路結構層。於一些實施例中,分類模型111可相關於神經網路(例如,分類模型111可由深度殘差網絡(ResNet)以及全連接層組成,或者是由EfficentNet以及全連接層組成)。
請一併參照第2圖,其繪示根據本揭示之一實施例中一種機器學習方法的示意圖,第1圖所示的機器學習裝置100可用以執行第2圖中的機器學習方法。
如第2圖所示,首先於步驟S210中,從記憶體120中獲取模型參數MP,並根據模型參數MP執行分類模型111。於一實施例中,記憶體120中的模型參數MP可以是根據過往訓練經驗當中取得的平均值、人工給定的預設值、或是隨機數值。
於步驟S220中,根據多個訓練樣本計算第一損失以及第二損失,其中第一損失對應於多個神經網路結構層中的輸出層,第二損失對應於多個神經網路結構層中位於輸出層之前的一者。於一實施例中,第一損失由處理器110從分類模型111的神經網路結構層的輸出層產生,第二損失由處理器110從輸出層之前的神經網路結構層產生。於一些實施例中,輸出層可包括至少一全連接層。後續將配合具體的例子,進一步說明步驟S220在一些實施例當中的詳細步驟。
於步驟S230中,根據第一損失以及第二損失對模 型參數MP執行多個更新操作以訓練分類模型111。於一實施例中,處理器110在更新操作中根據第一損失以及第二損失更新模型參數MP以產生更新後的模型參數MP,並根據更新後的模型參數MP訓練分類模型111以產生訓練後的分類模型111。後續將配合具體的例子,進一步說明步驟S230在一些實施例當中的詳細步驟。
藉此,訓練後的分類模型111可用以執行後續應用。例如,訓練後的分類模型111可用於輸入圖片、影像、串流當中的物件辨識、入臉辨識、音頻辨識或動態偵測等,或可用於關於股票資料或天氣資訊的資料預測。
請一併參照第3圖以及第4圖,第3圖繪示根據本揭示之一實施例中分類模型以及損失的示意圖。第4圖繪示在一些實施例當中步驟S220的詳細步驟S221至S224A的流程圖。
如第3圖所示,分類模型111包括神經網路結構層SL1、SL2、...SLt。在一些實施例中,t為正整數。一般而言,分類模型111中的總層數可以根據實際應用的需求(例如,分類的精確度、分類目標物的複雜度、輸入影像的差異性)而定。在某些情況下,t的常見範圍可以在16到128之間,但本揭示文件並不以特定層數為限。
舉例而言,神經網路結構層SL1及SL2可以是卷積層,神經網路結構層SL3可以是池化層,神經網路結構層SL4及SL5可以是卷積層,神經網路結構層SL6可以是池化層,神經網路結構層SL7可以是卷積層,神經網路 結構層SL8可以是線性整流層,神經網路結構層SLt可以是全連接層,但本揭示文件並不以此為限。
於一些實施例中,分類模型111可具有多個殘差映射區塊(Residual Mapping Block),藉由使用殘差映射區塊的結構,可以大大降低t。以下以分類模型111的這種結構為例,以進一步說明步驟S221至步驟S224A。
需特別補充的是,為了說明上的方便,第3圖中的分類模型111僅僅是例示性說明而繪示具有殘差映射區塊的模型(例如ResNet模型),但本揭示文件並不以此為限。實際應用中,分類模型111可以是其他類型的卷積神經網路。於一些實施例中,分類模型111可以是EfficentNet模型。
如第3圖以及第4圖所示,於步驟S221中,藉由處理器110根據訓練樣本
Figure 110144877-A0305-02-0009-1
從神經網路結構層SL1、SL2、...SLt的輸出層SLt產生多個預測標籤
Figure 110144877-A0305-02-0009-3
。值得注意的是,n是訓練樣本
Figure 110144877-A0305-02-0009-5
的數量,n也是預測標籤
Figure 110144877-A0305-02-0009-6
的數量,n可以是正整數,且i可以是不大於n的正整數。如圖3所示,當訓練樣本Xi輸入到分類模型111時,經由神經網路結構層SL1、SL2、...SLt的運算,可從分類模型111的神經網路結構層SLt(即輸出層)產生預測標籤
Figure 110144877-A0305-02-0009-7
。以此類推,可以將訓練樣本
Figure 110144877-A0305-02-0009-8
輸入分類模型111以產生預測標籤
Figure 110144877-A0305-02-0009-9
如第3圖以及第4圖所示,於步驟S222中,藉 由處理器110執行比較演算法將預測標籤
Figure 110144877-A0305-02-0010-10
以及訓練樣本
Figure 110144877-A0305-02-0010-11
的多個訓練標籤
Figure 110144877-A0305-02-0010-12
進行比較,以產生第一損失L1。如第3圖所示,將預測標籤
Figure 110144877-A0305-02-0010-41
以及訓練樣本Xi的訓練標籤y i進行比較計算損失。以此類推,藉由處理器110執行比較演算法將預測標籤以及訓練標籤進行比較以計算出多個損失,且藉由處理器110根據這些損失(即,傳統的損失函數(Loss Function))產生第一損失L1。於在一些實施例中,可藉由處理器110對預測標籤
Figure 110144877-A0305-02-0010-13
以及訓練標籤
Figure 110144877-A0305-02-0010-14
執行交叉熵(Cross Entropy)計算以獲得第一損失L1。
如第3圖以及第4圖所示,於步驟S223中,藉由根據訓練樣本
Figure 110144877-A0305-02-0010-15
從分類模型111中生成多個提取特徵
Figure 110144877-A0305-02-0010-16
。如第3圖所示,將訓練樣本Xi輸入到分類模型111後,可經由神經網路結構層SL1、SL2、...SLt-1的操作從分類模型111的神經網路結構層Lt-1的人工神經元(Artificial Neuron)計算提取特徵Hi,1、Hi,2、...Hi,m,其中m可以是正整數並等於人工神經元的數量,且提取特徵Hi,1、Hi,2、…Hi,m分別對應於神經網路結構層Lt-1中的人工神經元。此外,提取特徵Hi,1、Hi,2、…Hi,m也可以分別對應於神經網路結構層Lt-1中的人工神經元。此外,提取特徵Hi,1,Hi,2,…Hi,m也可以分別對應於神經網路結構層Lt-1之前的任一神經網路結構層中的人工神經元。以此類推,可由人工神經元計算與訓練樣本
Figure 110144877-A0305-02-0010-17
對應的提取特徵
Figure 110144877-A0305-02-0010-18
值得注意的是,提取特徵
Figure 110144877-A0305-02-0011-19
以及訓練標籤
Figure 110144877-A0305-02-0011-20
之間可能存在偽相關(Spurious Correlation)。詳細而言,假設第一提取特徵對第二提取特徵以及訓練標籤y i都存在因果相關(Causally Related),但是第二提取特徵以及訓練標籤y i彼此之間沒有因果相關。基於此,可將第二提取特徵以及訓練標籤y i相關連(Associated)。當第二提取特徵的數值隨著標籤的變化線性增加時,第二提取特徵以及訓練標籤y i之間存在偽相關。如果可以觀察到導致偽相關的提取特徵(即,第一提取特徵、第二提取特徵以及訓練標籤y i之間的關係),則偽相關屬於顯性(Explicit)。否則,偽相關可被認為是隱性(Implicit)(即,第二提取特徵以及訓練標籤y i之間的關係)。偽相關將會導致預測標籤
Figure 110144877-A0305-02-0011-21
以及訓練標籤
Figure 110144877-A0305-02-0011-22
的差異更大。
例如,若患者臨床影像具有病灶的細胞組織以及與細胞組織顏色相似的骨骼,則導致骨骼的提取特徵以及病灶的標籤之間的顯性偽相關(Explicit Spurious Correlation)。於另一例子中,患者臨床影像通常具有背景,患者臨床影像中的病灶以及背景相似。因此,這導致了背景提取特徵與病灶標籤之間的隱性偽相關(Implicit Spurious Correlation)。
為了避免偽相關,以下段落進一步描述了使用統計獨立性(Statistical lndependence)消除顯性偽相關以及使用平均處理效應(Average Treatment Effect,ATE)消除隱性偽相關的細節。
如第3圖以及第4圖所示,於步驟S224A中,藉由處理器110根據提取特徵之間的統計獨立性計算第二損失L2,其中提取特徵對應於神經網絡結構層SL1、SL2、...SLt中的一者(即神經網絡結構層SLt-1)。具體而言,隨機變數(Random Variable)的統計獨立性如下公式(1)所示。
E(apbq)=E(ap)E(bq).....(1)
其中E(.)表示隨機變數的期望值,a以及b為隨機變數,p以及q為正整數。根據公式(1),獨立性損失可以用下面的公式(2)表示。
independent loss=-|E(apbq)-E(ap)E(bq)|.....(2)
如第3圖所示,藉由將隨機變數替換為提取特徵
Figure 110144877-A0305-02-0012-23
,公式(2)可以改寫為以下公式(3),其表示第 二損失L2(即,提取特徵
Figure 110144877-A0305-02-0012-24
之間的獨立性損失)。
Figure 110144877-A0305-02-0012-25
其中j以及k為正整數且不大於m。藉由公式(3),根據提取特徵
Figure 110144877-A0305-02-0012-26
計算第二損失L2。於一些實施例中,公式(3)的第二損失可以進一步乘以重要性值(Importance Value)以產生第二損失L2,其中重要性值大於零且是控制獨立性損失的重要性的超參數。
請一併參照第5圖,第5圖繪示在另一些實施例當中步驟S220的詳細步驟S221至S224B的流程圖。
值得注意的是,第4圖與第5圖的區別僅在於步 驟S224B。換言之,除了執行步驟S224A以產生第二損失之外,可選地,也可執行步驟S224B以產生第二損失。因此,以下僅針對步驟S224B進行說明,其餘步驟不再贅述。
如第3圖以及第5圖所示,於步驟S224B中,藉由處理器110根據提取特徵以及訓練樣本的訓練標籤之間的平均處理效應計算第二損失L3,其中提取特徵對應於神經網路結構層SL1、SL2、...SLt中的一者(即,神經網路結構層SLt-1)。詳細而言,隨機變數的平均處理效應(即,因果性)如下公式(4)所示。
Figure 110144877-A0305-02-0013-27
其中p(.)表示隨機變數的機率,Yi和Ti是隨機變數,
Figure 110144877-A0305-02-0013-46
且代表治療,
Figure 110144877-A0305-02-0013-47
且是觀察結果,
Figure 110144877-A0305-02-0013-48
且是共變向量(Covariate Vector),以及
Figure 110144877-A0305-02-0013-28
如第3圖所示,藉由將Yi以及Ti替換為訓練標籤
Figure 110144877-A0305-02-0013-29
以及經由強激活函數(Hard Sigmoid)處理的提 取特徵
Figure 110144877-A0305-02-0013-32
,公式(4)可以改寫為如下公式(5)。
Figure 110144877-A0305-02-0013-33
其中第j個提取特徵的損失是指對應於提取特徵H1,j、H2,j、...Hn,j的因果性損失(即,平均處理效應損 失),
Figure 110144877-A0305-02-0014-49
是指範圍為
Figure 110144877-A0305-02-0014-34
的強激活函數。基於公 式(5),指示提取特徵
Figure 110144877-A0305-02-0014-35
的平均處理效應的第二損失L3如下公式(6)所示。
Figure 110144877-A0305-02-0014-36
藉由公式(6),根據提取特徵以及訓練樣本的訓練標籤計算第二損失L3。於一些實施例中,公式(6)的第二損失還可進一步乘以另一重要性值以產生第二損失L3,其中另一重要性值也大於零且是控制平均處理效應損失的重要性的另一超參數。
請一併參照第6圖,第6圖繪示在一些實施例當中步驟S230的詳細步驟S231A至S233的流程圖。
如第6圖所示,於步驟S231A中,藉由處理器110根據第一損失以及第二損失計算損失差。詳細而言,藉由處理器110執行第一損失以及第二損失之間的差值運算以產生損失差(即,第一損失減去第二損失)。值得注意的是,第二損失可以由第4圖的步驟S224A或第5圖的步驟S224B產生。換言之,可根據第一損失以及獨立損失或根據第一損失以及平均處理效應損失以計算損失差。
此外,還可以同時根據第一損失、第4圖中的步驟S224A產生的第二損失以及第5圖中的步驟S224B產生的第二損失計算損失差(更詳細的內容將在下面的段落中藉由一些例子進行說明)。
於步驟S232中,判斷損失差是否收斂。於一些實 施例中,當損失差收斂時,損失差可接近或等於根據統計實驗結果所產生的差閾值。
於本實施例中,若損失差沒有收斂,則執行步驟S233。於步驟S233中,藉由處理器110根據第一損失以及第二損失對分類模型執行反向傳遞(Backpropagation)操作以更新模型參數MP。換言之,根據基於第一損失以及第二損失的反向傳遞,從模型參數MP產生更新的模型參數。
藉此,繼續重複步驟S233、S220以及S231A,以迭代的方式逐漸更新模型參數MP。如此一來,損失差將逐漸最小化(即,第二損失逐漸最大化),直到損失差接近或等於差閾值。反之,若損失差收斂,則表示機器學習裝置100已完成訓練,訓練後的分類模型111可用於執行後續應用。
基於上述實施例,藉由使用步驟S224A中的第二損失,可以在步驟S230中去除屬於顯性偽相關的提取特徵。此外,藉由使用步驟S224B中的第二損失,可以在步驟S230中去除屬於隱性偽相關的提取特徵。
請一併參照第7圖,第7圖繪示在一些實施例當中在步驟S224A之後的額外步驟的流程圖。
如圖7所示,步驟S220’A以與在步驟S224B中計算第二損失相同的方式計算第三損失。換言之,這表示在藉由處理器110產生第一損失後,產生獨立損失以及平均處理效應損失。由於步驟S220’A以及步驟S224B 類似,此步驟不再贅述。
請一併參照第8圖,第8圖繪示在另一些實施例當中步驟S230的詳細步驟S231B至S233的流程圖。
值得注意的是,第6圖以及第8圖的區別僅在於步驟S231B。換言之,除了執行步驟S231A以產生損失差之外,可選地,也可進行步驟S231B以產生損失差。因此,以下僅針對步驟S231B進行說明,其餘步驟不再贅述。
如第8圖所示,在執行完步驟S220’之後,再執行步驟S231B。於步驟S231B中,藉由處理器110根據第一損失、第二損失以及第三損失計算損失差。詳細而言,藉由處理器110執行第一損失以及第二損失之間的差值運算以產生第一差值,然後在第一差值以及第三個損失之間執行另一差值運算以生成損失差(即,第一損失減去第二損失,然後再減去第三損失)。因此,於步驟S233中,根據基於第一損失、第二損失以及第三損失的反向傳遞從模型參數MP產生更新的模型參數。藉此,同樣繼續重複步驟S233、S220以及S231B,以迭代的方式逐漸更新模型參數MP。如此一來,類似地,損失差也逐漸最小化(即,第二損失以及第三損失逐漸最大化),直到損失差接近或等於差閾值。
基於上述實施例,藉由同時利用步驟S224A中的第二損失以及S220'中的第三損失,可以在步驟S230中去除屬於顯性偽相關以及隱性偽相關的提取特徵。
如第1圖所示,在機器學習裝置100的訓練過程中,根據第一損失以及第二損失更新分類模型111的模型參數MP,以避免提取特徵以及訓練標籤之間的顯性偽相關或隱性偽相關,其中第二損失可以是獨立性損失或平均處理效應損失。此外,藉由利用獨立性損失以及平均處理效應損失調整模型參數MP可以去除顯性偽相關以及隱性偽相關,從而大大提高分類模型111的預測精確度。
在電腦視覺以及電腦預測領域,深度學習的準確度主要依賴於大量標記的訓練資料。隨著訓練資料的質量、數量以及種類的增加,分類模型的性能通常會相對提高。然而,分類模型在提取特徵以及訓練標籤之間總是存在顯性偽相關或隱性偽相關。如果我們能去除顯性偽相關或隱性偽相關,效率會更高且更準確。在本揭示的上述實施例中,提出根據獨立性損失以及平均處理效應損失調整模型,去除分類模型中的顯性偽相關或隱性偽相關。因此,根據獨立性損失以及平均處理效應損失調整模型參數可以提高模型的整體性能。
於應用層面上,本揭示文件的機器學習方法與機器學習系統可以用在各種具有機器視覺、圖像分類、資料預測或是資料分類的領域,舉例而言,此機器學習方法可以用在醫療影像的分類,像是可以分辨正常狀態、患有肺炎、患有支氣管炎、患有心臟疾病的X光影像,或是可以分辨正常胎兒、胎位不正的超音波影像。機器學習方法還可用於預測未來股票資料的上漲或下跌。另一方面,此機器學 習方法也可以用在自動駕駛收集之影像的分類,像是可以分辨正常路面、有障礙物的路面及其他車輛的路況影像。還有其他與此類似的機器學習領域,舉例而言,本揭示文件的機器學習方法與機器學習系統也可以用在音譜辨識、光譜辨識、大數據分析、資料特徵辨識等其他有關機器學習的範疇當中。
雖然本揭示的特定實施例已經揭露有關上述實施例,此些實施例不意欲限制本揭示。各種替代及改良可藉由相關領域中的一般技術人員在本揭示中執行而沒有從本揭示的原理及精神背離。因此,本揭示的保護範圍由所附申請專利範圍確定。
S210~S230:步驟

Claims (20)

  1. 一種機器學習方法,包括:藉由一處理器從一記憶體中獲取一模型參數,並根據該模型參數執行一分類模型,其中該分類模型包括多個神經網路結構層;藉由該處理器根據多個訓練樣本所產生的多個預測標籤計算一第一損失,以及藉由該處理器根據由該多個訓練樣本所產生的多個提取特徵計算一第二損失,其中該第一損失對應於該些神經網路結構層中的一輸出層,該第二損失對應於該些神經網路結構層中位於該輸出層之前的一者;以及藉由該處理器根據該第一損失以及該第二損失對該模型參數執行多個更新操作以訓練該分類模型。
  2. 如請求項1所述之機器學習方法,其中根據該些訓練樣本所產生的該多個預測標籤計算該第一損失的步驟包括:藉由該處理器根據該些訓練樣本從該些神經網路結構層的該輸出層產生該多個預測標籤;以及藉由該處理器對該些預測標籤以及該些訓練樣本的多個訓練標籤進行比較以計算該第一損失。
  3. 如請求項1所述之機器學習方法,其中根據由該些訓練樣本所產生的該多個提取特徵計算該第二損失 的步驟包括:藉由該處理器根據該些訓練樣本從該分類模型中產生該多個提取特徵;以及藉由該處理器根據該些提取特徵之間的統計獨立性計算該第二損失,其中該些提取特徵對應於該些神經網路結構層中的一者。
  4. 如請求項3所述之機器學習方法,其中根據該第一損失以及該第二損失對該模型參數執行該些更新操作以訓練該分類模型包括:藉由該處理器根據該第一損失以及該第二損失計算多個損失差;以及藉由該處理器根據該些損失差對該分類模型執行多個反向傳遞操作以更新該模型參數。
  5. 如請求項3所述之機器學習方法,更包括:藉由該處理器根據該些提取特徵以及該些訓練樣本的多個訓練標籤之間的平均處理效應計算一第三損失。
  6. 如請求項5所述之機器學習方法,其中根據該第一損失以及該第二損失對該模型參數執行該些更新操作以訓練該分類模型包括:藉由該處理器根據該第一損失、該第二損失以及該第三損失計算多個損失差;以及 藉由該處理器根據該些損失差對該分類模型執行多個反向傳遞操作以更新該模型參數。
  7. 如請求項1所述之機器學習方法,其中根據由該些訓練樣本所產生的該多個提取特徵計算該第二損失的步驟包括:藉由該處理器根據該些訓練樣本從該分類模型中產生多個提取特徵;以及藉由該處理器根據該些提取特徵以及該些訓練樣本的多個訓練標籤之間的平均處理效應計算該第二損失,其中該些提取特徵對應於該些神經網路結構中的一者。
  8. 如請求項7所述之機器學習方法,其中根據該第一損失以及該第二損失對該模型參數執行該些更新操作以訓練該分類模型包括:藉由該處理器根據該第一損失以及該第二損失計算多個損失差;以及藉由該處理器根據該些損失差對該分類模型執行多個反向傳遞操作以更新該模型參數。
  9. 如請求項1所述之機器學習方法,其中該輸出層包括至少一全連接層,該些神經網路結構層中的一者包括至少一個卷積層。
  10. 如請求項1所述之機器學習方法,其中該分類模型相關於神經網路。
  11. 一種機器學習裝置,包括:一記憶體,用以儲存多個指令以及一模型參數;一處理器,連接該記憶體,其中該處理器用以運行一分類模型,並執行該些指令以:從該記憶體中獲取該模型參數,並根據該模型參數執行該分類模型,其中該分類模型包括多個神經網路結構層;根據多個訓練樣本所產生的多個預測標籤計算一第一損失以及根據由該多個訓練樣本所產生的多個提取特徵計算一第二損失,其中該第一損失對應於該些神經網路結構層中的一輸出層,該第二損失對應於該些神經網路結構層中位於該輸出層之前的一者;以及根據該第一損失以及該第二損失對該模型參數執行多個更新操作以訓練該分類模型。
  12. 如請求項11所述之機器學習裝置,其中該處理器更用以:根據該些訓練樣本從該些神經網路結構層的該輸出層產生該多個預測標籤;以及對該些預測標籤以及該些訓練樣本的多個訓練標籤進行比較以計算該第一損失。
  13. 如請求項11所述之機器學習裝置,其中該處理器更用以:根據該些訓練樣本從該分類模型中產生該多個提取特徵;以及根據該些提取特徵之間的統計獨立性計算該第二損失,其中該些提取特徵對應於該些神經網路結構層中的一者。
  14. 如請求項13所述之機器學習裝置,其中該處理器更用以:根據該第一損失以及該第二損失計算多個損失差;以及根據該些損失差對該分類模型執行多個反向傳遞操作以更新該模型參數。
  15. 如請求項13所述之機器學習裝置,其中該處理器更用以:根據該些提取特徵以及該些訓練樣本的多個訓練標籤之間的平均處理效應計算一第三損失。
  16. 如請求項15所述之機器學習裝置,其中該處理器更用以:根據該第一損失、該第二損失以及該第三損失計算多個損失差;以及根據該些損失差對該分類模型執行多個反向傳遞操作以 更新該模型參數。
  17. 如請求項11所述之機器學習裝置,其中該處理器更用以:根據該些訓練樣本從該分類模型中產生多個提取特徵;以及根據該些提取特徵以及該些訓練樣本的多個訓練標籤之間的平均處理效應計算該第二損失,其中該些提取特徵對應於該些神經網路結構中的一者。
  18. 如請求項17所述之機器學習裝置,其中該處理器更用以:根據該第一損失以及該第二損失計算多個損失差;以及根據該些損失差對該分類模型執行多個反向傳遞操作以更新該模型參數。
  19. 如請求項11所述之機器學習裝置,其中該輸出層包括至少一全連接層,該些神經網路結構層中的一者包括至少一個卷積層。
  20. 如請求項11所述之機器學習裝置,其中該分類模型相關於神經網路。
TW110144877A 2020-12-02 2021-12-01 機器學習裝置以及方法 TWI781000B (zh)

Applications Claiming Priority (6)

Application Number Priority Date Filing Date Title
US202063120216P 2020-12-02 2020-12-02
US63/120,216 2020-12-02
US202163152348P 2021-02-23 2021-02-23
US63/152,348 2021-02-23
US17/448,711 2021-09-24
US17/448,711 US20220172064A1 (en) 2020-12-02 2021-09-24 Machine learning method and machine learning device for eliminating spurious correlation

Publications (2)

Publication Number Publication Date
TW202223770A TW202223770A (zh) 2022-06-16
TWI781000B true TWI781000B (zh) 2022-10-11

Family

ID=78820691

Family Applications (1)

Application Number Title Priority Date Filing Date
TW110144877A TWI781000B (zh) 2020-12-02 2021-12-01 機器學習裝置以及方法

Country Status (5)

Country Link
US (1) US20220172064A1 (zh)
EP (1) EP4009245A1 (zh)
JP (1) JP7307785B2 (zh)
CN (1) CN114648094A (zh)
TW (1) TWI781000B (zh)

Families Citing this family (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116976221B (zh) * 2023-08-10 2024-05-17 西安理工大学 基于冲蚀特性的堰塞体溃决峰值流量预测方法及存储介质

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108734300A (zh) * 2017-04-24 2018-11-02 英特尔公司 使用自主机器的识别、再标识和安全性增强
US20200012937A1 (en) * 2018-07-06 2020-01-09 Capital One Services, Llc Systems and methods to identify neural network brittleness based on sample data and seed generation
TW202030648A (zh) * 2019-01-31 2020-08-16 大陸商北京市商湯科技開發有限公司 一種目標對象處理方法、裝置、電子設備及儲存介質
TW202034227A (zh) * 2019-03-05 2020-09-16 南韓商三星電子股份有限公司 用於提供旋轉不變神經網路的方法及系統

Family Cites Families (6)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
JP7139625B2 (ja) 2017-08-04 2022-09-21 富士電機株式会社 要因分析システム、要因分析方法およびプログラム
JP2019096006A (ja) 2017-11-21 2019-06-20 キヤノン株式会社 情報処理装置、情報処理方法
US11954881B2 (en) * 2018-08-28 2024-04-09 Apple Inc. Semi-supervised learning using clustering as an additional constraint
JP7095747B2 (ja) 2018-10-29 2022-07-05 日本電信電話株式会社 音響モデル学習装置、モデル学習装置、それらの方法、およびプログラム
JP7086878B2 (ja) 2019-02-20 2022-06-20 株式会社東芝 学習装置、学習方法、プログラムおよび認識装置
CN111476363A (zh) 2020-03-13 2020-07-31 清华大学 区分化变量去相关的稳定学习方法及装置

Patent Citations (5)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN108734300A (zh) * 2017-04-24 2018-11-02 英特尔公司 使用自主机器的识别、再标识和安全性增强
US20200012937A1 (en) * 2018-07-06 2020-01-09 Capital One Services, Llc Systems and methods to identify neural network brittleness based on sample data and seed generation
US20200117998A1 (en) * 2018-07-06 2020-04-16 Capital One Services, Llc Systems and methods to identify neural network brittleness based on sample data and seed generation
TW202030648A (zh) * 2019-01-31 2020-08-16 大陸商北京市商湯科技開發有限公司 一種目標對象處理方法、裝置、電子設備及儲存介質
TW202034227A (zh) * 2019-03-05 2020-09-16 南韓商三星電子股份有限公司 用於提供旋轉不變神經網路的方法及系統

Also Published As

Publication number Publication date
TW202223770A (zh) 2022-06-16
JP7307785B2 (ja) 2023-07-12
EP4009245A1 (en) 2022-06-08
JP2022088341A (ja) 2022-06-14
US20220172064A1 (en) 2022-06-02
CN114648094A (zh) 2022-06-21

Similar Documents

Publication Publication Date Title
CN108133188B (zh) 一种基于运动历史图像与卷积神经网络的行为识别方法
US10452899B2 (en) Unsupervised deep representation learning for fine-grained body part recognition
CN108182394B (zh) 卷积神经网络的训练方法、人脸识别方法及装置
US11836931B2 (en) Target detection method, apparatus and device for continuous images, and storage medium
WO2019228317A1 (zh) 人脸识别方法、装置及计算机可读介质
US11151417B2 (en) Method of and system for generating training images for instance segmentation machine learning algorithm
JP2023549579A (ja) ビデオ行動認識のための時間ボトルネック・アテンション・アーキテクチャ
EP3602424A1 (en) Sensor data processor with update ability
CN115661943B (zh) 一种基于轻量级姿态评估网络的跌倒检测方法
CN113657560B (zh) 基于节点分类的弱监督图像语义分割方法及系统
CN111783997B (zh) 一种数据处理方法、装置及设备
WO2023061102A1 (zh) 视频行为识别方法、装置、计算机设备和存储介质
US11790492B1 (en) Method of and system for customized image denoising with model interpretations
CN112258557B (zh) 一种基于空间注意力特征聚合的视觉跟踪方法
CN114359631A (zh) 基于编码-译码弱监督网络模型的目标分类与定位方法
TWI781000B (zh) 機器學習裝置以及方法
US20220188636A1 (en) Meta pseudo-labels
CN116912568A (zh) 基于自适应类别均衡的含噪声标签图像识别方法
CN117765432A (zh) 一种基于动作边界预测的中学理化生实验动作检测方法
JP2018055287A (ja) 統合装置及びプログラム
CN116486150A (zh) 一种基于不确定性感知的图像分类模型回归误差消减方法
Zhang et al. PCANet: pyramid context-aware network for retinal vessel segmentation
KR102526415B1 (ko) 준지도 학습 방식의 단일 영상 깊이 추정 시스템 및 방법과 이를 위한 컴퓨터 프로그램
US20210365719A1 (en) System and method for few-shot learning
CN115116117A (zh) 一种基于多模态融合网络的学习投入度数据的获取方法

Legal Events

Date Code Title Description
GD4A Issue of patent certificate for granted invention patent