ナード戦隊データマン

機械学習と自然言語処理についてのブログ

doccanoでAutoLabelingっぽいものを自前で実装

doccano1とは、@Hironsan が作成しているアノテーションツールです。今回は、いくつかの要件に基づいて自動ラベリングをテストします。

要件

  1. 手作業でアノテーションしたプロジェクトが数件ある。
  2. それらのプロジェクトのラベルは一致している。
  3. 自動ラベリング用のプロジェクトも同様に一致したラベルリストがある。
  4. 手作業でアノテーションしたプロジェクトから分類器を訓練する。
  5. 訓練されたモデルで自動ラベリング用のプロジェクトのテキストを分類する。
  6. 分類したラベルをDBへ突っ込む。

コード

1. ラベルの保存

import sqlite3
import json

def connect(dbpath="/doccano/app/db.sqlite3"):
    return sqlite3.connect(dbpath)


def get_labels(conn, base_project=[1,3], ml_project=2):
    c = conn.cursor()
    out = {}
    out2 = None
    for idx in base_project:
        c.execute("select * from server_label where project_id=?", (idx, ));
        labels = c.fetchall()
        labels = {x[0]:x[1] for x in labels}
        out[idx] = labels
    c.execute("select * from server_label where project_id=?", (ml_project, ));
    labels = c.fetchall()
    labels = {x[0]:x[1] for x in labels}
    out2 = labels
    return out, out2


def test_out(out, out2):
    base = None
    for k,v in out.items():
        if base is None:
            base = sorted(list(v.values()))
        else:
            assert(base == sorted(list(v.values())))
    assert(base == sorted(list(out2.values())))
    return True


def save(out, out2):
    with open("/doccano/labels.json", "w") as f:
        json.dump({"base":out, "ml":out2}, f)
    return True

def run(base_ids):
    conn = connect()
    out, out2 = get_labels(conn, base_ids)
    test_out(out, out2)
    save(out,out2)

if __name__ == "__main__":
    import sys
    base_ids = list(map(int, sys.argv[1].split(",")))
    run(base_ids)

2. 訓練

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from keras.preprocessing.sequence import pad_sequences
from keras.layers import Input, Dense, Embedding, GRU
from keras.layers import SpatialDropout1D
from keras.layers.convolutional import SeparableConv1D, MaxPooling1D
from keras.models import Sequential
from keras.callbacks import ModelCheckpoint
import sentencepiece as spm
import pickle
import numpy as np
import sqlite3
import json
from sklearn.model_selection import train_test_split
 
def build_model(num_class, max_features=15000, dim=200, max_len=300, dropout_rate=0.2, gru_size=100):
    model = Sequential()
    model.add(Embedding(max_features+1, dim, input_length=max_len))
    model.add(SpatialDropout1D(dropout_rate))
    model.add(SeparableConv1D(32, kernel_size=3, padding='same', activation='relu'))
    model.add(MaxPooling1D(pool_size=2))
    model.add(SeparableConv1D(64, kernel_size=3, padding='same', activation='relu'))
    model.add(MaxPooling1D(pool_size=2))
    model.add(GRU(gru_size))
    model.add(Dense(num_class, activation='sigmoid', kernel_initializer='normal'))
    model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
    return model 

def load_labels(labelfile="/doccano/labels.json"):
    label2ind = []
    with open(labelfile) as f:
        base = json.load(f)["base"]
        for k,v in base.items():
             labels = sorted(list(v.items()), key=lambda x: x[1])
             label2ind += [(int(label[0]),int(i)) for i,label in enumerate(labels)]
        label2ind = dict(label2ind)
    print(label2ind)
    return label2ind

def prepare(
        labelfile="/doccano/labels.json",
        dbfile="/doccano/app/db.sqlite3",
        spfile="/doccano/sp_model/jawiki.model"
):
    label2ind = load_labels(labelfile)
    conn = sqlite3.connect(dbfile)
    sp = spm.SentencePieceProcessor()
    sp.Load(spfile)
    return label2ind, conn, sp


def label2vec(label, label2ind):
    out = np.zeros(len(label2ind.keys()))
    out[label2ind[label]] = 1.0
    return out

def preprocess(data, label2ind):
    X = [x[0] for x in data]
    y = [label2vec(x[1], label2ind) for x in data]
    y = np.array(y)
    X_fix = np.array([sp.EncodeAsIds(str(x)) for x in X])
    X_fix = pad_sequences(X_fix, 300) 
    X, y = shuffle(X_fix, y) 
    X_train, X_val, y_train, y_val = train_test_split(X, y)
    return X_train, X_val, y_train, y_val


def search(conn, project_ids=[1]):
    data = []
    c = conn.cursor()
    for project_id in project_ids:
        c.execute("""
        select text, label_id 
        from server_document 
        inner join server_documentannotation 
        on server_document.id = server_documentannotation.document_id 
        where project_id=?
        """, (project_id,))
        data += c.fetchall()
    return data
   
