# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import math
import time
from collections import OrderedDict
import numpy as np
import torch
from scipy.stats.stats import pearsonr, spearmanr
from torch import nn
from kiwi import constants as const
from kiwi.metrics.functions import fscore, precision_recall_fscore_support
from kiwi.models.utils import replace_token
[docs]class Metric:
def __init__(
self,
target_name=None,
metric_name=None,
PAD=None,
STOP=None,
prefix=None,
):
super().__init__()
self.reset()
self.prefix = prefix
self.target_name = target_name
self.metric_name = metric_name
self.PAD = PAD
self.STOP = STOP
[docs] def update(self, **kwargs):
raise NotImplementedError
[docs] def reset(self):
raise NotImplementedError
[docs] def summarize(self, **kwargs):
raise NotImplementedError
[docs] def get_name(self):
return self._prefix(self.metric_name)
def _prefix_keys(self, summary):
if self.prefix:
summary = OrderedDict(
{self._prefix(key): value for key, value in summary.items()}
)
return summary
def _prefix(self, key):
if self.prefix:
return '{}_{}'.format(self.prefix, key)
return key
[docs] def token_mask(self, batch):
target = self.get_target(batch)
if self.PAD is not None:
return target != self.PAD
else:
return torch.ones(
target.shape, dtype=torch.uint8, device=target.device
)
[docs] def get_target(self, batch):
target = getattr(batch, self.target_name)
if self.STOP is not None:
target = replace_token(target[:, 1:-1], self.STOP, self.PAD)
return target
[docs] def get_token_indices(self, batch):
mask = self.token_mask(batch)
return mask.view(-1).nonzero().squeeze()
[docs] def get_predictions(self, model_out):
predictions = model_out[self.target_name]
return predictions
[docs] def get_target_flat(self, batch):
target_flat = self.get_target(batch).contiguous().view(-1)
token_indices = self.get_token_indices(batch)
return target_flat[token_indices]
[docs] def get_predictions_flat(self, model_out, batch):
predictions = self.get_predictions(model_out).contiguous()
predictions_flat = predictions.view(-1, predictions.shape[-1]).squeeze()
token_indices = self.get_token_indices(batch)
return predictions_flat[token_indices]
[docs] def get_tokens(self, batch):
return self.token_mask(batch).sum().item()
[docs]class NLLMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='NLL', **kwargs)
[docs] def update(self, loss, batch, **kwargs):
self.tokens += self.get_tokens(batch)
self.nll += loss[self.target_name].item()
[docs] def summarize(self):
summary = {self.metric_name: self.nll / self.tokens}
return self._prefix_keys(summary)
[docs] def reset(self):
self.nll = 0.0
self.tokens = 0
[docs]class PerplexityMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='PERP', **kwargs)
[docs] def reset(self):
self.tokens = 0
self.nll = 0.0
[docs] def update(self, loss, batch, **kwargs):
self.tokens += self.get_tokens(batch)
self.nll += loss[self.target_name].item()
[docs] def summarize(self):
summary = {self.metric_name: math.e ** (self.nll / self.tokens)}
return self._prefix_keys(summary)
[docs]class CorrectMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='CORRECT', **kwargs)
[docs] def update(self, model_out, batch, **kwargs):
self.tokens += self.get_tokens(batch)
logits = self.get_predictions_flat(model_out, batch)
target = self.get_target_flat(batch)
_, pred = logits.max(-1)
correct = target == pred
correct_count = correct.sum().item()
self.correct += correct_count
[docs] def summarize(self):
summary = {self.metric_name: float(self.correct) / self.tokens}
return self._prefix_keys(summary)
[docs] def reset(self):
self.correct = 0
self.tokens = 0
[docs]class F1Metric(Metric):
def __init__(self, labels, **kwargs):
super().__init__(metric_name='F1_MULT', **kwargs)
self.labels = labels
[docs] def update(self, model_out, batch, **kwargs):
logits = self.get_predictions_flat(model_out, batch)
target = self.get_target_flat(batch)
_, y_hat = logits.max(-1)
self.Y_HAT += y_hat.tolist()
self.Y += target.tolist()
[docs] def summarize(self):
summary = OrderedDict()
_, _, f1, _ = precision_recall_fscore_support(self.Y_HAT, self.Y)
summary[self.metric_name] = np.prod(f1)
for i, label in enumerate(self.labels):
summary['F1_' + label] = f1[i]
return self._prefix_keys(summary)
[docs] def reset(self):
self.Y = []
self.Y_HAT = []
[docs]class PearsonMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='PEARSON', **kwargs)
[docs] def reset(self):
self.predictions = []
self.target = []
[docs] def update(self, model_out, batch, **kwargs):
target = self.get_target_flat(batch)
predictions = self.get_predictions_flat(model_out, batch)
self.predictions += predictions.tolist()
self.target += target.tolist()
[docs] def summarize(self):
pearson = pearsonr(self.predictions, self.target)[0]
summary = {self.metric_name: pearson}
return self._prefix_keys(summary)
[docs]class SpearmanMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='SPEARMAN', **kwargs)
[docs] def reset(self):
self.predictions = []
self.target = []
[docs] def update(self, model_out, batch, **kwargs):
target = self.get_target_flat(batch)
predictions = self.get_predictions_flat(model_out, batch)
self.predictions += predictions.tolist()
self.target += target.tolist()
[docs] def summarize(self):
spearman = spearmanr(self.predictions, self.target)[0]
summary = {self.metric_name: spearman}
return self._prefix_keys(summary)
[docs]class ExpectedErrorMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='ExpErr', **kwargs)
[docs] def update(self, model_out, batch, **kwargs):
logits = self.get_predictions_flat(model_out, batch)
target = self.get_target_flat(batch)
probs = nn.functional.softmax(logits, -1)
probs = probs.gather(-1, target.unsqueeze(-1)).squeeze()
errors = 1.0 - probs
self.tokens += self.get_tokens(batch)
self.expected_error += errors.sum().item()
[docs] def summarize(self):
summary = {self.metric_name: self.expected_error / self.tokens}
return self._prefix_keys(summary)
[docs] def reset(self):
self.expected_error = 0.0
self.tokens = 0
[docs]class TokPerSecMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='TokPerSec', **kwargs)
[docs] def update(self, batch, **kwargs):
self.tokens += self.get_tokens(batch)
[docs] def summarize(self):
summary = {self.metric_name: self.tokens / (time.time() - self.time)}
return self._prefix_keys(summary)
[docs] def reset(self):
self.tokens = 0
self.time = time.time()
[docs]class LogMetric(Metric):
"""Logs averages of values in loss, model or batch.
"""
def __init__(self, targets, metric_name=None, **kwargs):
self.targets = targets
metric_name = metric_name or self._format(*targets[0])
super().__init__(metric_name=metric_name, **kwargs)
[docs] def update(self, **kwargs):
self.steps += 1
for side, target in self.targets:
key = self._format(side, target)
self.log[key] += kwargs[side][target].mean().item()
[docs] def summarize(self):
summary = {
key: value / float(self.steps) for key, value in self.log.items()
}
return self._prefix_keys(summary)
[docs] def reset(self):
self.log = {
self._format(side, target): 0.0 for side, target in self.targets
}
self.steps = 0
def _format(self, side, target):
return '{}_{}'.format(side, target)
[docs]class RMSEMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='RMSE', **kwargs)
[docs] def update(self, batch, model_out, **kwargs):
predictions = self.get_predictions_flat(model_out, batch)
target = self.get_target_flat(batch)
self.squared_error += ((predictions - target) ** 2).sum().item()
self.tokens += self.get_tokens(batch)
[docs] def summarize(self):
rmse = math.sqrt(self.squared_error / self.tokens)
summary = {self.metric_name: rmse}
return self._prefix_keys(summary)
[docs] def reset(self):
self.squared_error = 0.0
self.tokens = 0
[docs]class TokenMetric(Metric):
def __init__(self, target_token=const.UNK_ID, token_name='UNK', **kwargs):
self.target_token = target_token
super().__init__(metric_name='UNKS', **kwargs)
[docs] def update(self, batch, **kwargs):
target = self.get_target_flat(batch)
self.targets += (target == self.target_token).sum().item()
self.tokens += self.get_tokens(batch)
[docs] def summarize(self):
summary = {}
if self.tokens:
summary = {self.metric_name: self.targets / self.tokens}
return self._prefix_keys(summary)
[docs] def reset(self):
self.tokens = 0
self.targets = 0
[docs]class ThresholdCalibrationMetric(Metric):
def __init__(self, **kwargs):
super().__init__(metric_name='F1_CAL', **kwargs)
[docs] def update(self, model_out, batch, **kwargs):
logits = self.get_predictions_flat(model_out, batch)
bad_probs = nn.functional.softmax(logits, -1)[:, const.BAD_ID]
target = self.get_target_flat(batch)
self.scores += bad_probs.tolist()
self.Y += target.tolist()
[docs] def summarize(self):
summary = {}
mid = len(self.Y) // 2
if mid:
perm = np.random.permutation(len(self.Y))
self.Y = [self.Y[idx] for idx in perm]
self.scores = [self.scores[idx] for idx in perm]
m = MovingF1()
fscore, threshold = m.choose(
m.eval(self.scores[:mid], self.Y[:mid])
)
predictions = [
const.BAD_ID if score >= threshold else const.OK_ID
for score in self.scores[mid:]
]
_, _, f1, _ = precision_recall_fscore_support(
predictions, self.Y[mid:]
)
f1_mult = np.prod(f1)
summary = {self.metric_name: f1_mult}
return self._prefix_keys(summary)
[docs] def reset(self):
self.scores = []
self.Y = []
[docs]class MovingMetric:
"""Class to compute the changes in one metric as a function of a second metric.
Example: F1 score vs. Classification Threshold, Quality vs Skips
"""
[docs] def eval(self, scores, labels):
"""Compute the graph metric1 vs metric2
Args:
Scores: Model Outputs
Labels: Corresponding Labels
"""
self.init(scores, labels)
scores, labels = self.sort(scores, labels)
init_threshold = scores[0]
thresholds = [(self.compute(), init_threshold)]
for score, label in zip(scores, labels):
self.update(score, label)
thresholds.append((self.compute(), score))
return thresholds
[docs] def init(self, scores, labels):
"""Initialize the Metric for threshold < min(scores)
"""
return scores, labels
[docs] def sort(self, scores, labels):
"""Sort List of labels and scores.
"""
return zip(*sorted(zip(scores, labels)))
[docs] def update(self, score, label):
"""Move the threshold past score
"""
return None
[docs] def compute(self):
"""Compute the current Value of the metric
"""
pass
[docs] def choose(self, thresholds):
"""Choose the best (threshold, metric) tuple from an iterable.
"""
pass
[docs]class MovingF1(MovingMetric):
[docs] def init(self, scores, labels, class_idx=1):
"""
Compute F1 Mult for all decision thresholds over (scores, labels)
Initialize the threshold s.t. all examples are classified as
`class_idx`.
Args:
scores: Likelihood scores for class index
Labels: Gold Truth classes in {0,1}
class_index: ID of class
"""
# -1 if class_idx == 0 , 1 if class_idx == 1
self.sign = 2 * class_idx - 1
class_one = sum(labels)
class_zero = len(labels) - class_one
self.fp_zero = (1 - class_idx) * class_one
self.tp_zero = (1 - class_idx) * class_zero
self.fp_one = class_idx * class_zero
self.tp_one = class_idx * class_one
[docs] def update(self, score, label):
"""Move the decision threshold.
"""
self.tp_zero += self.sign * (1 - label)
self.fp_zero += self.sign * label
self.tp_one -= self.sign * label
self.fp_one -= self.sign * (1 - label)
[docs] def compute(self):
f1_zero = fscore(self.tp_zero, self.fp_zero, self.fp_one)
f1_one = fscore(self.tp_one, self.fp_one, self.fp_zero)
return f1_one * f1_zero
[docs] def choose(self, thresholds):
return max(thresholds)
[docs]class MovingSkipsAtQuality(MovingMetric):
"""Computes Quality of skipped examples vs fraction of skips.
"""
def __init__(
self, scores_higher_is_better=False, labels_higher_is_better=False
):
"""
Args:
scores_higher_is_better:
If True, higher model outputs indicate higher quality.
labels_higher_is_better:
If True, higher label values indicate higher quality.
"""
self.scores_higher_is_better = scores_higher_is_better
self.labels_higher_is_better = labels_higher_is_better
[docs] def eval(self, scores, labels):
"""
Args:
scores: Model output quality or error scores. If quality scores
are provided, pass scores_higher_is_better=True.
labels: Ground truth quality or error scores. If quality scores
are provided, pass labels_higher_is_better=True.
"""
return super().eval(scores, labels)
[docs] def init(self, scores, labels):
"""
Args:
scores: Model output quality or error scores. If quality scores
are provided, pass scores_higher_is_better=True.
labels: Ground truth quality or error scores. If quality scores
are provided, pass labels_higher_is_better=True.
"""
self.cumulative_qual = 0.0
self.skipped = 0
self.data_size = len(scores)
[docs] def update(self, score, label):
self.cumulative_qual += label
self.skipped += 1
[docs] def compute(self):
if not self.skipped:
return None, 0.0
return (
self.skipped / self.data_size,
self.cumulative_qual / self.skipped,
)
[docs] def choose(self, thresholds, target_qual):
"""Chooses the smallest threshold such that
avg. quality is greater than or equal to target_qual
"""
best = None
sign = 1 if self.labels_higher_is_better else -1
for ((skip, qual), t) in thresholds:
if (sign * (qual - target_qual)) >= 0:
# The quality at threshold t is admissible given target_qual
if best is None:
best = ((skip, qual), t)
else:
last_best = abs(best[0][0] - target_qual)
if abs(qual - target_qual) < last_best:
# The quality at threshold t is admissible given
# target_qual and closer than the previous best
best = ((skip, qual), t)
return best
[docs] def sort(self, scores, labels):
return zip(
*sorted(zip(scores, labels), reverse=self.scores_higher_is_better)
)