# OpenKiwi: Open-Source Machine Translation Quality Estimation
# Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#
from functools import partial
from torchtext import data
from kiwi import constants as const
from kiwi.data.corpus import Corpus
from kiwi.data.fields.alignment_field import AlignmentField
from kiwi.data.fields.qe_field import QEField
from kiwi.data.fieldsets.fieldset import Fieldset
from kiwi.data.tokenizers import align_tokenizer, tokenizer
[docs]def build_fieldset():
fs = Fieldset()
source_vocab_options = dict(
min_freq='source_vocab_min_frequency', max_size='source_vocab_size'
)
target_vocab_options = dict(
min_freq='target_vocab_min_frequency', max_size='target_vocab_size'
)
source_field = QEField(tokenize=tokenizer)
target_field = QEField(tokenize=tokenizer)
source_pos = QEField(tokenize=tokenizer)
target_pos = QEField(tokenize=tokenizer)
target_tags_field = data.Field(
tokenize=tokenizer, pad_token=None, unk_token=None
)
fs.add(
name=const.SOURCE,
field=source_field,
file_option_suffix='_source',
required=Fieldset.ALL,
vocab_options=source_vocab_options,
)
fs.add(
name=const.TARGET,
field=target_field,
file_option_suffix='_target',
required=Fieldset.ALL,
vocab_options=target_vocab_options,
)
fs.add(
name=const.ALIGNMENTS,
field=AlignmentField(tokenize=align_tokenizer, use_vocab=False),
file_option_suffix='_alignments',
required=Fieldset.ALL,
)
fs.add(
name=const.TARGET_TAGS,
field=target_tags_field,
file_option_suffix='_target_tags',
required=[Fieldset.TRAIN, Fieldset.VALID],
)
fs.add(
name=const.SOURCE_POS,
field=source_pos,
file_option_suffix='_source_pos',
required=None,
)
fs.add(
name=const.TARGET_POS,
field=target_pos,
file_option_suffix='_target_pos',
required=None,
)
target_stacked = data.Field(tokenize=tokenizer)
fs.add(
name=const.TARGET_STACKED,
field=target_stacked,
file_option_suffix='_target_stacked',
file_reader=partial(Corpus.read_tabular_file, extract_column=1),
required=None,
)
target_parse_heads = data.Field(tokenize=tokenizer, use_vocab=False)
target_parse_relations = data.Field(tokenize=tokenizer)
fs.add(
name=const.TARGET_PARSE_HEADS,
field=target_parse_heads,
file_option_suffix='_target_parse',
file_reader=partial(Corpus.read_tabular_file, extract_column=1),
required=None,
)
fs.add(
name=const.TARGET_PARSE_RELATIONS,
field=target_parse_relations,
file_option_suffix='_target_parse',
file_reader=partial(Corpus.read_tabular_file, extract_column=2),
required=None,
)
target_ngram_left = data.Field(tokenize=tokenizer)
target_ngram_right = data.Field(tokenize=tokenizer)
fs.add(
name=const.TARGET_NGRAM_LEFT,
field=target_ngram_left,
file_option_suffix='_target_ngram',
file_reader=partial(Corpus.read_tabular_file, extract_column=1),
required=None,
)
fs.add(
name=const.TARGET_NGRAM_RIGHT,
field=target_ngram_right,
file_option_suffix='_target_ngram',
file_reader=partial(Corpus.read_tabular_file, extract_column=2),
required=None,
)
return fs
#
# def build_test_dataset(options):
# source_field = QEField(tokenize=tokenizer)
# target_field = QEField(tokenize=tokenizer)
# source_pos = QEField(tokenize=tokenizer)
# target_pos = QEField(tokenize=tokenizer)
# alignments_field = AlignmentField(
# tokenize=align_tokenizer, use_vocab=False)
# target_tags_field = data.Field(
# tokenize=tokenizer, pad_token=None, unk_token=None
# )
# target_parse_heads = data.Field(tokenize=tokenizer, use_vocab=False)
# target_parse_relations = data.Field(tokenize=tokenizer)
# target_ngram_left = data.Field(tokenize=tokenizer)
# target_ngram_right = data.Field(tokenize=tokenizer)
# target_stacked = data.Field(tokenize=tokenizer)
#
# fields = {
# const.SOURCE: source_field,
# const.TARGET: target_field,
# const.ALIGNMENTS: alignments_field,
# const.TARGET_TAGS: target_tags_field
# }
#
# test_files = {
# const.SOURCE: options.test_source,
# const.TARGET: options.test_target,
# const.TARGET_TAGS: options.test_target_tags,
# const.ALIGNMENTS: options.test_alignments,
# }
#
# if options.test_target_parse:
# parse_fields = {
# const.TARGET_PARSE_HEADS: target_parse_heads,
# const.TARGET_PARSE_RELATIONS: target_parse_relations,
# }
# parse_file_fields = [
# '',
# '',
# '',
# '',
# '',
# const.TARGET_PARSE_HEADS,
# const.TARGET_PARSE_RELATIONS,
# ]
#
# if options.test_target_ngram:
# ngram_fields = {
# const.TARGET_NGRAM_LEFT: target_ngram_left,
# const.TARGET_NGRAM_RIGHT: target_ngram_right,
# }
# ngram_file_fields = [
# '', '', '', '', '', '', '', '', '', '', '', '', '',
# const.TARGET_NGRAM_LEFT,
# const.TARGET_NGRAM_RIGHT,
# ]
#
# if options.test_target_stacked:
# stacked_fields = {const.TARGET_STACKED: target_stacked}
# stacked_file_fields = [const.TARGET_STACKED]
#
# if options.test_source_pos:
# fields[const.SOURCE_POS] = source_pos
# test_files[const.SOURCE_POS] = options.test_source_pos
#
# if options.test_target_pos:
# fields[const.TARGET_POS] = target_pos
# test_files[const.TARGET_POS] = options.test_target_pos
#
# if options.test_target_parse:
# test_target_parse_file = options.test_target_parse
#
# if options.test_target_ngram:
# test_target_ngram_file = options.test_target_ngram
#
# if options.test_target_stacked:
# test_target_stacked_file = options.test_target_stacked
#
# def filter_len(x):
# return (
# options.source_min_length
# <= len(x.source)
# <= options.source_max_length
# ) and (
# options.target_min_length
# <= len(x.target)
# <= options.target_max_length
# )
#
# test_examples = Corpus.from_files(fields=fields, files=test_files)
# if options.test_target_parse:
# test_examples.paste_fields(
# Corpus.from_tabular_file(
# fields=parse_fields,
# file_fields=parse_file_fields,
# file_path=test_target_parse_file,
# )
# )
# if options.test_target_ngram:
# test_examples.paste_fields(
# Corpus.from_tabular_file(
# fields=ngram_fields,
# file_fields=ngram_file_fields,
# file_path=test_target_ngram_file,
# )
# )
# if options.test_target_stacked:
# test_examples.paste_fields(
# Corpus.from_tabular_file(
# fields=stacked_fields,
# file_fields=stacked_file_fields,
# file_path=test_target_stacked_file,
# )
# )
#
# dataset = QEDataset(
# examples=test_examples,
# fields=test_examples.dataset_fields,
# filter_pred=filter_len,
# )
#
# return dataset