ナード戦隊データマン

機械学習, 自然言語処理, データサイエンスについてのブログ

type probability: エンティティリンキングに使えそうな特徴量

DBPediaには、各エンティティのタイプ情報を持つデータがあります。今回は、Wikipediaから各メンションがどのエンティティと紐付いているかを統計的に算出し、その上でDBPediaと結びつけることで、語のタイプ確率を求めます。

実行

git clone https://github.com/sugiyamath/entity_types_scripts
cd entity_types_scripts
wget http://downloads.dbpedia.org/2016-10/core/instance_types_en.ttl.bz2
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
bunzip2 *.bz2
mv enwiki-latest-pages-articles.xml dump
python build_types.py
python extract_mention.py
python json2marisa.py
python mprob.py "Obama"

出力例

python mprob.py "Obama"
{'Animal': 0.007423904974016332,
 'City': 0.05493689680772086,
 'MusicalArtist': 0.0007423904974016332,
 'President': 0.8916109873793615,
 'RailwayLine': 0.0007423904974016332,
 'School': 0.0007423904974016332,
 'SoccerPlayer': 0.011878247958426132,
 'owl#Thing': 0.03192279138827023}

 python mprob.py "Einstein"
{'Album': 0.003913894324853229,
 'ArtificialSatellite': 0.0136986301369863,
 'Congressman': 0.0019569471624266144,
 'Film': 0.0019569471624266144,
 'InformationAppliance': 0.0019569471624266144,
 'Person': 0.009784735812133072,
 'RaceHorse': 0.05675146771037182,
 'School': 0.005870841487279843,
 'Scientist': 0.8473581213307241,
 'Software': 0.009784735812133072,
 'Song': 0.007827788649706457,
 'Station': 0.0136986301369863,
 'TelevisionShow': 0.01761252446183953,
 'University': 0.005870841487279843,
 'owl#Thing': 0.0019569471624266144}

 python mprob.py "Kyoto"
{'AdministrativeRegion': 0.08986175115207373,
 'Award': 0.00030721966205837174,
 'City': 0.8935483870967742,
 'Company': 0.00030721966205837174,
 'Country': 0.00015360983102918587,
 'Diocese': 0.00015360983102918587,
 'GolfTournament': 0.00030721966205837174,
 'HistoricBuilding': 0.00030721966205837174,
 'Museum': 0.0006144393241167435,
 'PublicTransitSystem': 0.0004608294930875576,
 'Racecourse': 0.0015360983102918587,
 'RailwayLine': 0.0009216589861751152,
 'Single': 0.001075268817204301,
 'SoccerClub': 0.00030721966205837174,
 'Song': 0.001228878648233487,
 'Station': 0.0006144393241167435,
 'University': 0.004301075268817204,
 'WorldHeritageSite': 0.0004608294930875576,
 'owl#Thing': 0.003533026113671275}

 python mprob.py "Google"
{'Company': 0.9876835622927522,
 'InformationAppliance': 0.00023685457129322596,
 'Organisation': 7.895152376440866e-05,
 'Software': 0.0007105637138796779,
 'Website': 0.010342649613137533,
 'owl#Thing': 0.0009474182851729038}

python mprob.py "Python"
{'ComedyGroup': 0.005298318359824925,
 'Film': 0.0055286800276434,
 'ProgrammingLanguage': 0.9518544114259387,
 'Reptile': 0.008062658373646624,
 'RollerCoaster': 0.0034554250172771253,
 'Software': 0.000691085003455425,
 'TelevisionShow': 0.000230361667818475,
 'Weapon': 0.00598940336328035,
 'owl#Thing': 0.01888965676111495}

python mprob.py "Andy Hunt"
{'SoccerPlayer': 0.8780487804878049, 'Writer': 0.12195121951219512}

コードの中身

build_types.py

dbpediaのtypesをjsonへ。

import re
from tqdm import tqdm
import json

