ナード戦隊データマン

機械学習, 自然言語処理, データサイエンスについてのブログ

ソース文にドメインタグを追加してTransformerを訓練

機械翻訳では、ソース文を編集する手法がいくつかあります。その一つがタグの追加です1

概要

ソース文に以下のような文があるとします。

文1: 11次元超重力理論を実証するためには、北極へ行け。
文2: P対NP問題を俺は解決した。

この文に対して、ドメイン情報があれば、それをタグとしてソース文に追加します。

文1: <Physics>11次元超重力理論を実証するためには、北極へ行け。
文2: <CS>P対NP問題を俺は解決した。

すると、翻訳機はドメインタグを認識し、ドメインに応じた出力をします。これによってBLUEスコアを上げることが期待できます。

とあるデータの例

とあるJA-EN翻訳の例を考えてみます。このデータは、論文の要約文の翻訳と、その論文のドメインの情報を持っています。そこでこのデータを使い、論文のドメインタグの追加がBLUEスコアにどの程度影響するのか見てみます。なお、ここではtensor2tensorを使います。

Note: vocabファイルは事前作成しています。vocabファイル内に、domainタグを挿入しています。

コード

データの展開

import os
from tqdm import tqdm

def extract(datafile, outdir="./data"):
    domains = []
    filename = datafile.split("/")[-1]
    with open(datafile) as f, \
         open(os.path.join(outdir, filename+".ja.fixed"), "w") as f1j, \
         open(os.path.join(outdir, filename+".ja.fixed.withdomain"), "w") as f2j, \
         open(os.path.join(outdir, filename+".en.fixed"), "w") as f3e:
        for line in tqdm(f):
            t = [x.strip() for x in line.split("|||")]
            jaline = t[-2]
            enline = t[-1]
            domain = "<{}>".format(t[-4][0])
            jaline2 = domain + jaline[:]
            if domain not in domains:
                domains.append(domain)
            f1j.write(jaline+"\n")
            f2j.write(jaline2+"\n")
            f3e.write(enline+"\n")
    return domains


def run():
    domains = extract("./data/train/train-1.txt")
    extract("./data/train/train-2.txt")
    extract("./data/train/train-3.txt")
    extract("./data/test/test.txt")
    with open("./data/domain.txt", "w") as f:
        f.write('\n'.join(domains))


if __name__ == "__main__":
    run()

myproblem.py

from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry


@registry.register_problem
class Translate_JAEN(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**15  # 32k

    @property
    def is_generate_per_split(self):
        return False

    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 19,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]

    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        filename_jp = "../data/train-1.txt.ja.fixed"
        filename_en = "../data/train-1.txt.en.fixed"
        with open(filename_jp) as f_jp, open(filename_en) as f_en:
            for src, tgt in zip(f_jp, f_en):
                src = src.strip()
                tgt = tgt.strip()
                if not src or not tgt:
                    continue
                yield {'inputs': src, 'targets': tgt}


@registry.register_problem
class Translate_JAEN_Domain(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return 2**15  # 32k

    @property
    def is_generate_per_split(self):
        return False

    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]

    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        filename_jp = "../data/train-1.txt.ja.fixed.withdomain"
        filename_en = "../data/train-1.txt.en.fixed"
        with open(filename_jp) as f_jp, open(filename_en) as f_en:
            for src, tgt in zip(f_jp, f_en):
                src = src.strip()
                tgt = tgt.strip()
                if not src or not tgt:
                    continue
                yield {'inputs': src, 'targets': tgt}

訓練

echo "datagen1"
t2t-datagen \
    --data_dir=../model_data/data_jaen_1 \
    --tmp_dir=/tmp/jaen_1 \
    --problem=translate_jaen \
    --t2t_usr_dir=../model

echo "datagen2"
t2t-datagen \
    --data_dir=../model_data/data_jaen_domain_1 \
    --tmp_dir=/tmp/jaen_domain_1 \
    --problem=translate_jaen__domain \
    --t2t_usr_dir=../model

echo "trainer1"
t2t-trainer \
    --data_dir=../model_data/data_jaen_1 \
    --problem=translate_jaen \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --train_steps=250000 \
    --batch_size=8000 \
    --output_dir=../training_results/training_result_jaen_1 \
    --t2t_usr_dir=../model

