ナード戦隊データマン

データサイエンスを用いて悪と戦うぞ

Word-sense disambiguation: 語義の曖昧性解消問題

曖昧性解消とは、ある単語が複数の意味を持つとき、文脈に応じて適切なエンティティを選択することです。今回は、前回作成したスクリプトの続きとして作成します。

仕組み

無題の図形描画.jpg

  1. node=(entity, mention), edge=(node1, node2)とした共起グラフを構築。
  2. あるメンションに紐づくエンティティリストをn個取得。
  3. そのメンションとエンティティのペアn個(e1,m),(e2,m),...,(en,m)を共起グラフから各々検索。
  4. 各々の検索によって、紐付いているノードを取得し、ノードの件数と、ノードのメンションを取得。
  5. ノードの件数を使ってある値pを計算。
  6. ノードに紐づくメンションの平均ベクトルと、文脈の平均ベクトルのコサイン類似度を値qとする。
  7. ある値alphaに対し、r = alphap+(1-alpha)qを計算。
  8. rを正規化.

実行

前回のスクリプト実行フロー

git clone https://github.com/sugiyamath/entity_types_scripts
cd entity_types_scripts
wget http://downloads.dbpedia.org/2016-10/core/instance_types_en.ttl.bz2
wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
bunzip2 *.bz2
mv enwiki-latest-pages-articles.xml dump
python build_types.py
python extract_mention.py
python json2marisa.py
python mprob.py "Obama"

以下のスクリプトは、前回作成したスクリプトをすべて実行した上で実行してください。

python mention_and_graph.py
python pkl2marisa.py
python extract_graph.py
python build_graph.py
python create_index.py
pythoon eprob.py "Galileo" "Japan,TV" 0.4

実行例

python eprob.py Galileo Japan,TV 0.4
{'Galileo_(1968_film)': 0.08162450203592601,
 'Galileo_(1975_film)': 0.07264921641693314,
 'Galileo_(1994_film)': 0.06521374088908712,
 'Galileo_(TV_series)': 0.18111023032336127,
 'Galileo_(horse)': 0.09737876204706845,
 'Galileo_(operating_system)': 0.04316242003252071,
 'Galileo_(satellite_navigation)': 0.13500750427325464,
 'Galileo_(song)': 0.03378523117010994,
 'Galileo_(spacecraft)': 0.12453115091750815,
 'Galileo_Galilei': 0.07404413908764415,
 'Intel_Galileo': 0.06936555756415227,
 'Life_of_Galileo': 0.022127545242434238}

python eprob.py Galileo person,history 0.4
{'Galileo_(1968_film)': 0.044064635682177,
 'Galileo_(1975_film)': -0.03600631439106292,
 'Galileo_(1994_film)': 0.15265454663089653,
 'Galileo_(TV_series)': 0.0013680732885760023,
 'Galileo_(horse)': 0.10470165517486565,
 'Galileo_(operating_system)': 0.0005596730220754116,
 'Galileo_(satellite_navigation)': 0.0738529599122509,
 'Galileo_(song)': 0.05449777719438314,
 'Galileo_(spacecraft)': 0.16406225919086243,
 'Galileo_Galilei': 0.37058668171965764,
 'Intel_Galileo': 0.07787230258914513,
 'Life_of_Galileo': -0.00821425001382678}

コードの中身

mention_and_graph.py

(entity, mention) というペアをキーとしてIDを振ります。

# coding: utf-8
import re
from tqdm import tqdm
import pickle


def extract_mention_and_entity(exp):
    tmp = exp[2:-2]
    tmp2 = tmp[0].upper() + tmp[1:]
    if "|" in tmp2:
        entity, mention = tmp2.split("|")
        mention = mention.strip()
    else:
        entity = tmp2[:]
        mention = tmp[:]
    entity = entity.strip()
    entity = entity.replace(" ", "_")
    return entity, mention


