# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from kiwi import constants as const
from kiwi.data.fieldsets.quetch import build_fieldset
from kiwi.metrics import CorrectMetric, F1Metric, LogMetric
from kiwi.models.model import Model, ModelConfig
from kiwi.models.utils import align_tensor, convolve_tensor, make_loss_weights
[docs]class QUETCHConfig(ModelConfig):
def __init__(
self,
vocabs,
predict_target=True,
predict_gaps=False,
predict_source=False,
source_embeddings_size=50,
target_embeddings_size=50,
hidden_sizes=None,
bad_weight=3.0,
window_size=10,
max_aligned=5,
dropout=0.4,
embeddings_dropout=0.4,
freeze_embeddings=False,
):
super().__init__(vocabs)
if hidden_sizes is None:
hidden_sizes = [100]
source_vectors = vocabs[const.SOURCE].vectors
target_vectors = vocabs[const.TARGET].vectors
if source_vectors is not None:
source_embeddings_size = source_vectors.size(1)
if target_vectors is not None:
target_embeddings_size = target_vectors.size(1)
self.source_embeddings_size = source_embeddings_size
self.target_embeddings_size = target_embeddings_size
self.bad_weight = bad_weight
self.dropout = dropout
self.embeddings_dropout = embeddings_dropout
self.freeze_embeddings = freeze_embeddings
# self.predict_side = predict_side
# if predicting tags or source, default predict_target=true
# doesn't make sense
if predict_gaps or predict_source:
predict_target = predict_target
self.predict_target = predict_target
self.predict_gaps = predict_gaps
self.predict_source = predict_source
self.window_size = window_size
self.max_aligned = max_aligned
self.hidden_sizes = hidden_sizes
if const.SOURCE_TAGS in vocabs:
self.tags_pad_id = vocabs[const.SOURCE_TAGS].stoi[const.PAD]
elif const.GAP_TAGS in vocabs:
self.tags_pad_id = vocabs[const.GAP_TAGS].stoi[const.PAD]
else:
self.tags_pad_id = vocabs[const.TARGET_TAGS].stoi[const.PAD]
# FIXME: this might not correspond to reality (in vocabs)!
self.nb_classes = len(const.LABELS)
self.tag_bad_index = const.BAD_ID
self.pad_token = const.PAD
self.unaligned_idx = const.UNALIGNED_ID
self.source_padding_idx = const.PAD_ID
self.target_padding_idx = const.PAD_ID
[docs]@Model.register_subclass
class QUETCH(Model):
"""QUality Estimation from scraTCH (QUETCH) model.
TODO: add references.
"""
title = "QUETCH"
def __init__(self, vocabs, **kwargs):
super().__init__(vocabs=vocabs, ConfigCls=QUETCHConfig, **kwargs)
self.source_emb = None
self.target_emb = None
self.embeddings_dropout = None
self.linear = None
self.dropout = None
self.linear_out = None
source_vectors = vocabs[const.SOURCE].vectors
target_vectors = vocabs[const.TARGET].vectors
self.build(source_vectors, target_vectors)
[docs] @staticmethod
def fieldset(*args, **kwargs):
return build_fieldset(*args, **kwargs)
[docs] @staticmethod
def from_options(vocabs, opts):
model = QUETCH(
vocabs=vocabs,
predict_target=opts.predict_target,
predict_gaps=opts.predict_gaps,
predict_source=opts.predict_source,
source_embeddings_size=opts.source_embeddings_size,
target_embeddings_size=opts.target_embeddings_size,
hidden_sizes=opts.hidden_sizes,
bad_weight=opts.bad_weight,
window_size=opts.window_size,
max_aligned=opts.max_aligned,
dropout=opts.dropout,
embeddings_dropout=opts.embeddings_dropout,
freeze_embeddings=opts.freeze_embeddings,
)
return model
[docs] def loss(self, model_out, target):
if self.config.predict_source:
output_name = const.SOURCE_TAGS
elif self.config.predict_gaps:
output_name = const.GAP_TAGS
else:
output_name = const.TARGET_TAGS
# (bs*ts, nb_classes)
probs = model_out[output_name]
# (bs*ts, )
y = getattr(target, output_name)
predicted = probs.view(-1, self.config.nb_classes)
y = y.view(-1)
loss = self._loss(predicted, y)
return {const.LOSS: loss}
def _build_embeddings(self, source_vectors=None, target_vectors=None):
# Embeddings layers:
if source_vectors is not None:
# source_embeddings_size = self.source_embeddings.size(1)
self.source_emb = nn.Embedding(
num_embeddings=source_vectors.size(0),
embedding_dim=source_vectors.size(1),
padding_idx=self.config.source_padding_idx,
_weight=source_vectors,
)
else:
self.source_emb = nn.Embedding(
num_embeddings=self.config.source_vocab_size,
embedding_dim=self.config.source_embeddings_size,
padding_idx=self.config.source_padding_idx,
)
if target_vectors is not None:
self.target_emb = nn.Embedding(
num_embeddings=target_vectors.size(0),
embedding_dim=target_vectors.size(1),
padding_idx=self.config.target_padding_idx,
_weight=target_vectors,
)
else:
self.target_emb = nn.Embedding(
num_embeddings=self.config.target_vocab_size,
embedding_dim=self.config.target_embeddings_size,
padding_idx=self.config.target_padding_idx,
)
if self.config.freeze_embeddings:
self.source_emb.weight.requires_grad = False
self.source_emb.bias.requires_grad = False
self.target_emb.weight.requires_grad = False
self.target_emb.bias.requires_grad = False
self.embeddings_dropout = nn.Dropout(self.config.embeddings_dropout)
[docs] def build(self, source_vectors=None, target_vectors=None):
hidden_size = self.config.hidden_sizes[0]
nb_classes = self.config.nb_classes
dropout = self.config.dropout
weight = make_loss_weights(
nb_classes, const.BAD_ID, self.config.bad_weight
)
self._loss = nn.CrossEntropyLoss(
weight=weight, ignore_index=const.PAD_TAGS_ID
)
# Embeddings layers:
self._build_embeddings(source_vectors, target_vectors)
feature_set_size = (
self.config.source_embeddings_size
+ self.config.target_embeddings_size
) * self.config.window_size
self.linear = nn.Linear(feature_set_size, hidden_size)
self.linear_out = nn.Linear(hidden_size, nb_classes)
self.dropout = nn.Dropout(dropout)
torch.nn.init.xavier_uniform_(self.linear.weight)
torch.nn.init.xavier_uniform_(self.linear_out.weight)
torch.nn.init.constant_(self.linear.bias, 0.0)
torch.nn.init.constant_(self.linear_out.bias, 0.0)
self.is_built = True
[docs] def forward(self, batch):
assert self.is_built
if self.config.predict_source:
align_side = const.SOURCE_TAGS
else:
align_side = const.TARGET_TAGS
target_input, source_input, nb_alignments = self.make_input(
batch, align_side
)
#
# Source Branch
#
# (bs, ts, aligned, window) -> (bs, ts, aligned, window, emb)
h_source = self.source_emb(source_input)
if len(h_source.shape) == 5:
# (bs, ts, aligned, window, emb) -> (bs, ts, window, emb)
h_source = h_source.sum(2, keepdim=False) / nb_alignments.unsqueeze(
-1
).unsqueeze(-1)
# (bs, ts, window, emb) -> (bs, ts, window * emb)
h_source = h_source.view(source_input.size(0), source_input.size(1), -1)
#
# Target Branch
#
# (bs, ts * window) -> (bs, ts * window, emb)
h_target = self.target_emb(target_input)
if len(h_target.shape) == 5:
# (bs, ts, aligned, window, emb) -> (bs, ts, window, emb)
h_target = h_target.sum(2, keepdim=False) / nb_alignments.unsqueeze(
-1
).unsqueeze(-1)
# (bs, ts * window, emb) -> (bs, ts, window * emb)
h_target = h_target.view(target_input.size(0), target_input.size(1), -1)
#
# POS tags branches
#
feature_set = (h_source, h_target)
#
# Merge Branches
#
# (bs, ts, window * emb) -> (bs, ts, 2 * window * emb)
h = torch.cat(feature_set, dim=-1)
h = self.embeddings_dropout(h)
# (bs, ts, 2 * window * emb) -> (bs, ts, hs)
h = torch.tanh(self.linear(h))
h = self.dropout(h)
# (bs, ts, hs) -> (bs, ts, 2)
h = self.linear_out(h)
outputs = OrderedDict()
if self.config.predict_target:
outputs[const.TARGET_TAGS] = h
if self.config.predict_gaps:
outputs[const.GAP_TAGS] = h
if self.config.predict_source:
outputs[const.SOURCE_TAGS] = h
return outputs
@staticmethod
def _unmask(tensor, mask):
lengths = mask.int().sum(dim=-1)
return [x[: lengths[i]] for i, x in enumerate(tensor)]
[docs] def metrics(self):
metrics = []
if self.config.predict_target:
metrics.append(
F1Metric(
prefix=const.TARGET_TAGS,
target_name=const.TARGET_TAGS,
PAD=const.PAD_TAGS_ID,
labels=const.LABELS,
)
)
metrics.append(
CorrectMetric(
prefix=const.TARGET_TAGS,
target_name=const.TARGET_TAGS,
PAD=const.PAD_TAGS_ID,
)
)
if self.config.predict_source:
metrics.append(
F1Metric(
prefix=const.SOURCE_TAGS,
target_name=const.SOURCE_TAGS,
PAD=const.PAD_TAGS_ID,
labels=const.LABELS,
)
)
metrics.append(
CorrectMetric(
prefix=const.SOURCE_TAGS,
target_name=const.SOURCE_TAGS,
PAD=const.PAD_TAGS_ID,
)
)
if self.config.predict_gaps:
metrics.append(
F1Metric(
prefix=const.GAP_TAGS,
target_name=const.GAP_TAGS,
PAD=const.PAD_TAGS_ID,
labels=const.LABELS,
)
)
metrics.append(
CorrectMetric(
prefix=const.GAP_TAGS,
target_name=const.GAP_TAGS,
PAD=const.PAD_TAGS_ID,
)
)
metrics.append(LogMetric(targets=[(const.LOSS, const.LOSS)]))
return metrics
[docs] def metrics_ordering(self):
return max