Source code for kiwi.models.utils

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <>
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  GNU Affero General Public License for more details.
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <>.

import logging
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from more_itertools import first, flatten
from torch.autograd import Function
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import pad_packed_sequence as unpack

from kiwi import constants as const

logger = logging.getLogger(__name__)

[docs]def unroll(list_of_lists): """ :param list_of_lists: a list that contains lists :param rec: unroll recursively :return: a flattened list """ if isinstance(first(list_of_lists), (np.ndarray, list)): return list(flatten(list_of_lists)) return list_of_lists
[docs]def convolve_tensor(sequences, window_size, pad_value=0): """Convolve a sequence and apply padding :param sequence: 2D tensor :param window_size: filter length :param pad_value: int value used as padding :return: 3D tensor, where the last dimension has size window_size """ pad = (window_size // 2,) * 2 t = F.pad(sequences, pad=pad, value=pad_value) t = t.unfold(1, window_size, 1) # For 3D tensors # torch.nn.ConstantPad2d((0, 0, 1, 1), 0)(x).unfold(1, 3, 1) # F.pad(x, (0, 0, 1, 1), value=0).unfold(1, 3, 1) return t
# def convolve_sequence(sequence, window_size, pad_value=0): # """Convolve a sequence and apply padding # # :param sequence: list of ids # :param window_size: filter length # :param pad_value: int value used as padding # :return: list of lists with size of window_size # """ # pad = [pad_value for _ in range(window_size // 2)] # pad_sequence = pad + sequence + pad # return list(windowed(pad_sequence, window_size, fillvalue=pad_value))
[docs]def align_tensor( tensor, alignments, max_aligned, unaligned_idx, padding_idx, pad_size, target_length=None, ): alignments = [ map_alignments_to_target(sample, target_length=target_length) for sample in alignments ] # aligned_tensor = tensor.new_full( # (tensor.shape[0], pad_size, max_aligned, tensor.shape[2]), # padding_idx) aligned = [ align_source( sample, alignment, max_aligned, unaligned_idx, padding_idx, pad_size ) for sample, alignment in zip(tensor, alignments) ] aligned_tensor = torch.stack([sample[0] for sample in aligned]) nb_alignments = torch.stack([sample[1] for sample in aligned]) return aligned_tensor, nb_alignments
[docs]def map_alignments_to_target(src2tgt_alignments, target_length=None): """Maps a target index to a list of source indexes. Args: src2tgt_alignments (list): list of tuples with source, target indexes. target_length: size of the target side; if None, the highest index in the alignments is used. Returns: A list of size target_length where position i refers to the i-th target token and contains a list of source indexes aligned to it. """ if target_length is None: if not src2tgt_alignments: target_length = 0 else: target_length = 1 + max(src2tgt_alignments, key=lambda a: a[1])[1] trg2src = [None] * target_length for source, target in src2tgt_alignments: if not trg2src[target]: trg2src[target] = [] trg2src[target].append(source) return trg2src
[docs]def align_source( source, trg2src_alignments, max_aligned, unaligned_idx, padding_idx, pad_size, ): assert len(source.shape) == 2 window_size = source.shape[1] assert len(trg2src_alignments) <= pad_size aligned_source = source.new_full( (pad_size, max_aligned, window_size), padding_idx ) unaligned = source.new_full((window_size,), unaligned_idx) nb_alignments = source.new_ones(pad_size, dtype=torch.float) for i, source_positions in enumerate(trg2src_alignments): if not source_positions: aligned_source[i, 0] = unaligned else: selected = torch.index_select( source, 0, torch.tensor( source_positions[:max_aligned], device=source.device ), ) aligned_source[i, : len(selected)] = selected # counts how many tokens is a target token aligned to nb_alignments[i] = len(selected) return aligned_source, nb_alignments
[docs]def apply_packed_sequence(rnn, embedding, lengths): """ Runs a forward pass of embeddings through an rnn using packed sequence. Args: rnn: The RNN that that we want to compute a forward pass with. embedding (FloatTensor b x seq x dim): A batch of sequence embeddings. lengths (LongTensor batch): The length of each sequence in the batch. Returns: output: The output of the RNN `rnn` with input `embedding` """ # Sort Batch by sequence length lengths_sorted, permutation = torch.sort(lengths, descending=True) embedding_sorted = embedding[permutation] # Use Packed Sequence embedding_packed = pack(embedding_sorted, lengths_sorted, batch_first=True) outputs_packed, (hidden, cell) = rnn(embedding_packed) outputs_sorted, _ = unpack(outputs_packed, batch_first=True) # Restore original order _, permutation_rev = torch.sort(permutation, descending=False) outputs = outputs_sorted[permutation_rev] hidden, cell = hidden[:, permutation_rev], cell[:, permutation_rev] return outputs, (hidden, cell)
[docs]def replace_token(target, old, new): """Replaces old tokens with new. args: target (LongTensor) old (int): The token to be replaced by new new (int): The token used to replace old """ return target.masked_fill(target == old, new)
[docs]def make_loss_weights(nb_classes, target_idx, weight): """Creates a loss weight vector for nn.CrossEntropyLoss args: nb_classes: Number of classes target_idx: ID of the target (reweighted) class weight: Weight of the target class returns: weights (FloatTensor): Weight Tensor of shape `nb_classes` such that `weights[target_idx] = weight` `weights[other_idx] = 1.0` """ weights = torch.ones(nb_classes) weights[target_idx] = weight return weights
[docs]def load_torch_file(file_path): file_path = Path(file_path) if not file_path.exists(): raise FileNotFoundError('Torch file not found: {}'.format(file_path)) file_dict = torch.load( str(file_path), map_location=lambda storage, loc: storage ) if isinstance(file_dict, Path): # Resolve cases where file is just a link to another torch file linked_path = file_dict if not linked_path.exists(): relative_path = ( file_path.with_name( / const.MODEL_FILE ) if relative_path.exists(): linked_path = relative_path return load_torch_file(linked_path) return file_dict
[docs]class GradientMul(Function):
[docs] @staticmethod def forward(ctx, x, constant=0): ctx.constant = constant return x
[docs] @staticmethod def backward(ctx, grad): return ctx.constant * grad, None
gradient_mul = GradientMul.apply