ナード戦隊データマン

機械学習, 自然言語処理, データサイエンスについてのブログ

逆翻訳をtensor2tensorで実行

逆翻訳とは、ターゲット言語のモノリンガルコーパスを訓練済みNMTで翻訳することによりデータを増やす方法です。

実行フロー

  1. ドメインコーパスでsrc->tgtとtgt-srcモデルを訓練。
  2. tgt-srcモデルを使ってモノリンガルコーパスを翻訳し、ドメインコーパスに追加。
  3. 訓練済みのsrc-tgtモデルに対して、2のコーパスでファインチューニング。

f:id:mathgeekjp:20190827115702j:plain

tensor2tensorでの実行方法

vocabの生成

まず、なんらかのコーパスを用いて1つのvocabファイルをt2t-datagenで生成しておきます。このvocabファイルを4つコピーし、以下に配置します。

data/data_mtnt_jaen/vocab.translate_jaen.32768.subwords
data/data_mtnt_enja/vocab.translate_enja.32768.subwords
data/data_mtnt_jaen_domain/vocab.translate_jaen__backtranslation.32768.subwords
data/data_mtnt_enja_domain/vocab.translate_enja__backtranslation.32768.subwords

これは、ファインチューニングのときも、同一のsubword分割を行うためです。

シェルスクリプト

echo "datagen1"
t2t-datagen \
    --data_dir=data/data_mtnt_jaen \
    --tmp_dir=/tmp/mtnt_jaen_1 \
    --problem=translate_jaen \
    --t2t_usr_dir=.
 
echo "datagen2"
t2t-datagen \
    --data_dir=data/data_mtnt_enja \
    --tmp_dir=/tmp/mtnt_enja_1 \
    --problem=translate_enja \
    --t2t_usr_dir=.
 
echo "trainer1"
t2t-trainer \
    --data_dir=data/data_mtnt_jaen \
    --problem=translate_jaen \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --train_steps=250000 \
    --batch_size=8000 \
    --output_dir=training_results/training_result_mtnt_jaen_1 \
    --t2t_usr_dir=.
 
echo "trainer2"
t2t-trainer \
    --data_dir=data/data_mtnt_enja \
    --problem=translate_enja \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --train_steps=250000 \
    --batch_size=8000 \
    --output_dir=training_results/training_result_mtnt_enja_1 \
    --t2t_usr_dir=.
 
 
echo "bt1"
t2t-decoder \
    --data_dir=data/data_mtnt_enja \
    --problem=translate_enja \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --output_dir=training_results/training_result_mtnt_enja_1 \
    --decode_hparams="beam_size=4,alpha=0.6" \
    --decode_from_file="./data/MTNT/monolingual/train.en.placeholded.dropped" \
    --decode_to_file="./data/MTNT/monolingual/train.ja.bt" --t2t_usr_dir=.
 
echo "bt2"
t2t-decoder \
    --data_dir=data/data_mtnt_jaen \
    --problem=translate_jaen \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --output_dir=training_results/training_result_mtnt_jaen_1 \
    --decode_hparams="beam_size=4,alpha=0.6" \
    --decode_from_file="./data/MTNT/monolingual/train.ja.placeholded.dropped" \
    --decode_to_file="./data/MTNT/monolingual/train.en.bt" --t2t_usr_dir=.
 
echo "cat1"
cat ./data/MTNT/monolingual/train.ja.bt ./data/MTNT/train/train_mtnt.ja-en.ja.placeholded > ./data/MTNT/train/train.domain.ja-en.ja
cat ./data/MTNT/monolingual/train.en.placeholded.dropped ./data/MTNT/train/train_mtnt.ja-en.en.placeholded > ./data/MTNT/train/train.domain.ja-en.en
 
