Source code for kiwi.models.nuqe

#  OpenKiwi: Open-Source Machine Translation Quality Estimation
#  Copyright (C) 2019 Unbabel <openkiwi@unbabel.com>
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU Affero General Public License as published
#  by the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU Affero General Public License for more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from kiwi import constants as const
from kiwi.data.fieldsets.quetch import build_fieldset
from kiwi.models.model import Model
from kiwi.models.quetch import QUETCH
from kiwi.models.utils import make_loss_weights


[docs]@Model.register_subclass class NuQE(QUETCH): """Neural Quality Estimation (NuQE) model for word level quality estimation.""" title = 'NuQE' def __init__(self, vocabs, **kwargs): self.source_emb = None self.target_emb = None self.linear_1 = None self.linear_2 = None self.linear_3 = None self.linear_4 = None self.linear_5 = None self.linear_6 = None self.linear_out = None self.embeddings_dropout = None self.dropout = None self.gru1 = None self.gru2 = None self.is_built = False super().__init__(vocabs, **kwargs)
[docs] def build(self, source_vectors=None, target_vectors=None): nb_classes = self.config.nb_classes # FIXME: Remove dependency on magic number weight = make_loss_weights( nb_classes, const.BAD_ID, self.config.bad_weight ) self._loss = nn.CrossEntropyLoss( weight=weight, ignore_index=self.config.tags_pad_id, reduction='sum' ) # Embeddings layers: self._build_embeddings(source_vectors, target_vectors) feature_set_size = ( self.config.source_embeddings_size + self.config.target_embeddings_size ) * self.config.window_size l1_dim = self.config.hidden_sizes[0] l2_dim = self.config.hidden_sizes[1] l3_dim = self.config.hidden_sizes[2] l4_dim = self.config.hidden_sizes[3] nb_classes = self.config.nb_classes dropout = self.config.dropout # Linear layers self.linear_1 = nn.Linear(feature_set_size, l1_dim) self.linear_2 = nn.Linear(l1_dim, l1_dim) self.linear_3 = nn.Linear(2 * l2_dim, l2_dim) self.linear_4 = nn.Linear(l2_dim, l2_dim) self.linear_5 = nn.Linear(2 * l2_dim, l3_dim) self.linear_6 = nn.Linear(l3_dim, l4_dim) # Output layer self.linear_out = nn.Linear(l4_dim, nb_classes) # Recurrent Layers self.gru_1 = nn.GRU( l1_dim, l2_dim, bidirectional=True, batch_first=True ) self.gru_2 = nn.GRU( l2_dim, l2_dim, bidirectional=True, batch_first=True ) # Dropout after linear layers self.dropout_in = nn.Dropout(dropout) self.dropout_out = nn.Dropout(dropout) # Explicit initializations nn.init.xavier_uniform_(self.linear_1.weight) nn.init.xavier_uniform_(self.linear_2.weight) nn.init.xavier_uniform_(self.linear_3.weight) nn.init.xavier_uniform_(self.linear_4.weight) nn.init.xavier_uniform_(self.linear_5.weight) nn.init.xavier_uniform_(self.linear_6.weight) # nn.init.xavier_uniform_(self.linear_out) nn.init.constant_(self.linear_1.bias, 0.0) nn.init.constant_(self.linear_2.bias, 0.0) nn.init.constant_(self.linear_3.bias, 0.0) nn.init.constant_(self.linear_4.bias, 0.0) nn.init.constant_(self.linear_5.bias, 0.0) nn.init.constant_(self.linear_6.bias, 0.0) # nn.init.constant_(self.linear_out.bias, 0.) self.is_built = True
[docs] @staticmethod def fieldset(*args, **kwargs): return build_fieldset(*args, **kwargs)
[docs] @staticmethod def from_options(vocabs, opts): model = NuQE( vocabs=vocabs, predict_target=opts.predict_target, predict_gaps=opts.predict_gaps, predict_source=opts.predict_source, source_embeddings_size=opts.source_embeddings_size, target_embeddings_size=opts.target_embeddings_size, hidden_sizes=opts.hidden_sizes, bad_weight=opts.bad_weight, window_size=opts.window_size, max_aligned=opts.max_aligned, dropout=opts.dropout, embeddings_dropout=opts.embeddings_dropout, freeze_embeddings=opts.freeze_embeddings, ) return model
[docs] def forward(self, batch): assert self.is_built if self.config.predict_source: align_side = const.SOURCE_TAGS else: align_side = const.TARGET_TAGS target_input, source_input, nb_alignments = self.make_input( batch, align_side ) # # Source Branch # # (bs, ts, aligned, window) -> (bs, ts, aligned, window, emb) h_source = self.source_emb(source_input) h_source = self.embeddings_dropout(h_source) if len(h_source.shape) == 5: # (bs, ts, aligned, window, emb) -> (bs, ts, window, emb) h_source = h_source.sum(2, keepdim=False) / nb_alignments.unsqueeze( -1 ).unsqueeze(-1) # (bs, ts, window, emb) -> (bs, ts, window * emb) h_source = h_source.view(source_input.size(0), source_input.size(1), -1) # # Target Branch # # (bs, ts * window) -> (bs, ts * window, emb) h_target = self.target_emb(target_input) h_target = self.embeddings_dropout(h_target) if len(h_target.shape) == 5: # (bs, ts, aligned, window, emb) -> (bs, ts, window, emb) h_target = h_target.sum(2, keepdim=False) / nb_alignments.unsqueeze( -1 ).unsqueeze(-1) # (bs, ts * window, emb) -> (bs, ts, window * emb) h_target = h_target.view(target_input.size(0), target_input.size(1), -1) # # POS tags branches # feature_set = (h_source, h_target) # # Merge Branches # # (bs, ts, window * emb) -> (bs, ts, 2 * window * emb) h = torch.cat(feature_set, dim=-1) h = self.dropout_in(h) # # First linears # # (bs, ts, 2 * window * emb) -> (bs, ts, l1_dim) h = F.relu(self.linear_1(h)) # (bs, ts, l1_dim) -> (bs, ts, l1_dim) h = F.relu(self.linear_2(h)) # # First recurrent # # (bs, ts, l1_dim) -> (bs, ts, l1_dim) h, _ = self.gru_1(h) # # Second linears # # (bs, ts, l1_dim) -> (bs, ts, l2_dim) h = F.relu(self.linear_3(h)) # (bs, ts, l2_dim) -> (bs, ts, l2_dim) h = F.relu(self.linear_4(h)) # # Second recurrent # # (bs, ts, l2_dim) -> (bs, ts, l2_dim) h, _ = self.gru_2(h) # # Third linears # # (bs, ts, l1_dim) -> (bs, ts, l3_dim) h = F.relu(self.linear_5(h)) # (bs, ts, l3_dim) -> (bs, ts, l4_dim) h = F.relu(self.linear_6(h)) h = self.dropout_out(h) # # Output layer # # (bs, ts, hs) -> (bs, ts, 2) h = self.linear_out(h) # h = F.log_softmax(h, dim=-1) outputs = OrderedDict() if self.config.predict_target: outputs[const.TARGET_TAGS] = h if self.config.predict_gaps: outputs[const.GAP_TAGS] = h if self.config.predict_source: outputs[const.SOURCE_TAGS] = h return outputs