if __name__ == "__main__":
    reg = re.compile(r"\[\[.+?\]\]")
    out = {}
    counter = 0
    with open("dump", errors='ignore') as f1:
        for line in tqdm(f1):
            ents = []
            mentions = []
            for x in re.findall(reg, line):
                try:
                    entity, mention = extract_mention_and_entity(x)
                except Exception:
                    continue
                key = (entity, mention)
                if key in out:
                    continue
                out[key] = counter
                counter += 1

    with open("me2id.pkl", "wb") as f2:
        pickle.dump(out, f2)

pkl2marisa.py

me2id.pklをmarisaに変換します。

from marisa_trie import BytesTrie
import pickle
from tqdm import tqdm

if __name__ == "__main__":
    with open("./me2id.pkl", "rb") as f:
        em2id = pickle.load(f)

    trie = BytesTrie(
        [(str(x[0]), bytes(str(x[1]), 'utf-8')) for x in tqdm(em2id.items())])
    trie.save("em2id.marisa")

    trie = BytesTrie(
        [(str(x[1]), bytes(str(x[0]), 'utf-8'))
         for x in tqdm(em2id.items())])
    trie.save("id2em.marisa")

extract_graph.py

共起するメンション,エンティティのペアのグラフをtxt形式で出力します。

# coding: utf-8
import re
from tqdm import tqdm
import pickle
from marisa_trie import BytesTrie


def extract_mention_and_entity(exp):
    tmp = exp[2:-2]
    tmp2 = tmp[0].upper() + tmp[1:]
    if "|" in tmp2:
        entity, mention = tmp2.split("|")
        mention = mention.strip()
    else:
        entity = tmp2[:]
        mention = tmp[:]
    entity = entity.strip()
    entity = entity.replace(" ", "_")
    return entity, mention


if __name__ == "__main__":
    trie = BytesTrie()
    trie.load("./types.marisa")
    with open("me2id.pkl", "rb") as f:
        em2id = pickle.load(f)

    reg = re.compile(r"\[\[.+?\]\]")
    out = {}
    with open("dump", errors='ignore') as f1:
        with open("graph.txt", "w") as f2:
            for line in tqdm(f1):
                ents = []
                mentions = []
                for x in re.findall(reg, line):
                    try:
                        entity, mention = extract_mention_and_entity(x)
                        trie[entity]
                        ents.append(entity)
                        mentions.append(mention)
                    except Exception:
                        continue
                try:
                    assert len(ents) == len(mentions)
                except AssertionError:
                    continue

                pairs = sorted(list(zip(ents, mentions)))
                for i in range(len(pairs) - 1):
                    for j in range(i + 1, len(pairs)):
                        pair = [em2id[tuple(pairs[i])], em2id[tuple(pairs[j])]]
                        f2.write(str(pair[0]) + "\t" + str(pair[1]) + "\n")

build_graph.py

txtからsqliteへ格納します。

from tqdm import tqdm
import sqlite3


def create_table(conn):
    c = conn.cursor()
    sql = """
create table if not exists graph (
    id integer primary key,
    from_id integer NOT NULL,
    to_id integer NOT NULL
);
"""
    c.execute(sql)


def insert_graph(conn, f, t):
    c = conn.cursor()
    sql = "insert into graph(from_id,to_id) values (?,?)"
    c.execute(sql, (f, t))


if __name__ == "__main__":
    debug = False
    conn = sqlite3.connect("db.sqlite3")
    create_table(conn)
    with open("graph.txt") as f:
        for line in tqdm(f):
            line = list(map(int, line.strip().split("\t")))
            insert_graph(conn, line[0], line[1])
            insert_graph(conn, line[1], line[0])
    conn.commit()

create_index.py

インデクスを作成します。

import sqlite3


def create_index(conn):
    c = conn.cursor()
    sql1 = "create index index_from_id_graph on graph(from_id)"
    sql2 = "create index index_to_id_graph on graph(to_id)"
    c.execute(sql1)
    c.execute(sql2)
    conn.commit()