def load(filename):
    out = {}
    out2 = {}
    with open(filename) as f:
        for line in tqdm(f):
            if line.startswith("<") and '__' not in line:
                line = line.split()
                entity = line[0].split("/")[-1][:-1]
                if entity not in out:
                    out[entity] = []
                out[entity].append(line[2].split("/")[-1][:-1])
    for k,vs in tqdm(out.items()):
        out2[k] = list(set(vs))
    return out2


def save(out, filename):
    with open(filename, "w") as f:
        json.dump(out, f, indent=4, sort_keys=True)

if __name__ == "__main__":
    save(load("./instance_types_en.ttl"), "out.json")

extract_mention.py

wikipediaのダンプからアンカーを取り出して、「表現(mention)」と「エンティティ(Wikipediaページ名)」に分けて、「ある表現があるエンティティである回数」の統計をとる。

# coding: utf-8
import re
from tqdm import tqdm

def extract_mention_and_entity(exp):
    tmp = exp[2:-2]
    tmp2 = tmp[0].upper() + tmp[1:]
    if "|" in tmp2:
        entity, mention = tmp2.split("|")
        mention = mention.strip()
    else:
        entity = tmp2[:]
        mention = tmp[:]
    entity = entity.strip()
    entity = entity.replace(" ", "_")
    return entity, mention

if __name__ == "__main__":
    import json
    reg = re.compile(r"\[\[.+?\]\]")
    out = {}
    with open("dump", errors='ignore') as f:
        for line in tqdm(f):
            exps = re.findall(reg, line)
            for exp in exps:
                try:
                    entity, mention = extract_mention_and_entity(exp)
                except:
                    continue
                if mention in out:
                    if entity in out[mention]:
                        out[mention][entity] += 1
                    else:
                        out[mention][entity] = 1
                else:
                    out[mention] = {}

    with open("mention_stat.json", "w") as f:
        json.dump(out, f)

json2marisa.py

生成したjsonデータをmarisa_trieへ変換。

import json
import sys
from marisa_trie import BytesTrie

if __name__ == "__main__":
    print("load types")
    with open("./types.json") as f:
        data = json.load(f)

    print("types to trie")
    trie = BytesTrie([(k,bytes(json.dumps(v), "utf-8")) for k,v in data.items()])

    print("saving...")
    trie.save("types.marisa")

    print("load mention_stat")
    with open("./mention_stat.json") as f:
        data = json.load(f)

    print("mention_stat to trie")
    trie = BytesTrie([(k,bytes(json.dumps(v), "utf-8")) for k,v in data.items()])

    print("saving...")
    trie.save("mention_stat.marisa")

    print("Done!")

mprob.py

モジュール本体。

# coding: utf-8
from typing import Tuple, Dict
from marisa_trie import BytesTrie
import json

Probs = Dict[str, float]

def typeprob(mention:str, mstat:BytesTrie, types:BytesTrie) -> Probs:
    """
    Calculate type probabilities of the mention.
    Returns probabilities as dictionary (keys are type, values are prob).
    pre: len(mention) > 0
    pre: len(mstat[mention]) > 0
    pre: type(json.loads(mstat[mention][0].decode()))) == dict
    """
    
    total = 0
    prob = {}
    stat = json.loads(mstat[mention][0].decode())
    for k,v in stat.items():
        try:
            enttypes = json.loads(types[k][0].decode())
        except:
            continue
        for enttype in enttypes:
            if enttype not in prob:
                prob[enttype] = 0
            prob[enttype] += int(v)
            total += int(v)
    return dict([(k,float(v)/float(total)) for k,v in prob.items()])


if __name__ == "__main__":
    import sys
    import pprint
    mstat = BytesTrie()
    types = BytesTrie()
    mstat.load("./mention_stat.marisa")
    types.load("./types.marisa")
    pprint.pprint(typeprob(sys.argv[1], mstat, types))

注意

このプログラムの実行には8G程度のメモリが必要です。