#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import warnings
from collections import defaultdict
import torchtext
from kiwi.constants import PAD, START, STOP, UNALIGNED, UNK, UNK_ID
def _default_unk_index():
    return UNK_ID  # should be zero
[docs]class Vocabulary(torchtext.vocab.Vocab):
    """Defines a vocabulary object that will be used to numericalize a field.
    Attributes:
        freqs: A collections.Counter object holding the frequencies of tokens
            in the data used to build the Vocab.
        stoi: A collections.defaultdict instance mapping token strings to
            numerical identifiers.
        itos: A list of token strings indexed by their numerical identifiers.
    """
    def __init__(
        self,
        counter,
        max_size=None,
        min_freq=1,
        specials=None,
        vectors=None,
        unk_init=None,
        vectors_cache=None,
        rare_with_vectors=True,
        add_vectors_vocab=False,
    ):
        """Create a Vocab object from a collections.Counter.
        Arguments:
            counter: collections.Counter object holding the frequencies of
                each value found in the data.
            max_size: The maximum size of the vocabulary, or None for no
                maximum. Default: None.
            min_freq: The minimum frequency needed to include a token in the
                vocabulary. Values less than 1 will be set to 1. Default: 1.
            specials: The list of special tokens (e.g., padding or eos) that
                will be prepended to the vocabulary in addition to an <unk>
                token. Default: ['<pad>']
            vectors: One of either the available pretrained vectors
                or custom pretrained vectors (see Vocab.load_vectors);
                or a list of aforementioned vectors
            unk_init (callback): by default, initialize out-of-vocabulary word
                vectors to zero vectors; can be any function that takes in a
                Tensor and returns a Tensor of the same size.
                Default: torch.Tensor.zero_
            vectors_cache: directory for cached vectors.
                Default: '.vector_cache'
            rare_with_vectors: if True and a vectors object is passed, then
                it will add words that appears less than min_freq but are in
                vectors vocabulary. Default: True.
            add_vectors_vocab: by default, the vocabulary is built using only
                words from the provided datasets. If this flag is true, the
                vocabulary will add words that are not in the datasets but are
                in the vectors vocabulary (e.g. words from polyglot vectors).
                Default: False.
        """
        if specials is None:
            specials = ['<pad>']
        self.freqs = counter
        counter = counter.copy()
        min_freq = max(min_freq, 1)
        self.itos = list(specials)
        # frequencies of special tokens are not counted when building vocabulary
        # in frequency order
        for tok in specials:
            del counter[tok]
        max_size = None if max_size is None else max_size + len(self.itos)
        # sort by frequency, then alphabetically
        words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
        words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
        if not isinstance(vectors, list) and vectors is not None:
            vectors = [vectors]
        # add words that appears less than min_freq but are in embeddings
        # vocabulary
        for word, freq in words_and_frequencies:
            if freq < min_freq:
                if vectors is not None and rare_with_vectors:
                    for v in vectors:
                        if word in v.stoi:
                            self.itos.append(word)
                else:
                    break
            elif len(self.itos) == max_size:
                break
            else:
                self.itos.append(word)
        if add_vectors_vocab:
            if (
                max_size is not None
                and sum(v.stoi for v in vectors) + len(self.itos) > max_size
            ):
                warnings.warn(
                    'Adding the vectors vocabulary will make '
                    'len(vocab) > max_vocab_size!'
                )
            vset = set()
            for v in vectors:
                vset.update(v.stoi.keys())
            v_itos = vset - set(self.itos)
            self.itos.extend(list(v_itos))
        self.stoi = defaultdict(_default_unk_index)
        # stoi is simply a reverse dict for itos
        self.stoi.update({tok: i for i, tok in enumerate(self.itos)})
        self.vectors = None
        if vectors is not None:
            self.load_vectors(vectors, unk_init=unk_init, cache=vectors_cache)
        else:
            assert unk_init is None and vectors_cache is None 
[docs]def merge_vocabularies(vocab_a, vocab_b, max_size=None, vectors=None, **kwargs):
    merged = vocab_a.freqs + vocab_b.freqs
    return Vocabulary(
        merged,
        specials=[UNK, PAD, START, STOP, UNALIGNED],
        max_size=max_size,
        vectors=vectors,
        **kwargs,
    )