Source code for kiwi.metrics.stats

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import functools
import logging
from collections import OrderedDict

from kiwi.loggers import tracking_logger

logger = logging.getLogger(__name__)


[docs]@functools.total_ordering class StatsSummary(OrderedDict): def __init__(self, prefix=None, main_metric=None, ordering=max, **kwargs): self.prefix = prefix self._main_metric_name = main_metric self.ordering = ordering super().__init__(**kwargs) @property def main_metric(self): if self._main_metric_name: return self._main_metric_name elif self: return list(self.keys())[0] return None
[docs] def main_metric_value(self): return self.__getitem__(self.main_metric)
def _make_key(self, key): if self.prefix: key = '{}_{}'.format(self.prefix, key) return key def __str__(self): return ', '.join(['{}: {:0.4f}'.format(k, v) for k, v in self.items()])
[docs] def log(self): """Log statistics to output and also to tracking logger. :param stats_summary: StatsSummary object """ print('\r', end='\r') logger.info(self) for k, v in self.items(): tracking_logger.log_metric(k, v)
def __setitem__(self, key, value): key = self._make_key(key) super().__setitem__(key, value) def __getitem__(self, key): key = self._make_key(key) return super().__getitem__(key) def __contains__(self, key): key = self._make_key(key) return super().__contains__(key)
[docs] def get(self, key, default=None): key = self._make_key(key) return super().get(key, default)
def __eq__(self, other): return isinstance(other, StatsSummary) and self.get( self.main_metric ) == other.get(self.main_metric) def __le__(self, other): if self.ordering == max: return isinstance(other, StatsSummary) and self.get( self.main_metric ) <= other.get(self.main_metric) else: return isinstance(other, StatsSummary) and self.get( self.main_metric ) >= other.get(self.main_metric) def __gt__(self, other): if self.ordering == max: return isinstance(other, StatsSummary) and self.get( self.main_metric ) > other.get(self.main_metric) else: return isinstance(other, StatsSummary) and self.get( self.main_metric ) < other.get(self.main_metric)
[docs] def better_than(self, other): if self.ordering == max: return isinstance(other, StatsSummary) and self.get( self.main_metric ) > other.get(self.main_metric) else: return isinstance(other, StatsSummary) and self.get( self.main_metric ) < other.get(self.main_metric)
[docs]class Stats: def __init__( self, metrics, main_metric=None, main_metric_ordering=max, log_interval=0, ): self.metrics = metrics main_metric = main_metric or self.metrics[0] self.main_metric_name = main_metric.get_name() self.main_metric_ordering = main_metric_ordering self.log_interval = log_interval self.reset()
[docs] def update(self, **kwargs): self.steps += 1 for metric in self.metrics: metric.update(**kwargs)
[docs] def summarize(self, prefix=None): summary = StatsSummary( prefix=prefix, main_metric=self.main_metric_name, ordering=self.main_metric_ordering, ) if self.steps: for metric in self.metrics: summary.update(metric.summarize()) return summary
[docs] def reset(self): self.steps = 0 for metric in self.metrics: metric.reset()
[docs] def wrap_up(self, prefix=None): summary = self.summarize(prefix) self.reset() return summary
[docs] def log(self, step=None): if ( step is None or self.log_interval > 0 and not step % self.log_interval ): stats_summary = self.wrap_up() stats_summary.log()