WO2021176734A1 - 学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム - Google Patents
学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム Download PDFInfo
- Publication number
- WO2021176734A1 WO2021176734A1 PCT/JP2020/009878 JP2020009878W WO2021176734A1 WO 2021176734 A1 WO2021176734 A1 WO 2021176734A1 JP 2020009878 W JP2020009878 W JP 2020009878W WO 2021176734 A1 WO2021176734 A1 WO 2021176734A1
- Authority
- WO
- WIPO (PCT)
- Prior art keywords
- model
- estimation
- estimation result
- learning
- lightweight
- Prior art date
Links
Images
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N20/00—Machine learning
- G06N20/20—Ensemble learning
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
- G06F18/2178—Validation; Performance evaluation; Active pattern learning techniques based on feedback of a supervisor
- G06F18/2185—Validation; Performance evaluation; Active pattern learning techniques based on feedback of a supervisor the supervisor being an automated module, e.g. intelligent oracle
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/21—Design or setup of recognition systems or techniques; Extraction of features in feature space; Blind source separation
- G06F18/217—Validation; Performance evaluation; Active pattern learning techniques
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/24—Classification techniques
- G06F18/241—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches
- G06F18/2413—Classification techniques relating to the classification model, e.g. parametric or non-parametric approaches based on distances to training or reference patterns
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F18/00—Pattern recognition
- G06F18/20—Analysing
- G06F18/285—Selection of pattern recognition techniques, e.g. of classifiers in a multi-classifier system
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/045—Combinations of networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/04—Architecture, e.g. interconnection topology
- G06N3/0495—Quantised networks; Sparse networks; Compressed networks
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/084—Backpropagation, e.g. using gradient descent
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/06—Physical realisation, i.e. hardware implementation of neural networks, neurons or parts of neurons
- G06N3/063—Physical realisation, i.e. hardware implementation of neural networks, neurons or parts of neurons using electronic means
Definitions
- the present invention relates to a learning device, a learning method, a learning program, an estimation device, an estimation method, and an estimation program.
- DNN deep neural network
- Such real-time applications are required to process a large number of queries in real time with limited resources while maintaining the accuracy of DNN. Therefore, a technique called a model cascade has been proposed, which can speed up inference processing with less deterioration in accuracy by using a lightweight model with high speed and low accuracy and a high precision model with low speed and high accuracy.
- model cascade multiple models including a lightweight model and a high-precision model are used.
- estimation is first performed with a lightweight model, and if the result is credible, the result is adopted and processing is terminated. On the other hand, if the estimation result of the lightweight model is unreliable, the inference is continuously performed by the high-precision model and the result is adopted.
- IDK Cascade see, for example, Non-Patent Document 1 in which an IDK (I Don't Know) classifier is introduced to determine whether or not the estimation result of the lightweight model can be trusted is known.
- the conventional model cascade has a problem that calculation cost and calculation resource overhead may occur.
- the number of models is increased by one, resulting in computational cost and computational resource overhead.
- the learning device inputs the training data into the first model that outputs the estimation result based on the input data, and acquires the first estimation result.
- a model that outputs the estimation result based on the estimation unit, the correctness and certainty of the first estimation result, and the input data, and the processing speed is slower than that of the first model, or the first
- the first model and the second model are included based on the correctness of the second estimation result obtained by inputting the training data into the second model having higher estimation accuracy than the model of. It is characterized by having an update unit that updates the parameters of the first model so that the model cascade is optimized.
- FIG. 1 is a diagram illustrating a model cascade.
- FIG. 2 is a diagram showing a configuration example of the learning device according to the first embodiment.
- FIG. 3 is a diagram showing an example of loss for each case.
- FIG. 4 is a flowchart showing the flow of learning processing of the high-precision model.
- FIG. 5 is a flowchart showing the flow of learning processing of the lightweight model.
- FIG. 6 is a diagram showing a configuration example of the estimation system according to the second embodiment.
- FIG. 7 is a flowchart showing the flow of the estimation process.
- FIG. 8 is a diagram showing the experimental results.
- FIG. 9 is a diagram showing the experimental results.
- FIG. 10 is a diagram showing the experimental results.
- FIG. 11 is a diagram showing the experimental results.
- FIG. 1 is a diagram illustrating a model cascade.
- FIG. 2 is a diagram showing a configuration example of the learning device according to the first embodiment.
- FIG. 3 is a diagram showing
- FIG. 12 is a diagram showing the experimental results.
- FIG. 13 is a diagram showing a configuration example of the estimation device according to the third embodiment.
- FIG. 14 is a diagram illustrating a model cascade including three or more models.
- FIG. 15 is a flowchart showing the flow of learning processing of three or more models.
- FIG. 16 is a flowchart showing the flow of estimation processing by three or more models.
- FIG. 17 is a diagram showing an example of a computer that executes a learning program.
- the learning device the learning method, the learning program, the estimation device, the estimation method, and the embodiments of the estimation program according to the present application will be described in detail based on the drawings.
- the present invention is not limited to the embodiments described below.
- the learning device learns a high-precision model and a lightweight model using the input learning data. Then, the learning device outputs information about the trained high-precision model and information about the trained lightweight model. For example, the learning device outputs the parameters required to build each model.
- the high-precision model and the lightweight model are models that output estimation results based on the input data.
- the high-precision model and the lightweight model are multi-class classification models that take an image as an input and estimate the probability of each class of the object appearing in the image.
- the high-precision model and the lightweight model are not limited to such a multi-class classification model, and may be any model to which machine learning can be applied.
- the high-precision model has a slower processing speed and higher estimation accuracy than the lightweight model.
- the high-precision model may be known to have a slower processing speed than the lightweight model. In this case, the high-precision model is expected to have higher estimation accuracy than the lightweight model. Further, the high-precision model may be known to have higher estimation accuracy than the lightweight model. In this case, the lightweight model is expected to be faster than the high-precision model.
- FIG. 1 is a diagram illustrating a model cascade.
- the lightweight model outputs the probabilities of each class for the objects appearing in the input image. For example, the lightweight model outputs the probability that the object in the image is cat as about 0.5. In addition, the lightweight model outputs the probability that the object in the image is a dog as about 0.35.
- the estimation result is adopted. That is, the estimation result of the lightweight model is output as the final estimation result of the model cascade.
- the estimation result obtained by inputting the same image into the high-precision model is output as the final estimation result of the model cascade.
- the high-precision model like the lightweight model, outputs the probabilities of each class for the objects appearing in the input image. For example, the condition is that the maximum value of the probability output by the lightweight model exceeds the threshold.
- the high-precision model is ResNet18, which runs on a server or the like.
- the lightweight model is MobileNet V2, which operates on IoT devices and various terminal devices.
- the high-precision model and the lightweight model may operate on the same computer.
- FIG. 2 is a diagram showing a configuration example of the learning device according to the first embodiment.
- the learning device 10 accepts the input of the learning data and outputs the trained high-precision model information and the trained lightweight model information. Further, the learning device 10 has a high-precision model learning unit 11 and a lightweight model learning unit 12.
- the high-precision model learning unit 11 has an estimation unit 111, a loss calculation unit 112, and an update unit 113. Further, the high-precision model learning unit 11 stores the high-precision model information 114.
- the high-precision model information 114 is information such as parameters for constructing a high-precision model. It is assumed that the training data has a known label. For example, the training data is a combination of an image and a label (correct class).
- the estimation unit 111 inputs learning data into the high-precision model constructed based on the high-precision model information 114, and acquires the estimation result.
- the estimation unit 111 receives the input of the learning data and outputs the estimation result.
- the loss calculation unit 112 calculates the loss based on the estimation result acquired by the estimation unit 111.
- the loss calculation unit 112 receives the input of the estimation result and the label, and outputs the loss.
- the loss calculation unit 112 calculates the loss so that the smaller the certainty of the label is, the larger the loss is in the estimation result acquired by the estimation unit 111.
- certainty is the degree of certainty that the estimation result is correct.
- the certainty may be the probability output by the above-mentioned multiclass classification model.
- the loss calculation unit 112 can calculate the softmax cross entropy described later as a loss.
- the update unit 113 updates the parameters of the high-precision model so that the loss is optimized. For example, if the high-precision model is a neural network, the update unit 113 updates the parameters of the high-precision model by an error backpropagation method or the like. Specifically, the update unit 113 updates the high-precision model information 114. The update unit 113 receives the input of the loss calculated by the loss calculation unit 112, and outputs the updated model information.
- the lightweight model learning unit 12 has an estimation unit 121, a loss calculation unit 122, and an update unit 123. Further, the lightweight model learning unit 12 stores the lightweight model information 124.
- the lightweight model information 124 is information such as parameters for constructing a lightweight model.
- the estimation unit 121 inputs learning data into the lightweight model constructed based on the lightweight model information 124, and acquires the estimation result.
- the estimation unit 121 receives the input of the learning data and outputs the estimation result.
- the high-precision model learning unit 11 learns the high-precision model based on the output of the high-precision model.
- the lightweight model learning unit 12 learns the lightweight model based on the outputs of both the high-precision model and the lightweight model.
- the loss calculation unit 122 calculates the loss based on the estimation result acquired by the estimation unit.
- the loss calculation unit 122 receives the estimation result by the high-precision model, the estimation result by the lightweight model, and the input of the label, and outputs the loss.
- the estimation result by the high-precision model may be an estimation result obtained by further inputting learning data into the high-precision model after learning by the high-precision model learning unit 11.
- the lightweight model learning unit 12 accepts an input as to whether or not the estimation result by the high-precision model is correct. For example, if the class with the highest probability output by the high-precision model matches the label, the estimation result is correct.
- the loss calculation unit 122 calculates the loss for the purpose of maximizing the profit when the model cascade is configured, in addition to maximizing the estimation accuracy of the lightweight model alone.
- the profit increases as the estimation accuracy increases, and increases as the calculation cost decreases.
- the high-precision model has a feature that the estimation accuracy is high but the calculation cost is high.
- the lightweight model is characterized in that the estimation accuracy is low but the calculation cost is low. Therefore, the loss calculation unit 122 calculates the loss Loss as in Eq. (1).
- w is a weight, which is a preset parameter.
- L classifier is the softmax entropy in the multi-class classification model.
- the L classifier is an example of the first term in which the smaller the certainty of the correct answer in the estimation result by the lightweight model, the larger the certainty.
- the L classifier is expressed as in Eq. (2).
- N is the number of samples.
- k is the number of classes.
- y is a label representing the correct class.
- q is the probability output by the lightweight model.
- i is the number that identifies the sample.
- j is a number that identifies the class.
- the labels y i and j are 1 if the j-th class is correct and 0 if the j-th class is incorrect in the i-th sample.
- L cascade is a term for maximizing profit when constructing a model cascade.
- the L cascade represents the loss for each sample when the estimation results of the high-precision model and the lightweight model are adopted based on the certainty of the lightweight model.
- the loss includes a penalty for improper certainty and the cost of using a precision model.
- the loss is divided into four patterns according to the combination of whether or not the estimation result of the high-precision model is correct and whether or not the estimation result of the lightweight model is correct. Details will be described later, but if the estimation of the high-precision model is incorrect and the certainty of the lightweight model is low, the penalty will be large. On the other hand, if the estimation of the lightweight model is correct and the certainty of the lightweight model is high, the penalty is small.
- L cascade is expressed by Eq. (3).
- 1 fast is an indicator function that returns 0 if the estimation result of the lightweight model is correct and 1 if the estimation result of the lightweight model is incorrect.
- 1 acc is an indicator function that returns 0 if the estimation result of the high-precision model is correct and 1 if the estimation result of the high-precision model is incorrect.
- COST acc is the cost of making an estimate with a high-precision model and is a preset parameter.
- max j q i, j is the maximum value of the probability output by the lightweight model and is an example of certainty. If the estimation result is correct, it can be said that the higher the certainty, the higher the estimation accuracy. On the other hand, if the estimation result is incorrect, it can be said that the higher the certainty, the lower the estimation accuracy.
- Equation (3) max j q i, j 1 fast is an example of the second term, which increases as the certainty of the estimation result by the lightweight model increases when the estimation result by the lightweight model is incorrect.
- (1-max j q i, j ) 1 acc in Eq. (3) becomes larger as the certainty of the estimation result by the lightweight model is smaller when the estimation result by the high-precision model is incorrect. It is an example of the term.
- (1-max j q i, j ) COST acc in Eq. (3) is an example of the fourth term, which increases as the certainty of the estimation result by the lightweight model decreases. In this case, the minimization of the loss by the update unit 123 corresponds to the optimization of the loss.
- the update unit 123 updates the parameters of the lightweight model so that the loss is optimized. That is, the update unit 123 is a model that outputs the estimation result based on the estimation result by the lightweight model and the input data, and is a training data for a high-precision model that has a slower processing speed and a higher estimation accuracy than the lightweight model. Based on the estimation results obtained by inputting, the parameters of the lightweight model are updated so that the model cascade including the lightweight model and the high-precision model is optimized. The update unit 123 receives the input of the loss calculated by the loss calculation unit 122, and outputs the updated model information.
- FIG. 3 is a diagram showing an example of loss for each case.
- the vertical axis is the value of L cascade.
- the horizontal axis is the value of max j q i, j.
- set COST acc 0.5.
- max j q i, j is the certainty of the estimation result by the lightweight model, and is simply called the certainty here.
- “ ⁇ ” in FIG. 3 is the value of L cascade for the certainty when the estimation results of both the lightweight model and the high-precision model are correct.
- the higher the certainty the smaller the value of L cascade. This is because if the estimation result by the lightweight model is correct, the greater the certainty, the easier it is for the lightweight model to be adopted.
- “ ⁇ ” in FIG. 3 is the value of L cascade for the certainty when the estimation result of the lightweight model is correct and the estimation result of the high-precision model is incorrect.
- the higher the certainty the smaller the value of L cascade.
- the maximum value and the degree of decrease of L cascade are larger than those of " ⁇ ". This is because if the estimation result by the high-precision model is incorrect and the estimation result by the lightweight model is correct, the greater the certainty, the more likely it is that the lightweight model will be adopted.
- “ ⁇ ” in FIG. 3 is the value of L cascade for the certainty when the estimation result of the lightweight model is incorrect and the estimation result of the high-precision model is correct.
- the higher the certainty the larger the value of L cascade. This is because even if the estimation result of the lightweight model is incorrect, the smaller the certainty, the more difficult it is to adopt the estimation result.
- “ ⁇ ” in FIG. 3 is the value of L cascade for the certainty when the estimation results of both the lightweight model and the high-precision model are incorrect. In this case, the higher the certainty, the smaller the value of L cascade. However, the value of L cascade is larger than that of " ⁇ ". This is because the estimation results of both models are always incorrect and the loss is always large, and in such a situation, the lightweight model should be able to make an accurate estimation.
- FIG. 4 is a flowchart showing the flow of learning processing of the high-precision model.
- the estimation unit 111 estimates a class of learning data using a high-precision model (step S101).
- the loss calculation unit 112 calculates the loss based on the estimation result of the high-precision model (step S102). Then, the update unit 113 updates the parameters of the high-precision model so that the loss is optimized (step S103).
- the learning device 10 may repeat the processes from step S101 to step S103 until the end condition is satisfied.
- the end condition may be that the process is repeated a predetermined number of times, or that the update width of the parameter has converged.
- FIG. 5 is a flowchart showing the flow of learning processing of the lightweight model. As shown in FIG. 5, first, the estimation unit 121 estimates a class of learning data using a lightweight model (step S201).
- the loss calculation unit 122 calculates the loss based on the estimation result of the lightweight model, the estimation result of the high-precision model, and the cost of estimation by the high-precision model (step S202). Then, the update unit 123 updates the parameters of the lightweight model so that the loss is optimized (step S203). The learning device 10 may repeat the processes from step S201 to step S203 until the end condition is satisfied.
- the estimation unit 121 inputs the learning data to the lightweight model that outputs the estimation result based on the input data, and acquires the first estimation result.
- the update unit 123 is a model that outputs an estimation result based on the first estimation result and the input data, and is a training data for a high-precision model that has a slower processing speed and a higher estimation accuracy than the lightweight model. Based on the second estimation result obtained by inputting, the parameters of the lightweight model are updated so that the model cascade including the lightweight model and the high-precision model is optimized.
- the lightweight model in the model cascade composed of the lightweight model and the high-precision model, the lightweight model can perform estimation suitable for the model cascade without providing a model such as an IDK classifier. Therefore, the performance of the model cascade can be improved. As a result, according to the first embodiment, not only the accuracy of the model cascade can be improved, but also the calculation cost and the overhead of the calculation resource can be suppressed. Further, in the first embodiment, since the loss function is changed, it is not necessary to change the model architecture, and there is no limitation on the model to be applied and the optimization method.
- the update unit 123 has a first term that increases as the certainty of the correct answer in the first estimation result decreases, and increases as the certainty of the first estimation result increases when the first estimation result is incorrect.
- the parameters of the lightweight model are updated so that the loss calculated based on the loss function including the fourth term is minimized.
- the estimation system 2 includes a high-precision estimation device 20 and a lightweight estimation device 30. Further, the high-precision estimation device 20 and the lightweight estimation device 30 are connected via the network N.
- the network N is, for example, the Internet.
- the high-precision estimation device 20 may be a server provided in a cloud environment.
- the lightweight estimation device 30 may be an IoT device and various terminal devices.
- the high-precision estimation device 20 stores the high-precision model information 201.
- the high-precision model information 201 is information such as parameters of the trained high-precision model. Further, the high-precision estimation device 20 has an estimation unit 202.
- the estimation unit 202 inputs estimation data into the high-precision model constructed based on the high-precision model information 201, and acquires the estimation result.
- the estimation unit 202 receives the input of the estimation data and outputs the estimation result. It is assumed that the estimation data has an unknown label. For example, the estimation data is an image.
- the high-precision estimation device 20 and the lightweight estimation device 30 form a model cascade. Therefore, the estimation unit 202 does not always estimate the estimation data. When it is determined that the estimation result of the lightweight model is not adopted, the estimation unit 202 performs estimation by the high-precision model.
- the lightweight estimation device 30 stores the lightweight model information 301.
- the lightweight model information 301 is information such as parameters of the learned lightweight model. Further, the lightweight estimation device 30 has an estimation unit 302 and a determination unit 303.
- the estimation unit 302 is a model that outputs an estimation result obtained by inputting training data into a lightweight model that outputs an estimation result based on the input data and an estimation result based on the input data. , Pre-training so that the model cascade including the lightweight model and the high-precision model is optimized based on the estimation result obtained by inputting the training data into the high-precision model with higher estimation accuracy than the lightweight model.
- the estimation result is acquired by inputting the estimation data into the lightweight model in which the set parameters are set.
- the estimation unit 302 receives the input of the estimation data and outputs the estimation result.
- the determination unit 303 determines whether or not the estimation result by the lightweight model satisfies a predetermined condition regarding the estimation accuracy. For example, the determination unit 303 determines that the estimation result by the lightweight model satisfies the condition when the certainty level is equal to or higher than the threshold value. In that case, the estimation system 2 adopts the estimation result of the lightweight model.
- the estimation unit 202 of the high-precision estimation device 20 inputs the estimation data into the high-precision model and outputs the estimation result. get. In that case, the estimation system 2 adopts the estimation result of the high-precision model.
- FIG. 7 is a flowchart showing a flow of estimation processing. As shown in FIG. 7, first, the estimation unit 302 estimates a class of estimation data using a lightweight model (step S301).
- the determination unit 303 determines whether or not the estimation result satisfies the condition (step S302).
- the estimation system 2 outputs the estimation result of the lightweight model (step S303).
- the estimation unit 202 estimates the class of estimation data using the high-precision model (step S304). Then, the estimation system 2 outputs the estimation result of the high-precision model (step S305).
- the estimation unit 302 estimates based on the estimation result obtained by inputting the training data into the lightweight model that outputs the estimation result based on the input data and the input data.
- a model cascade that includes a lightweight model and a high-precision model based on the estimation results obtained by inputting training data into a high-precision model that outputs results and has higher estimation accuracy than the lightweight model.
- Data for estimation is input to a lightweight model in which parameters trained in advance to be optimized are set, and the estimation result is acquired.
- the determination unit 303 determines whether or not the estimation result by the lightweight model satisfies a predetermined condition regarding the estimation accuracy.
- the estimation unit 202 inputs the estimation data into the high-precision model and acquires the estimation result.
- the estimation unit 202 inputs the estimation data into the high-precision model and acquires the estimation result.
- the estimation system 2 can be expressed as follows. That is, the estimation system 2 has a high-precision estimation device 20 and a lightweight estimation device 30.
- the lightweight estimation device 30 is a model that outputs an estimation result obtained by inputting training data into a lightweight model that outputs an estimation result based on the input data and an estimation result based on the input data. Based on the estimation results obtained by inputting the training data into the high-precision model, which has a slower processing speed than the lightweight model or higher estimation accuracy than the lightweight model, the lightweight model and the high-precision model are selected.
- the estimation unit 302 which inputs data for estimation and acquires the first estimation result, and the first estimation result are in a lightweight model in which parameters trained in advance are set so that the including model cascade is optimized.
- the high-precision estimation device 20 inputs estimation data into the high-precision model and acquires a second estimation result. It has a unit 202. Further, the high-precision estimation device 20 may acquire estimation data from the lightweight estimation device 30.
- the estimation unit 202 performs estimation according to the estimation result by the lightweight estimation device 30. That is, the estimation unit 202 is a model that outputs the estimation result obtained by inputting the training data into the lightweight model that outputs the estimation result based on the input data and the estimation result based on the input data. Therefore, based on the estimation results obtained by inputting the training data to the high-precision model, which is slower than the lightweight model or has higher estimation accuracy than the lightweight model, the lightweight model and the high-precision model are selected. For estimation, according to the first estimation result acquired by the lightweight estimation device 30 by inputting data for estimation into a lightweight model in which parameters trained in advance are set so that the including model cascade is optimized. Input the data into the precision model to get the second estimation result.
- FIGS. 9 and 10 show the relationship between the number of offloads and the accuracy when the test data is estimated by adopting the threshold value that has the highest accuracy in the above validation data. From this, it can be seen that according to the second embodiment, the number of offloads is most reduced while maintaining the accuracy of the high-precision model.
- FIGS. 11 and 12 show the relationship between the number of offloads and the accuracy when the number of offloads is the smallest while maintaining the accuracy of the high-precision model in the test data. From this, it can be seen that the number of offloads is most reduced according to the second embodiment.
- FIG. 13 is a diagram showing a configuration example of the estimation device according to the third embodiment.
- the estimation device 2a has the same function as the estimation system 2 of the second embodiment.
- the high-precision estimation unit 20a has the same function as the high-precision estimation device 20 of the second embodiment.
- the lightweight estimation unit 30a has the same function as the lightweight estimation device 30 of the second embodiment. Unlike the second embodiment, since the estimation unit 202 and the determination unit 303 are in the same device, data exchange via the network does not occur in the estimation process.
- FIG. 14 is a diagram illustrating a model cascade including three or more models.
- M M> 3 models.
- the m + 1st model M-1 ⁇ m ⁇ 1
- the relationship between the m + 1st model and the mth model is the same as the relationship between the high-precision model and the lightweight model.
- the Mth model is the most accurate model
- the first model is the lightest model.
- the estimation process using three or more models can be realized by using the estimation system 2 described in the second embodiment.
- the estimation system 2 replaces the high-precision model information 201 with the information of the second model and the lightweight model information 301 with the information of the first model. Then, the estimation system 2 executes the same estimation process as in the second embodiment.
- the estimation system 2 replaces the high-precision model information 201 with the information of the third model.
- the lightweight model information 301 is replaced with the information of the second model, and the estimation process is further executed.
- the estimation system 2 repeats this process until an estimation result satisfying the conditions is obtained or the estimation process by the Mth model is completed. Note that the same processing can be realized only by the lightweight estimation device 30 by replacing the lightweight model information 301.
- the learning process of three or more models can be realized by using the learning device 10 described in the first embodiment.
- the learning device 10 extracts two models having consecutive numbers from the M models, and executes the learning process using the information of those models.
- the learning device 10 replaces the high-precision model information 114 with the information of the Mth model, and replaces the lightweight model information 124 with the information of the M-1st model.
- the learning device 10 executes the same learning process as in the first embodiment.
- the learning device 10 replaces the high-precision model information 114 with the information of the m-th model, replaces the lightweight model information 124 with the information of the m-1st model, and is the same as in the first embodiment. Executes the learning process of.
- FIG. 15 is a flowchart showing the flow of learning processing of three or more models.
- the learning device 10 of the first embodiment performs the learning process.
- the learning device 10 sets M as the initial value of m (step S401).
- the estimation unit 121 estimates a class of training data using the m-1st model (step S402).
- the loss calculation unit 122 calculates the loss based on the estimation result of the m-1st model, the estimation result of the mth model, and the estimation cost of the mth model (step S403). Then, the update unit 123 updates the parameters of the m-1st model so that the loss is optimized (step S404).
- the learning device 10 reduces m by 1 (step S405). When m reaches 1, the learning device 10 ends the process. On the other hand, when m has not reached 1 (step S406, No), the learning device 10 returns to step S402 and repeats the process.
- FIG. 16 is a flowchart showing the flow of estimation processing by three or more models.
- the lightweight estimation device 30 of the second embodiment performs the estimation process.
- the lightweight estimation device 30 sets 1 as the initial value of m (step S501).
- the estimation unit 302 estimates the class of estimation data using the m-th model (step S502).
- the determination unit 303 determines whether or not the estimation result satisfies the condition and whether or not m reaches M (step S503).
- the lightweight estimation device 30 outputs the estimation result of the m-th model (step S504).
- the estimation unit 302 causes the lightweight estimation device 30 to increase m by 1 (step S505), and step S502. Return to and repeat the process.
- the number of IDK classifiers increases, and the calculation cost and the overhead of calculation resources increase.
- the fourth embodiment even if the number of models constituting the model cascade is increased to three or more, the problem of increasing such overhead does not occur.
- each component of each of the illustrated devices is a functional concept, and does not necessarily have to be physically configured as shown in the figure. That is, the specific form of distribution and integration of each device is not limited to the one shown in the figure, and all or part of the device is functionally or physically dispersed or physically distributed in arbitrary units according to various loads and usage conditions. Can be integrated and configured. Further, each processing function performed by each device may be realized by a CPU and a program analyzed and executed by the CPU, or may be realized as hardware by wired logic.
- the learning device 10 and the lightweight estimation device 30 can be implemented by installing a program that executes the above learning process or estimation process as package software or online software on a desired computer.
- the information processing device can function as the learning device 10 or the lightweight estimation device 30.
- the information processing device referred to here includes a desktop type or notebook type personal computer.
- information processing devices include smartphones, mobile communication terminals such as mobile phones and PHS (Personal Handyphone System), and slate terminals such as PDAs (Personal Digital Assistants).
- the learning device 10 and the lightweight estimation device 30 can be implemented as a server device in which the terminal device used by the user is a client and the service related to the learning process or the estimation process is provided to the client.
- the server device is implemented as a server device that provides a service that inputs training data and outputs training model information.
- the server device may be implemented as a Web server, or may be implemented as a cloud that provides services related to the above processing by outsourcing.
- FIG. 17 is a diagram showing an example of a computer that executes a learning program.
- the estimation program may also be executed by a similar computer.
- the computer 1000 has, for example, a memory 1010 and a processor 1020.
- the computer 1000 also has a hard disk drive interface 1030, a disk drive interface 1040, a serial port interface 1050, a video adapter 1060, and a network interface 1070. Each of these parts is connected by a bus 1080.
- the memory 1010 includes a ROM (Read Only Memory) 1011 and a RAM 1012.
- the ROM 1011 stores, for example, a boot program such as a BIOS (BASIC Input Output System).
- the processor 1020 includes a CPU 1021 and a GPU (Graphics Processing Unit) 1022.
- the hard disk drive interface 1030 is connected to the hard disk drive 1090.
- the disk drive interface 1040 is connected to the disk drive 1100. For example, a removable storage medium such as a magnetic disk or an optical disk is inserted into the disk drive 1100.
- the serial port interface 1050 is connected to, for example, a mouse 1110 and a keyboard 1120.
- the video adapter 1060 is connected to, for example, the display 1130.
- the hard disk drive 1090 stores, for example, OS1091, application program 1092, program module 1093, and program data 1094. That is, the program that defines each process of the learning device 10 is implemented as a program module 1093 in which a code that can be executed by a computer is described.
- the program module 1093 is stored in, for example, the hard disk drive 1090.
- the program module 1093 for executing the same processing as the functional configuration in the learning device 10 is stored in the hard disk drive 1090.
- the hard disk drive 1090 may be replaced by an SSD.
- the setting data used in the processing of the above-described embodiment is stored as program data 1094 in, for example, a memory 1010 or a hard disk drive 1090. Then, the CPU 1020 reads the program module 1093 and the program data 1094 stored in the memory 1010 and the hard disk drive 1090 into the RAM 1012 as needed, and executes the processing of the above-described embodiment.
- the program module 1093 and the program data 1094 are not limited to those stored in the hard disk drive 1090, but may be stored in, for example, a removable storage medium and read by the CPU 1020 via the disk drive 1100 or the like. Alternatively, the program module 1093 and the program data 1094 may be stored in another computer connected via a network (LAN (Local Area Network), WAN (Wide Area Network), etc.). Then, the program module 1093 and the program data 1094 may be read by the CPU 1020 from another computer via the network interface 1070.
- LAN Local Area Network
- WAN Wide Area Network
Landscapes
- Engineering & Computer Science (AREA)
- Theoretical Computer Science (AREA)
- Data Mining & Analysis (AREA)
- Physics & Mathematics (AREA)
- Artificial Intelligence (AREA)
- General Physics & Mathematics (AREA)
- General Engineering & Computer Science (AREA)
- Evolutionary Computation (AREA)
- Life Sciences & Earth Sciences (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Software Systems (AREA)
- Mathematical Physics (AREA)
- Bioinformatics & Computational Biology (AREA)
- Bioinformatics & Cheminformatics (AREA)
- Evolutionary Biology (AREA)
- Computing Systems (AREA)
- Biophysics (AREA)
- Health & Medical Sciences (AREA)
- Biomedical Technology (AREA)
- Computational Linguistics (AREA)
- General Health & Medical Sciences (AREA)
- Molecular Biology (AREA)
- Medical Informatics (AREA)
- Image Analysis (AREA)
Abstract
推定部は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力し、第1の推定結果を取得する。また、更新部は、第1の推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも処理速度が遅く推定精度が高い高精度モデルに学習用データを入力して得られた第2の推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように、軽量モデルのパラメータを更新する。
Description
本発明は、学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラムに関する。
昨今、ビデオ監視、音声アシスタント、自動運転といった、DNN(deep neural network)を使ったリアルタイムアプリケーションが登場している。このようなリアルタイムアプリケーションには、DNNの精度を保ちつつ限られたリソースで多量のクエリをリアルタイムに処理することが求められる。そこで、高速かつ低精度な軽量モデルと低速かつ高精度な高精度モデルを使って、精度劣化少なく推論処理を高速化可能なモデルカスケードという技術が提案されている。
モデルカスケードでは軽量モデル及び高精度モデルを含む複数のモデルが用いられる。モデルカスケードによる推論を行う際は、まず軽量モデルで推定を行い、その結果が信用できる場合にはその結果を採用して処理を終了する。一方、軽量モデルの推定結果が信用できない場合には、続けて高精度モデルで推論を行い、その結果を採用する。例えば、軽量モデルの推定結果を信用できるか否かを判定するためにIDK(I Don’t Know)分類器を導入したIDK Cascade(例えば、非特許文献1を参照)が知られている。
Wang, Xin, et al. "Idk cascades: Fast deep learning by learning not to overthink." arXiv preprint arXiv:1706.00885 (2017).
しかしながら、従来のモデルカスケードには、計算コスト及び計算リソースのオーバーヘッドが生じる場合があるという問題がある。例えば、非特許文献1の技術では、軽量分類器及び高精度分類器に加え、IDK分類器を設ける必要がある。このため、モデルが1つ増えることになり、計算コスト及び計算リソースのオーバーヘッドが生じる。
上述した課題を解決し、目的を達成するために、学習装置は、入力されたデータを基に推定結果を出力する第1のモデルに学習用データを入力し、第1の推定結果を取得する推定部と、前記第1の推定結果の正否及び確信度と、入力されたデータを基に推定結果を出力するモデルであって、前記第1のモデルよりも処理速度が遅い、又は前記第1のモデルよりも推定精度が高い第2のモデルに前記学習用データを入力して得られた第2の推定結果の正否と、を基に、前記第1のモデルと前記第2のモデルを含むモデルカスケードが最適化されるように、前記第1のモデルのパラメータを更新する更新部と、を有することを特徴とする。
本発明によれば、モデルカスケードの計算コスト及び計算リソースのオーバーヘッドを抑止することができる。
以下に、本願に係る学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラムの実施形態を図面に基づいて詳細に説明する。なお、本発明は、以下に説明する実施形態により限定されるものではない。
[第1の実施形態]
第1の実施形態に係る学習装置は、入力された学習用データを用いて、高精度モデル及び軽量モデルの学習を行う。そして、学習装置は、学習済みの高精度モデルに関する情報、及び学習済みの軽量モデルに関する情報を出力する。例えば、学習装置は、各モデルを構築するために必要なパラメータを出力する。
第1の実施形態に係る学習装置は、入力された学習用データを用いて、高精度モデル及び軽量モデルの学習を行う。そして、学習装置は、学習済みの高精度モデルに関する情報、及び学習済みの軽量モデルに関する情報を出力する。例えば、学習装置は、各モデルを構築するために必要なパラメータを出力する。
高精度モデル及び軽量モデルは、入力されたデータを基に推定結果を出力するモデルである。第1の実施形態において、高精度モデル及び軽量モデルは、画像を入力とし、当該画像に写る物体のクラスごとの確率を推定する多クラス分類モデルであるものとする。ただし、高精度モデル及び軽量モデルは、そのような多クラス分類モデルに限定されるものではなく、機械学習が適用可能なあらゆるモデルであってよい。
高精度モデルは、軽量モデルと比べて処理速度が遅く推定精度が高いものとする。なお、高精度モデルは、単に軽量モデルより処理速度が遅いことが既知のものであってもよい。この場合、高精度モデルの方が軽量モデルよりも推定精度が高いことが期待される。また、高精度モデルは、単に軽量モデルより推定精度が高いことが既知のものであってもよい。この場合、軽量モデルの方が高精度モデルよりも処理速度が速いことが期待される。
高精度モデル及び軽量モデルは、モデルカスケードを構成する。図1は、モデルカスケードについて説明する図である。説明のため、図1には2つの画像を表示しているが、いずれも同じ画像である。図1に示すように、軽量モデルは、入力された画像に写る物体について各クラスの確率を出力する。例えば、軽量モデルは、画像に写る物体がcatである確率を約0.5と出力する。また、軽量モデルは、画像に写る物体がdogである確率を約0.35と出力する。
ここで、軽量モデルの出力、すなわち推定結果が条件を満たす場合、当該推定結果が採用される。つまり、軽量モデルの推定結果が、モデルカスケードの最終的な推定結果として出力される。一方で、軽量モデルの推定結果が条件を満たさない場合、同一の画像を高精度モデルに入力して得られた推定結果が、モデルカスケードの最終的な推定結果として出力される。ただし、高精度モデルは、軽量モデルと同様に、入力された画像に写る物体について各クラスの確率を出力する。例えば、条件は、軽量モデルが出力した確率の最大値が閾値を超えていることである。
例えば、高精度モデルはResNet18であって、サーバ等で動作する。また、例えば、軽量モデルはMobileNetV2であって、IoT機器及び各種端末装置で動作する。なお、高精度モデル及び軽量モデルは、同一のコンピュータで動作するものであってもよい。
[第1の実施形態の構成]
図2は、第1の実施形態に係る学習装置の構成例を示す図である。図2に示すように、学習装置10は、学習用データの入力を受け付け、学習済み高精度モデル情報及び学習済み軽量モデル情報を出力する。また、学習装置10は、高精度モデル学習部11及び軽量モデル学習部12を有する。
図2は、第1の実施形態に係る学習装置の構成例を示す図である。図2に示すように、学習装置10は、学習用データの入力を受け付け、学習済み高精度モデル情報及び学習済み軽量モデル情報を出力する。また、学習装置10は、高精度モデル学習部11及び軽量モデル学習部12を有する。
高精度モデル学習部11は、推定部111、損失計算部112、更新部113を有する。また、高精度モデル学習部11は、高精度モデル情報114を記憶する。高精度モデル情報114は、高精度モデルを構築するためのパラメータ等の情報である。学習用データは、ラベルが既知のデータであるものとする。例えば、学習用データは、画像とラベル(正解のクラス)の組み合わせである。
推定部111は、高精度モデル情報114を基に構築された高精度モデルに学習用データを入力し、推定結果を取得する。推定部111は、学習用データの入力を受け付け、推定結果を出力する。
損失計算部112は、推定部111によって取得された推定結果を基に損失を計算する。損失計算部112は、推定結果及びラベルの入力を受け付け、損失を出力する。例えば、損失計算部112は、推定部111によって取得された推定結果において、ラベルに対する確信度が小さいほど大きくなるように損失を計算する。例えば、確信度は、推定結果が正解であることの確からしさの度合いである。例えば、確信度は、前述の多クラス分類モデルが出力した確率であってもよい。具体的には、損失計算部112は、後述するソフトマックスクロスエントロピーを損失として計算することができる。
更新部113は、損失が最適化されるように、高精度モデルのパラメータを更新する。例えば、高精度モデルがニューラルネットワークであれば、更新部113は、誤差逆伝播法等により高精度モデルのパラメータを更新する。具体的には、更新部113は、高精度モデル情報114を更新する。更新部113は、損失計算部112によって計算された損失の入力を受け付け、更新済みのモデルの情報を出力する。
軽量モデル学習部12は、推定部121、損失計算部122、更新部123を有する。また、軽量モデル学習部12は、軽量モデル情報124を記憶する。軽量モデル情報124は、軽量モデルを構築するためのパラメータ等の情報である。
推定部121は、軽量モデル情報124を基に構築された軽量モデルに学習用データを入力し、推定結果を取得する。推定部121は、学習用データの入力を受け付け、推定結果を出力する。
ここで、高精度モデル学習部11は、高精度モデルの出力を基に、高精度モデルの学習を行うものであった。一方で、軽量モデル学習部12は、高精度モデル及び軽量モデルの両方の出力を基に、軽量モデルの学習を行う。
損失計算部122は、推定部によって取得された推定結果を基に損失を計算する。損失計算部122は、高精度モデルによる推定結果、軽量モデルによる推定結果及びラベルの入力を受け付け、損失を出力する。高精度モデルによる推定結果は、高精度モデル学習部11による学習が行われた後の高精度モデルに、さらに学習用データを入力して得られた推定結果であってよい。さらに具体的には、軽量モデル学習部12は、高精度モデルによる推定結果が正解であったか否かの入力を受け付ける。例えば、高精度モデルが出力した確率が最大であったクラスがラベルと一致していれば、その推定結果は正解である。
損失計算部122は、軽量モデル単体での推定精度の最大化に加え、モデルカスケードを構成した場合の利益の最大化を目的として損失を計算する。ここで、利益は、推定精度が高いほど大きくなり、計算コストが小さいほど大きくなるものとする。
例えば、高精度モデルには、推定精度は高いが計算コストが大きいという特徴がある。また、また、例えば、軽量モデルには、推定精度は低いが計算コストが小さいという特徴がある。そこで、損失計算部122は、(1)式のように損失Lossを計算する。ただし、wは重みであり、事前に設定されるパラメータである。
ここで、Lclassifierは、多クラス分類モデルにおけるソフトマックスエントロピーである。また、Lclassifierは、軽量モデルによる推定結果における正解に対する確信度が小さいほど大きくなる第1の項の一例である。Lclassifierは、(2)式のように表される。ただし、Nはサンプル数である。また、kはクラス数である。また、yは正解のクラスを表すラベルである。また、qは軽量モデルによって出力された確率である。iはサンプルを識別する番号である。また、jはクラスを識別する番号である。ラベルyi,jは、i番目のサンプルにおいて、j番目のクラスが正解であれば1になり、不正解であれば0になる。
また、Lcascadeは、モデルカスケードを構成した場合の利益の最大化のための項である。Lcascadeは、各サンプルについて、軽量モデルの確信度に基づいて高精度モデル及び軽量モデルの推定結果を採用した場合の損失を表している。ここで、損失は、不適切な確信度へのペナルティと高精度モデルを用いるコストを含む。また、損失は高精度モデルの推定結果が正解か否かと、軽量モデルの推定結果が正解か否かとの組み合わせで4パターンに分けられる。詳細は後述するが、高精度モデルの推定が不正解、かつ軽量モデルの確信度が低い場合は、ペナルティは大きくなる。一方、軽量モデルの推定が正解、かつ軽量モデルの確信度が高い場合は、ペナルティは小さくなる。Lcascadeは、(3)式のように表される。
1fastは、軽量モデルの推定結果が正解であれば0、軽量モデルの推定結果が不正解であれば1を返す指示関数である。また、1accは、高精度モデルの推定結果が正解であれば0、高精度モデルの推定結果が不正解であれば1を返す指示関数である。COSTaccは、高精度モデルによる推定を行うことにかかるコストであり、事前に設定されるパラメータである。
maxjqi,jは、軽量モデルが出力する確率の最大値であり、確信度の一例である。推定結果が正解であれば、確信度が大きいほど推定精度は高いといえる。一方、推定結果が不正解であれば、確信度が大きいほど推定精度は低いといえる。
(3)式のmaxjqi,j1fastは、軽量モデルによる推定結果が不正解である場合に軽量モデルによる推定結果の確信度が大きいほど大きくなる第2の項の一例である。また、(3)式の(1-maxjqi,j)1accは、高精度モデルによる推定結果が不正解である場合に軽量モデルによる推定結果の確信度が小さいほど大きくなる第3の項の一例である。また、(3)式の(1-maxjqi,j)COSTaccは、軽量モデルによる推定結果の確信度が小さいほど大きくなる第4の項の一例である。この場合、更新部123による損失の最小化が、損失の最適化に相当する。
更新部123は、損失が最適化されるように、軽量モデルのパラメータを更新する。つまり、更新部123は、軽量モデルによる推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも処理速度が遅く推定精度が高い高精度モデルに学習用データを入力して得られた推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように、軽量モデルのパラメータを更新する。更新部123は、損失計算部122によって計算された損失の入力を受け付け、更新済みのモデルの情報を出力する。
図3は、ケースごとの損失の一例を示す図である。縦軸はLcascadeの値である。また、横軸は、maxjqi,jの値である。また、COSTacc=0.5とする。maxjqi,jは、軽量モデルによる推定結果の確信度であり、ここでは単に確信度と呼ぶ。
図3の「□」は、軽量モデル及び高精度モデルの両方の推定結果が正解である場合の、確信度に対するLcascadeの値である。この場合、確信度が大きいほどLcascadeの値は小さくなる。これは、軽量モデルによる推定結果が正解であれば、確信度が大きいほど軽量モデルが採用されやすくなるためである。
図3の「◇」は、軽量モデルの推定結果が正解であり、高精度モデルの推定結果が不正解である場合の、確信度に対するLcascadeの値である。この場合、確信度が大きいほどLcascadeの値は小さくなる。また、「□」の場合と比べて、Lcascadeの最大値及び小さくなる度合いが大きい。これは、高精度モデルによる推定結果が不正解であって、軽量モデルによる推定結果が正解であれば、確信度が大きいほど軽量モデルが採用されやすくなる傾向がさらに大きくなるためである。
図3の「■」は、軽量モデルの推定結果が不正解であり、高精度モデルの推定結果が正解である場合の、確信度に対するLcascadeの値である。この場合、確信度が大きいほどLcascadeの値は大きくなる。これは、軽量モデルの推定結果が不正解である場合も、確信度が小さいほど推定結果が採用されにくくなるためである。
図3の「◆」は、軽量モデル及び高精度モデルの両方の推定結果が不正解である場合の、確信度に対するLcascadeの値である。この場合、確信度が大きいほどLcascadeの値は小さくなる。ただし、「□」の場合と比べて、Lcascadeの値は大きい。これは、両方のモデルの推定結果が不正解であることから常に損失が大きく、そのような状況では軽量モデルで正確な推定ができるようにすべきであるためである。
[第1の実施形態の処理]
図4は、高精度モデルの学習処理の流れを示すフローチャートである。図4に示すように、まず、推定部111は、高精度モデルを用いて学習用データのクラスを推定する(ステップS101)。
図4は、高精度モデルの学習処理の流れを示すフローチャートである。図4に示すように、まず、推定部111は、高精度モデルを用いて学習用データのクラスを推定する(ステップS101)。
次に、損失計算部112は、高精度モデルの推定結果を基に損失を計算する(ステップS102)。そして、更新部113は、損失が最適化されるように高精度モデルのパラメータを更新する(ステップS103)。なお、学習装置10は、終了条件が満たされるまで、ステップS101からステップS103までの処理を繰り返してもよい。終了条件は、既定の回数だけ処理が繰り返されたことであってもよいし、パラメータの更新幅が収束したことであってもよい。
図5は、軽量モデルの学習処理の流れを示すフローチャートである。図5に示すように、まず、推定部121は、軽量モデルを用いて学習用データのクラスを推定する(ステップS201)。
次に、損失計算部122は、軽量モデルの推定結果、及び高精度モデルの推定結果及び高精度モデルによる推定のコストを基に損失を計算する(ステップS202)。そして、更新部123は、損失が最適化されるように軽量モデルのパラメータを更新する(ステップS203)。なお、学習装置10は、終了条件が満たされるまで、ステップS201からステップS203までの処理を繰り返してもよい。
[第1の実施形態の効果]
これまで説明してきたように、推定部121は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力し、第1の推定結果を取得する。また、更新部123は、第1の推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも処理速度が遅く推定精度が高い高精度モデルに学習用データを入力して得られた第2の推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように、軽量モデルのパラメータを更新する。このように、第1の実施形態では、軽量モデルと高精度モデルによって構成されるモデルカスケードにおいて、IDK分類器等のモデルを設けることなく、軽量モデルがモデルカスケードに適した推定を行えるようにすることで、モデルカスケードの性能を向上させることができる。その結果、第1の実施形態によれば、モデルカスケードの精度が向上するだけでなく、計算コスト及び計算リソースのオーバーヘッドを抑止することができる。さらに、第1の実施形態では、損失関数に変更を加えるものであるため、モデルアーキテクチャの変更が不要であり、適用するモデルや最適化手法に制限がない。
これまで説明してきたように、推定部121は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力し、第1の推定結果を取得する。また、更新部123は、第1の推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも処理速度が遅く推定精度が高い高精度モデルに学習用データを入力して得られた第2の推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように、軽量モデルのパラメータを更新する。このように、第1の実施形態では、軽量モデルと高精度モデルによって構成されるモデルカスケードにおいて、IDK分類器等のモデルを設けることなく、軽量モデルがモデルカスケードに適した推定を行えるようにすることで、モデルカスケードの性能を向上させることができる。その結果、第1の実施形態によれば、モデルカスケードの精度が向上するだけでなく、計算コスト及び計算リソースのオーバーヘッドを抑止することができる。さらに、第1の実施形態では、損失関数に変更を加えるものであるため、モデルアーキテクチャの変更が不要であり、適用するモデルや最適化手法に制限がない。
更新部123は、第1の推定結果における正解に対する確信度が小さいほど大きくなる第1の項と、第1の推定結果が不正解である場合に第1の推定結果の確信度が大きいほど大きくなる第2の項と、第2の推定結果が不正解である場合に第1の推定結果の確信度が小さいほど大きくなる第3の項と、第1の推定結果の確信度が小さいほど大きくなる第4の項と、を含む損失関数を基に計算される損失が最小化されるように、軽量モデルのパラメータを更新する。この結果、第1の実施形態では、軽量モデルと高精度モデルによって構成されるモデルカスケードにおいて、高精度モデルの推定結果を採用する場合のコストを考慮した上で、モデルカスケードの推定精度を向上させることができる。
[第2の実施形態]
[第2の実施形態の構成]
第2の実施形態では、学習済みの高精度モデル及び軽量モデルを使って推定を行う推定システムについて説明する。第2の実施形態の推定システムによれば、IDK分類器等を設けることなく、モデルカスケードによる推定を精度良く行うことができる。また、以降の実施形態の説明においては、説明済みの実施形態と同様の機能を有する部には同じ符号を付し、適宜説明を省略する。
[第2の実施形態の構成]
第2の実施形態では、学習済みの高精度モデル及び軽量モデルを使って推定を行う推定システムについて説明する。第2の実施形態の推定システムによれば、IDK分類器等を設けることなく、モデルカスケードによる推定を精度良く行うことができる。また、以降の実施形態の説明においては、説明済みの実施形態と同様の機能を有する部には同じ符号を付し、適宜説明を省略する。
図6に示すように、推定システム2は、高精度推定装置20及び軽量推定装置30を有する。また、高精度推定装置20及び軽量推定装置30は、ネットワークNを介して接続される。ネットワークNは、例えばインターネットである。その場合、高精度推定装置20は、クラウド環境に設けられたサーバであってもよい。また、軽量推定装置30は、IoT機器及び各種端末装置であってもよい。
図6に示すように、高精度推定装置20は、高精度モデル情報201を記憶する。高精度モデル情報201は、学習済みの高精度モデルのパラメータ等の情報である。また、高精度推定装置20は、推定部202を有する。
推定部202は、高精度モデル情報201を基に構築された高精度モデルに推定用データを入力し、推定結果を取得する。推定部202は、推定用データの入力を受け付け、推定結果を出力する。推定用データは、ラベルが未知のデータであるものとする。例えば、推定用データは、画像である。
ここで、高精度推定装置20及び軽量推定装置30は、モデルカスケードを構成する。このため、推定部202は、常に推定用データについての推定を行うわけではない。推定部202は、軽量モデルの推定結果を採用しないという判断がされた場合に、高精度モデルによる推定を行う。
軽量推定装置30は、軽量モデル情報301を記憶する。軽量モデル情報301は、学習済みの軽量モデルのパラメータ等の情報である。また、軽量推定装置30は、推定部302及び判定部303を有する。
推定部302は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも推定精度が高い高精度モデルに学習用データを入力して得られた推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された軽量モデルに、推定用のデータを入力して推定結果を取得する。推定部302は、推定用データの入力を受け付け、推定結果を出力する。
また、判定部303は、軽量モデルによる推定結果が、推定精度に関する所定の条件を満たすか否かを判定する。例えば、判定部303は、確信度が閾値以上である場合に、軽量モデルによる推定結果が条件を満たすと判定する。その場合、推定システム2は、軽量モデルの推定結果を採用する。
また、高精度推定装置20の推定部202は、判定部303によって、軽量モデルによる推定結果が条件を満たさないと判定された場合、高精度モデルに、推定用のデータを入力して推定結果を取得する。その場合、推定システム2は、高精度モデルの推定結果を採用する。
[第2の実施形態の処理]
図7は、図7は、推定処理の流れを示すフローチャートである。図7に示すように、まず、推定部302は、軽量モデルを用いて推定用データのクラスを推定する(ステップS301)。
図7は、図7は、推定処理の流れを示すフローチャートである。図7に示すように、まず、推定部302は、軽量モデルを用いて推定用データのクラスを推定する(ステップS301)。
ここで、判定部303は、推定結果が条件を満たすか否かを判定する(ステップS302)。推定結果が条件を満たす場合(ステップS302、Yes)、推定システム2は軽量モデルの推定結果を出力する(ステップS303)。
一方、推定結果が条件を満たさない場合(ステップS302、No)、推定部202は、高精度モデルを用いて推定用データのクラスを推定する(ステップS304)。そして、推定システム2は高精度モデルの推定結果を出力する(ステップS305)。
[第2の実施形態の効果]
これまで説明してきたように、推定部302は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも推定精度が高い高精度モデルに学習用データを入力して得られた推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された軽量モデルに、推定用のデータを入力して推定結果を取得する。また、判定部303は、軽量モデルによる推定結果が、推定精度に関する所定の条件を満たすか否かを判定する。この結果、第2の実施形態では、軽量モデルと高精度モデルによって構成されるモデルカスケードにおいて、オーバーヘッドの発生を抑止しつつ高精度な推定を行うことができる。
これまで説明してきたように、推定部302は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも推定精度が高い高精度モデルに学習用データを入力して得られた推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された軽量モデルに、推定用のデータを入力して推定結果を取得する。また、判定部303は、軽量モデルによる推定結果が、推定精度に関する所定の条件を満たすか否かを判定する。この結果、第2の実施形態では、軽量モデルと高精度モデルによって構成されるモデルカスケードにおいて、オーバーヘッドの発生を抑止しつつ高精度な推定を行うことができる。
推定部202は、判定部303によって、軽量モデルによる推定結果が条件を満たさないと判定された場合、高精度モデルに、推定用のデータを入力して推定結果を取得する。これにより、第2の実施形態によれば、軽量モデルによる推定結果が採用できない場合であっても、高精度の推定結果を得ることができる。
ここで、第2の実施形態に係る推定システム2は、以下のように表現することができる。すなわち、推定システム2は、高精度推定装置20及び軽量推定装置30を有する。軽量推定装置30は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも処理速度が遅い、又は軽量のモデルよりも推定精度が高い高精度モデルに学習用データを入力して得られた推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された軽量モデルに、推定用のデータを入力して第1の推定結果を取得する推定部302と、第1の推定結果が、推定精度に関する所定の条件を満たすか否かを判定する判定部303と、を有する。高精度推定装置20は、判定部303によって、第1の推定結果が条件を満たさないと判定された場合、高精度モデルに、推定用のデータを入力して第2の推定結果を取得する推定部202を有する。また、高精度推定装置20は、推定用データを軽量推定装置30から取得してもよい。
推定部202は、軽量推定装置30による推定の結果に応じて推定を行う。すなわち、推定部202は、入力されたデータを基に推定結果を出力する軽量モデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、軽量モデルよりも処理速度が遅い、又は軽量モデルよりも推定精度が高い高精度モデルに学習用データを入力して得られた推定結果と、を基に、軽量モデルと高精度モデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された軽量モデルに、軽量推定装置30が推定用のデータを入力して取得する第1の推定結果に応じて、推定用のデータを高精度モデルに入力して第2の推定結果を取得する。
[実験]
ここで、実施形態の効果を確認するために行った実験とその結果について説明する。図8から図9は、実験結果を示す図である。実験では、第2の実施形態における判定部303が、確信度が閾値を超えているか否かを判定するものとする。実験における各設定は下記の通りである。
データセット:CIFAR100
train:45000, validation:5000, test:10000
軽量モデル:MobileNetV2
高精度モデル:ResNet18
モデルの学習方法
Momentum SGD
lr=0.01, momentum=0.9, weight decay=5e-4
lrは60,120,160エポックで0.2倍
batch size:128
比較手法(各5回ずつ実験)
・Base:クラス確率の最大値を利用
・IDK Cascades(非特許文献1を参照)
・ConfNet(参考文献1を参照)
・Temperature Scaling(参考文献2を参照)
・第2の実施形態
精度:モデルカスケード構成で推論を行った際の精度
オフロード数:高精度モデルで推論を行った回数
(参考文献1)Wan, Sheng, et al. "Confnet: Predict with Confidence." 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018.
(参考文献2)Guo, Chuan, et al. "On calibration of modern neural networks." Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017.
ここで、実施形態の効果を確認するために行った実験とその結果について説明する。図8から図9は、実験結果を示す図である。実験では、第2の実施形態における判定部303が、確信度が閾値を超えているか否かを判定するものとする。実験における各設定は下記の通りである。
データセット:CIFAR100
train:45000, validation:5000, test:10000
軽量モデル:MobileNetV2
高精度モデル:ResNet18
モデルの学習方法
Momentum SGD
lr=0.01, momentum=0.9, weight decay=5e-4
lrは60,120,160エポックで0.2倍
batch size:128
比較手法(各5回ずつ実験)
・Base:クラス確率の最大値を利用
・IDK Cascades(非特許文献1を参照)
・ConfNet(参考文献1を参照)
・Temperature Scaling(参考文献2を参照)
・第2の実施形態
精度:モデルカスケード構成で推論を行った際の精度
オフロード数:高精度モデルで推論を行った回数
(参考文献1)Wan, Sheng, et al. "Confnet: Predict with Confidence." 2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018.
(参考文献2)Guo, Chuan, et al. "On calibration of modern neural networks." Proceedings of the 34th International Conference on Machine Learning-Volume 70. JMLR. org, 2017.
上記のtestデータを用いて、第2の実施形態を含む各手法で実際に推定を行い、閾値を0から1まで0.01刻みで変化させた際のオフロード数と精度の関係を図8に示す。図8に示すように、実施形態の手法(proposed)は、他の手法と比べ、オフロード数が減った場合であっても高い精度を示した。
また、上記のvalidationデータで最も精度が高くなる閾値を採用して、testデータの推定を行った際のオフロード数と精度の関係を図9及び図10に示す。これより、第2の実施形態によれば、高精度モデルの精度を維持しつつ最もオフロード数が削減されていることがわかる。
さらに、testデータで高精度モデルの精度を維持しつつ最もオフロードが少なかった際のオフロード数と精度の関係を図11及び図12に示す。これより、第2の実施形態によれば最もオフロード数が削減されていることがわかる。
[第3の実施形態]
第2の実施形態では、軽量モデルによる推定を行う装置と、高精度モデルによる推定を行う装置が別々である場合の例について説明した。一方で、軽量モデルによる推定と高精度モデルによる推定は同じ装置で行われてもよい。
第2の実施形態では、軽量モデルによる推定を行う装置と、高精度モデルによる推定を行う装置が別々である場合の例について説明した。一方で、軽量モデルによる推定と高精度モデルによる推定は同じ装置で行われてもよい。
図13は、第3の実施形態に係る推定装置の構成例を示す図である。推定装置2aは、第2の実施形態の推定システム2と同様の機能を有する。また、高精度推定部20aは、第2の実施形態の高精度推定装置20と同様の機能を有する。また、軽量推定部30aは、第2の実施形態の軽量推定装置30と同様の機能を有する。第2の実施形態と異なり、推定部202と判定部303は同じ装置内にあるため、推定処理において、ネットワークを介したデータのやり取りは発生しない。
[第4の実施形態]
これまで、モデルが軽量モデル及び高精度モデルの2つである場合の実施形態について説明した。一方で、これまでに説明した実施形態は、モデルが3つ以上の場合に拡張することができる。
これまで、モデルが軽量モデル及び高精度モデルの2つである場合の実施形態について説明した。一方で、これまでに説明した実施形態は、モデルが3つ以上の場合に拡張することができる。
図14は、3つ以上のモデルを含むモデルカスケードについて説明する図である。ここでは、M個(M>3)のモデルがあるものとする。m+1番目(M-1≧m≧1)のモデルは、m番目のモデルよりも処理速度が遅い、又はm番目のモデルよりも推定精度が高いものとする。つまり、m+1番目のモデルとm番目のモデルとの関係は、高精度モデルと軽量モデルとの関係と同様である。さらに、M番目のモデルは最も高精度なモデルであり、1番目のモデルは最も軽量なモデルということができる。
第4の実施形態では、第2の実施形態で説明した推定システム2を使って、3つ以上のモデルによる推定処理を実現することができる。まず、推定システム2は、高精度モデル情報201を2番目のモデルの情報に置き換え、軽量モデル情報301を1番目のモデルの情報に置き換える。そして、推定システム2は、第2の実施形態と同様の推定処理を実行する。
その後、1番目のモデルの推定結果が条件と満たさず、かつ、2番目のモデルの推定結果が条件を満たさない場合、推定システム2は、高精度モデル情報201を3番目のモデルの情報に置き換え、軽量モデル情報301を2番目のモデルの情報に置き換えて推定処理をさらに実行する。推定システム2は、条件を満たす推定結果が得られるか、又はM番目のモデルによる推定処理が終わるまでこの処理を繰り返す。なお、同様の処理は、軽量モデル情報301を置き換えていくことにより、軽量推定装置30のみでも実現可能である。
さらに、第4の実施形態では、第1の実施形態で説明した学習装置10を使って、3つ以上のモデルの学習処理を実現することができる。学習装置10は、M個のモデルから番号が連続する2つのモデルを抽出し、それらのモデルの情報を用いて学習処理を実行する。まず、学習装置10は、高精度モデル情報114をM番目のモデルの情報に置き換え、軽量モデル情報124をM-1番目のモデルの情報に置き換える。そして、学習装置10は、第1の実施形態と同様の学習処理を実行する。一般化すると、学習装置10は、高精度モデル情報114をm番目のモデルの情報に置き換え、軽量モデル情報124をm-1番目のモデルの情報に置き換えた上で、第1の実施形態と同様の学習処理を実行する。
図15は、3つ以上のモデルの学習処理の流れを示すフローチャートである。ここでは、第1の実施形態の学習装置10が学習処理を行うものとする。図15に示すように、まず、学習装置10は、mの初期値としてMを設定する(ステップS401)。推定部121は、m-1番目のモデルを用いて学習用データのクラスを推定する(ステップS402)。
次に、損失計算部122は、m-1番目のモデルの推定結果、及びm番目のモデルの推定結果及びm番目のモデルによる推定のコストを基に損失を計算する(ステップS403)。そして、更新部123は、損失が最適化されるようにm-1番目のモデルのパラメータを更新する(ステップS404)。
ここで、学習装置10は、mを1だけ減少させる(ステップS405)。mが1に達した場合(ステップS406、Yes)、学習装置10は処理を終了する。一方、mが1に達していない場合(ステップS406、No)、学習装置10はステップS402に戻り処理を繰り返す。
図16は、3つ以上のモデルによる推定処理の流れを示すフローチャートである。ここでは、第2の実施形態の軽量推定装置30が推定処理を行うものとする。図16に示すように、まず、軽量推定装置30は、mの初期値として1を設定する(ステップS501)。推定部302は、m番目のモデルを用いて推定用データのクラスを推定する(ステップS502)。
ここで、判定部303は、推定結果が条件を満たすか否か、及びmがMに達しているか否かを判定する(ステップS503)。推定結果が条件を満たすか、又はmがMに達している場合(ステップS503、Yes)、軽量推定装置30はm番目のモデルの推定結果を出力する(ステップS504)。
一方、推定結果が条件を満たさず、かつmがMに達していない場合(ステップS503、No)、推定部302は、軽量推定装置30は、mを1だけ増加させ(ステップS505)、ステップS502に戻り処理を繰り返す。
例えば、従来の技術では、モデルが増加するのに従いIDK分類器の数も増加し、計算コスト及び計算リソースのオーバーヘッドが拡大する。一方で、第4の実施形態によれば、モデルカスケードを構成するモデルの数が3つ以上に増加したとしても、そのようなオーバーヘッドが拡大する問題は生じない。
[システム構成等]
また、図示した各装置の各構成要素は機能概念的なものであり、必ずしも物理的に図示のように構成されていることを要しない。すなわち、各装置の分散及び統合の具体的形態は図示のものに限られず、その全部又は一部を、各種の負荷や使用状況等に応じて、任意の単位で機能的又は物理的に分散又は統合して構成することができる。さらに、各装置にて行われる各処理機能は、その全部又は任意の一部が、CPU及び当該CPUにて解析実行されるプログラムにて実現され、あるいは、ワイヤードロジックによるハードウェアとして実現され得る。
また、図示した各装置の各構成要素は機能概念的なものであり、必ずしも物理的に図示のように構成されていることを要しない。すなわち、各装置の分散及び統合の具体的形態は図示のものに限られず、その全部又は一部を、各種の負荷や使用状況等に応じて、任意の単位で機能的又は物理的に分散又は統合して構成することができる。さらに、各装置にて行われる各処理機能は、その全部又は任意の一部が、CPU及び当該CPUにて解析実行されるプログラムにて実現され、あるいは、ワイヤードロジックによるハードウェアとして実現され得る。
また、本実施形態において説明した各処理のうち、自動的に行われるものとして説明した処理の全部又は一部を手動的に行うこともでき、あるいは、手動的に行われるものとして説明した処理の全部又は一部を公知の方法で自動的に行うこともできる。この他、上記文書中や図面中で示した処理手順、制御手順、具体的名称、各種のデータやパラメータを含む情報については、特記する場合を除いて任意に変更することができる。
[プログラム]
一実施形態として、学習装置10及び軽量推定装置30は、パッケージソフトウェアやオンラインソフトウェアとして上記の学習処理又は推定処理を実行するプログラムを所望のコンピュータにインストールさせることによって実装できる。例えば、上記のプログラムを情報処理装置に実行させることにより、情報処理装置を学習装置10又は軽量推定装置30として機能させることができる。ここで言う情報処理装置には、デスクトップ型又はノート型のパーソナルコンピュータが含まれる。また、その他にも、情報処理装置にはスマートフォン、携帯電話機やPHS(Personal Handyphone System)等の移動体通信端末、さらには、PDA(Personal Digital Assistant)等のスレート端末等がその範疇に含まれる。
一実施形態として、学習装置10及び軽量推定装置30は、パッケージソフトウェアやオンラインソフトウェアとして上記の学習処理又は推定処理を実行するプログラムを所望のコンピュータにインストールさせることによって実装できる。例えば、上記のプログラムを情報処理装置に実行させることにより、情報処理装置を学習装置10又は軽量推定装置30として機能させることができる。ここで言う情報処理装置には、デスクトップ型又はノート型のパーソナルコンピュータが含まれる。また、その他にも、情報処理装置にはスマートフォン、携帯電話機やPHS(Personal Handyphone System)等の移動体通信端末、さらには、PDA(Personal Digital Assistant)等のスレート端末等がその範疇に含まれる。
また、学習装置10及び軽量推定装置30は、ユーザが使用する端末装置をクライアントとし、当該クライアントに上記の学習処理又は推定処理に関するサービスを提供するサーバ装置として実装することもできる。例えば、サーバ装置は、学習用のデータを入力とし、学習済みのモデルの情報を出力とするサービスを提供するサーバ装置として実装される。この場合、サーバ装置は、Webサーバとして実装することとしてもよいし、アウトソーシングによって上記の処理に関するサービスを提供するクラウドとして実装することとしてもかまわない。
図17は、学習プログラムを実行するコンピュータの一例を示す図である。なお、推定プログラムについても同様のコンピュータによって実行されてもよい。コンピュータ1000は、例えば、メモリ1010、プロセッサ1020を有する。また、コンピュータ1000は、ハードディスクドライブインタフェース1030、ディスクドライブインタフェース1040、シリアルポートインタフェース1050、ビデオアダプタ1060、ネットワークインタフェース1070を有する。これらの各部は、バス1080によって接続される。
メモリ1010は、ROM(Read Only Memory)1011及びRAM1012を含む。ROM1011は、例えば、BIOS(BASIC Input Output System)等のブートプログラムを記憶する。プロセッサ1020は、CPU1021及びGPU(Graphics Processing Unit)1022を含む。ハードディスクドライブインタフェース1030は、ハードディスクドライブ1090に接続される。ディスクドライブインタフェース1040は、ディスクドライブ1100に接続される。例えば磁気ディスクや光ディスク等の着脱可能な記憶媒体が、ディスクドライブ1100に挿入される。シリアルポートインタフェース1050は、例えばマウス1110、キーボード1120に接続される。ビデオアダプタ1060は、例えばディスプレイ1130に接続される。
ハードディスクドライブ1090は、例えば、OS1091、アプリケーションプログラム1092、プログラムモジュール1093、プログラムデータ1094を記憶する。すなわち、学習装置10の各処理を規定するプログラムは、コンピュータにより実行可能なコードが記述されたプログラムモジュール1093として実装される。プログラムモジュール1093は、例えばハードディスクドライブ1090に記憶される。例えば、学習装置10における機能構成と同様の処理を実行するためのプログラムモジュール1093が、ハードディスクドライブ1090に記憶される。なお、ハードディスクドライブ1090は、SSDにより代替されてもよい。
また、上述した実施形態の処理で用いられる設定データは、プログラムデータ1094として、例えばメモリ1010やハードディスクドライブ1090に記憶される。そして、CPU1020は、メモリ1010やハードディスクドライブ1090に記憶されたプログラムモジュール1093やプログラムデータ1094を必要に応じてRAM1012に読み出して、上述した実施形態の処理を実行する。
なお、プログラムモジュール1093やプログラムデータ1094は、ハードディスクドライブ1090に記憶される場合に限らず、例えば着脱可能な記憶媒体に記憶され、ディスクドライブ1100等を介してCPU1020によって読み出されてもよい。あるいは、プログラムモジュール1093及びプログラムデータ1094は、ネットワーク(LAN(Local Area Network)、WAN(Wide Area Network)等)を介して接続された他のコンピュータに記憶されてもよい。そして、プログラムモジュール1093及びプログラムデータ1094は、他のコンピュータから、ネットワークインタフェース1070を介してCPU1020によって読み出されてもよい。
2 推定システム
2a 推定装置
10 学習装置
11 高精度モデル学習部
12 軽量モデル学習部
20 高精度推定装置
20a 高精度推定部
30 軽量推定装置
30a 軽量推定部
111、121、202、302 推定部
112、122 損失計算部
113、123 更新部
114、201 高精度モデル情報
124、301 軽量モデル情報
303 判定部
2a 推定装置
10 学習装置
11 高精度モデル学習部
12 軽量モデル学習部
20 高精度推定装置
20a 高精度推定部
30 軽量推定装置
30a 軽量推定部
111、121、202、302 推定部
112、122 損失計算部
113、123 更新部
114、201 高精度モデル情報
124、301 軽量モデル情報
303 判定部
Claims (8)
- 入力されたデータを基に推定結果を出力する第1のモデルに学習用データを入力し、第1の推定結果を取得する推定部と、
前記第1の推定結果と、入力されたデータを基に推定結果を出力するモデルであって、前記第1のモデルよりも処理速度が遅い、又は前記第1のモデルよりも推定精度が高い第2のモデルに前記学習用データを入力して得られた第2の推定結果と、を基に、前記第1のモデルと前記第2のモデルを含むモデルカスケードが最適化されるように、前記第1のモデルのパラメータを更新する更新部と、
を有することを特徴とする学習装置。 - 前記更新部は、
前記第1の推定結果における正解に対する確信度が小さいほど大きくなる第1の項と、前記第1の推定結果が不正解である場合に前記第1の推定結果の確信度が大きいほど大きくなる第2の項と、前記第2の推定結果が不正解である場合に前記第1の推定結果の確信度が小さいほど大きくなる第3の項と、前記第1の推定結果の確信度が小さいほど大きくなる第4の項と、を含む損失関数を基に計算される損失が最適化されるように、前記第1のモデルのパラメータを更新することを特徴とする請求項1に記載の学習装置。 - 学習装置によって実行される学習方法であって、
入力されたデータを基に推定結果を出力する第1のモデルに学習用データを入力し、第1の推定結果を取得する推定工程と、
前記第1の推定結果と、入力されたデータを基に推定結果を出力するモデルであって、前記第1のモデルよりも処理速度が遅い、又は前記第1のモデルよりも推定精度が高い第2のモデルに前記学習用データを入力して得られた第2の推定結果と、を基に、前記第1のモデルと前記第2のモデルを含むモデルカスケードが最適化されるように、前記第1のモデルのパラメータを更新する更新工程と、
を含むことを特徴とする学習方法。 - コンピュータを、請求項1又は2に記載の学習装置として機能させるための学習プログラム。
- 入力されたデータを基に推定結果を出力する第1のモデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、前記第1のモデルよりも処理速度が遅い、又は前記第1のモデルよりも推定精度が高い第2のモデルに前記学習用データを入力して得られた推定結果と、を基に、前記第1のモデルと前記第2のモデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された前記第1のモデルに、推定用のデータを入力して第1の推定結果を取得する第1の推定部と、
前記第1の推定結果が、推定精度に関する所定の条件を満たすか否かを判定する判定部と、
を有することを特徴とする推定装置。 - 入力されたデータを基に推定結果を出力する第1のモデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、前記第1のモデルよりも処理速度が遅い、又は前記第1のモデルよりも推定精度が高い第2のモデルに前記学習用データを入力して得られた推定結果と、を基に、前記第1のモデルと前記第2のモデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された前記第1のモデルに、他の推定装置が推定用のデータを入力して取得する第1の推定結果に応じて、前記推定用のデータを前記第2のモデルに入力して第2の推定結果を取得する第2の推定部を有することを特徴とする推定装置。
- 推定装置によって実行される推定方法であって、
入力されたデータを基に推定結果を出力する第1のモデルに学習用データを入力して得られた推定結果と、入力されたデータを基に推定結果を出力するモデルであって、前記第1のモデルよりも処理速度が遅い、又は前記第1のモデルよりも推定精度が高い第2のモデルに前記学習用データを入力して得られた推定結果と、を基に、前記第1のモデルと前記第2のモデルを含むモデルカスケードが最適化されるように予め学習されたパラメータが設定された前記第1のモデルに、推定用のデータを入力して第1の推定結果を取得する第1の推定工程と、
前記第1の推定結果が、推定精度に関する所定の条件を満たすか否かを判定する判定工程と、
を含むことを特徴とする推定方法。 - コンピュータを、請求項5又は6に記載の推定装置として機能させるための推定プログラム。
Priority Applications (4)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
PCT/JP2020/009878 WO2021176734A1 (ja) | 2020-03-06 | 2020-03-06 | 学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム |
JP2022504953A JP7447985B2 (ja) | 2020-03-06 | 2020-03-06 | 学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム |
US17/801,272 US20230112076A1 (en) | 2020-03-06 | 2020-03-06 | Learning device, learning method, learning program, estimation device, estimation method, and estimation program |
JP2024029580A JP2024051136A (ja) | 2020-03-06 | 2024-02-29 | 学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
PCT/JP2020/009878 WO2021176734A1 (ja) | 2020-03-06 | 2020-03-06 | 学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム |
Publications (1)
Publication Number | Publication Date |
---|---|
WO2021176734A1 true WO2021176734A1 (ja) | 2021-09-10 |
Family
ID=77614024
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
PCT/JP2020/009878 WO2021176734A1 (ja) | 2020-03-06 | 2020-03-06 | 学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム |
Country Status (3)
Country | Link |
---|---|
US (1) | US20230112076A1 (ja) |
JP (2) | JP7447985B2 (ja) |
WO (1) | WO2021176734A1 (ja) |
Families Citing this family (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20220156524A1 (en) * | 2020-11-16 | 2022-05-19 | Google Llc | Efficient Neural Networks via Ensembles and Cascades |
Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018170175A1 (en) * | 2017-03-15 | 2018-09-20 | Salesforce.Com, Inc. | Probability-based guider |
-
2020
- 2020-03-06 WO PCT/JP2020/009878 patent/WO2021176734A1/ja active Application Filing
- 2020-03-06 JP JP2022504953A patent/JP7447985B2/ja active Active
- 2020-03-06 US US17/801,272 patent/US20230112076A1/en active Pending
-
2024
- 2024-02-29 JP JP2024029580A patent/JP2024051136A/ja active Pending
Patent Citations (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
WO2018170175A1 (en) * | 2017-03-15 | 2018-09-20 | Salesforce.Com, Inc. | Probability-based guider |
Non-Patent Citations (1)
Title |
---|
HANYA, YOSHINORI: "Distillation of Knowledge in Deep Learning", CODE CRAFT HOUSE, 12 January 2018 (2018-01-12), pages 1 - 17, XP055852891, Retrieved from the Internet <URL:http://codecrafthouse.jp/p/2018/01/knowledge-distillation> [retrieved on 20200717] * |
Also Published As
Publication number | Publication date |
---|---|
JP2024051136A (ja) | 2024-04-10 |
JPWO2021176734A1 (ja) | 2021-09-10 |
JP7447985B2 (ja) | 2024-03-12 |
US20230112076A1 (en) | 2023-04-13 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
US10936949B2 (en) | Training machine learning models using task selection policies to increase learning progress | |
US10546066B2 (en) | End-to-end learning of dialogue agents for information access | |
US11449744B2 (en) | End-to-end memory networks for contextual language understanding | |
US11941527B2 (en) | Population based training of neural networks | |
KR20190068255A (ko) | 고정 소수점 뉴럴 네트워크를 생성하는 방법 및 장치 | |
WO2022110640A1 (zh) | 一种模型优化方法、装置、计算机设备及存储介质 | |
JP7293504B2 (ja) | 強化学習を用いたデータ評価 | |
CN111406264A (zh) | 神经架构搜索 | |
US10783452B2 (en) | Learning apparatus and method for learning a model corresponding to a function changing in time series | |
US11449731B2 (en) | Update of attenuation coefficient for a model corresponding to time-series input data | |
US11914672B2 (en) | Method of neural architecture search using continuous action reinforcement learning | |
CN111598253A (zh) | 使用教师退火来训练机器学习模型 | |
JP2024051136A (ja) | 学習装置、学習方法、学習プログラム、推定装置、推定方法及び推定プログラム | |
WO2022068934A1 (en) | Method of neural architecture search using continuous action reinforcement learning | |
CN111008689A (zh) | 使用softmax近似来减少神经网络推理时间 | |
US20220391765A1 (en) | Systems and Methods for Semi-Supervised Active Learning | |
JP6928346B2 (ja) | 予測装置、予測方法および予測プログラム | |
CN114970732A (zh) | 分类模型的后验校准方法、装置、计算机设备及介质 | |
WO2022070343A1 (ja) | 学習装置、学習方法及び学習プログラム | |
WO2022070342A1 (ja) | 学習装置、学習方法及び学習プログラム | |
JP7571878B2 (ja) | 学習装置、学習方法及び学習プログラム | |
US20230127832A1 (en) | Bnn training with mini-batch particle flow | |
JP7557438B2 (ja) | 自然言語処理モデル取得装置、自然言語処理装置、自然言語処理モデル取得方法、自然言語処理方法及びプログラム | |
US20240338595A1 (en) | Testing membership in distributional simplex | |
US20230368015A1 (en) | Entropy-based anti-modeling for machine learning applications |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
121 | Ep: the epo has been informed by wipo that ep was designated in this application |
Ref document number: 20923008 Country of ref document: EP Kind code of ref document: A1 |
|
ENP | Entry into the national phase |
Ref document number: 2022504953 Country of ref document: JP Kind code of ref document: A |
|
NENP | Non-entry into the national phase |
Ref country code: DE |
|
122 | Ep: pct application non-entry in european phase |
Ref document number: 20923008 Country of ref document: EP Kind code of ref document: A1 |