echo "cat2"
cat ./data/MTNT/monolingual/train.en.bt ./data/MTNT/train/train_mtnt.en-ja.en.placeholded > ./data/MTNT/train/train.domain.en-ja.en
cat ./data/MTNT/monolingual/train.ja.placeholded.dropped ./data/MTNT/train/train_mtnt.en-ja.ja.placeholded > ./data/MTNT/train/train.domain.en-ja.ja
 
echo "domain datagen1"
t2t-datagen \
    --data_dir=data/data_mtnt_jaen_domain \
    --tmp_dir=/tmp/mtnt_jaen_2 \
    --problem=translate_jaen__backtranslation \
    --t2t_usr_dir=.
 
echo "domain datagen2"
t2t-datagen \
    --data_dir=data/data_mtnt_enja_domain \
    --tmp_dir=/tmp/mtnt_enja_2 \
    --problem=translate_enja__backtranslation \
    --t2t_usr_dir=.
 
echo "domain trainer1"
t2t-trainer \
    --data_dir=data/data_mtnt_jaen_domain \
    --problem=translate_jaen__backtranslation \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --train_steps=50000 \
    --batch_size=8000 \
    --output_dir=training_results/training_result_mtnt_jaen_2 \
    --t2t_usr_dir=. \
    --warm_start_from=training_results/training_result_mtnt_jaen_1/model.ckpt-250000
 
echo "domain trainer2"
t2t-trainer \
    --data_dir=data/data_mtnt_enja_domain \
    --problem=translate_enja__backtranslation \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --train_steps=50000 \
    --batch_size=8000 \
    --output_dir=training_results/training_result_mtnt_enja_2 \
    --t2t_usr_dir=. \
    --warm_start_from=training_results/training_result_mtnt_enja_1/model.ckpt-250000
 
 
echo "ensembling1"
pushd training_results
python3 avg_checkpoints.py --checkpoints="./training_result_mtnt_jaen_2/model.ckpt-47000,./training_result_mtnt_jaen_2/model.ckpt-48000,./training_result_mtnt_jaen_2/model.ckpt-49000,./training_result_mtnt_jaen_2/model.ckpt-50000" --output_path="./model"
mkdir training_result_mtnt_jaen_3
mv model-0.* training_result_mtnt_jaen_3
mv checkpoint training_result_mtnt_jaen_3
 
echo "ensembling2"
python3 avg_checkpoints.py --checkpoints="./training_result_mtnt_enja_2/model.ckpt-47000,./training_result_mtnt_enja_2/model.ckpt-48000,./training_result_mtnt_enja_2/model.ckpt-49000,./training_result_mtnt_enja_2/model.ckpt-50000" --output_path="./model"
mkdir training_result_mtnt_enja_3
mv model-0.* training_result_mtnt_enja_3
mv checkpoint training_result_mtnt_enja_3
popd
 
echo "training done"

このスクリプトは見たまんまのことをしていますが、backtranslation時には別のproblemを指定しています。この際、warm_start_fromを指定することで訓練済みモデルをファインチューニングします。

myproblem.pyの中身

from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
 
 
@registry.register_problem
class Translate_JAEN(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**15  # 32k
 
    @property
    def is_generate_per_split(self):
        return False
 
    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]
 
    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        filename_jp = "./data/MTNT/train.ja-en.ja"
        filename_en = "./data/MTNT/train.ja-en.en"
        with open(filename_jp) as f_jp, open(filename_en) as f_en:
            for src, tgt in zip(f_jp, f_en):
                src = src.strip()
                tgt = tgt.strip()
                if not src or not tgt:
                    continue
                yield {'inputs': src, 'targets': tgt}
 
 
@registry.register_problem
class Translate_ENJA(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**15  # 32k
 
    @property
    def is_generate_per_split(self):
        return False
 
    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]
 
    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        filename_jp = "./data/MTNT/train.en-ja.ja"
        filename_en = "./data/MTNT/train.en-ja.en"
        with open(filename_jp) as f_jp, open(filename_en) as f_en:
            for tgt, src in zip(f_jp, f_en):
                src = src.strip()
                tgt = tgt.strip()
                if not src or not tgt:
                    continue
                yield {'inputs': src, 'targets': tgt}
 
 
