kiwi.systems.tlm_system
BatchSizeConfig
Base class for all pydantic configs. Used to configure base behaviour of configs.
ModelConfig
TLMSystem
Helper class that provides a standard way to create an ABC using
kiwi.systems.tlm_system.
logger
Bases: kiwi.utils.io.BaseConfig
kiwi.utils.io.BaseConfig
train
valid
encoder
tlm_outputs
Bases: kiwi.systems._meta_module.Serializable, pytorch_lightning.LightningModule
kiwi.systems._meta_module.Serializable
pytorch_lightning.LightningModule
Helper class that provides a standard way to create an ABC using inheritance.
Config
System configuration base class.
class_name
load
If set, system architecture and vocabulary parameters are ignored. Load pretrained kiwi encoder model.
load_vocabs
model
data_processing
optimizer
batch_size
num_data_workers
map_name_to_class
check_consistency
check_model_requirement
check_batching
subclasses
set_config_options
prepare_data
Initialize the data sources that model will use to create the data loaders
train_dataloader
Return a PyTorch DataLoader for the training set.
Requires calling prepare_data beforehand.
PyTorch DataLoader
val_dataloader
Return a PyTorch DataLoader for the validation set.
forward
Same as torch.nn.Module.forward().
In Kiwi we use it to glue together the modular parts that constitute a model, e.g., the encoder and a tlm_output.
batch_inputs – Dict containing a batch of data. See kiwi.data.encoders.field_encoders.QEEncoder.batch_encode().
outputs of the tlm_outputs module.
outputs
training_step
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
batch (Tensor | (Tensor, …) | [Tensor, …]) – The output of your DataLoader. A tensor, tuple or list.
Tensor
DataLoader
batch_idx (int) – Integer displaying index of this batch
optimizer_idx (int) – When using multiple optimizers, this argument will also be present.
hiddens (Tensor) – Passed in if :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0.
Dict with loss key and optional log or progress bar keys. When implementing training_step(), return whatever you need in that step:
training_step()
loss -> tensor scalar REQUIRED
progress_bar -> Dict for progress bar display. Must have only tensors
log -> Dict of metrics to add to logger. Must have only tensors (no images, etc)
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Examples
def training_step(self, batch, batch_idx): x, y, z = batch # implement your own out = self(x) loss = self.loss(out, x) logger_logs = {'training_loss': loss} # optional (MUST ALL BE TENSORS) # if using TestTubeLogger or TensorBoardLogger you can nest scalars logger_logs = {'losses': logger_logs} # optional (MUST ALL BE TENSORS) output = { 'loss': loss, # required 'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS) 'log': logger_logs } # return a dict return output
If you define multiple optimizers, this step will be called with an additional optimizer_idx parameter.
optimizer_idx
# Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder if optimizer_idx == 1: # do training_step with decoder
If you add truncated back propagation through time you will also get an additional argument with the hidden states of the previous step.
# Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step ... out, hiddens = self.lstm(data, hiddens) ... return { "loss": ..., "hiddens": hiddens # remember to detach() this }
Notes
The loss value shown in the progress bar is smoothed (averaged) over the last values, so it differs from the actual loss returned in train/validation step.
validation_step
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(train_batch) val_outs.append(out) validation_epoch_end(val_outs)
batch_idx (int) – The index of this batch
dataloader_idx (int) – The index of the dataloader that produced this batch (only if multiple val datasets used)
Dict or OrderedDict - passed to validation_epoch_end(). If you defined validation_step_end() it will go to that first.
validation_epoch_end()
validation_step_end()
# pseudocode of order out = validation_step() if defined('validation_step_end'): out = validation_step_end(out) out = validation_epoch_end(out)
# if you have one val dataloader: def validation_step(self, batch, batch_idx) # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx)
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # all optional... # return whatever you need for the collation function validation_epoch_end output = OrderedDict({ 'val_loss': loss_val, 'val_acc': torch.tensor(val_acc), # everything must be a tensor }) # return an optional dict return output
If you pass in multiple val datasets, validation_step will have an additional argument.
# CASE 2: multiple validation datasets def validation_step(self, batch, batch_idx, dataset_idx): # dataset_idx tells you which dataset this is.
Note
If you don’t need to validate you don’t need to implement this method.
When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
validation_step()
validation_epoch_end
Called at the end of the validation epoch with the outputs of all validation steps.
# the pseudocode for these calls val_outs = [] for val_batch in val_data: out = validation_step(val_batch) val_outs.append(out) validation_epoch_end(val_outs)
outputs – List of outputs you defined in validation_step(), or if there are multiple dataloaders, a list containing a list of outputs for each dataloader.
Dict or OrderedDict. May have the following optional keys:
progress_bar (dict for progress bar display; only tensors)
log (dict of metrics to add to logger; only tensors).
If you didn’t define a validation_step(), this won’t be called.
The outputs here are strictly for logging or progress bar.
If you don’t need to display anything, don’t return anything.
If you want to manually set current step, you can specify the ‘step’ key in the ‘log’ dict.
With a single dataloader:
def validation_epoch_end(self, outputs): val_acc_mean = 0 for output in outputs: val_acc_mean += output['val_acc'] val_acc_mean /= len(outputs) tqdm_dict = {'val_acc': val_acc_mean.item()} # show val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, 'log': {'val_acc': val_acc_mean.item()} } return results
With multiple dataloaders, outputs will be a list of lists. The outer list contains one entry per dataloader, while the inner list contains the individual outputs of each validation step for that dataloader.
def validation_epoch_end(self, outputs): val_acc_mean = 0 i = 0 for dataloader_outputs in outputs: for output in dataloader_outputs: val_acc_mean += output['val_acc'] i += 1 val_acc_mean /= i tqdm_dict = {'val_acc': val_acc_mean.item()} # show val_loss and val_acc in progress bar but only log val_loss results = { 'progress_bar': tqdm_dict, 'log': {'val_acc': val_acc_mean.item(), 'step': self.current_epoch} } return results
loss
Compute total model loss.
Dict[loss_key]=value
loss_dict
metrics_step
metrics_end
main_metric
Configure and retrieve the metric to be used for monitoring.
The first time it is called, the main metric is configured based on the specified metrics in selected_metric or, if not provided, on the first metric in the TLM outputs. Subsequent calls return the configured main metric. If a subsequent call specifies selected_metric, configuration is done again.
selected_metric
Note that the first element might be a concatenation of several metrics in case selected_metric is a list. This is useful for considering more than one metric as the best (metric_end() will sum over them).
metric_end()
num_parameters
from_config
from_dict
_load_dict
to_dict
on_save_checkpoint
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
checkpoint – Checkpoint to be saved
Example
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
on_load_checkpoint
Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.
on_save_checkpoint()
checkpoint – Loaded checkpoint
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
load_from_checkpoint
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under module_arguments
Any arguments specified through *args and **kwargs will override args stored in hparams.
checkpoint_path – Path to checkpoint. This can also be a URL.
args – Any positional args needed to init the model.
map_location – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().
torch.load()
hparams_file –
Optional path to a .yaml file with hierarchical structure as in this example:
drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.
dict
LightningModule
If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.
Namespace
.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
tags_csv –
Warning
Deprecated since version 0.7.6.
tags_csv argument is deprecated in v0.7.6. Will be removed v0.9.0.
Optional path to a .csv file with two columns (key, value) as in this example:
key,value drop_prob,0.2 batch_size,32
Use this method to pass in a .csv file with the hparams you’d like to use.
hparam_overrides – A dictionary with keys to override in the hparams
kwargs – Any keyword args needed to init the model.
LightningModule with loaded weights and hyperparameters (if available).
# load weights without mapping ... MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path: NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
configure_optimizers
Instantiate configured optimizer and LR scheduler.
Single optimizer
List or Tuple - List of optimizers
learning-rate schedulers
kiwi.systems.qe_system
kiwi.systems.xlm