データナード

機械学習と自然言語処理についての備忘録 (旧ナード戦隊データマン)

パラレルコーパスフィルタリングのMarginベーススコアリング

Marginベーススコアリング手法は、近傍探索を類似度に組み合わせた手法で、Facebook AI Researchによって論文が公開されています。

実装したいもの

Margin-based Parallel Corpus Mining with Multilingual Sentence Embeddings https://arxiv.org/pdf/1811.01136.pdf

簡易実装

import numpy as np


class Similarity:
    def __init__(self, src_index, tgt_index, k):
        self._index_s = src_index
        self._index_t = tgt_index
        self._k = k

    def calc(self, vs1, vs2):
        """Calculate similarity between two vectors"""
        return self._margin(vs1, vs2)

    def _cosine(self, v1, v2):
        top = np.dot(v1, v2)
        return top

    def _margin_single(self, v1, v2, m1, m2):
        return 2 * self._cosine(v1, v2) / (m1+m2)

    def _margin(self, vs1, vs2):
        ms1 = self._nncos(vs1, "src")
        ms2 = self._nncos(vs2, "tgt")
        for i, (v1, v2, m1, m2) in enumerate(zip(vs1, vs2, ms1, ms2)):
            yield self._margin_single(v1, v2, m1, m2)

    def _nncos(self, vs, v_type="src"):
        if v_type == "src":
            D, _ = self._index_t.search(np.array(vs, dtype=np.float32),
                                        self._k)
        elif v_type == "tgt":
            D, _ = self._index_s.search(np.array(vs, dtype=np.float32),
                                        self._k)
        else:
            raise Exception("Wrong v_type: {}".format(v_type))
        return D.mean(axis=1)

使い方

import numpy as np
import faiss
from similarity import Similarity

vs1 = np.load("vs1.npy")
vs2 = np.load("vs2.npy")
idx1 = faiss.read_index("idx1.faiss")
idx2 = faiss.read_index("idx2.faiss")

sim = Similarity(idx1, idx2, k=5)
scores = list(sim.calc(vs1, vs2))

課題

faissは高速な類似検索ですが、上記のデータが膨大な場合、実行に何日もかかってしまう恐れがあります。膨大なクエリに対処するためには以下の2点を考慮するのがよいかもしれません。

  • GPU版のfaissを使う。
  • IndexIVFFlatを使う。
method search time 1-R@1 index size index build time
Flat-CPU 9.100 s 1.0000 512 MB 0 s
Flat-GPU (Titan X) 0.753 s 0.9935 512 MB 0 s
IVF16384,Flat 0.538 s 0.8980 512 + 8 MB 240 s
IVF16384,Flat (Titan X) 0.059 s 0.8145 512 + 8 MB 5 s

時間計算量が高すぎるので、Margin-based手法を使うよりは、ドット積を使ったほうが高速に結果が得られます。この場合、速度と精度はトレードオフにあります。

参考

  1. Indexing 1M vectors · facebookresearch/faiss Wiki · GitHub
  2. https://arxiv.org/pdf/1811.01136.pdf