#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import logging
from distutils.util import strtobool
from kiwi.cli.better_argparse import PipelineParser
from kiwi.cli.models import linear, nuqe, predictor, predictor_estimator, quetch
from kiwi.lib import train
logger = logging.getLogger(__name__)
[docs]def train_opts(parser):
    # Training loop options
    group = parser.add_argument_group('training')
    group.add_argument(
        '--epochs', type=int, default=50, help='Number of epochs for training.'
    )
    group.add_argument(
        '--train-batch-size',
        type=int,
        default=64,
        help='Maximum batch size for training.',
    )
    group.add_argument(
        '--valid-batch-size',
        type=int,
        default=64,
        help='Maximum batch size for evaluating.',
    )
    # Optimization options
    group = parser.add_argument_group('training-optimization')
    group.add_argument(
        '--optimizer',
        default='adam',
        choices=['sgd', 'adagrad', 'adadelta', 'adam', 'sparseadam'],
        help='Optimization method.',
    )
    group.add_argument(
        '--learning-rate',
        type=float,
        default=1.0,
        help='Starting learning rate. '
        'Recommended settings: sgd = 1, adagrad = 0.1, '
        'adadelta = 1, adam = 0.001',
    )
    group.add_argument(
        '--learning-rate-decay',
        type=float,
        default=1.0,
        help='Decay learning rate by this factor. ',
    )
    group.add_argument(
        '--learning-rate-decay-start',
        type=int,
        default=0,
        help='Start decay after this epoch.',
    )
    # Saving and resuming options
    group = parser.add_argument_group('training-save-load')
    group.add_argument(
        '--checkpoint-validation-steps',
        type=int,
        default=0,
        help='Perform validation every X training batches. Saves model'
        ' if `checkpoint-save` is true.',
    )
    group.add_argument(
        '--checkpoint-save',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=True,
        help='Save a training snapshot when validation is run. If false '
        'it will never save the model.',
    )
    group.add_argument(
        '--checkpoint-keep-only-best',
        type=int,
        default=1,
        help='Keep only n best models according to main metric (F1Mult '
        'by default); 0 will keep all.',
    )
    group.add_argument(
        '--checkpoint-early-stop-patience',
        type=int,
        default=0,
        help='Stop training if evaluation metrics do not improve after X '
        'validations; 0 disables this.',
    )
    group.add_argument(
        '--resume',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=False,
        help='Resume training a previous run. '
        'If --output-dir is not none, Kiwi will load from a checkpoint folder '
        'in that location. If --output-dir is not specified, '
        'then the --run-uuid option must be set. Files are then searched '
        'under the "runs" directory. If not found, they are '
        'downloaded from the MLflow server '
        '(check the --mlflow-tracking-uri option).',
    ) 
[docs]def build_parser():
    return PipelineParser(
        name='train',
        model_parsers=[
            nuqe.parser_for_pipeline('train'),
            predictor_estimator.parser_for_pipeline('train'),
            predictor.parser_for_pipeline('train'),
            quetch.parser_for_pipeline('train'),
            linear.parser_for_pipeline('train'),
        ],
        options_fn=train_opts,
    ) 
[docs]def main(argv=None):
    parser = build_parser()
    options = parser.parse(args=argv)
    train.train_from_options(options) 
if __name__ == '__main__':  # pragma: no cover
    main()  # pragma: no cover