さえめろ の めも🐰

さえめろの備忘録です。twitter : @sae_mero_

MNISTのスケーリングは255.0で割らないほうがいいらしい

MNIST

今回の事件の舞台。
多分この記事を読んでいる方はMNISTについての説明なんていらないとは思うのですが、一応ざっくりとだけ説明。

MNISTは、0〜9までの手書き文字の画像からなるデータセットです。
そしてそれぞれの画像に対し、そこに書かれている数字をラベルとして持っています。

sklearnではfrom sklearn.datasets import fetch_mldata で使えるようになります。
(今回もsklearnを使用して分類しています)

画像は28px × 28px のモノクロデータであり、各ピクセルが0~255の値を持ちます。
これを784(=28*28)のベクトルとして考え、分類します。


MNISTのスケーリング

今回はSVMを使用しました。
SVMはチューニングが大切な分類器ですが、とりあえずデフォルトで何もせずに分類させるとこうなった。

from sklearn.utils import shuffle
from sklearn.datasets import fetch_mldata
from sklearn.metrics import f1_score
from sklearn.cross_validation import train_test_split
from sklearn import svm
from sklearn.svm import SVC
import numpy as np


mnist = fetch_mldata('MNIST original')
mnist_X,mnist_y = shuffle(mnist.data.astype('float32'),
                         mnist.target.astype('int32'),
                         random_state = 42)

train_X,test_X,train_y,test_y = train_test_split(mnist_X,mnist_y,
                                                test_size = 0.2,
                                                random_state = 43)


# 全部やると重いのでデータを10000に制限しました
clf = svm.SVC()
clf = clf.fit(train_X[:10000],train_y[:10000])

pred_y = clf.predict(test_X[:100])
f1_score(test_y[:len(pred_y[:100])], pred_y, average='macro')


これで精度は約1.8%。
生ゴミのようなものが出来上がりました。

それもそのはずで、MNISTの0〜255の数値では範囲が広すぎて、SVMによる座標上の自然な分割が行われません。

なのでmnist_Xをもっと小さい範囲にスケーリングしてから訓練させることで、劇的に精度が上がります。
そして0〜1の範囲にスケーリングするために、mnist_Xを255.0で除算するのが一般的なようです。

試しに

mnist_X = mnist_X / 255.0

を追加することで、精度は約91.5%まで上がります。劇的。

で、ここまではまあ理解のできる話です。


謎のスケーリング

先ほど mnist_X /= 255.0 と追加し、その前にも255.0で割るのが一般的などと書きました。
が、どうやらそれは最適ではないらしい。
SVMの精度を上げるためにグリッドサーチなどで戦っていたのですが、mnist_Xに対するスケーリングを

mnist_X = mnist_X / 100.0

と何気なく少なくしてみたところ、精度が95%まで上昇しました。
その差4%。でかい。

その後しばらく検証してみたところ、trainとtestがどういうわけかたであってもほぼ確実に(random_stateを0~100まで変更した限りでは100%)精度が向上する結果となりました。

より平均的な精度の高くなる数字は検証はしていませんが、100くらいで割ると精度が上がると言えます。
ちなみに半数の127.5では2%くらいの伸びでした。

多分SVMさんがとても分けやすいように並ぶのかな〜とは思うのですが、何しろ数値だけの上昇であり根拠がわかりません。あとものすごい検証不足です。
どなたかこんな感じじゃないかな〜とかでもいいので、わかる方いらっしゃいましたらご教授くださると幸いです。



追記

ちなみに、現在MLP(これはsklearnではなく手打ちです)でもMNISTを分類していますが、これもやっぱり100くらいで割ると精度が高いです。