kiwi.trainers package

Submodules

kiwi.trainers.callbacks module

class kiwi.trainers.callbacks.Checkpoint(output_dir, checkpoint_save=False, checkpoint_keep_only_best=0, checkpoint_early_stop_patience=0, checkpoint_validation_steps=0)[source]

Bases: object

Class for determining whether to evaluate / save the model.

best_iteration_path()[source]
best_model_path()[source]
best_stats()[source]
best_stats_and_path()[source]
check_in(trainer, stats, epoch=None, step=None)[source]

Saves stat summary and handles checkpoint saving.

check_out()[source]
static copy_best_model(model_dir, output_dir)[source]
early_stopping()[source]
last_model_path()[source]

Generates the path where the latest model should be saved.

make_output_path(epoch=None, step=None)[source]
must_eval(epoch=None, step=None)[source]
must_save_best(stats)[source]
push_to_heap(stats, output_path)[source]

Push stats and output path to the heap.

remove_snapshot(path_to_remove, event=None)[source]

Remove snapshot locally and in MLFlow.

save_latest(trainer, saved_best=False)[source]

Saves latest checkpoint of the current model. In case a model was just saved due to being the best validation, saves a pointer instead of the full model. Returns path of saved model.

worst_stats()[source]
exception kiwi.trainers.callbacks.EarlyStopException(*args, **kwargs)[source]

Bases: StopIteration

kiwi.trainers.linear_word_qe_trainer module

class kiwi.trainers.linear_word_qe_trainer.LinearWordQETrainer(model, optimizer_name, regularization_constant, checkpointer)[source]

Bases: kiwi.models.linear.linear_trainer.LinearTrainer

model

kiwi.trainers.trainer module

class kiwi.trainers.trainer.Trainer(model, optimizer, checkpointer, log_interval=100, scheduler=None)[source]

Bases: object

eval_epoch(valid_iterator, prefix='EVAL')[source]
eval_step(batch)[source]
classmethod from_directory(directory, device_id=None)[source]
make_sub_directory(root_directory, current_epoch, prefix='epoch')[source]
predict(valid_iterator)[source]
classmethod resume(local_path=None, prefix='latest_', device_id=None)[source]
run(train_iterator, valid_iterator, epochs=50)[source]
Parameters:
  • train_iterator
  • epochs – Number of epochs for training.
save(output_directory)[source]
stats_summary_history
train_epoch(train_iterator, valid_iterator)[source]
train_step(batch)[source]
train_steps(train_iterator, valid_iterator, max_steps)[source]

kiwi.trainers.utils module

kiwi.trainers.utils.optimizer_class(name)[source]

Module contents