データナード

機械学習と自然言語処理についての備忘録 (旧ナード戦隊データマン)

cc_net的な方法でlangstatsを生成する

cc_netとは、CommonCrawlからモノリンガルコーパスを生成するためのツールです。langstatsを生成するのにも使えそうですが、slurmによる分散実行部分のコードが欠損しており、sbatchにcc_netを渡すには面倒なコードになってしまいます。そのため、cc_net的な方法で実行される単一のpythonスクリプトを作成しました。

github.com

概要

f:id:mathgeekjp:20191224140811j:plain

  1. shardごとにwetファイルをダウンロードする。(1 shard=50 WET)
  2. wetファイルからパラグラフごとのハッシュ値生成。
  3. ハッシュ値をもとに、重複を排除しながら、言語検出もしながら、言語ごとにjson listファイルへデータを出力。
  4. json listファイルを読み込み、言語モデルでスコアリングして最終出力(langstat)を出力。

なぜこんな複雑なことをしなければいけないのか

ccnetは質のよいモノリンガルコーパスを作成するために、重複排除と言語モデルのスコアリングを行っています。今回ほしいのはlangstatだけですが、「質のよいデータが各ドメインにどれくらいあるのか」を判断するために、ccnetが使えると考えます。

json listへ出力する理由は、単に空間効率を上げるためです。これは、ジェネレータの連鎖によって空間効率の向上をしていますが、言語モデルによるスコアリングは言語ごとに分割して実行する必要があり、この処理の部分をメモリ上で行うのが高価であるため、ファイルへ一旦出力するようにしています。

cc_netのslurmが使えない

cc_netのslurmを使う部分のコードが「Facebook内部で使うモジュール」であるため、我々部外者が使うことはできません。このコードは過度にcc_netに統合されすぎていて、sbatchコマンド等で軽く実行するようなことができません。

そのため、cc_netをもっとシンプルに書き換えることを考えました。

cc_netを参考にしたスクリプト

import gc
import sys
import hashlib
from collections import defaultdict
from tqdm import tqdm
import gzip
from multiprocessing.pool import Pool
from multiprocessing import Process
import fasttext
import os
import text_normalizer
import kenlm
import sentencepiece as spm
import json
import random
import string
from functools import partial
 
bin_dir = sys.argv[1]
lid_model = fasttext.load_model(os.path.join(bin_dir, "lid.bin"))
 
ls1 = {x.split(".")[0] for x in os.listdir(os.path.join(bin_dir, "lm_sp"))
       if x.endswith(".arpa.bin")}
ls2 = {x.split(".")[0] for x in os.listdir(os.path.join(bin_dir, "lm_sp"))
       if x.endswith(".sp.model")}
langs = ls1 & ls2
 
lm = None
sp = None
 
 
def _file_loader(fname):
    with gzip.open(fname, "rt") as f:
        for line in f:
            yield line
 
 
def _file_loader_bulk(fnames):
    for fname in fnames:
        for line in _file_loader(fname):
            yield line
 
 
def _corpus_loader(line_generator):
    wstr = "WARC/1.0"
    header_mode = False
    for line in line_generator:
        line = line.strip()
        if not header_mode and not line:
            yield line, None
            continue
        elif not header_mode and line == wstr:
            header_mode = True
        elif header_mode:
            if not line:
                header_mode = False
                yield line, None
                continue
        yield line, header_mode
 
 
def _corpus_loader_dedup(line_generator, hashes):
    wstr = "WARC/1.0"
    ustr = "WARC-Target-URI"
    header_mode = None
    out = []
    url = None
    domain = None
    for line in line_generator:
        line = line.strip()
        if line == wstr:
            header_mode = True
            if out:
                if domain is not None and url is not None:
                    yield {"url": url, "domain": domain, "data": out}
                url = None
                domain = None
                out = []
            continue
        if header_mode:
            if line.startswith(ustr):
                url = line.split(ustr+":")[1].strip()
                domain = url.split("//")[1].split("/")[0]
            if line:
                continue
            else:
                header_mode = False
        else:
            h = hashes[hashlib.sha1(bytes(line.lower(), encoding="utf-8")).digest()]
            if h < 2:
                out.append(line)
 
 
