# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import logging
from abc import ABCMeta, abstractmethod
import torch
import torch.nn as nn
import kiwi
from kiwi import constants as const
from kiwi.data import utils
from kiwi.models.utils import load_torch_file
logger = logging.getLogger(__name__)
[docs]class ModelConfig:
__metaclass__ = ABCMeta
def __init__(self, vocabs):
"""Model Configuration Base Class.
Args:
vocabs: Dictionary Mapping Field Names to Vocabularies.
Must contain 'source' and 'target' keys
"""
self.source_vocab_size = len(vocabs[const.SOURCE])
self.target_vocab_size = len(vocabs[const.TARGET])
[docs] @classmethod
def from_dict(cls, config_dict, vocabs):
"""Create config from a saved state_dict.
Args:
config_dict: A dictionary that is the return value of
a call to the `state_dict()` method of `cls`
vocab: See `ModelConfig.__init__`
"""
config = cls(vocabs)
config.update(config_dict)
return config
[docs] def update(self, other_config):
"""Updates the config object with the values of `other_config`
Args:
other_config: The `dict` or `ModelConfig` object to update with.
"""
config_dict = dict()
if isinstance(self, other_config.__class__):
config_dict = other_config.__dict__
elif isinstance(other_config, dict):
config_dict = other_config
self.__dict__.update(config_dict)
[docs] def state_dict(self):
"""Return the __dict__ for serialization.
"""
self.__dict__['__version__'] = kiwi.__version__
return self.__dict__
[docs]class Model(nn.Module):
__metaclass__ = ABCMeta
subclasses = {}
def __init__(self, vocabs, ConfigCls=ModelConfig, config=None, **kwargs):
"""Quality Estimation Base Class.
Args:
vocabs: Dictionary Mapping Field Names to Vocabularies.
ConfigCls: ModelConfig Subclass
config: A State Dict of a ModelConfig subclass.
If set, passing other kwargs will raise an Exception.
"""
super().__init__()
self.vocabs = vocabs
if config is None:
config = ConfigCls(vocabs=vocabs, **kwargs)
else:
config = ConfigCls.from_dict(config_dict=config, vocabs=vocabs)
assert not kwargs
self.config = config
[docs] @classmethod
def register_subclass(cls, subclass):
cls.subclasses[subclass.__name__] = subclass
return subclass
[docs] @abstractmethod
def loss(self, model_out, target):
pass
[docs] @abstractmethod
def forward(self, *args, **kwargs):
pass
[docs] def num_parameters(self):
return sum(p.numel() for p in self.parameters())
[docs] def predict(self, batch, class_name=const.BAD, unmask=True):
model_out = self(batch)
predictions = {}
class_index = torch.tensor([const.LABELS.index(class_name)])
for key in model_out:
if key in [const.TARGET_TAGS, const.SOURCE_TAGS, const.GAP_TAGS]:
# Models are assumed to return logits, not probabilities
logits = model_out[key]
probs = torch.softmax(logits, dim=-1)
class_probs = probs.index_select(
-1, class_index.to(device=probs.device)
)
class_probs = class_probs.squeeze(-1).tolist()
if unmask:
if key == const.SOURCE_TAGS:
input_key = const.SOURCE
else:
input_key = const.TARGET
mask = self.get_mask(batch, input_key)
if key == const.GAP_TAGS:
# Append one extra token
mask = torch.cat(
[mask.new_ones((mask.shape[0], 1)), mask], dim=1
)
lengths = mask.int().sum(dim=-1)
for i, x in enumerate(class_probs):
class_probs[i] = x[: lengths[i]]
predictions[key] = class_probs
elif key == const.SENTENCE_SCORES:
predictions[key] = model_out[key].tolist()
elif key == const.BINARY:
logits = model_out[key]
probs = torch.softmax(logits, dim=-1)
class_probs = probs.index_select(
-1, class_index.to(device=probs.device)
)
predictions[key] = class_probs.tolist()
return predictions
[docs] def predict_raw(self, examples):
batch = self.preprocess(examples)
return self.predict(batch, class_name=const.BAD_ID, unmask=True)
[docs] def preprocess(self, examples):
"""Preprocess Raw Data.
Args:
examples (list of dict): List of examples. Each Example is a dict
with field strings as keys, and
unnumericalized, tokenized data as values.
Return:
A batch object.
"""
raise NotImplementedError
[docs] def get_mask(self, batch, output):
"""Compute Mask of Tokens for side.
Args:
batch: Namespace of tensors
side: String identifier.
"""
side = output
# if output in [const.TARGET_TAGS, const.GAP_TAGS]:
# side = const.TARGET
# elif output == const.SOURCE_TAGS:
# side = const.SOURCE
input_tensor = getattr(batch, side)
if isinstance(input_tensor, tuple) and len(input_tensor) == 2:
input_tensor, lengths = input_tensor
# output_tensor = getattr(batch, output)
# if isinstance(output_tensor, tuple) and len(output_tensor) == 2:
# output_tensor, lengths = output_tensor
mask = torch.ones_like(input_tensor, dtype=torch.uint8)
possible_padding = [const.PAD, const.START, const.STOP]
unk_id = self.vocabs[side].stoi.get(const.UNK)
for pad in possible_padding:
pad_id = self.vocabs[side].stoi.get(pad)
if pad_id is not None and pad_id != unk_id:
mask &= torch.as_tensor(
input_tensor != pad_id,
device=mask.device,
dtype=torch.uint8,
)
return mask
[docs] @staticmethod
def create_from_file(path):
try:
model_dict = load_torch_file(path)
except FileNotFoundError:
# If no model is found
raise FileNotFoundError(
'No valid model data found in {}'.format(path)
)
for model_name in Model.subclasses:
if model_name in model_dict:
model = Model.subclasses[model_name].from_dict(model_dict)
return model
[docs] @classmethod
def from_file(cls, path):
model_dict = torch.load(
str(path), map_location=lambda storage, loc: storage
)
if cls.__name__ not in model_dict:
raise KeyError(
'{} model data not found in {}'.format(cls.__name__, path)
)
return cls.from_dict(model_dict)
[docs] @classmethod
def from_dict(cls, model_dict):
vocabs = utils.deserialize_vocabs(model_dict[const.VOCAB])
class_dict = model_dict[cls.__name__]
model = cls(vocabs=vocabs, config=class_dict[const.CONFIG])
model.load_state_dict(class_dict[const.STATE_DICT])
return model
[docs] def save(self, path):
vocabs = utils.serialize_vocabs(self.vocabs)
model_dict = {
'__version__': kiwi.__version__,
const.VOCAB: vocabs,
self.__class__.__name__: {
const.CONFIG: self.config.state_dict(),
const.STATE_DICT: self.state_dict(),
},
}
torch.save(model_dict, str(path))