Train your own Metric

To train your own metric we recommend you to install directly from source:

   git clone https://github.com/Unbabel/COMET.git
   poetry install

After having your repo locally installed you can train your own model/metric with the following command:

comet-score -s src.de -t hyp1.en -r ref.en --model PATH/TO/CHECKPOINT

You can also upload your model to Hugging Face Hub. Use Unbabel/wmt22-comet-da as example. Then you can use your model directly from the hub.

Config Files

In COMET uses PyTorch-Lightning to train models. With that said, YAML files will be used to initialize various Lightning objects.

Config files for Lightning classes:

Then after setting up the these Lightning classes you can setup your model architecture. There are 4 different model architectures:

For each class you can find a config example in configs/models/. The init_args will then be used to initialize your model/metric.

Input Data

To train your models you need to pass a train set and a validation set using the training_data and validation_data arguments respectively.

Depending on the underlying models your data need to be formatted differently. RegressionMetrics expect the following format:

src mt ref score
isto é um exemplo this is a example this is an example 0.2

For ReferencelessRegression you can drop the ref column but, if passed, it is ignored.

Finally, Ranking Metrics expect two contrastive examples. E.g:

src neg pos ref
isto é um exemplo this is a example this is an example this is an example

where pos column contains a postive sample and neg a negative sample.

You can check the available data from previous WMT editions in here.

Available Encoders

All COMET models depend on an underlying encoder. We currently support the following encoders:

  • BERT

  • XLM-RoBERTa

  • MiniLM

  • XLM-RoBERTa-XL

  • RemBERT

You can change the underlying encoder architecture using the encoder_model argument in your config file. Then, you can select any compatible model from HuggingFace Transformers using the pretrained_model argument.