Source code for kiwi.cli.pipelines.train

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import logging
from distutils.util import strtobool

from kiwi.cli.better_argparse import PipelineParser
from kiwi.cli.models import linear, nuqe, predictor, predictor_estimator, quetch
from kiwi.lib import train

logger = logging.getLogger(__name__)


[docs]def train_opts(parser): # Training loop options group = parser.add_argument_group('training') group.add_argument( '--epochs', type=int, default=50, help='Number of epochs for training.' ) group.add_argument( '--train-batch-size', type=int, default=64, help='Maximum batch size for training.', ) group.add_argument( '--valid-batch-size', type=int, default=64, help='Maximum batch size for evaluating.', ) # Optimization options group = parser.add_argument_group('training-optimization') group.add_argument( '--optimizer', default='adam', choices=['sgd', 'adagrad', 'adadelta', 'adam', 'sparseadam'], help='Optimization method.', ) group.add_argument( '--learning-rate', type=float, default=1.0, help='Starting learning rate. ' 'Recommended settings: sgd = 1, adagrad = 0.1, ' 'adadelta = 1, adam = 0.001', ) group.add_argument( '--learning-rate-decay', type=float, default=1.0, help='Decay learning rate by this factor. ', ) group.add_argument( '--learning-rate-decay-start', type=int, default=0, help='Start decay after this epoch.', ) # Saving and resuming options group = parser.add_argument_group('training-save-load') group.add_argument( '--checkpoint-validation-steps', type=int, default=0, help='Perform validation every X training batches. Saves model' ' if `checkpoint-save` is true.', ) group.add_argument( '--checkpoint-save', type=lambda x: bool(strtobool(x)), nargs='?', const=True, default=True, help='Save a training snapshot when validation is run. If false ' 'it will never save the model.', ) group.add_argument( '--checkpoint-keep-only-best', type=int, default=1, help='Keep only n best models according to main metric (F1Mult ' 'by default); 0 will keep all.', ) group.add_argument( '--checkpoint-early-stop-patience', type=int, default=0, help='Stop training if evaluation metrics do not improve after X ' 'validations; 0 disables this.', ) group.add_argument( '--resume', type=lambda x: bool(strtobool(x)), nargs='?', const=True, default=False, help='Resume training a previous run. ' 'If --output-dir is not none, Kiwi will load from a checkpoint folder ' 'in that location. If --output-dir is not specified, ' 'then the --run-uuid option must be set. Files are then searched ' 'under the "runs" directory. If not found, they are ' 'downloaded from the MLflow server ' '(check the --mlflow-tracking-uri option).', )
[docs]def build_parser(): return PipelineParser( name='train', model_parsers=[ nuqe.parser_for_pipeline('train'), predictor_estimator.parser_for_pipeline('train'), predictor.parser_for_pipeline('train'), quetch.parser_for_pipeline('train'), linear.parser_for_pipeline('train'), ], options_fn=train_opts, )
[docs]def main(argv=None): parser = build_parser() options = parser.parse(args=argv) train.train_from_options(options)
if __name__ == '__main__': # pragma: no cover main() # pragma: no cover