ナード戦隊データマン

データサイエンスを用いて悪と戦うぞ

カーネル法を用いたSVMを理解する

カーネル法を用いたSVMは、線形サポートベクターマシンを拡張したものです。ここでは、線形SVMを用いた例から初め、カーネル法を理解していきます。

線形モデル

低次元での線形モデルは、直線や平面によって分類するため、制約が強いものとなります。線形モデルに「非線形特徴量」を追加すると、制約を超えることができる場合があります。

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mglearn
from IPython.display import display
from sklearn.datasets import make_blobs
from sklearn.svm import LinearSVC


X, y = make_blobs(centers=4, random_state=8)
y = y % 2

linear_svm = LinearSVC().fit(X, y)

mglearn.plots.plot_2d_separator(linear_svm, X)
mglearn.discrete_scatter(X[:,0], X[:,1], y)
plt.xlabel("Feature 0")
plt.ylabel("Feature 1")

f:id:mathgeekjp:20170903193936p:plain

直線で分類しようとしても、上記データセットをうまく分類することができません。

ここで、feature1の二乗という非線形特徴量を追加したらどうなるか見てみます。

from mpl_toolkits.mplot3d import Axes3D, axes3d

X_new = np.hstack([X, X[:,1:]**2])
mask = y == 0

linear_svm_3d = LinearSVC().fit(X_new, y)
coef, intercept = linear_svm_3d.coef_.ravel(), linear_svm_3d.intercept_

figure = plt.figure()
ax = Axes3D(figure, elev=-152, azim=-26)
xx = np.linspace(X_new[:, 0].min() -2, X_new[:, 0].max()+2, 50)
yy = np.linspace(X_new[:, 1].min() -2, X_new[:, 1].max()+2, 50)

XX, YY = np.meshgrid(xx,yy)
ZZ = (coef[0]*XX + coef[1]*YY + intercept)/ -coef[2]

ax.plot_surface(XX, YY, ZZ, rstride=8, cstride=8, alpha=0.3)
ax.scatter(X_new[mask, 0], X_new[mask, 1], X_new[mask, 2], c='b', cmap=mglearn.cm2, s=60)
ax.scatter(X_new[~mask, 0], X_new[~mask, 1], X_new[~mask, 2], c='r', marker='^', cmap=mglearn.cm2, s=60)
ax.set_xlabel("feature0")
ax.set_ylabel("feature1")
ax.set_zlabel("feature1 **2")

f:id:mathgeekjp:20170903194357p:plain

このように、平面による分類が可能になります。

カーネル

非線形特徴量を追加すると、線形モデルが柔軟になることがわかりました。しかし、どのような特徴量を追加したら良いのでしょうか。特徴量が1000次元だとして、すべての可能な積を加えるなどしたら、計算量は大きくなってしまいます。

実際には、高次元空間での分類器を学習する方法としてカーネルトリックを用います。方法としては、多項式カーネルやガウシアンカーネル(RBF)を用います。ただし、SVMにおける数学的背景は複雑なので、詳細は割愛します。データポイント間の距離を測るガウシアンカーネルは以下で与えられます。

f:id:mathgeekjp:20170903200309p:plain

SVMでは、個々のデータポイントが2つのクラスの決定境界に対してどの程度重要かを学習し、データポイントの予測の際に距離を測定します。

rbfを使ってみる

SVMのパラメータとして、Cとgammaがあります。Cは正則化パラメータで、gammaはrbfの直径の大きさです。パラメータ調整をして確かめてみます。

from sklearn.svm import SVC
from sklearn.model_selection import train_test_split

X, y = mglearn.tools.make_handcrafted_dataset()
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.33)

fig, axes = plt.subplots(3, 3, figsize=(15, 10))

for ax, C in zip(axes, [-1, 0, 3]):
    for a, gamma in zip(ax, range(-1, 2)):
        mglearn.plots.plot_svm(log_C=C, log_gamma=gamma, ax=a)
        
axes[0,0].legend(["class 0", "class 1", "sv class 0", "sv class 1"], ncol=4, loc=(.9, 1.2))

for C in [0.1, 1, 1000]:
    for gamma in [0.1, 1, 10]:
            svc = SVC(kernel='rbf', C=C, gamma=gamma)
            svc.fit(X_train, y_train)
            print("C={}, gamma={}".format(C, gamma))
            print("Accuracy on training set: {:.2f}".format(svc.score(X_train, y_train)))
            print("Accuracy on test set: {:.2f}".format(svc.score(X_test, y_test)))
            print("\n")

f:id:mathgeekjp:20170903200006p:plain

C=0.1, gamma=0.1
Accuracy on training set: 0.53
Accuracy on test set: 0.44


C=0.1, gamma=1
Accuracy on training set: 0.53
Accuracy on test set: 0.44


C=0.1, gamma=10
Accuracy on training set: 0.53
Accuracy on test set: 0.44


C=1, gamma=0.1
Accuracy on training set: 0.94
Accuracy on test set: 0.89


C=1, gamma=1
Accuracy on training set: 0.94
Accuracy on test set: 0.89


C=1, gamma=10
Accuracy on training set: 1.00
Accuracy on test set: 0.67


C=1000, gamma=0.1
Accuracy on training set: 1.00
Accuracy on test set: 0.89


C=1000, gamma=1
Accuracy on training set: 1.00
Accuracy on test set: 0.89


C=1000, gamma=10
Accuracy on training set: 1.00
Accuracy on test set: 0.78

SVMはパラメータの設定に強く影響をうけていることがわかります。実際には、特徴量のスケール変換により、スケールが同じぐらいになるようにすると良いかもしれません。スケール変換は、それぞれの特徴量の最小値をデータから引き、最大のレンジで割ることで0〜1の範囲に収めます。

まとめ

SVMを用いると、特徴量がわずかでも、複雑な決定境界を生成できます。高次元でも機能しますが、サンプルが多いとコンピュータの性能上の困難があります。

rbfを用いたSVMでは、Cとgammmaをパラメータとしました。どちらもモデルの複雑さを制御しますが、大きくするとモデルは複雑になります。

参考

  1. calculatedcontent.com

  2. github.com

  3. github.com

  4. shop.oreilly.com