Source code for comet.modules.feedforward

# -*- coding: utf-8 -*-
# Copyright (C) 2020 Unbabel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Feed Forward
==============
    Feed Forward Neural Network module that can be used for classification or regression
"""
from typing import List, Optional

import torch
from torch import nn


[docs]class FeedForward(nn.Module): """ Feed Forward Neural Network. :param in_dim: Number input features. :param out_dim: Number of output features. Default is just a score. :param hidden_sizes: List with hidden layer sizes. :param activations: Name of the activation function to be used in the hidden layers. :param final_activation: Name of the final activation function if any. :param dropout: dropout to be used in the hidden layers. """ def __init__( self, in_dim: int, out_dim: int = 1, hidden_sizes: List[int] = [3072, 768], activations: str = "Sigmoid", final_activation: Optional[str] = None, dropout: float = 0.1, ) -> None: super().__init__() modules = [] modules.append(nn.Linear(in_dim, hidden_sizes[0])) modules.append(self.build_activation(activations)) modules.append(nn.Dropout(dropout)) for i in range(1, len(hidden_sizes)): modules.append(nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])) modules.append(self.build_activation(activations)) modules.append(nn.Dropout(dropout)) modules.append(nn.Linear(hidden_sizes[-1], int(out_dim))) if final_activation is not None: modules.append(self.build_activation(final_activation)) self.ff = nn.Sequential(*modules) def build_activation(self, activation: str) -> nn.Module: if hasattr(nn, activation): return getattr(nn, activation)()
[docs] def forward(self, in_features: torch.Tensor) -> torch.Tensor: return self.ff(in_features)