kiwi.modules.common.attention

Module Contents

Classes

Attention

Generic Attention Implementation.

class kiwi.modules.common.attention.Attention(scorer, dropout=0)

Bases: torch.nn.Module

Generic Attention Implementation.

  1. Use query and keys to compute scores (energies)

  2. Apply softmax to get attention probabilities

  3. Perform a dot product between values and probabilites (outputs)

Parameters
  • scorer (kiwi.modules.common.Scorer) – a scorer object

  • dropout (float) – dropout rate after softmax (default: 0.)

forward(self, query, keys, values=None, mask=None)

Compute the attention between query, keys and values.

Parameters
  • query (torch.Tensor) – set of query vectors with shape of (batch_size, …, target_len, hidden_size)

  • keys (torch.Tensor) – set of keys vectors with shape of (batch_size, …, source_len, hidden_size)

  • values (torch.Tensor, optional) – set of values vectors with shape of: (batch_size, …, source_len, hidden_size). If None, keys are treated as values. Default: None

  • mask (torch.ByteTensor, optional) – Tensor representing valid positions. If None, all positions are considered valid. Shape of (batch_size, target_len)

Returns

combination of values and attention probabilities.

Shape of (batch_size, …, target_len, hidden_size)

torch.Tensor: attention probabilities between query and keys.

Shape of (batch_size, …, target_len, source_len)

Return type

torch.Tensor