if __name__ == "__main__":
    label2ind, conn, sp = prepare()
    data = search(conn)
    X_train, X_val, y_train, y_val = preprocess(data, label2ind)
    model = build_model(len(label2ind.keys()))
    mcp_save = ModelCheckpoint('/doccano/model.h5', save_best_only=True, monitor='val_loss', mode='min')
    model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val), epochs=15, batch_size=50, verbose=0, callbacks=[mcp_save])

3. 分類

# coding: utf-8
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences
import sentencepiece as spm
import pickle
import numpy as np
import sqlite3
import json
from sklearn.model_selection import train_test_split
 

def load_labels(labelfile="/doccano/labels.json"):
    with open(labelfile) as f:
        v = json.load(f)["ml"]
        labels = sorted(list(v.items()), key=lambda x: x[1])
        ind2label = {int(i):int(label[0]) for i,label in enumerate(labels)}
    return ind2label

def prepare(
        labelfile="/doccano/labels.json",
        dbfile="/doccano/app/db.sqlite3",
        spfile="/doccano/sp_model/jawiki.model",
        modelfile="/doccano/model.h5"
):
    ind2label = load_labels(labelfile)
    conn = sqlite3.connect(dbfile)
    sp = spm.SentencePieceProcessor()
    sp.Load(spfile)
    model = load_model(modelfile)
    return ind2label, conn, sp, model


def preprocess(data):
    X = [str(x[1]) for x in data]
    X_fix = np.array([sp.EncodeAsIds(str(x)) for x in X])
    X_fix = pad_sequences(X_fix, 300) 
    return np.array(X_fix)


def search(conn, project_id=2):
    c = conn.cursor()
    c.execute("""
        select server_document.id, text 
        from server_document
        where project_id=?
    """, (project_id,))
    data = c.fetchall()
    return data


if __name__ == "__main__":
    ind2label, conn, sp, model = prepare()
    data = search(conn)
    X = preprocess(data)
    preds = model.predict_proba(X)
    out = []
    for d, xs in zip(data, preds):
        i = np.argmax(xs)
        out.append((d[0],ind2label[i],float(xs[i])))
    with open("pred.json", "w") as f:
        json.dump(out,f)

4. DBへ突っ込む

import sqlite3
import json
from datetime import date

def run():
    labdata, conn, preds = prepare()
    label_ids = get_labels(labdata)
    delete_doclab(conn, label_ids)
    update_db(conn, preds)

def prepare(
        dbfile="/doccano/app/db.sqlite3",
        labelfile="/doccano/labels.json",
        predfile="/doccano/pred.json"
):
    with open(labelfile) as f:
        labdata = json.load(f)

    conn = sqlite3.connect(dbfile)

    with open(predfile) as f:
        preds = json.load(f)
        
    return labdata, conn, preds


def get_labels(labdata):
    return list(labdata["ml"].keys())


def delete_doclab(conn, label_ids):
    c = conn.cursor()
    sql = "delete from server_documentannotation where label_id=?"
    for label_id in label_ids:
        c.execute(sql, (label_id,))
        conn.commit()
    return True

def update_db(conn, preds):
    c = conn.cursor()
    sql = "insert into server_documentannotation (document_id,label_id,prob,user_id,created_at,updated_at,manual) values (?,?,?,?,?,?,?)"
    for pred in preds:
        pred = tuple(pred + [1,date.today(),date.today(),0])
        c.execute(sql, pred)
        conn.commit()
    return True


if __name__ == "__main__":
    run()

5. それらのスクリプトを一気に実行

#!/bin/bash
today=`date +%Y-%m-%d.%H:%M:%S`
touch /doccano/$today.log
python /doccano/save_labels.py 1 2>&1 | tee -a /doccano/$today.log
python /doccano/train.py 2>&1 | tee -a /doccano/$today.log
python /doccano/predict.py 2>&1 | tee -a /doccano/$today.log
python /doccano/update_db.py 2>&1 | tee -a /doccano/$today.log

補足

事前に対象となるプロジェクトIDを特定してください。 上記例では、project_id=1を「手作業アノテーションプロジェクト」、project_id=2を自動アノテーションプロジェクトとしています。

これらのコードは単なる実験用なので、もうすこし抽象化したい場合は、run.shに対してプロジェクトIDのリストを渡したり、あるいはプロジェクトIDのリストをjsonから読み込むような仕組みがあるほうが良さそうです。

また、「手作業アノテーションの補助的な目的」で自動アノテーションを使うような場合は、もっと工夫が必要かもしれません。しかし、ツールが用意している機能をHackしたようなコードなので、動作保証はいたしまてん。

その他のことは、コードを読めばわかるので、説明は割愛。