Source code for kiwi.data.fieldsets.fieldset

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

from functools import partial

from kiwi.data.vectors import AvailableVectors


[docs]class Fieldset: ALL = 'all' TRAIN = 'train' VALID = 'valid' TEST = 'test' def __init__(self): """ """ self._fields = {} self._options = {} self._required = {} self._vocab_options = {} self._vocab_vectors = {} self._file_reader = {}
[docs] def add( self, name, field, file_option_suffix, required=ALL, vocab_options=None, vocab_vectors=None, file_reader=None, ): """ Args: name: field: file_option_suffix: required (str or list or None): file_reader (callable): by default, uses Corpus.from_files(). Returns: """ self._fields[name] = field self._options[name] = file_option_suffix if not isinstance(required, list): required = [required] self._required[name] = required self._file_reader[name] = file_reader if vocab_options is None: vocab_options = {} self._vocab_options[name] = vocab_options self._vocab_vectors[name] = vocab_vectors
@property def fields(self): return self._fields
[docs] def is_required(self, name, set_name): required = self._required[name] if set_name in required or self.ALL in required: return True else: return False
[docs] def fields_and_files(self, set_name, **files_options): fields = {} files = {} for name, file_option_suffix in self._options.items(): file_option = '{}{}'.format(set_name, file_option_suffix) file_name = files_options.get(file_option) if not file_name and self.is_required(name, set_name): raise FileNotFoundError( 'File {} is required (use the {} ' 'option).'.format(file_name, file_option.replace('_', '-')) ) elif file_name: files[name] = { 'name': file_name, 'reader': self._file_reader.get(name), } fields[name] = self._fields[name] return fields, files
# def files_formats(self): # return { # set_name: self._file_format.get(set_name) # for set_name in self._fields # } #
[docs] def vocab_kwargs(self, name, **kwargs): if name not in self._vocab_options: raise KeyError( 'Field named "{}" does not exist in this fieldset'.format(name) ) vkwargs = {} for argument, option_name in self._vocab_options[name].items(): option_value = kwargs.get(option_name) if option_value is not None: vkwargs[argument] = option_value return vkwargs
[docs] def vocab_vectors_loader( self, name, embeddings_format='polyglot', embeddings_binary=False, **kwargs ): if name not in self._vocab_vectors: raise KeyError( 'Field named "{}" does not exist in this fieldset'.format(name) ) def no_vectors_fn(): return None vectors_fn = no_vectors_fn option_name = self._vocab_vectors[name] if option_name: option_value = kwargs.get(option_name) if option_value: emb_model = AvailableVectors[embeddings_format] # logger.info('Loading {} embeddings from {}'.format( # name, option_value)) vectors_fn = partial( emb_model, option_value, binary=embeddings_binary ) return vectors_fn
[docs] def vocab_vectors(self, name, **kwargs): vectors_fn = self.vocab_vectors_loader(name, **kwargs) return vectors_fn()
[docs] def fields_vocab_options(self, **kwargs): vocab_options = {} for name, field in self.fields.items(): vocab_options[name] = dict( vectors_fn=self.vocab_vectors_loader(name, **kwargs) ) vocab_options[name].update(self.vocab_kwargs(name, **kwargs)) return vocab_options