echo "trainer2"
t2t-trainer \
    --data_dir=../model_data/data_jaen_domain_1 \
    --problem=translate_jaen__domain \
    --model=transformer \
    --hparams_set=transformer_base_single_gpu \
    --train_steps=250000 \
    --batch_size=8000 \
    --output_dir=../training_results/training_result_jaen_domain_1 \
    --t2t_usr_dir=../model

avg_checkpoint

pushd ../training_results
python3 avg_checkpoints.py \
    --checkpoints="./training_result_jaen_domain_1/model.ckpt-250000,./training_result_jaen_domain_1/model.ckpt-249000,./training_result_jaen_domain_1/model.ckpt-248000,./training_result_jaen_domain_1/model.ckpt-247000,./training_result_jaen_domain_1/model.ckpt-246000,./training_result_jaen_domain_1/model.ckpt-245000,./training_result_jaen_domain_1/model.ckpt-244000,./training_result_jaen_domain_1/model.ckpt-243000,./training_result_jaen_domain_1/model.ckpt-242000,./training_result_jaen_domain_1/model.ckpt-241000," --output_path="./model"
mv model-0.* training_result_jaen_domain_2
mv checkpoint training_result_jaen_domain_2
popd

デコード

echo "JAEN 1"
t2t-decoder \
    --data_dir ../model_data/data_jaen_1 \
    --problem translate_jaen \
    --model transformer \
    --hparams_set=transformer_base_single_gpu \
    --output_dir=../training_results/training_result_jaen_1 \
    --decode_hparams="beam_size=4,alpha=0.6" \
    --decode_from_file=../data/test.txt.ja.fixed \
    --decode_to_file=../data/test_out/test.hyp.1.en \
    --t2t_usr_dir=../model


echo "JAEN2"
t2t-decoder \
    --data_dir ../model_data/data_jaen_domain_1 \
    --problem translate_jaen__domain \
    --model transformer \
    --hparams_set=transformer_base_single_gpu \
    --output_dir=../training_results/training_result_jaen_domain_1 \
    --decode_hparams="beam_size=4,alpha=0.6" \
    --decode_from_file=../data/test.txt.ja.fixed.withdomain \
    --decode_to_file=../data/test_out/test.hyp.2.en \
    --t2t_usr_dir=../model

echo "JAEN3"
t2t-decoder \
    --data_dir ../model_data/data_jaen_domain_1 \
    --problem translate_jaen__domain \
    --model transformer \
    --hparams_set=transformer_base_single_gpu \
    --output_dir=../training_results/training_result_jaen_domain_2 \
    --decode_hparams="beam_size=4,alpha=0.6" \
    --decode_from_file=../data/test.txt.ja.fixed.withdomain \
    --decode_to_file=../data/test_out/test.hyp.3.en \
    --t2t_usr_dir=../model

評価

echo "baseline"
t2t-bleu \
    --translation=../data/test_out/test.hyp.1.en \
    --reference=../data/test.txt.en.fixed 
echo ""
echo "with domain tag"
t2t-bleu \
    --translation=../data/test_out/test.hyp.2.en \
    --reference=../data/test.txt.en.fixed 
echo ""
echo "ensemble"
t2t-bleu \
    --translation=../data/test_out/test.hyp.3.en \
    --reference=../data/test.txt.en.fixed 

結果

baseline
BLEU_uncased =  28.04
BLEU_cased =  26.99

with domain tag
BLEU_uncased =  28.81
BLEU_cased =  27.73

ensemble
BLEU_uncased =  29.29
BLEU_cased =  28.20

考察

このように、ドメインタグを追加すると、BLUEスコアが上がるということがわかります。さらに、checkpointのアンサンブルを作成すると、よりスコアが上がります。ドメインタグを追加するだけではなく、タスクタグを追加する方法もあるようです2。目的のタスクとは直接関係しないタスクを同時学習することにより、BLUEは上がるようです。

Domain ControlによってBLUEスコアが上がる理由としては、ドメインに固有の単語、文法、スタイルなどが使われているためだと考えられます。一つのデータセットが複数のドメインに分類できるとき、これらのドメインを表すタグを付与することによってドメインを認識できるようです。

参考