def _detect_lang(batch):
    out = {}
    lscores = []
    for line in batch["data"]:
        tmp_pred = lid_model.predict(line)
        lang = tmp_pred[0][0].split("__")[-1]
        lscores.append(float(tmp_pred[1][0]))
        if lang not in out:
            out[lang] = []
        out[lang].append(line)
    lscore = sum(x/len(lscores) for x in lscores)
    out = max(out.items(), key=lambda x: len(x[1]))
    return {"lang": out[0],
            "language_score": float(lscore),
            "length": len(''.join(out[1])),
            "url": batch["url"],
            "domain": batch["domain"],
            "data": batch["data"]}
 
 
def randomString(stringLength=10):
    """Generate a random string of fixed length """
    letters = string.ascii_lowercase
    return ''.join(random.choice(letters) for i in range(stringLength))
 
 
def _save_to_tmp(result, tmp_dir, fprefix, langs):
    if result["lang"] in langs:
        with open(os.path.join(
                tmp_dir, fprefix+"_{}".format(result["lang"])), "a") as f:
            f.write(json.dumps(result)+"\n")
 
 
def _initializer(lm_s, sp_s):
    global lm, sp
    lm = lm_s
    sp = sp_s
 
 
def _load_lm_bulk(langs):
    global shared_data
    pool = Pool(os.cpu_count(), _initializer, ())
    load_func = partial(_load_lm, bin_dir=bin_dir)
    pool.map(load_func, langs)
    pool.close()
 
 
def _jl_loader(tmp_dir, fprefix, lang):
    with open(os.path.join(
            tmp_dir, fprefix+"_{}".format(lang))) as f:
        for line in f:
            yield line
 
 
def _output(results, score_outpath, langstat_outpath):
    out = {}
    sep = "_____"
    with open(score_outpath, "a") as f:
        for result in tqdm(results):
            d = result["domain"] + sep + result["lang"] 
            if d not in out:
                out[d] = 0
            out[d] += result["length"]
            f.write("{}\t{}\t{}\t{}\t{}\n".format(
                result["url"], result["domain"], result["lang"],
                result["language_score"], result["perplexity"]))
    with open(langstat_outpath, "a") as f:
        for key, value in tqdm(out.items()):
            key = key.split(sep)
            if len(key) == 2:
                f.write('{}\t{}\t{}\n'.format(key[0], key[1], value))
 
 
def _add_lang_score(line):
    global lm, sp
    result = json.loads(line.strip())
    doc_score = 0
    doc_length = 0
    for line in result["data"]:
        line = text_normalizer.normalize(line)
        pieces = ' '.join(sp.EncodeAsPieces(line))
        if len(pieces):
            doc_score += lm.score(' '.join(pieces))
            doc_length += len(pieces)
    if doc_length:
        result["perplexity"] = 10.0**(-doc_score/doc_length)
    else:
        result["perplexity"] = 0.0
    del(result["data"])
    return result
 
 
def _add_lang_score_bulk(line_gen, lang,
                         score_outpath, langstat_outpath, bin_dir):
    lm_s, sp_s = _load_lm(lang, bin_dir)
    pool = Pool(os.cpu_count(), _initializer, (lm_s, sp_s))
    _output(pool.imap_unordered(_add_lang_score, line_gen),
            score_outpath, langstat_outpath)
    pool.close()
    gc.collect()
 
 
def _create_hash(fname):
    hashes = defaultdict(int)
    for line, mode in tqdm(_corpus_loader(_file_loader(fname))):
        if mode is not None and not mode:
            hashes[hashlib.sha1(bytes(line.lower(), encoding="utf-8")).digest()] += 1
    return hashes
 
 
