ナード戦隊データマン

データサイエンスを用いて悪と戦うぞ

BERT苦行録2 - sentencepieceを使って事前訓練

BERTを用いて日本語ツイートの感情分析を試すという記事では、BERTについてファインチューニングと事前訓練を行いました。今回は事前訓練を行う上での注意点を書きます。

1. 語彙数とトーカナイザの問題

MeCabボキャブラリにあわせてBERTを訓練をしようと試みましたが、それは愚かです。そのような方法を使った場合、「語彙数を頻度等で切り捨てる」ようなことをすることになり、カバーできる文が少なくなってしまいます。

@taku910 さんは、sentencepieceの記事で以下のようにおっしゃっています。

単語をそのまま扱うのは実用上の問題点があります。RNNによるテキスト生成では、語彙サイズに依存した計算量が必要となるめ、大規模な語彙を扱えません。高頻度語彙のみに限定することで計算量の問題は回避できますが、低頻度語が捨てられてしまいます。この問題を解決する手法の一つがSentencepieceの土台ともなったサブワードです。

つまり、語彙数に依存したモデルは語彙を圧縮する方法を使ったほうが良さそうです。BERTでは、MeCabを使うよりも、むしろsentencepieceを使ったほうがよいと思います。

sentencepieceを試しに使ったモデルの設定は以下で公開しました。 https://github.com/sugiyamath/bert/tree/master/jamodel

jawiki.model, jawiki.vocabはどちらもsentencepieceのモデルです。

BERTのトーカナイザは、正規化を無効化し、中国文字判定を無効化しておきます。その上で、BERTのモデルに渡す入力データは、毎回事前にsentencepieceでトーカナイズしておきます。これでBERTはsentencepieceに対応できます。

2. 巨大なデータの前処理はデータを分割せよ

jawiki全体を前処理したらOOMが起こりました。この問題をbertのissueで残したら以下の返信がありました:

You should shard the input data (text.txt_00000, text.txt_00001), run the script for each shard (tf_examples.tfrecord_00000, tf_examples.tf_record_00001), and then pass in a file glob (e.g., tf_examples.tfrecord*) to run_pretraining.py. データをtext.txt_00000, text.txt_00001のように分割してそれぞれtf_examples.tfrecord_00000,...でスクリプトを回し、run_pretraining.pyに対してtf_examples.tfrecord*というglobを渡してください。

大体以下のような感じになりました。(マルチプロセスとマルチスレッドを組合せている部分は気休めです。)

from subprocess import check_output
import os
from tqdm import tqdm
from functools import partial
from multiprocessing.pool import ThreadPool

def build_command(input_file, output_file, vocab_file="/root/work/bert/jamodel/vocab.txt"):
    cmd = ["python","/root/work/bert/create_pretraining_data.py",
           "--input_file={}".format(input_file),
           "--output_file={}".format(output_file),
           "--vocab_file={}".format(vocab_file),
       "--do_lower_case=False","--max_seq_length=128","--max_predictions_per_seq=20",
       "--masked_lm_prob=0.15","--random_seed=12345","--dupe_factor=2"]
    result = check_output(cmd)
    return result


def execute(i, input_dir, output_dir, vocab_file):
    try:
        datanum = str(i).rjust(6, '0')
        input_file = os.path.join(input_dir, "text.txt_{}".format(datanum))
        output_file = os.path.join(output_dir, "tf_examples.tf_record_{}".format(datanum))
        if os.path.exists(output_file):
            return None
        else:
            build_command(input_file, output_file, vocab_file)
            print(str(i), end=' ', flush=True)
    except:
        print("Error:"+str(i), end=' ', flush=True)


def execute_them(ds, input_dir, output_dir, vocab_file, poolsize=10):
    try:
        pool = ThreadPool(poolsize)
        func = partial(execute,
            input_dir=input_dir,
            output_dir=output_dir,
            vocab_file=vocab_file)

        pool.map(func, ds)
    except Exception as e:
        print(e)
        print("Error in pool")
        

def main(input_dir, output_dir, vocab_file, datasize=999425):
    from multiprocessing import Pool
    import numpy as np

    poolsize = 5
    targets = np.split(np.array(list(range(datasize))), 5)
    
    func = partial(execute_them,
            input_dir=input_dir,
            output_dir=output_dir,
            vocab_file=vocab_file)
    pool = Pool(5)
    pool.map(func, targets)


if __name__ == "__main__":
    main("/root/work/data/txt_data/", "/root/work/data/records", "/root/work/bert/jamodel/vocab.txt")

ということで、OOMが発生するので分割してください。