Source code for kiwi.data.utils

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import copy
import logging
from collections import defaultdict
from math import ceil
from pathlib import Path

import torch

from kiwi import constants as const
from kiwi.data.fieldsets.fieldset import Fieldset

logger = logging.getLogger(__name__)


[docs]def serialize_vocabs(vocabs, include_vectors=False): """Make vocab dictionary serializable. """ serialized_vocabs = [] for name, vocab in vocabs.items(): vocab = copy.copy(vocab) vocab.stoi = dict(vocab.stoi) if not include_vectors: vocab.vectors = None serialized_vocabs.append((name, vocab)) return serialized_vocabs
[docs]def deserialize_vocabs(vocabs): """Restore defaultdict lost in serialization. """ vocabs = dict(vocabs) for name, vocab in vocabs.items(): # Hack. Can't pickle defaultdict :( vocab.stoi = defaultdict(lambda: const.UNK_ID, vocab.stoi) return vocabs
[docs]def serialize_fields_to_vocabs(fields): """ Save Vocab objects in Field objects to `vocab.pt` file. From OpenNMT """ vocabs = fields_to_vocabs(fields) vocabs = serialize_vocabs(vocabs) return vocabs
[docs]def deserialize_fields_from_vocabs(fields, vocabs): """ Load serialized vocabularies into their fields. """ # TODO redundant deserialization vocabs = deserialize_vocabs(vocabs) return fields_from_vocabs(fields, vocabs)
[docs]def fields_from_vocabs(fields, vocabs): """ Load Field objects from vocabs dict. From OpenNMT """ vocabs = deserialize_vocabs(vocabs) for name, vocab in vocabs.items(): if name not in fields: logger.debug( 'No field "{}" for loading vocabulary; ignoring.'.format(name) ) else: fields[name].vocab = vocab return fields
[docs]def fields_to_vocabs(fields): """ Extract Vocab Dictionary from Fields Dictionary. Args: fields: A dict mapping field names to Field objects Returns: vocab: A dict mapping field names to Vocabularies """ vocabs = {} for name, field in fields.items(): if field is not None and 'vocab' in field.__dict__: vocabs[name] = field.vocab return vocabs
[docs]def save_vocabularies_from_fields(directory, fields, include_vectors=False): """ Save Vocab objects in Field objects to `vocab.pt` file. From OpenNMT """ vocabs = serialize_fields_to_vocabs(fields) vocab_path = Path(directory, const.VOCAB_FILE) torch.save({const.VOCAB: vocabs}, str(vocab_path)) return vocab_path
[docs]def load_vocabularies_to_fields(vocab_path, fields): """Load serialized Vocabularies from disk into fields.""" if Path(vocab_path).exists(): vocabs_dict = torch.load( str(vocab_path), map_location=lambda storage, loc: storage ) vocabs = vocabs_dict[const.VOCAB] fields = deserialize_fields_from_vocabs(fields, vocabs) logger.info('Loaded vocabularies from {}'.format(vocab_path)) return all( [vocab_loaded_if_needed(field) for _, field in fields.items()] ) return False
[docs]def load_vocabularies_to_datasets(vocab_path, *datasets): fields = {} for dataset in datasets: fields.update(dataset.fields) return load_vocabularies_to_fields(vocab_path, fields)
[docs]def vocab_loaded_if_needed(field): return not field.use_vocab or (hasattr(field, const.VOCAB) and field.vocab)
[docs]def save_vocabularies_from_datasets(directory, *datasets): fields = {} for dataset in datasets: fields.update(dataset.fields) return save_vocabularies_from_fields(directory, fields)
[docs]def build_vocabulary(fields_vocab_options, *datasets): fields = {} for dataset in datasets: fields.update(dataset.fields) for name, field in fields.items(): if not vocab_loaded_if_needed(field): kwargs_vocab = fields_vocab_options[name] if 'vectors_fn' in kwargs_vocab: vectors_fn = kwargs_vocab['vectors_fn'] kwargs_vocab['vectors'] = vectors_fn() del kwargs_vocab['vectors_fn'] field.build_vocab(*datasets, **kwargs_vocab)
[docs]def load_datasets(directory, *datasets_names): dataset_path = Path(directory, const.DATAFILE) dataset_dict = torch.load( str(dataset_path), map_location=lambda storage, loc: storage ) datasets = [dataset_dict[name] for name in datasets_names] return datasets
[docs]def save_datasets(directory, **named_datasets): """Pickle datasets to standard file in directory Note that fields cannot be saved as part of a dataset, so they are ignored. Args: directory (str or Path): directory where to save the datasets pickle. named_datasets (dict): mapping of name and respective dataset. """ # Fields cannot be pickled # Saving field to a temporary list dataset_fields_tmp = [] for dataset in named_datasets.values(): dataset_fields_tmp.append(dataset.fields) dataset.fields = [] logging.info('Saving preprocessed datasets...') dataset_path = Path(directory, const.DATAFILE) torch.save(named_datasets, str(dataset_path)) # Reconstructing dataset.field from the temporary list for dataset, fields in zip(named_datasets.values(), dataset_fields_tmp): dataset.fields = fields
[docs]def save_training_datasets(directory, train_dataset, valid_dataset): ds_dict = {const.TRAIN: train_dataset, const.EVAL: valid_dataset} save_datasets(directory, **ds_dict)
[docs]def load_training_datasets(directory, fieldset): # FIXME: test if this works. Ideally, fields would be already contained # inside the loaded datasets. train_ds, valid_ds = load_datasets(directory, const.TRAIN, const.EVAL) # Remove fields not actually loaded (checking if they're required). fields = fieldset.fields for field in dict(fields): # Make a copy so del can be used if not hasattr(train_ds.examples[0], field): for set_name in [Fieldset.TRAIN, Fieldset.VALID]: if fieldset.is_required(field, set_name): raise AttributeError( 'Loaded {} dataset does not have a ' '{} field.'.format(set_name, field) ) del fields[field] train_ds.fields = fields valid_ds.fields = fields load_vocabularies_to_fields( Path(directory, const.VOCAB_FILE), fieldset.fields ) return train_ds, valid_ds
[docs]def cross_split_dataset(dataset, splits): examples_per_split = ceil(len(dataset) / splits) for split in range(splits): held_out_start = examples_per_split * split held_out_stop = examples_per_split * (split + 1) held_out_examples = dataset[held_out_start:held_out_stop] held_in_examples = dataset[:held_out_start] + dataset[held_out_stop:] train_split = dataset.__class__(held_in_examples, dataset.fields) eval_split = dataset.__class__(held_out_examples, dataset.fields) yield train_split, eval_split
[docs]def save_file(file_path, data, token_sep=' ', example_sep='\n'): if data and isinstance(data[0], list): data = [token_sep.join(map(str, sentence)) for sentence in data] else: data = map(str, data) example_str = example_sep.join(data) + '\n' Path(file_path).write_text(example_str)
[docs]def save_predicted_probabilities(directory, predictions, prefix=''): directory = Path(directory) directory.mkdir(parents=True, exist_ok=True) for key, preds in predictions.items(): if prefix: key = '{}.{}'.format(prefix, key) output_path = Path(directory, key) logger.info('Saving {} predictions to {}'.format(key, output_path)) save_file(output_path, preds, token_sep=' ', example_sep='\n')
[docs]def read_file(path): """Reads a file into a list of lists of words. """ with Path(path).open('r', encoding='utf8') as f: return [[token for token in line.strip().split()] for line in f]
[docs]def hter_to_binary(x): """Transform hter score into binary OK/BAD label. """ return ceil(float(x))
[docs]def wmt18_to_target(batch, *args): """Extract target tags from wmt18 format file. """ return batch[1::2]
[docs]def wmt18_to_gaps(batch, *args): """Extract gap tags from wmt18 format file. """ return batch[::2]
[docs]def project(batch, *args): """Projection onto the first argument. Needed to create a postprocessing pipeline that implements the identity. """ return batch
[docs]def filter_len( x, source_min_length=1, source_max_length=float('inf'), target_min_length=1, target_max_length=float('inf'), ): return (source_min_length <= len(x.source) <= source_max_length) and ( target_min_length <= len(x.target) <= target_max_length )