ナード戦隊データマン

データサイエンスを用いて悪と戦うぞ

sklearnのjsonシリアライズ・デシリアライズ

scikit-learnのモデルは、通常はpickleで保存しますが、移植性の観点では好ましい方法ではありません。ここでは、移植性を高めるためにjsonシリアライズします。

コード

汚いコードですがお許しください。

from functools import reduce
import operator
import numpy as np
import traceback
import json

def getFromDict(dataDict, mapList):
    return reduce(operator.getitem, mapList, dataDict)

def setInDict(dataDict, mapList, value):
    getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
    
def model2dict(model):
    types = [dict, str, int, float, type(True), type(False), type(None)]
    stack = [{"cl":model,"key":[]}]
    out = {"_modulename":model.__class__.__module__, "_classname":model.__class__.__name__}
    while stack:
        t = stack.pop(0)
        for x in dir(t['cl']):
            try:
                val = getattr(t["cl"], x)
            except AttributeError as e:
                continue
            try:
                if callable(val) or x.startswith('__') or x.startswith("_abc"):
                    continue
                if type(val) in types:
                    setInDict(out,t['key']+[x],val)
                elif "numpy" in str(type(val)):
                    val = {"numpy": val.tolist(), "dtype":val.dtype.name}
                    setInDict(out, t['key'] + [x], val)
                elif isinstance(val, list) or isinstance(val, tuple):
                    setInDict(out, t['key'] + [x], {})
                    if type(val[0]) in types:
                        setInDict(out, t['key'] + [x], val)
                    else:
                        for i, v in enumerate(val):
                            if "numpy" in str(type(v)):
                                val = {"numpy": v.tolist(), "dtype":v.dtype.name}
                                setInDict(out, t['key']+[x]+[i], v)
                            else:
                                setInDict(out, t['key']+[x]+[i],{"_modulename":v.__class__.__module__, "_classname":v.__class__.__name__})
                                stack.append({"cl":v, "key": t['key']+[x]+[i]})
                else:
                    setInDict(out, t['key'] + [x], {"_modulename":val.__class__.__module__, "_classname":val.__class__.__name__})
                    stack.append({"cl":val, "key":t['key']+[x]})
            except Exception as e:
                break
    return out


def dict2model(data):
    import importlib
    module = importlib.import_module(data['_modulename'])
    try:
        model = getattr(module, data['_classname'])()
    except:
        raise Exception("The model isn't supported yet.", data['_modulename'])
    for k,v in data.items():
        if k=="_classname" or k=="_modulename":
            continue
        if isinstance(v, dict) and "_classname" in v:
            setattr(model, k, dict2model(v))
        elif isinstance(v, dict) and "numpy" in v:
            try:
                setattr(model, k, np.array(v["numpy"],dtype=v["dtype"]))
            except:
                pass
        elif isinstance(v, dict) and 0 in v:
            lst = []
            for ind,d in v.items():
                if "_classname" in d:
                    lst.append(dict2model(d))
                else:
                    lst.append(d)
            setattr(model, k, lst)                
        else:
            try:
                setattr(model, k, v)
            except Exception as e:
                pass
    return model


if __name__ == "__main__":
    from pprint import pprint, pformat
    from sklearn.datasets import load_iris
    from sklearn.linear_model import LogisticRegression
    from sklearn.svm import SVC
    from sklearn.naive_bayes import GaussianNB
    
    d = load_iris()
    X,y = d['data'], d['target']
    #clf = LogisticRegression().fit(X,y)
    clf = SVC().fit(X,y)
    #clf = GaussianNB().fit(X, y)
    data = model2dict(clf)
    print(json.dumps(data, indent=2, sort_keys=True))
    model = dict2model(data) 
    ypred = model.predict(X)

irisデータに対するSVCjson出力した結果は以下

