ナード戦隊データマン

機械学習と自然言語処理についてのブログ

tensor2tensorの翻訳モデルをサーバ化して実行

tensor2tensorで訓練した翻訳モデルを使いたいとき、複雑なコードを書かずに実行する方法の一つがserving1です。

実行方法

tensor2tensorですでに訓練済みの翻訳モデルが存在しているとします。

1. serving用にexportする (t2t-exporter)

t2t-exporterを使えば、serving用にエクスポートできます。

t2t-exporter --data_dir=data/data2 --problem=translate_jpen --model=transformer --hparams_set=transformer_base_single_gpu --output_dir=training_result2 --decode_hparams="beam_size=4,alpha=0.6" --t2t_usr_dir=.

2. tensorflow_model_serverを立ち上げる

tensorflow_model_server2は、tensorflowのモデルをサーバ化して使うためのモジュールです。

インストール方法は以下です。

echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
apt-get update && apt-get install tensorflow-model-server

インストールしたら、exportしたt2tモデルを読み込んで指定したポートで立ち上げます。

tensorflow_model_server --port=9000 --model_name=my_model --model_base_path=/root/work/mt/model2/training_result2/export/ &

3. pythonからAPIを使う (query.py)

from __future__ import absolute_import, division, print_function

import os

from six.moves import input
from tensor2tensor.serving import serving_utils
from tensor2tensor.utils import hparam, registry, usr_dir


def make_request_fn(servable_name="my_model", server="localhost:9000"):
    request_fn = serving_utils.make_grpc_request_fn(
        servable_name=servable_name, server=server, timeout_secs=30)
    return request_fn


def translate(text,
              udir="/root/work/mt/model2",
              ddir="/root/work/mt/model2/data/data3",
              servable_name="my_model",
              server="localhost:9000"):
    usr_dir.import_usr_dir(udir)
    problem = registry.problem("translate_jpen")
    hparams = hparam.HParams(
        data_dir=os.path.expanduser(ddir))
    problem.get_hparams(hparams)
    request_fn = make_request_fn(servable_name, server)
    inputs = text
    outputs = serving_utils.predict([inputs], problem, request_fn)
    outputs, = outputs
    output, score = outputs
    return {"input": text, "output": output, "score": score}


def run():
    while True:
        inputs = input(">>")
        result = translate(inputs)
        output = result["output"]
        score = result["score"]
        if len(score.shape) > 0:
            print_str = """
Input:
{inputs}

Output (Scores [{score}]):
{output}
        """
            score_text = ",".join(["{:.3f}".format(s) for s in score])
            print(
                print_str.format(inputs=inputs,
                                 output=output,
                                 score=score_text))
        else:
            print_str = """
Input:
{inputs}

Output (Score {score:.3f}):
{output}
        """
            print(print_str.format(inputs=inputs, output=output, score=score))


if __name__ == "__main__":
    run()

translateという関数を呼び出すことにより翻訳が実行できます。udirはユーザディレクトリで、tensor2tensorのユーザ定義問題を定義したファイルの場所です。ddirは、t2tの実行時に指定したデータディレクトリです。servable_nameはサーバ上で実行しているモデル名で、serverはサーバのアドレスです。

ちなみに、runは対話的に翻訳を実行するためのものですが、本質的に重要なのはtranslate関数です。

参考