Source code for kiwi.predictors.predictor

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <>
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  GNU Affero General Public License for more details.
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <>.

import logging
from collections import defaultdict

import torch
from import Example

from kiwi import constants as const
from import build_bucket_iterator
from import QEDataset

logger = logging.getLogger(__name__)

[docs]class Predicter: def __init__(self, model, fields=None): """Class to load a model for inference. Args: model (kiwi.models.Model): A trained QE model fields (dict[str: Field]): A dict mapping field names to strings. For online prediction. """ self.model = model self.fields = fields # Will break in Multi GPU mode self._device = next(model.parameters()).device
[docs] def to(self, device): """Method to mode Predicter object to other device. e.g: "cuda" Args: device (str): Device to which the model should be move to. """ self._device = device
[docs] def predict(self, examples, batch_size=1): """Create Predictions for a list of examples. Args: examples: A dict mapping field names to the list of raw examples (strings). batch_size: Batch Size to use. Default 1. Returns: A dict mapping prediction levels (word, sentence ..) to the model predictions for each example. Raises: Exception: If an example has an empty string as `source` or `target` field. Example: >>> import kiwi >>> predictor = kiwi.load_model('tests/toy-data/models/nuqe.torch') >>> src = ['a b c', 'd e f g'] >>> tgt = ['q w e r', 't y'] >>> align = ['0-0 1-1 1-2', '1-1 3-0'] >>> examples = {kiwi.constants.SOURCE: src, kiwi.constants.TARGET: tgt, kiwi.constants.ALIGNMENTS: align} >>> predictor.predict(examples) {'tags': [[0.4760947525501251, 0.47569847106933594, 0.4948718547821045, 0.5305878520011902], [0.5105430483818054, 0.5252899527549744]]} """ if not examples: return defaultdict(list) if self.fields is None: raise Exception('Missing fields object.') if not examples.get(const.SOURCE): raise KeyError('Missing required field "{}"'.format(const.SOURCE)) if not examples.get(const.TARGET): raise KeyError('Missing required field "{}"'.format(const.TARGET)) if not all( [s.strip() for s in examples[const.SOURCE] + examples[const.TARGET]] ): raise Exception( 'Empty String in {} or {} field found!'.format( const.SOURCE, const.TARGET ) ) fields = [(name, self.fields[name]) for name in examples] field_examples = [ Example.fromlist(values, fields) for values in zip(*examples.values()) ] dataset = QEDataset(field_examples, fields=fields) return, batch_size)
[docs] def run(self, dataset, batch_size=1): iterator = build_bucket_iterator( dataset, self._device, batch_size, is_train=False ) self.model.eval() predictions = defaultdict(list) with torch.no_grad(): for batch in iterator: model_pred = self.model.predict(batch) for key, values in model_pred.items(): if isinstance(values, list): predictions[key] += values else: predictions[key].append(values) return dict(predictions)