Source code for kiwi.data.vocabulary

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import warnings
from collections import defaultdict

import torchtext

from kiwi.constants import PAD, START, STOP, UNALIGNED, UNK, UNK_ID


def _default_unk_index():
    return UNK_ID  # should be zero


[docs]class Vocabulary(torchtext.vocab.Vocab): """Defines a vocabulary object that will be used to numericalize a field. Attributes: freqs: A collections.Counter object holding the frequencies of tokens in the data used to build the Vocab. stoi: A collections.defaultdict instance mapping token strings to numerical identifiers. itos: A list of token strings indexed by their numerical identifiers. """ def __init__( self, counter, max_size=None, min_freq=1, specials=None, vectors=None, unk_init=None, vectors_cache=None, rare_with_vectors=True, add_vectors_vocab=False, ): """Create a Vocab object from a collections.Counter. Arguments: counter: collections.Counter object holding the frequencies of each value found in the data. max_size: The maximum size of the vocabulary, or None for no maximum. Default: None. min_freq: The minimum frequency needed to include a token in the vocabulary. Values less than 1 will be set to 1. Default: 1. specials: The list of special tokens (e.g., padding or eos) that will be prepended to the vocabulary in addition to an <unk> token. Default: ['<pad>'] vectors: One of either the available pretrained vectors or custom pretrained vectors (see Vocab.load_vectors); or a list of aforementioned vectors unk_init (callback): by default, initialize out-of-vocabulary word vectors to zero vectors; can be any function that takes in a Tensor and returns a Tensor of the same size. Default: torch.Tensor.zero_ vectors_cache: directory for cached vectors. Default: '.vector_cache' rare_with_vectors: if True and a vectors object is passed, then it will add words that appears less than min_freq but are in vectors vocabulary. Default: True. add_vectors_vocab: by default, the vocabulary is built using only words from the provided datasets. If this flag is true, the vocabulary will add words that are not in the datasets but are in the vectors vocabulary (e.g. words from polyglot vectors). Default: False. """ if specials is None: specials = ['<pad>'] self.freqs = counter counter = counter.copy() min_freq = max(min_freq, 1) self.itos = list(specials) # frequencies of special tokens are not counted when building vocabulary # in frequency order for tok in specials: del counter[tok] max_size = None if max_size is None else max_size + len(self.itos) # sort by frequency, then alphabetically words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) if not isinstance(vectors, list) and vectors is not None: vectors = [vectors] # add words that appears less than min_freq but are in embeddings # vocabulary for word, freq in words_and_frequencies: if freq < min_freq: if vectors is not None and rare_with_vectors: for v in vectors: if word in v.stoi: self.itos.append(word) else: break elif len(self.itos) == max_size: break else: self.itos.append(word) if add_vectors_vocab: if ( max_size is not None and sum(v.stoi for v in vectors) + len(self.itos) > max_size ): warnings.warn( 'Adding the vectors vocabulary will make ' 'len(vocab) > max_vocab_size!' ) vset = set() for v in vectors: vset.update(v.stoi.keys()) v_itos = vset - set(self.itos) self.itos.extend(list(v_itos)) self.stoi = defaultdict(_default_unk_index) # stoi is simply a reverse dict for itos self.stoi.update({tok: i for i, tok in enumerate(self.itos)}) self.vectors = None if vectors is not None: self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache) else: assert unk_init is None and vectors_cache is None
[docs]def merge_vocabularies(vocab_a, vocab_b, max_size=None, vectors=None, **kwargs): merged = vocab_a.freqs + vocab_b.freqs return Vocabulary( merged, specials=[UNK, PAD, START, STOP, UNALIGNED], max_size=max_size, vectors=vectors, **kwargs, )