Source code for kiwi.models.modules.scorer

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import torch
from torch import nn


[docs]class Scorer(nn.Module): """Score function for Attention module. """ def __init__(self): super().__init__()
[docs] def forward(self, query, keys): """Computes Scores for each key given the query. args: query: FloatTensor batch x n keys: FloatTensor batch x seq_length x m ret: scores: FloatTensor batch x seq_length """ raise NotImplementedError
[docs]class MLPScorer(Scorer): """Implements a score function based on a Multilayer Perceptron. """ def __init__(self, query_size, key_size, layers=2, nonlinearity=nn.Tanh): super().__init__() layer_list = [] size = query_size + key_size for i in range(layers): size_next = size // 2 if i < layers - 1 else 1 layer_list.append( nn.Sequential(nn.Linear(size, size_next), nonlinearity()) ) size = size_next self.layers = nn.ModuleList(layer_list)
[docs] def forward(self, query, keys): layer_in = torch.cat([query.unsqueeze(1).expand_as(keys), keys], dim=-1) layer_in = layer_in.reshape(-1, layer_in.size(-1)) for layer in self.layers: layer_in = layer(layer_in) out = layer_in.reshape(keys.size()[:-1]) return out