Source code for kiwi.data.qe_dataset

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

from torchtext import data


[docs]class QEDataset(data.Dataset): """Defines a dataset for quality estimation. Based on the WMT 201X."""
[docs] @staticmethod def sort_key(ex): # don't work for pack_padded_sequences # return data.interleave_keys(len(ex.source), len(ex.target)) return len(ex.source)
def __init__(self, examples, fields, filter_pred=None): """Create a dataset from a list of Examples and Fields. Arguments: examples: List of Examples. fields (List(tuple(str, Field))): The Fields to use in this tuple. The string is a field name, and the Field is the associated field. filter_pred (callable or None): Use only examples for which filter_pred(example) is True, or use all examples if None. Default is None. """ # ensure that examples is not a generator examples = list(examples) super().__init__(examples, fields, filter_pred) def __getstate__(self): """For pickling. Copied from OpenNMT-py DatasetBase implementation. """ return self.__dict__ def __setstate__(self, _d): """For pickling. Copied from OpenNMT-py DatasetBase implementation. """ self.__dict__.update(_d) def __reduce_ex__(self, proto): """For pickling. Copied from OpenNMT-py DatasetBase implementation. """ return super(QEDataset, self).__reduce_ex__(proto)
[docs] def split( self, split_ratio=0.7, stratified=False, strata_field='label', random_state=None, ): datasets = super().split( split_ratio, stratified, strata_field, random_state ) casted_datasets = [ QEDataset(examples=dataset.examples, fields=dataset.fields) for dataset in datasets ] return casted_datasets