@registry.register_problem
class Translate_JAEN_Backtranslation(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**15  # 32k
 
    @property
    def is_generate_per_split(self):
        return False
 
    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]
 
    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        filename_jp = "./data/MTNT/train/train.domain.ja-en.ja"
        filename_en = "./data/MTNT/train/train.domain.ja-en.en"
        with open(filename_jp) as f_jp, open(filename_en) as f_en:
            for src, tgt in zip(f_jp, f_en):
                src = src.strip()
                tgt = tgt.strip()
                if not src or not tgt:
                    continue
                yield {'inputs': src, 'targets': tgt}
 
 
@registry.register_problem
class Translate_ENJA_Backtranslation(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**15  # 32k
 
    @property
    def is_generate_per_split(self):
        return False
 
    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]
 
    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        filename_jp = "./data/MTNT/train/train.domain.en-ja.ja"
        filename_en = "./data/MTNT/train/train.domain.en-ja.en"
        with open(filename_jp) as f_jp, open(filename_en) as f_en:
            for tgt, src in zip(f_jp, f_en):
                src = src.strip()
                tgt = tgt.strip()
                if not src or not tgt:
                    continue
                yield {'inputs': src, 'targets': tgt}

まとめ

  • 逆翻訳は、ドメインのモノリンガルコーパスを翻訳して、データを水増しする。
  • t2tのファインチューニングは、warm_start_fromを指定することで実行する。
  • ファインチューニング時は同じvocabファイルのコピーを使う。

補足

いくつかのデータ補正スクリプトを使う必要があるかもしれません。

placeholder

placeholderに対する補正を事前に行ってしまいます。このフェーズは、何をplaceholderとするかに依存しています。正規表現や固有表現抽出などが使われます。なお、placeholderとして変換されたものは、翻訳の後処理でもとに戻す必要があります。

文の長さに対する補正

subwordの数が128より小さい文だけを保持します。

from tensor2tensor.data_generators import text_encoder
from tqdm import tqdm
 
encoder = text_encoder.SubwordTextEncoder(
    "../../data_mtnt_enja_domain/vocab.translate_enja_backtranslation.32768.subwords"
)
 
 
def drop_them(datafile):
    out = ""
    with open(datafile) as f:
        for line in tqdm(f):
            line = line.strip()
            n_subwords = len(encoder.encode(line))
            if n_subwords < 128:
                out += line + "\n"
    return out
 
 
def run():
    infile = "./train.en.placeholded"
    out = drop_them(infile)
    with open(infile + ".dropped", "w") as f:
        f.write(out)
 
    infile = "./train.ja.placeholded"
    out = drop_them(infile)
    with open(infile + ".dropped", "w") as f:
        f.write(out)
 
 
if __name__ == "__main__":
    run()

言語の確認

言語が正しいことを確認できたデータのみを保持します。

from tqdm import tqdm
from langdetect import detect
 
 
def selecting(srcfile, tgtfile, srclang, tgtlang):
    with open(srcfile) as fs, open(tgtfile) as ft:
        with open(srcfile+".selected", "w") as fsw, \
             open(tgtfile+".selected", "w") as ftw:
            for sline, tline in tqdm(zip(fs, ft)):
                try:
                    sd = detect(sline)
                    td = detect(tline)
                except:
                    continue
                if sd == srclang and td == tgtlang:
                    fsw.write(sline)
                    ftw.write(tline)
 
 
def run():
    selecting("./train.en.placeholded.dropped", "./train.ja.bt", "en", "ja")
    selecting("./train.ja.placeholded.dropped", "./train.en.bt", "ja", "en")
 
 
if __name__ == "__main__":
    run()