ナード戦隊データマン

機械学習と自然言語処理についてのブログ

Wikipediaから拡張固有表現を用いたCoNLL形式のデータを生成

拡張固有表現とは、PER, LOC, ORG, MISCのような少数のタイプではなく、より多くのタイプを定義したものです。

追記: 2019-06-25 11:09

IOBタグの定義を間違えました。 Bは先頭、Iは中間です。

CoNLL形式について

固有表現抽出でよく使われるデータには、CoNLL 2003があります。このデータは以下のような形式になっています:

EU I-ORG
rejects O
German I-MISC
call O
to O
boycott O
British I-MISC
lamb O
. . O O

IOBタグは、抽出対象の語かどうかと、その位置を表します。B(beginning)は語の先頭で、I(inside)は先頭よりも後ろです。Oは抽出対象ではありません。

IOBの後ろの文字列は、固有表現のタイプを表します。例えば、ORGならば組織名、PERなら人名、という具合です。

DBPedia instance types

DBPediaタイプは、エンティティのタイプを表します。以下からダウンロードできます。

http://downloads.dbpedia.org/2016-10/core/instance_types_en.ttl.bz2

今回は、このタイプを「拡張固有表現」のように扱うことを検討します。

事前準備: typeとエンティティの対応付け

ダウンロードしたinstance_types_en.ttl.bz2をjson形式にします。

import re
from tqdm import tqdm
import json

def load(filename):
    out = {}
    out2 = {}
    with open(filename) as f:
        for line in tqdm(f):
            if line.startswith("<") and '__' not in line:
                line = line.split()
                entity = line[0].split("/")[-1][:-1]
                if entity not in out:
                    out[entity] = []
                out[entity].append(line[2].split("/")[-1][:-1])
    for k,vs in tqdm(out.items()):
        out2[k] = list(set(vs))
    return out2


def save(out, filename):
    with open(filename, "w") as f:
        json.dump(out, f, indent=4, sort_keys=True)

if __name__ == "__main__":
    save(load("./instance_types_en.ttl"), "types.json")

WikipediaからCoNLL形式のデータを生成

さて、Wikipediaから拡張固有表現の訓練データを生成するために、以下のスクリプトを実行します:

import re
import json
import nltk
from tqdm import tqdm
 
reg1 = re.compile(r"\[\[(.+?)\]\]")
reg2 = re.compile(r"\{\{.+?\}\}")
reg3 = re.compile(r"\'\'\'(.+?)\'\'\'")
 
 
def fmt_line(line):
    mentions = []
    pairs = {}
    tmp_line = re.sub(reg1, r"\1", line)
    tmp_line = re.sub(reg2, "", tmp_line)
    tmp_line = re.sub(reg3, r"\1", tmp_line)
    tmp_line = tmp_line.replace("\n", "").strip()
    mentions += [x for x in re.findall(reg1, line)]
    mentions += [x for x in re.findall(reg3, line)]
    for k in mentions:
        if "|" in k:
            tmp = k.split("|")
            mention = tmp[1]
            entity = tmp[0]
            tmp_line = tmp_line.replace(k, mention)
        else:
            mention = k
            entity = k[0].upper() + k[1:]
        pairs[mention] = entity.replace(" ", "_")
    return tmp_line, pairs
 
 
def has_illegal_symbols(line):
    flag = False
    syms = list("[]{}|<>") + ["'''"]
    for sym in syms:
        flag = flag or sym in line
        if flag:
            break
    return flag
 
 
def line2conll(line, typedict, tokenize=nltk.word_tokenize):
    if "[[" not in line or "]]" not in line:
        return None, None
    line, pairs = fmt_line(line)
    if has_illegal_symbols(line):
        return None, None
    words = tokenize(line)
 
    prevs = []
    labels = ["O" for _ in range(len(words))]
    ents = ["" for _ in range(len(words))]
    for mention, entity in pairs.items():
        if entity in typedict:
            target_type = typedict[entity][0]
        else:
            continue
        target = tokenize(mention)
        for i, word in enumerate(words):
            if word == target[len(prevs)]:
                prevs.append(i)
            else:
                prevs = []
            if len(prevs) == len(target):
                ind = prevs[0]
                labels[ind] = "B-" + target_type
                ents[ind] = entity
                for prev in prevs[1:]:
                    labels[prev] = "I-" + target_type
                    ents[prev] = entity
                prevs = []
    return words, labels, ents
 
 
def run(datapath="./enwiki.xml",
        typepath="./types.json",
        outpath="enwiki.conll"):
    with open(typepath) as f:
        typedict = json.load(f)
 
    with open(datapath) as f:
        for line in tqdm(f):
            try:
                words, labels, ents = line2conll(line, typedict)
            except Exception:
                continue
            if words is None or len(words) < 2:
                continue
            if len(words) != len(labels):
                continue
            with open(outpath, "a") as f:
                for word, label, entity in zip(words, labels, ents):
                    f.write('\t'.join([word, label, entity]) + '\n')
                f.write('\n')
 
 
if __name__ == "__main__":
    run()

生成されたデータ

生成されるデータは以下のような形式になっています:

Anarchism       I-owl#Thing     Anarchism
is      O
an      O
anti-authoritarian      O
political       B-owl#Thing     Political_philosophy
philosophy      I-owl#Thing     Political_philosophy
that    O
advocates       O
self-managed    O
,       O
self-governed   O
societies       O
based   O
on      O
voluntary       O
,       O

ゴミ削除

前述のスクリプトは、ちゃんとパースしていないのでゴミが含まれてしまうかもしれません。それに備えて、ゴミ文を除去するスクリプトを作ります。

from tqdm import tqdm
import time
 
 
def has_illegal_symbol(sent):
    illegals = [
        "&", ";", "name=", "--", "'", "(", ")", "`"
    ]
    for word, tag, entity in sent:
        for sym in illegals:
            if sym in word:
                return True
    return False
 
 
def fix(datafile="./enwiki.conll", outfile="./enwiki_fixed.conll"):
    with open(datafile) as f:
        sent = []
        for line in tqdm(f):
            if "\t" not in line:
                continue
            try:
                word, tag, entity = line.strip().split("\t")
            except ValueError:
                word, tag, entity = line.strip().split("\t") + [""]
            sent.append([word, tag, entity])
            if word.endswith("."):
                if has_illegal_symbol(sent):
                    sent = []
                    continue
                if word != ".":
                    sent[-1][0] = word[:-1]
                    sent.append([".", "O", ""])
                with open(outfile, "a") as f:
                    for data in sent:
                        f.write('\t'.join(data) + "\n")
                    f.write("\n")
                sent = []
 
 
if __name__ == "__main__":
    fix()