{
  "C": 1.0,
  "_classname": "SVC",
  "_dual_coef_": {
    "dtype": "float64",
    "numpy": [
      [
        0.0,
        0.4386244104464341,
        0.06663115831281492,
        0.002277020749968409,
        0.35843112572827146,
        1.0,
        0.4588582343970687,
        -0.5517100820489875,
        -0.0,
        -0.0,
        -0.0,
        -0.5428293066095111,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.1664909718151888,
        -0.0,
        -0.06379158916087012,
        -0.0,
        -0.0,
        -0.0,
        -1.0,
        -0.11757608131084718,
        -0.7337655049344587,
        -0.0,
        -0.4272359406177839,
        -0.09993001758512109,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.4053395104913362,
        -0.0,
        -0.0,
        -0.0,
        -0.16140841734798533,
        -0.0,
        -0.0,
        -0.0,
        -0.0
      ],
      [
        0.009368831960995326,
        0.9556474901217844,
        0.0,
        0.0,
        0.011617305692425932,
        0.9686218445123266,
        0.0,
        0.0,
        1.0,
        1.0,
        1.0,
        0.0,
        0.40434216978125864,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        0.27391516456770126,
        1.0,
        0.0,
        -0.0,
        -1.0,
        -1.0,
        -0.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -0.8300343547007096,
        -0.13092148302187717,
        -1.0,
        -0.6237238841373084,
        -1.0,
        -0.7730127984352919,
        -0.32056481405377285,
        -1.0,
        -1.0,
        -1.0
      ]
    ]
  },
  "_estimator_type": "classifier",
  "_gamma": 0.25,
  "_impl": "c_svc",
  "_intercept_": {
    "dtype": "float64",
    "numpy": [
      -0.03985691051205832,
      -0.1677745320943027,
      -0.143704687543137
    ]
  },
  "_modulename": "sklearn.svm.classes",
  "_pairwise": false,
  "_sparse": false,
  "_sparse_kernels": [
    "linear",
    "poly",
    "rbf",
    "sigmoid",
    "precomputed"
  ],
  "cache_size": 200,
  "class_weight": null,
  "class_weight_": {
    "dtype": "float64",
    "numpy": [
      1.0,
      1.0,
      1.0
    ]
  },
  "classes_": {
    "dtype": "int64",
    "numpy": [
      0,
      1,
      2
    ]
  },
  "coef0": 0.0,
  "decision_function_shape": "ovr",
  "degree": 3,
  "dual_coef_": {
    "dtype": "float64",
    "numpy": [
      [
        0.0,
        0.4386244104464341,
        0.06663115831281492,
        0.002277020749968409,
        0.35843112572827146,
        1.0,
        0.4588582343970687,
        -0.5517100820489875,
        -0.0,
        -0.0,
        -0.0,
        -0.5428293066095111,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.1664909718151888,
        -0.0,
        -0.06379158916087012,
        -0.0,
        -0.0,
        -0.0,
        -1.0,
        -0.11757608131084718,
        -0.7337655049344587,
        -0.0,
        -0.4272359406177839,
        -0.09993001758512109,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.0,
        -0.4053395104913362,
        -0.0,
        -0.0,
        -0.0,
        -0.16140841734798533,
        -0.0,
        -0.0,
        -0.0,
        -0.0
      ],
      [
        0.009368831960995326,
        0.9556474901217844,
        0.0,
        0.0,
        0.011617305692425932,
        0.9686218445123266,
        0.0,
        0.0,
        1.0,
        1.0,
        1.0,
        0.0,
        0.40434216978125864,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        1.0,
        0.27391516456770126,
        1.0,
        0.0,
        -0.0,
        -1.0,
        -1.0,
        -0.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -1.0,
        -0.8300343547007096,
        -0.13092148302187717,
        -1.0,
        -0.6237238841373084,
        -1.0,
        -0.7730127984352919,
        -0.32056481405377285,
        -1.0,
        -1.0,
        -1.0
      ]
    ]
  },
  "epsilon": 0.0,
  "fit_status_": 0,
  "gamma": "auto_deprecated",
  "intercept_": {
    "dtype": "float64",
    "numpy": [
      -0.03985691051205832,
      -0.1677745320943027,
      -0.143704687543137
    ]
  },
  "kernel": "rbf",
  "max_iter": -1,
  "n_support_": {
    "dtype": "int32",
    "numpy": [
      7,
      19,
      19
    ]
  },
  "nu": 0.0,
  "probA_": {
    "dtype": "float64",
    "numpy": []
  },
  "probB_": {
    "dtype": "float64",
    "numpy": []
  },
  "probability": false,
  "random_state": null,
  "shape_fit_": [
    150,
    4
  ],
  "shrinking": true,
  "support_": {
    "dtype": "int32",
    "numpy": [
      13,
      15,
      18,
      23,
      24,
      41,
      44,
      50,
      52,
      54,
      56,
      57,
      60,
      63,
      66,
      68,
      70,
      72,
      76,
      77,
      78,
      83,
      84,
      85,
      86,
      98,
      100,
      106,
      110,
      118,
      119,
      121,
      123,
      126,
      127,
      129,
      131,
      133,
      134,
      138,
      141,
      142,
      146,
      147,
      149
    ]
  },
  "support_vectors_": {
    "dtype": "float64",
    "numpy": [
      [
        4.3,
        3.0,
        1.1,
        0.1
      ],
      [
        5.7,
        4.4,
        1.5,
        0.4
      ],
      [
        5.7,
        3.8,
        1.7,
        0.3
      ],
      [
        5.1,
        3.3,
        1.7,
        0.5
      ],
      [
        4.8,
        3.4,
        1.9,
        0.2
      ],
      [
        4.5,
        2.3,
        1.3,
        0.3
      ],
      [
        5.1,
        3.8,
        1.9,
        0.4
      ],
      [
        7.0,
        3.2,
        4.7,
        1.4
      ],
      [
        6.9,
        3.1,
        4.9,
        1.5
      ],
      [
        6.5,
        2.8,
        4.6,
        1.5
      ],
      [
        6.3,
        3.3,
        4.7,
        1.6
      ],
      [
        4.9,
        2.4,
        3.3,
        1.0
      ],
      [
        5.0,
        2.0,
        3.5,
        1.0
      ],
      [
        6.1,
        2.9,
        4.7,
        1.4
      ],
      [
        5.6,
        3.0,
        4.5,
        1.5
      ],
      [
        6.2,
        2.2,
        4.5,
        1.5
      ],
      [
        5.9,
        3.2,
        4.8,
        1.8
      ],
      [
        6.3,
        2.5,
        4.9,
        1.5
      ],
      [
        6.8,
        2.8,
        4.8,
        1.4
      ],
      [
        6.7,
        3.0,
        5.0,
        1.7
      ],
      [
        6.0,
        2.9,
        4.5,
        1.5
      ],
      [
        6.0,
        2.7,
        5.1,
        1.6
      ],
      [
        5.4,
        3.0,
        4.5,
        1.5
      ],
      [
        6.0,
        3.4,
        4.5,
        1.6
      ],
      [
        6.7,
        3.1,
        4.7,
        1.5
      ],
      [
        5.1,
        2.5,
        3.0,
        1.1
      ],
      [
        6.3,
        3.3,
        6.0,
        2.5
      ],
      [
        4.9,
        2.5,
        4.5,
        1.7
      ],
      [
        6.5,
        3.2,
        5.1,
        2.0
      ],
      [
        7.7,
        2.6,
        6.9,
        2.3
      ],
      [
        6.0,
        2.2,
        5.0,
        1.5
      ],
      [
        5.6,
        2.8,
        4.9,
        2.0
      ],
      [
        6.3,
        2.7,
        4.9,
        1.8
      ],
      [
        6.2,
        2.8,
        4.8,
        1.8
      ],
      [
        6.1,
        3.0,
        4.9,
        1.8
      ],
      [
        7.2,
        3.0,
        5.8,
        1.6
      ],
      [
        7.9,
        3.8,
        6.4,
        2.0
      ],
      [
        6.3,
        2.8,
        5.1,
        1.5
      ],
      [
        6.1,
        2.6,
        5.6,
        1.4
      ],
      [
        6.0,
        3.0,
        4.8,
        1.8
      ],
      [
        6.9,
        3.1,
        5.1,
        2.3
      ],
      [
        5.8,
        2.7,
        5.1,
        1.9
      ],
      [
        6.3,
        2.5,
        5.0,
        1.9
      ],
      [
        6.5,
        3.0,
        5.2,
        2.0
      ],
      [
        5.9,
        3.0,
        5.1,
        1.8
      ]
    ]
  },
  "tol": 0.001,
  "verbose": false
}

このコードで対応できなかったもの

LogisticRegression, SVC, GaussianNBは一応対応できましたが、Treeベースの手法は対応が難しかったです。出力自体は一応成功しますが、ロードに失敗します。

理由の一つとして、Treeベース手法が利用しているcythonのTreeクラスの初期化には必ず引数を3つ代入する必要があり、その代入を考慮するとコードが特殊になってしまうためです。

jsonシリアライズの何が楽しいのか

メリット

  1. モデルをjson形式で表現できれば、理論的にはあらゆるプログラミング言語でモデルを利用できます。
  2. モデルの中身を理解することができます。
  3. バージョン互換性の問題の回避に役立てられます。
  4. 半永久的にファイルの内容が意味を持ちます。

デメリット

  1. ロードに時間がかかる可能性があります。
  2. シリアライズ・デシリアライズのためのコードを書く必要があります。