# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from collections import OrderedDict
import torch
from torch import nn
from kiwi import constants as const
from kiwi.metrics import CorrectMetric, ExpectedErrorMetric, PerplexityMetric
from kiwi.models.model import Model, ModelConfig
from kiwi.models.modules.attention import Attention
from kiwi.models.modules.scorer import MLPScorer
from kiwi.models.utils import apply_packed_sequence, replace_token
[docs]class PredictorConfig(ModelConfig):
def __init__(
self,
vocabs,
hidden_pred=400,
rnn_layers_pred=3,
dropout_pred=0.0,
share_embeddings=False,
embedding_sizes=0,
target_embeddings_size=200,
source_embeddings_size=200,
out_embeddings_size=200,
predict_inverse=False,
):
"""Predictor Hyperparams.
"""
super().__init__(vocabs)
# Vocabulary
self.target_side = const.TARGET
self.source_side = const.SOURCE
self.predict_inverse = predict_inverse
if self.predict_inverse:
self.source_side, self.target_side = (
self.target_side,
self.source_side,
)
self.target_vocab_size, self.source_vocab_size = (
self.source_vocab_size,
self.target_vocab_size,
)
# Architecture
self.hidden_pred = hidden_pred
self.rnn_layers_pred = rnn_layers_pred
self.dropout_pred = dropout_pred
self.share_embeddings = share_embeddings
if embedding_sizes:
self.target_embeddings_size = embedding_sizes
self.source_embeddings_size = embedding_sizes
self.out_embeddings_size = embedding_sizes
else:
self.target_embeddings_size = target_embeddings_size
self.source_embeddings_size = source_embeddings_size
self.out_embeddings_size = out_embeddings_size
[docs]@Model.register_subclass
class Predictor(Model):
"""Bidirectional Conditional Language Model
Implemented after Kim et al 2017, see:
http://www.statmt.org/wmt17/pdf/WMT63.pdf
"""
title = 'PredEst Predictor model (an embedder model)'
def __init__(self, vocabs, **kwargs):
"""
Args:
vocabs: Dictionary Mapping Field Names to Vocabularies.
kwargs:
config: A state dict of a PredictorConfig object.
dropout: LSTM dropout Default 0.0
hidden_pred: LSTM Hidden Size, default 200
rnn_layers: Default 3
embedding_sizes: If set, takes precedence over other embedding params
Default 100
source_embeddings_size: Default 100
target_embeddings_size: Default 100
out_embeddings_size: Output softmax embedding. Default 100
share_embeddings: Tie input and output embeddings for target.
Default False
predict_inverse: Predict from target to source. Default False
"""
super().__init__(vocabs=vocabs, ConfigCls=PredictorConfig, **kwargs)
scorer = MLPScorer(
self.config.hidden_pred * 2, self.config.hidden_pred * 2, layers=2
)
self.attention = Attention(scorer)
self.embedding_source = nn.Embedding(
self.config.source_vocab_size,
self.config.source_embeddings_size,
const.PAD_ID,
)
self.embedding_target = nn.Embedding(
self.config.target_vocab_size,
self.config.target_embeddings_size,
const.PAD_ID,
)
self.lstm_source = nn.LSTM(
input_size=self.config.source_embeddings_size,
hidden_size=self.config.hidden_pred,
num_layers=self.config.rnn_layers_pred,
batch_first=True,
dropout=self.config.dropout_pred,
bidirectional=True,
)
self.forward_target = nn.LSTM(
input_size=self.config.target_embeddings_size,
hidden_size=self.config.hidden_pred,
num_layers=self.config.rnn_layers_pred,
batch_first=True,
dropout=self.config.dropout_pred,
bidirectional=False,
)
self.backward_target = nn.LSTM(
input_size=self.config.target_embeddings_size,
hidden_size=self.config.hidden_pred,
num_layers=self.config.rnn_layers_pred,
batch_first=True,
dropout=self.config.dropout_pred,
bidirectional=False,
)
self.W1 = self.embedding_target
if not self.config.share_embeddings:
self.W1 = nn.Embedding(
self.config.target_vocab_size,
self.config.out_embeddings_size,
const.PAD_ID,
)
self.W2 = nn.Parameter(
torch.zeros(
self.config.out_embeddings_size, self.config.out_embeddings_size
)
)
self.V = nn.Parameter(
torch.zeros(
2 * self.config.target_embeddings_size,
2 * self.config.out_embeddings_size,
)
)
self.C = nn.Parameter(
torch.zeros(
2 * self.config.hidden_pred, 2 * self.config.out_embeddings_size
)
)
self.S = nn.Parameter(
torch.zeros(
2 * self.config.hidden_pred, 2 * self.config.out_embeddings_size
)
)
for p in self.parameters():
if len(p.shape) > 1:
nn.init.xavier_uniform_(p)
self._loss = nn.CrossEntropyLoss(
reduction='sum', ignore_index=const.PAD_ID
)
[docs] @staticmethod
def fieldset(*args, **kwargs):
from kiwi.data.fieldsets.predictor import build_fieldset
return build_fieldset()
[docs] @staticmethod
def from_options(vocabs, opts):
"""
Args:
vocabs:
opts:
Returns:
"""
model = Predictor(
vocabs,
hidden_pred=opts.hidden_pred,
rnn_layers_pred=opts.rnn_layers_pred,
dropout_pred=opts.dropout_pred,
share_embeddings=opts.share_embeddings,
embedding_sizes=opts.embedding_sizes,
target_embeddings_size=opts.target_embeddings_size,
source_embeddings_size=opts.source_embeddings_size,
out_embeddings_size=opts.out_embeddings_size,
predict_inverse=opts.predict_inverse,
)
return model
[docs] def loss(self, model_out, batch, target_side=None):
if not target_side:
target_side = self.config.target_side
target = getattr(batch, target_side)
# There are no predictions for first/last element
target = replace_token(target[:, 1:-1], const.STOP_ID, const.PAD_ID)
# Predicted Class must be in dim 1 for xentropyloss
logits = model_out[target_side]
logits = logits.transpose(1, 2)
loss = self._loss(logits, target)
loss_dict = OrderedDict()
loss_dict[target_side] = loss
loss_dict[const.LOSS] = loss
return loss_dict
[docs] def forward(self, batch, source_side=None, target_side=None):
if not source_side:
source_side = self.config.source_side
if not target_side:
target_side = self.config.target_side
source = getattr(batch, source_side)
target = getattr(batch, target_side)
batch_size, target_len = target.shape[:2]
# Remove First and Last Element (Start / Stop Tokens)
source_mask = self.get_mask(batch, source_side)[:, 1:-1]
source_lengths = source_mask.sum(1)
target_lengths = self.get_mask(batch, target_side).sum(1)
source_embeddings = self.embedding_source(source)
target_embeddings = self.embedding_target(target)
# Source Encoding
source_contexts, hidden = apply_packed_sequence(
self.lstm_source, source_embeddings, source_lengths
)
# Target Encoding.
h_forward, h_backward = self._split_hidden(hidden)
forward_contexts, _ = self.forward_target(target_embeddings, h_forward)
target_emb_rev = self._reverse_padded_seq(
target_lengths, target_embeddings
)
backward_contexts, _ = self.backward_target(target_emb_rev, h_backward)
backward_contexts = self._reverse_padded_seq(
target_lengths, backward_contexts
)
# For each position, concatenate left context i-1 and right context i+1
target_contexts = torch.cat(
[forward_contexts[:, :-2], backward_contexts[:, 2:]], dim=-1
)
# For each position i, concatenate Emeddings i-1 and i+1
target_embeddings = torch.cat(
[target_embeddings[:, :-2], target_embeddings[:, 2:]], dim=-1
)
# Get Attention vectors for all positions and stack.
self.attention.set_mask(source_mask.float())
attns = [
self.attention(
target_contexts[:, i], source_contexts, source_contexts
)
for i in range(target_len - 2)
]
attns = torch.stack(attns, dim=1)
# Combine attention, embeddings and target context vectors
C = torch.einsum('bsi,il->bsl', [attns, self.C])
V = torch.einsum('bsj,jl->bsl', [target_embeddings, self.V])
S = torch.einsum('bsk,kl->bsl', [target_contexts, self.S])
t_tilde = C + V + S
# Maxout with pooling size 2
t, _ = torch.max(
t_tilde.view(
t_tilde.shape[0], t_tilde.shape[1], t_tilde.shape[-1] // 2, 2
),
dim=-1,
)
f = torch.einsum('oh,bso->bsh', [self.W2, t])
logits = torch.einsum('vh,bsh->bsv', [self.W1.weight, f])
PreQEFV = torch.einsum('bsh,bsh->bsh', [self.W1(target[:, 1:-1]), f])
PostQEFV = torch.cat([forward_contexts, backward_contexts], dim=-1)
return {
target_side: logits,
const.PREQEFV: PreQEFV,
const.POSTQEFV: PostQEFV,
}
@staticmethod
def _reverse_padded_seq(lengths, sequence):
""" Reverses a batch of padded sequences of different length.
"""
batch_size, max_length = sequence.shape[:-1]
reversed_idx = []
for i in range(batch_size * max_length):
batch_id = i // max_length
sent_id = i % max_length
if sent_id < lengths[batch_id]:
sent_id_rev = lengths[batch_id] - sent_id - 1
else:
sent_id_rev = sent_id # Padding symbol, don't change order
reversed_idx.append(max_length * batch_id + sent_id_rev)
flat_sequence = sequence.contiguous().view(batch_size * max_length, -1)
reversed_seq = flat_sequence[reversed_idx, :].view(*sequence.shape)
return reversed_seq
@staticmethod
def _split_hidden(hidden):
"""Split Hidden State into forward/backward parts.
"""
h, c = hidden
size = h.shape[0]
idx_forward = torch.arange(0, size, 2, dtype=torch.long)
idx_backward = torch.arange(1, size, 2, dtype=torch.long)
hidden_forward = (h[idx_forward], c[idx_forward])
hidden_backward = (h[idx_backward], c[idx_backward])
return hidden_forward, hidden_backward
[docs] def metrics(self):
metrics = []
main_metric = PerplexityMetric(
prefix=self.config.target_side,
target_name=self.config.target_side,
PAD=const.PAD_ID,
STOP=const.STOP_ID,
)
metrics.append(main_metric)
metrics.append(
CorrectMetric(
prefix=self.config.target_side,
target_name=self.config.target_side,
PAD=const.PAD_ID,
STOP=const.STOP_ID,
)
)
metrics.append(
ExpectedErrorMetric(
prefix=self.config.target_side,
target_name=self.config.target_side,
PAD=const.PAD_ID,
STOP=const.STOP_ID,
)
)
return metrics
[docs] def metrics_ordering(self):
return min