kiwi.systems.encoders.bert

Module Contents

Classes

TransformersTextEncoder

Encode a field, handling vocabulary, tokenization and embeddings.

BertEncoder

BERT model as presented in Google’s paper and using Hugging Face’s code

kiwi.systems.encoders.bert.logger
class kiwi.systems.encoders.bert.TransformersTextEncoder(tokenizer_name, is_source=False)

Bases: kiwi.data.encoders.field_encoders.TextEncoder

Encode a field, handling vocabulary, tokenization and embeddings.

Heavily inspired in torchtext and torchnlp.

fit_vocab(self, samples, vocab_size=None, vocab_min_freq=0, embeddings_name=None, keep_rare_words_with_embeddings=False, add_embeddings_vocab=False)
class kiwi.systems.encoders.bert.BertEncoder(vocabs: Dict[str, Vocabulary], config: Config, pre_load_model: bool = True)

Bases: kiwi.systems._meta_module.MetaModule

BERT model as presented in Google’s paper and using Hugging Face’s code

References

https://arxiv.org/abs/1810.04805

class Config

Bases: kiwi.utils.io.BaseConfig

Base class for all pydantic configs. Used to configure base behaviour of configs.

model_name :Union[str, Path] = bert-base-multilingual-cased

Pre-trained BERT model to use.

use_mismatch_features :bool = False

Use Alibaba’s mismatch features.

use_predictor_features :bool = False

Use features originally proposed in the Predictor model.

interleave_input :bool = False

Concatenate SOURCE and TARGET without internal padding (111222000 instead of 111002220)

freeze :bool = False

Freeze BERT during training.

use_mlp :bool = True

Apply a linear layer on top of BERT.

hidden_size :int = 100

Size of the linear layer on top of BERT.

scalar_mix_dropout :confloat(ge=0.0, le=1.0) = 0.1
scalar_mix_layer_norm :bool = True
fix_relative_path(cls, v)
no_implementation(cls, v)
load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True)

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

Returns

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type

NamedTuple with missing_keys and unexpected_keys fields

classmethod input_data_encoders(cls, config: Config)
size(self, field=None)
forward(self, batch_inputs, *args, include_target_logits=False, include_source_logits=False)
static concat_input(source_batch, target_batch, pad_id)

Concatenate the target + source embeddings into one tensor.

Returns

concatenation of embeddings, mask of target (as ones) and source

(as zeroes) and concatenation of attention_mask

static split_outputs(features: Tensor, batch_inputs: MultiFieldBatch, interleaved: bool = False) → Dict[str, Tensor]

Split features back into sentences A and B.

Parameters
  • features – BERT’s output: [CLS] target [SEP] source [SEP]. Shape of (bs, 1 + target_len + 1 + source_len + 1, 2)

  • batch_inputs – the regular batch object, containing source and target batches

  • interleaved – whether the concat strategy was interleaved

Returns

dict of tensors for source and target.

static interleave_input(source_batch, target_batch, pad_id)

Interleave the source + target embeddings into one tensor.

This means making the input as [batch, target [SEP] source].

Returns

interleave of embds, mask of target (as zeroes) and source (as ones)

and concatenation of attention_mask.

static get_mismatch_features(logits, target, pred)