TWI793951B - 模型訓練方法與模型訓練系統 - Google Patents

模型訓練方法與模型訓練系統 Download PDF

Info

Publication number
TWI793951B
TWI793951B TW111100068A TW111100068A TWI793951B TW I793951 B TWI793951 B TW I793951B TW 111100068 A TW111100068 A TW 111100068A TW 111100068 A TW111100068 A TW 111100068A TW I793951 B TWI793951 B TW I793951B
Authority
TW
Taiwan
Prior art keywords
model
hourglass
loss value
network
loss function
Prior art date
Application number
TW111100068A
Other languages
English (en)
Other versions
TW202321993A (zh
Inventor
李蛟
張朝晉
Original Assignee
威盛電子股份有限公司
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by 威盛電子股份有限公司 filed Critical 威盛電子股份有限公司
Application granted granted Critical
Publication of TWI793951B publication Critical patent/TWI793951B/zh
Publication of TW202321993A publication Critical patent/TW202321993A/zh

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/21Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
    • G06F18/214Generating training patterns; Bootstrap methods, e.g. bagging or boosting
    • G06F18/2155Generating training patterns; Bootstrap methods, e.g. bagging or boosting characterised by the incorporation of unlabelled data, e.g. multiple instance learning [MIL], semi-supervised techniques using expectation-maximisation [EM] or naïve labelling
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/24Classification techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/045Combinations of networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/048Activation functions
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/082Learning methods modifying the architecture, e.g. adding, deleting or silencing nodes or connections

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Artificial Intelligence (AREA)
  • General Physics & Mathematics (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computational Linguistics (AREA)
  • Software Systems (AREA)
  • Mathematical Physics (AREA)
  • Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Biophysics (AREA)
  • Computing Systems (AREA)
  • General Health & Medical Sciences (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Evolutionary Biology (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Image Analysis (AREA)
  • Management, Administration, Business Operations System, And Electronic Commerce (AREA)

Abstract

一種模型訓練方法,包含:將一未標注資料輸入至具有一深度神經網路架構之一模型,其中該模型至少包含一堆疊沙漏網路,該堆疊沙漏網路包含複數個沙漏網路;針對該未標注資料,分別得到該複數個沙漏網路之複數個第一蒸餾特徵層輸出;根據一第一損失函數與該複數個第一蒸餾特徵層輸出,來計算一第一損失值;以及依據該第一損失值來進行反向傳播,以調整該模型的參數。

Description

模型訓練方法與模型訓練系統
本發明係有關於模型訓練,尤指一種可透過未標註資料及蒸餾損失來對模型中的堆疊沙漏網路進行半監督訓練以及可基於硬體平台的模型部署需求來適應性地對模型中的堆疊沙漏網路進行裁剪的方法與系統。
關於採用人工智慧技術來進行分類的應用(例如影像辨識),由於實際環境中光照、相機、遮檔等因素的影響和分類的多樣性,導致實際使用的模型進行訓練所需要的資料量會十分龐大,因而加重人工進行樣本標注的困難。另外,不同的硬體平台會有不同的模型部署需求(例如記憶體容量、運算能力以及演算即時性),故不同的硬體平台往往需要分別部署不同大小的模型。
因此,需要一種採用半監督學習機制且能夠根據硬體平台的模型部署需求來適應性地進行裁剪的模型以及相關的模型訓練方法與系統。
因此,本發明的目的之一在於提出一種可透過未標註資料及蒸餾損失來對模型中的堆疊沙漏網路進行半監督訓練以及可基於硬體平台的模型部署需求來適應性地對模型中的堆疊沙漏網路進行裁剪的方法與系統。
在本發明的一實施例中,揭露一種模型訓練方法。該模型訓練方法 包含:將一未標注資料輸入至具有一深度神經網路架構之一模型,其中該模型至少包含一堆疊沙漏網路,該堆疊沙漏網路包含複數個沙漏網路;針對該未標注資料,分別得到該複數個沙漏網路之複數個第一蒸餾特徵層輸出;根據一第一損失函數與該複數個第一蒸餾特徵層輸出,來計算一第一損失值;以及依據該第一損失值來進行反向傳播,以調整該模型的參數。
在本發明的另一實施例中,揭露一種模型訓練系統。該模型訓練系統包含一儲存裝置以及一處理器。該儲存裝置用以儲存一程式碼。該處理器用以載入並執行該程式碼,以執行以下操作:將一未標注資料輸入至具有一深度神經網路架構之一模型,其中該模型至少包含一堆疊沙漏網路,該堆疊沙漏網路包含複數個沙漏網路;針對該未標注資料,分別得到該複數個沙漏網路之複數個第一蒸餾特徵層輸出;根據一第一損失函數與該複數個第一蒸餾特徵層輸出,來計算一第一損失值;以及依據該第一損失值來進行反向傳播,以調整該模型的參數。
本發明模型訓練方法與系統所採用的半監督學習機制可使用老師-學生知識蒸餾的機制,因此可以有效地減少訓練樣本的標注需求,節省大量人力與時間。另外,本發明模型訓練方法與系統可根據硬體平台的模型部署需求,適應性地裁剪所訓練之模型中的堆疊沙漏網路以產生部署至硬體平台之固化模型,而不需要重新訓練。
100:模型訓練系統
102:處理器
104:儲存裝置
110:硬體平台
202:調整網路
204:堆疊沙漏網路
206_1,206_2,206_3,206_4:沙漏網路
208:Maxout網路
210:逐元素最大值函數層
300:CBL模組
302:卷積層
304:批標準化層
306:帶泄漏修正線性單元層
402,404,406,408,412,414,416,418,421,422,423,424,425,426:模組
Code_TF:程式碼
MD:模型
MD_F:固化模型
D_IN:訓練資料
D_L:標注資料
D_UL:未標注資料
A1,A2,A3,A4:蒸餾特徵層輸出
L1,L2,L3,L4:分類輸出
L_OUT:最終分類輸出
第1圖為根據本發明一實施例之模型訓練系統的示意圖。
第2圖為根據本發明一實施例之模型的示意圖。
第3圖為根據本發明一實施例之CBL模組的示意圖。
第4圖為根據本發明一實施例之沙漏網路結構的示意圖。
第1圖為根據本發明一實施例之模型訓練系統的示意圖。如第1圖所示,模型訓練系統100包含(但不限於)一處理器102以及一儲存裝置104。儲存裝置104用以儲存一程式碼Code_TF,例如儲存裝置104可以是傳統硬碟、固態硬碟、記憶體等等,但本發明並不以此為限。處理器102可載入並執行程式碼Code_TF,來實現本發明模型訓練方法。根據本發明模型訓練方法,會對具有一深度神經網路(deep neural network)架構之一模型MD進行訓練,並且於模型MD訓練達到一定程度之後,對模型MD進行固化(freeze)來產生一固化模型(frozen model)MD_F,以便後續將固化模型MD_F部署至一硬體平台110,例如硬體平台110可以是手機、邊緣裝置(edge device)等等。本實施例中,模型訓練方法會採用半監督學習機制,換言之,模型MD會根據訓練資料D_IN(例如是影像資料)來進行半監督訓練,因此,訓練資料D_IN中少量資料是以人工方式進行標注的標注資料(labeled data)D_L,而訓練資料D_IN中大量資料則是未標注資料(unlabeled data)D_UL。此外,模型訓練方法另會參照硬體平台110的模型部署需求(例如記憶體容量、運算能力及/或演算即時性),適應性地(adaptively)裁剪所訓練之模型MD以產生部署至硬體平台110的固化模型MD_F。
第2圖為根據本發明一實施例之模型的示意圖。本實施例中,利用訓練資料D_IN來進行半監督訓練的模型MD可採用第2圖所示的網路架構。如第2圖所示,模型MD包含(但不限於)一調整網路(resize network)202、一堆疊沙漏網路(stacked hourglass network)204、一Maxout網路208以及一逐元素最大值 (element-wise maximum)函數層210。調整網路202是用來進行維度調整,例如降低調整網路202最後輸出的每一通道(channel)的特徵圖(feature map)的大小,假設輸入影像的尺寸為W*H(單位是像素),則可透過調整網路202來將輸入變成(W/s)*(H/s),例如s=4。
本實施例中,調整網路202可根據設計需求而由一個或多個CBL模組構成。第3圖為根據本發明一實施例之CBL模組的示意圖。CBL模組300包含(但不限於)一卷積(convolution)層302、一批標準化(batch normalization,BN)層304以及一帶泄漏修正線性單元(Leaky Rectified Linear Unit,Leaky ReLU)層306。舉例來說,卷積層302可採用步長(stride)為2,假若輸入影像的尺寸為W*H,則每經過一個卷積層302的處理,便可分別讓W與H減半(亦即W/2與H/2)。批標準化層304可以加快模型的訓練速度,並讓訓練更加穩定。帶泄漏修正線性單元層306是啟動函數層,用以保留正值,並將負值替換為0。
調整網路202可以根據輸入的大小和網路的輸出個數來調節CBL模組的個數。假若輸入為640*360以及檢測目標為8*8的大小,則調整網路202可以使用串接的3個CBL模組300(卷積層302的步長為2),來得到80x44的特徵圖大小。假若使用更小的輸入320*180,此時檢測目標的大小僅有4*4,則調整網路202可以使用串接的2個CBL模組300(卷積層302的步長為2),來得到80x44的特徵圖大小。
堆疊沙漏網路204由多個沙漏網路串接構成,如第2圖所示,本實施例所採用的堆疊沙漏網路204包含4個沙漏網路206_1、206_2、206_3、206_4,其中每個沙漏網路具有相同的網路結構。請注意,堆疊沙漏網路204由4個沙漏 網路構成僅作為範例說明之用,並非作為本發明的限制,實作上,堆疊沙漏網路204可由K個沙漏網路構成,其中K可以是任何不小於2的正整數(亦即K≧2)。 第4圖為根據本發明一實施例之沙漏網路結構的示意圖。於一實施方式中,每一個沙漏網路206_1、206_2、206_3、206_4可採用第4圖所示的網路結構。複數個模組402、404、406、408中的每一個模組代表對輸入特徵圖進行降採樣(down-sample),使得輸出特徵圖的大小會小於輸入特徵圖的大小。複數個模組412、414、416、418中的每一個模組代表對輸入特徵圖進行升採樣(up-sample),使得輸出特徵圖的大小會大於輸入特徵圖的大小。另外,複數個模組421、422、423、424、425、426中的每一個模組則代表輸出特徵圖與輸入特徵圖會具有相同大小。
藉由調整網路202可以得到需要的特徵圖大小,亦即透過CBL模組的適當設計,每一通道(每一類別)的特徵圖會具有所要的大小,為了得到更好的語義資訊,沙漏網路可以繼續降維並且將語義資訊較好的低維特徵融合到最終的輸出特徵圖,如第4圖所示,沙漏網路的輸出特徵圖與輸入特徵圖會具有相同大小,因此理論上可以無限堆疊,這裡的堆疊並不影響輸出的物理意義,但是透過堆疊可以提高模型輸出的準確度。由於堆疊沙漏網路204中串接的沙漏網路越多,則特徵偵測的效果會越好,因此,本發明模型訓練方法所採用的半監督學習機制可使用老師-學生知識蒸餾(teacher-student knowledge distillation),將最後一個沙漏網路206_4作為老師,來指導前面多個沙漏網路206_1、206_2、206_3的學習,換言之,最後一個沙漏網路206_4的輸出結果可作為前面多個沙漏網路206_1、206_2、206_3的目標以進行訓練。
沙漏網路206_1、206_2、206_3、206_4會分別產生蒸餾特徵層輸出 A1、A2、A3、A4,其中每個蒸餾特徵層輸出包含多個通道(多個類別)的相對應特徵圖。Maxout網路208會對蒸餾特徵層輸出A1、A2、A3、A4進行處理來分別產生分類輸出L1、L2、L3、L4,Maxout網路208可視為啟動函數層,它的輸出是一組輸入的最大值,亦即,蒸餾特徵層輸出A1經由啟動函數(亦即Maxout函數)而得到分類輸出L1,蒸餾特徵層輸出A2經由啟動函數(亦即Maxout函數)而得到分類輸出L2,蒸餾特徵層輸出A3經由啟動函數(亦即Maxout函數)而得到分類輸出L3,以及蒸餾特徵層輸出A4經由啟動函數(亦即Maxout函數)而得到分類輸出L4。假設每個蒸餾特徵層輸出包含m個通道(m個類別)的相對應特徵圖,透過特徵圖的Maxout函數處理,則每一個分類輸出會包含分別對應至m個通道(m個類別)的預測概率。逐元素最大值函數層210則是用來針對分類輸出L1、L2、L3、L4進行逐元素取最大值的操作,以產生一最終分類輸出L_OUT,舉例來說,最終分類輸出L_OUT中對應至第1個通道(第1個類別)的預測概率是取分類輸出L1、L2、L3、L4中對應至第1個通道(第1個類別)的所有預測概率
Figure 111100068-A0305-02-0008-7
中的最大值,最終分類輸出L_OUT中對應至第2個通道(第2個類別)的預測概率
Figure 111100068-A0305-02-0008-8
是取分類輸出L1、L2、L3、L4中對應至第2個通道(第2個類別)的所有預測概率
Figure 111100068-A0305-02-0008-9
中的最大值,最終分類輸出L_OUT中對應至第3個通道(第3個類別)的預測概率
Figure 111100068-A0305-02-0008-10
是取分類輸出L1、L2、L3、L4中對應至第3個通道(第3個類別)的所有預測概率
Figure 111100068-A0305-02-0008-11
中的最大值,以此類推。在一實施例中,可以採用最終分類輸出L_OUT中預測概率最大的類別做為輸入影像的分類結果。
如前所述,本發明模型訓練方法採用半監督學習機制來對模型MD進行訓練,因此,會先使用人工建立的少量標注資料D_L來對模型MD進行訓練,當模型MD訓練到一定準確率之後,再透過大量的未標注資料D_L來對模型MD進行訓練。當利用標注資料D_L來對模型MD進行訓練時,本發明模型訓練方法 採用兩種損失函數(loss function)來計算反向傳播(back propagation)所要使用的損失值L,於本實施例中,反向傳播所要使用損失值L主要由兩個部分構成,分別是分類損失值Lclassify以及特徵學習損失值Ldistillation
計算分類損失值Lclassify的損失函數是採用交叉熵(cross entropy)損失函數,並基於分類輸出Li(例如第2圖所示的分類輸出L1、L2、L3、L4)來決定分類損失值Lclassify,如下所示:
Figure 111100068-A0305-02-0009-5
其中M為類別的數量(通道的數量);N為沙漏網路的個數,y ic 為樣本i在c類別上的標籤(label),P ic 為樣本i屬於類別c的預測概率。預測概率就是網路學習的分類分數,例如得分在0.5以上就可以認為是某個類別。
計算特徵學習損失值Ldistillation的損失函數是採用蒸餾損失函數(例如L2損失函數),並基於蒸餾特徵層輸出Ai(例如第2圖所示的蒸餾特徵層輸出A1、A2、A3、A4)來決定特徵學習損失值Ldistillation,如下所示:
Figure 111100068-A0305-02-0009-4
其中D表示平方和函數;An表示n個沙漏網路的蒸餾特徵層輸出,例如N=4時,Ldistillation=[F(A1)-F(A4)]2+[F(A2)-F(A4)]2+[F(A3)-F(A4)]2A ni A n 的第i個通道;S為spatial softmax函數。所有的操作均為逐元素操作,亦即,M個通道輸出的絕對值平方之後相加,並進行空間唯獨上的softmax運算,舉例來說,透過spatial softmax函數,可以得到同一蒸餾特徵層中每個通道的預測概率。
最終損失值L是由分類損失值Lclassify以及特徵學習損失值Ldistillation所組成,如下所示:L=L classify +γL distillation
其中γ為權重係數,一開始可以先將權重係數γ設為1,等到模型MD充分訓練後,可以再根據分類損失值Lclassify以及特徵學習損失值Ldistillation之間的比值來調整權重係數γ,使得分類損失值Lclassify以及特徵學習損失值Ldistillation能維持在同一數量級,例如權重係數γ最終可設為0.2。
當改用未標注資料D_UL來對模型MD進行訓練時,由於缺乏人工建立的標籤,本發明模型訓練方法僅採用單一損失函數來計算反向傳播所需使用的損失值L,本實施例中,反向傳播所需使用的損失值L僅由特徵學習損失值Ldistillation構成,而不包含分類損失值Lclassify,此時,通過最後一個沙漏網路206_4的輸出結果來作為真實標籤,並採用特徵學習損失值Ldistillation(亦即以最後一個沙漏網路206_4的輸出結果作為前面多個沙漏網路206_1、206_2、206_3的目標來計算損失),來指導前面多個沙漏網路206_1、206_2、206_3的學習,達到半監督訓練/學習的效果。
每次訓練時,本發明模型訓練方法可採用Adam優化演算法(Adam optimization)來進行反向傳播,以調整模型MD的參數(例如每一層的權重值),舉例來說,當使用未標注資料D_UL來對模型MD進行訓練時,可判斷特徵學習損失值Ldistillation與一閾值(threshold)的關係,如果大於閾值,則將特徵學習損失值Ldistillation的大小回傳至模型MD中以調整F(An)值的大小,如此不斷迴圈,直到特徵學習損失值Ldistillation小於閾值為止。
堆疊沙漏網路204是由多個沙漏網路206_1、206_2、206_3、206_4串接而構成,本實施例中,這些沙漏網路206_1、206_2、206_3、206_4是由相同的損失函數來進行訓練,因此,本發明模型訓練方法可直接透過剪裁來獲得不同大小的模型,例如控制CBL模組及/或沙漏網路的數量,自我調整模型大小,進而適應具有不同模型部署需求的硬體平台。
假設硬體平台110具有第一模型部署需求(例如記憶體容量M1、運算能力P1以/或演算即時性R1),則在對模型MD進行固化來產生部署至硬體平台110的固化模型MD_F時,可以將沙漏網路206_2、206_3、206_4去掉,只取透過沙漏網路206_1所得到的分類輸出L1,請注意,由於僅有一個分類輸出L1,因此,固化模型MD_F另可省略逐元素最大值函數層210,故硬體平台110實際執行固化模型MD_F所定義的分類網路時,分類輸出L1會直接作為最終分類輸出L_OUT。
假設硬體平台110具有第二模型部署需求(例如記憶體容量M2(M2>M1)、運算能力P2(P2>P1)以/或演算即時性R2(R2<R1)),則在對模型MD進行固化來產生部署至硬體平台110的固化模型MD_F時,可以將沙漏網路206_3、206_4去掉,只取透過沙漏網路206_1、206_2所得到的分類輸出L1、L2,硬體平台110實際執行固化模型MD_F所定義的分類網路時,分類輸出L1、L2後續可透過逐元素最大值函數層210來得到最終分類輸出L_OUT。
假設硬體平台110具有第三模型部署需求(例如記憶體容量M3(M3>M2)、運算能力P3(P3>P2)以/或演算即時性R3(R3<R2)),則在對模型MD進行固化來產生部署至硬體平台110的固化模型MD_F時,可以將沙漏網路206_4 去掉,只取透過沙漏網路206_1、206_2、206_3所得到的分類輸出L1、L2、L3,硬體平台110實際執行固化模型MD_F所定義的分類網路時,分類輸出L1、L2、L3後續可透過逐元素最大值函數層210來得到最終分類輸出L_OUT。
綜上所述,本發明模型訓練方法與系統所採用的半監督學習機制可使用老師-學生知識蒸餾的機制,因此可以有效地減少訓練樣本的標注需求,節省大量人力與時間。另外,本發明模型訓練方法與系統可根據硬體平台的模型部署需求,適應性地裁剪所訓練之模型中的堆疊沙漏網路以產生部署至硬體平台之固化模型,而不需要重新訓練。
以上所述僅為本發明之較佳實施例,凡依本發明申請專利範圍所做之均等變化與修飾,皆應屬本發明之涵蓋範圍。
202:調整網路
204:堆疊沙漏網路
206_1,206_2,206_3,206_4:沙漏網路
208:Maxout網路
210:逐元素最大值函數層
MD:模型
D_IN:訓練資料
A1,A2,A3,A4:蒸餾特徵層輸出
L1,L2,L3,L4:分類輸出
L_OUT:最終分類輸出

Claims (8)

  1. 一種模型訓練方法,包含:將一未標注資料輸入至具有一深度神經網路架構之一模型,其中該模型至少包含一堆疊沙漏網路,該堆疊沙漏網路包含複數個沙漏網路;針對該未標注資料,分別得到該複數個沙漏網路之複數個第一蒸餾特徵層輸出;根據一第一損失函數與該複數個第一蒸餾特徵層輸出,來計算一第一損失值;依據該第一損失值來進行反向傳播,以調整該模型的參數;於該未標注資料輸入至該模型之前:將一標注資料輸入至該模型;針對該標注資料,分別得到該複數個沙漏網路之複數個第二蒸餾特徵層輸出;透過該複數個第二蒸餾特徵層輸出,來分別得到複數個分類輸出;根據該第一損失函數與該複數個第二蒸餾特徵層輸出,來計算一第二損失值;根據一第二損失函數與該複數個分類輸出,來計算一第三損失值,其中該第二損失函數不同於該第一損失函數;依據該第二損失值與該第三損失值來計算一第四損失值;以及依據該第四損失值來進行反向傳播,以調整該模型的參數。
  2. 如請求項1所述之模型訓練方法,其中該第一損失函數採用L2損失函數,其以該複數個沙漏網路中最後一個沙漏網路所產生之第一蒸餾特徵層輸出作為該複數個沙漏網路中其它沙漏網路所產生之第一蒸餾特徵層輸出的 目標來計算該第一損失值。
  3. 如請求項1所述之模型訓練方法,其中該第二損失函數採用交叉熵損失函數。
  4. 如請求項1所述之模型訓練方法,另包含:參照一硬體平台的模型部署需求,適應性地裁剪所訓練之該模型中的該堆疊沙漏網路以產生部署至該硬體平台之一固化模型;其中該固化模型所包含之沙漏網路的個數取決於該硬體平台的模型部署需求。
  5. 一種模型訓練系統,包含:一儲存裝置,用以儲存一程式碼;以及一處理器,用以載入並執行該程式碼,以執行以下操作:將一未標注資料輸入至具有一深度神經網路架構之一模型,其中該模型至少包含一堆疊沙漏網路,該堆疊沙漏網路包含複數個沙漏網路;針對該未標注資料,分別得到該複數個沙漏網路之複數個第一蒸餾特徵層輸出;根據一第一損失函數與該複數個第一蒸餾特徵層輸出,來計算一第一損失值;依據該第一損失值來進行反向傳播,以調整該模型的參數;於該未標注資料輸入至該模型之前:將一標注資料輸入至該模型;針對該標注資料,分別得到該複數個沙漏網路之複數個第二蒸餾特 徵層輸出;透過該複數個第二蒸餾特徵層輸出,來分別得到複數個分類輸出;根據該第一損失函數與該複數個第二蒸餾特徵層輸出,來計算一第二損失值;根據一第二損失函數與該複數個分類輸出,來計算一第三損失值,其中該第二損失函數不同於該第一損失函數;依據該第二損失值與該第三損失值來計算一第四損失值;以及依據該第四損失值來進行反向傳播,以調整該模型的參數。
  6. 如請求項5所述之模型訓練系統,其中該第一損失函數採用L2損失函數,其以該複數個沙漏網路中最後一個沙漏網路所產生之第一蒸餾特徵層輸出作為該複數個沙漏網路中其它沙漏網路所產生之第一蒸餾特徵層輸出的目標來計算該第一損失值。
  7. 如請求項5所述之模型訓練系統,其中該第二損失函數採用交叉熵損失函數。
  8. 如請求項5所述之模型訓練系統,其中該處理器另執行該程式碼,以執行以下操作:參照一硬體平台的模型部署需求,適應性地裁剪所訓練之該模型中的該堆疊沙漏網路以產生部署至該硬體平台之一固化模型;其中該固化模型所包含之沙漏網路的個數取決於該硬體平台的模型部署需求。
TW111100068A 2021-11-24 2022-01-03 模型訓練方法與模型訓練系統 TWI793951B (zh)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
CN202111403181.9A CN113919499A (zh) 2021-11-24 2021-11-24 模型训练方法与模型训练系统
CN202111403181.9 2021-11-24

Publications (2)

Publication Number Publication Date
TWI793951B true TWI793951B (zh) 2023-02-21
TW202321993A TW202321993A (zh) 2023-06-01

Family

ID=79247901

Family Applications (1)

Application Number Title Priority Date Filing Date
TW111100068A TWI793951B (zh) 2021-11-24 2022-01-03 模型訓練方法與模型訓練系統

Country Status (2)

Country Link
CN (1) CN113919499A (zh)
TW (1) TWI793951B (zh)

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110298437A (zh) * 2019-06-28 2019-10-01 Oppo广东移动通信有限公司 神经网络的分割计算方法、装置、存储介质及移动终端
CN112579777A (zh) * 2020-12-23 2021-03-30 华南理工大学 一种未标注文本的半监督分类方法
CN112949786A (zh) * 2021-05-17 2021-06-11 腾讯科技(深圳)有限公司 数据分类识别方法、装置、设备及可读存储介质
TW202131219A (zh) * 2020-02-12 2021-08-16 大陸商深圳市商湯科技有限公司 圖像識別方法及圖像識別裝置、電子設備和電腦可讀儲存媒介

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN110298437A (zh) * 2019-06-28 2019-10-01 Oppo广东移动通信有限公司 神经网络的分割计算方法、装置、存储介质及移动终端
TW202131219A (zh) * 2020-02-12 2021-08-16 大陸商深圳市商湯科技有限公司 圖像識別方法及圖像識別裝置、電子設備和電腦可讀儲存媒介
CN112579777A (zh) * 2020-12-23 2021-03-30 华南理工大学 一种未标注文本的半监督分类方法
CN112949786A (zh) * 2021-05-17 2021-06-11 腾讯科技(深圳)有限公司 数据分类识别方法、装置、设备及可读存储介质

Also Published As

Publication number Publication date
TW202321993A (zh) 2023-06-01
CN113919499A (zh) 2022-01-11

Similar Documents

Publication Publication Date Title
US20210042580A1 (en) Model training method and apparatus for image recognition, network device, and storage medium
Zhang et al. A return-cost-based binary firefly algorithm for feature selection
CN108520206B (zh) 一种基于全卷积神经网络的真菌显微图像识别方法
CN108846826B (zh) 物体检测方法、装置、图像处理设备及存储介质
US10699192B1 (en) Method for optimizing hyperparameters of auto-labeling device which auto-labels training images for use in deep learning network to analyze images with high precision, and optimizing device using the same
WO2019223250A1 (zh) 一种确定剪枝阈值的方法及装置、模型剪枝方法及装置
CN107729999A (zh) 考虑矩阵相关性的深度神经网络压缩方法
CN102567742A (zh) 一种基于自适应核函数选择的支持向量机自动分类方法
WO2022217853A1 (en) Methods, devices and media for improving knowledge distillation using intermediate representations
CN107392919A (zh) 基于自适应遗传算法的灰度阈值获取方法、图像分割方法
CN113837376B (zh) 基于动态编码卷积核融合的神经网络剪枝方法
CN112488209A (zh) 一种基于半监督学习的增量式图片分类方法
Pietron et al. Retrain or not retrain?-efficient pruning methods of deep cnn networks
CN111626328B (zh) 一种基于轻量化深度神经网络的图像识别方法及装置
CN111695640A (zh) 地基云图识别模型训练方法及地基云图识别方法
CN116051388A (zh) 经由语言请求的自动照片编辑
CN112598062A (zh) 一种图像识别方法和装置
CN113205103A (zh) 一种轻量级的文身检测方法
CN112380917A (zh) 一种用于农作物病虫害检测的无人机
CN112597919A (zh) 基于YOLOv3剪枝网络和嵌入式开发板的实时药盒检测方法
CN109284378A (zh) 一种面向知识图谱的关系分类方法
JP6935868B2 (ja) 画像認識装置、画像認識方法、およびプログラム
TWI793951B (zh) 模型訓練方法與模型訓練系統
CN114386565A (zh) 提供神经网络
Gordienko et al. Adaptive iterative pruning for accelerating deep neural networks