def create_hashes(files):
    pool = Pool(os.cpu_count())
    hashes_list = pool.map(_create_hash, files)
    hashes = defaultdict(int)
    for h in hashes_list:
        hashes.update(h)
    pool.close()
    return hashes
 
 
def _load_lm(lang, bin_dir):
    kpath = os.path.join(bin_dir, "lm_sp", lang+".arpa.bin")
    spath = os.path.join(bin_dir, "lm_sp", lang+".sp.model")
    lm = kenlm.Model(kpath)
    sp = spm.SentencePieceProcessor()
    sp.Load(spath)
    return lm, sp
 
 
def _group_n(lis, n):
    for i in range(0, len(lis), n):
        yield lis[i:i+n]
 
 
def _save_bulk(results, tmp_dir, fprefix, langs):
    _save_func = partial(_save_to_tmp,
                         tmp_dir=tmp_dir, fprefix=fprefix, langs=langs)
    for result in results:
        _save_func(result)
    del(results)
    gc.collect()
 
 
class LoaderProxy:
    def __init__(self, loader):
        self._loader = loader
 
    def __iter__(self):
        for result in self._loader:
            yield result
 
 
def _check_process(ps):
    stack = []
    for i, p in enumerate(ps):
        if p.is_alive():
            continue
        else:
            p.close()
            stack.append(i)
    for i in stack[::-1]:
        ps.pop(i)
    return ps
            
            
def _parallel_s(files, hashes, tmp_dir, fprefix, langs):
    loaders = [LoaderProxy((_detect_lang(x)
                            for x in _corpus_loader_dedup(
                                    tqdm(_file_loader(fname)), hashes)))
               for fname in files]
    _save_func = partial(_save_bulk,
                         tmp_dir=tmp_dir,
                         fprefix=fprefix,
                         langs=langs)
    ps = []
    while loaders or ps:
        ps = _check_process(ps)
        while len(ps) < os.cpu_count():
            if loaders:
                loader = loaders.pop(0)
                p = Process(target=_save_func, args=(loader, ))
                p.start()
                ps.append(p)
            else:
                ps = _check_process(ps)
                if not ps:
                    break

    del(loaders)
    gc.collect()
 
 
def _parallel_t(fprefix, score_outpath, langstat_outpath, tmp_dir="./tmp"):
    target_langs = list({x.split("_")[-1] for x in os.listdir(tmp_dir)
                         if x.startswith(fprefix)})
    for lang in target_langs:
        loader = _jl_loader(tmp_dir, fprefix, lang)
        _add_lang_score_bulk(loader,
                             lang,
                             score_outpath,
                             langstat_outpath,
                             bin_dir)
        
    
def main(score_outpath, langstat_outpath, tmp_dir="./tmp"):
    fprefix = randomString(20)
    files = [x.strip() for x in sys.stdin]
    hashes = create_hashes(files)
    _parallel_s(files, hashes, tmp_dir, fprefix, langs)
    _parallel_t(fprefix, score_outpath, langstat_outpath, tmp_dir)
 
 
if __name__ == "__main__":
    main(sys.argv[2], sys.argv[3])

gist: https://gist.github.com/sugiyamath/2fa4ce8a1bd1f02203eee082c241886e

補足

このコードは、主要な部分の処理だけを記述しているため、以下の処理は追加で書く必要があります:

  1. shardごとのWETファイルの並列ダウンロード。
  2. 並列ダウンロードしたファイルのパスのリストを標準入力からlangstat_generatorにわたす部分。
  3. shardごとに出力ファイル名を変えてコマンドライン引数に渡す部分。
  4. AWS parallel clusterで実行するために、sbatchへ渡せるように、1)すべての処理を単一のスクリプトにまとめる, 2)利用している依存関係を解決するために、主要部分のpythonコードをpyinstallerなどを使って単一バイナリに落とし込む。
  5. sbatchに渡すノード数,CPU数,タスク数等。

参考

  1. GitHub - facebookresearch/cc_net: Tools to download and cleanup Common Crawl data