ナード戦隊データマン

データサイエンスを用いて悪と戦うぞ

モデルの評価と汎化性能: 交差検証, グリッドサーチ, ROC, AUC

汎化性能を検証するために訓練データとテストデータに分ける方法は、訓練データをどのように選ぶかによって結果が左右されてしまいます。そのような場合に、交差検証を行うことができます。しかし、偏りのあるデータを評価する際には、標準の方法では不十分な場合があり、ROC曲線やAUCによる評価のほうが良い場合があります。ここでは、汎化性能を評価し、高める方法を書きます。

k分割交差検証

k分割交差検証では、データをk分割し、kのうちの1つをテストデータ、k-1個を訓練データとします。テストデータを選ぶ組合せはk個あるので、k回精度分析を行います。

注意すべきなのは、分割によってデータの偏りが生まれることです。目的変数が[0,0,0,1,1,1,2,2,2]のデータがあるとき、単純に3分割すると0だけからなるテストデータ、1だけからなるテストデータ、2だけからなるテストデータとなり、スコアが0になってしまいます。

これを避けるには、層化k分割交差検証を用いたり、あるいは単にデータを事前にシャッフルします。普通のk分割交差検証と層化k分割交差検証の違いを表したものが以下です。

f:id:mathgeekjp:20170914152344p:plain

pythonのsklearnでは、cross_val_score関数を用意しています。Kfold分類オブジェクトと組み合わせることで、詳細な制御が可能です。

from sklearn.model_selection import cross_val_score, KFold, StratifiedKFold
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression

iris = load_iris()
clf = LogisticRegression()

kfolds = []
kfolds.append(KFold(n_splits=3)) #ダメな例
kfolds.append(KFold(n_splits=3, shuffle=True, random_state=0))
kfolds.append(StratifiedKFold(n_splits=3))

for kfold in kfolds:
    print("{}".format(cross_val_score(clf, iris.data, iris.target, cv=kfold)))
[ 0.  0.  0.]
[ 0.9   0.96  0.96]
[ 0.96078431  0.92156863  0.95833333]

最初の例は、irisデータを単純に3分割した例ですが、データに偏りがあるため、精度が0となっています。その次の例は、データをシャッフルしたもの、最後のものは層化したものです。

グリッドサーチ

モデルの汎化性能を最大にするパラメータを選ぶ際に、単純なforループを回す方法もありますが、pythonではGridSearchCVを用意しています。SVMの例で見てみましょう。

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC

param_grid = {'C':[0.001,0.01,0.1,1,10,100], 'gamma':[0.001,0.01,0.1,1,10,100]}
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, random_state=0)
grid_search = GridSearchCV(SVC(), param_grid, cv=5).fit(X_train, y_train)

print(grid_search.score(X_test, y_test))
print("best_params:{}".format(grid_search.best_params_))
0.973684210526
best_params:{'C': 100, 'gamma': 0.01}

grid_searchの内部では交差検証を行っていて、cvはkの値を指定します。best_params_には汎化性能を最大化するパラメータが格納されています。より詳細なレポートを見たい場合は以下を実行してください。

import pandas as pd
display(pd.DataFrame(grid_search.cv_results_))

また、「ネストした交差検証」を用いることもできます。上記では、train_test_splitで分割していますが、これをcross_val_scoreにします。

print(cross_val_score(grid_search, iris.data, iris.target, cv=5))
[ 0.96666667  1.          0.96666667  0.96666667  1.        ]

評価基準の変更

cross_val_scoreやscoreは、標準では分類精度または回帰ではR2を用いています。しかし、データの偏りがある場合、これらの基準が適していないかもしれません。例えば、陽性データが全体の99%を占めている場合、「陽性を選ぶだけのモデル」を作ったとしても、精度は99%になります。こんなモデルは無意味なのに、高い評価を与えているような印象を抱きます。

ROC曲線は、偽陽性率に対して真陽性率をプロットしたグラフです。最も悪いモデルは偽陽性率=真陽性率となっている直線です。最も良いモデルは、左上に寄っています。一方、AUCとはROC曲線の下の面積です。つまり、偏ったデータでは、R2や精度値を用いるより、AUCを用いたほうがよいといえます。グリッドサーチでAUCを用いる場合は、scoringパラメータに"roc_auc"を指定します。

from sklearn.datasets import load_digits
from sklearn.dummy import DummyClassifier
from sklearn.metrics import roc_curve, roc_auc_score
import matplotlib.pyplot as plt

digits = load_digits()
y = digits.target == 9

X_train, X_test, y_train, y_test = train_test_split(digits.data, y, random_state=0)

dummy_majority = DummyClassifier(strategy='most_frequent').fit(X_train, y_train)
grid_search_2 = GridSearchCV(SVC(), param_grid, cv=5, scoring="roc_auc").fit(X_train, y_train)

print("Dummy not AUC:{}".format(dummy_majority.score(X_test, y_test)))
print("SVC not AUC:{}".format(grid_search_2.score(X_test, y_test)))
print("Dummy AUC:{}".format(roc_auc_score(y_test, dummy_majority.predict_proba(X_test)[:,1])))
print("SVC AUC:{}".format(roc_auc_score(y_test, grid_search_2.decision_function(X_test))))

fpr_grid, tpr_grid, thresholds_grid = roc_curve(y_test, grid_search_2.decision_function(X_test))
fpr_dummy, tpr_dummy, thresholds_dummy = roc_curve(y_test, dummy_majority.predict_proba(X_test)[:,1])

plt.plot(fpr_grid, tpr_grid, label="SVC")
plt.plot(fpr_dummy, tpr_dummy, label="Dummy")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.legend()
Dummy not AUC:0.8955555555555555
SVC not AUC:0.9993664537247241
Dummy AUC:0.5
SVC AUC:0.9993664537247241

f:id:mathgeekjp:20170914162558p:plain

このように、標準のスコアリング関数では評価することが難しいような偏ったデータでも、ROCを用いれば精度の違いを明らかにすることができます。

参考

1.3.3. Model evaluation: quantifying the quality of predictions — scikit-learn 0.19.0 documentation

2.Introduction to Machine Learning with Python - O'Reilly Media