# -*- coding: utf-8 -*-
r"""
BERT Encoder
==============
Pretrained BERT encoder from Hugging Face.
"""
from argparse import Namespace
from typing import Dict
import torch
from transformers import AutoModel
from comet.models.encoders.encoder_base import Encoder
from comet.tokenizers import HFTextEncoder
from torchnlp.utils import lengths_to_mask
[docs]class BERTEncoder(Encoder):
""" BERT encoder.
:param tokenizer: BERT text encoder.
:param hparams: ArgumentParser.
Check the available models here:
https://huggingface.co/transformers/pretrained_models.html
"""
def __init__(self, tokenizer: HFTextEncoder, hparams: Namespace,) -> None:
super().__init__(tokenizer)
self.model = AutoModel.from_pretrained(hparams.pretrained_model)
self._output_units = self.model.config.hidden_size
self._n_layers = self.model.config.num_hidden_layers + 1
self._max_pos = self.model.config.max_position_embeddings
[docs] @classmethod
def from_pretrained(cls, hparams: Namespace) -> Encoder:
""" Function that loads a pretrained encoder from Hugging Face.
:param hparams: Namespace.
Returns:
- Encoder model
"""
tokenizer = HFTextEncoder(model=hparams.pretrained_model)
model = BERTEncoder(tokenizer=tokenizer, hparams=hparams)
return model
[docs] def freeze_embeddings(self) -> None:
""" Frezees the embedding layer of the network to save some memory while training. """
for param in self.model.embeddings.parameters():
param.requires_grad = False
[docs] def layerwise_lr(self, lr: float, decay: float):
"""
returns grouped model parameters with layer-wise decaying learning rate
"""
opt_parameters = [
{
"params": self.model.embeddings.parameters(),
"lr": lr * decay ** (self.num_layers),
}
]
opt_parameters += [
{
"params": self.model.encoder.layer[l].parameters(),
"lr": lr * decay ** (self.num_layers - 1 - l),
}
for l in range(self.num_layers - 1)
]
return opt_parameters
[docs] def forward(
self, tokens: torch.Tensor, lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Encodes a batch of sequences.
:param tokens: Torch tensor with the input sequences [batch_size x seq_len].
:param lengths: Torch tensor with the length of each sequence [seq_len].
Returns:
- 'sentemb': tensor [batch_size x 1024] with the sentence encoding.
- 'wordemb': tensor [batch_size x seq_len x 1024] with the word level embeddings.
- 'mask': torch.Tensor [seq_len x batch_size]
- 'all_layers': List with the word_embeddings returned by each layer.
- 'extra': tuple with the last_hidden_state [batch_size x seq_len x hidden_size],
the pooler_output representing the entire sentence and the word embeddings for
all BERT layers (list of tensors [batch_size x seq_len x hidden_size])
"""
mask = lengths_to_mask(lengths, device=tokens.device)
last_hidden_states, pooler_output, all_layers = self.model(tokens, mask, output_hidden_states=True)
return {
"sentemb": pooler_output,
"wordemb": last_hidden_states,
"all_layers": all_layers,
"mask": mask,
"extra": (last_hidden_states, pooler_output, all_layers),
}