if __name__ == "__main__":
    conn = sqlite3.connect("./db.sqlite3")
    create_index(conn)

comention.py

モジュール内で使われるサブモジュールです。

# coding: utf-8
from marisa_trie import BytesTrie
import sqlite3
import json
from typing import Dict, Tuple, List


def get_comentions(conn: sqlite3.Connection,
                   mention: str,
                   mstat: BytesTrie,
                   em2id: BytesTrie,
                   id2em: BytesTrie) -> Dict[str, List[Tuple[str, str]]]:
    out = {}
    c = conn.cursor()
    sql = "select to_id from graph where from_id=?"
    es = list(json.loads((mstat[mention][0].decode())).keys())
    ids = [int(em2id[str((e, mention))][0]) for e in es]
    for idx, e in zip(ids, es):
        c.execute(sql, (idx, ))
        tmp = [eval(id2em[str(x[0])][0].decode()) for x in c.fetchall()]
        if tmp:
            out[e] = tmp
    return out


if __name__ == "__main__":
    import sys
    from pprint import pprint
    conn = sqlite3.connect("./db.sqlite3")
    trie = BytesTrie().load("./mention_stat.marisa")
    trie2 = BytesTrie().load("./em2id.marisa")
    trie3 = BytesTrie().load("./id2em.marisa")
    pprint(get_comentions(conn, sys.argv[1], trie, trie2, trie3))

eprob.py

モジュール本体です。

#coding: utf-8
import sqlite3
import numpy as np
from marisa_trie import BytesTrie
from comention import get_comentions


def calc_prob1(coms):
    out = {}
    total = 0
    for k, v in coms.items():
        val = len(v)
        out[k] = val
        total += val
    return {k: float(v) / float(total) for k, v in out.items()}


def cossim(v1, v2):
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))


def softmax_dict(d):
    ks, vs = [], []
    for k, v in d.items():
        ks.append(k)
        vs.append(v)
    vs = np.array(vs)
    m = np.max(vs)
    ex = np.exp(vs - m)
    result = ex / ex.sum(axis=0)
    return dict(zip(ks, result.tolist()))


def calc_prob2(model, coms, contexts):
    out = {}
    w = []
    for c in contexts:
        try:
            w.append(model.wv[c])
        except Exception:
            continue
    w = np.mean(w, axis=0)
    for k, xss in coms.items():
        v = []
        for xs in xss:
            for x in xs[1].split():
                try:
                    v.append(model.wv[x])
                except Exception:
                    continue
        v = np.mean(v, axis=0)
        out[k] = cossim(v, w)
    return out


def calc_entityprob(prob1, prob2, alpha=0.5):
    out = {}
    total = 0
    for k, v1 in prob1.items():
        v2 = prob2[k]
        value = alpha * v1 + (1 - alpha) * v2
        out[k] = value
        total += value
    return {k: float(v) / float(total) for k, v in out.items()}


def compute(mention, contexts, model, conn, mstat, em2id, id2em, alpha=0.5):
    coms = get_comentions(conn, mention, mstat, em2id, id2em)
    prob1 = calc_prob1(coms)
    prob2 = calc_prob2(model, coms, contexts)
    return calc_entityprob(prob1, prob2, alpha)


if __name__ == "__main__":
    import sys
    from pprint import pprint
    from gensim.models import KeyedVectors
    mention = sys.argv[1]
    contexts = sys.argv[2].split(",")
    alpha = float(sys.argv[3])
    conn = sqlite3.connect("./db.sqlite3")
    trie = BytesTrie().load("./mention_stat.marisa")
    trie2 = BytesTrie().load("./em2id.marisa")
    trie3 = BytesTrie().load("./id2em.marisa")
    model = KeyedVectors.load("./enwiki_model/word2vec.model", mmap="r")
    pprint(compute(mention, contexts, model, conn,
                   trie, trie2, trie3, alpha=alpha))