# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
import logging
from collections import defaultdict
from pathlib import Path
import torch
from tqdm import tqdm
import kiwi
from kiwi import constants as const
from kiwi.loggers import tracking_logger
from kiwi.metrics.stats import Stats
from kiwi.models.model import Model
from kiwi.models.utils import load_torch_file
from kiwi.trainers.callbacks import EarlyStopException
from kiwi.trainers.utils import optimizer_class
logger = logging.getLogger(__name__)
[docs]class Trainer:
def __init__(
self, model, optimizer, checkpointer, log_interval=100, scheduler=None
):
"""
Args:
model: A kiwi.Model to train
optimizer: An optimizer
checkpointer: A Checkpointer object
log_interval: Log train stats every /n/ batches. Default 100
scheduler: A learning rate scheduler
"""
self.model = model
self.stats = Stats(
metrics=model.metrics(),
main_metric_ordering=model.metrics_ordering(),
log_interval=log_interval,
)
self.optimizer = optimizer
self.checkpointer = checkpointer
self.scheduler = scheduler
self._step = 0
self._epoch = 0
@property
def stats_summary_history(self):
return self.checkpointer.stats_summary_history
[docs] def run(self, train_iterator, valid_iterator, epochs=50):
"""
Args:
train_iterator:
epochs: Number of epochs for training.
"""
# log(self.eval_epoch(valid_dataset))
for epoch in range(self._epoch + 1, epochs + 1):
logger.info('Epoch {} of {}'.format(epoch, epochs))
self.train_epoch(train_iterator, valid_iterator)
self.stats.log()
try:
self.checkpointer(self, valid_iterator, epoch=epoch)
except EarlyStopException as e:
logger.info(e)
break
self.checkpointer.check_out()
[docs] def train_epoch(self, train_iterator, valid_iterator):
self.model.train()
for batch in tqdm(
train_iterator,
total=len(train_iterator),
desc='Batches',
unit=' batches',
ncols=80,
):
self._step += 1
outputs = self.train_step(batch)
self.stats.update(batch=batch, **outputs)
self.stats.log(step=self._step)
try:
self.checkpointer(self, valid_iterator, step=self._step)
except EarlyStopException as e:
logger.info(e)
break
self._epoch += 1
[docs] def train_steps(self, train_iterator, valid_iterator, max_steps):
train_iterator.repeat = True
self.model.train()
step = 0
for step, batch in tqdm(
enumerate(train_iterator, 1),
total=max_steps,
desc='Steps',
unit=' batches',
ncols=80,
):
self._step += 1
outputs = self.train_step(batch)
self.stats.update(batch=batch, **outputs)
self.stats.log(step=self._step)
try:
self.checkpointer(self, valid_iterator, step=self._step)
except EarlyStopException as e:
logger.info(e)
break
if step > max_steps:
break
eval_stats_summary = self.eval_epoch(valid_iterator)
eval_stats_summary.log()
sub_path = Path('step_{}'.format(self._step))
self.save(self.checkpointer.output_directory / sub_path)
train_iterator.repeat = False
[docs] def train_step(self, batch):
self.model.zero_grad()
model_out = self.model(batch)
loss_dict = self.model.loss(model_out, batch)
loss_dict[const.LOSS].backward()
self.optimizer.step()
return dict(loss=loss_dict, model_out=model_out)
[docs] def eval_epoch(self, valid_iterator, prefix='EVAL'):
self.model.eval()
self.stats.reset()
with torch.no_grad():
for batch in valid_iterator:
outputs = self.eval_step(batch)
self.stats.update(batch=batch, **outputs)
stats_summary = self.stats.wrap_up(prefix=prefix)
self.model.train()
return stats_summary
[docs] def eval_step(self, batch):
model_out = self.model(batch)
loss_dict = self.model.loss(model_out, batch)
return dict(loss=loss_dict, model_out=model_out)
[docs] def predict(self, valid_iterator):
self.model.eval()
with torch.no_grad():
predictions = defaultdict(list)
for batch in valid_iterator:
model_pred = self.model.predict(batch)
for key, values in model_pred.items():
predictions[key] += values
self.model.train()
return predictions
[docs] def make_sub_directory(self, root_directory, current_epoch, prefix='epoch'):
root_path = Path(root_directory)
epoch_path = Path('{}_{}'.format(prefix, current_epoch))
output_path = root_path / epoch_path
output_path.mkdir(exist_ok=True)
return output_path
[docs] def save(self, output_directory):
output_directory = Path(output_directory)
output_directory.mkdir(exist_ok=True)
logging.info('Saving training state to {}'.format(output_directory))
model_path = output_directory / const.MODEL_FILE
self.model.save(str(model_path))
optimizer_path = output_directory / const.OPTIMIZER
scheduler_dict = None
if self.scheduler:
scheduler_dict = {
'name': type(self.scheduler).__name__.lower(),
'state_dict': self.scheduler.state_dict(),
}
optimizer_dict = {
'name': type(self.optimizer).__name__.lower(),
'state_dict': self.optimizer.state_dict(),
'scheduler_dict': scheduler_dict,
}
torch.save(optimizer_dict, str(optimizer_path))
state = {
'__version__': kiwi.__version__,
'_epoch': self._epoch,
'_step': self._step,
'checkpointer': self.checkpointer,
}
state_path = output_directory / const.TRAINER
torch.save(state, str(state_path))
# Send to MLflow
event = None
if tracking_logger.should_log_artifacts():
logger.info('Logging artifacts to {}'.format(output_directory))
event = tracking_logger.log_artifacts(
str(output_directory), artifact_path=str(output_directory.name)
)
return event
[docs] @classmethod
def from_directory(cls, directory, device_id=None):
logger.info('Loading training state from {}'.format(directory))
root_path = Path(directory)
model_path = root_path / const.MODEL_FILE
model = Model.create_from_file(model_path)
if device_id is not None:
model.to(device_id)
optimizer_path = root_path / const.OPTIMIZER
optimizer_dict = load_torch_file(str(optimizer_path))
optimizer = optimizer_class(optimizer_dict['name'])(
model.parameters(), lr=0.0
)
optimizer.load_state_dict(optimizer_dict['state_dict'])
trainer = cls(model, optimizer, checkpointer=None)
trainer_path = root_path / const.TRAINER
state = load_torch_file(str(trainer_path))
trainer.__dict__.update(state)
return trainer
[docs] @classmethod
def resume(cls, local_path=None, prefix='latest_', device_id=None):
if local_path:
artifacts_uri = Path(local_path)
else:
artifacts_uri = Path(tracking_logger.get_artifact_uri())
if Path(local_path) / Path(prefix + 'epoch') in artifacts_uri.glob(
'{}*'.format(prefix)
):
last_save = 'epoch'
else:
logging.info(
'Latest epoch not found. Looking for other checkpoints'
)
prefix = 'epoch_'
saved_checkpoints = [
int(str(path.name).replace(prefix, ''))
for path in artifacts_uri.glob('{}*'.format(prefix))
if path.is_dir()
]
if not saved_checkpoints:
raise FileNotFoundError(
"Couldn't load trainer from: {}".format(
artifacts_uri / (prefix + '*')
)
)
last_save = max(saved_checkpoints)
snapshot_dir = artifacts_uri / '{}{}'.format(prefix, last_save)
logger.info('Resuming training from: {}'.format(snapshot_dir))
return cls.from_directory(snapshot_dir, device_id=device_id)