最小二乗法(least squares method)
前回は線形回帰とは、y = ax + b のa、bを求めることで未知のxに対し、yを予想できるようになることと書きましたが、そのa、bを求めるための方法が最小二乗法になります。
最小二乗法について勉強したことを以下に纏めます。
独学で勉強しただけなので、書いてあることが誤っていることがあるかもしれません。
なので書いてあることが絶対正しいと思わないで下さい。
y = ax + b を求めるには
サンプルとなる点が2つの場合、それぞれの点を通るy = ax + b を求めればいいので、単純な方程式で解けます。(例)サンプルが2つの場合
サンプル | y | x |
---|---|---|
データA | 2 | 1 |
データB | 3 | 2 |
それぞれをy = ax + b の x, yに代入。
2 = a * 1 + b ・・・ ①
3 = a * 2 + b ・・・ ①
上記の式を解くと、a = 1、b = 1 となるので y = ax + b は y = x + 1 になる
サンプルが増えると
サンプルが増えると上記のように単純にすべての点を通る式を求められなくなります。すべての点を通れないので、すべての点とy = ax + b を出来るだけ近づける(誤差を少なくする)ように a、b を調整します。
(例)サンプルが3つの場合
サンプル | y | x |
---|---|---|
データA | 2 | 1 |
データB | 3 | 2 |
データC | 3.5 | 3 |
①a = 1、b = 1の誤差の合計
a = 1、b = 1とした場合、y = 1 * x + 1 となります。実際のyと計算した値の差(誤差)は以下のようになります。
サンプル | y | x | 計算値(1 * x + 1) | 誤差(y-計算値) |
---|---|---|---|---|
データA | 2 | 1 | 2 | 0 |
データB | 3 | 2 | 3 | 0 |
データC | 3.5 | 3 | 4 | -0.5 |
誤差の合計を求めるので、①の誤差の合計は0 + 0 + -0.5 = -0.5になります。
②a = 1、b = 0.5の誤差の合計
次に、a = 1、b = 0.5とした場合、y = 1 * x + 0.5 となります。
実際のyと計算した値の差(誤差)は以下のようになります。
サンプル | y | x | 計算値(1 * x + 0.5) | 誤差(y-計算値) |
---|---|---|---|---|
データA | 2 | 1 | 1.5 | 0.5 |
データB | 3 | 2 | 2.5 | 0.5 |
データC | 3.5 | 3 | 3.5 | 0 |
誤差の合計を求めるので、②の誤差の合計は0.5 + 0.5 + 0 = 1になります。
誤差は0に近い方が誤差が小さいと言えるので、①の結果(-0.5)と②の結果(1)を比べると①の方が誤差が小さいと言えます。
③a = 0.6、b = 1.5の誤差の合計
次に、a = 0.6、b = 1.5とした場合、y = 0.6 * x + 1.5 となります。
実際のyと計算した値の差(誤差)は以下のようになります。
サンプル | y | x | 計算値(0.6 * x + 1.5) | 誤差(y-計算値) |
---|---|---|---|---|
データA | 2 | 1 | 2.1 | -0.1 |
データB | 3 | 2 | 2.7 | 0.3 |
データC | 3.5 | 3 | 3.3 | 0.2 |
③の誤差の合計を求めようとした場合、単純に足すと -0.1 + 0.3 + 0.2 = 0.4 となりますが、
0に近いほど誤差がないと言えるので-0.1は誤差が減るのではなく、誤差が増えるように計算されないといけません。
なので、誤差の合計としては、0.1 + 0.3 + 0.2 = 0.6 とすべきです。
誤差の一つ一つを絶対値にすれば正しく計算できますが、絶対値にしなくても2乗することで符号の問題は解決できます。
サンプル | y | x | 計算値(0.6 * x + 1.5) | 誤差(y-計算値) | 誤差の2乗 |
---|---|---|---|---|---|
データA | 2 | 1 | 2.1 | -0.1 | 0.01 |
データB | 3 | 2 | 2.7 | 0.3 | 0.09 |
データC | 3.5 | 3 | 3.3 | 0.2 | 0.04 |
誤差の2乗の合計は0.14。
このように誤差を2乗してy = ax + b のa、bを求めることを最小二乗法(least squares method)といいます。
最小二乗法
上記までやってきたことを数式で表すと以下になります。求めたい式(a、bを\( \theta_1、\theta_0 \)に変更しています)
\(y \fallingdotseq h_\theta (x) = \theta_0 + \theta_1 * x \)
→ \(\theta_0、\theta_1\)を求めたい。\(\theta_0、\theta_1\)が求まれば未知のxに対するyがわかるようになります。
上記の\(\theta_0、\theta_1\)を求めるため、予測値と実測値の差を表す式、誤差合計(最小二乗法)は以下のようになります。
\(J(\theta_0, \theta_1) = \frac{1}{2m}\sum_{i=1}^m(h_\theta (x_i) - y_i)^2 \)
→ \(J(\theta_0, \theta_1)\)が誤差の2乗の合計になります。これを出来るだけ小さくなるように\(\theta_0、 \theta_1\)を調整します 。
シグマ記号
\( \sum_{i=1}^m \)はシグマ記号というもので下のi=1からiが上のmになるまで、右にある計算式を計算し、その計算結果を合算していくというものです。
例えば以下の場合、
\( \sum_{i=1}^3 i*2\)
以下の計算をします。
int sum = 0;
for(int i = 1;i <= 3;i++){
sum += i*2;
}
実際に \(\theta_0、\theta_1\) を求めるには次回の最急降下法を使います。
ページのトップへ戻る